Count on a tree SPOJ - COT | 主席树

主席树入门题

题意:
给你一棵树,以及每个节点的权值。询问:对于给定的结点u, v,求连线上的第k小权值(我感觉准确就应该叫第k小,虽然网上都叫第k大。。。)
分析:
这道题目跟主席树入门题目:求一段区间第k大很像,不同的是这道题目是求在树的一条链上的第k大。一开始的时候我只感觉到这道题目隐约的跟lca有点关系,但是不知道该如何去处理。然后我看了别人的博客,但是大多数都是直接给出了一个结论:对于查询区间[u,v],答案就是root[u]+root[v]-root[lca]-root[lca的父亲]上的第k大
一开始不是很理解,后来发现其实类似于前缀和,对于一个节点,它的每个儿子节点做一棵新版本的树。如果从整棵树的根节点开始做的话,最后得到的一个root[u],就表示了u这个节点到根节点的信息。如果要获得从u到v这条链上的信息,拿就要先算出这两个点的lca,u到lca的信息是通过root[u]-root[lca的父亲]获得的,那么另外的一段就是root[v]-root[lca]获得的。合起来就是那个结论了。
代码:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
//typedef __int128 lll;
#define close() ios::sync_with_stdio(0), cin.tie(0), cout.tie(0)
#define mem(a, b) memset(a, b, sizeof(a))
#define debug(a) cout << "debug :" << a << endl
#define fi first
#define se second
const int maxn =  1e5 + 10;
const int inf = 0x3f3f3f3f;
struct edge
{
    int ep, nex;
    edge(int a = 0, int b = 0) : ep(a), nex(b){};
}e[maxn << 1];
struct node
{
    int l, r, sum;
}t[maxn << 5];

int tot, head[maxn];
int v[maxn], sorted[maxn], num;
int root[maxn], rootnum;
int depth[maxn], fa[maxn][20];

void init()
{
    mem(head, -1); tot = 0;
}
void addedge(int sp, int ep)
{
    e[tot] = edge(ep, head[sp]);
    head[sp] = tot++;
    e[tot] = edge(sp, head[ep]);
    head[ep] = tot++;
}

int getid(int x)
{
    return lower_bound(sorted + 1, sorted + 1 + num, x) - sorted;
}


void build(int l, int r, int &tar)
{
    tar = ++rootnum;
    t[tar].sum = 0;
    if(l == r) return;
    int mid = l + r >> 1;
    build(l, mid, t[tar].l);
    build(mid + 1, r, t[tar].r);
}

void update(int last, int &tar, int l, int r, int x)
{
    tar = ++rootnum;
    t[tar] = t[last];
    t[tar].sum++;
    if(l == r) return;
    int mid = l + r >> 1;
    if(x <= mid) update(t[last].l, t[tar].l, l, mid, x);
    else update(t[last].r, t[tar].r, mid + 1, r, x);
}

int query(int lson, int rson, int lca, int lcaf, int l, int r, int k)
{
    if(l == r) return l;
    int mid = l + r >> 1;
    int dif = t[t[lson].l].sum + t[t[rson].l].sum - t[t[lca].l].sum - t[t[lcaf].l].sum;
    if(dif >= k) return query(t[lson].l, t[rson].l, t[lca].l, t[lcaf].l, l, mid, k);
    else return query(t[lson].r, t[rson].r, t[lca].r, t[lcaf].r, mid + 1, r, k - dif);
}

void dfs(int now, int f)
{
    depth[now] = depth[f] + 1;
    fa[now][0] = f;
    for(int i = 1; i < 20; i++)
        fa[now][i] = fa[fa[now][i - 1]][i - 1];
    update(root[f], root[now], 1, num, getid(v[now]));
    for(int i = head[now]; ~i; i = e[i].nex)
    {
        int ep = e[i].ep;
        if(ep == f) continue;
        dfs(ep, now);
    }
}

int lca(int x, int y)
{
    if(depth[x] < depth[y]) swap(x, y);
    int dif = depth[x] - depth[y];
    for(int i = 0; i < 20; i++)
        if((1 << i) & dif)
            x = fa[x][i];
    if(x == y) return x;
    for(int i = 19; i >= 0; i--)
        if(fa[x][i] != fa[y][i])
            x = fa[x][i], y = fa[y][i];
    return fa[x][0];
}

int main()
{
    int n, m;
    cin >> n >> m;
    init();
    for(int i = 1; i <= n; i++)
        scanf("%d", v + i), sorted[i] = v[i];
    sort(sorted + 1, sorted + 1 + n);
    num = unique(sorted + 1, sorted + 1 + n) - sorted - 1;

    for(int i = 1; i <= n - 1; i++)
    {
        int x, y;
        scanf("%d%d", &x, &y);
        addedge(x, y);
    }
    dfs(1, 0);
    for(int i = 1; i <= m; i++)
    {
        int a, b, k;
        scanf("%d%d%d", &a, &b, &k);
        int f = lca(a, b);
        int id = query(root[a], root[b], root[f], root[fa[f][0]], 1, num, k);
        printf("%d\n", sorted[id]);
    }
}