diff --git a/lbry/db/database.py b/lbry/db/database.py index aed5beab3..9e965095d 100644 --- a/lbry/db/database.py +++ b/lbry/db/database.py @@ -254,12 +254,21 @@ class Database: async def get_missing_required_filters(self, height) -> Dict[int, Tuple[int, int]]: return await self.run(q.get_missing_required_filters, height) + async def get_missing_sub_filters_for_addresses(self, granularity, address_manager): + return await self.run(q.get_missing_sub_filters_for_addresses, granularity, address_manager) + + async def get_missing_tx_for_addresses(self, address_manager): + return await self.run(q.get_missing_tx_for_addresses, address_manager) + async def insert_block(self, block): return await self.run(q.insert_block, block) async def insert_block_filter(self, height: int, factor: int, address_filter: bytes): return await self.run(q.insert_block_filter, height, factor, address_filter) + async def insert_tx_filter(self, tx_hash: bytes, height: int, address_filter: bytes): + return await self.run(q.insert_tx_filter, tx_hash, height, address_filter) + async def insert_transaction(self, block_hash, tx): return await self.run(q.insert_transaction, block_hash, tx) @@ -316,6 +325,11 @@ class Database: 'depth': k.depth } for k in pubkeys]) + async def generate_addresses_using_filters(self, best_height, allowed_gap, address_manager): + return await self.run( + q.generate_addresses_using_filters, best_height, allowed_gap, address_manager + ) + async def get_transactions(self, **constraints) -> Result[Transaction]: return await self.fetch_result(q.get_transactions, **constraints) diff --git a/lbry/db/queries/address.py b/lbry/db/queries/address.py index 382d65900..3b4f76f95 100644 --- a/lbry/db/queries/address.py +++ b/lbry/db/queries/address.py @@ -10,7 +10,10 @@ from lbry.crypto.bip32 import PubKey from ..utils import query from ..query_context import context from ..tables import TXO, PubkeyAddress, AccountAddress -from .filters import get_filter_matchers, get_filter_matchers_at_granularity, has_sub_filters +from .filters import ( + get_filter_matchers, get_filter_matchers_at_granularity, has_filter_range, + get_tx_matchers_for_missing_txs, +) log = logging.getLogger(__name__) @@ -87,15 +90,14 @@ def generate_addresses_using_filters(best_height, allowed_gap, address_manager) for address_hash, n, is_new in addresses: gap += 1 address_bytes = bytearray(address_hash) - for granularity, height, matcher in matchers: + for granularity, height, matcher, filter_range in matchers: if matcher.Match(address_bytes): gap = 0 - match = (granularity, height) - if match not in need and match not in have: - if has_sub_filters(granularity, height): - have.add(match) + if filter_range not in need and filter_range not in have: + if has_filter_range(*filter_range): + have.add(filter_range) else: - need.add(match) + need.add(filter_range) if gap >= allowed_gap: break return need @@ -103,13 +105,23 @@ def generate_addresses_using_filters(best_height, allowed_gap, address_manager) def get_missing_sub_filters_for_addresses(granularity, address_manager): need = set() - with DatabaseAddressIterator(*address_manager) as addresses: - for height, matcher in get_filter_matchers_at_granularity(granularity): - for address_hash, n, is_new in addresses: - address_bytes = bytearray(address_hash) - if matcher.Match(address_bytes) and not has_sub_filters(granularity, height): - need.add((height, granularity)) - break + for height, matcher, filter_range in get_filter_matchers_at_granularity(granularity): + for address_hash, n, is_new in DatabaseAddressIterator(*address_manager): + address_bytes = bytearray(address_hash) + if matcher.Match(address_bytes) and not has_filter_range(*filter_range): + need.add(filter_range) + break + return need + + +def get_missing_tx_for_addresses(address_manager): + need = set() + for tx_hash, matcher in get_tx_matchers_for_missing_txs(): + for address_hash, n, is_new in DatabaseAddressIterator(*address_manager): + address_bytes = bytearray(address_hash) + if matcher.Match(address_bytes): + need.add(tx_hash) + break return need diff --git a/lbry/db/queries/filters.py b/lbry/db/queries/filters.py index 9047d1ee7..e192ccbc2 100644 --- a/lbry/db/queries/filters.py +++ b/lbry/db/queries/filters.py @@ -1,5 +1,5 @@ from math import log10 -from typing import Dict, List, Tuple, Optional +from typing import Dict, List, Tuple, Set, Optional from sqlalchemy import between, func, or_ from sqlalchemy.future import select @@ -7,29 +7,32 @@ from sqlalchemy.future import select from lbry.blockchain.block import PyBIP158, get_address_filter from ..query_context import context -from ..tables import BlockFilter, TXFilter +from ..tables import BlockFilter, TXFilter, TX def has_filters(): return context().has_records(BlockFilter) -def has_sub_filters(granularity: int, height: int): +def get_sub_filter_range(granularity: int, height: int): + end = height if granularity >= 3: - sub_filter_size = 10**(granularity-1) - sub_filters_count = context().fetchtotal( - (BlockFilter.c.factor == granularity-1) & - between(BlockFilter.c.height, height, height + sub_filter_size * 9) - ) - return sub_filters_count == 10 + end = height + 10**(granularity-1) * 9 elif granularity == 2: - sub_filters_count = context().fetchtotal( - (BlockFilter.c.factor == 1) & - between(BlockFilter.c.height, height, height + 99) + end = height + 99 + return granularity - 1, height, end + + +def has_filter_range(factor: int, start: int, end: int): + if factor >= 1: + filters = context().fetchtotal( + (BlockFilter.c.factor == factor) & + between(BlockFilter.c.height, start, end) ) - return sub_filters_count == 100 - elif granularity == 1: - tx_filters_count = context().fetchtotal(TXFilter.c.height == height) + expected = 10 if factor >= 2 else 100 + return filters == expected + elif factor == 0: + tx_filters_count = context().fetchtotal(TXFilter.c.height == start) return tx_filters_count > 0 @@ -93,9 +96,9 @@ def get_maximum_known_filters() -> Dict[str, Optional[int]]: return context().fetchone(query) -def get_missing_required_filters(height) -> Dict[int, Tuple[int, int]]: +def get_missing_required_filters(height) -> Set[Tuple[int, int, int]]: known_filters = get_maximum_known_filters() - missing_filters = {} + missing_filters = set() for granularity, (start, end) in get_minimal_required_filter_ranges(height).items(): known_height = known_filters.get(str(granularity)) if known_height is not None and known_height > start: @@ -104,9 +107,9 @@ def get_missing_required_filters(height) -> Dict[int, Tuple[int, int]]: else: adjusted_height = known_height + 10**granularity if adjusted_height <= end: - missing_filters[granularity] = (adjusted_height, end) + missing_filters.add((granularity, adjusted_height, end)) else: - missing_filters[granularity] = (start, end) + missing_filters.add((granularity, start, end)) return missing_filters @@ -123,20 +126,33 @@ def get_filter_matchers(height) -> List[Tuple[int, int, PyBIP158]]: .where(or_(*conditions)) .order_by(BlockFilter.c.height.desc()) ) - return [ - (bf["factor"], bf["height"], get_address_filter(bf["address_filter"])) - for bf in context().fetchall(query) - ] + return [( + bf["factor"], bf["height"], + get_address_filter(bf["address_filter"]), + get_sub_filter_range(bf["factor"], bf["height"]) + ) for bf in context().fetchall(query)] -def get_filter_matchers_at_granularity(granularity) -> List[Tuple[int, PyBIP158]]: +def get_filter_matchers_at_granularity(granularity) -> List[Tuple[int, PyBIP158, Tuple]]: query = ( select(BlockFilter.c.height, BlockFilter.c.address_filter) .where(BlockFilter.c.factor == granularity) .order_by(BlockFilter.c.height.desc()) ) + return [( + bf["height"], + get_address_filter(bf["address_filter"]), + get_sub_filter_range(granularity, bf["height"]) + ) for bf in context().fetchall(query)] + + +def get_tx_matchers_for_missing_txs() -> List[Tuple[int, PyBIP158]]: + query = ( + select(TXFilter.c.tx_hash, TXFilter.c.address_filter) + .where(TXFilter.c.tx_hash.notin_(select(TX.c.tx_hash))) + ) return [ - (bf["height"], get_address_filter(bf["address_filter"])) + (bf["tx_hash"], get_address_filter(bf["address_filter"])) for bf in context().fetchall(query) ] diff --git a/lbry/service/api.py b/lbry/service/api.py index 10427e35a..9f312c839 100644 --- a/lbry/service/api.py +++ b/lbry/service/api.py @@ -2635,7 +2635,7 @@ class API: async def transaction_search( self, txids: StrOrList, # transaction ids to find - ) -> List[Transaction]: + ) -> Dict[str, str]: """ Search for transaction(s) in the entire blockchain. @@ -3435,6 +3435,9 @@ class Client(API): self.receive_messages_task = asyncio.create_task(self.receive_messages()) async def disconnect(self): + if self.ws is not None: + await self.ws.close() + self.ws = None if self.session is not None: await self.session.close() self.session = None diff --git a/lbry/service/full_node.py b/lbry/service/full_node.py index 04a96a35f..ead0d13fb 100644 --- a/lbry/service/full_node.py +++ b/lbry/service/full_node.py @@ -56,8 +56,9 @@ class FullNode(Service): async def search_transactions(self, txids): tx_hashes = [unhexlify(txid)[::-1] for txid in txids] return { - hexlify(tx['tx_hash'][::-1]).decode(): hexlify(tx['raw']).decode() - for tx in await self.db.get_transactions(tx_hashes=tx_hashes) + #hexlify(tx['tx_hash'][::-1]).decode(): hexlify(tx['raw']).decode() + tx.id: hexlify(tx.raw).decode() + for tx in await self.db.get_transactions(tx_hash__in=tx_hashes) } async def broadcast(self, tx): diff --git a/lbry/service/light_client.py b/lbry/service/light_client.py index 8622aa41b..ef4ce3807 100644 --- a/lbry/service/light_client.py +++ b/lbry/service/light_client.py @@ -1,13 +1,13 @@ import asyncio import logging from typing import Dict -from typing import List, Optional, NamedTuple, Tuple -from binascii import unhexlify +from typing import List, Optional, Tuple +from binascii import hexlify, unhexlify -from lbry.blockchain.block import Block, get_address_filter -from lbry.event import BroadcastSubscription -from lbry.crypto.hash import hash160 -from lbry.wallet.account import AddressManager +from lbry.blockchain.block import Block +from lbry.event import EventController, BroadcastSubscription +from lbry.crypto.hash import double_sha256 +from lbry.wallet import WalletManager from lbry.blockchain import Ledger, Transaction from lbry.db import Database @@ -44,8 +44,8 @@ class LightClient(Service): return await self.client.transaction_search(txids=txids) async def get_address_filters(self, start_height: int, end_height: int = None, granularity: int = 0): - return await self.sync.filters.get_filters( - start_height=start_height, end_height=end_height, granularity=granularity + return await self.client.first.address_filter( + granularity=granularity, start_height=start_height, end_height=end_height ) async def broadcast(self, tx): @@ -69,36 +69,6 @@ class LightClient(Service): return await self.client.sum_supports(claim_hash, include_channel_content, exclude_own_supports) -class TransactionEvent(NamedTuple): - address: str - tx: Transaction - - -class AddressesGeneratedEvent(NamedTuple): - address_manager: AddressManager - addresses: List[str] - - -class TransactionCacheItem: - __slots__ = '_tx', 'lock', 'has_tx', 'pending_verifications' - - def __init__(self, tx: Optional[Transaction] = None, lock: Optional[asyncio.Lock] = None): - self.has_tx = asyncio.Event() - self.lock = lock or asyncio.Lock() - self._tx = self.tx = tx - self.pending_verifications = 0 - - @property - def tx(self) -> Optional[Transaction]: - return self._tx - - @tx.setter - def tx(self, tx: Transaction): - self._tx = tx - if tx is not None: - self.has_tx.set() - - class FilterManager: """ Efficient on-demand address filter access. @@ -106,25 +76,108 @@ class FilterManager: downloads on-demand what it doesn't have from full node. """ - def __init__(self, db, client): + def __init__(self, db: Database, client: APIClient): self.db = db self.client = client self.cache = {} - async def download(self, best_height): - our_height = await self.db.get_best_block_filter() - new_block_filters = await self.client.address_filter( - start_height=our_height+1, end_height=best_height, granularity=1 - ) - for block_filter in await new_block_filters.first: - await self.db.insert_block_filter( - block_filter["height"], unhexlify(block_filter["filter"]) + async def download_and_save_filters(self, needed_filters): + for factor, start, end in needed_filters: + filters = await self.client.first.address_filter( + granularity=factor, start_height=start, end_height=end ) + if factor == 0: + for tx_filter in filters: + await self.db.insert_tx_filter( + unhexlify(tx_filter["txid"])[::-1], tx_filter["height"], unhexlify(tx_filter["filter"]) + ) + else: + for block_filter in filters: + await self.db.insert_block_filter( + block_filter["height"], factor, unhexlify(block_filter["filter"]) + ) - async def get_filters(self, start_height, end_height, granularity): - return await self.client.address_filter( - start_height=start_height, end_height=end_height, granularity=granularity - ) + async def download_and_save_txs(self, tx_hashes): + if not tx_hashes: + return + txids = [hexlify(tx_hash[::-1]).decode() for tx_hash in tx_hashes] + txs = await self.client.first.transaction_search(txids=txids) + for raw_tx in txs.values(): + await self.db.insert_transaction(None, Transaction(unhexlify(raw_tx))) + + async def download_initial_filters(self, best_height): + missing = await self.db.get_missing_required_filters(best_height) + await self.download_and_save_filters(missing) + + async def generate_addresses(self, best_height: int, wallets: WalletManager): + for wallet in wallets: + for account in wallet.accounts: + for address_manager in account.address_managers.values(): + missing = await self.db.generate_addresses_using_filters( + best_height, address_manager.gap, ( + account.id, + address_manager.chain_number, + address_manager.public_key.pubkey_bytes, + address_manager.public_key.chain_code, + address_manager.public_key.depth + ) + ) + await self.download_and_save_filters(missing) + + async def download_sub_filters(self, granularity: int, wallets: WalletManager): + for wallet in wallets: + for account in wallet.accounts: + for address_manager in account.address_managers.values(): + missing = await self.db.get_missing_sub_filters_for_addresses( + granularity, (account.id, address_manager.chain_number) + ) + await self.download_and_save_filters(missing) + + async def download_transactions(self, wallets: WalletManager): + for wallet in wallets: + for account in wallet.accounts: + for address_manager in account.address_managers.values(): + missing = await self.db.get_missing_tx_for_addresses( + (account.id, address_manager.chain_number) + ) + await self.download_and_save_txs(missing) + + async def download(self, best_height: int, wallets: WalletManager): + await self.download_initial_filters(best_height) + await self.generate_addresses(best_height, wallets) + await self.download_sub_filters(3, wallets) + await self.download_sub_filters(2, wallets) + await self.download_sub_filters(1, wallets) + await self.download_transactions(wallets) + + @staticmethod + def get_root_of_merkle_tree(branches, branch_positions, working_branch): + for i, branch in enumerate(branches): + other_branch = unhexlify(branch)[::-1] + other_branch_on_left = bool((branch_positions >> i) & 1) + if other_branch_on_left: + combined = other_branch + working_branch + else: + combined = working_branch + other_branch + working_branch = double_sha256(combined) + return hexlify(working_branch[::-1]) + + async def maybe_verify_transaction(self, tx, remote_height, merkle=None): + tx.height = remote_height + cached = self._tx_cache.get(tx.hash) + if not cached: + # cache txs looked up by transaction_show too + cached = TransactionCacheItem() + cached.tx = tx + self._tx_cache[tx.hash] = cached + if 0 < remote_height < len(self.headers) and cached.pending_verifications <= 1: + # can't be tx.pending_verifications == 1 because we have to handle the transaction_show case + if not merkle: + merkle = await self.network.retriable_call(self.network.get_merkle, tx.hash, remote_height) + merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) + header = await self.headers.get(remote_height) + tx.position = merkle['pos'] + tx.is_verified = merkle_root == header['merkle_root'] class BlockHeaderManager: @@ -169,11 +222,15 @@ class FastSync(Sync): self.service = service self.client = client self.advance_loop_task: Optional[asyncio.Task] = None - self.on_block = client.get_event_stream('blockchain.block') + self.on_block = client.get_event_stream("blockchain.block") self.on_block_event = asyncio.Event() self.on_block_subscription: Optional[BroadcastSubscription] = None + self._on_synced_controller = EventController() + self.on_synced = self._on_synced_controller.stream + self.conf.events.register("blockchain.block", self.on_synced) self.blocks = BlockHeaderManager(self.db, self.client) self.filters = FilterManager(self.db, self.client) + self.best_height: Optional[int] = None async def get_block_headers(self, start_height: int, end_height: int = None): return await self.client.first.block_list(start_height, end_height) @@ -182,68 +239,27 @@ class FastSync(Sync): return await self.client.first.block_tip() async def start(self): + self.on_block_subscription = self.on_block.listen(self.handle_on_block) self.advance_loop_task = asyncio.create_task(self.advance()) await self.advance_loop_task self.advance_loop_task = asyncio.create_task(self.loop()) - self.on_block_subscription = self.on_block.listen( - lambda e: self.on_block_event.set() - ) async def stop(self): for task in (self.on_block_subscription, self.advance_loop_task): if task is not None: task.cancel() + def handle_on_block(self, e): + self.best_height = e[0] + self.on_block_event.set() + async def advance(self): - best_height = await self.client.first.block_tip() + height = self.best_height or await self.client.first.block_tip() await asyncio.wait([ - self.blocks.download(best_height), - self.filters.download(best_height), + self.blocks.download(height), + self.filters.download(height, self.service.wallets), ]) - - block_filters = {} - for block_filter in await self.db.get_filters(0, best_height, 1): - block_filters[block_filter['height']] = \ - get_address_filter(unhexlify(block_filter['filter'])) - - for wallet in self.service.wallets: - for account in wallet.accounts: - for address_manager in account.address_managers.values(): - i = gap = 0 - while gap < 20: - key, i = address_manager.public_key.child(i), i+1 - address = bytearray(hash160(key.pubkey_bytes)) - for block, matcher in block_filters.items(): - if matcher.Match(address): - gap = 0 - continue - gap += 1 - - # address = None - # address_array = [bytearray(self.db.ledger.address_to_hash160(address))] - # for address_filter in filters: - # print(address_filter) - # address_filter = get_address_filter(unhexlify(address_filter['filter'])) - # print(address_filter.MatchAny(address_array)) - - -# address_array = [ -# bytearray(a['address'].encode()) -# for a in await self.service.db.get_all_addresses() -# ] -# block_filters = await self.service.get_block_address_filters() -# for block_hash, block_filter in block_filters.items(): -# bf = get_address_filter(block_filter) -# if bf.MatchAny(address_array): -# print(f'match: {block_hash} - {block_filter}') -# tx_filters = await self.service.get_transaction_address_filters(block_hash=block_hash) -# for txid, tx_filter in tx_filters.items(): -# tf = get_address_filter(tx_filter) -# if tf.MatchAny(address_array): -# print(f' match: {txid} - {tx_filter}') -# txs = await self.service.search_transactions([txid]) -# tx = Transaction(unhexlify(txs[txid])) -# await self.service.db.insert_transaction(tx) + await self._on_synced_controller.add(height) async def loop(self): while True: @@ -256,306 +272,3 @@ class FastSync(Sync): except Exception as e: log.exception(e) await self.stop() - - # async def get_local_status_and_history(self, address, history=None): - # if not history: - # address_details = await self.db.get_address(address=address) - # history = (address_details['history'] if address_details else '') or '' - # parts = history.split(':')[:-1] - # return ( - # hexlify(sha256(history.encode())).decode() if history else None, - # list(zip(parts[0::2], map(int, parts[1::2]))) - # ) - # - # @staticmethod - # def get_root_of_merkle_tree(branches, branch_positions, working_branch): - # for i, branch in enumerate(branches): - # other_branch = unhexlify(branch)[::-1] - # other_branch_on_left = bool((branch_positions >> i) & 1) - # if other_branch_on_left: - # combined = other_branch + working_branch - # else: - # combined = working_branch + other_branch - # working_branch = double_sha256(combined) - # return hexlify(working_branch[::-1]) - # - # - # @property - # def local_height_including_downloaded_height(self): - # return max(self.headers.height, self._download_height) - # - # async def initial_headers_sync(self): - # get_chunk = partial(self.network.retriable_call, self.network.get_headers, count=1000, b64=True) - # self.headers.chunk_getter = get_chunk - # - # async def doit(): - # for height in reversed(sorted(self.headers.known_missing_checkpointed_chunks)): - # async with self._header_processing_lock: - # await self.headers.ensure_chunk_at(height) - # self._other_tasks.add(doit()) - # await self.update_headers() - # - # async def update_headers(self, height=None, headers=None, subscription_update=False): - # rewound = 0 - # while True: - # - # if height is None or height > len(self.headers): - # # sometimes header subscription updates are for a header in the future - # # which can't be connected, so we do a normal header sync instead - # height = len(self.headers) - # headers = None - # subscription_update = False - # - # if not headers: - # header_response = await self.network.retriable_call(self.network.get_headers, height, 2001) - # headers = header_response['hex'] - # - # if not headers: - # # Nothing to do, network thinks we're already at the latest height. - # return - # - # added = await self.headers.connect(height, unhexlify(headers)) - # if added > 0: - # height += added - # self._on_header_controller.add( - # BlockHeightEvent(self.headers.height, added)) - # - # if rewound > 0: - # # we started rewinding blocks and apparently found - # # a new chain - # rewound = 0 - # await self.db.rewind_blockchain(height) - # - # if subscription_update: - # # subscription updates are for latest header already - # # so we don't need to check if there are newer / more - # # on another loop of update_headers(), just return instead - # return - # - # elif added == 0: - # # we had headers to connect but none got connected, probably a reorganization - # height -= 1 - # rewound += 1 - # log.warning( - # "Blockchain Reorganization: attempting rewind to height %s from starting height %s", - # height, height+rewound - # ) - # - # else: - # raise IndexError(f"headers.connect() returned negative number ({added})") - # - # if height < 0: - # raise IndexError( - # "Blockchain reorganization rewound all the way back to genesis hash. " - # "Something is very wrong. Maybe you are on the wrong blockchain?" - # ) - # - # if rewound >= 100: - # raise IndexError( - # "Blockchain reorganization dropped {} headers. This is highly unusual. " - # "Will not continue to attempt reorganizing. Please, delete the ledger " - # "synchronization directory inside your wallet directory (folder: '{}') and " - # "restart the program to synchronize from scratch." - # .format(rewound, self.ledger.get_id()) - # ) - # - # headers = None # ready to download some more headers - # - # # if we made it this far and this was a subscription_update - # # it means something went wrong and now we're doing a more - # # robust sync, turn off subscription update shortcut - # subscription_update = False - # - # async def receive_header(self, response): - # async with self._header_processing_lock: - # header = response[0] - # await self.update_headers( - # height=header['height'], headers=header['hex'], subscription_update=True - # ) - # - # async def subscribe_accounts(self): - # if self.network.is_connected and self.accounts: - # log.info("Subscribe to %i accounts", len(self.accounts)) - # await asyncio.wait([ - # self.subscribe_account(a) for a in self.accounts - # ]) - # - # async def subscribe_account(self, account: Account): - # for address_manager in account.address_managers.values(): - # await self.subscribe_addresses(address_manager, await address_manager.get_addresses()) - # await account.ensure_address_gap() - # - # async def unsubscribe_account(self, account: Account): - # for address in await account.get_addresses(): - # await self.network.unsubscribe_address(address) - # - # async def subscribe_addresses( - # self, address_manager: AddressManager, addresses: List[str], batch_size: int = 1000): - # if self.network.is_connected and addresses: - # addresses_remaining = list(addresses) - # while addresses_remaining: - # batch = addresses_remaining[:batch_size] - # results = await self.network.subscribe_address(*batch) - # for address, remote_status in zip(batch, results): - # self._update_tasks.add(self.update_history(address, remote_status, address_manager)) - # addresses_remaining = addresses_remaining[batch_size:] - # log.info("subscribed to %i/%i addresses on %s:%i", len(addresses) - len(addresses_remaining), - # len(addresses), *self.network.client.server_address_and_port) - # log.info( - # "finished subscribing to %i addresses on %s:%i", len(addresses), - # *self.network.client.server_address_and_port - # ) - # - # def process_status_update(self, update): - # address, remote_status = update - # self._update_tasks.add(self.update_history(address, remote_status)) - # - # async def update_history(self, address, remote_status, address_manager: AddressManager = None): - # async with self._address_update_locks[address]: - # self._known_addresses_out_of_sync.discard(address) - # - # local_status, local_history = await self.get_local_status_and_history(address) - # - # if local_status == remote_status: - # return True - # - # remote_history = await self.network.retriable_call(self.network.get_history, address) - # remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history)) - # we_need = set(remote_history) - set(local_history) - # if not we_need: - # return True - # - # cache_tasks: List[asyncio.Task[Transaction]] = [] - # synced_history = StringIO() - # loop = asyncio.get_running_loop() - # for i, (txid, remote_height) in enumerate(remote_history): - # if i < len(local_history) and local_history[i] == (txid, remote_height) and not cache_tasks: - # synced_history.write(f'{txid}:{remote_height}:') - # else: - # check_local = (txid, remote_height) not in we_need - # cache_tasks.append(loop.create_task( - # self.cache_transaction(unhexlify(txid)[::-1], remote_height, check_local=check_local) - # )) - # - # synced_txs = [] - # for task in cache_tasks: - # tx = await task - # - # check_db_for_txos = [] - # for txi in tx.inputs: - # if txi.txo_ref.txo is not None: - # continue - # cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.hash) - # if cache_item is not None: - # if cache_item.tx is None: - # await cache_item.has_tx.wait() - # assert cache_item.tx is not None - # txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref - # else: - # check_db_for_txos.append(txi.txo_ref.hash) - # - # referenced_txos = {} if not check_db_for_txos else { - # txo.id: txo for txo in await self.db.get_txos( - # txo_hash__in=check_db_for_txos, order_by='txo.txo_hash', no_tx=True - # ) - # } - # - # for txi in tx.inputs: - # if txi.txo_ref.txo is not None: - # continue - # referenced_txo = referenced_txos.get(txi.txo_ref.id) - # if referenced_txo is not None: - # txi.txo_ref = referenced_txo.ref - # - # synced_history.write(f'{tx.id}:{tx.height}:') - # synced_txs.append(tx) - # - # await self.db.save_transaction_io_batch( - # synced_txs, address, self.ledger.address_to_hash160(address), synced_history.getvalue() - # ) - # await asyncio.wait([ - # self.ledger._on_transaction_controller.add(TransactionEvent(address, tx)) - # for tx in synced_txs - # ]) - # - # if address_manager is None: - # address_manager = await self.get_address_manager_for_address(address) - # - # if address_manager is not None: - # await address_manager.ensure_address_gap() - # - # local_status, local_history = \ - # await self.get_local_status_and_history(address, synced_history.getvalue()) - # if local_status != remote_status: - # if local_history == remote_history: - # return True - # log.warning( - # "Wallet is out of sync after syncing. Remote: %s with %d items, local: %s with %d items", - # remote_status, len(remote_history), local_status, len(local_history) - # ) - # log.warning("local: %s", local_history) - # log.warning("remote: %s", remote_history) - # self._known_addresses_out_of_sync.add(address) - # return False - # else: - # return True - # - # async def cache_transaction(self, tx_hash, remote_height, check_local=True): - # cache_item = self._tx_cache.get(tx_hash) - # if cache_item is None: - # cache_item = self._tx_cache[tx_hash] = TransactionCacheItem() - # elif cache_item.tx is not None and \ - # cache_item.tx.height >= remote_height and \ - # (cache_item.tx.is_verified or remote_height < 1): - # return cache_item.tx # cached tx is already up-to-date - # - # try: - # cache_item.pending_verifications += 1 - # return await self._update_cache_item(cache_item, tx_hash, remote_height, check_local) - # finally: - # cache_item.pending_verifications -= 1 - # - # async def _update_cache_item(self, cache_item, tx_hash, remote_height, check_local=True): - # - # async with cache_item.lock: - # - # tx = cache_item.tx - # - # if tx is None and check_local: - # # check local db - # tx = cache_item.tx = await self.db.get_transaction(tx_hash=tx_hash) - # - # merkle = None - # if tx is None: - # # fetch from network - # _raw, merkle = await self.network.retriable_call( - # self.network.get_transaction_and_merkle, tx_hash, remote_height - # ) - # tx = Transaction(unhexlify(_raw), height=merkle.get('block_height')) - # cache_item.tx = tx # make sure it's saved before caching it - # await self.maybe_verify_transaction(tx, remote_height, merkle) - # return tx - # - # async def maybe_verify_transaction(self, tx, remote_height, merkle=None): - # tx.height = remote_height - # cached = self._tx_cache.get(tx.hash) - # if not cached: - # # cache txs looked up by transaction_show too - # cached = TransactionCacheItem() - # cached.tx = tx - # self._tx_cache[tx.hash] = cached - # if 0 < remote_height < len(self.headers) and cached.pending_verifications <= 1: - # # can't be tx.pending_verifications == 1 because we have to handle the transaction_show case - # if not merkle: - # merkle = await self.network.retriable_call(self.network.get_merkle, tx.hash, remote_height) - # merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) - # header = await self.headers.get(remote_height) - # tx.position = merkle['pos'] - # tx.is_verified = merkle_root == header['merkle_root'] - # - # async def get_address_manager_for_address(self, address) -> Optional[AddressManager]: - # details = await self.db.get_address(address=address) - # for account in self.accounts: - # if account.id == details['account']: - # return account.address_managers[details['chain']] - # return None diff --git a/tests/integration/service/test_light_client.py b/tests/integration/service/test_light_client.py new file mode 100644 index 000000000..ddd32596a --- /dev/null +++ b/tests/integration/service/test_light_client.py @@ -0,0 +1,46 @@ +import asyncio +from binascii import unhexlify + +from lbry.testcase import IntegrationTestCase +from lbry.service.full_node import FullNode +from lbry.service.light_client import LightClient +from lbry.blockchain.block import get_address_filter + + +class LightClientTests(IntegrationTestCase): + + async def asyncSetUp(self): + await super().asyncSetUp() + await self.chain.generate(200) + self.full_node_daemon = await self.make_full_node_daemon() + self.full_node: FullNode = self.full_node_daemon.service + self.light_client_daemon = await self.make_light_client_daemon(self.full_node_daemon, start=False) + self.light_client: LightClient = self.light_client_daemon.service + self.light_client.conf.wallet_storage = "database" + self.addCleanup(self.light_client.client.disconnect) + await self.light_client.client.connect() + self.addCleanup(self.light_client.db.close) + await self.light_client.db.open() + self.addCleanup(self.light_client.wallets.close) + await self.light_client.wallets.open() + await self.light_client.client.start_event_streams() + self.db = self.light_client.db + self.sync = self.light_client.sync + self.client = self.light_client.client + self.account = self.light_client.wallets.default.accounts.default + + async def test_sync(self): + self.assertEqual(await self.client.first.block_tip(), 200) + + self.assertEqual(await self.db.get_best_block_height(), -1) + self.assertEqual(await self.db.get_missing_required_filters(200), {(2, 0, 100)}) + await self.sync.start() + self.assertEqual(await self.db.get_best_block_height(), 200) + self.assertEqual(await self.db.get_missing_required_filters(200), set()) + + address = await self.account.receiving.get_or_create_usable_address() + await self.chain.send_to_address(address, '5.0') + await self.chain.generate(1) + await self.assertBalance(self.account, '0.0') + await self.sync.on_synced.first + await self.assertBalance(self.account, '5.0') diff --git a/tests/unit/wallet/test_sync.py b/tests/unit/wallet/test_sync.py index 5bee79aa1..6408ae725 100644 --- a/tests/unit/wallet/test_sync.py +++ b/tests/unit/wallet/test_sync.py @@ -5,7 +5,7 @@ from lbry.crypto.hash import hash160 from lbry.crypto.bip32 import from_extended_key_string from lbry.blockchain.block import create_address_filter from lbry.db import queries as q -from lbry.db.tables import AccountAddress +from lbry.db.tables import AccountAddress, TX from lbry.db.query_context import context from lbry.testcase import UnitDBTestCase @@ -13,16 +13,16 @@ from lbry.testcase import UnitDBTestCase class TestMissingRequiredFiltersCalculation(UnitDBTestCase): def test_get_missing_required_filters(self): - self.assertEqual(q.get_missing_required_filters(99), {1: (0, 99)}) - self.assertEqual(q.get_missing_required_filters(100), {100: (0, 0)}) - self.assertEqual(q.get_missing_required_filters(199), {100: (0, 0), 1: (100, 199)}) - self.assertEqual(q.get_missing_required_filters(201), {100: (0, 100), 1: (200, 201)}) + self.assertEqual(q.get_missing_required_filters(99), {(1, 0, 99)}) + self.assertEqual(q.get_missing_required_filters(100), {(2, 0, 0)}) + self.assertEqual(q.get_missing_required_filters(199), {(2, 0, 0), (1, 100, 199)}) + self.assertEqual(q.get_missing_required_filters(201), {(2, 0, 100), (1, 200, 201)}) # all filters missing self.assertEqual(q.get_missing_required_filters(134_567), { - 10_000: (0, 120_000), - 1_000: (130_000, 133_000), - 100: (134_000, 134_400), - 1: (134_500, 134_567) + (4, 0, 120_000), + (3, 130_000, 133_000), + (2, 134_000, 134_400), + (1, 134_500, 134_567) }) q.insert_block_filter(110_000, 4, b'beef') @@ -31,10 +31,10 @@ class TestMissingRequiredFiltersCalculation(UnitDBTestCase): q.insert_block_filter(134_499, 1, b'beef') # we we have some filters, but not recent enough (all except 10k are adjusted) self.assertEqual(q.get_missing_required_filters(134_567), { - 10_000: (120_000, 120_000), # 0 -> 120_000 - 1_000: (130_000, 133_000), - 100: (134_000, 134_400), - 1: (134_500, 134_567) + (4, 120_000, 120_000), # 0 -> 120_000 + (3, 130_000, 133_000), + (2, 134_000, 134_400), + (1, 134_500, 134_567) }) q.insert_block_filter(132_000, 3, b'beef') @@ -42,10 +42,10 @@ class TestMissingRequiredFiltersCalculation(UnitDBTestCase): q.insert_block_filter(134_550, 1, b'beef') # all filters get adjusted because we have recent of each self.assertEqual(q.get_missing_required_filters(134_567), { - 10_000: (120_000, 120_000), # 0 -> 120_000 - 1_000: (133_000, 133_000), # 130_000 -> 133_000 - 100: (134_400, 134_400), # 134_000 -> 134_400 - 1: (134_551, 134_567) # 134_500 -> 134_551 + (4, 120_000, 120_000), # 0 -> 120_000 + (3, 133_000, 133_000), # 130_000 -> 133_000 + (2, 134_400, 134_400), # 134_000 -> 134_400 + (1, 134_551, 134_567) # 134_500 -> 134_551 }) q.insert_block_filter(120_000, 4, b'beef') @@ -54,15 +54,15 @@ class TestMissingRequiredFiltersCalculation(UnitDBTestCase): q.insert_block_filter(134_566, 1, b'beef') # we have latest filters for all except latest single block self.assertEqual(q.get_missing_required_filters(134_567), { - 1: (134_567, 134_567) # 134_551 -> 134_567 + (1, 134_567, 134_567) # 134_551 -> 134_567 }) q.insert_block_filter(134_567, 1, b'beef') # we have all latest filters - self.assertEqual(q.get_missing_required_filters(134_567), {}) + self.assertEqual(q.get_missing_required_filters(134_567), set()) -class TestAddressGeneration(UnitDBTestCase): +class TestAddressGenerationAndTXSync(UnitDBTestCase): RECEIVING_KEY_N = 0 @@ -119,7 +119,7 @@ class TestAddressGeneration(UnitDBTestCase): elif granularity == 1: q.insert_tx_filter(hexlify(f'tx{height}'.encode()), height, create_address_filter(addresses)) - def test_generate_from_filters(self): + def test_generate_from_filters_and_download_txs(self): # 15 addresses will get generated, 9 due to filters and 6 due to gap pubkeys = [self.receiving_pubkey.child(n) for n in range(15)] hashes = [hash160(key.pubkey_bytes) for key in pubkeys] @@ -144,7 +144,10 @@ class TestAddressGeneration(UnitDBTestCase): q.insert_block_filter(134_567, 1, create_address_filter(hashes[8:9])) # check that all required filters did get created - self.assertEqual(q.get_missing_required_filters(134_567), {}) + self.assertEqual(q.get_missing_required_filters(134_567), set()) + + # no addresses + self.assertEqual([], self.get_ordered_addresses()) # generate addresses with 6 address gap, returns new sub filters needed self.assertEqual( @@ -154,15 +157,15 @@ class TestAddressGeneration(UnitDBTestCase): self.receiving_pubkey.chain_code, self.receiving_pubkey.depth )), { - (1, 134500), - (1, 134567), - (2, 134000), - (2, 134400), - (3, 130000), - (3, 133000), - (4, 0), - (4, 100000), - (4, 120000) + (0, 134500, 134500), + (0, 134567, 134567), + (1, 134000, 134099), + (1, 134400, 134499), + (2, 130000, 130900), + (2, 133000, 133900), + (3, 0, 9000), + (3, 100000, 109000), + (3, 120000, 129000) } ) @@ -189,15 +192,108 @@ class TestAddressGeneration(UnitDBTestCase): self.receiving_pubkey.depth )), set() ) - - # no new addresses should have been generated + # no new addresses should have been generated either self.assertEqual([key.address for key in pubkeys], self.get_ordered_addresses()) + # check sub filters at 1,000 self.assertEqual( - q.generate_addresses_using_filters(134_567, 6, ( + q.get_missing_sub_filters_for_addresses(3, ( + self.root_pubkey.address, self.RECEIVING_KEY_N, + )), { + (2, 3000, 3900), + (2, 103000, 103900), + (2, 123000, 123900), + } + ) + # "download" missing 1,000 sub filters + self.insert_sub_filters(3, hashes[0:1], 3000) + self.insert_sub_filters(3, hashes[1:2], 103_000) + self.insert_sub_filters(3, hashes[2:3], 123_000) + # no more missing sub filters at 1,000 + self.assertEqual( + q.get_missing_sub_filters_for_addresses(3, ( + self.root_pubkey.address, self.RECEIVING_KEY_N, + )), set() + ) + + # check sub filters at 100 + self.assertEqual( + q.get_missing_sub_filters_for_addresses(2, ( + self.root_pubkey.address, self.RECEIVING_KEY_N, + )), { + (1, 3300, 3399), + (1, 103300, 103399), + (1, 123300, 123399), + (1, 130300, 130399), + (1, 133300, 133399), + } + ) + # "download" missing 100 sub filters + self.insert_sub_filters(2, hashes[0:1], 3300) + self.insert_sub_filters(2, hashes[1:2], 103_300) + self.insert_sub_filters(2, hashes[2:3], 123_300) + self.insert_sub_filters(2, hashes[3:4], 130_300) + self.insert_sub_filters(2, hashes[4:5], 133_300) + # no more missing sub filters at 100 + self.assertEqual( + q.get_missing_sub_filters_for_addresses(2, ( + self.root_pubkey.address, self.RECEIVING_KEY_N, + )), set() + ) + + # check tx filters + self.assertEqual( + q.get_missing_sub_filters_for_addresses(1, ( + self.root_pubkey.address, self.RECEIVING_KEY_N, + )), { + (0, 3303, 3303), + (0, 103303, 103303), + (0, 123303, 123303), + (0, 130303, 130303), + (0, 133303, 133303), + (0, 134003, 134003), + (0, 134403, 134403), + } + ) + # "download" missing tx filters + self.insert_sub_filters(1, hashes[0:1], 3303) + self.insert_sub_filters(1, hashes[1:2], 103_303) + self.insert_sub_filters(1, hashes[2:3], 123_303) + self.insert_sub_filters(1, hashes[3:4], 130_303) + self.insert_sub_filters(1, hashes[4:5], 133_303) + self.insert_sub_filters(1, hashes[5:6], 134_003) + self.insert_sub_filters(1, hashes[6:7], 134_403) + # no more missing tx filters + self.assertEqual( + q.get_missing_sub_filters_for_addresses(1, ( + self.root_pubkey.address, self.RECEIVING_KEY_N, + )), set() + ) + + # find TXs we need to download + missing_txs = { + b'7478313033333033', + b'7478313233333033', + b'7478313330333033', + b'7478313333333033', + b'7478313334303033', + b'7478313334343033', + b'7478313334353030', + b'7478313334353637', + b'747833333033' + } + self.assertEqual( + q.get_missing_tx_for_addresses(( + self.root_pubkey.address, self.RECEIVING_KEY_N, + )), missing_txs + ) + # "download" missing TXs + ctx = context() + for tx_hash in missing_txs: + ctx.execute(TX.insert().values(tx_hash=tx_hash)) + # check we have everything + self.assertEqual( + q.get_missing_tx_for_addresses(( self.root_pubkey.address, self.RECEIVING_KEY_N, - self.receiving_pubkey.pubkey_bytes, - self.receiving_pubkey.chain_code, - self.receiving_pubkey.depth )), set() )