if statements work

This commit is contained in:
SpookyDervish
2025-10-14 21:23:11 +11:00
parent 48e7488a63
commit 3d9208f0f8
7 changed files with 256 additions and 34 deletions

32
AST.py
View File

@@ -12,6 +12,7 @@ class NodeType(Enum):
FunctionStatement = "FunctionStatement" FunctionStatement = "FunctionStatement"
BlockStatement = "BlockStatement" BlockStatement = "BlockStatement"
ReturnStatement = "ReturnStatement" ReturnStatement = "ReturnStatement"
IfStatement = "IfStatement"
# Expressions # Expressions
InfixExpression = "InfixExpression" InfixExpression = "InfixExpression"
@@ -20,6 +21,7 @@ class NodeType(Enum):
IntegerLiteral = "IntegerLiteral" IntegerLiteral = "IntegerLiteral"
FloatLiteral = "FloatLiteral" FloatLiteral = "FloatLiteral"
IdentifierLiteral = "IdentifierLiteral" IdentifierLiteral = "IdentifierLiteral"
BooleanLiteral = "BooleanLiteral"
class Node: class Node:
@abstractmethod @abstractmethod
@@ -88,6 +90,19 @@ class IdentifierLiteral(Expression):
"type": self.type().value, "type": self.type().value,
"value": self.value "value": self.value
} }
class BooleanLiteral(Expression):
def __init__(self, value: bool = None) -> None:
self.value: bool = value
def type(self) -> NodeType:
return NodeType.BooleanLiteral
def json(self) -> dict:
return {
"type": self.type().value,
"value": self.value
}
# endregion # endregion
# region Statements # region Statements
@@ -180,6 +195,23 @@ class ReassignStatement(Statement):
"ident": self.ident.json(), "ident": self.ident.json(),
"right_value": self.right_value.json() "right_value": self.right_value.json()
} }
class IfStatement(Statement):
def __init__(self, condition: Expression = None, consequence: BlockStatement = None, alternative: BlockStatement = None) -> None:
self.condition = condition
self.consequence = consequence
self.alternative = alternative
def type(self) -> NodeType:
return NodeType.IfStatement
def json(self) -> dict:
return {
"type": self.type().value,
"condition": self.condition.json(),
"consequence": self.consequence.json(),
"alternative": self.alternative.json() if self.alternative is not None else None
}
# endregion # endregion
# region Expressions # region Expressions

View File

