sqrt function works... finally... ;-;

This commit is contained in:
SpookyDervish
2025-10-18 19:28:54 +11:00
parent 70bc672885
commit 4bd698e330
8 changed files with 83 additions and 71 deletions

View File

@@ -22,11 +22,15 @@ class Compiler:
"Short": ir.IntType(16),
"Int": ir.IntType(32),
"Long": ir.IntType(64),
"Float": ir.FloatType(),
"Float": ir.DoubleType(),
"Double": ir.DoubleType(),
"String": ir.PointerType(ir.IntType(8)),
"Nil": ir.VoidType()
}
self.py_type_map: dict[str, type] = {
"Int": int,
"Float": float
}
self.module: ir.Module = ir.Module("main")
self.builder: ir.IRBuilder = ir.IRBuilder()
@@ -50,6 +54,14 @@ class Compiler:
var_arg=True
)
return ir.Function(self.module, fnty, "printf")
def __init_sqrt() -> ir.Function:
fnty: ir.FunctionType = ir.FunctionType(
self.type_map["Float"],
[ir.DoubleType()],
var_arg=False
)
return ir.Function(self.module, fnty, "sqrt")
def __init_booleans() -> tuple[ir.GlobalVariable, ir.GlobalVariable]:
bool_type: ir.Type = self.type_map["Bool"]
@@ -65,6 +77,7 @@ class Compiler:
return true_var, false_var
self.environment.define("print", __init_print(), ir.IntType(32))
self.environment.define("sqrt", __init_sqrt(), ir.IntType(32))
true_var, false_var = __init_booleans()
self.environment.define("true", true_var, true_var.type)
@@ -221,6 +234,8 @@ class Compiler:
value = None
Type = None
match operator:
case "=":
value = right_value
@@ -246,8 +261,10 @@ class Compiler:
value = self.builder.fdiv(orig_value, right_value)
case _:
print("Unsupported assignment operator.")
return
ptr, _ = self.environment.lookup(name)
self.builder.store(value, ptr)
def __visit_if_statement(self, node: IfStatement) -> None:
@@ -281,9 +298,14 @@ class Compiler:
test, _ = self.__resolve_value(condition)
while_loop_entry = self.builder.append_basic_block(f"while_loop_entry_{self.__increment_counter()}")
while_loop_otherwise = self.builder.append_basic_block(f"while_loop_otherwise_{self.counter}")
self.breakpoints.append(while_loop_otherwise)
self.continues.append(while_loop_entry)
# Creating a condition branch
# condition
# / \
@@ -299,6 +321,9 @@ class Compiler:
self.builder.cbranch(test, while_loop_entry, while_loop_otherwise)
self.builder.position_at_start(while_loop_otherwise)
self.breakpoints.pop()
self.continues.pop()
def __visit_break_statement(self, node: BreakStatement) -> None:
self.builder.branch(self.breakpoints[-1])
@@ -458,7 +483,14 @@ class Compiler:
case "print":
ret = self.builtin_print(params=args, return_type=types[0])
ret_type = self.type_map["Int"]
case "sqrt":
ret = self.builtin_sqrt(params=args)
ret_type = self.type_map["Float"]
case _:
if not self.environment.lookup(name):
print(f"The function \"{name}\" is not defined.")
exit(1)
func, ret_type = self.environment.lookup(name)
ret = self.builder.call(func, args)
@@ -563,6 +595,7 @@ class Compiler:
return global_fmt, global_fmt.type
# region Builtins
def builtin_print(self, params: list[ir.Instruction], return_type: ir.Type) -> None:
func, _ = self.environment.lookup("print")
@@ -582,4 +615,19 @@ class Compiler:
# printing from a normal string
fmt_arg = self.builder.bitcast(self.module.get_global(f"__str_{self.counter}"), ir.IntType(8).as_pointer())
return self.builder.call(func, [fmt_arg, *rest_params])
def builtin_sqrt(self, params: list[ir.Instruction]) -> None:
func, _ = self.environment.lookup("sqrt")
c_float = self.builder.alloca(self.type_map["Float"])
self.builder.store(params[0], c_float)
if isinstance(params[0], ir.LoadInstr):
c_fmt: ir.LoadInstr = params[0]
g_var_ptr = c_fmt.operands[0]
float_val = self.builder.load(g_var_ptr)
return self.builder.call(func, [float_val])
else:
return self.builder.call(func, [params[0]])
# endregion
# endregion

