adds a simple sqlite asyncio wrapper
This commit is contained in:
parent
6fe648b9a9
commit
3b021af1bd
1
setup.py
1
setup.py
|
@ -27,7 +27,6 @@ setup(
|
||||||
python_requires='>=3.6',
|
python_requires='>=3.6',
|
||||||
install_requires=(
|
install_requires=(
|
||||||
'aiorpcx',
|
'aiorpcx',
|
||||||
'aiosqlite',
|
|
||||||
'coincurve',
|
'coincurve',
|
||||||
'pbkdf2',
|
'pbkdf2',
|
||||||
'cryptography'
|
'cryptography'
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Tuple, List, Union
|
import asyncio
|
||||||
|
from asyncio import wrap_future
|
||||||
|
from concurrent.futures.thread import ThreadPoolExecutor
|
||||||
|
|
||||||
|
from typing import Tuple, List, Union, Callable, Any, Awaitable, Iterable
|
||||||
|
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import aiosqlite
|
|
||||||
|
|
||||||
from torba.hash import TXRefImmutable
|
from torba.hash import TXRefImmutable
|
||||||
from torba.basetransaction import BaseTransaction
|
from torba.basetransaction import BaseTransaction
|
||||||
|
@ -11,6 +14,54 @@ from torba.baseaccount import BaseAccount
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AIOSQLite:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# has to be single threaded as there is no mapping of thread:connection
|
||||||
|
self.executor = ThreadPoolExecutor(max_workers=1)
|
||||||
|
self.connection: sqlite3.Connection = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def connect(cls, path: Union[bytes, str], *args, **kwargs):
|
||||||
|
db = cls()
|
||||||
|
db.connection = await wrap_future(db.executor.submit(sqlite3.connect, path, *args, **kwargs))
|
||||||
|
return db
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
def __close(conn):
|
||||||
|
self.executor.submit(conn.close)
|
||||||
|
self.executor.shutdown(wait=True)
|
||||||
|
conn = self.connection
|
||||||
|
self.connection = None
|
||||||
|
return asyncio._get_running_loop().call_later(0.01, __close, conn)
|
||||||
|
|
||||||
|
def executescript(self, script: str) -> Awaitable:
|
||||||
|
return wrap_future(self.executor.submit(self.connection.executescript, script))
|
||||||
|
|
||||||
|
def execute_fetchall(self, sql: str, parameters: Iterable = None) -> Awaitable[Iterable[sqlite3.Row]]:
|
||||||
|
parameters = parameters if parameters is not None else []
|
||||||
|
def __fetchall(conn: sqlite3.Connection, *args, **kwargs):
|
||||||
|
return conn.execute(*args, **kwargs).fetchall()
|
||||||
|
return wrap_future(self.executor.submit(__fetchall, self.connection, sql, parameters))
|
||||||
|
|
||||||
|
def execute(self, sql: str, parameters: Iterable = None) -> Awaitable[sqlite3.Cursor]:
|
||||||
|
parameters = parameters if parameters is not None else []
|
||||||
|
return self.run(lambda conn, sql, parameters: conn.execute(sql, parameters), sql, parameters)
|
||||||
|
|
||||||
|
def run(self, fn: Callable[[sqlite3.Connection, Any], Any], *args, **kwargs) -> Awaitable:
|
||||||
|
return wrap_future(self.executor.submit(self.__run_transaction, fn, *args, **kwargs))
|
||||||
|
|
||||||
|
def __run_transaction(self, fn: Callable[[sqlite3.Connection, Any], Any], *args, **kwargs):
|
||||||
|
self.connection.execute('begin')
|
||||||
|
try:
|
||||||
|
fn(self.connection, *args, **kwargs)
|
||||||
|
except (Exception, OSError):
|
||||||
|
self.connection.rollback()
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
self.connection.commit()
|
||||||
|
|
||||||
|
|
||||||
def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''):
|
def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''):
|
||||||
sql, values = [], {}
|
sql, values = [], {}
|
||||||
for key, constraint in constraints.items():
|
for key, constraint in constraints.items():
|
||||||
|
@ -105,13 +156,12 @@ class SQLiteMixin:
|
||||||
|
|
||||||
def __init__(self, path):
|
def __init__(self, path):
|
||||||
self._db_path = path
|
self._db_path = path
|
||||||
self.db: aiosqlite.Connection = None
|
self.db: AIOSQLite = None
|
||||||
self.ledger = None
|
self.ledger = None
|
||||||
|
|
||||||
async def open(self):
|
async def open(self):
|
||||||
log.info("connecting to database: %s", self._db_path)
|
log.info("connecting to database: %s", self._db_path)
|
||||||
self.db = aiosqlite.connect(self._db_path, isolation_level=None)
|
self.db = await AIOSQLite.connect(self._db_path)
|
||||||
await self.db.__aenter__()
|
|
||||||
await self.db.executescript(self.CREATE_TABLES_QUERY)
|
await self.db.executescript(self.CREATE_TABLES_QUERY)
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
|
|
Loading…
Reference in a new issue