@@ -1,9 +1,9 @@
from llvmlite import ir from llvmlite import ir
from AST import Node, NodeType, Program, Expression from AST import Node, NodeType, Program, Expression
from AST import ExpressionStatement, AssignmentStatement, BlockStatement, ReturnStatement, FunctionStatement, ReassignStatement from AST import ExpressionStatement, AssignmentStatement, BlockStatement, ReturnStatement, FunctionStatement, ReassignStatement, IfStatement
from AST import InfixExpression from AST import InfixExpression
from AST import IntegerLiteral, FloatLiteral, IdentifierLiteral from AST import IntegerLiteral, FloatLiteral, IdentifierLiteral, BooleanLiteral
from environment import Environment from environment import Environment
@@ -25,6 +25,26 @@ class Compiler:
self.environment: Environment = Environment() self.environment: Environment = Environment()
self.errors: list[str] = [] self.errors: list[str] = []
self.__initialize_builtins()
def __initialize_builtins(self) -> None:
def __init_booleans() -> tuple[ir.GlobalVariable, ir.GlobalVariable]:
bool_type: ir.Type = self.type_map["Bool"]
true_var = ir.GlobalVariable(self.module, bool_type, "true")
true_var.initializer = ir.Constant(bool_type, 1)
true_var.global_constant = True
false_var = ir.GlobalVariable(self.module, bool_type, "false")
false_var.initializer = ir.Constant(bool_type, 0)
false_var.global_constant = True
return true_var, false_var
true_var, false_var = __init_booleans()
self.environment.define("true", true_var, true_var.type)
self.environment.define("false", false_var, false_var.type)
def compile(self, node: Node) -> None: def compile(self, node: Node) -> None:
match node.type(): match node.type():
case NodeType.Program: case NodeType.Program:
@@ -43,6 +63,8 @@ class Compiler:
self.__visit_return_statement(node) self.__visit_return_statement(node)
case NodeType.ReassignStatement: case NodeType.ReassignStatement:
self.__visit_reassign_statement(node) self.__visit_reassign_statement(node)
case NodeType.IfStatement:
self.__visit_if_statement(node)
# Expressions # Expressions
case NodeType.InfixExpression: case NodeType.InfixExpression:
@@ -129,6 +151,31 @@ class Compiler:
else: else:
ptr, _ = self.environment.lookup(name) ptr, _ = self.environment.lookup(name)
self.builder.store(value, ptr) self.builder.store(value, ptr)
def __visit_if_statement(self, node: IfStatement) -> None:
condition = node.condition
consequence = node.consequence
alternative = node.alternative
test, _ = self.__resolve_value(condition)
if alternative is None:
with self.builder.if_then(test):
self.compile(consequence)
else:
with self.builder.if_else(test) as (true, otherwise):
# Creating a condition branch
# condition
# / \
# true false
# / \
# / \
# if block else block
with true:
self.compile(consequence)
with otherwise:
self.compile(alternative)
# endregion # endregion
# region Expressions # region Expressions
@@ -156,6 +203,21 @@ class Compiler:
case "^": case "^":
# TODO # TODO
pass pass
case "<":
value = self.builder.icmp_signed('<', left_value, right_value)
Type = ir.IntType(1)
case "<=":
value = self.builder.icmp_signed('<=', left_value, right_value)
Type = ir.IntType(1)
case ">":
value = self.builder.icmp_signed('>', left_value, right_value)
Type = ir.IntType(1)
case ">=":
value = self.builder.icmp_signed('>=', left_value, right_value)
Type = ir.IntType(1)
case "==":
value = self.builder.icmp_signed('==', left_value, right_value)
Type = ir.IntType(1)
elif isinstance(right_type, ir.FloatType) and isinstance(left_type, ir.FloatType): elif isinstance(right_type, ir.FloatType) and isinstance(left_type, ir.FloatType):
Type = self.type_map["Float"] Type = self.type_map["Float"]
match operator: match operator:
@@ -172,6 +234,21 @@ class Compiler:
case "^": case "^":
# TODO # TODO
pass pass
case "<":
value = self.builder.fcmp_ordered('<', left_value, right_value)
Type = ir.IntType(1)
case "<=":
value = self.builder.fcmp_ordered('<=', left_value, right_value)
Type = ir.IntType(1)
case ">":
value = self.builder.fcmp_ordered('>', left_value, right_value)
Type = ir.IntType(1)
case ">=":
value = self.builder.fcmp_ordered('>=', left_value, right_value)
Type = ir.IntType(1)
case "==":
value = self.builder.fcmp_ordered('==', left_value, right_value)
Type = ir.IntType(1)
return value, Type return value, Type
# endregion # endregion
@@ -193,6 +270,9 @@ class Compiler:
node: IdentifierLiteral = node node: IdentifierLiteral = node
ptr, Type = self.environment.lookup(node.value) ptr, Type = self.environment.lookup(node.value)
return self.builder.load(ptr), Type return self.builder.load(ptr), Type
case NodeType.BooleanLiteral:
node: BooleanLiteral = node
return ir.Constant(ir.IntType(1), 1 if node.value else 0), ir.IntType(1)
# expression value # expression value
case NodeType.InfixExpression: case NodeType.InfixExpression:

View File

@@ -21,26 +21,57 @@
}, },
"value": { "value": {
"type": "IntegerLiteral", "type": "IntegerLiteral",
"value": 0 "value": 123
}, },
"value_type": "Int" "value_type": "Int"
}, },
{ {
"type": "ReassignStatement", "type": "ExpressionStatement",
"ident": { "expr": {
"type": "IdentifierLiteral", "type": "IfStatement",
"value": "x" "condition": {
}, "type": "InfixExpression",
"right_value": { "left_node": {
"type": "InfixExpression", "type": "IdentifierLiteral",
"left_node": { "value": "x"
"type": "IdentifierLiteral", },
"value": "x" "operator": "==",
"right_node": {
"type": "IntegerLiteral",
"value": 123
}
}, },
"operator": "*", "consequence": {
"right_node": { "type": "BlockStatement",
"type": "IntegerLiteral", "statements": [
"value": 2 {
"type": "ReassignStatement",
"ident": {
"type": "IdentifierLiteral",
"value": "x"
},
"right_value": {
"type": "IntegerLiteral",
"value": 0
}
}
]
},
"alternative": {
"type": "BlockStatement",
"statements": [
{
"type": "ReassignStatement",
"ident": {
"type": "IdentifierLiteral",
"value": "x"
},
"right_value": {
"type": "IntegerLiteral",
"value": 1
}
}
]
} }
} }
}, },

