working client tx sync

This commit is contained in:
Lex Berezhny 2021-01-04 10:44:36 -05:00
parent 522dc72dc1
commit 5eed7d87d3
8 changed files with 385 additions and 484 deletions

View file

@ -254,12 +254,21 @@ class Database:
async def get_missing_required_filters(self, height) -> Dict[int, Tuple[int, int]]:
return await self.run(q.get_missing_required_filters, height)
async def get_missing_sub_filters_for_addresses(self, granularity, address_manager):
return await self.run(q.get_missing_sub_filters_for_addresses, granularity, address_manager)
async def get_missing_tx_for_addresses(self, address_manager):
return await self.run(q.get_missing_tx_for_addresses, address_manager)
async def insert_block(self, block):
return await self.run(q.insert_block, block)
async def insert_block_filter(self, height: int, factor: int, address_filter: bytes):
return await self.run(q.insert_block_filter, height, factor, address_filter)
async def insert_tx_filter(self, tx_hash: bytes, height: int, address_filter: bytes):
return await self.run(q.insert_tx_filter, tx_hash, height, address_filter)
async def insert_transaction(self, block_hash, tx):
return await self.run(q.insert_transaction, block_hash, tx)
@ -316,6 +325,11 @@ class Database:
'depth': k.depth
} for k in pubkeys])
async def generate_addresses_using_filters(self, best_height, allowed_gap, address_manager):
return await self.run(
q.generate_addresses_using_filters, best_height, allowed_gap, address_manager
)
async def get_transactions(self, **constraints) -> Result[Transaction]:
return await self.fetch_result(q.get_transactions, **constraints)

View file

@ -10,7 +10,10 @@ from lbry.crypto.bip32 import PubKey
from ..utils import query
from ..query_context import context
from ..tables import TXO, PubkeyAddress, AccountAddress
from .filters import get_filter_matchers, get_filter_matchers_at_granularity, has_sub_filters
from .filters import (
get_filter_matchers, get_filter_matchers_at_granularity, has_filter_range,
get_tx_matchers_for_missing_txs,
)
log = logging.getLogger(__name__)
@ -87,15 +90,14 @@ def generate_addresses_using_filters(best_height, allowed_gap, address_manager)
for address_hash, n, is_new in addresses:
gap += 1
address_bytes = bytearray(address_hash)
for granularity, height, matcher in matchers:
for granularity, height, matcher, filter_range in matchers:
if matcher.Match(address_bytes):
gap = 0
match = (granularity, height)
if match not in need and match not in have:
if has_sub_filters(granularity, height):
have.add(match)
if filter_range not in need and filter_range not in have:
if has_filter_range(*filter_range):
have.add(filter_range)
else:
need.add(match)
need.add(filter_range)
if gap >= allowed_gap:
break
return need
@ -103,13 +105,23 @@ def generate_addresses_using_filters(best_height, allowed_gap, address_manager)
def get_missing_sub_filters_for_addresses(granularity, address_manager):
need = set()
with DatabaseAddressIterator(*address_manager) as addresses:
for height, matcher in get_filter_matchers_at_granularity(granularity):
for address_hash, n, is_new in addresses:
address_bytes = bytearray(address_hash)
if matcher.Match(address_bytes) and not has_sub_filters(granularity, height):
need.add((height, granularity))
break
for height, matcher, filter_range in get_filter_matchers_at_granularity(granularity):
for address_hash, n, is_new in DatabaseAddressIterator(*address_manager):
address_bytes = bytearray(address_hash)
if matcher.Match(address_bytes) and not has_filter_range(*filter_range):
need.add(filter_range)
break
return need
def get_missing_tx_for_addresses(address_manager):
need = set()
for tx_hash, matcher in get_tx_matchers_for_missing_txs():
for address_hash, n, is_new in DatabaseAddressIterator(*address_manager):
address_bytes = bytearray(address_hash)
if matcher.Match(address_bytes):
need.add(tx_hash)
break
return need

View file

