题意

给定一棵n点的树,每个点有点权$w_i$。每次给定u,v,求$(\sum_{i=1}^n\sum_{j=1}^nf(i,j)) \mod 10^9+7$,其中当i到j的路径和u到v的路径有交集时$f(i,j)=w_iw_j$,否则$f(i,j)=0$

题解

直接算很麻烦,我们考虑反着算。

首先算出$\sum_{i=1}^n\sum_{j=1}^nw_iw_j=(\sum_{i=1}^nw_i)^2$,然后我们把不满足条件的点对$i,j$的$w_iw_j$减去。

显然如果把$u$到$v$这条路径上的点全部删除,这棵树会变为森林,所有不满足条件的路径一定在某棵子树上,所以答案就要减去这棵子树点权和的平方。

我们可以记录$A_i$表示以i为根子树的$w_i$和,$B_i$表示以i所有轻儿子$j$子树权值和平方的和。

然后维护就很显然了,但是细节有点多。

代码

/*
Author: CNYALI_LK
LANG: C++
PROG: 5405.cpp
Mail: cnyalilk@vip.qq.com
*/
#include<bits/stdc++.h>
#define debug(...) fprintf(stderr,__VA_ARGS__)
#define DEBUG printf("Passing [%s] in LINE %lld\n",__FUNCTION__,__LINE__)
#define Debug debug("Passing [%s] in LINE %lld\n",__FUNCTION__,__LINE__)
#define all(x) x.begin(),x.end()
using namespace std;
const double eps=1e-8;
const double PI=acos(-1.0);
typedef long long ll;
template<class T>ll chkmin(T &a,T b){return a>b?a=b,1:0;}
template<class T>ll chkmax(T &a,T b){return a<b?a=b,1:0;}
template<class T>T sqr(T a){return a*a;}
template<class T>T mmin(T a,T b){return a<b?a:b;}
template<class T>T mmax(T a,T b){return a>b?a:b;}
template<class T>T aabs(T a){return a<0?-a:a;}
#define min mmin
#define max mmax
#define abs aabs
ll read(){
    ll s=0,base=1;
    char c;
    while(!isdigit(c=getchar()))if(c==EOF)exit(0);else if(c=='-')base=-base;
    while(isdigit(c)){s=s*10+(c^48);c=getchar();}
    return s*base;
}
char WritellBuffer[1024];
template<class T>void write(T a,char end){
    ll cnt=0,fu=1;
    if(a<0){putchar('-');fu=-1;}
    do{WritellBuffer[++cnt]=fu*(a%10)+'0';a/=10;}while(a);
    while(cnt){putchar(WritellBuffer[cnt]);--cnt;}
    putchar(end);
}
ll beg[102424],to[233333],lst[233333],e,siz[102424];
void add(ll u,ll v){
    to[++e]=v;
    lst[e]=beg[u];
    beg[u]=e;
    to[++e]=u;
    lst[e]=beg[v];
    beg[v]=e;
}
ll hvy[102424],Ws[102424],w[102424],xw[102424];

ll n,m;
const ll p=1000000007;
void dfs1(ll x,ll fa){
    hvy[x]=0;
    Ws[x]=w[x];
    xw[x]=0;
    siz[x]=1;
    for(ll i=beg[x];i;i=lst[i])if(to[i]!=fa){
        dfs1(to[i],x);
        siz[x]+=siz[to[i]];     
        if(siz[to[i]]>siz[hvy[x]])hvy[x]=to[i];
        Ws[x]=(Ws[x]+Ws[to[i]])%p;
        xw[x]=(xw[x]+Ws[to[i]]*Ws[to[i]])%p;
    }
    xw[x]-=Ws[hvy[x]]*Ws[hvy[x]]%p;
    if(xw[x]<0)xw[x]+=p;
}
ll ys[102424],fys[102424],fa[102424],t;
ll sumw[102424],squw[102424],dis[102424];
struct SegmentTree{
    ll sum[666666];
    ll tag[666666];
    ll a[102424];
#define mid ((l+r)>>1)
    void buildtree(ll x,ll l,ll r){
        tag[x]=0;
        if(l==r)sum[x]=a[l];
        else{
            buildtree(x<<1,l,mid);
            buildtree(x<<1|1,mid+1,r);
            sum[x]=sum[x<<1]+sum[x<<1|1];
        }
    }
    void make(ll *w,ll n){
        for(ll i=1;i<=n;++i)a[i]=w[i];
        buildtree(1,1,n);
    }

