题意:
给定N个点的树,点有权值,求多少个点对(u,v)满足u到v的路径上点权值最大值减最小值不大于给定的K
分析:
1、分治的时候子树的size不准仍然可以求根(可以证明),或者用一个更准确的方法写
void divide(int now)
{
vis[now] = true;
res += cal(now, -1e9, 1e9);
for(int i = head[now]; ~i; i = e[i].nex)
{
int ep = e[i].ep;
if(vis[ep]) continue;
res -= cal(ep, v[now], v[now]);
maxp[root = 0] = sizz = siz[ep]; // 或者 = siz[ep] > size[now] ? sizz - siz[now] : siz[ep]
getroot(ep, -1);
divide(root);
}
}
2、扣除两个点来自同一子树的情况,通过容斥原理,把子树再求一遍。
代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> par;
#define mp make_pair
const int maxn = 2e5;
const int inf = 0x3f3f3f3f;
struct edge
{
int ep, v, nex;
edge(int a = 0, int b = 0, int c = 0) : ep(a), v(b), nex(c){}
}e[maxn << 1];
int head[maxn], tot;
void addedge(int sp, int ep, int v = 1)
{
e[tot] = edge(ep, v, head[sp]);
head[sp] = tot++;
e[tot] = edge(sp, v, head[ep]);
head[ep] = tot++;
}
int v[maxn], n, k;
int siz[maxn], maxp[maxn], sizz, root;
bool vis[maxn];
par path[maxn];
int cur;
ll res;
void init()
{
memset(head, -1, sizeof(head)); tot = 0;
memset(vis, false, sizeof(vis)); res = 0;
}
void getroot(int now, int f)
{
siz[now] = 1, maxp[now] = 0;
for(int i = head[now]; ~i; i = e[i].nex)
{
int ep = e[i].ep;
if(ep == f || vis[ep]) continue;
getroot(ep, now);
siz[now] += siz[ep];
maxp[now] = max(maxp[now], siz[ep]);
}
maxp[now] = max(maxp[now], sizz - siz[now]);
if(maxp[now] < maxp[root]) root = now;
}
void dfs_maxmin(int now, int f, int maxx, int minn)
{
maxx = max(maxx, v[now]), minn = min(minn, v[now]);
if(maxx - minn <= k) path[cur++] = mp(minn, maxx);
for(int i = head[now]; ~i; i = e[i].nex)
{
int ep = e[i].ep;
if(ep == f || vis[ep]) continue;
dfs_maxmin(ep, now, maxx, minn);
}
}
ll cal(int now, int maxx, int minn)
{
ll ans = 0;
cur = 0;
dfs_maxmin(now, -1, maxx, minn);
sort(path, path + cur);
for(int i = cur - 1; i >= 0; i--)
{
int num = lower_bound(path, path + i, mp(path[i].second - k, 0)) - path;
ans += i - num;
}
return ans;
}
void divide(int now)
{
vis[now] = true;
res += cal(now, -1e9, 1e9);
for(int i = head[now]; ~i; i = e[i].nex)
{
int ep = e[i].ep;
if(vis[ep]) continue;
res -= cal(ep, v[now], v[now]);
maxp[root = 0] = sizz = siz[ep];
getroot(ep, -1);
divide(root);
}
}
int main()
{
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
int t; cin >> t;
while(t--)
{
init();
cin >> n >> k;
for(int i = 1; i <= n; i++)
cin >> v[i];
for(int i = 1; i < n; i++)
{
int sp, ep;
cin >> sp >> ep;
addedge(sp, ep);
}
maxp[root = 0] = sizz = n;
getroot(1, -1);
divide(root);
cout << res * 2 << endl;
}
}