#ifndef __AST_HPP__
#define __AST_HPP__

#include <llvm/Pass.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/LegacyPassManager.h>
#include <llvm/IR/Module.h>
#include <llvm/IR/Value.h>
#include <llvm/IR/Verifier.h>
#include <llvm/Transforms/InstCombine/InstCombine.h>
#include <llvm/Transforms/Scalar.h>
#include <llvm/Transforms/Scalar/GVN.h>
#include <llvm/Transforms/Utils.h>
#include "llvm-c/TargetMachine.h"

#include <memory>
#include <iostream>
#include <utility>
#include <vector>

#include "symbol.hpp"

using namespace llvm;

class AST {
 public:
  virtual ~AST() = default;
  virtual void sem_analyze() {}
  virtual void compile() const {}
  virtual void printAST(std::ostream &out) const = 0;
  virtual Value* LLVM_IR_generate() const { return nullptr; }

  void compile_to_LLVM_IR(bool optimize = true) {
    // Initialize.
    TheModule = std::make_unique<Module>("minibasic program", TheContext);
    // Set the Target triple.
    TheModule->setTargetTriple(LLVMGetDefaultTargetTriple());
    TheFPM = std::make_unique<legacy::FunctionPassManager>(TheModule.get());
    if (optimize) {
      TheFPM->add(createPromoteMemoryToRegisterPass());
      TheFPM->add(createInstructionCombiningPass());
      TheFPM->add(createReassociatePass());
      TheFPM->add(createGVNPass());
      TheFPM->add(createCFGSimplificationPass());
    }
    TheFPM->doInitialization();

    // Initialize types.
    i8  = IntegerType::get(TheContext, 8);
    i32 = IntegerType::get(TheContext, 32);
    i64 = IntegerType::get(TheContext, 64);

    vars_type = ArrayType::get(i32, 26);
    nl_type = ArrayType::get(i8, 2);

    // Initialize global variables.
    TheVars = new GlobalVariable(
      *TheModule, vars_type, false, GlobalValue::PrivateLinkage,
      ConstantAggregateZero::get(vars_type), "vars");
    TheVars->setAlignment(MaybeAlign(16));
    TheNL = new GlobalVariable(
      *TheModule, nl_type, true, GlobalValue::PrivateLinkage,
      ConstantArray::get(nl_type, {c8('\n'), c8('\0')}), "nl");
    TheNL->setAlignment(MaybeAlign(1));

    // Initialize library functions.
    FunctionType *writeInteger_type =
      FunctionType::get(llvm::Type::getVoidTy(TheContext), {i64}, false);
    TheWriteInteger =
      Function::Create(writeInteger_type, Function::ExternalLinkage,
                       "writeInteger", TheModule.get());
    FunctionType *writeString_type =
      FunctionType::get(llvm::Type::getVoidTy(TheContext),
                        {PointerType::get(i8, 0)}, false);
    TheWriteString =
      Function::Create(writeString_type, Function::ExternalLinkage,
                       "writeString", TheModule.get());

    // Define and start the main function.
    FunctionType *main_type = FunctionType::get(i32, {}, false);
    Function *main =
      Function::Create(main_type, Function::ExternalLinkage,
                       "main", TheModule.get());
    BasicBlock *BB = BasicBlock::Create(TheContext, "entry", main);
    Builder.SetInsertPoint(BB);

    // Emit the program code
    LLVM_IR_generate();
    Builder.CreateRet(c32(0));

    // Verify the IR.
    bool bad = verifyModule(*TheModule, &errs());
    if (bad) {
      std::cerr << "The IR is bad!" << std::endl;
      TheModule->print(errs(), nullptr);
      std::exit(1);
    }

    // Optimize the IR!
    TheFPM->run(*main);

    // Print out the LLVM IR
    TheModule->print(outs(), nullptr);
  }

 protected:
  static LLVMContext TheContext;
  static IRBuilder<> Builder;
  static std::unique_ptr<Module> TheModule;
  static std::unique_ptr<legacy::FunctionPassManager> TheFPM;

  static GlobalVariable *TheVars;
  static GlobalVariable *TheNL;
  static Function *TheWriteInteger;
  static Function *TheWriteString;

  static llvm::Type *i8;
  static llvm::Type *i32;
  static llvm::Type *i64;

  static ArrayType *vars_type;
  static ArrayType *nl_type;

