source/projects/datalog-shell/__main__.py
2021-05-15 11:34:32 -06:00

277 lines
7.8 KiB
Python
Executable file

#!/usr/bin/env python3
__doc__ = f"""
Datalog (py)
============
An interactive datalog interpreter with commands and persistence
Commands
~~~~~~~~
.help (this message)
.all display all tuples
.quit to exit the REPL
To exit, use control-c or control-d
The interpreter
~~~~~~~~~~~~~~~
The interpreter reads one line at a time from stdin.
Lines are either
- definitions (ending in .),
- queries (ending in ?)
- retractions (ending in !)
A definition may contain arbitrarily many datalog tuples and rules.
edge(a, b). edge(b, c). % A pair of definitions
⇒ edge(a, b). % The REPL's response that it has been committed
⇒ edge(b, c).
A query may contain definitions, but they exist only for the duration of the query.
edge(X, Y)? % A query which will enumerate all 2-edges
⇒ edge(a, b).
⇒ edge(b, c).
edge(c, d). edge(X, Y)? % A query with a local tuple
⇒ edge(a, b).
⇒ edge(b, c).
⇒ edge(c, d).
A retraction may contain only one tuple or clause, which will be expunged.
edge(a, b)! % This tuple is in our dataset
⇒ edge(a, b) % So deletion succeeds
edge(a, b)! % This tuple is no longer in our dataset
⇒ Ø % So deletion fails
"""
import argparse
import logging
import sys
from datalog.debris import Timing
from datalog.evaluator import select
from datalog.reader import pr_str, read_command, read_dataset
from datalog.types import (
CachedDataset,
Constant,
Dataset,
LVar,
PartlyIndexedDataset,
Rule,
TableIndexedDataset,
)
from prompt_toolkit import print_formatted_text, prompt, PromptSession
from prompt_toolkit.formatted_text import FormattedText
from prompt_toolkit.history import FileHistory
from prompt_toolkit.styles import Style
from yaspin import Spinner, yaspin
STYLE = Style.from_dict(
{
# User input (default text).
"": "",
"prompt": "ansigreen",
"time": "ansiyellow",
}
)
SPINNER = Spinner(["|", "/", "-", "\\"], 200)
class InterpreterInterrupt(Exception):
"""An exception used to break the prompt or evaluation."""
def print_(fmt, **kwargs):
print_formatted_text(FormattedText(fmt), **kwargs)
def print_db(db):
"""Render a database for debugging."""
for e in db.tuples():
print(f"{pr_str(e)}")
for r in db.rules():
print(f"{pr_str(r)}")
def main(args):
"""REPL entry point."""
if args.db_cls == "simple":
db_cls = Dataset
elif args.db_cls == "cached":
db_cls = CachedDataset
elif args.db_cls == "table":
db_cls = TableIndexedDataset
elif args.db_cls == "partly":
db_cls = PartlyIndexedDataset
print(f"Using dataset type {db_cls}")
session = PromptSession(history=FileHistory(".datalog.history"))
db = db_cls([], [])
if args.dbs:
for db_file in args.dbs:
try:
with open(db_file, "r") as f:
db = db.merge(read_dataset(f.read()))
print(f"Loaded {db_file} ...")
except Exception as e:
print("Internal error - {e}")
print(f"Unable to load db {db_file}, skipping")
while True:
try:
line = session.prompt([("class:prompt", ">>> ")], style=STYLE)
except (InterpreterInterrupt, KeyboardInterrupt):
continue
except EOFError:
break
if line == ".all":
op = ".all"
elif line == ".dbg":
op = ".dbg"
elif line == ".quit":
break
elif line in {".help", "help", "?", "??", "???"}:
print(__doc__)
continue
elif line.split(" ")[0] == ".log":
op = ".log"
else:
try:
op, val = read_command(line)
except Exception as e:
print(f"Got an unknown command or syntax error, can't tell which")
continue
# Definition merges on the DB
if op == ".all":
print_db(db)
# .dbg drops to a debugger shell so you can poke at the instance objects (database)
elif op == ".dbg":
import pdb
pdb.set_trace()
# .log sets the log level - badly
elif op == ".log":
level = line.split(" ")[1].upper()
try:
ch.setLevel(getattr(logging, level))
except BaseException:
print(f"Unknown log level {level}")
elif op == ".":
# FIXME (arrdem 2019-06-15):
# Syntax rules the parser doesn't impose...
try:
for rule in val.rules():
assert (
not rule.free_vars
), f"Rule contains free variables {rule.free_vars!r}"
for tuple in val.tuples():
assert not any(
isinstance(e, LVar) for e in tuple
), f"Tuples cannot contain lvars - {tuple!r}"
except BaseException as e:
print(f"Error: {e}")
continue
db = db.merge(val)
print_db(val)
# Queries execute - note that rules as queries have to be temporarily merged.
elif op == "?":
# In order to support ad-hoc rules (joins), we have to generate a transient "query" database
# by bolting the rule on as an overlay to the existing database. If of course we have a join.
#
# `val` was previously assumed to be the query pattern. Introduce `qdb`, now used as the
# database to query and "fix" `val` to be the temporary rule's pattern.
#
# We use a new db and db local so that the ephemeral rule doesn't persist unless the user
# later `.` defines it.
#
# Unfortunately doing this merge does nuke caches.
qdb = db
if isinstance(val, Rule):
qdb = db.merge(db_cls([], [val]))
val = val.pattern
with yaspin(SPINNER) as spinner:
with Timing() as t:
try:
results = list(select(qdb, val))
except KeyboardInterrupt:
print(f"Evaluation aborted after {t}")
continue
# It's kinda bogus to move sorting out but oh well
sorted(results)
for _results, _bindings in results:
_result = _results[0] # select only selects one tuple at a time
print(f"{pr_str(_result)}")
# So we can report empty sets explicitly.
if not results:
print("⇒ Ø")
print_([("class:time", f"Elapsed time - {t}")], style=STYLE)
# Retractions try to delete, but may fail.
elif op == "!":
if val in db.tuples() or val in [r.pattern for r in db.rules()]:
db = db_cls(
[u for u in db.tuples() if u != val],
[r for r in db.rules() if r.pattern != val],
)
print(f"{pr_str(val)}")
else:
print("⇒ Ø")
parser = argparse.ArgumentParser()
# Select which dataset type to use
parser.add_argument(
"--db-type",
choices=["simple", "cached", "table", "partly"],
help="Choose which DB to use (default partly)",
dest="db_cls",
default="partly",
)
parser.add_argument(
"--load-db", dest="dbs", action="append", help="Datalog files to load first."
)
if __name__ == "__main__":
args = parser.parse_args(sys.argv[1:])
logger = logging.getLogger("arrdem.datalog")
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
ch.setFormatter(formatter)
logger.addHandler(ch)
main(args)