这篇文章上次修改于 197 天前,可能其部分内容已经发生变化,如有疑问可询问作者。
ACM程序课算法笔记17——树链剖分
定义
将一个树分割成若干条链即树链剖分。
树链剖分有多种形式,有重链剖分和长链剖分。至于如何剖分,实际上就是让每个点都选择一条边,并且这些边形成多个链条,重链剖分和长链剖分的区别在于选择哪一条边。
重链剖分
定义
- 重子节点:对于一个节点,有多个子节点,或者说多个子树,对于子树而言,size最大的子树的点我们称为重子节点。
- 轻子节点:相反,size最小的子树就是轻子节点。
- 重边:父节点连接到重子节点的边。
- 轻边:父节点连接到轻子节点的边。
- 重链:多条首尾相连的重边称为重链。
实现
实现分成两部分,两次的dfs完成。
第一次dfs获取:深度depth,父节点parent,子树大小size,重子节点hson。
void buildTree(int node, int dep){
hson[node] = -1;
size[node] = 1;
depth[node] = dep;
for(int son: adj[node]){
if(depth[son] != 0) continue;
buildTree(son, dep + 1);
size[son] += size[son];
parent[son] = node;
if(hson[node] == -1 or size[son] > size[hson[node]]){
hson[node] = son;
}
}
}
第二次dfs获取:重链顶端top,新的一个id的双映射,ori2new,new2ori
void dfs2(int node, int topNode) {
top[node] = topNode;
cnt++;
ori2new[node] = cnt;
new2ori[cnt] = node;
if (hson[node] == -1) return;
dfs2(hson[node], topNode);
for (int j: adj[node])
if (j != parent[node] and j != hson[node])
dfs2(j, j);
}
性质
树链剖分的获得的id的双映射,使每个点获得了一个新的id,这个新的id有个特点,就是在一条重链上的点的id都是连续的,这样可以更好的用区间管理来批量修改一个重链上的点。一颗子树内的 新id 是连续的。这个特征源自于这个顺序是dfs序。
树上的每个节点都属于或者仅属于一条重链,并且开头节点不一定为重子节点。
树链剖分求LCA
不断向上跳重链,当跳到同一条重链上时,深度较小的结点即为 LCA。
向上跳重链时需要先跳所在重链顶端深度较大的那个。
int lca(int u, int v) {
while (top[u] != top[v]) {
if (depth[top[u]] > depth[top[v]])
u = parent[top[u]];
else
v = parent[top[v]];
}
return depth[u] > depth[v] ? v : u;
}
树上维护
树上维护的需求可能有:对X子树的所有节点进行修改,对AB路径进行修改。
1.对X子树的所有节点进行修改:
已知我们的新id是dfs序,所以对于X子树来说,这个子树的所有节点新编号都是连续的。而且dfs来说,一个子树的根节点是这个子树中新序号最小的。再抓住我们已经知道的size数组。这个区间就是
[x, x + size[x] - 1]
2.对AB路径进行修改:
树链剖分可以快速的对一个路径上的所有点进行一个处理。如何获取这个路径上所有的点呢。
首先,如何获取这个路径呢?在剖分之后,一条路径可以分成多个重链+两个部分重链。只需要获取重链即可。
首先找到LCA,然后实际上就是两个点到LCA路径加在一起。
从A点到LCA,实际上就是跳过重链和度过几个轻边,
从B点到LCA,其实也同理。
class Tree
{
public:
Tree(int n) : num(n), cnt(0), adj(nullptr), hson(n + 1, -1), size(n + 1), depth(n + 1), parent(n + 1),
ori2New(n + 1), new2Ori(n + 1), top(n + 1), segTree(n + 1) {}
void build(vector<int> *adj_, int root)
{
adj = adj_;
parent[root] = -1;
build1(root, 1);
build2(root, root);
adj = nullptr;
}
void loadData(const vector<int> & arr){
vector<int> newArr(num + 1);
for(int i = 1; i <= num; i++){
newArr[ori2New[i]] = arr[i];
}
segTree.loadData(newArr);
}
void pathAdd(int p1, int p2, int addValue){
while(top[p1] != top[p2]){
if (depth[top[p1]] < depth[top[p2]]){
swap(p1, p2);
}
segTree.updateRange(ori2New[top[p1]], ori2New[p1], addValue);
p1 = parent[top[p1]];
}
if(depth[p1] > depth[p2]) swap(p1, p2);
segTree.updateRange(ori2New[p1], ori2New[p2], addValue);
}
int pathSum(int p1, int p2){
int result = 0;
while(top[p1] != top[p2]){
if (depth[top[p1]] < depth[top[p2]]){
swap(p1, p2);
}
result += segTree.queryRange(ori2New[top[p1]], ori2New[p1]);
p1 = parent[top[p1]];
}
if(depth[p1] > depth[p2]) swap(p1, p2);
result += segTree.queryRange(ori2New[p1], ori2New[p2]);
return result;
}
void treeAdd(int treeRoot, int addValue){
segTree.updateRange(ori2New[treeRoot], ori2New[treeRoot] + size[treeRoot] - 1, addValue);
}
int treeSum(int treeRoot){
return segTree.queryRange(ori2New[treeRoot], ori2New[treeRoot] + size[treeRoot] - 1);
}
private:
void build1(int node, int dep)
{
hson[node] = -1;
size[node] = 1;
depth[node] = dep;
for (int son : adj[node])
{
if (depth[son] != 0)
continue;
build1(son, dep + 1);
size[node] += size[son];
parent[son] = node;
if (hson[node] == -1 or size[son] > size[hson[node]])
{
hson[node] = son;
}
}
}
void build2(int node, int topNode)
{
top[node] = topNode;
cnt++;
ori2New[node] = cnt;
new2Ori[cnt] = node;
if (hson[node] == -1)
return;
build2(hson[node], topNode);
for (int j : adj[node])
if (j != parent[node] and j != hson[node])
build2(j, j);
}
int num;
int cnt;
vector<int> *adj;
vector<int> hson;
vector<int> size;
vector<int> depth;
vector<int> parent;
vector<int> ori2New;
vector<int> new2Ori;
vector<int> top;
SegmentTree segTree;
};