View File

@@ -2,14 +2,23 @@
target triple = "x86_64-pc-windows-msvc" target triple = "x86_64-pc-windows-msvc"
target datalayout = "" target datalayout = ""
define i32 @"test"() @"true" = constant i1 1
{ @"false" = constant i1 0
test_entry:
ret i32 0
}
define i32 @"main"() define i32 @"main"()
{ {
main_entry: main_entry:
ret i32 123 %".2" = alloca i32
store i32 6, i32* %".2"
%".4" = load i32, i32* %".2"
%".5" = icmp eq i32 %".4", 5
br i1 %".5", label %"main_entry.if", label %"main_entry.else"
main_entry.if:
store i32 0, i32* %".2"
br label %"main_entry.endif"
main_entry.else:
store i32 1, i32* %".2"
br label %"main_entry.endif"
main_entry.endif:
%".11" = load i32, i32* %".2"
ret i32 %".11"
} }

View File

@@ -95,8 +95,39 @@ class Lexer:
tok = self.__new_token(TokenType.POW, self.current_char) tok = self.__new_token(TokenType.POW, self.current_char)
case "%": case "%":
tok = self.__new_token(TokenType.MODULUS, self.current_char) tok = self.__new_token(TokenType.MODULUS, self.current_char)
case "<":
# Handle <=
if self.__peek_char() == "=":
ch = self.current_char
self.__read_char()
tok = self.__new_token(TokenType.LT_EQ, ch + self.current_char)
else:
tok = self.__new_token(TokenType.LT, self.current_char)
case ">":
# Handle >=
if self.__peek_char() == "=":
ch = self.current_char
self.__read_char()
tok = self.__new_token(TokenType.GT_EQ, ch + self.current_char)
else:
tok = self.__new_token(TokenType.GT, self.current_char)
case "=": case "=":
tok = self.__new_token(TokenType.EQ, self.current_char) # Handle ==
if self.__peek_char() == "=":
ch = self.current_char
self.__read_char()
tok = self.__new_token(TokenType.EQ_EQ, ch + self.current_char)
else:
tok = self.__new_token(TokenType.EQ, self.current_char)
case "!":
# Handle !=
if self.__peek_char() == "=":
ch = self.current_char
self.__read_char()
tok = self.__new_token(TokenType.NOT_EQ, ch + self.current_char)
else:
# TODO: handle BANG
tok = self.__new_token(TokenType.ILLEGAL, self.current_char)
case "(": case "(":
tok = self.__new_token(TokenType.LPAREN, self.current_char) tok = self.__new_token(TokenType.LPAREN, self.current_char)
case ")": case ")":

View File

@@ -29,6 +29,7 @@ class TokenType(Enum):
EQ_EQ = "==" EQ_EQ = "=="
LT_EQ = "<=" LT_EQ = "<="
GT_EQ = ">=" GT_EQ = ">="
NOT_EQ = "!="
# Symbols # Symbols
LPAREN = "LPAREN" LPAREN = "LPAREN"

View File

