Range Sum Query 2D – Mutable

  • My natural naive knowledge is:
    • either have update() in O(1) time and sumRange() in O(n) time: have a loop to calculate the sum each time after an update
    • or have update() in O(n) time and sumRange() in O(1) time: have another array to store the sum of all previous elements, thus the sum for f(n) could be deduced from sum(n+1) – sum(n)
  • For sure, my naive approach got TLE.
  • Time to learn a more advanced data structure: Binary Indexed Tree!
    • Leaf nodes are the elements of the input array
    • each internal node represents some merging of the leaf nodes

 

  • An array representation of tree could be used to represent Segment Trees. For each node at index i, the left child is at index 2*i+1, right child at 2*i+2 and the parent is at floor((i-1)/2).
  • In mathematics and computer science, the floor and ceiling functions map a real number to the largest previous or the smallest following integer, respectively. More precisely, floor(x)  is the largest integer less than or equal to x and ceiling(x)  is the smallest integer greater than or equal to x.
  • There’s a graph here which is very helpful: http://www.geeksforgeeks.org/binary-indexed-tree-or-fenwick-tree-2/
    • Every node has two numbers: one is the index of this node, the other is the value at this node
    • One important formula to get the parent index of a node at index i: parent(i) = i-i&(-i)
    • The above formula means to remove the right-most bit from i, and then AND with i, and then subtract that result from i, the following code snippet will be helpful, given i = 12, then parent(i) = 8
      
      
    • System.out.println(Integer.toBinaryString(12));
      System.out.println(Integer.toBinaryString(-12));
      System.out.println(Integer.toBinaryString(12&(-12)));
      System.out.println(12 – (12&(-12)));1100
      11111111111111111111111111110100
      100
      8
    • In the opposite direction, the formula would become parent(i) = i+i&(-i)

Looked at the two posts:

  1. https://discuss.leetcode.com/topic/30343/java-2d-binary-indexed-tree-solution-clean-and-short-17ms
  2. https://discuss.leetcode.com/topic/30935/share-my-java-2-d-binary-indexed-tree-solution

Key notes for me:

  • The biggest time cost is in constructing the binary indexed tree: O(m*n*lgm*lgn), (apparently, there’s a nested fro loop, thus m*n, and inside the nested for loop, it’s calling update(), which is O(lgm*lgn)), however, the tree only needs to be built once, and the benefit this tree brings is that after the tree is built, both update() and sumRange() will have O(lgm*lgn) time.
  • we need an additional array nums[][] to store matrix[][]
  • in the constructor, we build the tree, by calling update(i,j,newVal); function
  • inside update(i,j,newVal); function: we compute the delta between the newVal and the current value stored at nums[i][j], and then
    • store the newVal into this position
    • increment this node and all its parent nodes by delta
      • so, inside this update(i,j,newVal); function, the for loop starts from i+1 and j+1
  • how the sum is computed from the constructed tree:
    • row1,col1 is the overlap that got subtracted twice. That’s why it’s written like this.
int[][] nums;
int[][] tree;
int height;
int width;

public NumMatrix(int[][] matrix) {
if(matrix.length == 0 || matrix[0].length == 0) return;
height = matrix.length;
width = matrix[0].length;
this.nums = new int[height][width];
this.tree = new int[height+1][width+1];
for(int i = 0; i < height; i++){
for(int j = 0; j < width; j++){
update(i, j, matrix[i][j]);
}
}
}

public void update(int rowIndex, int colIndex, int newVal) {
if(height == 0 || width == 0) return;
int delta = newVal - nums[rowIndex][colIndex];
nums[rowIndex][colIndex] = newVal;
for(int i = rowIndex+1; i <= height; i += i&(-i)){
for(int j = colIndex+1; j <= width; j += j&(-j)){
tree[i][j] += delta;//just use its previous value plus delta is good
}
}
}

public int sumRegion(int row1, int col1, int row2, int col2) {
if(height == 0 || width == 0) return 0;
return sum(row2+1, col2+1) + sum(row1, col1) - sum(row1, col2+1) - sum(row2+1, col1);
}

private int sum(int row, int col) {
int sum = 0;
for(int i = row; i > 0; i -= i&(-i)){
for(int j = col; j > 0; j -= j&(-j)){
sum += tree[i][j];
}
}
return sum;
}

Advertisements