385 lines
14 KiB
Python
385 lines
14 KiB
Python
import os
|
|
import logging
|
|
import asyncio
|
|
import sqlite3
|
|
import platform
|
|
from binascii import hexlify
|
|
from dataclasses import dataclass
|
|
from contextvars import ContextVar
|
|
from concurrent.futures.thread import ThreadPoolExecutor
|
|
from concurrent.futures.process import ProcessPoolExecutor
|
|
from typing import Tuple, List, Union, Callable, Any, Awaitable, Iterable, Dict, Optional
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
sqlite3.enable_callback_tracebacks(True)
|
|
|
|
|
|
@dataclass
|
|
class ReaderProcessState:
|
|
cursor: sqlite3.Cursor
|
|
|
|
|
|
reader_context: Optional[ContextVar[ReaderProcessState]] = ContextVar('reader_context')
|
|
|
|
|
|
def initializer(path):
|
|
db = sqlite3.connect(path)
|
|
db.row_factory = dict_row_factory
|
|
db.executescript("pragma journal_mode=WAL;")
|
|
reader = ReaderProcessState(db.cursor())
|
|
reader_context.set(reader)
|
|
|
|
|
|
def run_read_only_fetchall(sql, params):
|
|
cursor = reader_context.get().cursor
|
|
try:
|
|
return cursor.execute(sql, params).fetchall()
|
|
except (Exception, OSError) as e:
|
|
log.exception('Error running transaction:', exc_info=e)
|
|
raise
|
|
|
|
|
|
def run_read_only_fetchone(sql, params):
|
|
cursor = reader_context.get().cursor
|
|
try:
|
|
return cursor.execute(sql, params).fetchone()
|
|
except (Exception, OSError) as e:
|
|
log.exception('Error running transaction:', exc_info=e)
|
|
raise
|
|
|
|
|
|
if platform.system() == 'Windows' or 'ANDROID_ARGUMENT' in os.environ:
|
|
ReaderExecutorClass = ThreadPoolExecutor
|
|
else:
|
|
ReaderExecutorClass = ProcessPoolExecutor
|
|
|
|
|
|
class AIOSQLite:
|
|
reader_executor: ReaderExecutorClass
|
|
|
|
def __init__(self):
|
|
# has to be single threaded as there is no mapping of thread:connection
|
|
self.writer_executor = ThreadPoolExecutor(max_workers=1)
|
|
self.writer_connection: Optional[sqlite3.Connection] = None
|
|
self._closing = False
|
|
self.query_count = 0
|
|
self.write_lock = asyncio.Lock()
|
|
self.writers = 0
|
|
self.read_ready = asyncio.Event()
|
|
|
|
@classmethod
|
|
async def connect(cls, path: Union[bytes, str], *args, **kwargs):
|
|
sqlite3.enable_callback_tracebacks(True)
|
|
db = cls()
|
|
|
|
def _connect_writer():
|
|
db.writer_connection = sqlite3.connect(path, *args, **kwargs)
|
|
|
|
readers = max(os.cpu_count() - 2, 2)
|
|
db.reader_executor = ReaderExecutorClass(
|
|
max_workers=readers, initializer=initializer, initargs=(path, )
|
|
)
|
|
await asyncio.get_event_loop().run_in_executor(db.writer_executor, _connect_writer)
|
|
db.read_ready.set()
|
|
return db
|
|
|
|
async def close(self):
|
|
if self._closing:
|
|
return
|
|
self._closing = True
|
|
await asyncio.get_event_loop().run_in_executor(self.writer_executor, self.writer_connection.close)
|
|
self.writer_executor.shutdown(wait=True)
|
|
self.reader_executor.shutdown(wait=True)
|
|
self.read_ready.clear()
|
|
self.writer_connection = None
|
|
|
|
def executemany(self, sql: str, params: Iterable):
|
|
params = params if params is not None else []
|
|
# this fetchall is needed to prevent SQLITE_MISUSE
|
|
return self.run(lambda conn: conn.executemany(sql, params).fetchall())
|
|
|
|
def executescript(self, script: str) -> Awaitable:
|
|
return self.run(lambda conn: conn.executescript(script))
|
|
|
|
async def _execute_fetch(self, sql: str, parameters: Iterable = None,
|
|
read_only=False, fetch_all: bool = False) -> List[dict]:
|
|
read_only_fn = run_read_only_fetchall if fetch_all else run_read_only_fetchone
|
|
parameters = parameters if parameters is not None else []
|
|
if read_only:
|
|
while self.writers:
|
|
await self.read_ready.wait()
|
|
return await asyncio.get_event_loop().run_in_executor(
|
|
self.reader_executor, read_only_fn, sql, parameters
|
|
)
|
|
if fetch_all:
|
|
return await self.run(lambda conn: conn.execute(sql, parameters).fetchall())
|
|
return await self.run(lambda conn: conn.execute(sql, parameters).fetchone())
|
|
|
|
async def execute_fetchall(self, sql: str, parameters: Iterable = None,
|
|
read_only=False) -> List[dict]:
|
|
return await self._execute_fetch(sql, parameters, read_only, fetch_all=True)
|
|
|
|
async def execute_fetchone(self, sql: str, parameters: Iterable = None,
|
|
read_only=False) -> List[dict]:
|
|
return await self._execute_fetch(sql, parameters, read_only, fetch_all=False)
|
|
|
|
def execute(self, sql: str, parameters: Iterable = None) -> Awaitable[sqlite3.Cursor]:
|
|
parameters = parameters if parameters is not None else []
|
|
return self.run(lambda conn: conn.execute(sql, parameters))
|
|
|
|
async def run(self, fun, *args, **kwargs):
|
|
self.writers += 1
|
|
self.read_ready.clear()
|
|
async with self.write_lock:
|
|
try:
|
|
return await asyncio.get_event_loop().run_in_executor(
|
|
self.writer_executor, lambda: self.__run_transaction(fun, *args, **kwargs)
|
|
)
|
|
finally:
|
|
self.writers -= 1
|
|
if not self.writers:
|
|
self.read_ready.set()
|
|
|
|
def __run_transaction(self, fun: Callable[[sqlite3.Connection, Any, Any], Any], *args, **kwargs):
|
|
self.writer_connection.execute('begin')
|
|
try:
|
|
self.query_count += 1
|
|
result = fun(self.writer_connection, *args, **kwargs) # type: ignore
|
|
self.writer_connection.commit()
|
|
return result
|
|
except (Exception, OSError) as e:
|
|
log.exception('Error running transaction:', exc_info=e)
|
|
self.writer_connection.rollback()
|
|
log.warning("rolled back")
|
|
raise
|
|
|
|
def run_with_foreign_keys_disabled(self, fun, *args, **kwargs) -> Awaitable:
|
|
return asyncio.get_event_loop().run_in_executor(
|
|
self.writer_executor, self.__run_transaction_with_foreign_keys_disabled, fun, args, kwargs
|
|
)
|
|
|
|
def __run_transaction_with_foreign_keys_disabled(self,
|
|
fun: Callable[[sqlite3.Connection, Any, Any], Any],
|
|
args, kwargs):
|
|
foreign_keys_enabled, = self.writer_connection.execute("pragma foreign_keys").fetchone()
|
|
if not foreign_keys_enabled:
|
|
raise sqlite3.IntegrityError("foreign keys are disabled, use `AIOSQLite.run` instead")
|
|
try:
|
|
self.writer_connection.execute('pragma foreign_keys=off').fetchone()
|
|
return self.__run_transaction(fun, *args, **kwargs)
|
|
finally:
|
|
self.writer_connection.execute('pragma foreign_keys=on').fetchone()
|
|
|
|
|
|
def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''):
|
|
sql, values = [], {}
|
|
for key, constraint in constraints.items():
|
|
tag = '0'
|
|
if '#' in key:
|
|
key, tag = key[:key.index('#')], key[key.index('#')+1:]
|
|
col, op, key = key, '=', key.replace('.', '_')
|
|
if not key:
|
|
sql.append(constraint)
|
|
continue
|
|
if key.startswith('$$'):
|
|
col, key = col[2:], key[1:]
|
|
elif key.startswith('$'):
|
|
values[key] = constraint
|
|
continue
|
|
if key.endswith('__not'):
|
|
col, op = col[:-len('__not')], '!='
|
|
elif key.endswith('__is_null'):
|
|
col = col[:-len('__is_null')]
|
|
sql.append(f'{col} IS NULL')
|
|
continue
|
|
if key.endswith('__is_not_null'):
|
|
col = col[:-len('__is_not_null')]
|
|
sql.append(f'{col} IS NOT NULL')
|
|
continue
|
|
if key.endswith('__lt'):
|
|
col, op = col[:-len('__lt')], '<'
|
|
elif key.endswith('__lte'):
|
|
col, op = col[:-len('__lte')], '<='
|
|
elif key.endswith('__gt'):
|
|
col, op = col[:-len('__gt')], '>'
|
|
elif key.endswith('__gte'):
|
|
col, op = col[:-len('__gte')], '>='
|
|
elif key.endswith('__like'):
|
|
col, op = col[:-len('__like')], 'LIKE'
|
|
elif key.endswith('__not_like'):
|
|
col, op = col[:-len('__not_like')], 'NOT LIKE'
|
|
elif key.endswith('__in') or key.endswith('__not_in'):
|
|
if key.endswith('__in'):
|
|
col, op, one_val_op = col[:-len('__in')], 'IN', '='
|
|
else:
|
|
col, op, one_val_op = col[:-len('__not_in')], 'NOT IN', '!='
|
|
if constraint:
|
|
if isinstance(constraint, (list, set, tuple)):
|
|
if len(constraint) == 1:
|
|
values[f'{key}{tag}'] = next(iter(constraint))
|
|
sql.append(f'{col} {one_val_op} :{key}{tag}')
|
|
else:
|
|
keys = []
|
|
for i, val in enumerate(constraint):
|
|
keys.append(f':{key}{tag}_{i}')
|
|
values[f'{key}{tag}_{i}'] = val
|
|
sql.append(f'{col} {op} ({", ".join(keys)})')
|
|
elif isinstance(constraint, str):
|
|
sql.append(f'{col} {op} ({constraint})')
|
|
else:
|
|
raise ValueError(f"{col} requires a list, set or string as constraint value.")
|
|
continue
|
|
elif key.endswith('__any') or key.endswith('__or'):
|
|
where, subvalues = constraints_to_sql(constraint, ' OR ', key+tag+'_')
|
|
sql.append(f'({where})')
|
|
values.update(subvalues)
|
|
continue
|
|
if key.endswith('__and'):
|
|
where, subvalues = constraints_to_sql(constraint, ' AND ', key+tag+'_')
|
|
sql.append(f'({where})')
|
|
values.update(subvalues)
|
|
continue
|
|
sql.append(f'{col} {op} :{prepend_key}{key}{tag}')
|
|
values[prepend_key+key+tag] = constraint
|
|
return joiner.join(sql) if sql else '', values
|
|
|
|
|
|
def query(select, **constraints) -> Tuple[str, Dict[str, Any]]:
|
|
sql = [select]
|
|
limit = constraints.pop('limit', None)
|
|
offset = constraints.pop('offset', None)
|
|
order_by = constraints.pop('order_by', None)
|
|
group_by = constraints.pop('group_by', None)
|
|
|
|
accounts = constraints.pop('accounts', [])
|
|
if accounts:
|
|
constraints['account__in'] = [a.public_key.address for a in accounts]
|
|
|
|
where, values = constraints_to_sql(constraints)
|
|
if where:
|
|
sql.append('WHERE')
|
|
sql.append(where)
|
|
|
|
if group_by is not None:
|
|
sql.append(f'GROUP BY {group_by}')
|
|
|
|
if order_by:
|
|
sql.append('ORDER BY')
|
|
if isinstance(order_by, str):
|
|
sql.append(order_by)
|
|
elif isinstance(order_by, list):
|
|
sql.append(', '.join(order_by))
|
|
else:
|
|
raise ValueError("order_by must be string or list")
|
|
|
|
if limit is not None:
|
|
sql.append(f'LIMIT {limit}')
|
|
|
|
if offset is not None:
|
|
sql.append(f'OFFSET {offset}')
|
|
|
|
return ' '.join(sql), values
|
|
|
|
|
|
def interpolate(sql, values):
|
|
for k in sorted(values.keys(), reverse=True):
|
|
value = values[k]
|
|
if isinstance(value, bytes):
|
|
value = f"X'{hexlify(value).decode()}'"
|
|
elif isinstance(value, str):
|
|
value = f"'{value}'"
|
|
else:
|
|
value = str(value)
|
|
sql = sql.replace(f":{k}", value)
|
|
return sql
|
|
|
|
|
|
def constrain_single_or_list(constraints, column, value, convert=lambda x: x):
|
|
if value is not None:
|
|
if isinstance(value, list):
|
|
value = [convert(v) for v in value]
|
|
if len(value) == 1:
|
|
constraints[column] = value[0]
|
|
elif len(value) > 1:
|
|
constraints[f"{column}__in"] = value
|
|
else:
|
|
constraints[column] = convert(value)
|
|
return constraints
|
|
|
|
|
|
class SQLiteMixin:
|
|
|
|
SCHEMA_VERSION: Optional[str] = None
|
|
CREATE_TABLES_QUERY: str
|
|
MAX_QUERY_VARIABLES = 900
|
|
|
|
CREATE_VERSION_TABLE = """
|
|
create table if not exists version (
|
|
version text
|
|
);
|
|
"""
|
|
|
|
def __init__(self, path):
|
|
self._db_path = path
|
|
self.db: AIOSQLite = None
|
|
self.ledger = None
|
|
|
|
async def open(self):
|
|
log.info("connecting to database: %s", self._db_path)
|
|
self.db = await AIOSQLite.connect(self._db_path, isolation_level=None)
|
|
if self.SCHEMA_VERSION:
|
|
tables = [t[0] for t in await self.db.execute_fetchall(
|
|
"SELECT name FROM sqlite_master WHERE type='table';"
|
|
)]
|
|
if tables:
|
|
if 'version' in tables:
|
|
version = await self.db.execute_fetchone("SELECT version FROM version LIMIT 1;")
|
|
if version == (self.SCHEMA_VERSION,):
|
|
return
|
|
await self.db.executescript('\n'.join(
|
|
f"DROP TABLE {table};" for table in tables
|
|
))
|
|
await self.db.execute(self.CREATE_VERSION_TABLE)
|
|
await self.db.execute("INSERT INTO version VALUES (?)", (self.SCHEMA_VERSION,))
|
|
await self.db.executescript(self.CREATE_TABLES_QUERY)
|
|
|
|
async def close(self):
|
|
await self.db.close()
|
|
|
|
@staticmethod
|
|
def _insert_sql(table: str, data: dict, ignore_duplicate: bool = False,
|
|
replace: bool = False) -> Tuple[str, List]:
|
|
columns, values = [], []
|
|
for column, value in data.items():
|
|
columns.append(column)
|
|
values.append(value)
|
|
policy = ""
|
|
if ignore_duplicate:
|
|
policy = " OR IGNORE"
|
|
if replace:
|
|
policy = " OR REPLACE"
|
|
sql = "INSERT{} INTO {} ({}) VALUES ({})".format(
|
|
policy, table, ', '.join(columns), ', '.join(['?'] * len(values))
|
|
)
|
|
return sql, values
|
|
|
|
@staticmethod
|
|
def _update_sql(table: str, data: dict, where: str,
|
|
constraints: Union[list, tuple]) -> Tuple[str, list]:
|
|
columns, values = [], []
|
|
for column, value in data.items():
|
|
columns.append(f"{column} = ?")
|
|
values.append(value)
|
|
values.extend(constraints)
|
|
sql = "UPDATE {} SET {} WHERE {}".format(
|
|
table, ', '.join(columns), where
|
|
)
|
|
return sql, values
|
|
|
|
|
|
def dict_row_factory(cursor, row):
|
|
d = {}
|
|
for idx, col in enumerate(cursor.description):
|
|
d[col[0]] = row[idx]
|
|
return d
|