sqlite3->apsw wip

This commit is contained in:
Victor Shyba 2020-03-25 03:54:40 -03:00
parent 121f34fde7
commit 4c709aad1a
4 changed files with 63 additions and 43 deletions

View file

@ -344,6 +344,9 @@ class SQLiteStorage(SQLiteMixin):
self.loop = loop or asyncio.get_event_loop() self.loop = loop or asyncio.get_event_loop()
self.time_getter = time_getter or time.time self.time_getter = time_getter or time.time
async def open(self):
await super().open()
async def run_and_return_one_or_none(self, query, *args): async def run_and_return_one_or_none(self, query, *args):
for row in await self.db.execute_fetchall(query, args): for row in await self.db.execute_fetchall(query, args):
if len(row) == 1: if len(row) == 1:

View file

@ -1,7 +1,9 @@
import os import os
import logging import logging
import asyncio import asyncio
import sqlite3 from collections import namedtuple
import apsw
import platform import platform
from binascii import hexlify from binascii import hexlify
from dataclasses import dataclass from dataclasses import dataclass
@ -16,21 +18,25 @@ from .constants import TXO_TYPES, CLAIM_TYPES
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
sqlite3.enable_callback_tracebacks(True)
@dataclass @dataclass
class ReaderProcessState: class ReaderProcessState:
cursor: sqlite3.Cursor cursor: apsw.Connection
reader_context: Optional[ContextVar[ReaderProcessState]] = ContextVar('reader_context') reader_context: Optional[ContextVar[ReaderProcessState]] = ContextVar('reader_context')
def initializer(path): def initializer(path, reader=True):
db = sqlite3.connect(path) db = apsw.Connection(
db.row_factory = dict_row_factory path,
db.executescript("pragma journal_mode=WAL;") flags=(
apsw.SQLITE_OPEN_READONLY if reader else apsw.SQLITE_OPEN_READWRITE |
apsw.SQLITE_OPEN_URI
)
)
db.cursor().execute("pragma journal_mode=WAL;")
reader = ReaderProcessState(db.cursor()) reader = ReaderProcessState(db.cursor())
reader_context.set(reader) reader_context.set(reader)
@ -65,7 +71,7 @@ class AIOSQLite:
def __init__(self): def __init__(self):
# has to be single threaded as there is no mapping of thread:connection # has to be single threaded as there is no mapping of thread:connection
self.writer_executor = ThreadPoolExecutor(max_workers=1) self.writer_executor = ThreadPoolExecutor(max_workers=1)
self.writer_connection: Optional[sqlite3.Connection] = None self.writer_connection: Optional[apsw.Connection] = None
self._closing = False self._closing = False
self.query_count = 0 self.query_count = 0
self.write_lock = asyncio.Lock() self.write_lock = asyncio.Lock()
@ -74,11 +80,19 @@ class AIOSQLite:
@classmethod @classmethod
async def connect(cls, path: Union[bytes, str], *args, **kwargs): async def connect(cls, path: Union[bytes, str], *args, **kwargs):
sqlite3.enable_callback_tracebacks(True)
db = cls() db = cls()
def _connect_writer(): def _connect_writer():
db.writer_connection = sqlite3.connect(path, *args, **kwargs) db.writer_connection = apsw.Connection(
path,
flags=(
apsw.SQLITE_OPEN_READWRITE |
apsw.SQLITE_OPEN_CREATE |
apsw.SQLITE_OPEN_URI
)
)
db.writer_connection.cursor().execute('pragma journal_mode=WAL;').fetchone()
db.writer_connection.cursor().execute('pragma foreign_keys=on').fetchone()
readers = max(os.cpu_count() - 2, 2) readers = max(os.cpu_count() - 2, 2)
db.reader_executor = ReaderExecutorClass( db.reader_executor = ReaderExecutorClass(
@ -104,7 +118,7 @@ class AIOSQLite:
return self.run(lambda conn: conn.executemany(sql, params).fetchall()) return self.run(lambda conn: conn.executemany(sql, params).fetchall())
def executescript(self, script: str) -> Awaitable: def executescript(self, script: str) -> Awaitable:
return self.run(lambda conn: conn.executescript(script)) return self.run(lambda conn: conn.execute(script).fetchall())
async def _execute_fetch(self, sql: str, parameters: Iterable = None, async def _execute_fetch(self, sql: str, parameters: Iterable = None,
read_only=False, fetch_all: bool = False) -> List[dict]: read_only=False, fetch_all: bool = False) -> List[dict]:
@ -128,9 +142,9 @@ class AIOSQLite:
read_only=False) -> List[dict]: read_only=False) -> List[dict]:
return await self._execute_fetch(sql, parameters, read_only, fetch_all=False) return await self._execute_fetch(sql, parameters, read_only, fetch_all=False)
def execute(self, sql: str, parameters: Iterable = None) -> Awaitable[sqlite3.Cursor]: def execute(self, sql: str, parameters: Iterable = None) -> Awaitable[apsw.Connection]:
parameters = parameters if parameters is not None else [] parameters = parameters if parameters is not None else []
return self.run(lambda conn: conn.execute(sql, parameters)) return self.run(lambda conn: conn.execute(sql, parameters).fetchall())
async def run(self, fun, *args, **kwargs): async def run(self, fun, *args, **kwargs):
self.writers += 1 self.writers += 1
@ -145,16 +159,16 @@ class AIOSQLite:
if not self.writers: if not self.writers:
self.read_ready.set() self.read_ready.set()
def __run_transaction(self, fun: Callable[[sqlite3.Connection, Any, Any], Any], *args, **kwargs): def __run_transaction(self, fun: Callable[[apsw.Connection, Any, Any], Any], *args, **kwargs):
self.writer_connection.execute('begin') self.writer_connection.cursor().execute('begin;')
try: try:
self.query_count += 1 self.query_count += 1
result = fun(self.writer_connection, *args, **kwargs) # type: ignore result = fun(self.writer_connection.cursor(), *args, **kwargs) # type: ignore
self.writer_connection.commit() self.writer_connection.cursor().execute('commit;')
return result return result
except (Exception, OSError) as e: except (Exception, OSError) as e:
log.exception('Error running transaction:', exc_info=e) log.exception('Error running transaction:', exc_info=e)
self.writer_connection.rollback() self.writer_connection.cursor().execute('rollback;')
log.warning("rolled back") log.warning("rolled back")
raise raise
@ -164,16 +178,16 @@ class AIOSQLite:
) )
def __run_transaction_with_foreign_keys_disabled(self, def __run_transaction_with_foreign_keys_disabled(self,
fun: Callable[[sqlite3.Connection, Any, Any], Any], fun: Callable[[apsw.Connection, Any, Any], Any],
args, kwargs): args, kwargs):
foreign_keys_enabled, = self.writer_connection.execute("pragma foreign_keys").fetchone() foreign_keys_enabled, = self.writer_connection.cursor().execute("pragma foreign_keys").fetchone()
if not foreign_keys_enabled: if not foreign_keys_enabled:
raise sqlite3.IntegrityError("foreign keys are disabled, use `AIOSQLite.run` instead") raise apsw.Error("foreign keys are disabled, use `AIOSQLite.run` instead")
try: try:
self.writer_connection.execute('pragma foreign_keys=off').fetchone() self.writer_connection.cursor().execute('pragma foreign_keys=off').fetchone()
return self.__run_transaction(fun, *args, **kwargs) return self.__run_transaction(fun, *args, **kwargs)
finally: finally:
self.writer_connection.execute('pragma foreign_keys=on').fetchone() self.writer_connection.cursor().execute('pragma foreign_keys=on').fetchone()
def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''): def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''):
@ -327,7 +341,7 @@ class SQLiteMixin:
async def open(self): async def open(self):
log.info("connecting to database: %s", self._db_path) log.info("connecting to database: %s", self._db_path)
self.db = await AIOSQLite.connect(self._db_path, isolation_level=None) self.db = await AIOSQLite.connect(self._db_path)
if self.SCHEMA_VERSION: if self.SCHEMA_VERSION:
tables = [t[0] for t in await self.db.execute_fetchall( tables = [t[0] for t in await self.db.execute_fetchall(
"SELECT name FROM sqlite_master WHERE type='table';" "SELECT name FROM sqlite_master WHERE type='table';"
@ -390,7 +404,6 @@ class Database(SQLiteMixin):
SCHEMA_VERSION = "1.2" SCHEMA_VERSION = "1.2"
PRAGMAS = """ PRAGMAS = """
pragma journal_mode=WAL;
""" """
CREATE_ACCOUNT_TABLE = """ CREATE_ACCOUNT_TABLE = """
@ -475,7 +488,10 @@ class Database(SQLiteMixin):
async def open(self): async def open(self):
await super().open() await super().open()
self.db.writer_connection.row_factory = dict_row_factory def exec_factory(cursor, statement, bindings):
cursor.setrowtrace(dict_row_factory)
return True
self.db.writer_connection.setexectrace(exec_factory)
def txo_to_row(self, tx, txo): def txo_to_row(self, tx, txo):
row = { row = {
@ -484,7 +500,7 @@ class Database(SQLiteMixin):
'address': txo.get_address(self.ledger), 'address': txo.get_address(self.ledger),
'position': txo.position, 'position': txo.position,
'amount': txo.amount, 'amount': txo.amount,
'script': sqlite3.Binary(txo.script.source) 'script': txo.script.source
} }
if txo.is_claim: if txo.is_claim:
if txo.can_decode_claim: if txo.can_decode_claim:
@ -510,7 +526,7 @@ class Database(SQLiteMixin):
def tx_to_row(tx): def tx_to_row(tx):
row = { row = {
'txid': tx.id, 'txid': tx.id,
'raw': sqlite3.Binary(tx.raw), 'raw': tx.raw,
'height': tx.height, 'height': tx.height,
'position': tx.position, 'position': tx.position,
'is_verified': tx.is_verified 'is_verified': tx.is_verified
@ -529,7 +545,7 @@ class Database(SQLiteMixin):
'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified 'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified
}, 'txid = ?', (tx.id,))) }, 'txid = ?', (tx.id,)))
def _transaction_io(self, conn: sqlite3.Connection, tx: Transaction, address, txhash): def _transaction_io(self, conn: apsw.Connection, tx: Transaction, address, txhash):
conn.execute(*self._insert_sql('tx', self.tx_to_row(tx), replace=True)).fetchall() conn.execute(*self._insert_sql('tx', self.tx_to_row(tx), replace=True)).fetchall()
is_my_input = False is_my_input = False
@ -933,8 +949,8 @@ class Database(SQLiteMixin):
"(account, address, chain, pubkey, chain_code, n, depth) values " "(account, address, chain, pubkey, chain_code, n, depth) values "
"(?, ?, ?, ?, ?, ?, ?)", (( "(?, ?, ?, ?, ?, ?, ?)", ((
account.id, k.address, chain, account.id, k.address, chain,
sqlite3.Binary(k.pubkey_bytes), k.pubkey_bytes,
sqlite3.Binary(k.chain_code), k.chain_code,
k.n, k.depth k.n, k.depth
) for k in pubkeys) ) for k in pubkeys)
) )

