%require "3.8"
%language "C++"
%locations
%output "parser.cpp"
%defines "parser.hpp"
%verbose
%define api.location.file "location.hpp"

%{
#include <iostream>
#include <memory>
#include <string>
#include <cstdio>

extern FILE *yyin;
%}

%code requires {
#include "ast.hpp"
}

%{
#include "lexer.hpp"

YY_DECL;

static std::unique_ptr<Block> ast;
SymbolTable st;
std::vector<int> rt_stack;
%}

%define api.value.type variant
%define api.token.constructor
%define api.value.automove
%define parse.assert
%define parse.error verbose
%define api.token.prefix {T_}

%token begin  "begin"
%token do     "do"
%token else   "else"
%token end    "end"
%token for    "for"
%token if     "if"
%token let    "let"
%token print  "print"
%token then   "then"
%token var    "var"
%token int    "int"
%token bool   "bool"
%token true   "true"
%token false  "false"

%token <int> const
%token <char> id

%nonassoc '=' '<' '>'
%left '+' '-'
%left '*' '/' '%'

%type <std::unique_ptr<Block>>  block decl_list stmt_list program
%type <std::unique_ptr<Decl>>   decl
%type <std::unique_ptr<Stmt>>   stmt
%type <std::unique_ptr<Expr>>   expr
%type <std::unique_ptr<Type>>   type

%expect 1
%%

program :
  block               { ast = $1; }
;

block :
  decl_list stmt_list { $$ = $1; $$->merge($2); }
;

decl_list :
  /* nothing */       { $$ = std::make_unique<Block>(); }
| decl_list decl      { $$ = $1; $$->append_decl($2); }
;

decl :
  "var" id ':' type   { $$ = std::make_unique<Decl>($2, $4); }
;

type :
  "int"               { $$ = std::make_unique<Type>(TYPE_int); }
| "bool"              { $$ = std::make_unique<Type>(TYPE_bool); }
;

stmt_list :
  /* nothing */       { $$ = std::make_unique<Block>(); }
| stmt_list stmt      { $$ = $1; $$->append_stmt($2); }
;

stmt :
  "begin" block "end"               { $$ = $2; }
| "for" expr "do" stmt              { $$ = std::make_unique<For>($2, $4); }
| "if" expr "then" stmt             { $$ = std::make_unique<If>($2, $4); }
| "if" expr "then" stmt "else" stmt { $$ = std::make_unique<If>($2, $4, $6); }
| "let" id '=' expr                 { $$ = std::make_unique<Let>(std::make_unique<Id>($2), $4); }
| "print" expr                      { $$ = std::make_unique<Print>($2); }
;

expr :
  const               { $$ = std::make_unique<IntConst>($1); }
| id                  { $$ = std::make_unique<Id>($1); }
| '(' expr ')'        { $$ = $2; }
| expr '+' expr       { $$ = std::make_unique<BinOp>($1, '+', $3); }
| expr '-' expr       { $$ = std::make_unique<BinOp>($1, '-', $3); }
| expr '*' expr       { $$ = std::make_unique<BinOp>($1, '*', $3); }
| expr '/' expr       { $$ = std::make_unique<BinOp>($1, '/', $3); }
| expr '%' expr       { $$ = std::make_unique<BinOp>($1, '%', $3); }
| expr '=' expr       { $$ = std::make_unique<BinOp>($1, '=', $3); }
| expr '<' expr       { $$ = std::make_unique<BinOp>($1, '<', $3); }
| expr '>' expr       { $$ = std::make_unique<BinOp>($1, '>', $3); }
| "true"              { $$ = std::make_unique<BoolConst>(true); }
| "false"             { $$ = std::make_unique<BoolConst>(false); }
;

%%

void yy::parser::error(const location_type& l, const std::string& m) {
  std::cerr << "\033[91mError: " << m << " at line "
            << l.begin.line << "\033[0m" << std::endl;
  std::exit(1);
}

void usage() {
  std::cout << "Usage: mbc [<options>] [<source-file>...]" << std::endl
            << "Options: " << std::endl
            << "  -h : help (this message)" << std::endl
            << "  -p : print the AST" << std::endl
            << "  -x : execute" << std::endl
            << "  -c : compile (default)" << std::endl;
}

enum Mode { PRINT, EVAL, COMPILE };

int process(Mode mode, const char* filename) {
  bool from_file = false;
  if (filename == nullptr || strcmp(filename, "-") == 0) {
    yyin = stdin;
  } else {
    yyin = fopen(filename, "r");
    if (yyin == nullptr) {
      perror(filename);
      return 1;
    }
    from_file = true;
    yyrestart(yyin);
  }

  yy::parser p;
  int result = p.parse();
  if (result == 0)
    switch (mode) {
      case PRINT:
        std::cout << "AST: " << *ast << std::endl;
        break;
      case EVAL:
	ast->sem_analyze();
	ast->execute();
        break;
      case COMPILE:
        ast->sem_analyze();
        prologue();
        ast->compile();
        epilogue();
        break;
    }
  if (from_file) fclose(yyin);
  return result;
}

int main(int argc, char **argv) {
  Mode mode = COMPILE;
  int named_files = 0;
  for (int i = 1; i < argc; ++i) {
    if (argv[i][0] == '-')
      switch (argv[i][1]) {
        case 'h':
          usage();
          return 0;
        case 'p':
          mode = PRINT;
          continue;
        case 'x':
          mode = EVAL;
          continue;
        case 'c':
          mode = COMPILE;
          continue;
      }
    ++named_files;
    int result = process(mode, argv[i]);
    if (result != 0) return result;
  }
  if (named_files == 0) process(mode, nullptr);
}
