Files
Plasma/compiler.py

152 lines
4.4 KiB
Python
Raw Normal View History

2025-10-13 21:05:03 +11:00
from llvmlite import ir
from AST import Node, NodeType, Program, Expression
2025-10-14 07:14:53 +11:00
from AST import ExpressionStatement, AssignmentStatement, BlockStatement
2025-10-13 21:05:03 +11:00
from AST import InfixExpression
from AST import IntegerLiteral, FloatLiteral, IdentifierLiteral
from environment import Environment
class Compiler:
def __init__(self) -> None:
self.type_map: dict[str, ir.type] = {
"bool": ir.IntType(1),
"byte": ir.IntType(8),
"short": ir.IntType(16),
"int": ir.IntType(32),
"long": ir.IntType(64),
"float": ir.FloatType(),
"double": ir.DoubleType()
}
self.module: ir.Module = ir.Module("main")
self.builder: ir.IRBuilder = ir.IRBuilder()
self.environment: Environment = Environment()
def compile(self, node: Node) -> None:
match node.type():
case NodeType.Program:
self.__visit_program(node)
# Statements
case NodeType.ExpressionStatement:
self.__visit_expression_statement(node)
case NodeType.AssignmentStatement:
self.__visit_assignment_statement(node)
# Expressions
case NodeType.InfixExpression:
self.__visit_infix_expression(node)
# region Visit Methods
def __visit_program(self, node: Program) -> None:
func_main: str = "main"
param_types: list[ir.Type] = []
return_type = ir.Type = self.type_map["int"]
fnty = ir.FunctionType(return_type, param_types)
func = ir.Function(self.module, fnty, func_main)
block = func.append_basic_block(f"{func_main}_entry")
self.builder = ir.IRBuilder(block)
for stmt in node.statements:
self.compile(stmt)
return_value: ir.Constant = ir.Constant(self.type_map["int"], 123)
self.builder.ret(return_value)
# region Statements
def __visit_expression_statement(self, node: ExpressionStatement) -> None:
self.compile(node.expr)
def __visit_assignment_statement(self, node: AssignmentStatement) -> None:
name: str = node.name.value
value: Expression = node.value
value_type: str = node.value_type # TODO: implemented
value, Type = self.__resolve_value(node=value)
if self.environment.lookup(name) is None:
# Define and allocate the new variable
ptr = self.builder.alloca(Type)
# Storing the value to the ptr
self.builder.store(value, ptr)
# Add the variable to the environment
self.environment.define(name, value, Type)
else:
ptr, _ = self.environment.lookup(name)
self.builder.store(value, ptr)
# endregion
# region Expressions
def __visit_infix_expression(self, node: InfixExpression) -> None:
operator: str = node.operator
left_value, left_type = self.__resolve_value(node.left_node)
right_value, right_type = self.__resolve_value(node.right_node)
value = None
Type = None
if isinstance(right_type, ir.IntType) and isinstance(left_type, ir.IntType):
Type = self.type_map["int"]
match operator:
case "+":
value = self.builder.add(left_value, right_value)
case "-":
value = self.builder.sub(left_value, right_value)
case "*":
value = self.builder.mul(left_value, right_value)
case "/":
value = self.builder.sdiv(left_value, right_value)
case "%":
value = self.builder.srem(left_value, right_value)
case "^":
# TODO
pass
elif isinstance(right_type, ir.FloatType) and isinstance(left_type, ir.FloatType):
Type = self.type_map["float"]
match operator:
case "+":
value = self.builder.fadd(left_value, right_value)
case "-":
value = self.builder.fsub(left_value, right_value)
case "*":
value = self.builder.fmul(left_value, right_value)
case "/":
value = self.builder.fdiv(left_value, right_value)
case "%":
value = self.builder.frem(left_value, right_value)
case "^":
# TODO
pass
return value, Type
# endregion
# endregion
# region Helper Methods
def __resolve_value(self, node: Expression) -> tuple[ir.Value, ir.Type]:
match node.type():
case NodeType.IntegerLiteral:
node: IntegerLiteral = node
value, Type = node.value, self.type_map['int']
return ir.Constant(Type, value), Type
case NodeType.FloatLiteral:
node: FloatLiteral = node
value, Type = node.value, self.type_map['float']
return ir.Constant(Type, value), Type
case NodeType.IdentifierLiteral:
node: IdentifierLiteral = node
ptr, Type = self.environment.lookup(node.value)
return self.builder.load(ptr), Type
# expression value
case NodeType.InfixExpression:
return self.__visit_infix_expression(node)
# endregion