Black all the things

This commit is contained in:
Reid 'arrdem' McKenzie 2021-09-02 22:10:35 -06:00
parent 7170fd40a8
commit 2494211ef2
31 changed files with 395 additions and 294 deletions

View file

@ -100,10 +100,16 @@ def create_tables(queries: Queries, conn) -> None:
log.info("Created migrations table") log.info("Created migrations table")
# Insert the bootstrap 'fixup' record # Insert the bootstrap 'fixup' record
execute_migration(queries, conn, execute_migration(
queries,
conn,
MigrationDescriptor( MigrationDescriptor(
name="anosql_migrations_create_table", name="anosql_migrations_create_table",
sha256sum=sha256(queries.anosql_migrations_create_table.sql.encode("utf-8")).hexdigest())) sha256sum=sha256(
queries.anosql_migrations_create_table.sql.encode("utf-8")
).hexdigest(),
),
)
def committed_migrations(queries: Queries, conn) -> t.Iterable[MigrationDescriptor]: def committed_migrations(queries: Queries, conn) -> t.Iterable[MigrationDescriptor]:
@ -133,7 +139,8 @@ def available_migrations(queries: Queries, conn) -> t.Iterable[MigrationDescript
yield MigrationDescriptor( yield MigrationDescriptor(
name=query_name, name=query_name,
committed_at=None, committed_at=None,
sha256sum=sha256(query_fn.sql.encode("utf-8")).hexdigest()) sha256sum=sha256(query_fn.sql.encode("utf-8")).hexdigest(),
)
def execute_migration(queries: Queries, conn, migration: MigrationDescriptor): def execute_migration(queries: Queries, conn, migration: MigrationDescriptor):

View file

@ -15,7 +15,9 @@ CREATE TABLE kv (`id` INT, `key` TEXT, `value` TEXT);
def table_exists(conn, table_name): def table_exists(conn, table_name):
return list(conn.execute(f"""\ return list(
conn.execute(
f"""\
SELECT ( SELECT (
`name` `name`
) )
@ -23,7 +25,9 @@ def table_exists(conn, table_name):
WHERE WHERE
`type` = 'table' `type` = 'table'
AND `name` = '{table_name}' AND `name` = '{table_name}'
;""")) ;"""
)
)
@pytest.fixture @pytest.fixture
@ -36,7 +40,9 @@ def conn() -> sqlite3.Connection:
def test_connect(conn: sqlite3.Connection): def test_connect(conn: sqlite3.Connection):
"""Assert that the connection works and we can execute against it.""" """Assert that the connection works and we can execute against it."""
assert list(conn.execute("SELECT 1;")) == [(1, ), ] assert list(conn.execute("SELECT 1;")) == [
(1,),
]
@pytest.fixture @pytest.fixture
@ -66,7 +72,9 @@ def test_migrations_list(conn, queries):
"""Test that we can list out available migrations.""" """Test that we can list out available migrations."""
ms = list(anosql_migrations.available_migrations(queries, conn)) ms = list(anosql_migrations.available_migrations(queries, conn))
assert any(m.name == "migration_0000_create_kv" for m in ms), f"Didn't find in {ms!r}" assert any(
m.name == "migration_0000_create_kv" for m in ms
), f"Didn't find in {ms!r}"
def test_committed_migrations(conn, queries): def test_committed_migrations(conn, queries):
@ -96,4 +104,6 @@ def test_post_committed_migrations(migrated_conn, queries):
"""Assert that the create_kv migration has been committed.""" """Assert that the create_kv migration has been committed."""
ms = list(anosql_migrations.committed_migrations(queries, migrated_conn)) ms = list(anosql_migrations.committed_migrations(queries, migrated_conn))
assert any(m.name == "migration_0000_create_kv" for m in ms), "\n".join(migrated_conn.iterdump()) assert any(m.name == "migration_0000_create_kv" for m in ms), "\n".join(
migrated_conn.iterdump()
)

View file

@ -243,15 +243,12 @@ latex_elements = {
# The paper size ('letterpaper' or 'a4paper'). # The paper size ('letterpaper' or 'a4paper').
# #
# 'papersize': 'letterpaper', # 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt'). # The font size ('10pt', '11pt' or '12pt').
# #
# 'pointsize': '10pt', # 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble. # Additional stuff for the LaTeX preamble.
# #
# 'preamble': '', # 'preamble': '',
# Latex figure (float) alignment # Latex figure (float) alignment
# #
# 'figure_align': 'htbp', # 'figure_align': 'htbp',
@ -261,8 +258,7 @@ latex_elements = {
# (source start file, target name, title, # (source start file, target name, title,
# author, documentclass [howto, manual, or own class]). # author, documentclass [howto, manual, or own class]).
latex_documents = [ latex_documents = [
(master_doc, "anosql.tex", u"anosql Documentation", (master_doc, "anosql.tex", u"anosql Documentation", u"Honza Pokorny", "manual"),
u"Honza Pokorny", "manual"),
] ]
# The name of an image file (relative to this directory) to place at the top of # The name of an image file (relative to this directory) to place at the top of
@ -302,10 +298,7 @@ latex_documents = [
# One entry per manual page. List of tuples # One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section). # (source start file, name, description, authors, manual section).
man_pages = [ man_pages = [(master_doc, "anosql", u"anosql Documentation", [author], 1)]
(master_doc, "anosql", u"anosql Documentation",
[author], 1)
]
# If true, show URL addresses after external links. # If true, show URL addresses after external links.
# #
@ -318,9 +311,15 @@ man_pages = [
# (source start file, target name, title, author, # (source start file, target name, title, author,
# dir menu entry, description, category) # dir menu entry, description, category)
texinfo_documents = [ texinfo_documents = [
(master_doc, "anosql", u"anosql Documentation", (
author, "anosql", "One line description of project.", master_doc,
"Miscellaneous"), "anosql",
u"anosql Documentation",
author,
"anosql",
"One line description of project.",
"Miscellaneous",
),
] ]
# Documents to append as an appendix to all manuals. # Documents to append as an appendix to all manuals.

View file

@ -2,4 +2,10 @@ from .core import from_path, from_str, SQLOperationType
from .exceptions import SQLLoadException, SQLParseException from .exceptions import SQLLoadException, SQLParseException
__all__ = ["from_path", "from_str", "SQLOperationType", "SQLLoadException", "SQLParseException"] __all__ = [
"from_path",
"from_str",
"SQLOperationType",
"SQLLoadException",
"SQLParseException",
]

View file

@ -53,7 +53,9 @@ class SQLite3DriverAdapter(object):
conn.execute(sql, parameters) conn.execute(sql, parameters)
@staticmethod @staticmethod
def insert_update_delete_many(conn: sqlite3.Connection, _query_name, sql, parameters): def insert_update_delete_many(
conn: sqlite3.Connection, _query_name, sql, parameters
):
log.debug({"sql": sql, "parameters": parameters}) log.debug({"sql": sql, "parameters": parameters})
conn.executemany(sql, parameters) conn.executemany(sql, parameters)

View file

@ -7,7 +7,7 @@ from .patterns import (
doc_comment_pattern, doc_comment_pattern,
empty_pattern, empty_pattern,
query_name_definition_pattern, query_name_definition_pattern,
valid_query_name_pattern valid_query_name_pattern,
) )
@ -89,8 +89,8 @@ def get_driver_adapter(driver_name):
class SQLOperationType(object): class SQLOperationType(object):
"""Enumeration (kind of) of anosql operation types """Enumeration (kind of) of anosql operation types"""
"""
INSERT_RETURNING = 0 INSERT_RETURNING = 0
INSERT_UPDATE_DELETE = 1 INSERT_UPDATE_DELETE = 1
INSERT_UPDATE_DELETE_MANY = 2 INSERT_UPDATE_DELETE_MANY = 2
@ -168,9 +168,13 @@ def _create_fns(query_name, docs, op_type, sql, driver_adapter):
if op_type == SQLOperationType.INSERT_RETURNING: if op_type == SQLOperationType.INSERT_RETURNING:
return driver_adapter.insert_returning(conn, query_name, sql, parameters) return driver_adapter.insert_returning(conn, query_name, sql, parameters)
elif op_type == SQLOperationType.INSERT_UPDATE_DELETE: elif op_type == SQLOperationType.INSERT_UPDATE_DELETE:
return driver_adapter.insert_update_delete(conn, query_name, sql, parameters) return driver_adapter.insert_update_delete(
conn, query_name, sql, parameters
)
elif op_type == SQLOperationType.INSERT_UPDATE_DELETE_MANY: elif op_type == SQLOperationType.INSERT_UPDATE_DELETE_MANY:
return driver_adapter.insert_update_delete_many(conn, query_name, sql, *parameters) return driver_adapter.insert_update_delete_many(
conn, query_name, sql, *parameters
)
elif op_type == SQLOperationType.SCRIPT: elif op_type == SQLOperationType.SCRIPT:
return driver_adapter.execute_script(conn, sql) return driver_adapter.execute_script(conn, sql)
elif op_type == SQLOperationType.SELECT_ONE_ROW: elif op_type == SQLOperationType.SELECT_ONE_ROW:
@ -351,5 +355,5 @@ def from_path(sql_path, driver_name):
else: else:
raise SQLLoadException( raise SQLLoadException(
"The sql_path must be a directory or file, got {}".format(sql_path), "The sql_path must be a directory or file, got {}".format(sql_path),
sql_path sql_path,
) )

View file

@ -102,10 +102,15 @@ def pg_conn(postgresql):
with postgresql.cursor() as cur: with postgresql.cursor() as cur:
with open(USERS_DATA_PATH) as fp: with open(USERS_DATA_PATH) as fp:
cur.copy_from(fp, "users", sep=",", columns=["username", "firstname", "lastname"]) cur.copy_from(
fp, "users", sep=",", columns=["username", "firstname", "lastname"]
)
with open(BLOGS_DATA_PATH) as fp: with open(BLOGS_DATA_PATH) as fp:
cur.copy_from( cur.copy_from(
fp, "blogs", sep=",", columns=["userid", "title", "content", "published"] fp,
"blogs",
sep=",",
columns=["userid", "title", "content", "published"],
) )
return postgresql return postgresql

View file

@ -29,17 +29,26 @@ def test_record_query(pg_conn, queries):
def test_parameterized_query(pg_conn, queries): def test_parameterized_query(pg_conn, queries):
actual = queries.blogs.get_user_blogs(pg_conn, userid=1) actual = queries.blogs.get_user_blogs(pg_conn, userid=1)
expected = [("How to make a pie.", date(2018, 11, 23)), ("What I did Today", date(2017, 7, 28))] expected = [
("How to make a pie.", date(2018, 11, 23)),
("What I did Today", date(2017, 7, 28)),
]
assert actual == expected assert actual == expected
def test_parameterized_record_query(pg_conn, queries): def test_parameterized_record_query(pg_conn, queries):
dsn = pg_conn.get_dsn_parameters() dsn = pg_conn.get_dsn_parameters()
with psycopg2.connect(cursor_factory=psycopg2.extras.RealDictCursor, **dsn) as conn: with psycopg2.connect(cursor_factory=psycopg2.extras.RealDictCursor, **dsn) as conn:
actual = queries.blogs.pg_get_blogs_published_after(conn, published=date(2018, 1, 1)) actual = queries.blogs.pg_get_blogs_published_after(
conn, published=date(2018, 1, 1)
)
expected = [ expected = [
{"title": "How to make a pie.", "username": "bobsmith", "published": "2018-11-23 00:00"}, {
"title": "How to make a pie.",
"username": "bobsmith",
"published": "2018-11-23 00:00",
},
{"title": "Testing", "username": "janedoe", "published": "2018-01-01 00:00"}, {"title": "Testing", "username": "janedoe", "published": "2018-01-01 00:00"},
] ]

View file

@ -5,6 +5,7 @@ import pytest
@pytest.fixture @pytest.fixture
def sqlite(request): def sqlite(request):
import sqlite3 import sqlite3
sqlconnection = sqlite3.connect(":memory:") sqlconnection = sqlite3.connect(":memory:")
def fin(): def fin():
@ -18,11 +19,13 @@ def sqlite(request):
def test_simple_query(sqlite): def test_simple_query(sqlite):
_test_create_insert = ("-- name: create-some-table#\n" _test_create_insert = (
"-- name: create-some-table#\n"
"-- testing insertion\n" "-- testing insertion\n"
"CREATE TABLE foo (a, b, c);\n\n" "CREATE TABLE foo (a, b, c);\n\n"
"-- name: insert-some-value!\n" "-- name: insert-some-value!\n"
"INSERT INTO foo (a, b, c) VALUES (1, 2, 3);\n") "INSERT INTO foo (a, b, c) VALUES (1, 2, 3);\n"
)
q = anosql.from_str(_test_create_insert, "sqlite3") q = anosql.from_str(_test_create_insert, "sqlite3")
q.create_some_table(sqlite) q.create_some_table(sqlite)
@ -30,11 +33,13 @@ def test_simple_query(sqlite):
def test_auto_insert_query(sqlite): def test_auto_insert_query(sqlite):
_test_create_insert = ("-- name: create-some-table#\n" _test_create_insert = (
"-- name: create-some-table#\n"
"-- testing insertion\n" "-- testing insertion\n"
"CREATE TABLE foo (a, b, c);\n\n" "CREATE TABLE foo (a, b, c);\n\n"
"-- name: insert-some-value<!\n" "-- name: insert-some-value<!\n"
"INSERT INTO foo (a, b, c) VALUES (1, 2, 3);\n") "INSERT INTO foo (a, b, c) VALUES (1, 2, 3);\n"
)
q = anosql.from_str(_test_create_insert, "sqlite3") q = anosql.from_str(_test_create_insert, "sqlite3")
q.create_some_table(sqlite) q.create_some_table(sqlite)
@ -44,13 +49,15 @@ def test_auto_insert_query(sqlite):
def test_parametrized_insert(sqlite): def test_parametrized_insert(sqlite):
_test_create_insert = ("-- name: create-some-table#\n" _test_create_insert = (
"-- name: create-some-table#\n"
"-- testing insertion\n" "-- testing insertion\n"
"CREATE TABLE foo (a, b, c);\n\n" "CREATE TABLE foo (a, b, c);\n\n"
"-- name: insert-some-value!\n" "-- name: insert-some-value!\n"
"INSERT INTO foo (a, b, c) VALUES (?, ?, ?);\n\n" "INSERT INTO foo (a, b, c) VALUES (?, ?, ?);\n\n"
"-- name: get-all-values\n" "-- name: get-all-values\n"
"SELECT * FROM foo;\n") "SELECT * FROM foo;\n"
)
q = anosql.from_str(_test_create_insert, "sqlite3") q = anosql.from_str(_test_create_insert, "sqlite3")
q.create_some_table(sqlite) q.create_some_table(sqlite)
@ -59,13 +66,15 @@ def test_parametrized_insert(sqlite):
def test_parametrized_insert_named(sqlite): def test_parametrized_insert_named(sqlite):
_test_create_insert = ("-- name: create-some-table#\n" _test_create_insert = (
"-- name: create-some-table#\n"
"-- testing insertion\n" "-- testing insertion\n"
"CREATE TABLE foo (a, b, c);\n\n" "CREATE TABLE foo (a, b, c);\n\n"
"-- name: insert-some-value!\n" "-- name: insert-some-value!\n"
"INSERT INTO foo (a, b, c) VALUES (:a, :b, :c);\n\n" "INSERT INTO foo (a, b, c) VALUES (:a, :b, :c);\n\n"
"-- name: get-all-values\n" "-- name: get-all-values\n"
"SELECT * FROM foo;\n") "SELECT * FROM foo;\n"
)
q = anosql.from_str(_test_create_insert, "sqlite3") q = anosql.from_str(_test_create_insert, "sqlite3")
q.create_some_table(sqlite) q.create_some_table(sqlite)
@ -74,23 +83,27 @@ def test_parametrized_insert_named(sqlite):
def test_one_row(sqlite): def test_one_row(sqlite):
_test_one_row = ("-- name: one-row?\n" _test_one_row = (
"-- name: one-row?\n"
"SELECT 1, 'hello';\n\n" "SELECT 1, 'hello';\n\n"
"-- name: two-rows?\n" "-- name: two-rows?\n"
"SELECT 1 UNION SELECT 2;\n") "SELECT 1 UNION SELECT 2;\n"
)
q = anosql.from_str(_test_one_row, "sqlite3") q = anosql.from_str(_test_one_row, "sqlite3")
assert q.one_row(sqlite) == (1, "hello") assert q.one_row(sqlite) == (1, "hello")
assert q.two_rows(sqlite) is None assert q.two_rows(sqlite) is None
def test_simple_query_pg(postgresql): def test_simple_query_pg(postgresql):
_queries = ("-- name: create-some-table#\n" _queries = (
"-- name: create-some-table#\n"
"-- testing insertion\n" "-- testing insertion\n"
"CREATE TABLE foo (id serial primary key, a int, b int, c int);\n\n" "CREATE TABLE foo (id serial primary key, a int, b int, c int);\n\n"
"-- name: insert-some-value!\n" "-- name: insert-some-value!\n"
"INSERT INTO foo (a, b, c) VALUES (1, 2, 3);\n\n" "INSERT INTO foo (a, b, c) VALUES (1, 2, 3);\n\n"
"-- name: get-all-values\n" "-- name: get-all-values\n"
"SELECT a, b, c FROM foo;\n") "SELECT a, b, c FROM foo;\n"
)
q = anosql.from_str(_queries, "psycopg2") q = anosql.from_str(_queries, "psycopg2")
@ -101,13 +114,15 @@ def test_simple_query_pg(postgresql):
def test_auto_insert_query_pg(postgresql): def test_auto_insert_query_pg(postgresql):
_queries = ("-- name: create-some-table#\n" _queries = (
"-- name: create-some-table#\n"
"-- testing insertion\n" "-- testing insertion\n"
"CREATE TABLE foo (id serial primary key, a int, b int, c int);\n\n" "CREATE TABLE foo (id serial primary key, a int, b int, c int);\n\n"
"-- name: insert-some-value<!\n" "-- name: insert-some-value<!\n"
"INSERT INTO foo (a, b, c) VALUES (1, 2, 3) returning id;\n\n" "INSERT INTO foo (a, b, c) VALUES (1, 2, 3) returning id;\n\n"
"-- name: get-all-values\n" "-- name: get-all-values\n"
"SELECT a, b, c FROM foo;\n") "SELECT a, b, c FROM foo;\n"
)
q = anosql.from_str(_queries, "psycopg2") q = anosql.from_str(_queries, "psycopg2")
@ -118,13 +133,15 @@ def test_auto_insert_query_pg(postgresql):
def test_parameterized_insert_pg(postgresql): def test_parameterized_insert_pg(postgresql):
_queries = ("-- name: create-some-table#\n" _queries = (
"-- name: create-some-table#\n"
"-- testing insertion\n" "-- testing insertion\n"
"CREATE TABLE foo (id serial primary key, a int, b int, c int);\n\n" "CREATE TABLE foo (id serial primary key, a int, b int, c int);\n\n"
"-- name: insert-some-value!\n" "-- name: insert-some-value!\n"
"INSERT INTO foo (a, b, c) VALUES (%s, %s, %s);\n\n" "INSERT INTO foo (a, b, c) VALUES (%s, %s, %s);\n\n"
"-- name: get-all-values\n" "-- name: get-all-values\n"
"SELECT a, b, c FROM foo;\n") "SELECT a, b, c FROM foo;\n"
)
q = anosql.from_str(_queries, "psycopg2") q = anosql.from_str(_queries, "psycopg2")
@ -135,13 +152,15 @@ def test_parameterized_insert_pg(postgresql):
def test_auto_parameterized_insert_query_pg(postgresql): def test_auto_parameterized_insert_query_pg(postgresql):
_queries = ("-- name: create-some-table#\n" _queries = (
"-- name: create-some-table#\n"
"-- testing insertion\n" "-- testing insertion\n"
"CREATE TABLE foo (id serial primary key, a int, b int, c int);\n\n" "CREATE TABLE foo (id serial primary key, a int, b int, c int);\n\n"
"-- name: insert-some-value<!\n" "-- name: insert-some-value<!\n"
"INSERT INTO foo (a, b, c) VALUES (%s, %s, %s) returning id;\n\n" "INSERT INTO foo (a, b, c) VALUES (%s, %s, %s) returning id;\n\n"
"-- name: get-all-values\n" "-- name: get-all-values\n"
"SELECT a, b, c FROM foo;\n") "SELECT a, b, c FROM foo;\n"
)
q = anosql.from_str(_queries, "psycopg2") q = anosql.from_str(_queries, "psycopg2")
@ -154,13 +173,15 @@ def test_auto_parameterized_insert_query_pg(postgresql):
def test_parameterized_select_pg(postgresql): def test_parameterized_select_pg(postgresql):
_queries = ("-- name: create-some-table#\n" _queries = (
"-- name: create-some-table#\n"
"-- testing insertion\n" "-- testing insertion\n"
"CREATE TABLE foo (id serial primary key, a int, b int, c int);\n\n" "CREATE TABLE foo (id serial primary key, a int, b int, c int);\n\n"
"-- name: insert-some-value!\n" "-- name: insert-some-value!\n"
"INSERT INTO foo (a, b, c) VALUES (1, 2, 3)\n\n" "INSERT INTO foo (a, b, c) VALUES (1, 2, 3)\n\n"
"-- name: get-all-values\n" "-- name: get-all-values\n"
"SELECT a, b, c FROM foo WHERE a = %s;\n") "SELECT a, b, c FROM foo WHERE a = %s;\n"
)
q = anosql.from_str(_queries, "psycopg2") q = anosql.from_str(_queries, "psycopg2")
@ -171,13 +192,15 @@ def test_parameterized_select_pg(postgresql):
def test_parameterized_insert_named_pg(postgresql): def test_parameterized_insert_named_pg(postgresql):
_queries = ("-- name: create-some-table#\n" _queries = (
"-- name: create-some-table#\n"
"-- testing insertion\n" "-- testing insertion\n"
"CREATE TABLE foo (id serial primary key, a int, b int, c int);\n\n" "CREATE TABLE foo (id serial primary key, a int, b int, c int);\n\n"
"-- name: insert-some-value!\n" "-- name: insert-some-value!\n"
"INSERT INTO foo (a, b, c) VALUES (%(a)s, %(b)s, %(c)s)\n\n" "INSERT INTO foo (a, b, c) VALUES (%(a)s, %(b)s, %(c)s)\n\n"
"-- name: get-all-values\n" "-- name: get-all-values\n"
"SELECT a, b, c FROM foo;\n") "SELECT a, b, c FROM foo;\n"
)
q = anosql.from_str(_queries, "psycopg2") q = anosql.from_str(_queries, "psycopg2")
@ -188,13 +211,15 @@ def test_parameterized_insert_named_pg(postgresql):
def test_parameterized_select_named_pg(postgresql): def test_parameterized_select_named_pg(postgresql):
_queries = ("-- name: create-some-table#\n" _queries = (
"-- name: create-some-table#\n"
"-- testing insertion\n" "-- testing insertion\n"
"CREATE TABLE foo (id serial primary key, a int, b int, c int);\n\n" "CREATE TABLE foo (id serial primary key, a int, b int, c int);\n\n"
"-- name: insert-some-value!\n" "-- name: insert-some-value!\n"
"INSERT INTO foo (a, b, c) VALUES (1, 2, 3)\n\n" "INSERT INTO foo (a, b, c) VALUES (1, 2, 3)\n\n"
"-- name: get-all-values\n" "-- name: get-all-values\n"
"SELECT a, b, c FROM foo WHERE a = %(a)s;\n") "SELECT a, b, c FROM foo WHERE a = %(a)s;\n"
)
q = anosql.from_str(_queries, "psycopg2") q = anosql.from_str(_queries, "psycopg2")
@ -208,7 +233,6 @@ def test_without_trailing_semi_colon_pg():
"""Make sure keywords ending queries are recognized even without """Make sure keywords ending queries are recognized even without
semi-colons. semi-colons.
""" """
_queries = ("-- name: get-by-a\n" _queries = "-- name: get-by-a\n" "SELECT a, b, c FROM foo WHERE a = :a\n"
"SELECT a, b, c FROM foo WHERE a = :a\n")
q = anosql.from_str(_queries, "psycopg2") q = anosql.from_str(_queries, "psycopg2")
assert q.get_by_a.sql == "SELECT a, b, c FROM foo WHERE a = %(a)s" assert q.get_by_a.sql == "SELECT a, b, c FROM foo WHERE a = %(a)s"

View file

@ -32,16 +32,25 @@ def test_record_query(sqlite3_conn, queries):
def test_parameterized_query(sqlite3_conn, queries): def test_parameterized_query(sqlite3_conn, queries):
actual = queries.blogs.get_user_blogs(sqlite3_conn, userid=1) actual = queries.blogs.get_user_blogs(sqlite3_conn, userid=1)
expected = [("How to make a pie.", "2018-11-23"), ("What I did Today", "2017-07-28")] expected = [
("How to make a pie.", "2018-11-23"),
("What I did Today", "2017-07-28"),
]
assert actual == expected assert actual == expected
def test_parameterized_record_query(sqlite3_conn, queries): def test_parameterized_record_query(sqlite3_conn, queries):
sqlite3_conn.row_factory = dict_factory sqlite3_conn.row_factory = dict_factory
actual = queries.blogs.sqlite_get_blogs_published_after(sqlite3_conn, published="2018-01-01") actual = queries.blogs.sqlite_get_blogs_published_after(
sqlite3_conn, published="2018-01-01"
)
expected = [ expected = [
{"title": "How to make a pie.", "username": "bobsmith", "published": "2018-11-23 00:00"}, {
"title": "How to make a pie.",
"username": "bobsmith",
"published": "2018-11-23 00:00",
},
{"title": "Testing", "username": "janedoe", "published": "2018-01-01 00:00"}, {"title": "Testing", "username": "janedoe", "published": "2018-01-01 00:00"},
] ]
@ -51,7 +60,10 @@ def test_parameterized_record_query(sqlite3_conn, queries):
def test_select_cursor_context_manager(sqlite3_conn, queries): def test_select_cursor_context_manager(sqlite3_conn, queries):
with queries.blogs.get_user_blogs_cursor(sqlite3_conn, userid=1) as cursor: with queries.blogs.get_user_blogs_cursor(sqlite3_conn, userid=1) as cursor:
actual = cursor.fetchall() actual = cursor.fetchall()
expected = [("How to make a pie.", "2018-11-23"), ("What I did Today", "2017-07-28")] expected = [
("How to make a pie.", "2018-11-23"),
("What I did Today", "2017-07-28"),
]
assert actual == expected assert actual == expected

View file

@ -2,12 +2,15 @@ from damm import Damm
import pytest import pytest
@pytest.mark.parametrize("num", [ @pytest.mark.parametrize(
"num",
[
"0", # 0 itself is the start Damm state "0", # 0 itself is the start Damm state
"37", # [0, 3] => 7 "37", # [0, 3] => 7
"92", # [0, 9] => 2 "92", # [0, 9] => 2
"1234", # Amusingly, this is a 0-product. "1234", # Amusingly, this is a 0-product.
]) ],
)
def test_num_verifies(num): def test_num_verifies(num):
"""Assert that known-good Damm checks pass.""" """Assert that known-good Damm checks pass."""

View file

@ -61,7 +61,7 @@ from datalog.types import (
LVar, LVar,
PartlyIndexedDataset, PartlyIndexedDataset,
Rule, Rule,
TableIndexedDataset TableIndexedDataset,
) )
from prompt_toolkit import print_formatted_text, PromptSession from prompt_toolkit import print_formatted_text, PromptSession

View file

@ -6,7 +6,7 @@ from datalog.types import (
Constant, Constant,
Dataset, Dataset,
PartlyIndexedDataset, PartlyIndexedDataset,
TableIndexedDataset TableIndexedDataset,
) )
import pytest import pytest
@ -194,7 +194,9 @@ def test_alternate_rule_lrec(db_cls):
"""Testing that both recursion and alternation work.""" """Testing that both recursion and alternation work."""
if db_cls == Dataset: if db_cls == Dataset:
pytest.xfail("left-recursive rules aren't supported with a trivial store and no planner") pytest.xfail(
"left-recursive rules aren't supported with a trivial store and no planner"
)
d = read( d = read(
""" """

View file

@ -91,4 +91,8 @@ if __name__ == "__main__":
with open(filename) as f: with open(filename) as f:
root = ast.parse(f.read(), filename) root = ast.parse(f.read(), filename)
print(yaml.dump(YAMLTreeDumper().visit(root), default_flow_style=False, sort_keys=False)) print(
yaml.dump(
YAMLTreeDumper().visit(root), default_flow_style=False, sort_keys=False
)
)

View file

@ -791,8 +791,9 @@ class ModuleInterpreter(StrictNodeVisitor):
for n in node.names: for n in node.names:
if n in self.ns and self.ns[n] is not GLOBAL: if n in self.ns and self.ns[n] is not GLOBAL:
raise SyntaxError( raise SyntaxError(
"SyntaxError: name '{}' is assigned to before global declaration" "SyntaxError: name '{}' is assigned to before global declaration".format(
.format(n) n
)
) )
# Don't store GLOBAL in the top-level namespace # Don't store GLOBAL in the top-level namespace
if self.ns.parent: if self.ns.parent:
@ -895,6 +896,7 @@ class ModuleInterpreter(StrictNodeVisitor):
def visit_Ellipsis(self, node): def visit_Ellipsis(self, node):
# In Py3k only # In Py3k only
from ast import Ellipsis from ast import Ellipsis
return Ellipsis return Ellipsis
def visit_Print(self, node): def visit_Print(self, node):
@ -942,10 +944,14 @@ class InterpreterSystem(object):
elif os.path.isfile(e): elif os.path.isfile(e):
# FIXME (arrdem 2021-05-31) # FIXME (arrdem 2021-05-31)
raise RuntimeError("Import from .zip/.whl/.egg archives aren't supported yet") raise RuntimeError(
"Import from .zip/.whl/.egg archives aren't supported yet"
)
else: else:
self.modules[name] = __import__(name, globals, locals, fromlist, level) self.modules[name] = __import__(
name, globals, locals, fromlist, level
)
return self.modules[name] return self.modules[name]

View file

@ -27,6 +27,7 @@ for _ in range(10):
def bar(a, b, **bs): def bar(a, b, **bs):
pass pass
import requests import requests

View file

@ -60,21 +60,20 @@ def bench(callable, reps):
with timing() as t: with timing() as t:
callable() callable()
timings.append(t.duration) timings.append(t.duration)
print(f"""Ran {callable.__name__!r} {reps} times, total time {timer(run_t.duration)} print(
f"""Ran {callable.__name__!r} {reps} times, total time {timer(run_t.duration)}
mean: {timer(mean(timings))} mean: {timer(mean(timings))}
median: {timer(median(timings))} median: {timer(median(timings))}
stddev: {timer(stdev(timings))} stddev: {timer(stdev(timings))}
test overhead: {timer((run_t.duration - sum(timings)) / reps)} test overhead: {timer((run_t.duration - sum(timings)) / reps)}
""") """
)
def test_reference_json(reps): def test_reference_json(reps):
"""As a reference benchmark, test just appending to a file.""" """As a reference benchmark, test just appending to a file."""
jobs = [ jobs = [{"user_id": randint(0, 1 << 32), "msg": randstr(256)} for _ in range(reps)]
{"user_id": randint(0, 1<<32), "msg": randstr(256)}
for _ in range(reps)
]
jobs_i = iter(jobs) jobs_i = iter(jobs)
def naive_serialize(): def naive_serialize():
@ -86,15 +85,13 @@ def test_reference_json(reps):
def test_reference_fsync(reps): def test_reference_fsync(reps):
"""As a reference benchmark, test just appending to a file.""" """As a reference benchmark, test just appending to a file."""
jobs = [ jobs = [{"user_id": randint(0, 1 << 32), "msg": randstr(256)} for _ in range(reps)]
{"user_id": randint(0, 1<<32), "msg": randstr(256)}
for _ in range(reps)
]
jobs_i = iter(jobs) jobs_i = iter(jobs)
handle, path = tempfile.mkstemp() handle, path = tempfile.mkstemp()
os.close(handle) os.close(handle)
with open(path, "w") as fd: with open(path, "w") as fd:
def naive_fsync(): def naive_fsync():
fd.write(json.dumps([next(jobs_i), ["CREATED"]])) fd.write(json.dumps([next(jobs_i), ["CREATED"]]))
fd.flush() fd.flush()
@ -106,10 +103,7 @@ def test_reference_fsync(reps):
def test_insert(q, reps): def test_insert(q, reps):
"""Benchmark insertion time to a given SQLite DB.""" """Benchmark insertion time to a given SQLite DB."""
jobs = [ jobs = [{"user_id": randint(0, 1 << 32), "msg": randstr(256)} for _ in range(reps)]
{"user_id": randint(0, 1<<32), "msg": randstr(256)}
for _ in range(reps)
]
jobs_i = iter(jobs) jobs_i = iter(jobs)
def insert(): def insert():

View file

@ -160,7 +160,9 @@ def compile_query(query):
elif isinstance(query, str): elif isinstance(query, str):
terms = [query] terms = [query]
assert not any(keyword in query.lower() for keyword in ["select", "update", "delete", ";"]) assert not any(
keyword in query.lower() for keyword in ["select", "update", "delete", ";"]
)
return " AND ".join(terms) return " AND ".join(terms)
@ -173,7 +175,6 @@ class Job(NamedTuple):
class JobQueue(object): class JobQueue(object):
def __init__(self, path): def __init__(self, path):
self._db = sqlite3.connect(path) self._db = sqlite3.connect(path)
self._queries = anosql.from_str(_SQL, "sqlite3") self._queries = anosql.from_str(_SQL, "sqlite3")
@ -196,7 +197,7 @@ class JobQueue(object):
json.loads(payload), json.loads(payload),
json.loads(events), json.loads(events),
json.loads(state), json.loads(state),
datetime.fromtimestamp(int(modified)) datetime.fromtimestamp(int(modified)),
) )
def _from_result(self, result) -> Job: def _from_result(self, result) -> Job:
@ -227,6 +228,7 @@ class JobQueue(object):
if limit: if limit:
limit = int(limit) limit = int(limit)
def lf(iterable): def lf(iterable):
iterable = iter(iterable) iterable = iter(iterable)
for i in range(limit): for i in range(limit):
@ -234,6 +236,7 @@ class JobQueue(object):
yield next(iterable) yield next(iterable)
except StopIteration: except StopIteration:
break break
jobs = lf(jobs) jobs = lf(jobs)
return self._from_results(jobs) return self._from_results(jobs)
@ -265,9 +268,7 @@ class JobQueue(object):
"""Fetch all available data about a given job by ID.""" """Fetch all available data about a given job by ID."""
with self._db as db: with self._db as db:
return self._from_result( return self._from_result(self._queries.job_get(db, id=job_id))
self._queries.job_get(db, id=job_id)
)
def cas_state(self, job_id, old_state, new_state): def cas_state(self, job_id, old_state, new_state):
"""CAS update a job's state, returning the updated job or indicating a conflict.""" """CAS update a job's state, returning the updated job or indicating a conflict."""
@ -287,11 +288,7 @@ class JobQueue(object):
with self._db as db: with self._db as db:
return self._from_result( return self._from_result(
self._queries.job_append_event( self._queries.job_append_event(db, id=job_id, event=json.dumps(event))
db,
id=job_id,
event=json.dumps(event)
)
) )
def delete_job(self, job_id): def delete_job(self, job_id):

View file

@ -52,9 +52,7 @@ def get_jobs():
query = blob.get("query", "true") query = blob.get("query", "true")
return jsonify({ return jsonify({"jobs": [job_as_json(j) for j in request.q.query(query)]}), 200
"jobs": [job_as_json(j) for j in request.q.query(query)]
}), 200
@app.route("/api/v0/job/create", methods=["POST"]) @app.route("/api/v0/job/create", methods=["POST"])
@ -64,9 +62,7 @@ def create_job():
blob = request.get_json(force=True) blob = request.get_json(force=True)
payload = blob["payload"] payload = blob["payload"]
state = blob.get("state", None) state = blob.get("state", None)
job = request.q.create( job = request.q.create(payload, state)
payload, state
)
return jsonify(job_as_json(job)), 200 return jsonify(job_as_json(job)), 200

View file

@ -20,7 +20,7 @@ class Job(t.NamedTuple):
payload=obj["payload"], payload=obj["payload"],
events=obj["events"], events=obj["events"],
state=obj["state"], state=obj["state"],
modified=datetime.fromtimestamp(obj["modified"]) modified=datetime.fromtimestamp(obj["modified"]),
) )
@ -32,63 +32,62 @@ class JobqClient(object):
def jobs(self, query=None, limit=10) -> t.Iterable[Job]: def jobs(self, query=None, limit=10) -> t.Iterable[Job]:
"""Enumerate jobs on the queue.""" """Enumerate jobs on the queue."""
for job in self._session.post(self._url + "/api/v0/job", for job in (
json={"query": query or [], self._session.post(
"limit": limit})\ self._url + "/api/v0/job", json={"query": query or [], "limit": limit}
.json()\ )
.get("jobs"): .json()
.get("jobs")
):
yield Job.from_json(job) yield Job.from_json(job)
def poll(self, query, state) -> Job: def poll(self, query, state) -> Job:
"""Poll the job queue for the first job matching the given query, atomically advancing it to the given state and returning the advanced Job.""" """Poll the job queue for the first job matching the given query, atomically advancing it to the given state and returning the advanced Job."""
return Job.from_json( return Job.from_json(
self._session self._session.post(
.post(self._url + "/api/v0/job/poll", self._url + "/api/v0/job/poll", json={"query": query, "state": state}
json={"query": query, ).json()
"state": state}) )
.json())
def create(self, payload: object, state=None) -> Job: def create(self, payload: object, state=None) -> Job:
"""Create a new job in the system.""" """Create a new job in the system."""
return Job.from_json( return Job.from_json(
self._session self._session.post(
.post(self._url + "/api/v0/job/create", self._url + "/api/v0/job/create",
json={"payload": payload, json={"payload": payload, "state": state},
"state": state}) ).json()
.json()) )
def fetch(self, job: Job) -> Job: def fetch(self, job: Job) -> Job:
"""Fetch the current state of a job.""" """Fetch the current state of a job."""
return Job.from_json( return Job.from_json(
self._session self._session.get(self._url + "/api/v0/job/" + job.id).json()
.get(self._url + "/api/v0/job/" + job.id) )
.json())
def advance(self, job: Job, state: object) -> Job: def advance(self, job: Job, state: object) -> Job:
"""Attempt to advance a job to a subsequent state.""" """Attempt to advance a job to a subsequent state."""
return Job.from_json( return Job.from_json(
self._session self._session.post(
.post(job.url + "/state", job.url + "/state", json={"old": job.state, "new": state}
json={"old": job.state, ).json()
"new": state}) )
.json())
def event(self, job: Job, event: object) -> Job: def event(self, job: Job, event: object) -> Job:
"""Attempt to record an event against a job.""" """Attempt to record an event against a job."""
return Job.from_json( return Job.from_json(
self._session self._session.post(
.post(self._url + f"/api/v0/job/{job.id}/event", self._url + f"/api/v0/job/{job.id}/event", json=event
json=event) ).json()
.json()) )
def delete(self, job: Job) -> None: def delete(self, job: Job) -> None:
"""Delete a remote job.""" """Delete a remote job."""
return (self._session return self._session.delete(
.delete(self._url + f"/api/v0/job/{job.id}") self._url + f"/api/v0/job/{job.id}"
.raise_for_status()) ).raise_for_status()

