import logging from typing import Tuple, List, Set, Iterator, Optional from sqlalchemy import func from sqlalchemy.future import select from lbry.crypto.hash import hash160 from lbry.crypto.bip32 import PubKey from ..utils import query from ..query_context import context from ..tables import TXO, PubkeyAddress, AccountAddress from .filters import ( get_filter_matchers, get_filter_matchers_at_granularity, has_filter_range, get_tx_matchers_for_missing_txs, ) log = logging.getLogger(__name__) class DatabaseAddressIterator: def __init__(self, account_id, chain): self.account_id = account_id self.chain = chain self.n = -1 def __iter__(self) -> Iterator[Tuple[bytes, int, bool]]: with context().connect_streaming() as c: sql = ( select( AccountAddress.c.pubkey, AccountAddress.c.n ).where( (AccountAddress.c.account == self.account_id) & (AccountAddress.c.chain == self.chain) ).order_by(AccountAddress.c.n) ) for row in c.execute(sql): self.n = row['n'] yield hash160(row['pubkey']), self.n, False class PersistingAddressIterator(DatabaseAddressIterator): def __init__(self, account_id, chain, pubkey_bytes, chain_code, depth): super().__init__(account_id, chain) self.pubkey_bytes = pubkey_bytes self.chain_code = chain_code self.depth = depth self.pubkey_buffer = [] def flush(self): if self.pubkey_buffer: add_keys([{ 'account': self.account_id, 'address': k.address, 'chain': self.chain, 'pubkey': k.pubkey_bytes, 'chain_code': k.chain_code, 'n': k.n, 'depth': k.depth } for k in self.pubkey_buffer]) self.pubkey_buffer.clear() def __enter__(self) -> 'PersistingAddressIterator': return self def __exit__(self, exc_type, exc_val, exc_tb): self.flush() def __iter__(self) -> Iterator[Tuple[bytes, int, bool]]: yield from super().__iter__() pubkey = PubKey(context().ledger, self.pubkey_bytes, self.chain_code, 0, self.depth) while True: self.n += 1 pubkey_child = pubkey.child(self.n) self.pubkey_buffer.append(pubkey_child) if len(self.pubkey_buffer) >= 900: self.flush() yield hash160(pubkey_child.pubkey_bytes), self.n, True def generate_addresses_using_filters(best_height, allowed_gap, address_manager) -> Set: need, have = set(), set() matchers = get_filter_matchers(best_height) with PersistingAddressIterator(*address_manager) as addresses: gap = 0 for address_hash, n, is_new in addresses: # pylint: disable=unused-variable gap += 1 address_bytes = bytearray(address_hash) for matcher, filter_range in matchers: if matcher.Match(address_bytes): gap = 0 if filter_range not in need and filter_range not in have: if has_filter_range(*filter_range): have.add(filter_range) else: need.add(filter_range) if gap >= allowed_gap: break return need def get_missing_sub_filters_for_addresses(granularity, address_manager): need = set() for matcher, filter_range in get_filter_matchers_at_granularity(granularity): for address_hash, _, _ in DatabaseAddressIterator(*address_manager): address_bytes = bytearray(address_hash) if matcher.Match(address_bytes) and not has_filter_range(*filter_range): need.add(filter_range) break return need def get_missing_tx_for_addresses(address_manager): need = set() for tx_hash, matcher in get_tx_matchers_for_missing_txs(): for address_hash, _, _ in DatabaseAddressIterator(*address_manager): address_bytes = bytearray(address_hash) if matcher.Match(address_bytes): need.add(tx_hash) break return need def update_address_used_times(addresses): context().execute( PubkeyAddress.update() .values(used_times=( select(func.count(TXO.c.address)) .where((TXO.c.address == PubkeyAddress.c.address)), )) .where(PubkeyAddress.c.address._in(addresses)) ) def select_addresses(cols, **constraints): return context().fetchall(query( [AccountAddress, PubkeyAddress], select(*cols).select_from(PubkeyAddress.join(AccountAddress)), **constraints )) def get_addresses(cols=None, include_total=False, **constraints) -> Tuple[List[dict], Optional[int]]: if cols is None: cols = ( PubkeyAddress.c.address, PubkeyAddress.c.used_times, AccountAddress.c.account, AccountAddress.c.chain, AccountAddress.c.pubkey, AccountAddress.c.chain_code, AccountAddress.c.n, AccountAddress.c.depth ) return ( select_addresses(cols, **constraints), get_address_count(**constraints) if include_total else None ) def get_address_count(**constraints): count = select_addresses([func.count().label("total")], **constraints) return count[0]["total"] or 0 def get_all_addresses(): return [r["address"] for r in context().fetchall(select(PubkeyAddress.c.address))] def add_keys(pubkeys): c = context() current_limit = c.variable_limit // len(pubkeys[0]) # (overall limit) // (maximum on a query) for start in range(0, len(pubkeys), current_limit - 1): batch = pubkeys[start:(start + current_limit - 1)] c.execute(c.insert_or_ignore(PubkeyAddress).values([{'address': k['address']} for k in batch])) c.execute(c.insert_or_ignore(AccountAddress).values(batch))