  static ConstantInt* c8(char c) {
    return ConstantInt::get(TheContext, APInt(8, c, true));
  }
  static ConstantInt* c32(int n) {
    return ConstantInt::get(TheContext, APInt(32, n, true));
  }

  void UNREACHABLE() const {
    std::cerr << "Unreachable was reached!" << std::endl;
    std::exit(1);
  }
};

inline std::ostream &operator<<(std::ostream &out, const AST &ast) {
  ast.printAST(out);
  return out;
}

inline std::ostream &operator<<(std::ostream &out, MBType t) {
  switch (t) {
    case TYPE_int: out << "int"; break;
    case TYPE_bool: out << "bool"; break;
  }
  return out;
}

class Expr : public AST {
 public:
  void check_type(MBType expected_type) {
    sem_analyze();
    if (type != expected_type)
      std::cerr << "Type mismatch" << std::endl;
  }
  virtual int evaluate() const = 0;

 protected:
  MBType type;
};

class Stmt : public AST {
 public:
  virtual void execute() const = 0;
};

extern std::vector<int> rt_stack;

class Decl : public AST {
 public:
  Decl(char v, std::unique_ptr<MBType> t) : var(v), type(*t) {}

  void sem_analyze() override {
    st.insert(var, type);
  }

  void allocate() {
    rt_stack.push_back(0);
  }

  void deallocate() {
    rt_stack.pop_back();
  }

  void printAST(std::ostream &out) const override {
    out << "Decl(" << var << ":" << type << ")";
  }

 private:
  char var;
  MBType type;
};

class BinOp : public Expr {
 public:
  BinOp(std::unique_ptr<Expr> e1, char o, std::unique_ptr<Expr> e2)
    : expr1(std::move(e1)), op(o), expr2(std::move(e2)) {}

  void sem_analyze() override {
    expr1->check_type(TYPE_int);
    expr2->check_type(TYPE_int);
    switch (op) {
      case '+': case '-': case '*': case '/': case '%':
        type = TYPE_int;
        break;
      case '=': case '<': case '>':
        type = TYPE_bool;
        break;
    }
  }

  int evaluate() const override {
    switch (op) {
      case '+': return expr1->evaluate() + expr2->evaluate();
      case '-': return expr1->evaluate() - expr2->evaluate();
      case '*': return expr1->evaluate() * expr2->evaluate();
      case '/': return expr1->evaluate() / expr2->evaluate();
      case '%': return expr1->evaluate() % expr2->evaluate();
      case '=': return expr1->evaluate() == expr2->evaluate();
      case '<': return expr1->evaluate() < expr2->evaluate();
      case '>': return expr1->evaluate() > expr2->evaluate();
    }
    UNREACHABLE(); return 42;
  }

  void compile() const override {
    expr1->compile();
    expr2->compile();
    std::cout << "  popl %ebx" << std::endl   // expr2
              << "  popl %eax" << std::endl;  // expr1
    switch (op) {
      case '+':
        std::cout << "  addl %ebx, %eax" << std::endl
                  << "  pushl %eax"      << std::endl;
        break;
      case '-':
        std::cout << "  subl %ebx, %eax" << std::endl
                  << "  pushl %eax"      << std::endl;
        break;
      case '*':
        std::cout << "  mull %ebx"       << std::endl
                  << "  pushl %eax"      << std::endl;
        break;
      case '/':
        std::cout << "  cbq"             << std::endl
                  << "  divl %ebx"       << std::endl
                  << "  pushl %eax"      << std::endl;
        break;
      case '%':
        std::cout << "  cbq"             << std::endl
                  << "  divl %ebx"       << std::endl
                  << "  pushl %edx"      << std::endl;
        break;
      /*
      case '=': expr1->evaluate() == expr2->evaluate();
      case '<': expr1->evaluate() < expr2->evaluate();
      case '>': expr1->evaluate() > expr2->evaluate();
      */
      default: std::exit(42);
    }
  }

  Value* LLVM_IR_generate() const override {
    Value* l = expr1->LLVM_IR_generate();
    Value* r = expr2->LLVM_IR_generate();
    switch (op) {
      case '+': return Builder.CreateAdd(l, r, "addtmp");
      case '-': return Builder.CreateSub(l, r, "subtmp");
      case '*': return Builder.CreateMul(l, r, "multmp");
      case '/': return Builder.CreateSDiv(l, r, "divtmp");
      case '%': return Builder.CreateSRem(l, r, "remtmp");
      /*
      case '=': ...
      case '<': ...
      case '>': ...
      */
    }
    return nullptr;
  }

