B - Happy King HDU - 5314 | 点分治 + 容斥原理

题意:
给定N个点的树,点有权值,求多少个点对(u,v)满足u到v的路径上点权值最大值减最小值不大于给定的K

分析:
1、分治的时候子树的size不准仍然可以求根(可以证明),或者用一个更准确的方法写

    void divide(int now)
    {
        vis[now] = true;
        res += cal(now, -1e9, 1e9);
        for(int i = head[now]; ~i; i = e[i].nex)
        {
            int ep = e[i].ep;
            if(vis[ep]) continue;
            res -= cal(ep, v[now], v[now]);
            maxp[root = 0] = sizz = siz[ep]; // 或者 = siz[ep] > size[now] ? sizz - siz[now] : siz[ep]
            getroot(ep, -1);
            divide(root);
        }
    }

2、扣除两个点来自同一子树的情况,通过容斥原理,把子树再求一遍。
代码:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> par;
#define mp make_pair
const int maxn = 2e5;
const int inf = 0x3f3f3f3f;
struct edge
{
    int ep, v, nex;
    edge(int a = 0, int b = 0, int c = 0) : ep(a), v(b), nex(c){}
}e[maxn << 1];
int head[maxn], tot;
void addedge(int sp, int ep, int v = 1)
{
    e[tot] = edge(ep, v, head[sp]);
    head[sp] = tot++;
    e[tot] = edge(sp, v, head[ep]);
    head[ep] = tot++;
}

int v[maxn], n, k;
int siz[maxn], maxp[maxn], sizz, root;
bool vis[maxn];
par path[maxn];
int cur;
ll res;

void init()
{
    memset(head, -1, sizeof(head)); tot = 0;
    memset(vis, false, sizeof(vis)); res = 0;
}

void getroot(int now, int f)
{
    siz[now] = 1, maxp[now] = 0;
    for(int i = head[now]; ~i; i = e[i].nex)
    {
        int ep = e[i].ep;
        if(ep == f || vis[ep]) continue;
        getroot(ep, now);
        siz[now] += siz[ep];
        maxp[now] = max(maxp[now], siz[ep]);
    }
    maxp[now] = max(maxp[now], sizz - siz[now]);
    if(maxp[now] < maxp[root]) root = now;
}

void dfs_maxmin(int now, int f, int maxx, int minn)
{
    maxx = max(maxx, v[now]), minn = min(minn, v[now]);
    if(maxx - minn <= k) path[cur++] = mp(minn, maxx);
    for(int i = head[now]; ~i; i = e[i].nex)
    {
        int ep = e[i].ep;
        if(ep == f || vis[ep]) continue;
        dfs_maxmin(ep, now, maxx, minn);
    }
}

ll cal(int now, int maxx, int minn)
{
    ll ans = 0;
    cur = 0;
    dfs_maxmin(now, -1, maxx, minn);
    sort(path, path + cur);
    for(int i = cur - 1; i >= 0; i--)
    {
        int num = lower_bound(path, path + i, mp(path[i].second - k, 0)) - path;
        ans += i - num;
    }
    return ans;
}

void divide(int now)
{
    vis[now] = true;
    res += cal(now, -1e9, 1e9);
    for(int i = head[now]; ~i; i = e[i].nex)
    {
        int ep = e[i].ep;
        if(vis[ep]) continue;
        res -= cal(ep, v[now], v[now]);
		maxp[root = 0] = sizz = siz[ep];
		getroot(ep, -1);
        divide(root);
    }
}

int main()
{
	ios::sync_with_stdio(0);
	cin.tie(0); cout.tie(0);
    int t; cin >> t;
    while(t--)
    {
        init();
        cin >> n >> k;
        for(int i = 1; i <= n; i++)
            cin >> v[i];
        for(int i = 1; i < n; i++)
        {
            int sp, ep;
            cin >> sp >> ep;
            addedge(sp, ep);
        }
		maxp[root = 0] = sizz = n;
		getroot(1, -1);
        divide(root);
        cout << res * 2 << endl;
    }
}