from typing import List, Dict, NamedTuple, Union import re from console import console class Instruction(NamedTuple): opcode: str operands: List[str] class PeepholeRule(NamedTuple): match: List[Instruction] replace: List[Instruction] class Optimizer: def __init__(self, lines: list[str], window_size: int = 6): self.lines = lines self.final_str = "" # sliding window self.window_size = window_size self.rules: list[PeepholeRule] = [] def match_instruction(self, pattern: Instruction, instr: Instruction, bindings: Dict[str, str]) -> bool: #print(pattern.opcode == instr.opcode, instr.opcode, instr.operands, pattern.opcode) if pattern.opcode != instr.opcode: return False if len(pattern.operands) != len(instr.operands): return False for p_op, i_op in zip(pattern.operands, instr.operands): if p_op.isidentifier(): # wildcard like 'x' if p_op in bindings: if bindings[p_op] != i_op: return False else: bindings[p_op] = i_op else: if p_op != i_op: return False return True def match_window(self, patterns: List[Instruction], window: List[Instruction]) -> Union[Dict[str, str], None]: if len(patterns) != len(window): return None bindings = {} for p, i in zip(patterns, window): if not self.match_instruction(p, i, bindings): return None return bindings def parse_instruction(self, line: str) -> Instruction: # Assumes format: OPCODE operand1, operand2 parts = line.strip().split(None, 1) opcode = parts[0] operands = re.split(r'\s*,\s*', parts[1]) if len(parts) > 1 else [] return Instruction(opcode, operands) def apply_replacement(self, replacements: List[Instruction], bindings: dict[str, str]) -> List[Instruction]: result = [] for pattern in replacements: operands = [bindings.get(op, op) for op in pattern.operands] result.append(Instruction(pattern.opcode, operands)) return result def optimize_until_stable(self): while True: passes = 1 old_instructions = self.lines new_instructions = self.peephole() if old_instructions == new_instructions: print(f"Optimized in {passes} passes.") return new_instructions old_instructions = new_instructions passes += 1 def peephole(self): self.lines = [self.parse_instruction(line) for line in self.lines] i = 0 while i < len(self.lines): matched = False for rule in self.rules: window = self.lines[i:i+len(rule.match)] bindings = self.match_window(rule.match, window) if bindings is not None: replacement_instrs = self.apply_replacement(rule.replace, bindings) self.lines[i:i + len(rule.match)] = replacement_instrs matched = True break if not matched: i += 1 self.lines = [f"{line.opcode} {', '.join(line.operands)}\n" for line in self.lines] return self.lines