package coins.backend.opt;

import java.io.*;
import java.util.*;
import coins.backend.*;
import coins.backend.util.*;
import coins.backend.sym.*;
import coins.backend.lir.*;
import coins.backend.cfg.*;
import coins.backend.ana.*;

/** Constant Propagation Using Sparse Simple Constants Algorithm. */
public class Constants implements LocalTransform {
  private PrintWriter output;

  private Function function;

  private LirNode[] values;
  private LirNode nonconst;

  // Lattice elements:
  // null      undef, not yet decided 
  // a value   object always holds the same value. one of
  //             INTCONST, FLOATCONST, (ADD/SUB (STATIC/FRAME) (INTCONST))
  // nonconst  non constant node, it's value changes time to time

  public Constants(PrintWriter out) {
    output = out;
  }

  public void doIt(Function f) {
    function = f;
    FlowGraph flowGraph = f.flowGraph;
    BasicBlk[] blks = flowGraph.blkVectorByRPost();
    int nblks = flowGraph.maxDfn();
    nonconst = function.newLir.operator(Op.UNDEFINED, Type.UNKNOWN, new LirNode[0]);

    // Initialize all variables (REGs) to Top.
    int symBound = f.localSymtab.idBound();
    values = new LirNode[symBound];
    for (int i = 0; i < symBound; i++)
      values[i] = null; // null means top or undef

    boolean changed = true;
    while (changed) {
      changed = false;
      // Process basic blocks in reverse depth first order.
      for (int j = 1; j <= nblks; j++) {
        // Examine expressions in blks[j].
        for (BiLink p = blks[j].instrList().first(); !p.atEnd(); p = p.next()) {
          LirNode ins = (LirNode)p.elem();
          // Fold constants.
          if (ins.opCode == Op.SET && ins.src(0).opCode == Op.REG) {
            LirNode val = symbolicExecution(ins.src(1));
            if (val != null && val != nonconst && val != ins.src(1))
              ins.setSrc(1, val);
            int id = ((LirSymRef)ins.src(0)).symbol.id;
            if (values[id] != val) {
              changed = true;
              values[id] = val;
            }
          } else if (ins.opCode == Op.PHI) {
            // meet operation.
            // nonconst appeared -> nonconst
            // different values appeared -> nonconst
            // all undef -> undef
            // otherwise -> the value (all the same)
            int n = ins.nSrcs();
            LirNode result = null;
            for (int i = 1; i < n; i++) {
              if (ins.src(i).opCode == Op.UNDEFINED)
                continue;
              if (ins.src(i).opCode != Op.REG)
                throw new CantHappenException("Bad PHI form");
              LirNode val = values[((LirSymRef)ins.src(i)).symbol.id];
              if (val == null)
                continue;
              if (val == nonconst) {
                result = nonconst;
                break;
              }
              if (result == null)
                result = val;
              else if (!val.equals(result)) {
                result = nonconst;
                break;
              }
            }
            int id = ((LirSymRef)ins.src(0)).symbol.id;
            if (values[id] != result) {
              values[id] = result;
              changed = true;
            }
          } else {
            symbolicExecution(ins);
          }
        }
      }
      
    }
    
  }


