#!/usr/bin/env python3

"""The Ichor VM.interpreterementation.

The whole point of Shoggoth is that program executions are checkpointable and restartable. This requires that rather than
using a traditional recursive interpreter which is difficult to snapshot, interpretation in shoggoth occur within a
context (a virtual machine) which DOES have an easily introspected and serialized representation.

"""


from dataclasses import dataclass
from functools import (
    singledispatch,
    update_wrapper,
)
from typing import Optional

from ichor import isa
from ichor.state import (
    Closure,
    FunctionRef,
    Identifier,
    Module,
    Stackframe,
    TypeRef,
    Variant,
    VariantRef,
)


@dataclass
class InterpreterState(object):
    module: Module
    stackframe: Stackframe
    clock: int = 0


class InterpreterError(Exception):
    """An error raised by the interpreter when something goes awry."""

    def __init__(self, state: InterpreterState, message: str, cause: Optional[Exception] = None):
        super().__init__(message, cause)
        self.state = state


class InterpreterReturn(Exception):
    def __init__(self, val):
        super().__init__()
        self.val = val


class InterpreterRestart(Exception):
    def __init__(self, state: InterpreterState):
        super().__init__()
        self.state = state


def handledispatch(func):
    dispatcher = singledispatch(func)
    def wrapper(self, state, opcode):
        assert isinstance(state, InterpreterState)
        assert isinstance(opcode, isa.Opcode)
        return dispatcher.dispatch(opcode.__class__)(self, state, opcode)
    wrapper.register = dispatcher.register
    update_wrapper(wrapper, func)
    return wrapper


