Files
Plasma/compiler.py
2025-10-19 10:32:33 +11:00

713 lines
24 KiB
Python

from llvmlite import ir
import os
from AST import Node, NodeType, Program, Expression
from AST import ExpressionStatement, AssignmentStatement, BlockStatement, ReturnStatement, FunctionStatement, ReassignStatement, IfStatement
from AST import WhileStatement, BreakStatement, ContinueStatement, ForStatement, DependStatement
from AST import InfixExpression, CallExpression, PrefixExpression, PostfixExpression
from AST import IntegerLiteral, FloatLiteral, IdentifierLiteral, BooleanLiteral, StringLiteral
from AST import FunctionParameter
from environment import Environment
from lexer import Lexer
from plasma_parser import Parser
class Compiler:
def __init__(self) -> None:
self.type_map: dict[str, ir.type] = {
"Bool": ir.IntType(1),
"Byte": ir.IntType(8),
"Short": ir.IntType(16),
"Int": ir.IntType(32),
"Long": ir.IntType(64),
"Float": ir.DoubleType(),
"Double": ir.DoubleType(),
"String": ir.PointerType(ir.IntType(8)),
"Nil": ir.VoidType()
}
self.module: ir.Module = ir.Module("main")
self.builder: ir.IRBuilder = ir.IRBuilder()
self.environment: Environment = Environment()
self.errors: list[str] = []
self.counter = 0
self.__initialize_builtins()
self.breakpoints: list[ir.Block] = []
self.continues: list[ir.Block] = []
self.global_parsed_pallets: dict[str, Program] = {}
def __initialize_builtins(self) -> None:
def __init_print() -> ir.Function:
fnty: ir.FunctionType = ir.FunctionType(
self.type_map["Int"],
[ir.IntType(8).as_pointer()],
var_arg=True
)
return ir.Function(self.module, fnty, "printf")
def __init_pow() -> ir.Function:
fnty: ir.FunctionType = ir.FunctionType(
self.type_map["Float"],
[ir.DoubleType(), ir.DoubleType()],
var_arg=False
)
self.environment.define("pow", ir.Function(self.module, fnty, "pow"), ir.IntType(32))
def __init_booleans() -> tuple[ir.GlobalVariable, ir.GlobalVariable]:
bool_type: ir.Type = self.type_map["Bool"]
true_var = ir.GlobalVariable(self.module, bool_type, "true")
true_var.initializer = ir.Constant(bool_type, 1)
true_var.global_constant = True
false_var = ir.GlobalVariable(self.module, bool_type, "false")
false_var.initializer = ir.Constant(bool_type, 0)
false_var.global_constant = True
self.environment.define("true", true_var, true_var.type)
self.environment.define("false", false_var, false_var.type)
self.environment.define("print", __init_print(), ir.IntType(32))
__init_booleans()
__init_pow()
def __increment_counter(self) -> int:
self.counter += 1
return self.counter
def compile(self, node: Node) -> None:
match node.type():
case NodeType.Program:
self.__visit_program(node)
# Statements
case NodeType.ExpressionStatement:
self.__visit_expression_statement(node)
case NodeType.AssignmentStatement:
self.__visit_assignment_statement(node)
case NodeType.FunctionStatement:
self.__visit_function_statement(node)
case NodeType.BlockStatement:
self.__visit_block_statement(node)
case NodeType.ReturnStatement:
self.__visit_return_statement(node)
case NodeType.ReassignStatement:
self.__visit_reassign_statement(node)
case NodeType.IfStatement:
self.__visit_if_statement(node)
case NodeType.WhileStatement:
self.__visit_while_statement(node)
case NodeType.BreakStatement:
self.__visit_break_statement(node)
case NodeType.ContinueStatement:
self.__visit_continue_statement(node)
case NodeType.ForStatement:
self.__visit_for_statement(node)
case NodeType.DependStatement:
self.__visit_depend_statement(node)
# Expressions
case NodeType.InfixExpression:
self.__visit_infix_expression(node)
case NodeType.CallExpression:
self.__visit_call_expression(node)
case NodeType.PostfixExpression:
self.__visit_postfix_expression(node)
# region Visit Methods
def __visit_program(self, node: Program) -> None:
for stmt in node.statements:
self.compile(stmt)
# region Statements
def __visit_expression_statement(self, node: ExpressionStatement) -> None:
self.compile(node.expr)
def __visit_assignment_statement(self, node: AssignmentStatement) -> None:
name: str = node.name.value
value: Expression = node.value
value_type = self.type_map[node.value_type] # TODO: implemented
print(value.json())
value, Type = self.__resolve_value(value)
if self.environment.lookup(name) is None:
# Define and allocate the new variable
ptr = self.builder.alloca(value_type)
# Storing the value to the ptr
if isinstance(value_type, ir.DoubleType) and isinstance(Type, ir.IntType):
self.builder.store(self.builder.sitofp(value, value_type), ptr)
else:
print(value_type, Type, value_type)
self.builder.store(value, ptr)
# Add the variable to the environment
self.environment.define(name, ptr, value_type)
else:
ptr, _ = self.environment.lookup(name)
self.builder.store(value, ptr)
def __visit_block_statement(self, node: BlockStatement) -> None:
for stmt in node.statements:
self.compile(stmt)
def __visit_return_statement(self, node: ReturnStatement) -> None:
value: Expression = node.return_value
value, Type = self.__resolve_value(value)
self.builder.ret(value)
def __visit_function_statement(self, node: FunctionStatement) -> None:
name: str = node.name.value
body: BlockStatement = node.body
params: list[FunctionParameter] = node.parameters
param_names: list[str] = [p.name for p in params]
param_types: list[ir.Type] = [self.type_map[p.value_type] for p in params]
return_type: ir.Type = self.type_map[node.return_type]
fnty: ir.FunctionType = ir.FunctionType(return_type, param_types)
func: ir.Function = ir.Function(self.module, fnty, name)
block: ir.Block = func.append_basic_block(f"{name}_entry")
previous_builder = self.builder
self.builder = ir.IRBuilder(block)
# storing the pointers to each parameter
params_ptr = []
for i, typ in enumerate(param_types):
ptr = self.builder.alloca(typ)
self.builder.store(func.args[i], ptr)
params_ptr.append(ptr)
# adding the params to the environment
previous_env = self.environment
self.environment = Environment(parent=self.environment)
for i, x in enumerate(zip(param_types, param_names)):
typ = param_types[i]
ptr = params_ptr[i]
self.environment.define(x[1], ptr, typ)
self.environment.define(name, func, return_type)
self.compile(body)
self.environment = previous_env
self.environment.define(name, func, return_type)
self.builder = previous_builder
def __visit_reassign_statement(self, node: ReassignStatement) -> None:
name: str = node.ident.value
operator: str = node.operator
value: Expression = node.right_value
if self.environment.lookup(name) is None:
self.errors.append(f"Identifier {name} has not been declared before it was re-assigned.")
return
right_value, right_type = self.__resolve_value(value)
var_ptr, _ = self.environment.lookup(name)
orig_value = self.builder.load(var_ptr)
if isinstance(orig_value.type, ir.IntType) and isinstance(right_type, ir.DoubleType):
orig_value = self.builder.sitofp(orig_value, ir.DoubleType())
if isinstance(orig_value.type, ir.DoubleType) and isinstance(right_type, ir.IntType):
right_value = self.builder.sitofp(right_value, ir.DoubleType())
value = None
Type = None
match operator:
case "=":
value = right_value
case "+=":
if isinstance(orig_value.type, ir.IntType) and isinstance(right_type, ir.IntType):
value = self.builder.add(orig_value, right_value)
else:
value = self.builder.fadd(orig_value, right_value)
case "-=":
if isinstance(orig_value.type, ir.IntType) and isinstance(right_type, ir.IntType):
value = self.builder.sub(orig_value, right_value)
else:
value = self.builder.fsub(orig_value, right_value)
case "*=":
if isinstance(orig_value.type, ir.IntType) and isinstance(right_type, ir.IntType):
value = self.builder.mul(orig_value, right_value)
else:
value = self.builder.fmul(orig_value, right_value)
case "/=":
if isinstance(orig_value.type, ir.IntType) and isinstance(right_type, ir.IntType):
value = self.builder.sdiv(orig_value, right_value)
else:
value = self.builder.fdiv(orig_value, right_value)
case _:
print("Unsupported assignment operator.")
return
ptr, _ = self.environment.lookup(name)
self.builder.store(value, ptr)
def __visit_if_statement(self, node: IfStatement) -> None:
condition = node.condition
consequence = node.consequence
alternative = node.alternative
test, _ = self.__resolve_value(condition)
if alternative is None:
with self.builder.if_then(test):
self.compile(consequence)
else:
with self.builder.if_else(test) as (true, otherwise):
# Creating a condition branch
# condition
# / \
# true false
# / \
# / \
# if block else block
with true:
self.compile(consequence)
with otherwise:
self.compile(alternative)
def __visit_while_statement(self, node: WhileStatement) -> None:
condition: Expression = node.condition
body: BlockStatement = node.body
test, _ = self.__resolve_value(condition)
while_loop_entry = self.builder.append_basic_block(f"while_loop_entry_{self.__increment_counter()}")
while_loop_otherwise = self.builder.append_basic_block(f"while_loop_otherwise_{self.counter}")
self.breakpoints.append(while_loop_otherwise)
self.continues.append(while_loop_entry)
# Creating a condition branch
# condition
# / \
# true false
# / \
# / \
# if block else block
self.builder.cbranch(test, while_loop_entry, while_loop_otherwise)
self.builder.position_at_start(while_loop_entry)
self.compile(body)
test, _ = self.__resolve_value(condition)
self.builder.cbranch(test, while_loop_entry, while_loop_otherwise)
self.builder.position_at_start(while_loop_otherwise)
self.breakpoints.pop()
self.continues.pop()
def __visit_break_statement(self, node: BreakStatement) -> None:
self.builder.branch(self.breakpoints[-1])
def __visit_continue_statement(self, node: ContinueStatement) -> None:
self.builder.branch[self.continues[-1]]
def __visit_for_statement(self, node: ForStatement) -> None:
var_declaration: AssignmentStatement = node.var_declaration
condition: Expression = node.condition
action: ReassignStatement = node.action
body: BlockStatement = node.body
previous_env = self.environment
self.environment = Environment(parent=previous_env)
self.compile(var_declaration)
for_loop_entry = self.builder.append_basic_block(f"for_loop_entry_{self.__increment_counter()}")
for_loop_otherwise = self.builder.append_basic_block(f"for_loop_otherwise_{self.counter}")
self.breakpoints.append(for_loop_otherwise)
self.continues.append(for_loop_entry)
self.builder.branch(for_loop_entry)
self.builder.position_at_start(for_loop_entry)
self.compile(body)
self.compile(action)
test, _ = self.__resolve_value(condition)
self.builder.cbranch(test, for_loop_entry, for_loop_otherwise)
self.builder.position_at_start(for_loop_otherwise)
self.breakpoints.pop()
self.continues.pop()
def __visit_depend_statement(self, node: DependStatement) -> None:
file_path: str = node.file_path
if self.global_parsed_pallets.get(file_path) is not None:
print(f"warning: \"{file_path}\" is already imported globally!\n")
return
with open(os.path.abspath(file_path), "r") as f:
pallet_code: str = f.read()
l: Lexer = Lexer(pallet_code)
p: Parser = Parser(l)
program: Program = p.parse_program()
if len(p.errors) > 0:
print(f"Error in dependency: {file_path}")
for err in p.errors:
print(err)
exit(1)
self.compile(program)
self.global_parsed_pallets[file_path] = program
# endregion
# region Expressions
def __visit_infix_expression(self, node: InfixExpression) -> None:
operator: str = node.operator
left_value, left_type = self.__resolve_value(node.left_node)
right_value, right_type = self.__resolve_value(node.right_node)
value = None
Type = None
if isinstance(right_type, ir.IntType) and isinstance(left_type, ir.IntType):
Type = self.type_map["Int"]
match operator:
case "+":
value = self.builder.add(left_value, right_value)
case "-":
value = self.builder.sub(left_value, right_value)
case "*":
value = self.builder.mul(left_value, right_value)
case "/":
value = self.builder.sdiv(left_value, right_value)
case "%":
value = self.builder.srem(left_value, right_value)
case "<":
value = self.builder.icmp_signed('<', left_value, right_value)
Type = ir.IntType(1)
case "<=":
value = self.builder.icmp_signed('<=', left_value, right_value)
Type = ir.IntType(1)
case ">":
value = self.builder.icmp_signed('>', left_value, right_value)
Type = ir.IntType(1)
case ">=":
value = self.builder.icmp_signed('>=', left_value, right_value)
Type = ir.IntType(1)
case "==":
value = self.builder.icmp_signed('==', left_value, right_value)
Type = ir.IntType(1)
case "!=":
value = self.builder.icmp_signed('!=', left_value, right_value)
Type = ir.IntType(1)
elif isinstance(right_type, ir.DoubleType) and isinstance(left_type, ir.DoubleType):
Type = self.type_map["Float"]
match operator:
case "+":
value = self.builder.fadd(left_value, right_value)
case "-":
value = self.builder.fsub(left_value, right_value)
case "*":
value = self.builder.fmul(left_value, right_value)
case "/":
value = self.builder.fdiv(left_value, right_value)
case "%":
value = self.builder.frem(left_value, right_value)
case "<":
value = self.builder.fcmp_ordered('<', left_value, right_value)
Type = ir.IntType(1)
case "<=":
value = self.builder.fcmp_ordered('<=', left_value, right_value)
Type = ir.IntType(1)
case ">":
value = self.builder.fcmp_ordered('>', left_value, right_value)
Type = ir.IntType(1)
case ">=":
value = self.builder.fcmp_ordered('>=', left_value, right_value)
Type = ir.IntType(1)
case "==":
value = self.builder.fcmp_ordered('==', left_value, right_value)
Type = ir.IntType(1)
case "!=":
value = self.builder.fcmp_ordered('!=', left_value, right_value)
Type = ir.IntType(1)
elif isinstance(right_type, ir.IntType) and isinstance(left_type, ir.DoubleType):
Type = self.type_map["Float"]
match operator:
case "+":
value = self.builder.fadd(left_value, self.builder.sitofp(right_value, ir.DoubleType()))
case "-":
value = self.builder.fsub(left_value, self.builder.sitofp(right_value, ir.DoubleType()))
case "*":
value = self.builder.fmul(left_value, self.builder.sitofp(right_value, ir.DoubleType()))
case "/":
value = self.builder.fdiv(left_value, self.builder.sitofp(right_value, ir.DoubleType()))
case "%":
value = self.builder.frem(left_value, self.builder.sitofp(right_value, ir.DoubleType()))
case "<":
value = self.builder.fcmp_ordered('<', left_value, self.builder.sitofp(right_value, ir.DoubleType()))
Type = ir.IntType(1)
case "<=":
value = self.builder.fcmp_ordered('<=', left_value, self.builder.sitofp(right_value, ir.DoubleType()))
Type = ir.IntType(1)
case ">":
value = self.builder.fcmp_ordered('>', left_value, self.builder.sitofp(right_value, ir.DoubleType()))
Type = ir.IntType(1)
case ">=":
value = self.builder.fcmp_ordered('>=', left_value, self.builder.sitofp(right_value, ir.DoubleType()))
Type = ir.IntType(1)
case "==":
value = self.builder.fcmp_ordered('==', left_value, self.builder.sitofp(right_value, ir.DoubleType()))
Type = ir.IntType(1)
case "!=":
value = self.builder.fcmp_ordered('!=', left_value, self.builder.sitofp(right_value, ir.DoubleType()))
Type = ir.IntType(1)
elif isinstance(right_type, ir.DoubleType) and isinstance(left_type, ir.IntType):
Type = self.type_map["Float"]
match operator:
case "+":
value = self.builder.fadd(self.builder.sitofp(left_value, ir.DoubleType()), right_value)
case "-":
value = self.builder.fsub(self.builder.sitofp(left_value, ir.DoubleType()), right_value)
case "*":
value = self.builder.fmul(self.builder.sitofp(left_value, ir.DoubleType()), right_value)
case "/":
value = self.builder.fdiv(self.builder.sitofp(left_value, ir.DoubleType()), right_value)
case "%":
value = self.builder.frem(self.builder.sitofp(left_value, ir.DoubleType()), right_value)
case "<":
value = self.builder.fcmp_ordered('<', self.builder.sitofp(left_value, ir.DoubleType()), right_value)
Type = ir.IntType(1)
case "<=":
value = self.builder.fcmp_ordered('<=', self.builder.sitofp(left_value, ir.DoubleType()), right_value)
Type = ir.IntType(1)
case ">":
value = self.builder.fcmp_ordered('>', self.builder.sitofp(left_value, ir.DoubleType()), right_value)
Type = ir.IntType(1)
case ">=":
value = self.builder.fcmp_ordered('>=', self.builder.sitofp(left_value, ir.DoubleType()), right_value)
Type = ir.IntType(1)
case "==":
value = self.builder.fcmp_ordered('==', self.builder.sitofp(left_value, ir.DoubleType()), right_value)
Type = ir.IntType(1)
case "!=":
value = self.builder.fcmp_ordered('!=', self.builder.sitofp(left_value, ir.DoubleType()), right_value)
Type = ir.IntType(1)
return value, Type
def __visit_call_expression(self, node: CallExpression) -> tuple[ir.Instruction, ir.Type]:
name: str = node.function.value
params: list[Expression] = node.arguments
args = []
types = []
if len(params) > 0:
for x in params:
p_val, p_type = self.__resolve_value(x)
args.append(p_val)
types.append(p_type)
match name:
case "print":
ret = self.builtin_print(params=args, return_type=types[0])
ret_type = self.type_map["Int"]
case "pow":
ret = self.builtin_pow(params=args)
ret_type = self.type_map["Float"]
case _:
if not self.environment.lookup(name):
print(f"The function \"{name}\" is not defined.")
exit(1)
func, ret_type = self.environment.lookup(name)
ret = self.builder.call(func, args)
return ret, ret_type
def __visit_prefix_expression(self, node: PrefixExpression) -> tuple[ir.Value, ir.Type]:
operator: str = node.operator
right_node: Expression = node.right_node
right_value, right_type = self.__resolve_value(right_node)
Type = None
value = None
if isinstance(right_type, ir.DoubleType):
Type = ir.DoubleType
match operator:
case "-":
value = self.builder.fmul(right_value, ir.Constant(ir.DoubleType(), -1.0))
case "!":
value = ir.Constant(ir.IntType(1), 0)
elif isinstance(right_type, ir.IntType):
Type = ir.IntType(32)
match operator:
case "-":
value = self.builder.mul(right_value, ir.Constant(ir.IntType(32), -1))
case "!":
value = self.builder.not_(right_value)
return value, Type
def __visit_postfix_expression(self, node: PostfixExpression) -> None:
left_node: IdentifierLiteral = node.left_node
operator: str = node.operator
if self.environment.lookup(left_node.value) is None:
self.errors.append(f"Identifier {left_node.value} has not been declared before it was re-assigned.")
return
var_ptr, _ = self.environment.lookup(left_node.value)
orig_value = self.builder.load(var_ptr)
value = None
match operator:
case "++":
if isinstance(orig_value.type, ir.IntType):
value = self.builder.add(orig_value, ir.Constant(ir.IntType(32), 1))
elif isinstance(orig_value.type, ir.DoubleType):
value = self.builder.fadd(orig_value, ir.Constant(ir.DoubleType(), 1.0))
case "--":
if isinstance(orig_value.type, ir.IntType):
value = self.builder.sub(orig_value, ir.Constant(ir.IntType(32), 1))
elif isinstance(orig_value.type, ir.DoubleType):
value = self.builder.fsub(orig_value, ir.Constant(ir.DoubleType(), 1.0))
self.builder.store(value, var_ptr)
# endregion
# endregion
# region Helper Methods
def __resolve_value(self, node: Expression) -> tuple[ir.Value, ir.Type]:
match node.type():
case NodeType.IntegerLiteral:
node: IntegerLiteral = node
value, Type = node.value, self.type_map['Int']
return ir.Constant(Type, value), Type
case NodeType.FloatLiteral:
node: FloatLiteral = node
value, Type = node.value, self.type_map['Float']
return ir.Constant(Type, value), Type
case NodeType.IdentifierLiteral:
node: IdentifierLiteral = node
idk = self.environment.lookup(node.value)
if not idk:
print(f"\"{node.value}\" is not defined in the current scope.")
exit(1)
ptr, Type = idk
return self.builder.load(ptr), Type
case NodeType.BooleanLiteral:
node: BooleanLiteral = node
return ir.Constant(ir.IntType(1), 1 if node.value else 0), ir.IntType(1)
case NodeType.StringLiteral:
node: StringLiteral = node
string, Type = self.__convert_string(node.value)
return string, Type
# expression value
case NodeType.InfixExpression:
return self.__visit_infix_expression(node)
case NodeType.CallExpression:
return self.__visit_call_expression(node)
case NodeType.PrefixExpression:
return self.__visit_prefix_expression(node)
def __convert_string(self, string: str) -> tuple[ir.Constant, ir.ArrayType]:
string = string.replace("\\n", "\n\0")
fmt: str = f"{string}\0"
c_fmt: ir.Constant = ir.Constant(ir.ArrayType(ir.IntType(8), len(fmt)), bytearray(fmt.encode("utf-8")))
global_fmt = ir.GlobalVariable(self.module, c_fmt.type, name=f"__str_{self.__increment_counter()}")
global_fmt.linkage = "internal"
global_fmt.global_constant = True
global_fmt.initializer = c_fmt
return global_fmt, global_fmt.type
# region Builtins
def builtin_print(self, params: list[ir.Instruction], return_type: ir.Type) -> None:
func, _ = self.environment.lookup("print")
c_str = self.builder.alloca(return_type)
self.builder.store(params[0], c_str)
rest_params = params[1:]
if isinstance(params[0], ir.LoadInstr):
# printing from a variable load instruction
c_fmt: ir.LoadInstr = params[0]
g_var_ptr = c_fmt.operands[0]
string_val = self.builder.load(g_var_ptr)
fmt_arg = self.builder.bitcast(string_val, ir.IntType(8).as_pointer())
return self.builder.call(func, [fmt_arg, *rest_params])
else:
# printing from a normal string
fmt_arg = self.builder.bitcast(self.module.get_global(f"__str_{self.counter}"), ir.IntType(8).as_pointer())
return self.builder.call(func, [fmt_arg, *rest_params])
def builtin_pow(self, params: list[ir.Instruction]) -> None:
func, _ = self.environment.lookup("pow")
rest_params = params[1:]
if isinstance(params[0], ir.LoadInstr) and not isinstance(params[1], ir.LoadInstr):
# printing from a variable load instruction
c_fmt: ir.LoadInstr = params[0]
g_var_ptr = c_fmt.operands[0]
val = self.builder.load(g_var_ptr)
return self.builder.call(func, [val, *rest_params])
if isinstance(params[1], ir.LoadInstr) and not isinstance(params[0], ir.LoadInstr):
# printing from a variable load instruction
c_fmt: ir.LoadInstr = params[1]
g_var_ptr = c_fmt.operands[1]
val = self.builder.load(g_var_ptr)
return self.builder.call(func, [val, *rest_params])
if isinstance(params[0], ir.LoadInstr) and isinstance(params[1], ir.LoadInstr):
# printing from a variable load instruction
c_fmt: ir.LoadInstr = params[0]
g_var_ptr = c_fmt.operands[0]
val = self.builder.load(g_var_ptr)
c_fmt2: ir.LoadInstr = params[1]
g_var_ptr2 = c_fmt2.operands[1]
val2 = self.builder.load(g_var_ptr2)
return self.builder.call(func, [val, val2])
return self.builder.call(func, params)
# endregion
# endregion