View File

@@ -1,47 +0,0 @@
{
"type": "Program",
"statements": [
{
"DependStatement": {
"type": "DependStatement",
"file_path": "tests/math.pla"
}
},
{
"FunctionStatement": {
"type": "FunctionStatement",
"name": {
"type": "IdentifierLiteral",
"value": "main"
},
"return_type": "Int",
"parameters": [],
"body": {
"type": "BlockStatement",
"statements": [
{
"type": "ReturnStatement",
"return_value": {
"type": "CallExpression",
"function": {
"type": "IdentifierLiteral",
"value": "add"
},
"arguments": [
{
"type": "IntegerLiteral",
"value": 1
},
{
"type": "IntegerLiteral",
"value": 2
}
]
}
}
]
}
}
}
]
}

View File

@@ -1,13 +0,0 @@
; ModuleID = "main"
target triple = "x86_64-unknown-linux-gnu"
target datalayout = ""
declare i32 @"printf"(i8* %".1", ...)
@"true" = constant i1 1
@"false" = constant i1 0
define i32 @"main"()
{
main_entry:
ret i32 0
}

27
ir.ll Normal file
View File

@@ -0,0 +1,27 @@
; ModuleID = "main"
target triple = "x86_64-pc-windows-msvc"
target datalayout = ""
declare i32 @"printf"(i8* %".1", ...)
declare double @"sqrt"(double %".1")
@"true" = constant i1 1
@"false" = constant i1 0
define i32 @"main"()
{
main_entry:
%".2" = alloca double
store double 0x4010000000000000, double* %".2"
%".4" = call double @"sqrt"(double 0x4010000000000000)
%".5" = alloca double
store double %".4", double* %".5"
%".7" = load double, double* %".5"
%".8" = alloca [5 x i8]*
store [5 x i8]* @"__str_1", [5 x i8]** %".8"
%".10" = bitcast [5 x i8]* @"__str_1" to i8*
%".11" = call i32 (i8*, ...) @"printf"(i8* %".10", double %".7")
ret i32 0
}
@"__str_1" = internal constant [5 x i8] c"%f\0a\00\00"

View File

@@ -46,7 +46,7 @@ if __name__ == "__main__":
parse_st: float = time.time()
program: Program = p.parse_program()
parse_et: float = time.time()
print(f"Parsed in {round((parse_et - parse_st) * 1000, 6)} ms.")
#print(f"Parsed in {round((parse_et - parse_st) * 1000, 6)} ms.")
if len(p.errors) > 0:
for err in p.errors:
@@ -67,7 +67,7 @@ if __name__ == "__main__":
compiler_st: float = time.time()
c.compile(program)
compiler_et: float = time.time()
print(f"Compiled in {round((compiler_et - compiler_st) * 1000, 6)} ms.")
#print(f"Compiled in {round((compiler_et - compiler_st) * 1000, 6)} ms.")
module: ir.Module = c.module
module.triple = llvm.get_default_triple()
@@ -101,4 +101,5 @@ if __name__ == "__main__":
et = time.time()
print(f"\n\nProgram returned: {result}\n=== Executed in {round((et - st) * 1000, 6)} ms. ===")
#print(f"\n\nProgram returned: {result}\n=== Executed in {round((et - st) * 1000, 6)} ms. ===")
exit(result)

View File

@@ -381,7 +381,7 @@ class Parser:
self.__next_token() # skip ;
stmt.action = self.__parse_expression(PrecedenceType.P_LOWEST)
stmt.action = self.__parse_assignment_statement()
self.__next_token()

View File

@@ -1,3 +0,0 @@
add = Func(a: Int, b: Int): Int {
return a + b;
}

View File

@@ -1,5 +1,4 @@
depend "tests/math.pla";
main = Func(): Int {
return $add(1 ,2);
$print("%f\n", $sqrt(9.0));
return 0;
}