class Interpreter(object):
    """A shit .interpretere instruction pointer based interpreter."""
    def __init__(self, bootstrap_module: Module):
        self.bootstrap = bootstrap_module

    def pre_instr(self, state: InterpreterState, opcode: isa.Opcode) -> InterpreterState:
        return state

    def post_instr(self, state: InterpreterState,  opcode: isa.Opcode) -> InterpreterState:
        return state

    def step(self, state: InterpreterState, opcode: isa.Opcode) -> InterpreterState:
        return self.handle_opcode(state, opcode)

    def handle_unknown(self, state: InterpreterState, opcode: isa.Opcode) -> InterpreterState:
        raise InterpreterError(state, "Unsupported operation: {opcode}")

    def handle_fault(self, state, opcode, message, cause=None) -> InterpreterState:
        raise InterpreterError(state, message, cause)

    @handledispatch
    def handle_opcode(self, state: InterpreterState, opcode: isa.Opcode) -> InterpreterState:
        return self.handle_unknown(state, opcode)

    @handle_opcode.register(isa.IDENTIFIERC)
    def _handle_identifierc(self, state: InterpreterState, opcode: isa.IDENTIFIERC) -> InterpreterState:
        name = opcode.val
        if not (name in state.module.functions
                or name in state.module.types
                or any(name in t.constructors for t in state.module.types.values())):
            return self.handle_fault(state, opcode, "IDENTIFIERC references unknown entity")

        state.stackframe.push(Identifier(name))
        state.stackframe._ip += 1
        return state

    @handle_opcode.register(isa.TYPEREF)
    def _handle_typeref(self, state, opcode) -> InterpreterState:
        id = state.stackframe.pop()
        if not isinstance(id, Identifier):
            return self.handle_fault(state, opcode, "TYPEREF consumes an identifier")

        if not id.name in state.module.types:
            return self.handle_fault(state, opcode, "TYPEREF must be given a valid type identifier")

        state.stackframe.push(TypeRef(id.name))
        state.stackframe._ip += 1
        return state

    @handle_opcode.register(isa.ARMREF)
    def _handle_armref(self, state, opcode) -> InterpreterState:
        id: Identifier = state.stackframe.pop()
        if not isinstance(id, Identifier):
            return self.handle_fault(state, opcode, "VARIANTREF consumes an identifier and a typeref")

        t: TypeRef = state.stackframe.pop()
        if not isinstance(t, TypeRef):
            return self.handle_fault(state, opcode, "VARIANTREF consumes an identifier and a typeref")

        type = state.module.types[t.name]
        if id.name not in type.constructors:
            return self.handle_fault(state, opcode, f"VARIANTREF given {id.name!r} which does not name a constructor within {type!r}")

        state.stackframe.push(VariantRef(t, id.name))
        state.stackframe._ip += 1
        return state

    @handle_opcode.register(isa.ARM)
    def _handle_arm(self, state: InterpreterState, opcode: isa.ARM) -> InterpreterState:
        armref: VariantRef = state.stackframe.pop()
        if not isinstance(armref, VariantRef):
            return self.handle_fault(state, opcode, "VARIANT must be given a valid constructor reference")

        ctor = state.module.types[armref.type.name].constructors[armref.arm]
        if opcode.nargs != len(ctor):
            return self.handle_fault(state, opcode, "VARIANT given n-args inconsistent with the type constructor")

        if opcode.nargs > len(state.stackframe):
            return self.handle_fault(state, opcode, "Stack size violation")

        # FIXME: Where does type variable to type binding occur?
        # Certainly needs to be AT LEAST here, where we also need to be doing some typechecking
        v = Variant(armref.type.name, armref.arm, tuple(state.stackframe[:opcode.nargs]))
        state.stackframe.drop(opcode.nargs)
        state.stackframe.push(v)
        state.stackframe._ip += 1
        return state

    @handle_opcode.register(isa.ATEST)
    def _handle_atest(self, state: InterpreterState, opcode: isa.ATEST) -> InterpreterState:
        armref: VariantRef = state.stackframe.pop()
        if not isinstance(armref, VariantRef):
            return self.handle_fault(state, opcode, "VTEST must be given a variant reference")

        inst: Variant = state.stackframe.pop()
        if not isinstance(inst, Variant):
            return self.handle_fault(state, opcode, "VTEST must be given an instance of a variant")

        if inst.type == armref.type.name and inst.variant == armref.arm:
            state.stackframe.goto(opcode.target)
        else:
            state.stackframe._ip += 1

        return state

    @handle_opcode.register(isa.GOTO)
    def _handle_goto(self, state, opcode: isa.GOTO) -> InterpreterState:
        if (opcode.target < 0):
            return self.handle_fault(state, opcode, "Illegal branch target")
        state.stackframe.goto(opcode.target)
        return state

    @handle_opcode.register(isa.DUP)
    def _handle_dupe(self, state, opcode: isa.DUP) -> InterpreterState:
        if (opcode.nargs > len(state.stackframe)):
            return self.handle_fault(state, opcode, "Stack size violation")

        state.stackframe.dup(opcode.nargs)
        state.stackframe._ip += 1
        return state

    @handle_opcode.register(isa.ROT)
    def _handle_rot(self, state, opcode: isa.DUP) -> InterpreterState:
        if (opcode.nargs > len(state.stackframe)):
            return self.handle_fault(state, opcode, "Stack size violation")

        state.stackframe.rot(opcode.nargs)
        state.stackframe._ip += 1
        return state

    @handle_opcode.register(isa.DROP)
    def _handle_drop(self, state, opcode: isa.DROP) -> InterpreterState:
        if (opcode.nargs > len(state.stackframe)):
            return self.handle_fault(state, opcode, "Stack size violation")

        state.stackframe.drop(opcode.nargs)
        state.stackframe._ip += 1
        return state

    @handle_opcode.register(isa.SLOT)
    def _handle_slot(self, state, opcode: isa.SLOT) -> InterpreterState:
        if (opcode.target < 0):
            return self.handle_fault(state, opcode, "SLOT must have a positive reference")

        if (opcode.target > len(state.stackframe) - 1):
            return self.handle_fault(state, opcode, "SLOT reference out of range")

        state.stackframe.slot(opcode.target)
        state.stackframe._ip += 1
        return state

    @handle_opcode.register(isa.FUNREF)
    def _handle_funref(self, state, opcode) -> InterpreterState:
        id = state.stackframe.pop()
        if not isinstance(id, Identifier):
            return self.handle_fault(state, opcode, "FUNREF consumes an IDENTIFIER")
        try:
            # FIXME: Verify this statically
            state.stackframe.push(FunctionRef.parse(id.name))
        except:
            return self.handle_fault(state, opcode, "Invalid function ref")

        state.stackframe._ip += 1
        return state

    @handle_opcode.register(isa.CALLF)
    def _handle_callf(self, state, opcode: isa.CALLF) -> InterpreterState:
        sig = state.stackframe.pop()
        if not isinstance(sig, FunctionRef):
            return self.handle_fault(state, opcode, "CALLF requires a funref at top of stack")

        fun = state.module.functions[sig.name]
        if opcode.nargs != len(fun.arguments):
            return self.handle_fault(state, opcode, "CALLF target violation; argument count missmatch")

        if opcode.nargs > len(state.stackframe):
            return self.handle_fault(state, opcode, "Stack size violation")

        try:
            ip = state.module.labels[fun.signature]
            state.stackframe = state.stackframe.call(fun, ip)
        except KeyError:
            return self.handle_fault(state, opcode, "Unknown FUNREF target")

        return state

    @handle_opcode.register(isa.RETURN)
    def _handle_return(self, state, opcode: isa.RETURN) -> InterpreterState:
        n = 1  # FIXME: clean this up
        if (n > len(state.stackframe)):
            return self.handle_fault(state, opcode, "Stack size violation")

        if state.stackframe.depth == 0:
            raise InterpreterReturn(state.stackframe[:n])

        if (len(state.stackframe._fun.returns) != n):
            return self.handle_fault(state, opcode, "Signature violation")

        state.stackframe = state.stackframe.ret(n)
        return state

    @handle_opcode.register(isa.CLOSUREF)
    def _handle_closuref(self, state: InterpreterState, opcode: isa.CLOSUREF) -> InterpreterState:
        n = opcode.nargs

        sig = state.stackframe.pop()
        if not isinstance(sig, FunctionRef):
            return self.handle_fault(state, opcode, "CLOSUREF requires a funref at top of stack")

        fun = state.module.functions[sig.name]
        if not n <= len(fun.arguments):
            return self.handle_fault(state, opcode, "CLOSUREF target violation; too many parameters provided")

        if n > len(state.stackframe):
            return self.handle_fault(state, opcode, "Stack size violation")

        c = Closure(
            sig,
            state.stackframe[:n]
        )
        state.stackframe.drop(n)
        state.stackframe.push(c)
        state.stackframe._ip += 1
        return state

    @handle_opcode.register(isa.CLOSUREC)
    def _handle_closurec(self, state, opcode: isa.CLOSUREC) -> InterpreterState:
        n = opcode.nargs

        c = state.stackframe.pop()
        if not isinstance(c, Closure):
            return self.handle_fault(state, opcode, "CLOSUREC requires a closure at top of stack")

        fun = state.module.functions[c.funref.name]
        if n + len(c.frag) > len(fun.arguments):
            return self.handle_fault(state, opcode, "CLOSUREC target violation; too many parameters provided")

        if n > len(state.stackframe):
            return self.handle_fault(state, opcode, "Stack size violation")

        c = Closure(
            c.funref,
            state.stackframe[:n] + c.frag
        )
        state.stackframe.drop(n)
        state.stackframe.push(c)
        state.stackframe._ip += 1
        return state

    @handle_opcode.register(isa.CALLC)
    def _handle_callc(self, state, opcode: isa.CALLC) -> InterpreterState:
        n = opcode.nargs
        c = state.stackframe.pop()
        if not isinstance(c, Closure):
            return self.handle_fault(state, opcode, "CALLC requires a closure at top of stack")
        fun = state.module.functions[c.funref.name]
        if n + len(c.frag) != len(fun.arguments):
            return self.handle_fault(state, opcode, "CALLC target vionation; argument count missmatch")
        if n > len(state.stackframe):
            return self.handle_fault(state, opcode, "Stack size violation")

        # Push the closure's stack fragment
        state.stackframe._stack = c.frag + state.stackframe._stack

        # Perform a "normal" funref call
        try:
            ip = state.module.labels[fun.signature]
        except KeyError:
            return self.handle_fault(state, opcode, "Unknown target")

        state.stackframe = state.stackframe.call(fun, ip)
        return state

    @handle_opcode.register(isa.BREAK)
    def _handle_break(self, state, _) -> InterpreterState:
        raise InterpreterReturn(state.stackframe._stack)

    def run(self, opcodes, stack=[]):
        """Directly interpret some opcodes in the configured environment."""

        _mod = self.bootstrap.copy()
        _main = _mod.define_function(";<main>;;", opcodes)
        _main_fun = _mod.functions[_main]
        _main_ip = _mod.labels[_main]

        state = InterpreterState(
            _mod, Stackframe(_main_fun, _main_ip, stack)
        )

        while True:
            try:
                opcode = state.module.codepage[state.stackframe._ip]
                self.pre_instr(state, opcode)
                state = self.handle_opcode(state, opcode)
                self.post_instr(state, opcode)
                state.clock += 1

            # FIXME: This case analysis isn't super obvious.
            except InterpreterReturn as r:
                return r.val

            except InterpreterRestart as r:
                state = r.state
                continue

            except Exception as e:
                raise e