官方题解
2026-03-30 16:13:23
发布于:浙江
题目大意
给定一棵以 1 号节点为根的树,每个节点有一个整数能量值。要求通过任意重排每个节点的子树顺序,生成字典序最小的先序遍历序列。
题解思路
对任意节点 u,设它的“最稳定子序列” S(u) 是在允许重排 u 的孩子后,u 的子树先序序列中字典序最小的那条。显然整棵树的答案就是 S(1)。
如果 u 的孩子子树最稳定子序列分别是 S(v1),S(v2),…,那么 S(u) 一定形如:
S(u) = [a_u] + S(按某顺序排列的孩子 1) + S(孩子 2) + …
为了让拼接后的整体字典序最小,只需把孩子按各自的 S(v) 的字典序从小到大排序后依次拼接即可:因为先序序列在进入第一个孩子子树前已经固定为 a_u,之后最先产生差异的位置一定落在“第一个不同孩子的子序列”里,所以把最小的子序列放在最前是最优的。
因此问题变成:需要能高效比较两个子树的 S(u) 与 S(v)。直接把 S(u) 展开成数组会导致总长度平方级。
做法是把每个 S(u) 当成一条“序列对象”,用一棵隐式 Treap 维护:
- 长度 len;
- 两个模数下的多项式哈希值;
- 支持 O(log n) 取第 k 个元素、取前缀哈希。
比较两个序列时,用二分 + 前缀哈希求它们的最长公共前缀 LCP,再比较下一位元素即可完成字典序比较。对子树排序后,用 Treap 进行序列拼接(merge)即可得到父节点的序列对象。
整体流程:先把树以 1 为根,得到父子关系与后序顺序;按后序从叶到根处理,每个节点对孩子排序并合并生成自己的序列对象,同时把排序结果存下来。最后按存下来的孩子顺序做一次先序遍历输出能量值。
复杂度:设比较一次序列为 O(log^2 n),则总复杂度约为 O((n + 总排序比较次数)·log^2 n),在本题规模内可通过;空间 O(n)。
参考代码
#include <bits/stdc++.h>
using namespace std;
static const long long MOD1 = 1000000007LL;
static const long long MOD2 = 1000000009LL;
static const long long BASE = 911382323LL;
long long add_mod(long long a, long long b, long long mod) {
a += b;
if (a >= mod) a -= mod;
return a;
}
long long mul_mod(long long a, long long b, long long mod) {
return (long long)((__int128)a * b % mod);
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n;
cin >> n;
vector<long long> a(n + 1);
for (int i = 1; i <= n; i++) cin >> a[i];
vector<vector<int>> adj(n + 1);
for (int i = 1; i <= n - 1; i++) {
int u, v;
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
vector<long long> all;
all.reserve(n);
for (int i = 1; i <= n; i++) all.push_back(a[i]);
sort(all.begin(), all.end());
all.erase(unique(all.begin(), all.end()), all.end());
vector<int> tok(n + 1, 0);
for (int i = 1; i <= n; i++) {
int id = (int)(lower_bound(all.begin(), all.end(), a[i]) - all.begin());
tok[i] = id + 1;
}
vector<int> parent(n + 1, 0);
vector<vector<int>> child(n + 1);
vector<int> order;
order.reserve(n);
vector<int> st;
st.reserve(n);
st.push_back(1);
parent[1] = -1;
while (!st.empty()) {
int u = st.back();
st.pop_back();
order.push_back(u);
for (int v : adj[u]) {
if (v == parent[u]) continue;
parent[v] = u;
st.push_back(v);
}
}
for (int v = 2; v <= n; v++) {
child[parent[v]].push_back(v);
}
vector<long long> pw1(n + 2, 1), pw2(n + 2, 1);
for (int i = 1; i <= n + 1; i++) {
pw1[i] = mul_mod(pw1[i - 1], BASE, MOD1);
pw2[i] = mul_mod(pw2[i - 1], BASE, MOD2);
}
vector<int> L(n + 2, 0), R(n + 2, 0), SZ(n + 2, 0), VAL(n + 2, 0);
vector<unsigned int> PRI(n + 2, 0);
vector<long long> H1(n + 2, 0), H2(n + 2, 0);
int node_cnt = 0;
auto rng_next = [&]() -> unsigned int {
static unsigned long long x = 88172645463325252ull;
x ^= x << 7;
x ^= x >> 9;
return (unsigned int)(x & 0xffffffffu);
};
auto pull = [&](int x) {
int ls = L[x], rs = R[x];
int lenL = ls ? SZ[ls] : 0;
int lenR = rs ? SZ[rs] : 0;
SZ[x] = lenL + 1 + lenR;
long long left1 = ls ? H1[ls] : 0;
long long right1 = rs ? H1[rs] : 0;
long long left2 = ls ? H2[ls] : 0;
long long right2 = rs ? H2[rs] : 0;
long long t1 = mul_mod(left1, pw1[1 + lenR], MOD1);
t1 = add_mod(t1, mul_mod((long long)VAL[x], pw1[lenR], MOD1), MOD1);
t1 = add_mod(t1, right1, MOD1);
H1[x] = t1;
long long t2 = mul_mod(left2, pw2[1 + lenR], MOD2);
t2 = add_mod(t2, mul_mod((long long)VAL[x], pw2[lenR], MOD2), MOD2);
t2 = add_mod(t2, right2, MOD2);
H2[x] = t2;
};
function<int(int,int)> merge = [&](int a, int b) -> int {
if (!a) return b;
if (!b) return a;
if (PRI[a] < PRI[b]) {
R[a] = merge(R[a], b);
pull(a);
return a;
} else {
L[b] = merge(a, L[b]);
pull(b);
return b;
}
};
auto new_node = [&](int v) -> int {
++node_cnt;
VAL[node_cnt] = v;
PRI[node_cnt] = rng_next();
L[node_cnt] = R[node_cnt] = 0;
SZ[node_cnt] = 1;
H1[node_cnt] = v % MOD1;
H2[node_cnt] = v % MOD2;
return node_cnt;
};
function<pair<long long,long long>(int,int)> prefix_hash = [&](int x, int len) -> pair<long long,long long> {
if (!x || len <= 0) return {0, 0};
int ls = L[x], rs = R[x];
int lenL = ls ? SZ[ls] : 0;
if (len <= lenL) return prefix_hash(ls, len);
if (len == lenL + 1) {
long long left1 = ls ? H1[ls] : 0;
long long left2 = ls ? H2[ls] : 0;
long long h1 = add_mod(mul_mod(left1, pw1[1], MOD1), VAL[x] % MOD1, MOD1);
long long h2 = add_mod(mul_mod(left2, pw2[1], MOD2), VAL[x] % MOD2, MOD2);
return {h1, h2};
}
int takeR = len - lenL - 1;
auto rightPref = prefix_hash(rs, takeR);
long long left1 = ls ? H1[ls] : 0;
long long left2 = ls ? H2[ls] : 0;
long long h1 = mul_mod(left1, pw1[1 + takeR], MOD1);
h1 = add_mod(h1, mul_mod((long long)VAL[x], pw1[takeR], MOD1), MOD1);
h1 = add_mod(h1, rightPref.first, MOD1);
long long h2 = mul_mod(left2, pw2[1 + takeR], MOD2);
h2 = add_mod(h2, mul_mod((long long)VAL[x], pw2[takeR], MOD2), MOD2);
h2 = add_mod(h2, rightPref.second, MOD2);
return {h1, h2};
};
function<int(int,int)> kth = [&](int x, int k) -> int {
int ls = L[x];
int lenL = ls ? SZ[ls] : 0;
if (k < lenL) return kth(ls, k);
if (k == lenL) return VAL[x];
return kth(R[x], k - lenL - 1);
};
auto less_seq = [&](int ra, int rb, int firstA, int firstB) -> bool {
if (firstA != firstB) return firstA < firstB;
int la = ra ? SZ[ra] : 0;
int lb = rb ? SZ[rb] : 0;
int lim = min(la, lb);
int lo = 0, hi = lim;
while (lo < hi) {
int mid = (lo + hi + 1) >> 1;
auto ha = prefix_hash(ra, mid);
auto hb = prefix_hash(rb, mid);
if (ha == hb) lo = mid;
else hi = mid - 1;
}
int lcp = lo;
if (lcp == lim) return la < lb;
int va = kth(ra, lcp);
int vb = kth(rb, lcp);
return va < vb;
};
vector<int> root_id(n + 1, 0);
vector<int> first_tok(n + 1, 0);
for (int i = 1; i <= n; i++) first_tok[i] = tok[i];
for (int idx = n - 1; idx >= 0; idx--) {
int u = order[idx];
auto &ch = child[u];
if (!ch.empty()) {
sort(ch.begin(), ch.end(), [&](int x, int y) {
return less_seq(root_id[x], root_id[y], first_tok[x], first_tok[y]);
});
}
int cur = new_node(tok[u]);
for (int v : ch) {
cur = merge(cur, root_id[v]);
}
root_id[u] = cur;
}
vector<long long> ans;
ans.reserve(n);
vector<int> stack2;
stack2.reserve(n);
stack2.push_back(1);
while (!stack2.empty()) {
int u = stack2.back();
stack2.pop_back();
ans.push_back(a[u]);
auto &ch = child[u];
for (int i = (int)ch.size() - 1; i >= 0; i--) stack2.push_back(ch[i]);
}
for (int i = 0; i < n; i++) {
if (i) cout << ' ';
cout << ans[i];
}
cout << "\n";
return 0;
}
全部评论 1
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 2e5 + 10;
vector<int> adj[MAXN];
int a[MAXN], n;
bool cmp(int x, int y) {
if (a[x] != a[y]) return a[x] < a[y];
// 双指针遍历孩子列表,比较字典序
auto &vx = adj[x], &vy = adj[y];
int i = 0, j = 0, sx = vx.size(), sy = vy.size();
while (i < sx && j < sy) {
if (cmp(vx[i], vy[j])) return true;
if (cmp(vy[j], vx[i])) return false;
i++, j++;
}
return sx < sy;
}
void dfs_sort(int u, int fa) {
vector<int> sons;
for (int v : adj[u]) {
if (v != fa) {
dfs_sort(v, u);
sons.push_back(v);
}
}
sort(sons.begin(), sons.end(), cmp);
adj[u] = move(sons);
}
void dfs_print(int u) {
cout << a[u] << ' ';
for (int v : adj[u]) {
dfs_print(v);
}
}
int main() {
cin >> n;
for (int i = 1; i <= n; ++i)cin >> a[i];
for (int i = 1; i < n; ++i) {
int u, v;
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
dfs_sort(1, -1);
dfs_print(1);
return 0;
}我47行就搞定了<->昨天 来自 浙江
0












有帮助,赞一个