Motivation
Building on ideas from Part 1, in this part, we explore some lesser-known extensions and applications of BIT.
This is by no means a finished post, and I'll keep updating whenever I find new BIT applications/extensions.
Variation 1: Range update, Point query
We want the data structure, which supports the following operations on an input array \(a[1..n]\):
range_update(L, R, v): increase all \(a[L..R]\) by \(v\)point_query(i): return the current \(a[i]\)
Recall from Part 1, normal BIT only supports range prefix queries and point updates. How are we going to implement point query, range update with existing operations?
Consider the diff array d[1..n], d[1] = a[1], d[i] = a[i]-a[i-1] for i=2,...,n,
a[i] = d[i] + d[i-1] + ... + d[1] for i=1,...,n.
The key idea is to use the basic range query-point update BIT on the diff array \(d[1..n]\) instead of the original array \(a[1..n]\).
Clearly, query(i) = sum(d[1..i]) = a[i].
Now think about the range update u(L,R,v), \(a[i] \leftarrow a[i]+v\) for all \(L \leq i \leq R\), how does the diff array change with the update?
- No change to \(d[i]\) for any \(i < L\), \(L+1 \leq i \leq R\), \(R+1 < i\)
- \(d[L] \leftarrow d[L] + v\)
- \(d[R+1] \leftarrow d[R+1] - v\)
Thus we can achieve the range update u(L,R,v) with two point updates:
update(L,v)update(R+1,-v)
See also my reference implementation in C++ if you're interested.
Variation 2: Range update, Range query
We want the data structure, which supports the following operations on an input array \(a[1..n]\):
range_update(L, R, v): increase all \(a[L..R]\) by \(v\)range_query(i): return the prefix sum up to \(i\), i.e., \(\sum_{j=1}^i a[j]\)
We've seen above how range update can be implemented. Let see how prefix range queries at different \(i\)s change after a range_update(L,R,v), let denote the change to sum(a[1..i]) as \(\delta(i)\)
- \(i < L\), \(\delta(i) = 0 = 0 * i - 0\)
- \(L \leq i \leq R\), \(\delta(i) = (i-L+1) * v = v * i - (L-1) * v\)
- \(i > R\), \(\delta(i) = (R-L+1) * v = 0 * i + R * v - (L-1) * v = 0 * i - ((L-1) * v - R * v)\)
Note \(\delta(i) = b(i) * i - c(i)\). We can use 2 range update/point query BITs (the variation 1 above) to implement the range query and range update. The first BIT computes \(b(i)\):
The second BIT computes \(c(i)\):
range_update(L,R,v) can be implemented as:
# bit1
bit1.point_update(L, v)
bit1.point_update(R+1,-v)
# bit2
bit2.point_update(L, (L-1)*v)
bit2.point_update(R+1, -R*v)
Also,
range_query(i) = bit1.query(i) * i - bit2.query(i)
See also my reference implementation in C++ if you're interested.
2D BIT
Consider a 2D grid of size \(H \times W\), we want to support the following operations on the grid in \(O(\log H * \log W)\) time:
point_update(r,c,v): increase the \(g[r][c]\) by \(v\)range_query(r1,r2,c1,c2): return sum of elements in \(g[r1..r2][c1..c2]\), \(r1 \leq r2, c1 \leq c2\).
As you may expect, the point_update is a straightforward extension of the normal 1D BIT into 2D. The range query can be computed as:
range_query(r1,r2,c1,c2) = range_query(1,r2,1,c2) - range_query(1,r2,1,c1-1) - range_query(1,r1-1,1,c2) + range_query(1,r1-1,1,c1-1)
See also my reference implementation in C++.
Inversion Counting
\(a[1..n]\) is a permutation of \(\{1,2...n\}\). An inversion is a pair \((i,j)\), \(1 \leq i < j \leq n\) such that \(a[i] > a[j]\). Count the number of inversions in \(a[1..n]\).
The key idea is use the values \(a[i]\) as index in BIT.
Below is an overly commented Python snippet:
def count_inversion(a: List[int]) -> int:
// initialize a BIT of size len(a): bit
inv_cnt = 0
for i, x in enumerate(a):
# bit.query(x) returns how many values added so far, which are smaller than x.
# Thus i - bit.query(x) is the number of values added which are bigger than x.
inv_cnt += (i - bit.query(x))
# It's important to update after the query.
bit.update(x,1)
return inv_cnt
Order Statistic Tree
We need a set-like data structure, which supports the following operations:
add(x): add \(x\) to the setdelete(x): remove \(x\) from the set if existedorder_of(x): count the number of elements in the set that are strictly less than \(x\)find_by_order(k): return the \(k^{th}\) lowest element in the set
Constraint: all updates \(x\) are in the range \([1..N]\) for some integer \(N\).
Again, the trick here is to use the element value as index in BIT array, making use of the constraint that all element values are in the range \([1..N]\).