    void pushdown(ll x,ll l,ll r){
        tag[x<<1]+=tag[x];
        if(tag[x<<1]>=p)tag[x<<1]-=p;
        tag[x<<1|1]+=tag[x];
        if(tag[x<<1|1]>=p)tag[x<<1|1]-=p;
        sum[x<<1]+=tag[x]*(mid-l+1);
        sum[x<<1]%=p;
        sum[x<<1|1]+=tag[x]*(r-mid);
        sum[x<<1|1]%=p;
        tag[x]=0;
    }
    void add(ll x,ll l,ll r,ll lx,ll rx,ll w){
        if(lx<=l&&r<=rx){
            tag[x]+=w;
            if(tag[x]>=p)tag[x]-=p;
            sum[x]=(sum[x]+(r-l+1)*w)%p;
            return;
        }
        if(r<lx||rx<l)return;
        pushdown(x,l,r);
        add(x<<1,l,mid,lx,rx,w);
        add(x<<1|1,mid+1,r,lx,rx,w);
        sum[x]=sum[x<<1]+sum[x<<1|1];
        if(sum[x]>=p)sum[x]-=p;
    }
    ll Sum(ll x,ll l,ll r,ll lx,ll rx){
        if(lx<=l&&r<=rx){
            return sum[x];
        }
        if(r<lx||rx<l)return 0;
        pushdown(x,l,r);
        return (Sum(x<<1,l,mid,lx,rx)+Sum(x<<1|1,mid+1,r,lx,rx))%p;
    }
};
SegmentTree s1,s2;
ll top[102424];
void dfs2(ll x,ll f){

    ll k=ys[x]=++t;
    sumw[t]=Ws[x];
    squw[t]=xw[x];
    fys[t]=x;
    fa[t]=ys[f];
    if(hvy[x]){
        dis[t+1]=dis[k];
        top[t+1]=top[k];
        dfs2(hvy[x],x);
    }
    for(ll i=beg[x];i;i=lst[i])if(to[i]!=f&&to[i]!=hvy[x]){
        dis[t+1]=dis[k]+1;
        top[t+1]=t+1;
        dfs2(to[i],x);
    }
}
void change(ll x,ll w){
    while(x){
        if(fa[top[x]]){
            s2.add(1,1,n,fa[top[x]],fa[top[x]],((s1.Sum(1,1,n,top[x],top[x])*2+w)*w%p+p)%p);
        }   
        s1.add(1,1,n,top[x],x,w);
        x=fa[top[x]];
    }
}
ll getans(ll l,ll r){

    ll cnt=((sqr(s1.Sum(1,1,n,1,1))-sqr(s1.Sum(1,1,n,ys[hvy[fys[l]]],ys[hvy[fys[l]]]))-sqr(s1.Sum(1,1,n,ys[hvy[fys[r]]],ys[hvy[fys[r]]])))%p+p)%p,go=0;
    while(top[l]!=top[r]){
        if(dis[l]<dis[r])swap(l,r);
        cnt-=s2.Sum(1,1,n,top[l],l);

        ll s=ys[hvy[fys[fa[top[l]]]]];
        cnt+=sqr(s1.Sum(1,1,n,top[l],top[l]))%p;

        cnt-=sqr(s1.Sum(1,1,n,s,s))%p;
        l=fa[top[l]];
        if(cnt>=p)cnt-=p;
        if(cnt<0)cnt+=p;
        if(cnt<0)cnt+=p;
        go=1;
    }
    if(l>r)swap(l,r);
    cnt+=sqr(s1.Sum(1,1,n,ys[hvy[fys[l]]],ys[hvy[fys[l]]]))%p;

    ll s=s2.Sum(1,1,n,l,r);
    cnt=cnt-s;
    cnt-=sqr(s1.Sum(1,1,n,1,1)-s1.Sum(1,1,n,l,l))%p;
    if(cnt>=p)cnt-=p;
    if(cnt<0)cnt+=p;
    if(cnt<0)cnt+=p;
    return cnt;
}
int main(){
#ifdef cnyali_lk
    freopen("5405.in","r",stdin);
    freopen("5405.out","w",stdout);
#endif
    while(n=read()){
        m=read();
        for(ll i=1;i<=n;++i)w[i]=read(),beg[i]=0;
        e=0;
        for(ll i=1;i<n;++i){
            add(read(),read());
        }
        dfs1(1,0);
        t=0;
        top[1]=1;
        dfs2(1,0);
        s1.make(sumw,n);
        s2.make(squw,n);
        while(m){
            --m;
            if(read()^2){

                ll u=read(),s=read();

                change(ys[u],s-w[u]);
                w[u]=s;
            }
            else{
                write(getans(ys[read()],ys[read()]),'\n');
            }
        }

    }
    return 0;
}

标签: 树链剖分

添加新评论