#ifndef __AST_HPP__
#define __AST_HPP__

#include <iostream>
#include <vector>
#include <map>

class AST {
 public:
  virtual void printAST(std::ostream &out) const = 0;
  virtual void cgen() const {}
};

inline std::ostream &operator<<(std::ostream &out, const AST &ast) {
  ast.printAST(out);
  return out;
}

class Stmt : public AST {
};

class Expr : public AST {
};

class BinOp : public Expr {
 public:
  BinOp(Expr *e1, char o, Expr *e2) : expr1(e1), op(o), expr2(e2) {}
  void printAST(std::ostream &out) const override {
    out << op << "(" << *expr1 << ", " << *expr2 << ")";
  }
  void cgen() const override {
    expr1->cgen();
    expr2->cgen(); 
    std::cout << "  popl %ebx" << std::endl;	// expr2
    std::cout << "  popl %eax" << std::endl;	// expr1

    switch(op) {
      case '+': 
        std::cout << "  addl %ebx, %eax" << std::endl
                  << "  pushl %eax"      << std::endl;
        break;
      case '-':
        std::cout << "  subl %ebx, %eax" << std::endl
                  << "  pushl %eax"      << std::endl;
        break;
      case '*':
        std::cout << "  mull %ebx"  << std::endl
                  << "  pushl %eax" << std::endl;
        break;
      case '/':
        std::cout << "  cdq"        << std::endl
                  << "  divl %ebx"  << std::endl
                  << "  pushl %eax" << std::endl;
        break;
      case '%':
        std::cout << "  cdq"        << std::endl
                  << "  divl %ebx"  << std::endl
                  << "  pushl %edx" << std::endl;
        break;
    }
  }
 private:
  Expr *expr1;
  char op;
  Expr *expr2;
};

class Const : public Expr {
 public:
  Const(int n): num(n) {}
  void printAST(std::ostream &out) const override {
    out << "Const(" << num << ")";
  }
  void cgen() const override {
    std::cout << "  pushl $" << num << std::endl;
  }
 private:
  int num;
};

class Id : public Expr {
 public:
  Id(char x): var(x) {}
  void printAST(std::ostream &out) const override {
    out << "Id(" << var << ")";
  }
  char getVar() { return var; }
  void cgen() const override {
    std::cout << "  pushl " << 4 * (var - 'a') << "(%edi)" << std::endl;
  }
 private:
  char var;
};

class Block : public Stmt {
 public:
  Block() : stmt_list() {}
  void append(Stmt *s) { stmt_list.push_back(s); }
  void printAST(std::ostream &out) const override {
    out << "Block(";
    bool first = true;
    for (const auto &s : stmt_list) {
      if (!first) out << ", ";
      first = false;
      out << *s;
    }
    out << ")";
  }
  void cgen() const override {
    for (Stmt *s : stmt_list) 
      s->cgen();
  }
 private:
  std::vector<Stmt *> stmt_list;
};

class Let : public Stmt {
 public:
  Let(Id *lhs, Expr *rhs): var(lhs), expr(rhs) {}
  void printAST(std::ostream &out) const override {
    out << "Let(" << *var << ", " << *expr << ")";
  }
  void cgen() const override {
    expr->cgen();
    std::cout << "  popl %eax"   << std::endl
              << "  movl %eax, " << 4 * (var->getVar() - 'a') << "(%edi)" << std::endl;
  }
 private:
  Id   *var;
  Expr *expr;
};

class For : public Stmt {
 public:
  For(Expr *e, Stmt *s): expr(e), stmt(s) {}
  void printAST(std::ostream &out) const override {
    out << "For(" << *expr << ", " << *stmt << ")";
  }
  void cgen() const override {
    static int counter = 0;
    expr->cgen();
    int lbeg4 = counter++;
    int lend4 = counter++;
    std::cout << "Lbegfor" << lbeg4 << ":" << std::endl
              << "  popl %eax"            << std::endl
              << "  cmpl $0, %eax"        << std::endl
              << "  jle Lendfor" << lend4 << std::endl
              << "  decl %eax"            << std::endl
              << "  pushl %eax"           << std::endl;
    stmt->cgen();
    std::cout << "  jmp Lbegfor" << lbeg4 << std::endl
              << "Lendfor" << lend4 << ":" << std::endl;
  }
 private:
  Expr *expr;
  Stmt *stmt;
};

class If : public Stmt {
 public:
  If(Expr *c, Stmt *s1, Stmt *s2 = nullptr) : cond(c), stmt1(s1), stmt2(s2) {}
  void printAST(std::ostream &out) const override {
    out << "If(" << *cond << ", " << *stmt1;
    if (stmt2 != nullptr) out << ", " << *stmt2;
    out << ")";
  }
  void cgen() const override {
    static int counter = 0;
    cond->cgen();
    int lelse = counter++;
    std::cout << "  popl %eax"         << std::endl
              << "  andl %eax, %eax"   << std::endl
              << "  jz Lelse" << lelse << std::endl;
    stmt1->cgen();
    int lendif = counter++;
    std::cout << "   jmp Lendif" << lendif << std::endl
              << "Lelse" << lelse << ":"   << std::endl;
    if (stmt2 != nullptr) stmt2->cgen();
    std::cout << "Lendif" << lendif << ":" << std::endl;
  }
 private:
  Expr *cond;
  Stmt *stmt1;
  Stmt *stmt2;
};

class Print : public Stmt {
 public:
  Print(Expr *e): expr(e) {}
  void printAST(std::ostream &out) const override {
    out << "Print(" << *expr << ")";
  }
  void cgen() const override {
    std::cout << "  pushl %edi"         << std::endl;
    expr->cgen();	// pushes the result (int to print) on the stack
    std::cout << "  sub $8, %esp"       << std::endl
              << "  call _writeInteger" << std::endl
              << "  add $12, %esp"      << std::endl
              << "  mov $NL, %eax"      << std::endl
              << "  pushl %eax"         << std::endl
              << "  subl $8, %esp"      << std::endl
              << "  call _writeString"  << std::endl
              << "  addl $12, %esp"     << std::endl
              << "  popl %edi"          << std::endl;
  }
 private:
  Expr *expr;
};

inline void prologue() {
  std::cout << ".text"             << std::endl
            << ".global _start"    << std::endl
                                   << std::endl
            << "_start:"           << std::endl
            << "  movl $var, %edi" << std::endl
                                   << std::endl;
}

inline void epilogue() {
  std::cout                       << std::endl
            << "  movl $1, %eax"  << std::endl
            << "  movl $0, %ebx"  << std::endl
            << "  int $0x80"      << std::endl
                                  << std::endl
            << ".data"            << std::endl
            << "var:"             << std::endl
            << ".rept 26"         << std::endl
            << ".long 0"          << std::endl
            << ".endr"            << std::endl
            << "NL:"              << std::endl
            << ".asciz \"\\n\""   << std::endl;
 
}

#endif
