这篇文章上次修改于 197 天前,可能其部分内容已经发生变化,如有疑问可询问作者。

ACM程序课算法笔记17——树链剖分

定义

将一个树分割成若干条链即树链剖分

树链剖分有多种形式,有重链剖分和长链剖分。至于如何剖分,实际上就是让每个点都选择一条边,并且这些边形成多个链条,重链剖分和长链剖分的区别在于选择哪一条边。

重链剖分

定义

  • 重子节点:对于一个节点,有多个子节点,或者说多个子树,对于子树而言,size最大的子树的点我们称为重子节点
  • 轻子节点:相反,size最小的子树就是轻子节点
  • 重边:父节点连接到重子节点的边。
  • 轻边:父节点连接到轻子节点的边。
  • 重链:多条首尾相连的重边称为重链。

实现

实现分成两部分,两次的dfs完成。

第一次dfs获取:深度depth,父节点parent,子树大小size,重子节点hson。

void buildTree(int node, int dep){
	hson[node] = -1;
	size[node] = 1;
	depth[node] = dep;
	for(int son: adj[node]){
		if(depth[son] != 0) continue;
		buildTree(son, dep + 1);
		size[son] += size[son];
		parent[son] = node;
		if(hson[node] == -1 or size[son] > size[hson[node]]){
			hson[node] = son;
		}
	}
} 

第二次dfs获取:重链顶端top,新的一个id的双映射,ori2new,new2ori

void dfs2(int node, int topNode) {
  top[node] = topNode;
  cnt++;
  ori2new[node] = cnt;
  new2ori[cnt] = node;
  
  if (hson[node] == -1) return;
 
  dfs2(hson[node], topNode);
  
  for (int j: adj[node])
    if (j != parent[node] and j != hson[node]) 
			dfs2(j, j);
}

性质

树链剖分的获得的id的双映射,使每个点获得了一个新的id,这个新的id有个特点,就是在一条重链上的点的id都是连续的,这样可以更好的用区间管理来批量修改一个重链上的点。一颗子树内的 新id 是连续的。这个特征源自于这个顺序是dfs序。

树上的每个节点都属于或者仅属于一条重链,并且开头节点不一定为重子节点。

树链剖分求LCA

不断向上跳重链,当跳到同一条重链上时,深度较小的结点即为 LCA。

向上跳重链时需要先跳所在重链顶端深度较大的那个。

int lca(int u, int v) {
  while (top[u] != top[v]) {
    if (depth[top[u]] > depth[top[v]])
      u = parent[top[u]];
    else
      v = parent[top[v]];
  }
  return depth[u] > depth[v] ? v : u;
}

树上维护

树上维护的需求可能有:对X子树的所有节点进行修改,对AB路径进行修改。

1.对X子树的所有节点进行修改:

已知我们的新id是dfs序,所以对于X子树来说,这个子树的所有节点新编号都是连续的。而且dfs来说,一个子树的根节点是这个子树中新序号最小的。再抓住我们已经知道的size数组。这个区间就是

[x, x + size[x] - 1]

2.对AB路径进行修改:

树链剖分可以快速的对一个路径上的所有点进行一个处理。如何获取这个路径上所有的点呢。

首先,如何获取这个路径呢?在剖分之后,一条路径可以分成多个重链+两个部分重链。只需要获取重链即可。

首先找到LCA,然后实际上就是两个点到LCA路径加在一起。

从A点到LCA,实际上就是跳过重链和度过几个轻边,

从B点到LCA,其实也同理。

class Tree
{
public:
    Tree(int n) : num(n), cnt(0), adj(nullptr), hson(n + 1, -1), size(n + 1), depth(n + 1), parent(n + 1),
                  ori2New(n + 1), new2Ori(n + 1), top(n + 1), segTree(n + 1) {}
    void build(vector<int> *adj_, int root)
    {
        adj = adj_;
        parent[root] = -1;
        build1(root, 1);
        build2(root, root);
        adj = nullptr;
    }
    void loadData(const vector<int> & arr){
        vector<int> newArr(num + 1);
        for(int i = 1; i <= num; i++){
            newArr[ori2New[i]] = arr[i];
        }
        segTree.loadData(newArr);
    }

    void pathAdd(int p1, int p2, int addValue){
        while(top[p1] != top[p2]){
            if (depth[top[p1]] < depth[top[p2]]){
                swap(p1, p2);
            }
            segTree.updateRange(ori2New[top[p1]], ori2New[p1], addValue);
            p1 = parent[top[p1]];
        }
        if(depth[p1] > depth[p2]) swap(p1, p2);
        segTree.updateRange(ori2New[p1], ori2New[p2], addValue);
    }

    int pathSum(int p1, int p2){
        int result = 0;
        while(top[p1] != top[p2]){
            if (depth[top[p1]] < depth[top[p2]]){
                swap(p1, p2);
            }
            result += segTree.queryRange(ori2New[top[p1]], ori2New[p1]);
            p1 = parent[top[p1]];
        }
        if(depth[p1] > depth[p2]) swap(p1, p2);
        result += segTree.queryRange(ori2New[p1], ori2New[p2]);
        return result;
    }

    void treeAdd(int treeRoot, int addValue){
        segTree.updateRange(ori2New[treeRoot], ori2New[treeRoot] + size[treeRoot] - 1, addValue);
    }

    int treeSum(int treeRoot){
        return segTree.queryRange(ori2New[treeRoot], ori2New[treeRoot] + size[treeRoot] - 1);
    }

private:
    void build1(int node, int dep)
    {
        hson[node] = -1;
        size[node] = 1;
        depth[node] = dep;
        for (int son : adj[node])
        {
            if (depth[son] != 0)
                continue;
            build1(son, dep + 1);
            size[node] += size[son];
            parent[son] = node;
            if (hson[node] == -1 or size[son] > size[hson[node]])
            {
                hson[node] = son;
            }
        }
    }

    void build2(int node, int topNode)
    {
        top[node] = topNode;
        cnt++;
        ori2New[node] = cnt;
        new2Ori[cnt] = node;

        if (hson[node] == -1)
            return;

        build2(hson[node], topNode);

        for (int j : adj[node])
            if (j != parent[node] and j != hson[node])
                build2(j, j);
    }

    int num;
    int cnt;
    vector<int> *adj;
    vector<int> hson;
    vector<int> size;
    vector<int> depth;
    vector<int> parent;
    vector<int> ori2New;
    vector<int> new2Ori;
    vector<int> top;
    SegmentTree segTree;
};