Segment tree in Java

In this tutorial, we will see how a segment tree is implemented in Java and also on how to build a segment tree, update a value in the segment tree and query in a segment tree in Java.

A segment tree is a data structure that allows answering a range of queries and updates over an array.

Let us consider an array ‘A’ of size ‘N’ corresponding to the segment tree ‘T’.

  • The root node of the T represents the whole array as [0:N-1].
  • Now the root node must be divided into half of the root node i.e A[0:(N-1)/2] and A[0:((N-1)/2)+1].
  • Again each child node is divided into equal halves.
  • The total number of nodes in a segment tree can be either 2N or 2N-1.

Once a segment tree is built the user can update a value in an array and query value in a segment tree.

 A simple Java program to build, update and query value in a segment tree

class SegmentTree{

    int[] tree;

    SegmentTree(int n){
        tree = new int[n];
    }

    void build(int[] arr, int node, int start, int end){
        if(start == end){
            tree[node] = arr[start];
        }

        else{
            int mid = (start + end)/2;
            build(arr, 2*node+1, start, mid);
            build(arr, 2*node+2, mid+1, end);
            tree[node] = tree[2*node+1] + tree[2*node+2];
        }
    }

    void update(int[] arr, int node, int index, int val, int start, int end){
        if(start == end){
            arr[index] += val;
            tree[node] += val;
        }

        else{
            int mid = (start + end)/2;
            if(start <= index && index <= mid){
                update(arr, 2*node+1, index, val, start, mid);
            }
            else{
                update(arr, 2*node+2, index, val, mid + 1, end);
            }
            tree[node] = tree[2*node+1] + tree[2*node+2];
        }
    }

    int query(int node, int start, int end, int left, int right){

        if(right < start || end < left){
            return 0;
        }

        if(left <= start && end <= right){
            return tree[node];
        }

        int mid = (start + end)/2;
        int p1 = query(2*node+1, start, mid, left, right);
        int p2 = query(2*node+2, mid+1, end, left, right);
        return p1 + p2;
    }
}

public class Tasktest{
    public static void main(String[] args){
        int arr[] = {11, 22, 33, 4, 5, 6, 45, 74, 8}; // user can give any values
        int n = arr.length;
        int height = (int)( Math.log(n)/Math.log(2) )+ 1;
        int tree_nodes = (int) Math.pow(2, height+1);
        SegmentTree ob = new SegmentTree(tree_nodes);
        ob.build(arr, 0, 0, n-1);
        for(int i = 0; i < tree_nodes; i++){
            System.out.print(ob.tree[i] + " ");
        }
        System.out.println();
        System.out.println(ob.query(0, 0, n-1, 0, 1));
    }
}
output:
208 75 133 66 9 51 82 33 33 4 5 6 45 74 8 11 22 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
33

You may also read:

Leave a Reply