sqlite3->apsw wip

This commit is contained in:
Victor Shyba 2020-03-25 03:54:40 -03:00
parent 121f34fde7
commit 4c709aad1a
4 changed files with 63 additions and 43 deletions

View file

@ -344,6 +344,9 @@ class SQLiteStorage(SQLiteMixin):
self.loop = loop or asyncio.get_event_loop()
self.time_getter = time_getter or time.time
async def open(self):
await super().open()
async def run_and_return_one_or_none(self, query, *args):
for row in await self.db.execute_fetchall(query, args):
if len(row) == 1:

View file

@ -1,7 +1,9 @@
import os
import logging
import asyncio
import sqlite3
from collections import namedtuple
import apsw
import platform
from binascii import hexlify
from dataclasses import dataclass
@ -16,21 +18,25 @@ from .constants import TXO_TYPES, CLAIM_TYPES
log = logging.getLogger(__name__)
sqlite3.enable_callback_tracebacks(True)
@dataclass
class ReaderProcessState:
cursor: sqlite3.Cursor
cursor: apsw.Connection
reader_context: Optional[ContextVar[ReaderProcessState]] = ContextVar('reader_context')
def initializer(path):
db = sqlite3.connect(path)
db.row_factory = dict_row_factory
db.executescript("pragma journal_mode=WAL;")
def initializer(path, reader=True):
db = apsw.Connection(
path,
flags=(
apsw.SQLITE_OPEN_READONLY if reader else apsw.SQLITE_OPEN_READWRITE |
apsw.SQLITE_OPEN_URI
)
)
db.cursor().execute("pragma journal_mode=WAL;")
reader = ReaderProcessState(db.cursor())
reader_context.set(reader)
@ -65,7 +71,7 @@ class AIOSQLite:
def __init__(self):
# has to be single threaded as there is no mapping of thread:connection
self.writer_executor = ThreadPoolExecutor(max_workers=1)
self.writer_connection: Optional[sqlite3.Connection] = None
self.writer_connection: Optional[apsw.Connection] = None
self._closing = False
self.query_count = 0
self.write_lock = asyncio.Lock()
@ -74,11 +80,19 @@ class AIOSQLite:
@classmethod
async def connect(cls, path: Union[bytes, str], *args, **kwargs):
sqlite3.enable_callback_tracebacks(True)
db = cls()
def _connect_writer():
db.writer_connection = sqlite3.connect(path, *args, **kwargs)
db.writer_connection = apsw.Connection(
path,
flags=(
apsw.SQLITE_OPEN_READWRITE |
apsw.SQLITE_OPEN_CREATE |
apsw.SQLITE_OPEN_URI
)
)
db.writer_connection.cursor().execute('pragma journal_mode=WAL;').fetchone()
db.writer_connection.cursor().execute('pragma foreign_keys=on').fetchone()
readers = max(os.cpu_count() - 2, 2)
db.reader_executor = ReaderExecutorClass(
@ -104,7 +118,7 @@ class AIOSQLite:
return self.run(lambda conn: conn.executemany(sql, params).fetchall())
def executescript(self, script: str) -> Awaitable:
return self.run(lambda conn: conn.executescript(script))
return self.run(lambda conn: conn.execute(script).fetchall())
async def _execute_fetch(self, sql: str, parameters: Iterable = None,
read_only=False, fetch_all: bool = False) -> List[dict]:
@ -128,9 +142,9 @@ class AIOSQLite:
read_only=False) -> List[dict]:
return await self._execute_fetch(sql, parameters, read_only, fetch_all=False)
def execute(self, sql: str, parameters: Iterable = None) -> Awaitable[sqlite3.Cursor]:
def execute(self, sql: str, parameters: Iterable = None) -> Awaitable[apsw.Connection]:
parameters = parameters if parameters is not None else []
return self.run(lambda conn: conn.execute(sql, parameters))
return self.run(lambda conn: conn.execute(sql, parameters).fetchall())
async def run(self, fun, *args, **kwargs):
self.writers += 1
@ -145,16 +159,16 @@ class AIOSQLite:
if not self.writers:
self.read_ready.set()
def __run_transaction(self, fun: Callable[[sqlite3.Connection, Any, Any], Any], *args, **kwargs):
self.writer_connection.execute('begin')
def __run_transaction(self, fun: Callable[[apsw.Connection, Any, Any], Any], *args, **kwargs):
self.writer_connection.cursor().execute('begin;')
try:
self.query_count += 1
result = fun(self.writer_connection, *args, **kwargs) # type: ignore
self.writer_connection.commit()
result = fun(self.writer_connection.cursor(), *args, **kwargs) # type: ignore
self.writer_connection.cursor().execute('commit;')
return result
except (Exception, OSError) as e:
log.exception('Error running transaction:', exc_info=e)
self.writer_connection.rollback()
self.writer_connection.cursor().execute('rollback;')
log.warning("rolled back")
raise
@ -164,16 +178,16 @@ class AIOSQLite:
)
def __run_transaction_with_foreign_keys_disabled(self,
fun: Callable[[sqlite3.Connection, Any, Any], Any],
fun: Callable[[apsw.Connection, Any, Any], Any],
args, kwargs):
foreign_keys_enabled, = self.writer_connection.execute("pragma foreign_keys").fetchone()
foreign_keys_enabled, = self.writer_connection.cursor().execute("pragma foreign_keys").fetchone()
if not foreign_keys_enabled:
raise sqlite3.IntegrityError("foreign keys are disabled, use `AIOSQLite.run` instead")
raise apsw.Error("foreign keys are disabled, use `AIOSQLite.run` instead")
try:
self.writer_connection.execute('pragma foreign_keys=off').fetchone()
self.writer_connection.cursor().execute('pragma foreign_keys=off').fetchone()
return self.__run_transaction(fun, *args, **kwargs)
finally:
self.writer_connection.execute('pragma foreign_keys=on').fetchone()
self.writer_connection.cursor().execute('pragma foreign_keys=on').fetchone()
def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''):
@ -327,7 +341,7 @@ class SQLiteMixin:
async def open(self):
log.info("connecting to database: %s", self._db_path)
self.db = await AIOSQLite.connect(self._db_path, isolation_level=None)
self.db = await AIOSQLite.connect(self._db_path)
if self.SCHEMA_VERSION:
tables = [t[0] for t in await self.db.execute_fetchall(
"SELECT name FROM sqlite_master WHERE type='table';"
@ -390,7 +404,6 @@ class Database(SQLiteMixin):
SCHEMA_VERSION = "1.2"
PRAGMAS = """
pragma journal_mode=WAL;
"""
CREATE_ACCOUNT_TABLE = """
@ -475,7 +488,10 @@ class Database(SQLiteMixin):
async def open(self):
await super().open()
self.db.writer_connection.row_factory = dict_row_factory
def exec_factory(cursor, statement, bindings):
cursor.setrowtrace(dict_row_factory)
return True
self.db.writer_connection.setexectrace(exec_factory)
def txo_to_row(self, tx, txo):
row = {
@ -484,7 +500,7 @@ class Database(SQLiteMixin):
'address': txo.get_address(self.ledger),
'position': txo.position,
'amount': txo.amount,
'script': sqlite3.Binary(txo.script.source)
'script': txo.script.source
}
if txo.is_claim:
if txo.can_decode_claim:
@ -510,7 +526,7 @@ class Database(SQLiteMixin):
def tx_to_row(tx):
row = {
'txid': tx.id,
'raw': sqlite3.Binary(tx.raw),
'raw': tx.raw,
'height': tx.height,
'position': tx.position,
'is_verified': tx.is_verified
@ -529,7 +545,7 @@ class Database(SQLiteMixin):
'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified
}, 'txid = ?', (tx.id,)))
def _transaction_io(self, conn: sqlite3.Connection, tx: Transaction, address, txhash):
def _transaction_io(self, conn: apsw.Connection, tx: Transaction, address, txhash):
conn.execute(*self._insert_sql('tx', self.tx_to_row(tx), replace=True)).fetchall()
is_my_input = False
@ -933,8 +949,8 @@ class Database(SQLiteMixin):
"(account, address, chain, pubkey, chain_code, n, depth) values "
"(?, ?, ?, ?, ?, ?, ?)", ((
account.id, k.address, chain,
sqlite3.Binary(k.pubkey_bytes),
sqlite3.Binary(k.chain_code),
k.pubkey_bytes,
k.chain_code,
k.n, k.depth
) for k in pubkeys)
)

View file

@ -33,6 +33,7 @@ setup(
],
},
install_requires=[
'apsw==3.30.1.post1',
'aiohttp==3.5.4',
'aioupnp==0.0.17',
'appdirs==1.4.3',

View file

@ -1,7 +1,7 @@
import sys
import os
import unittest
import sqlite3
import apsw
import tempfile
import asyncio
from concurrent.futures.thread import ThreadPoolExecutor
@ -35,7 +35,7 @@ class TestAIOSQLite(AsyncioTestCase):
async def test_foreign_keys_integrity_error(self):
self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent"))
with self.assertRaises(sqlite3.IntegrityError):
with self.assertRaises(apsw.ConstraintError):
await self.db.run(self.delete_item)
self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent"))
@ -52,7 +52,7 @@ class TestAIOSQLite(AsyncioTestCase):
async def test_integrity_error_when_foreign_keys_disabled_and_skipped(self):
await self.db.executescript("pragma foreign_keys=off;")
self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent"))
with self.assertRaises(sqlite3.IntegrityError):
with self.assertRaises(apsw.ConstraintError):
await self.db.run_with_foreign_keys_disabled(self.delete_item)
self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent"))
@ -401,25 +401,25 @@ class TestUpgrade(AsyncioTestCase):
os.remove(self.path)
def get_version(self):
with sqlite3.connect(self.path) as conn:
with apsw.Connection(self.path) as conn:
versions = conn.execute('select version from version').fetchall()
assert len(versions) == 1
return versions[0][0]
def get_tables(self):
with sqlite3.connect(self.path) as conn:
with apsw.Connection(self.path) as conn:
sql = "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;"
return [col[0] for col in conn.execute(sql).fetchall()]
def add_address(self, address):
with sqlite3.connect(self.path) as conn:
with apsw.Connection(self.path) as conn:
conn.execute("""
INSERT INTO account_address (address, account, chain, n, pubkey, chain_code, depth)
VALUES (?, 'account1', 0, 0, 'pubkey', 'chain_code', 0)
""", (address,))
def get_addresses(self):
with sqlite3.connect(self.path) as conn:
with apsw.Connection(self.path) as conn:
sql = "SELECT address FROM account_address ORDER BY address;"
return [col[0] for col in conn.execute(sql).fetchall()]
@ -472,7 +472,7 @@ class TestSQLiteRace(AsyncioTestCase):
max_misuse_attempts = 40000
def setup_db(self):
self.db = sqlite3.connect(":memory:", isolation_level=None)
self.db = apsw.Connection(":memory:")
self.db.executescript(
"create table test1 (id text primary key not null, val text);\n" +
"create table test2 (id text primary key not null, val text);\n" +
@ -505,7 +505,7 @@ class TestSQLiteRace(AsyncioTestCase):
[(unsupported_type(1), ), (unsupported_type(2), )]
)
self.assertTrue(False)
except sqlite3.InterfaceError as err:
except apsw.ConstraintError as err:
self.assertEqual(str(err), "Error binding parameter 0 - probably unsupported type.")
async def test_unhandled_sqlite_misuse(self):
@ -529,13 +529,13 @@ class TestSQLiteRace(AsyncioTestCase):
)
attempts += 1
await asyncio.gather(f1, f2)
print(f"\nsqlite3 {sqlite3.version}/python {python_version} "
print(f"\nsqlite3 {apsw.sqlitelibversion()}/python {python_version} "
f"did not raise SQLITE_MISUSE within {attempts} attempts of the race condition")
self.assertTrue(False, 'this test failing means either the sqlite race conditions '
'have been fixed in cpython or the test max_attempts needs to be increased')
except sqlite3.InterfaceError as err:
except apsw.ConstraintError as err:
self.assertEqual(str(err), "Error binding parameter 0 - probably unsupported type.")
print(f"\nsqlite3 {sqlite3.version}/python {python_version} raised SQLITE_MISUSE "
print(f"\nsqlite3 {apsw.sqlitelibversion()}/python {python_version} raised SQLITE_MISUSE "
f"after {attempts} attempts of the race condition")
@unittest.SkipTest