
package coins.backend;

import java.lang.*;
import java.io.*;
import java.util.*;
import coins.backend.*;
import coins.backend.tmd.*;
import coins.backend.util.*;
import coins.backend.sym.*;
import coins.backend.lir.*;
import coins.backend.cfg.*;
import coins.backend.ana.*;
import coins.backend.opt.*;


/** Register allocation session object. **/
public class RegisterAllocation {

  private Function func;
  private TargetMachine rule;

  private BackEnd root;

  private LiveVariableAnalysis liveInfo;
  private InterferenceGraph idg;

  /** Construct register allocation object. **/
  private RegisterAllocation(Function func, TargetMachine rule, BackEnd root) {
    this.func = func;
    this.rule = rule;
    this.root = root;
  }

  /** Do register allocation, return true if successful. **/
  public static boolean allocate(Function func, TargetMachine rule, BackEnd root) {
    return new RegisterAllocation(func, rule, root).doIt();
  }

  /** Do register allocation and return true if successful. **/
  private boolean doIt() {

    // Convert to SSA form.
    func.apply(new Ssa(root));

    root.debOut.println();
    root.debOut.println("After SSA:");
    func.printIt(root.debOut);
    
    // Create live range and convert back to regular form
    liveRange();

    root.debOut.println();
    root.debOut.println("After live range:");
    func.printIt(root.debOut);

    // Do live variable analysis.
    liveInfo
      = (LiveVariableAnalysis)func.require(LiveVariableAnalysis.analyzer);

    // Create interference graph.
    idg = (InterferenceGraph)func.require(InterferenceGraph.analyzer);

    Symbol[] regvars = liveInfo.regvarVec();

    root.debOut.println();
    root.debOut.println("After live var. analysis/interference graph:");
    func.printIt(root.debOut,
                 new LocalAnalyzer[] {
                   LiveVariableAnalysis.analyzer,
                   InterferenceGraph.analyzer});

    // Coalescing.

    // Compute spill cost.
    // reference count * loop factor 
    // 

    // Compute register pressure factor

    // Coloring
    coloring(regvars);

    return true;
  }


  /** Create live range, squeeze unused variables **/
  private void liveRange() {
    int symBound = func.localSymtab.idBound();
    int parent[] = new int[symBound];
    int count[] = new int[symBound];
    for (int i = 0; i < symBound; i++) {
      parent[i] = -1;
      count[i] = 1;
    }

    BitSet used = new BitSet(symBound);

    Symbol symVec[] = func.localSymtab.sortedSymbols();
    
    // Scan PHI instructions
    for (BiLink p = func.flowGraph.basicBlkList.first(); !p.atEnd(); p = p.next()) {
      BasicBlk blk = (BasicBlk)p.elem();

      for (BiLink q = blk.instrList().first(); !q.atEnd(); q = q.next()) {
        LirNode ins = (LirNode)q.elem();

        if (ins.opCode == Op.PHI) {
          Symbol dst = ((LirSymRef)ins.src(0)).symbol;
          int n = ins.nSrcs();
          for (int i = 1; i < n; i++) {
            if (ins.src(i).src(0).opCode == Op.REG) {
              Symbol src = ((LirSymRef)ins.src(i).src(0)).symbol;
              bindRange(dst.id, src.id, parent, count);
            }
          }
        }
      }
    }

    // Rename to new name; remove phi
    for (BiLink p = func.flowGraph.basicBlkList.first(); !p.atEnd(); p = p.next()) {
      BasicBlk blk = (BasicBlk)p.elem();

      for (BiLink q = blk.instrList().first(); !q.atEnd(); ) {
        LirNode ins = (LirNode)q.elem();
        BiLink next = q.next();

        if (ins.opCode == Op.PHI)
          q.unlink();
        else
          renameRegvars(ins, parent, symVec, used);

        q = next;
      }
    }

    // Squeeze unused register variables
    for (BiLink p = func.localSymtab.symbols().first(); !p.atEnd(); p = p.next()) {
      Symbol sym = (Symbol)p.elem();

      if (sym.storage == Storage.REG && !used.get(sym.id))
        func.localSymtab.remove(sym);
    }

  }


