From bbae5ef63f4f4e1f21f897519c4f3f632a835e8a Mon Sep 17 00:00:00 2001 From: Reid 'arrdem' McKenzie Date: Sat, 15 May 2021 11:34:32 -0600 Subject: [PATCH] And fmt --- projects/calf/setup.py | 1 - projects/calf/src/python/calf/cursedrepl.py | 20 +- projects/calf/src/python/calf/grammar.py | 80 ++++- projects/calf/src/python/calf/lexer.py | 1 - projects/calf/src/python/calf/packages.py | 43 ++- projects/calf/src/python/calf/parser.py | 43 ++- projects/calf/src/python/calf/reader.py | 23 +- projects/calf/tests/python/test_grammar.py | 37 +- projects/calf/tests/python/test_lexer.py | 65 +++- projects/calf/tests/python/test_parser.py | 295 ++++++++-------- projects/calf/tests/python/test_reader.py | 30 +- projects/datalog-shell/__main__.py | 328 +++++++++--------- projects/datalog-shell/setup.py | 5 +- projects/datalog/make_graph.py | 20 +- projects/datalog/setup.py | 4 +- projects/datalog/src/python/datalog/debris.py | 8 +- projects/datalog/src/python/datalog/easy.py | 54 +-- .../datalog/src/python/datalog/evaluator.py | 18 +- projects/datalog/src/python/datalog/reader.py | 20 +- projects/datalog/src/python/datalog/types.py | 18 +- .../test/python/test_datalog_evaluator.py | 37 +- projects/flowmetal/setup.py | 12 +- .../flowmetal/src/python/flowmetal/db/base.py | 2 +- .../src/python/flowmetal/frontend.py | 1 + .../src/python/flowmetal/interpreter.py | 1 + .../flowmetal/src/python/flowmetal/models.py | 2 - .../flowmetal/src/python/flowmetal/reaper.py | 1 + .../src/python/flowmetal/scheduler.py | 1 + projects/gandi/src/python/gandi/client.py | 148 ++++---- .../src/python/arrdem/updater/__main__.py | 232 +++++++------ projects/ratchet/setup.py | 16 +- .../src/python/ratchet/backend/sqlite.py | 34 +- projects/yamlschema/test_yamlschema.py | 187 +++++----- projects/yamlschema/yamlschema.py | 55 +-- 34 files changed, 956 insertions(+), 886 deletions(-) diff --git a/projects/calf/setup.py b/projects/calf/setup.py index 6edcf6d..08b87b5 100644 --- a/projects/calf/setup.py +++ b/projects/calf/setup.py @@ -25,7 +25,6 @@ setup( "calf-read = calf.reader:main", "calf-analyze = calf.analyzer:main", "calf-compile = calf.compiler:main", - # Client/server stuff "calf-client = calf.client:main", "calf-server = calf.server:main", diff --git a/projects/calf/src/python/calf/cursedrepl.py b/projects/calf/src/python/calf/cursedrepl.py index 447f7c2..ba2eaaf 100644 --- a/projects/calf/src/python/calf/cursedrepl.py +++ b/projects/calf/src/python/calf/cursedrepl.py @@ -7,7 +7,6 @@ from curses.textpad import Textbox, rectangle def curse_repl(handle_buffer): - def handle(buff, count): try: return list(handle_buffer(buff, count)), None @@ -24,22 +23,25 @@ def curse_repl(handle_buffer): maxy, maxx = stdscr.getmaxyx() stdscr.clear() - stdscr.addstr(0, 0, "Enter example: (hit Ctrl-G to execute, Ctrl-C to exit)", curses.A_BOLD) - editwin = curses.newwin(5, maxx - 4, - 2, 2) - rectangle(stdscr, - 1, 1, - 1 + 5 + 1, maxx - 2) + stdscr.addstr( + 0, + 0, + "Enter example: (hit Ctrl-G to execute, Ctrl-C to exit)", + curses.A_BOLD, + ) + editwin = curses.newwin(5, maxx - 4, 2, 2) + rectangle(stdscr, 1, 1, 1 + 5 + 1, maxx - 2) # Printing is part of the prompt cur = 8 + def putstr(str, x=0, attr=0): # ya rly. I know exactly what I'm doing here nonlocal cur # This is how we handle going off the bottom of the scren lol if cur < maxy: stdscr.addstr(cur, x, str, attr) - cur += (len(str.split("\n")) or 1) + cur += len(str.split("\n")) or 1 for ex, buff, vals, err in reversed(examples): putstr(f"Example {ex}:", attr=curses.A_BOLD) @@ -58,7 +60,7 @@ def curse_repl(handle_buffer): elif vals: putstr(" Values:") - for x, t in zip(range(1, 1<<64), vals): + for x, t in zip(range(1, 1 << 64), vals): putstr(f" {x:>3}) " + repr(t)) putstr("") diff --git a/projects/calf/src/python/calf/grammar.py b/projects/calf/src/python/calf/grammar.py index 2ea6a44..d11d204 100644 --- a/projects/calf/src/python/calf/grammar.py +++ b/projects/calf/src/python/calf/grammar.py @@ -28,34 +28,82 @@ COMMENT_PATTERN = r";(([^\n\r]*)(\n\r?)?)" TOKENS = [ # Paren (noral) lists - (r"\(", "PAREN_LEFT",), - (r"\)", "PAREN_RIGHT",), + ( + r"\(", + "PAREN_LEFT", + ), + ( + r"\)", + "PAREN_RIGHT", + ), # Bracket lists - (r"\[", "BRACKET_LEFT",), - (r"\]", "BRACKET_RIGHT",), + ( + r"\[", + "BRACKET_LEFT", + ), + ( + r"\]", + "BRACKET_RIGHT", + ), # Brace lists (maps) - (r"\{", "BRACE_LEFT",), - (r"\}", "BRACE_RIGHT",), - (r"\^", "META",), - (r"'", "SINGLE_QUOTE",), - (STRING_PATTERN, "STRING",), - (r"#", "MACRO_DISPATCH",), + ( + r"\{", + "BRACE_LEFT", + ), + ( + r"\}", + "BRACE_RIGHT", + ), + ( + r"\^", + "META", + ), + ( + r"'", + "SINGLE_QUOTE", + ), + ( + STRING_PATTERN, + "STRING", + ), + ( + r"#", + "MACRO_DISPATCH", + ), # Symbols - (SYMBOL_PATTERN, "SYMBOL",), + ( + SYMBOL_PATTERN, + "SYMBOL", + ), # Numbers - (SIMPLE_INTEGER, "INTEGER",), - (FLOAT_PATTERN, "FLOAT",), + ( + SIMPLE_INTEGER, + "INTEGER", + ), + ( + FLOAT_PATTERN, + "FLOAT", + ), # Keywords # # Note: this is a dirty f'n hack in that in order for keywords to work, ":" # has to be defined to be a valid keyword. - (r":" + SYMBOL_PATTERN + "?", "KEYWORD",), + ( + r":" + SYMBOL_PATTERN + "?", + "KEYWORD", + ), # Whitespace # # Note that the whitespace token will contain at most one newline - (r"(\n\r?|[,\t ]*)", "WHITESPACE",), + ( + r"(\n\r?|[,\t ]*)", + "WHITESPACE", + ), # Comment - (COMMENT_PATTERN, "COMMENT",), + ( + COMMENT_PATTERN, + "COMMENT", + ), # Strings (r'"(?P(?:[^\"]|\.)*)"', "STRING"), ] diff --git a/projects/calf/src/python/calf/lexer.py b/projects/calf/src/python/calf/lexer.py index e6d71d8..f779421 100644 --- a/projects/calf/src/python/calf/lexer.py +++ b/projects/calf/src/python/calf/lexer.py @@ -8,7 +8,6 @@ parsing, linting or other use. import io import re -import sys from calf.token import CalfToken from calf.io.reader import PeekPosReader diff --git a/projects/calf/src/python/calf/packages.py b/projects/calf/src/python/calf/packages.py index 15511f7..aae64d7 100644 --- a/projects/calf/src/python/calf/packages.py +++ b/projects/calf/src/python/calf/packages.py @@ -12,48 +12,47 @@ from collections import namedtuple class CalfLoaderConfig(namedtuple("CalfLoaderConfig", ["paths"])): - """ - """ + """""" class CalfDelayedPackage( namedtuple("CalfDelayedPackage", ["name", "version", "metadata", "path"]) ): """ - This structure represents the delay of loading a packaage. + This structure represents the delay of loading a packaage. - Rather than eagerly analyze packages, it may be profitable to use lazy loading / lazy resolution - of symbols. It may also be possible to cache analyzing some packages. - """ + Rather than eagerly analyze packages, it may be profitable to use lazy loading / lazy resolution + of symbols. It may also be possible to cache analyzing some packages. + """ class CalfPackage( namedtuple("CalfPackage", ["name", "version", "metadata", "modules"]) ): """ - This structure represents the result of forcing the load of a package, and is the product of - either loading a package directly, or a package becoming a direct dependency and being forced. - """ + This structure represents the result of forcing the load of a package, and is the product of + either loading a package directly, or a package becoming a direct dependency and being forced. + """ def parse_package_requirement(config, env, requirement): """ - :param config: - :param env: - :param requirement: - :returns: + :param config: + :param env: + :param requirement: + :returns: - - """ + + """ def analyze_package(config, env, package): """ - :param config: - :param env: - :param module: - :returns: + :param config: + :param env: + :param module: + :returns: - Given a loader configuration and an environment to load into, analyzes the requested package, - returning an updated environment. - """ + Given a loader configuration and an environment to load into, analyzes the requested package, + returning an updated environment. + """ diff --git a/projects/calf/src/python/calf/parser.py b/projects/calf/src/python/calf/parser.py index 2600525..9162976 100644 --- a/projects/calf/src/python/calf/parser.py +++ b/projects/calf/src/python/calf/parser.py @@ -2,11 +2,8 @@ The Calf parser. """ -from collections import namedtuple from itertools import tee import logging -import sys -from typing import NamedTuple, Callable from calf.lexer import CalfLexer, lex_buffer, lex_file from calf.grammar import MATCHING, WHITESPACE_TYPES @@ -45,17 +42,18 @@ def mk_dict(contents, open=None, close=None): close.start_position, ) + def mk_str(token): buff = token.value if buff.startswith('"""') and not buff.endswith('"""'): - raise ValueError('Unterminated tripple quote string') + raise ValueError("Unterminated tripple quote string") elif buff.startswith('"') and not buff.endswith('"'): - raise ValueError('Unterminated quote string') + raise ValueError("Unterminated quote string") elif not buff.startswith('"') or buff == '"' or buff == '"""': - raise ValueError('Illegal string') + raise ValueError("Illegal string") if buff.startswith('"""'): buff = buff[3:-3] @@ -114,15 +112,17 @@ class CalfMissingCloseParseError(CalfParseError): def __init__(self, expected_close_token, open_token): super(CalfMissingCloseParseError, self).__init__( f"expected {expected_close_token} starting from {open_token}, got end of file.", - open_token + open_token, ) self.expected_close_token = expected_close_token -def parse_stream(stream, - discard_whitespace: bool = True, - discard_comments: bool = True, - stack: list = None): +def parse_stream( + stream, + discard_whitespace: bool = True, + discard_comments: bool = True, + stack: list = None, +): """Parses a token stream, producing a lazy sequence of all read top level forms. If `discard_whitespace` is truthy, then no WHITESPACE tokens will be emitted @@ -134,11 +134,10 @@ def parse_stream(stream, stack = stack or [] - def recur(_stack = None): - yield from parse_stream(stream, - discard_whitespace, - discard_comments, - _stack or stack) + def recur(_stack=None): + yield from parse_stream( + stream, discard_whitespace, discard_comments, _stack or stack + ) for token in stream: # Whitespace discarding @@ -205,7 +204,9 @@ def parse_stream(stream, # Case of maybe matching something else, but definitely being wrong else: - matching = next(reversed([t[1] for t in stack if t[0] == token.type]), None) + matching = next( + reversed([t[1] for t in stack if t[0] == token.type]), None + ) raise CalfUnexpectedCloseParseError(token, matching) # Atoms @@ -216,18 +217,14 @@ def parse_stream(stream, yield token -def parse_buffer(buffer, - discard_whitespace=True, - discard_comments=True): +def parse_buffer(buffer, discard_whitespace=True, discard_comments=True): """ Parses a buffer, producing a lazy sequence of all parsed level forms. Propagates all errors. """ - yield from parse_stream(lex_buffer(buffer), - discard_whitespace, - discard_comments) + yield from parse_stream(lex_buffer(buffer), discard_whitespace, discard_comments) def parse_file(file): diff --git a/projects/calf/src/python/calf/reader.py b/projects/calf/src/python/calf/reader.py index 3a52758..efd47a1 100644 --- a/projects/calf/src/python/calf/reader.py +++ b/projects/calf/src/python/calf/reader.py @@ -13,6 +13,7 @@ from calf.parser import parse_stream from calf.token import * from calf.types import * + class CalfReader(object): def handle_keyword(self, t: CalfToken) -> Any: """Convert a token to an Object value for a symbol. @@ -79,8 +80,7 @@ class CalfReader(object): return Vector.of(self.read(t.value)) elif isinstance(t, CalfDictToken): - return Map.of([(self.read1(k), self.read1(v)) - for k, v in t.items()]) + return Map.of([(self.read1(k), self.read1(v)) for k, v in t.items()]) # Magical pairwise stuff elif isinstance(t, CalfQuoteToken): @@ -119,28 +119,21 @@ class CalfReader(object): yield self.read1(t) -def read_stream(stream, - reader: CalfReader = None): - """Read from a stream of parsed tokens. - - """ +def read_stream(stream, reader: CalfReader = None): + """Read from a stream of parsed tokens.""" reader = reader or CalfReader() yield from reader.read(stream) def read_buffer(buffer): - """Read from a buffer, producing a lazy sequence of all top level forms. - - """ + """Read from a buffer, producing a lazy sequence of all top level forms.""" yield from read_stream(parse_stream(lex_buffer(buffer))) def read_file(file): - """Read from a file, producing a lazy sequence of all top level forms. - - """ + """Read from a file, producing a lazy sequence of all top level forms.""" yield from read_stream(parse_stream(lex_file(file))) @@ -151,6 +144,8 @@ def main(): from calf.cursedrepl import curse_repl def handle_buffer(buff, count): - return list(read_stream(parse_stream(lex_buffer(buff, source=f"")))) + return list( + read_stream(parse_stream(lex_buffer(buff, source=f""))) + ) curse_repl(handle_buffer) diff --git a/projects/calf/tests/python/test_grammar.py b/projects/calf/tests/python/test_grammar.py index 71351bb..b4cf0fc 100644 --- a/projects/calf/tests/python/test_grammar.py +++ b/projects/calf/tests/python/test_grammar.py @@ -8,23 +8,24 @@ from calf import grammar as cg from conftest import parametrize -@parametrize('ex', [ - # Proper strings - '""', - '"foo bar"', - '"foo\n bar\n\r qux"', - '"foo\\"bar"', - - '""""""', - '"""foo bar baz"""', - '"""foo "" "" "" bar baz"""', - - # Unterminated string cases - '"', - '"f', - '"foo bar', - '"foo\\" bar', - '"""foo bar baz', -]) +@parametrize( + "ex", + [ + # Proper strings + '""', + '"foo bar"', + '"foo\n bar\n\r qux"', + '"foo\\"bar"', + '""""""', + '"""foo bar baz"""', + '"""foo "" "" "" bar baz"""', + # Unterminated string cases + '"', + '"f', + '"foo bar', + '"foo\\" bar', + '"""foo bar baz', + ], +) def test_match_string(ex): assert re.fullmatch(cg.STRING_PATTERN, ex) diff --git a/projects/calf/tests/python/test_lexer.py b/projects/calf/tests/python/test_lexer.py index 0c46d4c..7751859 100644 --- a/projects/calf/tests/python/test_lexer.py +++ b/projects/calf/tests/python/test_lexer.py @@ -20,23 +20,62 @@ def lex_single_token(buffer): @parametrize( "text,token_type", [ - ("(", "PAREN_LEFT",), - (")", "PAREN_RIGHT",), - ("[", "BRACKET_LEFT",), - ("]", "BRACKET_RIGHT",), - ("{", "BRACE_LEFT",), - ("}", "BRACE_RIGHT",), - ("^", "META",), - ("#", "MACRO_DISPATCH",), + ( + "(", + "PAREN_LEFT", + ), + ( + ")", + "PAREN_RIGHT", + ), + ( + "[", + "BRACKET_LEFT", + ), + ( + "]", + "BRACKET_RIGHT", + ), + ( + "{", + "BRACE_LEFT", + ), + ( + "}", + "BRACE_RIGHT", + ), + ( + "^", + "META", + ), + ( + "#", + "MACRO_DISPATCH", + ), ("'", "SINGLE_QUOTE"), - ("foo", "SYMBOL",), + ( + "foo", + "SYMBOL", + ), ("foo/bar", "SYMBOL"), - (":foo", "KEYWORD",), - (":foo/bar", "KEYWORD",), - (" ,,\t ,, \t", "WHITESPACE",), + ( + ":foo", + "KEYWORD", + ), + ( + ":foo/bar", + "KEYWORD", + ), + ( + " ,,\t ,, \t", + "WHITESPACE", + ), ("\n\r", "WHITESPACE"), ("\n", "WHITESPACE"), - (" , ", "WHITESPACE",), + ( + " , ", + "WHITESPACE", + ), ("; this is a sample comment\n", "COMMENT"), ('"foo"', "STRING"), ('"foo bar baz"', "STRING"), diff --git a/projects/calf/tests/python/test_parser.py b/projects/calf/tests/python/test_parser.py index e0f4a9e..1e87431 100644 --- a/projects/calf/tests/python/test_parser.py +++ b/projects/calf/tests/python/test_parser.py @@ -8,12 +8,15 @@ from conftest import parametrize import pytest -@parametrize("text", [ - '"', - '"foo bar', - '"""foo bar', - '"""foo bar"', -]) +@parametrize( + "text", + [ + '"', + '"foo bar', + '"""foo bar', + '"""foo bar"', + ], +) def test_bad_strings_raise(text): """Tests asserting we won't let obviously bad strings fly.""" # FIXME (arrdem 2021-03-13): @@ -22,81 +25,89 @@ def test_bad_strings_raise(text): next(cp.parse_buffer(text)) -@parametrize("text", [ - "[1.0", - "(1.0", - "{1.0", -]) +@parametrize( + "text", + [ + "[1.0", + "(1.0", + "{1.0", + ], +) def test_unterminated_raises(text): """Tests asserting that we don't let unterminated collections parse.""" with pytest.raises(cp.CalfMissingCloseParseError): next(cp.parse_buffer(text)) -@parametrize("text", [ - "[{]", - "[(]", - "({)", - "([)", - "{(}", - "{[}", -]) +@parametrize( + "text", + [ + "[{]", + "[(]", + "({)", + "([)", + "{(}", + "{[}", + ], +) def test_unbalanced_raises(text): """Tests asserting that we don't let missmatched collections parse.""" with pytest.raises(cp.CalfUnexpectedCloseParseError): next(cp.parse_buffer(text)) -@parametrize("buff, value", [ - ('"foo"', "foo"), - ('"foo\tbar"', "foo\tbar"), - ('"foo\n\rbar"', "foo\n\rbar"), - ('"foo\\"bar\\""', "foo\"bar\""), - ('"""foo"""', 'foo'), - ('"""foo"bar"baz"""', 'foo"bar"baz'), -]) +@parametrize( + "buff, value", + [ + ('"foo"', "foo"), + ('"foo\tbar"', "foo\tbar"), + ('"foo\n\rbar"', "foo\n\rbar"), + ('"foo\\"bar\\""', 'foo"bar"'), + ('"""foo"""', "foo"), + ('"""foo"bar"baz"""', 'foo"bar"baz'), + ], +) def test_strings_round_trip(buff, value): assert next(cp.parse_buffer(buff)) == value -@parametrize('text, element_types', [ - # Integers - ("(1)", ["INTEGER"]), - ("( 1 )", ["INTEGER"]), - ("(,1,)", ["INTEGER"]), - ("(1\n)", ["INTEGER"]), - ("(\n1\n)", ["INTEGER"]), - ("(1, 2, 3, 4)", ["INTEGER", "INTEGER", "INTEGER", "INTEGER"]), - # Floats - ("(1.0)", ["FLOAT"]), - ("(1.0e0)", ["FLOAT"]), - ("(1e0)", ["FLOAT"]), - ("(1e0)", ["FLOAT"]), - - # Symbols - ("(foo)", ["SYMBOL"]), - ("(+)", ["SYMBOL"]), - ("(-)", ["SYMBOL"]), - ("(*)", ["SYMBOL"]), - ("(foo-bar)", ["SYMBOL"]), - ("(+foo-bar+)", ["SYMBOL"]), - ("(+foo-bar+)", ["SYMBOL"]), - ("( foo bar )", ["SYMBOL", "SYMBOL"]), - - # Keywords - ("(:foo)", ["KEYWORD"]), - ("( :foo )", ["KEYWORD"]), - ("(\n:foo\n)", ["KEYWORD"]), - ("(,:foo,)", ["KEYWORD"]), - ("(:foo :bar)", ["KEYWORD", "KEYWORD"]), - ("(:foo :bar 1)", ["KEYWORD", "KEYWORD", "INTEGER"]), - - # Strings - ('("foo", "bar", "baz")', ["STRING", "STRING", "STRING"]), - - # Lists - ('([] [] ())', ["SQLIST", "SQLIST", "LIST"]), -]) +@parametrize( + "text, element_types", + [ + # Integers + ("(1)", ["INTEGER"]), + ("( 1 )", ["INTEGER"]), + ("(,1,)", ["INTEGER"]), + ("(1\n)", ["INTEGER"]), + ("(\n1\n)", ["INTEGER"]), + ("(1, 2, 3, 4)", ["INTEGER", "INTEGER", "INTEGER", "INTEGER"]), + # Floats + ("(1.0)", ["FLOAT"]), + ("(1.0e0)", ["FLOAT"]), + ("(1e0)", ["FLOAT"]), + ("(1e0)", ["FLOAT"]), + # Symbols + ("(foo)", ["SYMBOL"]), + ("(+)", ["SYMBOL"]), + ("(-)", ["SYMBOL"]), + ("(*)", ["SYMBOL"]), + ("(foo-bar)", ["SYMBOL"]), + ("(+foo-bar+)", ["SYMBOL"]), + ("(+foo-bar+)", ["SYMBOL"]), + ("( foo bar )", ["SYMBOL", "SYMBOL"]), + # Keywords + ("(:foo)", ["KEYWORD"]), + ("( :foo )", ["KEYWORD"]), + ("(\n:foo\n)", ["KEYWORD"]), + ("(,:foo,)", ["KEYWORD"]), + ("(:foo :bar)", ["KEYWORD", "KEYWORD"]), + ("(:foo :bar 1)", ["KEYWORD", "KEYWORD", "INTEGER"]), + # Strings + ('("foo", "bar", "baz")', ["STRING", "STRING", "STRING"]), + # Lists + ("([] [] ())", ["SQLIST", "SQLIST", "LIST"]), + ], +) def test_parse_list(text, element_types): """Test we can parse various lists of contents.""" l_t = next(cp.parse_buffer(text, discard_whitespace=True)) @@ -104,45 +115,43 @@ def test_parse_list(text, element_types): assert [t.type for t in l_t] == element_types -@parametrize('text, element_types', [ - # Integers - ("[1]", ["INTEGER"]), - ("[ 1 ]", ["INTEGER"]), - ("[,1,]", ["INTEGER"]), - ("[1\n]", ["INTEGER"]), - ("[\n1\n]", ["INTEGER"]), - ("[1, 2, 3, 4]", ["INTEGER", "INTEGER", "INTEGER", "INTEGER"]), - - # Floats - ("[1.0]", ["FLOAT"]), - ("[1.0e0]", ["FLOAT"]), - ("[1e0]", ["FLOAT"]), - ("[1e0]", ["FLOAT"]), - - # Symbols - ("[foo]", ["SYMBOL"]), - ("[+]", ["SYMBOL"]), - ("[-]", ["SYMBOL"]), - ("[*]", ["SYMBOL"]), - ("[foo-bar]", ["SYMBOL"]), - ("[+foo-bar+]", ["SYMBOL"]), - ("[+foo-bar+]", ["SYMBOL"]), - ("[ foo bar ]", ["SYMBOL", "SYMBOL"]), - - # Keywords - ("[:foo]", ["KEYWORD"]), - ("[ :foo ]", ["KEYWORD"]), - ("[\n:foo\n]", ["KEYWORD"]), - ("[,:foo,]", ["KEYWORD"]), - ("[:foo :bar]", ["KEYWORD", "KEYWORD"]), - ("[:foo :bar 1]", ["KEYWORD", "KEYWORD", "INTEGER"]), - - # Strings - ('["foo", "bar", "baz"]', ["STRING", "STRING", "STRING"]), - - # Lists - ('[[] [] ()]', ["SQLIST", "SQLIST", "LIST"]), -]) +@parametrize( + "text, element_types", + [ + # Integers + ("[1]", ["INTEGER"]), + ("[ 1 ]", ["INTEGER"]), + ("[,1,]", ["INTEGER"]), + ("[1\n]", ["INTEGER"]), + ("[\n1\n]", ["INTEGER"]), + ("[1, 2, 3, 4]", ["INTEGER", "INTEGER", "INTEGER", "INTEGER"]), + # Floats + ("[1.0]", ["FLOAT"]), + ("[1.0e0]", ["FLOAT"]), + ("[1e0]", ["FLOAT"]), + ("[1e0]", ["FLOAT"]), + # Symbols + ("[foo]", ["SYMBOL"]), + ("[+]", ["SYMBOL"]), + ("[-]", ["SYMBOL"]), + ("[*]", ["SYMBOL"]), + ("[foo-bar]", ["SYMBOL"]), + ("[+foo-bar+]", ["SYMBOL"]), + ("[+foo-bar+]", ["SYMBOL"]), + ("[ foo bar ]", ["SYMBOL", "SYMBOL"]), + # Keywords + ("[:foo]", ["KEYWORD"]), + ("[ :foo ]", ["KEYWORD"]), + ("[\n:foo\n]", ["KEYWORD"]), + ("[,:foo,]", ["KEYWORD"]), + ("[:foo :bar]", ["KEYWORD", "KEYWORD"]), + ("[:foo :bar 1]", ["KEYWORD", "KEYWORD", "INTEGER"]), + # Strings + ('["foo", "bar", "baz"]', ["STRING", "STRING", "STRING"]), + # Lists + ("[[] [] ()]", ["SQLIST", "SQLIST", "LIST"]), + ], +) def test_parse_sqlist(text, element_types): """Test we can parse various 'square' lists of contents.""" l_t = next(cp.parse_buffer(text, discard_whitespace=True)) @@ -150,41 +159,21 @@ def test_parse_sqlist(text, element_types): assert [t.type for t in l_t] == element_types -@parametrize('text, element_pairs', [ - ("{}", - []), - - ("{:foo 1}", - [["KEYWORD", "INTEGER"]]), - - ("{:foo 1, :bar 2}", - [["KEYWORD", "INTEGER"], - ["KEYWORD", "INTEGER"]]), - - ("{foo 1, bar 2}", - [["SYMBOL", "INTEGER"], - ["SYMBOL", "INTEGER"]]), - - ("{foo 1, bar -2}", - [["SYMBOL", "INTEGER"], - ["SYMBOL", "INTEGER"]]), - - ("{foo 1, bar -2e0}", - [["SYMBOL", "INTEGER"], - ["SYMBOL", "FLOAT"]]), - - ("{foo ()}", - [["SYMBOL", "LIST"]]), - - ("{foo []}", - [["SYMBOL", "SQLIST"]]), - - ("{foo {}}", - [["SYMBOL", "DICT"]]), - - ('{"foo" {}}', - [["STRING", "DICT"]]) -]) +@parametrize( + "text, element_pairs", + [ + ("{}", []), + ("{:foo 1}", [["KEYWORD", "INTEGER"]]), + ("{:foo 1, :bar 2}", [["KEYWORD", "INTEGER"], ["KEYWORD", "INTEGER"]]), + ("{foo 1, bar 2}", [["SYMBOL", "INTEGER"], ["SYMBOL", "INTEGER"]]), + ("{foo 1, bar -2}", [["SYMBOL", "INTEGER"], ["SYMBOL", "INTEGER"]]), + ("{foo 1, bar -2e0}", [["SYMBOL", "INTEGER"], ["SYMBOL", "FLOAT"]]), + ("{foo ()}", [["SYMBOL", "LIST"]]), + ("{foo []}", [["SYMBOL", "SQLIST"]]), + ("{foo {}}", [["SYMBOL", "DICT"]]), + ('{"foo" {}}', [["STRING", "DICT"]]), + ], +) def test_parse_dict(text, element_pairs): """Test we can parse various mappings.""" d_t = next(cp.parse_buffer(text, discard_whitespace=True)) @@ -192,27 +181,25 @@ def test_parse_dict(text, element_pairs): assert [[t.type for t in pair] for pair in d_t.value] == element_pairs -@parametrize("text", [ - "{1}", - "{1, 2, 3}", - "{:foo}", - "{:foo :bar :baz}" -]) +@parametrize("text", ["{1}", "{1, 2, 3}", "{:foo}", "{:foo :bar :baz}"]) def test_parse_bad_dict(text): """Assert that dicts with missmatched pairs don't parse.""" with pytest.raises(Exception): next(cp.parse_buffer(text)) -@parametrize("text", [ - "()", - "(1 1.1 1e2 -2 foo :foo foo/bar :foo/bar [{},])", - "{:foo bar, :baz [:qux]}", - "'foo", - "'[foo bar :baz 'qux, {}]", - "#foo []", - "^{} bar", -]) +@parametrize( + "text", + [ + "()", + "(1 1.1 1e2 -2 foo :foo foo/bar :foo/bar [{},])", + "{:foo bar, :baz [:qux]}", + "'foo", + "'[foo bar :baz 'qux, {}]", + "#foo []", + "^{} bar", + ], +) def test_examples(text): """Shotgun examples showing we can parse some stuff.""" diff --git a/projects/calf/tests/python/test_reader.py b/projects/calf/tests/python/test_reader.py index 049618d..7516277 100644 --- a/projects/calf/tests/python/test_reader.py +++ b/projects/calf/tests/python/test_reader.py @@ -5,18 +5,22 @@ from conftest import parametrize from calf.reader import read_buffer -@parametrize('text', [ - "()", - "[]", - "[[[[[[[[[]]]]]]]]]", - "{1 {2 {}}}", - '"foo"', - "foo", - "'foo", - "^foo bar", - "^:foo bar", - "{\"foo\" '([:bar ^:foo 'baz 3.14159e0])}", - "[:foo bar 'baz lo/l, 1, 1.2. 1e-5 -1e2]", -]) + +@parametrize( + "text", + [ + "()", + "[]", + "[[[[[[[[[]]]]]]]]]", + "{1 {2 {}}}", + '"foo"', + "foo", + "'foo", + "^foo bar", + "^:foo bar", + "{\"foo\" '([:bar ^:foo 'baz 3.14159e0])}", + "[:foo bar 'baz lo/l, 1, 1.2. 1e-5 -1e2]", + ], +) def test_read(text): assert list(read_buffer(text)) diff --git a/projects/datalog-shell/__main__.py b/projects/datalog-shell/__main__.py index 7d7e453..c1ee170 100755 --- a/projects/datalog-shell/__main__.py +++ b/projects/datalog-shell/__main__.py @@ -58,13 +58,13 @@ from datalog.debris import Timing from datalog.evaluator import select from datalog.reader import pr_str, read_command, read_dataset from datalog.types import ( - CachedDataset, - Constant, - Dataset, - LVar, - PartlyIndexedDataset, - Rule, - TableIndexedDataset + CachedDataset, + Constant, + Dataset, + LVar, + PartlyIndexedDataset, + Rule, + TableIndexedDataset, ) from prompt_toolkit import print_formatted_text, prompt, PromptSession @@ -74,190 +74,204 @@ from prompt_toolkit.styles import Style from yaspin import Spinner, yaspin -STYLE = Style.from_dict({ - # User input (default text). - "": "", - "prompt": "ansigreen", - "time": "ansiyellow" -}) +STYLE = Style.from_dict( + { + # User input (default text). + "": "", + "prompt": "ansigreen", + "time": "ansiyellow", + } +) SPINNER = Spinner(["|", "/", "-", "\\"], 200) class InterpreterInterrupt(Exception): - """An exception used to break the prompt or evaluation.""" + """An exception used to break the prompt or evaluation.""" def print_(fmt, **kwargs): - print_formatted_text(FormattedText(fmt), **kwargs) + print_formatted_text(FormattedText(fmt), **kwargs) def print_db(db): - """Render a database for debugging.""" + """Render a database for debugging.""" - for e in db.tuples(): - print(f"⇒ {pr_str(e)}") + for e in db.tuples(): + print(f"⇒ {pr_str(e)}") - for r in db.rules(): - print(f"⇒ {pr_str(r)}") + for r in db.rules(): + print(f"⇒ {pr_str(r)}") def main(args): - """REPL entry point.""" + """REPL entry point.""" - if args.db_cls == "simple": - db_cls = Dataset - elif args.db_cls == "cached": - db_cls = CachedDataset - elif args.db_cls == "table": - db_cls = TableIndexedDataset - elif args.db_cls == "partly": - db_cls = PartlyIndexedDataset + if args.db_cls == "simple": + db_cls = Dataset + elif args.db_cls == "cached": + db_cls = CachedDataset + elif args.db_cls == "table": + db_cls = TableIndexedDataset + elif args.db_cls == "partly": + db_cls = PartlyIndexedDataset - print(f"Using dataset type {db_cls}") + print(f"Using dataset type {db_cls}") - session = PromptSession(history=FileHistory(".datalog.history")) - db = db_cls([], []) + session = PromptSession(history=FileHistory(".datalog.history")) + db = db_cls([], []) - if args.dbs: - for db_file in args.dbs: - try: - with open(db_file, "r") as f: - db = db.merge(read_dataset(f.read())) - print(f"Loaded {db_file} ...") - except Exception as e: - print("Internal error - {e}") - print(f"Unable to load db {db_file}, skipping") + if args.dbs: + for db_file in args.dbs: + try: + with open(db_file, "r") as f: + db = db.merge(read_dataset(f.read())) + print(f"Loaded {db_file} ...") + except Exception as e: + print("Internal error - {e}") + print(f"Unable to load db {db_file}, skipping") - while True: - try: - line = session.prompt([("class:prompt", ">>> ")], style=STYLE) - except (InterpreterInterrupt, KeyboardInterrupt): - continue - except EOFError: - break + while True: + try: + line = session.prompt([("class:prompt", ">>> ")], style=STYLE) + except (InterpreterInterrupt, KeyboardInterrupt): + continue + except EOFError: + break - if line == ".all": - op = ".all" - elif line == ".dbg": - op = ".dbg" - elif line == ".quit": - break + if line == ".all": + op = ".all" + elif line == ".dbg": + op = ".dbg" + elif line == ".quit": + break - elif line in {".help", "help", "?", "??", "???"}: - print(__doc__) - continue - - elif line.split(" ")[0] == ".log": - op = ".log" - - else: - try: - op, val = read_command(line) - except Exception as e: - print(f"Got an unknown command or syntax error, can't tell which") - continue - - # Definition merges on the DB - if op == ".all": - print_db(db) - - # .dbg drops to a debugger shell so you can poke at the instance objects (database) - elif op == ".dbg": - import pdb - pdb.set_trace() - - # .log sets the log level - badly - elif op == ".log": - level = line.split(" ")[1].upper() - try: - ch.setLevel(getattr(logging, level)) - except BaseException: - print(f"Unknown log level {level}") - - elif op == ".": - # FIXME (arrdem 2019-06-15): - # Syntax rules the parser doesn't impose... - try: - for rule in val.rules(): - assert not rule.free_vars, f"Rule contains free variables {rule.free_vars!r}" - - for tuple in val.tuples(): - assert not any(isinstance(e, LVar) for e in tuple), f"Tuples cannot contain lvars - {tuple!r}" - - except BaseException as e: - print(f"Error: {e}") - continue - - db = db.merge(val) - print_db(val) - - # Queries execute - note that rules as queries have to be temporarily merged. - elif op == "?": - # In order to support ad-hoc rules (joins), we have to generate a transient "query" database - # by bolting the rule on as an overlay to the existing database. If of course we have a join. - # - # `val` was previously assumed to be the query pattern. Introduce `qdb`, now used as the - # database to query and "fix" `val` to be the temporary rule's pattern. - # - # We use a new db and db local so that the ephemeral rule doesn't persist unless the user - # later `.` defines it. - # - # Unfortunately doing this merge does nuke caches. - qdb = db - if isinstance(val, Rule): - qdb = db.merge(db_cls([], [val])) - val = val.pattern - - with yaspin(SPINNER) as spinner: - with Timing() as t: - try: - results = list(select(qdb, val)) - except KeyboardInterrupt: - print(f"Evaluation aborted after {t}") + elif line in {".help", "help", "?", "??", "???"}: + print(__doc__) continue - # It's kinda bogus to move sorting out but oh well - sorted(results) + elif line.split(" ")[0] == ".log": + op = ".log" - for _results, _bindings in results: - _result = _results[0] # select only selects one tuple at a time - print(f"⇒ {pr_str(_result)}") + else: + try: + op, val = read_command(line) + except Exception as e: + print(f"Got an unknown command or syntax error, can't tell which") + continue - # So we can report empty sets explicitly. - if not results: - print("⇒ Ø") + # Definition merges on the DB + if op == ".all": + print_db(db) - print_([("class:time", f"Elapsed time - {t}")], style=STYLE) + # .dbg drops to a debugger shell so you can poke at the instance objects (database) + elif op == ".dbg": + import pdb - # Retractions try to delete, but may fail. - elif op == "!": - if val in db.tuples() or val in [r.pattern for r in db.rules()]: - db = db_cls([u for u in db.tuples() if u != val], - [r for r in db.rules() if r.pattern != val]) - print(f"⇒ {pr_str(val)}") - else: - print("⇒ Ø") + pdb.set_trace() + + # .log sets the log level - badly + elif op == ".log": + level = line.split(" ")[1].upper() + try: + ch.setLevel(getattr(logging, level)) + except BaseException: + print(f"Unknown log level {level}") + + elif op == ".": + # FIXME (arrdem 2019-06-15): + # Syntax rules the parser doesn't impose... + try: + for rule in val.rules(): + assert ( + not rule.free_vars + ), f"Rule contains free variables {rule.free_vars!r}" + + for tuple in val.tuples(): + assert not any( + isinstance(e, LVar) for e in tuple + ), f"Tuples cannot contain lvars - {tuple!r}" + + except BaseException as e: + print(f"Error: {e}") + continue + + db = db.merge(val) + print_db(val) + + # Queries execute - note that rules as queries have to be temporarily merged. + elif op == "?": + # In order to support ad-hoc rules (joins), we have to generate a transient "query" database + # by bolting the rule on as an overlay to the existing database. If of course we have a join. + # + # `val` was previously assumed to be the query pattern. Introduce `qdb`, now used as the + # database to query and "fix" `val` to be the temporary rule's pattern. + # + # We use a new db and db local so that the ephemeral rule doesn't persist unless the user + # later `.` defines it. + # + # Unfortunately doing this merge does nuke caches. + qdb = db + if isinstance(val, Rule): + qdb = db.merge(db_cls([], [val])) + val = val.pattern + + with yaspin(SPINNER) as spinner: + with Timing() as t: + try: + results = list(select(qdb, val)) + except KeyboardInterrupt: + print(f"Evaluation aborted after {t}") + continue + + # It's kinda bogus to move sorting out but oh well + sorted(results) + + for _results, _bindings in results: + _result = _results[0] # select only selects one tuple at a time + print(f"⇒ {pr_str(_result)}") + + # So we can report empty sets explicitly. + if not results: + print("⇒ Ø") + + print_([("class:time", f"Elapsed time - {t}")], style=STYLE) + + # Retractions try to delete, but may fail. + elif op == "!": + if val in db.tuples() or val in [r.pattern for r in db.rules()]: + db = db_cls( + [u for u in db.tuples() if u != val], + [r for r in db.rules() if r.pattern != val], + ) + print(f"⇒ {pr_str(val)}") + else: + print("⇒ Ø") parser = argparse.ArgumentParser() # Select which dataset type to use -parser.add_argument("--db-type", - choices=["simple", "cached", "table", "partly"], - help="Choose which DB to use (default partly)", - dest="db_cls", - default="partly") +parser.add_argument( + "--db-type", + choices=["simple", "cached", "table", "partly"], + help="Choose which DB to use (default partly)", + dest="db_cls", + default="partly", +) -parser.add_argument("--load-db", dest="dbs", action="append", - help="Datalog files to load first.") +parser.add_argument( + "--load-db", dest="dbs", action="append", help="Datalog files to load first." +) if __name__ == "__main__": - args = parser.parse_args(sys.argv[1:]) - logger = logging.getLogger("arrdem.datalog") - ch = logging.StreamHandler() - ch.setLevel(logging.INFO) - formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") - ch.setFormatter(formatter) - logger.addHandler(ch) - main(args) + args = parser.parse_args(sys.argv[1:]) + logger = logging.getLogger("arrdem.datalog") + ch = logging.StreamHandler() + ch.setLevel(logging.INFO) + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + ch.setFormatter(formatter) + logger.addHandler(ch) + main(args) diff --git a/projects/datalog-shell/setup.py b/projects/datalog-shell/setup.py index 2ef284a..00b77e8 100644 --- a/projects/datalog-shell/setup.py +++ b/projects/datalog-shell/setup.py @@ -23,10 +23,7 @@ setup( "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", ], - - scripts=[ - "bin/datalog" - ], + scripts=["bin/datalog"], install_requires=[ "arrdem.datalog~=2.0.0", "prompt_toolkit==2.0.9", diff --git a/projects/datalog/make_graph.py b/projects/datalog/make_graph.py index 744e47d..d6aac0a 100644 --- a/projects/datalog/make_graph.py +++ b/projects/datalog/make_graph.py @@ -9,17 +9,17 @@ from uuid import uuid4 as uuid with open("graph.dtl", "w") as f: - nodes = [] + nodes = [] - # Generate 10k edges - for i in range(10000): - if nodes: - from_node = choice(nodes) - else: - from_node = uuid() + # Generate 10k edges + for i in range(10000): + if nodes: + from_node = choice(nodes) + else: + from_node = uuid() - to_node = uuid() + to_node = uuid() - nodes.append(to_node) + nodes.append(to_node) - f.write(f"edge({str(from_node)!r}, {str(to_node)!r}).\n") + f.write(f"edge({str(from_node)!r}, {str(to_node)!r}).\n") diff --git a/projects/datalog/setup.py b/projects/datalog/setup.py index f9eff15..634003b 100644 --- a/projects/datalog/setup.py +++ b/projects/datalog/setup.py @@ -26,5 +26,7 @@ setup( ], # Package setup package_dir={"": "src/python"}, - packages=["datalog",], + packages=[ + "datalog", + ], ) diff --git a/projects/datalog/src/python/datalog/debris.py b/projects/datalog/src/python/datalog/debris.py index 32b4f75..7cdf126 100644 --- a/projects/datalog/src/python/datalog/debris.py +++ b/projects/datalog/src/python/datalog/debris.py @@ -16,8 +16,8 @@ def constexpr_p(expr): class Timing(object): """ - A context manager object which records how long the context took. - """ + A context manager object which records how long the context took. + """ def __init__(self): self.start = None @@ -36,8 +36,8 @@ class Timing(object): def __call__(self): """ - If the context is exited, return its duration. Otherwise return the duration "so far". - """ + If the context is exited, return its duration. Otherwise return the duration "so far". + """ from datetime import datetime diff --git a/projects/datalog/src/python/datalog/easy.py b/projects/datalog/src/python/datalog/easy.py index 666dd1c..b42a5b1 100644 --- a/projects/datalog/src/python/datalog/easy.py +++ b/projects/datalog/src/python/datalog/easy.py @@ -22,9 +22,9 @@ def read(text: str, db_cls=PartlyIndexedDataset): def q(t: Tuple[str]) -> LTuple: """Helper for writing terse queries. - Takes a tuple of strings, and interprets them as a logic tuple. - So you don't have to write the logic tuple out by hand. - """ + Takes a tuple of strings, and interprets them as a logic tuple. + So you don't have to write the logic tuple out by hand. + """ def _x(s: str): if s[0].isupper(): @@ -50,38 +50,38 @@ def __result(results_bindings): def select(db: Dataset, query: Tuple[str], bindings=None) -> Sequence[Tuple]: """Helper for interpreting tuples of strings as a query, and returning simplified results. - Executes your query, returning matching full tuples. - """ + Executes your query, returning matching full tuples. + """ return __mapv(__result, __select(db, q(query), bindings=bindings)) def join(db: Dataset, query: Sequence[Tuple[str]], bindings=None) -> Sequence[dict]: """Helper for interpreting a bunch of tuples of strings as a join query, and returning simplified -results. + results. - Executes the query clauses as a join, returning a sequence of tuples and binding mappings such - that the join constraints are simultaneously satisfied. + Executes the query clauses as a join, returning a sequence of tuples and binding mappings such + that the join constraints are simultaneously satisfied. - >>> db = read(''' - ... edge(a, b). - ... edge(b, c). - ... edge(c, d). - ... ''') - >>> join(db, [ - ... ('edge', 'A', 'B'), - ... ('edge', 'B', 'C') - ... ]) - [((('edge', 'a', 'b'), - ('edge', 'b', 'c')), - {'A': 'a', 'B': 'b', 'C': 'c'}), - ((('edge', 'b', 'c'), - ('edge', 'c', 'd')), - {'A': 'b', 'B': 'c', 'C': 'd'}), - ((('edge', 'c', 'd'), - ('edge', 'd', 'f')), - {'A': 'c', 'B': 'd', 'C': 'f'})] - """ + >>> db = read(''' + ... edge(a, b). + ... edge(b, c). + ... edge(c, d). + ... ''') + >>> join(db, [ + ... ('edge', 'A', 'B'), + ... ('edge', 'B', 'C') + ... ]) + [((('edge', 'a', 'b'), + ('edge', 'b', 'c')), + {'A': 'a', 'B': 'b', 'C': 'c'}), + ((('edge', 'b', 'c'), + ('edge', 'c', 'd')), + {'A': 'b', 'B': 'c', 'C': 'd'}), + ((('edge', 'c', 'd'), + ('edge', 'd', 'f')), + {'A': 'c', 'B': 'd', 'C': 'f'})] + """ return __mapv(__result, __join(db, [q(c) for c in query], bindings=bindings)) diff --git a/projects/datalog/src/python/datalog/evaluator.py b/projects/datalog/src/python/datalog/evaluator.py index 5a66808..d902ebf 100644 --- a/projects/datalog/src/python/datalog/evaluator.py +++ b/projects/datalog/src/python/datalog/evaluator.py @@ -20,8 +20,8 @@ from datalog.types import ( def match(tuple, expr, bindings=None): """Attempt to construct lvar bindings from expr such that tuple and expr equate. - If the match is successful, return the binding map, otherwise return None. - """ + If the match is successful, return the binding map, otherwise return None. + """ bindings = bindings.copy() if bindings is not None else {} for a, b in zip(expr, tuple): @@ -43,9 +43,9 @@ def match(tuple, expr, bindings=None): def apply_bindings(expr, bindings, strict=True): """Given an expr which may contain lvars, substitute its lvars for constants returning the - simplified expr. + simplified expr. - """ + """ if strict: return tuple((bindings[e] if isinstance(e, LVar) else e) for e in expr) @@ -56,10 +56,10 @@ def apply_bindings(expr, bindings, strict=True): def select(db: Dataset, expr, bindings=None, _recursion_guard=None, _select_guard=None): """Evaluate an expression in a database, lazily producing a sequence of 'matching' tuples. - The dataset is a set of tuples and rules, and the expression is a single tuple containing lvars - and constants. Evaluates rules and tuples, returning + The dataset is a set of tuples and rules, and the expression is a single tuple containing lvars + and constants. Evaluates rules and tuples, returning - """ + """ def __select_tuples(): # As an opt. support indexed scans, which is optional. @@ -170,8 +170,8 @@ def select(db: Dataset, expr, bindings=None, _recursion_guard=None, _select_guar def join(db: Dataset, clauses, bindings, pattern=None, _recursion_guard=None): """Evaluate clauses over the dataset, joining (or antijoining) with the seed bindings. - Yields a sequence of tuples and LVar bindings for which all joins and antijoins were satisfied. - """ + Yields a sequence of tuples and LVar bindings for which all joins and antijoins were satisfied. + """ def __join(g, clause): for ts, bindings in g: diff --git a/projects/datalog/src/python/datalog/reader.py b/projects/datalog/src/python/datalog/reader.py index ec417d4..ceb1afb 100644 --- a/projects/datalog/src/python/datalog/reader.py +++ b/projects/datalog/src/python/datalog/reader.py @@ -27,13 +27,19 @@ class Actions(object): return self._db_cls(tuples, rules) def make_symbol(self, input, start, end, elements): - return LVar("".join(e.text for e in elements),) + return LVar( + "".join(e.text for e in elements), + ) def make_word(self, input, start, end, elements): - return Constant("".join(e.text for e in elements),) + return Constant( + "".join(e.text for e in elements), + ) def make_string(self, input, start, end, elements): - return Constant(elements[1].text,) + return Constant( + elements[1].text, + ) def make_comment(self, input, start, end, elements): return None @@ -81,11 +87,11 @@ class Actions(object): class Parser(Grammar): """Implementation detail. - A slightly hacked version of the Parser class canopy generates, which lets us control what the - parsing entry point is. This lets me play games with having one parser and one grammar which is - used both for the command shell and for other things. + A slightly hacked version of the Parser class canopy generates, which lets us control what the + parsing entry point is. This lets me play games with having one parser and one grammar which is + used both for the command shell and for other things. - """ + """ def __init__(self, input, actions, types): self._input = input diff --git a/projects/datalog/src/python/datalog/types.py b/projects/datalog/src/python/datalog/types.py index 5d9c2b5..c4d3739 100644 --- a/projects/datalog/src/python/datalog/types.py +++ b/projects/datalog/src/python/datalog/types.py @@ -66,8 +66,8 @@ class Dataset(object): class CachedDataset(Dataset): """An extension of the dataset which features a cache of rule produced tuples. - Note that this cache is lost when merging datasets - which ensures correctness. - """ + Note that this cache is lost when merging datasets - which ensures correctness. + """ # Inherits tuples, rules, merge @@ -90,11 +90,11 @@ class CachedDataset(Dataset): class TableIndexedDataset(CachedDataset): """An extension of the Dataset type which features both a cache and an index by table & length. - The index allows more efficient scans by maintaining 'table' style partitions. - It does not support user-defined indexing schemes. + The index allows more efficient scans by maintaining 'table' style partitions. + It does not support user-defined indexing schemes. - Note that index building is delayed until an index is scanned. - """ + Note that index building is delayed until an index is scanned. + """ # From Dataset: # tuples, rules, merge @@ -126,11 +126,11 @@ class TableIndexedDataset(CachedDataset): class PartlyIndexedDataset(TableIndexedDataset): """An extension of the Dataset type which features both a cache and and a full index by table, - length, tuple index and value. + length, tuple index and value. - The index allows extremely efficient scans when elements of the tuple are known. + The index allows extremely efficient scans when elements of the tuple are known. - """ + """ # From Dataset: # tuples, rules, merge diff --git a/projects/datalog/test/python/test_datalog_evaluator.py b/projects/datalog/test/python/test_datalog_evaluator.py index ac5ab24..d4e1663 100644 --- a/projects/datalog/test/python/test_datalog_evaluator.py +++ b/projects/datalog/test/python/test_datalog_evaluator.py @@ -25,8 +25,19 @@ def test_id_query(db_cls): Constant("a"), Constant("b"), ) - assert not select(db_cls([], []), ("a", "b",)) - assert select(db_cls([ab], []), ("a", "b",)) == [((("a", "b"),), {},)] + assert not select( + db_cls([], []), + ( + "a", + "b", + ), + ) + assert select(db_cls([ab], []), ("a", "b",)) == [ + ( + (("a", "b"),), + {}, + ) + ] @pytest.mark.parametrize("db_cls,", DBCLS) @@ -47,7 +58,17 @@ def test_lvar_unification(db_cls): d = read("""edge(b, c). edge(c, c).""", db_cls=db_cls) - assert select(d, ("edge", "X", "X",)) == [((("edge", "c", "c"),), {"X": "c"})] + assert ( + select( + d, + ( + "edge", + "X", + "X", + ), + ) + == [((("edge", "c", "c"),), {"X": "c"})] + ) @pytest.mark.parametrize("db_cls,", DBCLS) @@ -105,12 +126,12 @@ no-b(X, Y) :- def test_nested_antijoin(db_cls): """Test a query which negates a subquery which uses an antijoin. - Shouldn't exercise anything more than `test_antjoin` does, but it's an interesting case since you - actually can't capture the same semantics using a single query. Antijoins can't propagate positive - information (create lvar bindings) so I'm not sure you can express this another way without a - different evaluation strategy. + Shouldn't exercise anything more than `test_antjoin` does, but it's an interesting case since you + actually can't capture the same semantics using a single query. Antijoins can't propagate positive + information (create lvar bindings) so I'm not sure you can express this another way without a + different evaluation strategy. - """ + """ d = read( """ diff --git a/projects/flowmetal/setup.py b/projects/flowmetal/setup.py index fa67c3c..22feec0 100644 --- a/projects/flowmetal/setup.py +++ b/projects/flowmetal/setup.py @@ -3,7 +3,7 @@ from setuptools import setup setup( name="arrdem.flowmetal", # Package metadata - version='0.0.0', + version="0.0.0", license="MIT", description="A weird execution engine", long_description=open("README.md").read(), @@ -18,20 +18,16 @@ setup( "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.8", ], - # Package setup package_dir={"": "src/python"}, packages=[ "flowmetal", ], entry_points={ - 'console_scripts': [ - 'iflow=flowmetal.repl:main' - ], + "console_scripts": ["iflow=flowmetal.repl:main"], }, install_requires=[ - 'prompt-toolkit~=3.0.0', + "prompt-toolkit~=3.0.0", ], - extras_require={ - } + extras_require={}, ) diff --git a/projects/flowmetal/src/python/flowmetal/db/base.py b/projects/flowmetal/src/python/flowmetal/db/base.py index cf69279..529edd4 100644 --- a/projects/flowmetal/src/python/flowmetal/db/base.py +++ b/projects/flowmetal/src/python/flowmetal/db/base.py @@ -2,7 +2,7 @@ An abstract or base Flowmetal DB. """ -from abc import abc, abstractmethod, abstractproperty, abstractclassmethod, abstractstaticmethod +from abc import abstractclassmethod, abstractmethod class Db(ABC): diff --git a/projects/flowmetal/src/python/flowmetal/frontend.py b/projects/flowmetal/src/python/flowmetal/frontend.py index a3e854d..82475bb 100644 --- a/projects/flowmetal/src/python/flowmetal/frontend.py +++ b/projects/flowmetal/src/python/flowmetal/frontend.py @@ -3,6 +3,7 @@ import click + @click.group() def cli(): pass diff --git a/projects/flowmetal/src/python/flowmetal/interpreter.py b/projects/flowmetal/src/python/flowmetal/interpreter.py index a3e854d..82475bb 100644 --- a/projects/flowmetal/src/python/flowmetal/interpreter.py +++ b/projects/flowmetal/src/python/flowmetal/interpreter.py @@ -3,6 +3,7 @@ import click + @click.group() def cli(): pass diff --git a/projects/flowmetal/src/python/flowmetal/models.py b/projects/flowmetal/src/python/flowmetal/models.py index 59e54aa..ebf7bee 100644 --- a/projects/flowmetal/src/python/flowmetal/models.py +++ b/projects/flowmetal/src/python/flowmetal/models.py @@ -1,5 +1,3 @@ """ Somewhat generic models of Flowmetal programs. """ - -from typing import NamedTuple diff --git a/projects/flowmetal/src/python/flowmetal/reaper.py b/projects/flowmetal/src/python/flowmetal/reaper.py index a3e854d..82475bb 100644 --- a/projects/flowmetal/src/python/flowmetal/reaper.py +++ b/projects/flowmetal/src/python/flowmetal/reaper.py @@ -3,6 +3,7 @@ import click + @click.group() def cli(): pass diff --git a/projects/flowmetal/src/python/flowmetal/scheduler.py b/projects/flowmetal/src/python/flowmetal/scheduler.py index a3e854d..82475bb 100644 --- a/projects/flowmetal/src/python/flowmetal/scheduler.py +++ b/projects/flowmetal/src/python/flowmetal/scheduler.py @@ -3,6 +3,7 @@ import click + @click.group() def cli(): pass diff --git a/projects/gandi/src/python/gandi/client.py b/projects/gandi/src/python/gandi/client.py index 97d070f..ca85117 100644 --- a/projects/gandi/src/python/gandi/client.py +++ b/projects/gandi/src/python/gandi/client.py @@ -9,94 +9,102 @@ import requests class GandiAPI(object): - """An extremely incomplete Gandi REST API driver class. + """An extremely incomplete Gandi REST API driver class. - Exists to close over your API key, and make talking to the API slightly more idiomatic. + Exists to close over your API key, and make talking to the API slightly more idiomatic. - Note: In an effort to be nice, this API maintains a cache of the zones visible with your API - key. The cache is maintained when using this driver, but concurrent modifications of your zone(s) - via the web UI or other processes will obviously undermine it. This cache can be disabled by - setting the `use_cache` kwarg to `False`. + Note: In an effort to be nice, this API maintains a cache of the zones visible with your API + key. The cache is maintained when using this driver, but concurrent modifications of your zone(s) + via the web UI or other processes will obviously undermine it. This cache can be disabled by + setting the `use_cache` kwarg to `False`. - """ - - def __init__(self, key=None, use_cache=True): - self._base = "https://dns.api.gandi.net/api/v5" - self._key = key - self._use_cache = use_cache - self._zones = None + """ - # Helpers for making requests with the API key as required by the API. + def __init__(self, key=None, use_cache=True): + self._base = "https://dns.api.gandi.net/api/v5" + self._key = key + self._use_cache = use_cache + self._zones = None - def _do_request(self, method, path, headers=None, **kwargs): - headers = headers or {} - headers["X-Api-Key"] = self._key - resp = method(self._base + path, headers=headers, **kwargs) - if resp.status_code > 299: - print(resp.text) - resp.raise_for_status() + # Helpers for making requests with the API key as required by the API. - return resp + def _do_request(self, method, path, headers=None, **kwargs): + headers = headers or {} + headers["X-Api-Key"] = self._key + resp = method(self._base + path, headers=headers, **kwargs) + if resp.status_code > 299: + print(resp.text) + resp.raise_for_status() - def _get(self, path, **kwargs): - return self._do_request(requests.get, path, **kwargs) + return resp - def _post(self, path, **kwargs): - return self._do_request(requests.post, path, **kwargs) + def _get(self, path, **kwargs): + return self._do_request(requests.get, path, **kwargs) - def _put(self, path, **kwargs): - return self._do_request(requests.put, path, **kwargs) + def _post(self, path, **kwargs): + return self._do_request(requests.post, path, **kwargs) - # Intentional public API - - def domain_records(self, domain): - return self._get("/domains/{0}/records".format(domain)).json() + def _put(self, path, **kwargs): + return self._do_request(requests.put, path, **kwargs) - def get_zones(self): - if self._use_cache: - if not self._zones: - self._zones = self._get("/zones").json() - return self._zones - else: - return self._get("/zones").json() + # Intentional public API - def get_zone(self, zone_id): - return self._get("/zones/{}".format(zone_id)).json() + def domain_records(self, domain): + return self._get("/domains/{0}/records".format(domain)).json() - def get_zone_by_name(self, zone_name): - for zone in self.get_zones(): - if zone["name"] == zone_name: - return zone + def get_zones(self): + if self._use_cache: + if not self._zones: + self._zones = self._get("/zones").json() + return self._zones + else: + return self._get("/zones").json() - def create_zone(self, zone_name): - new_zone_id = self._post("/zones", - headers={"content-type": "application/json"}, - data=json.dumps({"name": zone_name}))\ - .headers["Location"]\ - .split("/")[-1] - new_zone = self.get_zone(new_zone_id) + def get_zone(self, zone_id): + return self._get("/zones/{}".format(zone_id)).json() - # Note: If the cache is active, update the cache. - if self._use_cache and self._zones is not None: - self._zones.append(new_zone) + def get_zone_by_name(self, zone_name): + for zone in self.get_zones(): + if zone["name"] == zone_name: + return zone - return new_zone + def create_zone(self, zone_name): + new_zone_id = ( + self._post( + "/zones", + headers={"content-type": "application/json"}, + data=json.dumps({"name": zone_name}), + ) + .headers["Location"] + .split("/")[-1] + ) + new_zone = self.get_zone(new_zone_id) - def replace_domain(self, domain, records): - date = datetime.now() - date = "{:%A, %d. %B %Y %I:%M%p}".format(date) - zone_name = f"updater generated domain - {domain} - {date}" + # Note: If the cache is active, update the cache. + if self._use_cache and self._zones is not None: + self._zones.append(new_zone) - zone = self.get_zone_by_name(zone_name) - if not zone: - zone = self.create_zone(zone_name) + return new_zone - print("Using zone", zone["uuid"]) + def replace_domain(self, domain, records): + date = datetime.now() + date = "{:%A, %d. %B %Y %I:%M%p}".format(date) + zone_name = f"updater generated domain - {domain} - {date}" - for r in records: - self._post("/zones/{0}/records".format(zone["uuid"]), - headers={"content-type": "application/json"}, - data=json.dumps(r)) + zone = self.get_zone_by_name(zone_name) + if not zone: + zone = self.create_zone(zone_name) - return self._post("/zones/{0}/domains/{1}".format(zone["uuid"], domain), - headers={"content-type": "application/json"}) + print("Using zone", zone["uuid"]) + + for r in records: + self._post( + "/zones/{0}/records".format(zone["uuid"]), + headers={"content-type": "application/json"}, + data=json.dumps(r), + ) + + return self._post( + "/zones/{0}/domains/{1}".format(zone["uuid"], domain), + headers={"content-type": "application/json"}, + ) diff --git a/projects/public-dns/src/python/arrdem/updater/__main__.py b/projects/public-dns/src/python/arrdem/updater/__main__.py index 4ba8da7..27493a2 100644 --- a/projects/public-dns/src/python/arrdem/updater/__main__.py +++ b/projects/public-dns/src/python/arrdem/updater/__main__.py @@ -20,165 +20,171 @@ import meraki RECORD_LINE_PATTERN = re.compile( - "^(?P\S+)\s+" - "(?P\S+)\s+" - "IN\s+" - "(?P\S+)\s+" - "(?P.+)$") + "^(?P\S+)\s+" + "(?P\S+)\s+" + "IN\s+" + "(?P\S+)\s+" + "(?P.+)$" +) def update(m, k, f, *args, **kwargs): - """clojure.core/update for Python's stateful maps.""" - if k in m: - m[k] = f(m[k], *args, **kwargs) - return m + """clojure.core/update for Python's stateful maps.""" + if k in m: + m[k] = f(m[k], *args, **kwargs) + return m def parse_zone_record(line): - if match := RECORD_LINE_PATTERN.search(line): - dat = match.groupdict() - dat = update(dat, "rrset_ttl", int) - dat = update(dat, "rrset_values", lambda x: [x]) - return dat + if match := RECORD_LINE_PATTERN.search(line): + dat = match.groupdict() + dat = update(dat, "rrset_ttl", int) + dat = update(dat, "rrset_values", lambda x: [x]) + return dat def same_record(lr, rr): - """ - A test to see if two records name the same zone entry. - """ + """ + A test to see if two records name the same zone entry. + """ - return lr["rrset_name"] == rr["rrset_name"] and \ - lr["rrset_type"] == rr["rrset_type"] + return lr["rrset_name"] == rr["rrset_name"] and lr["rrset_type"] == rr["rrset_type"] def records_equate(lr, rr): - """ - Equality, ignoring rrset_href which is generated by the API. - """ + """ + Equality, ignoring rrset_href which is generated by the API. + """ - if not same_record(lr, rr): - return False - elif lr["rrset_ttl"] != rr["rrset_ttl"]: - return False - elif set(lr["rrset_values"]) != set(rr["rrset_values"]): - return False - else: - return True + if not same_record(lr, rr): + return False + elif lr["rrset_ttl"] != rr["rrset_ttl"]: + return False + elif set(lr["rrset_values"]) != set(rr["rrset_values"]): + return False + else: + return True def template_and_parse_zone(template_file, template_bindings): - assert template_file is not None - assert template_bindings is not None + assert template_file is not None + assert template_bindings is not None - with open(template_file) as f: - dat = jinja2.Template(f.read()).render(**template_bindings) + with open(template_file) as f: + dat = jinja2.Template(f.read()).render(**template_bindings) - uncommitted_records = [] - for line in dat.splitlines(): - if line and not line[0] == "#": - record = parse_zone_record(line) - if record: - uncommitted_records.append(record) + uncommitted_records = [] + for line in dat.splitlines(): + if line and not line[0] == "#": + record = parse_zone_record(line) + if record: + uncommitted_records.append(record) - records = [] + records = [] - for uncommitted_r in uncommitted_records: - flag = False - for committed_r in records: - if same_record(uncommitted_r, committed_r): - # Join the two records - committed_r["rrset_values"].extend(uncommitted_r["rrset_values"]) - flag = True + for uncommitted_r in uncommitted_records: + flag = False + for committed_r in records: + if same_record(uncommitted_r, committed_r): + # Join the two records + committed_r["rrset_values"].extend(uncommitted_r["rrset_values"]) + flag = True - if not flag: - records.append(uncommitted_r) + if not flag: + records.append(uncommitted_r) - sorted(records, key=lambda x: (x["rrset_type"], x["rrset_name"])) + sorted(records, key=lambda x: (x["rrset_type"], x["rrset_name"])) - return records + return records def diff_zones(left_zone, right_zone): - """ - Equality between unordered lists of records constituting a zone. - """ - - in_left_not_right = [] - for lr in left_zone: - flag = False - for rr in right_zone: - if same_record(lr, rr) and records_equate(lr, rr): - flag = True - break + """ + Equality between unordered lists of records constituting a zone. + """ - if not flag: - in_left_not_right.append(lr) - - in_right_not_left = [] - for rr in right_zone: - flag = False + in_left_not_right = [] for lr in left_zone: - if same_record(lr, rr) and records_equate(lr, rr): - flag = True - break + flag = False + for rr in right_zone: + if same_record(lr, rr) and records_equate(lr, rr): + flag = True + break - if not flag: - in_right_not_left.append(lr) + if not flag: + in_left_not_right.append(lr) - return in_left_not_right or in_right_not_left + in_right_not_left = [] + for rr in right_zone: + flag = False + for lr in left_zone: + if same_record(lr, rr) and records_equate(lr, rr): + flag = True + break + + if not flag: + in_right_not_left.append(lr) + + return in_left_not_right or in_right_not_left -parser = argparse.ArgumentParser(description="\"Dynamic\" DNS updating for self-hosted services") +parser = argparse.ArgumentParser( + description='"Dynamic" DNS updating for self-hosted services' +) parser.add_argument("--config", dest="config_file", required=True) parser.add_argument("--templates", dest="template_dir", required=True) parser.add_argument("--dry-run", dest="dry", action="store_true", default=False) + def main(): - args = parser.parse_args() - config = yaml.safe_load(open(args.config_file, "r")) + args = parser.parse_args() + config = yaml.safe_load(open(args.config_file, "r")) - dashboard = meraki.DashboardAPI(config["meraki"]["key"], output_log=False) - org = config["meraki"]["organization"] - device = config["meraki"]["router_serial"] + dashboard = meraki.DashboardAPI(config["meraki"]["key"], output_log=False) + org = config["meraki"]["organization"] + device = config["meraki"]["router_serial"] - uplinks = dashboard.appliance.getOrganizationApplianceUplinkStatuses( - organizationId=org, - serials=[device] - )[0]["uplinks"] + uplinks = dashboard.appliance.getOrganizationApplianceUplinkStatuses( + organizationId=org, serials=[device] + )[0]["uplinks"] - template_bindings = { - "local": { - # One of the two - "public_v4s": [link.get("publicIp") for link in uplinks if link.get("publicIp")], - }, - # Why isn't there a merge method - **config["bindings"] - } + template_bindings = { + "local": { + # One of the two + "public_v4s": [ + link.get("publicIp") for link in uplinks if link.get("publicIp") + ], + }, + # Why isn't there a merge method + **config["bindings"], + } - api = GandiAPI(config["gandi"]["key"]) + api = GandiAPI(config["gandi"]["key"]) - for task in config["tasks"]: - if isinstance(task, str): - task = {"template": task + ".j2", - "zones": [task]} + for task in config["tasks"]: + if isinstance(task, str): + task = {"template": task + ".j2", "zones": [task]} - computed_zone = template_and_parse_zone(os.path.join(args.template_dir, task["template"]), template_bindings) + computed_zone = template_and_parse_zone( + os.path.join(args.template_dir, task["template"]), template_bindings + ) - for zone_name in task["zones"]: - try: - live_zone = api.domain_records(zone_name) + for zone_name in task["zones"]: + try: + live_zone = api.domain_records(zone_name) - if diff_zones(computed_zone, live_zone): - print("Zone {} differs, computed zone:".format(zone_name)) - pprint(computed_zone) - if not args.dry: - print(api.replace_domain(zone_name, computed_zone)) - else: - print("Zone {} up to date".format(zone_name)) + if diff_zones(computed_zone, live_zone): + print("Zone {} differs, computed zone:".format(zone_name)) + pprint(computed_zone) + if not args.dry: + print(api.replace_domain(zone_name, computed_zone)) + else: + print("Zone {} up to date".format(zone_name)) + + except Exception as e: + print("While processing zone {}".format(zone_name)) + raise e - except Exception as e: - print("While processing zone {}".format(zone_name)) - raise e if __name__ == "__main__" or 1: - main() + main() diff --git a/projects/ratchet/setup.py b/projects/ratchet/setup.py index ca50cb0..481956f 100644 --- a/projects/ratchet/setup.py +++ b/projects/ratchet/setup.py @@ -3,7 +3,7 @@ from setuptools import setup setup( name="arrdem.ratchet", # Package metadata - version='0.0.0', + version="0.0.0", license="MIT", description="A 'ratcheting' message system", long_description=open("README.md").read(), @@ -18,18 +18,12 @@ setup( "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.8", ], - # Package setup - package_dir={ - "": "src/python" - }, + package_dir={"": "src/python"}, packages=[ "ratchet", ], - entry_points={ - }, - install_requires=[ - ], - extras_require={ - } + entry_points={}, + install_requires=[], + extras_require={}, ) diff --git a/projects/ratchet/src/python/ratchet/backend/sqlite.py b/projects/ratchet/src/python/ratchet/backend/sqlite.py index efcf5bb..66f346a 100644 --- a/projects/ratchet/src/python/ratchet/backend/sqlite.py +++ b/projects/ratchet/src/python/ratchet/backend/sqlite.py @@ -98,14 +98,15 @@ VALUES (?, ?, ?, ?); """ - class SQLiteDriver: - def __init__(self, - filename="~/.ratchet.sqlite3", - sqlite_timeout=1000, - message_ttl=60000, - message_space="_", - message_author=f"{os.getpid()}@{socket.gethostname()}"): + def __init__( + self, + filename="~/.ratchet.sqlite3", + sqlite_timeout=1000, + message_ttl=60000, + message_space="_", + message_author=f"{os.getpid()}@{socket.gethostname()}", + ): self._path = os.path.expanduser(filename) self._sqlite_timeout = sqlite_timeout self._message_ttl = message_ttl @@ -120,14 +121,11 @@ class SQLiteDriver: conn.executescript(SCHEMA_SCRIPT) def _connection(self): - return sql.connect(self._filename, - timeout=self._sqlite_timeout) + return sql.connect(self._filename, timeout=self._sqlite_timeout) - def create_message(self, - message: str, - ttl: int = None, - space: str = None, - author: str = None): + def create_message( + self, message: str, ttl: int = None, space: str = None, author: str = None + ): """Create a single message.""" ttl = ttl or self._message_ttl @@ -138,11 +136,9 @@ class SQLiteDriver: cursor.execute(CREATE_MESSAGE_SCRIPT, author, space, ttl, message) return cursor.lastrowid - def create_event(self, - timeout: int, - ttl: int = None, - space: str = None, - author: str = None): + def create_event( + self, timeout: int, ttl: int = None, space: str = None, author: str = None + ): """Create a (pending) event.""" ttl = ttl or self._message_ttl diff --git a/projects/yamlschema/test_yamlschema.py b/projects/yamlschema/test_yamlschema.py index f67ce2e..2e50678 100644 --- a/projects/yamlschema/test_yamlschema.py +++ b/projects/yamlschema/test_yamlschema.py @@ -7,111 +7,94 @@ from yamlschema import lint_buffer import pytest -@pytest.mark.parametrize('schema, obj', [ - ({"type": "number"}, - "---\n1.0"), - ({"type": "integer"}, - "---\n3"), - ({"type": "string"}, - "---\nfoo bar baz"), - ({"type": "string", - "maxLength": 15}, - "---\nfoo bar baz"), - ({"type": "string", - "minLength": 10}, - "---\nfoo bar baz"), - ({"type": "string", - "pattern": "^foo.*"}, - "---\nfoo bar baz"), - ({"type": "object", - "additionalProperties": True}, - "---\nfoo: bar\nbaz: qux"), - ({"type": "object", - "properties": {"foo": {"type": "string"}}}, - "---\nfoo: bar\nbaz: qux"), - ({"type": "object", - "properties": {"foo": {"type": "string"}}, - "additionalProperties": False}, - "---\nfoo: bar"), - ({"type": "object", - "properties": {"foo": {"type": "object"}}}, - "---\nfoo: {}"), - ({"type": "object", - "properties": {"foo": { - "type": "array", - "items": {"type": "object"}}}}, - "---\nfoo: [{}, {}, {foo: bar}]"), -]) +@pytest.mark.parametrize( + "schema, obj", + [ + ({"type": "number"}, "---\n1.0"), + ({"type": "integer"}, "---\n3"), + ({"type": "string"}, "---\nfoo bar baz"), + ({"type": "string", "maxLength": 15}, "---\nfoo bar baz"), + ({"type": "string", "minLength": 10}, "---\nfoo bar baz"), + ({"type": "string", "pattern": "^foo.*"}, "---\nfoo bar baz"), + ({"type": "object", "additionalProperties": True}, "---\nfoo: bar\nbaz: qux"), + ( + {"type": "object", "properties": {"foo": {"type": "string"}}}, + "---\nfoo: bar\nbaz: qux", + ), + ( + { + "type": "object", + "properties": {"foo": {"type": "string"}}, + "additionalProperties": False, + }, + "---\nfoo: bar", + ), + ({"type": "object", "properties": {"foo": {"type": "object"}}}, "---\nfoo: {}"), + ( + { + "type": "object", + "properties": {"foo": {"type": "array", "items": {"type": "object"}}}, + }, + "---\nfoo: [{}, {}, {foo: bar}]", + ), + ], +) def test_lint_document_ok(schema, obj): assert not list(lint_buffer(schema, obj)) -@pytest.mark.parametrize('msg, schema, obj', [ - # Numerics - ("Floats are not ints", - {"type": "integer"}, - "---\n1.0"), - ("Ints are not floats", - {"type": "number"}, - "---\n1"), - - # Numerics - range limits. Integer edition - ("1 is the limit of the range", - {"type": "integer", - "exclusiveMaximum": 1}, - "---\n1"), - ("1 is the limit of the range", - {"type": "integer", - "exclusiveMinimum": 1}, - "---\n1"), - ("1 is out of the range", - {"type": "integer", - "minimum": 2}, - "---\n1"), - ("1 is out of the range", - {"type": "integer", - "maximum": 0}, - "---\n1"), - ("1 is out of the range", - {"type": "integer", - "exclusiveMinimum": 1}, - "---\n1"), - - # Numerics - range limits. Number/Float edition - ("1 is the limit of the range", - {"type": "number", - "exclusiveMaximum": 1}, - "---\n1.0"), - ("1 is the limit of the range", - {"type": "number", - "exclusiveMinimum": 1}, - "---\n1.0"), - ("1 is out of the range", - {"type": "number", - "minimum": 2}, - "---\n1.0"), - ("1 is out of the range", - {"type": "number", - "maximum": 0}, - "---\n1.0"), - ("1 is out of the range", - {"type": "number", - "exclusiveMinimum": 1}, - "---\n1.0"), - - # String shit - ("String too short", - {"type": "string", "minLength": 1}, - "---\n''"), - ("String too long", - {"type": "string", "maxLength": 1}, - "---\nfoo"), - ("String does not match pattern", - {"type": "string", "pattern": "bar"}, - "---\nfoo"), - ("String does not fully match pattern", - {"type": "string", "pattern": "foo"}, - "---\nfooooooooo"), -]) +@pytest.mark.parametrize( + "msg, schema, obj", + [ + # Numerics + ("Floats are not ints", {"type": "integer"}, "---\n1.0"), + ("Ints are not floats", {"type": "number"}, "---\n1"), + # Numerics - range limits. Integer edition + ( + "1 is the limit of the range", + {"type": "integer", "exclusiveMaximum": 1}, + "---\n1", + ), + ( + "1 is the limit of the range", + {"type": "integer", "exclusiveMinimum": 1}, + "---\n1", + ), + ("1 is out of the range", {"type": "integer", "minimum": 2}, "---\n1"), + ("1 is out of the range", {"type": "integer", "maximum": 0}, "---\n1"), + ("1 is out of the range", {"type": "integer", "exclusiveMinimum": 1}, "---\n1"), + # Numerics - range limits. Number/Float edition + ( + "1 is the limit of the range", + {"type": "number", "exclusiveMaximum": 1}, + "---\n1.0", + ), + ( + "1 is the limit of the range", + {"type": "number", "exclusiveMinimum": 1}, + "---\n1.0", + ), + ("1 is out of the range", {"type": "number", "minimum": 2}, "---\n1.0"), + ("1 is out of the range", {"type": "number", "maximum": 0}, "---\n1.0"), + ( + "1 is out of the range", + {"type": "number", "exclusiveMinimum": 1}, + "---\n1.0", + ), + # String shit + ("String too short", {"type": "string", "minLength": 1}, "---\n''"), + ("String too long", {"type": "string", "maxLength": 1}, "---\nfoo"), + ( + "String does not match pattern", + {"type": "string", "pattern": "bar"}, + "---\nfoo", + ), + ( + "String does not fully match pattern", + {"type": "string", "pattern": "foo"}, + "---\nfooooooooo", + ), + ], +) def test_lint_document_fails(msg, schema, obj): assert list(lint_buffer(schema, obj)), msg diff --git a/projects/yamlschema/yamlschema.py b/projects/yamlschema/yamlschema.py index 6351f11..0cb7c54 100644 --- a/projects/yamlschema/yamlschema.py +++ b/projects/yamlschema/yamlschema.py @@ -59,9 +59,7 @@ class YamlLinter(object): return schema def lint_mapping(self, schema, node: Node) -> t.Iterable[str]: - """FIXME. - - """ + """FIXME.""" if schema["type"] != "object" or not isinstance(node, MappingNode): yield LintRecord( @@ -71,9 +69,7 @@ class YamlLinter(object): f"Expected {schema['type']}, got {node.id} {str(node.start_mark).lstrip()}", ) - additional_type: t.Union[dict, bool] = ( - schema.get("additionalProperties", True) - ) + additional_type: t.Union[dict, bool] = schema.get("additionalProperties", True) properties: dict = schema.get("properties", {}) required: t.Iterable[str] = schema.get("required", []) @@ -135,37 +131,26 @@ class YamlLinter(object): elif schema["type"] == "number": yield from self.lint_number(schema, node) else: - raise NotImplementedError( - f"Scalar type {schema['type']} is not supported" - ) + raise NotImplementedError(f"Scalar type {schema['type']} is not supported") def lint_string(self, schema, node: Node) -> t.Iterable[str]: """FIXME.""" if node.tag != "tag:yaml.org,2002:str": yield LintRecord( - LintLevel.MISSMATCH, - node, - schema, - f"Expected a string, got a {node}" + LintLevel.MISSMATCH, node, schema, f"Expected a string, got a {node}" ) if maxl := schema.get("maxLength"): if len(node.value) > maxl: yield LintRecord( - LintLevel.MISSMATCH, - node, - schema, - f"Expected a shorter string" + LintLevel.MISSMATCH, node, schema, f"Expected a shorter string" ) if minl := schema.get("minLength"): if len(node.value) < minl: yield LintRecord( - LintLevel.MISSMATCH, - node, - schema, - f"Expected a longer string" + LintLevel.MISSMATCH, node, schema, f"Expected a longer string" ) if pat := schema.get("pattern"): @@ -174,7 +159,7 @@ class YamlLinter(object): LintLevel.MISSMATCH, node, schema, - f"Expected a string matching the pattern" + f"Expected a string matching the pattern", ) def lint_integer(self, schema, node: Node) -> t.Iterable[str]: @@ -184,10 +169,7 @@ class YamlLinter(object): else: yield LintRecord( - LintLevel.MISSMATCH, - node, - schema, - f"Expected an integer, got a {node}" + LintLevel.MISSMATCH, node, schema, f"Expected an integer, got a {node}" ) def lint_number(self, schema, node: Node) -> t.Iterable[str]: @@ -197,13 +179,9 @@ class YamlLinter(object): else: yield LintRecord( - LintLevel.MISSMATCH, - node, - schema, - f"Expected an integer, got a {node}" + LintLevel.MISSMATCH, node, schema, f"Expected an integer, got a {node}" ) - def _lint_num_range(self, schema, node: Node, value) -> t.Iterable[str]: """"FIXME.""" @@ -213,7 +191,7 @@ class YamlLinter(object): LintLevel.MISSMATCH, node, schema, - f"Expected a multiple of {base}, got {value}" + f"Expected a multiple of {base}, got {value}", ) if (max := schema.get("exclusiveMaximum")) is not None: @@ -222,7 +200,7 @@ class YamlLinter(object): LintLevel.MISSMATCH, node, schema, - f"Expected a value less than {max}, got {value}" + f"Expected a value less than {max}, got {value}", ) if (max := schema.get("maximum")) is not None: @@ -231,7 +209,7 @@ class YamlLinter(object): LintLevel.MISSMATCH, node, schema, - f"Expected a value less than or equal to {max}, got {value}" + f"Expected a value less than or equal to {max}, got {value}", ) if (min := schema.get("exclusiveMinimum")) is not None: @@ -240,7 +218,7 @@ class YamlLinter(object): LintLevel.MISSMATCH, node, schema, - f"Expected a value greater than {min}, got {value}" + f"Expected a value greater than {min}, got {value}", ) if (min := schema.get("minimum")) is not None: @@ -249,7 +227,7 @@ class YamlLinter(object): LintLevel.MISSMATCH, node, schema, - f"Expected a value greater than or equal to {min}, got {value}" + f"Expected a value greater than or equal to {min}, got {value}", ) def lint_document(self, node, schema=None) -> t.Iterable[str]: @@ -271,10 +249,7 @@ class YamlLinter(object): # This is the schema that rejects everything. elif schema == False: yield LintRecord( - LintLevel.UNEXPECTED, - node, - schema, - "Received an unexpected value" + LintLevel.UNEXPECTED, node, schema, "Received an unexpected value" ) # Walking the PyYAML node hierarchy