#include <stdio.h>
#include <stdlib.h>

#include "ast.h"
#include "parser.h"


AST ast_node(int code, AST left, AST right)
{
  AST n = (AST) malloc(sizeof(struct tree_tag));
  if (n == NULL) {
    fprintf(stderr, "Out of memory\n");
    exit(1);
  }
  n->code = code;
  n->left = left;
  n->right = right;
  return n;
}

AST ast_let(variable var, AST expr)
{
  AST n = (AST) malloc(sizeof(struct tree_tag));
  if (n == NULL) {
    fprintf(stderr, "Out of memory\n");
    exit(1);
  }
  n->code = T_let;
  n->var = var;
  n->left = expr;
  return n;
}

AST ast_num(number num)
{
  AST n = (AST) malloc(sizeof(struct tree_tag));
  if (n == NULL) {
    fprintf(stderr, "Out of memory\n");
    exit(1);
  }
  n->code = T_const;
  n->num = num;
  return n;
}

AST ast_var(variable var)
{
  AST n = (AST) malloc(sizeof(struct tree_tag));
  if (n == NULL) {
    fprintf(stderr, "Out of memory\n");
    exit(1);
  }
  n->code = T_var;
  n->var = var;
  return n;
}

#define NO_RESULT 0

number memory[26];

number interpret (AST ast)
{
  int i, n;

  if (ast != NULL) {
    switch (ast->code) {

    case T_begin:
      interpret(ast->left);
      interpret(ast->right);
      return NO_RESULT;

    case T_print:
      printf("%u\n", interpret(ast->left));
      return NO_RESULT;

    case T_let:
      memory[ast->var] = interpret(ast->left);
      return NO_RESULT;
      
    case T_for:
      n = interpret(ast->left);
      for (i=0; i<n; i++)
        interpret(ast->right);
      return NO_RESULT;

    case T_if:
      if (interpret(ast->left))
        interpret(ast->right);
      return NO_RESULT;
      
    case T_const:
      return ast->num;

    case T_var:
      return memory[ast->var];

    case '+':
      return interpret(ast->left) + interpret(ast->right);

    case '*':
      return interpret(ast->left) * interpret(ast->right);
    }
  }
  return NO_RESULT;
}
