switch to apsw

This commit is contained in:
Lex Berezhny 2019-12-07 18:13:13 -05:00
parent c28f0f6286
commit 1c349270bf
7 changed files with 136 additions and 137 deletions

View file

@ -80,9 +80,9 @@ class LBRYBlockProcessor(BlockProcessor):
finally: finally:
self.sql.commit() self.sql.commit()
if self.db.first_sync and self.height == self.daemon.cached_height(): if self.db.first_sync and self.height == self.daemon.cached_height():
self.timer.run(self.sql.db.executescript, self.sql.SEARCH_INDEXES, timer_name='executing SEARCH_INDEXES') self.timer.run(self.sql.execute, self.sql.SEARCH_INDEXES, timer_name='executing SEARCH_INDEXES')
if self.env.individual_tag_indexes: if self.env.individual_tag_indexes:
self.timer.run(self.sql.db.executescript, self.sql.TAG_INDEXES, timer_name='executing TAG_INDEXES') self.timer.run(self.sql.execute, self.sql.TAG_INDEXES, timer_name='executing TAG_INDEXES')
for cache in self.search_cache.values(): for cache in self.search_cache.values():
cache.clear() cache.clear()

View file

@ -14,8 +14,13 @@ class FindShortestID:
break break
def finalize(self): def finalize(self):
return '#'+self.short_id if self.short_id:
return '#'+self.short_id
@classmethod
def factory(cls):
return cls(), cls.step, cls.finalize
def register_canonical_functions(connection): def register_canonical_functions(connection):
connection.create_aggregate("shortest_id", 2, FindShortestID) connection.createaggregatefunction("shortest_id", FindShortestID.factory, 2)

View file

@ -1,4 +1,3 @@
import sqlite3
from torba.client.basedatabase import constraints_to_sql from torba.client.basedatabase import constraints_to_sql
CREATE_FULL_TEXT_SEARCH = """ CREATE_FULL_TEXT_SEARCH = """
@ -26,9 +25,7 @@ def fts_action_sql(claims=None, action='insert'):
where, values = "", {} where, values = "", {}
if claims: if claims:
where, values = constraints_to_sql({ where, values = constraints_to_sql({'claim.claim_hash__in': claims})
'claim.claim_hash__in': [sqlite3.Binary(claim_hash) for claim_hash in claims]
})
where = 'WHERE '+where where = 'WHERE '+where
return f""" return f"""

View file

