This commit is contained in:
Reid 'arrdem' McKenzie 2021-05-15 11:34:32 -06:00
parent debc7996ee
commit bbae5ef63f
34 changed files with 956 additions and 886 deletions

View file

@ -25,7 +25,6 @@ setup(
"calf-read = calf.reader:main",
"calf-analyze = calf.analyzer:main",
"calf-compile = calf.compiler:main",
# Client/server stuff
"calf-client = calf.client:main",
"calf-server = calf.server:main",

View file

@ -7,7 +7,6 @@ from curses.textpad import Textbox, rectangle
def curse_repl(handle_buffer):
def handle(buff, count):
try:
return list(handle_buffer(buff, count)), None
@ -24,22 +23,25 @@ def curse_repl(handle_buffer):
maxy, maxx = stdscr.getmaxyx()
stdscr.clear()
stdscr.addstr(0, 0, "Enter example: (hit Ctrl-G to execute, Ctrl-C to exit)", curses.A_BOLD)
editwin = curses.newwin(5, maxx - 4,
2, 2)
rectangle(stdscr,
1, 1,
1 + 5 + 1, maxx - 2)
stdscr.addstr(
0,
0,
"Enter example: (hit Ctrl-G to execute, Ctrl-C to exit)",
curses.A_BOLD,
)
editwin = curses.newwin(5, maxx - 4, 2, 2)
rectangle(stdscr, 1, 1, 1 + 5 + 1, maxx - 2)
# Printing is part of the prompt
cur = 8
def putstr(str, x=0, attr=0):
# ya rly. I know exactly what I'm doing here
nonlocal cur
# This is how we handle going off the bottom of the scren lol
if cur < maxy:
stdscr.addstr(cur, x, str, attr)
cur += (len(str.split("\n")) or 1)
cur += len(str.split("\n")) or 1
for ex, buff, vals, err in reversed(examples):
putstr(f"Example {ex}:", attr=curses.A_BOLD)
@ -58,7 +60,7 @@ def curse_repl(handle_buffer):
elif vals:
putstr(" Values:")
for x, t in zip(range(1, 1<<64), vals):
for x, t in zip(range(1, 1 << 64), vals):
putstr(f" {x:>3}) " + repr(t))
putstr("")

View file

@ -28,34 +28,82 @@ COMMENT_PATTERN = r";(([^\n\r]*)(\n\r?)?)"
TOKENS = [
# Paren (noral) lists
(r"\(", "PAREN_LEFT",),
(r"\)", "PAREN_RIGHT",),
(
r"\(",
"PAREN_LEFT",
),
(
r"\)",
"PAREN_RIGHT",
),
# Bracket lists
(r"\[", "BRACKET_LEFT",),
(r"\]", "BRACKET_RIGHT",),
(
r"\[",
"BRACKET_LEFT",
),
(
r"\]",
"BRACKET_RIGHT",
),
# Brace lists (maps)
(r"\{", "BRACE_LEFT",),
(r"\}", "BRACE_RIGHT",),
(r"\^", "META",),
(r"'", "SINGLE_QUOTE",),
(STRING_PATTERN, "STRING",),
(r"#", "MACRO_DISPATCH",),
(
r"\{",
"BRACE_LEFT",
),
(
r"\}",
"BRACE_RIGHT",
),
(
r"\^",
"META",
),
(
r"'",
"SINGLE_QUOTE",
),
(
STRING_PATTERN,
"STRING",
),
(
r"#",
"MACRO_DISPATCH",
),
# Symbols
(SYMBOL_PATTERN, "SYMBOL",),
(
SYMBOL_PATTERN,
"SYMBOL",
),
# Numbers
(SIMPLE_INTEGER, "INTEGER",),
(FLOAT_PATTERN, "FLOAT",),
(
SIMPLE_INTEGER,
"INTEGER",
),
(
FLOAT_PATTERN,
"FLOAT",
),
# Keywords
#
# Note: this is a dirty f'n hack in that in order for keywords to work, ":"
# has to be defined to be a valid keyword.
(r":" + SYMBOL_PATTERN + "?", "KEYWORD",),
(
r":" + SYMBOL_PATTERN + "?",
"KEYWORD",
),
# Whitespace
#
# Note that the whitespace token will contain at most one newline
(r"(\n\r?|[,\t ]*)", "WHITESPACE",),
(
r"(\n\r?|[,\t ]*)",
"WHITESPACE",
),
# Comment
(COMMENT_PATTERN, "COMMENT",),
(
COMMENT_PATTERN,
"COMMENT",
),
# Strings
(r'"(?P<body>(?:[^\"]|\.)*)"', "STRING"),
]

View file

@ -8,7 +8,6 @@ parsing, linting or other use.
import io
import re
import sys
from calf.token import CalfToken
from calf.io.reader import PeekPosReader

View file

@ -12,48 +12,47 @@ from collections import namedtuple
class CalfLoaderConfig(namedtuple("CalfLoaderConfig", ["paths"])):
"""
"""
""""""
class CalfDelayedPackage(
namedtuple("CalfDelayedPackage", ["name", "version", "metadata", "path"])
):
"""
This structure represents the delay of loading a packaage.
This structure represents the delay of loading a packaage.
Rather than eagerly analyze packages, it may be profitable to use lazy loading / lazy resolution
of symbols. It may also be possible to cache analyzing some packages.
"""
Rather than eagerly analyze packages, it may be profitable to use lazy loading / lazy resolution
of symbols. It may also be possible to cache analyzing some packages.
"""
class CalfPackage(
namedtuple("CalfPackage", ["name", "version", "metadata", "modules"])
):
"""
This structure represents the result of forcing the load of a package, and is the product of
either loading a package directly, or a package becoming a direct dependency and being forced.
"""
This structure represents the result of forcing the load of a package, and is the product of
either loading a package directly, or a package becoming a direct dependency and being forced.
"""
def parse_package_requirement(config, env, requirement):
"""
:param config:
:param env:
:param requirement:
:returns:
:param config:
:param env:
:param requirement:
:returns:
"""
"""
def analyze_package(config, env, package):
"""
:param config:
:param env:
:param module:
:returns:
:param config:
:param env:
:param module:
:returns:
Given a loader configuration and an environment to load into, analyzes the requested package,
returning an updated environment.
"""
Given a loader configuration and an environment to load into, analyzes the requested package,
returning an updated environment.
"""

View file

@ -2,11 +2,8 @@
The Calf parser.
"""
from collections import namedtuple
from itertools import tee
import logging
import sys
from typing import NamedTuple, Callable
from calf.lexer import CalfLexer, lex_buffer, lex_file
from calf.grammar import MATCHING, WHITESPACE_TYPES
@ -45,17 +42,18 @@ def mk_dict(contents, open=None, close=None):
close.start_position,
)
def mk_str(token):
buff = token.value
if buff.startswith('"""') and not buff.endswith('"""'):
raise ValueError('Unterminated tripple quote string')
raise ValueError("Unterminated tripple quote string")
elif buff.startswith('"') and not buff.endswith('"'):
raise ValueError('Unterminated quote string')
raise ValueError("Unterminated quote string")
elif not buff.startswith('"') or buff == '"' or buff == '"""':
raise ValueError('Illegal string')
raise ValueError("Illegal string")
if buff.startswith('"""'):
buff = buff[3:-3]
@ -114,15 +112,17 @@ class CalfMissingCloseParseError(CalfParseError):
def __init__(self, expected_close_token, open_token):
super(CalfMissingCloseParseError, self).__init__(
f"expected {expected_close_token} starting from {open_token}, got end of file.",
open_token
open_token,
)
self.expected_close_token = expected_close_token
def parse_stream(stream,
discard_whitespace: bool = True,
discard_comments: bool = True,
stack: list = None):
def parse_stream(
stream,
discard_whitespace: bool = True,
discard_comments: bool = True,
stack: list = None,
):
"""Parses a token stream, producing a lazy sequence of all read top level forms.
If `discard_whitespace` is truthy, then no WHITESPACE tokens will be emitted
@ -134,11 +134,10 @@ def parse_stream(stream,
stack = stack or []
def recur(_stack = None):
yield from parse_stream(stream,
discard_whitespace,
discard_comments,
_stack or stack)
def recur(_stack=None):
yield from parse_stream(
stream, discard_whitespace, discard_comments, _stack or stack
)
for token in stream:
# Whitespace discarding
@ -205,7 +204,9 @@ def parse_stream(stream,
# Case of maybe matching something else, but definitely being wrong
else:
matching = next(reversed([t[1] for t in stack if t[0] == token.type]), None)
matching = next(
reversed([t[1] for t in stack if t[0] == token.type]), None
)
raise CalfUnexpectedCloseParseError(token, matching)
# Atoms
@ -216,18 +217,14 @@ def parse_stream(stream,
yield token
def parse_buffer(buffer,
discard_whitespace=True,
discard_comments=True):
def parse_buffer(buffer, discard_whitespace=True, discard_comments=True):
"""
Parses a buffer, producing a lazy sequence of all parsed level forms.
Propagates all errors.
"""
yield from parse_stream(lex_buffer(buffer),
discard_whitespace,
discard_comments)
yield from parse_stream(lex_buffer(buffer), discard_whitespace, discard_comments)
def parse_file(file):

View file

@ -13,6 +13,7 @@ from calf.parser import parse_stream
from calf.token import *
from calf.types import *
class CalfReader(object):
def handle_keyword(self, t: CalfToken) -> Any:
"""Convert a token to an Object value for a symbol.
@ -79,8 +80,7 @@ class CalfReader(object):
return Vector.of(self.read(t.value))
elif isinstance(t, CalfDictToken):
return Map.of([(self.read1(k), self.read1(v))
for k, v in t.items()])
return Map.of([(self.read1(k), self.read1(v)) for k, v in t.items()])
# Magical pairwise stuff
elif isinstance(t, CalfQuoteToken):
@ -119,28 +119,21 @@ class CalfReader(object):
yield self.read1(t)
def read_stream(stream,
reader: CalfReader = None):
"""Read from a stream of parsed tokens.
"""
def read_stream(stream, reader: CalfReader = None):
"""Read from a stream of parsed tokens."""
reader = reader or CalfReader()
yield from reader.read(stream)
def read_buffer(buffer):
"""Read from a buffer, producing a lazy sequence of all top level forms.
"""
"""Read from a buffer, producing a lazy sequence of all top level forms."""
yield from read_stream(parse_stream(lex_buffer(buffer)))
def read_file(file):
"""Read from a file, producing a lazy sequence of all top level forms.
"""
"""Read from a file, producing a lazy sequence of all top level forms."""
yield from read_stream(parse_stream(lex_file(file)))
@ -151,6 +144,8 @@ def main():
from calf.cursedrepl import curse_repl
def handle_buffer(buff, count):
return list(read_stream(parse_stream(lex_buffer(buff, source=f"<Example {count}>"))))
return list(
read_stream(parse_stream(lex_buffer(buff, source=f"<Example {count}>")))
)
curse_repl(handle_buffer)

View file

@ -8,23 +8,24 @@ from calf import grammar as cg
from conftest import parametrize
@parametrize('ex', [
# Proper strings
'""',
'"foo bar"',
'"foo\n bar\n\r qux"',
'"foo\\"bar"',
'""""""',
'"""foo bar baz"""',
'"""foo "" "" "" bar baz"""',
# Unterminated string cases
'"',
'"f',
'"foo bar',
'"foo\\" bar',
'"""foo bar baz',
])
@parametrize(
"ex",
[
# Proper strings
'""',
'"foo bar"',
'"foo\n bar\n\r qux"',
'"foo\\"bar"',
'""""""',
'"""foo bar baz"""',
'"""foo "" "" "" bar baz"""',
# Unterminated string cases
'"',
'"f',
'"foo bar',
'"foo\\" bar',
'"""foo bar baz',
],
)
def test_match_string(ex):
assert re.fullmatch(cg.STRING_PATTERN, ex)

View file

@ -20,23 +20,62 @@ def lex_single_token(buffer):
@parametrize(
"text,token_type",
[
("(", "PAREN_LEFT",),
(")", "PAREN_RIGHT",),
("[", "BRACKET_LEFT",),
("]", "BRACKET_RIGHT",),
("{", "BRACE_LEFT",),
("}", "BRACE_RIGHT",),
("^", "META",),
("#", "MACRO_DISPATCH",),
(
"(",
"PAREN_LEFT",
),
(
")",
"PAREN_RIGHT",
),
(
"[",
"BRACKET_LEFT",
),
(
"]",
"BRACKET_RIGHT",
),
(
"{",
"BRACE_LEFT",
),
(
"}",
"BRACE_RIGHT",
),
(
"^",
"META",
),
(
"#",
"MACRO_DISPATCH",
),
("'", "SINGLE_QUOTE"),
("foo", "SYMBOL",),
(
"foo",
"SYMBOL",
),
("foo/bar", "SYMBOL"),
(":foo", "KEYWORD",),
(":foo/bar", "KEYWORD",),
(" ,,\t ,, \t", "WHITESPACE",),
(
":foo",
"KEYWORD",
),
(
":foo/bar",
"KEYWORD",
),
(
" ,,\t ,, \t",
"WHITESPACE",
),
("\n\r", "WHITESPACE"),
("\n", "WHITESPACE"),
(" , ", "WHITESPACE",),
(
" , ",
"WHITESPACE",
),
("; this is a sample comment\n", "COMMENT"),
('"foo"', "STRING"),
('"foo bar baz"', "STRING"),

View file

@ -8,12 +8,15 @@ from conftest import parametrize
import pytest
@parametrize("text", [
'"',
'"foo bar',
'"""foo bar',
'"""foo bar"',
])
@parametrize(
"text",
[
'"',
'"foo bar',
'"""foo bar',
'"""foo bar"',
],
)
def test_bad_strings_raise(text):
"""Tests asserting we won't let obviously bad strings fly."""
# FIXME (arrdem 2021-03-13):
@ -22,81 +25,89 @@ def test_bad_strings_raise(text):
next(cp.parse_buffer(text))
@parametrize("text", [
"[1.0",
"(1.0",
"{1.0",
])
@parametrize(
"text",
[
"[1.0",
"(1.0",
"{1.0",
],
)
def test_unterminated_raises(text):
"""Tests asserting that we don't let unterminated collections parse."""
with pytest.raises(cp.CalfMissingCloseParseError):
next(cp.parse_buffer(text))
@parametrize("text", [
"[{]",
"[(]",
"({)",
"([)",
"{(}",
"{[}",
])
@parametrize(
"text",
[
"[{]",
"[(]",
"({)",
"([)",
"{(}",
"{[}",
],
)
def test_unbalanced_raises(text):
"""Tests asserting that we don't let missmatched collections parse."""
with pytest.raises(cp.CalfUnexpectedCloseParseError):
next(cp.parse_buffer(text))
@parametrize("buff, value", [
('"foo"', "foo"),
('"foo\tbar"', "foo\tbar"),
('"foo\n\rbar"', "foo\n\rbar"),
('"foo\\"bar\\""', "foo\"bar\""),
('"""foo"""', 'foo'),
('"""foo"bar"baz"""', 'foo"bar"baz'),
])
@parametrize(
"buff, value",
[
('"foo"', "foo"),
('"foo\tbar"', "foo\tbar"),
('"foo\n\rbar"', "foo\n\rbar"),
('"foo\\"bar\\""', 'foo"bar"'),
('"""foo"""', "foo"),
('"""foo"bar"baz"""', 'foo"bar"baz'),
],
)
def test_strings_round_trip(buff, value):
assert next(cp.parse_buffer(buff)) == value
@parametrize('text, element_types', [
# Integers
("(1)", ["INTEGER"]),
("( 1 )", ["INTEGER"]),
("(,1,)", ["INTEGER"]),
("(1\n)", ["INTEGER"]),
("(\n1\n)", ["INTEGER"]),
("(1, 2, 3, 4)", ["INTEGER", "INTEGER", "INTEGER", "INTEGER"]),
# Floats
("(1.0)", ["FLOAT"]),
("(1.0e0)", ["FLOAT"]),
("(1e0)", ["FLOAT"]),
("(1e0)", ["FLOAT"]),
# Symbols
("(foo)", ["SYMBOL"]),
("(+)", ["SYMBOL"]),
("(-)", ["SYMBOL"]),
("(*)", ["SYMBOL"]),
("(foo-bar)", ["SYMBOL"]),
("(+foo-bar+)", ["SYMBOL"]),
("(+foo-bar+)", ["SYMBOL"]),
("( foo bar )", ["SYMBOL", "SYMBOL"]),
# Keywords
("(:foo)", ["KEYWORD"]),
("( :foo )", ["KEYWORD"]),
("(\n:foo\n)", ["KEYWORD"]),
("(,:foo,)", ["KEYWORD"]),
("(:foo :bar)", ["KEYWORD", "KEYWORD"]),
("(:foo :bar 1)", ["KEYWORD", "KEYWORD", "INTEGER"]),
# Strings
('("foo", "bar", "baz")', ["STRING", "STRING", "STRING"]),
# Lists
('([] [] ())', ["SQLIST", "SQLIST", "LIST"]),
])
@parametrize(
"text, element_types",
[
# Integers
("(1)", ["INTEGER"]),
("( 1 )", ["INTEGER"]),
("(,1,)", ["INTEGER"]),
("(1\n)", ["INTEGER"]),
("(\n1\n)", ["INTEGER"]),
("(1, 2, 3, 4)", ["INTEGER", "INTEGER", "INTEGER", "INTEGER"]),
# Floats
("(1.0)", ["FLOAT"]),
("(1.0e0)", ["FLOAT"]),
("(1e0)", ["FLOAT"]),
("(1e0)", ["FLOAT"]),
# Symbols
("(foo)", ["SYMBOL"]),
("(+)", ["SYMBOL"]),
("(-)", ["SYMBOL"]),
("(*)", ["SYMBOL"]),
("(foo-bar)", ["SYMBOL"]),
("(+foo-bar+)", ["SYMBOL"]),
("(+foo-bar+)", ["SYMBOL"]),
("( foo bar )", ["SYMBOL", "SYMBOL"]),
# Keywords
("(:foo)", ["KEYWORD"]),
("( :foo )", ["KEYWORD"]),
("(\n:foo\n)", ["KEYWORD"]),
("(,:foo,)", ["KEYWORD"]),
("(:foo :bar)", ["KEYWORD", "KEYWORD"]),
("(:foo :bar 1)", ["KEYWORD", "KEYWORD", "INTEGER"]),
# Strings
('("foo", "bar", "baz")', ["STRING", "STRING", "STRING"]),
# Lists
("([] [] ())", ["SQLIST", "SQLIST", "LIST"]),
],
)
def test_parse_list(text, element_types):
"""Test we can parse various lists of contents."""
l_t = next(cp.parse_buffer(text, discard_whitespace=True))
@ -104,45 +115,43 @@ def test_parse_list(text, element_types):
assert [t.type for t in l_t] == element_types
@parametrize('text, element_types', [
# Integers
("[1]", ["INTEGER"]),
("[ 1 ]", ["INTEGER"]),
("[,1,]", ["INTEGER"]),
("[1\n]", ["INTEGER"]),
("[\n1\n]", ["INTEGER"]),
("[1, 2, 3, 4]", ["INTEGER", "INTEGER", "INTEGER", "INTEGER"]),
# Floats
("[1.0]", ["FLOAT"]),
("[1.0e0]", ["FLOAT"]),
("[1e0]", ["FLOAT"]),
("[1e0]", ["FLOAT"]),
# Symbols
("[foo]", ["SYMBOL"]),
("[+]", ["SYMBOL"]),
("[-]", ["SYMBOL"]),
("[*]", ["SYMBOL"]),
("[foo-bar]", ["SYMBOL"]),
("[+foo-bar+]", ["SYMBOL"]),
("[+foo-bar+]", ["SYMBOL"]),
("[ foo bar ]", ["SYMBOL", "SYMBOL"]),
# Keywords
("[:foo]", ["KEYWORD"]),
("[ :foo ]", ["KEYWORD"]),
("[\n:foo\n]", ["KEYWORD"]),
("[,:foo,]", ["KEYWORD"]),
("[:foo :bar]", ["KEYWORD", "KEYWORD"]),
("[:foo :bar 1]", ["KEYWORD", "KEYWORD", "INTEGER"]),
# Strings
('["foo", "bar", "baz"]', ["STRING", "STRING", "STRING"]),
# Lists
('[[] [] ()]', ["SQLIST", "SQLIST", "LIST"]),
])
@parametrize(
"text, element_types",
[
# Integers
("[1]", ["INTEGER"]),
("[ 1 ]", ["INTEGER"]),
("[,1,]", ["INTEGER"]),
("[1\n]", ["INTEGER"]),
("[\n1\n]", ["INTEGER"]),
("[1, 2, 3, 4]", ["INTEGER", "INTEGER", "INTEGER", "INTEGER"]),
# Floats
("[1.0]", ["FLOAT"]),
("[1.0e0]", ["FLOAT"]),
("[1e0]", ["FLOAT"]),
("[1e0]", ["FLOAT"]),
# Symbols
("[foo]", ["SYMBOL"]),
("[+]", ["SYMBOL"]),
("[-]", ["SYMBOL"]),
("[*]", ["SYMBOL"]),
("[foo-bar]", ["SYMBOL"]),
("[+foo-bar+]", ["SYMBOL"]),
("[+foo-bar+]", ["SYMBOL"]),
("[ foo bar ]", ["SYMBOL", "SYMBOL"]),
# Keywords
("[:foo]", ["KEYWORD"]),
("[ :foo ]", ["KEYWORD"]),
("[\n:foo\n]", ["KEYWORD"]),
("[,:foo,]", ["KEYWORD"]),
("[:foo :bar]", ["KEYWORD", "KEYWORD"]),
("[:foo :bar 1]", ["KEYWORD", "KEYWORD", "INTEGER"]),
# Strings
('["foo", "bar", "baz"]', ["STRING", "STRING", "STRING"]),
# Lists
("[[] [] ()]", ["SQLIST", "SQLIST", "LIST"]),
],
)
def test_parse_sqlist(text, element_types):
"""Test we can parse various 'square' lists of contents."""
l_t = next(cp.parse_buffer(text, discard_whitespace=True))
@ -150,41 +159,21 @@ def test_parse_sqlist(text, element_types):
assert [t.type for t in l_t] == element_types
@parametrize('text, element_pairs', [
("{}",
[]),
("{:foo 1}",
[["KEYWORD", "INTEGER"]]),
("{:foo 1, :bar 2}",
[["KEYWORD", "INTEGER"],
["KEYWORD", "INTEGER"]]),
("{foo 1, bar 2}",
[["SYMBOL", "INTEGER"],
["SYMBOL", "INTEGER"]]),
("{foo 1, bar -2}",
[["SYMBOL", "INTEGER"],
["SYMBOL", "INTEGER"]]),
("{foo 1, bar -2e0}",
[["SYMBOL", "INTEGER"],
["SYMBOL", "FLOAT"]]),
("{foo ()}",
[["SYMBOL", "LIST"]]),
("{foo []}",
[["SYMBOL", "SQLIST"]]),
("{foo {}}",
[["SYMBOL", "DICT"]]),
('{"foo" {}}',
[["STRING", "DICT"]])
])
@parametrize(
"text, element_pairs",
[
("{}", []),
("{:foo 1}", [["KEYWORD", "INTEGER"]]),
("{:foo 1, :bar 2}", [["KEYWORD", "INTEGER"], ["KEYWORD", "INTEGER"]]),
("{foo 1, bar 2}", [["SYMBOL", "INTEGER"], ["SYMBOL", "INTEGER"]]),
("{foo 1, bar -2}", [["SYMBOL", "INTEGER"], ["SYMBOL", "INTEGER"]]),
("{foo 1, bar -2e0}", [["SYMBOL", "INTEGER"], ["SYMBOL", "FLOAT"]]),
("{foo ()}", [["SYMBOL", "LIST"]]),
("{foo []}", [["SYMBOL", "SQLIST"]]),
("{foo {}}", [["SYMBOL", "DICT"]]),
('{"foo" {}}', [["STRING", "DICT"]]),
],
)
def test_parse_dict(text, element_pairs):
"""Test we can parse various mappings."""
d_t = next(cp.parse_buffer(text, discard_whitespace=True))
@ -192,27 +181,25 @@ def test_parse_dict(text, element_pairs):
assert [[t.type for t in pair] for pair in d_t.value] == element_pairs
@parametrize("text", [
"{1}",
"{1, 2, 3}",
"{:foo}",
"{:foo :bar :baz}"
])
@parametrize("text", ["{1}", "{1, 2, 3}", "{:foo}", "{:foo :bar :baz}"])
def test_parse_bad_dict(text):
"""Assert that dicts with missmatched pairs don't parse."""
with pytest.raises(Exception):
next(cp.parse_buffer(text))
@parametrize("text", [
"()",
"(1 1.1 1e2 -2 foo :foo foo/bar :foo/bar [{},])",
"{:foo bar, :baz [:qux]}",
"'foo",
"'[foo bar :baz 'qux, {}]",
"#foo []",
"^{} bar",
])
@parametrize(
"text",
[
"()",
"(1 1.1 1e2 -2 foo :foo foo/bar :foo/bar [{},])",
"{:foo bar, :baz [:qux]}",
"'foo",
"'[foo bar :baz 'qux, {}]",
"#foo []",
"^{} bar",
],
)
def test_examples(text):
"""Shotgun examples showing we can parse some stuff."""

View file

@ -5,18 +5,22 @@ from conftest import parametrize
from calf.reader import read_buffer
@parametrize('text', [
"()",
"[]",
"[[[[[[[[[]]]]]]]]]",
"{1 {2 {}}}",
'"foo"',
"foo",
"'foo",
"^foo bar",
"^:foo bar",
"{\"foo\" '([:bar ^:foo 'baz 3.14159e0])}",
"[:foo bar 'baz lo/l, 1, 1.2. 1e-5 -1e2]",
])
@parametrize(
"text",
[
"()",
"[]",
"[[[[[[[[[]]]]]]]]]",
"{1 {2 {}}}",
'"foo"',
"foo",
"'foo",
"^foo bar",
"^:foo bar",
"{\"foo\" '([:bar ^:foo 'baz 3.14159e0])}",
"[:foo bar 'baz lo/l, 1, 1.2. 1e-5 -1e2]",
],
)
def test_read(text):
assert list(read_buffer(text))

View file

@ -58,13 +58,13 @@ from datalog.debris import Timing
from datalog.evaluator import select
from datalog.reader import pr_str, read_command, read_dataset
from datalog.types import (
CachedDataset,
Constant,
Dataset,
LVar,
PartlyIndexedDataset,
Rule,
TableIndexedDataset
CachedDataset,
Constant,
Dataset,
LVar,
PartlyIndexedDataset,
Rule,
TableIndexedDataset,
)
from prompt_toolkit import print_formatted_text, prompt, PromptSession
@ -74,190 +74,204 @@ from prompt_toolkit.styles import Style
from yaspin import Spinner, yaspin
STYLE = Style.from_dict({
# User input (default text).
"": "",
"prompt": "ansigreen",
"time": "ansiyellow"
})
STYLE = Style.from_dict(
{
# User input (default text).
"": "",
"prompt": "ansigreen",
"time": "ansiyellow",
}
)
SPINNER = Spinner(["|", "/", "-", "\\"], 200)
class InterpreterInterrupt(Exception):
"""An exception used to break the prompt or evaluation."""
"""An exception used to break the prompt or evaluation."""
def print_(fmt, **kwargs):
print_formatted_text(FormattedText(fmt), **kwargs)
print_formatted_text(FormattedText(fmt), **kwargs)
def print_db(db):
"""Render a database for debugging."""
"""Render a database for debugging."""
for e in db.tuples():
print(f"{pr_str(e)}")
for e in db.tuples():
print(f"{pr_str(e)}")
for r in db.rules():
print(f"{pr_str(r)}")
for r in db.rules():
print(f"{pr_str(r)}")
def main(args):
"""REPL entry point."""
"""REPL entry point."""
if args.db_cls == "simple":
db_cls = Dataset
elif args.db_cls == "cached":
db_cls = CachedDataset
elif args.db_cls == "table":
db_cls = TableIndexedDataset
elif args.db_cls == "partly":
db_cls = PartlyIndexedDataset
if args.db_cls == "simple":
db_cls = Dataset
elif args.db_cls == "cached":
db_cls = CachedDataset
elif args.db_cls == "table":
db_cls = TableIndexedDataset
elif args.db_cls == "partly":
db_cls = PartlyIndexedDataset
print(f"Using dataset type {db_cls}")
print(f"Using dataset type {db_cls}")
session = PromptSession(history=FileHistory(".datalog.history"))
db = db_cls([], [])
session = PromptSession(history=FileHistory(".datalog.history"))
db = db_cls([], [])
if args.dbs:
for db_file in args.dbs:
try:
with open(db_file, "r") as f:
db = db.merge(read_dataset(f.read()))
print(f"Loaded {db_file} ...")
except Exception as e:
print("Internal error - {e}")
print(f"Unable to load db {db_file}, skipping")
if args.dbs:
for db_file in args.dbs:
try:
with open(db_file, "r") as f:
db = db.merge(read_dataset(f.read()))
print(f"Loaded {db_file} ...")
except Exception as e:
print("Internal error - {e}")
print(f"Unable to load db {db_file}, skipping")
while True:
try:
line = session.prompt([("class:prompt", ">>> ")], style=STYLE)
except (InterpreterInterrupt, KeyboardInterrupt):
continue
except EOFError:
break
while True:
try:
line = session.prompt([("class:prompt", ">>> ")], style=STYLE)
except (InterpreterInterrupt, KeyboardInterrupt):
continue
except EOFError:
break
if line == ".all":
op = ".all"
elif line == ".dbg":
op = ".dbg"
elif line == ".quit":
break
if line == ".all":
op = ".all"
elif line == ".dbg":
op = ".dbg"
elif line == ".quit":
break
elif line in {".help", "help", "?", "??", "???"}:
print(__doc__)
continue
elif line.split(" ")[0] == ".log":
op = ".log"
else:
try:
op, val = read_command(line)
except Exception as e:
print(f"Got an unknown command or syntax error, can't tell which")
continue
# Definition merges on the DB
if op == ".all":
print_db(db)
# .dbg drops to a debugger shell so you can poke at the instance objects (database)
elif op == ".dbg":
import pdb
pdb.set_trace()
# .log sets the log level - badly
elif op == ".log":
level = line.split(" ")[1].upper()
try:
ch.setLevel(getattr(logging, level))
except BaseException:
print(f"Unknown log level {level}")
elif op == ".":
# FIXME (arrdem 2019-06-15):
# Syntax rules the parser doesn't impose...
try:
for rule in val.rules():
assert not rule.free_vars, f"Rule contains free variables {rule.free_vars!r}"
for tuple in val.tuples():
assert not any(isinstance(e, LVar) for e in tuple), f"Tuples cannot contain lvars - {tuple!r}"
except BaseException as e:
print(f"Error: {e}")
continue
db = db.merge(val)
print_db(val)
# Queries execute - note that rules as queries have to be temporarily merged.
elif op == "?":
# In order to support ad-hoc rules (joins), we have to generate a transient "query" database
# by bolting the rule on as an overlay to the existing database. If of course we have a join.
#
# `val` was previously assumed to be the query pattern. Introduce `qdb`, now used as the
# database to query and "fix" `val` to be the temporary rule's pattern.
#
# We use a new db and db local so that the ephemeral rule doesn't persist unless the user
# later `.` defines it.
#
# Unfortunately doing this merge does nuke caches.
qdb = db
if isinstance(val, Rule):
qdb = db.merge(db_cls([], [val]))
val = val.pattern
with yaspin(SPINNER) as spinner:
with Timing() as t:
try:
results = list(select(qdb, val))
except KeyboardInterrupt:
print(f"Evaluation aborted after {t}")
elif line in {".help", "help", "?", "??", "???"}:
print(__doc__)
continue
# It's kinda bogus to move sorting out but oh well
sorted(results)
elif line.split(" ")[0] == ".log":
op = ".log"
for _results, _bindings in results:
_result = _results[0] # select only selects one tuple at a time
print(f"{pr_str(_result)}")
else:
try:
op, val = read_command(line)
except Exception as e:
print(f"Got an unknown command or syntax error, can't tell which")
continue
# So we can report empty sets explicitly.
if not results:
print("⇒ Ø")
# Definition merges on the DB
if op == ".all":
print_db(db)
print_([("class:time", f"Elapsed time - {t}")], style=STYLE)
# .dbg drops to a debugger shell so you can poke at the instance objects (database)
elif op == ".dbg":
import pdb
# Retractions try to delete, but may fail.
elif op == "!":
if val in db.tuples() or val in [r.pattern for r in db.rules()]:
db = db_cls([u for u in db.tuples() if u != val],
[r for r in db.rules() if r.pattern != val])
print(f"{pr_str(val)}")
else:
print("⇒ Ø")
pdb.set_trace()
# .log sets the log level - badly
elif op == ".log":
level = line.split(" ")[1].upper()
try:
ch.setLevel(getattr(logging, level))
except BaseException:
print(f"Unknown log level {level}")
elif op == ".":
# FIXME (arrdem 2019-06-15):
# Syntax rules the parser doesn't impose...
try:
for rule in val.rules():
assert (
not rule.free_vars
), f"Rule contains free variables {rule.free_vars!r}"
for tuple in val.tuples():
assert not any(
isinstance(e, LVar) for e in tuple
), f"Tuples cannot contain lvars - {tuple!r}"
except BaseException as e:
print(f"Error: {e}")
continue
db = db.merge(val)
print_db(val)
# Queries execute - note that rules as queries have to be temporarily merged.
elif op == "?":
# In order to support ad-hoc rules (joins), we have to generate a transient "query" database
# by bolting the rule on as an overlay to the existing database. If of course we have a join.
#
# `val` was previously assumed to be the query pattern. Introduce `qdb`, now used as the
# database to query and "fix" `val` to be the temporary rule's pattern.
#
# We use a new db and db local so that the ephemeral rule doesn't persist unless the user
# later `.` defines it.
#
# Unfortunately doing this merge does nuke caches.
qdb = db
if isinstance(val, Rule):
qdb = db.merge(db_cls([], [val]))
val = val.pattern
with yaspin(SPINNER) as spinner:
with Timing() as t:
try:
results = list(select(qdb, val))
except KeyboardInterrupt:
print(f"Evaluation aborted after {t}")
continue
# It's kinda bogus to move sorting out but oh well
sorted(results)
for _results, _bindings in results:
_result = _results[0] # select only selects one tuple at a time
print(f"{pr_str(_result)}")
# So we can report empty sets explicitly.
if not results:
print("⇒ Ø")
print_([("class:time", f"Elapsed time - {t}")], style=STYLE)
# Retractions try to delete, but may fail.
elif op == "!":
if val in db.tuples() or val in [r.pattern for r in db.rules()]:
db = db_cls(
[u for u in db.tuples() if u != val],
[r for r in db.rules() if r.pattern != val],
)
print(f"{pr_str(val)}")
else:
print("⇒ Ø")
parser = argparse.ArgumentParser()
# Select which dataset type to use
parser.add_argument("--db-type",
choices=["simple", "cached", "table", "partly"],
help="Choose which DB to use (default partly)",
dest="db_cls",
default="partly")
parser.add_argument(
"--db-type",
choices=["simple", "cached", "table", "partly"],
help="Choose which DB to use (default partly)",
dest="db_cls",
default="partly",
)
parser.add_argument("--load-db", dest="dbs", action="append",
help="Datalog files to load first.")
parser.add_argument(
"--load-db", dest="dbs", action="append", help="Datalog files to load first."
)
if __name__ == "__main__":
args = parser.parse_args(sys.argv[1:])
logger = logging.getLogger("arrdem.datalog")
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
ch.setFormatter(formatter)
logger.addHandler(ch)
main(args)
args = parser.parse_args(sys.argv[1:])
logger = logging.getLogger("arrdem.datalog")
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
ch.setFormatter(formatter)
logger.addHandler(ch)
main(args)

View file

@ -23,10 +23,7 @@ setup(
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
],
scripts=[
"bin/datalog"
],
scripts=["bin/datalog"],
install_requires=[
"arrdem.datalog~=2.0.0",
"prompt_toolkit==2.0.9",

View file

@ -9,17 +9,17 @@ from uuid import uuid4 as uuid
with open("graph.dtl", "w") as f:
nodes = []
nodes = []
# Generate 10k edges
for i in range(10000):
if nodes:
from_node = choice(nodes)
else:
from_node = uuid()
# Generate 10k edges
for i in range(10000):
if nodes:
from_node = choice(nodes)
else:
from_node = uuid()
to_node = uuid()
to_node = uuid()
nodes.append(to_node)
nodes.append(to_node)
f.write(f"edge({str(from_node)!r}, {str(to_node)!r}).\n")
f.write(f"edge({str(from_node)!r}, {str(to_node)!r}).\n")

View file

@ -26,5 +26,7 @@ setup(
],
# Package setup
package_dir={"": "src/python"},
packages=["datalog",],
packages=[
"datalog",
],
)

View file

@ -16,8 +16,8 @@ def constexpr_p(expr):
class Timing(object):
"""
A context manager object which records how long the context took.
"""
A context manager object which records how long the context took.
"""
def __init__(self):
self.start = None
@ -36,8 +36,8 @@ class Timing(object):
def __call__(self):
"""
If the context is exited, return its duration. Otherwise return the duration "so far".
"""
If the context is exited, return its duration. Otherwise return the duration "so far".
"""
from datetime import datetime

View file

@ -22,9 +22,9 @@ def read(text: str, db_cls=PartlyIndexedDataset):
def q(t: Tuple[str]) -> LTuple:
"""Helper for writing terse queries.
Takes a tuple of strings, and interprets them as a logic tuple.
So you don't have to write the logic tuple out by hand.
"""
Takes a tuple of strings, and interprets them as a logic tuple.
So you don't have to write the logic tuple out by hand.
"""
def _x(s: str):
if s[0].isupper():
@ -50,38 +50,38 @@ def __result(results_bindings):
def select(db: Dataset, query: Tuple[str], bindings=None) -> Sequence[Tuple]:
"""Helper for interpreting tuples of strings as a query, and returning simplified results.
Executes your query, returning matching full tuples.
"""
Executes your query, returning matching full tuples.
"""
return __mapv(__result, __select(db, q(query), bindings=bindings))
def join(db: Dataset, query: Sequence[Tuple[str]], bindings=None) -> Sequence[dict]:
"""Helper for interpreting a bunch of tuples of strings as a join query, and returning simplified
results.
results.
Executes the query clauses as a join, returning a sequence of tuples and binding mappings such
that the join constraints are simultaneously satisfied.
Executes the query clauses as a join, returning a sequence of tuples and binding mappings such
that the join constraints are simultaneously satisfied.
>>> db = read('''
... edge(a, b).
... edge(b, c).
... edge(c, d).
... ''')
>>> join(db, [
... ('edge', 'A', 'B'),
... ('edge', 'B', 'C')
... ])
[((('edge', 'a', 'b'),
('edge', 'b', 'c')),
{'A': 'a', 'B': 'b', 'C': 'c'}),
((('edge', 'b', 'c'),
('edge', 'c', 'd')),
{'A': 'b', 'B': 'c', 'C': 'd'}),
((('edge', 'c', 'd'),
('edge', 'd', 'f')),
{'A': 'c', 'B': 'd', 'C': 'f'})]
"""
>>> db = read('''
... edge(a, b).
... edge(b, c).
... edge(c, d).
... ''')
>>> join(db, [
... ('edge', 'A', 'B'),
... ('edge', 'B', 'C')
... ])
[((('edge', 'a', 'b'),
('edge', 'b', 'c')),
{'A': 'a', 'B': 'b', 'C': 'c'}),
((('edge', 'b', 'c'),
('edge', 'c', 'd')),
{'A': 'b', 'B': 'c', 'C': 'd'}),
((('edge', 'c', 'd'),
('edge', 'd', 'f')),
{'A': 'c', 'B': 'd', 'C': 'f'})]
"""
return __mapv(__result, __join(db, [q(c) for c in query], bindings=bindings))

View file

@ -20,8 +20,8 @@ from datalog.types import (
def match(tuple, expr, bindings=None):
"""Attempt to construct lvar bindings from expr such that tuple and expr equate.
If the match is successful, return the binding map, otherwise return None.
"""
If the match is successful, return the binding map, otherwise return None.
"""
bindings = bindings.copy() if bindings is not None else {}
for a, b in zip(expr, tuple):
@ -43,9 +43,9 @@ def match(tuple, expr, bindings=None):
def apply_bindings(expr, bindings, strict=True):
"""Given an expr which may contain lvars, substitute its lvars for constants returning the
simplified expr.
simplified expr.
"""
"""
if strict:
return tuple((bindings[e] if isinstance(e, LVar) else e) for e in expr)
@ -56,10 +56,10 @@ def apply_bindings(expr, bindings, strict=True):
def select(db: Dataset, expr, bindings=None, _recursion_guard=None, _select_guard=None):
"""Evaluate an expression in a database, lazily producing a sequence of 'matching' tuples.
The dataset is a set of tuples and rules, and the expression is a single tuple containing lvars
and constants. Evaluates rules and tuples, returning
The dataset is a set of tuples and rules, and the expression is a single tuple containing lvars
and constants. Evaluates rules and tuples, returning
"""
"""
def __select_tuples():
# As an opt. support indexed scans, which is optional.
@ -170,8 +170,8 @@ def select(db: Dataset, expr, bindings=None, _recursion_guard=None, _select_guar
def join(db: Dataset, clauses, bindings, pattern=None, _recursion_guard=None):
"""Evaluate clauses over the dataset, joining (or antijoining) with the seed bindings.
Yields a sequence of tuples and LVar bindings for which all joins and antijoins were satisfied.
"""
Yields a sequence of tuples and LVar bindings for which all joins and antijoins were satisfied.
"""
def __join(g, clause):
for ts, bindings in g:

View file

@ -27,13 +27,19 @@ class Actions(object):
return self._db_cls(tuples, rules)
def make_symbol(self, input, start, end, elements):
return LVar("".join(e.text for e in elements),)
return LVar(
"".join(e.text for e in elements),
)
def make_word(self, input, start, end, elements):
return Constant("".join(e.text for e in elements),)
return Constant(
"".join(e.text for e in elements),
)
def make_string(self, input, start, end, elements):
return Constant(elements[1].text,)
return Constant(
elements[1].text,
)
def make_comment(self, input, start, end, elements):
return None
@ -81,11 +87,11 @@ class Actions(object):
class Parser(Grammar):
"""Implementation detail.
A slightly hacked version of the Parser class canopy generates, which lets us control what the
parsing entry point is. This lets me play games with having one parser and one grammar which is
used both for the command shell and for other things.
A slightly hacked version of the Parser class canopy generates, which lets us control what the
parsing entry point is. This lets me play games with having one parser and one grammar which is
used both for the command shell and for other things.
"""
"""
def __init__(self, input, actions, types):
self._input = input

View file

@ -66,8 +66,8 @@ class Dataset(object):
class CachedDataset(Dataset):
"""An extension of the dataset which features a cache of rule produced tuples.
Note that this cache is lost when merging datasets - which ensures correctness.
"""
Note that this cache is lost when merging datasets - which ensures correctness.
"""
# Inherits tuples, rules, merge
@ -90,11 +90,11 @@ class CachedDataset(Dataset):
class TableIndexedDataset(CachedDataset):
"""An extension of the Dataset type which features both a cache and an index by table & length.
The index allows more efficient scans by maintaining 'table' style partitions.
It does not support user-defined indexing schemes.
The index allows more efficient scans by maintaining 'table' style partitions.
It does not support user-defined indexing schemes.
Note that index building is delayed until an index is scanned.
"""
Note that index building is delayed until an index is scanned.
"""
# From Dataset:
# tuples, rules, merge
@ -126,11 +126,11 @@ class TableIndexedDataset(CachedDataset):
class PartlyIndexedDataset(TableIndexedDataset):
"""An extension of the Dataset type which features both a cache and and a full index by table,
length, tuple index and value.
length, tuple index and value.
The index allows extremely efficient scans when elements of the tuple are known.
The index allows extremely efficient scans when elements of the tuple are known.
"""
"""
# From Dataset:
# tuples, rules, merge

View file

@ -25,8 +25,19 @@ def test_id_query(db_cls):
Constant("a"),
Constant("b"),
)
assert not select(db_cls([], []), ("a", "b",))
assert select(db_cls([ab], []), ("a", "b",)) == [((("a", "b"),), {},)]
assert not select(
db_cls([], []),
(
"a",
"b",
),
)
assert select(db_cls([ab], []), ("a", "b",)) == [
(
(("a", "b"),),
{},
)
]
@pytest.mark.parametrize("db_cls,", DBCLS)
@ -47,7 +58,17 @@ def test_lvar_unification(db_cls):
d = read("""edge(b, c). edge(c, c).""", db_cls=db_cls)
assert select(d, ("edge", "X", "X",)) == [((("edge", "c", "c"),), {"X": "c"})]
assert (
select(
d,
(
"edge",
"X",
"X",
),
)
== [((("edge", "c", "c"),), {"X": "c"})]
)
@pytest.mark.parametrize("db_cls,", DBCLS)
@ -105,12 +126,12 @@ no-b(X, Y) :-
def test_nested_antijoin(db_cls):
"""Test a query which negates a subquery which uses an antijoin.
Shouldn't exercise anything more than `test_antjoin` does, but it's an interesting case since you
actually can't capture the same semantics using a single query. Antijoins can't propagate positive
information (create lvar bindings) so I'm not sure you can express this another way without a
different evaluation strategy.
Shouldn't exercise anything more than `test_antjoin` does, but it's an interesting case since you
actually can't capture the same semantics using a single query. Antijoins can't propagate positive
information (create lvar bindings) so I'm not sure you can express this another way without a
different evaluation strategy.
"""
"""
d = read(
"""

View file

@ -3,7 +3,7 @@ from setuptools import setup
setup(
name="arrdem.flowmetal",
# Package metadata
version='0.0.0',
version="0.0.0",
license="MIT",
description="A weird execution engine",
long_description=open("README.md").read(),
@ -18,20 +18,16 @@ setup(
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
],
# Package setup
package_dir={"": "src/python"},
packages=[
"flowmetal",
],
entry_points={
'console_scripts': [
'iflow=flowmetal.repl:main'
],
"console_scripts": ["iflow=flowmetal.repl:main"],
},
install_requires=[
'prompt-toolkit~=3.0.0',
"prompt-toolkit~=3.0.0",
],
extras_require={
}
extras_require={},
)

View file

@ -2,7 +2,7 @@
An abstract or base Flowmetal DB.
"""
from abc import abc, abstractmethod, abstractproperty, abstractclassmethod, abstractstaticmethod
from abc import abstractclassmethod, abstractmethod
class Db(ABC):

View file

@ -3,6 +3,7 @@
import click
@click.group()
def cli():
pass

View file

@ -3,6 +3,7 @@
import click
@click.group()
def cli():
pass

View file

@ -1,5 +1,3 @@
"""
Somewhat generic models of Flowmetal programs.
"""
from typing import NamedTuple

View file

@ -3,6 +3,7 @@
import click
@click.group()
def cli():
pass

View file

@ -3,6 +3,7 @@
import click
@click.group()
def cli():
pass

View file

@ -9,94 +9,102 @@ import requests
class GandiAPI(object):
"""An extremely incomplete Gandi REST API driver class.
"""An extremely incomplete Gandi REST API driver class.
Exists to close over your API key, and make talking to the API slightly more idiomatic.
Exists to close over your API key, and make talking to the API slightly more idiomatic.
Note: In an effort to be nice, this API maintains a cache of the zones visible with your API
key. The cache is maintained when using this driver, but concurrent modifications of your zone(s)
via the web UI or other processes will obviously undermine it. This cache can be disabled by
setting the `use_cache` kwarg to `False`.
Note: In an effort to be nice, this API maintains a cache of the zones visible with your API
key. The cache is maintained when using this driver, but concurrent modifications of your zone(s)
via the web UI or other processes will obviously undermine it. This cache can be disabled by
setting the `use_cache` kwarg to `False`.
"""
def __init__(self, key=None, use_cache=True):
self._base = "https://dns.api.gandi.net/api/v5"
self._key = key
self._use_cache = use_cache
self._zones = None
"""
# Helpers for making requests with the API key as required by the API.
def __init__(self, key=None, use_cache=True):
self._base = "https://dns.api.gandi.net/api/v5"
self._key = key
self._use_cache = use_cache
self._zones = None
def _do_request(self, method, path, headers=None, **kwargs):
headers = headers or {}
headers["X-Api-Key"] = self._key
resp = method(self._base + path, headers=headers, **kwargs)
if resp.status_code > 299:
print(resp.text)
resp.raise_for_status()
# Helpers for making requests with the API key as required by the API.
return resp
def _do_request(self, method, path, headers=None, **kwargs):
headers = headers or {}
headers["X-Api-Key"] = self._key
resp = method(self._base + path, headers=headers, **kwargs)
if resp.status_code > 299:
print(resp.text)
resp.raise_for_status()
def _get(self, path, **kwargs):
return self._do_request(requests.get, path, **kwargs)
return resp
def _post(self, path, **kwargs):
return self._do_request(requests.post, path, **kwargs)
def _get(self, path, **kwargs):
return self._do_request(requests.get, path, **kwargs)
def _put(self, path, **kwargs):
return self._do_request(requests.put, path, **kwargs)
def _post(self, path, **kwargs):
return self._do_request(requests.post, path, **kwargs)
# Intentional public API
def domain_records(self, domain):
return self._get("/domains/{0}/records".format(domain)).json()
def _put(self, path, **kwargs):
return self._do_request(requests.put, path, **kwargs)
def get_zones(self):
if self._use_cache:
if not self._zones:
self._zones = self._get("/zones").json()
return self._zones
else:
return self._get("/zones").json()
# Intentional public API
def get_zone(self, zone_id):
return self._get("/zones/{}".format(zone_id)).json()
def domain_records(self, domain):
return self._get("/domains/{0}/records".format(domain)).json()
def get_zone_by_name(self, zone_name):
for zone in self.get_zones():
if zone["name"] == zone_name:
return zone
def get_zones(self):
if self._use_cache:
if not self._zones:
self._zones = self._get("/zones").json()
return self._zones
else:
return self._get("/zones").json()
def create_zone(self, zone_name):
new_zone_id = self._post("/zones",
headers={"content-type": "application/json"},
data=json.dumps({"name": zone_name}))\
.headers["Location"]\
.split("/")[-1]
new_zone = self.get_zone(new_zone_id)
def get_zone(self, zone_id):
return self._get("/zones/{}".format(zone_id)).json()
# Note: If the cache is active, update the cache.
if self._use_cache and self._zones is not None:
self._zones.append(new_zone)
def get_zone_by_name(self, zone_name):
for zone in self.get_zones():
if zone["name"] == zone_name:
return zone
return new_zone
def create_zone(self, zone_name):
new_zone_id = (
self._post(
"/zones",
headers={"content-type": "application/json"},
data=json.dumps({"name": zone_name}),
)
.headers["Location"]
.split("/")[-1]
)
new_zone = self.get_zone(new_zone_id)
def replace_domain(self, domain, records):
date = datetime.now()
date = "{:%A, %d. %B %Y %I:%M%p}".format(date)
zone_name = f"updater generated domain - {domain} - {date}"
# Note: If the cache is active, update the cache.
if self._use_cache and self._zones is not None:
self._zones.append(new_zone)
zone = self.get_zone_by_name(zone_name)
if not zone:
zone = self.create_zone(zone_name)
return new_zone
print("Using zone", zone["uuid"])
def replace_domain(self, domain, records):
date = datetime.now()
date = "{:%A, %d. %B %Y %I:%M%p}".format(date)
zone_name = f"updater generated domain - {domain} - {date}"
for r in records:
self._post("/zones/{0}/records".format(zone["uuid"]),
headers={"content-type": "application/json"},
data=json.dumps(r))
zone = self.get_zone_by_name(zone_name)
if not zone:
zone = self.create_zone(zone_name)
return self._post("/zones/{0}/domains/{1}".format(zone["uuid"], domain),
headers={"content-type": "application/json"})
print("Using zone", zone["uuid"])
for r in records:
self._post(
"/zones/{0}/records".format(zone["uuid"]),
headers={"content-type": "application/json"},
data=json.dumps(r),
)
return self._post(
"/zones/{0}/domains/{1}".format(zone["uuid"], domain),
headers={"content-type": "application/json"},
)

View file

@ -20,165 +20,171 @@ import meraki
RECORD_LINE_PATTERN = re.compile(
"^(?P<rrset_name>\S+)\s+"
"(?P<rrset_ttl>\S+)\s+"
"IN\s+"
"(?P<rrset_type>\S+)\s+"
"(?P<rrset_values>.+)$")
"^(?P<rrset_name>\S+)\s+"
"(?P<rrset_ttl>\S+)\s+"
"IN\s+"
"(?P<rrset_type>\S+)\s+"
"(?P<rrset_values>.+)$"
)
def update(m, k, f, *args, **kwargs):
"""clojure.core/update for Python's stateful maps."""
if k in m:
m[k] = f(m[k], *args, **kwargs)
return m
"""clojure.core/update for Python's stateful maps."""
if k in m:
m[k] = f(m[k], *args, **kwargs)
return m
def parse_zone_record(line):
if match := RECORD_LINE_PATTERN.search(line):
dat = match.groupdict()
dat = update(dat, "rrset_ttl", int)
dat = update(dat, "rrset_values", lambda x: [x])
return dat
if match := RECORD_LINE_PATTERN.search(line):
dat = match.groupdict()
dat = update(dat, "rrset_ttl", int)
dat = update(dat, "rrset_values", lambda x: [x])
return dat
def same_record(lr, rr):
"""
A test to see if two records name the same zone entry.
"""
"""
A test to see if two records name the same zone entry.
"""
return lr["rrset_name"] == rr["rrset_name"] and \
lr["rrset_type"] == rr["rrset_type"]
return lr["rrset_name"] == rr["rrset_name"] and lr["rrset_type"] == rr["rrset_type"]
def records_equate(lr, rr):
"""
Equality, ignoring rrset_href which is generated by the API.
"""
"""
Equality, ignoring rrset_href which is generated by the API.
"""
if not same_record(lr, rr):
return False
elif lr["rrset_ttl"] != rr["rrset_ttl"]:
return False
elif set(lr["rrset_values"]) != set(rr["rrset_values"]):
return False
else:
return True
if not same_record(lr, rr):
return False
elif lr["rrset_ttl"] != rr["rrset_ttl"]:
return False
elif set(lr["rrset_values"]) != set(rr["rrset_values"]):
return False
else:
return True
def template_and_parse_zone(template_file, template_bindings):
assert template_file is not None
assert template_bindings is not None
assert template_file is not None
assert template_bindings is not None
with open(template_file) as f:
dat = jinja2.Template(f.read()).render(**template_bindings)
with open(template_file) as f:
dat = jinja2.Template(f.read()).render(**template_bindings)
uncommitted_records = []
for line in dat.splitlines():
if line and not line[0] == "#":
record = parse_zone_record(line)
if record:
uncommitted_records.append(record)
uncommitted_records = []
for line in dat.splitlines():
if line and not line[0] == "#":
record = parse_zone_record(line)
if record:
uncommitted_records.append(record)
records = []
records = []
for uncommitted_r in uncommitted_records:
flag = False
for committed_r in records:
if same_record(uncommitted_r, committed_r):
# Join the two records
committed_r["rrset_values"].extend(uncommitted_r["rrset_values"])
flag = True
for uncommitted_r in uncommitted_records:
flag = False
for committed_r in records:
if same_record(uncommitted_r, committed_r):
# Join the two records
committed_r["rrset_values"].extend(uncommitted_r["rrset_values"])
flag = True
if not flag:
records.append(uncommitted_r)
if not flag:
records.append(uncommitted_r)
sorted(records, key=lambda x: (x["rrset_type"], x["rrset_name"]))
sorted(records, key=lambda x: (x["rrset_type"], x["rrset_name"]))
return records
return records
def diff_zones(left_zone, right_zone):
"""
Equality between unordered lists of records constituting a zone.
"""
in_left_not_right = []
for lr in left_zone:
flag = False
for rr in right_zone:
if same_record(lr, rr) and records_equate(lr, rr):
flag = True
break
"""
Equality between unordered lists of records constituting a zone.
"""
if not flag:
in_left_not_right.append(lr)
in_right_not_left = []
for rr in right_zone:
flag = False
in_left_not_right = []
for lr in left_zone:
if same_record(lr, rr) and records_equate(lr, rr):
flag = True
break
flag = False
for rr in right_zone:
if same_record(lr, rr) and records_equate(lr, rr):
flag = True
break
if not flag:
in_right_not_left.append(lr)
if not flag:
in_left_not_right.append(lr)
return in_left_not_right or in_right_not_left
in_right_not_left = []
for rr in right_zone:
flag = False
for lr in left_zone:
if same_record(lr, rr) and records_equate(lr, rr):
flag = True
break
if not flag:
in_right_not_left.append(lr)
return in_left_not_right or in_right_not_left
parser = argparse.ArgumentParser(description="\"Dynamic\" DNS updating for self-hosted services")
parser = argparse.ArgumentParser(
description='"Dynamic" DNS updating for self-hosted services'
)
parser.add_argument("--config", dest="config_file", required=True)
parser.add_argument("--templates", dest="template_dir", required=True)
parser.add_argument("--dry-run", dest="dry", action="store_true", default=False)
def main():
args = parser.parse_args()
config = yaml.safe_load(open(args.config_file, "r"))
args = parser.parse_args()
config = yaml.safe_load(open(args.config_file, "r"))
dashboard = meraki.DashboardAPI(config["meraki"]["key"], output_log=False)
org = config["meraki"]["organization"]
device = config["meraki"]["router_serial"]
dashboard = meraki.DashboardAPI(config["meraki"]["key"], output_log=False)
org = config["meraki"]["organization"]
device = config["meraki"]["router_serial"]
uplinks = dashboard.appliance.getOrganizationApplianceUplinkStatuses(
organizationId=org,
serials=[device]
)[0]["uplinks"]
uplinks = dashboard.appliance.getOrganizationApplianceUplinkStatuses(
organizationId=org, serials=[device]
)[0]["uplinks"]
template_bindings = {
"local": {
# One of the two
"public_v4s": [link.get("publicIp") for link in uplinks if link.get("publicIp")],
},
# Why isn't there a merge method
**config["bindings"]
}
template_bindings = {
"local": {
# One of the two
"public_v4s": [
link.get("publicIp") for link in uplinks if link.get("publicIp")
],
},
# Why isn't there a merge method
**config["bindings"],
}
api = GandiAPI(config["gandi"]["key"])
api = GandiAPI(config["gandi"]["key"])
for task in config["tasks"]:
if isinstance(task, str):
task = {"template": task + ".j2",
"zones": [task]}
for task in config["tasks"]:
if isinstance(task, str):
task = {"template": task + ".j2", "zones": [task]}
computed_zone = template_and_parse_zone(os.path.join(args.template_dir, task["template"]), template_bindings)
computed_zone = template_and_parse_zone(
os.path.join(args.template_dir, task["template"]), template_bindings
)
for zone_name in task["zones"]:
try:
live_zone = api.domain_records(zone_name)
for zone_name in task["zones"]:
try:
live_zone = api.domain_records(zone_name)
if diff_zones(computed_zone, live_zone):
print("Zone {} differs, computed zone:".format(zone_name))
pprint(computed_zone)
if not args.dry:
print(api.replace_domain(zone_name, computed_zone))
else:
print("Zone {} up to date".format(zone_name))
if diff_zones(computed_zone, live_zone):
print("Zone {} differs, computed zone:".format(zone_name))
pprint(computed_zone)
if not args.dry:
print(api.replace_domain(zone_name, computed_zone))
else:
print("Zone {} up to date".format(zone_name))
except Exception as e:
print("While processing zone {}".format(zone_name))
raise e
except Exception as e:
print("While processing zone {}".format(zone_name))
raise e
if __name__ == "__main__" or 1:
main()
main()

View file

@ -3,7 +3,7 @@ from setuptools import setup
setup(
name="arrdem.ratchet",
# Package metadata
version='0.0.0',
version="0.0.0",
license="MIT",
description="A 'ratcheting' message system",
long_description=open("README.md").read(),
@ -18,18 +18,12 @@ setup(
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
],
# Package setup
package_dir={
"": "src/python"
},
package_dir={"": "src/python"},
packages=[
"ratchet",
],
entry_points={
},
install_requires=[
],
extras_require={
}
entry_points={},
install_requires=[],
extras_require={},
)

View file

@ -98,14 +98,15 @@ VALUES (?, ?, ?, ?);
"""
class SQLiteDriver:
def __init__(self,
filename="~/.ratchet.sqlite3",
sqlite_timeout=1000,
message_ttl=60000,
message_space="_",
message_author=f"{os.getpid()}@{socket.gethostname()}"):
def __init__(
self,
filename="~/.ratchet.sqlite3",
sqlite_timeout=1000,
message_ttl=60000,
message_space="_",
message_author=f"{os.getpid()}@{socket.gethostname()}",
):
self._path = os.path.expanduser(filename)
self._sqlite_timeout = sqlite_timeout
self._message_ttl = message_ttl
@ -120,14 +121,11 @@ class SQLiteDriver:
conn.executescript(SCHEMA_SCRIPT)
def _connection(self):
return sql.connect(self._filename,
timeout=self._sqlite_timeout)
return sql.connect(self._filename, timeout=self._sqlite_timeout)
def create_message(self,
message: str,
ttl: int = None,
space: str = None,
author: str = None):
def create_message(
self, message: str, ttl: int = None, space: str = None, author: str = None
):
"""Create a single message."""
ttl = ttl or self._message_ttl
@ -138,11 +136,9 @@ class SQLiteDriver:
cursor.execute(CREATE_MESSAGE_SCRIPT, author, space, ttl, message)
return cursor.lastrowid
def create_event(self,
timeout: int,
ttl: int = None,
space: str = None,
author: str = None):
def create_event(
self, timeout: int, ttl: int = None, space: str = None, author: str = None
):
"""Create a (pending) event."""
ttl = ttl or self._message_ttl

View file

@ -7,111 +7,94 @@ from yamlschema import lint_buffer
import pytest
@pytest.mark.parametrize('schema, obj', [
({"type": "number"},
"---\n1.0"),
({"type": "integer"},
"---\n3"),
({"type": "string"},
"---\nfoo bar baz"),
({"type": "string",
"maxLength": 15},
"---\nfoo bar baz"),
({"type": "string",
"minLength": 10},
"---\nfoo bar baz"),
({"type": "string",
"pattern": "^foo.*"},
"---\nfoo bar baz"),
({"type": "object",
"additionalProperties": True},
"---\nfoo: bar\nbaz: qux"),
({"type": "object",
"properties": {"foo": {"type": "string"}}},
"---\nfoo: bar\nbaz: qux"),
({"type": "object",
"properties": {"foo": {"type": "string"}},
"additionalProperties": False},
"---\nfoo: bar"),
({"type": "object",
"properties": {"foo": {"type": "object"}}},
"---\nfoo: {}"),
({"type": "object",
"properties": {"foo": {
"type": "array",
"items": {"type": "object"}}}},
"---\nfoo: [{}, {}, {foo: bar}]"),
])
@pytest.mark.parametrize(
"schema, obj",
[
({"type": "number"}, "---\n1.0"),
({"type": "integer"}, "---\n3"),
({"type": "string"}, "---\nfoo bar baz"),
({"type": "string", "maxLength": 15}, "---\nfoo bar baz"),
({"type": "string", "minLength": 10}, "---\nfoo bar baz"),
({"type": "string", "pattern": "^foo.*"}, "---\nfoo bar baz"),
({"type": "object", "additionalProperties": True}, "---\nfoo: bar\nbaz: qux"),
(
{"type": "object", "properties": {"foo": {"type": "string"}}},
"---\nfoo: bar\nbaz: qux",
),
(
{
"type": "object",
"properties": {"foo": {"type": "string"}},
"additionalProperties": False,
},
"---\nfoo: bar",
),
({"type": "object", "properties": {"foo": {"type": "object"}}}, "---\nfoo: {}"),
(
{
"type": "object",
"properties": {"foo": {"type": "array", "items": {"type": "object"}}},
},
"---\nfoo: [{}, {}, {foo: bar}]",
),
],
)
def test_lint_document_ok(schema, obj):
assert not list(lint_buffer(schema, obj))
@pytest.mark.parametrize('msg, schema, obj', [
# Numerics
("Floats are not ints",
{"type": "integer"},
"---\n1.0"),
("Ints are not floats",
{"type": "number"},
"---\n1"),
# Numerics - range limits. Integer edition
("1 is the limit of the range",
{"type": "integer",
"exclusiveMaximum": 1},
"---\n1"),
("1 is the limit of the range",
{"type": "integer",
"exclusiveMinimum": 1},
"---\n1"),
("1 is out of the range",
{"type": "integer",
"minimum": 2},
"---\n1"),
("1 is out of the range",
{"type": "integer",
"maximum": 0},
"---\n1"),
("1 is out of the range",
{"type": "integer",
"exclusiveMinimum": 1},
"---\n1"),
# Numerics - range limits. Number/Float edition
("1 is the limit of the range",
{"type": "number",
"exclusiveMaximum": 1},
"---\n1.0"),
("1 is the limit of the range",
{"type": "number",
"exclusiveMinimum": 1},
"---\n1.0"),
("1 is out of the range",
{"type": "number",
"minimum": 2},
"---\n1.0"),
("1 is out of the range",
{"type": "number",
"maximum": 0},
"---\n1.0"),
("1 is out of the range",
{"type": "number",
"exclusiveMinimum": 1},
"---\n1.0"),
# String shit
("String too short",
{"type": "string", "minLength": 1},
"---\n''"),
("String too long",
{"type": "string", "maxLength": 1},
"---\nfoo"),
("String does not match pattern",
{"type": "string", "pattern": "bar"},
"---\nfoo"),
("String does not fully match pattern",
{"type": "string", "pattern": "foo"},
"---\nfooooooooo"),
])
@pytest.mark.parametrize(
"msg, schema, obj",
[
# Numerics
("Floats are not ints", {"type": "integer"}, "---\n1.0"),
("Ints are not floats", {"type": "number"}, "---\n1"),
# Numerics - range limits. Integer edition
(
"1 is the limit of the range",
{"type": "integer", "exclusiveMaximum": 1},
"---\n1",
),
(
"1 is the limit of the range",
{"type": "integer", "exclusiveMinimum": 1},
"---\n1",
),
("1 is out of the range", {"type": "integer", "minimum": 2}, "---\n1"),
("1 is out of the range", {"type": "integer", "maximum": 0}, "---\n1"),
("1 is out of the range", {"type": "integer", "exclusiveMinimum": 1}, "---\n1"),
# Numerics - range limits. Number/Float edition
(
"1 is the limit of the range",
{"type": "number", "exclusiveMaximum": 1},
"---\n1.0",
),
(
"1 is the limit of the range",
{"type": "number", "exclusiveMinimum": 1},
"---\n1.0",
),
("1 is out of the range", {"type": "number", "minimum": 2}, "---\n1.0"),
("1 is out of the range", {"type": "number", "maximum": 0}, "---\n1.0"),
(
"1 is out of the range",
{"type": "number", "exclusiveMinimum": 1},
"---\n1.0",
),
# String shit
("String too short", {"type": "string", "minLength": 1}, "---\n''"),
("String too long", {"type": "string", "maxLength": 1}, "---\nfoo"),
(
"String does not match pattern",
{"type": "string", "pattern": "bar"},
"---\nfoo",
),
(
"String does not fully match pattern",
{"type": "string", "pattern": "foo"},
"---\nfooooooooo",
),
],
)
def test_lint_document_fails(msg, schema, obj):
assert list(lint_buffer(schema, obj)), msg

View file

@ -59,9 +59,7 @@ class YamlLinter(object):
return schema
def lint_mapping(self, schema, node: Node) -> t.Iterable[str]:
"""FIXME.
"""
"""FIXME."""
if schema["type"] != "object" or not isinstance(node, MappingNode):
yield LintRecord(
@ -71,9 +69,7 @@ class YamlLinter(object):
f"Expected {schema['type']}, got {node.id} {str(node.start_mark).lstrip()}",
)
additional_type: t.Union[dict, bool] = (
schema.get("additionalProperties", True)
)
additional_type: t.Union[dict, bool] = schema.get("additionalProperties", True)
properties: dict = schema.get("properties", {})
required: t.Iterable[str] = schema.get("required", [])
@ -135,37 +131,26 @@ class YamlLinter(object):
elif schema["type"] == "number":
yield from self.lint_number(schema, node)
else:
raise NotImplementedError(
f"Scalar type {schema['type']} is not supported"
)
raise NotImplementedError(f"Scalar type {schema['type']} is not supported")
def lint_string(self, schema, node: Node) -> t.Iterable[str]:
"""FIXME."""
if node.tag != "tag:yaml.org,2002:str":
yield LintRecord(
LintLevel.MISSMATCH,
node,
schema,
f"Expected a string, got a {node}"
LintLevel.MISSMATCH, node, schema, f"Expected a string, got a {node}"
)
if maxl := schema.get("maxLength"):
if len(node.value) > maxl:
yield LintRecord(
LintLevel.MISSMATCH,
node,
schema,
f"Expected a shorter string"
LintLevel.MISSMATCH, node, schema, f"Expected a shorter string"
)
if minl := schema.get("minLength"):
if len(node.value) < minl:
yield LintRecord(
LintLevel.MISSMATCH,
node,
schema,
f"Expected a longer string"
LintLevel.MISSMATCH, node, schema, f"Expected a longer string"
)
if pat := schema.get("pattern"):
@ -174,7 +159,7 @@ class YamlLinter(object):
LintLevel.MISSMATCH,
node,
schema,
f"Expected a string matching the pattern"
f"Expected a string matching the pattern",
)
def lint_integer(self, schema, node: Node) -> t.Iterable[str]:
@ -184,10 +169,7 @@ class YamlLinter(object):
else:
yield LintRecord(
LintLevel.MISSMATCH,
node,
schema,
f"Expected an integer, got a {node}"
LintLevel.MISSMATCH, node, schema, f"Expected an integer, got a {node}"
)
def lint_number(self, schema, node: Node) -> t.Iterable[str]:
@ -197,13 +179,9 @@ class YamlLinter(object):
else:
yield LintRecord(
LintLevel.MISSMATCH,
node,
schema,
f"Expected an integer, got a {node}"
LintLevel.MISSMATCH, node, schema, f"Expected an integer, got a {node}"
)
def _lint_num_range(self, schema, node: Node, value) -> t.Iterable[str]:
""""FIXME."""
@ -213,7 +191,7 @@ class YamlLinter(object):
LintLevel.MISSMATCH,
node,
schema,
f"Expected a multiple of {base}, got {value}"
f"Expected a multiple of {base}, got {value}",
)
if (max := schema.get("exclusiveMaximum")) is not None:
@ -222,7 +200,7 @@ class YamlLinter(object):
LintLevel.MISSMATCH,
node,
schema,
f"Expected a value less than {max}, got {value}"
f"Expected a value less than {max}, got {value}",
)
if (max := schema.get("maximum")) is not None:
@ -231,7 +209,7 @@ class YamlLinter(object):
LintLevel.MISSMATCH,
node,
schema,
f"Expected a value less than or equal to {max}, got {value}"
f"Expected a value less than or equal to {max}, got {value}",
)
if (min := schema.get("exclusiveMinimum")) is not None:
@ -240,7 +218,7 @@ class YamlLinter(object):
LintLevel.MISSMATCH,
node,
schema,
f"Expected a value greater than {min}, got {value}"
f"Expected a value greater than {min}, got {value}",
)
if (min := schema.get("minimum")) is not None:
@ -249,7 +227,7 @@ class YamlLinter(object):
LintLevel.MISSMATCH,
node,
schema,
f"Expected a value greater than or equal to {min}, got {value}"
f"Expected a value greater than or equal to {min}, got {value}",
)
def lint_document(self, node, schema=None) -> t.Iterable[str]:
@ -271,10 +249,7 @@ class YamlLinter(object):
# This is the schema that rejects everything.
elif schema == False:
yield LintRecord(
LintLevel.UNEXPECTED,
node,
schema,
"Received an unexpected value"
LintLevel.UNEXPECTED, node, schema, "Received an unexpected value"
)
# Walking the PyYAML node hierarchy