#ifndef __AST_HPP__
#define __AST_HPP__

#include <iostream>
#include <map>
#include <vector>

#include "symbol.hpp"

extern std::vector<int> rt_stack;

class AST {
 public:
  virtual void printAST(std::ostream &out) const = 0;
  virtual void sem_analyze() {}
};

inline std::ostream &operator<<(std::ostream &out, const AST &ast) {
  ast.printAST(out);
  return out;
}

inline std::ostream &operator<<(std::ostream &out, Type t) {
  switch (t) {
    case TYPE_int: out << "int"; break;
    case TYPE_bool: out << "bool"; break;
  }
  return out;
}

class Decl : public AST {
 public:
  Decl(char v, Type t) : var(v), type(t) {}
  void printAST(std::ostream &out) const override {
    out << "Decl(" << var << ":" << type << ")";
  }
  void sem_analyze() {
    st.insert(var, type);
  }
 private:
  char var;
  Type type;
};

class Stmt : public AST {
 public:
  virtual void execute() const = 0;
 private:
};

class Expr : public AST {
 public:
  void check_type(Type expected_type) {
    sem_analyze();
    if (type != expected_type) yyerror("Type mismatch");
  }
  virtual int evaluate() const = 0;
 protected:
  Type type;
};

class Block : public Stmt {
 public:
  Block() : decl_list(), stmt_list() {}
  ~Block() {
    for (Decl *d : decl_list) delete d;
    for (Stmt *s : stmt_list) delete s;
  }
  void append_decl(Decl *d) { decl_list.push_back(d); }
  void append_stmt(Stmt *s) { stmt_list.push_back(s); }
  void merge(Block *b) {
    stmt_list = b->stmt_list;
    b -> stmt_list.clear();
    delete b;
  }
  void printAST(std::ostream &out) const override {
    out << "Block(";
    bool first = true;
    for (const auto &d : decl_list) {
      if (!first) out << ", ";
      first = false;
      out << *d;
    }
    for (const auto &s : stmt_list) {
      if (!first) out << ", ";
      first = false;
      out << *s;
    }
    out << ")";
  }
  void sem_analyze() override {
    st.enterScope();
    for (Decl *d : decl_list) d->sem_analyze();
    for (Stmt *s : stmt_list) s->sem_analyze();
    st.exitScope();
  }
  void execute() const override {
    for (Decl *d : decl_list) rt_stack.push_back(0);
    for (Stmt *s : stmt_list) s->execute();
    for (Decl *d : decl_list) rt_stack.pop_back();
  }
 private:
  std::vector<Decl *> decl_list;
  std::vector<Stmt *> stmt_list;
};

class BinOp : public Expr {
 public:
  BinOp(Expr *e1, char o, Expr *e2) : expr1(e1), op(o), expr2(e2) {}
  void printAST(std::ostream &out) const override {
    out << op << "(" << *expr1 << ", " << *expr2 << ")";
  }
  void sem_analyze() override {
    expr1->check_type(TYPE_int);
    expr2->check_type(TYPE_int);
    switch(op) {
      case '+': case '-': case '*': case '/': case '%':
	type = TYPE_int;
	break;
      case '<': case '=': case '>':
	type = TYPE_bool;
	break;
    }
  }
  int evaluate() const override {
    switch(op) {
      case '+': return expr1->evaluate() + expr2->evaluate();
      case '-': return expr1->evaluate() - expr2->evaluate();
      case '*': return expr1->evaluate() * expr2->evaluate();
      case '/': return expr1->evaluate() / expr2->evaluate();
      case '%': return expr1->evaluate() % expr2->evaluate();
      case '<': return expr1->evaluate() < expr2->evaluate();
      case '=': return expr1->evaluate() == expr2->evaluate();
      case '>': return expr1->evaluate() > expr2->evaluate();
      default: std::cerr << "Case not handled" << std::endl; return 42;
    }
  }
 private:
  Expr *expr1;
  char op;
  Expr *expr2;
};

class IntConst : public Expr {
 public:
  IntConst(int n) : num(n) {}
  void printAST(std::ostream &out) const override {
    out << "IntConst(" << num << ")";
  }
  void sem_analyze() override {
    type = TYPE_int;
  }
  int evaluate() const override {
    return num;
  }
 private:
  int num;
};

class BoolConst : public Expr {
 public:
  BoolConst(int b) : bv(b) {}
  void printAST(std::ostream &out) const override {
    out << "BoolConst(" << (bv ? "true" : "false") << ")";
  }
  void sem_analyze() override {
    type = TYPE_bool;
  }
  int evaluate() const override {
    return bv;
  }
 private:
  int bv;
};

class Id : public Expr {
 public:
  Id(char x) : var(x) {}
  void printAST(std::ostream &out) const override {
    out << "Id(" << var << ")";
  }
  void sem_analyze() override {
    STEntry *e = st.lookup(var);
    type = e->type;
    offset = e->offset;
  }
  int evaluate() const override {
    return rt_stack[offset];   
  }
  char getVar() { return var; }
 private:
  char var;
  int offset;
};

class Let : public Stmt {
 public:
  Let(Id *lhs, Expr *rhs) : var(lhs), expr(rhs) {}
  void printAST(std::ostream &out) const override {
    out << "Let(" << *var << ", " << *expr << ")";
  }
  void sem_analyze() override {
    STEntry *lhs = st.lookup(var->getVar());
    expr->check_type(lhs->type);
    offset = lhs->offset;
  }
  void execute() const override {
    rt_stack[offset] = expr->evaluate();
  }
 private:
  Id   *var;
  Expr *expr;
  int offset;
};

class For : public Stmt {
 public:
  For(Expr *e, Stmt *s) : expr(e), stmt(s) {}
  void printAST(std::ostream &out) const override {
    out << "For(" << *expr << ", " << *stmt << ")";
  }
  void sem_analyze() override {
    expr->check_type(TYPE_int);
    stmt->sem_analyze();
  }
  void execute() const override {
    for (int times = expr->evaluate(), i = 0; i < times; ++i)
      stmt->execute();
  }
 private:
  Expr *expr;
  Stmt *stmt;
};

class If : public Stmt {
 public:
  If(Expr *c, Stmt *s1, Stmt *s2 = nullptr) : cond(c), stmt1(s1), stmt2(s2) {}
  void printAST(std::ostream &out) const override {
    out << "If(" << *cond << ", " << *stmt1;
    if (stmt2 != nullptr) out << ", " << *stmt2;
    out << ")";
  }
  void sem_analyze() override {
    cond->check_type(TYPE_bool);
    stmt1->sem_analyze();
    if (stmt2 != nullptr) stmt2->sem_analyze();
  }
  void execute() const override {
    if (cond->evaluate()) stmt1->execute();
    else if (stmt2 != nullptr) stmt2->execute();
  }
 private:
  Expr *cond;
  Stmt *stmt1;
  Stmt *stmt2;
};

class Print : public Stmt {
 public:
  Print(Expr *e) : expr(e) {}
  void printAST(std::ostream &out) const override {
    out << "Print(" << *expr << ")";
  }
  void sem_analyze() override {
    expr->check_type(TYPE_int);
  }
  void execute() const override {
    std::cout << expr->evaluate() << std::endl;
  }
 private:
  Expr *expr;
};

#endif
