if statements work

This commit is contained in:
SpookyDervish
2025-10-14 21:23:11 +11:00
parent 48e7488a63
commit 3d9208f0f8
7 changed files with 256 additions and 34 deletions

32
AST.py
View File

@@ -12,6 +12,7 @@ class NodeType(Enum):
FunctionStatement = "FunctionStatement"
BlockStatement = "BlockStatement"
ReturnStatement = "ReturnStatement"
IfStatement = "IfStatement"
# Expressions
InfixExpression = "InfixExpression"
@@ -20,6 +21,7 @@ class NodeType(Enum):
IntegerLiteral = "IntegerLiteral"
FloatLiteral = "FloatLiteral"
IdentifierLiteral = "IdentifierLiteral"
BooleanLiteral = "BooleanLiteral"
class Node:
@abstractmethod
@@ -88,6 +90,19 @@ class IdentifierLiteral(Expression):
"type": self.type().value,
"value": self.value
}
class BooleanLiteral(Expression):
def __init__(self, value: bool = None) -> None:
self.value: bool = value
def type(self) -> NodeType:
return NodeType.BooleanLiteral
def json(self) -> dict:
return {
"type": self.type().value,
"value": self.value
}
# endregion
# region Statements
@@ -180,6 +195,23 @@ class ReassignStatement(Statement):
"ident": self.ident.json(),
"right_value": self.right_value.json()
}
class IfStatement(Statement):
def __init__(self, condition: Expression = None, consequence: BlockStatement = None, alternative: BlockStatement = None) -> None:
self.condition = condition
self.consequence = consequence
self.alternative = alternative
def type(self) -> NodeType:
return NodeType.IfStatement
def json(self) -> dict:
return {
"type": self.type().value,
"condition": self.condition.json(),
"consequence": self.consequence.json(),
"alternative": self.alternative.json() if self.alternative is not None else None
}
# endregion
# region Expressions

View File

@@ -1,9 +1,9 @@
from llvmlite import ir
from AST import Node, NodeType, Program, Expression
from AST import ExpressionStatement, AssignmentStatement, BlockStatement, ReturnStatement, FunctionStatement, ReassignStatement
from AST import ExpressionStatement, AssignmentStatement, BlockStatement, ReturnStatement, FunctionStatement, ReassignStatement, IfStatement
from AST import InfixExpression
from AST import IntegerLiteral, FloatLiteral, IdentifierLiteral
from AST import IntegerLiteral, FloatLiteral, IdentifierLiteral, BooleanLiteral
from environment import Environment
@@ -25,6 +25,26 @@ class Compiler:
self.environment: Environment = Environment()
self.errors: list[str] = []
self.__initialize_builtins()
def __initialize_builtins(self) -> None:
def __init_booleans() -> tuple[ir.GlobalVariable, ir.GlobalVariable]:
bool_type: ir.Type = self.type_map["Bool"]
true_var = ir.GlobalVariable(self.module, bool_type, "true")
true_var.initializer = ir.Constant(bool_type, 1)
true_var.global_constant = True
false_var = ir.GlobalVariable(self.module, bool_type, "false")
false_var.initializer = ir.Constant(bool_type, 0)
false_var.global_constant = True
return true_var, false_var
true_var, false_var = __init_booleans()
self.environment.define("true", true_var, true_var.type)
self.environment.define("false", false_var, false_var.type)
def compile(self, node: Node) -> None:
match node.type():
case NodeType.Program:
@@ -43,6 +63,8 @@ class Compiler:
self.__visit_return_statement(node)
case NodeType.ReassignStatement:
self.__visit_reassign_statement(node)
case NodeType.IfStatement:
self.__visit_if_statement(node)
# Expressions
case NodeType.InfixExpression:
@@ -129,6 +151,31 @@ class Compiler:
else:
ptr, _ = self.environment.lookup(name)
self.builder.store(value, ptr)
def __visit_if_statement(self, node: IfStatement) -> None:
condition = node.condition
consequence = node.consequence
alternative = node.alternative
test, _ = self.__resolve_value(condition)
if alternative is None:
with self.builder.if_then(test):
self.compile(consequence)
else:
with self.builder.if_else(test) as (true, otherwise):
# Creating a condition branch
# condition
# / \
# true false
# / \
# / \
# if block else block
with true:
self.compile(consequence)
with otherwise:
self.compile(alternative)
# endregion
# region Expressions
@@ -156,6 +203,21 @@ class Compiler:
case "^":
# TODO
pass
case "<":
value = self.builder.icmp_signed('<', left_value, right_value)
Type = ir.IntType(1)
case "<=":
value = self.builder.icmp_signed('<=', left_value, right_value)
Type = ir.IntType(1)
case ">":
value = self.builder.icmp_signed('>', left_value, right_value)
Type = ir.IntType(1)
case ">=":
value = self.builder.icmp_signed('>=', left_value, right_value)
Type = ir.IntType(1)
case "==":
value = self.builder.icmp_signed('==', left_value, right_value)
Type = ir.IntType(1)
elif isinstance(right_type, ir.FloatType) and isinstance(left_type, ir.FloatType):
Type = self.type_map["Float"]
match operator:
@@ -172,6 +234,21 @@ class Compiler:
case "^":
# TODO
pass
case "<":
value = self.builder.fcmp_ordered('<', left_value, right_value)
Type = ir.IntType(1)
case "<=":
value = self.builder.fcmp_ordered('<=', left_value, right_value)
Type = ir.IntType(1)
case ">":
value = self.builder.fcmp_ordered('>', left_value, right_value)
Type = ir.IntType(1)
case ">=":
value = self.builder.fcmp_ordered('>=', left_value, right_value)
Type = ir.IntType(1)
case "==":
value = self.builder.fcmp_ordered('==', left_value, right_value)
Type = ir.IntType(1)
return value, Type
# endregion
@@ -193,6 +270,9 @@ class Compiler:
node: IdentifierLiteral = node
ptr, Type = self.environment.lookup(node.value)
return self.builder.load(ptr), Type
case NodeType.BooleanLiteral:
node: BooleanLiteral = node
return ir.Constant(ir.IntType(1), 1 if node.value else 0), ir.IntType(1)
# expression value
case NodeType.InfixExpression:

View File

@@ -21,26 +21,57 @@
},
"value": {
"type": "IntegerLiteral",
"value": 0
"value": 123
},
"value_type": "Int"
},
{
"type": "ReassignStatement",
"ident": {
"type": "IdentifierLiteral",
"value": "x"
},
"right_value": {
"type": "InfixExpression",
"left_node": {
"type": "IdentifierLiteral",
"value": "x"
"type": "ExpressionStatement",
"expr": {
"type": "IfStatement",
"condition": {
"type": "InfixExpression",
"left_node": {
"type": "IdentifierLiteral",
"value": "x"
},
"operator": "==",
"right_node": {
"type": "IntegerLiteral",
"value": 123
}
},
"operator": "*",
"right_node": {
"type": "IntegerLiteral",
"value": 2
"consequence": {
"type": "BlockStatement",
"statements": [
{
"type": "ReassignStatement",
"ident": {
"type": "IdentifierLiteral",
"value": "x"
},
"right_value": {
"type": "IntegerLiteral",
"value": 0
}
}
]
},
"alternative": {
"type": "BlockStatement",
"statements": [
{
"type": "ReassignStatement",
"ident": {
"type": "IdentifierLiteral",
"value": "x"
},
"right_value": {
"type": "IntegerLiteral",
"value": 1
}
}
]
}
}
},

View File

@@ -2,14 +2,23 @@
target triple = "x86_64-pc-windows-msvc"
target datalayout = ""
define i32 @"test"()
{
test_entry:
ret i32 0
}
@"true" = constant i1 1
@"false" = constant i1 0
define i32 @"main"()
{
main_entry:
ret i32 123
%".2" = alloca i32
store i32 6, i32* %".2"
%".4" = load i32, i32* %".2"
%".5" = icmp eq i32 %".4", 5
br i1 %".5", label %"main_entry.if", label %"main_entry.else"
main_entry.if:
store i32 0, i32* %".2"
br label %"main_entry.endif"
main_entry.else:
store i32 1, i32* %".2"
br label %"main_entry.endif"
main_entry.endif:
%".11" = load i32, i32* %".2"
ret i32 %".11"
}

View File

@@ -95,8 +95,39 @@ class Lexer:
tok = self.__new_token(TokenType.POW, self.current_char)
case "%":
tok = self.__new_token(TokenType.MODULUS, self.current_char)
case "<":
# Handle <=
if self.__peek_char() == "=":
ch = self.current_char
self.__read_char()
tok = self.__new_token(TokenType.LT_EQ, ch + self.current_char)
else:
tok = self.__new_token(TokenType.LT, self.current_char)
case ">":
# Handle >=
if self.__peek_char() == "=":
ch = self.current_char
self.__read_char()
tok = self.__new_token(TokenType.GT_EQ, ch + self.current_char)
else:
tok = self.__new_token(TokenType.GT, self.current_char)
case "=":
tok = self.__new_token(TokenType.EQ, self.current_char)
# Handle ==
if self.__peek_char() == "=":
ch = self.current_char
self.__read_char()
tok = self.__new_token(TokenType.EQ_EQ, ch + self.current_char)
else:
tok = self.__new_token(TokenType.EQ, self.current_char)
case "!":
# Handle !=
if self.__peek_char() == "=":
ch = self.current_char
self.__read_char()
tok = self.__new_token(TokenType.NOT_EQ, ch + self.current_char)
else:
# TODO: handle BANG
tok = self.__new_token(TokenType.ILLEGAL, self.current_char)
case "(":
tok = self.__new_token(TokenType.LPAREN, self.current_char)
case ")":

View File

@@ -29,6 +29,7 @@ class TokenType(Enum):
EQ_EQ = "=="
LT_EQ = "<="
GT_EQ = ">="
NOT_EQ = "!="
# Symbols
LPAREN = "LPAREN"

View File

@@ -4,9 +4,9 @@ from typing import Callable
from enum import Enum, auto
from AST import Statement, Expression, Program
from AST import ExpressionStatement, AssignmentStatement, FunctionStatement, ReturnStatement, BlockStatement, ReassignStatement
from AST import ExpressionStatement, AssignmentStatement, FunctionStatement, ReturnStatement, BlockStatement, ReassignStatement, IfStatement
from AST import InfixExpression
from AST import IntegerLiteral, FloatLiteral, IdentifierLiteral
from AST import IntegerLiteral, FloatLiteral, IdentifierLiteral, BooleanLiteral
class PrecedenceType(Enum):
P_LOWEST = 0
@@ -25,7 +25,13 @@ PRECEDENCES: dict[TokenType, PrecedenceType] = {
TokenType.ASTERISK: PrecedenceType.P_PRODUCT,
TokenType.SLASH: PrecedenceType.P_PRODUCT,
TokenType.MODULUS: PrecedenceType.P_PRODUCT,
TokenType.POW: PrecedenceType.P_EXPONENT
TokenType.POW: PrecedenceType.P_EXPONENT,
TokenType.EQ_EQ: PrecedenceType.P_EQUALS,
TokenType.NOT_EQ: PrecedenceType.P_EQUALS,
TokenType.LT: PrecedenceType.P_LESSGREATER,
TokenType.GT: PrecedenceType.P_LESSGREATER,
TokenType.LT_EQ: PrecedenceType.P_LESSGREATER,
TokenType.GT_EQ: PrecedenceType.P_LESSGREATER
}
class Parser:
@@ -38,22 +44,27 @@ class Parser:
self.peek_token: Token = None
self.prefix_parse_functions: dict[Token, Callable] = { # -1
TokenType.IDENT: self.__parse_identifier,
TokenType.INT: self.__parse_int_literal,
TokenType.FLOAT: self.__parse_float_literal,
TokenType.LPAREN: self.__parse_grouped_expression
TokenType.LPAREN: self.__parse_grouped_expression,
TokenType.IF: self.__parse_if_statement,
TokenType.TRUE: self.__parse_boolean,
TokenType.FALSE: self.__parse_boolean,
}
self.infix_parse_functions: dict[Token, Callable] = { # 5 + 5
TokenType.PLUS: self.__parse_infix_expression,
TokenType.MINUS: self.__parse_infix_expression,
TokenType.SLASH: self.__parse_infix_expression,
TokenType.ASTERISK: self.__parse_infix_expression,
TokenType.POW: self.__parse_infix_expression,
TokenType.MODULUS: self.__parse_infix_expression
TokenType.MODULUS: self.__parse_infix_expression,
TokenType.EQ_EQ: self.__parse_infix_expression,
TokenType.NOT_EQ: self.__parse_infix_expression,
TokenType.LT: self.__parse_infix_expression,
TokenType.GT: self.__parse_infix_expression,
TokenType.LT_EQ: self.__parse_infix_expression,
TokenType.GT_EQ: self.__parse_infix_expression,
}
self.__next_token()
@@ -229,6 +240,30 @@ class Parser:
self.__next_token()
return block_stmt
def __parse_if_statement(self) -> IfStatement:
condition: Expression = None
consequence: BlockStatement = None
alternative: BlockStatement = None
self.__next_token()
condition = self.__parse_expression(PrecedenceType.P_LOWEST)
if not self.__expect_peek(TokenType.LBRACE):
return None
consequence = self.__parse_block_statement()
if self.__peek_token_is(TokenType.UNLESS):
self.__next_token()
if not self.__expect_peek(TokenType.LBRACE):
return None
alternative = self.__parse_block_statement()
return IfStatement(condition, consequence, alternative)
# endregion
# region Expression Methods
@@ -297,4 +332,7 @@ class Parser:
return None
return float_lit
def __parse_boolean(self) -> BooleanLiteral:
return BooleanLiteral(value=self.__current_token_is(TokenType.TRUE))
# endregion