@ -1,5 +1,5 @@
from math import log10
from typing import Dict, List, Tuple, Optional
from typing import Dict, List, Tuple, Set, Optional
from sqlalchemy import between, func, or_
from sqlalchemy.future import select
@ -7,29 +7,32 @@ from sqlalchemy.future import select
from lbry.blockchain.block import PyBIP158, get_address_filter
from ..query_context import context
from ..tables import BlockFilter, TXFilter
from ..tables import BlockFilter, TXFilter, TX
def has_filters():
return context().has_records(BlockFilter)
def has_sub_filters(granularity: int, height: int):
def get_sub_filter_range(granularity: int, height: int):
end = height
if granularity >= 3:
sub_filter_size = 10**(granularity-1)
sub_filters_count = context().fetchtotal(
(BlockFilter.c.factor == granularity-1) &
between(BlockFilter.c.height, height, height + sub_filter_size * 9)
)
return sub_filters_count == 10
end = height + 10**(granularity-1) * 9
elif granularity == 2:
sub_filters_count = context().fetchtotal(
(BlockFilter.c.factor == 1) &
between(BlockFilter.c.height, height, height + 99)
end = height + 99
return granularity - 1, height, end
def has_filter_range(factor: int, start: int, end: int):
if factor >= 1:
filters = context().fetchtotal(
(BlockFilter.c.factor == factor) &
between(BlockFilter.c.height, start, end)
)
return sub_filters_count == 100
elif granularity == 1:
tx_filters_count = context().fetchtotal(TXFilter.c.height == height)
expected = 10 if factor >= 2 else 100
return filters == expected
elif factor == 0:
tx_filters_count = context().fetchtotal(TXFilter.c.height == start)
return tx_filters_count > 0
@ -93,9 +96,9 @@ def get_maximum_known_filters() -> Dict[str, Optional[int]]:
return context().fetchone(query)
def get_missing_required_filters(height) -> Dict[int, Tuple[int, int]]:
def get_missing_required_filters(height) -> Set[Tuple[int, int, int]]:
known_filters = get_maximum_known_filters()
missing_filters = {}
missing_filters = set()
for granularity, (start, end) in get_minimal_required_filter_ranges(height).items():
known_height = known_filters.get(str(granularity))
if known_height is not None and known_height > start:
@ -104,9 +107,9 @@ def get_missing_required_filters(height) -> Dict[int, Tuple[int, int]]:
else:
adjusted_height = known_height + 10**granularity
if adjusted_height <= end:
missing_filters[granularity] = (adjusted_height, end)
missing_filters.add((granularity, adjusted_height, end))
else:
missing_filters[granularity] = (start, end)
missing_filters.add((granularity, start, end))
return missing_filters
@ -123,20 +126,33 @@ def get_filter_matchers(height) -> List[Tuple[int, int, PyBIP158]]:
.where(or_(*conditions))
.order_by(BlockFilter.c.height.desc())
)
return [
(bf["factor"], bf["height"], get_address_filter(bf["address_filter"]))
for bf in context().fetchall(query)
]
return [(
bf["factor"], bf["height"],
get_address_filter(bf["address_filter"]),
get_sub_filter_range(bf["factor"], bf["height"])
) for bf in context().fetchall(query)]
def get_filter_matchers_at_granularity(granularity) -> List[Tuple[int, PyBIP158]]:
def get_filter_matchers_at_granularity(granularity) -> List[Tuple[int, PyBIP158, Tuple]]:
query = (
select(BlockFilter.c.height, BlockFilter.c.address_filter)
.where(BlockFilter.c.factor == granularity)
.order_by(BlockFilter.c.height.desc())
)
return [(
bf["height"],
get_address_filter(bf["address_filter"]),
get_sub_filter_range(granularity, bf["height"])
) for bf in context().fetchall(query)]
def get_tx_matchers_for_missing_txs() -> List[Tuple[int, PyBIP158]]:
query = (
select(TXFilter.c.tx_hash, TXFilter.c.address_filter)
.where(TXFilter.c.tx_hash.notin_(select(TX.c.tx_hash)))
)
return [
(bf["height"], get_address_filter(bf["address_filter"]))
(bf["tx_hash"], get_address_filter(bf["address_filter"]))
for bf in context().fetchall(query)
]

View file

