nearly got functions to compile!

This commit is contained in:
SpookyDervish
2025-10-14 07:46:54 +11:00
parent 518a19d3bf
commit 5741a48e73
2 changed files with 63 additions and 30 deletions

View File

@@ -1,7 +1,7 @@
from llvmlite import ir
from AST import Node, NodeType, Program, Expression
from AST import ExpressionStatement, AssignmentStatement, BlockStatement
from AST import ExpressionStatement, AssignmentStatement, BlockStatement, ReturnStatement, FunctionStatement
from AST import InfixExpression
from AST import IntegerLiteral, FloatLiteral, IdentifierLiteral
@@ -11,13 +11,13 @@ from environment import Environment
class Compiler:
def __init__(self) -> None:
self.type_map: dict[str, ir.type] = {
"bool": ir.IntType(1),
"byte": ir.IntType(8),
"short": ir.IntType(16),
"int": ir.IntType(32),
"long": ir.IntType(64),
"float": ir.FloatType(),
"double": ir.DoubleType()
"Bool": ir.IntType(1),
"Byte": ir.IntType(8),
"Short": ir.IntType(16),
"Int": ir.IntType(32),
"Long": ir.IntType(64),
"Float": ir.FloatType(),
"Double": ir.DoubleType()
}
self.module: ir.Module = ir.Module("main")
@@ -34,6 +34,12 @@ class Compiler:
self.__visit_expression_statement(node)
case NodeType.AssignmentStatement:
self.__visit_assignment_statement(node)
case NodeType.FunctionStatement:
self.__visit_function_statement(node)
case NodeType.BlockStatement:
self.__visit_block_statement(node)
case NodeType.ReturnStatement:
self.__visit_return_statement(node)
# Expressions
case NodeType.InfixExpression:
@@ -41,23 +47,9 @@ class Compiler:
# region Visit Methods
def __visit_program(self, node: Program) -> None:
func_main: str = "main"
param_types: list[ir.Type] = []
return_type = ir.Type = self.type_map["int"]
fnty = ir.FunctionType(return_type, param_types)
func = ir.Function(self.module, fnty, func_main)
block = func.append_basic_block(f"{func_main}_entry")
self.builder = ir.IRBuilder(block)
for stmt in node.statements:
self.compile(stmt)
return_value: ir.Constant = ir.Constant(self.type_map["int"], 123)
self.builder.ret(return_value)
# region Statements
def __visit_expression_statement(self, node: ExpressionStatement) -> None:
self.compile(node.expr)
@@ -77,11 +69,52 @@ class Compiler:
self.builder.store(value, ptr)
# Add the variable to the environment
self.environment.define(name, value, Type)
self.environment.define(name, ptr, Type)
else:
ptr, _ = self.environment.lookup(name)
self.builder.store(value, ptr)
# endregion
def __visit_block_statement(self, node: BlockStatement) -> None:
for stmt in node.statements:
self.compile(stmt)
def __visit_return_statement(self, node: ReturnStatement) -> None:
value: Expression = node.return_value
value, Type = self.__resolve_value(node)
self.builder.ret(value)
def __visit_function_statement(self, node: FunctionStatement) -> None:
name: str = node.name.value
body: BlockStatement = node.body
params: list[IdentifierLiteral] = node.parameters
param_types: list[ir.Type] = [] # TODO
return_type: ir.Type = self.type_map[node.return_type]
fnty: ir.FunctionType = ir.FunctionType(return_type, param_types)
func: ir.Function = ir.Function(self.module, fnty, name)
block: ir.Block = func.append_basic_block(f"{name}_entry")
previous_builder = self.builder
self.builder = ir.IRBuilder(block)
previous_env = self.environment
self.environment = Environment(parent=self.environment)
self.environment.define(name, func, return_type)
self.compile(body)
self.environment = previous_env
self.environment.define(name, func, return_type)
self.builder = previous_builder
# endregion
# region Expressions
def __visit_infix_expression(self, node: InfixExpression) -> None:
@@ -93,7 +126,7 @@ class Compiler:
value = None
Type = None
if isinstance(right_type, ir.IntType) and isinstance(left_type, ir.IntType):
Type = self.type_map["int"]
Type = self.type_map["Int"]
match operator:
case "+":
value = self.builder.add(left_value, right_value)
@@ -109,7 +142,7 @@ class Compiler:
# TODO
pass
elif isinstance(right_type, ir.FloatType) and isinstance(left_type, ir.FloatType):
Type = self.type_map["float"]
Type = self.type_map["Float"]
match operator:
case "+":
value = self.builder.fadd(left_value, right_value)
@@ -135,11 +168,11 @@ class Compiler:
match node.type():
case NodeType.IntegerLiteral:
node: IntegerLiteral = node
value, Type = node.value, self.type_map['int']
value, Type = node.value, self.type_map['Int']
return ir.Constant(Type, value), Type
case NodeType.FloatLiteral:
node: FloatLiteral = node
value, Type = node.value, self.type_map['float']
value, Type = node.value, self.type_map['Float']
return ir.Constant(Type, value), Type
case NodeType.IdentifierLiteral:
node: IdentifierLiteral = node

View File

@@ -9,8 +9,8 @@ from llvmlite.binding import targets
from ctypes import CFUNCTYPE, c_int, c_float
LEXER_DEBUG: bool = False
PARSER_DEBUG: bool = True
COMPILER_DEBUG: bool = False
PARSER_DEBUG: bool = False
COMPILER_DEBUG: bool = True
if __name__ == "__main__":