View file

@ -33,6 +33,7 @@ setup(
], ],
}, },
install_requires=[ install_requires=[
'apsw==3.30.1.post1',
'aiohttp==3.5.4', 'aiohttp==3.5.4',
'aioupnp==0.0.17', 'aioupnp==0.0.17',
'appdirs==1.4.3', 'appdirs==1.4.3',

View file

@ -1,7 +1,7 @@
import sys import sys
import os import os
import unittest import unittest
import sqlite3 import apsw
import tempfile import tempfile
import asyncio import asyncio
from concurrent.futures.thread import ThreadPoolExecutor from concurrent.futures.thread import ThreadPoolExecutor
@ -35,7 +35,7 @@ class TestAIOSQLite(AsyncioTestCase):
async def test_foreign_keys_integrity_error(self): async def test_foreign_keys_integrity_error(self):
self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent")) self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent"))
with self.assertRaises(sqlite3.IntegrityError): with self.assertRaises(apsw.ConstraintError):
await self.db.run(self.delete_item) await self.db.run(self.delete_item)
self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent")) self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent"))
@ -52,7 +52,7 @@ class TestAIOSQLite(AsyncioTestCase):
async def test_integrity_error_when_foreign_keys_disabled_and_skipped(self): async def test_integrity_error_when_foreign_keys_disabled_and_skipped(self):
await self.db.executescript("pragma foreign_keys=off;") await self.db.executescript("pragma foreign_keys=off;")
self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent")) self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent"))
with self.assertRaises(sqlite3.IntegrityError): with self.assertRaises(apsw.ConstraintError):
await self.db.run_with_foreign_keys_disabled(self.delete_item) await self.db.run_with_foreign_keys_disabled(self.delete_item)
self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent")) self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent"))
@ -401,25 +401,25 @@ class TestUpgrade(AsyncioTestCase):
os.remove(self.path) os.remove(self.path)
def get_version(self): def get_version(self):
with sqlite3.connect(self.path) as conn: with apsw.Connection(self.path) as conn:
versions = conn.execute('select version from version').fetchall() versions = conn.execute('select version from version').fetchall()
assert len(versions) == 1 assert len(versions) == 1
return versions[0][0] return versions[0][0]
def get_tables(self): def get_tables(self):
with sqlite3.connect(self.path) as conn: with apsw.Connection(self.path) as conn:
sql = "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;" sql = "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;"
return [col[0] for col in conn.execute(sql).fetchall()] return [col[0] for col in conn.execute(sql).fetchall()]
def add_address(self, address): def add_address(self, address):
with sqlite3.connect(self.path) as conn: with apsw.Connection(self.path) as conn:
conn.execute(""" conn.execute("""
INSERT INTO account_address (address, account, chain, n, pubkey, chain_code, depth) INSERT INTO account_address (address, account, chain, n, pubkey, chain_code, depth)
VALUES (?, 'account1', 0, 0, 'pubkey', 'chain_code', 0) VALUES (?, 'account1', 0, 0, 'pubkey', 'chain_code', 0)
""", (address,)) """, (address,))
def get_addresses(self): def get_addresses(self):
with sqlite3.connect(self.path) as conn: with apsw.Connection(self.path) as conn:
sql = "SELECT address FROM account_address ORDER BY address;" sql = "SELECT address FROM account_address ORDER BY address;"
return [col[0] for col in conn.execute(sql).fetchall()] return [col[0] for col in conn.execute(sql).fetchall()]
@ -472,7 +472,7 @@ class TestSQLiteRace(AsyncioTestCase):
max_misuse_attempts = 40000 max_misuse_attempts = 40000
def setup_db(self): def setup_db(self):
self.db = sqlite3.connect(":memory:", isolation_level=None) self.db = apsw.Connection(":memory:")
self.db.executescript( self.db.executescript(
"create table test1 (id text primary key not null, val text);\n" + "create table test1 (id text primary key not null, val text);\n" +
"create table test2 (id text primary key not null, val text);\n" + "create table test2 (id text primary key not null, val text);\n" +
@ -505,7 +505,7 @@ class TestSQLiteRace(AsyncioTestCase):
[(unsupported_type(1), ), (unsupported_type(2), )] [(unsupported_type(1), ), (unsupported_type(2), )]
) )
self.assertTrue(False) self.assertTrue(False)
except sqlite3.InterfaceError as err: except apsw.ConstraintError as err:
self.assertEqual(str(err), "Error binding parameter 0 - probably unsupported type.") self.assertEqual(str(err), "Error binding parameter 0 - probably unsupported type.")
async def test_unhandled_sqlite_misuse(self): async def test_unhandled_sqlite_misuse(self):
@ -529,13 +529,13 @@ class TestSQLiteRace(AsyncioTestCase):
) )
attempts += 1 attempts += 1
await asyncio.gather(f1, f2) await asyncio.gather(f1, f2)
print(f"\nsqlite3 {sqlite3.version}/python {python_version} " print(f"\nsqlite3 {apsw.sqlitelibversion()}/python {python_version} "
f"did not raise SQLITE_MISUSE within {attempts} attempts of the race condition") f"did not raise SQLITE_MISUSE within {attempts} attempts of the race condition")
self.assertTrue(False, 'this test failing means either the sqlite race conditions ' self.assertTrue(False, 'this test failing means either the sqlite race conditions '
'have been fixed in cpython or the test max_attempts needs to be increased') 'have been fixed in cpython or the test max_attempts needs to be increased')
except sqlite3.InterfaceError as err: except apsw.ConstraintError as err:
self.assertEqual(str(err), "Error binding parameter 0 - probably unsupported type.") self.assertEqual(str(err), "Error binding parameter 0 - probably unsupported type.")
print(f"\nsqlite3 {sqlite3.version}/python {python_version} raised SQLITE_MISUSE " print(f"\nsqlite3 {apsw.sqlitelibversion()}/python {python_version} raised SQLITE_MISUSE "
f"after {attempts} attempts of the race condition") f"after {attempts} attempts of the race condition")
@unittest.SkipTest @unittest.SkipTest