From a484137be24af1316c03aa5abedf9032a5b96884 Mon Sep 17 00:00:00 2001 From: Reid 'arrdem' McKenzie Date: Fri, 12 Aug 2022 23:26:42 -0600 Subject: [PATCH] Rewrite using handlers, singledispatch, hooks --- .../shoggoth/src/python/ichor/assembler.py | 2 +- .../shoggoth/src/python/ichor/interpreter.py | 564 +++++++++++------- .../shoggoth/test/python/ichor/fixtures.py | 23 +- .../test/python/ichor/test_interpreter.py | 5 +- 4 files changed, 355 insertions(+), 239 deletions(-) diff --git a/projects/shoggoth/src/python/ichor/assembler.py b/projects/shoggoth/src/python/ichor/assembler.py index 9941e98..93bd599 100644 --- a/projects/shoggoth/src/python/ichor/assembler.py +++ b/projects/shoggoth/src/python/ichor/assembler.py @@ -45,7 +45,7 @@ class FuncBuilder(object): pass case _: - self._stack -= getattr(op, 'nargs', 0) + self._stack -= getattr(op, "nargs", 0) self._stack += 1 def write(self, op: Union[isa.Opcode, isa.Label, Sequence[isa.Opcode]]): diff --git a/projects/shoggoth/src/python/ichor/interpreter.py b/projects/shoggoth/src/python/ichor/interpreter.py index ed372bb..d0eb2b3 100644 --- a/projects/shoggoth/src/python/ichor/interpreter.py +++ b/projects/shoggoth/src/python/ichor/interpreter.py @@ -9,8 +9,12 @@ context (a virtual machine) which DOES have an easily introspected and serialize """ -from copy import deepcopy -from textwrap import indent +from dataclasses import dataclass +from functools import ( + singledispatch, + update_wrapper, +) +from typing import Optional from ichor import isa from ichor.state import ( @@ -25,13 +29,42 @@ from ichor.state import ( ) +@dataclass +class InterpreterState(object): + module: Module + stackframe: Stackframe + clock: int = 0 + + class InterpreterError(Exception): """An error raised by the interpreter when something goes awry.""" - def __init__(self, module, stack, message=None): - self.module = module - self.stack = stack - super().__init__(message) + def __init__(self, state: InterpreterState, message: str, cause: Optional[Exception] = None): + super().__init__(message, cause) + self.state = state + + +class InterpreterReturn(Exception): + def __init__(self, val): + super().__init__() + self.val = val + + +class InterpreterRestart(Exception): + def __init__(self, state: InterpreterState): + super().__init__() + self.state = state + + +def handledispatch(func): + dispatcher = singledispatch(func) + def wrapper(self, state, opcode): + assert isinstance(state, InterpreterState) + assert isinstance(opcode, isa.Opcode) + return dispatcher.dispatch(opcode.__class__)(self, state, opcode) + wrapper.register = dispatcher.register + update_wrapper(wrapper, func) + return wrapper class Interpreter(object): @@ -39,238 +72,305 @@ class Interpreter(object): def __init__(self, bootstrap_module: Module): self.bootstrap = bootstrap_module + def pre_instr(self, state: InterpreterState, opcode: isa.Opcode) -> InterpreterState: + return state + + def post_instr(self, state: InterpreterState, opcode: isa.Opcode) -> InterpreterState: + return state + + def step(self, state: InterpreterState, opcode: isa.Opcode) -> InterpreterState: + return self.handle_opcode(state, opcode) + + def handle_unknown(self, state: InterpreterState, opcode: isa.Opcode) -> InterpreterState: + raise InterpreterError(state, "Unsupported operation: {opcode}") + + def handle_fault(self, state, opcode, message, cause=None) -> InterpreterState: + raise InterpreterError(state, message, cause) + + @handledispatch + def handle_opcode(self, state: InterpreterState, opcode: isa.Opcode) -> InterpreterState: + return self.handle_unknown(state, opcode) + + @handle_opcode.register(isa.IDENTIFIERC) + def _handle_identifierc(self, state: InterpreterState, opcode: isa.IDENTIFIERC) -> InterpreterState: + name = opcode.val + if not (name in state.module.functions + or name in state.module.types + or any(name in t.constructors for t in state.module.types.values())): + return self.handle_fault(state, opcode, "IDENTIFIERC references unknown entity") + + state.stackframe.push(Identifier(name)) + state.stackframe._ip += 1 + return state + + @handle_opcode.register(isa.TYPEREF) + def _handle_typeref(self, state, opcode) -> InterpreterState: + id = state.stackframe.pop() + if not isinstance(id, Identifier): + return self.handle_fault(state, opcode, "TYPEREF consumes an identifier") + + if not id.name in state.module.types: + return self.handle_fault(state, opcode, "TYPEREF must be given a valid type identifier") + + state.stackframe.push(TypeRef(id.name)) + state.stackframe._ip += 1 + return state + + @handle_opcode.register(isa.ARMREF) + def _handle_armref(self, state, opcode) -> InterpreterState: + id: Identifier = state.stackframe.pop() + if not isinstance(id, Identifier): + return self.handle_fault(state, opcode, "VARIANTREF consumes an identifier and a typeref") + + t: TypeRef = state.stackframe.pop() + if not isinstance(t, TypeRef): + return self.handle_fault(state, opcode, "VARIANTREF consumes an identifier and a typeref") + + type = state.module.types[t.name] + if id.name not in type.constructors: + return self.handle_fault(state, opcode, f"VARIANTREF given {id.name!r} which does not name a constructor within {type!r}") + + state.stackframe.push(VariantRef(t, id.name)) + state.stackframe._ip += 1 + return state + + @handle_opcode.register(isa.ARM) + def _handle_arm(self, state: InterpreterState, opcode: isa.ARM) -> InterpreterState: + armref: VariantRef = state.stackframe.pop() + if not isinstance(armref, VariantRef): + return self.handle_fault(state, opcode, "VARIANT must be given a valid constructor reference") + + ctor = state.module.types[armref.type.name].constructors[armref.arm] + if opcode.nargs != len(ctor): + return self.handle_fault(state, opcode, "VARIANT given n-args inconsistent with the type constructor") + + if opcode.nargs > len(state.stackframe): + return self.handle_fault(state, opcode, "Stack size violation") + + # FIXME: Where does type variable to type binding occur? + # Certainly needs to be AT LEAST here, where we also need to be doing some typechecking + v = Variant(armref.type.name, armref.arm, tuple(state.stackframe[:opcode.nargs])) + state.stackframe.drop(opcode.nargs) + state.stackframe.push(v) + state.stackframe._ip += 1 + return state + + @handle_opcode.register(isa.ATEST) + def _handle_atest(self, state: InterpreterState, opcode: isa.ATEST) -> InterpreterState: + armref: VariantRef = state.stackframe.pop() + if not isinstance(armref, VariantRef): + return self.handle_fault(state, opcode, "VTEST must be given a variant reference") + + inst: Variant = state.stackframe.pop() + if not isinstance(inst, Variant): + return self.handle_fault(state, opcode, "VTEST must be given an instance of a variant") + + if inst.type == armref.type.name and inst.variant == armref.arm: + state.stackframe.goto(opcode.target) + else: + state.stackframe._ip += 1 + + return state + + @handle_opcode.register(isa.GOTO) + def _handle_goto(self, state, opcode: isa.GOTO) -> InterpreterState: + if (opcode.target < 0): + return self.handle_fault(state, opcode, "Illegal branch target") + state.stackframe.goto(opcode.target) + return state + + @handle_opcode.register(isa.DUP) + def _handle_dupe(self, state, opcode: isa.DUP) -> InterpreterState: + if (opcode.nargs > len(state.stackframe)): + return self.handle_fault(state, opcode, "Stack size violation") + + state.stackframe.dup(opcode.nargs) + state.stackframe._ip += 1 + return state + + @handle_opcode.register(isa.ROT) + def _handle_rot(self, state, opcode: isa.DUP) -> InterpreterState: + if (opcode.nargs > len(state.stackframe)): + return self.handle_fault(state, opcode, "Stack size violation") + + state.stackframe.rot(opcode.nargs) + state.stackframe._ip += 1 + return state + + @handle_opcode.register(isa.DROP) + def _handle_drop(self, state, opcode: isa.DROP) -> InterpreterState: + if (opcode.nargs > len(state.stackframe)): + return self.handle_fault(state, opcode, "Stack size violation") + + state.stackframe.drop(opcode.nargs) + state.stackframe._ip += 1 + return state + + @handle_opcode.register(isa.SLOT) + def _handle_slot(self, state, opcode: isa.SLOT) -> InterpreterState: + if (opcode.target < 0): + return self.handle_fault(state, opcode, "SLOT must have a positive reference") + + if (opcode.target > len(state.stackframe) - 1): + return self.handle_fault(state, opcode, "SLOT reference out of range") + + state.stackframe.slot(opcode.target) + state.stackframe._ip += 1 + return state + + @handle_opcode.register(isa.FUNREF) + def _handle_funref(self, state, opcode) -> InterpreterState: + id = state.stackframe.pop() + if not isinstance(id, Identifier): + return self.handle_fault(state, opcode, "FUNREF consumes an IDENTIFIER") + try: + # FIXME: Verify this statically + state.stackframe.push(FunctionRef.parse(id.name)) + except: + return self.handle_fault(state, opcode, "Invalid function ref") + + state.stackframe._ip += 1 + return state + + @handle_opcode.register(isa.CALLF) + def _handle_callf(self, state, opcode: isa.CALLF) -> InterpreterState: + sig = state.stackframe.pop() + if not isinstance(sig, FunctionRef): + return self.handle_fault(state, opcode, "CALLF requires a funref at top of stack") + + fun = state.module.functions[sig.name] + if opcode.nargs != len(fun.arguments): + return self.handle_fault(state, opcode, "CALLF target violation; argument count missmatch") + + if opcode.nargs > len(state.stackframe): + return self.handle_fault(state, opcode, "Stack size violation") + + try: + ip = state.module.labels[fun.signature] + state.stackframe = state.stackframe.call(fun, ip) + except KeyError: + return self.handle_fault(state, opcode, "Unknown FUNREF target") + + return state + + @handle_opcode.register(isa.RETURN) + def _handle_return(self, state, opcode: isa.RETURN) -> InterpreterState: + n = 1 # FIXME: clean this up + if (n > len(state.stackframe)): + return self.handle_fault(state, opcode, "Stack size violation") + + if state.stackframe.depth == 0: + raise InterpreterReturn(state.stackframe[:n]) + + if (len(state.stackframe._fun.returns) != n): + return self.handle_fault(state, opcode, "Signature violation") + + state.stackframe = state.stackframe.ret(n) + return state + + @handle_opcode.register(isa.CLOSUREF) + def _handle_closuref(self, state: InterpreterState, opcode: isa.CLOSUREF) -> InterpreterState: + n = opcode.nargs + + sig = state.stackframe.pop() + if not isinstance(sig, FunctionRef): + return self.handle_fault(state, opcode, "CLOSUREF requires a funref at top of stack") + + fun = state.module.functions[sig.name] + if not n <= len(fun.arguments): + return self.handle_fault(state, opcode, "CLOSUREF target violation; too many parameters provided") + + if n > len(state.stackframe): + return self.handle_fault(state, opcode, "Stack size violation") + + c = Closure( + sig, + state.stackframe[:n] + ) + state.stackframe.drop(n) + state.stackframe.push(c) + state.stackframe._ip += 1 + return state + + @handle_opcode.register(isa.CLOSUREC) + def _handle_closurec(self, state, opcode: isa.CLOSUREC) -> InterpreterState: + n = opcode.nargs + + c = state.stackframe.pop() + if not isinstance(c, Closure): + return self.handle_fault(state, opcode, "CLOSUREC requires a closure at top of stack") + + fun = state.module.functions[c.funref.name] + if n + len(c.frag) > len(fun.arguments): + return self.handle_fault(state, opcode, "CLOSUREC target violation; too many parameters provided") + + if n > len(state.stackframe): + return self.handle_fault(state, opcode, "Stack size violation") + + c = Closure( + c.funref, + state.stackframe[:n] + c.frag + ) + state.stackframe.drop(n) + state.stackframe.push(c) + state.stackframe._ip += 1 + return state + + @handle_opcode.register(isa.CALLC) + def _handle_callc(self, state, opcode: isa.CALLC) -> InterpreterState: + n = opcode.nargs + c = state.stackframe.pop() + if not isinstance(c, Closure): + return self.handle_fault(state, opcode, "CALLC requires a closure at top of stack") + fun = state.module.functions[c.funref.name] + if n + len(c.frag) != len(fun.arguments): + return self.handle_fault(state, opcode, "CALLC target vionation; argument count missmatch") + if n > len(state.stackframe): + return self.handle_fault(state, opcode, "Stack size violation") + + # Push the closure's stack fragment + state.stackframe._stack = c.frag + state.stackframe._stack + + # Perform a "normal" funref call + try: + ip = state.module.labels[fun.signature] + except KeyError: + return self.handle_fault(state, opcode, "Unknown target") + + state.stackframe = state.stackframe.call(fun, ip) + return state + + @handle_opcode.register(isa.BREAK) + def _handle_break(self, state, _) -> InterpreterState: + raise InterpreterReturn(state.stackframe._stack) + def run(self, opcodes, stack=[]): """Directly interpret some opcodes in the configured environment.""" - mod = self.bootstrap.copy() - main = mod.define_function(";
;;", opcodes) - main_fun = mod.functions[main] - main_ip = mod.labels[main] - stackframe = Stackframe(main_fun, main_ip, stack) - clock: int = 0 - - print(mod) - - def _error(msg=None): - # Note this is pretty expensive because we have to snapshot the stack BEFORE we do anything - # And the stack object isn't immutable or otherwise designed for cheap snapshotting - raise InterpreterError(mod, deepcopy(stackframe), msg) - - def _debug(): - b = [] - b.append(f"clock {clock}:") - b.append(" stack:") - for offset, it in zip(range(0, len(stackframe), 1), stackframe): - b.append(f" {offset: <3} {it}") - b.append(f" op: {op}") - print(indent("\n".join(b), " " * stackframe.depth)) + _mod = self.bootstrap.copy() + _main = _mod.define_function(";
;;", opcodes) + _main_fun = _mod.functions[_main] + _main_ip = _mod.labels[_main] + state = InterpreterState( + _mod, Stackframe(_main_fun, _main_ip, stack) + ) while True: - op = mod.codepage[stackframe._ip] - _debug() - clock += 1 + try: + opcode = state.module.codepage[state.stackframe._ip] + self.pre_instr(state, opcode) + state = self.handle_opcode(state, opcode) + self.post_instr(state, opcode) + state.clock += 1 - match op: - case isa.IDENTIFIERC(name): - if not (name in mod.functions - or name in mod.types - or any(name in t.constructors for t in mod.types.values())): - _error("IDENTIFIERC references unknown entity") + # FIXME: This case analysis isn't super obvious. + except InterpreterReturn as r: + return r.val - stackframe.push(Identifier(name)) + except InterpreterRestart as r: + state = r.state + continue - case isa.TYPEREF(): - id = stackframe.pop() - if not isinstance(id, Identifier): - _error("TYPEREF consumes an identifier") - if not id.name in mod.types: - _error("TYPEREF must be given a valid type identifier") - - stackframe.push(TypeRef(id.name)) - - case isa.ARMREF(): - id: Identifier = stackframe.pop() - if not isinstance(id, Identifier): - _error("VARIANTREF consumes an identifier and a typeref") - - t: TypeRef = stackframe.pop() - if not isinstance(t, TypeRef): - _error("VARIANTREF consumes an identifier and a typeref") - - type = mod.types[t.name] - if id.name not in type.constructors: - _error(f"VARIANTREF given {id.name!r} which does not name a constructor within {type!r}") - - stackframe.push(VariantRef(t, id.name)) - - case isa.ARM(n): - armref: VariantRef = stackframe.pop() - if not isinstance(armref, VariantRef): - _error("VARIANT must be given a valid constructor reference") - - ctor = mod.types[armref.type.name].constructors[armref.arm] - if n != len(ctor): - _error("VARIANT given n-args inconsistent with the type constructor") - - if n > len(stackframe): - _error("Stack size violation") - - # FIXME: Where does type variable to type binding occur? - # Certainly needs to be AT LEAST here, where we also need to be doing some typechecking - v = Variant(armref.type.name, armref.arm, tuple(stackframe[:n])) - stackframe.drop(n) - stackframe.push(v) - - case isa.ATEST(n): - armref: VariantRef = stackframe.pop() - if not isinstance(armref, VariantRef): - _error("VTEST must be given a variant reference") - - inst: Variant = stackframe.pop() - if not isinstance(inst, Variant): - _error("VTEST must be given an instance of a variant") - - if inst.type == armref.type.name and inst.variant == armref.arm: - stackframe.goto(n) - continue - - case isa.GOTO(n): - if (n < 0): - _error("Illegal branch target") - stackframe.goto(n) - continue - - case isa.DUP(n): - if (n > len(stackframe)): - _error("Stack size violation") - - stackframe.dup(n) - - case isa.ROT(n): - if (n > len(stackframe)): - _error("Stack size violation") - - stackframe.rot(n) - - case isa.DROP(n): - if (n > len(stackframe)): - _error("Stack size violation") - - stackframe.drop(n) - - case isa.SLOT(n): - if (n < 0): - _error("SLOT must have a positive reference") - if (n > len(stackframe) - 1): - _error("SLOT reference out of range") - stackframe.slot(n) - - case isa.FUNREF(): - id = stackframe.pop() - if not isinstance(id, Identifier): - _error("FUNREF consumes an IDENTIFIER") - try: - # FIXME: Verify this statically - stackframe.push(FunctionRef.parse(id.name)) - except: - _error("Invalid function ref") - - case isa.CALLF(n): - sig = stackframe.pop() - if not isinstance(sig, FunctionRef): - _error("CALLF requires a funref at top of stack") - fun = mod.functions[sig.name] - if n != len(fun.arguments): - _error("CALLF target violation; argument count missmatch") - if n > len(stackframe): - _error("Stack size violation") - - try: - ip = mod.labels[fun.signature] - except KeyError: - _error("Unknown target") - - stackframe = stackframe.call(fun, ip) - continue - - case isa.RETURN(): - n = 1 # FIXME: clean this up - if (n > len(stackframe)): - _error("Stack size violation") - - if stackframe.depth == 0: - return stackframe[:n] - - if (len(stackframe._fun.returns) != n): - _error("Signature violation") - - stackframe = stackframe.ret(n) - continue - - case isa.CLOSUREF(n): - sig = stackframe.pop() - if not isinstance(sig, FunctionRef): - _error("CLOSUREF requires a funref at top of stack") - fun = mod.functions[sig.name] - if not n <= len(fun.arguments): - _error("CLOSUREF target violation; too many parameters provided") - if n > len(stackframe): - _error("Stack size violation") - - c = Closure( - sig, - stackframe[:n] - ) - stackframe.drop(n) - stackframe.push(c) - - case isa.CLOSUREC(n): - c = stackframe.pop() - if not isinstance(c, Closure): - _error("CLOSUREC requires a closure at top of stack") - fun = mod.functions[c.funref.name] - if n + len(c.frag) > len(fun.arguments): - _error("CLOSUREC target violation; too many parameters provided") - if n > len(stackframe): - _error("Stack size violation") - - c = Closure( - c.funref, - stackframe[:n] + c.frag - ) - stackframe.drop(n) - stackframe.push(c) - - case isa.CALLC(n): - c = stackframe.pop() - if not isinstance(c, Closure): - _error("CALLC requires a closure at top of stack") - fun = mod.functions[c.funref.name] - if n + len(c.frag) != len(fun.arguments): - _error("CALLC target vionation; argument count missmatch") - if n > len(stackframe): - _error("Stack size violation") - - # Extract the function signature - - # Push the closure's stack fragment - stackframe._stack = c.frag + stackframe._stack - - # Perform a "normal" funref call - try: - ip = mod.labels[fun.signature] - except KeyError: - _error("Unknown target") - - stackframe = stackframe.call(fun, ip) - continue - - case isa.BREAK(): - # FIXME: let users override this / set custom handlers - return stackframe._stack - - case _: - raise Exception(f"Unhandled interpreter state {op}") - - stackframe._ip += 1 + except Exception as e: + raise e diff --git a/projects/shoggoth/test/python/ichor/fixtures.py b/projects/shoggoth/test/python/ichor/fixtures.py index 84f9119..32c449a 100644 --- a/projects/shoggoth/test/python/ichor/fixtures.py +++ b/projects/shoggoth/test/python/ichor/fixtures.py @@ -1,10 +1,29 @@ #!/usr/bin/env python3 +from textwrap import indent + +from ichor import isa from ichor.bootstrap import BOOTSTRAP -from ichor.interpreter import Interpreter +from ichor.interpreter import ( + Interpreter, + InterpreterState, +) import pytest +class LoggingInterpreter(Interpreter): + def pre_instr(self, state: InterpreterState, opcode: isa.Opcode) -> InterpreterState: + b = [] + b.append(f"clock {state.clock}:") + b.append(" stack:") + for offset, it in zip(range(0, len(state.stackframe), 1), state.stackframe): + b.append(f" {offset: <3} {it}") + b.append(f" op: {opcode}") + print(indent("\n".join(b), " " * state.stackframe.depth)) + + return state + + @pytest.fixture def vm(): - return Interpreter(BOOTSTRAP) + return LoggingInterpreter(BOOTSTRAP) diff --git a/projects/shoggoth/test/python/ichor/test_interpreter.py b/projects/shoggoth/test/python/ichor/test_interpreter.py index 60460e9..d181914 100644 --- a/projects/shoggoth/test/python/ichor/test_interpreter.py +++ b/projects/shoggoth/test/python/ichor/test_interpreter.py @@ -4,11 +4,8 @@ Tests coverign the VM interpreter from .fixtures import * # noqa -from ichor.bootstrap import ( - FALSE, - TRUE, -) from ichor import isa +from ichor.bootstrap import FALSE, TRUE from ichor.interpreter import InterpreterError import pytest