#ifndef __AST_HPP__
#define __AST_HPP__

#include <llvm/Pass.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/LegacyPassManager.h>
#include <llvm/IR/Value.h>
#include <llvm/IR/Verifier.h>
#include <llvm/Transforms/InstCombine/InstCombine.h>
#include <llvm/Transforms/Scalar.h>
#include <llvm/Transforms/Scalar/GVN.h>
#include <llvm/Transforms/Utils.h>

#include <iostream>
#include <map>
#include <memory>
#include <vector>

using namespace llvm;

extern std::map<char, int> global_vars;

class AST {
 public:
  virtual void printAST(std::ostream &out) const = 0;
  virtual Value* igen() const { return nullptr; }

  void LLVM_IR_gen(bool optimize = true) {
    // Initialize
    TheModule = std::make_unique<Module>("minibasic program", TheContext);
    TheFPM = std::make_unique<legacy::FunctionPassManager>(TheModule.get());
    if (optimize) {
      TheFPM->add(createPromoteMemoryToRegisterPass());
      TheFPM->add(createInstructionCombiningPass());
      TheFPM->add(createReassociatePass());
      TheFPM->add(createGVNPass());
      TheFPM->add(createCFGSimplificationPass());
    }
    TheFPM->doInitialization();

    // Initialize types
    i8  = IntegerType::get(TheContext, 8);
    i32 = IntegerType::get(TheContext, 32);
    i64 = IntegerType::get(TheContext, 64);

    vars_type = ArrayType::get(i32, 26);
    nl_type = ArrayType::get(i8, 2);

    // Initialize global variables
    TheVars = new GlobalVariable(
      *TheModule, vars_type, false, GlobalValue::PrivateLinkage,
      ConstantAggregateZero::get(vars_type), "vars");
    TheVars->setAlignment(MaybeAlign(16));
    TheNL = new GlobalVariable(
      *TheModule, nl_type, true, GlobalValue::PrivateLinkage,
      ConstantArray::get(nl_type, {c8('\n'), c8('\0')}), "nl");
    TheNL->setAlignment(MaybeAlign(1));

    // Initialize library functions
    FunctionType *writeInteger_type =
      FunctionType::get(Type::getVoidTy(TheContext), {i64}, false);
    TheWriteInteger =
      Function::Create(writeInteger_type, Function::ExternalLinkage,
                       "writeInteger", TheModule.get());
    FunctionType *writeString_type =
      FunctionType::get(Type::getVoidTy(TheContext),
                        {PointerType::get(i8, 0)}, false);
    TheWriteString =
      Function::Create(writeString_type, Function::ExternalLinkage,
                       "writeString", TheModule.get());

    // Define and start the main function.
    FunctionType *main_type = FunctionType::get(i32, {}, false);
    Function *main =
      Function::Create(main_type, Function::ExternalLinkage,
                       "main", TheModule.get());
    BasicBlock *BB = BasicBlock::Create(TheContext, "entry", main);
    Builder.SetInsertPoint(BB);

    // Emit the program code.
    igen();
    Builder.CreateRet(c32(0));

    // Verify the IR.
    bool bad = verifyModule(*TheModule, &errs());
    if (bad) {
      std::cerr << "The IR is bad!" << std::endl;
      TheModule->print(errs(), nullptr);
      std::exit(1);
    }

    // Optimize!
    TheFPM->run(*main);

    // Print out the IR.
    TheModule->print(outs(), nullptr);
  }

 protected:
  static LLVMContext TheContext;
  static IRBuilder<> Builder;
  static std::unique_ptr<Module> TheModule;
  static std::unique_ptr<legacy::FunctionPassManager> TheFPM;

  static GlobalVariable *TheVars;
  static GlobalVariable *TheNL;
  static Function *TheWriteInteger;
  static Function *TheWriteString;

  static Type *i8;
  static Type *i32;
  static Type *i64;

  static ArrayType *vars_type;
  static ArrayType *nl_type;

  static ConstantInt* c8(char c) {
    return ConstantInt::get(TheContext, APInt(8, c, true));
  }
  static ConstantInt* c32(int n) {
    return ConstantInt::get(TheContext, APInt(32, n, true));
  }
};

inline std::ostream &operator<<(std::ostream &out, const AST &ast) {
  ast.printAST(out);
  return out;
}


class Stmt : public AST {};

class Expr : public AST {};

class Block : public Stmt {
 public:
  Block() : stmt_list() {}
  void append(Stmt *s) { stmt_list.push_back(s); }
  void printAST(std::ostream &out) const override {
    out << "Block(";
    bool first = true;
    for (const auto &s : stmt_list) {
      if (!first) out << ", ";
      first = false;
      out << *s;
    }
    out << ")";
  }
  Value* igen() const override {
    for (Stmt *s : stmt_list)
      s->igen();
    return nullptr;
  }
 private:
  std::vector<Stmt *> stmt_list;
};

class BinOp : public Expr {
 public:
  BinOp(Expr *e1, char o, Expr *e2) : expr1(e1), op(o), expr2(e2) {}
  void printAST(std::ostream &out) const override {
    out << op << "(" << *expr1 << ", " << *expr2 << ")";
  }
  Value* igen() const override {
    Value* l = expr1->igen();
    Value* r = expr2->igen();
    switch (op) {
      case '+': return Builder.CreateAdd(l, r, "addtmp");
      case '-': return Builder.CreateSub(l, r, "subtmp");
      case '*': return Builder.CreateMul(l, r, "multmp");
      case '/': return Builder.CreateSDiv(l, r, "divtmp");
      case '%': return Builder.CreateSRem(l, r, "modtmp");
    }
    return nullptr;
  }

 private:
  Expr *expr1;
  char op;
  Expr *expr2;
};

