Counter Smaller After Itself

这题实在太经典了. Carry额外信息(Augment BST).

Solution 1: Merge Sort

Why this question is related with merge sort ?

  1. Phase 1: divide
  2. Phase 2: combine
这真是一个非常典型的Divide and Conquer.
eg.      [4   1    3   2] 
        /             \
     [4  1]           [3  2]
     /   \           /      \
   [4,0]  [1,0]    [3,0]    [2,0]
     \      /         \       /
   [1,0] [4,1]      [2,0]   [3,1]
      \                      /
      [1,0] [2,0] [3,1] [4,3]

Case1: if right part is smaller than left part, leftNum.count+. and Swap.

Case2: if right part is larger, merge.

I am using two array (one for count) at first, however, hard to read. Better to declare a Pair class.

Key Point : How to define subproblem?

always count smaller in right partition.

Loop Invariant:

all elements in right partition are smaller is all elements in left partition, during sort, count smaller one in right partition and add it to elements in left.

    class Pair {
          int val;
          int count;
          public Pair (int val, int count);
    public Pair[] vMerge (int[] array, int left, int right) {
        if (left >= right) {
            return new Pair(array[left],0);
        int mid = left + (right - left) / 2;
        Pair[] lRes = vMerge(array, left, mid);
        Pair[] rRes = vMerge(array,mid + 1, right);
        return combine(lRes,rRes);
    public Pair[] combine(Pair[] lRes, Pair[] rRes) {
        //declare a pair array for holding results.
        Pair[] res = new Pair[lRes.length + rRes.length];
        int i = 0, j = 0;
        int idx = 0;
        while (i < lRes.length && j < rRes.length) {
             if (lRes[i].val < rRes[j].val) {
                res[idx++] = new Pair(lRes[i].val, lRes[i].count + j);
             } else {
                 res[idx++] = new Pair(rRes[j].val, rRes[j].count);
        if (i < lRes.length) {
            res[idx++] = new Pair(lRes[i].val, lRes[i].count + i);
        if (j < rRes.length) {
            res[idx++] = new Pair(rRes[j].val, rRes[j].count);
        return res;

Solution 2 : Binary Index Tree

Solution 3 : Special Trick

Implementation 1 : this doesn't guarantee tree is balanced.

If a long linkedlist. this will be very ineffcient. basically iterate each element.

   class Node {
       int val, leftCount, dupCount = 0;
       Node left, right;
       public Node(int x) {
           this.val = x;
   public List<Integer> countSmaller(int[] nums) {
         if (nums.length == 0) {
              return Arrays.asList(nums);
         int[] counts = new int[nums.length];
         Node root = new Node[nums.length - 1];
         for(int i = nums.length - 1; i >= 0; i--) {
               counts[i] = insert(root, nums[i]);
         return Arrays.asList(counts);
  private int insert(Node node, int num) {
        int smaller = 0;
        while(node.val != num) {
             //case 1: insertion node is smaller than current node, update current node left count++
             if(node.val > num) {
                  if(node.left == null) {
                        node.left = new TreeNode(num);
                  node = node.left;
             //case 2 : insertion node is larger, add cur node left subtree plus itself dup.
             else {
                   node.leftCount += node.leftCount + node.dupCount;
                   if(node.right == null) {
                       node.right = new TreeNode(num);
                   node = node.right;
        return smaller + node.leftCount;
Implemenation 2 : Sort and maintain balanced by using boolean value (smart and tricky).

Limitation: no duplicate exists.

  1. Sort array and construct BST without really inserting values in it.
  2. insert in tree and get elments < A[i].

注意: array sort之前要先存一下original array. 哭了,debug了半个小时。 修改了Input啊TAT

class Node {
      int inserted;
      int numLeft;
      int val;
      Node left;
      Node right;
      public Node (int val) {
          this.val = val;
          this.inserted = 0;
int[] countSmaller(int[] array) {
    int[] counts = new int[array.length];
    int[] ori = Arrays.copyOf(array, array.length);
    Node root = construct(array);
    for(int i = ori.length - 1; i >= 0; i--) {
          counts[i] = insert(root, ori[i]);
    return counts;

public int insert (Node root, int val) {
    int smaller = 0;
    while (root.val != val) {
          //update target smaller count
          if (root.val <  target) {
                smaller += root.numLeft + root.inserted;
                root = root.right;
          } else {
              root = root.left;
    return smaller + root.numLeft;

Node construct(int[] array) {
      array = dedup(array);
      Node root = inorder(array,0,array.length - 1);
      return root;
int[] dedup(int[] array) {
    //[0,s] distinct
    for(int f = 1; f < array.length; f++) {
        if (array[f] != array[s]) {
            array[++s] = array[f];
    return Arrays.copyOf(array,0,s + 1);
Node inorder(int[] array, int left, int right) {
      if (left > right) {
          return null;
      int mid = left + (right - left) / 2;   
      Node root = new TreeNode(array[mid]);
      root.left = inorder(array,left,mid - 1);
      root.right = inorder(array, mid + 1,right);
      return root;