  void printAST(std::ostream &out) const override {
    out << op << "(" << *expr1 << ", " << *expr2 << ")";
  }

 private:
  std::unique_ptr<Expr> expr1;
  char op;
  std::unique_ptr<Expr> expr2;
};

class BoolConst : public Expr {
 public:
  BoolConst(int b) : bv(b) {}

  void sem_analyze() override {
    type = TYPE_bool;
  }

  int evaluate() const override {
    return bv;
  }

  void compile() const override {
    std::cout << "  pushl $" << bv << std::endl;
  }

  void printAST(std::ostream &out) const override {
    out << "BoolConst(" << bv << ")";
  }

 private:
  int bv;
};

class IntConst : public Expr {
 public:
  IntConst(int n) : num(n) {}

  void sem_analyze() override {
    type = TYPE_int;
  }

  int evaluate() const override {
    return num;
  }

  void compile() const override {
    std::cout << "  pushl $" << num << std::endl;
  }

  Value* LLVM_IR_generate() const override {
    return c32(num);
  }

  void printAST(std::ostream &out) const override {
    out << "IntConst(" << num << ")";
  }

 private:
  int num;
};

class Id : public Expr {
 public:
  Id(char v) : var(v) {}
  char name() const { return var; }

  void sem_analyze() override {
    SymTabEntry *e = st.lookup(var);
    type = e->type;
    offset = e->offset;
  }

  int evaluate() const override {
    return rt_stack[offset];
  }

  void compile() const override {
    std::cout << "  pushl " << 4 * (var - 'a') << "(%edi)" << std::endl;
  }

  Value* LLVM_IR_generate() const override {
    char name[] = { var, '_', 'p', 't', 'r', '\0' };
    Value *v = Builder.CreateGEP(vars_type, TheVars,
                                 {c32(0), c32(var - 'a')}, name);
    name[1] = '\0';
    return Builder.CreateLoad(i32, v, name);
  }

  void printAST(std::ostream &out) const override {
    out << "Id(" << var << ")";
  }

 private:
  char var;
  int offset;
};

class Block : public Stmt {
 public:
  Block() : decl_list(), stmt_list() {}

  void append_decl(std::unique_ptr<Decl> d) {
    decl_list.push_back(std::move(d));
  }

  void append_stmt(std::unique_ptr<Stmt> s) {
    stmt_list.push_back(std::move(s));
  }

  void merge(std::unique_ptr<Block> b) {
    stmt_list = std::move(b->stmt_list);
  }

  void sem_analyze() override {
    st.enter_scope();
    for (const auto &d : decl_list) d->sem_analyze();
    for (const auto &s : stmt_list) s->sem_analyze();
    st.exit_scope();
  }

  void execute() const override {
    for (const auto &decl : decl_list) decl->allocate();
    for (const auto &stmt : stmt_list) stmt->execute();
    for (const auto &decl : decl_list) decl->deallocate();
  }

  void compile() const override {
    for (const auto &stmt : stmt_list) stmt->compile();
  }

  Value* LLVM_IR_generate() const override {
    for (const auto &stmt : stmt_list) stmt->LLVM_IR_generate();
    return nullptr;
  }

  void printAST(std::ostream &out) const override {
    out << "Block(";
    bool first = true;
    for (const auto &d : decl_list) {
      if (!first) out << ", ";
      first = false;
      out << *d;
    }
    for (const auto &s : stmt_list) {
      if (!first) out << ", ";
      first = false;
      out << *s;
    }
    out << ")";
  }

 private:
  std::vector<std::unique_ptr<Decl>> decl_list;
  std::vector<std::unique_ptr<Stmt>> stmt_list;
};

class If : public Stmt {
 public:
  If(std::unique_ptr<Expr> e, std::unique_ptr<Stmt> s1,
     std::unique_ptr<Stmt> s2 = {})
     : expr(std::move(e)), stmt1(std::move(s1)), stmt2(std::move(s2)) {}

  void sem_analyze() override {
    expr->check_type(TYPE_bool);
    stmt1->sem_analyze();
    if (stmt2 != nullptr) stmt2->sem_analyze();
  }

  void execute() const override {
    if (expr->evaluate() != 0)
      stmt1->execute();
    else if (stmt2 != nullptr) stmt2->execute();
  }

