source/projects/anosql/src/python/anosql/adapters/sqlite3.py
2022-01-08 23:53:24 -07:00

96 lines
2.8 KiB
Python

"""
A driver object implementing support for SQLite3
"""
from contextlib import contextmanager
import logging
import re
import sqlite3
log = logging.getLogger(__name__)
class SQLite3DriverAdapter(object):
@staticmethod
def process_sql(_query_name, _op_type, sql):
"""Munge queries.
Args:
_query_name (str): The name of the sql query.
_op_type (anosql.SQLOperationType): The type of SQL operation performed by the sql.
sql (str): The sql as written before processing.
Returns:
str: A normalized form of the query suitable to logging or copy/paste.
"""
# Normalize out comments
sql = re.sub(r"-{2,}.*?\n", "", sql)
# Normalize out a variety of syntactically irrelevant whitespace
#
# FIXME: This is technically invalid, because what if you had `foo ` as
# a table name. Shit idea, but this won't handle it correctly.
sql = re.sub(r"\s+", " ", sql)
sql = re.sub(r"\(\s+", "(", sql)
sql = re.sub(r"\s+\)", ")", sql)
sql = re.sub(r"\s+,", ",", sql)
sql = re.sub(r"\s+;", ";", sql)
return sql
@staticmethod
def select(conn, _query_name, sql, parameters):
cur = conn.cursor()
log.debug({"sql": sql, "parameters": parameters})
cur.execute(sql, parameters)
results = cur.fetchall()
cur.close()
return results
@staticmethod
@contextmanager
def select_cursor(conn: sqlite3.Connection, _query_name, sql, parameters):
cur = conn.cursor()
log.debug({"sql": sql, "parameters": parameters})
cur.execute(sql, parameters)
try:
yield cur
finally:
cur.close()
@staticmethod
def insert_update_delete(conn: sqlite3.Connection, _query_name, sql, parameters):
log.debug({"sql": sql, "parameters": parameters})
conn.execute(sql, parameters)
@staticmethod
def insert_update_delete_many(
conn: sqlite3.Connection, _query_name, sql, parameters
):
log.debug({"sql": sql, "parameters": parameters})
conn.executemany(sql, parameters)
@staticmethod
def insert_returning(conn: sqlite3.Connection, _query_name, sql, parameters):
cur = conn.cursor()
log.debug({"sql": sql, "parameters": parameters})
cur.execute(sql, parameters)
if "returning" not in sql.lower():
# Original behavior - return the last row ID
results = cur.lastrowid
else:
# New behavior - honor a `RETURNING` clause
results = cur.fetchall()
log.debug({"results": results})
cur.close()
return results
@staticmethod
def execute_script(conn: sqlite3.Connection, sql):
log.debug({"sql": sql, "parameters": None})
conn.executescript(sql)