update test reader to use plain sqlite

This commit is contained in:
Lex Berezhny 2021-06-15 15:26:28 -04:00
parent 25e16c3565
commit b0371dd33d

View file

@ -1,6 +1,6 @@
import time import time
import struct import struct
import apsw import sqlite3
import logging import logging
from operator import itemgetter from operator import itemgetter
from typing import Tuple, List, Dict, Union, Type, Optional from typing import Tuple, List, Dict, Union, Type, Optional
@ -21,19 +21,20 @@ from lbry.wallet import Ledger, RegTestLedger
from lbry.wallet.server.db.common import CLAIM_TYPES, STREAM_TYPES, COMMON_TAGS, INDEXED_LANGUAGES from lbry.wallet.server.db.common import CLAIM_TYPES, STREAM_TYPES, COMMON_TAGS, INDEXED_LANGUAGES
class SQLiteOperationalError(apsw.Error): class SQLiteOperationalError(sqlite3.OperationalError):
def __init__(self, metrics): def __init__(self, metrics):
super().__init__('sqlite query errored') super().__init__('sqlite query errored')
self.metrics = metrics self.metrics = metrics
class SQLiteInterruptedError(apsw.InterruptError): class SQLiteInterruptedError(sqlite3.OperationalError):
def __init__(self, metrics): def __init__(self, metrics):
super().__init__('sqlite query interrupted') super().__init__('sqlite query interrupted')
self.metrics = metrics self.metrics = metrics
ATTRIBUTE_ARRAY_MAX_LENGTH = 100 ATTRIBUTE_ARRAY_MAX_LENGTH = 100
sqlite3.enable_callback_tracebacks(True)
INTEGER_PARAMS = { INTEGER_PARAMS = {
'height', 'creation_height', 'activation_height', 'expiration_height', 'height', 'creation_height', 'activation_height', 'expiration_height',
@ -63,7 +64,7 @@ ORDER_FIELDS = {
@dataclass @dataclass
class ReaderState: class ReaderState:
db: apsw.Connection db: sqlite3.Connection
stack: List[List] stack: List[List]
metrics: Dict metrics: Dict
is_tracking_metrics: bool is_tracking_metrics: bool
@ -90,7 +91,7 @@ class ReaderState:
self.db.interrupt() self.db.interrupt()
return return
self.db.setprogresshandler(interruptor, 100) self.db.set_progress_handler(interruptor, 100)
def get_resolve_censor(self) -> Censor: def get_resolve_censor(self) -> Censor:
return Censor(Censor.RESOLVE) return Censor(Censor.RESOLVE)
@ -105,13 +106,13 @@ ctx: ContextVar[Optional[ReaderState]] = ContextVar('ctx')
def row_factory(cursor, row): def row_factory(cursor, row):
return { return {
k[0]: (set(row[i].split(',')) if k[0] == 'tags' else row[i]) k[0]: (set(row[i].split(',')) if k[0] == 'tags' else row[i])
for i, k in enumerate(cursor.getdescription()) for i, k in enumerate(cursor.description)
} }
def initializer(log, _path, _ledger_name, query_timeout, _measure=False, block_and_filter=None): def initializer(log, _path, _ledger_name, query_timeout, _measure=False, block_and_filter=None):
db = apsw.Connection(_path, flags=apsw.SQLITE_OPEN_READONLY | apsw.SQLITE_OPEN_URI) db = sqlite3.connect(_path, isolation_level=None, uri=True)
db.setrowtrace(row_factory) db.row_factory = row_factory
if block_and_filter: if block_and_filter:
blocked_streams, blocked_channels, filtered_streams, filtered_channels = block_and_filter blocked_streams, blocked_channels, filtered_streams, filtered_channels = block_and_filter
else: else:
@ -184,31 +185,12 @@ def execute_query(sql, values, row_offset: int, row_limit: int, censor: Censor)
context = ctx.get() context = ctx.get()
context.set_query_timeout() context.set_query_timeout()
try: try:
c = context.db.cursor() rows = context.db.execute(sql, values).fetchall()
def row_filter(cursor, row): return rows[row_offset:row_limit]
nonlocal row_offset except sqlite3.OperationalError as err:
row = row_factory(cursor, row)
if len(row) > 1 and censor.censor(row):
return
if row_offset:
row_offset -= 1
return
return row
c.setrowtrace(row_filter)
i, rows = 0, []
for row in c.execute(sql, values):
i += 1
rows.append(row)
if i >= row_limit:
break
return rows
except apsw.Error as err:
plain_sql = interpolate(sql, values) plain_sql = interpolate(sql, values)
if context.is_tracking_metrics: if context.is_tracking_metrics:
context.metrics['execute_query'][-1]['sql'] = plain_sql context.metrics['execute_query'][-1]['sql'] = plain_sql
if isinstance(err, apsw.InterruptError):
context.log.warning("interrupted slow sqlite query:\n%s", plain_sql)
raise SQLiteInterruptedError(context.metrics)
context.log.exception('failed running query', exc_info=err) context.log.exception('failed running query', exc_info=err)
raise SQLiteOperationalError(context.metrics) raise SQLiteOperationalError(context.metrics)