1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
| #include <bits/stdc++.h> using namespace std; const int N = 250000 + 10; int n, q; int tot, rt[20 * N], v[20 * N], ls[20 * N], rs[20 * N]; void upd(int &x, int p, int c, int l = 0, int r = n) { if (!x) x = ++tot; v[x] += c; if (l == r) return; int mid = (l + r) >> 1; p <= mid ? upd(ls[x], p, c, l, mid) : upd(rs[x], p, c, mid + 1, r); } int qry(int x, int L, int R, int l = 0, int r = n) { if (L <= l && r <= R) return v[x]; int mid = (l + r) >> 1; return (L <= mid ? qry(ls[x], L, R, l, mid) : 0) + (mid < R ? qry(rs[x], L, R, mid + 1, r) : 0); } vector<int> G[N]; array<int, 2> mx[N][18]; array<int, 2> operator+(array<int, 2> a, array<int, 2> b) { if (a[0] ^ b[0]) { if (a[1] > b[1]) return array<int, 2>{a[0], a[1] - b[1]}; return array<int, 2>{b[0], b[1] - a[1]}; } return array<int, 2>{a[0], a[1] + b[1]}; } int f[N][18], d[N], st[N], en[N], a[N], dfn; void dfs(int x = 1, int fx = 0) { f[x][0] = fx; d[x] = d[fx] + 1; mx[x][0] = {a[x], 1}; for (int i = 1; i < 18; i++) { f[x][i] = f[f[x][i - 1]][i - 1]; mx[x][i] = mx[x][i - 1] + mx[f[x][i - 1]][i - 1]; } st[x] = dfn++; upd(rt[a[x]], st[x], 1); for (int y : G[x]) if (y ^ fx) dfs(y, x); en[x] = dfn; upd(rt[a[x]], en[x], -1); } array<int, 2> getmx(int x, int y) { if (d[x] < d[y]) swap(x, y); array<int, 2> res{0, 0}; auto up = [&](int &x, int k) { res = res + mx[x][k]; x = f[x][k]; }; for (int i = 17; ~i; --i) if (d[f[x][i]] >= d[y]) up(x, i); if (x ^ y) { for (int i = 17; ~i; --i) if (f[x][i] ^ f[y][i]) up(x, i), up(y, i); up(x, 0), up(y, 0); } return array<int, 2>{(res + mx[x][0])[0], x}; } int main() { scanf("%d%d", &n, &q); for (int i = 1; i <= n; i++) scanf("%d", &a[i]); for (int i = 1, x, y; i < n; i++) { scanf("%d%d", &x, &y); G[x].push_back(y); G[y].push_back(x); } dfs(); for (int x, y; q--;) { scanf("%d%d", &x, &y); auto res = getmx(x, y); int c = res[0], lca = res[1]; int cnt = qry(rt[c], 0, st[x]) + qry(rt[c], 0, st[y]) - qry(rt[c], 0, st[lca]); if (f[lca][0]) cnt -= qry(rt[c], 0, st[f[lca][0]]); int cnt0 = d[x] + d[y] - 2 * d[lca] + 1; printf("%d\n", cnt > cnt0 / 2 ? c : -1); } return 0; }
|