  void compile() const override {
    static int counter = 0;
    expr->compile();
    int lelse = counter++;
    std::cout << "  popl %eax"         << std::endl
              << "  andl %eax, %eax"   << std::endl
              << "  jz Lelse" << lelse << std::endl;
    stmt1->compile();
    int lendif = counter++;
    std::cout << "  jmp Lendif" << lendif  << std::endl
              << "Lelse" << lelse << ":"   << std::endl;
    if (stmt2 != nullptr) stmt2->compile();
    std::cout << "Lendif" << lendif << ":" << std::endl;
  }

  Value* LLVM_IR_generate() const override {
   /*
      if expr then stmt1 else stmt2

      preBB:
        ...
	%v = LLVM_IR_gnerate expr
	%cond = icmp ne i32 %v, 0
	br %cond, label %then, label %else

      then:
	LLVM_IR_gnerate stmt1
	br label %endif

      else:
	LLVM_IR_gnerate stmt2
	br label %endif

      endif:
        ...
    */
    Value *v = expr->LLVM_IR_generate();
    Value *cond = Builder.CreateICmpNE(v, c32(0), "cond");
    Function *TheFunction = Builder.GetInsertBlock()->getParent();
    BasicBlock *ThenBB =
      BasicBlock::Create(TheContext, "then", TheFunction);
    BasicBlock *ElseBB =
      BasicBlock::Create(TheContext, "else", TheFunction);
    BasicBlock *EndIfBB =
      BasicBlock::Create(TheContext, "endif", TheFunction);
    Builder.CreateCondBr(cond, ThenBB, ElseBB);
    Builder.SetInsertPoint(ThenBB);
    stmt1->LLVM_IR_generate();
    Builder.CreateBr(EndIfBB);
    Builder.SetInsertPoint(ElseBB);
    if (stmt2 != nullptr) stmt2->LLVM_IR_generate();
    Builder.CreateBr(EndIfBB);
    Builder.SetInsertPoint(EndIfBB);
    return nullptr;
  }

  void printAST(std::ostream &out) const override {
    out << "If(" << *expr << ", " << *stmt1;
    if (stmt2) out << ", " << *stmt2;
    out << ")";
  }

 private:
  std::unique_ptr<Expr> expr;
  std::unique_ptr<Stmt> stmt1;
  std::unique_ptr<Stmt> stmt2;
};

class For : public Stmt {
 public:
  For(std::unique_ptr<Expr> e, std::unique_ptr<Stmt> s)
    : expr(std::move(e)), stmt(std::move(s)) {}

  void sem_analyze() override {
    expr->check_type(TYPE_int);
    stmt->sem_analyze();
  }

  void execute() const override {
    int limit = expr->evaluate();
    for (int i = 0; i < limit; ++i)
      stmt->execute();
  }

  void compile() const override {
    static int counter = 0;
    expr->compile();
    int lbeg4 = counter++;
    int lend4 = counter++;
    std::cout << "Lbegfor" << lbeg4 << ":" << std::endl
              << "  popl %eax"             << std::endl
              << "  cmp $0, %eax"          << std::endl
              << "  jle Lendfor" << lend4  << std::endl
              << "  decl %eax"             << std::endl
              << "  pushl %eax"            << std::endl;
    stmt->compile();
    std::cout << "  jmp Lbegfor" << lbeg4  << std::endl
              << "Lendfor" << lend4 << ":" << std::endl;
  }

  Value* LLVM_IR_generate() const override {
   /*
      for expr do stmt

      preBB:
        %n = LLVM_IR_generate expr
	br label %loop

      loop:
        %i = phi([%n %preBB], [%i1 %body])
	%cond = icmp sgt i32 %i, 0
        br i1 %cond, label %body, label %endfor

      body:
        %i1 = sub i32 %i, 1
        LLVM_IR_generate stmt
	br label %loop

      endfor:
        ...
   */
    Value* n = expr->LLVM_IR_generate();
    BasicBlock *PrevBB = Builder.GetInsertBlock();
    Function *TheFunction = PrevBB->getParent();
    BasicBlock *LoopBB =
      BasicBlock::Create(TheContext, "loop", TheFunction);
    BasicBlock *BodyBB =
      BasicBlock::Create(TheContext, "body", TheFunction);
    BasicBlock *EndForBB =
      BasicBlock::Create(TheContext, "endfor", TheFunction);
    Builder.CreateBr(LoopBB);
    Builder.SetInsertPoint(LoopBB);
    PHINode *phi_iter = Builder.CreatePHI(i32, 2, "iter");
    phi_iter->addIncoming(n, PrevBB);
    Value *loop_cond =
      Builder.CreateICmpSGT(phi_iter, c32(0), "loop_cond");
    Builder.CreateCondBr(loop_cond, BodyBB, EndForBB);
    Builder.SetInsertPoint(BodyBB);
    Value *remaining =
      Builder.CreateSub(phi_iter, c32(1), "remaining");
    stmt->LLVM_IR_generate();
    phi_iter->addIncoming(remaining, Builder.GetInsertBlock());
    Builder.CreateBr(LoopBB);
    Builder.SetInsertPoint(EndForBB);
    return nullptr;
  }

