lbry-sdk/lbry/db/query_context.py

575 lines
20 KiB
Python
Raw Normal View History

2020-06-05 06:35:22 +02:00
import os
import time
import multiprocessing as mp
from enum import Enum
from decimal import Decimal
2020-06-19 20:28:34 +02:00
from typing import Dict, List, Optional, Tuple
2020-06-05 06:35:22 +02:00
from dataclasses import dataclass
from contextvars import ContextVar
2020-06-22 02:14:14 +02:00
from sqlalchemy import create_engine, inspect, bindparam
2020-06-05 06:35:22 +02:00
from sqlalchemy.engine import Engine, Connection
from lbry.event import EventQueuePublisher
from lbry.blockchain.ledger import Ledger
from lbry.blockchain.transaction import Transaction, Output, Input
from lbry.schema.tags import clean_tags
from lbry.schema.result import Censor
from lbry.schema.mime_types import guess_stream_type
from .utils import pg_insert, chunk
2020-06-26 16:39:16 +02:00
from .tables import Block, TX, TXO, TXI, Claim, Tag, Support
2020-06-05 06:35:22 +02:00
from .constants import TXO_TYPES, STREAM_TYPES
_context: ContextVar['QueryContext'] = ContextVar('_context')
@dataclass
class QueryContext:
engine: Engine
connection: Connection
ledger: Ledger
message_queue: mp.Queue
stop_event: mp.Event
stack: List[List]
metrics: Dict
is_tracking_metrics: bool
blocked_streams: Dict
blocked_channels: Dict
filtered_streams: Dict
filtered_channels: Dict
pid: int
# QueryContext __enter__/__exit__ state
print_timers: List
current_timer_name: Optional[str] = None
current_timer_time: float = 0
current_progress: Optional['ProgressContext'] = None
@property
def is_postgres(self):
return self.connection.dialect.name == 'postgresql'
@property
def is_sqlite(self):
return self.connection.dialect.name == 'sqlite'
def raise_unsupported_dialect(self):
raise RuntimeError(f'Unsupported database dialect: {self.connection.dialect.name}.')
def get_resolve_censor(self) -> Censor:
return Censor(self.blocked_streams, self.blocked_channels)
def get_search_censor(self) -> Censor:
return Censor(self.filtered_streams, self.filtered_channels)
def execute(self, sql, *args):
return self.connection.execute(sql, *args)
def fetchone(self, sql, *args):
row = self.connection.execute(sql, *args).fetchone()
return dict(row._mapping) if row else row
def fetchall(self, sql, *args):
rows = self.connection.execute(sql, *args).fetchall()
return [dict(row._mapping) for row in rows]
def insert_or_ignore(self, table):
if self.is_sqlite:
return table.insert().prefix_with("OR IGNORE")
elif self.is_postgres:
return pg_insert(table).on_conflict_do_nothing()
else:
self.raise_unsupported_dialect()
def insert_or_replace(self, table, replace):
if self.is_sqlite:
return table.insert().prefix_with("OR REPLACE")
elif self.is_postgres:
insert = pg_insert(table)
return insert.on_conflict_do_update(
table.primary_key, set_={col: getattr(insert.excluded, col) for col in replace}
)
else:
self.raise_unsupported_dialect()
def has_table(self, table):
return inspect(self.engine).has_table(table)
def get_bulk_loader(self) -> 'BulkLoader':
return BulkLoader(self)
def reset_metrics(self):
self.stack = []
self.metrics = {}
def with_timer(self, timer_name: str) -> 'QueryContext':
self.current_timer_name = timer_name
return self
def __enter__(self) -> 'QueryContext':
self.current_timer_time = time.perf_counter()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.current_timer_name and self.current_timer_name in self.print_timers:
elapsed = time.perf_counter() - self.current_timer_time
print(f"{self.print_timers} in {elapsed:.6f}s", flush=True)
self.current_timer_name = None
self.current_timer_time = 0
self.current_progress = None
def context(with_timer: str = None) -> 'QueryContext':
if isinstance(with_timer, str):
return _context.get().with_timer(with_timer)
return _context.get()
def initialize(
ledger: Ledger, message_queue: mp.Queue, stop_event: mp.Event,
track_metrics=False, block_and_filter=None, print_timers=None):
url = ledger.conf.db_url_or_default
engine = create_engine(url)
connection = engine.connect()
if block_and_filter is not None:
blocked_streams, blocked_channels, filtered_streams, filtered_channels = block_and_filter
else:
blocked_streams = blocked_channels = filtered_streams = filtered_channels = {}
_context.set(
QueryContext(
pid=os.getpid(),
engine=engine, connection=connection,
ledger=ledger, message_queue=message_queue, stop_event=stop_event,
stack=[], metrics={}, is_tracking_metrics=track_metrics,
blocked_streams=blocked_streams, blocked_channels=blocked_channels,
filtered_streams=filtered_streams, filtered_channels=filtered_channels,
print_timers=print_timers or []
)
)
def uninitialize():
ctx = _context.get(None)
if ctx is not None:
if ctx.connection:
ctx.connection.close()
_context.set(None)
class ProgressUnit(Enum):
NONE = "", None
TASKS = "tasks", None
BLOCKS = "blocks", Block
TXS = "txs", TX
TXIS = "txis", TXI
CLAIMS = "claims", Claim
SUPPORTS = "supports", Support
def __new__(cls, value, table):
next_id = len(cls.__members__) + 1
obj = object.__new__(cls)
obj._value_ = next_id
obj.label = value
obj.table = table
return obj
class Event(Enum):
2020-06-26 16:39:16 +02:00
START = "blockchain.sync.start", ProgressUnit.BLOCKS
COMPLETE = "blockchain.sync.complete", ProgressUnit.BLOCKS
2020-06-05 06:35:22 +02:00
# full node specific sync events
BLOCK_READ = "blockchain.sync.block.read", ProgressUnit.BLOCKS
BLOCK_SAVE = "blockchain.sync.block.save", ProgressUnit.TXS
2020-06-26 16:39:16 +02:00
BLOCK_FILTER = "blockchain.sync.block.filter", ProgressUnit.BLOCKS
2020-06-19 20:28:34 +02:00
CLAIM_META = "blockchain.sync.claim.update", ProgressUnit.CLAIMS
2020-06-26 16:39:16 +02:00
CLAIM_TRIE = "blockchain.sync.claim.takeovers", ProgressUnit.CLAIMS
STAKE_CALC = "blockchain.sync.claim.stakes", ProgressUnit.CLAIMS
CLAIM_CHAN = "blockchain.sync.claim.channels", ProgressUnit.CLAIMS
2020-06-19 20:28:34 +02:00
CLAIM_SIGN = "blockchain.sync.claim.signatures", ProgressUnit.CLAIMS
SUPPORT_SIGN = "blockchain.sync.support.signatures", ProgressUnit.SUPPORTS
TRENDING_CALC = "blockchain.sync.trending", ProgressUnit.BLOCKS
2020-06-05 06:35:22 +02:00
# full node + light client sync events
INPUT_UPDATE = "db.sync.input", ProgressUnit.TXIS
CLAIM_DELETE = "db.sync.claim.delete", ProgressUnit.CLAIMS
CLAIM_INSERT = "db.sync.claim.insert", ProgressUnit.CLAIMS
2020-06-19 20:28:34 +02:00
CLAIM_UPDATE = "db.sync.claim.update", ProgressUnit.CLAIMS
2020-06-05 06:35:22 +02:00
SUPPORT_DELETE = "db.sync.support.delete", ProgressUnit.SUPPORTS
SUPPORT_INSERT = "db.sync.support.insert", ProgressUnit.SUPPORTS
def __new__(cls, value, unit: ProgressUnit):
next_id = len(cls.__members__) + 1
obj = object.__new__(cls)
obj._value_ = next_id
obj.label = value
obj.unit = unit
return obj
class ProgressPublisher(EventQueuePublisher):
def message_to_event(self, message):
event = Event(message[0]) # pylint: disable=no-value-for-parameter
d = {
"event": event.label,
"data": {
"pid": message[1],
"step": message[2],
"total": message[3],
"unit": event.unit.label
}
}
if len(message) > 4 and isinstance(message[4], dict):
d['data'].update(message[4])
return d
2020-06-19 20:28:34 +02:00
class BreakProgress(Exception):
"""Break out of progress when total is 0."""
2020-06-05 06:35:22 +02:00
class ProgressContext:
def __init__(self, ctx: QueryContext, event: Event, step_size=1):
self.ctx = ctx
self.event = event
self.extra = None
self.step_size = step_size
self.last_step = -1
self.total = 0
def __enter__(self) -> 'ProgressContext':
self.ctx.__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
2020-06-26 16:39:16 +02:00
self.ctx.message_queue.put(self.get_event_args(self.total))
self.ctx.__exit__(exc_type, exc_val, exc_tb)
2020-06-19 20:28:34 +02:00
if exc_type == BreakProgress:
return True
2020-06-05 06:35:22 +02:00
def start(self, total, extra=None):
2020-06-19 20:28:34 +02:00
if not total:
raise BreakProgress
2020-06-05 06:35:22 +02:00
self.total = total
if extra is not None:
self.extra = extra
self.step(0)
def step(self, done):
send_condition = (
# enforce step rate
(self.step_size == 1 or done % self.step_size == 0) and
# deduplicate finish event by not sending a step where done == total
done < self.total and
# deduplicate same step
done != self.last_step
)
if send_condition:
self.ctx.message_queue.put_nowait(self.get_event_args(done))
self.last_step = done
def get_event_args(self, done):
if self.extra is not None:
return self.event.value, self.ctx.pid, done, self.total, self.extra
return self.event.value, self.ctx.pid, done, self.total
def progress(e: Event, step_size=1) -> ProgressContext:
ctx = context(e.label)
ctx.current_progress = ProgressContext(ctx, e, step_size=step_size)
return ctx.current_progress
class BulkLoader:
def __init__(self, ctx: QueryContext):
self.ctx = ctx
self.ledger = ctx.ledger
self.blocks = []
self.txs = []
self.txos = []
self.txis = []
2020-06-19 20:28:34 +02:00
self.supports = []
2020-06-05 06:35:22 +02:00
self.claims = []
self.tags = []
2020-06-19 20:28:34 +02:00
self.update_claims = []
self.delete_tags = []
2020-06-05 06:35:22 +02:00
@staticmethod
2020-06-19 20:28:34 +02:00
def block_to_row(block: Block) -> dict:
2020-06-05 06:35:22 +02:00
return {
'block_hash': block.block_hash,
'previous_hash': block.prev_block_hash,
'file_number': block.file_number,
'height': 0 if block.is_first_block else block.height,
2020-06-19 20:28:34 +02:00
'timestamp': block.timestamp,
2020-06-05 06:35:22 +02:00
}
@staticmethod
2020-06-19 20:28:34 +02:00
def tx_to_row(block_hash: bytes, tx: Transaction) -> dict:
2020-06-05 06:35:22 +02:00
row = {
'tx_hash': tx.hash,
'block_hash': block_hash,
'raw': tx.raw,
'height': tx.height,
'position': tx.position,
'is_verified': tx.is_verified,
2020-06-19 20:28:34 +02:00
'timestamp': tx.timestamp,
'day': tx.day,
2020-06-05 06:35:22 +02:00
'purchased_claim_hash': None,
}
txos = tx.outputs
if len(txos) >= 2 and txos[1].can_decode_purchase_data:
txos[0].purchase = txos[1]
row['purchased_claim_hash'] = txos[1].purchase_data.claim_hash
return row
@staticmethod
2020-06-19 20:28:34 +02:00
def txi_to_row(tx: Transaction, txi: Input) -> dict:
2020-06-05 06:35:22 +02:00
return {
'tx_hash': tx.hash,
'txo_hash': txi.txo_ref.hash,
'position': txi.position,
}
2020-06-19 20:28:34 +02:00
def txo_to_row(self, tx: Transaction, txo: Output) -> dict:
2020-06-05 06:35:22 +02:00
row = {
'tx_hash': tx.hash,
'txo_hash': txo.hash,
'address': txo.get_address(self.ledger) if txo.has_address else None,
'position': txo.position,
'amount': txo.amount,
2020-06-22 01:51:09 +02:00
'height': tx.height,
2020-06-05 06:35:22 +02:00
'script_offset': txo.script.offset,
'script_length': txo.script.length,
'txo_type': 0,
'claim_id': None,
'claim_hash': None,
'claim_name': None,
'channel_hash': None,
2020-06-22 01:51:09 +02:00
'public_key': None,
'public_key_hash': None
2020-06-05 06:35:22 +02:00
}
if txo.is_claim:
if txo.can_decode_claim:
claim = txo.claim
row['txo_type'] = TXO_TYPES.get(claim.claim_type, TXO_TYPES['stream'])
if claim.is_signed:
row['channel_hash'] = claim.signing_channel_hash
2020-06-22 01:51:09 +02:00
if claim.is_channel:
row['public_key'] = claim.channel.public_key_bytes
row['public_key_hash'] = self.ledger.address_to_hash160(
self.ledger.public_key_to_address(claim.channel.public_key_bytes)
)
2020-06-05 06:35:22 +02:00
else:
row['txo_type'] = TXO_TYPES['stream']
elif txo.is_support:
row['txo_type'] = TXO_TYPES['support']
2020-06-22 01:51:09 +02:00
if txo.can_decode_support:
claim = txo.support
if claim.is_signed:
row['channel_hash'] = claim.signing_channel_hash
2020-06-05 06:35:22 +02:00
elif txo.purchase is not None:
row['txo_type'] = TXO_TYPES['purchase']
row['claim_id'] = txo.purchased_claim_id
row['claim_hash'] = txo.purchased_claim_hash
if txo.script.is_claim_involved:
row['claim_id'] = txo.claim_id
row['claim_hash'] = txo.claim_hash
try:
claim_name = txo.claim_name
if '\x00' in claim_name:
# log.error(f"Name for claim {txo.claim_id} contains a NULL (\\x00) character, skipping.")
pass
else:
row['claim_name'] = claim_name
except UnicodeDecodeError:
# log.error(f"Name for claim {txo.claim_id} contains invalid unicode, skipping.")
pass
return row
2020-06-19 20:28:34 +02:00
def claim_to_rows(self, txo: Output) -> Tuple[dict, List]:
2020-06-05 06:35:22 +02:00
try:
assert txo.claim_name
assert txo.normalized_name
except Exception:
#self.logger.exception(f"Could not decode claim name for {tx.id}:{txo.position}.")
2020-06-19 20:28:34 +02:00
return {}, []
2020-06-05 06:35:22 +02:00
tx = txo.tx_ref.tx
claim_hash = txo.claim_hash
claim_record = {
'claim_hash': claim_hash,
'claim_id': txo.claim_id,
'claim_name': txo.claim_name,
'normalized': txo.normalized_name,
'address': txo.get_address(self.ledger),
'txo_hash': txo.ref.hash,
'amount': txo.amount,
2020-06-19 20:28:34 +02:00
'timestamp': tx.timestamp,
2020-06-05 06:35:22 +02:00
'release_time': None,
2020-06-22 01:51:09 +02:00
'height': tx.height,
2020-06-05 06:35:22 +02:00
'title': None,
'author': None,
'description': None,
'claim_type': None,
# streams
'stream_type': None,
'media_type': None,
'fee_amount': 0,
2020-06-22 01:51:09 +02:00
'fee_currency': None,
2020-06-05 06:35:22 +02:00
'duration': None,
# reposts
'reposted_claim_hash': None,
2020-06-19 20:28:34 +02:00
# signed claims
'channel_hash': None,
'signature': None,
'signature_digest': None,
2020-06-22 01:51:09 +02:00
'is_signature_valid': None,
2020-06-05 06:35:22 +02:00
}
try:
claim = txo.claim
except Exception:
#self.logger.exception(f"Could not parse claim protobuf for {tx.id}:{txo.position}.")
2020-06-19 20:28:34 +02:00
return claim_record, []
2020-06-05 06:35:22 +02:00
if claim.is_stream:
claim_record['claim_type'] = TXO_TYPES['stream']
claim_record['stream_type'] = STREAM_TYPES[guess_stream_type(claim_record['media_type'])]
2020-06-22 01:51:09 +02:00
claim_record['media_type'] = claim.stream.source.media_type
2020-06-05 06:35:22 +02:00
claim_record['title'] = claim.stream.title
claim_record['description'] = claim.stream.description
claim_record['author'] = claim.stream.author
if claim.stream.video and claim.stream.video.duration:
claim_record['duration'] = claim.stream.video.duration
if claim.stream.audio and claim.stream.audio.duration:
claim_record['duration'] = claim.stream.audio.duration
if claim.stream.release_time:
claim_record['release_time'] = claim.stream.release_time
if claim.stream.has_fee:
fee = claim.stream.fee
if isinstance(fee.currency, str):
claim_record['fee_currency'] = fee.currency.lower()
if isinstance(fee.amount, Decimal):
claim_record['fee_amount'] = int(fee.amount*1000)
elif claim.is_repost:
claim_record['claim_type'] = TXO_TYPES['repost']
claim_record['reposted_claim_hash'] = claim.repost.reference.claim_hash
elif claim.is_channel:
claim_record['claim_type'] = TXO_TYPES['channel']
2020-06-19 20:28:34 +02:00
if claim.is_signed:
claim_record['channel_hash'] = claim.signing_channel_hash
claim_record['signature'] = txo.get_encoded_signature()
claim_record['signature_digest'] = txo.get_signature_digest(self.ledger)
2020-06-19 20:28:34 +02:00
tags = [
{'claim_hash': claim_hash, 'tag': tag} for tag in clean_tags(claim.message.tags)
]
return claim_record, tags
2020-06-05 06:35:22 +02:00
2020-06-19 20:28:34 +02:00
def add_block(self, block: Block, add_claims_supports: set = None):
self.blocks.append(self.block_to_row(block))
for tx in block.txs:
self.add_transaction(block.block_hash, tx, add_claims_supports)
return self
def add_transaction(self, block_hash: bytes, tx: Transaction, add_claims_supports: set = None):
self.txs.append(self.tx_to_row(block_hash, tx))
for txi in tx.inputs:
if txi.coinbase is None:
self.txis.append(self.txi_to_row(tx, txi))
for txo in tx.outputs:
self.txos.append(self.txo_to_row(tx, txo))
if add_claims_supports:
if txo.is_support and txo.hash in add_claims_supports:
self.add_support(txo)
elif txo.is_claim and txo.hash in add_claims_supports:
self.add_claim(txo)
return self
def add_support(self, txo: Output):
tx = txo.tx_ref.tx
claim_hash = txo.claim_hash
support_record = {
2020-06-22 01:51:09 +02:00
'txo_hash': txo.ref.hash,
2020-06-19 20:28:34 +02:00
'claim_hash': claim_hash,
'address': txo.get_address(self.ledger),
'amount': txo.amount,
'height': tx.height,
2020-06-22 01:51:09 +02:00
'emoji': None,
'channel_hash': None,
'signature': None,
'signature_digest': None,
2020-06-19 20:28:34 +02:00
}
self.supports.append(support_record)
support = txo.can_decode_support
if support:
support_record['emoji'] = support.emoji
if support.is_signed:
support_record['channel_hash'] = support.signing_channel_hash
2020-06-22 01:51:09 +02:00
support_record['signature'] = txo.get_encoded_signature()
support_record['signature_digest'] = txo.get_signature_digest(None)
2020-06-19 20:28:34 +02:00
def add_claim(self, txo: Output):
claim, tags = self.claim_to_rows(txo)
if claim:
tx = txo.tx_ref.tx
if txo.script.is_claim_name:
claim['creation_height'] = tx.height
claim['creation_timestamp'] = tx.timestamp
else:
claim['creation_height'] = None
claim['creation_timestamp'] = None
2020-06-19 20:28:34 +02:00
self.claims.append(claim)
self.tags.extend(tags)
return self
2020-06-05 06:35:22 +02:00
2020-06-19 20:28:34 +02:00
def update_claim(self, txo: Output):
claim, tags = self.claim_to_rows(txo)
if claim:
2020-06-22 01:51:09 +02:00
claim['claim_hash_'] = claim.pop('claim_hash')
2020-06-19 20:28:34 +02:00
self.update_claims.append(claim)
2020-06-22 01:51:09 +02:00
self.delete_tags.append({'claim_hash_': claim['claim_hash_']})
2020-06-19 20:28:34 +02:00
self.tags.extend(tags)
2020-06-05 06:35:22 +02:00
return self
def save(self, batch_size=10000):
queries = (
2020-06-19 20:28:34 +02:00
(Block.insert(), self.blocks),
(TX.insert(), self.txs),
(TXO.insert(), self.txos),
(TXI.insert(), self.txis),
(Claim.insert(), self.claims),
2020-06-22 01:51:09 +02:00
(Tag.delete().where(Tag.c.claim_hash == bindparam('claim_hash_')), self.delete_tags),
(Claim.update().where(Claim.c.claim_hash == bindparam('claim_hash_')), self.update_claims),
2020-06-19 20:28:34 +02:00
(Tag.insert(), self.tags),
(Support.insert(), self.supports),
2020-06-05 06:35:22 +02:00
)
p = self.ctx.current_progress
done = row_scale = 0
if p:
unit_table = p.event.unit.table
progress_total, row_total = 0, sum(len(q[1]) for q in queries)
2020-06-19 20:28:34 +02:00
for sql, rows in queries:
if sql.table == unit_table:
progress_total += len(rows)
2020-06-05 06:35:22 +02:00
if not progress_total:
assert row_total == 0, "Rows used for progress are empty but other rows present."
return
row_scale = row_total / progress_total
p.start(progress_total)
execute = self.ctx.connection.execute
2020-06-19 20:28:34 +02:00
for sql, rows in queries:
for chunk_rows in chunk(rows, batch_size):
execute(sql, chunk_rows)
2020-06-05 06:35:22 +02:00
if p:
2020-06-19 20:28:34 +02:00
done += int(len(chunk_rows)/row_scale)
2020-06-05 06:35:22 +02:00
p.step(done)