  private LirNode symbolicExecution(LirNode ins) {
    int n = ins.nSrcs();
    LirNode[] opr = new LirNode[n];
    if (ins.opCode != Op.SUBREG) {
      // Fold constant expressions in subnodes
      for (int i = 0; i < n; i++) {
        LirNode val = symbolicExecution(ins.src(i));
        opr[i] = val;
        if (val != null && val != nonconst && val != ins.src(i))
          ins.setSrc(i, val);
      }
    }

    switch (ins.opCode) {
    case Op.REG:
      return values[((LirSymRef)ins).symbol.id];

    case Op.INTCONST:
    case Op.FLOATCONST:
    case Op.STATIC:
    case Op.FRAME:
      return ins;

    case Op.SUBREG:
    case Op.MEM:
    case Op.JUMP:
    case Op.JUMPC:
    case Op.JUMPN:
    case Op.CALL:
    case Op.SET:
    case Op.PROLOGUE:
    case Op.EPILOGUE:
    case Op.USE:
    case Op.CLOBBER:
    case Op.PARALLEL:
      return nonconst;
    }

    switch (ins.nSrcs()) {
    case 1:
      switch (opr[0].opCode) {
      case Op.INTCONST:
        long ivalue = ((LirIconst)opr[0]).value;
        switch (ins.opCode) {
        case Op.NEG: ivalue = -ivalue; break;
        case Op.BNOT: ivalue = ~ivalue; break;
        case Op.CONVSX:  
        case Op.CONVZX:
        case Op.CONVIT:
        case Op.CONVSF:
        case Op.CONVUF:
        default:
          return nonconst;
        }
        return function.newLir.iconst(ins.type, ivalue);

      case Op.FLOATCONST:
        double dvalue = ((LirFconst)opr[0]).value;
        switch (ins.opCode) {
        case Op.CONVFX:
        case Op.CONVFT:
        case Op.CONVFI:
        default:
          return nonconst;
        }
        // return function.newLir.fconst(ins.type, dvalue);

      default:
        return nonconst;
      }

    case 2:
       if (opr[0] == null) {
        if (opr[1] == null
            || opr[1].opCode == Op.INTCONST
            || opr[1].opCode == Op.FLOATCONST)
          return null;
        else
          return nonconst;
      }
      if (opr[1] == null) {
        if (opr[0] == null
            || opr[0].opCode == Op.INTCONST
            || opr[0].opCode == Op.FLOATCONST)
          return null;
        else
          return nonconst;
      }

      if (opr[0].opCode == Op.INTCONST && opr[1].opCode == Op.INTCONST) {
        long ival0 = ((LirIconst)opr[0]).value;
        long ival1 = ((LirIconst)opr[1]).value;
        long mask = ((1 << Type.bits(ins.type)) - 1);
        long ivalue;
            
        switch (ins.opCode) {
        default:
          return nonconst;

        case Op.ADD: ivalue = ival0 + ival1; break;
        case Op.SUB: ivalue = ival0 - ival1; break;
        case Op.MUL: ivalue = ival0 * ival1; break;
        case Op.DIVS: ivalue = ival0 / ival1; break;
        case Op.DIVU: ivalue = ((ival0 & mask) / (ival1 & mask)) & mask; break;
        case Op.MODS: ivalue = ival0 % ival1; break;
        case Op.MODU: ivalue = ((ival0 & mask) % (ival1 & mask)) & mask; break;
        case Op.BAND: ivalue = ival0 & ival1; break;
        case Op.BOR: ivalue = ival0 | ival1; break;
        case Op.BXOR: ivalue = ival0 ^ ival1; break;
        case Op.LSHS: ivalue = ival0 << ival1; break;
        case Op.LSHU: ivalue = ival0 << ival1; break;
        case Op.RSHS: ivalue = ival0 >> ival1; break;
        case Op.RSHU: ivalue = (ival0 & mask) >> ival1; break;
        case Op.TSTEQ: ivalue = (ival0 == ival1) ? 1 : 0; break;
        case Op.TSTNE: ivalue = (ival0 != ival1) ? 1 : 0; break;
        case Op.TSTLTS: ivalue = (ival0 < ival1) ? 1 : 0; break;
        case Op.TSTLES: ivalue = (ival0 <= ival1) ? 1 : 0; break;
        case Op.TSTGTS: ivalue = (ival0 > ival1) ? 1 : 0; break;
        case Op.TSTGES: ivalue = (ival0 >= ival1) ? 1 : 0; break;
        case Op.TSTLTU: ivalue = ((ival0 & mask) < (ival1 & mask)) ? 1 : 0; break;
        case Op.TSTLEU: ivalue = ((ival0 & mask) <= (ival1 & mask)) ? 1 : 0; break;
        case Op.TSTGTU: ivalue = ((ival0 & mask) > (ival1 & mask)) ? 1 : 0; break;
        case Op.TSTGEU: ivalue = ((ival0 & mask) >= (ival1 & mask)) ? 1 : 0; break;
        }
            
        return function.newLir.iconst(ins.type, ivalue);

      }
      else if (opr[0].opCode == Op.FLOATCONST && opr[1].opCode == Op.FLOATCONST) {
        double dval0 = ((LirFconst)opr[0]).value;
        double dval1 = ((LirFconst)opr[1]).value;
        double dvalue;
        int ivalue;

        switch (ins.opCode) {
        default:
          switch (ins.opCode) {
          default:
            return nonconst;

          case Op.ADD: dvalue = dval0 + dval1; break;
          case Op.SUB: dvalue = dval0 - dval1; break;
          case Op.MUL: dvalue = dval0 * dval1; break;
          case Op.DIVS: dvalue = dval0 / dval1; break;
          }
          return function.newLir.fconst(ins.type, dvalue);

        case Op.TSTEQ: ivalue = (dval0 == dval1) ? 1 : 0; break;
        case Op.TSTNE: ivalue = (dval0 != dval1) ? 1 : 0; break;
        case Op.TSTLTS: ivalue = (dval0 < dval1) ? 1 : 0; break;
        case Op.TSTLES: ivalue = (dval0 <= dval1) ? 1 : 0; break;
        case Op.TSTGTS: ivalue = (dval0 > dval1) ? 1 : 0; break;
        case Op.TSTGES: ivalue = (dval0 >= dval1) ? 1 : 0; break;
        }
        return function.newLir.iconst(ins.type, ivalue);
      }
      else
        return nonconst;

    default:
      return nonconst;
    }
  }
}