  /** Rename variables and check uses **/
  private void renameRegvars(LirNode tree, int parent[], Symbol symVec[],
                             BitSet used) {
    int n = tree.nSrcs();
    for (int i = 0; i < n; i++) {
      if (tree.src(i).opCode == Op.REG) {
        int id = ((LirSymRef)tree.src(i)).symbol.id;
        if (parent[id] >= 0) {
          tree.setSrc(i, func.newLir.symRef(symVec[parent[id]]));
          used.set(parent[id]);
        } else
          used.set(id);
      }
      else
        renameRegvars(tree.src(i), parent, symVec, used);
    }
  }


  /** Disjoint set union - Tarjan's Path Compression **/

  private int find(int parent[], int x)
  {
    int y, w;

    /* y = root of x */
    for (y = x; parent[y] >= 0; y = parent[y])
      ;
    /* compress */
    for (; (w = parent[x]) >= 0; x = w)
      parent[x] = y;
    return y;
  }
  

  private int union(int parent[], int count[], int x, int y)
  {
    if (count[x] < count[y]) {
      int w = x; x = y; y = w;
    }
    parent[y] = x;
    count[x] += count[y];
    return x;
  }


  private void bindRange(int x, int y, int parent[], int count[]) {
    /* union symx,symy */
    x = find(parent, x);
    y = find(parent, y);
    if (x == y)
      return; /* already joined */

    int w = union(parent, count, x, y);
    if (w == y) {
      y = x;
      x = w;
    }
  }


  /** Return number of 1's in x **/
  private int count1s(long x) {
    long x1 = ((x  & 0xaaaaaaaaaaaaaaaaL) >>  1) + (x  & 0x5555555555555555L);
    long x2 = ((x1 & 0xccccccccccccccccL) >>  2) + (x1 & 0x3333333333333333L);
    long x3 = ((x2 & 0xf0f0f0f0f0f0f0f0L) >>  4) + (x2 & 0x0f0f0f0f0f0f0f0fL);
    long x4 = ((x3 & 0xff00ff00ff00ff00L) >>  8) + (x3 & 0x00ff00ff00ff00ffL);
    long x5 = ((x4 & 0xffff0000ffff0000L) >> 16) + (x4 & 0x0000ffff0000ffffL);
    long x6 = ((x5 & 0xffffffff00000000L) >> 32) + (x5 & 0x00000000ffffffffL);
    return (int)x6;
  }


  /** Return the position of leftmost 1 in x **/
  private int leftmost1(long x) {
    for (int i = 0; i < 64; i++)
      if (((x >> i) & 1) != 0)
        return i;
    return -1;
  }

  /** Get available register set for a variable. **/
  private long getRegset(Symbol regvar) {
    for (ImList p = (ImList)regvar.opt; !p.atEnd(); p = p.next()) {
      if (p.elem() == "&regset")
        return rule.getRegset((String)p.elem2nd());
    }
    return (1 << rule.nRegisters()) - 1;
  }







  /** Coloring registers **/

  private long[] regsets;
  private int[] nregsets;
  private int[] neighbors;
  private boolean[] removed;
  private int nPhyRegs;
  private Symbol[] regvars;
  private int[] assignedReg;

  private int[] spillCosts;
  private int[] disturbFactor;


  /** Compute Spill Cost. **/
  private static int INFINITY = 9999999;

