树链剖分

树链剖分,看它名字就知道,是将树搞成一条条链来做

不说了,上网baidu

给个模板

说实在的,真的不好打,打好了还卡我long long

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#include<vector>
long long n,r,q,mod;
#define N 100010
long long bit1[N],bit2[N];
void addbit(long long *b,long long pos,long long w)
{
while(pos<=n)
{
b[pos]=(b[pos]+w+mod)%mod;
pos+=pos&(-pos);
}
}
long long querybit(long long *b,long long pos)
{
long long sum=0;
while(pos)
{
sum=(sum+b[pos]+mod)%mod;
pos-=pos&(-pos);
}
return sum;
}
void add(long long l,long long r,long long w)
{
w%=mod;
addbit(bit1,l,w);
addbit(bit1,r+1,-w);
addbit(bit2,l,(l-1)*w%mod);
addbit(bit2,r+1,-w*r%mod);
}
long long query(long long l,long long r)
{
long long ans=(r*querybit(bit1,r)%mod-(l-1)*querybit(bit1,l-1)%mod-querybit(bit2,r)+querybit(bit2,l-1))%mod;
while(ans<0)ans+=mod;
return ans;
}
//树状数组维护区间,详见我上一篇博客
vector<long long>t[N];
long long f[N],s[N],w[N],ff[N],num[N],cnt,d[N];
long long a[N];
void dfs1(long long x)
{
w[x]=1;
fr(i,0,t[x].size()-1)
if(t[x][i]!=f[x])
{
f[t[x][i]]=x;
d[t[x][i]]=d[x]+1;
dfs1(t[x][i]);
w[x]+=w[t[x][i]];
}
}//得出每个点的父亲,深度,子树大小
void dfs2(long long x,long long st)
{
num[x]=++cnt;
ff[x]=st;
if(w[x]==1)
return;
fr(i,0,t[x].size()-1)
if(t[x][i]!=f[x]&&w[t[x][i]]>w[s[x]])
s[x]=t[x][i];
dfs2(s[x],st);
fr(i,0,t[x].size()-1)
if(t[x][i]!=f[x]&&t[x][i]!=s[x])
dfs2(t[x][i],t[x][i]);
}//得出每个点的编号,主儿子,深度最小的点且满足可以沿着那个点的主儿子走到它
#define swap(x,y) {long long k=x;x=y;y=k;}
int main()
{
n=read();
q=read();
r=read();
mod=read();
fr(i,1,n)
a[i]=read();
fr(i,1,n-1)
{
long long u=read(),v=read();
t[u].push_back(v);
t[v].push_back(u);
}
dfs1(r);
dfs2(r,r);
fr(i,1,n)
add(num[i],num[i],a[i]);
while(q--)
{
long long opt=read();
if(opt==1)
{
long long u=read(),v=read(),w=read();
while(ff[u]!=ff[v])
{
if(d[ff[u]]<d[ff[v]])
swap(u,v);
add(num[ff[u]],num[u],w);
u=f[ff[u]];
}
if(num[u]<num[v])
add(num[u],num[v],w);
else
add(num[v],num[u],w);
}
if(opt==2)
{
long long u=read(),v=read(),ans=0;
while(ff[u]!=ff[v])
{
if(d[ff[u]]<d[ff[v]])
swap(u,v);
ans=(ans+query(num[ff[u]],num[u]))%mod;
u=f[ff[u]];
}
if(num[u]<num[v])
ans=(ans+query(num[u],num[v]))%mod;
else
ans=(ans+query(num[v],num[u]))%mod;
printf("%lld\n",ans);
}
if(opt==3)
{
long long u=read(),ww=read();
add(num[u],num[u]+w[u]-1,ww);
}
if(opt==4)
{
long long u=read();
printf("%lld\n",query(num[u],num[u]+w[u]-1));
}
}
return 0;
}