
package coins.backend;

import java.lang.*;
import java.io.*;
import java.util.*;
import coins.backend.*;
import coins.backend.tmd.*;
import coins.backend.util.*;
import coins.backend.sym.*;
import coins.backend.lir.*;
import coins.backend.cfg.*;
import coins.backend.ana.*;
import coins.backend.opt.*;

public class RegisterAllocation {

  Function func;
  PrintWriter debOut;

  /** Construct register allocation object. **/
  public RegisterAllocation(Function f, PrintWriter out) {
    func = f;
    debOut = out;
  }

  /** Do register allocation, return true if succeed to allocate. **/
  public boolean allocate() {
    // Convert to SSA form.\
    func.apply(new Ssa(debOut));

    debOut.println();
    debOut.println("After SSA:");
    func.printIt(debOut);
    
    // Create live range and convert back to regular form
    liveRange();

    debOut.println();
    debOut.println("After live range:");
    func.printIt(debOut);

    // Do live variable analysis.
    LiveVariableAnalysis liveInfo
      = (LiveVariableAnalysis)func.require(LiveVariableAnalysis.analyzer);

    // Create interference graph.
    InterferenceGraph ig
      = (InterferenceGraph)func.require(InterferenceGraph.analyzer);
    
    debOut.println();
    debOut.println("After live var. analysis/interference graph:");
    func.printIt(debOut,
                 new LocalAnalyzer[] {
                   LiveVariableAnalysis.analyzer,
                   InterferenceGraph.analyzer});

    // Coalescing.

    // Compute spill cost.

    // 

    return true;
  }


  /** Create live range **/
  private void liveRange() {
    int symBound = func.localSymtab.idBound();
    int parent[] = new int[symBound];
    int count[] = new int[symBound];
    for (int i = 0; i < symBound; i++) {
      parent[i] = -1;
      count[i] = 1;
    }

    Symbol symVec[] = func.localSymtab.sortedSymbols();
    
    // Scan PHI instructions
    for (BiLink p = func.flowGraph.basicBlkList.first(); !p.atEnd(); p = p.next()) {
      BasicBlk blk = (BasicBlk)p.elem();

      for (BiLink q = blk.instrList().first(); !q.atEnd(); q = q.next()) {
        LirNode ins = (LirNode)q.elem();

        if (ins.opCode == Op.PHI) {
          Symbol dst = ((LirSymRef)ins.src(0)).symbol;
          int n = ins.nSrcs();
          for (int i = 1; i < n; i++) {
            if (ins.src(i).src(0).opCode == Op.REG) {
              Symbol src = ((LirSymRef)ins.src(i).src(0)).symbol;
              bindRange(dst.id, src.id, parent, count);
            }
          }
        }
      }
    }

    // Rename to new name; remove phi
    for (BiLink p = func.flowGraph.basicBlkList.first(); !p.atEnd(); p = p.next()) {
      BasicBlk blk = (BasicBlk)p.elem();

      for (BiLink q = blk.instrList().first(); !q.atEnd(); ) {
        LirNode ins = (LirNode)q.elem();
        BiLink next = q.next();

        if (ins.opCode == Op.PHI)
          q.unlink();
        else
          renameRegvars(ins, parent, symVec);

        q = next;
      }
    }

  }


  /** Rename variables **/
  private void renameRegvars(LirNode tree, int parent[], Symbol symVec[]) {
    int n = tree.nSrcs();
    for (int i = 0; i < n; i++) {
      if (tree.src(i).opCode == Op.REG) {
        int id = ((LirSymRef)tree.src(i)).symbol.id;
        if (parent[id] >= 0)
          tree.setSrc(i, func.newLir.symRef(symVec[parent[id]]));
      }
      else
        renameRegvars(tree.src(i), parent, symVec);
    }
  }


  /** Disjoint set union - Tarjan's Path Compression **/

  private int find(int parent[], int x)
  {
    int y, w;

    /* y = root of x */
    for (y = x; parent[y] >= 0; y = parent[y])
      ;
    /* compress */
    for (; (w = parent[x]) >= 0; x = w)
      parent[x] = y;
    return y;
  }
  

  private int union(int parent[], int count[], int x, int y)
  {
    if (count[x] < count[y]) {
      int w = x; x = y; y = w;
    }
    parent[y] = x;
    count[x] += count[y];
    return x;
  }


  private void bindRange(int x, int y, int parent[], int count[]) {
    /* union symx,symy */
    x = find(parent, x);
    y = find(parent, y);
    if (x == y)
      return; /* already joined */

    int w = union(parent, count, x, y);
    if (w == y) {
      y = x;
      x = w;
    }
  }

}