  private void computeSpillCost() {

    // Check variable definitions/references.

    DefUseHandler handler = new DefUseHandler() {
        public void defined(LirNode node) {
          if (node.opCode == Op.SUBREG)
            node = node.src(0);
          spillCosts[liveInfo.regvarIndex(((LirSymRef)node).symbol)]++;
        }

        public void used(LirNode node) {
          if (node.opCode == Op.SUBREG)
            node = node.src(0);
          spillCosts[liveInfo.regvarIndex(((LirSymRef)node).symbol)]++;
        }
      };

    for (BiLink p = func.flowGraph.basicBlkList.first(); !p.atEnd(); p = p.next()) {
      BasicBlk blk = (BasicBlk)p.elem();

      for (BiLink q = blk.instrList().first(); !q.atEnd(); q = q.next()) {
        LirNode ins = (LirNode)q.elem();

        if (ins.opCode != Op.PHI)
          ins.pickupDefUseReg(handler);
      }
    }

    for (int i = nPhyRegs; i < regvars.length; i++)
      root.debOut.println("var " + regvars[i].name
                          + ": spill costs " + spillCosts[i]);
  }


  private void coloring(Symbol[] regvars) {

    this.regvars = regvars;

    // count number of neighbors
    neighbors = new int[regvars.length];
    for (int i = 0; i < regvars.length; i++)
      neighbors[i] = idg.interfereList(regvars[i]).length();

    // count number of disturbing variables.
    disturbFactor = new int[regvars.length];
    for (int i = 0; i < regvars.length; i++)
      disturbFactor[i] = idg.disturbList(regvars[i]).length();

    // count number of register sets.
    regsets = new long[regvars.length];
    nregsets = new int[regvars.length];
    for (int i = 0; i < regvars.length; i++) {
      regsets[i] = getRegset(regvars[i]);
      nregsets[i] = count1s(regsets[i]);
    }
    nPhyRegs = liveInfo.nPhyRegs();
    assignedReg = new int[regvars.length];
    for (int i = 0; i < nPhyRegs; i++)
      assignedReg[i] = rule.regNumber(regvars[i]);
    removed = new boolean[regvars.length];
    spillCosts = new int[regvars.length];

    // compute spill costs.
    computeSpillCost();

    pruneAndColor();
  }



  private void pruneAndColor() {

    // choose a variable x such that |neighbors(x)| < |regset(x)|
    int nleft = 0;
    int prune = -1;
    for (int i = nPhyRegs; i < regvars.length; i++) {
      if (!removed[i]) {
        nleft++;
        if (neighbors[i] < nregsets[i]) {
          prune = i;
          break;
        }
      }
    }
    if (nleft == 0)
      return;

    if (prune < 0) {
      // No such node, choose register variable to spill.
      int mincost = INFINITY;
      prune = -1;
      for (int i = nPhyRegs; i < regvars.length; i++) {
        if (!removed[i] && disturbFactor[i] >= nregsets[i]) {
          if (spillCosts[i] < mincost) {
            mincost = spillCosts[i];
            prune = i;
          }
        }
      }
    }

    removed[prune] = true;
    for (BiLink p = idg.interfereList(regvars[prune]).first(); !p.atEnd();
         p = p.next()) {
      Symbol sym = (Symbol)p.elem();
      if (!removed[liveInfo.regvarIndex(sym)])
        neighbors[liveInfo.regvarIndex(sym)]--;
    }

    pruneAndColor();

    // Choose one of unused registers.
    long avail = regsets[prune];
    for (BiLink p = idg.interfereList(regvars[prune]).first(); !p.atEnd();
         p = p.next()) {
      Symbol sym = (Symbol)p.elem();
      if (!removed[liveInfo.regvarIndex(sym)])
        avail &= ~(1 << assignedReg[liveInfo.regvarIndex(sym)]);
    }
    if ((assignedReg[prune] = leftmost1(avail)) < 0)
      root.debOut.println("var " + regvars[prune].name + " spilled.");
    else
      root.debOut.println("var " + regvars[prune].name + " assigned to: "
                          + rule.numberReg(assignedReg[prune]).name);
    removed[prune] = false;
  }
}
