AST is accepting functions!!!

This commit is contained in:
SpookyDervish
2025-10-14 07:14:53 +11:00
parent f9cd1dba29
commit 518a19d3bf
9 changed files with 243 additions and 101 deletions

150
AST.py
View File

@@ -8,6 +8,9 @@ class NodeType(Enum):
# Statements
ExpressionStatement = "ExpressionStatement"
AssignmentStatement = "AssignmentStatement"
FunctionStatement = "FunctionStatement"
BlockStatement = "BlockStatement"
ReturnStatement = "ReturnStatement"
# Expressions
InfixExpression = "InfixExpression"
@@ -45,57 +48,6 @@ class Program(Node):
"statements": [{stmt.type().value: stmt.json()} for stmt in self.statements]
}
# region Statements
class ExpressionStatement(Statement):
def __init__(self, expr: Expression = None) -> None:
self.expr: Expression = expr
def type(self) -> NodeType:
return NodeType.ExpressionStatement
def json(self) -> dict:
return {
"type": self.type().value,
"expr": self.expr.json()
}
class AssignmentStatement(Statement):
def __init__(self, name: Expression = None, value: Expression = None, value_type: str = None) -> None:
self.name = name
self.value = value
self.value_type = value_type
def type(self) -> NodeType:
return NodeType.AssignmentStatement
def json(self) -> dict:
return {
"type": self.type().value,
"name": self.name.json(),
"value": self.value.json(),
"value_type": self.value_type
}
# endregion
# region Expressions
class InfixExpression(Expression):
def __init__(self, left_node: Expression, operator: str, right_node: Expression = None) -> None:
self.left_node: Expression = left_node
self.operator: str = operator
self.right_node: Expression = right_node
def type(self) -> NodeType:
return NodeType.InfixExpression
def json(self) -> dict:
return {
"type": self.type().value,
"left_node": self.left_node.json(),
"operator": self.operator,
"right_node": self.right_node.json()
}
# endregion
# region Literals
class IntegerLiteral(Expression):
def __init__(self, value: int = None) -> None:
@@ -136,3 +88,99 @@ class IdentifierLiteral(Expression):
"value": self.value
}
# endregion
# region Statements
class ExpressionStatement(Statement):
def __init__(self, expr: Expression = None) -> None:
self.expr: Expression = expr
def type(self) -> NodeType:
return NodeType.ExpressionStatement
def json(self) -> dict:
return {
"type": self.type().value,
"expr": self.expr.json()
}
class AssignmentStatement(Statement):
def __init__(self, name: Expression = None, value: Expression = None, value_type: str = None) -> None:
self.name = name
self.value = value
self.value_type = value_type
def type(self) -> NodeType:
return NodeType.AssignmentStatement
def json(self) -> dict:
return {
"type": self.type().value,
"name": self.name.json(),
"value": self.value.json(),
"value_type": self.value_type
}
class BlockStatement(Statement):
def __init__(self, statements: list[Statement] = None) -> None:
self.statements: list[Statement] = statements if statements is not None else []
def type(self) -> NodeType:
return NodeType.BlockStatement
def json(self) -> dict:
return {
"type": self.type().value,
"statements": [stmt.json() for stmt in self.statements]
}
class ReturnStatement(Statement):
def __init__(self, return_value: Expression = None) -> None:
self.return_value = return_value
def type(self) -> NodeType:
return NodeType.ReturnStatement
def json(self):
return {
"type": self.type().value,
"return_value": self.return_value.json()
}
class FunctionStatement(Statement):
def __init__(self, parameters: list = [], body: BlockStatement = None, name: IdentifierLiteral = None, return_type: str = None):
self.parameters = parameters
self.body = body
self.name = name
self.return_type = return_type
def type(self) -> NodeType:
return NodeType.FunctionStatement
def json(self) -> dict:
return {
"type": self.type().value,
"name": self.name.json(),
"return_type": self.return_type,
"parameters": [p.json() for p in self.parameters],
"body": self.body.json()
}
# endregion
# region Expressions
class InfixExpression(Expression):
def __init__(self, left_node: Expression, operator: str, right_node: Expression = None) -> None:
self.left_node: Expression = left_node
self.operator: str = operator
self.right_node: Expression = right_node
def type(self) -> NodeType:
return NodeType.InfixExpression
def json(self) -> dict:
return {
"type": self.type().value,
"left_node": self.left_node.json(),
"operator": self.operator,
"right_node": self.right_node.json()
}
# endregion

