diff --git a/lbry/lbry/wallet/server/block_processor.py b/lbry/lbry/wallet/server/block_processor.py index 4c4897e65..d86eda514 100644 --- a/lbry/lbry/wallet/server/block_processor.py +++ b/lbry/lbry/wallet/server/block_processor.py @@ -80,9 +80,9 @@ class LBRYBlockProcessor(BlockProcessor): finally: self.sql.commit() if self.db.first_sync and self.height == self.daemon.cached_height(): - self.timer.run(self.sql.db.executescript, self.sql.SEARCH_INDEXES, timer_name='executing SEARCH_INDEXES') + self.timer.run(self.sql.execute, self.sql.SEARCH_INDEXES, timer_name='executing SEARCH_INDEXES') if self.env.individual_tag_indexes: - self.timer.run(self.sql.db.executescript, self.sql.TAG_INDEXES, timer_name='executing TAG_INDEXES') + self.timer.run(self.sql.execute, self.sql.TAG_INDEXES, timer_name='executing TAG_INDEXES') for cache in self.search_cache.values(): cache.clear() diff --git a/lbry/lbry/wallet/server/db/canonical.py b/lbry/lbry/wallet/server/db/canonical.py index dd2f83483..a85fc8369 100644 --- a/lbry/lbry/wallet/server/db/canonical.py +++ b/lbry/lbry/wallet/server/db/canonical.py @@ -14,8 +14,13 @@ class FindShortestID: break def finalize(self): - return '#'+self.short_id + if self.short_id: + return '#'+self.short_id + + @classmethod + def factory(cls): + return cls(), cls.step, cls.finalize def register_canonical_functions(connection): - connection.create_aggregate("shortest_id", 2, FindShortestID) + connection.createaggregatefunction("shortest_id", FindShortestID.factory, 2) diff --git a/lbry/lbry/wallet/server/db/full_text_search.py b/lbry/lbry/wallet/server/db/full_text_search.py index ed4fe834d..c553fc1b2 100644 --- a/lbry/lbry/wallet/server/db/full_text_search.py +++ b/lbry/lbry/wallet/server/db/full_text_search.py @@ -1,4 +1,3 @@ -import sqlite3 from torba.client.basedatabase import constraints_to_sql CREATE_FULL_TEXT_SEARCH = """ @@ -26,9 +25,7 @@ def fts_action_sql(claims=None, action='insert'): where, values = "", {} if claims: - where, values = constraints_to_sql({ - 'claim.claim_hash__in': [sqlite3.Binary(claim_hash) for claim_hash in claims] - }) + where, values = constraints_to_sql({'claim.claim_hash__in': claims}) where = 'WHERE '+where return f""" diff --git a/lbry/lbry/wallet/server/db/reader.py b/lbry/lbry/wallet/server/db/reader.py index ec18310a9..2f3d71421 100644 --- a/lbry/lbry/wallet/server/db/reader.py +++ b/lbry/lbry/wallet/server/db/reader.py @@ -1,6 +1,6 @@ import time import struct -import sqlite3 +import apsw import logging from operator import itemgetter from typing import Tuple, List, Dict, Union, Type, Optional @@ -21,20 +21,19 @@ from .common import CLAIM_TYPES, STREAM_TYPES, COMMON_TAGS from .full_text_search import FTS_ORDER_BY -class SQLiteOperationalError(sqlite3.OperationalError): +class SQLiteOperationalError(apsw.Error): def __init__(self, metrics): super().__init__('sqlite query errored') self.metrics = metrics -class SQLiteInterruptedError(sqlite3.OperationalError): +class SQLiteInterruptedError(apsw.InterruptError): def __init__(self, metrics): super().__init__('sqlite query interrupted') self.metrics = metrics ATTRIBUTE_ARRAY_MAX_LENGTH = 100 -sqlite3.enable_callback_tracebacks(True) INTEGER_PARAMS = { 'height', 'creation_height', 'activation_height', 'expiration_height', @@ -62,14 +61,9 @@ ORDER_FIELDS = { } | INTEGER_PARAMS -PRAGMAS = """ - pragma journal_mode=WAL; -""" - - @dataclass class ReaderState: - db: sqlite3.Connection + db: apsw.Connection stack: List[List] metrics: Dict is_tracking_metrics: bool @@ -92,15 +86,17 @@ class ReaderState: self.db.interrupt() return - self.db.set_progress_handler(interruptor, 100) + self.db.setprogresshandler(interruptor, 100) ctx: ContextVar[Optional[ReaderState]] = ContextVar('ctx') def initializer(log, _path, _ledger_name, query_timeout, _measure=False): - db = sqlite3.connect(_path, isolation_level=None, uri=True) - db.row_factory = sqlite3.Row + db = apsw.Connection(_path, flags=apsw.SQLITE_OPEN_READONLY | apsw.SQLITE_OPEN_URI) + def row_factory(cursor, row): + return {k[0]: row[i] for i, k in enumerate(cursor.getdescription())} + db.setrowtrace(row_factory) ctx.set( ReaderState( db=db, stack=[], metrics={}, is_tracking_metrics=_measure, @@ -167,8 +163,8 @@ def execute_query(sql, values) -> List: context = ctx.get() context.set_query_timeout() try: - return context.db.execute(sql, values).fetchall() - except sqlite3.OperationalError as err: + return context.db.cursor().execute(sql, values).fetchall() + except apsw.Error as err: plain_sql = interpolate(sql, values) if context.is_tracking_metrics: context.metrics['execute_query'][-1]['sql'] = plain_sql @@ -231,27 +227,27 @@ def _get_claims(cols, for_count=False, **constraints) -> Tuple[str, Dict]: constraints['claim.claim_id__in'] = constraints.pop('claim_ids') if 'reposted_claim_id' in constraints: - constraints['claim.reposted_claim_hash'] = sqlite3.Binary(unhexlify(constraints.pop('reposted_claim_id'))[::-1]) + constraints['claim.reposted_claim_hash'] = unhexlify(constraints.pop('reposted_claim_id'))[::-1] if 'name' in constraints: constraints['claim.normalized'] = normalize_name(constraints.pop('name')) if 'public_key_id' in constraints: - constraints['claim.public_key_hash'] = sqlite3.Binary( + constraints['claim.public_key_hash'] = ( ctx.get().ledger.address_to_hash160(constraints.pop('public_key_id'))) if 'channel_hash' in constraints: - constraints['claim.channel_hash'] = sqlite3.Binary(constraints.pop('channel_hash')) + constraints['claim.channel_hash'] = constraints.pop('channel_hash') if 'channel_ids' in constraints: channel_ids = constraints.pop('channel_ids') if channel_ids: constraints['claim.channel_hash__in'] = [ - sqlite3.Binary(unhexlify(cid)[::-1]) for cid in channel_ids + unhexlify(cid)[::-1] for cid in channel_ids ] if 'not_channel_ids' in constraints: not_channel_ids = constraints.pop('not_channel_ids') if not_channel_ids: not_channel_ids_binary = [ - sqlite3.Binary(unhexlify(ncid)[::-1]) for ncid in not_channel_ids + unhexlify(ncid)[::-1] for ncid in not_channel_ids ] if constraints.get('has_channel_signature', False): constraints['claim.channel_hash__not_in'] = not_channel_ids_binary @@ -264,7 +260,7 @@ def _get_claims(cols, for_count=False, **constraints) -> Tuple[str, Dict]: blocklist_ids = constraints.pop('blocklist_channel_ids') if blocklist_ids: blocking_channels = [ - sqlite3.Binary(unhexlify(channel_id)[::-1]) for channel_id in blocklist_ids + unhexlify(channel_id)[::-1] for channel_id in blocklist_ids ] constraints.update({ f'$blocking_channel{i}': a for i, a in enumerate(blocking_channels) @@ -290,9 +286,7 @@ def _get_claims(cols, for_count=False, **constraints) -> Tuple[str, Dict]: if 'txid' in constraints: tx_hash = unhexlify(constraints.pop('txid'))[::-1] nout = constraints.pop('nout', 0) - constraints['claim.txo_hash'] = sqlite3.Binary( - tx_hash + struct.pack(' List: if 'channel' in constraints: channel_url = constraints.pop('channel') match = resolve_url(channel_url) - if isinstance(match, sqlite3.Row): + if isinstance(match, dict): constraints['channel_hash'] = match['claim_hash'] else: - return [[0]] if cols == 'count(*)' else [] + return [{'row_count': 0}] if cols == 'count(*) as row_count' else [] sql, values = _get_claims(cols, for_count, **constraints) return execute_query(sql, values) @@ -342,8 +336,8 @@ def get_claims_count(**constraints) -> int: constraints.pop('offset', None) constraints.pop('limit', None) constraints.pop('order_by', None) - count = get_claims('count(*)', for_count=True, **constraints) - return count[0][0] + count = get_claims('count(*) as row_count', for_count=True, **constraints) + return count[0]['row_count'] def _search(**constraints): @@ -365,22 +359,18 @@ def _search(**constraints): ) -def _get_referenced_rows(txo_rows: List[sqlite3.Row]): +def _get_referenced_rows(txo_rows: List[dict]): repost_hashes = set(filter(None, map(itemgetter('reposted_claim_hash'), txo_rows))) channel_hashes = set(filter(None, map(itemgetter('channel_hash'), txo_rows))) reposted_txos = [] if repost_hashes: - reposted_txos = _search( - **{'claim.claim_hash__in': [sqlite3.Binary(h) for h in repost_hashes]} - ) + reposted_txos = _search(**{'claim.claim_hash__in': repost_hashes}) channel_hashes |= set(filter(None, map(itemgetter('channel_hash'), reposted_txos))) channel_txos = [] if channel_hashes: - channel_txos = _search( - **{'claim.claim_hash__in': [sqlite3.Binary(h) for h in channel_hashes]} - ) + channel_txos = _search(**{'claim.claim_hash__in': channel_hashes}) # channels must come first for client side inflation to work properly return channel_txos + reposted_txos @@ -405,7 +395,7 @@ def search(constraints) -> Tuple[List, List, int, int]: @measure def resolve(urls) -> Tuple[List, List]: txo_rows = [resolve_url(raw_url) for raw_url in urls] - extra_txo_rows = _get_referenced_rows([r for r in txo_rows if isinstance(r, sqlite3.Row)]) + extra_txo_rows = _get_referenced_rows([r for r in txo_rows if isinstance(r, dict)]) return txo_rows, extra_txo_rows diff --git a/lbry/lbry/wallet/server/db/trending.py b/lbry/lbry/wallet/server/db/trending.py index 3ec0562b5..0fc425a3b 100644 --- a/lbry/lbry/wallet/server/db/trending.py +++ b/lbry/lbry/wallet/server/db/trending.py @@ -47,9 +47,13 @@ class ZScore: return self.last return (self.last - self.mean) / (self.standard_deviation or 1) + @classmethod + def factory(cls): + return cls(), cls.step, cls.finalize + def register_trending_functions(connection): - connection.create_aggregate("zscore", 1, ZScore) + connection.createaggregatefunction("zscore", ZScore.factory, 1) def calculate_trending(db, height, final_height): @@ -75,8 +79,8 @@ def calculate_trending(db, height, final_height): """) zscore = ZScore() - for (global_sum,) in db.execute("SELECT AVG(amount) FROM trend GROUP BY height"): - zscore.step(global_sum) + for global_sum in db.execute("SELECT AVG(amount) AS avg_amount FROM trend GROUP BY height"): + zscore.step(global_sum.avg_amount) global_mean, global_deviation = 0, 1 if zscore.count > 0: global_mean = zscore.mean diff --git a/lbry/lbry/wallet/server/db/writer.py b/lbry/lbry/wallet/server/db/writer.py index a3e149160..712db014d 100644 --- a/lbry/lbry/wallet/server/db/writer.py +++ b/lbry/lbry/wallet/server/db/writer.py @@ -1,8 +1,9 @@ import os -import sqlite3 +import apsw from typing import Union, Tuple, Set, List from itertools import chain from decimal import Decimal +from collections import namedtuple from torba.server.db import DB from torba.server.util import class_logger @@ -22,7 +23,6 @@ from .common import CLAIM_TYPES, STREAM_TYPES, COMMON_TAGS ATTRIBUTE_ARRAY_MAX_LENGTH = 100 -sqlite3.enable_callback_tracebacks(True) class SQLDB: @@ -158,7 +158,6 @@ class SQLDB: ) CREATE_TABLES_QUERY = ( - PRAGMAS + CREATE_CLAIM_TABLE + CREATE_TREND_TABLE + CREATE_FULL_TEXT_SEARCH + @@ -176,14 +175,26 @@ class SQLDB: self._fts_synced = False def open(self): - self.db = sqlite3.connect(self._db_path, isolation_level=None, check_same_thread=False, uri=True) - self.db.row_factory = sqlite3.Row - self.db.executescript(self.CREATE_TABLES_QUERY) + self.db = apsw.Connection( + self._db_path, + flags=( + apsw.SQLITE_OPEN_READWRITE | + apsw.SQLITE_OPEN_CREATE | + apsw.SQLITE_OPEN_URI + ) + ) + def exec_factory(cursor, statement, bindings): + tpl = namedtuple('row', (d[0] for d in cursor.getdescription())) + cursor.setrowtrace(lambda cursor, row: tpl(*row)) + return True + self.db.setexectrace(exec_factory) + self.execute(self.CREATE_TABLES_QUERY) register_canonical_functions(self.db) register_trending_functions(self.db) def close(self): - self.db.close() + if self.db is not None: + self.db.close() @staticmethod def _insert_sql(table: str, data: dict) -> Tuple[str, list]: @@ -213,7 +224,10 @@ class SQLDB: return f"DELETE FROM {table} WHERE {where}", values def execute(self, *args): - return self.db.execute(*args) + return self.db.cursor().execute(*args) + + def executemany(self, *args): + return self.db.cursor().executemany(*args) def begin(self): self.execute('begin;') @@ -222,7 +236,7 @@ class SQLDB: self.execute('commit;') def _upsertable_claims(self, txos: List[Output], header, clear_first=False): - claim_hashes, claims, tags = [], [], {} + claim_hashes, claims, tags = set(), [], {} for txo in txos: tx = txo.tx_ref.tx @@ -233,14 +247,14 @@ class SQLDB: #self.logger.exception(f"Could not decode claim name for {tx.id}:{txo.position}.") continue - claim_hash = sqlite3.Binary(txo.claim_hash) - claim_hashes.append(claim_hash) + claim_hash = txo.claim_hash + claim_hashes.add(claim_hash) claim_record = { 'claim_hash': claim_hash, 'claim_id': txo.claim_id, 'claim_name': txo.claim_name, 'normalized': txo.normalized_name, - 'txo_hash': sqlite3.Binary(txo.ref.hash), + 'txo_hash': txo.ref.hash, 'tx_position': tx.position, 'amount': txo.amount, 'timestamp': header['timestamp'], @@ -291,7 +305,7 @@ class SQLDB: self._clear_claim_metadata(claim_hashes) if tags: - self.db.executemany( + self.executemany( "INSERT OR IGNORE INTO tag (tag, claim_hash, height) VALUES (?, ?, ?)", tags.values() ) @@ -300,7 +314,7 @@ class SQLDB: def insert_claims(self, txos: List[Output], header): claims = self._upsertable_claims(txos, header) if claims: - self.db.executemany(""" + self.executemany(""" INSERT OR IGNORE INTO claim ( claim_hash, claim_id, claim_name, normalized, txo_hash, tx_position, amount, claim_type, media_type, stream_type, timestamp, creation_timestamp, @@ -322,7 +336,7 @@ class SQLDB: def update_claims(self, txos: List[Output], header): claims = self._upsertable_claims(txos, header, clear_first=True) if claims: - self.db.executemany(""" + self.executemany(""" UPDATE claim SET txo_hash=:txo_hash, tx_position=:tx_position, amount=:amount, height=:height, claim_type=:claim_type, media_type=:media_type, stream_type=:stream_type, @@ -335,35 +349,32 @@ class SQLDB: def delete_claims(self, claim_hashes: Set[bytes]): """ Deletes claim supports and from claimtrie in case of an abandon. """ if claim_hashes: - binary_claim_hashes = [sqlite3.Binary(claim_hash) for claim_hash in claim_hashes] affected_channels = self.execute(*query( - "SELECT channel_hash FROM claim", channel_hash__is_not_null=1, claim_hash__in=binary_claim_hashes + "SELECT channel_hash FROM claim", channel_hash__is_not_null=1, claim_hash__in=claim_hashes )).fetchall() for table in ('claim', 'support', 'claimtrie'): - self.execute(*self._delete_sql(table, {'claim_hash__in': binary_claim_hashes})) - self._clear_claim_metadata(binary_claim_hashes) - return {r['channel_hash'] for r in affected_channels} + self.execute(*self._delete_sql(table, {'claim_hash__in': claim_hashes})) + self._clear_claim_metadata(claim_hashes) + return {r.channel_hash for r in affected_channels} return set() - def _clear_claim_metadata(self, binary_claim_hashes: List[sqlite3.Binary]): - if binary_claim_hashes: + def _clear_claim_metadata(self, claim_hashes: Set[bytes]): + if claim_hashes: for table in ('tag',): # 'language', 'location', etc - self.execute(*self._delete_sql(table, {'claim_hash__in': binary_claim_hashes})) + self.execute(*self._delete_sql(table, {'claim_hash__in': claim_hashes})) def split_inputs_into_claims_supports_and_other(self, txis): txo_hashes = {txi.txo_ref.hash for txi in txis} claims = self.execute(*query( - "SELECT txo_hash, claim_hash, normalized FROM claim", - txo_hash__in=[sqlite3.Binary(txo_hash) for txo_hash in txo_hashes] + "SELECT txo_hash, claim_hash, normalized FROM claim", txo_hash__in=txo_hashes )).fetchall() - txo_hashes -= {r['txo_hash'] for r in claims} + txo_hashes -= {r.txo_hash for r in claims} supports = {} if txo_hashes: supports = self.execute(*query( - "SELECT txo_hash, claim_hash FROM support", - txo_hash__in=[sqlite3.Binary(txo_hash) for txo_hash in txo_hashes] + "SELECT txo_hash, claim_hash FROM support", txo_hash__in=txo_hashes )).fetchall() - txo_hashes -= {r['txo_hash'] for r in supports} + txo_hashes -= {r.txo_hash for r in supports} return claims, supports, txo_hashes def insert_supports(self, txos: List[Output]): @@ -371,11 +382,11 @@ class SQLDB: for txo in txos: tx = txo.tx_ref.tx supports.append(( - sqlite3.Binary(txo.ref.hash), tx.position, tx.height, - sqlite3.Binary(txo.claim_hash), txo.amount + txo.ref.hash, tx.position, tx.height, + txo.claim_hash, txo.amount )) if supports: - self.db.executemany( + self.executemany( "INSERT OR IGNORE INTO support (" " txo_hash, tx_position, height, claim_hash, amount" ") " @@ -384,9 +395,7 @@ class SQLDB: def delete_supports(self, txo_hashes: Set[bytes]): if txo_hashes: - self.execute(*self._delete_sql( - 'support', {'txo_hash__in': [sqlite3.Binary(txo_hash) for txo_hash in txo_hashes]} - )) + self.execute(*self._delete_sql('support', {'txo_hash__in': txo_hashes})) def calculate_reposts(self, txos: List[Output]): targets = set() @@ -398,7 +407,7 @@ class SQLDB: if claim.is_repost: targets.add((claim.repost.reference.claim_hash,)) if targets: - self.db.executemany( + self.executemany( """ UPDATE claim SET reposted = ( SELECT count(*) FROM claim AS repost WHERE repost.reposted_claim_hash = claim.claim_hash @@ -441,10 +450,7 @@ class SQLDB: if new_channel_keys or missing_channel_keys or affected_channels: all_channel_keys = dict(self.execute(*query( "SELECT claim_hash, public_key_bytes FROM claim", - claim_hash__in=[ - sqlite3.Binary(channel_hash) for channel_hash in - set(new_channel_keys) | missing_channel_keys | affected_channels - ] + claim_hash__in=set(new_channel_keys) | missing_channel_keys | affected_channels ))) sub_timer.stop() @@ -461,7 +467,7 @@ class SQLDB: for claim_hash, txo in signables.items(): claim = txo.claim update = { - 'claim_hash': sqlite3.Binary(claim_hash), + 'claim_hash': claim_hash, 'channel_hash': None, 'signature': None, 'signature_digest': None, @@ -469,9 +475,9 @@ class SQLDB: } if claim.is_signed: update.update({ - 'channel_hash': sqlite3.Binary(claim.signing_channel_hash), - 'signature': sqlite3.Binary(txo.get_encoded_signature()), - 'signature_digest': sqlite3.Binary(txo.get_signature_digest(self.ledger)), + 'channel_hash': claim.signing_channel_hash, + 'signature': txo.get_encoded_signature(), + 'signature_digest': txo.get_signature_digest(self.ledger), 'signature_valid': 0 }) claim_updates.append(update) @@ -485,13 +491,13 @@ class SQLDB: channel_hash IN ({','.join('?' for _ in changed_channel_keys)}) AND signature IS NOT NULL """ - for affected_claim in self.execute(sql, [sqlite3.Binary(h) for h in changed_channel_keys]): - if affected_claim['claim_hash'] not in signables: + for affected_claim in self.execute(sql, changed_channel_keys): + if affected_claim.claim_hash not in signables: claim_updates.append({ - 'claim_hash': sqlite3.Binary(affected_claim['claim_hash']), - 'channel_hash': sqlite3.Binary(affected_claim['channel_hash']), - 'signature': sqlite3.Binary(affected_claim['signature']), - 'signature_digest': sqlite3.Binary(affected_claim['signature_digest']), + 'claim_hash': affected_claim.claim_hash, + 'channel_hash': affected_claim.channel_hash, + 'signature': affected_claim.signature, + 'signature_digest': affected_claim.signature_digest, 'signature_valid': 0 }) sub_timer.stop() @@ -509,7 +515,7 @@ class SQLDB: sub_timer = timer.add_timer('update claims') sub_timer.start() if claim_updates: - self.db.executemany(f""" + self.executemany(f""" UPDATE claim SET channel_hash=:channel_hash, signature=:signature, signature_digest=:signature_digest, signature_valid=:signature_valid, @@ -542,24 +548,24 @@ class SQLDB: signature_valid=CASE WHEN signature IS NOT NULL THEN 0 END, channel_join=NULL, canonical_url=NULL WHERE channel_hash IN ({','.join('?' for _ in spent_claims)}) - """, [sqlite3.Binary(cid) for cid in spent_claims] + """, spent_claims ) sub_timer.stop() sub_timer = timer.add_timer('update channels') sub_timer.start() if channels: - self.db.executemany( + self.executemany( """ UPDATE claim SET public_key_bytes=:public_key_bytes, public_key_hash=:public_key_hash WHERE claim_hash=:claim_hash""", [{ - 'claim_hash': sqlite3.Binary(claim_hash), - 'public_key_bytes': sqlite3.Binary(txo.claim.channel.public_key_bytes), - 'public_key_hash': sqlite3.Binary( - self.ledger.address_to_hash160( - self.ledger.public_key_to_address(txo.claim.channel.public_key_bytes))) + 'claim_hash': claim_hash, + 'public_key_bytes': txo.claim.channel.public_key_bytes, + 'public_key_hash': self.ledger.address_to_hash160( + self.ledger.public_key_to_address(txo.claim.channel.public_key_bytes) + ) } for claim_hash, txo in channels.items()] ) sub_timer.stop() @@ -567,7 +573,7 @@ class SQLDB: sub_timer = timer.add_timer('update claims_in_channel counts') sub_timer.start() if all_channel_keys: - self.db.executemany(f""" + self.executemany(f""" UPDATE claim SET claims_in_channel=( SELECT COUNT(*) FROM claim AS claim_in_channel @@ -575,7 +581,7 @@ class SQLDB: claim_in_channel.channel_hash=claim.claim_hash ) WHERE claim_hash = ? - """, [(sqlite3.Binary(channel_hash),) for channel_hash in all_channel_keys.keys()]) + """, [(channel_hash,) for channel_hash in all_channel_keys.keys()]) sub_timer.stop() def _update_support_amount(self, claim_hashes): @@ -623,7 +629,7 @@ class SQLDB: overtakes = self.execute(f""" SELECT winner.normalized, winner.claim_hash, claimtrie.claim_hash AS current_winner, - MAX(winner.effective_amount) + MAX(winner.effective_amount) AS max_winner_effective_amount FROM ( SELECT normalized, claim_hash, effective_amount FROM claim WHERE normalized IN ( @@ -635,22 +641,22 @@ class SQLDB: HAVING current_winner IS NULL OR current_winner <> winner.claim_hash """, changed_claim_hashes+deleted_names) for overtake in overtakes: - if overtake['current_winner']: + if overtake.current_winner: self.execute( f"UPDATE claimtrie SET claim_hash = ?, last_take_over_height = {height} " f"WHERE normalized = ?", - (sqlite3.Binary(overtake['claim_hash']), overtake['normalized']) + (overtake.claim_hash, overtake.normalized) ) else: self.execute( f"INSERT INTO claimtrie (claim_hash, normalized, last_take_over_height) " f"VALUES (?, ?, {height})", - (sqlite3.Binary(overtake['claim_hash']), overtake['normalized']) + (overtake.claim_hash, overtake.normalized) ) self.execute( f"UPDATE claim SET activation_height = {height} WHERE normalized = ? " f"AND (activation_height IS NULL OR activation_height > {height})", - (overtake['normalized'],) + (overtake.normalized,) ) def _copy(self, height): @@ -660,15 +666,12 @@ class SQLDB: def update_claimtrie(self, height, changed_claim_hashes, deleted_names, timer): r = timer.run - binary_claim_hashes = [ - sqlite3.Binary(claim_hash) for claim_hash in changed_claim_hashes - ] r(self._calculate_activation_height, height) - r(self._update_support_amount, binary_claim_hashes) + r(self._update_support_amount, changed_claim_hashes) - r(self._update_effective_amount, height, binary_claim_hashes) - r(self._perform_overtake, height, binary_claim_hashes, list(deleted_names)) + r(self._update_effective_amount, height, changed_claim_hashes) + r(self._perform_overtake, height, changed_claim_hashes, list(deleted_names)) r(self._update_effective_amount, height) r(self._perform_overtake, height, [], []) @@ -697,10 +700,10 @@ class SQLDB: self.split_inputs_into_claims_supports_and_other, tx.inputs ) body_timer.start() - delete_claim_hashes.update({r['claim_hash'] for r in spent_claims}) - deleted_claim_names.update({r['normalized'] for r in spent_claims}) - delete_support_txo_hashes.update({r['txo_hash'] for r in spent_supports}) - recalculate_claim_hashes.update({r['claim_hash'] for r in spent_supports}) + delete_claim_hashes.update({r.claim_hash for r in spent_claims}) + deleted_claim_names.update({r.normalized for r in spent_claims}) + delete_support_txo_hashes.update({r.txo_hash for r in spent_supports}) + recalculate_claim_hashes.update({r.claim_hash for r in spent_supports}) delete_others.update(spent_others) # Outputs for output in tx.outputs: @@ -728,31 +731,31 @@ class SQLDB: expire_timer = timer.add_timer('recording expired claims') expire_timer.start() for expired in self.get_expiring(height): - delete_claim_hashes.add(expired['claim_hash']) - deleted_claim_names.add(expired['normalized']) + delete_claim_hashes.add(expired.claim_hash) + deleted_claim_names.add(expired.normalized) expire_timer.stop() r = timer.run r(update_full_text_search, 'before-delete', - delete_claim_hashes, self.db, self.main.first_sync) + delete_claim_hashes, self.db.cursor(), self.main.first_sync) affected_channels = r(self.delete_claims, delete_claim_hashes) r(self.delete_supports, delete_support_txo_hashes) r(self.insert_claims, insert_claims, header) r(self.calculate_reposts, insert_claims) r(update_full_text_search, 'after-insert', - [txo.claim_hash for txo in insert_claims], self.db, self.main.first_sync) + [txo.claim_hash for txo in insert_claims], self.db.cursor(), self.main.first_sync) r(update_full_text_search, 'before-update', - [txo.claim_hash for txo in update_claims], self.db, self.main.first_sync) + [txo.claim_hash for txo in update_claims], self.db.cursor(), self.main.first_sync) r(self.update_claims, update_claims, header) r(update_full_text_search, 'after-update', - [txo.claim_hash for txo in update_claims], self.db, self.main.first_sync) + [txo.claim_hash for txo in update_claims], self.db.cursor(), self.main.first_sync) r(self.validate_channel_signatures, height, insert_claims, update_claims, delete_claim_hashes, affected_channels, forward_timer=True) r(self.insert_supports, insert_supports) r(self.update_claimtrie, height, recalculate_claim_hashes, deleted_claim_names, forward_timer=True) - r(calculate_trending, self.db, height, daemon_height) + r(calculate_trending, self.db.cursor(), height, daemon_height) if not self._fts_synced and self.main.first_sync and height == daemon_height: - r(first_sync_finished, self.db) + r(first_sync_finished, self.db.cursor()) self._fts_synced = True diff --git a/lbry/tests/unit/wallet/server/test_sqldb.py b/lbry/tests/unit/wallet/server/test_sqldb.py index e313863bc..8432b6de2 100644 --- a/lbry/tests/unit/wallet/server/test_sqldb.py +++ b/lbry/tests/unit/wallet/server/test_sqldb.py @@ -38,10 +38,10 @@ class TestSQLDB(unittest.TestCase): db_url = 'file:test_sqldb?mode=memory&cache=shared' self.sql = writer.SQLDB(self, db_url) self.addCleanup(self.sql.close) + self.sql.open() reader.initializer(logging.getLogger(__name__), db_url, 'regtest', self.query_timeout) self.addCleanup(reader.cleanup) self.timer = Timer('BlockProcessor') - self.sql.open() self._current_height = 0 self._txos = {} @@ -113,8 +113,8 @@ class TestSQLDB(unittest.TestCase): def get_controlling(self): for claim in self.sql.execute("select claim.* from claimtrie natural join claim"): - txo = self._txos[claim['txo_hash']] - controlling = txo.claim.stream.title, claim['amount'], claim['effective_amount'], claim['activation_height'] + txo = self._txos[claim.txo_hash] + controlling = txo.claim.stream.title, claim.amount, claim.effective_amount, claim.activation_height return controlling def get_active(self): @@ -122,18 +122,18 @@ class TestSQLDB(unittest.TestCase): active = [] for claim in self.sql.execute( f"select * from claim where activation_height <= {self._current_height}"): - txo = self._txos[claim['txo_hash']] + txo = self._txos[claim.txo_hash] if controlling and controlling[0] == txo.claim.stream.title: continue - active.append((txo.claim.stream.title, claim['amount'], claim['effective_amount'], claim['activation_height'])) + active.append((txo.claim.stream.title, claim.amount, claim.effective_amount, claim.activation_height)) return active def get_accepted(self): accepted = [] for claim in self.sql.execute( f"select * from claim where activation_height > {self._current_height}"): - txo = self._txos[claim['txo_hash']] - accepted.append((txo.claim.stream.title, claim['amount'], claim['effective_amount'], claim['activation_height'])) + txo = self._txos[claim.txo_hash] + accepted.append((txo.claim.stream.title, claim.amount, claim.effective_amount, claim.activation_height)) return accepted def advance(self, height, txs):