lbry-sdk/lbry/wallet/server/db/reader.py

583 lines
21 KiB
Python
Raw Permalink Normal View History

import time
import struct
2019-12-07 18:13:13 -05:00
import apsw
import logging
2019-09-24 12:53:23 -03:00
from operator import itemgetter
from typing import Tuple, List, Dict, Union, Type, Optional
from binascii import unhexlify
from decimal import Decimal
from contextvars import ContextVar
from functools import wraps
from dataclasses import dataclass
2020-01-02 22:18:49 -05:00
from lbry.wallet.database import query, interpolate
from lbry.schema.url import URL, normalize_name
from lbry.schema.tags import clean_tags
from lbry.schema.result import Outputs, Censor
2020-01-02 22:18:49 -05:00
from lbry.wallet import Ledger, RegTestLedger
from .common import CLAIM_TYPES, STREAM_TYPES, COMMON_TAGS
2019-11-14 14:31:49 -05:00
from .full_text_search import FTS_ORDER_BY
2019-12-07 18:13:13 -05:00
class SQLiteOperationalError(apsw.Error):
2019-07-17 21:50:20 -04:00
def __init__(self, metrics):
super().__init__('sqlite query errored')
self.metrics = metrics
2019-12-07 18:13:13 -05:00
class SQLiteInterruptedError(apsw.InterruptError):
def __init__(self, metrics):
super().__init__('sqlite query interrupted')
self.metrics = metrics
ATTRIBUTE_ARRAY_MAX_LENGTH = 100
INTEGER_PARAMS = {
'height', 'creation_height', 'activation_height', 'expiration_height',
'timestamp', 'creation_timestamp', 'release_time', 'fee_amount',
'tx_position', 'channel_join', 'reposted',
'amount', 'effective_amount', 'support_amount',
'trending_group', 'trending_mixed',
'trending_local', 'trending_global',
}
SEARCH_PARAMS = {
2019-11-14 14:31:49 -05:00
'name', 'text', 'claim_id', 'claim_ids', 'txid', 'nout', 'channel', 'channel_ids', 'not_channel_ids',
'public_key_id', 'claim_type', 'stream_types', 'media_types', 'fee_currency',
'has_channel_signature', 'signature_valid',
'any_tags', 'all_tags', 'not_tags', 'reposted_claim_id',
'any_locations', 'all_locations', 'not_locations',
'any_languages', 'all_languages', 'not_languages',
'is_controlling', 'limit', 'offset', 'order_by',
'no_totals',
} | INTEGER_PARAMS
ORDER_FIELDS = {
'name', 'claim_hash'
} | INTEGER_PARAMS
@dataclass
class ReaderState:
2019-12-07 18:13:13 -05:00
db: apsw.Connection
stack: List[List]
metrics: Dict
is_tracking_metrics: bool
2020-01-02 22:18:49 -05:00
ledger: Type[Ledger]
query_timeout: float
log: logging.Logger
blocked_streams: Dict
blocked_channels: Dict
filtered_streams: Dict
filtered_channels: Dict
def close(self):
self.db.close()
def reset_metrics(self):
self.stack = []
self.metrics = {}
def set_query_timeout(self):
stop_at = time.perf_counter() + self.query_timeout
def interruptor():
if time.perf_counter() >= stop_at:
self.db.interrupt()
return
2019-12-07 18:13:13 -05:00
self.db.setprogresshandler(interruptor, 100)
def get_resolve_censor(self) -> Censor:
return Censor(self.blocked_streams, self.blocked_channels)
def get_search_censor(self) -> Censor:
return Censor(self.filtered_streams, self.filtered_channels)
ctx: ContextVar[Optional[ReaderState]] = ContextVar('ctx')
def row_factory(cursor, row):
return {
k[0]: (set(row[i].split(',')) if k[0] == 'tags' else row[i])
for i, k in enumerate(cursor.getdescription())
}
def initializer(log, _path, _ledger_name, query_timeout, _measure=False, block_and_filter=None):
2019-12-07 18:13:13 -05:00
db = apsw.Connection(_path, flags=apsw.SQLITE_OPEN_READONLY | apsw.SQLITE_OPEN_URI)
db.setrowtrace(row_factory)
if block_and_filter:
blocked_streams, blocked_channels, filtered_streams, filtered_channels = block_and_filter
else:
blocked_streams = blocked_channels = filtered_streams = filtered_channels = {}
ctx.set(
ReaderState(
db=db, stack=[], metrics={}, is_tracking_metrics=_measure,
2020-01-02 22:18:49 -05:00
ledger=Ledger if _ledger_name == 'mainnet' else RegTestLedger,
query_timeout=query_timeout, log=log,
blocked_streams=blocked_streams, blocked_channels=blocked_channels,
filtered_streams=filtered_streams, filtered_channels=filtered_channels,
)
)
def cleanup():
ctx.get().close()
ctx.set(None)
def measure(func):
@wraps(func)
def wrapper(*args, **kwargs):
state = ctx.get()
if not state.is_tracking_metrics:
return func(*args, **kwargs)
2019-07-17 21:50:20 -04:00
metric = {}
state.metrics.setdefault(func.__name__, []).append(metric)
state.stack.append([])
start = time.perf_counter()
try:
return func(*args, **kwargs)
finally:
elapsed = int((time.perf_counter()-start)*1000)
2019-07-17 21:50:20 -04:00
metric['total'] = elapsed
metric['isolated'] = (elapsed-sum(state.stack.pop()))
if state.stack:
state.stack[-1].append(elapsed)
return wrapper
def reports_metrics(func):
@wraps(func)
def wrapper(*args, **kwargs):
state = ctx.get()
if not state.is_tracking_metrics:
return func(*args, **kwargs)
state.reset_metrics()
r = func(*args, **kwargs)
return r, state.metrics
return wrapper
@reports_metrics
def search_to_bytes(constraints) -> Union[bytes, Tuple[bytes, Dict]]:
return encode_result(search(constraints))
@reports_metrics
def resolve_to_bytes(urls) -> Union[bytes, Tuple[bytes, Dict]]:
return encode_result(resolve(urls))
def encode_result(result):
return Outputs.to_bytes(*result)
@measure
2020-01-18 20:58:30 -05:00
def execute_query(sql, values, row_offset: int, row_limit: int, censor: Censor) -> List:
context = ctx.get()
context.set_query_timeout()
try:
c = context.db.cursor()
def row_filter(cursor, row):
2020-01-18 20:58:30 -05:00
nonlocal row_offset
row = row_factory(cursor, row)
if len(row) > 1 and censor.censor(row):
return
2020-01-18 20:58:30 -05:00
if row_offset:
row_offset -= 1
return
return row
c.setrowtrace(row_filter)
i, rows = 0, []
for row in c.execute(sql, values):
i += 1
rows.append(row)
if i >= row_limit:
break
return rows
2019-12-07 18:13:13 -05:00
except apsw.Error as err:
2019-07-17 21:50:20 -04:00
plain_sql = interpolate(sql, values)
2019-07-17 22:18:17 -04:00
if context.is_tracking_metrics:
2019-07-18 19:15:01 -04:00
context.metrics['execute_query'][-1]['sql'] = plain_sql
if isinstance(err, apsw.InterruptError):
2019-07-17 21:50:20 -04:00
context.log.warning("interrupted slow sqlite query:\n%s", plain_sql)
raise SQLiteInterruptedError(context.metrics)
2019-07-17 21:50:20 -04:00
context.log.exception('failed running query', exc_info=err)
raise SQLiteOperationalError(context.metrics)
def claims_query(cols, for_count=False, **constraints) -> Tuple[str, Dict]:
if 'order_by' in constraints:
2020-01-18 20:58:30 -05:00
order_by_parts = constraints['order_by']
if isinstance(order_by_parts, str):
order_by_parts = [order_by_parts]
sql_order_by = []
2020-01-18 20:58:30 -05:00
for order_by in order_by_parts:
is_asc = order_by.startswith('^')
column = order_by[1:] if is_asc else order_by
if column not in ORDER_FIELDS:
raise NameError(f'{column} is not a valid order_by field')
if column == 'name':
column = 'normalized'
sql_order_by.append(
f"claim.{column} ASC" if is_asc else f"claim.{column} DESC"
)
constraints['order_by'] = sql_order_by
ops = {'<=': '__lte', '>=': '__gte', '<': '__lt', '>': '__gt'}
for constraint in INTEGER_PARAMS:
if constraint in constraints:
value = constraints.pop(constraint)
postfix = ''
if isinstance(value, str):
if len(value) >= 2 and value[:2] in ops:
postfix, value = ops[value[:2]], value[2:]
elif len(value) >= 1 and value[0] in ops:
postfix, value = ops[value[0]], value[1:]
if constraint == 'fee_amount':
value = Decimal(value)*1000
constraints[f'claim.{constraint}{postfix}'] = int(value)
if constraints.pop('is_controlling', False):
if {'sequence', 'amount_order'}.isdisjoint(constraints):
for_count = False
constraints['claimtrie.claim_hash__is_not_null'] = ''
if 'sequence' in constraints:
constraints['order_by'] = 'claim.activation_height ASC'
constraints['offset'] = int(constraints.pop('sequence')) - 1
constraints['limit'] = 1
if 'amount_order' in constraints:
constraints['order_by'] = 'claim.effective_amount DESC'
constraints['offset'] = int(constraints.pop('amount_order')) - 1
constraints['limit'] = 1
if 'claim_id' in constraints:
claim_id = constraints.pop('claim_id')
if len(claim_id) == 40:
constraints['claim.claim_id'] = claim_id
else:
constraints['claim.claim_id__like'] = f'{claim_id[:40]}%'
elif 'claim_ids' in constraints:
constraints['claim.claim_id__in'] = constraints.pop('claim_ids')
if 'reposted_claim_id' in constraints:
2019-12-07 18:13:13 -05:00
constraints['claim.reposted_claim_hash'] = unhexlify(constraints.pop('reposted_claim_id'))[::-1]
if 'name' in constraints:
constraints['claim.normalized'] = normalize_name(constraints.pop('name'))
if 'public_key_id' in constraints:
2019-12-07 18:13:13 -05:00
constraints['claim.public_key_hash'] = (
ctx.get().ledger.address_to_hash160(constraints.pop('public_key_id')))
if 'channel_hash' in constraints:
2019-12-07 18:13:13 -05:00
constraints['claim.channel_hash'] = constraints.pop('channel_hash')
if 'channel_ids' in constraints:
channel_ids = constraints.pop('channel_ids')
if channel_ids:
constraints['claim.channel_hash__in'] = [
2019-12-07 18:13:13 -05:00
unhexlify(cid)[::-1] for cid in channel_ids
]
if 'not_channel_ids' in constraints:
not_channel_ids = constraints.pop('not_channel_ids')
if not_channel_ids:
not_channel_ids_binary = [
unhexlify(ncid)[::-1] for ncid in not_channel_ids
]
if constraints.get('has_channel_signature', False):
constraints['claim.channel_hash__not_in'] = not_channel_ids_binary
else:
constraints['null_or_not_channel__or'] = {
'claim.signature_valid__is_null': True,
'claim.channel_hash__not_in': not_channel_ids_binary
}
if 'signature_valid' in constraints:
has_channel_signature = constraints.pop('has_channel_signature', False)
if has_channel_signature:
constraints['claim.signature_valid'] = constraints.pop('signature_valid')
else:
constraints['null_or_signature__or'] = {
'claim.signature_valid__is_null': True,
'claim.signature_valid': constraints.pop('signature_valid')
}
elif constraints.pop('has_channel_signature', False):
constraints['claim.signature_valid__is_not_null'] = True
if 'txid' in constraints:
tx_hash = unhexlify(constraints.pop('txid'))[::-1]
nout = constraints.pop('nout', 0)
2019-12-07 18:13:13 -05:00
constraints['claim.txo_hash'] = tx_hash + struct.pack('<I', nout)
if 'claim_type' in constraints:
constraints['claim.claim_type'] = CLAIM_TYPES[constraints.pop('claim_type')]
if 'stream_types' in constraints:
stream_types = constraints.pop('stream_types')
if stream_types:
constraints['claim.stream_type__in'] = [
STREAM_TYPES[stream_type] for stream_type in stream_types
]
if 'media_types' in constraints:
media_types = constraints.pop('media_types')
if media_types:
constraints['claim.media_type__in'] = media_types
if 'fee_currency' in constraints:
constraints['claim.fee_currency'] = constraints.pop('fee_currency').lower()
_apply_constraints_for_array_attributes(constraints, 'tag', clean_tags, for_count)
_apply_constraints_for_array_attributes(constraints, 'language', lambda _: _, for_count)
_apply_constraints_for_array_attributes(constraints, 'location', lambda _: _, for_count)
2019-11-14 14:31:49 -05:00
if 'text' in constraints:
constraints["search"] = constraints.pop("text")
constraints["order_by"] = FTS_ORDER_BY
select = f"SELECT {cols} FROM search JOIN claim ON (search.rowid=claim.rowid)"
else:
select = f"SELECT {cols} FROM claim"
if not for_count:
select += " LEFT JOIN claimtrie USING (claim_hash)"
return query(select, **constraints)
def select_claims(censor: Censor, cols: str, for_count=False, **constraints) -> List:
if 'channel' in constraints:
channel_url = constraints.pop('channel')
match = resolve_url(channel_url)
2019-12-07 18:13:13 -05:00
if isinstance(match, dict):
constraints['channel_hash'] = match['claim_hash']
else:
return [{'row_count': 0}] if cols == 'count(*) as row_count' else []
2020-01-18 20:58:30 -05:00
row_offset = constraints.pop('offset', 0)
row_limit = constraints.pop('limit', 20)
sql, values = claims_query(cols, for_count, **constraints)
return execute_query(sql, values, row_offset, row_limit, censor)
@measure
def count_claims(**constraints) -> int:
constraints.pop('offset', None)
constraints.pop('limit', None)
constraints.pop('order_by', None)
count = select_claims(Censor(), 'count(*) as row_count', for_count=True, **constraints)
2019-12-07 18:13:13 -05:00
return count[0]['row_count']
def search_claims(censor: Censor, **constraints) -> List:
return select_claims(
censor,
"""
claimtrie.claim_hash as is_controlling,
claimtrie.last_take_over_height,
claim.claim_hash, claim.txo_hash,
claim.claims_in_channel, claim.reposted,
claim.height, claim.creation_height,
claim.activation_height, claim.expiration_height,
claim.effective_amount, claim.support_amount,
claim.trending_group, claim.trending_mixed,
claim.trending_local, claim.trending_global,
claim.short_url, claim.canonical_url,
claim.channel_hash, claim.reposted_claim_hash,
claim.signature_valid
""", **constraints
)
def _get_referenced_rows(censor: Censor, txo_rows: List[dict]):
repost_hashes = set(filter(None, map(itemgetter('reposted_claim_hash'), txo_rows)))
channel_hashes = set(filter(None, map(itemgetter('channel_hash'), txo_rows)))
reposted_txos = []
if repost_hashes:
reposted_txos = search_claims(censor, **{'claim.claim_hash__in': repost_hashes})
channel_hashes |= set(filter(None, map(itemgetter('channel_hash'), reposted_txos)))
channel_txos = []
if channel_hashes:
channel_txos = search_claims(censor, **{'claim.claim_hash__in': channel_hashes})
# channels must come first for client side inflation to work properly
return channel_txos + reposted_txos
@measure
def search(constraints) -> Tuple[List, List, int, int, Censor]:
assert set(constraints).issubset(SEARCH_PARAMS), \
f"Search query contains invalid arguments: {set(constraints).difference(SEARCH_PARAMS)}"
total = None
if not constraints.pop('no_totals', False):
total = count_claims(**constraints)
constraints['offset'] = abs(constraints.get('offset', 0))
constraints['limit'] = min(abs(constraints.get('limit', 10)), 50)
if 'order_by' not in constraints:
constraints['order_by'] = ["claim_hash"]
context = ctx.get()
search_censor = context.get_search_censor()
txo_rows = search_claims(search_censor, **constraints)
extra_txo_rows = _get_referenced_rows(context.get_resolve_censor(), txo_rows)
return txo_rows, extra_txo_rows, constraints['offset'], total, search_censor
@measure
def resolve(urls) -> Tuple[List, List]:
txo_rows = [resolve_url(raw_url) for raw_url in urls]
extra_txo_rows = _get_referenced_rows(
ctx.get().get_resolve_censor(), [r for r in txo_rows if isinstance(r, dict)]
)
return txo_rows, extra_txo_rows
@measure
def resolve_url(raw_url):
censor = ctx.get().get_resolve_censor()
try:
url = URL.parse(raw_url)
except ValueError as e:
return e
channel = None
if url.has_channel:
query = url.channel.to_dict()
if set(query) == {'name'}:
query['is_controlling'] = True
else:
query['order_by'] = ['^creation_height']
matches = search_claims(censor, **query, limit=1)
if matches:
channel = matches[0]
else:
return LookupError(f'Could not find channel in "{raw_url}".')
if url.has_stream:
query = url.stream.to_dict()
if channel is not None:
if set(query) == {'name'}:
# temporarily emulate is_controlling for claims in channel
query['order_by'] = ['effective_amount', '^height']
else:
query['order_by'] = ['^channel_join']
query['channel_hash'] = channel['claim_hash']
query['signature_valid'] = 1
elif set(query) == {'name'}:
query['is_controlling'] = 1
matches = search_claims(censor, **query, limit=1)
if matches:
return matches[0]
else:
return LookupError(f'Could not find stream in "{raw_url}".')
return channel
def _apply_constraints_for_array_attributes(constraints, attr, cleaner, for_count=False):
2019-07-16 17:02:40 -04:00
any_items = set(cleaner(constraints.pop(f'any_{attr}s', []))[:ATTRIBUTE_ARRAY_MAX_LENGTH])
all_items = set(cleaner(constraints.pop(f'all_{attr}s', []))[:ATTRIBUTE_ARRAY_MAX_LENGTH])
2020-01-18 21:23:38 -05:00
not_items = set(cleaner(constraints.pop(f'not_{attr}s', []))[:ATTRIBUTE_ARRAY_MAX_LENGTH])
all_items = {item for item in all_items if item not in not_items}
any_items = {item for item in any_items if item not in not_items}
2019-07-16 17:02:40 -04:00
any_queries = {}
if attr == 'tag':
common_tags = any_items & COMMON_TAGS.keys()
if common_tags:
any_items -= common_tags
if len(common_tags) < 5:
for item in common_tags:
index_name = COMMON_TAGS[item]
any_queries[f'#_common_tag_{index_name}'] = f"""
EXISTS(
SELECT 1 FROM tag INDEXED BY tag_{index_name}_idx WHERE claim.claim_hash=tag.claim_hash
AND tag = '{item}'
)
"""
elif len(common_tags) >= 5:
constraints.update({
f'$any_common_tag{i}': item for i, item in enumerate(common_tags)
})
values = ', '.join(
f':$any_common_tag{i}' for i in range(len(common_tags))
)
any_queries[f'#_any_common_tags'] = f"""
EXISTS(
SELECT 1 FROM tag WHERE claim.claim_hash=tag.claim_hash
AND tag IN ({values})
)
"""
if any_items:
constraints.update({
f'$any_{attr}{i}': item for i, item in enumerate(any_items)
})
values = ', '.join(
f':$any_{attr}{i}' for i in range(len(any_items))
)
if for_count or attr == 'tag':
any_queries[f'claim.claim_hash__in#_any_{attr}'] = f"""
SELECT claim_hash FROM {attr} WHERE {attr} IN ({values})
"""
else:
any_queries[f'#_any_{attr}'] = f"""
EXISTS(
SELECT 1 FROM {attr} WHERE
claim.claim_hash={attr}.claim_hash
AND {attr} IN ({values})
)
"""
if len(any_queries) == 1:
constraints.update(any_queries)
elif len(any_queries) > 1:
constraints[f'ORed_{attr}_queries__any'] = any_queries
if all_items:
constraints[f'$all_{attr}_count'] = len(all_items)
constraints.update({
f'$all_{attr}{i}': item for i, item in enumerate(all_items)
})
values = ', '.join(
f':$all_{attr}{i}' for i in range(len(all_items))
)
if for_count:
constraints[f'claim.claim_hash__in#_all_{attr}'] = f"""
SELECT claim_hash FROM {attr} WHERE {attr} IN ({values})
GROUP BY claim_hash HAVING COUNT({attr}) = :$all_{attr}_count
"""
else:
constraints[f'#_all_{attr}'] = f"""
{len(all_items)}=(
SELECT count(*) FROM {attr} WHERE
claim.claim_hash={attr}.claim_hash
AND {attr} IN ({values})
)
"""
2020-01-18 21:23:38 -05:00
if not_items:
constraints.update({
f'$not_{attr}{i}': item for i, item in enumerate(not_items)
})
values = ', '.join(
f':$not_{attr}{i}' for i in range(len(not_items))
)
if for_count:
constraints[f'claim.claim_hash__not_in#_not_{attr}'] = f"""
SELECT claim_hash FROM {attr} WHERE {attr} IN ({values})
"""
else:
constraints[f'#_not_{attr}'] = f"""
NOT EXISTS(
SELECT 1 FROM {attr} WHERE
claim.claim_hash={attr}.claim_hash
AND {attr} IN ({values})
)
"""