This commit is contained in:
Reid 'arrdem' McKenzie 2021-05-15 11:34:32 -06:00
parent debc7996ee
commit bbae5ef63f
34 changed files with 956 additions and 886 deletions

View file

@ -25,7 +25,6 @@ setup(
"calf-read = calf.reader:main", "calf-read = calf.reader:main",
"calf-analyze = calf.analyzer:main", "calf-analyze = calf.analyzer:main",
"calf-compile = calf.compiler:main", "calf-compile = calf.compiler:main",
# Client/server stuff # Client/server stuff
"calf-client = calf.client:main", "calf-client = calf.client:main",
"calf-server = calf.server:main", "calf-server = calf.server:main",

View file

@ -7,7 +7,6 @@ from curses.textpad import Textbox, rectangle
def curse_repl(handle_buffer): def curse_repl(handle_buffer):
def handle(buff, count): def handle(buff, count):
try: try:
return list(handle_buffer(buff, count)), None return list(handle_buffer(buff, count)), None
@ -24,22 +23,25 @@ def curse_repl(handle_buffer):
maxy, maxx = stdscr.getmaxyx() maxy, maxx = stdscr.getmaxyx()
stdscr.clear() stdscr.clear()
stdscr.addstr(0, 0, "Enter example: (hit Ctrl-G to execute, Ctrl-C to exit)", curses.A_BOLD) stdscr.addstr(
editwin = curses.newwin(5, maxx - 4, 0,
2, 2) 0,
rectangle(stdscr, "Enter example: (hit Ctrl-G to execute, Ctrl-C to exit)",
1, 1, curses.A_BOLD,
1 + 5 + 1, maxx - 2) )
editwin = curses.newwin(5, maxx - 4, 2, 2)
rectangle(stdscr, 1, 1, 1 + 5 + 1, maxx - 2)
# Printing is part of the prompt # Printing is part of the prompt
cur = 8 cur = 8
def putstr(str, x=0, attr=0): def putstr(str, x=0, attr=0):
# ya rly. I know exactly what I'm doing here # ya rly. I know exactly what I'm doing here
nonlocal cur nonlocal cur
# This is how we handle going off the bottom of the scren lol # This is how we handle going off the bottom of the scren lol
if cur < maxy: if cur < maxy:
stdscr.addstr(cur, x, str, attr) stdscr.addstr(cur, x, str, attr)
cur += (len(str.split("\n")) or 1) cur += len(str.split("\n")) or 1
for ex, buff, vals, err in reversed(examples): for ex, buff, vals, err in reversed(examples):
putstr(f"Example {ex}:", attr=curses.A_BOLD) putstr(f"Example {ex}:", attr=curses.A_BOLD)
@ -58,7 +60,7 @@ def curse_repl(handle_buffer):
elif vals: elif vals:
putstr(" Values:") putstr(" Values:")
for x, t in zip(range(1, 1<<64), vals): for x, t in zip(range(1, 1 << 64), vals):
putstr(f" {x:>3}) " + repr(t)) putstr(f" {x:>3}) " + repr(t))
putstr("") putstr("")

View file

@ -28,34 +28,82 @@ COMMENT_PATTERN = r";(([^\n\r]*)(\n\r?)?)"
TOKENS = [ TOKENS = [
# Paren (noral) lists # Paren (noral) lists
(r"\(", "PAREN_LEFT",), (
(r"\)", "PAREN_RIGHT",), r"\(",
"PAREN_LEFT",
),
(
r"\)",
"PAREN_RIGHT",
),
# Bracket lists # Bracket lists
(r"\[", "BRACKET_LEFT",), (
(r"\]", "BRACKET_RIGHT",), r"\[",
"BRACKET_LEFT",
),
(
r"\]",
"BRACKET_RIGHT",
),
# Brace lists (maps) # Brace lists (maps)
(r"\{", "BRACE_LEFT",), (
(r"\}", "BRACE_RIGHT",), r"\{",
(r"\^", "META",), "BRACE_LEFT",
(r"'", "SINGLE_QUOTE",), ),
(STRING_PATTERN, "STRING",), (
(r"#", "MACRO_DISPATCH",), r"\}",
"BRACE_RIGHT",
),
(
r"\^",
"META",
),
(
r"'",
"SINGLE_QUOTE",
),
(
STRING_PATTERN,
"STRING",
),
(
r"#",
"MACRO_DISPATCH",
),
# Symbols # Symbols
(SYMBOL_PATTERN, "SYMBOL",), (
SYMBOL_PATTERN,
"SYMBOL",
),
# Numbers # Numbers
(SIMPLE_INTEGER, "INTEGER",), (
(FLOAT_PATTERN, "FLOAT",), SIMPLE_INTEGER,
"INTEGER",
),
(
FLOAT_PATTERN,
"FLOAT",
),
# Keywords # Keywords
# #
# Note: this is a dirty f'n hack in that in order for keywords to work, ":" # Note: this is a dirty f'n hack in that in order for keywords to work, ":"
# has to be defined to be a valid keyword. # has to be defined to be a valid keyword.
(r":" + SYMBOL_PATTERN + "?", "KEYWORD",), (
r":" + SYMBOL_PATTERN + "?",
"KEYWORD",
),
# Whitespace # Whitespace
# #
# Note that the whitespace token will contain at most one newline # Note that the whitespace token will contain at most one newline
(r"(\n\r?|[,\t ]*)", "WHITESPACE",), (
r"(\n\r?|[,\t ]*)",
"WHITESPACE",
),
# Comment # Comment
(COMMENT_PATTERN, "COMMENT",), (
COMMENT_PATTERN,
"COMMENT",
),
# Strings # Strings
(r'"(?P<body>(?:[^\"]|\.)*)"', "STRING"), (r'"(?P<body>(?:[^\"]|\.)*)"', "STRING"),
] ]

View file

@ -8,7 +8,6 @@ parsing, linting or other use.
import io import io
import re import re
import sys
from calf.token import CalfToken from calf.token import CalfToken
from calf.io.reader import PeekPosReader from calf.io.reader import PeekPosReader

View file

@ -12,48 +12,47 @@ from collections import namedtuple
class CalfLoaderConfig(namedtuple("CalfLoaderConfig", ["paths"])): class CalfLoaderConfig(namedtuple("CalfLoaderConfig", ["paths"])):
""" """"""
"""
class CalfDelayedPackage( class CalfDelayedPackage(
namedtuple("CalfDelayedPackage", ["name", "version", "metadata", "path"]) namedtuple("CalfDelayedPackage", ["name", "version", "metadata", "path"])
): ):
""" """
This structure represents the delay of loading a packaage. This structure represents the delay of loading a packaage.
Rather than eagerly analyze packages, it may be profitable to use lazy loading / lazy resolution Rather than eagerly analyze packages, it may be profitable to use lazy loading / lazy resolution
of symbols. It may also be possible to cache analyzing some packages. of symbols. It may also be possible to cache analyzing some packages.
""" """
class CalfPackage( class CalfPackage(
namedtuple("CalfPackage", ["name", "version", "metadata", "modules"]) namedtuple("CalfPackage", ["name", "version", "metadata", "modules"])
): ):
""" """
This structure represents the result of forcing the load of a package, and is the product of This structure represents the result of forcing the load of a package, and is the product of
either loading a package directly, or a package becoming a direct dependency and being forced. either loading a package directly, or a package becoming a direct dependency and being forced.
""" """
def parse_package_requirement(config, env, requirement): def parse_package_requirement(config, env, requirement):
""" """
:param config: :param config:
:param env: :param env:
:param requirement: :param requirement:
:returns: :returns:
""" """
def analyze_package(config, env, package): def analyze_package(config, env, package):
""" """
:param config: :param config:
:param env: :param env:
:param module: :param module:
:returns: :returns:
Given a loader configuration and an environment to load into, analyzes the requested package, Given a loader configuration and an environment to load into, analyzes the requested package,
returning an updated environment. returning an updated environment.
""" """

View file

@ -2,11 +2,8 @@
The Calf parser. The Calf parser.
""" """
from collections import namedtuple
from itertools import tee from itertools import tee
import logging import logging
import sys
from typing import NamedTuple, Callable
from calf.lexer import CalfLexer, lex_buffer, lex_file from calf.lexer import CalfLexer, lex_buffer, lex_file
from calf.grammar import MATCHING, WHITESPACE_TYPES from calf.grammar import MATCHING, WHITESPACE_TYPES
@ -45,17 +42,18 @@ def mk_dict(contents, open=None, close=None):
close.start_position, close.start_position,
) )
def mk_str(token): def mk_str(token):
buff = token.value buff = token.value
if buff.startswith('"""') and not buff.endswith('"""'): if buff.startswith('"""') and not buff.endswith('"""'):
raise ValueError('Unterminated tripple quote string') raise ValueError("Unterminated tripple quote string")
elif buff.startswith('"') and not buff.endswith('"'): elif buff.startswith('"') and not buff.endswith('"'):
raise ValueError('Unterminated quote string') raise ValueError("Unterminated quote string")
elif not buff.startswith('"') or buff == '"' or buff == '"""': elif not buff.startswith('"') or buff == '"' or buff == '"""':
raise ValueError('Illegal string') raise ValueError("Illegal string")
if buff.startswith('"""'): if buff.startswith('"""'):
buff = buff[3:-3] buff = buff[3:-3]
@ -114,15 +112,17 @@ class CalfMissingCloseParseError(CalfParseError):
def __init__(self, expected_close_token, open_token): def __init__(self, expected_close_token, open_token):
super(CalfMissingCloseParseError, self).__init__( super(CalfMissingCloseParseError, self).__init__(
f"expected {expected_close_token} starting from {open_token}, got end of file.", f"expected {expected_close_token} starting from {open_token}, got end of file.",
open_token open_token,
) )
self.expected_close_token = expected_close_token self.expected_close_token = expected_close_token
def parse_stream(stream, def parse_stream(
discard_whitespace: bool = True, stream,
discard_comments: bool = True, discard_whitespace: bool = True,
stack: list = None): discard_comments: bool = True,
stack: list = None,
):
"""Parses a token stream, producing a lazy sequence of all read top level forms. """Parses a token stream, producing a lazy sequence of all read top level forms.
If `discard_whitespace` is truthy, then no WHITESPACE tokens will be emitted If `discard_whitespace` is truthy, then no WHITESPACE tokens will be emitted
@ -134,11 +134,10 @@ def parse_stream(stream,
stack = stack or [] stack = stack or []
def recur(_stack = None): def recur(_stack=None):
yield from parse_stream(stream, yield from parse_stream(
discard_whitespace, stream, discard_whitespace, discard_comments, _stack or stack
discard_comments, )
_stack or stack)
for token in stream: for token in stream:
# Whitespace discarding # Whitespace discarding
@ -205,7 +204,9 @@ def parse_stream(stream,
# Case of maybe matching something else, but definitely being wrong # Case of maybe matching something else, but definitely being wrong
else: else:
matching = next(reversed([t[1] for t in stack if t[0] == token.type]), None) matching = next(
reversed([t[1] for t in stack if t[0] == token.type]), None
)
raise CalfUnexpectedCloseParseError(token, matching) raise CalfUnexpectedCloseParseError(token, matching)
# Atoms # Atoms
@ -216,18 +217,14 @@ def parse_stream(stream,
yield token yield token
def parse_buffer(buffer, def parse_buffer(buffer, discard_whitespace=True, discard_comments=True):
discard_whitespace=True,
discard_comments=True):
""" """
Parses a buffer, producing a lazy sequence of all parsed level forms. Parses a buffer, producing a lazy sequence of all parsed level forms.
Propagates all errors. Propagates all errors.
""" """
yield from parse_stream(lex_buffer(buffer), yield from parse_stream(lex_buffer(buffer), discard_whitespace, discard_comments)
discard_whitespace,
discard_comments)
def parse_file(file): def parse_file(file):

View file

@ -13,6 +13,7 @@ from calf.parser import parse_stream
from calf.token import * from calf.token import *
from calf.types import * from calf.types import *
class CalfReader(object): class CalfReader(object):
def handle_keyword(self, t: CalfToken) -> Any: def handle_keyword(self, t: CalfToken) -> Any:
"""Convert a token to an Object value for a symbol. """Convert a token to an Object value for a symbol.
@ -79,8 +80,7 @@ class CalfReader(object):
return Vector.of(self.read(t.value)) return Vector.of(self.read(t.value))
elif isinstance(t, CalfDictToken): elif isinstance(t, CalfDictToken):
return Map.of([(self.read1(k), self.read1(v)) return Map.of([(self.read1(k), self.read1(v)) for k, v in t.items()])
for k, v in t.items()])
# Magical pairwise stuff # Magical pairwise stuff
elif isinstance(t, CalfQuoteToken): elif isinstance(t, CalfQuoteToken):
@ -119,28 +119,21 @@ class CalfReader(object):
yield self.read1(t) yield self.read1(t)
def read_stream(stream, def read_stream(stream, reader: CalfReader = None):
reader: CalfReader = None): """Read from a stream of parsed tokens."""
"""Read from a stream of parsed tokens.
"""
reader = reader or CalfReader() reader = reader or CalfReader()
yield from reader.read(stream) yield from reader.read(stream)
def read_buffer(buffer): def read_buffer(buffer):
"""Read from a buffer, producing a lazy sequence of all top level forms. """Read from a buffer, producing a lazy sequence of all top level forms."""
"""
yield from read_stream(parse_stream(lex_buffer(buffer))) yield from read_stream(parse_stream(lex_buffer(buffer)))
def read_file(file): def read_file(file):
"""Read from a file, producing a lazy sequence of all top level forms. """Read from a file, producing a lazy sequence of all top level forms."""
"""
yield from read_stream(parse_stream(lex_file(file))) yield from read_stream(parse_stream(lex_file(file)))
@ -151,6 +144,8 @@ def main():
from calf.cursedrepl import curse_repl from calf.cursedrepl import curse_repl
def handle_buffer(buff, count): def handle_buffer(buff, count):
return list(read_stream(parse_stream(lex_buffer(buff, source=f"<Example {count}>")))) return list(
read_stream(parse_stream(lex_buffer(buff, source=f"<Example {count}>")))
)
curse_repl(handle_buffer) curse_repl(handle_buffer)

View file

@ -8,23 +8,24 @@ from calf import grammar as cg
from conftest import parametrize from conftest import parametrize
@parametrize('ex', [ @parametrize(
# Proper strings "ex",
'""', [
'"foo bar"', # Proper strings
'"foo\n bar\n\r qux"', '""',
'"foo\\"bar"', '"foo bar"',
'"foo\n bar\n\r qux"',
'""""""', '"foo\\"bar"',
'"""foo bar baz"""', '""""""',
'"""foo "" "" "" bar baz"""', '"""foo bar baz"""',
'"""foo "" "" "" bar baz"""',
# Unterminated string cases # Unterminated string cases
'"', '"',
'"f', '"f',
'"foo bar', '"foo bar',
'"foo\\" bar', '"foo\\" bar',
'"""foo bar baz', '"""foo bar baz',
]) ],
)
def test_match_string(ex): def test_match_string(ex):
assert re.fullmatch(cg.STRING_PATTERN, ex) assert re.fullmatch(cg.STRING_PATTERN, ex)

View file

@ -20,23 +20,62 @@ def lex_single_token(buffer):
@parametrize( @parametrize(
"text,token_type", "text,token_type",
[ [
("(", "PAREN_LEFT",), (
(")", "PAREN_RIGHT",), "(",
("[", "BRACKET_LEFT",), "PAREN_LEFT",
("]", "BRACKET_RIGHT",), ),
("{", "BRACE_LEFT",), (
("}", "BRACE_RIGHT",), ")",
("^", "META",), "PAREN_RIGHT",
("#", "MACRO_DISPATCH",), ),
(
"[",
"BRACKET_LEFT",
),
(
"]",
"BRACKET_RIGHT",
),
(
"{",
"BRACE_LEFT",
),
(
"}",
"BRACE_RIGHT",
),
(
"^",
"META",
),
(
"#",
"MACRO_DISPATCH",
),
("'", "SINGLE_QUOTE"), ("'", "SINGLE_QUOTE"),
("foo", "SYMBOL",), (
"foo",
"SYMBOL",
),
("foo/bar", "SYMBOL"), ("foo/bar", "SYMBOL"),
(":foo", "KEYWORD",), (
(":foo/bar", "KEYWORD",), ":foo",
(" ,,\t ,, \t", "WHITESPACE",), "KEYWORD",
),
(
":foo/bar",
"KEYWORD",
),
(
" ,,\t ,, \t",
"WHITESPACE",
),
("\n\r", "WHITESPACE"), ("\n\r", "WHITESPACE"),
("\n", "WHITESPACE"), ("\n", "WHITESPACE"),
(" , ", "WHITESPACE",), (
" , ",
"WHITESPACE",
),
("; this is a sample comment\n", "COMMENT"), ("; this is a sample comment\n", "COMMENT"),
('"foo"', "STRING"), ('"foo"', "STRING"),
('"foo bar baz"', "STRING"), ('"foo bar baz"', "STRING"),

View file

@ -8,12 +8,15 @@ from conftest import parametrize
import pytest import pytest
@parametrize("text", [ @parametrize(
'"', "text",
'"foo bar', [
'"""foo bar', '"',
'"""foo bar"', '"foo bar',
]) '"""foo bar',
'"""foo bar"',
],
)
def test_bad_strings_raise(text): def test_bad_strings_raise(text):
"""Tests asserting we won't let obviously bad strings fly.""" """Tests asserting we won't let obviously bad strings fly."""
# FIXME (arrdem 2021-03-13): # FIXME (arrdem 2021-03-13):
@ -22,81 +25,89 @@ def test_bad_strings_raise(text):
next(cp.parse_buffer(text)) next(cp.parse_buffer(text))
@parametrize("text", [ @parametrize(
"[1.0", "text",
"(1.0", [
"{1.0", "[1.0",
]) "(1.0",
"{1.0",
],
)
def test_unterminated_raises(text): def test_unterminated_raises(text):
"""Tests asserting that we don't let unterminated collections parse.""" """Tests asserting that we don't let unterminated collections parse."""
with pytest.raises(cp.CalfMissingCloseParseError): with pytest.raises(cp.CalfMissingCloseParseError):
next(cp.parse_buffer(text)) next(cp.parse_buffer(text))
@parametrize("text", [ @parametrize(
"[{]", "text",
"[(]", [
"({)", "[{]",
"([)", "[(]",
"{(}", "({)",
"{[}", "([)",
]) "{(}",
"{[}",
],
)
def test_unbalanced_raises(text): def test_unbalanced_raises(text):
"""Tests asserting that we don't let missmatched collections parse.""" """Tests asserting that we don't let missmatched collections parse."""
with pytest.raises(cp.CalfUnexpectedCloseParseError): with pytest.raises(cp.CalfUnexpectedCloseParseError):
next(cp.parse_buffer(text)) next(cp.parse_buffer(text))
@parametrize("buff, value", [ @parametrize(
('"foo"', "foo"), "buff, value",
('"foo\tbar"', "foo\tbar"), [
('"foo\n\rbar"', "foo\n\rbar"), ('"foo"', "foo"),
('"foo\\"bar\\""', "foo\"bar\""), ('"foo\tbar"', "foo\tbar"),
('"""foo"""', 'foo'), ('"foo\n\rbar"', "foo\n\rbar"),
('"""foo"bar"baz"""', 'foo"bar"baz'), ('"foo\\"bar\\""', 'foo"bar"'),
]) ('"""foo"""', "foo"),
('"""foo"bar"baz"""', 'foo"bar"baz'),
],
)
def test_strings_round_trip(buff, value): def test_strings_round_trip(buff, value):
assert next(cp.parse_buffer(buff)) == value assert next(cp.parse_buffer(buff)) == value
@parametrize('text, element_types', [
# Integers
("(1)", ["INTEGER"]),
("( 1 )", ["INTEGER"]),
("(,1,)", ["INTEGER"]),
("(1\n)", ["INTEGER"]),
("(\n1\n)", ["INTEGER"]),
("(1, 2, 3, 4)", ["INTEGER", "INTEGER", "INTEGER", "INTEGER"]),
# Floats @parametrize(
("(1.0)", ["FLOAT"]), "text, element_types",
("(1.0e0)", ["FLOAT"]), [
("(1e0)", ["FLOAT"]), # Integers
("(1e0)", ["FLOAT"]), ("(1)", ["INTEGER"]),
("( 1 )", ["INTEGER"]),
# Symbols ("(,1,)", ["INTEGER"]),
("(foo)", ["SYMBOL"]), ("(1\n)", ["INTEGER"]),
("(+)", ["SYMBOL"]), ("(\n1\n)", ["INTEGER"]),
("(-)", ["SYMBOL"]), ("(1, 2, 3, 4)", ["INTEGER", "INTEGER", "INTEGER", "INTEGER"]),
("(*)", ["SYMBOL"]), # Floats
("(foo-bar)", ["SYMBOL"]), ("(1.0)", ["FLOAT"]),
("(+foo-bar+)", ["SYMBOL"]), ("(1.0e0)", ["FLOAT"]),
("(+foo-bar+)", ["SYMBOL"]), ("(1e0)", ["FLOAT"]),
("( foo bar )", ["SYMBOL", "SYMBOL"]), ("(1e0)", ["FLOAT"]),
# Symbols
# Keywords ("(foo)", ["SYMBOL"]),
("(:foo)", ["KEYWORD"]), ("(+)", ["SYMBOL"]),
("( :foo )", ["KEYWORD"]), ("(-)", ["SYMBOL"]),
("(\n:foo\n)", ["KEYWORD"]), ("(*)", ["SYMBOL"]),
("(,:foo,)", ["KEYWORD"]), ("(foo-bar)", ["SYMBOL"]),
("(:foo :bar)", ["KEYWORD", "KEYWORD"]), ("(+foo-bar+)", ["SYMBOL"]),
("(:foo :bar 1)", ["KEYWORD", "KEYWORD", "INTEGER"]), ("(+foo-bar+)", ["SYMBOL"]),
("( foo bar )", ["SYMBOL", "SYMBOL"]),
# Strings # Keywords
('("foo", "bar", "baz")', ["STRING", "STRING", "STRING"]), ("(:foo)", ["KEYWORD"]),
("( :foo )", ["KEYWORD"]),
# Lists ("(\n:foo\n)", ["KEYWORD"]),
('([] [] ())', ["SQLIST", "SQLIST", "LIST"]), ("(,:foo,)", ["KEYWORD"]),
]) ("(:foo :bar)", ["KEYWORD", "KEYWORD"]),
("(:foo :bar 1)", ["KEYWORD", "KEYWORD", "INTEGER"]),
# Strings
('("foo", "bar", "baz")', ["STRING", "STRING", "STRING"]),
# Lists
("([] [] ())", ["SQLIST", "SQLIST", "LIST"]),
],
)
def test_parse_list(text, element_types): def test_parse_list(text, element_types):
"""Test we can parse various lists of contents.""" """Test we can parse various lists of contents."""
l_t = next(cp.parse_buffer(text, discard_whitespace=True)) l_t = next(cp.parse_buffer(text, discard_whitespace=True))
@ -104,45 +115,43 @@ def test_parse_list(text, element_types):
assert [t.type for t in l_t] == element_types assert [t.type for t in l_t] == element_types
@parametrize('text, element_types', [ @parametrize(
# Integers "text, element_types",
("[1]", ["INTEGER"]), [
("[ 1 ]", ["INTEGER"]), # Integers
("[,1,]", ["INTEGER"]), ("[1]", ["INTEGER"]),
("[1\n]", ["INTEGER"]), ("[ 1 ]", ["INTEGER"]),
("[\n1\n]", ["INTEGER"]), ("[,1,]", ["INTEGER"]),
("[1, 2, 3, 4]", ["INTEGER", "INTEGER", "INTEGER", "INTEGER"]), ("[1\n]", ["INTEGER"]),
("[\n1\n]", ["INTEGER"]),
# Floats ("[1, 2, 3, 4]", ["INTEGER", "INTEGER", "INTEGER", "INTEGER"]),
("[1.0]", ["FLOAT"]), # Floats
("[1.0e0]", ["FLOAT"]), ("[1.0]", ["FLOAT"]),
("[1e0]", ["FLOAT"]), ("[1.0e0]", ["FLOAT"]),
("[1e0]", ["FLOAT"]), ("[1e0]", ["FLOAT"]),
("[1e0]", ["FLOAT"]),
# Symbols # Symbols
("[foo]", ["SYMBOL"]), ("[foo]", ["SYMBOL"]),
("[+]", ["SYMBOL"]), ("[+]", ["SYMBOL"]),
("[-]", ["SYMBOL"]), ("[-]", ["SYMBOL"]),
("[*]", ["SYMBOL"]), ("[*]", ["SYMBOL"]),
("[foo-bar]", ["SYMBOL"]), ("[foo-bar]", ["SYMBOL"]),
("[+foo-bar+]", ["SYMBOL"]), ("[+foo-bar+]", ["SYMBOL"]),
("[+foo-bar+]", ["SYMBOL"]), ("[+foo-bar+]", ["SYMBOL"]),
("[ foo bar ]", ["SYMBOL", "SYMBOL"]), ("[ foo bar ]", ["SYMBOL", "SYMBOL"]),
# Keywords
# Keywords ("[:foo]", ["KEYWORD"]),
("[:foo]", ["KEYWORD"]), ("[ :foo ]", ["KEYWORD"]),
("[ :foo ]", ["KEYWORD"]), ("[\n:foo\n]", ["KEYWORD"]),
("[\n:foo\n]", ["KEYWORD"]), ("[,:foo,]", ["KEYWORD"]),
("[,:foo,]", ["KEYWORD"]), ("[:foo :bar]", ["KEYWORD", "KEYWORD"]),
("[:foo :bar]", ["KEYWORD", "KEYWORD"]), ("[:foo :bar 1]", ["KEYWORD", "KEYWORD", "INTEGER"]),
("[:foo :bar 1]", ["KEYWORD", "KEYWORD", "INTEGER"]), # Strings
('["foo", "bar", "baz"]', ["STRING", "STRING", "STRING"]),
# Strings # Lists
('["foo", "bar", "baz"]', ["STRING", "STRING", "STRING"]), ("[[] [] ()]", ["SQLIST", "SQLIST", "LIST"]),
],
# Lists )
('[[] [] ()]', ["SQLIST", "SQLIST", "LIST"]),
])
def test_parse_sqlist(text, element_types): def test_parse_sqlist(text, element_types):
"""Test we can parse various 'square' lists of contents.""" """Test we can parse various 'square' lists of contents."""
l_t = next(cp.parse_buffer(text, discard_whitespace=True)) l_t = next(cp.parse_buffer(text, discard_whitespace=True))
@ -150,41 +159,21 @@ def test_parse_sqlist(text, element_types):
assert [t.type for t in l_t] == element_types assert [t.type for t in l_t] == element_types
@parametrize('text, element_pairs', [ @parametrize(
("{}", "text, element_pairs",
[]), [
("{}", []),
("{:foo 1}", ("{:foo 1}", [["KEYWORD", "INTEGER"]]),
[["KEYWORD", "INTEGER"]]), ("{:foo 1, :bar 2}", [["KEYWORD", "INTEGER"], ["KEYWORD", "INTEGER"]]),
("{foo 1, bar 2}", [["SYMBOL", "INTEGER"], ["SYMBOL", "INTEGER"]]),
("{:foo 1, :bar 2}", ("{foo 1, bar -2}", [["SYMBOL", "INTEGER"], ["SYMBOL", "INTEGER"]]),
[["KEYWORD", "INTEGER"], ("{foo 1, bar -2e0}", [["SYMBOL", "INTEGER"], ["SYMBOL", "FLOAT"]]),
["KEYWORD", "INTEGER"]]), ("{foo ()}", [["SYMBOL", "LIST"]]),
("{foo []}", [["SYMBOL", "SQLIST"]]),
("{foo 1, bar 2}", ("{foo {}}", [["SYMBOL", "DICT"]]),
[["SYMBOL", "INTEGER"], ('{"foo" {}}', [["STRING", "DICT"]]),
["SYMBOL", "INTEGER"]]), ],
)
("{foo 1, bar -2}",
[["SYMBOL", "INTEGER"],
["SYMBOL", "INTEGER"]]),
("{foo 1, bar -2e0}",
[["SYMBOL", "INTEGER"],
["SYMBOL", "FLOAT"]]),
("{foo ()}",
[["SYMBOL", "LIST"]]),
("{foo []}",
[["SYMBOL", "SQLIST"]]),
("{foo {}}",
[["SYMBOL", "DICT"]]),
('{"foo" {}}',
[["STRING", "DICT"]])
])
def test_parse_dict(text, element_pairs): def test_parse_dict(text, element_pairs):
"""Test we can parse various mappings.""" """Test we can parse various mappings."""
d_t = next(cp.parse_buffer(text, discard_whitespace=True)) d_t = next(cp.parse_buffer(text, discard_whitespace=True))
@ -192,27 +181,25 @@ def test_parse_dict(text, element_pairs):
assert [[t.type for t in pair] for pair in d_t.value] == element_pairs assert [[t.type for t in pair] for pair in d_t.value] == element_pairs
@parametrize("text", [ @parametrize("text", ["{1}", "{1, 2, 3}", "{:foo}", "{:foo :bar :baz}"])
"{1}",
"{1, 2, 3}",
"{:foo}",
"{:foo :bar :baz}"
])
def test_parse_bad_dict(text): def test_parse_bad_dict(text):
"""Assert that dicts with missmatched pairs don't parse.""" """Assert that dicts with missmatched pairs don't parse."""
with pytest.raises(Exception): with pytest.raises(Exception):
next(cp.parse_buffer(text)) next(cp.parse_buffer(text))
@parametrize("text", [ @parametrize(
"()", "text",
"(1 1.1 1e2 -2 foo :foo foo/bar :foo/bar [{},])", [
"{:foo bar, :baz [:qux]}", "()",
"'foo", "(1 1.1 1e2 -2 foo :foo foo/bar :foo/bar [{},])",
"'[foo bar :baz 'qux, {}]", "{:foo bar, :baz [:qux]}",
"#foo []", "'foo",
"^{} bar", "'[foo bar :baz 'qux, {}]",
]) "#foo []",
"^{} bar",
],
)
def test_examples(text): def test_examples(text):
"""Shotgun examples showing we can parse some stuff.""" """Shotgun examples showing we can parse some stuff."""

View file

@ -5,18 +5,22 @@ from conftest import parametrize
from calf.reader import read_buffer from calf.reader import read_buffer
@parametrize('text', [
"()", @parametrize(
"[]", "text",
"[[[[[[[[[]]]]]]]]]", [
"{1 {2 {}}}", "()",
'"foo"', "[]",
"foo", "[[[[[[[[[]]]]]]]]]",
"'foo", "{1 {2 {}}}",
"^foo bar", '"foo"',
"^:foo bar", "foo",
"{\"foo\" '([:bar ^:foo 'baz 3.14159e0])}", "'foo",
"[:foo bar 'baz lo/l, 1, 1.2. 1e-5 -1e2]", "^foo bar",
]) "^:foo bar",
"{\"foo\" '([:bar ^:foo 'baz 3.14159e0])}",
"[:foo bar 'baz lo/l, 1, 1.2. 1e-5 -1e2]",
],
)
def test_read(text): def test_read(text):
assert list(read_buffer(text)) assert list(read_buffer(text))

View file

@ -58,13 +58,13 @@ from datalog.debris import Timing
from datalog.evaluator import select from datalog.evaluator import select
from datalog.reader import pr_str, read_command, read_dataset from datalog.reader import pr_str, read_command, read_dataset
from datalog.types import ( from datalog.types import (
CachedDataset, CachedDataset,
Constant, Constant,
Dataset, Dataset,
LVar, LVar,
PartlyIndexedDataset, PartlyIndexedDataset,
Rule, Rule,
TableIndexedDataset TableIndexedDataset,
) )
from prompt_toolkit import print_formatted_text, prompt, PromptSession from prompt_toolkit import print_formatted_text, prompt, PromptSession
@ -74,190 +74,204 @@ from prompt_toolkit.styles import Style
from yaspin import Spinner, yaspin from yaspin import Spinner, yaspin
STYLE = Style.from_dict({ STYLE = Style.from_dict(
# User input (default text). {
"": "", # User input (default text).
"prompt": "ansigreen", "": "",
"time": "ansiyellow" "prompt": "ansigreen",
}) "time": "ansiyellow",
}
)
SPINNER = Spinner(["|", "/", "-", "\\"], 200) SPINNER = Spinner(["|", "/", "-", "\\"], 200)
class InterpreterInterrupt(Exception): class InterpreterInterrupt(Exception):
"""An exception used to break the prompt or evaluation.""" """An exception used to break the prompt or evaluation."""
def print_(fmt, **kwargs): def print_(fmt, **kwargs):
print_formatted_text(FormattedText(fmt), **kwargs) print_formatted_text(FormattedText(fmt), **kwargs)
def print_db(db): def print_db(db):
"""Render a database for debugging.""" """Render a database for debugging."""
for e in db.tuples(): for e in db.tuples():
print(f"{pr_str(e)}") print(f"{pr_str(e)}")
for r in db.rules(): for r in db.rules():
print(f"{pr_str(r)}") print(f"{pr_str(r)}")
def main(args): def main(args):
"""REPL entry point.""" """REPL entry point."""
if args.db_cls == "simple": if args.db_cls == "simple":
db_cls = Dataset db_cls = Dataset
elif args.db_cls == "cached": elif args.db_cls == "cached":
db_cls = CachedDataset db_cls = CachedDataset
elif args.db_cls == "table": elif args.db_cls == "table":
db_cls = TableIndexedDataset db_cls = TableIndexedDataset
elif args.db_cls == "partly": elif args.db_cls == "partly":
db_cls = PartlyIndexedDataset db_cls = PartlyIndexedDataset
print(f"Using dataset type {db_cls}") print(f"Using dataset type {db_cls}")
session = PromptSession(history=FileHistory(".datalog.history")) session = PromptSession(history=FileHistory(".datalog.history"))
db = db_cls([], []) db = db_cls([], [])
if args.dbs: if args.dbs:
for db_file in args.dbs: for db_file in args.dbs:
try: try:
with open(db_file, "r") as f: with open(db_file, "r") as f:
db = db.merge(read_dataset(f.read())) db = db.merge(read_dataset(f.read()))
print(f"Loaded {db_file} ...") print(f"Loaded {db_file} ...")
except Exception as e: except Exception as e:
print("Internal error - {e}") print("Internal error - {e}")
print(f"Unable to load db {db_file}, skipping") print(f"Unable to load db {db_file}, skipping")
while True: while True:
try: try:
line = session.prompt([("class:prompt", ">>> ")], style=STYLE) line = session.prompt([("class:prompt", ">>> ")], style=STYLE)
except (InterpreterInterrupt, KeyboardInterrupt): except (InterpreterInterrupt, KeyboardInterrupt):
continue continue
except EOFError: except EOFError:
break break
if line == ".all": if line == ".all":
op = ".all" op = ".all"
elif line == ".dbg": elif line == ".dbg":
op = ".dbg" op = ".dbg"
elif line == ".quit": elif line == ".quit":
break break
elif line in {".help", "help", "?", "??", "???"}: elif line in {".help", "help", "?", "??", "???"}:
print(__doc__) print(__doc__)
continue
elif line.split(" ")[0] == ".log":
op = ".log"
else:
try:
op, val = read_command(line)
except Exception as e:
print(f"Got an unknown command or syntax error, can't tell which")
continue
# Definition merges on the DB
if op == ".all":
print_db(db)
# .dbg drops to a debugger shell so you can poke at the instance objects (database)
elif op == ".dbg":
import pdb
pdb.set_trace()
# .log sets the log level - badly
elif op == ".log":
level = line.split(" ")[1].upper()
try:
ch.setLevel(getattr(logging, level))
except BaseException:
print(f"Unknown log level {level}")
elif op == ".":
# FIXME (arrdem 2019-06-15):
# Syntax rules the parser doesn't impose...
try:
for rule in val.rules():
assert not rule.free_vars, f"Rule contains free variables {rule.free_vars!r}"
for tuple in val.tuples():
assert not any(isinstance(e, LVar) for e in tuple), f"Tuples cannot contain lvars - {tuple!r}"
except BaseException as e:
print(f"Error: {e}")
continue
db = db.merge(val)
print_db(val)
# Queries execute - note that rules as queries have to be temporarily merged.
elif op == "?":
# In order to support ad-hoc rules (joins), we have to generate a transient "query" database
# by bolting the rule on as an overlay to the existing database. If of course we have a join.
#
# `val` was previously assumed to be the query pattern. Introduce `qdb`, now used as the
# database to query and "fix" `val` to be the temporary rule's pattern.
#
# We use a new db and db local so that the ephemeral rule doesn't persist unless the user
# later `.` defines it.
#
# Unfortunately doing this merge does nuke caches.
qdb = db
if isinstance(val, Rule):
qdb = db.merge(db_cls([], [val]))
val = val.pattern
with yaspin(SPINNER) as spinner:
with Timing() as t:
try:
results = list(select(qdb, val))
except KeyboardInterrupt:
print(f"Evaluation aborted after {t}")
continue continue
# It's kinda bogus to move sorting out but oh well elif line.split(" ")[0] == ".log":
sorted(results) op = ".log"
for _results, _bindings in results: else:
_result = _results[0] # select only selects one tuple at a time try:
print(f"{pr_str(_result)}") op, val = read_command(line)
except Exception as e:
print(f"Got an unknown command or syntax error, can't tell which")
continue
# So we can report empty sets explicitly. # Definition merges on the DB
if not results: if op == ".all":
print("⇒ Ø") print_db(db)
print_([("class:time", f"Elapsed time - {t}")], style=STYLE) # .dbg drops to a debugger shell so you can poke at the instance objects (database)
elif op == ".dbg":
import pdb
# Retractions try to delete, but may fail. pdb.set_trace()
elif op == "!":
if val in db.tuples() or val in [r.pattern for r in db.rules()]: # .log sets the log level - badly
db = db_cls([u for u in db.tuples() if u != val], elif op == ".log":
[r for r in db.rules() if r.pattern != val]) level = line.split(" ")[1].upper()
print(f"{pr_str(val)}") try:
else: ch.setLevel(getattr(logging, level))
print("⇒ Ø") except BaseException:
print(f"Unknown log level {level}")
elif op == ".":
# FIXME (arrdem 2019-06-15):
# Syntax rules the parser doesn't impose...
try:
for rule in val.rules():
assert (
not rule.free_vars
), f"Rule contains free variables {rule.free_vars!r}"
for tuple in val.tuples():
assert not any(
isinstance(e, LVar) for e in tuple
), f"Tuples cannot contain lvars - {tuple!r}"
except BaseException as e:
print(f"Error: {e}")
continue
db = db.merge(val)
print_db(val)
# Queries execute - note that rules as queries have to be temporarily merged.
elif op == "?":
# In order to support ad-hoc rules (joins), we have to generate a transient "query" database
# by bolting the rule on as an overlay to the existing database. If of course we have a join.
#
# `val` was previously assumed to be the query pattern. Introduce `qdb`, now used as the
# database to query and "fix" `val` to be the temporary rule's pattern.
#
# We use a new db and db local so that the ephemeral rule doesn't persist unless the user
# later `.` defines it.
#
# Unfortunately doing this merge does nuke caches.
qdb = db
if isinstance(val, Rule):
qdb = db.merge(db_cls([], [val]))
val = val.pattern
with yaspin(SPINNER) as spinner:
with Timing() as t:
try:
results = list(select(qdb, val))
except KeyboardInterrupt:
print(f"Evaluation aborted after {t}")
continue
# It's kinda bogus to move sorting out but oh well
sorted(results)
for _results, _bindings in results:
_result = _results[0] # select only selects one tuple at a time
print(f"{pr_str(_result)}")
# So we can report empty sets explicitly.
if not results:
print("⇒ Ø")
print_([("class:time", f"Elapsed time - {t}")], style=STYLE)
# Retractions try to delete, but may fail.
elif op == "!":
if val in db.tuples() or val in [r.pattern for r in db.rules()]:
db = db_cls(
[u for u in db.tuples() if u != val],
[r for r in db.rules() if r.pattern != val],
)
print(f"{pr_str(val)}")
else:
print("⇒ Ø")
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# Select which dataset type to use # Select which dataset type to use
parser.add_argument("--db-type", parser.add_argument(
choices=["simple", "cached", "table", "partly"], "--db-type",
help="Choose which DB to use (default partly)", choices=["simple", "cached", "table", "partly"],
dest="db_cls", help="Choose which DB to use (default partly)",
default="partly") dest="db_cls",
default="partly",
)
parser.add_argument("--load-db", dest="dbs", action="append", parser.add_argument(
help="Datalog files to load first.") "--load-db", dest="dbs", action="append", help="Datalog files to load first."
)
if __name__ == "__main__": if __name__ == "__main__":
args = parser.parse_args(sys.argv[1:]) args = parser.parse_args(sys.argv[1:])
logger = logging.getLogger("arrdem.datalog") logger = logging.getLogger("arrdem.datalog")
ch = logging.StreamHandler() ch = logging.StreamHandler()
ch.setLevel(logging.INFO) ch.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") formatter = logging.Formatter(
ch.setFormatter(formatter) "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
logger.addHandler(ch) )
main(args) ch.setFormatter(formatter)
logger.addHandler(ch)
main(args)

View file

@ -23,10 +23,7 @@ setup(
"Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.7",
], ],
scripts=["bin/datalog"],
scripts=[
"bin/datalog"
],
install_requires=[ install_requires=[
"arrdem.datalog~=2.0.0", "arrdem.datalog~=2.0.0",
"prompt_toolkit==2.0.9", "prompt_toolkit==2.0.9",

View file

@ -9,17 +9,17 @@ from uuid import uuid4 as uuid
with open("graph.dtl", "w") as f: with open("graph.dtl", "w") as f:
nodes = [] nodes = []
# Generate 10k edges # Generate 10k edges
for i in range(10000): for i in range(10000):
if nodes: if nodes:
from_node = choice(nodes) from_node = choice(nodes)
else: else:
from_node = uuid() from_node = uuid()
to_node = uuid() to_node = uuid()
nodes.append(to_node) nodes.append(to_node)
f.write(f"edge({str(from_node)!r}, {str(to_node)!r}).\n") f.write(f"edge({str(from_node)!r}, {str(to_node)!r}).\n")

View file

@ -26,5 +26,7 @@ setup(
], ],
# Package setup # Package setup
package_dir={"": "src/python"}, package_dir={"": "src/python"},
packages=["datalog",], packages=[
"datalog",
],
) )

View file

@ -16,8 +16,8 @@ def constexpr_p(expr):
class Timing(object): class Timing(object):
""" """
A context manager object which records how long the context took. A context manager object which records how long the context took.
""" """
def __init__(self): def __init__(self):
self.start = None self.start = None
@ -36,8 +36,8 @@ class Timing(object):
def __call__(self): def __call__(self):
""" """
If the context is exited, return its duration. Otherwise return the duration "so far". If the context is exited, return its duration. Otherwise return the duration "so far".
""" """
from datetime import datetime from datetime import datetime

View file

@ -22,9 +22,9 @@ def read(text: str, db_cls=PartlyIndexedDataset):
def q(t: Tuple[str]) -> LTuple: def q(t: Tuple[str]) -> LTuple:
"""Helper for writing terse queries. """Helper for writing terse queries.
Takes a tuple of strings, and interprets them as a logic tuple. Takes a tuple of strings, and interprets them as a logic tuple.
So you don't have to write the logic tuple out by hand. So you don't have to write the logic tuple out by hand.
""" """
def _x(s: str): def _x(s: str):
if s[0].isupper(): if s[0].isupper():
@ -50,38 +50,38 @@ def __result(results_bindings):
def select(db: Dataset, query: Tuple[str], bindings=None) -> Sequence[Tuple]: def select(db: Dataset, query: Tuple[str], bindings=None) -> Sequence[Tuple]:
"""Helper for interpreting tuples of strings as a query, and returning simplified results. """Helper for interpreting tuples of strings as a query, and returning simplified results.
Executes your query, returning matching full tuples. Executes your query, returning matching full tuples.
""" """
return __mapv(__result, __select(db, q(query), bindings=bindings)) return __mapv(__result, __select(db, q(query), bindings=bindings))
def join(db: Dataset, query: Sequence[Tuple[str]], bindings=None) -> Sequence[dict]: def join(db: Dataset, query: Sequence[Tuple[str]], bindings=None) -> Sequence[dict]:
"""Helper for interpreting a bunch of tuples of strings as a join query, and returning simplified """Helper for interpreting a bunch of tuples of strings as a join query, and returning simplified
results. results.
Executes the query clauses as a join, returning a sequence of tuples and binding mappings such Executes the query clauses as a join, returning a sequence of tuples and binding mappings such
that the join constraints are simultaneously satisfied. that the join constraints are simultaneously satisfied.
>>> db = read(''' >>> db = read('''
... edge(a, b). ... edge(a, b).
... edge(b, c). ... edge(b, c).
... edge(c, d). ... edge(c, d).
... ''') ... ''')
>>> join(db, [ >>> join(db, [
... ('edge', 'A', 'B'), ... ('edge', 'A', 'B'),
... ('edge', 'B', 'C') ... ('edge', 'B', 'C')
... ]) ... ])
[((('edge', 'a', 'b'), [((('edge', 'a', 'b'),
('edge', 'b', 'c')), ('edge', 'b', 'c')),
{'A': 'a', 'B': 'b', 'C': 'c'}), {'A': 'a', 'B': 'b', 'C': 'c'}),
((('edge', 'b', 'c'), ((('edge', 'b', 'c'),
('edge', 'c', 'd')), ('edge', 'c', 'd')),
{'A': 'b', 'B': 'c', 'C': 'd'}), {'A': 'b', 'B': 'c', 'C': 'd'}),
((('edge', 'c', 'd'), ((('edge', 'c', 'd'),
('edge', 'd', 'f')), ('edge', 'd', 'f')),
{'A': 'c', 'B': 'd', 'C': 'f'})] {'A': 'c', 'B': 'd', 'C': 'f'})]
""" """
return __mapv(__result, __join(db, [q(c) for c in query], bindings=bindings)) return __mapv(__result, __join(db, [q(c) for c in query], bindings=bindings))

View file

@ -20,8 +20,8 @@ from datalog.types import (
def match(tuple, expr, bindings=None): def match(tuple, expr, bindings=None):
"""Attempt to construct lvar bindings from expr such that tuple and expr equate. """Attempt to construct lvar bindings from expr such that tuple and expr equate.
If the match is successful, return the binding map, otherwise return None. If the match is successful, return the binding map, otherwise return None.
""" """
bindings = bindings.copy() if bindings is not None else {} bindings = bindings.copy() if bindings is not None else {}
for a, b in zip(expr, tuple): for a, b in zip(expr, tuple):
@ -43,9 +43,9 @@ def match(tuple, expr, bindings=None):
def apply_bindings(expr, bindings, strict=True): def apply_bindings(expr, bindings, strict=True):
"""Given an expr which may contain lvars, substitute its lvars for constants returning the """Given an expr which may contain lvars, substitute its lvars for constants returning the
simplified expr. simplified expr.
""" """
if strict: if strict:
return tuple((bindings[e] if isinstance(e, LVar) else e) for e in expr) return tuple((bindings[e] if isinstance(e, LVar) else e) for e in expr)
@ -56,10 +56,10 @@ def apply_bindings(expr, bindings, strict=True):
def select(db: Dataset, expr, bindings=None, _recursion_guard=None, _select_guard=None): def select(db: Dataset, expr, bindings=None, _recursion_guard=None, _select_guard=None):
"""Evaluate an expression in a database, lazily producing a sequence of 'matching' tuples. """Evaluate an expression in a database, lazily producing a sequence of 'matching' tuples.
The dataset is a set of tuples and rules, and the expression is a single tuple containing lvars The dataset is a set of tuples and rules, and the expression is a single tuple containing lvars
and constants. Evaluates rules and tuples, returning and constants. Evaluates rules and tuples, returning
""" """
def __select_tuples(): def __select_tuples():
# As an opt. support indexed scans, which is optional. # As an opt. support indexed scans, which is optional.
@ -170,8 +170,8 @@ def select(db: Dataset, expr, bindings=None, _recursion_guard=None, _select_guar
def join(db: Dataset, clauses, bindings, pattern=None, _recursion_guard=None): def join(db: Dataset, clauses, bindings, pattern=None, _recursion_guard=None):
"""Evaluate clauses over the dataset, joining (or antijoining) with the seed bindings. """Evaluate clauses over the dataset, joining (or antijoining) with the seed bindings.
Yields a sequence of tuples and LVar bindings for which all joins and antijoins were satisfied. Yields a sequence of tuples and LVar bindings for which all joins and antijoins were satisfied.
""" """
def __join(g, clause): def __join(g, clause):
for ts, bindings in g: for ts, bindings in g:

View file

@ -27,13 +27,19 @@ class Actions(object):
return self._db_cls(tuples, rules) return self._db_cls(tuples, rules)
def make_symbol(self, input, start, end, elements): def make_symbol(self, input, start, end, elements):
return LVar("".join(e.text for e in elements),) return LVar(
"".join(e.text for e in elements),
)
def make_word(self, input, start, end, elements): def make_word(self, input, start, end, elements):
return Constant("".join(e.text for e in elements),) return Constant(
"".join(e.text for e in elements),
)
def make_string(self, input, start, end, elements): def make_string(self, input, start, end, elements):
return Constant(elements[1].text,) return Constant(
elements[1].text,
)
def make_comment(self, input, start, end, elements): def make_comment(self, input, start, end, elements):
return None return None
@ -81,11 +87,11 @@ class Actions(object):
class Parser(Grammar): class Parser(Grammar):
"""Implementation detail. """Implementation detail.
A slightly hacked version of the Parser class canopy generates, which lets us control what the A slightly hacked version of the Parser class canopy generates, which lets us control what the
parsing entry point is. This lets me play games with having one parser and one grammar which is parsing entry point is. This lets me play games with having one parser and one grammar which is
used both for the command shell and for other things. used both for the command shell and for other things.
""" """
def __init__(self, input, actions, types): def __init__(self, input, actions, types):
self._input = input self._input = input

View file

@ -66,8 +66,8 @@ class Dataset(object):
class CachedDataset(Dataset): class CachedDataset(Dataset):
"""An extension of the dataset which features a cache of rule produced tuples. """An extension of the dataset which features a cache of rule produced tuples.
Note that this cache is lost when merging datasets - which ensures correctness. Note that this cache is lost when merging datasets - which ensures correctness.
""" """
# Inherits tuples, rules, merge # Inherits tuples, rules, merge
@ -90,11 +90,11 @@ class CachedDataset(Dataset):
class TableIndexedDataset(CachedDataset): class TableIndexedDataset(CachedDataset):
"""An extension of the Dataset type which features both a cache and an index by table & length. """An extension of the Dataset type which features both a cache and an index by table & length.
The index allows more efficient scans by maintaining 'table' style partitions. The index allows more efficient scans by maintaining 'table' style partitions.
It does not support user-defined indexing schemes. It does not support user-defined indexing schemes.
Note that index building is delayed until an index is scanned. Note that index building is delayed until an index is scanned.
""" """
# From Dataset: # From Dataset:
# tuples, rules, merge # tuples, rules, merge
@ -126,11 +126,11 @@ class TableIndexedDataset(CachedDataset):
class PartlyIndexedDataset(TableIndexedDataset): class PartlyIndexedDataset(TableIndexedDataset):
"""An extension of the Dataset type which features both a cache and and a full index by table, """An extension of the Dataset type which features both a cache and and a full index by table,
length, tuple index and value. length, tuple index and value.
The index allows extremely efficient scans when elements of the tuple are known. The index allows extremely efficient scans when elements of the tuple are known.
""" """
# From Dataset: # From Dataset:
# tuples, rules, merge # tuples, rules, merge

View file

@ -25,8 +25,19 @@ def test_id_query(db_cls):
Constant("a"), Constant("a"),
Constant("b"), Constant("b"),
) )
assert not select(db_cls([], []), ("a", "b",)) assert not select(
assert select(db_cls([ab], []), ("a", "b",)) == [((("a", "b"),), {},)] db_cls([], []),
(
"a",
"b",
),
)
assert select(db_cls([ab], []), ("a", "b",)) == [
(
(("a", "b"),),
{},
)
]
@pytest.mark.parametrize("db_cls,", DBCLS) @pytest.mark.parametrize("db_cls,", DBCLS)
@ -47,7 +58,17 @@ def test_lvar_unification(db_cls):
d = read("""edge(b, c). edge(c, c).""", db_cls=db_cls) d = read("""edge(b, c). edge(c, c).""", db_cls=db_cls)
assert select(d, ("edge", "X", "X",)) == [((("edge", "c", "c"),), {"X": "c"})] assert (
select(
d,
(
"edge",
"X",
"X",
),
)
== [((("edge", "c", "c"),), {"X": "c"})]
)
@pytest.mark.parametrize("db_cls,", DBCLS) @pytest.mark.parametrize("db_cls,", DBCLS)
@ -105,12 +126,12 @@ no-b(X, Y) :-
def test_nested_antijoin(db_cls): def test_nested_antijoin(db_cls):
"""Test a query which negates a subquery which uses an antijoin. """Test a query which negates a subquery which uses an antijoin.
Shouldn't exercise anything more than `test_antjoin` does, but it's an interesting case since you Shouldn't exercise anything more than `test_antjoin` does, but it's an interesting case since you
actually can't capture the same semantics using a single query. Antijoins can't propagate positive actually can't capture the same semantics using a single query. Antijoins can't propagate positive
information (create lvar bindings) so I'm not sure you can express this another way without a information (create lvar bindings) so I'm not sure you can express this another way without a
different evaluation strategy. different evaluation strategy.
""" """
d = read( d = read(
""" """

View file

@ -3,7 +3,7 @@ from setuptools import setup
setup( setup(
name="arrdem.flowmetal", name="arrdem.flowmetal",
# Package metadata # Package metadata
version='0.0.0', version="0.0.0",
license="MIT", license="MIT",
description="A weird execution engine", description="A weird execution engine",
long_description=open("README.md").read(), long_description=open("README.md").read(),
@ -18,20 +18,16 @@ setup(
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.8",
], ],
# Package setup # Package setup
package_dir={"": "src/python"}, package_dir={"": "src/python"},
packages=[ packages=[
"flowmetal", "flowmetal",
], ],
entry_points={ entry_points={
'console_scripts': [ "console_scripts": ["iflow=flowmetal.repl:main"],
'iflow=flowmetal.repl:main'
],
}, },
install_requires=[ install_requires=[
'prompt-toolkit~=3.0.0', "prompt-toolkit~=3.0.0",
], ],
extras_require={ extras_require={},
}
) )

View file

@ -2,7 +2,7 @@
An abstract or base Flowmetal DB. An abstract or base Flowmetal DB.
""" """
from abc import abc, abstractmethod, abstractproperty, abstractclassmethod, abstractstaticmethod from abc import abstractclassmethod, abstractmethod
class Db(ABC): class Db(ABC):

View file

@ -3,6 +3,7 @@
import click import click
@click.group() @click.group()
def cli(): def cli():
pass pass

View file

@ -3,6 +3,7 @@
import click import click
@click.group() @click.group()
def cli(): def cli():
pass pass

View file

@ -1,5 +1,3 @@
""" """
Somewhat generic models of Flowmetal programs. Somewhat generic models of Flowmetal programs.
""" """
from typing import NamedTuple

View file

@ -3,6 +3,7 @@
import click import click
@click.group() @click.group()
def cli(): def cli():
pass pass

View file

@ -3,6 +3,7 @@
import click import click
@click.group() @click.group()
def cli(): def cli():
pass pass

View file

@ -9,94 +9,102 @@ import requests
class GandiAPI(object): class GandiAPI(object):
"""An extremely incomplete Gandi REST API driver class. """An extremely incomplete Gandi REST API driver class.
Exists to close over your API key, and make talking to the API slightly more idiomatic. Exists to close over your API key, and make talking to the API slightly more idiomatic.
Note: In an effort to be nice, this API maintains a cache of the zones visible with your API Note: In an effort to be nice, this API maintains a cache of the zones visible with your API
key. The cache is maintained when using this driver, but concurrent modifications of your zone(s) key. The cache is maintained when using this driver, but concurrent modifications of your zone(s)
via the web UI or other processes will obviously undermine it. This cache can be disabled by via the web UI or other processes will obviously undermine it. This cache can be disabled by
setting the `use_cache` kwarg to `False`. setting the `use_cache` kwarg to `False`.
""" """
def __init__(self, key=None, use_cache=True): def __init__(self, key=None, use_cache=True):
self._base = "https://dns.api.gandi.net/api/v5" self._base = "https://dns.api.gandi.net/api/v5"
self._key = key self._key = key
self._use_cache = use_cache self._use_cache = use_cache
self._zones = None self._zones = None
# Helpers for making requests with the API key as required by the API. # Helpers for making requests with the API key as required by the API.
def _do_request(self, method, path, headers=None, **kwargs): def _do_request(self, method, path, headers=None, **kwargs):
headers = headers or {} headers = headers or {}
headers["X-Api-Key"] = self._key headers["X-Api-Key"] = self._key
resp = method(self._base + path, headers=headers, **kwargs) resp = method(self._base + path, headers=headers, **kwargs)
if resp.status_code > 299: if resp.status_code > 299:
print(resp.text) print(resp.text)
resp.raise_for_status() resp.raise_for_status()
return resp return resp
def _get(self, path, **kwargs): def _get(self, path, **kwargs):
return self._do_request(requests.get, path, **kwargs) return self._do_request(requests.get, path, **kwargs)
def _post(self, path, **kwargs): def _post(self, path, **kwargs):
return self._do_request(requests.post, path, **kwargs) return self._do_request(requests.post, path, **kwargs)
def _put(self, path, **kwargs): def _put(self, path, **kwargs):
return self._do_request(requests.put, path, **kwargs) return self._do_request(requests.put, path, **kwargs)
# Intentional public API # Intentional public API
def domain_records(self, domain): def domain_records(self, domain):
return self._get("/domains/{0}/records".format(domain)).json() return self._get("/domains/{0}/records".format(domain)).json()
def get_zones(self): def get_zones(self):
if self._use_cache: if self._use_cache:
if not self._zones: if not self._zones:
self._zones = self._get("/zones").json() self._zones = self._get("/zones").json()
return self._zones return self._zones
else: else:
return self._get("/zones").json() return self._get("/zones").json()
def get_zone(self, zone_id): def get_zone(self, zone_id):
return self._get("/zones/{}".format(zone_id)).json() return self._get("/zones/{}".format(zone_id)).json()
def get_zone_by_name(self, zone_name): def get_zone_by_name(self, zone_name):
for zone in self.get_zones(): for zone in self.get_zones():
if zone["name"] == zone_name: if zone["name"] == zone_name:
return zone return zone
def create_zone(self, zone_name): def create_zone(self, zone_name):
new_zone_id = self._post("/zones", new_zone_id = (
headers={"content-type": "application/json"}, self._post(
data=json.dumps({"name": zone_name}))\ "/zones",
.headers["Location"]\ headers={"content-type": "application/json"},
.split("/")[-1] data=json.dumps({"name": zone_name}),
new_zone = self.get_zone(new_zone_id) )
.headers["Location"]
.split("/")[-1]
)
new_zone = self.get_zone(new_zone_id)
# Note: If the cache is active, update the cache. # Note: If the cache is active, update the cache.
if self._use_cache and self._zones is not None: if self._use_cache and self._zones is not None:
self._zones.append(new_zone) self._zones.append(new_zone)
return new_zone return new_zone
def replace_domain(self, domain, records): def replace_domain(self, domain, records):
date = datetime.now() date = datetime.now()
date = "{:%A, %d. %B %Y %I:%M%p}".format(date) date = "{:%A, %d. %B %Y %I:%M%p}".format(date)
zone_name = f"updater generated domain - {domain} - {date}" zone_name = f"updater generated domain - {domain} - {date}"
zone = self.get_zone_by_name(zone_name) zone = self.get_zone_by_name(zone_name)
if not zone: if not zone:
zone = self.create_zone(zone_name) zone = self.create_zone(zone_name)
print("Using zone", zone["uuid"]) print("Using zone", zone["uuid"])
for r in records: for r in records:
self._post("/zones/{0}/records".format(zone["uuid"]), self._post(
headers={"content-type": "application/json"}, "/zones/{0}/records".format(zone["uuid"]),
data=json.dumps(r)) headers={"content-type": "application/json"},
data=json.dumps(r),
)
return self._post("/zones/{0}/domains/{1}".format(zone["uuid"], domain), return self._post(
headers={"content-type": "application/json"}) "/zones/{0}/domains/{1}".format(zone["uuid"], domain),
headers={"content-type": "application/json"},
)

View file

@ -20,165 +20,171 @@ import meraki
RECORD_LINE_PATTERN = re.compile( RECORD_LINE_PATTERN = re.compile(
"^(?P<rrset_name>\S+)\s+" "^(?P<rrset_name>\S+)\s+"
"(?P<rrset_ttl>\S+)\s+" "(?P<rrset_ttl>\S+)\s+"
"IN\s+" "IN\s+"
"(?P<rrset_type>\S+)\s+" "(?P<rrset_type>\S+)\s+"
"(?P<rrset_values>.+)$") "(?P<rrset_values>.+)$"
)
def update(m, k, f, *args, **kwargs): def update(m, k, f, *args, **kwargs):
"""clojure.core/update for Python's stateful maps.""" """clojure.core/update for Python's stateful maps."""
if k in m: if k in m:
m[k] = f(m[k], *args, **kwargs) m[k] = f(m[k], *args, **kwargs)
return m return m
def parse_zone_record(line): def parse_zone_record(line):
if match := RECORD_LINE_PATTERN.search(line): if match := RECORD_LINE_PATTERN.search(line):
dat = match.groupdict() dat = match.groupdict()
dat = update(dat, "rrset_ttl", int) dat = update(dat, "rrset_ttl", int)
dat = update(dat, "rrset_values", lambda x: [x]) dat = update(dat, "rrset_values", lambda x: [x])
return dat return dat
def same_record(lr, rr): def same_record(lr, rr):
""" """
A test to see if two records name the same zone entry. A test to see if two records name the same zone entry.
""" """
return lr["rrset_name"] == rr["rrset_name"] and \ return lr["rrset_name"] == rr["rrset_name"] and lr["rrset_type"] == rr["rrset_type"]
lr["rrset_type"] == rr["rrset_type"]
def records_equate(lr, rr): def records_equate(lr, rr):
""" """
Equality, ignoring rrset_href which is generated by the API. Equality, ignoring rrset_href which is generated by the API.
""" """
if not same_record(lr, rr): if not same_record(lr, rr):
return False return False
elif lr["rrset_ttl"] != rr["rrset_ttl"]: elif lr["rrset_ttl"] != rr["rrset_ttl"]:
return False return False
elif set(lr["rrset_values"]) != set(rr["rrset_values"]): elif set(lr["rrset_values"]) != set(rr["rrset_values"]):
return False return False
else: else:
return True return True
def template_and_parse_zone(template_file, template_bindings): def template_and_parse_zone(template_file, template_bindings):
assert template_file is not None assert template_file is not None
assert template_bindings is not None assert template_bindings is not None
with open(template_file) as f: with open(template_file) as f:
dat = jinja2.Template(f.read()).render(**template_bindings) dat = jinja2.Template(f.read()).render(**template_bindings)
uncommitted_records = [] uncommitted_records = []
for line in dat.splitlines(): for line in dat.splitlines():
if line and not line[0] == "#": if line and not line[0] == "#":
record = parse_zone_record(line) record = parse_zone_record(line)
if record: if record:
uncommitted_records.append(record) uncommitted_records.append(record)
records = [] records = []
for uncommitted_r in uncommitted_records: for uncommitted_r in uncommitted_records:
flag = False flag = False
for committed_r in records: for committed_r in records:
if same_record(uncommitted_r, committed_r): if same_record(uncommitted_r, committed_r):
# Join the two records # Join the two records
committed_r["rrset_values"].extend(uncommitted_r["rrset_values"]) committed_r["rrset_values"].extend(uncommitted_r["rrset_values"])
flag = True flag = True
if not flag: if not flag:
records.append(uncommitted_r) records.append(uncommitted_r)
sorted(records, key=lambda x: (x["rrset_type"], x["rrset_name"])) sorted(records, key=lambda x: (x["rrset_type"], x["rrset_name"]))
return records return records
def diff_zones(left_zone, right_zone): def diff_zones(left_zone, right_zone):
""" """
Equality between unordered lists of records constituting a zone. Equality between unordered lists of records constituting a zone.
""" """
in_left_not_right = [] in_left_not_right = []
for lr in left_zone:
flag = False
for rr in right_zone:
if same_record(lr, rr) and records_equate(lr, rr):
flag = True
break
if not flag:
in_left_not_right.append(lr)
in_right_not_left = []
for rr in right_zone:
flag = False
for lr in left_zone: for lr in left_zone:
if same_record(lr, rr) and records_equate(lr, rr): flag = False
flag = True for rr in right_zone:
break if same_record(lr, rr) and records_equate(lr, rr):
flag = True
break
if not flag: if not flag:
in_right_not_left.append(lr) in_left_not_right.append(lr)
return in_left_not_right or in_right_not_left in_right_not_left = []
for rr in right_zone:
flag = False
for lr in left_zone:
if same_record(lr, rr) and records_equate(lr, rr):
flag = True
break
if not flag:
in_right_not_left.append(lr)
return in_left_not_right or in_right_not_left
parser = argparse.ArgumentParser(description="\"Dynamic\" DNS updating for self-hosted services") parser = argparse.ArgumentParser(
description='"Dynamic" DNS updating for self-hosted services'
)
parser.add_argument("--config", dest="config_file", required=True) parser.add_argument("--config", dest="config_file", required=True)
parser.add_argument("--templates", dest="template_dir", required=True) parser.add_argument("--templates", dest="template_dir", required=True)
parser.add_argument("--dry-run", dest="dry", action="store_true", default=False) parser.add_argument("--dry-run", dest="dry", action="store_true", default=False)
def main(): def main():
args = parser.parse_args() args = parser.parse_args()
config = yaml.safe_load(open(args.config_file, "r")) config = yaml.safe_load(open(args.config_file, "r"))
dashboard = meraki.DashboardAPI(config["meraki"]["key"], output_log=False) dashboard = meraki.DashboardAPI(config["meraki"]["key"], output_log=False)
org = config["meraki"]["organization"] org = config["meraki"]["organization"]
device = config["meraki"]["router_serial"] device = config["meraki"]["router_serial"]
uplinks = dashboard.appliance.getOrganizationApplianceUplinkStatuses( uplinks = dashboard.appliance.getOrganizationApplianceUplinkStatuses(
organizationId=org, organizationId=org, serials=[device]
serials=[device] )[0]["uplinks"]
)[0]["uplinks"]
template_bindings = { template_bindings = {
"local": { "local": {
# One of the two # One of the two
"public_v4s": [link.get("publicIp") for link in uplinks if link.get("publicIp")], "public_v4s": [
}, link.get("publicIp") for link in uplinks if link.get("publicIp")
# Why isn't there a merge method ],
**config["bindings"] },
} # Why isn't there a merge method
**config["bindings"],
}
api = GandiAPI(config["gandi"]["key"]) api = GandiAPI(config["gandi"]["key"])
for task in config["tasks"]: for task in config["tasks"]:
if isinstance(task, str): if isinstance(task, str):
task = {"template": task + ".j2", task = {"template": task + ".j2", "zones": [task]}
"zones": [task]}
computed_zone = template_and_parse_zone(os.path.join(args.template_dir, task["template"]), template_bindings) computed_zone = template_and_parse_zone(
os.path.join(args.template_dir, task["template"]), template_bindings
)
for zone_name in task["zones"]: for zone_name in task["zones"]:
try: try:
live_zone = api.domain_records(zone_name) live_zone = api.domain_records(zone_name)
if diff_zones(computed_zone, live_zone): if diff_zones(computed_zone, live_zone):
print("Zone {} differs, computed zone:".format(zone_name)) print("Zone {} differs, computed zone:".format(zone_name))
pprint(computed_zone) pprint(computed_zone)
if not args.dry: if not args.dry:
print(api.replace_domain(zone_name, computed_zone)) print(api.replace_domain(zone_name, computed_zone))
else: else:
print("Zone {} up to date".format(zone_name)) print("Zone {} up to date".format(zone_name))
except Exception as e:
print("While processing zone {}".format(zone_name))
raise e
except Exception as e:
print("While processing zone {}".format(zone_name))
raise e
if __name__ == "__main__" or 1: if __name__ == "__main__" or 1:
main() main()

View file

@ -3,7 +3,7 @@ from setuptools import setup
setup( setup(
name="arrdem.ratchet", name="arrdem.ratchet",
# Package metadata # Package metadata
version='0.0.0', version="0.0.0",
license="MIT", license="MIT",
description="A 'ratcheting' message system", description="A 'ratcheting' message system",
long_description=open("README.md").read(), long_description=open("README.md").read(),
@ -18,18 +18,12 @@ setup(
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.8",
], ],
# Package setup # Package setup
package_dir={ package_dir={"": "src/python"},
"": "src/python"
},
packages=[ packages=[
"ratchet", "ratchet",
], ],
entry_points={ entry_points={},
}, install_requires=[],
install_requires=[ extras_require={},
],
extras_require={
}
) )

View file

@ -98,14 +98,15 @@ VALUES (?, ?, ?, ?);
""" """
class SQLiteDriver: class SQLiteDriver:
def __init__(self, def __init__(
filename="~/.ratchet.sqlite3", self,
sqlite_timeout=1000, filename="~/.ratchet.sqlite3",
message_ttl=60000, sqlite_timeout=1000,
message_space="_", message_ttl=60000,
message_author=f"{os.getpid()}@{socket.gethostname()}"): message_space="_",
message_author=f"{os.getpid()}@{socket.gethostname()}",
):
self._path = os.path.expanduser(filename) self._path = os.path.expanduser(filename)
self._sqlite_timeout = sqlite_timeout self._sqlite_timeout = sqlite_timeout
self._message_ttl = message_ttl self._message_ttl = message_ttl
@ -120,14 +121,11 @@ class SQLiteDriver:
conn.executescript(SCHEMA_SCRIPT) conn.executescript(SCHEMA_SCRIPT)
def _connection(self): def _connection(self):
return sql.connect(self._filename, return sql.connect(self._filename, timeout=self._sqlite_timeout)
timeout=self._sqlite_timeout)
def create_message(self, def create_message(
message: str, self, message: str, ttl: int = None, space: str = None, author: str = None
ttl: int = None, ):
space: str = None,
author: str = None):
"""Create a single message.""" """Create a single message."""
ttl = ttl or self._message_ttl ttl = ttl or self._message_ttl
@ -138,11 +136,9 @@ class SQLiteDriver:
cursor.execute(CREATE_MESSAGE_SCRIPT, author, space, ttl, message) cursor.execute(CREATE_MESSAGE_SCRIPT, author, space, ttl, message)
return cursor.lastrowid return cursor.lastrowid
def create_event(self, def create_event(
timeout: int, self, timeout: int, ttl: int = None, space: str = None, author: str = None
ttl: int = None, ):
space: str = None,
author: str = None):
"""Create a (pending) event.""" """Create a (pending) event."""
ttl = ttl or self._message_ttl ttl = ttl or self._message_ttl

View file

@ -7,111 +7,94 @@ from yamlschema import lint_buffer
import pytest import pytest
@pytest.mark.parametrize('schema, obj', [ @pytest.mark.parametrize(
({"type": "number"}, "schema, obj",
"---\n1.0"), [
({"type": "integer"}, ({"type": "number"}, "---\n1.0"),
"---\n3"), ({"type": "integer"}, "---\n3"),
({"type": "string"}, ({"type": "string"}, "---\nfoo bar baz"),
"---\nfoo bar baz"), ({"type": "string", "maxLength": 15}, "---\nfoo bar baz"),
({"type": "string", ({"type": "string", "minLength": 10}, "---\nfoo bar baz"),
"maxLength": 15}, ({"type": "string", "pattern": "^foo.*"}, "---\nfoo bar baz"),
"---\nfoo bar baz"), ({"type": "object", "additionalProperties": True}, "---\nfoo: bar\nbaz: qux"),
({"type": "string", (
"minLength": 10}, {"type": "object", "properties": {"foo": {"type": "string"}}},
"---\nfoo bar baz"), "---\nfoo: bar\nbaz: qux",
({"type": "string", ),
"pattern": "^foo.*"}, (
"---\nfoo bar baz"), {
({"type": "object", "type": "object",
"additionalProperties": True}, "properties": {"foo": {"type": "string"}},
"---\nfoo: bar\nbaz: qux"), "additionalProperties": False,
({"type": "object", },
"properties": {"foo": {"type": "string"}}}, "---\nfoo: bar",
"---\nfoo: bar\nbaz: qux"), ),
({"type": "object", ({"type": "object", "properties": {"foo": {"type": "object"}}}, "---\nfoo: {}"),
"properties": {"foo": {"type": "string"}}, (
"additionalProperties": False}, {
"---\nfoo: bar"), "type": "object",
({"type": "object", "properties": {"foo": {"type": "array", "items": {"type": "object"}}},
"properties": {"foo": {"type": "object"}}}, },
"---\nfoo: {}"), "---\nfoo: [{}, {}, {foo: bar}]",
({"type": "object", ),
"properties": {"foo": { ],
"type": "array", )
"items": {"type": "object"}}}},
"---\nfoo: [{}, {}, {foo: bar}]"),
])
def test_lint_document_ok(schema, obj): def test_lint_document_ok(schema, obj):
assert not list(lint_buffer(schema, obj)) assert not list(lint_buffer(schema, obj))
@pytest.mark.parametrize('msg, schema, obj', [ @pytest.mark.parametrize(
# Numerics "msg, schema, obj",
("Floats are not ints", [
{"type": "integer"}, # Numerics
"---\n1.0"), ("Floats are not ints", {"type": "integer"}, "---\n1.0"),
("Ints are not floats", ("Ints are not floats", {"type": "number"}, "---\n1"),
{"type": "number"}, # Numerics - range limits. Integer edition
"---\n1"), (
"1 is the limit of the range",
# Numerics - range limits. Integer edition {"type": "integer", "exclusiveMaximum": 1},
("1 is the limit of the range", "---\n1",
{"type": "integer", ),
"exclusiveMaximum": 1}, (
"---\n1"), "1 is the limit of the range",
("1 is the limit of the range", {"type": "integer", "exclusiveMinimum": 1},
{"type": "integer", "---\n1",
"exclusiveMinimum": 1}, ),
"---\n1"), ("1 is out of the range", {"type": "integer", "minimum": 2}, "---\n1"),
("1 is out of the range", ("1 is out of the range", {"type": "integer", "maximum": 0}, "---\n1"),
{"type": "integer", ("1 is out of the range", {"type": "integer", "exclusiveMinimum": 1}, "---\n1"),
"minimum": 2}, # Numerics - range limits. Number/Float edition
"---\n1"), (
("1 is out of the range", "1 is the limit of the range",
{"type": "integer", {"type": "number", "exclusiveMaximum": 1},
"maximum": 0}, "---\n1.0",
"---\n1"), ),
("1 is out of the range", (
{"type": "integer", "1 is the limit of the range",
"exclusiveMinimum": 1}, {"type": "number", "exclusiveMinimum": 1},
"---\n1"), "---\n1.0",
),
# Numerics - range limits. Number/Float edition ("1 is out of the range", {"type": "number", "minimum": 2}, "---\n1.0"),
("1 is the limit of the range", ("1 is out of the range", {"type": "number", "maximum": 0}, "---\n1.0"),
{"type": "number", (
"exclusiveMaximum": 1}, "1 is out of the range",
"---\n1.0"), {"type": "number", "exclusiveMinimum": 1},
("1 is the limit of the range", "---\n1.0",
{"type": "number", ),
"exclusiveMinimum": 1}, # String shit
"---\n1.0"), ("String too short", {"type": "string", "minLength": 1}, "---\n''"),
("1 is out of the range", ("String too long", {"type": "string", "maxLength": 1}, "---\nfoo"),
{"type": "number", (
"minimum": 2}, "String does not match pattern",
"---\n1.0"), {"type": "string", "pattern": "bar"},
("1 is out of the range", "---\nfoo",
{"type": "number", ),
"maximum": 0}, (
"---\n1.0"), "String does not fully match pattern",
("1 is out of the range", {"type": "string", "pattern": "foo"},
{"type": "number", "---\nfooooooooo",
"exclusiveMinimum": 1}, ),
"---\n1.0"), ],
)
# String shit
("String too short",
{"type": "string", "minLength": 1},
"---\n''"),
("String too long",
{"type": "string", "maxLength": 1},
"---\nfoo"),
("String does not match pattern",
{"type": "string", "pattern": "bar"},
"---\nfoo"),
("String does not fully match pattern",
{"type": "string", "pattern": "foo"},
"---\nfooooooooo"),
])
def test_lint_document_fails(msg, schema, obj): def test_lint_document_fails(msg, schema, obj):
assert list(lint_buffer(schema, obj)), msg assert list(lint_buffer(schema, obj)), msg

View file

@ -59,9 +59,7 @@ class YamlLinter(object):
return schema return schema
def lint_mapping(self, schema, node: Node) -> t.Iterable[str]: def lint_mapping(self, schema, node: Node) -> t.Iterable[str]:
"""FIXME. """FIXME."""
"""
if schema["type"] != "object" or not isinstance(node, MappingNode): if schema["type"] != "object" or not isinstance(node, MappingNode):
yield LintRecord( yield LintRecord(
@ -71,9 +69,7 @@ class YamlLinter(object):
f"Expected {schema['type']}, got {node.id} {str(node.start_mark).lstrip()}", f"Expected {schema['type']}, got {node.id} {str(node.start_mark).lstrip()}",
) )
additional_type: t.Union[dict, bool] = ( additional_type: t.Union[dict, bool] = schema.get("additionalProperties", True)
schema.get("additionalProperties", True)
)
properties: dict = schema.get("properties", {}) properties: dict = schema.get("properties", {})
required: t.Iterable[str] = schema.get("required", []) required: t.Iterable[str] = schema.get("required", [])
@ -135,37 +131,26 @@ class YamlLinter(object):
elif schema["type"] == "number": elif schema["type"] == "number":
yield from self.lint_number(schema, node) yield from self.lint_number(schema, node)
else: else:
raise NotImplementedError( raise NotImplementedError(f"Scalar type {schema['type']} is not supported")
f"Scalar type {schema['type']} is not supported"
)
def lint_string(self, schema, node: Node) -> t.Iterable[str]: def lint_string(self, schema, node: Node) -> t.Iterable[str]:
"""FIXME.""" """FIXME."""
if node.tag != "tag:yaml.org,2002:str": if node.tag != "tag:yaml.org,2002:str":
yield LintRecord( yield LintRecord(
LintLevel.MISSMATCH, LintLevel.MISSMATCH, node, schema, f"Expected a string, got a {node}"
node,
schema,
f"Expected a string, got a {node}"
) )
if maxl := schema.get("maxLength"): if maxl := schema.get("maxLength"):
if len(node.value) > maxl: if len(node.value) > maxl:
yield LintRecord( yield LintRecord(
LintLevel.MISSMATCH, LintLevel.MISSMATCH, node, schema, f"Expected a shorter string"
node,
schema,
f"Expected a shorter string"
) )
if minl := schema.get("minLength"): if minl := schema.get("minLength"):
if len(node.value) < minl: if len(node.value) < minl:
yield LintRecord( yield LintRecord(
LintLevel.MISSMATCH, LintLevel.MISSMATCH, node, schema, f"Expected a longer string"
node,
schema,
f"Expected a longer string"
) )
if pat := schema.get("pattern"): if pat := schema.get("pattern"):
@ -174,7 +159,7 @@ class YamlLinter(object):
LintLevel.MISSMATCH, LintLevel.MISSMATCH,
node, node,
schema, schema,
f"Expected a string matching the pattern" f"Expected a string matching the pattern",
) )
def lint_integer(self, schema, node: Node) -> t.Iterable[str]: def lint_integer(self, schema, node: Node) -> t.Iterable[str]:
@ -184,10 +169,7 @@ class YamlLinter(object):
else: else:
yield LintRecord( yield LintRecord(
LintLevel.MISSMATCH, LintLevel.MISSMATCH, node, schema, f"Expected an integer, got a {node}"
node,
schema,
f"Expected an integer, got a {node}"
) )
def lint_number(self, schema, node: Node) -> t.Iterable[str]: def lint_number(self, schema, node: Node) -> t.Iterable[str]:
@ -197,13 +179,9 @@ class YamlLinter(object):
else: else:
yield LintRecord( yield LintRecord(
LintLevel.MISSMATCH, LintLevel.MISSMATCH, node, schema, f"Expected an integer, got a {node}"
node,
schema,
f"Expected an integer, got a {node}"
) )
def _lint_num_range(self, schema, node: Node, value) -> t.Iterable[str]: def _lint_num_range(self, schema, node: Node, value) -> t.Iterable[str]:
""""FIXME.""" """"FIXME."""
@ -213,7 +191,7 @@ class YamlLinter(object):
LintLevel.MISSMATCH, LintLevel.MISSMATCH,
node, node,
schema, schema,
f"Expected a multiple of {base}, got {value}" f"Expected a multiple of {base}, got {value}",
) )
if (max := schema.get("exclusiveMaximum")) is not None: if (max := schema.get("exclusiveMaximum")) is not None:
@ -222,7 +200,7 @@ class YamlLinter(object):
LintLevel.MISSMATCH, LintLevel.MISSMATCH,
node, node,
schema, schema,
f"Expected a value less than {max}, got {value}" f"Expected a value less than {max}, got {value}",
) )
if (max := schema.get("maximum")) is not None: if (max := schema.get("maximum")) is not None:
@ -231,7 +209,7 @@ class YamlLinter(object):
LintLevel.MISSMATCH, LintLevel.MISSMATCH,
node, node,
schema, schema,
f"Expected a value less than or equal to {max}, got {value}" f"Expected a value less than or equal to {max}, got {value}",
) )
if (min := schema.get("exclusiveMinimum")) is not None: if (min := schema.get("exclusiveMinimum")) is not None:
@ -240,7 +218,7 @@ class YamlLinter(object):
LintLevel.MISSMATCH, LintLevel.MISSMATCH,
node, node,
schema, schema,
f"Expected a value greater than {min}, got {value}" f"Expected a value greater than {min}, got {value}",
) )
if (min := schema.get("minimum")) is not None: if (min := schema.get("minimum")) is not None:
@ -249,7 +227,7 @@ class YamlLinter(object):
LintLevel.MISSMATCH, LintLevel.MISSMATCH,
node, node,
schema, schema,
f"Expected a value greater than or equal to {min}, got {value}" f"Expected a value greater than or equal to {min}, got {value}",
) )
def lint_document(self, node, schema=None) -> t.Iterable[str]: def lint_document(self, node, schema=None) -> t.Iterable[str]:
@ -271,10 +249,7 @@ class YamlLinter(object):
# This is the schema that rejects everything. # This is the schema that rejects everything.
elif schema == False: elif schema == False:
yield LintRecord( yield LintRecord(
LintLevel.UNEXPECTED, LintLevel.UNEXPECTED, node, schema, "Received an unexpected value"
node,
schema,
"Received an unexpected value"
) )
# Walking the PyYAML node hierarchy # Walking the PyYAML node hierarchy