From 4c709aad1a2ddd329225851f0ab4f83e98f5d2e1 Mon Sep 17 00:00:00 2001 From: Victor Shyba Date: Wed, 25 Mar 2020 03:54:40 -0300 Subject: [PATCH] sqlite3->apsw wip --- lbry/extras/daemon/storage.py | 3 ++ lbry/wallet/database.py | 78 ++++++++++++++++++------------ setup.py | 1 + tests/unit/wallet/test_database.py | 24 ++++----- 4 files changed, 63 insertions(+), 43 deletions(-) diff --git a/lbry/extras/daemon/storage.py b/lbry/extras/daemon/storage.py index 11a61e45e..f1a5ad908 100644 --- a/lbry/extras/daemon/storage.py +++ b/lbry/extras/daemon/storage.py @@ -344,6 +344,9 @@ class SQLiteStorage(SQLiteMixin): self.loop = loop or asyncio.get_event_loop() self.time_getter = time_getter or time.time + async def open(self): + await super().open() + async def run_and_return_one_or_none(self, query, *args): for row in await self.db.execute_fetchall(query, args): if len(row) == 1: diff --git a/lbry/wallet/database.py b/lbry/wallet/database.py index 218e54cc5..e6e11c1a3 100644 --- a/lbry/wallet/database.py +++ b/lbry/wallet/database.py @@ -1,7 +1,9 @@ import os import logging import asyncio -import sqlite3 +from collections import namedtuple + +import apsw import platform from binascii import hexlify from dataclasses import dataclass @@ -16,21 +18,25 @@ from .constants import TXO_TYPES, CLAIM_TYPES log = logging.getLogger(__name__) -sqlite3.enable_callback_tracebacks(True) @dataclass class ReaderProcessState: - cursor: sqlite3.Cursor + cursor: apsw.Connection reader_context: Optional[ContextVar[ReaderProcessState]] = ContextVar('reader_context') -def initializer(path): - db = sqlite3.connect(path) - db.row_factory = dict_row_factory - db.executescript("pragma journal_mode=WAL;") +def initializer(path, reader=True): + db = apsw.Connection( + path, + flags=( + apsw.SQLITE_OPEN_READONLY if reader else apsw.SQLITE_OPEN_READWRITE | + apsw.SQLITE_OPEN_URI + ) + ) + db.cursor().execute("pragma journal_mode=WAL;") reader = ReaderProcessState(db.cursor()) reader_context.set(reader) @@ -65,7 +71,7 @@ class AIOSQLite: def __init__(self): # has to be single threaded as there is no mapping of thread:connection self.writer_executor = ThreadPoolExecutor(max_workers=1) - self.writer_connection: Optional[sqlite3.Connection] = None + self.writer_connection: Optional[apsw.Connection] = None self._closing = False self.query_count = 0 self.write_lock = asyncio.Lock() @@ -74,11 +80,19 @@ class AIOSQLite: @classmethod async def connect(cls, path: Union[bytes, str], *args, **kwargs): - sqlite3.enable_callback_tracebacks(True) db = cls() def _connect_writer(): - db.writer_connection = sqlite3.connect(path, *args, **kwargs) + db.writer_connection = apsw.Connection( + path, + flags=( + apsw.SQLITE_OPEN_READWRITE | + apsw.SQLITE_OPEN_CREATE | + apsw.SQLITE_OPEN_URI + ) + ) + db.writer_connection.cursor().execute('pragma journal_mode=WAL;').fetchone() + db.writer_connection.cursor().execute('pragma foreign_keys=on').fetchone() readers = max(os.cpu_count() - 2, 2) db.reader_executor = ReaderExecutorClass( @@ -104,7 +118,7 @@ class AIOSQLite: return self.run(lambda conn: conn.executemany(sql, params).fetchall()) def executescript(self, script: str) -> Awaitable: - return self.run(lambda conn: conn.executescript(script)) + return self.run(lambda conn: conn.execute(script).fetchall()) async def _execute_fetch(self, sql: str, parameters: Iterable = None, read_only=False, fetch_all: bool = False) -> List[dict]: @@ -128,9 +142,9 @@ class AIOSQLite: read_only=False) -> List[dict]: return await self._execute_fetch(sql, parameters, read_only, fetch_all=False) - def execute(self, sql: str, parameters: Iterable = None) -> Awaitable[sqlite3.Cursor]: + def execute(self, sql: str, parameters: Iterable = None) -> Awaitable[apsw.Connection]: parameters = parameters if parameters is not None else [] - return self.run(lambda conn: conn.execute(sql, parameters)) + return self.run(lambda conn: conn.execute(sql, parameters).fetchall()) async def run(self, fun, *args, **kwargs): self.writers += 1 @@ -145,16 +159,16 @@ class AIOSQLite: if not self.writers: self.read_ready.set() - def __run_transaction(self, fun: Callable[[sqlite3.Connection, Any, Any], Any], *args, **kwargs): - self.writer_connection.execute('begin') + def __run_transaction(self, fun: Callable[[apsw.Connection, Any, Any], Any], *args, **kwargs): + self.writer_connection.cursor().execute('begin;') try: self.query_count += 1 - result = fun(self.writer_connection, *args, **kwargs) # type: ignore - self.writer_connection.commit() + result = fun(self.writer_connection.cursor(), *args, **kwargs) # type: ignore + self.writer_connection.cursor().execute('commit;') return result except (Exception, OSError) as e: log.exception('Error running transaction:', exc_info=e) - self.writer_connection.rollback() + self.writer_connection.cursor().execute('rollback;') log.warning("rolled back") raise @@ -164,16 +178,16 @@ class AIOSQLite: ) def __run_transaction_with_foreign_keys_disabled(self, - fun: Callable[[sqlite3.Connection, Any, Any], Any], + fun: Callable[[apsw.Connection, Any, Any], Any], args, kwargs): - foreign_keys_enabled, = self.writer_connection.execute("pragma foreign_keys").fetchone() + foreign_keys_enabled, = self.writer_connection.cursor().execute("pragma foreign_keys").fetchone() if not foreign_keys_enabled: - raise sqlite3.IntegrityError("foreign keys are disabled, use `AIOSQLite.run` instead") + raise apsw.Error("foreign keys are disabled, use `AIOSQLite.run` instead") try: - self.writer_connection.execute('pragma foreign_keys=off').fetchone() + self.writer_connection.cursor().execute('pragma foreign_keys=off').fetchone() return self.__run_transaction(fun, *args, **kwargs) finally: - self.writer_connection.execute('pragma foreign_keys=on').fetchone() + self.writer_connection.cursor().execute('pragma foreign_keys=on').fetchone() def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''): @@ -327,7 +341,7 @@ class SQLiteMixin: async def open(self): log.info("connecting to database: %s", self._db_path) - self.db = await AIOSQLite.connect(self._db_path, isolation_level=None) + self.db = await AIOSQLite.connect(self._db_path) if self.SCHEMA_VERSION: tables = [t[0] for t in await self.db.execute_fetchall( "SELECT name FROM sqlite_master WHERE type='table';" @@ -390,7 +404,6 @@ class Database(SQLiteMixin): SCHEMA_VERSION = "1.2" PRAGMAS = """ - pragma journal_mode=WAL; """ CREATE_ACCOUNT_TABLE = """ @@ -475,7 +488,10 @@ class Database(SQLiteMixin): async def open(self): await super().open() - self.db.writer_connection.row_factory = dict_row_factory + def exec_factory(cursor, statement, bindings): + cursor.setrowtrace(dict_row_factory) + return True + self.db.writer_connection.setexectrace(exec_factory) def txo_to_row(self, tx, txo): row = { @@ -484,7 +500,7 @@ class Database(SQLiteMixin): 'address': txo.get_address(self.ledger), 'position': txo.position, 'amount': txo.amount, - 'script': sqlite3.Binary(txo.script.source) + 'script': txo.script.source } if txo.is_claim: if txo.can_decode_claim: @@ -510,7 +526,7 @@ class Database(SQLiteMixin): def tx_to_row(tx): row = { 'txid': tx.id, - 'raw': sqlite3.Binary(tx.raw), + 'raw': tx.raw, 'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified @@ -529,7 +545,7 @@ class Database(SQLiteMixin): 'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified }, 'txid = ?', (tx.id,))) - def _transaction_io(self, conn: sqlite3.Connection, tx: Transaction, address, txhash): + def _transaction_io(self, conn: apsw.Connection, tx: Transaction, address, txhash): conn.execute(*self._insert_sql('tx', self.tx_to_row(tx), replace=True)).fetchall() is_my_input = False @@ -933,8 +949,8 @@ class Database(SQLiteMixin): "(account, address, chain, pubkey, chain_code, n, depth) values " "(?, ?, ?, ?, ?, ?, ?)", (( account.id, k.address, chain, - sqlite3.Binary(k.pubkey_bytes), - sqlite3.Binary(k.chain_code), + k.pubkey_bytes, + k.chain_code, k.n, k.depth ) for k in pubkeys) ) diff --git a/setup.py b/setup.py index e8fa1b759..ad5fb130d 100644 --- a/setup.py +++ b/setup.py @@ -33,6 +33,7 @@ setup( ], }, install_requires=[ + 'apsw==3.30.1.post1', 'aiohttp==3.5.4', 'aioupnp==0.0.17', 'appdirs==1.4.3', diff --git a/tests/unit/wallet/test_database.py b/tests/unit/wallet/test_database.py index 7ceed23d7..5998e4bb1 100644 --- a/tests/unit/wallet/test_database.py +++ b/tests/unit/wallet/test_database.py @@ -1,7 +1,7 @@ import sys import os import unittest -import sqlite3 +import apsw import tempfile import asyncio from concurrent.futures.thread import ThreadPoolExecutor @@ -35,7 +35,7 @@ class TestAIOSQLite(AsyncioTestCase): async def test_foreign_keys_integrity_error(self): self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent")) - with self.assertRaises(sqlite3.IntegrityError): + with self.assertRaises(apsw.ConstraintError): await self.db.run(self.delete_item) self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent")) @@ -52,7 +52,7 @@ class TestAIOSQLite(AsyncioTestCase): async def test_integrity_error_when_foreign_keys_disabled_and_skipped(self): await self.db.executescript("pragma foreign_keys=off;") self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent")) - with self.assertRaises(sqlite3.IntegrityError): + with self.assertRaises(apsw.ConstraintError): await self.db.run_with_foreign_keys_disabled(self.delete_item) self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent")) @@ -401,25 +401,25 @@ class TestUpgrade(AsyncioTestCase): os.remove(self.path) def get_version(self): - with sqlite3.connect(self.path) as conn: + with apsw.Connection(self.path) as conn: versions = conn.execute('select version from version').fetchall() assert len(versions) == 1 return versions[0][0] def get_tables(self): - with sqlite3.connect(self.path) as conn: + with apsw.Connection(self.path) as conn: sql = "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;" return [col[0] for col in conn.execute(sql).fetchall()] def add_address(self, address): - with sqlite3.connect(self.path) as conn: + with apsw.Connection(self.path) as conn: conn.execute(""" INSERT INTO account_address (address, account, chain, n, pubkey, chain_code, depth) VALUES (?, 'account1', 0, 0, 'pubkey', 'chain_code', 0) """, (address,)) def get_addresses(self): - with sqlite3.connect(self.path) as conn: + with apsw.Connection(self.path) as conn: sql = "SELECT address FROM account_address ORDER BY address;" return [col[0] for col in conn.execute(sql).fetchall()] @@ -472,7 +472,7 @@ class TestSQLiteRace(AsyncioTestCase): max_misuse_attempts = 40000 def setup_db(self): - self.db = sqlite3.connect(":memory:", isolation_level=None) + self.db = apsw.Connection(":memory:") self.db.executescript( "create table test1 (id text primary key not null, val text);\n" + "create table test2 (id text primary key not null, val text);\n" + @@ -505,7 +505,7 @@ class TestSQLiteRace(AsyncioTestCase): [(unsupported_type(1), ), (unsupported_type(2), )] ) self.assertTrue(False) - except sqlite3.InterfaceError as err: + except apsw.ConstraintError as err: self.assertEqual(str(err), "Error binding parameter 0 - probably unsupported type.") async def test_unhandled_sqlite_misuse(self): @@ -529,13 +529,13 @@ class TestSQLiteRace(AsyncioTestCase): ) attempts += 1 await asyncio.gather(f1, f2) - print(f"\nsqlite3 {sqlite3.version}/python {python_version} " + print(f"\nsqlite3 {apsw.sqlitelibversion()}/python {python_version} " f"did not raise SQLITE_MISUSE within {attempts} attempts of the race condition") self.assertTrue(False, 'this test failing means either the sqlite race conditions ' 'have been fixed in cpython or the test max_attempts needs to be increased') - except sqlite3.InterfaceError as err: + except apsw.ConstraintError as err: self.assertEqual(str(err), "Error binding parameter 0 - probably unsupported type.") - print(f"\nsqlite3 {sqlite3.version}/python {python_version} raised SQLITE_MISUSE " + print(f"\nsqlite3 {apsw.sqlitelibversion()}/python {python_version} raised SQLITE_MISUSE " f"after {attempts} attempts of the race condition") @unittest.SkipTest