From 3d9208f0f8b0651a4ba539e24d5486f16f3552d5 Mon Sep 17 00:00:00 2001 From: SpookyDervish <78246495+SpookyDervish@users.noreply.github.com> Date: Tue, 14 Oct 2025 21:23:11 +1100 Subject: [PATCH] if statements work --- AST.py | 32 ++++++++++++++++++ compiler.py | 84 ++++++++++++++++++++++++++++++++++++++++++++++-- debug/ast.json | 61 ++++++++++++++++++++++++++--------- debug/ir.ll | 23 +++++++++---- lexer.py | 33 ++++++++++++++++++- lexer_token.py | 1 + plasma_parser.py | 56 ++++++++++++++++++++++++++------ 7 files changed, 256 insertions(+), 34 deletions(-) diff --git a/AST.py b/AST.py index 228a0bb..3b81ae2 100644 --- a/AST.py +++ b/AST.py @@ -12,6 +12,7 @@ class NodeType(Enum): FunctionStatement = "FunctionStatement" BlockStatement = "BlockStatement" ReturnStatement = "ReturnStatement" + IfStatement = "IfStatement" # Expressions InfixExpression = "InfixExpression" @@ -20,6 +21,7 @@ class NodeType(Enum): IntegerLiteral = "IntegerLiteral" FloatLiteral = "FloatLiteral" IdentifierLiteral = "IdentifierLiteral" + BooleanLiteral = "BooleanLiteral" class Node: @abstractmethod @@ -88,6 +90,19 @@ class IdentifierLiteral(Expression): "type": self.type().value, "value": self.value } + +class BooleanLiteral(Expression): + def __init__(self, value: bool = None) -> None: + self.value: bool = value + + def type(self) -> NodeType: + return NodeType.BooleanLiteral + + def json(self) -> dict: + return { + "type": self.type().value, + "value": self.value + } # endregion # region Statements @@ -180,6 +195,23 @@ class ReassignStatement(Statement): "ident": self.ident.json(), "right_value": self.right_value.json() } + +class IfStatement(Statement): + def __init__(self, condition: Expression = None, consequence: BlockStatement = None, alternative: BlockStatement = None) -> None: + self.condition = condition + self.consequence = consequence + self.alternative = alternative + + def type(self) -> NodeType: + return NodeType.IfStatement + + def json(self) -> dict: + return { + "type": self.type().value, + "condition": self.condition.json(), + "consequence": self.consequence.json(), + "alternative": self.alternative.json() if self.alternative is not None else None + } # endregion # region Expressions diff --git a/compiler.py b/compiler.py index cdfb5e5..8492bb2 100644 --- a/compiler.py +++ b/compiler.py @@ -1,9 +1,9 @@ from llvmlite import ir from AST import Node, NodeType, Program, Expression -from AST import ExpressionStatement, AssignmentStatement, BlockStatement, ReturnStatement, FunctionStatement, ReassignStatement +from AST import ExpressionStatement, AssignmentStatement, BlockStatement, ReturnStatement, FunctionStatement, ReassignStatement, IfStatement from AST import InfixExpression -from AST import IntegerLiteral, FloatLiteral, IdentifierLiteral +from AST import IntegerLiteral, FloatLiteral, IdentifierLiteral, BooleanLiteral from environment import Environment @@ -25,6 +25,26 @@ class Compiler: self.environment: Environment = Environment() self.errors: list[str] = [] + self.__initialize_builtins() + + def __initialize_builtins(self) -> None: + def __init_booleans() -> tuple[ir.GlobalVariable, ir.GlobalVariable]: + bool_type: ir.Type = self.type_map["Bool"] + + true_var = ir.GlobalVariable(self.module, bool_type, "true") + true_var.initializer = ir.Constant(bool_type, 1) + true_var.global_constant = True + + false_var = ir.GlobalVariable(self.module, bool_type, "false") + false_var.initializer = ir.Constant(bool_type, 0) + false_var.global_constant = True + + return true_var, false_var + + true_var, false_var = __init_booleans() + self.environment.define("true", true_var, true_var.type) + self.environment.define("false", false_var, false_var.type) + def compile(self, node: Node) -> None: match node.type(): case NodeType.Program: @@ -43,6 +63,8 @@ class Compiler: self.__visit_return_statement(node) case NodeType.ReassignStatement: self.__visit_reassign_statement(node) + case NodeType.IfStatement: + self.__visit_if_statement(node) # Expressions case NodeType.InfixExpression: @@ -129,6 +151,31 @@ class Compiler: else: ptr, _ = self.environment.lookup(name) self.builder.store(value, ptr) + + def __visit_if_statement(self, node: IfStatement) -> None: + condition = node.condition + consequence = node.consequence + alternative = node.alternative + + test, _ = self.__resolve_value(condition) + + if alternative is None: + with self.builder.if_then(test): + self.compile(consequence) + else: + with self.builder.if_else(test) as (true, otherwise): + # Creating a condition branch + # condition + # / \ + # true false + # / \ + # / \ + # if block else block + with true: + self.compile(consequence) + + with otherwise: + self.compile(alternative) # endregion # region Expressions @@ -156,6 +203,21 @@ class Compiler: case "^": # TODO pass + case "<": + value = self.builder.icmp_signed('<', left_value, right_value) + Type = ir.IntType(1) + case "<=": + value = self.builder.icmp_signed('<=', left_value, right_value) + Type = ir.IntType(1) + case ">": + value = self.builder.icmp_signed('>', left_value, right_value) + Type = ir.IntType(1) + case ">=": + value = self.builder.icmp_signed('>=', left_value, right_value) + Type = ir.IntType(1) + case "==": + value = self.builder.icmp_signed('==', left_value, right_value) + Type = ir.IntType(1) elif isinstance(right_type, ir.FloatType) and isinstance(left_type, ir.FloatType): Type = self.type_map["Float"] match operator: @@ -172,6 +234,21 @@ class Compiler: case "^": # TODO pass + case "<": + value = self.builder.fcmp_ordered('<', left_value, right_value) + Type = ir.IntType(1) + case "<=": + value = self.builder.fcmp_ordered('<=', left_value, right_value) + Type = ir.IntType(1) + case ">": + value = self.builder.fcmp_ordered('>', left_value, right_value) + Type = ir.IntType(1) + case ">=": + value = self.builder.fcmp_ordered('>=', left_value, right_value) + Type = ir.IntType(1) + case "==": + value = self.builder.fcmp_ordered('==', left_value, right_value) + Type = ir.IntType(1) return value, Type # endregion @@ -193,6 +270,9 @@ class Compiler: node: IdentifierLiteral = node ptr, Type = self.environment.lookup(node.value) return self.builder.load(ptr), Type + case NodeType.BooleanLiteral: + node: BooleanLiteral = node + return ir.Constant(ir.IntType(1), 1 if node.value else 0), ir.IntType(1) # expression value case NodeType.InfixExpression: diff --git a/debug/ast.json b/debug/ast.json index b19efb1..98f9c22 100644 --- a/debug/ast.json +++ b/debug/ast.json @@ -21,26 +21,57 @@ }, "value": { "type": "IntegerLiteral", - "value": 0 + "value": 123 }, "value_type": "Int" }, { - "type": "ReassignStatement", - "ident": { - "type": "IdentifierLiteral", - "value": "x" - }, - "right_value": { - "type": "InfixExpression", - "left_node": { - "type": "IdentifierLiteral", - "value": "x" + "type": "ExpressionStatement", + "expr": { + "type": "IfStatement", + "condition": { + "type": "InfixExpression", + "left_node": { + "type": "IdentifierLiteral", + "value": "x" + }, + "operator": "==", + "right_node": { + "type": "IntegerLiteral", + "value": 123 + } }, - "operator": "*", - "right_node": { - "type": "IntegerLiteral", - "value": 2 + "consequence": { + "type": "BlockStatement", + "statements": [ + { + "type": "ReassignStatement", + "ident": { + "type": "IdentifierLiteral", + "value": "x" + }, + "right_value": { + "type": "IntegerLiteral", + "value": 0 + } + } + ] + }, + "alternative": { + "type": "BlockStatement", + "statements": [ + { + "type": "ReassignStatement", + "ident": { + "type": "IdentifierLiteral", + "value": "x" + }, + "right_value": { + "type": "IntegerLiteral", + "value": 1 + } + } + ] } } }, diff --git a/debug/ir.ll b/debug/ir.ll index b304946..9426d61 100644 --- a/debug/ir.ll +++ b/debug/ir.ll @@ -2,14 +2,23 @@ target triple = "x86_64-pc-windows-msvc" target datalayout = "" -define i32 @"test"() -{ -test_entry: - ret i32 0 -} - +@"true" = constant i1 1 +@"false" = constant i1 0 define i32 @"main"() { main_entry: - ret i32 123 + %".2" = alloca i32 + store i32 6, i32* %".2" + %".4" = load i32, i32* %".2" + %".5" = icmp eq i32 %".4", 5 + br i1 %".5", label %"main_entry.if", label %"main_entry.else" +main_entry.if: + store i32 0, i32* %".2" + br label %"main_entry.endif" +main_entry.else: + store i32 1, i32* %".2" + br label %"main_entry.endif" +main_entry.endif: + %".11" = load i32, i32* %".2" + ret i32 %".11" } diff --git a/lexer.py b/lexer.py index b4c71ca..f7111b7 100644 --- a/lexer.py +++ b/lexer.py @@ -95,8 +95,39 @@ class Lexer: tok = self.__new_token(TokenType.POW, self.current_char) case "%": tok = self.__new_token(TokenType.MODULUS, self.current_char) + case "<": + # Handle <= + if self.__peek_char() == "=": + ch = self.current_char + self.__read_char() + tok = self.__new_token(TokenType.LT_EQ, ch + self.current_char) + else: + tok = self.__new_token(TokenType.LT, self.current_char) + case ">": + # Handle >= + if self.__peek_char() == "=": + ch = self.current_char + self.__read_char() + tok = self.__new_token(TokenType.GT_EQ, ch + self.current_char) + else: + tok = self.__new_token(TokenType.GT, self.current_char) case "=": - tok = self.__new_token(TokenType.EQ, self.current_char) + # Handle == + if self.__peek_char() == "=": + ch = self.current_char + self.__read_char() + tok = self.__new_token(TokenType.EQ_EQ, ch + self.current_char) + else: + tok = self.__new_token(TokenType.EQ, self.current_char) + case "!": + # Handle != + if self.__peek_char() == "=": + ch = self.current_char + self.__read_char() + tok = self.__new_token(TokenType.NOT_EQ, ch + self.current_char) + else: + # TODO: handle BANG + tok = self.__new_token(TokenType.ILLEGAL, self.current_char) case "(": tok = self.__new_token(TokenType.LPAREN, self.current_char) case ")": diff --git a/lexer_token.py b/lexer_token.py index e15d6a5..e7693c0 100644 --- a/lexer_token.py +++ b/lexer_token.py @@ -29,6 +29,7 @@ class TokenType(Enum): EQ_EQ = "==" LT_EQ = "<=" GT_EQ = ">=" + NOT_EQ = "!=" # Symbols LPAREN = "LPAREN" diff --git a/plasma_parser.py b/plasma_parser.py index bd28d94..b638116 100644 --- a/plasma_parser.py +++ b/plasma_parser.py @@ -4,9 +4,9 @@ from typing import Callable from enum import Enum, auto from AST import Statement, Expression, Program -from AST import ExpressionStatement, AssignmentStatement, FunctionStatement, ReturnStatement, BlockStatement, ReassignStatement +from AST import ExpressionStatement, AssignmentStatement, FunctionStatement, ReturnStatement, BlockStatement, ReassignStatement, IfStatement from AST import InfixExpression -from AST import IntegerLiteral, FloatLiteral, IdentifierLiteral +from AST import IntegerLiteral, FloatLiteral, IdentifierLiteral, BooleanLiteral class PrecedenceType(Enum): P_LOWEST = 0 @@ -25,7 +25,13 @@ PRECEDENCES: dict[TokenType, PrecedenceType] = { TokenType.ASTERISK: PrecedenceType.P_PRODUCT, TokenType.SLASH: PrecedenceType.P_PRODUCT, TokenType.MODULUS: PrecedenceType.P_PRODUCT, - TokenType.POW: PrecedenceType.P_EXPONENT + TokenType.POW: PrecedenceType.P_EXPONENT, + TokenType.EQ_EQ: PrecedenceType.P_EQUALS, + TokenType.NOT_EQ: PrecedenceType.P_EQUALS, + TokenType.LT: PrecedenceType.P_LESSGREATER, + TokenType.GT: PrecedenceType.P_LESSGREATER, + TokenType.LT_EQ: PrecedenceType.P_LESSGREATER, + TokenType.GT_EQ: PrecedenceType.P_LESSGREATER } class Parser: @@ -38,22 +44,27 @@ class Parser: self.peek_token: Token = None self.prefix_parse_functions: dict[Token, Callable] = { # -1 - TokenType.IDENT: self.__parse_identifier, TokenType.INT: self.__parse_int_literal, TokenType.FLOAT: self.__parse_float_literal, - TokenType.LPAREN: self.__parse_grouped_expression - + TokenType.LPAREN: self.__parse_grouped_expression, + TokenType.IF: self.__parse_if_statement, + TokenType.TRUE: self.__parse_boolean, + TokenType.FALSE: self.__parse_boolean, } self.infix_parse_functions: dict[Token, Callable] = { # 5 + 5 - TokenType.PLUS: self.__parse_infix_expression, TokenType.MINUS: self.__parse_infix_expression, TokenType.SLASH: self.__parse_infix_expression, TokenType.ASTERISK: self.__parse_infix_expression, TokenType.POW: self.__parse_infix_expression, - TokenType.MODULUS: self.__parse_infix_expression - + TokenType.MODULUS: self.__parse_infix_expression, + TokenType.EQ_EQ: self.__parse_infix_expression, + TokenType.NOT_EQ: self.__parse_infix_expression, + TokenType.LT: self.__parse_infix_expression, + TokenType.GT: self.__parse_infix_expression, + TokenType.LT_EQ: self.__parse_infix_expression, + TokenType.GT_EQ: self.__parse_infix_expression, } self.__next_token() @@ -229,6 +240,30 @@ class Parser: self.__next_token() return block_stmt + + def __parse_if_statement(self) -> IfStatement: + condition: Expression = None + consequence: BlockStatement = None + alternative: BlockStatement = None + + self.__next_token() + + condition = self.__parse_expression(PrecedenceType.P_LOWEST) + + if not self.__expect_peek(TokenType.LBRACE): + return None + + consequence = self.__parse_block_statement() + + if self.__peek_token_is(TokenType.UNLESS): + self.__next_token() + + if not self.__expect_peek(TokenType.LBRACE): + return None + + alternative = self.__parse_block_statement() + + return IfStatement(condition, consequence, alternative) # endregion # region Expression Methods @@ -297,4 +332,7 @@ class Parser: return None return float_lit + + def __parse_boolean(self) -> BooleanLiteral: + return BooleanLiteral(value=self.__current_token_is(TokenType.TRUE)) # endregion \ No newline at end of file