lbry-sdk/lbry/db/query_context.py

676 lines
23 KiB
Python
Raw Normal View History

2020-06-05 06:35:22 +02:00
import os
import time
2020-07-06 05:03:45 +02:00
import functools
2020-06-30 23:32:51 +02:00
from io import BytesIO
2020-06-05 06:35:22 +02:00
import multiprocessing as mp
from decimal import Decimal
2020-06-19 20:28:34 +02:00
from typing import Dict, List, Optional, Tuple
2020-06-30 23:32:51 +02:00
from dataclasses import dataclass, field
2020-06-05 06:35:22 +02:00
from contextvars import ContextVar
2020-07-12 00:18:33 +02:00
from sqlalchemy import create_engine, inspect, bindparam, func, exists, case, event as sqlalchemy_event
2020-07-06 05:03:45 +02:00
from sqlalchemy.future import select
2020-06-05 06:35:22 +02:00
from sqlalchemy.engine import Engine, Connection
2020-06-30 23:32:51 +02:00
from sqlalchemy.sql import Insert
try:
from pgcopy import CopyManager
except ImportError:
CopyManager = None
2020-06-05 06:35:22 +02:00
from lbry.event import EventQueuePublisher
from lbry.blockchain.ledger import Ledger
from lbry.blockchain.transaction import Transaction, Output, Input
from lbry.schema.tags import clean_tags
from lbry.schema.result import Censor
from lbry.schema.mime_types import guess_stream_type
from .utils import pg_insert, chunk
2020-06-26 16:39:16 +02:00
from .tables import Block, TX, TXO, TXI, Claim, Tag, Support
2020-06-05 06:35:22 +02:00
from .constants import TXO_TYPES, STREAM_TYPES
_context: ContextVar['QueryContext'] = ContextVar('_context')
@dataclass
class QueryContext:
engine: Engine
connection: Connection
ledger: Ledger
message_queue: mp.Queue
stop_event: mp.Event
stack: List[List]
metrics: Dict
is_tracking_metrics: bool
blocked_streams: Dict
blocked_channels: Dict
filtered_streams: Dict
filtered_channels: Dict
pid: int
# QueryContext __enter__/__exit__ state
current_timer_name: Optional[str] = None
current_timer_time: float = 0
current_progress: Optional['ProgressContext'] = None
2020-06-30 23:32:51 +02:00
copy_managers: Dict[str, CopyManager] = field(default_factory=dict)
2020-06-05 06:35:22 +02:00
@property
def is_postgres(self):
return self.connection.dialect.name == 'postgresql'
@property
def is_sqlite(self):
return self.connection.dialect.name == 'sqlite'
def raise_unsupported_dialect(self):
raise RuntimeError(f'Unsupported database dialect: {self.connection.dialect.name}.')
def get_resolve_censor(self) -> Censor:
return Censor(self.blocked_streams, self.blocked_channels)
def get_search_censor(self) -> Censor:
return Censor(self.filtered_streams, self.filtered_channels)
2020-06-30 23:32:51 +02:00
def pg_copy(self, table, rows):
connection = self.connection.connection
copy_manager = self.copy_managers.get(table.name)
if copy_manager is None:
self.copy_managers[table.name] = copy_manager = CopyManager(
self.connection.connection, table.name, rows[0].keys()
)
copy_manager.copy(map(dict.values, rows), BytesIO)
connection.commit()
2020-06-05 06:35:22 +02:00
def execute(self, sql, *args):
return self.connection.execute(sql, *args)
def fetchone(self, sql, *args):
row = self.connection.execute(sql, *args).fetchone()
return dict(row._mapping) if row else row
def fetchall(self, sql, *args):
rows = self.connection.execute(sql, *args).fetchall()
return [dict(row._mapping) for row in rows]
2020-07-12 00:18:33 +02:00
def fetchtotal(self, condition) -> int:
2020-07-06 05:03:45 +02:00
sql = select(func.count('*').label('total')).where(condition)
return self.fetchone(sql)['total']
2020-07-12 00:18:33 +02:00
def fetchmax(self, column, default: int) -> int:
sql = select(func.coalesce(func.max(column), default).label('max_result'))
2020-07-06 05:03:45 +02:00
return self.fetchone(sql)['max_result']
2020-07-12 00:18:33 +02:00
def has_records(self, table) -> bool:
2020-07-06 05:03:45 +02:00
sql = select(exists([1], from_obj=table).label('result'))
2020-07-12 00:18:33 +02:00
return bool(self.fetchone(sql)['result'])
2020-07-06 05:03:45 +02:00
2020-06-05 06:35:22 +02:00
def insert_or_ignore(self, table):
if self.is_sqlite:
return table.insert().prefix_with("OR IGNORE")
elif self.is_postgres:
return pg_insert(table).on_conflict_do_nothing()
else:
self.raise_unsupported_dialect()
def insert_or_replace(self, table, replace):
if self.is_sqlite:
return table.insert().prefix_with("OR REPLACE")
elif self.is_postgres:
insert = pg_insert(table)
return insert.on_conflict_do_update(
table.primary_key, set_={col: getattr(insert.excluded, col) for col in replace}
)
else:
self.raise_unsupported_dialect()
def has_table(self, table):
return inspect(self.engine).has_table(table)
def get_bulk_loader(self) -> 'BulkLoader':
return BulkLoader(self)
def reset_metrics(self):
self.stack = []
self.metrics = {}
def with_timer(self, timer_name: str) -> 'QueryContext':
self.current_timer_name = timer_name
return self
2020-07-12 00:18:33 +02:00
@property
def elapsed(self):
return time.perf_counter() - self.current_timer_time
2020-06-05 06:35:22 +02:00
def __enter__(self) -> 'QueryContext':
self.current_timer_time = time.perf_counter()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.current_timer_name = None
self.current_timer_time = 0
self.current_progress = None
def context(with_timer: str = None) -> 'QueryContext':
if isinstance(with_timer, str):
return _context.get().with_timer(with_timer)
return _context.get()
2020-07-07 04:42:15 +02:00
def set_postgres_settings(connection, _):
cursor = connection.cursor()
cursor.execute('SET work_mem="100MB";')
cursor.close()
def set_sqlite_settings(connection, _):
cursor = connection.cursor()
cursor.execute('PRAGMA journal_mode=WAL;')
cursor.close()
2020-06-05 06:35:22 +02:00
def initialize(
ledger: Ledger, message_queue: mp.Queue, stop_event: mp.Event,
2020-07-12 00:18:33 +02:00
track_metrics=False, block_and_filter=None):
2020-06-05 06:35:22 +02:00
url = ledger.conf.db_url_or_default
engine = create_engine(url)
2020-07-07 04:42:15 +02:00
if engine.name == "postgresql":
2020-07-12 00:18:33 +02:00
sqlalchemy_event.listen(engine, "connect", set_postgres_settings)
2020-07-07 04:42:15 +02:00
elif engine.name == "sqlite":
2020-07-12 00:18:33 +02:00
sqlalchemy_event.listen(engine, "connect", set_sqlite_settings)
2020-06-05 06:35:22 +02:00
connection = engine.connect()
if block_and_filter is not None:
blocked_streams, blocked_channels, filtered_streams, filtered_channels = block_and_filter
else:
blocked_streams = blocked_channels = filtered_streams = filtered_channels = {}
_context.set(
QueryContext(
pid=os.getpid(),
engine=engine, connection=connection,
ledger=ledger, message_queue=message_queue, stop_event=stop_event,
stack=[], metrics={}, is_tracking_metrics=track_metrics,
blocked_streams=blocked_streams, blocked_channels=blocked_channels,
filtered_streams=filtered_streams, filtered_channels=filtered_channels,
)
)
def uninitialize():
ctx = _context.get(None)
if ctx is not None:
if ctx.connection:
ctx.connection.close()
2020-06-28 04:24:59 +02:00
if ctx.engine:
ctx.engine.dispose()
2020-06-05 06:35:22 +02:00
_context.set(None)
2020-07-06 05:03:45 +02:00
class Event:
_events: List['Event'] = []
2020-07-12 00:18:33 +02:00
__slots__ = 'id', 'name', 'units'
2020-07-06 05:03:45 +02:00
2020-07-12 00:18:33 +02:00
def __init__(self, name: str, units: Tuple[str]):
2020-07-06 05:03:45 +02:00
self.name = name
2020-07-12 00:18:33 +02:00
self.units = units
2020-07-06 05:03:45 +02:00
@classmethod
def get_by_id(cls, event_id) -> 'Event':
return cls._events[event_id]
@classmethod
def get_by_name(cls, name) -> 'Event':
for event in cls._events:
if event.name == name:
return event
@classmethod
2020-07-12 00:18:33 +02:00
def add(cls, name: str, *units: str) -> 'Event':
2020-07-06 05:03:45 +02:00
assert cls.get_by_name(name) is None, f"Event {name} already exists."
2020-07-12 00:18:33 +02:00
assert name.count('.') == 3, f"Event {name} does not follow pattern of: [module].sync.[phase].[task]"
event = cls(name, units)
2020-07-06 05:03:45 +02:00
cls._events.append(event)
event.id = cls._events.index(event)
return event
2020-07-12 00:18:33 +02:00
def event_emitter(name: str, *units: str, throttle=1):
event = Event.add(name, *units)
2020-07-06 05:03:45 +02:00
def wrapper(f):
@functools.wraps(f)
def with_progress(*args, **kwargs):
2020-07-12 00:18:33 +02:00
with progress(event, throttle=throttle) as p:
2020-07-06 05:03:45 +02:00
return f(*args, **kwargs, p=p)
return with_progress
return wrapper
2020-06-05 06:35:22 +02:00
class ProgressPublisher(EventQueuePublisher):
def message_to_event(self, message):
2020-07-12 00:18:33 +02:00
total, extra = None, None
if len(message) == 3:
event_id, progress_id, done = message
elif len(message) == 5:
event_id, progress_id, done, total, extra = message
else:
raise TypeError("progress message must be tuple of 3 or 5 values.")
event = Event.get_by_id(event_id)
2020-06-05 06:35:22 +02:00
d = {
2020-07-06 05:03:45 +02:00
"event": event.name,
2020-07-12 00:18:33 +02:00
"data": {"id": progress_id, "done": done}
2020-06-05 06:35:22 +02:00
}
2020-07-12 00:18:33 +02:00
if total is not None:
d['data']['total'] = total
d['data']['units'] = event.units
if isinstance(extra, dict):
d['data'].update(extra)
2020-06-05 06:35:22 +02:00
return d
2020-06-19 20:28:34 +02:00
class BreakProgress(Exception):
"""Break out of progress when total is 0."""
2020-07-12 00:18:33 +02:00
class Progress:
2020-06-05 06:35:22 +02:00
2020-07-12 00:18:33 +02:00
def __init__(self, message_queue: mp.Queue, event: Event, throttle=1):
self.message_queue = message_queue
2020-06-05 06:35:22 +02:00
self.event = event
2020-07-12 00:18:33 +02:00
self.progress_id = 0
self.throttle = throttle
self.last_done = (0,)*len(event.units)
self.last_done_queued = (0,)*len(event.units)
self.totals = (0,)*len(event.units)
2020-06-05 06:35:22 +02:00
2020-07-12 00:18:33 +02:00
def __enter__(self) -> 'Progress':
2020-06-05 06:35:22 +02:00
return self
def __exit__(self, exc_type, exc_val, exc_tb):
2020-07-12 00:18:33 +02:00
if self.last_done != self.last_done_queued:
self.message_queue.put((self.event.id, self.progress_id, self.last_done))
self.last_done_queued = self.last_done
2020-06-19 20:28:34 +02:00
if exc_type == BreakProgress:
return True
2020-07-12 00:18:33 +02:00
if self.last_done != self.totals: # or exc_type is not None:
# TODO: add exception info into closing message if there is any
self.message_queue.put((
self.event.id, self.progress_id, (-1,)*len(self.event.units)
))
def start(self, *totals: int, progress_id=0, label=None, extra=None):
assert len(totals) == len(self.event.units), \
f"Totals {totals} do not match up with units {self.event.units}."
if not any(totals):
2020-06-19 20:28:34 +02:00
raise BreakProgress
2020-07-12 00:18:33 +02:00
self.totals = totals
self.progress_id = progress_id
extra = {} if extra is None else extra.copy()
if label is not None:
extra['label'] = label
self.step(*((0,)*len(totals)), force=True, extra=extra)
def step(self, *done: int, force=False, extra=None):
if done == ():
assert len(self.totals) == 1, "Incrementing step() only works with one unit progress."
done = (self.last_done[0]+1,)
assert len(done) == len(self.totals), \
f"Done elements {done} don't match total elements {self.totals}."
self.last_done = done
send_condition = force or extra is not None or (
# throttle rate of events being generated (only throttles first unit value)
(self.throttle == 1 or done[0] % self.throttle == 0) and
2020-06-05 06:35:22 +02:00
# deduplicate finish event by not sending a step where done == total
2020-07-12 00:18:33 +02:00
any(i < j for i, j in zip(done, self.totals)) and
# deduplicate same event
done != self.last_done_queued
2020-06-05 06:35:22 +02:00
)
if send_condition:
2020-07-12 00:18:33 +02:00
if extra is not None:
self.message_queue.put_nowait(
(self.event.id, self.progress_id, done, self.totals, extra)
)
else:
self.message_queue.put_nowait(
(self.event.id, self.progress_id, done)
)
self.last_done_queued = done
def add(self, *done: int, force=False, extra=None):
assert len(done) == len(self.last_done), \
f"Done elements {done} don't match total elements {self.last_done}."
self.step(
*(i+j for i, j in zip(self.last_done, done)),
force=force, extra=extra
)
def iter(self, items: List):
self.start(len(items))
for item in items:
yield item
self.step()
class ProgressContext(Progress):
def __init__(self, ctx: QueryContext, event: Event, throttle=1):
super().__init__(ctx.message_queue, event, throttle)
self.ctx = ctx
def __enter__(self) -> 'ProgressContext':
self.ctx.__enter__()
return self
2020-06-05 06:35:22 +02:00
2020-07-12 00:18:33 +02:00
def __exit__(self, exc_type, exc_val, exc_tb):
return any((
self.ctx.__exit__(exc_type, exc_val, exc_tb),
super().__exit__(exc_type, exc_val, exc_tb)
))
2020-06-05 06:35:22 +02:00
2020-07-12 00:18:33 +02:00
def progress(e: Event, throttle=1) -> ProgressContext:
2020-07-06 05:03:45 +02:00
ctx = context(e.name)
2020-07-12 00:18:33 +02:00
ctx.current_progress = ProgressContext(ctx, e, throttle=throttle)
2020-06-05 06:35:22 +02:00
return ctx.current_progress
class BulkLoader:
def __init__(self, ctx: QueryContext):
self.ctx = ctx
self.ledger = ctx.ledger
self.blocks = []
self.txs = []
self.txos = []
self.txis = []
2020-06-19 20:28:34 +02:00
self.supports = []
2020-06-05 06:35:22 +02:00
self.claims = []
self.tags = []
2020-06-19 20:28:34 +02:00
self.update_claims = []
self.delete_tags = []
2020-06-05 06:35:22 +02:00
@staticmethod
2020-06-19 20:28:34 +02:00
def block_to_row(block: Block) -> dict:
2020-06-05 06:35:22 +02:00
return {
'block_hash': block.block_hash,
'previous_hash': block.prev_block_hash,
'file_number': block.file_number,
'height': 0 if block.is_first_block else block.height,
2020-06-19 20:28:34 +02:00
'timestamp': block.timestamp,
2020-06-05 06:35:22 +02:00
}
@staticmethod
2020-06-19 20:28:34 +02:00
def tx_to_row(block_hash: bytes, tx: Transaction) -> dict:
2020-06-05 06:35:22 +02:00
row = {
'tx_hash': tx.hash,
'block_hash': block_hash,
'raw': tx.raw,
'height': tx.height,
'position': tx.position,
'is_verified': tx.is_verified,
2020-06-19 20:28:34 +02:00
'timestamp': tx.timestamp,
'day': tx.day,
2020-06-05 06:35:22 +02:00
'purchased_claim_hash': None,
}
txos = tx.outputs
if len(txos) >= 2 and txos[1].can_decode_purchase_data:
txos[0].purchase = txos[1]
row['purchased_claim_hash'] = txos[1].purchase_data.claim_hash
return row
@staticmethod
2020-06-19 20:28:34 +02:00
def txi_to_row(tx: Transaction, txi: Input) -> dict:
2020-06-05 06:35:22 +02:00
return {
'tx_hash': tx.hash,
'txo_hash': txi.txo_ref.hash,
'position': txi.position,
2020-07-06 05:03:45 +02:00
'height': tx.height,
2020-06-05 06:35:22 +02:00
}
2020-06-19 20:28:34 +02:00
def txo_to_row(self, tx: Transaction, txo: Output) -> dict:
2020-06-05 06:35:22 +02:00
row = {
'tx_hash': tx.hash,
'txo_hash': txo.hash,
'address': txo.get_address(self.ledger) if txo.has_address else None,
'position': txo.position,
'amount': txo.amount,
2020-06-22 01:51:09 +02:00
'height': tx.height,
2020-06-05 06:35:22 +02:00
'script_offset': txo.script.offset,
'script_length': txo.script.length,
'txo_type': 0,
'claim_id': None,
'claim_hash': None,
'claim_name': None,
'channel_hash': None,
2020-07-06 05:03:45 +02:00
'signature': None,
'signature_digest': None,
2020-06-22 01:51:09 +02:00
'public_key': None,
'public_key_hash': None
2020-06-05 06:35:22 +02:00
}
if txo.is_claim:
if txo.can_decode_claim:
claim = txo.claim
row['txo_type'] = TXO_TYPES.get(claim.claim_type, TXO_TYPES['stream'])
2020-06-22 01:51:09 +02:00
if claim.is_channel:
row['public_key'] = claim.channel.public_key_bytes
row['public_key_hash'] = self.ledger.address_to_hash160(
self.ledger.public_key_to_address(claim.channel.public_key_bytes)
)
2020-06-05 06:35:22 +02:00
else:
row['txo_type'] = TXO_TYPES['stream']
elif txo.is_support:
row['txo_type'] = TXO_TYPES['support']
elif txo.purchase is not None:
row['txo_type'] = TXO_TYPES['purchase']
row['claim_id'] = txo.purchased_claim_id
row['claim_hash'] = txo.purchased_claim_hash
if txo.script.is_claim_involved:
2020-07-12 00:18:33 +02:00
signable = txo.can_decode_signable
if signable and signable.is_signed:
row['channel_hash'] = signable.signing_channel_hash
row['signature'] = txo.get_encoded_signature()
row['signature_digest'] = txo.get_signature_digest(self.ledger)
2020-06-05 06:35:22 +02:00
row['claim_id'] = txo.claim_id
row['claim_hash'] = txo.claim_hash
try:
row['claim_name'] = txo.claim_name.replace('\x00', '')
2020-06-05 06:35:22 +02:00
except UnicodeDecodeError:
pass
return row
2020-07-06 05:03:45 +02:00
def claim_to_rows(
2020-07-12 00:18:33 +02:00
self, txo: Output, staked_support_amount: int, staked_support_count: int,
signature: bytes = None, signature_digest: bytes = None, channel_public_key: bytes = None,
) -> Tuple[dict, List]:
2020-07-06 05:03:45 +02:00
2020-07-12 00:18:33 +02:00
tx = txo.tx_ref
2020-07-06 05:03:45 +02:00
d = {
'claim_type': None,
2020-06-05 06:35:22 +02:00
'address': txo.get_address(self.ledger),
2020-07-06 05:03:45 +02:00
'txo_hash': txo.hash,
2020-06-05 06:35:22 +02:00
'amount': txo.amount,
2020-07-12 00:18:33 +02:00
'height': tx.height,
'timestamp': tx.timestamp,
2020-07-06 05:03:45 +02:00
# support
'staked_amount': txo.amount + staked_support_amount,
'staked_support_amount': staked_support_amount,
'staked_support_count': staked_support_count,
# basic metadata
2020-06-05 06:35:22 +02:00
'title': None,
'description': None,
2020-07-06 05:03:45 +02:00
'author': None,
2020-06-05 06:35:22 +02:00
# streams
'stream_type': None,
'media_type': None,
2020-07-06 05:03:45 +02:00
'duration': None,
'release_time': None,
2020-06-05 06:35:22 +02:00
'fee_amount': 0,
2020-06-22 01:51:09 +02:00
'fee_currency': None,
2020-06-05 06:35:22 +02:00
# reposts
'reposted_claim_hash': None,
2020-06-19 20:28:34 +02:00
# signed claims
'channel_hash': None,
2020-06-22 01:51:09 +02:00
'is_signature_valid': None,
2020-06-05 06:35:22 +02:00
}
2020-07-06 05:03:45 +02:00
claim = txo.can_decode_claim
if not claim:
return d, []
2020-06-05 06:35:22 +02:00
if claim.is_stream:
2020-07-06 05:03:45 +02:00
d['claim_type'] = TXO_TYPES['stream']
d['stream_type'] = STREAM_TYPES[guess_stream_type(d['media_type'])]
d['media_type'] = claim.stream.source.media_type
d['title'] = claim.stream.title.replace('\x00', '')
d['description'] = claim.stream.description.replace('\x00', '')
d['author'] = claim.stream.author.replace('\x00', '')
2020-06-05 06:35:22 +02:00
if claim.stream.video and claim.stream.video.duration:
2020-07-06 05:03:45 +02:00
d['duration'] = claim.stream.video.duration
2020-06-05 06:35:22 +02:00
if claim.stream.audio and claim.stream.audio.duration:
2020-07-06 05:03:45 +02:00
d['duration'] = claim.stream.audio.duration
2020-06-05 06:35:22 +02:00
if claim.stream.release_time:
2020-07-06 05:03:45 +02:00
d['release_time'] = claim.stream.release_time
2020-06-05 06:35:22 +02:00
if claim.stream.has_fee:
fee = claim.stream.fee
if isinstance(fee.amount, Decimal):
2020-07-06 05:03:45 +02:00
d['fee_amount'] = int(fee.amount*1000)
if isinstance(fee.currency, str):
d['fee_currency'] = fee.currency.lower()
2020-06-05 06:35:22 +02:00
elif claim.is_repost:
2020-07-06 05:03:45 +02:00
d['claim_type'] = TXO_TYPES['repost']
d['reposted_claim_hash'] = claim.repost.reference.claim_hash
2020-06-05 06:35:22 +02:00
elif claim.is_channel:
2020-07-06 05:03:45 +02:00
d['claim_type'] = TXO_TYPES['channel']
2020-06-19 20:28:34 +02:00
if claim.is_signed:
2020-07-06 05:03:45 +02:00
d['channel_hash'] = claim.signing_channel_hash
d['is_signature_valid'] = Output.is_signature_valid(
signature, signature_digest, channel_public_key
)
2020-06-05 06:35:22 +02:00
2020-07-06 05:03:45 +02:00
tags = []
if claim.message.tags:
claim_hash = txo.claim_hash
tags = [
{'claim_hash': claim_hash, 'tag': tag}
for tag in clean_tags(claim.message.tags)
]
2020-06-19 20:28:34 +02:00
2020-07-06 05:03:45 +02:00
return d, tags
2020-06-19 20:28:34 +02:00
2020-07-12 00:18:33 +02:00
def support_to_row(
self, txo: Output, channel_public_key: bytes = None,
signature: bytes = None, signature_digest: bytes = None
):
tx = txo.tx_ref
2020-07-06 05:03:45 +02:00
d = {
2020-06-22 01:51:09 +02:00
'txo_hash': txo.ref.hash,
2020-07-06 05:03:45 +02:00
'claim_hash': txo.claim_hash,
2020-06-19 20:28:34 +02:00
'address': txo.get_address(self.ledger),
'amount': txo.amount,
'height': tx.height,
2020-07-12 00:18:33 +02:00
'timestamp': tx.timestamp,
2020-06-22 01:51:09 +02:00
'emoji': None,
'channel_hash': None,
2020-07-12 00:18:33 +02:00
'is_signature_valid': None,
2020-06-19 20:28:34 +02:00
}
support = txo.can_decode_support
if support:
2020-07-06 05:03:45 +02:00
d['emoji'] = support.emoji
2020-06-19 20:28:34 +02:00
if support.is_signed:
2020-07-06 05:03:45 +02:00
d['channel_hash'] = support.signing_channel_hash
2020-07-12 00:18:33 +02:00
d['is_signature_valid'] = Output.is_signature_valid(
signature, signature_digest, channel_public_key
)
2020-07-06 05:03:45 +02:00
return d
def add_block(self, block: Block):
self.blocks.append(self.block_to_row(block))
for tx in block.txs:
self.add_transaction(block.block_hash, tx)
return self
def add_transaction(self, block_hash: bytes, tx: Transaction):
self.txs.append(self.tx_to_row(block_hash, tx))
for txi in tx.inputs:
if txi.coinbase is None:
self.txis.append(self.txi_to_row(tx, txi))
for txo in tx.outputs:
self.txos.append(self.txo_to_row(tx, txo))
2020-06-19 20:28:34 +02:00
return self
2020-06-05 06:35:22 +02:00
2020-07-12 00:18:33 +02:00
def add_support(self, txo: Output, **extra):
self.supports.append(self.support_to_row(txo, **extra))
2020-07-06 05:03:45 +02:00
def add_claim(
2020-07-12 00:18:33 +02:00
self, txo: Output, short_url: str,
creation_height: int, activation_height: int, expiration_height: int,
takeover_height: int = None, channel_url: str = None, **extra
):
2020-07-06 05:03:45 +02:00
try:
claim_name = txo.claim_name.replace('\x00', '')
normalized_name = txo.normalized_name
except UnicodeDecodeError:
return self
d, tags = self.claim_to_rows(txo, **extra)
d['claim_hash'] = txo.claim_hash
d['claim_id'] = txo.claim_id
d['claim_name'] = claim_name
d['normalized'] = normalized_name
d['short_url'] = short_url
d['creation_height'] = creation_height
d['activation_height'] = activation_height
d['expiration_height'] = expiration_height
d['takeover_height'] = takeover_height
d['is_controlling'] = takeover_height is not None
if d['is_signature_valid']:
d['canonical_url'] = channel_url + '/' + short_url
else:
d['canonical_url'] = None
self.claims.append(d)
self.tags.extend(tags)
2020-06-05 06:35:22 +02:00
return self
2020-07-12 00:18:33 +02:00
def update_claim(self, txo: Output, channel_url: str = None, **extra):
2020-07-06 05:03:45 +02:00
d, tags = self.claim_to_rows(txo, **extra)
d['pk'] = txo.claim_hash
d['channel_url'] = channel_url
d['set_canonical_url'] = d['is_signature_valid']
self.update_claims.append(d)
self.delete_tags.append({'pk': txo.claim_hash})
self.tags.extend(tags)
return self
def get_queries(self):
return (
2020-06-19 20:28:34 +02:00
(Block.insert(), self.blocks),
(TX.insert(), self.txs),
(TXO.insert(), self.txos),
(TXI.insert(), self.txis),
(Claim.insert(), self.claims),
2020-07-06 05:03:45 +02:00
(Tag.delete().where(Tag.c.claim_hash == bindparam('pk')), self.delete_tags),
(Claim.update().where(Claim.c.claim_hash == bindparam('pk')).values(
canonical_url=case([
(bindparam('set_canonical_url'), bindparam('channel_url') + '/' + Claim.c.short_url)
], else_=None)
), self.update_claims),
2020-06-19 20:28:34 +02:00
(Tag.insert(), self.tags),
(Support.insert(), self.supports),
2020-06-05 06:35:22 +02:00
)
2020-07-12 00:18:33 +02:00
def flush(self, return_row_count_for_table) -> int:
2020-07-06 05:03:45 +02:00
execute = self.ctx.connection.execute
done = 0
for sql, rows in self.get_queries():
if not rows:
continue
if self.ctx.is_postgres and isinstance(sql, Insert):
self.ctx.pg_copy(sql.table, rows)
else:
execute(sql, rows)
2020-07-12 00:18:33 +02:00
if sql.table == return_row_count_for_table:
2020-07-06 05:03:45 +02:00
done += len(rows)
rows.clear()
return done