树上启发式合并
归档于 2019-07-30 21:44
树上启发式合并,一种美妙的黑科技,可以用普通的优化让你$n^2$变成严格$n log$,解决一些类似(树上数颜色,树上查众数)这样的问题
首先你要知道暴力为什么是$n^2$的
以这个图为例

每次你从一个节点开始向下搜,你从1节点搜到3,搜完这个子树然后你需要把3存的col等信息删去再遍历另一个子树才是正确的
那么我们每次遍历这个节点一个子树,每次搜完这棵子树都要清空当前子树储存信息这样(最差)复杂度$n^2$
我们可以发现清空最后一个遍历的子树是没有意义的,那么我们人为把最后一个子树放到最后不就是最优的吗
所以,首先我们先找出来重链,轻链,对于轻链我们求出子树答案,再清除子树贡献,.然后求出重链上子树答案,不清除贡献.最后我们再算一遍子树对当前节点贡献即可
你可能会认为,这不就是一个简单的优化吗,怎么就是$n log$了
我不知道
它并没有优化最优复杂度而是避免了最差复杂度
以给一棵根为1的树,每次询问子树颜色种类数为例
代码大致如下
#include<bits/stdc++.h>
using namespace std;
#define ll int
#define r register
#define A 1001010
ll head[A],nxt[A],ver[A],size[A],col[A],cnt[A],ans[A],son[A];
ll tot=0,num,sum,nowson,n,m,xx,yy;
inline void add(ll x,ll y){
nxt[++tot]=head[x],head[x]=tot,ver[tot]=y;
}
inline ll read(){
ll f=1,x=0;char c=getchar();
while(!isdigit(c)){
if(c=='-') f=-1;
c=getchar();
}
while(isdigit(c))
x=(x<<1)+(x<<3)+(c^48),c=getchar();
return f*x;
}
void dfs(ll x,ll fa){
size[x]=1;
for(ll i=head[x];i;i=nxt[i]){
ll y=ver[i];
if(y==fa) continue;
dfs(y,x);
size[x]+=size[y];
if(size[son[x]]<size[y])
son[x]=y;
}
}
void cal(ll x,ll fa,ll val){
if(!cnt[col[x]]) ++sum;
cnt[col[x]]+=val;
for(ll i=head[x];i;i=nxt[i]){
ll y=ver[i];
if(y==fa||y==nowson) continue;
cal(y,x,val);
}
}
void dsu(ll x,ll fa,bool op){
for(ll i=head[x];i;i=nxt[i]){
ll y=ver[i];
if(y==fa||y==son[x])
continue;
dsu(y,x,0);
//从轻儿子出发
}
if(son[x])
dsu(son[x],x,1),nowson=son[x];
cal(x,fa,1);nowson=0;
ans[x]=sum;
if(!op){
cal(x,fa,-1);
sum=0;
}
}
int main(){
n=read();
for(ll i=1;i<=n-1;i++){
xx=read(),yy=read();
add(xx,yy),add(yy,xx);
}
for(ll i=1;i<=n;i++)
col[i]=read();
dfs(1,0);
dsu(1,0,1);
m=read();
for(ll i=1;i<=m;i++){
xx=read();
printf("%d\n",ans[xx]);
}
}
另一种打法
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
using namespace std;
#define R register
#define ll long long
inline ll read(){
ll aa=0;R int bb=1;char cc=getchar();
while(cc<'0'||cc>'9')
{if(cc=='-')bb=-1;cc=getchar();}
while(cc>='0'&&cc<='9')
{aa=(aa<<1)+(aa<<3)+(cc^48);cc=getchar();}
return aa*bb;
}
const int N=1e5+3;
struct edge{
int v,last;
}ed[N<<1];
int first[N],tot;
inline void add(int x,int y)
{
ed[++tot].v=y;
ed[tot].last=first[x];
first[x]=tot;
}
int n,m,c[N],son[N],cnt[N],ans[N],siz[N];
void dfsi(int x,int fa)
{
siz[x]=1;
for(R int i=first[x],v;i;i=ed[i].last){
v=ed[i].v;
if(v==fa)continue;
dfsi(v,x);
siz[x]+=siz[v];
if(siz[v]>siz[son[x]])son[x]=v;
}
return;
}
int dfsj(int x,int fa,int bs,int kep)
{
if(kep){
for(R int i=first[x],v;i;i=ed[i].last){
v=ed[i].v;
if(v!=fa&&v!=son[x])
dfsj(v,x,0,1);
}
}
int res=0;
if(son[x])res+=dfsj(son[x],x,1,kep);
for(R int i=first[x],v;i;i=ed[i].last){
v=ed[i].v;
if(v!=fa&&v!=son[x])
res+=dfsj(v,x,0,0);
}
if(!cnt[c[x]])res++;
cnt[c[x]]++;
if(kep){
ans[x]=res;
if(!bs)memset(cnt,0,sizeof(cnt));
}
return res;
}
int main()
{
n=read();
for(R int i=1,x,y;i<n;++i){
x=read();y=read();
add(x,y);add(y,x);
}
for(R int i=1;i<=n;++i)c[i]=read();
dfsi(1,0); dfsj(1,0,1,1);
m=read();
for(R int i=1,x;i<=m;++i){
x=read();
printf("%d\n",ans[x]);
}
return 0;
}
虽然好像没什么区别
然后再看一道例题
有一棵 n 个节点的以 1 号节点为根的树,每个节点上有一个小桶,节点u上的小桶可以容纳${k_u}$
个小球,ljh每次可以给一个节点到根路径上的所有节点的小桶内放一个小球,如果这个节点的小桶满了则不能放进这个节点,最后多次询问某个节点值
首先暴力不能过
直接权值线段树+线段树合并很难维护,树链剖分也难以维护,但我们直接树上启发式合并+线段树暴力修改可以维护。
首先单纯线段树暴力修改可以维护,但这会超时。于是我们用启发式合并作为时间复杂度保证,莫名奇妙AC了这个题
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define A 1001010
ll head[A],nxt[A],ver[A],size[A],son[A],tong[A],col[A],getfa[A],isbigson[A],ans[A],al[A];
vector<pair<ll,ll> >v[A];
map<ll,ll>mp;
ll n,m,tot=0,Q,wwb=0;
struct tree{
ll l,r,f,x,t,c;
}tr[A];
void add(ll x,ll y){
nxt[++tot]=head[x],head[x]=tot,ver[tot]=y;
}
void prdfs(ll x,ll fa){
size[x]=v[x].size()+1;
for(ll i=head[x];i;i=nxt[i]){
ll y=ver[i];
if(y==fa) continue;
prdfs(y,x);
size[x]+=size[y];
if(size[son[x]]<size[y])
isbigson[son[x]]=0,son[x]=y,isbigson[y]=1;
}
}
void built(ll p,ll l,ll r){
tr[p].l=l,tr[p].r=r;
if(tr[p].l==tr[p].r){
return ;
}
ll mid=(l+r)>>1;
built(p<<1,l,mid);
built(p<<1|1,mid+1,r);
}
ll ask(ll p,ll pos){
if(pos>=tr[p].t) return tr[p].c;
return (pos>=tr[p<<1].t?tr[p<<1].c+ask(p<<1|1,pos-tr[p<<1].t):ask(p<<1,pos));
}
void insert(ll p,ll pos,ll t,ll c){
if(tr[p].l==tr[p].r)
{tr[p].t+=t;tr[p].c+=c;return;}
if(pos<=tr[p<<1].r)
insert(p<<1,pos,t,c);
else
insert(p<<1|1,pos,t,c);
tr[p].t=tr[p<<1].t+tr[p<<1|1].t;
tr[p].c=tr[p<<1].c+tr[p<<1|1].c;
}
void up(ll x,ll fa){
if(v[getfa[x]].size()<v[getfa[fa]].size()){
for(ll i=0;i<v[getfa[x]].size();i++)
v[getfa[fa]].push_back(v[getfa[x]][i]);
v[getfa[x]].clear();
getfa[x]=getfa[fa];
}
else{
for(ll i=0;i<v[getfa[fa]].size();i++)
v[getfa[x]].push_back(v[getfa[fa]][i]);
v[getfa[fa]].clear();
getfa[fa]=getfa[x];
}
}
void dfs(ll x,ll fa){
for(ll i=head[x];i;i=nxt[i]){
ll y=ver[i];
if(y==fa||y==son[x]) continue;
dfs(y,x);
}
if(son[x]) dfs(son[x],x);
for(ll i=0;i<v[getfa[x]].size();i++){
ll tim=v[getfa[x]][i].first,col=v[getfa[x]][i].second;
if(!al[col]) al[col]=tim,insert(1,tim,1,1);
else if(al[col]>tim){
insert(1,al[col],0,-1);
insert(1,tim,1,1);
al[col]=tim;
}
else insert(1,tim,1,0);
}
// printf("t=%lld tong=%lld\n",tr[1].t,tong[x]);
ans[x]=ask(1,min(tr[1].t,tong[x]));
if(son[x])
up(son[x],x);
if(!isbigson[x]){
for(ll i=0;i<v[getfa[x]].size();i++){
ll tim=v[getfa[x]][i].first,col=v[getfa[x]][i].second;
if(al[col]==tim)
insert(1,tim,-1,-1),al[col]=0;
else
insert(1,tim,-1,0);
}
up(x,fa);
}
/* for(ll i=1;i<=5;i++){
printf("ans=%lld ",ans[i]);
}
*//* cout<<endl;*/
}
int main(){
scanf("%lld",&n);
for(ll i=1;i<n;i++){
ll xx,yy;
scanf("%lld%lld",&xx,&yy);
add(xx,yy),add(yy,xx);
}
for(ll i=1;i<=n;i++){
scanf("%lld",&tong[i]);
getfa[i]=i;
}
prdfs(1,0);
scanf("%lld",&m);built(1,1,m);
for(ll i=1,x,c;i<=m;i++){
scanf("%lld%lld",&x,&c);
if(!mp[c])
mp[c]=++wwb;
//离散化
v[x].push_back(make_pair(i,mp[c]));
}
dfs(1,0);
scanf("%lld",&Q);
for(ll i=1,x;i<=Q;i++){
scanf("%lld",&x);
printf("%lld\n",ans[x]);
}
}