Categories
interview

Range Sum Query 2D Mutable – Java

We can use a 2D Binary Indexed Tree (Fenwick Tree) to efficiently calculate sum in a rectangular area in a mutable 2D matrix.

public class RangeSum {
    int[][] bit;
    int[][] inp;

    public RangeSum(int[][] inp) {
        this.inp = inp;
        this.bit = new int[inp.length + 1][inp[0].length + 1];
        for (int i = 0; i < inp.length; i++)
            for (int j = 0; j < inp[0].length; j++)
                add(i, j, inp[i][j]);
    }

    public void add(int i, int j, int delta) {
        // Increment the first set bit from the right. 
        // https://www.quora.com/In-programming-what-does-n-n-return
        // Faster increment i = i | (i + 1)
        for (int row = i + 1; row <= this.inp.length; row += row & (-row))
            for (int col = j + 1; col <= this.inp[0].length; col += col & (-col))
                this.bit[row][col] += delta;
    }

    public void update(int i, int j, int value) {
        int delta = value - this.inp[i][j];
        this.inp[i][j] = value;
        this.add(i, j, delta);
    }

    public int calcSum(int i, int j) {
        int sum = 0;
        // Remove the first set bit from the right.
        // Faster decrement i = (i & (i + 1)) - 1
        for (int row = i + 1; row > 0; row -= row & (-row))
            for (int col = j + 1; col > 0; col -= col & (-col))
                sum += this.bit[row][col];
        return sum;
    }

    public int getSum(int row1, int col1, int row2, int col2) {
        return calcSum(row2, col2) - calcSum(row1 - 1, col2) - calcSum(row2, col1 - 1) + calcSum(row1 - 1, col1 - 1);
    }

    public static void main(String[] args) {
        int[][] arr = {
                        {1, 2, 3},
                        {4, 5, 6},
                        {7, 8, 9}
                      };
        RangeSum rangeSum = new RangeSum(arr);
        System.out.println(rangeSum.getSum(1, 1, 2, 2)); // 28
        rangeSum.update(1, 1, 10);
        System.out.println(rangeSum.getSum(0, 0, 2, 2)); // 50
        System.out.println(rangeSum.getSum(1, 1, 2, 2)); // 33
    }
}

An alternative approach could be to convert the 2D array into a 1D array. For example:

...

public int getIndex(int x, int y) {
  // Size of the array is this.totalColumns * this.totalRows
  return this.totalColumns * x + y;
}