HDU5405 Sometimes Naive
题意
给定一棵n点的树,每个点有点权\(w_i\)。每次给定u,v,求\((\sum_{i=1}^n\sum_{j=1}^nf(i,j)) \mod 10^9+7\),其中当i到j的路径和u到v的路径有交集时\(f(i,j)=w_iw_j\),否则\(f(i,j)=0\)
题解
直接算很麻烦,我们考虑反着算。
首先算出\(\sum_{i=1}^n\sum_{j=1}^nw_iw_j=(\sum_{i=1}^nw_i)^2\),然后我们把不满足条件的点对\(i,j\)的\(w_iw_j\)减去。
显然如果把\(u\)到\(v\)这条路径上的点全部删除,这棵树会变为森林,所有不满足条件的路径一定在某棵子树上,所以答案就要减去这棵子树点权和的平方。
我们可以记录\(A_i\)表示以i为根子树的\(w_i\)和,\(B_i\)表示以i所有轻儿子\(j\)子树权值和平方的和。
然后维护就很显然了,但是细节有点多。
代码
/*
Author: CNYALI_LK
LANG: C++
PROG: 5405.cpp
Mail: cnyalilk@vip.qq.com
*/
#include<bits/stdc++.h>
#define debug(...) fprintf(stderr,__VA_ARGS__)
#define DEBUG printf("Passing [%s] in LINE %lld\n",__FUNCTION__,__LINE__)
#define Debug debug("Passing [%s] in LINE %lld\n",__FUNCTION__,__LINE__)
#define all(x) x.begin(),x.end()
using namespace std;
const double eps=1e-8;
const double PI=acos(-1.0);
typedef long long ll;
template<class T>ll chkmin(T &a,T b){return a>b?a=b,1:0;}
template<class T>ll chkmax(T &a,T b){return a<b?a=b,1:0;}
template<class T>T sqr(T a){return a*a;}
template<class T>T mmin(T a,T b){return a<b?a:b;}
template<class T>T mmax(T a,T b){return a>b?a:b;}
template<class T>T aabs(T a){return a<0?-a:a;}
#define min mmin
#define max mmax
#define abs aabs
ll read(){
ll s=0,base=1;
char c;
while(!isdigit(c=getchar()))if(c==EOF)exit(0);else if(c=='-')base=-base;
while(isdigit(c)){s=s*10+(c^48);c=getchar();}
return s*base;
}
char WritellBuffer[1024];
template<class T>void write(T a,char end){
ll cnt=0,fu=1;
if(a<0){putchar('-');fu=-1;}
do{WritellBuffer[++cnt]=fu*(a%10)+'0';a/=10;}while(a);
while(cnt){putchar(WritellBuffer[cnt]);--cnt;}
putchar(end);
}
ll beg[102424],to[233333],lst[233333],e,siz[102424];
void add(ll u,ll v){
to[++e]=v;
lst[e]=beg[u];
beg[u]=e;
to[++e]=u;
lst[e]=beg[v];
beg[v]=e;
}
ll hvy[102424],Ws[102424],w[102424],xw[102424];
ll n,m;
const ll p=1000000007;
void dfs1(ll x,ll fa){
hvy[x]=0;
Ws[x]=w[x];
xw[x]=0;
siz[x]=1;
for(ll i=beg[x];i;i=lst[i])if(to[i]!=fa){
dfs1(to[i],x);
siz[x]+=siz[to[i]];
if(siz[to[i]]>siz[hvy[x]])hvy[x]=to[i];
Ws[x]=(Ws[x]+Ws[to[i]])%p;
xw[x]=(xw[x]+Ws[to[i]]*Ws[to[i]])%p;
}
xw[x]-=Ws[hvy[x]]*Ws[hvy[x]]%p;
if(xw[x]<0)xw[x]+=p;
}
ll ys[102424],fys[102424],fa[102424],t;
ll sumw[102424],squw[102424],dis[102424];
struct SegmentTree{
ll sum[666666];
ll tag[666666];
ll a[102424];
#define mid ((l+r)>>1)
void buildtree(ll x,ll l,ll r){
tag[x]=0;
if(l==r)sum[x]=a[l];
else{
buildtree(x<<1,l,mid);
buildtree(x<<1|1,mid+1,r);
sum[x]=sum[x<<1]+sum[x<<1|1];
}
}
void make(ll *w,ll n){
for(ll i=1;i<=n;++i)a[i]=w[i];
buildtree(1,1,n);
}
void pushdown(ll x,ll l,ll r){
tag[x<<1]+=tag[x];
if(tag[x<<1]>=p)tag[x<<1]-=p;
tag[x<<1|1]+=tag[x];
if(tag[x<<1|1]>=p)tag[x<<1|1]-=p;
sum[x<<1]+=tag[x]*(mid-l+1);
sum[x<<1]%=p;
sum[x<<1|1]+=tag[x]*(r-mid);
sum[x<<1|1]%=p;
tag[x]=0;
}
void add(ll x,ll l,ll r,ll lx,ll rx,ll w){
if(lx<=l&&r<=rx){
tag[x]+=w;
if(tag[x]>=p)tag[x]-=p;
sum[x]=(sum[x]+(r-l+1)*w)%p;
return;
}
if(r<lx||rx<l)return;
pushdown(x,l,r);
add(x<<1,l,mid,lx,rx,w);
add(x<<1|1,mid+1,r,lx,rx,w);
sum[x]=sum[x<<1]+sum[x<<1|1];
if(sum[x]>=p)sum[x]-=p;
}
ll Sum(ll x,ll l,ll r,ll lx,ll rx){
if(lx<=l&&r<=rx){
return sum[x];
}
if(r<lx||rx<l)return 0;
pushdown(x,l,r);
return (Sum(x<<1,l,mid,lx,rx)+Sum(x<<1|1,mid+1,r,lx,rx))%p;
}
};
SegmentTree s1,s2;
ll top[102424];
void dfs2(ll x,ll f){
ll k=ys[x]=++t;
sumw[t]=Ws[x];
squw[t]=xw[x];
fys[t]=x;
fa[t]=ys[f];
if(hvy[x]){
dis[t+1]=dis[k];
top[t+1]=top[k];
dfs2(hvy[x],x);
}
for(ll i=beg[x];i;i=lst[i])if(to[i]!=f&&to[i]!=hvy[x]){
dis[t+1]=dis[k]+1;
top[t+1]=t+1;
dfs2(to[i],x);
}
}
void change(ll x,ll w){
while(x){
if(fa[top[x]]){
s2.add(1,1,n,fa[top[x]],fa[top[x]],((s1.Sum(1,1,n,top[x],top[x])*2+w)*w%p+p)%p);
}
s1.add(1,1,n,top[x],x,w);
x=fa[top[x]];
}
}
ll getans(ll l,ll r){
ll cnt=((sqr(s1.Sum(1,1,n,1,1))-sqr(s1.Sum(1,1,n,ys[hvy[fys[l]]],ys[hvy[fys[l]]]))-sqr(s1.Sum(1,1,n,ys[hvy[fys[r]]],ys[hvy[fys[r]]])))%p+p)%p,go=0;
while(top[l]!=top[r]){
if(dis[l]<dis[r])swap(l,r);
cnt-=s2.Sum(1,1,n,top[l],l);
ll s=ys[hvy[fys[fa[top[l]]]]];
cnt+=sqr(s1.Sum(1,1,n,top[l],top[l]))%p;
cnt-=sqr(s1.Sum(1,1,n,s,s))%p;
l=fa[top[l]];
if(cnt>=p)cnt-=p;
if(cnt<0)cnt+=p;
if(cnt<0)cnt+=p;
go=1;
}
if(l>r)swap(l,r);
cnt+=sqr(s1.Sum(1,1,n,ys[hvy[fys[l]]],ys[hvy[fys[l]]]))%p;
ll s=s2.Sum(1,1,n,l,r);
cnt=cnt-s;
cnt-=sqr(s1.Sum(1,1,n,1,1)-s1.Sum(1,1,n,l,l))%p;
if(cnt>=p)cnt-=p;
if(cnt<0)cnt+=p;
if(cnt<0)cnt+=p;
return cnt;
}
int main(){
#ifdef cnyali_lk
freopen("5405.in","r",stdin);
freopen("5405.out","w",stdout);
#endif
while(n=read()){
m=read();
for(ll i=1;i<=n;++i)w[i]=read(),beg[i]=0;
e=0;
for(ll i=1;i<n;++i){
add(read(),read());
}
dfs1(1,0);
t=0;
top[1]=1;
dfs2(1,0);
s1.make(sumw,n);
s2.make(squw,n);
while(m){
--m;
if(read()^2){
ll u=read(),s=read();
change(ys[u],s-w[u]);
w[u]=s;
}
else{
write(getans(ys[read()],ys[read()]),'\n');
}
}
}
return 0;
}