  void printAST(std::ostream &out) const override {
    out << "For(" << *expr << ", " << *stmt << ")";
  }

 private:
  std::unique_ptr<Expr> expr;
  std::unique_ptr<Stmt> stmt;
};

class Let : public Stmt {
 public:
  Let(std::unique_ptr<Id> lhs, std::unique_ptr<Expr> rhs)
    : var(std::move(lhs)), expr(std::move(rhs)) {}

  void sem_analyze() override {
    SymTabEntry *e = st.lookup(var->name());
    expr->check_type(e->type);
    offset = e->offset;
  }

  void execute() const override {
    rt_stack[offset] = expr->evaluate();
  }

  void compile() const override {
    expr->compile();
    std::cout << "  popl %eax" << std::endl
              << "  movl %eax, " << 4 * (var->name() - 'a') << "(%edi)" << std::endl;
  }

  Value* LLVM_IR_generate() const override {
    char name[] = { var->name(), '_', 'p', 't', 'r', '\0' };
    Value *lhs = Builder.CreateGEP(vars_type, TheVars,
                                   {c32(0), c32(var->name() - 'a')}, name);
    Value* rhs = expr->LLVM_IR_generate();
    Builder.CreateStore(rhs, lhs);
    return nullptr;
  }

  void printAST(std::ostream &out) const override {
    out << "Let(" << *var << ", " << *expr << ")";
  }

 private:
  std::unique_ptr<Id> var;
  std::unique_ptr<Expr> expr;
  int offset;
};

class Print : public Stmt {
 public:
  Print(std::unique_ptr<Expr> e) : expr(std::move(e)) {}

  void sem_analyze() override {
    expr->check_type(TYPE_int);
  }

  void execute() const override {
    std::cout << expr->evaluate() << std::endl;
  }

  void compile() const override {
    std::cout << "  pushl %edi" << std::endl;
    expr->compile();
    std::cout << "  subl $8, %esp" << std::endl
              << "  call _writeInteger" << std::endl
              << "  addl $12, %esp" << std::endl
              << "  movl $NL, %eax" << std::endl
              << "  pushl %eax" << std::endl
              << "  subl $8, %esp" << std::endl
              << "  call _writeString" << std::endl
              << "  addl $12, %esp" << std::endl
              << "  popl %edi" << std::endl;
  }

  Value* LLVM_IR_generate() const override {
    Value* n = expr->LLVM_IR_generate();
    Value *n64 = Builder.CreateSExt(n, i64, "ext");
    Builder.CreateCall(TheWriteInteger, {n64});
    Value *nl = Builder.CreateGEP(nl_type, TheNL, {c32(0), c32(0)}, "nl");
    Builder.CreateCall(TheWriteString, {nl});
    return nullptr;
  }

  void printAST(std::ostream &out) const override {
    out << "Print(" << *expr << ")";
  }

 private:
  std::unique_ptr<Expr> expr;
};

inline void prologue() {
  std::cout << ".text"             << std::endl
            << ".global _start"    << std::endl
                                   << std::endl
            << "_start:"           << std::endl
            << "  movl $var, %edi" << std::endl
                                   << std::endl;
}

inline void epilogue() {
  std::cout                       << std::endl
            << "  movl $1, %eax"  << std::endl
            << "  movl $0, %ebx"  << std::endl
            << "  int $0x80"      << std::endl
                                  << std::endl
            << ".data"             << std::endl
            << "var:"              << std::endl
            << ".rept 26"          << std::endl
            << ".long 0"           << std::endl
            << ".endr"             << std::endl
            << "NL:"               << std::endl
            << ".asciz \"\\n\""    << std::endl;
}

#endif
