forked from LBRYCommunity/lbry-sdk
432 lines
18 KiB
Python
432 lines
18 KiB
Python
|
import asyncio
|
||
|
import logging
|
||
|
from io import StringIO
|
||
|
from functools import partial
|
||
|
from operator import itemgetter
|
||
|
from collections import defaultdict
|
||
|
from binascii import hexlify, unhexlify
|
||
|
from typing import List, Optional, DefaultDict, NamedTuple
|
||
|
|
||
|
import pylru
|
||
|
from lbry.crypto.hash import double_sha256, sha256
|
||
|
|
||
|
from lbry.service.api import Client
|
||
|
from lbry.tasks import TaskGroup
|
||
|
from lbry.blockchain.transaction import Transaction
|
||
|
from lbry.blockchain.ledger import Ledger
|
||
|
from lbry.blockchain.block import get_block_filter
|
||
|
from lbry.db import Database
|
||
|
from lbry.event import EventController
|
||
|
from lbry.service.base import Service, Sync
|
||
|
|
||
|
from .account import Account, AddressManager
|
||
|
|
||
|
|
||
|
class TransactionEvent(NamedTuple):
|
||
|
address: str
|
||
|
tx: Transaction
|
||
|
|
||
|
|
||
|
class AddressesGeneratedEvent(NamedTuple):
|
||
|
address_manager: AddressManager
|
||
|
addresses: List[str]
|
||
|
|
||
|
|
||
|
class TransactionCacheItem:
|
||
|
__slots__ = '_tx', 'lock', 'has_tx', 'pending_verifications'
|
||
|
|
||
|
def __init__(self, tx: Optional[Transaction] = None, lock: Optional[asyncio.Lock] = None):
|
||
|
self.has_tx = asyncio.Event()
|
||
|
self.lock = lock or asyncio.Lock()
|
||
|
self._tx = self.tx = tx
|
||
|
self.pending_verifications = 0
|
||
|
|
||
|
@property
|
||
|
def tx(self) -> Optional[Transaction]:
|
||
|
return self._tx
|
||
|
|
||
|
@tx.setter
|
||
|
def tx(self, tx: Transaction):
|
||
|
self._tx = tx
|
||
|
if tx is not None:
|
||
|
self.has_tx.set()
|
||
|
|
||
|
|
||
|
class SPVSync(Sync):
|
||
|
|
||
|
def __init__(self, service: Service):
|
||
|
super().__init__(service)
|
||
|
return
|
||
|
self.headers = headers
|
||
|
self.network: Network = self.config.get('network') or Network(self)
|
||
|
self.network.on_header.listen(self.receive_header)
|
||
|
self.network.on_status.listen(self.process_status_update)
|
||
|
self.network.on_connected.listen(self.join_network)
|
||
|
|
||
|
self.accounts = []
|
||
|
|
||
|
self.on_address = self.ledger.on_address
|
||
|
|
||
|
self._on_header_controller = EventController()
|
||
|
self.on_header = self._on_header_controller.stream
|
||
|
self.on_header.listen(
|
||
|
lambda change: log.info(
|
||
|
'%s: added %s header blocks, final height %s',
|
||
|
self.ledger.get_id(), change, self.headers.height
|
||
|
)
|
||
|
)
|
||
|
self._download_height = 0
|
||
|
|
||
|
self._on_ready_controller = EventController()
|
||
|
self.on_ready = self._on_ready_controller.stream
|
||
|
|
||
|
self._tx_cache = pylru.lrucache(100000)
|
||
|
self._update_tasks = TaskGroup()
|
||
|
self._other_tasks = TaskGroup() # that we dont need to start
|
||
|
self._header_processing_lock = asyncio.Lock()
|
||
|
self._address_update_locks: DefaultDict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||
|
self._known_addresses_out_of_sync = set()
|
||
|
|
||
|
async def advance(self):
|
||
|
address_array = [
|
||
|
bytearray(a['address'].encode())
|
||
|
for a in await self.service.db.get_all_addresses()
|
||
|
]
|
||
|
block_filters = await self.service.get_block_address_filters()
|
||
|
for block_hash, block_filter in block_filters.items():
|
||
|
bf = get_block_filter(block_filter)
|
||
|
if bf.MatchAny(address_array):
|
||
|
print(f'match: {block_hash} - {block_filter}')
|
||
|
tx_filters = await self.service.get_transaction_address_filters(block_hash=block_hash)
|
||
|
for txid, tx_filter in tx_filters.items():
|
||
|
tf = get_block_filter(tx_filter)
|
||
|
if tf.MatchAny(address_array):
|
||
|
print(f' match: {txid} - {tx_filter}')
|
||
|
txs = await self.service.search_transactions([txid])
|
||
|
tx = Transaction(unhexlify(txs[txid]))
|
||
|
await self.service.db.insert_transaction(tx)
|
||
|
|
||
|
async def get_local_status_and_history(self, address, history=None):
|
||
|
if not history:
|
||
|
address_details = await self.db.get_address(address=address)
|
||
|
history = (address_details['history'] if address_details else '') or ''
|
||
|
parts = history.split(':')[:-1]
|
||
|
return (
|
||
|
hexlify(sha256(history.encode())).decode() if history else None,
|
||
|
list(zip(parts[0::2], map(int, parts[1::2])))
|
||
|
)
|
||
|
|
||
|
@staticmethod
|
||
|
def get_root_of_merkle_tree(branches, branch_positions, working_branch):
|
||
|
for i, branch in enumerate(branches):
|
||
|
other_branch = unhexlify(branch)[::-1]
|
||
|
other_branch_on_left = bool((branch_positions >> i) & 1)
|
||
|
if other_branch_on_left:
|
||
|
combined = other_branch + working_branch
|
||
|
else:
|
||
|
combined = working_branch + other_branch
|
||
|
working_branch = double_sha256(combined)
|
||
|
return hexlify(working_branch[::-1])
|
||
|
|
||
|
async def start(self):
|
||
|
await self.headers.open()
|
||
|
fully_synced = self.on_ready.first
|
||
|
asyncio.create_task(self.network.start())
|
||
|
await self.network.on_connected.first
|
||
|
async with self._header_processing_lock:
|
||
|
await self._update_tasks.add(self.initial_headers_sync())
|
||
|
await fully_synced
|
||
|
|
||
|
async def join_network(self, *_):
|
||
|
log.info("Subscribing and updating accounts.")
|
||
|
await self._update_tasks.add(self.subscribe_accounts())
|
||
|
await self._update_tasks.done.wait()
|
||
|
self._on_ready_controller.add(True)
|
||
|
|
||
|
async def stop(self):
|
||
|
self._update_tasks.cancel()
|
||
|
self._other_tasks.cancel()
|
||
|
await self._update_tasks.done.wait()
|
||
|
await self._other_tasks.done.wait()
|
||
|
await self.network.stop()
|
||
|
await self.headers.close()
|
||
|
|
||
|
@property
|
||
|
def local_height_including_downloaded_height(self):
|
||
|
return max(self.headers.height, self._download_height)
|
||
|
|
||
|
async def initial_headers_sync(self):
|
||
|
get_chunk = partial(self.network.retriable_call, self.network.get_headers, count=1000, b64=True)
|
||
|
self.headers.chunk_getter = get_chunk
|
||
|
|
||
|
async def doit():
|
||
|
for height in reversed(sorted(self.headers.known_missing_checkpointed_chunks)):
|
||
|
async with self._header_processing_lock:
|
||
|
await self.headers.ensure_chunk_at(height)
|
||
|
self._other_tasks.add(doit())
|
||
|
await self.update_headers()
|
||
|
|
||
|
async def update_headers(self, height=None, headers=None, subscription_update=False):
|
||
|
rewound = 0
|
||
|
while True:
|
||
|
|
||
|
if height is None or height > len(self.headers):
|
||
|
# sometimes header subscription updates are for a header in the future
|
||
|
# which can't be connected, so we do a normal header sync instead
|
||
|
height = len(self.headers)
|
||
|
headers = None
|
||
|
subscription_update = False
|
||
|
|
||
|
if not headers:
|
||
|
header_response = await self.network.retriable_call(self.network.get_headers, height, 2001)
|
||
|
headers = header_response['hex']
|
||
|
|
||
|
if not headers:
|
||
|
# Nothing to do, network thinks we're already at the latest height.
|
||
|
return
|
||
|
|
||
|
added = await self.headers.connect(height, unhexlify(headers))
|
||
|
if added > 0:
|
||
|
height += added
|
||
|
self._on_header_controller.add(
|
||
|
BlockHeightEvent(self.headers.height, added))
|
||
|
|
||
|
if rewound > 0:
|
||
|
# we started rewinding blocks and apparently found
|
||
|
# a new chain
|
||
|
rewound = 0
|
||
|
await self.db.rewind_blockchain(height)
|
||
|
|
||
|
if subscription_update:
|
||
|
# subscription updates are for latest header already
|
||
|
# so we don't need to check if there are newer / more
|
||
|
# on another loop of update_headers(), just return instead
|
||
|
return
|
||
|
|
||
|
elif added == 0:
|
||
|
# we had headers to connect but none got connected, probably a reorganization
|
||
|
height -= 1
|
||
|
rewound += 1
|
||
|
log.warning(
|
||
|
"Blockchain Reorganization: attempting rewind to height %s from starting height %s",
|
||
|
height, height+rewound
|
||
|
)
|
||
|
|
||
|
else:
|
||
|
raise IndexError(f"headers.connect() returned negative number ({added})")
|
||
|
|
||
|
if height < 0:
|
||
|
raise IndexError(
|
||
|
"Blockchain reorganization rewound all the way back to genesis hash. "
|
||
|
"Something is very wrong. Maybe you are on the wrong blockchain?"
|
||
|
)
|
||
|
|
||
|
if rewound >= 100:
|
||
|
raise IndexError(
|
||
|
"Blockchain reorganization dropped {} headers. This is highly unusual. "
|
||
|
"Will not continue to attempt reorganizing. Please, delete the ledger "
|
||
|
"synchronization directory inside your wallet directory (folder: '{}') and "
|
||
|
"restart the program to synchronize from scratch."
|
||
|
.format(rewound, self.ledger.get_id())
|
||
|
)
|
||
|
|
||
|
headers = None # ready to download some more headers
|
||
|
|
||
|
# if we made it this far and this was a subscription_update
|
||
|
# it means something went wrong and now we're doing a more
|
||
|
# robust sync, turn off subscription update shortcut
|
||
|
subscription_update = False
|
||
|
|
||
|
async def receive_header(self, response):
|
||
|
async with self._header_processing_lock:
|
||
|
header = response[0]
|
||
|
await self.update_headers(
|
||
|
height=header['height'], headers=header['hex'], subscription_update=True
|
||
|
)
|
||
|
|
||
|
async def subscribe_accounts(self):
|
||
|
if self.network.is_connected and self.accounts:
|
||
|
log.info("Subscribe to %i accounts", len(self.accounts))
|
||
|
await asyncio.wait([
|
||
|
self.subscribe_account(a) for a in self.accounts
|
||
|
])
|
||
|
|
||
|
async def subscribe_account(self, account: Account):
|
||
|
for address_manager in account.address_managers.values():
|
||
|
await self.subscribe_addresses(address_manager, await address_manager.get_addresses())
|
||
|
await account.ensure_address_gap()
|
||
|
|
||
|
async def unsubscribe_account(self, account: Account):
|
||
|
for address in await account.get_addresses():
|
||
|
await self.network.unsubscribe_address(address)
|
||
|
|
||
|
async def subscribe_addresses(self, address_manager: AddressManager, addresses: List[str], batch_size: int = 1000):
|
||
|
if self.network.is_connected and addresses:
|
||
|
addresses_remaining = list(addresses)
|
||
|
while addresses_remaining:
|
||
|
batch = addresses_remaining[:batch_size]
|
||
|
results = await self.network.subscribe_address(*batch)
|
||
|
for address, remote_status in zip(batch, results):
|
||
|
self._update_tasks.add(self.update_history(address, remote_status, address_manager))
|
||
|
addresses_remaining = addresses_remaining[batch_size:]
|
||
|
log.info("subscribed to %i/%i addresses on %s:%i", len(addresses) - len(addresses_remaining),
|
||
|
len(addresses), *self.network.client.server_address_and_port)
|
||
|
log.info(
|
||
|
"finished subscribing to %i addresses on %s:%i", len(addresses),
|
||
|
*self.network.client.server_address_and_port
|
||
|
)
|
||
|
|
||
|
def process_status_update(self, update):
|
||
|
address, remote_status = update
|
||
|
self._update_tasks.add(self.update_history(address, remote_status))
|
||
|
|
||
|
async def update_history(self, address, remote_status, address_manager: AddressManager = None):
|
||
|
async with self._address_update_locks[address]:
|
||
|
self._known_addresses_out_of_sync.discard(address)
|
||
|
|
||
|
local_status, local_history = await self.get_local_status_and_history(address)
|
||
|
|
||
|
if local_status == remote_status:
|
||
|
return True
|
||
|
|
||
|
remote_history = await self.network.retriable_call(self.network.get_history, address)
|
||
|
remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history))
|
||
|
we_need = set(remote_history) - set(local_history)
|
||
|
if not we_need:
|
||
|
return True
|
||
|
|
||
|
cache_tasks: List[asyncio.Task[Transaction]] = []
|
||
|
synced_history = StringIO()
|
||
|
loop = asyncio.get_running_loop()
|
||
|
for i, (txid, remote_height) in enumerate(remote_history):
|
||
|
if i < len(local_history) and local_history[i] == (txid, remote_height) and not cache_tasks:
|
||
|
synced_history.write(f'{txid}:{remote_height}:')
|
||
|
else:
|
||
|
check_local = (txid, remote_height) not in we_need
|
||
|
cache_tasks.append(loop.create_task(
|
||
|
self.cache_transaction(unhexlify(txid)[::-1], remote_height, check_local=check_local)
|
||
|
))
|
||
|
|
||
|
synced_txs = []
|
||
|
for task in cache_tasks:
|
||
|
tx = await task
|
||
|
|
||
|
check_db_for_txos = []
|
||
|
for txi in tx.inputs:
|
||
|
if txi.txo_ref.txo is not None:
|
||
|
continue
|
||
|
cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.hash)
|
||
|
if cache_item is not None:
|
||
|
if cache_item.tx is None:
|
||
|
await cache_item.has_tx.wait()
|
||
|
assert cache_item.tx is not None
|
||
|
txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref
|
||
|
else:
|
||
|
check_db_for_txos.append(txi.txo_ref.hash)
|
||
|
|
||
|
referenced_txos = {} if not check_db_for_txos else {
|
||
|
txo.id: txo for txo in await self.db.get_txos(
|
||
|
txo_hash__in=check_db_for_txos, order_by='txo.txo_hash', no_tx=True
|
||
|
)
|
||
|
}
|
||
|
|
||
|
for txi in tx.inputs:
|
||
|
if txi.txo_ref.txo is not None:
|
||
|
continue
|
||
|
referenced_txo = referenced_txos.get(txi.txo_ref.id)
|
||
|
if referenced_txo is not None:
|
||
|
txi.txo_ref = referenced_txo.ref
|
||
|
|
||
|
synced_history.write(f'{tx.id}:{tx.height}:')
|
||
|
synced_txs.append(tx)
|
||
|
|
||
|
await self.db.save_transaction_io_batch(
|
||
|
synced_txs, address, self.ledger.address_to_hash160(address), synced_history.getvalue()
|
||
|
)
|
||
|
await asyncio.wait([
|
||
|
self.ledger._on_transaction_controller.add(TransactionEvent(address, tx))
|
||
|
for tx in synced_txs
|
||
|
])
|
||
|
|
||
|
if address_manager is None:
|
||
|
address_manager = await self.get_address_manager_for_address(address)
|
||
|
|
||
|
if address_manager is not None:
|
||
|
await address_manager.ensure_address_gap()
|
||
|
|
||
|
local_status, local_history = \
|
||
|
await self.get_local_status_and_history(address, synced_history.getvalue())
|
||
|
if local_status != remote_status:
|
||
|
if local_history == remote_history:
|
||
|
return True
|
||
|
log.warning(
|
||
|
"Wallet is out of sync after syncing. Remote: %s with %d items, local: %s with %d items",
|
||
|
remote_status, len(remote_history), local_status, len(local_history)
|
||
|
)
|
||
|
log.warning("local: %s", local_history)
|
||
|
log.warning("remote: %s", remote_history)
|
||
|
self._known_addresses_out_of_sync.add(address)
|
||
|
return False
|
||
|
else:
|
||
|
return True
|
||
|
|
||
|
async def cache_transaction(self, tx_hash, remote_height, check_local=True):
|
||
|
cache_item = self._tx_cache.get(tx_hash)
|
||
|
if cache_item is None:
|
||
|
cache_item = self._tx_cache[tx_hash] = TransactionCacheItem()
|
||
|
elif cache_item.tx is not None and \
|
||
|
cache_item.tx.height >= remote_height and \
|
||
|
(cache_item.tx.is_verified or remote_height < 1):
|
||
|
return cache_item.tx # cached tx is already up-to-date
|
||
|
|
||
|
try:
|
||
|
cache_item.pending_verifications += 1
|
||
|
return await self._update_cache_item(cache_item, tx_hash, remote_height, check_local)
|
||
|
finally:
|
||
|
cache_item.pending_verifications -= 1
|
||
|
|
||
|
async def _update_cache_item(self, cache_item, tx_hash, remote_height, check_local=True):
|
||
|
|
||
|
async with cache_item.lock:
|
||
|
|
||
|
tx = cache_item.tx
|
||
|
|
||
|
if tx is None and check_local:
|
||
|
# check local db
|
||
|
tx = cache_item.tx = await self.db.get_transaction(tx_hash=tx_hash)
|
||
|
|
||
|
merkle = None
|
||
|
if tx is None:
|
||
|
# fetch from network
|
||
|
_raw, merkle = await self.network.retriable_call(
|
||
|
self.network.get_transaction_and_merkle, tx_hash, remote_height
|
||
|
)
|
||
|
tx = Transaction(unhexlify(_raw), height=merkle.get('block_height'))
|
||
|
cache_item.tx = tx # make sure it's saved before caching it
|
||
|
await self.maybe_verify_transaction(tx, remote_height, merkle)
|
||
|
return tx
|
||
|
|
||
|
async def maybe_verify_transaction(self, tx, remote_height, merkle=None):
|
||
|
tx.height = remote_height
|
||
|
cached = self._tx_cache.get(tx.hash)
|
||
|
if not cached:
|
||
|
# cache txs looked up by transaction_show too
|
||
|
cached = TransactionCacheItem()
|
||
|
cached.tx = tx
|
||
|
self._tx_cache[tx.hash] = cached
|
||
|
if 0 < remote_height < len(self.headers) and cached.pending_verifications <= 1:
|
||
|
# can't be tx.pending_verifications == 1 because we have to handle the transaction_show case
|
||
|
if not merkle:
|
||
|
merkle = await self.network.retriable_call(self.network.get_merkle, tx.hash, remote_height)
|
||
|
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
|
||
|
header = await self.headers.get(remote_height)
|
||
|
tx.position = merkle['pos']
|
||
|
tx.is_verified = merkle_root == header['merkle_root']
|
||
|
|
||
|
async def get_address_manager_for_address(self, address) -> Optional[AddressManager]:
|
||
|
details = await self.db.get_address(address=address)
|
||
|
for account in self.accounts:
|
||
|
if account.id == details['account']:
|
||
|
return account.address_managers[details['chain']]
|
||
|
return None
|