@ -1,6 +1,6 @@
import time import time
import struct import struct
import sqlite3 import apsw
import logging import logging
from operator import itemgetter from operator import itemgetter
from typing import Tuple, List, Dict, Union, Type, Optional from typing import Tuple, List, Dict, Union, Type, Optional
@ -21,20 +21,19 @@ from .common import CLAIM_TYPES, STREAM_TYPES, COMMON_TAGS
from .full_text_search import FTS_ORDER_BY from .full_text_search import FTS_ORDER_BY
class SQLiteOperationalError(sqlite3.OperationalError): class SQLiteOperationalError(apsw.Error):
def __init__(self, metrics): def __init__(self, metrics):
super().__init__('sqlite query errored') super().__init__('sqlite query errored')
self.metrics = metrics self.metrics = metrics
class SQLiteInterruptedError(sqlite3.OperationalError): class SQLiteInterruptedError(apsw.InterruptError):
def __init__(self, metrics): def __init__(self, metrics):
super().__init__('sqlite query interrupted') super().__init__('sqlite query interrupted')
self.metrics = metrics self.metrics = metrics
ATTRIBUTE_ARRAY_MAX_LENGTH = 100 ATTRIBUTE_ARRAY_MAX_LENGTH = 100
sqlite3.enable_callback_tracebacks(True)
INTEGER_PARAMS = { INTEGER_PARAMS = {
'height', 'creation_height', 'activation_height', 'expiration_height', 'height', 'creation_height', 'activation_height', 'expiration_height',
@ -62,14 +61,9 @@ ORDER_FIELDS = {
} | INTEGER_PARAMS } | INTEGER_PARAMS
PRAGMAS = """
pragma journal_mode=WAL;
"""
@dataclass @dataclass
class ReaderState: class ReaderState:
db: sqlite3.Connection db: apsw.Connection
stack: List[List] stack: List[List]
metrics: Dict metrics: Dict
is_tracking_metrics: bool is_tracking_metrics: bool
@ -92,15 +86,17 @@ class ReaderState:
self.db.interrupt() self.db.interrupt()
return return
self.db.set_progress_handler(interruptor, 100) self.db.setprogresshandler(interruptor, 100)
ctx: ContextVar[Optional[ReaderState]] = ContextVar('ctx') ctx: ContextVar[Optional[ReaderState]] = ContextVar('ctx')
def initializer(log, _path, _ledger_name, query_timeout, _measure=False): def initializer(log, _path, _ledger_name, query_timeout, _measure=False):
db = sqlite3.connect(_path, isolation_level=None, uri=True) db = apsw.Connection(_path, flags=apsw.SQLITE_OPEN_READONLY | apsw.SQLITE_OPEN_URI)
db.row_factory = sqlite3.Row def row_factory(cursor, row):
return {k[0]: row[i] for i, k in enumerate(cursor.getdescription())}
db.setrowtrace(row_factory)
ctx.set( ctx.set(
ReaderState( ReaderState(
db=db, stack=[], metrics={}, is_tracking_metrics=_measure, db=db, stack=[], metrics={}, is_tracking_metrics=_measure,
@ -167,8 +163,8 @@ def execute_query(sql, values) -> List:
context = ctx.get() context = ctx.get()
context.set_query_timeout() context.set_query_timeout()
try: try:
return context.db.execute(sql, values).fetchall() return context.db.cursor().execute(sql, values).fetchall()
except sqlite3.OperationalError as err: except apsw.Error as err:
plain_sql = interpolate(sql, values) plain_sql = interpolate(sql, values)
if context.is_tracking_metrics: if context.is_tracking_metrics:
context.metrics['execute_query'][-1]['sql'] = plain_sql context.metrics['execute_query'][-1]['sql'] = plain_sql
@ -231,27 +227,27 @@ def _get_claims(cols, for_count=False, **constraints) -> Tuple[str, Dict]:
constraints['claim.claim_id__in'] = constraints.pop('claim_ids') constraints['claim.claim_id__in'] = constraints.pop('claim_ids')
if 'reposted_claim_id' in constraints: if 'reposted_claim_id' in constraints:
constraints['claim.reposted_claim_hash'] = sqlite3.Binary(unhexlify(constraints.pop('reposted_claim_id'))[::-1]) constraints['claim.reposted_claim_hash'] = unhexlify(constraints.pop('reposted_claim_id'))[::-1]
if 'name' in constraints: if 'name' in constraints:
constraints['claim.normalized'] = normalize_name(constraints.pop('name')) constraints['claim.normalized'] = normalize_name(constraints.pop('name'))
if 'public_key_id' in constraints: if 'public_key_id' in constraints:
constraints['claim.public_key_hash'] = sqlite3.Binary( constraints['claim.public_key_hash'] = (
ctx.get().ledger.address_to_hash160(constraints.pop('public_key_id'))) ctx.get().ledger.address_to_hash160(constraints.pop('public_key_id')))
if 'channel_hash' in constraints: if 'channel_hash' in constraints:
constraints['claim.channel_hash'] = sqlite3.Binary(constraints.pop('channel_hash')) constraints['claim.channel_hash'] = constraints.pop('channel_hash')
if 'channel_ids' in constraints: if 'channel_ids' in constraints:
channel_ids = constraints.pop('channel_ids') channel_ids = constraints.pop('channel_ids')
if channel_ids: if channel_ids:
constraints['claim.channel_hash__in'] = [ constraints['claim.channel_hash__in'] = [
sqlite3.Binary(unhexlify(cid)[::-1]) for cid in channel_ids unhexlify(cid)[::-1] for cid in channel_ids
] ]
if 'not_channel_ids' in constraints: if 'not_channel_ids' in constraints:
not_channel_ids = constraints.pop('not_channel_ids') not_channel_ids = constraints.pop('not_channel_ids')
if not_channel_ids: if not_channel_ids:
not_channel_ids_binary = [ not_channel_ids_binary = [
sqlite3.Binary(unhexlify(ncid)[::-1]) for ncid in not_channel_ids unhexlify(ncid)[::-1] for ncid in not_channel_ids
] ]
if constraints.get('has_channel_signature', False): if constraints.get('has_channel_signature', False):
constraints['claim.channel_hash__not_in'] = not_channel_ids_binary constraints['claim.channel_hash__not_in'] = not_channel_ids_binary
@ -264,7 +260,7 @@ def _get_claims(cols, for_count=False, **constraints) -> Tuple[str, Dict]:
blocklist_ids = constraints.pop('blocklist_channel_ids') blocklist_ids = constraints.pop('blocklist_channel_ids')
if blocklist_ids: if blocklist_ids:
blocking_channels = [ blocking_channels = [
sqlite3.Binary(unhexlify(channel_id)[::-1]) for channel_id in blocklist_ids unhexlify(channel_id)[::-1] for channel_id in blocklist_ids
] ]
constraints.update({ constraints.update({
f'$blocking_channel{i}': a for i, a in enumerate(blocking_channels) f'$blocking_channel{i}': a for i, a in enumerate(blocking_channels)
@ -290,9 +286,7 @@ def _get_claims(cols, for_count=False, **constraints) -> Tuple[str, Dict]:
if 'txid' in constraints: if 'txid' in constraints:
tx_hash = unhexlify(constraints.pop('txid'))[::-1] tx_hash = unhexlify(constraints.pop('txid'))[::-1]
nout = constraints.pop('nout', 0) nout = constraints.pop('nout', 0)
constraints['claim.txo_hash'] = sqlite3.Binary( constraints['claim.txo_hash'] = tx_hash + struct.pack('<I', nout)
tx_hash + struct.pack('<I', nout)
)
if 'claim_type' in constraints: if 'claim_type' in constraints:
constraints['claim.claim_type'] = CLAIM_TYPES[constraints.pop('claim_type')] constraints['claim.claim_type'] = CLAIM_TYPES[constraints.pop('claim_type')]
@ -329,10 +323,10 @@ def get_claims(cols, for_count=False, **constraints) -> List:
if 'channel' in constraints: if 'channel' in constraints:
channel_url = constraints.pop('channel') channel_url = constraints.pop('channel')
match = resolve_url(channel_url) match = resolve_url(channel_url)
if isinstance(match, sqlite3.Row): if isinstance(match, dict):
constraints['channel_hash'] = match['claim_hash'] constraints['channel_hash'] = match['claim_hash']
else: else:
return [[0]] if cols == 'count(*)' else [] return [{'row_count': 0}] if cols == 'count(*) as row_count' else []
sql, values = _get_claims(cols, for_count, **constraints) sql, values = _get_claims(cols, for_count, **constraints)
return execute_query(sql, values) return execute_query(sql, values)
@ -342,8 +336,8 @@ def get_claims_count(**constraints) -> int:
constraints.pop('offset', None) constraints.pop('offset', None)
constraints.pop('limit', None) constraints.pop('limit', None)
constraints.pop('order_by', None) constraints.pop('order_by', None)
count = get_claims('count(*)', for_count=True, **constraints) count = get_claims('count(*) as row_count', for_count=True, **constraints)
return count[0][0] return count[0]['row_count']
def _search(**constraints): def _search(**constraints):
@ -365,22 +359,18 @@ def _search(**constraints):
) )
def _get_referenced_rows(txo_rows: List[sqlite3.Row]): def _get_referenced_rows(txo_rows: List[dict]):
repost_hashes = set(filter(None, map(itemgetter('reposted_claim_hash'), txo_rows))) repost_hashes = set(filter(None, map(itemgetter('reposted_claim_hash'), txo_rows)))
channel_hashes = set(filter(None, map(itemgetter('channel_hash'), txo_rows))) channel_hashes = set(filter(None, map(itemgetter('channel_hash'), txo_rows)))
reposted_txos = [] reposted_txos = []
if repost_hashes: if repost_hashes:
reposted_txos = _search( reposted_txos = _search(**{'claim.claim_hash__in': repost_hashes})
**{'claim.claim_hash__in': [sqlite3.Binary(h) for h in repost_hashes]}
)
channel_hashes |= set(filter(None, map(itemgetter('channel_hash'), reposted_txos))) channel_hashes |= set(filter(None, map(itemgetter('channel_hash'), reposted_txos)))
channel_txos = [] channel_txos = []
if channel_hashes: if channel_hashes:
channel_txos = _search( channel_txos = _search(**{'claim.claim_hash__in': channel_hashes})
**{'claim.claim_hash__in': [sqlite3.Binary(h) for h in channel_hashes]}
)
# channels must come first for client side inflation to work properly # channels must come first for client side inflation to work properly
return channel_txos + reposted_txos return channel_txos + reposted_txos
@ -405,7 +395,7 @@ def search(constraints) -> Tuple[List, List, int, int]:
@measure @measure
def resolve(urls) -> Tuple[List, List]: def resolve(urls) -> Tuple[List, List]:
txo_rows = [resolve_url(raw_url) for raw_url in urls] txo_rows = [resolve_url(raw_url) for raw_url in urls]
extra_txo_rows = _get_referenced_rows([r for r in txo_rows if isinstance(r, sqlite3.Row)]) extra_txo_rows = _get_referenced_rows([r for r in txo_rows if isinstance(r, dict)])
return txo_rows, extra_txo_rows return txo_rows, extra_txo_rows

View file

@ -47,9 +47,13 @@ class ZScore:
return self.last return self.last
return (self.last - self.mean) / (self.standard_deviation or 1) return (self.last - self.mean) / (self.standard_deviation or 1)
@classmethod
def factory(cls):
return cls(), cls.step, cls.finalize
def register_trending_functions(connection): def register_trending_functions(connection):
connection.create_aggregate("zscore", 1, ZScore) connection.createaggregatefunction("zscore", ZScore.factory, 1)
def calculate_trending(db, height, final_height): def calculate_trending(db, height, final_height):
@ -75,8 +79,8 @@ def calculate_trending(db, height, final_height):
""") """)
zscore = ZScore() zscore = ZScore()
for (global_sum,) in db.execute("SELECT AVG(amount) FROM trend GROUP BY height"): for global_sum in db.execute("SELECT AVG(amount) AS avg_amount FROM trend GROUP BY height"):
zscore.step(global_sum) zscore.step(global_sum.avg_amount)
global_mean, global_deviation = 0, 1 global_mean, global_deviation = 0, 1
if zscore.count > 0: if zscore.count > 0:
global_mean = zscore.mean global_mean = zscore.mean

View file

@ -1,8 +1,9 @@
import os import os
import sqlite3 import apsw
from typing import Union, Tuple, Set, List from typing import Union, Tuple, Set, List
from itertools import chain from itertools import chain
from decimal import Decimal from decimal import Decimal
from collections import namedtuple
from torba.server.db import DB from torba.server.db import DB
from torba.server.util import class_logger from torba.server.util import class_logger
@ -22,7 +23,6 @@ from .common import CLAIM_TYPES, STREAM_TYPES, COMMON_TAGS
ATTRIBUTE_ARRAY_MAX_LENGTH = 100 ATTRIBUTE_ARRAY_MAX_LENGTH = 100
sqlite3.enable_callback_tracebacks(True)
class SQLDB: class SQLDB:
@ -158,7 +158,6 @@ class SQLDB:
) )
CREATE_TABLES_QUERY = ( CREATE_TABLES_QUERY = (
PRAGMAS +
CREATE_CLAIM_TABLE + CREATE_CLAIM_TABLE +
CREATE_TREND_TABLE + CREATE_TREND_TABLE +
CREATE_FULL_TEXT_SEARCH + CREATE_FULL_TEXT_SEARCH +
@ -176,14 +175,26 @@ class SQLDB:
self._fts_synced = False self._fts_synced = False
def open(self): def open(self):
self.db = sqlite3.connect(self._db_path, isolation_level=None, check_same_thread=False, uri=True) self.db = apsw.Connection(
self.db.row_factory = sqlite3.Row self._db_path,
self.db.executescript(self.CREATE_TABLES_QUERY) flags=(
apsw.SQLITE_OPEN_READWRITE |
apsw.SQLITE_OPEN_CREATE |
apsw.SQLITE_OPEN_URI
)
)
def exec_factory(cursor, statement, bindings):
tpl = namedtuple('row', (d[0] for d in cursor.getdescription()))
cursor.setrowtrace(lambda cursor, row: tpl(*row))
return True
self.db.setexectrace(exec_factory)
self.execute(self.CREATE_TABLES_QUERY)
register_canonical_functions(self.db) register_canonical_functions(self.db)
register_trending_functions(self.db) register_trending_functions(self.db)
def close(self): def close(self):
self.db.close() if self.db is not None:
self.db.close()
@staticmethod @staticmethod
def _insert_sql(table: str, data: dict) -> Tuple[str, list]: def _insert_sql(table: str, data: dict) -> Tuple[str, list]:
@ -213,7 +224,10 @@ class SQLDB:
return f"DELETE FROM {table} WHERE {where}", values return f"DELETE FROM {table} WHERE {where}", values
def execute(self, *args): def execute(self, *args):
return self.db.execute(*args) return self.db.cursor().execute(*args)
def executemany(self, *args):
return self.db.cursor().executemany(*args)
def begin(self): def begin(self):
self.execute('begin;') self.execute('begin;')
@ -222,7 +236,7 @@ class SQLDB:
self.execute('commit;') self.execute('commit;')
def _upsertable_claims(self, txos: List[Output], header, clear_first=False): def _upsertable_claims(self, txos: List[Output], header, clear_first=False):
claim_hashes, claims, tags = [], [], {} claim_hashes, claims, tags = set(), [], {}
for txo in txos: for txo in txos:
tx = txo.tx_ref.tx tx = txo.tx_ref.tx
@ -233,14 +247,14 @@ class SQLDB:
#self.logger.exception(f"Could not decode claim name for {tx.id}:{txo.position}.") #self.logger.exception(f"Could not decode claim name for {tx.id}:{txo.position}.")
continue continue
claim_hash = sqlite3.Binary(txo.claim_hash) claim_hash = txo.claim_hash
claim_hashes.append(claim_hash) claim_hashes.add(claim_hash)
claim_record = { claim_record = {
'claim_hash': claim_hash, 'claim_hash': claim_hash,
'claim_id': txo.claim_id, 'claim_id': txo.claim_id,
'claim_name': txo.claim_name, 'claim_name': txo.claim_name,
'normalized': txo.normalized_name, 'normalized': txo.normalized_name,
'txo_hash': sqlite3.Binary(txo.ref.hash), 'txo_hash': txo.ref.hash,
'tx_position': tx.position, 'tx_position': tx.position,
'amount': txo.amount, 'amount': txo.amount,
'timestamp': header['timestamp'], 'timestamp': header['timestamp'],
@ -291,7 +305,7 @@ class SQLDB:
self._clear_claim_metadata(claim_hashes) self._clear_claim_metadata(claim_hashes)
if tags: if tags:
self.db.executemany( self.executemany(
"INSERT OR IGNORE INTO tag (tag, claim_hash, height) VALUES (?, ?, ?)", tags.values() "INSERT OR IGNORE INTO tag (tag, claim_hash, height) VALUES (?, ?, ?)", tags.values()
) )
@ -300,7 +314,7 @@ class SQLDB:
def insert_claims(self, txos: List[Output], header): def insert_claims(self, txos: List[Output], header):
claims = self._upsertable_claims(txos, header) claims = self._upsertable_claims(txos, header)
if claims: if claims:
self.db.executemany(""" self.executemany("""
INSERT OR IGNORE INTO claim ( INSERT OR IGNORE INTO claim (
claim_hash, claim_id, claim_name, normalized, txo_hash, tx_position, amount, claim_hash, claim_id, claim_name, normalized, txo_hash, tx_position, amount,
claim_type, media_type, stream_type, timestamp, creation_timestamp, claim_type, media_type, stream_type, timestamp, creation_timestamp,
@ -322,7 +336,7 @@ class SQLDB:
def update_claims(self, txos: List[Output], header): def update_claims(self, txos: List[Output], header):
claims = self._upsertable_claims(txos, header, clear_first=True) claims = self._upsertable_claims(txos, header, clear_first=True)
if claims: if claims:
self.db.executemany(""" self.executemany("""
UPDATE claim SET UPDATE claim SET
txo_hash=:txo_hash, tx_position=:tx_position, amount=:amount, height=:height, txo_hash=:txo_hash, tx_position=:tx_position, amount=:amount, height=:height,
claim_type=:claim_type, media_type=:media_type, stream_type=:stream_type, claim_type=:claim_type, media_type=:media_type, stream_type=:stream_type,
@ -335,35 +349,32 @@ class SQLDB:
def delete_claims(self, claim_hashes: Set[bytes]): def delete_claims(self, claim_hashes: Set[bytes]):
""" Deletes claim supports and from claimtrie in case of an abandon. """ """ Deletes claim supports and from claimtrie in case of an abandon. """
if claim_hashes: if claim_hashes:
binary_claim_hashes = [sqlite3.Binary(claim_hash) for claim_hash in claim_hashes]
affected_channels = self.execute(*query( affected_channels = self.execute(*query(
"SELECT channel_hash FROM claim", channel_hash__is_not_null=1, claim_hash__in=binary_claim_hashes "SELECT channel_hash FROM claim", channel_hash__is_not_null=1, claim_hash__in=claim_hashes
)).fetchall() )).fetchall()
for table in ('claim', 'support', 'claimtrie'): for table in ('claim', 'support', 'claimtrie'):
self.execute(*self._delete_sql(table, {'claim_hash__in': binary_claim_hashes})) self.execute(*self._delete_sql(table, {'claim_hash__in': claim_hashes}))
self._clear_claim_metadata(binary_claim_hashes) self._clear_claim_metadata(claim_hashes)
return {r['channel_hash'] for r in affected_channels} return {r.channel_hash for r in affected_channels}
return set() return set()
def _clear_claim_metadata(self, binary_claim_hashes: List[sqlite3.Binary]): def _clear_claim_metadata(self, claim_hashes: Set[bytes]):
if binary_claim_hashes: if claim_hashes:
for table in ('tag',): # 'language', 'location', etc for table in ('tag',): # 'language', 'location', etc
self.execute(*self._delete_sql(table, {'claim_hash__in': binary_claim_hashes})) self.execute(*self._delete_sql(table, {'claim_hash__in': claim_hashes}))
def split_inputs_into_claims_supports_and_other(self, txis): def split_inputs_into_claims_supports_and_other(self, txis):
txo_hashes = {txi.txo_ref.hash for txi in txis} txo_hashes = {txi.txo_ref.hash for txi in txis}
claims = self.execute(*query( claims = self.execute(*query(
"SELECT txo_hash, claim_hash, normalized FROM claim", "SELECT txo_hash, claim_hash, normalized FROM claim", txo_hash__in=txo_hashes
txo_hash__in=[sqlite3.Binary(txo_hash) for txo_hash in txo_hashes]
)).fetchall() )).fetchall()
txo_hashes -= {r['txo_hash'] for r in claims} txo_hashes -= {r.txo_hash for r in claims}
supports = {} supports = {}
if txo_hashes: if txo_hashes:
supports = self.execute(*query( supports = self.execute(*query(
"SELECT txo_hash, claim_hash FROM support", "SELECT txo_hash, claim_hash FROM support", txo_hash__in=txo_hashes
txo_hash__in=[sqlite3.Binary(txo_hash) for txo_hash in txo_hashes]
)).fetchall() )).fetchall()
txo_hashes -= {r['txo_hash'] for r in supports} txo_hashes -= {r.txo_hash for r in supports}
return claims, supports, txo_hashes return claims, supports, txo_hashes
def insert_supports(self, txos: List[Output]): def insert_supports(self, txos: List[Output]):
@ -371,11 +382,11 @@ class SQLDB:
for txo in txos: for txo in txos:
tx = txo.tx_ref.tx tx = txo.tx_ref.tx
supports.append(( supports.append((
sqlite3.Binary(txo.ref.hash), tx.position, tx.height, txo.ref.hash, tx.position, tx.height,
sqlite3.Binary(txo.claim_hash), txo.amount txo.claim_hash, txo.amount
)) ))
if supports: if supports:
self.db.executemany( self.executemany(
"INSERT OR IGNORE INTO support (" "INSERT OR IGNORE INTO support ("
" txo_hash, tx_position, height, claim_hash, amount" " txo_hash, tx_position, height, claim_hash, amount"
") " ") "
@ -384,9 +395,7 @@ class SQLDB:
def delete_supports(self, txo_hashes: Set[bytes]): def delete_supports(self, txo_hashes: Set[bytes]):
if txo_hashes: if txo_hashes:
self.execute(*self._delete_sql( self.execute(*self._delete_sql('support', {'txo_hash__in': txo_hashes}))
'support', {'txo_hash__in': [sqlite3.Binary(txo_hash) for txo_hash in txo_hashes]}
))
def calculate_reposts(self, txos: List[Output]): def calculate_reposts(self, txos: List[Output]):
targets = set() targets = set()
@ -398,7 +407,7 @@ class SQLDB:
if claim.is_repost: if claim.is_repost:
targets.add((claim.repost.reference.claim_hash,)) targets.add((claim.repost.reference.claim_hash,))
if targets: if targets:
self.db.executemany( self.executemany(
""" """
UPDATE claim SET reposted = ( UPDATE claim SET reposted = (
SELECT count(*) FROM claim AS repost WHERE repost.reposted_claim_hash = claim.claim_hash SELECT count(*) FROM claim AS repost WHERE repost.reposted_claim_hash = claim.claim_hash
@ -441,10 +450,7 @@ class SQLDB:
if new_channel_keys or missing_channel_keys or affected_channels: if new_channel_keys or missing_channel_keys or affected_channels:
all_channel_keys = dict(self.execute(*query( all_channel_keys = dict(self.execute(*query(
"SELECT claim_hash, public_key_bytes FROM claim", "SELECT claim_hash, public_key_bytes FROM claim",
claim_hash__in=[ claim_hash__in=set(new_channel_keys) | missing_channel_keys | affected_channels
sqlite3.Binary(channel_hash) for channel_hash in
set(new_channel_keys) | missing_channel_keys | affected_channels
]
))) )))
sub_timer.stop() sub_timer.stop()
@ -461,7 +467,7 @@ class SQLDB:
for claim_hash, txo in signables.items(): for claim_hash, txo in signables.items():
claim = txo.claim claim = txo.claim
update = { update = {
'claim_hash': sqlite3.Binary(claim_hash), 'claim_hash': claim_hash,
'channel_hash': None, 'channel_hash': None,
'signature': None, 'signature': None,
'signature_digest': None, 'signature_digest': None,
@ -469,9 +475,9 @@ class SQLDB:
} }
if claim.is_signed: if claim.is_signed:
update.update({ update.update({
'channel_hash': sqlite3.Binary(claim.signing_channel_hash), 'channel_hash': claim.signing_channel_hash,
'signature': sqlite3.Binary(txo.get_encoded_signature()), 'signature': txo.get_encoded_signature(),
'signature_digest': sqlite3.Binary(txo.get_signature_digest(self.ledger)), 'signature_digest': txo.get_signature_digest(self.ledger),
'signature_valid': 0 'signature_valid': 0
}) })
claim_updates.append(update) claim_updates.append(update)
@ -485,13 +491,13 @@ class SQLDB:
channel_hash IN ({','.join('?' for _ in changed_channel_keys)}) AND channel_hash IN ({','.join('?' for _ in changed_channel_keys)}) AND
signature IS NOT NULL signature IS NOT NULL
""" """
for affected_claim in self.execute(sql, [sqlite3.Binary(h) for h in changed_channel_keys]): for affected_claim in self.execute(sql, changed_channel_keys):
if affected_claim['claim_hash'] not in signables: if affected_claim.claim_hash not in signables:
claim_updates.append({ claim_updates.append({
'claim_hash': sqlite3.Binary(affected_claim['claim_hash']), 'claim_hash': affected_claim.claim_hash,
'channel_hash': sqlite3.Binary(affected_claim['channel_hash']), 'channel_hash': affected_claim.channel_hash,
'signature': sqlite3.Binary(affected_claim['signature']), 'signature': affected_claim.signature,
'signature_digest': sqlite3.Binary(affected_claim['signature_digest']), 'signature_digest': affected_claim.signature_digest,
'signature_valid': 0 'signature_valid': 0
}) })
sub_timer.stop() sub_timer.stop()
@ -509,7 +515,7 @@ class SQLDB:
sub_timer = timer.add_timer('update claims') sub_timer = timer.add_timer('update claims')
sub_timer.start() sub_timer.start()
if claim_updates: if claim_updates:
self.db.executemany(f""" self.executemany(f"""
UPDATE claim SET UPDATE claim SET
channel_hash=:channel_hash, signature=:signature, signature_digest=:signature_digest, channel_hash=:channel_hash, signature=:signature, signature_digest=:signature_digest,
signature_valid=:signature_valid, signature_valid=:signature_valid,
@ -542,24 +548,24 @@ class SQLDB:
signature_valid=CASE WHEN signature IS NOT NULL THEN 0 END, signature_valid=CASE WHEN signature IS NOT NULL THEN 0 END,
channel_join=NULL, canonical_url=NULL channel_join=NULL, canonical_url=NULL
WHERE channel_hash IN ({','.join('?' for _ in spent_claims)}) WHERE channel_hash IN ({','.join('?' for _ in spent_claims)})
""", [sqlite3.Binary(cid) for cid in spent_claims] """, spent_claims
) )
sub_timer.stop() sub_timer.stop()
sub_timer = timer.add_timer('update channels') sub_timer = timer.add_timer('update channels')
sub_timer.start() sub_timer.start()
if channels: if channels:
self.db.executemany( self.executemany(
""" """
UPDATE claim SET UPDATE claim SET
public_key_bytes=:public_key_bytes, public_key_bytes=:public_key_bytes,
public_key_hash=:public_key_hash public_key_hash=:public_key_hash
WHERE claim_hash=:claim_hash""", [{ WHERE claim_hash=:claim_hash""", [{
'claim_hash': sqlite3.Binary(claim_hash), 'claim_hash': claim_hash,
'public_key_bytes': sqlite3.Binary(txo.claim.channel.public_key_bytes), 'public_key_bytes': txo.claim.channel.public_key_bytes,
'public_key_hash': sqlite3.Binary( 'public_key_hash': self.ledger.address_to_hash160(
self.ledger.address_to_hash160( self.ledger.public_key_to_address(txo.claim.channel.public_key_bytes)
self.ledger.public_key_to_address(txo.claim.channel.public_key_bytes))) )
} for claim_hash, txo in channels.items()] } for claim_hash, txo in channels.items()]
) )
sub_timer.stop() sub_timer.stop()
@ -567,7 +573,7 @@ class SQLDB:
sub_timer = timer.add_timer('update claims_in_channel counts') sub_timer = timer.add_timer('update claims_in_channel counts')
sub_timer.start() sub_timer.start()
if all_channel_keys: if all_channel_keys:
self.db.executemany(f""" self.executemany(f"""
UPDATE claim SET UPDATE claim SET
claims_in_channel=( claims_in_channel=(
SELECT COUNT(*) FROM claim AS claim_in_channel SELECT COUNT(*) FROM claim AS claim_in_channel
@ -575,7 +581,7 @@ class SQLDB:
claim_in_channel.channel_hash=claim.claim_hash claim_in_channel.channel_hash=claim.claim_hash
) )
WHERE claim_hash = ? WHERE claim_hash = ?
""", [(sqlite3.Binary(channel_hash),) for channel_hash in all_channel_keys.keys()]) """, [(channel_hash,) for channel_hash in all_channel_keys.keys()])
sub_timer.stop() sub_timer.stop()
def _update_support_amount(self, claim_hashes): def _update_support_amount(self, claim_hashes):
@ -623,7 +629,7 @@ class SQLDB:
overtakes = self.execute(f""" overtakes = self.execute(f"""
SELECT winner.normalized, winner.claim_hash, SELECT winner.normalized, winner.claim_hash,
claimtrie.claim_hash AS current_winner, claimtrie.claim_hash AS current_winner,
MAX(winner.effective_amount) MAX(winner.effective_amount) AS max_winner_effective_amount
FROM ( FROM (
SELECT normalized, claim_hash, effective_amount FROM claim SELECT normalized, claim_hash, effective_amount FROM claim
WHERE normalized IN ( WHERE normalized IN (
@ -635,22 +641,22 @@ class SQLDB:
HAVING current_winner IS NULL OR current_winner <> winner.claim_hash HAVING current_winner IS NULL OR current_winner <> winner.claim_hash
""", changed_claim_hashes+deleted_names) """, changed_claim_hashes+deleted_names)
for overtake in overtakes: for overtake in overtakes:
if overtake['current_winner']: if overtake.current_winner:
self.execute( self.execute(
f"UPDATE claimtrie SET claim_hash = ?, last_take_over_height = {height} " f"UPDATE claimtrie SET claim_hash = ?, last_take_over_height = {height} "
f"WHERE normalized = ?", f"WHERE normalized = ?",
(sqlite3.Binary(overtake['claim_hash']), overtake['normalized']) (overtake.claim_hash, overtake.normalized)
) )
else: else:
self.execute( self.execute(
f"INSERT INTO claimtrie (claim_hash, normalized, last_take_over_height) " f"INSERT INTO claimtrie (claim_hash, normalized, last_take_over_height) "
f"VALUES (?, ?, {height})", f"VALUES (?, ?, {height})",
(sqlite3.Binary(overtake['claim_hash']), overtake['normalized']) (overtake.claim_hash, overtake.normalized)
) )
self.execute( self.execute(
f"UPDATE claim SET activation_height = {height} WHERE normalized = ? " f"UPDATE claim SET activation_height = {height} WHERE normalized = ? "
f"AND (activation_height IS NULL OR activation_height > {height})", f"AND (activation_height IS NULL OR activation_height > {height})",
(overtake['normalized'],) (overtake.normalized,)
) )
def _copy(self, height): def _copy(self, height):
@ -660,15 +666,12 @@ class SQLDB:
def update_claimtrie(self, height, changed_claim_hashes, deleted_names, timer): def update_claimtrie(self, height, changed_claim_hashes, deleted_names, timer):
r = timer.run r = timer.run
binary_claim_hashes = [
sqlite3.Binary(claim_hash) for claim_hash in changed_claim_hashes
]
r(self._calculate_activation_height, height) r(self._calculate_activation_height, height)
r(self._update_support_amount, binary_claim_hashes) r(self._update_support_amount, changed_claim_hashes)
r(self._update_effective_amount, height, binary_claim_hashes) r(self._update_effective_amount, height, changed_claim_hashes)
r(self._perform_overtake, height, binary_claim_hashes, list(deleted_names)) r(self._perform_overtake, height, changed_claim_hashes, list(deleted_names))
r(self._update_effective_amount, height) r(self._update_effective_amount, height)
r(self._perform_overtake, height, [], []) r(self._perform_overtake, height, [], [])
@ -697,10 +700,10 @@ class SQLDB:
self.split_inputs_into_claims_supports_and_other, tx.inputs self.split_inputs_into_claims_supports_and_other, tx.inputs
) )
body_timer.start() body_timer.start()
delete_claim_hashes.update({r['claim_hash'] for r in spent_claims}) delete_claim_hashes.update({r.claim_hash for r in spent_claims})
deleted_claim_names.update({r['normalized'] for r in spent_claims}) deleted_claim_names.update({r.normalized for r in spent_claims})
delete_support_txo_hashes.update({r['txo_hash'] for r in spent_supports}) delete_support_txo_hashes.update({r.txo_hash for r in spent_supports})
recalculate_claim_hashes.update({r['claim_hash'] for r in spent_supports}) recalculate_claim_hashes.update({r.claim_hash for r in spent_supports})
delete_others.update(spent_others) delete_others.update(spent_others)
# Outputs # Outputs
for output in tx.outputs: for output in tx.outputs:
@ -728,31 +731,31 @@ class SQLDB:
expire_timer = timer.add_timer('recording expired claims') expire_timer = timer.add_timer('recording expired claims')
expire_timer.start() expire_timer.start()
for expired in self.get_expiring(height): for expired in self.get_expiring(height):
delete_claim_hashes.add(expired['claim_hash']) delete_claim_hashes.add(expired.claim_hash)
deleted_claim_names.add(expired['normalized']) deleted_claim_names.add(expired.normalized)
expire_timer.stop() expire_timer.stop()
r = timer.run r = timer.run
r(update_full_text_search, 'before-delete', r(update_full_text_search, 'before-delete',
delete_claim_hashes, self.db, self.main.first_sync) delete_claim_hashes, self.db.cursor(), self.main.first_sync)
affected_channels = r(self.delete_claims, delete_claim_hashes) affected_channels = r(self.delete_claims, delete_claim_hashes)
r(self.delete_supports, delete_support_txo_hashes) r(self.delete_supports, delete_support_txo_hashes)
r(self.insert_claims, insert_claims, header) r(self.insert_claims, insert_claims, header)
r(self.calculate_reposts, insert_claims) r(self.calculate_reposts, insert_claims)
r(update_full_text_search, 'after-insert', r(update_full_text_search, 'after-insert',
[txo.claim_hash for txo in insert_claims], self.db, self.main.first_sync) [txo.claim_hash for txo in insert_claims], self.db.cursor(), self.main.first_sync)
r(update_full_text_search, 'before-update', r(update_full_text_search, 'before-update',
[txo.claim_hash for txo in update_claims], self.db, self.main.first_sync) [txo.claim_hash for txo in update_claims], self.db.cursor(), self.main.first_sync)
r(self.update_claims, update_claims, header) r(self.update_claims, update_claims, header)
r(update_full_text_search, 'after-update', r(update_full_text_search, 'after-update',
[txo.claim_hash for txo in update_claims], self.db, self.main.first_sync) [txo.claim_hash for txo in update_claims], self.db.cursor(), self.main.first_sync)
r(self.validate_channel_signatures, height, insert_claims, r(self.validate_channel_signatures, height, insert_claims,
update_claims, delete_claim_hashes, affected_channels, forward_timer=True) update_claims, delete_claim_hashes, affected_channels, forward_timer=True)
r(self.insert_supports, insert_supports) r(self.insert_supports, insert_supports)
r(self.update_claimtrie, height, recalculate_claim_hashes, deleted_claim_names, forward_timer=True) r(self.update_claimtrie, height, recalculate_claim_hashes, deleted_claim_names, forward_timer=True)
r(calculate_trending, self.db, height, daemon_height) r(calculate_trending, self.db.cursor(), height, daemon_height)
if not self._fts_synced and self.main.first_sync and height == daemon_height: if not self._fts_synced and self.main.first_sync and height == daemon_height:
r(first_sync_finished, self.db) r(first_sync_finished, self.db.cursor())
self._fts_synced = True self._fts_synced = True

View file

@ -38,10 +38,10 @@ class TestSQLDB(unittest.TestCase):
db_url = 'file:test_sqldb?mode=memory&cache=shared' db_url = 'file:test_sqldb?mode=memory&cache=shared'
self.sql = writer.SQLDB(self, db_url) self.sql = writer.SQLDB(self, db_url)
self.addCleanup(self.sql.close) self.addCleanup(self.sql.close)
self.sql.open()
reader.initializer(logging.getLogger(__name__), db_url, 'regtest', self.query_timeout) reader.initializer(logging.getLogger(__name__), db_url, 'regtest', self.query_timeout)
self.addCleanup(reader.cleanup) self.addCleanup(reader.cleanup)
self.timer = Timer('BlockProcessor') self.timer = Timer('BlockProcessor')
self.sql.open()
self._current_height = 0 self._current_height = 0
self._txos = {} self._txos = {}
@ -113,8 +113,8 @@ class TestSQLDB(unittest.TestCase):
def get_controlling(self): def get_controlling(self):
for claim in self.sql.execute("select claim.* from claimtrie natural join claim"): for claim in self.sql.execute("select claim.* from claimtrie natural join claim"):
txo = self._txos[claim['txo_hash']] txo = self._txos[claim.txo_hash]
controlling = txo.claim.stream.title, claim['amount'], claim['effective_amount'], claim['activation_height'] controlling = txo.claim.stream.title, claim.amount, claim.effective_amount, claim.activation_height
return controlling return controlling
def get_active(self): def get_active(self):
@ -122,18 +122,18 @@ class TestSQLDB(unittest.TestCase):
active = [] active = []
for claim in self.sql.execute( for claim in self.sql.execute(
f"select * from claim where activation_height <= {self._current_height}"): f"select * from claim where activation_height <= {self._current_height}"):
txo = self._txos[claim['txo_hash']] txo = self._txos[claim.txo_hash]
if controlling and controlling[0] == txo.claim.stream.title: if controlling and controlling[0] == txo.claim.stream.title:
continue continue
active.append((txo.claim.stream.title, claim['amount'], claim['effective_amount'], claim['activation_height'])) active.append((txo.claim.stream.title, claim.amount, claim.effective_amount, claim.activation_height))
return active return active
def get_accepted(self): def get_accepted(self):
accepted = [] accepted = []
for claim in self.sql.execute( for claim in self.sql.execute(
f"select * from claim where activation_height > {self._current_height}"): f"select * from claim where activation_height > {self._current_height}"):
txo = self._txos[claim['txo_hash']] txo = self._txos[claim.txo_hash]
accepted.append((txo.claim.stream.title, claim['amount'], claim['effective_amount'], claim['activation_height'])) accepted.append((txo.claim.stream.title, claim.amount, claim.effective_amount, claim.activation_height))
return accepted return accepted
def advance(self, height, txs): def advance(self, height, txs):