from llvmlite import ir from AST import Node, NodeType, Program, Expression from AST import ExpressionStatement, AssignmentStatement, BlockStatement, ReturnStatement, FunctionStatement from AST import InfixExpression from AST import IntegerLiteral, FloatLiteral, IdentifierLiteral from environment import Environment class Compiler: def __init__(self) -> None: self.type_map: dict[str, ir.type] = { "Bool": ir.IntType(1), "Byte": ir.IntType(8), "Short": ir.IntType(16), "Int": ir.IntType(32), "Long": ir.IntType(64), "Float": ir.FloatType(), "Double": ir.DoubleType() } self.module: ir.Module = ir.Module("main") self.builder: ir.IRBuilder = ir.IRBuilder() self.environment: Environment = Environment() def compile(self, node: Node) -> None: match node.type(): case NodeType.Program: self.__visit_program(node) # Statements case NodeType.ExpressionStatement: self.__visit_expression_statement(node) case NodeType.AssignmentStatement: self.__visit_assignment_statement(node) case NodeType.FunctionStatement: self.__visit_function_statement(node) case NodeType.BlockStatement: self.__visit_block_statement(node) case NodeType.ReturnStatement: self.__visit_return_statement(node) # Expressions case NodeType.InfixExpression: self.__visit_infix_expression(node) # region Visit Methods def __visit_program(self, node: Program) -> None: for stmt in node.statements: self.compile(stmt) # region Statements def __visit_expression_statement(self, node: ExpressionStatement) -> None: self.compile(node.expr) def __visit_assignment_statement(self, node: AssignmentStatement) -> None: name: str = node.name.value value: Expression = node.value value_type: str = node.value_type # TODO: implemented value, Type = self.__resolve_value(node=value) if self.environment.lookup(name) is None: # Define and allocate the new variable ptr = self.builder.alloca(Type) # Storing the value to the ptr self.builder.store(value, ptr) # Add the variable to the environment self.environment.define(name, ptr, Type) else: ptr, _ = self.environment.lookup(name) self.builder.store(value, ptr) def __visit_block_statement(self, node: BlockStatement) -> None: for stmt in node.statements: self.compile(stmt) def __visit_return_statement(self, node: ReturnStatement) -> None: value: Expression = node.return_value value, Type = self.__resolve_value(node) self.builder.ret(value) def __visit_function_statement(self, node: FunctionStatement) -> None: name: str = node.name.value body: BlockStatement = node.body params: list[IdentifierLiteral] = node.parameters param_types: list[ir.Type] = [] # TODO return_type: ir.Type = self.type_map[node.return_type] fnty: ir.FunctionType = ir.FunctionType(return_type, param_types) func: ir.Function = ir.Function(self.module, fnty, name) block: ir.Block = func.append_basic_block(f"{name}_entry") previous_builder = self.builder self.builder = ir.IRBuilder(block) previous_env = self.environment self.environment = Environment(parent=self.environment) self.environment.define(name, func, return_type) self.compile(body) self.environment = previous_env self.environment.define(name, func, return_type) self.builder = previous_builder # endregion # region Expressions def __visit_infix_expression(self, node: InfixExpression) -> None: operator: str = node.operator left_value, left_type = self.__resolve_value(node.left_node) right_value, right_type = self.__resolve_value(node.right_node) value = None Type = None if isinstance(right_type, ir.IntType) and isinstance(left_type, ir.IntType): Type = self.type_map["Int"] match operator: case "+": value = self.builder.add(left_value, right_value) case "-": value = self.builder.sub(left_value, right_value) case "*": value = self.builder.mul(left_value, right_value) case "/": value = self.builder.sdiv(left_value, right_value) case "%": value = self.builder.srem(left_value, right_value) case "^": # TODO pass elif isinstance(right_type, ir.FloatType) and isinstance(left_type, ir.FloatType): Type = self.type_map["Float"] match operator: case "+": value = self.builder.fadd(left_value, right_value) case "-": value = self.builder.fsub(left_value, right_value) case "*": value = self.builder.fmul(left_value, right_value) case "/": value = self.builder.fdiv(left_value, right_value) case "%": value = self.builder.frem(left_value, right_value) case "^": # TODO pass return value, Type # endregion # endregion # region Helper Methods def __resolve_value(self, node: Expression) -> tuple[ir.Value, ir.Type]: match node.type(): case NodeType.IntegerLiteral: node: IntegerLiteral = node value, Type = node.value, self.type_map['Int'] return ir.Constant(Type, value), Type case NodeType.FloatLiteral: node: FloatLiteral = node value, Type = node.value, self.type_map['Float'] return ir.Constant(Type, value), Type case NodeType.IdentifierLiteral: node: IdentifierLiteral = node ptr, Type = self.environment.lookup(node.value) return self.builder.load(ptr), Type # expression value case NodeType.InfixExpression: return self.__visit_infix_expression(node) # endregion