@ -2635,7 +2635,7 @@ class API:
async def transaction_search(
self,
txids: StrOrList, # transaction ids to find
) -> List[Transaction]:
) -> Dict[str, str]:
"""
Search for transaction(s) in the entire blockchain.
@ -3435,6 +3435,9 @@ class Client(API):
self.receive_messages_task = asyncio.create_task(self.receive_messages())
async def disconnect(self):
if self.ws is not None:
await self.ws.close()
self.ws = None
if self.session is not None:
await self.session.close()
self.session = None

View file

@ -56,8 +56,9 @@ class FullNode(Service):
async def search_transactions(self, txids):
tx_hashes = [unhexlify(txid)[::-1] for txid in txids]
return {
hexlify(tx['tx_hash'][::-1]).decode(): hexlify(tx['raw']).decode()
for tx in await self.db.get_transactions(tx_hashes=tx_hashes)
#hexlify(tx['tx_hash'][::-1]).decode(): hexlify(tx['raw']).decode()
tx.id: hexlify(tx.raw).decode()
for tx in await self.db.get_transactions(tx_hash__in=tx_hashes)
}
async def broadcast(self, tx):

View file

@ -1,13 +1,13 @@
import asyncio
import logging
from typing import Dict
from typing import List, Optional, NamedTuple, Tuple
from binascii import unhexlify
from typing import List, Optional, Tuple
from binascii import hexlify, unhexlify
from lbry.blockchain.block import Block, get_address_filter
from lbry.event import BroadcastSubscription
from lbry.crypto.hash import hash160
from lbry.wallet.account import AddressManager
from lbry.blockchain.block import Block
from lbry.event import EventController, BroadcastSubscription
from lbry.crypto.hash import double_sha256
from lbry.wallet import WalletManager
from lbry.blockchain import Ledger, Transaction
from lbry.db import Database
@ -44,8 +44,8 @@ class LightClient(Service):
return await self.client.transaction_search(txids=txids)
async def get_address_filters(self, start_height: int, end_height: int = None, granularity: int = 0):
return await self.sync.filters.get_filters(
start_height=start_height, end_height=end_height, granularity=granularity
return await self.client.first.address_filter(
granularity=granularity, start_height=start_height, end_height=end_height
)
async def broadcast(self, tx):
@ -69,36 +69,6 @@ class LightClient(Service):
return await self.client.sum_supports(claim_hash, include_channel_content, exclude_own_supports)
class TransactionEvent(NamedTuple):
address: str
tx: Transaction
class AddressesGeneratedEvent(NamedTuple):
address_manager: AddressManager
addresses: List[str]
class TransactionCacheItem:
__slots__ = '_tx', 'lock', 'has_tx', 'pending_verifications'
def __init__(self, tx: Optional[Transaction] = None, lock: Optional[asyncio.Lock] = None):
self.has_tx = asyncio.Event()
self.lock = lock or asyncio.Lock()
self._tx = self.tx = tx
self.pending_verifications = 0
@property
def tx(self) -> Optional[Transaction]:
return self._tx
@tx.setter
def tx(self, tx: Transaction):
self._tx = tx
if tx is not None:
self.has_tx.set()
class FilterManager:
"""
Efficient on-demand address filter access.
@ -106,25 +76,108 @@ class FilterManager:
downloads on-demand what it doesn't have from full node.
"""
def __init__(self, db, client):
def __init__(self, db: Database, client: APIClient):
self.db = db
self.client = client
self.cache = {}
async def download(self, best_height):
our_height = await self.db.get_best_block_filter()
new_block_filters = await self.client.address_filter(
start_height=our_height+1, end_height=best_height, granularity=1
)
for block_filter in await new_block_filters.first:
await self.db.insert_block_filter(
block_filter["height"], unhexlify(block_filter["filter"])
async def download_and_save_filters(self, needed_filters):
for factor, start, end in needed_filters:
filters = await self.client.first.address_filter(
granularity=factor, start_height=start, end_height=end
)
if factor == 0:
for tx_filter in filters:
await self.db.insert_tx_filter(
unhexlify(tx_filter["txid"])[::-1], tx_filter["height"], unhexlify(tx_filter["filter"])
)
else:
for block_filter in filters:
await self.db.insert_block_filter(
block_filter["height"], factor, unhexlify(block_filter["filter"])
)
async def get_filters(self, start_height, end_height, granularity):
return await self.client.address_filter(
start_height=start_height, end_height=end_height, granularity=granularity
)
async def download_and_save_txs(self, tx_hashes):
if not tx_hashes:
return
txids = [hexlify(tx_hash[::-1]).decode() for tx_hash in tx_hashes]
txs = await self.client.first.transaction_search(txids=txids)
for raw_tx in txs.values():
await self.db.insert_transaction(None, Transaction(unhexlify(raw_tx)))
async def download_initial_filters(self, best_height):
missing = await self.db.get_missing_required_filters(best_height)
await self.download_and_save_filters(missing)
async def generate_addresses(self, best_height: int, wallets: WalletManager):
for wallet in wallets:
for account in wallet.accounts:
for address_manager in account.address_managers.values():
missing = await self.db.generate_addresses_using_filters(
best_height, address_manager.gap, (
account.id,
address_manager.chain_number,
address_manager.public_key.pubkey_bytes,
address_manager.public_key.chain_code,
address_manager.public_key.depth
)
)
await self.download_and_save_filters(missing)
async def download_sub_filters(self, granularity: int, wallets: WalletManager):
for wallet in wallets:
for account in wallet.accounts:
for address_manager in account.address_managers.values():
missing = await self.db.get_missing_sub_filters_for_addresses(
granularity, (account.id, address_manager.chain_number)
)
await self.download_and_save_filters(missing)
async def download_transactions(self, wallets: WalletManager):
for wallet in wallets:
for account in wallet.accounts:
for address_manager in account.address_managers.values():
missing = await self.db.get_missing_tx_for_addresses(
(account.id, address_manager.chain_number)
)
await self.download_and_save_txs(missing)
async def download(self, best_height: int, wallets: WalletManager):
await self.download_initial_filters(best_height)
await self.generate_addresses(best_height, wallets)
await self.download_sub_filters(3, wallets)
await self.download_sub_filters(2, wallets)
await self.download_sub_filters(1, wallets)
await self.download_transactions(wallets)
@staticmethod
def get_root_of_merkle_tree(branches, branch_positions, working_branch):
for i, branch in enumerate(branches):
other_branch = unhexlify(branch)[::-1]
other_branch_on_left = bool((branch_positions >> i) & 1)
if other_branch_on_left:
combined = other_branch + working_branch
else:
combined = working_branch + other_branch
working_branch = double_sha256(combined)
return hexlify(working_branch[::-1])
async def maybe_verify_transaction(self, tx, remote_height, merkle=None):
tx.height = remote_height
cached = self._tx_cache.get(tx.hash)
if not cached:
# cache txs looked up by transaction_show too
cached = TransactionCacheItem()
cached.tx = tx
self._tx_cache[tx.hash] = cached
if 0 < remote_height < len(self.headers) and cached.pending_verifications <= 1:
# can't be tx.pending_verifications == 1 because we have to handle the transaction_show case
if not merkle:
merkle = await self.network.retriable_call(self.network.get_merkle, tx.hash, remote_height)
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
header = await self.headers.get(remote_height)
tx.position = merkle['pos']
tx.is_verified = merkle_root == header['merkle_root']
class BlockHeaderManager:
@ -169,11 +222,15 @@ class FastSync(Sync):
self.service = service
self.client = client
self.advance_loop_task: Optional[asyncio.Task] = None
self.on_block = client.get_event_stream('blockchain.block')
self.on_block = client.get_event_stream("blockchain.block")
self.on_block_event = asyncio.Event()
self.on_block_subscription: Optional[BroadcastSubscription] = None
self._on_synced_controller = EventController()
self.on_synced = self._on_synced_controller.stream
self.conf.events.register("blockchain.block", self.on_synced)
self.blocks = BlockHeaderManager(self.db, self.client)
self.filters = FilterManager(self.db, self.client)
self.best_height: Optional[int] = None
async def get_block_headers(self, start_height: int, end_height: int = None):
return await self.client.first.block_list(start_height, end_height)
@ -182,68 +239,27 @@ class FastSync(Sync):
return await self.client.first.block_tip()
async def start(self):
self.on_block_subscription = self.on_block.listen(self.handle_on_block)
self.advance_loop_task = asyncio.create_task(self.advance())
await self.advance_loop_task
self.advance_loop_task = asyncio.create_task(self.loop())
self.on_block_subscription = self.on_block.listen(
lambda e: self.on_block_event.set()
)
async def stop(self):
for task in (self.on_block_subscription, self.advance_loop_task):
if task is not None:
task.cancel()
def handle_on_block(self, e):
self.best_height = e[0]
self.on_block_event.set()
async def advance(self):
best_height = await self.client.first.block_tip()
height = self.best_height or await self.client.first.block_tip()
await asyncio.wait([
self.blocks.download(best_height),
self.filters.download(best_height),
self.blocks.download(height),
self.filters.download(height, self.service.wallets),
])
block_filters = {}
for block_filter in await self.db.get_filters(0, best_height, 1):
block_filters[block_filter['height']] = \
get_address_filter(unhexlify(block_filter['filter']))
for wallet in self.service.wallets:
for account in wallet.accounts:
for address_manager in account.address_managers.values():
i = gap = 0
while gap < 20:
key, i = address_manager.public_key.child(i), i+1
address = bytearray(hash160(key.pubkey_bytes))
for block, matcher in block_filters.items():
if matcher.Match(address):
gap = 0
continue
gap += 1
# address = None
# address_array = [bytearray(self.db.ledger.address_to_hash160(address))]
# for address_filter in filters:
# print(address_filter)
# address_filter = get_address_filter(unhexlify(address_filter['filter']))
# print(address_filter.MatchAny(address_array))
# address_array = [
# bytearray(a['address'].encode())
# for a in await self.service.db.get_all_addresses()
# ]
# block_filters = await self.service.get_block_address_filters()
# for block_hash, block_filter in block_filters.items():
# bf = get_address_filter(block_filter)
# if bf.MatchAny(address_array):
# print(f'match: {block_hash} - {block_filter}')
# tx_filters = await self.service.get_transaction_address_filters(block_hash=block_hash)
# for txid, tx_filter in tx_filters.items():
# tf = get_address_filter(tx_filter)
# if tf.MatchAny(address_array):
# print(f' match: {txid} - {tx_filter}')
# txs = await self.service.search_transactions([txid])
# tx = Transaction(unhexlify(txs[txid]))
# await self.service.db.insert_transaction(tx)
await self._on_synced_controller.add(height)
async def loop(self):
while True:
@ -256,306 +272,3 @@ class FastSync(Sync):
except Exception as e:
log.exception(e)
await self.stop()
# async def get_local_status_and_history(self, address, history=None):
# if not history:
# address_details = await self.db.get_address(address=address)
# history = (address_details['history'] if address_details else '') or ''
# parts = history.split(':')[:-1]
# return (
# hexlify(sha256(history.encode())).decode() if history else None,
# list(zip(parts[0::2], map(int, parts[1::2])))
# )
#
# @staticmethod
# def get_root_of_merkle_tree(branches, branch_positions, working_branch):
# for i, branch in enumerate(branches):
# other_branch = unhexlify(branch)[::-1]
# other_branch_on_left = bool((branch_positions >> i) & 1)
# if other_branch_on_left:
# combined = other_branch + working_branch
# else:
# combined = working_branch + other_branch
# working_branch = double_sha256(combined)
# return hexlify(working_branch[::-1])
#
#
# @property
# def local_height_including_downloaded_height(self):
# return max(self.headers.height, self._download_height)
#
# async def initial_headers_sync(self):
# get_chunk = partial(self.network.retriable_call, self.network.get_headers, count=1000, b64=True)
# self.headers.chunk_getter = get_chunk
#
# async def doit():
# for height in reversed(sorted(self.headers.known_missing_checkpointed_chunks)):
# async with self._header_processing_lock:
# await self.headers.ensure_chunk_at(height)
# self._other_tasks.add(doit())
# await self.update_headers()
#
# async def update_headers(self, height=None, headers=None, subscription_update=False):
# rewound = 0
# while True:
#
# if height is None or height > len(self.headers):
# # sometimes header subscription updates are for a header in the future
# # which can't be connected, so we do a normal header sync instead
# height = len(self.headers)
# headers = None
# subscription_update = False
#
# if not headers:
# header_response = await self.network.retriable_call(self.network.get_headers, height, 2001)
# headers = header_response['hex']
#
# if not headers:
# # Nothing to do, network thinks we're already at the latest height.
# return
#
# added = await self.headers.connect(height, unhexlify(headers))
# if added > 0:
# height += added
# self._on_header_controller.add(
# BlockHeightEvent(self.headers.height, added))
#
# if rewound > 0:
# # we started rewinding blocks and apparently found
# # a new chain
# rewound = 0
# await self.db.rewind_blockchain(height)
#
# if subscription_update:
# # subscription updates are for latest header already
# # so we don't need to check if there are newer / more
# # on another loop of update_headers(), just return instead
# return
#
# elif added == 0:
# # we had headers to connect but none got connected, probably a reorganization
# height -= 1
# rewound += 1
# log.warning(
# "Blockchain Reorganization: attempting rewind to height %s from starting height %s",
# height, height+rewound
# )
#
# else:
# raise IndexError(f"headers.connect() returned negative number ({added})")
#
# if height < 0:
# raise IndexError(
# "Blockchain reorganization rewound all the way back to genesis hash. "
# "Something is very wrong. Maybe you are on the wrong blockchain?"
# )
#
# if rewound >= 100:
# raise IndexError(
# "Blockchain reorganization dropped {} headers. This is highly unusual. "
# "Will not continue to attempt reorganizing. Please, delete the ledger "
# "synchronization directory inside your wallet directory (folder: '{}') and "
# "restart the program to synchronize from scratch."
# .format(rewound, self.ledger.get_id())
# )
#
# headers = None # ready to download some more headers
#
# # if we made it this far and this was a subscription_update
# # it means something went wrong and now we're doing a more
# # robust sync, turn off subscription update shortcut
# subscription_update = False
#
# async def receive_header(self, response):
# async with self._header_processing_lock:
# header = response[0]
# await self.update_headers(
# height=header['height'], headers=header['hex'], subscription_update=True
# )
#
# async def subscribe_accounts(self):
# if self.network.is_connected and self.accounts:
# log.info("Subscribe to %i accounts", len(self.accounts))
# await asyncio.wait([
# self.subscribe_account(a) for a in self.accounts
# ])
#
# async def subscribe_account(self, account: Account):
# for address_manager in account.address_managers.values():
# await self.subscribe_addresses(address_manager, await address_manager.get_addresses())
# await account.ensure_address_gap()
#
# async def unsubscribe_account(self, account: Account):
# for address in await account.get_addresses():
# await self.network.unsubscribe_address(address)
#
# async def subscribe_addresses(
# self, address_manager: AddressManager, addresses: List[str], batch_size: int = 1000):
# if self.network.is_connected and addresses:
# addresses_remaining = list(addresses)
# while addresses_remaining:
# batch = addresses_remaining[:batch_size]
# results = await self.network.subscribe_address(*batch)
# for address, remote_status in zip(batch, results):
# self._update_tasks.add(self.update_history(address, remote_status, address_manager))
# addresses_remaining = addresses_remaining[batch_size:]
# log.info("subscribed to %i/%i addresses on %s:%i", len(addresses) - len(addresses_remaining),
# len(addresses), *self.network.client.server_address_and_port)
# log.info(
# "finished subscribing to %i addresses on %s:%i", len(addresses),
# *self.network.client.server_address_and_port
# )
#
# def process_status_update(self, update):
# address, remote_status = update
# self._update_tasks.add(self.update_history(address, remote_status))
#
# async def update_history(self, address, remote_status, address_manager: AddressManager = None):
# async with self._address_update_locks[address]:
# self._known_addresses_out_of_sync.discard(address)
#
# local_status, local_history = await self.get_local_status_and_history(address)
#
# if local_status == remote_status:
# return True
#
# remote_history = await self.network.retriable_call(self.network.get_history, address)
# remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history))
# we_need = set(remote_history) - set(local_history)
# if not we_need:
# return True
#
# cache_tasks: List[asyncio.Task[Transaction]] = []
# synced_history = StringIO()
# loop = asyncio.get_running_loop()
# for i, (txid, remote_height) in enumerate(remote_history):
# if i < len(local_history) and local_history[i] == (txid, remote_height) and not cache_tasks:
# synced_history.write(f'{txid}:{remote_height}:')
# else:
# check_local = (txid, remote_height) not in we_need
# cache_tasks.append(loop.create_task(
# self.cache_transaction(unhexlify(txid)[::-1], remote_height, check_local=check_local)
# ))
#
# synced_txs = []
# for task in cache_tasks:
# tx = await task
#
# check_db_for_txos = []
# for txi in tx.inputs:
# if txi.txo_ref.txo is not None:
# continue
# cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.hash)
# if cache_item is not None:
# if cache_item.tx is None:
# await cache_item.has_tx.wait()
# assert cache_item.tx is not None
# txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref
# else:
# check_db_for_txos.append(txi.txo_ref.hash)
#
# referenced_txos = {} if not check_db_for_txos else {
# txo.id: txo for txo in await self.db.get_txos(
# txo_hash__in=check_db_for_txos, order_by='txo.txo_hash', no_tx=True
# )
# }
#
# for txi in tx.inputs:
# if txi.txo_ref.txo is not None:
# continue
# referenced_txo = referenced_txos.get(txi.txo_ref.id)
# if referenced_txo is not None:
# txi.txo_ref = referenced_txo.ref
#
# synced_history.write(f'{tx.id}:{tx.height}:')
# synced_txs.append(tx)
#
# await self.db.save_transaction_io_batch(
# synced_txs, address, self.ledger.address_to_hash160(address), synced_history.getvalue()
# )
# await asyncio.wait([
# self.ledger._on_transaction_controller.add(TransactionEvent(address, tx))
# for tx in synced_txs
# ])
#
# if address_manager is None:
# address_manager = await self.get_address_manager_for_address(address)
#
# if address_manager is not None:
# await address_manager.ensure_address_gap()
#
# local_status, local_history = \
# await self.get_local_status_and_history(address, synced_history.getvalue())
# if local_status != remote_status:
# if local_history == remote_history:
# return True
# log.warning(
# "Wallet is out of sync after syncing. Remote: %s with %d items, local: %s with %d items",
# remote_status, len(remote_history), local_status, len(local_history)
# )
# log.warning("local: %s", local_history)
# log.warning("remote: %s", remote_history)
# self._known_addresses_out_of_sync.add(address)
# return False
# else:
# return True
#
# async def cache_transaction(self, tx_hash, remote_height, check_local=True):
# cache_item = self._tx_cache.get(tx_hash)
# if cache_item is None:
# cache_item = self._tx_cache[tx_hash] = TransactionCacheItem()
# elif cache_item.tx is not None and \
# cache_item.tx.height >= remote_height and \
# (cache_item.tx.is_verified or remote_height < 1):
# return cache_item.tx # cached tx is already up-to-date
#
# try:
# cache_item.pending_verifications += 1
# return await self._update_cache_item(cache_item, tx_hash, remote_height, check_local)
# finally:
# cache_item.pending_verifications -= 1
#
# async def _update_cache_item(self, cache_item, tx_hash, remote_height, check_local=True):
#
# async with cache_item.lock:
#
# tx = cache_item.tx
#
# if tx is None and check_local:
# # check local db
# tx = cache_item.tx = await self.db.get_transaction(tx_hash=tx_hash)
#
# merkle = None
# if tx is None:
# # fetch from network
# _raw, merkle = await self.network.retriable_call(
# self.network.get_transaction_and_merkle, tx_hash, remote_height
# )
# tx = Transaction(unhexlify(_raw), height=merkle.get('block_height'))
# cache_item.tx = tx # make sure it's saved before caching it
# await self.maybe_verify_transaction(tx, remote_height, merkle)
# return tx
#
# async def maybe_verify_transaction(self, tx, remote_height, merkle=None):
# tx.height = remote_height
# cached = self._tx_cache.get(tx.hash)
# if not cached:
# # cache txs looked up by transaction_show too
# cached = TransactionCacheItem()
# cached.tx = tx
# self._tx_cache[tx.hash] = cached
# if 0 < remote_height < len(self.headers) and cached.pending_verifications <= 1:
# # can't be tx.pending_verifications == 1 because we have to handle the transaction_show case
# if not merkle:
# merkle = await self.network.retriable_call(self.network.get_merkle, tx.hash, remote_height)
# merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
# header = await self.headers.get(remote_height)
# tx.position = merkle['pos']
# tx.is_verified = merkle_root == header['merkle_root']
#
# async def get_address_manager_for_address(self, address) -> Optional[AddressManager]:
# details = await self.db.get_address(address=address)
# for account in self.accounts:
# if account.id == details['account']:
# return account.address_managers[details['chain']]
# return None

View file

@ -0,0 +1,46 @@
import asyncio
from binascii import unhexlify
from lbry.testcase import IntegrationTestCase
from lbry.service.full_node import FullNode
from lbry.service.light_client import LightClient
from lbry.blockchain.block import get_address_filter
class LightClientTests(IntegrationTestCase):
async def asyncSetUp(self):
await super().asyncSetUp()
await self.chain.generate(200)
self.full_node_daemon = await self.make_full_node_daemon()
self.full_node: FullNode = self.full_node_daemon.service
self.light_client_daemon = await self.make_light_client_daemon(self.full_node_daemon, start=False)
self.light_client: LightClient = self.light_client_daemon.service
self.light_client.conf.wallet_storage = "database"
self.addCleanup(self.light_client.client.disconnect)
await self.light_client.client.connect()
self.addCleanup(self.light_client.db.close)
await self.light_client.db.open()
self.addCleanup(self.light_client.wallets.close)
await self.light_client.wallets.open()
await self.light_client.client.start_event_streams()
self.db = self.light_client.db
self.sync = self.light_client.sync
self.client = self.light_client.client
self.account = self.light_client.wallets.default.accounts.default
async def test_sync(self):
self.assertEqual(await self.client.first.block_tip(), 200)
self.assertEqual(await self.db.get_best_block_height(), -1)
self.assertEqual(await self.db.get_missing_required_filters(200), {(2, 0, 100)})
await self.sync.start()
self.assertEqual(await self.db.get_best_block_height(), 200)
self.assertEqual(await self.db.get_missing_required_filters(200), set())
address = await self.account.receiving.get_or_create_usable_address()
await self.chain.send_to_address(address, '5.0')
await self.chain.generate(1)
await self.assertBalance(self.account, '0.0')
await self.sync.on_synced.first
await self.assertBalance(self.account, '5.0')

View file

@ -5,7 +5,7 @@ from lbry.crypto.hash import hash160
from lbry.crypto.bip32 import from_extended_key_string
from lbry.blockchain.block import create_address_filter
from lbry.db import queries as q
from lbry.db.tables import AccountAddress
from lbry.db.tables import AccountAddress, TX
from lbry.db.query_context import context
from lbry.testcase import UnitDBTestCase
@ -13,16 +13,16 @@ from lbry.testcase import UnitDBTestCase
class TestMissingRequiredFiltersCalculation(UnitDBTestCase):
def test_get_missing_required_filters(self):
self.assertEqual(q.get_missing_required_filters(99), {1: (0, 99)})
self.assertEqual(q.get_missing_required_filters(100), {100: (0, 0)})
self.assertEqual(q.get_missing_required_filters(199), {100: (0, 0), 1: (100, 199)})
self.assertEqual(q.get_missing_required_filters(201), {100: (0, 100), 1: (200, 201)})
self.assertEqual(q.get_missing_required_filters(99), {(1, 0, 99)})
self.assertEqual(q.get_missing_required_filters(100), {(2, 0, 0)})
self.assertEqual(q.get_missing_required_filters(199), {(2, 0, 0), (1, 100, 199)})
self.assertEqual(q.get_missing_required_filters(201), {(2, 0, 100), (1, 200, 201)})
# all filters missing
self.assertEqual(q.get_missing_required_filters(134_567), {
10_000: (0, 120_000),
1_000: (130_000, 133_000),
100: (134_000, 134_400),
1: (134_500, 134_567)
(4, 0, 120_000),
(3, 130_000, 133_000),
(2, 134_000, 134_400),
(1, 134_500, 134_567)
})
q.insert_block_filter(110_000, 4, b'beef')
@ -31,10 +31,10 @@ class TestMissingRequiredFiltersCalculation(UnitDBTestCase):
q.insert_block_filter(134_499, 1, b'beef')
# we we have some filters, but not recent enough (all except 10k are adjusted)
self.assertEqual(q.get_missing_required_filters(134_567), {
10_000: (120_000, 120_000), # 0 -> 120_000
1_000: (130_000, 133_000),
100: (134_000, 134_400),
1: (134_500, 134_567)
(4, 120_000, 120_000), # 0 -> 120_000
(3, 130_000, 133_000),
(2, 134_000, 134_400),
(1, 134_500, 134_567)
})
q.insert_block_filter(132_000, 3, b'beef')
@ -42,10 +42,10 @@ class TestMissingRequiredFiltersCalculation(UnitDBTestCase):
q.insert_block_filter(134_550, 1, b'beef')
# all filters get adjusted because we have recent of each
self.assertEqual(q.get_missing_required_filters(134_567), {
10_000: (120_000, 120_000), # 0 -> 120_000
1_000: (133_000, 133_000), # 130_000 -> 133_000
100: (134_400, 134_400), # 134_000 -> 134_400
1: (134_551, 134_567) # 134_500 -> 134_551
(4, 120_000, 120_000), # 0 -> 120_000
(3, 133_000, 133_000), # 130_000 -> 133_000
(2, 134_400, 134_400), # 134_000 -> 134_400
(1, 134_551, 134_567) # 134_500 -> 134_551
})
q.insert_block_filter(120_000, 4, b'beef')
@ -54,15 +54,15 @@ class TestMissingRequiredFiltersCalculation(UnitDBTestCase):
q.insert_block_filter(134_566, 1, b'beef')
# we have latest filters for all except latest single block
self.assertEqual(q.get_missing_required_filters(134_567), {
1: (134_567, 134_567) # 134_551 -> 134_567
(1, 134_567, 134_567) # 134_551 -> 134_567
})
q.insert_block_filter(134_567, 1, b'beef')
# we have all latest filters
self.assertEqual(q.get_missing_required_filters(134_567), {})
self.assertEqual(q.get_missing_required_filters(134_567), set())
class TestAddressGeneration(UnitDBTestCase):
class TestAddressGenerationAndTXSync(UnitDBTestCase):
RECEIVING_KEY_N = 0
@ -119,7 +119,7 @@ class TestAddressGeneration(UnitDBTestCase):
elif granularity == 1:
q.insert_tx_filter(hexlify(f'tx{height}'.encode()), height, create_address_filter(addresses))
def test_generate_from_filters(self):
def test_generate_from_filters_and_download_txs(self):
# 15 addresses will get generated, 9 due to filters and 6 due to gap
pubkeys = [self.receiving_pubkey.child(n) for n in range(15)]
hashes = [hash160(key.pubkey_bytes) for key in pubkeys]
@ -144,7 +144,10 @@ class TestAddressGeneration(UnitDBTestCase):
q.insert_block_filter(134_567, 1, create_address_filter(hashes[8:9]))
# check that all required filters did get created
self.assertEqual(q.get_missing_required_filters(134_567), {})
self.assertEqual(q.get_missing_required_filters(134_567), set())
# no addresses
self.assertEqual([], self.get_ordered_addresses())
# generate addresses with 6 address gap, returns new sub filters needed
self.assertEqual(
@ -154,15 +157,15 @@ class TestAddressGeneration(UnitDBTestCase):
self.receiving_pubkey.chain_code,
self.receiving_pubkey.depth
)), {
(1, 134500),
(1, 134567),
(2, 134000),
(2, 134400),
(3, 130000),
(3, 133000),
(4, 0),
(4, 100000),
(4, 120000)
(0, 134500, 134500),
(0, 134567, 134567),
(1, 134000, 134099),
(1, 134400, 134499),
(2, 130000, 130900),
(2, 133000, 133900),
(3, 0, 9000),
(3, 100000, 109000),
(3, 120000, 129000)
}
)
@ -189,15 +192,108 @@ class TestAddressGeneration(UnitDBTestCase):
self.receiving_pubkey.depth
)), set()
)
# no new addresses should have been generated
# no new addresses should have been generated either
self.assertEqual([key.address for key in pubkeys], self.get_ordered_addresses())
# check sub filters at 1,000
self.assertEqual(
q.generate_addresses_using_filters(134_567, 6, (
q.get_missing_sub_filters_for_addresses(3, (
self.root_pubkey.address, self.RECEIVING_KEY_N,
)), {
(2, 3000, 3900),
(2, 103000, 103900),
(2, 123000, 123900),
}
)
# "download" missing 1,000 sub filters
self.insert_sub_filters(3, hashes[0:1], 3000)
self.insert_sub_filters(3, hashes[1:2], 103_000)
self.insert_sub_filters(3, hashes[2:3], 123_000)
# no more missing sub filters at 1,000
self.assertEqual(
q.get_missing_sub_filters_for_addresses(3, (
self.root_pubkey.address, self.RECEIVING_KEY_N,
)), set()
)
# check sub filters at 100
self.assertEqual(
q.get_missing_sub_filters_for_addresses(2, (
self.root_pubkey.address, self.RECEIVING_KEY_N,
)), {
(1, 3300, 3399),
(1, 103300, 103399),
(1, 123300, 123399),
(1, 130300, 130399),
(1, 133300, 133399),
}
)
# "download" missing 100 sub filters
self.insert_sub_filters(2, hashes[0:1], 3300)
self.insert_sub_filters(2, hashes[1:2], 103_300)
self.insert_sub_filters(2, hashes[2:3], 123_300)
self.insert_sub_filters(2, hashes[3:4], 130_300)
self.insert_sub_filters(2, hashes[4:5], 133_300)
# no more missing sub filters at 100
self.assertEqual(
q.get_missing_sub_filters_for_addresses(2, (
self.root_pubkey.address, self.RECEIVING_KEY_N,
)), set()
)
# check tx filters
self.assertEqual(
q.get_missing_sub_filters_for_addresses(1, (
self.root_pubkey.address, self.RECEIVING_KEY_N,
)), {
(0, 3303, 3303),
(0, 103303, 103303),
(0, 123303, 123303),
(0, 130303, 130303),
(0, 133303, 133303),
(0, 134003, 134003),
(0, 134403, 134403),
}
)
# "download" missing tx filters
self.insert_sub_filters(1, hashes[0:1], 3303)
self.insert_sub_filters(1, hashes[1:2], 103_303)
self.insert_sub_filters(1, hashes[2:3], 123_303)
self.insert_sub_filters(1, hashes[3:4], 130_303)
self.insert_sub_filters(1, hashes[4:5], 133_303)
self.insert_sub_filters(1, hashes[5:6], 134_003)
self.insert_sub_filters(1, hashes[6:7], 134_403)
# no more missing tx filters
self.assertEqual(
q.get_missing_sub_filters_for_addresses(1, (
self.root_pubkey.address, self.RECEIVING_KEY_N,
)), set()
)
# find TXs we need to download
missing_txs = {
b'7478313033333033',
b'7478313233333033',
b'7478313330333033',
b'7478313333333033',
b'7478313334303033',
b'7478313334343033',
b'7478313334353030',
b'7478313334353637',
b'747833333033'
}
self.assertEqual(
q.get_missing_tx_for_addresses((
self.root_pubkey.address, self.RECEIVING_KEY_N,
)), missing_txs
)
# "download" missing TXs
ctx = context()
for tx_hash in missing_txs:
ctx.execute(TX.insert().values(tx_hash=tx_hash))
# check we have everything
self.assertEqual(
q.get_missing_tx_for_addresses((
self.root_pubkey.address, self.RECEIVING_KEY_N,
self.receiving_pubkey.pubkey_bytes,
self.receiving_pubkey.chain_code,
self.receiving_pubkey.depth
)), set()
)