Description Link to heading
834. Sum of Distances in Tree (Hard)
Solution Link to heading
To find the sum of distances to a single node (e.g., $0$) denoted as $dp[0]$, we can easily use DFS to compute it with a time complexity of $O(n)$. However, finding the sum of distances for all $n$ nodes would take $O(n^2)$ time, which would obviously result in a timeout for large graphs.
However, we notice that there exists a recurrence relationship between the parent node $j$’s $dp[j]$ and the child node $i$’s $dp[i]$. Specifically, $dp[i] = dp[j] - cnt[i] + n - cnt[i]$ (since nodes $i$ and $j$ are directly connected).
So, the remaining problem is how to calculate cnt[i]
, which represents the number of nodes in the subtree rooted at the current node in the tree represented as an undirected graph. Please refer to the Tree Organized as an Undirected Graph for more details.
Code Link to heading
class Solution {
public:
int count(vector<vector<int>> &tree, vector<int> &dis, vector<int> &cnt, int pa, int grandpa) {
int res = 1;
for (int child : tree[pa]) {
if (child == grandpa) { // prevent repeated traversal
continue;
}
dis[child] = dis[pa] + 1;
res += count(tree, dis, cnt, child, pa);
}
cnt[pa] = res;
return res;
}
vector<int> sumOfDistancesInTree(int n, vector<vector<int>> &edges) {
vector<vector<int>> tree(n);
for (auto &vec : edges) {
tree[vec[0]].push_back(vec[1]);
tree[vec[1]].push_back(vec[0]); // push_back twice to build undirected graph
}
vector<int> cnt(n);
vector<int> dp(n);
vector<int> dis(n);
count(tree, dis, cnt, 0, -1);
for (int i = 0; i < n; ++i) {
dp[0] += dis[i];
}
queue<pair<int, int>> q;
q.push({0, -1}); // pa, grandpa
while (!q.empty()) {
auto [pa, grandpa] = q.front();
q.pop();
for (int child : tree[pa]) {
if (child == grandpa) { // prevent repeated bfs
continue;
}
dp[child] = dp[pa] + n - 2 * cnt[child];
q.push({child, pa});
}
}
return dp;
}
};