主席树入门题
题意:
给你一棵树,以及每个节点的权值。询问:对于给定的结点u, v,求连线上的第k小权值(我感觉准确就应该叫第k小,虽然网上都叫第k大。。。)
分析:
这道题目跟主席树入门题目:求一段区间第k大很像,不同的是这道题目是求在树的一条链上的第k大。一开始的时候我只感觉到这道题目隐约的跟lca有点关系,但是不知道该如何去处理。然后我看了别人的博客,但是大多数都是直接给出了一个结论:对于查询区间[u,v],答案就是root[u]+root[v]-root[lca]-root[lca的父亲]上的第k大
一开始不是很理解,后来发现其实类似于前缀和,对于一个节点,它的每个儿子节点做一棵新版本的树。如果从整棵树的根节点开始做的话,最后得到的一个root[u],就表示了u这个节点到根节点的信息。如果要获得从u到v这条链上的信息,拿就要先算出这两个点的lca,u到lca的信息是通过root[u]-root[lca的父亲]获得的,那么另外的一段就是root[v]-root[lca]获得的。合起来就是那个结论了。
代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
//typedef __int128 lll;
#define close() ios::sync_with_stdio(0), cin.tie(0), cout.tie(0)
#define mem(a, b) memset(a, b, sizeof(a))
#define debug(a) cout << "debug :" << a << endl
#define fi first
#define se second
const int maxn = 1e5 + 10;
const int inf = 0x3f3f3f3f;
struct edge
{
int ep, nex;
edge(int a = 0, int b = 0) : ep(a), nex(b){};
}e[maxn << 1];
struct node
{
int l, r, sum;
}t[maxn << 5];
int tot, head[maxn];
int v[maxn], sorted[maxn], num;
int root[maxn], rootnum;
int depth[maxn], fa[maxn][20];
void init()
{
mem(head, -1); tot = 0;
}
void addedge(int sp, int ep)
{
e[tot] = edge(ep, head[sp]);
head[sp] = tot++;
e[tot] = edge(sp, head[ep]);
head[ep] = tot++;
}
int getid(int x)
{
return lower_bound(sorted + 1, sorted + 1 + num, x) - sorted;
}
void build(int l, int r, int &tar)
{
tar = ++rootnum;
t[tar].sum = 0;
if(l == r) return;
int mid = l + r >> 1;
build(l, mid, t[tar].l);
build(mid + 1, r, t[tar].r);
}
void update(int last, int &tar, int l, int r, int x)
{
tar = ++rootnum;
t[tar] = t[last];
t[tar].sum++;
if(l == r) return;
int mid = l + r >> 1;
if(x <= mid) update(t[last].l, t[tar].l, l, mid, x);
else update(t[last].r, t[tar].r, mid + 1, r, x);
}
int query(int lson, int rson, int lca, int lcaf, int l, int r, int k)
{
if(l == r) return l;
int mid = l + r >> 1;
int dif = t[t[lson].l].sum + t[t[rson].l].sum - t[t[lca].l].sum - t[t[lcaf].l].sum;
if(dif >= k) return query(t[lson].l, t[rson].l, t[lca].l, t[lcaf].l, l, mid, k);
else return query(t[lson].r, t[rson].r, t[lca].r, t[lcaf].r, mid + 1, r, k - dif);
}
void dfs(int now, int f)
{
depth[now] = depth[f] + 1;
fa[now][0] = f;
for(int i = 1; i < 20; i++)
fa[now][i] = fa[fa[now][i - 1]][i - 1];
update(root[f], root[now], 1, num, getid(v[now]));
for(int i = head[now]; ~i; i = e[i].nex)
{
int ep = e[i].ep;
if(ep == f) continue;
dfs(ep, now);
}
}
int lca(int x, int y)
{
if(depth[x] < depth[y]) swap(x, y);
int dif = depth[x] - depth[y];
for(int i = 0; i < 20; i++)
if((1 << i) & dif)
x = fa[x][i];
if(x == y) return x;
for(int i = 19; i >= 0; i--)
if(fa[x][i] != fa[y][i])
x = fa[x][i], y = fa[y][i];
return fa[x][0];
}
int main()
{
int n, m;
cin >> n >> m;
init();
for(int i = 1; i <= n; i++)
scanf("%d", v + i), sorted[i] = v[i];
sort(sorted + 1, sorted + 1 + n);
num = unique(sorted + 1, sorted + 1 + n) - sorted - 1;
for(int i = 1; i <= n - 1; i++)
{
int x, y;
scanf("%d%d", &x, &y);
addedge(x, y);
}
dfs(1, 0);
for(int i = 1; i <= m; i++)
{
int a, b, k;
scanf("%d%d%d", &a, &b, &k);
int f = lca(a, b);
int id = query(root[a], root[b], root[f], root[fa[f][0]], 1, num, k);
printf("%d\n", sorted[id]);
}
}