from llvmlite import ir import os from AST import Node, NodeType, Program, Expression from AST import ExpressionStatement, AssignmentStatement, BlockStatement, ReturnStatement, FunctionStatement, ReassignStatement, IfStatement from AST import WhileStatement, BreakStatement, ContinueStatement, ForStatement, DependStatement from AST import InfixExpression, CallExpression, PrefixExpression, PostfixExpression from AST import IntegerLiteral, FloatLiteral, IdentifierLiteral, BooleanLiteral, StringLiteral from AST import FunctionParameter from environment import Environment from lexer import Lexer from plasma_parser import Parser class Compiler: def __init__(self) -> None: self.type_map: dict[str, ir.type] = { "Bool": ir.IntType(1), "Byte": ir.IntType(8), "Short": ir.IntType(16), "Int": ir.IntType(32), "Long": ir.IntType(64), "Float": ir.FloatType(), "Double": ir.DoubleType(), "String": ir.PointerType(ir.IntType(8)), "Nil": ir.VoidType() } self.module: ir.Module = ir.Module("main") self.builder: ir.IRBuilder = ir.IRBuilder() self.environment: Environment = Environment() self.errors: list[str] = [] self.counter = 0 self.__initialize_builtins() self.breakpoints: list[ir.Block] = [] self.continues: list[ir.Block] = [] self.global_parsed_pallets: dict[str, Program] = {} def __initialize_builtins(self) -> None: def __init_print() -> ir.Function: fnty: ir.FunctionType = ir.FunctionType( self.type_map["Int"], [ir.IntType(8).as_pointer()], var_arg=True ) return ir.Function(self.module, fnty, "printf") def __init_booleans() -> tuple[ir.GlobalVariable, ir.GlobalVariable]: bool_type: ir.Type = self.type_map["Bool"] true_var = ir.GlobalVariable(self.module, bool_type, "true") true_var.initializer = ir.Constant(bool_type, 1) true_var.global_constant = True false_var = ir.GlobalVariable(self.module, bool_type, "false") false_var.initializer = ir.Constant(bool_type, 0) false_var.global_constant = True return true_var, false_var self.environment.define("print", __init_print(), ir.IntType(32)) true_var, false_var = __init_booleans() self.environment.define("true", true_var, true_var.type) self.environment.define("false", false_var, false_var.type) def __increment_counter(self) -> int: self.counter += 1 return self.counter def compile(self, node: Node) -> None: match node.type(): case NodeType.Program: self.__visit_program(node) # Statements case NodeType.ExpressionStatement: self.__visit_expression_statement(node) case NodeType.AssignmentStatement: self.__visit_assignment_statement(node) case NodeType.FunctionStatement: self.__visit_function_statement(node) case NodeType.BlockStatement: self.__visit_block_statement(node) case NodeType.ReturnStatement: self.__visit_return_statement(node) case NodeType.ReassignStatement: self.__visit_reassign_statement(node) case NodeType.IfStatement: self.__visit_if_statement(node) case NodeType.WhileStatement: self.__visit_while_statement(node) case NodeType.BreakStatement: self.__visit_break_statement(node) case NodeType.ContinueStatement: self.__visit_continue_statement(node) case NodeType.ForStatement: self.__visit_for_statement(node) case NodeType.DependStatement: self.__visit_depend_statement(node) # Expressions case NodeType.InfixExpression: self.__visit_infix_expression(node) case NodeType.CallExpression: self.__visit_call_expression(node) case NodeType.PostfixExpression: self.__visit_postfix_expression(node) # region Visit Methods def __visit_program(self, node: Program) -> None: for stmt in node.statements: self.compile(stmt) # region Statements def __visit_expression_statement(self, node: ExpressionStatement) -> None: self.compile(node.expr) def __visit_assignment_statement(self, node: AssignmentStatement) -> None: name: str = node.name.value value: Expression = node.value value_type: str = node.value_type # TODO: implemented value, Type = self.__resolve_value(node=value) if self.environment.lookup(name) is None: # Define and allocate the new variable ptr = self.builder.alloca(Type) # Storing the value to the ptr self.builder.store(value, ptr) # Add the variable to the environment self.environment.define(name, ptr, Type) else: ptr, _ = self.environment.lookup(name) self.builder.store(value, ptr) def __visit_block_statement(self, node: BlockStatement) -> None: for stmt in node.statements: self.compile(stmt) def __visit_return_statement(self, node: ReturnStatement) -> None: value: Expression = node.return_value value, Type = self.__resolve_value(value) self.builder.ret(value) def __visit_function_statement(self, node: FunctionStatement) -> None: name: str = node.name.value body: BlockStatement = node.body params: list[FunctionParameter] = node.parameters param_names: list[str] = [p.name for p in params] param_types: list[ir.Type] = [self.type_map[p.value_type] for p in params] return_type: ir.Type = self.type_map[node.return_type] fnty: ir.FunctionType = ir.FunctionType(return_type, param_types) func: ir.Function = ir.Function(self.module, fnty, name) block: ir.Block = func.append_basic_block(f"{name}_entry") previous_builder = self.builder self.builder = ir.IRBuilder(block) # storing the pointers to each parameter params_ptr = [] for i, typ in enumerate(param_types): ptr = self.builder.alloca(typ) self.builder.store(func.args[i], ptr) params_ptr.append(ptr) # adding the params to the environment previous_env = self.environment self.environment = Environment(parent=self.environment) for i, x in enumerate(zip(param_types, param_names)): typ = param_types[i] ptr = params_ptr[i] self.environment.define(x[1], ptr, typ) self.environment.define(name, func, return_type) self.compile(body) self.environment = previous_env self.environment.define(name, func, return_type) self.builder = previous_builder def __visit_reassign_statement(self, node: ReassignStatement) -> None: name: str = node.ident.value operator: str = node.operator value: Expression = node.right_value if self.environment.lookup(name) is None: self.errors.append(f"Identifier {name} has not been declared before it was re-assigned.") return right_value, right_type = self.__resolve_value(value) var_ptr, _ = self.environment.lookup(name) orig_value = self.builder.load(var_ptr) if isinstance(orig_value.type, ir.IntType) and isinstance(right_type, ir.FloatType): orig_value = self.builder.sitofp(orig_value, ir.FloatType()) if isinstance(orig_value.type, ir.FloatType) and isinstance(right_type, ir.IntType): right_value = self.builder.sitofp(right_value, ir.FloatType()) value = None Type = None match operator: case "=": value = right_value case "+=": if isinstance(orig_value.type, ir.IntType) and isinstance(right_type, ir.IntType): value = self.builder.add(orig_value, right_value) else: value = self.builder.fadd(orig_value, right_value) case "-=": if isinstance(orig_value.type, ir.IntType) and isinstance(right_type, ir.IntType): value = self.builder.sub(orig_value, right_value) else: value = self.builder.fsub(orig_value, right_value) case "*=": if isinstance(orig_value.type, ir.IntType) and isinstance(right_type, ir.IntType): value = self.builder.mul(orig_value, right_value) else: value = self.builder.fmul(orig_value, right_value) case "/=": if isinstance(orig_value.type, ir.IntType) and isinstance(right_type, ir.IntType): value = self.builder.sdiv(orig_value, right_value) else: value = self.builder.fdiv(orig_value, right_value) case _: print("Unsupported assignment operator.") ptr, _ = self.environment.lookup(name) self.builder.store(value, ptr) def __visit_if_statement(self, node: IfStatement) -> None: condition = node.condition consequence = node.consequence alternative = node.alternative test, _ = self.__resolve_value(condition) if alternative is None: with self.builder.if_then(test): self.compile(consequence) else: with self.builder.if_else(test) as (true, otherwise): # Creating a condition branch # condition # / \ # true false # / \ # / \ # if block else block with true: self.compile(consequence) with otherwise: self.compile(alternative) def __visit_while_statement(self, node: WhileStatement) -> None: condition: Expression = node.condition body: BlockStatement = node.body test, _ = self.__resolve_value(condition) while_loop_entry = self.builder.append_basic_block(f"while_loop_entry_{self.__increment_counter()}") while_loop_otherwise = self.builder.append_basic_block(f"while_loop_otherwise_{self.counter}") # Creating a condition branch # condition # / \ # true false # / \ # / \ # if block else block self.builder.cbranch(test, while_loop_entry, while_loop_otherwise) self.builder.position_at_start(while_loop_entry) self.compile(body) test, _ = self.__resolve_value(condition) self.builder.cbranch(test, while_loop_entry, while_loop_otherwise) self.builder.position_at_start(while_loop_otherwise) def __visit_break_statement(self, node: BreakStatement) -> None: self.builder.branch(self.breakpoints[-1]) def __visit_continue_statement(self, node: ContinueStatement) -> None: self.builder.branch[self.continues[-1]] def __visit_for_statement(self, node: ForStatement) -> None: var_declaration: AssignmentStatement = node.var_declaration condition: Expression = node.condition action: ReassignStatement = node.action body: BlockStatement = node.body previous_env = self.environment self.environment = Environment(parent=previous_env) self.compile(var_declaration) for_loop_entry = self.builder.append_basic_block(f"for_loop_entry_{self.__increment_counter()}") for_loop_otherwise = self.builder.append_basic_block(f"for_loop_otherwise_{self.counter}") self.breakpoints.append(for_loop_otherwise) self.continues.append(for_loop_entry) self.builder.branch(for_loop_entry) self.builder.position_at_start(for_loop_entry) self.compile(body) self.compile(action) test, _ = self.__resolve_value(condition) self.builder.cbranch(test, for_loop_entry, for_loop_otherwise) self.builder.position_at_start(for_loop_otherwise) self.breakpoints.pop() self.continues.pop() def __visit_depend_statement(self, node: DependStatement) -> None: file_path: str = node.file_path if self.global_parsed_pallets.get(file_path) is not None: print(f"warning: \"{file_path}\" is already imported globally!\n") return with open(os.path.abspath(file_path), "r") as f: pallet_code: str = f.read() l: Lexer = Lexer(pallet_code) p: Parser = Parser(l) program: Program = p.parse_program() if len(p.errors) > 0: print(f"Error in dependency: {file_path}") for err in p.errors: print(err) exit(1) self.compile(program) self.global_parsed_pallets[file_path] = program # endregion # region Expressions def __visit_infix_expression(self, node: InfixExpression) -> None: operator: str = node.operator left_value, left_type = self.__resolve_value(node.left_node) right_value, right_type = self.__resolve_value(node.right_node) value = None Type = None if isinstance(right_type, ir.IntType) and isinstance(left_type, ir.IntType): Type = self.type_map["Int"] match operator: case "+": value = self.builder.add(left_value, right_value) case "-": value = self.builder.sub(left_value, right_value) case "*": value = self.builder.mul(left_value, right_value) case "/": value = self.builder.sdiv(left_value, right_value) case "%": value = self.builder.srem(left_value, right_value) case "^": # TODO pass case "<": value = self.builder.icmp_signed('<', left_value, right_value) Type = ir.IntType(1) case "<=": value = self.builder.icmp_signed('<=', left_value, right_value) Type = ir.IntType(1) case ">": value = self.builder.icmp_signed('>', left_value, right_value) Type = ir.IntType(1) case ">=": value = self.builder.icmp_signed('>=', left_value, right_value) Type = ir.IntType(1) case "==": value = self.builder.icmp_signed('==', left_value, right_value) Type = ir.IntType(1) case "!=": value = self.builder.icmp_signed('!=', left_value, right_value) Type = ir.IntType(1) elif isinstance(right_type, ir.FloatType) and isinstance(left_type, ir.FloatType): Type = self.type_map["Float"] match operator: case "+": value = self.builder.fadd(left_value, right_value) case "-": value = self.builder.fsub(left_value, right_value) case "*": value = self.builder.fmul(left_value, right_value) case "/": value = self.builder.fdiv(left_value, right_value) case "%": value = self.builder.frem(left_value, right_value) case "^": # TODO pass case "<": value = self.builder.fcmp_ordered('<', left_value, right_value) Type = ir.IntType(1) case "<=": value = self.builder.fcmp_ordered('<=', left_value, right_value) Type = ir.IntType(1) case ">": value = self.builder.fcmp_ordered('>', left_value, right_value) Type = ir.IntType(1) case ">=": value = self.builder.fcmp_ordered('>=', left_value, right_value) Type = ir.IntType(1) case "==": value = self.builder.fcmp_ordered('==', left_value, right_value) Type = ir.IntType(1) case "!=": value = self.builder.fcmp_ordered('!=', left_value, right_value) Type = ir.IntType(1) return value, Type def __visit_call_expression(self, node: CallExpression) -> tuple[ir.Instruction, ir.Type]: name: str = node.function.value params: list[Expression] = node.arguments args = [] types = [] if len(params) > 0: for x in params: p_val, p_type = self.__resolve_value(x) args.append(p_val) types.append(p_type) match name: case "print": ret = self.builtin_print(params=args, return_type=types[0]) ret_type = self.type_map["Int"] case _: func, ret_type = self.environment.lookup(name) ret = self.builder.call(func, args) return ret, ret_type def __visit_prefix_expression(self, node: PrefixExpression) -> tuple[ir.Value, ir.Type]: operator: str = node.operator right_node: Expression = node.right_node right_value, right_type = self.__resolve_value(right_node) Type = None value = None if isinstance(right_type, ir.FloatType): Type = ir.FloatType match operator: case "-": value = self.builder.fmul(right_value, ir.Constant(ir.FloatType(), -1.0)) case "!": value = ir.Constant(ir.IntType(1), 0) elif isinstance(right_type, ir.IntType): Type = ir.IntType(32) match operator: case "-": value = self.builder.mul(right_value, ir.Constant(ir.IntType(32), -1)) case "!": value = self.builder.not_(right_value) return value, Type def __visit_postfix_expression(self, node: PostfixExpression) -> None: left_node: IdentifierLiteral = node.left_node operator: str = node.operator if self.environment.lookup(left_node.value) is None: self.errors.append(f"Identifier {left_node.value} has not been declared before it was re-assigned.") return var_ptr, _ = self.environment.lookup(left_node.value) orig_value = self.builder.load(var_ptr) value = None match operator: case "++": if isinstance(orig_value.type, ir.IntType): value = self.builder.add(orig_value, ir.Constant(ir.IntType(32), 1)) elif isinstance(orig_value.type, ir.FloatType): value = self.builder.fadd(orig_value, ir.Constant(ir.FloatType(), 1.0)) case "--": if isinstance(orig_value.type, ir.IntType): value = self.builder.sub(orig_value, ir.Constant(ir.IntType(32), 1)) elif isinstance(orig_value.type, ir.FloatType): value = self.builder.fsub(orig_value, ir.Constant(ir.FloatType(), 1.0)) self.builder.store(value, var_ptr) # endregion # endregion # region Helper Methods def __resolve_value(self, node: Expression) -> tuple[ir.Value, ir.Type]: match node.type(): case NodeType.IntegerLiteral: node: IntegerLiteral = node value, Type = node.value, self.type_map['Int'] return ir.Constant(Type, value), Type case NodeType.FloatLiteral: node: FloatLiteral = node value, Type = node.value, self.type_map['Float'] return ir.Constant(Type, value), Type case NodeType.IdentifierLiteral: node: IdentifierLiteral = node ptr, Type = self.environment.lookup(node.value) return self.builder.load(ptr), Type case NodeType.BooleanLiteral: node: BooleanLiteral = node return ir.Constant(ir.IntType(1), 1 if node.value else 0), ir.IntType(1) case NodeType.StringLiteral: node: StringLiteral = node string, Type = self.__convert_string(node.value) return string, Type # expression value case NodeType.InfixExpression: return self.__visit_infix_expression(node) case NodeType.CallExpression: return self.__visit_call_expression(node) case NodeType.PrefixExpression: return self.__visit_prefix_expression(node) def __convert_string(self, string: str) -> tuple[ir.Constant, ir.ArrayType]: string = string.replace("\\n", "\n\0") fmt: str = f"{string}\0" c_fmt: ir.Constant = ir.Constant(ir.ArrayType(ir.IntType(8), len(fmt)), bytearray(fmt.encode("utf-8"))) global_fmt = ir.GlobalVariable(self.module, c_fmt.type, name=f"__str_{self.__increment_counter()}") global_fmt.linkage = "internal" global_fmt.global_constant = True global_fmt.initializer = c_fmt return global_fmt, global_fmt.type def builtin_print(self, params: list[ir.Instruction], return_type: ir.Type) -> None: func, _ = self.environment.lookup("print") c_str = self.builder.alloca(return_type) self.builder.store(params[0], c_str) rest_params = params[1:] if isinstance(params[0], ir.LoadInstr): # printing from a variable load instruction c_fmt: ir.LoadInstr = params[0] g_var_ptr = c_fmt.operands[0] string_val = self.builder.load(g_var_ptr) fmt_arg = self.builder.bitcast(string_val, ir.IntType(8).as_pointer()) return self.builder.call(func, [fmt_arg, *rest_params]) else: # printing from a normal string fmt_arg = self.builder.bitcast(self.module.get_global(f"__str_{self.counter}"), ir.IntType(8).as_pointer()) return self.builder.call(func, [fmt_arg, *rest_params]) # endregion