# flake8: noqa: all # A Python AST interpreter written in Python # # This module is part of the Pycopy https://github.com/pfalcon/pycopy project. # # Copyright (c) 2019 Paul Sokolovsky, published under the MIT License import ast import builtins import logging import os import sys from typing import Optional, Type log = logging.getLogger(__name__) class StrictNodeVisitor(ast.NodeVisitor): def generic_visit(self, node): n = node.__class__.__name__ raise NotImplementedError("Visitor for node {} not implemented".format(n)) class ANamespace: def __init__(self, node): self.d = {} self.parent: Optional[Type["ANamespace"]] = None # Cross-link namespace to AST node. Note that we can't do the # opposite, because for one node, there can be different namespaces. self.node = node def __getitem__(self, k): return self.d[k] def get(self, k, default=None): return self.d.get(k, default) def __setitem__(self, k, v): self.d[k] = v def __delitem__(self, k): del self.d[k] def __contains__(self, k): return k in self.d def __str__(self): return "<{} {}>".format(self.__class__.__name__, self.d) class ModuleNS(ANamespace): # parent: Optional["ModuleNS"] = None pass class FunctionNS(ANamespace): pass class ClassNS(ANamespace): cls: Optional[type] = None # Pycopy by default doesn't support direct slice construction, use helper # object to construct it. class SliceGetter: def __getitem__(self, idx): return idx slice_getter = SliceGetter() def arg_name(arg): return arg.arg def kwarg_defaults(args): return args.kw_defaults class TargetNonlocalFlow(Exception): """Base exception class to simulate non-local control flow transfers in a target application.""" class TargetBreak(TargetNonlocalFlow): pass class TargetContinue(TargetNonlocalFlow): pass class TargetReturn(TargetNonlocalFlow): pass class VarScopeSentinel: def __init__(self, name): self.name = name NO_VAR = VarScopeSentinel("no_var") GLOBAL = VarScopeSentinel("global") NONLOCAL = VarScopeSentinel("nonlocal") class InterpFuncWrap: "Callable wrapper for AST functions (FunctionDef nodes)." def __init__(self, node, interp): self.node = node self.interp = interp self.lexical_scope = interp.ns def __call__(self, *args, **kwargs): return self.interp.call_func(self.node, self, *args, **kwargs) # Python don't fully treat objects, even those defining __call__() special method, as a true # callable. For example, such objects aren't automatically converted to bound methods if looked up # as another object's attributes. As we want our "interpreted functions" to behave as closely as # possible to real functions, we just wrap function object with a real function. An alternative # might have been to perform needed checks and explicitly bind a method using types.MethodType() in # visit_Attribute (but then maybe there would be still other cases of "callable object" vs # "function" discrepancies). def InterpFunc(fun): def func(*args, **kwargs): return fun.__call__(*args, **kwargs) return func class InterpWith: def __init__(self, ctx): self.ctx = ctx def __enter__(self): return self.ctx.__enter__() def __exit__(self, tp, exc, tb): # Don't leak meta-level exceptions into target if isinstance(exc, TargetNonlocalFlow): tp = exc = tb = None return self.ctx.__exit__(tp, exc, tb) class InterpModule: def __init__(self, ns): self.ns = ns def __getattr__(self, name): try: return self.ns[name] except KeyError: raise AttributeError def __dir__(self): return list(self.ns.d.keys()) # TODO (arrdem 2023-03-08): # This interpreter works well enough to import `requests` and many other libraries and do some # work, but is unsuited to Flowmetal's needs for checkpointing. Because this interpreter uses # direct execution, there's really no way to jam breakpoints or checkpoints or resume points into # program execution. Which is kinda the goal of the whole project. # # This interpreter, while complete, needs to get refactored into probably a `yield` based # coroutine structure wherein individual operations explicitly `yield` to an outer state # management loop which effectively trampolines single statements together with state management # logic. # # The outer interpreter needs to be able to check the "step budget" and decide if it's time for # the program to suspend. # # Individual steps (workflow calls/function calls) may also cause the program to suspend. # # Suspending requires signaling the top level loop, and the top level loop needs both the # namespace tree and the some sort of cursor or address into the AST under interpretation # representing where to resume. The logical equivalent of a program counter, but a tree path. class ModuleInterpreter(StrictNodeVisitor): """An interpreter specific to a single module.""" def __init__(self, system, fname, node): self.system = system self.fname = fname self.module_ns: ModuleNS = ModuleNS(node) self.ns: ANamespace = self.module_ns # Call stack (in terms of function AST nodes). self.call_stack = [] # To implement "store" operation, we need to arguments: location and value to store. The # operation itself is handled by a node visitor (e.g. visit_Name), and location is # represented by AST node, but there's no support to pass additional arguments to a visitor # (likely, because it would be a burden to explicit pass such additional arguments thru the # chain of visitors). So instead, we store this value as field. As interpretation happens # sequentially, there's no risk that it will be overwritten "concurrently". self.store_val = None # Current active exception, for bare "raise", which doesn't work across function boundaries # (and that's how we have it - exception would be caught in visit_Try, while re-rasing would # happen in visit_Raise). self.cur_exc = [] def push_ns(self, new_ns): new_ns.parent = self.ns self.ns = new_ns def pop_ns(self): assert self.ns is not None self.ns = self.ns.parent def stmt_list_visit(self, lst): res = None for s in lst: res = self.visit(s) return res def wrap_decorators(self, obj, node): for deco_n in reversed(list(node.decorator_list)): deco = self.visit(deco_n) obj = deco(obj) return obj def visit(self, node): val = super(StrictNodeVisitor, self).visit(node) return val def visit_Module(self, node): self.stmt_list_visit(node.body) def visit_Expression(self, node): return self.visit(node.body) def visit_ClassDef(self, node): ns: ClassNS = ClassNS(node) self.push_ns(ns) try: self.stmt_list_visit(node.body) except Exception: self.pop_ns() raise self.pop_ns() cls = type(node.name, tuple([self.visit(b) for b in node.bases]), ns.d) cls = self.wrap_decorators(cls, node) self.ns[node.name] = cls # Store reference to class object in the namespace object ns.cls = cls def visit_Lambda(self, node): node.name = "" return self.prepare_func(node) def visit_FunctionDef(self, node): # Defaults are evaluated at function definition time, so we need to do that now. func = self.prepare_func(node) func = self.wrap_decorators(func, node) self.ns[node.name] = func def prepare_func(self, node): """Prepare function AST node for future interpretation: pre-calculate and cache useful information, etc.""" func = InterpFuncWrap(node, self) args = node.args or node.posonlyargs num_required = len(args.args) - len(args.defaults) all_args = set() d = {} for i, a in enumerate(args.args): all_args.add(arg_name(a)) if i >= num_required: d[arg_name(a)] = self.visit(args.defaults[i - num_required]) for a, v in zip(getattr(args, "kwonlyargs", ()), kwarg_defaults(args)): all_args.add(arg_name(a)) if v is not None: d[arg_name(a)] = self.visit(v) # We can store cached argument names of a function in its node - it's static. node.args.all_args = all_args # We can't store the values of default arguments - they're dynamic, may depend on the # lexical scope. func.defaults_dict = d return InterpFunc(func) def prepare_func_args(self, node, interp_func, *args, **kwargs): def arg_num_mismatch(): raise TypeError( "{}() takes {} positional arguments but {} were given".format( node.name, len(argspec.args), len(args) ) ) argspec = node.args # If there's vararg, either offload surplus of args to it, or init it to empty tuple (all in # one statement). If no vararg, error on too many args. # # Note that we have to do the .posonlyargs dance if argspec.vararg: self.ns[argspec.vararg.arg] = args[len(argspec.args) :] else: if len(args) > len(argspec.args or getattr(argspec, "posonlyargs", ())): arg_num_mismatch() if argspec.args: for i in range(min(len(args), len(argspec.args))): self.ns[arg_name(argspec.args[i])] = args[i] elif getattr(argspec, "posonlyargs", ()): if len(args) != len(argspec.posonlyargs): arg_num_mismatch() for a, value in zip(argspec.posonlyargs, args): self.ns[arg_name(a)] = value # Process incoming keyword arguments, putting them in namespace if actual arg exists by that # name, or offload to function's kwarg if any. All make needed checks and error out. func_kwarg = {} for k, v in kwargs.items(): if k in argspec.all_args: if k in self.ns: raise TypeError( "{}() got multiple values for argument '{}'".format( node.name, k ) ) self.ns[k] = v elif argspec.kwarg: func_kwarg[k] = v else: raise TypeError( "{}() got an unexpected keyword argument '{}'".format(node.name, k) ) if argspec.kwarg: self.ns[arg_name(argspec.kwarg)] = func_kwarg # Finally, overlay default values for arguments not yet initialized. We need to do this last # for "multiple values for the same arg" check to work. for k, v in interp_func.defaults_dict.items(): if k not in self.ns: self.ns[k] = v # And now go thru and check for any missing arguments. for a in argspec.args: if arg_name(a) not in self.ns: raise TypeError( "{}() missing required positional argument: '{}'".format( node.name, arg_name(a) ) ) for a in getattr(argspec, "kwonlyargs", ()): if a.arg not in self.ns: raise TypeError( "{}() missing required keyword-only argument: '{}'".format( node.name, arg_name(a) ) ) def call_func(self, node, interp_func, *args, **kwargs): self.call_stack.append(node) # We need to switch from dynamic execution scope to lexical scope in which function was # defined (then switch back on return). dyna_scope = self.ns self.ns = interp_func.lexical_scope self.push_ns(FunctionNS(node)) try: self.prepare_func_args(node, interp_func, *args, **kwargs) if isinstance(node.body, list): res = self.stmt_list_visit(node.body) else: res = self.visit(node.body) except TargetReturn as e: res = e.args[0] finally: self.pop_ns() self.ns = dyna_scope self.call_stack.pop() return res def visit_Return(self, node): if not isinstance(self.ns, FunctionNS): raise SyntaxError("'return' outside function") raise TargetReturn(node.value and self.visit(node.value)) def visit_With(self, node): assert len(node.items) == 1 ctx = self.visit(node.items[0].context_expr) with InterpWith(ctx) as val: if node.items[0].optional_vars is not None: self.handle_assign(node.items[0].optional_vars, val) self.stmt_list_visit(node.body) def visit_Try(self, node): try: self.stmt_list_visit(node.body) except TargetNonlocalFlow: raise except Exception as e: self.cur_exc.append(e) try: for h in getattr(node, "handlers", ()): if h.type is None or isinstance(e, self.visit(h.type)): if h.name: self.ns[h.name] = e self.stmt_list_visit(h.body) if h.name: del self.ns[h.name] break else: raise finally: self.cur_exc.pop() else: self.stmt_list_visit(node.orelse) finally: if getattr(node, "finalbody", None): self.stmt_list_visit(node.finalbody) def visit_TryExcept(self, node): # Py2k only; py3k merged all this into one node type. return self.visit_Try(node) def visit_TryFinally(self, node): # Py2k only; py3k merged all this into one node type. return self.visit_Try(node) def visit_For(self, node): iter = self.visit(node.iter) for item in iter: self.handle_assign(node.target, item) try: self.stmt_list_visit(node.body) except TargetBreak: break except TargetContinue: continue else: self.stmt_list_visit(node.orelse) def visit_While(self, node): while self.visit(node.test): try: self.stmt_list_visit(node.body) except TargetBreak: break except TargetContinue: continue else: self.stmt_list_visit(node.orelse) def visit_Break(self, node): raise TargetBreak def visit_Continue(self, node): raise TargetContinue def visit_If(self, node): test = self.visit(node.test) if test: self.stmt_list_visit(node.body) else: self.stmt_list_visit(node.orelse) def visit_Import(self, node): for n in node.names: self.ns[n.asname or n.name] = self.system.handle_import(n.name) def visit_ImportFrom(self, node): mod = self.system.handle_import( node.module, None, None, [n.name for n in node.names], node.level ) for n in node.names: if n.name == "*": # This is the special case of the wildcard import. Copy # everything over. for n in getattr(mod, "__all__", dir(mod)): self.ns[n] = getattr(mod, n) else: self.ns[n.asname or n.name] = getattr(mod, n.name) def visit_Raise(self, node): if node.exc is None: if not self.cur_exc: raise RuntimeError("No active exception to reraise") raise self.cur_exc[-1] elif node.cause is None: raise self.visit(node.exc) # else: # raise self.visit(node.exc) from self.visit(node.cause) def visit_AugAssign(self, node): assert isinstance(node.target.ctx, ast.Store) # Not functional style, oops. Node in AST has store context, but we need to read its value # first. To not construct a copy of the entire node with load context, we temporarily patch # it in-place. save_ctx = node.target.ctx node.target.ctx = ast.Load() var_val = self.visit(node.target) node.target.ctx = save_ctx rval = self.visit(node.value) # As augmented assignment is statement, not operator, we can't put them all into map. We # could instead directly lookup special inplace methods (__iadd__ and friends) and use them, # with a fallback to normal binary operations, but from the point of view of this # interpreter, presence of such methods is an implementation detail of the object system, # it's not concerned with it. op = type(node.op) if op is ast.Add: var_val += rval elif op is ast.Sub: var_val -= rval elif op is ast.Mult: var_val *= rval elif op is ast.Div: var_val /= rval elif op is ast.FloorDiv: var_val //= rval elif op is ast.Mod: var_val %= rval elif op is ast.Pow: var_val **= rval elif op is ast.LShift: var_val <<= rval elif op is ast.RShift: var_val >>= rval elif op is ast.BitAnd: var_val &= rval elif op is ast.BitOr: var_val |= rval elif op is ast.BitXor: var_val ^= rval else: raise NotImplementedError self.store_val = var_val self.visit(node.target) def visit_Assign(self, node): val = self.visit(node.value) for n in node.targets: self.handle_assign(n, val) def handle_assign(self, target, val): if isinstance(target, ast.Tuple): it = iter(val) try: for elt_idx, t in enumerate(target.elts): if getattr(ast, "Starred", None) and isinstance(t, ast.Starred): t = t.value all_elts = list(it) break_i = len(all_elts) - (len(target.elts) - elt_idx - 1) self.store_val = all_elts[:break_i] it = iter(all_elts[break_i:]) else: self.store_val = next(it) self.visit(t) except StopIteration: raise ValueError( "not enough values to unpack (expected {})".format(len(target.elts)) ) try: next(it) raise ValueError( "too many values to unpack (expected {})".format(len(target.elts)) ) except StopIteration: # Expected pass else: self.store_val = val self.visit(target) def visit_Delete(self, node): for n in node.targets: self.visit(n) def visit_Pass(self, node): pass def visit_Assert(self, node): if node.msg is None: assert self.visit(node.test) else: assert self.visit(node.test), self.visit(node.msg) def visit_Expr(self, node): # Produced value is ignored self.visit(node.value) def enumerate_comps(self, iters): """Enumerate thru all possible values of comprehension clauses, including multiple "for" clauses, each optionally associated with multiple "if" clauses. Current result of the enumeration is stored in the namespace.""" def eval_ifs(iter): """Evaluate all "if" clauses.""" for cond in iter.ifs: if not self.visit(cond): return False return True if not iters: yield return for el in self.visit(iters[0].iter): self.store_val = el self.visit(iters[0].target) for t in self.enumerate_comps(iters[1:]): if eval_ifs(iters[0]): yield def visit_ListComp(self, node): self.push_ns(FunctionNS(node)) try: return [self.visit(node.elt) for _ in self.enumerate_comps(node.generators)] finally: self.pop_ns() def visit_SetComp(self, node): self.push_ns(FunctionNS(node)) try: return {self.visit(node.elt) for _ in self.enumerate_comps(node.generators)} finally: self.pop_ns() def visit_DictComp(self, node): self.push_ns(FunctionNS(node)) try: return { self.visit(node.key): self.visit(node.value) for _ in self.enumerate_comps(node.generators) } finally: self.pop_ns() def visit_IfExp(self, node): if self.visit(node.test): return self.visit(node.body) else: return self.visit(node.orelse) def visit_Call(self, node): func = self.visit(node.func) args = [] for a in node.args: if getattr(ast, "Starred", None) and isinstance(a, ast.Starred): args.extend(self.visit(a.value)) else: args.append(self.visit(a)) kwargs = {} for kw in node.keywords: val = self.visit(kw.value) if kw.arg is None: kwargs.update(val) else: kwargs[kw.arg] = val if func is builtins.super and not args: if not self.ns.parent or not isinstance(self.ns.parent, ClassNS): raise RuntimeError("super(): no arguments") # As we're creating methods dynamically outside of class, super() without argument won't # work, as that requires __class__ cell. Creating that would be cumbersome (Pycopy # definitely lacks enough introspection for that), so we substitute 2 implied args # (which argumentless super() would take from cell and 1st arg to func). In our case, we # take them from prepared bookkeeping info. args = (self.ns.parent.cls, self.ns["self"]) return func(*args, **kwargs) def visit_Compare(self, node): cmpop_map = { ast.Eq: lambda x, y: x == y, ast.NotEq: lambda x, y: x != y, ast.Lt: lambda x, y: x < y, ast.LtE: lambda x, y: x <= y, ast.Gt: lambda x, y: x > y, ast.GtE: lambda x, y: x >= y, ast.Is: lambda x, y: x is y, ast.IsNot: lambda x, y: x is not y, ast.In: lambda x, y: x in y, ast.NotIn: lambda x, y: x not in y, } lv = self.visit(node.left) for op, r in zip(node.ops, node.comparators): rv = self.visit(r) if not cmpop_map[type(op)](lv, rv): return False lv = rv return True def visit_BoolOp(self, node): if isinstance(node.op, ast.And): res = True for v in node.values: res = res and self.visit(v) elif isinstance(node.op, ast.Or): res = False for v in node.values: res = res or self.visit(v) else: raise NotImplementedError return res def visit_BinOp(self, node): binop_map = { ast.Add: lambda x, y: x + y, ast.Sub: lambda x, y: x - y, ast.Mult: lambda x, y: x * y, ast.Div: lambda x, y: x / y, ast.FloorDiv: lambda x, y: x // y, ast.Mod: lambda x, y: x % y, ast.Pow: lambda x, y: x**y, ast.LShift: lambda x, y: x << y, ast.RShift: lambda x, y: x >> y, ast.BitAnd: lambda x, y: x & y, ast.BitOr: lambda x, y: x | y, ast.BitXor: lambda x, y: x ^ y, } l = self.visit(node.left) r = self.visit(node.right) return binop_map[type(node.op)](l, r) def visit_UnaryOp(self, node): unop_map = { ast.UAdd: lambda x: +x, ast.USub: lambda x: -x, ast.Invert: lambda x: ~x, ast.Not: lambda x: not x, } val = self.visit(node.operand) return unop_map[type(node.op)](val) def visit_Subscript(self, node): obj = self.visit(node.value) idx = self.visit(node.slice) if isinstance(node.ctx, ast.Load): return obj[idx] elif isinstance(node.ctx, ast.Store): obj[idx] = self.store_val elif isinstance(node.ctx, ast.Del): del obj[idx] else: raise NotImplementedError def visit_Index(self, node): return self.visit(node.value) def visit_Slice(self, node): # Any of these can be None lower = node.lower and self.visit(node.lower) upper = node.upper and self.visit(node.upper) step = node.step and self.visit(node.step) slice = slice_getter[lower:upper:step] return slice def visit_Attribute(self, node): obj = self.visit(node.value) if isinstance(node.ctx, ast.Load): return getattr(obj, node.attr) elif isinstance(node.ctx, ast.Store): setattr(obj, node.attr, self.store_val) elif isinstance(node.ctx, ast.Del): delattr(obj, node.attr) else: raise NotImplementedError def visit_Global(self, node): for n in node.names: if n in self.ns and self.ns[n] is not GLOBAL: raise SyntaxError( "SyntaxError: name '{}' is assigned to before global declaration".format( n ) ) # Don't store GLOBAL in the top-level namespace if self.ns.parent: self.ns[n] = GLOBAL def visit_Nonlocal(self, node): if isinstance(self.ns, ModuleNS): raise SyntaxError("nonlocal declaration not allowed at module level") for n in node.names: self.ns[n] = NONLOCAL def resolve_nonlocal(self, id, ns): while ns: res = ns.get(id, NO_VAR) if res is GLOBAL: return self.module_ns if res is not NO_VAR and res is not NONLOCAL: if isinstance(ns, ModuleNS): break return ns ns = ns.parent raise SyntaxError("no binding for nonlocal '{}' found".format(id)) def visit_Name(self, node): if isinstance(node.ctx, ast.Load): res = NO_VAR ns = self.ns # We always lookup in the current namespace (on the first iteration), but afterwards we always skip class # namespaces. Or put it another way, class code can look up in its own namespace, but that's the only case # when the class namespace is consulted. skip_classes = False while ns: if not (skip_classes and isinstance(ns, ClassNS)): res = ns.get(node.id, NO_VAR) if res is not NO_VAR: break ns = ns.parent skip_classes = True if res is NONLOCAL: ns = self.resolve_nonlocal(node.id, ns.parent) return ns[node.id] if res is GLOBAL: res = self.module_ns.get(node.id, NO_VAR) if res is not NO_VAR: return res try: return getattr(builtins, node.id) except AttributeError: raise NameError("name '{}' is not defined".format(node.id)) elif isinstance(node.ctx, ast.Store): res = self.ns.get(node.id, NO_VAR) if res is GLOBAL: self.module_ns[node.id] = self.store_val elif res is NONLOCAL: ns = self.resolve_nonlocal(node.id, self.ns.parent) ns[node.id] = self.store_val else: self.ns[node.id] = self.store_val elif isinstance(node.ctx, ast.Del): res = self.ns.get(node.id, NO_VAR) if res is NO_VAR: raise NameError("name '{}' is not defined".format(node.id)) elif res is GLOBAL: del self.module_ns[node.id] elif res is NONLOCAL: ns = self.resolve_nonlocal(node.id, self.ns.parent) del ns[node.id] else: del self.ns[node.id] else: raise NotImplementedError def visit_Dict(self, node): return {self.visit(p[0]): self.visit(p[1]) for p in zip(node.keys, node.values)} def visit_Set(self, node): return {self.visit(e) for e in node.elts} def visit_List(self, node): return [self.visit(e) for e in node.elts] def visit_Tuple(self, node): return tuple([self.visit(e) for e in node.elts]) def visit_NameConstant(self, node): return node.value def visit_Ellipsis(self, node): # In Py3k only from ast import Ellipsis return Ellipsis def visit_Print(self, node): # In Py2k only raise SyntaxError("Absolutely not. Use __future__.") def visit_Str(self, node): return node.s def visit_Bytes(self, node): return node.s def visit_Num(self, node): return node.n class InterpreterSystem(object): """A bag of shared state.""" def __init__(self, path=None): self.modules = {} self.path = path or sys.path def handle_import(self, name, globals=None, locals=None, fromlist=(), level=0): log.debug(" Attempting to import '{}'".format(name)) if name not in self.modules: if name in sys.modules: log.debug(" Short-circuited from bootstrap sys.modules") self.modules[name] = sys.modules[name] else: name = name.replace(".", os.path.sep) for e in self.path: for ext in [ # ".flow", ".py", ]: if os.path.isdir(e): f = os.path.join(e, name + ext) log.debug(" Checking {}".format(f)) if os.path.exists(f): mod = self.load(f) self.modules[name] = mod.ns break elif os.path.isfile(e): # FIXME (arrdem 2021-05-31) raise RuntimeError( "Import from .zip/.whl/.egg archives aren't supported yet" ) else: self.modules[name] = __import__( name, globals, locals, fromlist, level ) return self.modules[name] def load(self, fname): with open(fname) as f: tree = ast.parse(f.read()) interp = ModuleInterpreter(self, fname, tree) interp.visit(tree) return interp def execute(self, fname): with open(fname) as f: tree = ast.parse(f.read()) interp = ModuleInterpreter(self, fname, tree) interp.ns["__name__"] = "__main__" self.modules["__main__"] = InterpModule(interp.ns) interp.visit(tree) return interp if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) InterpreterSystem().execute(sys.argv[1])