BZOJ5210 最大连通子块和 <树链剖分+树形DP>

Problem

最大连通子块和


Description

给出一棵 个点,以 为根的有根树,点有点权。
要求支持如下两种操作:

  1. :将点 的点权改为
  2. :求以 为根的子树的最大连通子块和

一棵子树的最大连通子块和指该子树所有子连通块的点权和中的最大值(本题中子连通块包括空连通块,点权和为 )。

Input

第一行两个整数 , ,表示树的点数以及操作的数目。
第二行 个整数,第 个整数 表示第 个点的点权。
接下来的 行,每行两个整数 ,表示 之间有一条边相连。
接下来的 行,每行输入一个操作,含义如题目所述。
保证操作为 之一。

Output

对于每个 操作输出一行一个整数,表示询问子树的最大连通子块和。

Sample Input

1
2
3
4
5
6
7
8
9
10
5 4
3 -2 0 3 -1
1 2
1 3
4 2
2 5
Q 1
M 4 1
Q 1
Q 2

Sample Output

1
2
3
4
3
1

HINT

,任意时刻

Source

CQzhangyu&GXZlegend原创

标签:树链剖分 树上DP

Solution

经典树链剖分维护树上 。以下解法源自出题人CQzhangyu的博客GXZlegend的博客

首先考虑暴力 ,令 表示 子树中包含 的连通块权值和最大值,那么 。维护 表示 子树中连通块权值和最大值,则 。每次修改后重新 ,可做到

注意到每次修改后不是所有的 都变化。用树链剖分维护树上 ,可以每次不修改所有的 值。然而直接维护 的值不方便,因为递推式中的和式在线段树上不便于计算。于是引入 ,其中 儿子。那么重链上的转移就变为 。这其实就是最大连续子段和的 方式,可以用线段树维护带修改最大连续子段和。于是对于每次修改,向上跳重链,在线段树上每条重链的区域内维护最大连续子段和即可。注意这里“每条重链的区域”指的是链顶到链底的距离,而非括号序列。

再考虑如何维护 。注意到 每次都直接取最值,带修改后,其实是可删除堆的形式。对每个点维护可删除堆来维护轻儿子的 值最大值,对于重儿子则在线段树上维护。在线段树上修改后 时,用每个结点的 堆顶元素更新最大子段和,即可动态维护

查询时,直接向上跳重链,将该重链对应区间的最大子段和取出来打擂即可。

这样修改复杂度 ,查询复杂度 ,总复杂度

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#include <bits/stdc++.h>
#define MAX_N 200000
#define mid ((s+t)>>1)
using namespace std;
typedef long long lnt;
template <class T> inline void read(T &x) {
x = 0; int c = getchar(), f = 1;
for (; !isdigit(c); c = getchar()) if (c == 45) f = -1;
for (; isdigit(c); c = getchar()) (x *= 10) += f*(c-'0');
}
int n, m, a[MAX_N+5]; vector <int> G[MAX_N+5];
int ind, dep[MAX_N+5], fa[MAX_N+5], son[MAX_N+5];
int sz[MAX_N+5], top[MAX_N+5], into[MAX_N+5], outo[MAX_N+5];
lnt f[MAX_N+5], g[MAX_N+5], mxf[MAX_N+5];
struct node {
lnt s, mx, lmx, rmx;
node () {s = mx = lmx = rmx = 0LL;}
inline friend node operator + (const node &a, const node &b) {
node ret;
ret.s = a.s+b.s, ret.mx = max(max(a.mx, b.mx), a.rmx+b.lmx);
ret.lmx = max(a.lmx, a.s+b.lmx), ret.rmx = max(b.rmx, a.rmx+b.s);
return ret;
}
} tr[MAX_N<<2];
struct heap {
priority_queue <lnt> i, o;
inline void push(lnt x) {i.push(x);}
inline void pop(lnt x) {o.push(x);}
inline lnt top() {
while (!o.empty() && i.top() == o.top())
i.pop(), o.pop();
return i.top();
}
} h[MAX_N+5];
void addedge(int u, int v) {G[u].push_back(v), G[v].push_back(u);}
void DFS(int u) {
sz[u] = 1;
for (int i = 0; i < (int)G[u].size(); i++) {
int v = G[u][i]; if (v == fa[u]) continue;
dep[v] = dep[u]+1, fa[v] = u, DFS(v), sz[u] += sz[v];
if (!son[u] || sz[son[u]] < sz[v]) son[u] = v;
}
}
void DFS(int u, int tp) {
top[u] = tp, g[into[u] = ++ind] = a[u];
if (son[u]) DFS(son[u], tp);
for (int i = 0, v; i < (int)G[u].size(); i++)
if (((v = G[u][i]) ^ fa[u]) && (v ^ son[u]))
DFS(v, v), g[into[u]] += f[v], h[into[u]].push(mxf[v]);
outo[u] = son[u] ? outo[son[u]] : ind;
f[u] = max(f[son[u]]+g[into[u]], 0LL);
mxf[u] = max(mxf[son[u]], max(f[u], h[into[u]].top()));
}
void build(int v, int s, int t) {
if (s == t) {
tr[v].s = g[s], tr[v].mx = max(g[s], h[s].top());
tr[v].lmx = tr[v].rmx = max(g[s], 0LL); return;
}
build(v<<1, s, mid), build(v<<1|1, mid+1, t);
tr[v] = tr[v<<1]+tr[v<<1|1];
}
void modify(int v, int s, int t, int p) {
if (s == t) {
tr[v].s = g[s], tr[v].mx = max(g[s], h[s].top());
tr[v].lmx = tr[v].rmx = max(g[s], 0LL); return;
}
if (p <= mid) modify(v<<1, s, mid, p);
if (p >= mid+1) modify(v<<1|1, mid+1, t, p);
tr[v] = tr[v<<1]+tr[v<<1|1];
}
node query(int v, int s, int t, int l, int r) {
if (s >= l && t <= r) return tr[v]; node ret;
if (l <= mid) ret = ret+query(v<<1, s, mid, l, r);
if (r >= mid+1) ret = ret+query(v<<1|1, mid+1, t, l, r);
return ret;
}
void change(int u, int val) {
node pr, cr;
pr.lmx = g[into[u]];
cr.lmx = g[into[u]]-a[u]+val;
for (int i = 0; u; i++, u = fa[top[u]]) {
g[into[u]] += cr.lmx-pr.lmx;
if (i) h[into[u]].pop(pr.mx), h[into[u]].push(cr.mx);
pr = query(1, 1, n, into[top[u]], outo[u]);
modify(1, 1, n, into[u]);
cr = query(1, 1, n, into[top[u]], outo[u]);
}
}
int main() {
read(n), read(m);
for (int i = 1; i <= n; i++) read(a[i]), h[i].push(0);
for (int i = 1, u, v; i < n; i++) read(u), read(v), addedge(u, v);
DFS(1), DFS(1, 1), build(1, 1, n);
while (m--) {
char opt[2]; int x, y; scanf("%s", opt);
if (opt[0] == 'M') read(x), read(y), change(x, y), a[x] = y;
else read(x), printf("%lld\n", query(1, 1, n, into[x], outo[x]).mx);
}
return 0;
}
------------- Thanks For Reading -------------
0%