@@ -4,9 +4,9 @@ from typing import Callable
from enum import Enum, auto from enum import Enum, auto
from AST import Statement, Expression, Program from AST import Statement, Expression, Program
from AST import ExpressionStatement, AssignmentStatement, FunctionStatement, ReturnStatement, BlockStatement, ReassignStatement from AST import ExpressionStatement, AssignmentStatement, FunctionStatement, ReturnStatement, BlockStatement, ReassignStatement, IfStatement
from AST import InfixExpression from AST import InfixExpression
from AST import IntegerLiteral, FloatLiteral, IdentifierLiteral from AST import IntegerLiteral, FloatLiteral, IdentifierLiteral, BooleanLiteral
class PrecedenceType(Enum): class PrecedenceType(Enum):
P_LOWEST = 0 P_LOWEST = 0
@@ -25,7 +25,13 @@ PRECEDENCES: dict[TokenType, PrecedenceType] = {
TokenType.ASTERISK: PrecedenceType.P_PRODUCT, TokenType.ASTERISK: PrecedenceType.P_PRODUCT,
TokenType.SLASH: PrecedenceType.P_PRODUCT, TokenType.SLASH: PrecedenceType.P_PRODUCT,
TokenType.MODULUS: PrecedenceType.P_PRODUCT, TokenType.MODULUS: PrecedenceType.P_PRODUCT,
TokenType.POW: PrecedenceType.P_EXPONENT TokenType.POW: PrecedenceType.P_EXPONENT,
TokenType.EQ_EQ: PrecedenceType.P_EQUALS,
TokenType.NOT_EQ: PrecedenceType.P_EQUALS,
TokenType.LT: PrecedenceType.P_LESSGREATER,
TokenType.GT: PrecedenceType.P_LESSGREATER,
TokenType.LT_EQ: PrecedenceType.P_LESSGREATER,
TokenType.GT_EQ: PrecedenceType.P_LESSGREATER
} }
class Parser: class Parser:
@@ -38,22 +44,27 @@ class Parser:
self.peek_token: Token = None self.peek_token: Token = None
self.prefix_parse_functions: dict[Token, Callable] = { # -1 self.prefix_parse_functions: dict[Token, Callable] = { # -1
TokenType.IDENT: self.__parse_identifier, TokenType.IDENT: self.__parse_identifier,
TokenType.INT: self.__parse_int_literal, TokenType.INT: self.__parse_int_literal,
TokenType.FLOAT: self.__parse_float_literal, TokenType.FLOAT: self.__parse_float_literal,
TokenType.LPAREN: self.__parse_grouped_expression TokenType.LPAREN: self.__parse_grouped_expression,
TokenType.IF: self.__parse_if_statement,
TokenType.TRUE: self.__parse_boolean,
TokenType.FALSE: self.__parse_boolean,
} }
self.infix_parse_functions: dict[Token, Callable] = { # 5 + 5 self.infix_parse_functions: dict[Token, Callable] = { # 5 + 5
TokenType.PLUS: self.__parse_infix_expression, TokenType.PLUS: self.__parse_infix_expression,
TokenType.MINUS: self.__parse_infix_expression, TokenType.MINUS: self.__parse_infix_expression,
TokenType.SLASH: self.__parse_infix_expression, TokenType.SLASH: self.__parse_infix_expression,
TokenType.ASTERISK: self.__parse_infix_expression, TokenType.ASTERISK: self.__parse_infix_expression,
TokenType.POW: self.__parse_infix_expression, TokenType.POW: self.__parse_infix_expression,
TokenType.MODULUS: self.__parse_infix_expression TokenType.MODULUS: self.__parse_infix_expression,
TokenType.EQ_EQ: self.__parse_infix_expression,
TokenType.NOT_EQ: self.__parse_infix_expression,
TokenType.LT: self.__parse_infix_expression,
TokenType.GT: self.__parse_infix_expression,
TokenType.LT_EQ: self.__parse_infix_expression,
TokenType.GT_EQ: self.__parse_infix_expression,
} }
self.__next_token() self.__next_token()
@@ -229,6 +240,30 @@ class Parser:
self.__next_token() self.__next_token()
return block_stmt return block_stmt
def __parse_if_statement(self) -> IfStatement:
condition: Expression = None
consequence: BlockStatement = None
alternative: BlockStatement = None
self.__next_token()
condition = self.__parse_expression(PrecedenceType.P_LOWEST)
if not self.__expect_peek(TokenType.LBRACE):
return None
consequence = self.__parse_block_statement()
if self.__peek_token_is(TokenType.UNLESS):
self.__next_token()
if not self.__expect_peek(TokenType.LBRACE):
return None
alternative = self.__parse_block_statement()
return IfStatement(condition, consequence, alternative)
# endregion # endregion
# region Expression Methods # region Expression Methods
@@ -297,4 +332,7 @@ class Parser:
return None return None
return float_lit return float_lit
def __parse_boolean(self) -> BooleanLiteral:
return BooleanLiteral(value=self.__current_token_is(TokenType.TRUE))
# endregion # endregion