class Const : public Expr {
 public:
  Const(int n) : num(n) {}
  void printAST(std::ostream &out) const override {
    out << "Const(" << num << ")";
  }
  Value* igen() const override {
    return c32(num);
  }
 private:
  int num;
};

class Id : public Expr {
 public:
  Id(char x) : var(x) {}
  void printAST(std::ostream &out) const override {
    out << "Id(" << var << ")";
  }
  char getVar() { return var; }
  Value* igen() const override {
    char name[] = { var, '_', 'p', 't', 'r', '\0' };
    Value *v = Builder.CreateGEP(vars_type, TheVars,
                                 {c32(0), c32(var - 'a')}, name);
    name[1] = '\0';
    return Builder.CreateLoad(i32, v, name);
  }
 private:
  char var;
};

class Let : public Stmt {
 public:
  Let(Id *lhs, Expr *rhs) : var(lhs), expr(rhs) {}
  void printAST(std::ostream &out) const override {
    out << "Let(" << *var << ", " << *expr << ")";
  }
  Value* igen() const override {
    char name[] = { var->getVar(), '_', 'p', 't', 'r', '\0' };
    Value *lhs = Builder.CreateGEP(vars_type, TheVars,
                                   {c32(0), c32(var->getVar() - 'a')}, name);
    Value* rhs = expr->igen();
    Builder.CreateStore(rhs, lhs);
    return nullptr;
  }
 private:
  Id   *var;
  Expr *expr;
};

class For : public Stmt {
 public:
  For(Expr *e, Stmt *s) : expr(e), stmt(s) {}
  void printAST(std::ostream &out) const override {
    out << "For(" << *expr << ", " << *stmt << ")";
  }
  Value* igen() const override {
/*
    for expr do stmt

    prevBB:
      ...
      %n = igen(expr)
      br label %loop

    loop:
      %i = phi(%n, %i1)
      %cond = icmp sgt i32 %i, 0
      br i1 %cond, label %body, label %endfor

    body:
      %i1 = sub i32 %i, 1
      ...
      br label %loop

    endfor:
      ...
*/
    Value* n = expr->igen();
    BasicBlock *PrevBB = Builder.GetInsertBlock();
    Function *TheFunction = PrevBB->getParent();
    BasicBlock *LoopBB =
      BasicBlock::Create(TheContext, "loop", TheFunction);
    BasicBlock *BodyBB =
      BasicBlock::Create(TheContext, "body", TheFunction);
    BasicBlock *EndForBB =
      BasicBlock::Create(TheContext, "endfor", TheFunction);
    Builder.CreateBr(LoopBB);
    Builder.SetInsertPoint(LoopBB);
    PHINode *phi_iter = Builder.CreatePHI(i32, 2, "iter");
    phi_iter->addIncoming(n, PrevBB);
    Value *loop_cond =
      Builder.CreateICmpSGT(phi_iter, c32(0), "loop_cond");
    Builder.CreateCondBr(loop_cond, BodyBB, EndForBB);
    Builder.SetInsertPoint(BodyBB);
    Value *remaining =
      Builder.CreateSub(phi_iter, c32(1), "remaining");
    stmt->igen();
    phi_iter->addIncoming(remaining, Builder.GetInsertBlock());
    Builder.CreateBr(LoopBB);
    Builder.SetInsertPoint(EndForBB);
    return nullptr;
  }

 private:
  Expr *expr;
  Stmt *stmt;
};

class If : public Stmt {
 public:
  If(Expr *c, Stmt *s1, Stmt *s2 = nullptr) : cond(c), stmt1(s1), stmt2(s2) {}
  void printAST(std::ostream &out) const override {
    out << "If(" << *cond << ", " << *stmt1;
    if (stmt2 != nullptr) out << ", " << *stmt2;
    out << ")";
  }
  Value* igen() const override {
/*
    if cond then s1 else s2

    PrevBB:
      ...
      %v = compile condition
      %cond = icmp ne i32 %v, i32 0
      br $cond, label %then, label %else

    then:
      s1
      br label %endif

    else:
      s2
      br label %endif

    endif:
      ...
*/
    Value *v = cond->igen();
    Value *cond = Builder.CreateICmpNE(v, c32(0), "cond");
    Function *TheFunction = Builder.GetInsertBlock()->getParent();
    BasicBlock *ThenBB =
      BasicBlock::Create(TheContext, "then", TheFunction);
    BasicBlock *ElseBB =
      BasicBlock::Create(TheContext, "else", TheFunction);
    BasicBlock *EndIfBB =
      BasicBlock::Create(TheContext, "endif", TheFunction);
    Builder.CreateCondBr(cond, ThenBB, ElseBB);
    Builder.SetInsertPoint(ThenBB);
    stmt1->igen();
    Builder.CreateBr(EndIfBB);
    Builder.SetInsertPoint(ElseBB);
    if (stmt2 != nullptr) stmt2->igen();
    Builder.CreateBr(EndIfBB);
    Builder.SetInsertPoint(EndIfBB);
    return nullptr;
  }

 private:
  Expr *cond;
  Stmt *stmt1;
  Stmt *stmt2;
};

class Print : public Stmt {
 public:
  Print(Expr *e) : expr(e) {}
  void printAST(std::ostream &out) const override {
    out << "Print(" << *expr << ")";
  }
  Value* igen() const override {
    Value* n = expr->igen();
    Value *n64 = Builder.CreateSExt(n, i64, "ext");
    Builder.CreateCall(TheWriteInteger, {n64});
    Value *nl = Builder.CreateGEP(nl_type, TheNL, {c32(0), c32(0)}, "nl");
    Builder.CreateCall(TheWriteString, {nl});
    return nullptr;
  }
 private:
  Expr *expr;
};

#endif
