这篇文章上次修改于 371 天前,可能其部分内容已经发生变化,如有疑问可询问作者。

ACM程序课算法笔记12——树状数组

引入

对于数据储存,我们通常有求区间的需求,因此产生了不同的用于储存数据的数据结构,例如求区间和问题,我们会使用的前缀和数组。

本次我们考虑的问题是:单点更新和区间和查询:

普通数组:查询区间和O(N),修改单点值O(1)

前缀和数组:查询区间和O(1),修改单点值O(N)

但对于一个树状数组,有二分的特点

树状数组:查询区间和O(logN),修改单点值O(logN)

存储方式

树状数组是一种数组的储存形式,利用了类二叉树的结构,

图片略

对于树状数组的思路,有树状数组C和原数组A

(0001) C[1] = A[1]

(0010) C[2] = A[1] + A[2]

(0011) C[3] = A[3]

(0100) C[4] = A[1] + A[2] + A[3] + A[4]

(0101) C[5] = A[5]

(0110) C[6] = A[5] + A[6]

(0111) C[7] = A[7]

(1000) C[8] = A[1] + A[2] + A[3] + A[4] + A[5] + A[6] + A[7] + A[8]

对此储存方式,我们应当找到其中的规律,为了缩短篇幅,本笔记将线索写在了前面,也就是下标的二进制值,我们可以发现,结尾2的连续0数量次幂,也就是这个C元素包含原A中的个数,从本下标往前递减。

“结尾2的连续0数量次幂“这个描述只是为了方便观察规律,实际上个数也就等于最低位的1代表的值,例如1010,最低位1代表的值也就是2.

  • 最低位的获取方式也就是lowbit(x) = x & (-x) 利用了计算机补码的特性

单点更新

对于当我要修改原数组中的元素时,就需要修改不止一个树状数组里面的元素,要依据树的形状向上修改。例如,我要修改A[1],我们需要修改C[1],C[2],C[4],C[8]

如果修改A[3]我们就要修改C[3],C[4],C[8]

void update(int i, int val){ // 此处是让A[i] 增加val
	while(i <= n){
		C[i] += val;
		i += lowbit(i);
	}
}

从中可以发现单点更新的复杂度是O(LogN)

区间和查询

当要查询区间和时,对于树状数组,可以方便的先找到前缀和,例如找前缀和前7个

Sum[7] = C[4] + C[6] + C[7]

int preSum(int i){
	int res= 0;
	while (i > 0) {
	      res += c[i];
	      i -= lowbit(i);
	    }
	return res;
}

有以上代码,对此根据前缀和求区间和的方法,可以获得区间和为

int Sum(int a,int b) {
	return preSum(b) - preSum(a - 1);
}

因此,区间求和的复杂度也就是O(LogN)

例子:用树状数组求逆序对数量

基本思想:开一个数组c[n+1],初始化为0,记录前面数据的出现情况;当数据a出现时,就令c[a]=1。这样的话,若求a的逆序数,只需要算出在当前状态下c[a+1,n]中有多少个1,因为这些位置的数在a之前出现且比a大。

#include <iostream>
#include <vector>

using namespace std;

// 树状数组的实现
class FenwickTree {
private:
    vector<int> tree;

public:
    FenwickTree(int size) {
        tree.resize(size + 1, 0);
    }

    void update(int index, int delta) {
        while (index < tree.size()) {
            tree[index] += delta;
            index += index & (-index);  // 更新下一个节点
        }
    }

    int query(int index) {
        int sum = 0;
        while (index > 0) {
            sum += tree[index];
            index -= index & (-index);  // 计算前一个节点
        }
        return sum;
    }
};

int countInversions(vector<int>& nums) {
    int n = nums.size();
    FenwickTree fenwickTree(n);
    int inversions = 0;

    for (int i = n - 1; i >= 0; --i) {
        inversions += fenwickTree.query(nums[i] - 1); // 查询比当前数小的数的数量
        fenwickTree.update(nums[i], 1); // 将当前数加入树状数组
    }

    return inversions;
}

int main() {
    vector<int> nums = {3, 2, 1, 5, 4};
    cout << "Number of inversions: " << countInversions(nums) << endl;
    return 0;
}