diff --git a/tests/unit/wallet/server/reader.py b/tests/unit/wallet/server/reader.py index aef0a2369..78aa3869f 100644 --- a/tests/unit/wallet/server/reader.py +++ b/tests/unit/wallet/server/reader.py @@ -1,6 +1,6 @@ import time import struct -import apsw +import sqlite3 import logging from operator import itemgetter from typing import Tuple, List, Dict, Union, Type, Optional @@ -21,19 +21,20 @@ from lbry.wallet import Ledger, RegTestLedger from lbry.wallet.server.db.common import CLAIM_TYPES, STREAM_TYPES, COMMON_TAGS, INDEXED_LANGUAGES -class SQLiteOperationalError(apsw.Error): +class SQLiteOperationalError(sqlite3.OperationalError): def __init__(self, metrics): super().__init__('sqlite query errored') self.metrics = metrics -class SQLiteInterruptedError(apsw.InterruptError): +class SQLiteInterruptedError(sqlite3.OperationalError): def __init__(self, metrics): super().__init__('sqlite query interrupted') self.metrics = metrics ATTRIBUTE_ARRAY_MAX_LENGTH = 100 +sqlite3.enable_callback_tracebacks(True) INTEGER_PARAMS = { 'height', 'creation_height', 'activation_height', 'expiration_height', @@ -63,7 +64,7 @@ ORDER_FIELDS = { @dataclass class ReaderState: - db: apsw.Connection + db: sqlite3.Connection stack: List[List] metrics: Dict is_tracking_metrics: bool @@ -90,7 +91,7 @@ class ReaderState: self.db.interrupt() return - self.db.setprogresshandler(interruptor, 100) + self.db.set_progress_handler(interruptor, 100) def get_resolve_censor(self) -> Censor: return Censor(Censor.RESOLVE) @@ -105,13 +106,13 @@ ctx: ContextVar[Optional[ReaderState]] = ContextVar('ctx') def row_factory(cursor, row): return { k[0]: (set(row[i].split(',')) if k[0] == 'tags' else row[i]) - for i, k in enumerate(cursor.getdescription()) + for i, k in enumerate(cursor.description) } def initializer(log, _path, _ledger_name, query_timeout, _measure=False, block_and_filter=None): - db = apsw.Connection(_path, flags=apsw.SQLITE_OPEN_READONLY | apsw.SQLITE_OPEN_URI) - db.setrowtrace(row_factory) + db = sqlite3.connect(_path, isolation_level=None, uri=True) + db.row_factory = row_factory if block_and_filter: blocked_streams, blocked_channels, filtered_streams, filtered_channels = block_and_filter else: @@ -184,31 +185,12 @@ def execute_query(sql, values, row_offset: int, row_limit: int, censor: Censor) context = ctx.get() context.set_query_timeout() try: - c = context.db.cursor() - def row_filter(cursor, row): - nonlocal row_offset - row = row_factory(cursor, row) - if len(row) > 1 and censor.censor(row): - return - if row_offset: - row_offset -= 1 - return - return row - c.setrowtrace(row_filter) - i, rows = 0, [] - for row in c.execute(sql, values): - i += 1 - rows.append(row) - if i >= row_limit: - break - return rows - except apsw.Error as err: + rows = context.db.execute(sql, values).fetchall() + return rows[row_offset:row_limit] + except sqlite3.OperationalError as err: plain_sql = interpolate(sql, values) if context.is_tracking_metrics: context.metrics['execute_query'][-1]['sql'] = plain_sql - if isinstance(err, apsw.InterruptError): - context.log.warning("interrupted slow sqlite query:\n%s", plain_sql) - raise SQLiteInterruptedError(context.metrics) context.log.exception('failed running query', exc_info=err) raise SQLiteOperationalError(context.metrics)