nearly got functions to compile!

This commit is contained in:
SpookyDervish
2025-10-14 07:46:54 +11:00
parent 518a19d3bf
commit 5741a48e73
2 changed files with 63 additions and 30 deletions

View File

@@ -1,7 +1,7 @@
from llvmlite import ir from llvmlite import ir
from AST import Node, NodeType, Program, Expression from AST import Node, NodeType, Program, Expression
from AST import ExpressionStatement, AssignmentStatement, BlockStatement from AST import ExpressionStatement, AssignmentStatement, BlockStatement, ReturnStatement, FunctionStatement
from AST import InfixExpression from AST import InfixExpression
from AST import IntegerLiteral, FloatLiteral, IdentifierLiteral from AST import IntegerLiteral, FloatLiteral, IdentifierLiteral
@@ -11,13 +11,13 @@ from environment import Environment
class Compiler: class Compiler:
def __init__(self) -> None: def __init__(self) -> None:
self.type_map: dict[str, ir.type] = { self.type_map: dict[str, ir.type] = {
"bool": ir.IntType(1), "Bool": ir.IntType(1),
"byte": ir.IntType(8), "Byte": ir.IntType(8),
"short": ir.IntType(16), "Short": ir.IntType(16),
"int": ir.IntType(32), "Int": ir.IntType(32),
"long": ir.IntType(64), "Long": ir.IntType(64),
"float": ir.FloatType(), "Float": ir.FloatType(),
"double": ir.DoubleType() "Double": ir.DoubleType()
} }
self.module: ir.Module = ir.Module("main") self.module: ir.Module = ir.Module("main")
@@ -34,6 +34,12 @@ class Compiler:
self.__visit_expression_statement(node) self.__visit_expression_statement(node)
case NodeType.AssignmentStatement: case NodeType.AssignmentStatement:
self.__visit_assignment_statement(node) self.__visit_assignment_statement(node)
case NodeType.FunctionStatement:
self.__visit_function_statement(node)
case NodeType.BlockStatement:
self.__visit_block_statement(node)
case NodeType.ReturnStatement:
self.__visit_return_statement(node)
# Expressions # Expressions
case NodeType.InfixExpression: case NodeType.InfixExpression:
@@ -41,23 +47,9 @@ class Compiler:
# region Visit Methods # region Visit Methods
def __visit_program(self, node: Program) -> None: def __visit_program(self, node: Program) -> None:
func_main: str = "main"
param_types: list[ir.Type] = []
return_type = ir.Type = self.type_map["int"]
fnty = ir.FunctionType(return_type, param_types)
func = ir.Function(self.module, fnty, func_main)
block = func.append_basic_block(f"{func_main}_entry")
self.builder = ir.IRBuilder(block)
for stmt in node.statements: for stmt in node.statements:
self.compile(stmt) self.compile(stmt)
return_value: ir.Constant = ir.Constant(self.type_map["int"], 123)
self.builder.ret(return_value)
# region Statements # region Statements
def __visit_expression_statement(self, node: ExpressionStatement) -> None: def __visit_expression_statement(self, node: ExpressionStatement) -> None:
self.compile(node.expr) self.compile(node.expr)
@@ -77,10 +69,51 @@ class Compiler:
self.builder.store(value, ptr) self.builder.store(value, ptr)
# Add the variable to the environment # Add the variable to the environment
self.environment.define(name, value, Type) self.environment.define(name, ptr, Type)
else: else:
ptr, _ = self.environment.lookup(name) ptr, _ = self.environment.lookup(name)
self.builder.store(value, ptr) self.builder.store(value, ptr)
def __visit_block_statement(self, node: BlockStatement) -> None:
for stmt in node.statements:
self.compile(stmt)
def __visit_return_statement(self, node: ReturnStatement) -> None:
value: Expression = node.return_value
value, Type = self.__resolve_value(node)
self.builder.ret(value)
def __visit_function_statement(self, node: FunctionStatement) -> None:
name: str = node.name.value
body: BlockStatement = node.body
params: list[IdentifierLiteral] = node.parameters
param_types: list[ir.Type] = [] # TODO
return_type: ir.Type = self.type_map[node.return_type]
fnty: ir.FunctionType = ir.FunctionType(return_type, param_types)
func: ir.Function = ir.Function(self.module, fnty, name)
block: ir.Block = func.append_basic_block(f"{name}_entry")
previous_builder = self.builder
self.builder = ir.IRBuilder(block)
previous_env = self.environment
self.environment = Environment(parent=self.environment)
self.environment.define(name, func, return_type)
self.compile(body)
self.environment = previous_env
self.environment.define(name, func, return_type)
self.builder = previous_builder
# endregion # endregion
# region Expressions # region Expressions
@@ -93,7 +126,7 @@ class Compiler:
value = None value = None
Type = None Type = None
if isinstance(right_type, ir.IntType) and isinstance(left_type, ir.IntType): if isinstance(right_type, ir.IntType) and isinstance(left_type, ir.IntType):
Type = self.type_map["int"] Type = self.type_map["Int"]
match operator: match operator:
case "+": case "+":
value = self.builder.add(left_value, right_value) value = self.builder.add(left_value, right_value)
@@ -109,7 +142,7 @@ class Compiler:
# TODO # TODO
pass pass
elif isinstance(right_type, ir.FloatType) and isinstance(left_type, ir.FloatType): elif isinstance(right_type, ir.FloatType) and isinstance(left_type, ir.FloatType):
Type = self.type_map["float"] Type = self.type_map["Float"]
match operator: match operator:
case "+": case "+":
value = self.builder.fadd(left_value, right_value) value = self.builder.fadd(left_value, right_value)
@@ -135,11 +168,11 @@ class Compiler:
match node.type(): match node.type():
case NodeType.IntegerLiteral: case NodeType.IntegerLiteral:
node: IntegerLiteral = node node: IntegerLiteral = node
value, Type = node.value, self.type_map['int'] value, Type = node.value, self.type_map['Int']
return ir.Constant(Type, value), Type return ir.Constant(Type, value), Type
case NodeType.FloatLiteral: case NodeType.FloatLiteral:
node: FloatLiteral = node node: FloatLiteral = node
value, Type = node.value, self.type_map['float'] value, Type = node.value, self.type_map['Float']
return ir.Constant(Type, value), Type return ir.Constant(Type, value), Type
case NodeType.IdentifierLiteral: case NodeType.IdentifierLiteral:
node: IdentifierLiteral = node node: IdentifierLiteral = node

View File

@@ -9,8 +9,8 @@ from llvmlite.binding import targets
from ctypes import CFUNCTYPE, c_int, c_float from ctypes import CFUNCTYPE, c_int, c_float
LEXER_DEBUG: bool = False LEXER_DEBUG: bool = False
PARSER_DEBUG: bool = True PARSER_DEBUG: bool = False
COMPILER_DEBUG: bool = False COMPILER_DEBUG: bool = True
if __name__ == "__main__": if __name__ == "__main__":