View File

@@ -1,7 +1,7 @@
from llvmlite import ir
from AST import Node, NodeType, Program, Expression
from AST import ExpressionStatement, AssignmentStatement
from AST import ExpressionStatement, AssignmentStatement, BlockStatement
from AST import InfixExpression
from AST import IntegerLiteral, FloatLiteral, IdentifierLiteral

View File

@@ -2,17 +2,46 @@
"type": "Program",
"statements": [
{
"AssignmentStatement": {
"FunctionStatement": {
"type": "FunctionStatement",
"name": {
"type": "IdentifierLiteral",
"value": "main"
},
"return_type": "Int",
"parameters": [],
"body": {
"type": "BlockStatement",
"statements": [
{
"type": "AssignmentStatement",
"name": {
"type": "IdentifierLiteral",
"value": "myVar"
"value": "x"
},
"value": {
"type": "IntegerLiteral",
"value": 1
"value": 123
},
"value_type": "Bool"
"value_type": "Int"
},
{
"type": "ReturnStatement",
"return_value": {
"type": "InfixExpression",
"left_node": {
"type": "IdentifierLiteral",
"value": "x"
},
"operator": "+",
"right_node": {
"type": "IntegerLiteral",
"value": 5
}
}
}
]
}
}
}
]

View File

@@ -1,13 +0,0 @@
; ModuleID = "main"
target triple = "x86_64-pc-windows-msvc"
target datalayout = ""
define i32 @"main"()
{
main_entry:
%".2" = alloca float
store float 0x3ff3ae1480000000, float* %".2"
%".4" = alloca i32
store i32 456, i32* %".4"
ret i32 123
}

View File

@@ -23,6 +23,12 @@ class Lexer:
self.position = self.read_position
self.read_position += 1
def __peek_char(self) -> str | None:
if self.read_position >= len(self.source):
return None
return self.source[self.read_position]
def __skip_whitespace(self) -> None:
while self.current_char in [' ', '\t', '\n', '\r']:
if self.current_char == "\n":
@@ -100,9 +106,9 @@ class Lexer:
case "]":
tok = self.__new_token(TokenType.RBRACKET, self.current_char)
case "{":
tok = self.__new_token(TokenType.LCURLY, self.current_char)
tok = self.__new_token(TokenType.LBRACE, self.current_char)
case "}":
tok = self.__new_token(TokenType.RCURLY, self.current_char)
tok = self.__new_token(TokenType.RBRACE, self.current_char)
case ";":
tok = self.__new_token(TokenType.SEMICOLON, self.current_char)
case ":":

View File

@@ -28,12 +28,13 @@ class TokenType(Enum):
RPAREN = "RPAREN"
LBRACKET = "LBRACKET"
RBRACKET = "RBRACKET"
LCURLY = "LCURLY"
RCURLY = "RCURLY"
LBRACE = "LBRACE"
RBRACE = "RBRACE"
COLON = "COLON"
SEMICOLON = "SEMICOLON"
# Keywords
RETURN = "RETURN"
# Typing
TYPE = "TYPE"
@@ -53,14 +54,14 @@ class Token:
KEYWORDS: dict[str, TokenType] = {
"return": TokenType.RETURN
}
ALT_KEYWORDS: dict[str, TokenType] = {
}
TYPE_KEYWORDS: list[str] = ["Int", "Float", "String", "Bool", "List", "Nil"]
TYPE_KEYWORDS: list[str] = ["Int", "Float", "String", "Bool", "List", "Nil", "Func"]
def lookup_ident(ident: str) -> TokenType:
tt: TokenType | None = KEYWORDS.get(ident)

View File

@@ -9,8 +9,8 @@ from llvmlite.binding import targets
from ctypes import CFUNCTYPE, c_int, c_float
LEXER_DEBUG: bool = False
PARSER_DEBUG: bool = False
COMPILER_DEBUG: bool = True
PARSER_DEBUG: bool = True
COMPILER_DEBUG: bool = False
if __name__ == "__main__":

View File

@@ -4,7 +4,7 @@ from typing import Callable
from enum import Enum, auto
from AST import Statement, Expression, Program
from AST import ExpressionStatement, AssignmentStatement
from AST import ExpressionStatement, AssignmentStatement, FunctionStatement, ReturnStatement, BlockStatement
from AST import InfixExpression
from AST import IntegerLiteral, FloatLiteral, IdentifierLiteral
@@ -39,6 +39,7 @@ class Parser:
self.prefix_parse_functions: dict[Token, Callable] = { # -1
TokenType.IDENT: self.__parse_identifier,
TokenType.INT: self.__parse_int_literal,
TokenType.FLOAT: self.__parse_float_literal,
TokenType.LPAREN: self.__parse_grouped_expression
@@ -114,6 +115,8 @@ class Parser:
match self.current_token.type:
case TokenType.IDENT:
return self.__parse_assignment_statement()
case TokenType.RETURN:
return self.__parse_return_statement()
case _:
return self.__parse_expression_statement()
@@ -131,6 +134,43 @@ class Parser:
# x: Int = 10;
stmt: AssignmentStatement = AssignmentStatement(name=IdentifierLiteral(self.current_token.literal))
if self.__peek_token_is(TokenType.EQ): # function definition
# x = Func(): Int { return 10; }
self.__next_token()
func_stmt: FunctionStatement = FunctionStatement(name=stmt.name)
if not self.__expect_peek(TokenType.TYPE): # Func word
return None
if self.current_token.literal != "Func":
self.errors.append(f"Expected next token to be \"Func\", got {self.current_token.literal} instead.")
return None
if not self.__expect_peek(TokenType.LPAREN):
return None
func_stmt.parameters = []
if not self.__expect_peek(TokenType.RPAREN):
return None
if not self.__expect_peek(TokenType.COLON):
return None
if not self.__expect_peek(TokenType.TYPE):
return None
func_stmt.return_type = self.current_token.literal
if not self.__expect_peek(TokenType.LBRACE):
return None
func_stmt.body = self.__parse_block_statement()
return func_stmt
else:
if not self.__expect_peek(TokenType.COLON):
return None
@@ -150,6 +190,32 @@ class Parser:
self.__next_token()
return stmt
def __parse_return_statement(self) -> ReturnStatement:
stmt: ReturnStatement = ReturnStatement()
self.__next_token()
stmt.return_value = self.__parse_expression(PrecedenceType.P_LOWEST)
if not self.__expect_peek(TokenType.SEMICOLON):
return None
return stmt
def __parse_block_statement(self) -> BlockStatement:
block_stmt: BlockStatement = BlockStatement()
self.__next_token()
while not self.__current_token_is(TokenType.RBRACE) and not self.__current_token_is(TokenType.EOF):
stmt: Statement = self.__parse_statement()
if stmt is not None:
block_stmt.statements.append(stmt)
self.__next_token()
return block_stmt
# endregion
# region Expression Methods
@@ -194,6 +260,9 @@ class Parser:
# endregion
# region Prefix Methods
def __parse_identifier(self) -> IdentifierLiteral:
return IdentifierLiteral(value=self.current_token.literal)
def __parse_int_literal(self) -> Expression:
int_lit: IntegerLiteral = IntegerLiteral()

View File

@@ -1,2 +1,4 @@
a: Float = 1.23;
b: Int = 456;
main = Func(): Int {
x: Int = 123;
return x + 5;
}