From 5741a48e73e8cfa37880ba15b7a1775ffc44741f Mon Sep 17 00:00:00 2001 From: SpookyDervish <78246495+SpookyDervish@users.noreply.github.com> Date: Tue, 14 Oct 2025 07:46:54 +1100 Subject: [PATCH] nearly got functions to compile! --- compiler.py | 89 ++++++++++++++++++++++++++++++++++++----------------- main.py | 4 +-- 2 files changed, 63 insertions(+), 30 deletions(-) diff --git a/compiler.py b/compiler.py index 8f20b25..5c74811 100644 --- a/compiler.py +++ b/compiler.py @@ -1,7 +1,7 @@ from llvmlite import ir from AST import Node, NodeType, Program, Expression -from AST import ExpressionStatement, AssignmentStatement, BlockStatement +from AST import ExpressionStatement, AssignmentStatement, BlockStatement, ReturnStatement, FunctionStatement from AST import InfixExpression from AST import IntegerLiteral, FloatLiteral, IdentifierLiteral @@ -11,13 +11,13 @@ from environment import Environment class Compiler: def __init__(self) -> None: self.type_map: dict[str, ir.type] = { - "bool": ir.IntType(1), - "byte": ir.IntType(8), - "short": ir.IntType(16), - "int": ir.IntType(32), - "long": ir.IntType(64), - "float": ir.FloatType(), - "double": ir.DoubleType() + "Bool": ir.IntType(1), + "Byte": ir.IntType(8), + "Short": ir.IntType(16), + "Int": ir.IntType(32), + "Long": ir.IntType(64), + "Float": ir.FloatType(), + "Double": ir.DoubleType() } self.module: ir.Module = ir.Module("main") @@ -34,6 +34,12 @@ class Compiler: self.__visit_expression_statement(node) case NodeType.AssignmentStatement: self.__visit_assignment_statement(node) + case NodeType.FunctionStatement: + self.__visit_function_statement(node) + case NodeType.BlockStatement: + self.__visit_block_statement(node) + case NodeType.ReturnStatement: + self.__visit_return_statement(node) # Expressions case NodeType.InfixExpression: @@ -41,23 +47,9 @@ class Compiler: # region Visit Methods def __visit_program(self, node: Program) -> None: - func_main: str = "main" - param_types: list[ir.Type] = [] - return_type = ir.Type = self.type_map["int"] - - fnty = ir.FunctionType(return_type, param_types) - func = ir.Function(self.module, fnty, func_main) - - block = func.append_basic_block(f"{func_main}_entry") - - self.builder = ir.IRBuilder(block) - for stmt in node.statements: self.compile(stmt) - return_value: ir.Constant = ir.Constant(self.type_map["int"], 123) - self.builder.ret(return_value) - # region Statements def __visit_expression_statement(self, node: ExpressionStatement) -> None: self.compile(node.expr) @@ -77,11 +69,52 @@ class Compiler: self.builder.store(value, ptr) # Add the variable to the environment - self.environment.define(name, value, Type) + self.environment.define(name, ptr, Type) else: ptr, _ = self.environment.lookup(name) self.builder.store(value, ptr) - # endregion + + def __visit_block_statement(self, node: BlockStatement) -> None: + for stmt in node.statements: + self.compile(stmt) + + def __visit_return_statement(self, node: ReturnStatement) -> None: + value: Expression = node.return_value + value, Type = self.__resolve_value(node) + + self.builder.ret(value) + + def __visit_function_statement(self, node: FunctionStatement) -> None: + name: str = node.name.value + body: BlockStatement = node.body + + params: list[IdentifierLiteral] = node.parameters + param_types: list[ir.Type] = [] # TODO + + return_type: ir.Type = self.type_map[node.return_type] + + fnty: ir.FunctionType = ir.FunctionType(return_type, param_types) + func: ir.Function = ir.Function(self.module, fnty, name) + + block: ir.Block = func.append_basic_block(f"{name}_entry") + + previous_builder = self.builder + + self.builder = ir.IRBuilder(block) + + previous_env = self.environment + + self.environment = Environment(parent=self.environment) + self.environment.define(name, func, return_type) + + self.compile(body) + + self.environment = previous_env + self.environment.define(name, func, return_type) + + self.builder = previous_builder + + # endregion # region Expressions def __visit_infix_expression(self, node: InfixExpression) -> None: @@ -93,7 +126,7 @@ class Compiler: value = None Type = None if isinstance(right_type, ir.IntType) and isinstance(left_type, ir.IntType): - Type = self.type_map["int"] + Type = self.type_map["Int"] match operator: case "+": value = self.builder.add(left_value, right_value) @@ -109,7 +142,7 @@ class Compiler: # TODO pass elif isinstance(right_type, ir.FloatType) and isinstance(left_type, ir.FloatType): - Type = self.type_map["float"] + Type = self.type_map["Float"] match operator: case "+": value = self.builder.fadd(left_value, right_value) @@ -135,11 +168,11 @@ class Compiler: match node.type(): case NodeType.IntegerLiteral: node: IntegerLiteral = node - value, Type = node.value, self.type_map['int'] + value, Type = node.value, self.type_map['Int'] return ir.Constant(Type, value), Type case NodeType.FloatLiteral: node: FloatLiteral = node - value, Type = node.value, self.type_map['float'] + value, Type = node.value, self.type_map['Float'] return ir.Constant(Type, value), Type case NodeType.IdentifierLiteral: node: IdentifierLiteral = node diff --git a/main.py b/main.py index a7582ee..31fd052 100644 --- a/main.py +++ b/main.py @@ -9,8 +9,8 @@ from llvmlite.binding import targets from ctypes import CFUNCTYPE, c_int, c_float LEXER_DEBUG: bool = False -PARSER_DEBUG: bool = True -COMPILER_DEBUG: bool = False +PARSER_DEBUG: bool = False +COMPILER_DEBUG: bool = True if __name__ == "__main__":