View file

@ -12,7 +12,12 @@ from kazoo.exceptions import NodeExistsError
from kazoo.protocol.states import ZnodeStat from kazoo.protocol.states import ZnodeStat
from kazoo.recipe.lock import Lock, ReadLock, WriteLock from kazoo.recipe.lock import Lock, ReadLock, WriteLock
from kook.config import current_config, KookConfig from kook.config import current_config, KookConfig
from toolz.dicttoolz import assoc as _assoc, dissoc as _dissoc, merge as _merge, update_in from toolz.dicttoolz import (
assoc as _assoc,
dissoc as _dissoc,
merge as _merge,
update_in,
)
def assoc(m, k, v): def assoc(m, k, v):

View file

@ -14,10 +14,12 @@ setup(
packages=[ packages=[
"lilith", "lilith",
], ],
package_data={"": [ package_data={
"": [
"src/python/lilith/*.lark", "src/python/lilith/*.lark",
"src/python/lilith/*.lil", "src/python/lilith/*.lil",
]}, ]
},
include_package_data=True, include_package_data=True,
install_requires=requirements, install_requires=requirements,
entry_points={ entry_points={

View file

@ -89,16 +89,18 @@ class Proquint(object):
val = n << 8 | m val = n << 8 | m
# This is slightly un-idiomatic, but it precisely captures the coding definition # This is slightly un-idiomatic, but it precisely captures the coding definition
yield "".join([ yield "".join(
[
dict[val >> shift & mask] dict[val >> shift & mask]
for dict, shift, mask in [ for dict, shift, mask in [
(cls.CONSONANTS, 0xC, 0xf), (cls.CONSONANTS, 0xC, 0xF),
(cls.VOWELS, 0xA, 0x3), (cls.VOWELS, 0xA, 0x3),
(cls.CONSONANTS, 0x6, 0xf), (cls.CONSONANTS, 0x6, 0xF),
(cls.VOWELS, 0x4, 0x3), (cls.VOWELS, 0x4, 0x3),
(cls.CONSONANTS, 0x0, 0xf) (cls.CONSONANTS, 0x0, 0xF),
] ]
]) ]
)
# Core methods # Core methods
################################################################################################ ################################################################################################

View file

@ -10,7 +10,9 @@ from proquint import Proquint
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
g = parser.add_mutually_exclusive_group() g = parser.add_mutually_exclusive_group()
g.add_argument("-g", "--generate", dest="generate", default=False, action="store_true") g.add_argument("-g", "--generate", dest="generate", default=False, action="store_true")
g.add_argument("-p", "--predictable", dest="predictable", default=False, action="store_true") g.add_argument(
"-p", "--predictable", dest="predictable", default=False, action="store_true"
)
g.add_argument("-d", "--decode", dest="decode", default=False, action="store_true") g.add_argument("-d", "--decode", dest="decode", default=False, action="store_true")
g.add_argument("-e", "--encode", dest="encode", default=False, action="store_true") g.add_argument("-e", "--encode", dest="encode", default=False, action="store_true")
parser.add_argument("-w", "--width", dest="width", type=int, default=32) parser.add_argument("-w", "--width", dest="width", type=int, default=32)

View file

@ -38,7 +38,6 @@ examples = [
(536870912, 32, "fabab-babab"), (536870912, 32, "fabab-babab"),
(1073741824, 32, "habab-babab"), (1073741824, 32, "habab-babab"),
(2147483648, 32, "mabab-babab"), (2147483648, 32, "mabab-babab"),
# A random value # A random value
(3232235536, 32, "safom-babib"), (3232235536, 32, "safom-babib"),
] ]
@ -53,4 +52,6 @@ def test_decode_examples(val, width, qint):
def test_encode_examples(val, width, qint): def test_encode_examples(val, width, qint):
encoded_qint = proquint.Proquint.encode(val, width) encoded_qint = proquint.Proquint.encode(val, width)
decoded_val = proquint.Proquint.decode(encoded_qint) decoded_val = proquint.Proquint.decode(encoded_qint)
assert encoded_qint == qint, f"did not encode {val} to {qint}; got {encoded_qint} ({decoded_val})" assert (
encoded_qint == qint
), f"did not encode {val} to {qint}; got {encoded_qint} ({decoded_val})"

View file

@ -7,23 +7,19 @@ import proquint
@given(integers(min_value=0, max_value=1 << 16)) @given(integers(min_value=0, max_value=1 << 16))
def test_round_trip_16(val): def test_round_trip_16(val):
assert proquint.Proquint.decode( assert proquint.Proquint.decode(proquint.Proquint.encode(val, 16)) == val
proquint.Proquint.encode(val, 16)) == val
@given(integers(min_value=0, max_value=1 << 32)) @given(integers(min_value=0, max_value=1 << 32))
def test_round_trip_32(val): def test_round_trip_32(val):
assert proquint.Proquint.decode( assert proquint.Proquint.decode(proquint.Proquint.encode(val, 32)) == val
proquint.Proquint.encode(val, 32)) == val
@given(integers(min_value=0, max_value=1 << 64)) @given(integers(min_value=0, max_value=1 << 64))
def test_round_trip_64(val): def test_round_trip_64(val):
assert proquint.Proquint.decode( assert proquint.Proquint.decode(proquint.Proquint.encode(val, 64)) == val
proquint.Proquint.encode(val, 64)) == val
@given(integers(min_value=0, max_value=1 << 512)) @given(integers(min_value=0, max_value=1 << 512))
def test_round_trip_512(val): def test_round_trip_512(val):
assert proquint.Proquint.decode( assert proquint.Proquint.decode(proquint.Proquint.encode(val, 512)) == val
proquint.Proquint.encode(val, 512)) == val

View file

@ -58,7 +58,9 @@ def sort_key(requirement: str) -> str:
def _bq(query): def _bq(query):
"""Enumerate the PyPi package names from a Bazel query.""" """Enumerate the PyPi package names from a Bazel query."""
unused = subprocess.check_output(["bazel", "query", query, "--output=package"]).decode("utf-8") unused = subprocess.check_output(
["bazel", "query", query, "--output=package"]
).decode("utf-8")
for line in unused.split("\n"): for line in unused.split("\n"):
if line: if line:
yield line.replace("@arrdem_source_pypi//pypi__", "") yield line.replace("@arrdem_source_pypi//pypi__", "")
@ -67,7 +69,9 @@ def _bq(query):
def _unused(): def _unused():
"""Find unused requirements.""" """Find unused requirements."""
return set(_bq("@arrdem_source_pypi//...")) - set(_bq("filter('//pypi__', deps(//...))")) return set(_bq("@arrdem_source_pypi//...")) - set(
_bq("filter('//pypi__', deps(//...))")
)
def _load(fname): def _load(fname):

View file

@ -5,6 +5,8 @@ attrs==20.3.0
autoflake==1.4 autoflake==1.4
Babel==2.9.0 Babel==2.9.0
beautifulsoup4==4.9.3 beautifulsoup4==4.9.3
black==21.8b0
bleach==4.0.0
certifi==2020.12.5 certifi==2020.12.5
chardet==4.0.0 chardet==4.0.0
click==7.1.2 click==7.1.2
@ -40,8 +42,10 @@ openapi-schema-validator==0.1.5
openapi-spec-validator==0.3.0 openapi-spec-validator==0.3.0
packaging==20.9 packaging==20.9
parso==0.8.2 parso==0.8.2
pathspec==0.8.1 pathspec==0.9.0
pep517==0.11.0
pip-tools==6.2.0 pip-tools==6.2.0
platformdirs==2.3.0
pluggy==0.13.1 pluggy==0.13.1
port-for==0.6.1 port-for==0.6.1
prompt-toolkit==3.0.18 prompt-toolkit==3.0.18
@ -63,6 +67,7 @@ PyYAML==5.4.1
readme-renderer==29.0 readme-renderer==29.0
recommonmark==0.7.1 recommonmark==0.7.1
redis==3.5.3 redis==3.5.3
regex==2021.8.28
requests==2.25.1 requests==2.25.1
requests-toolbelt==0.9.1 requests-toolbelt==0.9.1
requirements-parser==0.2.0 requirements-parser==0.2.0
@ -81,15 +86,18 @@ sphinxcontrib-programoutput==0.17
sphinxcontrib-qthelp==1.0.3 sphinxcontrib-qthelp==1.0.3
sphinxcontrib-serializinghtml==1.1.4 sphinxcontrib-serializinghtml==1.1.4
toml==0.10.2 toml==0.10.2
tomli==1.2.1
tornado==6.1 tornado==6.1
typed-ast==1.4.2 typed-ast==1.4.2
typing-extensions==3.7.4.3 typing-extensions==3.10.0.2
unify==0.5 unify==0.5
untokenize==0.1.1 untokenize==0.1.1
urllib3==1.26.4 urllib3==1.26.4
urwid==2.1.2 urwid==2.1.2
wcwidth==0.2.5 wcwidth==0.2.5
webencodings==0.5.1
Werkzeug==2.0.1 Werkzeug==2.0.1
yamllint==1.26.1 yamllint==1.26.1
yarl==1.6.3 yarl==1.6.3
yaspin==1.5.0 yaspin==1.5.0
zipp==3.5.0

View file

@ -36,23 +36,17 @@ LICENSES_BY_LOWERNAME = {
"apache 2.0": APACHE, "apache 2.0": APACHE,
"apache": APACHE, "apache": APACHE,
"http://www.apache.org/licenses/license-2.0": APACHE, "http://www.apache.org/licenses/license-2.0": APACHE,
"bsd 3": BSD, "bsd 3": BSD,
"bsd": BSD, "bsd": BSD,
"gpl": GPL1, "gpl": GPL1,
"gpl2": GPL2, "gpl2": GPL2,
"gpl3": GPL3, "gpl3": GPL3,
"lgpl": LGPL, "lgpl": LGPL,
"lgpl3": LGPL3, "lgpl3": LGPL3,
"isc": ISCL, "isc": ISCL,
"mit": MIT, "mit": MIT,
"mpl": MPL10, "mpl": MPL10,
"mpl 2.0": MPL20, "mpl 2.0": MPL20,
"psf": PSFL, "psf": PSFL,
} }
@ -75,7 +69,9 @@ with open("tools/python/requirements.txt") as fd:
def bash_license(ln): def bash_license(ln):
while True: while True:
lnn = re.sub(r"[(),]|( version)|( license)|( ?v(?=\d))|([ -]clause)", "", ln.lower()) lnn = re.sub(
r"[(),]|( version)|( license)|( ?v(?=\d))|([ -]clause)", "", ln.lower()
)
if ln != lnn: if ln != lnn:
ln = lnn ln = lnn
else: else:
@ -85,7 +81,9 @@ def bash_license(ln):
return ln return ln
@pytest.mark.parametrize("a,b", [ @pytest.mark.parametrize(
"a,b",
[
("MIT", MIT), ("MIT", MIT),
("mit", MIT), ("mit", MIT),
("BSD", BSD), ("BSD", BSD),
@ -94,7 +92,8 @@ def bash_license(ln):
("GPL3", GPL3), ("GPL3", GPL3),
("GPL v3", GPL3), ("GPL v3", GPL3),
("GPLv3", GPL3), ("GPLv3", GPL3),
]) ],
)
def test_bash_license(a, b): def test_bash_license(a, b):
assert bash_license(a) == b assert bash_license(a) == b
@ -117,7 +116,7 @@ def licenses(package: Requirement):
if not version: if not version:
blob = requests.get( blob = requests.get(
f"https://pypi.org/pypi/{package.name}/json", f"https://pypi.org/pypi/{package.name}/json",
headers={"Accept": "application/json"} headers={"Accept": "application/json"},
).json() ).json()
if ln := bash_license(blob.get("license")): if ln := bash_license(blob.get("license")):
lics.append(ln) lics.append(ln)
@ -131,13 +130,15 @@ def licenses(package: Requirement):
if version: if version:
blob = requests.get( blob = requests.get(
f"https://pypi.org/pypi/{package.name}/{version}/json", f"https://pypi.org/pypi/{package.name}/{version}/json",
headers={"Accept": "application/json"} headers={"Accept": "application/json"},
).json() ).json()
lics.extend([ lics.extend(
[
c c
for c in blob.get("info", {}).get("classifiers", []) for c in blob.get("info", {}).get("classifiers", [])
if c.startswith("License") if c.startswith("License")
]) ]
)
ln = blob.get("info", {}).get("license") ln = blob.get("info", {}).get("license")
if ln and not lics: if ln and not lics:
lics.append(bash_license(ln)) lics.append(bash_license(ln))