diff --git a/lbry/blockchain/__init__.py b/lbry/blockchain/__init__.py index d42e3e1b1..e69de29bb 100644 --- a/lbry/blockchain/__init__.py +++ b/lbry/blockchain/__init__.py @@ -1,2 +0,0 @@ -from .sync import BlockchainSync -from .lbrycrd import Lbrycrd diff --git a/lbry/blockchain/bcd_data_stream.py b/lbry/blockchain/bcd_data_stream.py index 9314f3d46..320adc066 100644 --- a/lbry/blockchain/bcd_data_stream.py +++ b/lbry/blockchain/bcd_data_stream.py @@ -7,6 +7,9 @@ class BCDataStream: def __init__(self, data=None, fp=None): self.data = fp or BytesIO(data) + def tell(self): + return self.data.tell() + def reset(self): self.data.seek(0) diff --git a/lbry/blockchain/block.py b/lbry/blockchain/block.py index c59755b32..38e421307 100644 --- a/lbry/blockchain/block.py +++ b/lbry/blockchain/block.py @@ -1,37 +1,59 @@ import struct +from hashlib import sha256 +from typing import Set +from binascii import unhexlify +from typing import NamedTuple, List + +from chiabip158 import PyBIP158 + from lbry.crypto.hash import double_sha256 -from lbry.wallet.transaction import Transaction -from lbry.wallet.bcd_data_stream import BCDataStream +from lbry.blockchain.transaction import Transaction +from lbry.blockchain.bcd_data_stream import BCDataStream ZERO_BLOCK = bytes((0,)*32) -class Block: +def create_block_filter(addresses: Set[str]) -> bytes: + return bytes(PyBIP158([bytearray(a.encode()) for a in addresses]).GetEncoded()) - __slots__ = ( - 'version', 'block_hash', 'prev_block_hash', - 'merkle_root', 'claim_trie_root', 'timestamp', - 'bits', 'nonce', 'txs' - ) - def __init__(self, stream: BCDataStream): +def get_block_filter(block_filter: str) -> PyBIP158: + return PyBIP158(bytearray(unhexlify(block_filter))) + + +class Block(NamedTuple): + height: int + version: int + file_number: int + block_hash: bytes + prev_block_hash: bytes + merkle_root: bytes + claim_trie_root: bytes + timestamp: int + bits: int + nonce: int + txs: List[Transaction] + + @staticmethod + def from_data_stream(stream: BCDataStream, height: int, file_number: int): header = stream.data.read(112) version, = struct.unpack(' int: diff --git a/lbry/blockchain/hash.py b/lbry/blockchain/hash.py index 08a0aee82..872a39506 100644 --- a/lbry/blockchain/hash.py +++ b/lbry/blockchain/hash.py @@ -1,5 +1,5 @@ from binascii import hexlify, unhexlify -from .constants import NULL_HASH32 +from lbry.constants import NULL_HASH32 class TXRef: diff --git a/lbry/blockchain/header.py b/lbry/blockchain/header.py index 0e51257ca..31a4f6b37 100644 --- a/lbry/blockchain/header.py +++ b/lbry/blockchain/header.py @@ -12,7 +12,7 @@ from typing import Optional, Iterator, Tuple, Callable from binascii import hexlify, unhexlify from lbry.crypto.hash import sha512, double_sha256, ripemd160 -from lbry.wallet.util import ArithUint256 +from lbry.blockchain.util import ArithUint256 from .checkpoints import HASHES diff --git a/lbry/blockchain/lbrycrd.py b/lbry/blockchain/lbrycrd.py index 76d9a25a5..aee154d22 100644 --- a/lbry/blockchain/lbrycrd.py +++ b/lbry/blockchain/lbrycrd.py @@ -8,15 +8,16 @@ import tempfile import urllib.request from typing import Optional from binascii import hexlify -from concurrent.futures import ThreadPoolExecutor import aiohttp import zmq import zmq.asyncio -from lbry.wallet.stream import StreamController +from lbry.conf import Config +from lbry.event import EventController from .database import BlockchainDB +from .ledger import Ledger, RegTestLedger log = logging.getLogger(__name__) @@ -58,10 +59,10 @@ class Process(asyncio.SubprocessProtocol): class Lbrycrd: - def __init__(self, path, regtest=False): - self.data_dir = self.actual_data_dir = path - self.regtest = regtest - if regtest: + def __init__(self, ledger: Ledger): + self.ledger = ledger + self.data_dir = self.actual_data_dir = ledger.conf.lbrycrd_dir + if self.is_regtest: self.actual_data_dir = os.path.join(self.data_dir, 'regtest') self.blocks_dir = os.path.join(self.actual_data_dir, 'blocks') self.bin_dir = os.path.join(os.path.dirname(__file__), 'bin') @@ -74,34 +75,27 @@ class Lbrycrd: self.rpcport = 9245 + 2 # avoid conflict with default rpc port self.rpcuser = 'rpcuser' self.rpcpassword = 'rpcpassword' - self.session: Optional[aiohttp.ClientSession] = None self.subscribed = False self.subscription: Optional[asyncio.Task] = None self.subscription_url = 'tcp://127.0.0.1:29000' self.default_generate_address = None - self._on_block_controller = StreamController() + self._on_block_controller = EventController() self.on_block = self._on_block_controller.stream self.on_block.listen(lambda e: log.info('%s %s', hexlify(e['hash']), e['msg'])) self.db = BlockchainDB(self.actual_data_dir) - self.executor = ThreadPoolExecutor(max_workers=1) + self.session: Optional[aiohttp.ClientSession] = None + + @classmethod + def temp_regtest(cls): + return cls(RegTestLedger(Config.with_same_dir(tempfile.mkdtemp()))) def get_block_file_path_from_number(self, block_file_number): return os.path.join(self.actual_data_dir, 'blocks', f'blk{block_file_number:05}.dat') - async def get_block_files(self): - return await asyncio.get_running_loop().run_in_executor( - self.executor, self.db.get_block_files - ) - - async def get_file_details(self, block_file): - return await asyncio.get_running_loop().run_in_executor( - self.executor, self.db.get_file_details, block_file - ) - - @classmethod - def temp_regtest(cls): - return cls(tempfile.mkdtemp(), True) + @property + def is_regtest(self): + return isinstance(self.ledger, RegTestLedger) @property def rpc_url(self): @@ -150,7 +144,7 @@ class Lbrycrd: return self.exists or await self.download() def get_start_command(self, *args): - if self.regtest: + if self.is_regtest: args += ('-regtest',) return ( self.daemon_bin, @@ -164,6 +158,14 @@ class Lbrycrd: *args ) + async def open(self): + self.session = aiohttp.ClientSession() + await self.db.open() + + async def close(self): + await self.db.close() + await self.session.close() + async def start(self, *args): loop = asyncio.get_event_loop() command = self.get_start_command(*args) @@ -171,11 +173,11 @@ class Lbrycrd: self.transport, self.protocol = await loop.subprocess_exec(Process, *command) await self.protocol.ready.wait() assert not self.protocol.stopped.is_set() - self.session = aiohttp.ClientSession() + await self.open() async def stop(self, cleanup=True): try: - await self.session.close() + await self.close() self.transport.terminate() await self.protocol.stopped.wait() assert self.transport.get_returncode() == 0, "lbrycrd daemon exit with error" @@ -201,7 +203,7 @@ class Lbrycrd: try: while self.subscribed: msg = await sock.recv_multipart() - self._on_block_controller.add({ + await self._on_block_controller.add({ 'hash': msg[1], 'msg': struct.unpack(' LedgerType: - return mcs.ledgers[ledger_id] - - -class TransactionEvent(NamedTuple): - address: str - tx: Transaction - - -class AddressesGeneratedEvent(NamedTuple): - address_manager: AddressManager - addresses: List[str] - - -class BlockHeightEvent(NamedTuple): - height: int - change: int - - -class TransactionCacheItem: - __slots__ = '_tx', 'lock', 'has_tx', 'pending_verifications' - - def __init__(self, tx: Optional[Transaction] = None, lock: Optional[asyncio.Lock] = None): - self.has_tx = asyncio.Event() - self.lock = lock or asyncio.Lock() - self._tx = self.tx = tx - self.pending_verifications = 0 - - @property - def tx(self) -> Optional[Transaction]: - return self._tx - - @tx.setter - def tx(self, tx: Transaction): - self._tx = tx - if tx is not None: - self.has_tx.set() - - -class Ledger(metaclass=LedgerRegistry): +class Ledger: name = 'LBRY Credits' symbol = 'LBC' network_name = 'mainnet' @@ -107,69 +26,14 @@ class Ledger(metaclass=LedgerRegistry): genesis_bits = 0x1f00ffff target_timespan = 150 - default_fee_per_byte = 50 - default_fee_per_name_char = 200000 + fee_per_byte = 50 + fee_per_name_char = 200000 checkpoints = HASHES - def __init__(self, config=None): - self.config = config or {} - self.db: Database = self.config.get('db') or Database( - os.path.join(self.path, "blockchain.db") - ) - self.db.ledger = self - self.headers: Headers = self.config.get('headers') or self.headers_class( - os.path.join(self.path, "headers") - ) - self.headers.checkpoints = self.checkpoints - self.network: Network = self.config.get('network') or Network(self) - self.network.on_header.listen(self.receive_header) - self.network.on_status.listen(self.process_status_update) - self.network.on_connected.listen(self.join_network) - - self.accounts = [] - self.fee_per_byte: int = self.config.get('fee_per_byte', self.default_fee_per_byte) - - self._on_transaction_controller = StreamController() - self.on_transaction = self._on_transaction_controller.stream - self.on_transaction.listen( - lambda e: log.info( - '(%s) on_transaction: address=%s, height=%s, is_verified=%s, tx.id=%s', - self.get_id(), e.address, e.tx.height, e.tx.is_verified, e.tx.id - ) - ) - - self._on_address_controller = StreamController() - self.on_address = self._on_address_controller.stream - self.on_address.listen( - lambda e: log.info('(%s) on_address: %s', self.get_id(), e.addresses) - ) - - self._on_header_controller = StreamController() - self.on_header = self._on_header_controller.stream - self.on_header.listen( - lambda change: log.info( - '%s: added %s header blocks, final height %s', - self.get_id(), change, self.headers.height - ) - ) - self._download_height = 0 - - self._on_ready_controller = StreamController() - self.on_ready = self._on_ready_controller.stream - - self._tx_cache = pylru.lrucache(100000) - self._update_tasks = TaskGroup() - self._other_tasks = TaskGroup() # that we dont need to start - self._utxo_reservation_lock = asyncio.Lock() - self._header_processing_lock = asyncio.Lock() - self._address_update_locks: DefaultDict[str, asyncio.Lock] = defaultdict(asyncio.Lock) - + def __init__(self, conf: Config = None): + self.conf = conf or Config.with_same_dir('/dev/null') self.coin_selection_strategy = None - self._known_addresses_out_of_sync = set() - - self.fee_per_name_char = self.config.get('fee_per_name_char', self.default_fee_per_name_char) - self._balance_cache = pylru.lrucache(100000) @classmethod def get_id(cls): @@ -189,6 +53,28 @@ class Ledger(metaclass=LedgerRegistry): decoded = Base58.decode_check(address) return decoded[0] == cls.pubkey_address_prefix[0] + @classmethod + def valid_address_or_error(cls, address): + try: + assert cls.is_valid_address(address) + except: + raise Exception(f"'{address}' is not a valid address") + + @classmethod + def valid_channel_name_or_error(cls, name): + try: + if not name: + raise Exception( + "Channel name cannot be blank." + ) + parsed = URL.parse(name) + if not parsed.has_channel: + raise Exception("Channel names must start with '@' symbol.") + if parsed.channel.name != name: + raise Exception("Channel name has invalid character") + except (TypeError, ValueError): + raise Exception("Invalid channel name.") + @classmethod def public_key_to_address(cls, public_key): return cls.hash160_to_address(hash160(public_key)) @@ -197,867 +83,6 @@ class Ledger(metaclass=LedgerRegistry): def private_key_to_wif(private_key): return b'\x1c' + private_key + b'\x01' - @property - def path(self): - return os.path.join(self.config['data_path'], self.get_id()) - - def add_account(self, account: Account): - self.accounts.append(account) - - async def _get_account_and_address_info_for_address(self, wallet, address): - match = await self.db.get_address(accounts=wallet.accounts, address=address) - if match: - for account in wallet.accounts: - if match['account'] == account.public_key.address: - return account, match - - async def get_private_key_for_address(self, wallet, address) -> Optional[PrivateKey]: - match = await self._get_account_and_address_info_for_address(wallet, address) - if match: - account, address_info = match - return account.get_private_key(address_info['chain'], address_info['pubkey'].n) - return None - - async def get_public_key_for_address(self, wallet, address) -> Optional[PubKey]: - match = await self._get_account_and_address_info_for_address(wallet, address) - if match: - _, address_info = match - return address_info['pubkey'] - return None - - async def get_account_for_address(self, wallet, address): - match = await self._get_account_and_address_info_for_address(wallet, address) - if match: - return match[0] - - async def get_effective_amount_estimators(self, funding_accounts: Iterable[Account]): - estimators = [] - for account in funding_accounts: - utxos = await account.get_utxos() - for utxo in utxos: - estimators.append(utxo.get_estimator(self)) - return estimators - - async def get_addresses(self, **constraints): - return await self.db.get_addresses(**constraints) - - def get_address_count(self, **constraints): - return self.db.get_address_count(**constraints) - - async def get_spendable_utxos(self, amount: int, funding_accounts): - async with self._utxo_reservation_lock: - txos = await self.get_effective_amount_estimators(funding_accounts) - fee = Output.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(self) - selector = CoinSelector(amount, fee) - spendables = selector.select(txos, self.coin_selection_strategy) - if spendables: - await self.reserve_outputs(s.txo for s in spendables) - return spendables - - def reserve_outputs(self, txos): - return self.db.reserve_outputs(txos) - - def release_outputs(self, txos): - return self.db.release_outputs(txos) - - def release_tx(self, tx): - return self.release_outputs([txi.txo_ref.txo for txi in tx.inputs]) - - def get_utxos(self, **constraints): - self.constraint_spending_utxos(constraints) - return self.db.get_utxos(**constraints) - - def get_utxo_count(self, **constraints): - self.constraint_spending_utxos(constraints) - return self.db.get_utxo_count(**constraints) - - async def get_txos(self, resolve=False, **constraints) -> List[Output]: - txos = await self.db.get_txos(**constraints) - if resolve: - return await self._resolve_for_local_results(constraints.get('accounts', []), txos) - return txos - - def get_txo_count(self, **constraints): - return self.db.get_txo_count(**constraints) - - def get_txo_sum(self, **constraints): - return self.db.get_txo_sum(**constraints) - - def get_txo_plot(self, **constraints): - return self.db.get_txo_plot(**constraints) - - def get_transactions(self, **constraints): - return self.db.get_transactions(**constraints) - - def get_transaction_count(self, **constraints): - return self.db.get_transaction_count(**constraints) - - async def get_local_status_and_history(self, address, history=None): - if not history: - address_details = await self.db.get_address(address=address) - history = (address_details['history'] if address_details else '') or '' - parts = history.split(':')[:-1] - return ( - hexlify(sha256(history.encode())).decode() if history else None, - list(zip(parts[0::2], map(int, parts[1::2]))) - ) - - @staticmethod - def get_root_of_merkle_tree(branches, branch_positions, working_branch): - for i, branch in enumerate(branches): - other_branch = unhexlify(branch)[::-1] - other_branch_on_left = bool((branch_positions >> i) & 1) - if other_branch_on_left: - combined = other_branch + working_branch - else: - combined = working_branch + other_branch - working_branch = double_sha256(combined) - return hexlify(working_branch[::-1]) - - async def start(self): - if not os.path.exists(self.path): - os.mkdir(self.path) - await asyncio.wait([ - self.db.open(), - self.headers.open() - ]) - fully_synced = self.on_ready.first - asyncio.create_task(self.network.start()) - await self.network.on_connected.first - async with self._header_processing_lock: - await self._update_tasks.add(self.initial_headers_sync()) - await fully_synced - await asyncio.gather(*(a.maybe_migrate_certificates() for a in self.accounts)) - await asyncio.gather(*(a.save_max_gap() for a in self.accounts)) - if len(self.accounts) > 10: - log.info("Loaded %i accounts", len(self.accounts)) - else: - await self._report_state() - self.on_transaction.listen(self._reset_balance_cache) - - async def join_network(self, *_): - log.info("Subscribing and updating accounts.") - await self._update_tasks.add(self.subscribe_accounts()) - await self._update_tasks.done.wait() - self._on_ready_controller.add(True) - - async def stop(self): - self._update_tasks.cancel() - self._other_tasks.cancel() - await self._update_tasks.done.wait() - await self._other_tasks.done.wait() - await self.network.stop() - await self.db.close() - await self.headers.close() - - @property - def local_height_including_downloaded_height(self): - return max(self.headers.height, self._download_height) - - async def initial_headers_sync(self): - get_chunk = partial(self.network.retriable_call, self.network.get_headers, count=1000, b64=True) - self.headers.chunk_getter = get_chunk - - async def doit(): - for height in reversed(sorted(self.headers.known_missing_checkpointed_chunks)): - async with self._header_processing_lock: - await self.headers.ensure_chunk_at(height) - self._other_tasks.add(doit()) - await self.update_headers() - - async def update_headers(self, height=None, headers=None, subscription_update=False): - rewound = 0 - while True: - - if height is None or height > len(self.headers): - # sometimes header subscription updates are for a header in the future - # which can't be connected, so we do a normal header sync instead - height = len(self.headers) - headers = None - subscription_update = False - - if not headers: - header_response = await self.network.retriable_call(self.network.get_headers, height, 2001) - headers = header_response['hex'] - - if not headers: - # Nothing to do, network thinks we're already at the latest height. - return - - added = await self.headers.connect(height, unhexlify(headers)) - if added > 0: - height += added - self._on_header_controller.add( - BlockHeightEvent(self.headers.height, added)) - - if rewound > 0: - # we started rewinding blocks and apparently found - # a new chain - rewound = 0 - await self.db.rewind_blockchain(height) - - if subscription_update: - # subscription updates are for latest header already - # so we don't need to check if there are newer / more - # on another loop of update_headers(), just return instead - return - - elif added == 0: - # we had headers to connect but none got connected, probably a reorganization - height -= 1 - rewound += 1 - log.warning( - "Blockchain Reorganization: attempting rewind to height %s from starting height %s", - height, height+rewound - ) - - else: - raise IndexError(f"headers.connect() returned negative number ({added})") - - if height < 0: - raise IndexError( - "Blockchain reorganization rewound all the way back to genesis hash. " - "Something is very wrong. Maybe you are on the wrong blockchain?" - ) - - if rewound >= 100: - raise IndexError( - "Blockchain reorganization dropped {} headers. This is highly unusual. " - "Will not continue to attempt reorganizing. Please, delete the ledger " - "synchronization directory inside your wallet directory (folder: '{}') and " - "restart the program to synchronize from scratch." - .format(rewound, self.get_id()) - ) - - headers = None # ready to download some more headers - - # if we made it this far and this was a subscription_update - # it means something went wrong and now we're doing a more - # robust sync, turn off subscription update shortcut - subscription_update = False - - async def receive_header(self, response): - async with self._header_processing_lock: - header = response[0] - await self.update_headers( - height=header['height'], headers=header['hex'], subscription_update=True - ) - - async def subscribe_accounts(self): - if self.network.is_connected and self.accounts: - log.info("Subscribe to %i accounts", len(self.accounts)) - await asyncio.wait([ - self.subscribe_account(a) for a in self.accounts - ]) - - async def subscribe_account(self, account: Account): - for address_manager in account.address_managers.values(): - await self.subscribe_addresses(address_manager, await address_manager.get_addresses()) - await account.ensure_address_gap() - - async def unsubscribe_account(self, account: Account): - for address in await account.get_addresses(): - await self.network.unsubscribe_address(address) - - async def announce_addresses(self, address_manager: AddressManager, addresses: List[str]): - await self.subscribe_addresses(address_manager, addresses) - await self._on_address_controller.add( - AddressesGeneratedEvent(address_manager, addresses) - ) - - async def subscribe_addresses(self, address_manager: AddressManager, addresses: List[str], batch_size: int = 1000): - if self.network.is_connected and addresses: - addresses_remaining = list(addresses) - while addresses_remaining: - batch = addresses_remaining[:batch_size] - results = await self.network.subscribe_address(*batch) - for address, remote_status in zip(batch, results): - self._update_tasks.add(self.update_history(address, remote_status, address_manager)) - addresses_remaining = addresses_remaining[batch_size:] - log.info("subscribed to %i/%i addresses on %s:%i", len(addresses) - len(addresses_remaining), - len(addresses), *self.network.client.server_address_and_port) - log.info( - "finished subscribing to %i addresses on %s:%i", len(addresses), - *self.network.client.server_address_and_port - ) - - def process_status_update(self, update): - address, remote_status = update - self._update_tasks.add(self.update_history(address, remote_status)) - - async def update_history(self, address, remote_status, address_manager: AddressManager = None): - async with self._address_update_locks[address]: - self._known_addresses_out_of_sync.discard(address) - - local_status, local_history = await self.get_local_status_and_history(address) - - if local_status == remote_status: - return True - - remote_history = await self.network.retriable_call(self.network.get_history, address) - remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history)) - we_need = set(remote_history) - set(local_history) - if not we_need: - return True - - cache_tasks: List[asyncio.Task[Transaction]] = [] - synced_history = StringIO() - loop = asyncio.get_running_loop() - for i, (txid, remote_height) in enumerate(remote_history): - if i < len(local_history) and local_history[i] == (txid, remote_height) and not cache_tasks: - synced_history.write(f'{txid}:{remote_height}:') - else: - check_local = (txid, remote_height) not in we_need - cache_tasks.append(loop.create_task( - self.cache_transaction(unhexlify(txid)[::-1], remote_height, check_local=check_local) - )) - - synced_txs = [] - for task in cache_tasks: - tx = await task - - check_db_for_txos = [] - for txi in tx.inputs: - if txi.txo_ref.txo is not None: - continue - cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.hash) - if cache_item is not None: - if cache_item.tx is None: - await cache_item.has_tx.wait() - assert cache_item.tx is not None - txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref - else: - check_db_for_txos.append(txi.txo_ref.hash) - - referenced_txos = {} if not check_db_for_txos else { - txo.id: txo for txo in await self.db.get_txos( - txo_hash__in=check_db_for_txos, order_by='txo.txo_hash', no_tx=True - ) - } - - for txi in tx.inputs: - if txi.txo_ref.txo is not None: - continue - referenced_txo = referenced_txos.get(txi.txo_ref.id) - if referenced_txo is not None: - txi.txo_ref = referenced_txo.ref - - synced_history.write(f'{tx.id}:{tx.height}:') - synced_txs.append(tx) - - await self.db.save_transaction_io_batch( - synced_txs, address, self.address_to_hash160(address), synced_history.getvalue() - ) - await asyncio.wait([ - self._on_transaction_controller.add(TransactionEvent(address, tx)) - for tx in synced_txs - ]) - - if address_manager is None: - address_manager = await self.get_address_manager_for_address(address) - - if address_manager is not None: - await address_manager.ensure_address_gap() - - local_status, local_history = \ - await self.get_local_status_and_history(address, synced_history.getvalue()) - if local_status != remote_status: - if local_history == remote_history: - return True - log.warning( - "Wallet is out of sync after syncing. Remote: %s with %d items, local: %s with %d items", - remote_status, len(remote_history), local_status, len(local_history) - ) - log.warning("local: %s", local_history) - log.warning("remote: %s", remote_history) - self._known_addresses_out_of_sync.add(address) - return False - else: - return True - - async def cache_transaction(self, tx_hash, remote_height, check_local=True): - cache_item = self._tx_cache.get(tx_hash) - if cache_item is None: - cache_item = self._tx_cache[tx_hash] = TransactionCacheItem() - elif cache_item.tx is not None and \ - cache_item.tx.height >= remote_height and \ - (cache_item.tx.is_verified or remote_height < 1): - return cache_item.tx # cached tx is already up-to-date - - try: - cache_item.pending_verifications += 1 - return await self._update_cache_item(cache_item, tx_hash, remote_height, check_local) - finally: - cache_item.pending_verifications -= 1 - - async def _update_cache_item(self, cache_item, tx_hash, remote_height, check_local=True): - - async with cache_item.lock: - - tx = cache_item.tx - - if tx is None and check_local: - # check local db - tx = cache_item.tx = await self.db.get_transaction(tx_hash=tx_hash) - - merkle = None - if tx is None: - # fetch from network - _raw, merkle = await self.network.retriable_call( - self.network.get_transaction_and_merkle, tx_hash, remote_height - ) - tx = Transaction(unhexlify(_raw), height=merkle.get('block_height')) - cache_item.tx = tx # make sure it's saved before caching it - await self.maybe_verify_transaction(tx, remote_height, merkle) - return tx - - async def maybe_verify_transaction(self, tx, remote_height, merkle=None): - tx.height = remote_height - cached = self._tx_cache.get(tx.hash) - if not cached: - # cache txs looked up by transaction_show too - cached = TransactionCacheItem() - cached.tx = tx - self._tx_cache[tx.hash] = cached - if 0 < remote_height < len(self.headers) and cached.pending_verifications <= 1: - # can't be tx.pending_verifications == 1 because we have to handle the transaction_show case - if not merkle: - merkle = await self.network.retriable_call(self.network.get_merkle, tx.hash, remote_height) - merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) - header = await self.headers.get(remote_height) - tx.position = merkle['pos'] - tx.is_verified = merkle_root == header['merkle_root'] - - async def get_address_manager_for_address(self, address) -> Optional[AddressManager]: - details = await self.db.get_address(address=address) - for account in self.accounts: - if account.id == details['account']: - return account.address_managers[details['chain']] - return None - - def broadcast(self, tx): - # broadcast can't be a retriable call yet - return self.network.broadcast(hexlify(tx.raw).decode()) - - async def wait(self, tx: Transaction, height=-1, timeout=1): - addresses = set() - for txi in tx.inputs: - if txi.txo_ref.txo is not None: - addresses.add( - self.hash160_to_address(txi.txo_ref.txo.pubkey_hash) - ) - for txo in tx.outputs: - if txo.has_address: - addresses.add(self.hash160_to_address(txo.pubkey_hash)) - records = await self.db.get_addresses(address__in=addresses) - _, pending = await asyncio.wait([ - self.on_transaction.where(partial( - lambda a, e: a == e.address and e.tx.height >= height and e.tx.id == tx.id, - address_record['address'] - )) for address_record in records - ], timeout=timeout) - if pending: - records = await self.db.get_addresses(address__in=addresses) - for record in records: - found = False - local_history = (await self.get_local_status_and_history( - record['address'], history=record['history'] - ))[1] if record['history'] else [] - for txid, local_height in local_history: - if txid == tx.id and local_height >= height: - found = True - if not found: - print(record['history'], addresses, tx.id) - raise asyncio.TimeoutError('Timed out waiting for transaction.') - - async def _inflate_outputs( - self, query, accounts, - include_purchase_receipt=False, - include_is_my_output=False, - include_sent_supports=False, - include_sent_tips=False, - include_received_tips=False) -> Tuple[List[Output], dict, int, int]: - encoded_outputs = await query - outputs = Outputs.from_base64(encoded_outputs or b'') # TODO: why is the server returning None? - txs = [] - if len(outputs.txs) > 0: - txs: List[Transaction] = await asyncio.gather(*( - self.cache_transaction(*tx) for tx in outputs.txs - )) - - txos, blocked = outputs.inflate(txs) - - includes = ( - include_purchase_receipt, include_is_my_output, - include_sent_supports, include_sent_tips - ) - if accounts and any(includes): - copies = [] - receipts = {} - if include_purchase_receipt: - priced_claims = [] - for txo in txos: - if isinstance(txo, Output) and txo.has_price: - priced_claims.append(txo) - if priced_claims: - receipts = { - txo.purchased_claim_id: txo for txo in - await self.db.get_purchases( - accounts=accounts, - purchased_claim_hash__in=[c.claim_hash for c in priced_claims] - ) - } - for txo in txos: - if isinstance(txo, Output) and txo.can_decode_claim: - # transactions and outputs are cached and shared between wallets - # we don't want to leak informaion between wallet so we add the - # wallet specific metadata on throw away copies of the txos - txo_copy = copy.copy(txo) - copies.append(txo_copy) - if include_purchase_receipt: - txo_copy.purchase_receipt = receipts.get(txo.claim_id) - if include_is_my_output: - mine = await self.db.get_txo_count( - claim_id=txo.claim_id, txo_type__in=CLAIM_TYPES, is_my_output=True, - is_spent=False, accounts=accounts - ) - if mine: - txo_copy.is_my_output = True - else: - txo_copy.is_my_output = False - if include_sent_supports: - supports = await self.db.get_txo_sum( - claim_id=txo.claim_id, txo_type=TXO_TYPES['support'], - is_my_input=True, is_my_output=True, - is_spent=False, accounts=accounts - ) - txo_copy.sent_supports = supports - if include_sent_tips: - tips = await self.db.get_txo_sum( - claim_id=txo.claim_id, txo_type=TXO_TYPES['support'], - is_my_input=True, is_my_output=False, - accounts=accounts - ) - txo_copy.sent_tips = tips - if include_received_tips: - tips = await self.db.get_txo_sum( - claim_id=txo.claim_id, txo_type=TXO_TYPES['support'], - is_my_input=False, is_my_output=True, - accounts=accounts - ) - txo_copy.received_tips = tips - else: - copies.append(txo) - txos = copies - return txos, blocked, outputs.offset, outputs.total - - async def resolve(self, accounts, urls, **kwargs): - resolve = partial(self.network.retriable_call, self.network.resolve) - urls_copy = list(urls) - txos = [] - while urls_copy: - batch, urls_copy = urls_copy[:500], urls_copy[500:] - txos.extend((await self._inflate_outputs(resolve(batch), accounts, **kwargs))[0]) - assert len(urls) == len(txos), "Mismatch between urls requested for resolve and responses received." - result = {} - for url, txo in zip(urls, txos): - if txo: - if isinstance(txo, Output) and URL.parse(url).has_stream_in_channel: - if not txo.channel or not txo.is_signed_by(txo.channel, self): - txo = {'error': {'name': INVALID, 'text': f'{url} has invalid channel signature'}} - else: - txo = {'error': {'name': NOT_FOUND, 'text': f'{url} did not resolve to a claim'}} - result[url] = txo - return result - - async def claim_search( - self, accounts, include_purchase_receipt=False, include_is_my_output=False, - **kwargs) -> Tuple[List[Output], dict, int, int]: - return await self._inflate_outputs( - self.network.claim_search(**kwargs), accounts, - include_purchase_receipt=include_purchase_receipt, - include_is_my_output=include_is_my_output - ) - - async def get_claim_by_claim_id(self, accounts, claim_id, **kwargs) -> Output: - for claim in (await self.claim_search(accounts, claim_id=claim_id, **kwargs))[0]: - return claim - - async def _report_state(self): - try: - for account in self.accounts: - balance = dewies_to_lbc(await account.get_balance(include_claims=True)) - channel_count = await account.get_channel_count() - claim_count = await account.get_claim_count() - if isinstance(account.receiving, SingleKey): - log.info("Loaded single key account %s with %s LBC. " - "%d channels, %d certificates and %d claims", - account.id, balance, channel_count, len(account.channel_keys), claim_count) - else: - total_receiving = len(await account.receiving.get_addresses()) - total_change = len(await account.change.get_addresses()) - log.info("Loaded account %s with %s LBC, %d receiving addresses (gap: %d), " - "%d change addresses (gap: %d), %d channels, %d certificates and %d claims. ", - account.id, balance, total_receiving, account.receiving.gap, total_change, - account.change.gap, channel_count, len(account.channel_keys), claim_count) - except Exception as err: - if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8 - raise - log.exception( - 'Failed to display wallet state, please file issue ' - 'for this bug along with the traceback you see below:') - - async def _reset_balance_cache(self, e: TransactionEvent): - account_ids = [ - r['account'] for r in await self.db.get_addresses([AccountAddress.c.account], address=e.address) - ] - for account_id in account_ids: - if account_id in self._balance_cache: - del self._balance_cache[account_id] - - @staticmethod - def constraint_spending_utxos(constraints): - constraints['txo_type__in'] = (0, TXO_TYPES['purchase']) - - async def get_purchases(self, resolve=False, **constraints): - purchases = await self.db.get_purchases(**constraints) - if resolve: - claim_ids = [p.purchased_claim_id for p in purchases] - try: - resolved, _, _, _ = await self.claim_search([], claim_ids=claim_ids) - except Exception as err: - if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8 - raise - log.exception("Resolve failed while looking up purchased claim ids:") - resolved = [] - lookup = {claim.claim_id: claim for claim in resolved} - for purchase in purchases: - purchase.purchased_claim = lookup.get(purchase.purchased_claim_id) - return purchases - - def get_purchase_count(self, resolve=False, **constraints): - return self.db.get_purchase_count(**constraints) - - async def _resolve_for_local_results(self, accounts, txos): - results = [] - response = await self.resolve( - accounts, [txo.permanent_url for txo in txos if txo.can_decode_claim] - ) - for txo in txos: - resolved = response.get(txo.permanent_url) if txo.can_decode_claim else None - if isinstance(resolved, Output): - resolved.update_annotations(txo) - results.append(resolved) - else: - if isinstance(resolved, dict) and 'error' in resolved: - txo.meta['error'] = resolved['error'] - results.append(txo) - return results - - async def get_claims(self, resolve=False, **constraints): - claims = await self.db.get_claims(**constraints) - if resolve: - return await self._resolve_for_local_results(constraints.get('accounts', []), claims) - return claims - - def get_claim_count(self, **constraints): - return self.db.get_claim_count(**constraints) - - async def get_streams(self, resolve=False, **constraints): - streams = await self.db.get_streams(**constraints) - if resolve: - return await self._resolve_for_local_results(constraints.get('accounts', []), streams) - return streams - - def get_stream_count(self, **constraints): - return self.db.get_stream_count(**constraints) - - async def get_channels(self, resolve=False, **constraints): - channels = await self.db.get_channels(**constraints) - if resolve: - return await self._resolve_for_local_results(constraints.get('accounts', []), channels) - return channels - - def get_channel_count(self, **constraints): - return self.db.get_channel_count(**constraints) - - async def resolve_collection(self, collection, offset=0, page_size=1): - claim_ids = collection.claim.collection.claims.ids[offset:page_size+offset] - try: - resolve_results, _, _, _ = await self.claim_search([], claim_ids=claim_ids) - except Exception as err: - if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8 - raise - log.exception("Resolve failed while looking up collection claim ids:") - return [] - claims = [] - for claim_id in claim_ids: - found = False - for txo in resolve_results: - if txo.claim_id == claim_id: - claims.append(txo) - found = True - break - if not found: - claims.append(None) - return claims - - async def get_collections(self, resolve_claims=0, **constraints): - collections = await self.db.get_collections(**constraints) - if resolve_claims > 0: - for collection in collections: - collection.claims = await self.resolve_collection(collection, page_size=resolve_claims) - return collections - - def get_collection_count(self, resolve_claims=0, **constraints): - return self.db.get_collection_count(**constraints) - - def get_supports(self, **constraints): - return self.db.get_supports(**constraints) - - def get_support_count(self, **constraints): - return self.db.get_support_count(**constraints) - - async def get_transaction_history(self, **constraints): - txs: List[Transaction] = await self.db.get_transactions( - include_is_my_output=True, include_is_spent=True, - **constraints - ) - headers = self.headers - history = [] - for tx in txs: # pylint: disable=too-many-nested-blocks - ts = headers.estimated_timestamp(tx.height) - item = { - 'txid': tx.id, - 'timestamp': ts, - 'date': datetime.fromtimestamp(ts).isoformat(' ')[:-3] if tx.height > 0 else None, - 'confirmations': (headers.height+1) - tx.height if tx.height > 0 else 0, - 'claim_info': [], - 'update_info': [], - 'support_info': [], - 'abandon_info': [], - 'purchase_info': [] - } - is_my_inputs = all([txi.is_my_input for txi in tx.inputs]) - if is_my_inputs: - # fees only matter if we are the ones paying them - item['value'] = dewies_to_lbc(tx.net_account_balance+tx.fee) - item['fee'] = dewies_to_lbc(-tx.fee) - else: - # someone else paid the fees - item['value'] = dewies_to_lbc(tx.net_account_balance) - item['fee'] = '0.0' - for txo in tx.my_claim_outputs: - item['claim_info'].append({ - 'address': txo.get_address(self), - 'balance_delta': dewies_to_lbc(-txo.amount), - 'amount': dewies_to_lbc(txo.amount), - 'claim_id': txo.claim_id, - 'claim_name': txo.claim_name, - 'nout': txo.position, - 'is_spent': txo.is_spent, - }) - for txo in tx.my_update_outputs: - if is_my_inputs: # updating my own claim - previous = None - for txi in tx.inputs: - if txi.txo_ref.txo is not None: - other_txo = txi.txo_ref.txo - if (other_txo.is_claim or other_txo.script.is_support_claim) \ - and other_txo.claim_id == txo.claim_id: - previous = other_txo - break - if previous is not None: - item['update_info'].append({ - 'address': txo.get_address(self), - 'balance_delta': dewies_to_lbc(previous.amount-txo.amount), - 'amount': dewies_to_lbc(txo.amount), - 'claim_id': txo.claim_id, - 'claim_name': txo.claim_name, - 'nout': txo.position, - 'is_spent': txo.is_spent, - }) - else: # someone sent us their claim - item['update_info'].append({ - 'address': txo.get_address(self), - 'balance_delta': dewies_to_lbc(0), - 'amount': dewies_to_lbc(txo.amount), - 'claim_id': txo.claim_id, - 'claim_name': txo.claim_name, - 'nout': txo.position, - 'is_spent': txo.is_spent, - }) - for txo in tx.my_support_outputs: - item['support_info'].append({ - 'address': txo.get_address(self), - 'balance_delta': dewies_to_lbc(txo.amount if not is_my_inputs else -txo.amount), - 'amount': dewies_to_lbc(txo.amount), - 'claim_id': txo.claim_id, - 'claim_name': txo.claim_name, - 'is_tip': not is_my_inputs, - 'nout': txo.position, - 'is_spent': txo.is_spent, - }) - if is_my_inputs: - for txo in tx.other_support_outputs: - item['support_info'].append({ - 'address': txo.get_address(self), - 'balance_delta': dewies_to_lbc(-txo.amount), - 'amount': dewies_to_lbc(txo.amount), - 'claim_id': txo.claim_id, - 'claim_name': txo.claim_name, - 'is_tip': is_my_inputs, - 'nout': txo.position, - 'is_spent': txo.is_spent, - }) - for txo in tx.my_abandon_outputs: - item['abandon_info'].append({ - 'address': txo.get_address(self), - 'balance_delta': dewies_to_lbc(txo.amount), - 'amount': dewies_to_lbc(txo.amount), - 'claim_id': txo.claim_id, - 'claim_name': txo.claim_name, - 'nout': txo.position - }) - for txo in tx.any_purchase_outputs: - item['purchase_info'].append({ - 'address': txo.get_address(self), - 'balance_delta': dewies_to_lbc(txo.amount if not is_my_inputs else -txo.amount), - 'amount': dewies_to_lbc(txo.amount), - 'claim_id': txo.purchased_claim_id, - 'nout': txo.position, - 'is_spent': txo.is_spent, - }) - history.append(item) - return history - - def get_transaction_history_count(self, **constraints): - return self.db.get_transaction_count(**constraints) - - async def get_detailed_balance(self, accounts, confirmations=0): - result = { - 'total': 0, - 'available': 0, - 'reserved': 0, - 'reserved_subtotals': { - 'claims': 0, - 'supports': 0, - 'tips': 0 - } - } - for account in accounts: - balance = self._balance_cache.get(account.id) - if not balance: - balance = self._balance_cache[account.id] =\ - await account.get_detailed_balance(confirmations, reserved_subtotals=True) - for key, value in balance.items(): - if key == 'reserved_subtotals': - for subkey, subvalue in value.items(): - result['reserved_subtotals'][subkey] += subvalue - else: - result[key] += value - return result - class TestNetLedger(Ledger): network_name = 'testnet' diff --git a/lbry/blockchain/script.py b/lbry/blockchain/script.py index 393852619..b9f1b875e 100644 --- a/lbry/blockchain/script.py +++ b/lbry/blockchain/script.py @@ -294,20 +294,25 @@ class Template: class Script: - __slots__ = 'source', '_template', '_values', '_template_hint' + __slots__ = 'source', 'offset', '_template', '_values', '_template_hint' templates: List[Template] = [] NO_SCRIPT = Template('no_script', None) # special case - def __init__(self, source=None, template=None, values=None, template_hint=None): + def __init__(self, source=None, template=None, values=None, template_hint=None, offset=None): self.source = source + self.offset = offset self._template = template self._values = values self._template_hint = template_hint if source is None and template and values: self.generate() + @property + def length(self): + return len(self.source) + @property def template(self): if self._template is None: diff --git a/lbry/blockchain/sync.py b/lbry/blockchain/sync.py index 0f65d14a4..0f6efb274 100644 --- a/lbry/blockchain/sync.py +++ b/lbry/blockchain/sync.py @@ -1,111 +1,220 @@ import os import asyncio import logging -from threading import Thread -from multiprocessing import Queue, Event -from concurrent import futures +import multiprocessing as mp +from contextvars import ContextVar +from typing import Tuple, Optional +from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor -from lbry.wallet.stream import StreamController, EventQueuePublisher -from lbry.db import Database +from sqlalchemy import func, bindparam +from sqlalchemy.future import select + +from lbry.event import EventController, BroadcastSubscription +from lbry.service.base import Service, Sync, BlockEvent +from lbry.db import ( + queries, TXO_TYPES, Claim, Claimtrie, TX, TXO, TXI, Block as BlockTable, +) from .lbrycrd import Lbrycrd -from . import worker +from .block import Block, create_block_filter +from .bcd_data_stream import BCDataStream +from .ledger import Ledger log = logging.getLogger(__name__) +_context: ContextVar[Tuple[Lbrycrd, mp.Queue, mp.Event]] = ContextVar('ctx') -class ProgressMonitorThread(Thread): +def ctx(): + return _context.get() - STOP = 'stop' - FORMAT = '{l_bar}{bar}| {n_fmt:>6}/{total_fmt:>7} [{elapsed}<{remaining:>5}, {rate_fmt:>15}]' - def __init__(self, state: dict, queue: Queue, stream_controller: StreamController): - super().__init__() - self.state = state - self.queue = queue - self.stream_controller = stream_controller - self.loop = asyncio.get_event_loop() +def initialize(url: str, ledger: Ledger, progress: mp.Queue, stop: mp.Event, track_metrics=False): + chain = Lbrycrd(ledger) + chain.db.sync_open() + _context.set((chain, progress, stop)) + queries.initialize(url=url, ledger=ledger, track_metrics=track_metrics) - def run(self): - asyncio.set_event_loop(self.loop) - while True: - msg = self.queue.get() - if msg == self.STOP: + +def process_block_file(block_file_number): + chain, progress, stop = ctx() + block_file_path = chain.get_block_file_path_from_number(block_file_number) + num = 0 + progress.put_nowait((block_file_number, 1, num)) + best_height = queries.get_best_height() + best_block_processed = -1 + collector = queries.RowCollector(queries.ctx()) + with open(block_file_path, 'rb') as fp: + stream = BCDataStream(fp=fp) + for num, block_info in enumerate(chain.db.sync_get_file_details(block_file_number), start=1): + if stop.is_set(): return - self.stream_controller.add(msg) - - def shutdown(self): - self.queue.put(self.STOP) - self.join() - - def __enter__(self): - self.start() - - def __exit__(self, exc_type, exc_val, exc_tb): - self.shutdown() + if num % 100 == 0: + progress.put_nowait((block_file_number, 1, num)) + fp.seek(block_info['data_offset']) + block = Block.from_data_stream(stream, block_info['height'], block_file_number) + if block.height <= best_height: + continue + best_block_processed = max(block.height, best_block_processed) + collector.add_block(block) + collector.save(lambda remaining, total: progress.put((block_file_number, 2, remaining, total))) + return best_block_processed -class BlockchainSync: +def process_claimtrie(): + execute = queries.ctx().execute + chain, progress, stop = ctx() - def __init__(self, chain: Lbrycrd, db: Database, use_process_pool=False): + execute(Claimtrie.delete()) + for record in chain.db.sync_get_claimtrie(): + execute( + Claimtrie.insert(), { + 'normalized': record['normalized'], + 'claim_hash': record['claim_hash'], + 'last_take_over_height': record['last_take_over_height'], + } + ) + + best_height = queries.get_best_height() + + for record in chain.db.sync_get_claims(): + execute( + Claim.update() + .where(Claim.c.claim_hash == record['claim_hash']) + .values( + activation_height=record['activation_height'], + expiration_height=record['expiration_height'] + ) + ) + + support = TXO.alias('support') + effective_amount_update = ( + Claim.update() + .where(Claim.c.activation_height <= best_height) + .values( + effective_amount=( + select(func.coalesce(func.sum(support.c.amount), 0) + Claim.c.amount) + .select_from(support).where( + (support.c.claim_hash == Claim.c.claim_hash) & + (support.c.txo_type == TXO_TYPES['support']) & + (support.c.txo_hash.notin_(select(TXI.c.txo_hash))) + ).scalar_subquery() + ) + ) + ) + execute(effective_amount_update) + + +def process_block_and_tx_filters(): + execute = queries.ctx().execute + + blocks = [] + for block in queries.get_blocks_without_filters(): + block_filter = create_block_filter( + {r['address'] for r in queries.get_block_tx_addresses(block_hash=block['block_hash'])} + ) + blocks.append({'pk': block['block_hash'], 'block_filter': block_filter}) + execute(BlockTable.update().where(BlockTable.c.block_hash == bindparam('pk')), blocks) + + txs = [] + for tx in queries.get_transactions_without_filters(): + tx_filter = create_block_filter( + {r['address'] for r in queries.get_block_tx_addresses(tx_hash=tx['tx_hash'])} + ) + txs.append({'pk': tx['tx_hash'], 'tx_filter': tx_filter}) + execute(TX.update().where(TX.c.tx_hash == bindparam('pk')), txs) + + +class BlockchainSync(Sync): + + def __init__(self, service: Service, chain: Lbrycrd, multiprocess=False): + super().__init__(service) self.chain = chain - self.db = db - self.use_process_pool = use_process_pool - self._on_progress_controller = StreamController() + self.message_queue = mp.Queue() + self.stop_event = mp.Event() + self.on_block_subscription: Optional[BroadcastSubscription] = None + self.advance_loop_task: Optional[asyncio.Task] = None + self.advance_loop_event = asyncio.Event() + self.executor = self._create_executor(multiprocess) + self._on_progress_controller = EventController() self.on_progress = self._on_progress_controller.stream - def get_worker_pool(self, queue, full_stop) -> futures.Executor: + def _create_executor(self, multiprocess) -> Executor: args = dict( - initializer=worker.initializer, - initargs=(self.chain.data_dir, self.chain.regtest, self.db.db_path, queue, full_stop) + initializer=initialize, + initargs=( + self.service.db.url, self.chain.ledger, + self.message_queue, self.stop_event + ) ) - if not self.use_process_pool: - return futures.ThreadPoolExecutor(max_workers=1, **args) - return futures.ProcessPoolExecutor(max_workers=max(os.cpu_count()-1, 4), **args) + if multiprocess: + return ProcessPoolExecutor( + max_workers=max(os.cpu_count() - 1, 4), **args + ) + else: + return ThreadPoolExecutor( + max_workers=1, **args + ) + + async def start(self): + await self.advance() + self.chain.subscribe() + self.advance_loop_task = asyncio.create_task(self.advance_loop()) + self.on_block_subscription = self.chain.on_block.listen( + lambda e: self.advance_loop_event.set() + ) + + async def stop(self): + self.chain.unsubscribe() + if self.on_block_subscription is not None: + self.on_block_subscription.cancel() + self.stop_event.set() + self.advance_loop_task.cancel() + self.executor.shutdown() async def load_blocks(self): - jobs = [] - queue, full_stop = Queue(), Event() - executor = self.get_worker_pool(queue, full_stop) - files = list(await self.chain.get_block_files_not_synced()) - state = { - file.file_number: { - 'status': worker.PENDING, - 'done_txs': 0, - 'total_txs': file.txs, - 'done_blocks': 0, - 'total_blocks': file.blocks, - } for file in files - } - progress = EventQueuePublisher(queue, self._on_progress_controller) - progress.start() + tasks = [] + for file in await self.chain.db.get_block_files(): + tasks.append(asyncio.get_running_loop().run_in_executor( + self.executor, process_block_file, file['file_number'] + )) + done, pending = await asyncio.wait( + tasks, return_when=asyncio.FIRST_EXCEPTION + ) + if pending: + self.stop_event.set() + for future in pending: + future.cancel() + return max(f.result() for f in done) - def cancel_all_the_things(): - for job in jobs: - job.cancel() - full_stop.set() - for job in jobs: - exception = job.exception() - if exception is not None: - log.exception(exception) - raise exception + async def process_claims(self): + await asyncio.get_event_loop().run_in_executor( + self.executor, queries.process_claims_and_supports + ) - try: + async def process_block_and_tx_filters(self): + await asyncio.get_event_loop().run_in_executor( + self.executor, process_block_and_tx_filters + ) - for file in files: - jobs.append(executor.submit(worker.process_block_file, file.file_number)) + async def process_claimtrie(self): + await asyncio.get_event_loop().run_in_executor( + self.executor, process_claimtrie + ) - done, not_done = await asyncio.get_event_loop().run_in_executor( - None, futures.wait, jobs, None, futures.FIRST_EXCEPTION - ) - if not_done: - cancel_all_the_things() + async def post_process(self): + await self.process_claims() + if self.service.conf.spv_address_filters: + await self.process_block_and_tx_filters() + await self.process_claimtrie() - except asyncio.CancelledError: - cancel_all_the_things() - raise + async def advance(self): + best_height = await self.load_blocks() + await self.post_process() + await self._on_block_controller.add(BlockEvent(best_height)) - finally: - progress.stop() - executor.shutdown() + async def advance_loop(self): + while True: + await self.advance_loop_event.wait() + self.advance_loop_event.clear() + await self.advance() diff --git a/lbry/blockchain/testing.py b/lbry/blockchain/testing.py new file mode 100644 index 000000000..2eee89d6a --- /dev/null +++ b/lbry/blockchain/testing.py @@ -0,0 +1,78 @@ +import os +import sqlite3 +import asyncio +from typing import List + +from .block import Block +from .lbrycrd import Lbrycrd + + +def sync_create_lbrycrd_databases(dir_path: str): + for file_name, ddl in DDL.items(): + connection = sqlite3.connect(os.path.join(dir_path, file_name)) + connection.executescript(ddl) + connection.close() + + +async def create_lbrycrd_databases(dir_path: str): + await asyncio.get_running_loop().run_in_executor( + None, sync_create_lbrycrd_databases, dir_path + ) + + +async def add_block_to_lbrycrd(chain: Lbrycrd, block: Block, takeovers: List[str]): + for tx in block.txs: + for txo in tx.outputs: + if txo.is_claim: + await insert_claim(chain, block, tx, txo) + if txo.id in takeovers: + await insert_takeover(chain, block, tx, txo) + + +async def insert_claim(chain, block, tx, txo): + await chain.db.execute(""" + INSERT OR REPLACE INTO claim ( + claimID, name, nodeName, txID, txN, originalHeight, updateHeight, validHeight, + activationHeight, expirationHeight, amount + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, 10000, ?) + """, ( + txo.claim_hash, txo.claim_name, txo.claim_name, tx.hash, txo.position, + block.height, block.height, block.height, block.height, txo.amount + ) + ) + + +async def insert_takeover(chain, block, tx, txo): + await chain.db.execute( + "INSERT INTO takeover (name) VALUES (?)", + (txo.claim_name,) + ) + + +# These are extracted by opening each of lbrycrd latest sqlite databases and +# running '.schema' command. +DDL = { + 'claims.sqlite': """ + CREATE TABLE node (name BLOB NOT NULL PRIMARY KEY, parent BLOB REFERENCES node(name) DEFERRABLE INITIALLY DEFERRED, hash BLOB); + CREATE TABLE claim (claimID BLOB NOT NULL PRIMARY KEY, name BLOB NOT NULL, nodeName BLOB NOT NULL REFERENCES node(name) DEFERRABLE INITIALLY DEFERRED, txID BLOB NOT NULL, txN INTEGER NOT NULL, originalHeight INTEGER NOT NULL, updateHeight INTEGER NOT NULL, validHeight INTEGER NOT NULL, activationHeight INTEGER NOT NULL, expirationHeight INTEGER NOT NULL, amount INTEGER NOT NULL); + CREATE TABLE support (txID BLOB NOT NULL, txN INTEGER NOT NULL, supportedClaimID BLOB NOT NULL, name BLOB NOT NULL, nodeName BLOB NOT NULL, blockHeight INTEGER NOT NULL, validHeight INTEGER NOT NULL, activationHeight INTEGER NOT NULL, expirationHeight INTEGER NOT NULL, amount INTEGER NOT NULL, PRIMARY KEY(txID, txN)); + CREATE TABLE takeover (name BLOB NOT NULL, height INTEGER NOT NULL, claimID BLOB, PRIMARY KEY(name, height DESC)); + CREATE INDEX node_hash_len_name ON node (hash, LENGTH(name) DESC); + CREATE INDEX node_parent ON node (parent); + CREATE INDEX takeover_height ON takeover (height); + CREATE INDEX claim_activationHeight ON claim (activationHeight); + CREATE INDEX claim_expirationHeight ON claim (expirationHeight); + CREATE INDEX claim_nodeName ON claim (nodeName); + CREATE INDEX support_supportedClaimID ON support (supportedClaimID); + CREATE INDEX support_activationHeight ON support (activationHeight); + CREATE INDEX support_expirationHeight ON support (expirationHeight); + CREATE INDEX support_nodeName ON support (nodeName); + """, + 'block_index.sqlite': """ + CREATE TABLE block_file (file INTEGER NOT NULL PRIMARY KEY, blocks INTEGER NOT NULL, size INTEGER NOT NULL, undoSize INTEGER NOT NULL, heightFirst INTEGER NOT NULL, heightLast INTEGER NOT NULL, timeFirst INTEGER NOT NULL, timeLast INTEGER NOT NULL ); + CREATE TABLE block_info (hash BLOB NOT NULL PRIMARY KEY, prevHash BLOB NOT NULL, height INTEGER NOT NULL, file INTEGER NOT NULL, dataPos INTEGER NOT NULL, undoPos INTEGER NOT NULL, txCount INTEGER NOT NULL, status INTEGER NOT NULL, version INTEGER NOT NULL, rootTxHash BLOB NOT NULL, rootTrieHash BLOB NOT NULL, time INTEGER NOT NULL, bits INTEGER NOT NULL, nonce INTEGER NOT NULL ); + CREATE TABLE tx_to_block (txID BLOB NOT NULL PRIMARY KEY, file INTEGER NOT NULL, blockPos INTEGER NOT NULL, txPos INTEGER NOT NULL); + CREATE TABLE flag (name TEXT NOT NULL PRIMARY KEY, value INTEGER NOT NULL); + CREATE INDEX block_info_height ON block_info (height); + """, +} diff --git a/lbry/blockchain/transaction.py b/lbry/blockchain/transaction.py index fa3802049..b4ad692eb 100644 --- a/lbry/blockchain/transaction.py +++ b/lbry/blockchain/transaction.py @@ -1,9 +1,8 @@ import struct import hashlib import logging -import typing from binascii import hexlify, unhexlify -from typing import List, Iterable, Optional, Tuple +from typing import List, Iterable, Optional import ecdsa from cryptography.hazmat.backends import default_backend @@ -13,7 +12,6 @@ from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.primitives.asymmetric.utils import Prehashed from cryptography.exceptions import InvalidSignature -from lbry.error import InsufficientFundsError from lbry.crypto.hash import hash160, sha256 from lbry.crypto.base58 import Base58 from lbry.schema.url import normalize_name @@ -21,16 +19,10 @@ from lbry.schema.claim import Claim from lbry.schema.purchase import Purchase from .script import InputScript, OutputScript -from .constants import COIN, NULL_HASH32 from .bcd_data_stream import BCDataStream from .hash import TXRef, TXRefImmutable from .util import ReadOnlyList -if typing.TYPE_CHECKING: - from lbry.wallet.account import Account - from lbry.wallet.ledger import Ledger - from lbry.wallet.wallet import Wallet - log = logging.getLogger() @@ -190,20 +182,6 @@ class Input(InputOutput): stream.write_uint32(self.sequence) -class OutputEffectiveAmountEstimator: - - __slots__ = 'txo', 'txi', 'fee', 'effective_amount' - - def __init__(self, ledger: 'Ledger', txo: 'Output') -> None: - self.txo = txo - self.txi = Input.spend(txo) - self.fee: int = self.txi.get_fee(ledger) - self.effective_amount: int = txo.amount - self.fee - - def __lt__(self, other): - return self.effective_amount < other.effective_amount - - class Output(InputOutput): __slots__ = ( @@ -283,18 +261,15 @@ class Output(InputOutput): def get_address(self, ledger): return ledger.hash160_to_address(self.pubkey_hash) - def get_estimator(self, ledger): - return OutputEffectiveAmountEstimator(ledger, self) - @classmethod def pay_pubkey_hash(cls, amount, pubkey_hash): return cls(amount, OutputScript.pay_pubkey_hash(pubkey_hash)) @classmethod - def deserialize_from(cls, stream): + def deserialize_from(cls, stream, offset): return cls( amount=stream.read_uint64(), - script=OutputScript(stream.read_string()) + script=OutputScript(stream.read_string(), offset=offset+9) ) def serialize_to(self, stream, alternate_script=None): @@ -525,7 +500,7 @@ class Transaction: self.position = position self._day = julian_day if raw is not None: - self._deserialize() + self.deserialize() @property def is_broadcast(self): @@ -685,9 +660,10 @@ class Transaction: stream.write_uint32(self.signature_hash_type(1)) # signature hash type: SIGHASH_ALL return stream.get_bytes() - def _deserialize(self, stream=None): + def deserialize(self, stream=None): if self._raw is not None or stream is not None: stream = stream or BCDataStream(self._raw) + start = stream.tell() self.version = stream.read_uint32() input_count = stream.read_compact_size() if input_count == 0: @@ -698,7 +674,7 @@ class Transaction: ]) output_count = stream.read_compact_size() self._add(self._outputs, [ - Output.deserialize_from(stream) for _ in range(output_count) + Output.deserialize_from(stream, stream.tell()-start) for _ in range(output_count) ]) if self.is_segwit_flag: # drain witness portion of transaction @@ -710,180 +686,10 @@ class Transaction: self.locktime = stream.read_uint32() return self - @classmethod - def ensure_all_have_same_ledger_and_wallet( - cls, funding_accounts: Iterable['Account'], - change_account: 'Account' = None) -> Tuple['Ledger', 'Wallet']: - ledger = wallet = None - for account in funding_accounts: - if ledger is None: - ledger = account.ledger - wallet = account.wallet - if ledger != account.ledger: - raise ValueError( - 'All funding accounts used to create a transaction must be on the same ledger.' - ) - if wallet != account.wallet: - raise ValueError( - 'All funding accounts used to create a transaction must be from the same wallet.' - ) - if change_account is not None: - if change_account.ledger != ledger: - raise ValueError('Change account must use same ledger as funding accounts.') - if change_account.wallet != wallet: - raise ValueError('Change account must use same wallet as funding accounts.') - if ledger is None: - raise ValueError('No ledger found.') - if wallet is None: - raise ValueError('No wallet found.') - return ledger, wallet - - @classmethod - async def create(cls, inputs: Iterable[Input], outputs: Iterable[Output], - funding_accounts: Iterable['Account'], change_account: 'Account', - sign: bool = True): - """ Find optimal set of inputs when only outputs are provided; add change - outputs if only inputs are provided or if inputs are greater than outputs. """ - - tx = cls() \ - .add_inputs(inputs) \ - .add_outputs(outputs) - - ledger, _ = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account) - - # value of the outputs plus associated fees - cost = ( - tx.get_base_fee(ledger) + - tx.get_total_output_sum(ledger) - ) - # value of the inputs less the cost to spend those inputs - payment = tx.get_effective_input_sum(ledger) - - try: - - for _ in range(5): - - if payment < cost: - deficit = cost - payment - spendables = await ledger.get_spendable_utxos(deficit, funding_accounts) - if not spendables: - raise InsufficientFundsError() - payment += sum(s.effective_amount for s in spendables) - tx.add_inputs(s.txi for s in spendables) - - cost_of_change = ( - tx.get_base_fee(ledger) + - Output.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(ledger) - ) - if payment > cost: - change = payment - cost - if change > cost_of_change: - change_address = await change_account.change.get_or_create_usable_address() - change_hash160 = change_account.ledger.address_to_hash160(change_address) - change_amount = change - cost_of_change - change_output = Output.pay_pubkey_hash(change_amount, change_hash160) - change_output.is_internal_transfer = True - tx.add_outputs([Output.pay_pubkey_hash(change_amount, change_hash160)]) - - if tx._outputs: - break - # this condition and the outer range(5) loop cover an edge case - # whereby a single input is just enough to cover the fee and - # has some change left over, but the change left over is less - # than the cost_of_change: thus the input is completely - # consumed and no output is added, which is an invalid tx. - # to be able to spend this input we must increase the cost - # of the TX and run through the balance algorithm a second time - # adding an extra input and change output, making tx valid. - # we do this 5 times in case the other UTXOs added are also - # less than the fee, after 5 attempts we give up and go home - cost += cost_of_change + 1 - - if sign: - await tx.sign(funding_accounts) - - except Exception as e: - log.exception('Failed to create transaction:') - await ledger.release_tx(tx) - raise e - - return tx - @staticmethod def signature_hash_type(hash_type): return hash_type - async def sign(self, funding_accounts: Iterable['Account']): - ledger, wallet = self.ensure_all_have_same_ledger_and_wallet(funding_accounts) - for i, txi in enumerate(self._inputs): - assert txi.script is not None - assert txi.txo_ref.txo is not None - txo_script = txi.txo_ref.txo.script - if txo_script.is_pay_pubkey_hash: - address = ledger.hash160_to_address(txo_script.values['pubkey_hash']) - private_key = await ledger.get_private_key_for_address(wallet, address) - assert private_key is not None, 'Cannot find private key for signing output.' - tx = self._serialize_for_signature(i) - txi.script.values['signature'] = \ - private_key.sign(tx) + bytes((self.signature_hash_type(1),)) - txi.script.values['pubkey'] = private_key.public_key.pubkey_bytes - txi.script.generate() - else: - raise NotImplementedError("Don't know how to spend this output.") - self._reset() - - @classmethod - def pay(cls, amount: int, address: bytes, funding_accounts: List['Account'], change_account: 'Account'): - ledger, _ = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account) - output = Output.pay_pubkey_hash(amount, ledger.address_to_hash160(address)) - return cls.create([], [output], funding_accounts, change_account) - - @classmethod - def claim_create( - cls, name: str, claim: Claim, amount: int, holding_address: str, - funding_accounts: List['Account'], change_account: 'Account', signing_channel: Output = None): - ledger, _ = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account) - claim_output = Output.pay_claim_name_pubkey_hash( - amount, name, claim, ledger.address_to_hash160(holding_address) - ) - if signing_channel is not None: - claim_output.sign(signing_channel, b'placeholder txid:nout') - return cls.create([], [claim_output], funding_accounts, change_account, sign=False) - - @classmethod - def claim_update( - cls, previous_claim: Output, claim: Claim, amount: int, holding_address: str, - funding_accounts: List['Account'], change_account: 'Account', signing_channel: Output = None): - ledger, _ = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account) - updated_claim = Output.pay_update_claim_pubkey_hash( - amount, previous_claim.claim_name, previous_claim.claim_id, - claim, ledger.address_to_hash160(holding_address) - ) - if signing_channel is not None: - updated_claim.sign(signing_channel, b'placeholder txid:nout') - else: - updated_claim.clear_signature() - return cls.create( - [Input.spend(previous_claim)], [updated_claim], funding_accounts, change_account, sign=False - ) - - @classmethod - def support(cls, claim_name: str, claim_id: str, amount: int, holding_address: str, - funding_accounts: List['Account'], change_account: 'Account'): - ledger, _ = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account) - support_output = Output.pay_support_pubkey_hash( - amount, claim_name, claim_id, ledger.address_to_hash160(holding_address) - ) - return cls.create([], [support_output], funding_accounts, change_account) - - @classmethod - def purchase(cls, claim_id: str, amount: int, merchant_address: bytes, - funding_accounts: List['Account'], change_account: 'Account'): - ledger, _ = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account) - payment = Output.pay_pubkey_hash(amount, ledger.address_to_hash160(merchant_address)) - data = Output.add_purchase_data(Purchase(claim_id)) - return cls.create([], [payment, data], funding_accounts, change_account) - @property def my_inputs(self): for txi in self.inputs: diff --git a/lbry/blockchain/util.py b/lbry/blockchain/util.py index a9504bff1..2ae6e4ed8 100644 --- a/lbry/blockchain/util.py +++ b/lbry/blockchain/util.py @@ -1,6 +1,6 @@ import re from typing import TypeVar, Sequence, Optional -from .constants import COIN +from lbry.constants import COIN def coins_to_satoshis(coins): @@ -40,18 +40,6 @@ def subclass_tuple(name, base): return type(name, (base,), {'__slots__': ()}) -class cachedproperty: - - def __init__(self, f): - self.f = f - - def __get__(self, obj, objtype): - obj = obj or objtype - value = self.f(obj) - setattr(obj, self.f.__name__, value) - return value - - class ArithUint256: # https://github.com/bitcoin/bitcoin/blob/master/src/arith_uint256.cpp diff --git a/lbry/blockchain/worker.py b/lbry/blockchain/worker.py deleted file mode 100644 index 7eadade4b..000000000 --- a/lbry/blockchain/worker.py +++ /dev/null @@ -1,106 +0,0 @@ -from typing import Optional -from contextvars import ContextVar -from multiprocessing import Queue, Event -from dataclasses import dataclass -from itertools import islice - -from lbry.wallet.bcd_data_stream import BCDataStream -from lbry.db import Database -from .lbrycrd import Lbrycrd -from .block import Block - - -PENDING = 'pending' -RUNNING = 'running' -STOPPED = 'stopped' - - -def chunk(rows, step): - it, total = iter(rows), len(rows) - for _ in range(0, total, step): - yield min(step, total), islice(it, step) - total -= step - - -@dataclass -class WorkerContext: - lbrycrd: Lbrycrd - db: Database - progress: Queue - stop: Event - - -context: ContextVar[Optional[WorkerContext]] = ContextVar('context') - - -def initializer(data_dir: str, regtest: bool, db_path: str, progress: Queue, stop: Event): - context.set(WorkerContext( - lbrycrd=Lbrycrd(data_dir, regtest), - db=Database(db_path).sync_open(), - progress=progress, - stop=stop - )) - - -def process_block_file(block_file_number): - ctx: WorkerContext = context.get() - lbrycrd, db, progress, stop = ctx.lbrycrd, ctx.db, ctx.progress, ctx.stop - block_file_path = lbrycrd.get_block_file_path_from_number(block_file_number) - num = 0 - progress.put_nowait((block_file_number, 1, num)) - with open(block_file_path, 'rb') as fp: - stream = BCDataStream(fp=fp) - blocks, txs, claims, supports, spends = [], [], [], [], [] - for num, block_info in enumerate(lbrycrd.db.get_file_details(block_file_number), start=1): - if stop.is_set(): - return - if num % 100 == 0: - progress.put_nowait((block_file_number, 1, num)) - fp.seek(block_info['data_offset']) - block = Block(stream) - for tx in block.txs: - txs.append((block.block_hash, tx.position, tx.hash)) - for txi in tx.inputs: - if not txi.is_coinbase: - spends.append((block.block_hash, tx.hash, txi.txo_ref.hash)) - for output in tx.outputs: - try: - if output.is_support: - supports.append(( - block.block_hash, tx.hash, output.ref.hash, output.claim_hash, output.amount - )) - elif output.script.is_claim_name: - claims.append(( - block.block_hash, tx.hash, tx.position, output.ref.hash, output.claim_hash, - output.claim_name, 1, output.amount, None, None - )) - elif output.script.is_update_claim: - claims.append(( - block.block_hash, tx.hash, tx.position, output.ref.hash, output.claim_hash, - output.claim_name, 2, output.amount, None, None - )) - except Exception: - pass - blocks.append( - (block.block_hash, block.prev_block_hash, block_file_number, 0 if block.is_first_block else None) - ) - - progress.put((block_file_number, 1, num)) - - queries = ( - ("insert into block values (?, ?, ?, ?)", blocks), - ("insert into tx values (?, ?, ?)", txs), - ("insert into txi values (?, ?, ?)", spends), - ("insert into support values (?, ?, ?, ?, ?)", supports), - ("insert into claim_history values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", claims), - ) - total_txs = len(txs) - done_txs = 0 - step = int(sum(len(q[1]) for q in queries)/total_txs) - progress.put((block_file_number, 2, done_txs)) - for sql, rows in queries: - for chunk_size, chunk_rows in chunk(rows, 10000): - db.sync_executemany(sql, chunk_rows) - done_txs += int(chunk_size/step) - progress.put((block_file_number, 2, done_txs)) - progress.put((block_file_number, 2, total_txs)) diff --git a/tests/integration/blockchain/test_blockchain.py b/tests/integration/blockchain/test_blockchain.py index 0aec89942..d225f7264 100644 --- a/tests/integration/blockchain/test_blockchain.py +++ b/tests/integration/blockchain/test_blockchain.py @@ -1,35 +1,40 @@ import os import time import asyncio -import logging -from unittest import skip -from binascii import unhexlify, hexlify +import shutil +import tempfile +from binascii import hexlify, unhexlify from random import choice -from lbry.testcase import AsyncioTestCase -from lbry.crypto.base58 import Base58 -from lbry.blockchain import Lbrycrd, BlockchainSync +from lbry.conf import Config from lbry.db import Database -from lbry.blockchain.block import Block +from lbry.crypto.base58 import Base58 from lbry.schema.claim import Stream -from lbry.wallet.transaction import Transaction, Output -from lbry.wallet.constants import CENT -from lbry.wallet.bcd_data_stream import BCDataStream +from lbry.blockchain.lbrycrd import Lbrycrd +from lbry.blockchain.dewies import dewies_to_lbc, lbc_to_dewies +from lbry.blockchain.transaction import Transaction, Output +from lbry.constants import CENT +from lbry.blockchain.ledger import RegTestLedger +from lbry.testcase import AsyncioTestCase -#logging.getLogger('lbry.blockchain').setLevel(logging.DEBUG) -log = logging.getLogger(__name__) +from lbry.service.full_node import FullNode +from lbry.service.light_client import LightClient +from lbry.service.daemon import Daemon +from lbry.service.api import Client -@skip -class TestBlockchain(AsyncioTestCase): +class BlockchainTestCase(AsyncioTestCase): async def asyncSetUp(self): await super().asyncSetUp() - #self.chain = Lbrycrd.temp_regtest() - self.chain = Lbrycrd('/tmp/tmp0429f0ku/', True)#.temp_regtest() + self.chain = Lbrycrd.temp_regtest() + self.ledger = self.chain.ledger await self.chain.ensure() await self.chain.start('-maxblockfilesize=8', '-rpcworkqueue=128') - self.addCleanup(self.chain.stop, False) + self.addCleanup(self.chain.stop) + + +class TestEvents(BlockchainTestCase): async def test_block_event(self): msgs = [] @@ -50,29 +55,45 @@ class TestBlockchain(AsyncioTestCase): res = await self.chain.generate(3) await self.chain.on_block.where(lambda e: e['msg'] == 9) self.assertEqual(3, len(res)) - self.assertEqual([0, 1, 2, 3, 4, 7, 8, 9], msgs) + self.assertEqual([0, 1, 2, 3, 4, 7, 8, 9], msgs) # 5, 6 "missed" - async def test_sync(self): - if False: - names = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten'] - await self.chain.generate(101) - address = Base58.decode(await self.chain.get_new_address()) - for _ in range(190): - tx = Transaction().add_outputs([ - Output.pay_claim_name_pubkey_hash( - CENT, f'{choice(names)}{i}', - Stream().update( - title='a claim title', - description='Lorem ipsum '*400, - tags=['crypto', 'health', 'space'], - ).claim, - address) - for i in range(1, 20) - ]) - funded = await self.chain.fund_raw_transaction(hexlify(tx.raw).decode()) - signed = await self.chain.sign_raw_transaction_with_wallet(funded['hex']) - await self.chain.send_raw_transaction(signed['hex']) - await self.chain.generate(1) + +class TestBlockchainSync(BlockchainTestCase): + + async def asyncSetUp(self): + await super().asyncSetUp() + self.service = FullNode( + self.chain.ledger, f'sqlite:///{self.chain.data_dir}/lbry.db', self.chain + ) + self.service.conf.spv_address_filters = False + self.sync = self.service.sync + self.db = self.service.db + await self.db.open() + self.addCleanup(self.db.close) + + async def test_multi_block_file_sync(self): + names = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten'] + await self.chain.generate(101) + address = Base58.decode(await self.chain.get_new_address()) + start = time.perf_counter() + for _ in range(190): + tx = Transaction().add_outputs([ + Output.pay_claim_name_pubkey_hash( + CENT, f'{choice(names)}{i}', + Stream().update( + title='a claim title', + description='Lorem ipsum '*400, + tags=['crypto', 'health', 'space'], + ).claim, + address) + for i in range(1, 20) + ]) + funded = await self.chain.fund_raw_transaction(hexlify(tx.raw).decode()) + signed = await self.chain.sign_raw_transaction_with_wallet(funded['hex']) + await self.chain.send_raw_transaction(signed['hex']) + await self.chain.generate(1) + + print(f'generating {190*20} transactions took {time.perf_counter()-start}s') self.assertEqual( [(0, 191, 280), (1, 89, 178), (2, 12, 24)], @@ -81,9 +102,410 @@ class TestBlockchain(AsyncioTestCase): ) self.assertEqual(191, len(await self.chain.get_file_details(0))) - db = Database(os.path.join(self.chain.actual_data_dir, 'lbry.db')) - self.addCleanup(db.close) - await db.open() + await self.sync.advance() - sync = BlockchainSync(self.chain, use_process_pool=False) - await sync.load_blocks() + +class FullNodeTestCase(BlockchainTestCase): + + async def asyncSetUp(self): + await super().asyncSetUp() + + self.current_height = 0 + await self.generate(101, wait=False) + + self.service = FullNode(self.ledger, f'sqlite:///{self.chain.data_dir}/lbry.db') + self.service.conf.spv_address_filters = False + self.sync = self.service.sync + self.db = self.service.db + + self.daemon = Daemon(self.service) + self.api = self.daemon.api + self.addCleanup(self.daemon.stop) + await self.daemon.start() + + if False: #os.environ.get('TEST_LBRY_API', 'light_client') == 'light_client': + light_dir = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, light_dir, True) + + ledger = RegTestLedger(Config( + data_dir=light_dir, + wallet_dir=light_dir, + api='localhost:5389', + )) + + self.light_client = self.service = LightClient( + ledger, f'sqlite:///{light_dir}/light_client.db' + ) + self.light_api = Daemon(self.service) + await self.light_api.start() + self.addCleanup(self.light_api.stop) + #else: + # self.service = self.full_node + + #self.client = Client(self.service, self.ledger.conf.api_connection_url) + + async def generate(self, blocks, wait=True): + block_hashes = await self.chain.generate(blocks) + self.current_height += blocks + if wait: + await self.service.sync.on_block.where( + lambda b: self.current_height == b.height + ) + return block_hashes + + +class TestFullNode(FullNodeTestCase): + + async def test_foo(self): + await self.generate(10) + wallet = self.service.wallet_manager.default_wallet #create_wallet('test_wallet') + account = wallet.accounts[0] + addresses = await account.ensure_address_gap() + await self.chain.send_to_address(addresses[0], '5.0') + await self.generate(1) + self.assertEqual(await account.get_balance(), lbc_to_dewies('5.0')) + #self.assertEqual((await self.client.account_balance())['total'], '5.0') + + tx = await wallet.create_channel('@foo', lbc_to_dewies('1.0'), account, [account], addresses[0]) + await self.service.broadcast(tx) + await self.generate(1) + channels = await wallet.get_channels() + print(channels) + + +class TestClaimtrieSync(FullNodeTestCase): + + async def asyncSetUp(self): + await super().asyncSetUp() + self.last_block_hash = None + self.address = await self.chain.get_new_address() + + def find_claim_txo(self, tx): + for txo in tx.outputs: + if txo.is_claim: + return txo + + async def get_transaction(self, txid): + raw = await self.chain.get_raw_transaction(txid) + return Transaction(unhexlify(raw)) + + async def claim_name(self, title, amount): + claim = Stream().update(title=title).claim + return await self.chain.claim_name( + 'foo', hexlify(claim.to_bytes()).decode(), amount + ) + + async def claim_update(self, tx, amount): + claim = self.find_claim_txo(tx).claim + return await self.chain.update_claim( + tx.outputs[0].tx_ref.id, hexlify(claim.to_bytes()).decode(), amount + ) + + async def claim_abandon(self, tx): + return await self.chain.abandon_claim(tx.id, self.address) + + async def support_claim(self, tx, amount): + txo = self.find_claim_txo(tx) + response = await self.chain.support_claim( + txo.claim_name, txo.claim_id, amount + ) + return response['txId'] + + async def advance(self, new_height, ops): + blocks = (new_height-self.current_height)-1 + if blocks > 0: + await self.generate(blocks) + txs = [] + for op in ops: + if len(op) == 3: + op_type, value, amount = op + else: + (op_type, value), amount = op, None + if op_type == 'claim': + txid = await self.claim_name(value, amount) + elif op_type == 'update': + txid = await self.claim_update(value, amount) + elif op_type == 'abandon': + txid = await self.claim_abandon(value) + elif op_type == 'support': + txid = await self.support_claim(value, amount) + else: + raise ValueError(f'"{op_type}" is unknown operation') + txs.append(await self.get_transaction(txid)) + self.last_block_hash, = await self.generate(1) + self.current_height = new_height + return txs + + async def get_last_block(self): + return await self.chain.get_block(self.last_block_hash) + + async def get_controlling(self): + sql = f""" + select + tx.height, tx.raw, txo.position, effective_amount, activation_height + from claimtrie + join claim using (claim_hash) + join txo using (txo_hash) + join tx using (tx_hash) + where + txo.txo_type in (1, 2) and + expiration_height > {self.current_height} + """ + for claim in await self.db.execute_fetchall(sql): + tx = Transaction(claim['raw'], height=claim['height']) + txo = tx.outputs[claim['position']] + return ( + txo.claim.stream.title, dewies_to_lbc(txo.amount), + dewies_to_lbc(claim['effective_amount']), claim['activation_height'] + ) + + async def get_active(self): + controlling = await self.get_controlling() + active = [] + sql = f""" + select tx.height, tx.raw, txo.position, effective_amount, activation_height + from txo + join tx using (tx_hash) + join claim using (claim_hash) + where + txo.txo_type in (1, 2) and + activation_height <= {self.current_height} and + expiration_height > {self.current_height} + """ + for claim in await self.db.execute_fetchall(sql): + tx = Transaction(claim['raw'], height=claim['height']) + txo = tx.outputs[claim['position']] + if controlling and controlling[0] == txo.claim.stream.title: + continue + active.append(( + txo.claim.stream.title, dewies_to_lbc(txo.amount), + dewies_to_lbc(claim['effective_amount']), claim['activation_height'] + )) + return active + + async def get_accepted(self): + accepted = [] + sql = f""" + select tx.height, tx.raw, txo.position, effective_amount, activation_height + from txo + join tx using (tx_hash) + join claim using (claim_hash) + where + txo.txo_type in (1, 2) and + activation_height > {self.current_height} and + expiration_height > {self.current_height} + """ + for claim in await self.db.execute_fetchall(sql): + tx = Transaction(claim['raw'], height=claim['height']) + txo = tx.outputs[claim['position']] + accepted.append(( + txo.claim.stream.title, dewies_to_lbc(txo.amount), + dewies_to_lbc(claim['effective_amount']), claim['activation_height'] + )) + return accepted + + async def state(self, controlling=None, active=None, accepted=None): + self.assertEqual(controlling, await self.get_controlling()) + self.assertEqual(active or [], await self.get_active()) + self.assertEqual(accepted or [], await self.get_accepted()) + + async def test_example_from_spec(self): + # https://spec.lbry.com/#claim-activation-example + advance, state = self.advance, self.state + stream, = await advance(113, [('claim', 'Claim A', '10.0')]) + await state( + controlling=('Claim A', '10.0', '10.0', 113), + active=[], + accepted=[] + ) + await advance(501, [('claim', 'Claim B', '20.0')]) + await state( + controlling=('Claim A', '10.0', '10.0', 113), + active=[], + accepted=[('Claim B', '20.0', '0.0', 513)] + ) + await advance(510, [('support', stream, '14')]) + await state( + controlling=('Claim A', '10.0', '24.0', 113), + active=[], + accepted=[('Claim B', '20.0', '0.0', 513)] + ) + await advance(512, [('claim', 'Claim C', '50.0')]) + await state( + controlling=('Claim A', '10.0', '24.0', 113), + active=[], + accepted=[ + ('Claim B', '20.0', '0.0', 513), + ('Claim C', '50.0', '0.0', 524)] + ) + await advance(513, []) + await state( + controlling=('Claim A', '10.0', '24.0', 113), + active=[('Claim B', '20.0', '20.0', 513)], + accepted=[('Claim C', '50.0', '0.0', 524)] + ) + await advance(520, [('claim', 'Claim D', '60.0')]) + await state( + controlling=('Claim A', '10.0', '24.0', 113), + active=[('Claim B', '20.0', '20.0', 513)], + accepted=[ + ('Claim C', '50.0', '0.0', 524), + ('Claim D', '60.0', '0.0', 532)] + ) + await advance(524, []) + await state( + controlling=('Claim D', '60.0', '60.0', 524), + active=[ + ('Claim A', '10.0', '24.0', 113), + ('Claim B', '20.0', '20.0', 513), + ('Claim C', '50.0', '50.0', 524)], + accepted=[] + ) + # beyond example + await advance(525, [('update', stream, '70.0')]) + await state( + controlling=('Claim A', '70.0', '84.0', 525), + active=[ + ('Claim B', '20.0', '20.0', 513), + ('Claim C', '50.0', '50.0', 524), + ('Claim D', '60.0', '60.0', 524), + ], + accepted=[] + ) + + async def test_competing_claims_subsequent_blocks_height_wins(self): + advance, state = self.advance, self.state + await advance(113, [('claim', 'Claim A', '1.0')]) + await state( + controlling=('Claim A', '1.0', '1.0', 113), + active=[], + accepted=[] + ) + await advance(114, [('claim', 'Claim B', '1.0')]) + await state( + controlling=('Claim A', '1.0', '1.0', 113), + active=[('Claim B', '1.0', '1.0', 114)], + accepted=[] + ) + await advance(115, [('claim', 'Claim C', '1.0')]) + await state( + controlling=('Claim A', '1.0', '1.0', 113), + active=[ + ('Claim B', '1.0', '1.0', 114), + ('Claim C', '1.0', '1.0', 115)], + accepted=[] + ) + + async def test_competing_claims_in_single_block_position_wins(self): + claim_a, claim_b = await self.advance(113, [ + ('claim', 'Claim A', '1.0'), + ('claim', 'Claim B', '1.0') + ]) + block = await self.get_last_block() + # order of tx in block is non-deterministic, + # figure out what ordered we ended up with + if block['tx'][1] == claim_a.id: + winner, other = 'Claim A', 'Claim B' + else: + winner, other = 'Claim B', 'Claim A' + await self.state( + controlling=(winner, '1.0', '1.0', 113), + active=[(other, '1.0', '1.0', 113)], + accepted=[] + ) + + async def test_competing_claims_in_single_block_effective_amount_wins(self): + await self.advance(113, [ + ('claim', 'Claim A', '1.0'), + ('claim', 'Claim B', '2.0') + ]) + await self.state( + controlling=('Claim B', '2.0', '2.0', 113), + active=[('Claim A', '1.0', '1.0', 113)], + accepted=[] + ) + + async def test_winning_claim_deleted(self): + claim1, claim2 = await self.advance(113, [ + ('claim', 'Claim A', '1.0'), + ('claim', 'Claim B', '2.0') + ]) + await self.state( + controlling=('Claim B', '2.0', '2.0', 113), + active=[('Claim A', '1.0', '1.0', 113)], + accepted=[] + ) + await self.advance(114, [('abandon', claim2)]) + await self.state( + controlling=('Claim A', '1.0', '1.0', 113), + active=[], + accepted=[] + ) + + async def test_winning_claim_deleted_and_new_claim_becomes_winner(self): + claim1, claim2 = await self.advance(113, [ + ('claim', 'Claim A', '1.0'), + ('claim', 'Claim B', '2.0') + ]) + await self.state( + controlling=('Claim B', '2.0', '2.0', 113), + active=[('Claim A', '1.0', '1.0', 113)], + accepted=[] + ) + await self.advance(115, [ + ('abandon', claim2), + ('claim', 'Claim C', '3.0') + ]) + await self.state( + controlling=('Claim C', '3.0', '3.0', 115), + active=[('Claim A', '1.0', '1.0', 113)], + accepted=[] + ) + + async def test_winning_claim_expires_and_another_takes_over(self): + await self.advance(110, [('claim', 'Claim A', '2.0')]) + await self.advance(120, [('claim', 'Claim B', '1.0')]) + await self.state( + controlling=('Claim A', '2.0', '2.0', 110), + active=[('Claim B', '1.0', '1.0', 120)], + accepted=[] + ) + await self.advance(610, []) + await self.state( + controlling=('Claim B', '1.0', '1.0', 120), + active=[], + accepted=[] + ) + await self.advance(620, []) + await self.state( + controlling=None, + active=[], + accepted=[] + ) + + async def test_create_and_multiple_updates_in_same_block(self): + await self.chain.generate(10) + txid = await self.claim_name('Claim A', '1.0') + txid = await self.claim_update(await self.get_transaction(txid), '2.0') + await self.claim_update(await self.get_transaction(txid), '3.0') + await self.chain.generate(1) + await self.sync.advance() + self.current_height += 11 + await self.state( + controlling=('Claim A', '3.0', '3.0', 112), + active=[], + accepted=[] + ) + + async def test_create_and_abandon_in_same_block(self): + await self.chain.generate(10) + txid = await self.claim_name('Claim A', '1.0') + await self.claim_abandon(await self.get_transaction(txid)) + await self.chain.generate(1) + await self.sync.advance() + self.current_height += 11 + await self.state( + controlling=None, + active=[], + accepted=[] + ) diff --git a/tests/integration/blockchain/test_claim_commands.py b/tests/integration/blockchain/test_claim_commands.py index 561b7a76c..c30927d4a 100644 --- a/tests/integration/blockchain/test_claim_commands.py +++ b/tests/integration/blockchain/test_claim_commands.py @@ -9,8 +9,8 @@ from lbry.error import InsufficientFundsError from lbry.extras.daemon.daemon import DEFAULT_PAGE_SIZE from lbry.testcase import CommandTestCase -from lbry.wallet.transaction import Transaction -from lbry.wallet.util import satoshis_to_coins as lbc +from lbry.blockchain.transaction import Transaction +from lbry.blockchain.util import satoshis_to_coins as lbc log = logging.getLogger(__name__) @@ -142,7 +142,6 @@ class ClaimSearchCommand(ClaimTestCase): await self.assertFindsClaims([signed2], channel_ids=[channel_id2, self.channel_id], valid_channel_signature=True, invalid_channel_signature=False) # invalid signature still returns channel_id - self.ledger._tx_cache.clear() invalid_claims = await self.claim_search(invalid_channel_signature=True, has_channel_signature=True) self.assertEqual(3, len(invalid_claims)) self.assertTrue(all([not c['is_channel_signature_valid'] for c in invalid_claims])) @@ -234,7 +233,7 @@ class ClaimSearchCommand(ClaimTestCase): await self.assertFindsClaims([claim4, claim3, claim2], all_tags=['abc'], any_tags=['def', 'ghi']) async def test_order_by(self): - height = self.ledger.network.remote_height + height = self.ledger.sync.network.remote_height claims = [await self.stream_create(f'claim{i}') for i in range(5)] await self.assertFindsClaims(claims, order_by=["^height"]) @@ -820,7 +819,7 @@ class ChannelCommands(CommandTestCase): async def test_create_channel_names(self): # claim new name await self.channel_create('@foo') - self.assertItemCount(await self.daemon.jsonrpc_channel_list(), 1) + self.assertItemCount(await self.api.channel_list(), 1) await self.assertBalance(self.account, '8.991893') # fail to claim duplicate @@ -832,12 +831,12 @@ class ChannelCommands(CommandTestCase): await self.channel_create('foo') # nothing's changed after failed attempts - self.assertItemCount(await self.daemon.jsonrpc_channel_list(), 1) + self.assertItemCount(await self.api.channel_list(), 1) await self.assertBalance(self.account, '8.991893') # succeed overriding duplicate restriction await self.channel_create('@foo', allow_duplicate_name=True) - self.assertItemCount(await self.daemon.jsonrpc_channel_list(), 2) + self.assertItemCount(await self.api.channel_list(), 2) await self.assertBalance(self.account, '7.983786') async def test_channel_bids(self): diff --git a/tests/integration/blockchain/test_network.py b/tests/integration/blockchain/test_network.py index 0c0b47c79..9f449ac4e 100644 --- a/tests/integration/blockchain/test_network.py +++ b/tests/integration/blockchain/test_network.py @@ -4,7 +4,7 @@ from unittest.mock import Mock from binascii import unhexlify import lbry -from lbry.wallet.network import Network +from lbry.service.network import Network from lbry.wallet.orchstr8.node import SPVNode from lbry.wallet.rpc import RPCSession from lbry.testcase import IntegrationTestCase, AsyncioTestCase diff --git a/tests/integration/blockchain/test_sync.py b/tests/integration/blockchain/test_sync.py index 7af2bd1aa..852d47c9e 100644 --- a/tests/integration/blockchain/test_sync.py +++ b/tests/integration/blockchain/test_sync.py @@ -1,8 +1,16 @@ +import os import asyncio import logging -from lbry.testcase import IntegrationTestCase, WalletNode + +import aiohttp +from sqlalchemy import text + +from lbry.testcase import IntegrationTestCase, WalletNode, CommandTestCase from lbry.constants import CENT from lbry.wallet import WalletManager, RegTestLedger, Transaction, Output +from lbry.blockchain import Lbrycrd +from lbry.db import Database, TXI +from lbry.blockchain import Synchronizer class SyncTests(IntegrationTestCase): diff --git a/tests/integration/blockchain/test_transactions.py b/tests/integration/blockchain/test_transactions.py index f399d327b..821772e7a 100644 --- a/tests/integration/blockchain/test_transactions.py +++ b/tests/integration/blockchain/test_transactions.py @@ -80,7 +80,7 @@ class BasicTransactionTests(IntegrationTestCase): async def test_sending_and_receiving(self): account1, account2 = self.account, self.wallet.generate_account(self.ledger) - await self.ledger.subscribe_account(account2) + await self.ledger.sync.subscribe_account(account2) await self.assertBalance(account1, '0.0') await self.assertBalance(account2, '0.0') @@ -151,8 +151,8 @@ class BasicTransactionTests(IntegrationTestCase): for batch in range(0, len(sends), 10): txids = await asyncio.gather(*sends[batch:batch + 10]) await asyncio.wait([self.on_transaction_id(txid) for txid in txids]) - remote_status = await self.ledger.network.subscribe_address(address) - self.assertTrue(await self.ledger.update_history(address, remote_status)) + remote_status = await self.ledger.sync.network.subscribe_address(address) + self.assertTrue(await self.ledger.sync.update_history(address, remote_status)) # 20 unconfirmed txs, 10 from blockchain, 10 from local to local utxos = await self.account.get_utxos() txs = [] @@ -165,11 +165,11 @@ class BasicTransactionTests(IntegrationTestCase): await self.broadcast(tx) txs.append(tx) await asyncio.wait([self.on_transaction_address(tx, address) for tx in txs], timeout=1) - remote_status = await self.ledger.network.subscribe_address(address) - self.assertTrue(await self.ledger.update_history(address, remote_status)) + remote_status = await self.ledger.sync.network.subscribe_address(address) + self.assertTrue(await self.ledger.sync.update_history(address, remote_status)) # server history grows unordered txid = await self.blockchain.send_to_address(address, 1) await self.on_transaction_id(txid) - self.assertTrue(await self.ledger.update_history(address, remote_status)) - self.assertEqual(21, len((await self.ledger.get_local_status_and_history(address))[1])) + self.assertTrue(await self.ledger.sync.update_history(address, remote_status)) + self.assertEqual(21, len((await self.ledger.sync.get_local_status_and_history(address))[1])) self.assertEqual(0, len(self.ledger._known_addresses_out_of_sync)) diff --git a/tests/integration/blockchain/test_wallet_server_sessions.py b/tests/integration/blockchain/test_wallet_server_sessions.py index 5876c77a1..c16645bac 100644 --- a/tests/integration/blockchain/test_wallet_server_sessions.py +++ b/tests/integration/blockchain/test_wallet_server_sessions.py @@ -3,7 +3,7 @@ import asyncio import lbry import lbry.wallet from lbry.error import ServerPaymentFeeAboveMaxAllowedError -from lbry.wallet.network import ClientSession +from lbry.service.network import ClientSession from lbry.testcase import IntegrationTestCase, CommandTestCase from lbry.wallet.orchstr8.node import SPVNode diff --git a/tests/unit/blockchain/test_bcd_data_stream.py b/tests/unit/blockchain/test_bcd_data_stream.py index ab2095e45..df76cc9c8 100644 --- a/tests/unit/blockchain/test_bcd_data_stream.py +++ b/tests/unit/blockchain/test_bcd_data_stream.py @@ -1,9 +1,9 @@ -import unittest +from unittest import TestCase -from lbry.wallet.bcd_data_stream import BCDataStream +from lbry.blockchain.bcd_data_stream import BCDataStream -class TestBCDataStream(unittest.TestCase): +class TestBCDataStream(TestCase): def test_write_read(self): s = BCDataStream() diff --git a/tests/unit/blockchain/test_claim_proofs.py b/tests/unit/blockchain/test_claim_proofs.py index e393043f0..12254a55e 100644 --- a/tests/unit/blockchain/test_claim_proofs.py +++ b/tests/unit/blockchain/test_claim_proofs.py @@ -1,7 +1,7 @@ import unittest from binascii import hexlify, unhexlify -from lbry.wallet.claim_proofs import get_hash_for_outpoint, verify_proof +from lbry.blockchain.claim_proofs import get_hash_for_outpoint, verify_proof from lbry.crypto.hash import double_sha256 diff --git a/tests/unit/blockchain/test_dewies.py b/tests/unit/blockchain/test_dewies.py index 7d1b7ba7d..04c289341 100644 --- a/tests/unit/blockchain/test_dewies.py +++ b/tests/unit/blockchain/test_dewies.py @@ -1,6 +1,6 @@ import unittest -from lbry.wallet.dewies import lbc_to_dewies as l2d, dewies_to_lbc as d2l +from lbry.blockchain.dewies import lbc_to_dewies as l2d, dewies_to_lbc as d2l class TestDeweyConversion(unittest.TestCase): diff --git a/tests/unit/blockchain/test_headers.py b/tests/unit/blockchain/test_headers.py index e014f6e46..e08e4cdbc 100644 --- a/tests/unit/blockchain/test_headers.py +++ b/tests/unit/blockchain/test_headers.py @@ -3,9 +3,9 @@ import asyncio import tempfile from binascii import unhexlify -from lbry.wallet.util import ArithUint256 from lbry.testcase import AsyncioTestCase -from lbry.wallet.ledger import Headers as _Headers +from lbry.blockchain.util import ArithUint256 +from lbry.blockchain.ledger import Headers as _Headers class Headers(_Headers): @@ -168,9 +168,9 @@ class TestHeaders(AsyncioTestCase): await headers.open() self.assertEqual( cm.output, [ - 'WARNING:lbry.wallet.header:Reader file size doesnt match header size. ' + 'WARNING:lbry.blockchain.header:Reader file size doesnt match header size. ' 'Repairing, might take a while.', - 'WARNING:lbry.wallet.header:Header file corrupted at height 9, truncating ' + 'WARNING:lbry.blockchain.header:Header file corrupted at height 9, truncating ' 'it.' ] ) diff --git a/tests/unit/blockchain/test_script.py b/tests/unit/blockchain/test_script.py index 7333e1133..e3d5c351e 100644 --- a/tests/unit/blockchain/test_script.py +++ b/tests/unit/blockchain/test_script.py @@ -1,8 +1,8 @@ import unittest from binascii import hexlify, unhexlify -from lbry.wallet.bcd_data_stream import BCDataStream -from lbry.wallet.script import ( +from lbry.blockchain.bcd_data_stream import BCDataStream +from lbry.blockchain.script import ( InputScript, OutputScript, Template, ParseError, tokenize, push_data, PUSH_SINGLE, PUSH_INTEGER, PUSH_MANY, OP_HASH160, OP_EQUAL ) diff --git a/tests/unit/blockchain/test_sync.py b/tests/unit/blockchain/test_sync.py new file mode 100644 index 000000000..ee4e5edef --- /dev/null +++ b/tests/unit/blockchain/test_sync.py @@ -0,0 +1,523 @@ +import tempfile +import ecdsa +import hashlib +from binascii import hexlify +from typing import List, Tuple + +from lbry.testcase import AsyncioTestCase, get_output +from lbry.conf import Config +from lbry.db import RowCollector +from lbry.schema.claim import Claim +from lbry.schema.result import Censor +from lbry.blockchain.block import Block +from lbry.constants import COIN +from lbry.blockchain.transaction import Transaction, Input, Output +from lbry.service.full_node import FullNode +from lbry.blockchain.ledger import Ledger +from lbry.blockchain.lbrycrd import Lbrycrd +from lbry.blockchain.testing import create_lbrycrd_databases, add_block_to_lbrycrd + + +def get_input(fuzz=1): + return Input.spend(get_output(COIN, fuzz.to_bytes(32, 'little'))) + + +def get_tx(fuzz=1): + return Transaction().add_inputs([get_input(fuzz)]) + + +def search(**constraints) -> List: + return reader.search_claims(Censor(), **constraints) + + +def censored_search(**constraints) -> Tuple[List, Censor]: + rows, _, _, _, censor = reader.search(constraints) + return rows, censor + + +class TestSQLDB(AsyncioTestCase): + + async def asyncSetUp(self): + await super().asyncSetUp() + self.chain = Lbrycrd(Ledger(Config.with_same_dir(tempfile.mkdtemp()))) + self.addCleanup(self.chain.cleanup) + await create_lbrycrd_databases(self.chain.actual_data_dir) + await self.chain.open() + self.addCleanup(self.chain.close) + self.service = FullNode( + self.chain.ledger, f'sqlite:///{self.chain.data_dir}/lbry.db', self.chain + ) + self.service.conf.spv_address_filters = False + self.db = self.service.db + self.addCleanup(self.db.close) + await self.db.open() + self._txos = {} + + async def advance(self, height, txs, takeovers=None): + block = Block( + height=height, version=1, file_number=0, + block_hash=f'beef{height}'.encode(), prev_block_hash=f'beef{height-1}'.encode(), + merkle_root=b'beef', claim_trie_root=b'beef', + timestamp=99, bits=1, nonce=1, txs=txs + ) + await add_block_to_lbrycrd(self.chain, block, takeovers or []) + await RowCollector(self.db).add_block(block).save() + await self.service.sync.post_process() + return [tx.outputs[0] for tx in txs] + + def _make_tx(self, output, txi=None, **kwargs): + tx = get_tx(**kwargs).add_outputs([output]) + if txi is not None: + tx.add_inputs([txi]) + self._txos[output.ref.hash] = output + return tx + + def _set_channel_key(self, channel, key): + private_key = ecdsa.SigningKey.from_string(key*32, curve=ecdsa.SECP256k1, hashfunc=hashlib.sha256) + channel.private_key = private_key + channel.claim.channel.public_key_bytes = private_key.get_verifying_key().to_der() + channel.script.generate() + + def get_channel(self, title, amount, name='@foo', key=b'a', **kwargs): + claim = Claim() + claim.channel.title = title + channel = Output.pay_claim_name_pubkey_hash(amount, name, claim, b'abc') + self._set_channel_key(channel, key) + return self._make_tx(channel, **kwargs) + + def get_channel_update(self, channel, amount, key=b'a'): + self._set_channel_key(channel, key) + return self._make_tx( + Output.pay_update_claim_pubkey_hash( + amount, channel.claim_name, channel.claim_id, channel.claim, b'abc' + ), + Input.spend(channel) + ) + + def get_stream(self, title, amount, name='foo', channel=None, **kwargs): + claim = Claim() + claim.stream.update(title=title, **kwargs) + result = self._make_tx(Output.pay_claim_name_pubkey_hash(amount, name, claim, b'abc')) + if channel: + result.outputs[0].sign(channel) + result._reset() + return result + + def get_stream_update(self, tx, amount, channel=None): + stream = Transaction(tx[0].raw).outputs[0] + result = self._make_tx( + Output.pay_update_claim_pubkey_hash( + amount, stream.claim_name, stream.claim_id, stream.claim, b'abc' + ), + Input.spend(stream) + ) + if channel: + result.outputs[0].sign(channel) + result._reset() + return result + + def get_repost(self, claim_id, amount, channel): + claim = Claim() + claim.repost.reference.claim_id = claim_id + result = self._make_tx(Output.pay_claim_name_pubkey_hash(amount, 'repost', claim, b'abc')) + result.outputs[0].sign(channel) + result._reset() + return result + + def get_abandon(self, tx): + claim = Transaction(tx[0].raw).outputs[0] + return self._make_tx( + Output.pay_pubkey_hash(claim.amount, b'abc'), + Input.spend(claim) + ) + + def get_support(self, tx, amount): + claim = Transaction(tx[0].raw).outputs[0] + return self._make_tx( + Output.pay_support_pubkey_hash( + amount, claim.claim_name, claim.claim_id, b'abc' + ) + ) + + +class TestClaimtrie(TestSQLDB): + + def setUp(self): + super().setUp() + self._input_counter = 1 + + def _get_x_with_claim_id_prefix(self, getter, prefix, cached_iteration=None, **kwargs): + iterations = cached_iteration+1 if cached_iteration else 100 + for i in range(cached_iteration or 1, iterations): + stream = getter(f'claim #{i}', COIN, fuzz=self._input_counter, **kwargs) + if stream.outputs[0].claim_id.startswith(prefix): + cached_iteration is None and print(f'Found "{prefix}" in {i} iterations.') + self._input_counter += 1 + return stream + if cached_iteration: + raise ValueError(f'Failed to find "{prefix}" at cached iteration, run with None to find iteration.') + raise ValueError(f'Failed to find "{prefix}" in {iterations} iterations, try different values.') + + def get_channel_with_claim_id_prefix(self, prefix, cached_iteration=None, **kwargs): + return self._get_x_with_claim_id_prefix(self.get_channel, prefix, cached_iteration, **kwargs) + + def get_stream_with_claim_id_prefix(self, prefix, cached_iteration=None, **kwargs): + return self._get_x_with_claim_id_prefix(self.get_stream, prefix, cached_iteration, **kwargs) + + async def test_canonical_url_and_channel_validation(self): + advance, search = self.advance, partial(self.service.search_claims, []) + + tx_chan_a = self.get_channel_with_claim_id_prefix('a', 1, key=b'c') + tx_chan_ab = self.get_channel_with_claim_id_prefix('ab', 20, key=b'c') + txo_chan_a = tx_chan_a.outputs[0] + txo_chan_ab = tx_chan_ab.outputs[0] + await advance(1, [tx_chan_a]) + await advance(2, [tx_chan_ab]) + (r_ab, r_a) = search(order_by=['creation_height'], limit=2) + self.assertEqual("@foo#a", r_a['short_url']) + self.assertEqual("@foo#ab", r_ab['short_url']) + self.assertIsNone(r_a['canonical_url']) + self.assertIsNone(r_ab['canonical_url']) + self.assertEqual(0, r_a['claims_in_channel']) + self.assertEqual(0, r_ab['claims_in_channel']) + + tx_a = self.get_stream_with_claim_id_prefix('a', 2) + tx_ab = self.get_stream_with_claim_id_prefix('ab', 42) + tx_abc = self.get_stream_with_claim_id_prefix('abc', 65) + await advance(3, [tx_a]) + await advance(4, [tx_ab, tx_abc]) + (r_abc, r_ab, r_a) = search(order_by=['creation_height', 'tx_position'], limit=3) + self.assertEqual("foo#a", r_a['short_url']) + self.assertEqual("foo#ab", r_ab['short_url']) + self.assertEqual("foo#abc", r_abc['short_url']) + self.assertIsNone(r_a['canonical_url']) + self.assertIsNone(r_ab['canonical_url']) + self.assertIsNone(r_abc['canonical_url']) + + tx_a2 = self.get_stream_with_claim_id_prefix('a', 7, channel=txo_chan_a) + tx_ab2 = self.get_stream_with_claim_id_prefix('ab', 23, channel=txo_chan_a) + a2_claim = tx_a2.outputs[0] + ab2_claim = tx_ab2.outputs[0] + await advance(6, [tx_a2]) + await advance(7, [tx_ab2]) + (r_ab2, r_a2) = search(order_by=['creation_height'], limit=2) + self.assertEqual(f"foo#{a2_claim.claim_id[:2]}", r_a2['short_url']) + self.assertEqual(f"foo#{ab2_claim.claim_id[:4]}", r_ab2['short_url']) + self.assertEqual("@foo#a/foo#a", r_a2['canonical_url']) + self.assertEqual("@foo#a/foo#ab", r_ab2['canonical_url']) + self.assertEqual(2, search(claim_id=txo_chan_a.claim_id, limit=1)[0]['claims_in_channel']) + + # change channel public key, invaliding stream claim signatures + await advance(8, [self.get_channel_update(txo_chan_a, COIN, key=b'a')]) + (r_ab2, r_a2) = search(order_by=['creation_height'], limit=2) + self.assertEqual(f"foo#{a2_claim.claim_id[:2]}", r_a2['short_url']) + self.assertEqual(f"foo#{ab2_claim.claim_id[:4]}", r_ab2['short_url']) + self.assertIsNone(r_a2['canonical_url']) + self.assertIsNone(r_ab2['canonical_url']) + self.assertEqual(0, search(claim_id=txo_chan_a.claim_id, limit=1)[0]['claims_in_channel']) + + # reinstate previous channel public key (previous stream claim signatures become valid again) + channel_update = self.get_channel_update(txo_chan_a, COIN, key=b'c') + await advance(9, [channel_update]) + (r_ab2, r_a2) = search(order_by=['creation_height'], limit=2) + self.assertEqual(f"foo#{a2_claim.claim_id[:2]}", r_a2['short_url']) + self.assertEqual(f"foo#{ab2_claim.claim_id[:4]}", r_ab2['short_url']) + self.assertEqual("@foo#a/foo#a", r_a2['canonical_url']) + self.assertEqual("@foo#a/foo#ab", r_ab2['canonical_url']) + self.assertEqual(2, search(claim_id=txo_chan_a.claim_id, limit=1)[0]['claims_in_channel']) + self.assertEqual(0, search(claim_id=txo_chan_ab.claim_id, limit=1)[0]['claims_in_channel']) + + # change channel of stream + self.assertEqual("@foo#a/foo#ab", search(claim_id=ab2_claim.claim_id, limit=1)[0]['canonical_url']) + tx_ab2 = self.get_stream_update(tx_ab2, COIN, txo_chan_ab) + await advance(10, [tx_ab2]) + self.assertEqual("@foo#ab/foo#a", search(claim_id=ab2_claim.claim_id, limit=1)[0]['canonical_url']) + # TODO: currently there is a bug where stream leaving a channel does not update that channels claims count + self.assertEqual(2, search(claim_id=txo_chan_a.claim_id, limit=1)[0]['claims_in_channel']) + # TODO: after bug is fixed remove test above and add test below + #self.assertEqual(1, search(claim_id=txo_chan_a.claim_id, limit=1)[0]['claims_in_channel']) + self.assertEqual(1, search(claim_id=txo_chan_ab.claim_id, limit=1)[0]['claims_in_channel']) + + # claim abandon updates claims_in_channel + await advance(11, [self.get_abandon(tx_ab2)]) + self.assertEqual(0, search(claim_id=txo_chan_ab.claim_id, limit=1)[0]['claims_in_channel']) + + # delete channel, invaliding stream claim signatures + await advance(12, [self.get_abandon(channel_update)]) + (r_a2,) = search(order_by=['creation_height'], limit=1) + self.assertEqual(f"foo#{a2_claim.claim_id[:2]}", r_a2['short_url']) + self.assertIsNone(r_a2['canonical_url']) + + def test_resolve_issue_2448(self): + advance = self.advance + + tx_chan_a = self.get_channel_with_claim_id_prefix('a', 1, key=b'c') + tx_chan_ab = self.get_channel_with_claim_id_prefix('ab', 72, key=b'c') + txo_chan_a = tx_chan_a[0].outputs[0] + txo_chan_ab = tx_chan_ab[0].outputs[0] + advance(1, [tx_chan_a]) + advance(2, [tx_chan_ab]) + + self.assertEqual(reader.resolve_url("@foo#a")['claim_hash'], txo_chan_a.claim_hash) + self.assertEqual(reader.resolve_url("@foo#ab")['claim_hash'], txo_chan_ab.claim_hash) + + # update increase last height change of channel + advance(9, [self.get_channel_update(txo_chan_a, COIN, key=b'c')]) + + # make sure that activation_height is used instead of height (issue #2448) + self.assertEqual(reader.resolve_url("@foo#a")['claim_hash'], txo_chan_a.claim_hash) + self.assertEqual(reader.resolve_url("@foo#ab")['claim_hash'], txo_chan_ab.claim_hash) + + def test_canonical_find_shortest_id(self): + new_hash = 'abcdef0123456789beef' + other0 = '1bcdef0123456789beef' + other1 = 'ab1def0123456789beef' + other2 = 'abc1ef0123456789beef' + other3 = 'abcdef0123456789bee1' + f = FindShortestID() + f.step(other0, new_hash) + self.assertEqual('#a', f.finalize()) + f.step(other1, new_hash) + self.assertEqual('#abc', f.finalize()) + f.step(other2, new_hash) + self.assertEqual('#abcd', f.finalize()) + f.step(other3, new_hash) + self.assertEqual('#abcdef0123456789beef', f.finalize()) + + +class TestTrending(TestSQLDB): + + def test_trending(self): + advance = self.advance + no_trend = self.get_stream('Claim A', COIN) + downwards = self.get_stream('Claim B', COIN) + up_small = self.get_stream('Claim C', COIN) + up_medium = self.get_stream('Claim D', COIN) + up_biggly = self.get_stream('Claim E', COIN) + claims = advance(1, [up_biggly, up_medium, up_small, no_trend, downwards]) + for window in range(1, 8): + advance(zscore.TRENDING_WINDOW * window, [ + self.get_support(downwards, (20-window)*COIN), + self.get_support(up_small, int(20+(window/10)*COIN)), + self.get_support(up_medium, (20+(window*(2 if window == 7 else 1)))*COIN), + self.get_support(up_biggly, (20+(window*(3 if window == 7 else 1)))*COIN), + ]) + results = search(order_by=['trending_local']) + self.assertEqual([c.claim_id for c in claims], [hexlify(c['claim_hash'][::-1]).decode() for c in results]) + self.assertEqual([10, 6, 2, 0, -2], [int(c['trending_local']) for c in results]) + self.assertEqual([53, 38, -32, 0, -6], [int(c['trending_global']) for c in results]) + self.assertEqual([4, 4, 2, 0, 1], [int(c['trending_group']) for c in results]) + self.assertEqual([53, 38, 2, 0, -6], [int(c['trending_mixed']) for c in results]) + + def test_edge(self): + problematic = self.get_stream('Problem', COIN) + self.advance(1, [problematic]) + self.advance(zscore.TRENDING_WINDOW, [self.get_support(problematic, 53000000000)]) + self.advance(zscore.TRENDING_WINDOW * 2, [self.get_support(problematic, 500000000)]) + + +class TestContentBlocking(TestSQLDB): + + def test_blocking_and_filtering(self): + # content claims and channels + tx0 = self.get_channel('A Channel', COIN, '@channel1') + regular_channel = tx0[0].outputs[0] + tx1 = self.get_stream('Claim One', COIN, 'claim1') + tx2 = self.get_stream('Claim Two', COIN, 'claim2', regular_channel) + tx3 = self.get_stream('Claim Three', COIN, 'claim3') + self.advance(1, [tx0, tx1, tx2, tx3]) + claim1, claim2, claim3 = tx1[0].outputs[0], tx2[0].outputs[0], tx3[0].outputs[0] + + # block and filter channels + tx0 = self.get_channel('Blocking Channel', COIN, '@block') + tx1 = self.get_channel('Filtering Channel', COIN, '@filter') + blocking_channel = tx0[0].outputs[0] + filtering_channel = tx1[0].outputs[0] + self.sql.blocking_channel_hashes.add(blocking_channel.claim_hash) + self.sql.filtering_channel_hashes.add(filtering_channel.claim_hash) + self.advance(2, [tx0, tx1]) + self.assertEqual({}, dict(self.sql.blocked_streams)) + self.assertEqual({}, dict(self.sql.blocked_channels)) + self.assertEqual({}, dict(self.sql.filtered_streams)) + self.assertEqual({}, dict(self.sql.filtered_channels)) + + # nothing blocked + results, _ = reader.resolve([ + claim1.claim_name, claim2.claim_name, + claim3.claim_name, regular_channel.claim_name + ]) + self.assertEqual(claim1.claim_hash, results[0]['claim_hash']) + self.assertEqual(claim2.claim_hash, results[1]['claim_hash']) + self.assertEqual(claim3.claim_hash, results[2]['claim_hash']) + self.assertEqual(regular_channel.claim_hash, results[3]['claim_hash']) + + # nothing filtered + results, censor = censored_search() + self.assertEqual(6, len(results)) + self.assertEqual(0, censor.total) + self.assertEqual({}, censor.censored) + + # block claim reposted to blocking channel, also gets filtered + repost_tx1 = self.get_repost(claim1.claim_id, COIN, blocking_channel) + repost1 = repost_tx1[0].outputs[0] + self.advance(3, [repost_tx1]) + self.assertEqual( + {repost1.claim.repost.reference.claim_hash: blocking_channel.claim_hash}, + dict(self.sql.blocked_streams) + ) + self.assertEqual({}, dict(self.sql.blocked_channels)) + self.assertEqual( + {repost1.claim.repost.reference.claim_hash: blocking_channel.claim_hash}, + dict(self.sql.filtered_streams) + ) + self.assertEqual({}, dict(self.sql.filtered_channels)) + + # claim is blocked from results by direct repost + results, censor = censored_search(text='Claim') + self.assertEqual(2, len(results)) + self.assertEqual(claim2.claim_hash, results[0]['claim_hash']) + self.assertEqual(claim3.claim_hash, results[1]['claim_hash']) + self.assertEqual(1, censor.total) + self.assertEqual({blocking_channel.claim_hash: 1}, censor.censored) + results, _ = reader.resolve([claim1.claim_name]) + self.assertEqual( + f"Resolve of 'claim1' was censored by channel with claim id '{blocking_channel.claim_id}'.", + results[0].args[0] + ) + results, _ = reader.resolve([ + claim2.claim_name, regular_channel.claim_name # claim2 and channel still resolved + ]) + self.assertEqual(claim2.claim_hash, results[0]['claim_hash']) + self.assertEqual(regular_channel.claim_hash, results[1]['claim_hash']) + + # block claim indirectly by blocking its parent channel + repost_tx2 = self.get_repost(regular_channel.claim_id, COIN, blocking_channel) + repost2 = repost_tx2[0].outputs[0] + self.advance(4, [repost_tx2]) + self.assertEqual( + {repost1.claim.repost.reference.claim_hash: blocking_channel.claim_hash}, + dict(self.sql.blocked_streams) + ) + self.assertEqual( + {repost2.claim.repost.reference.claim_hash: blocking_channel.claim_hash}, + dict(self.sql.blocked_channels) + ) + self.assertEqual( + {repost1.claim.repost.reference.claim_hash: blocking_channel.claim_hash}, + dict(self.sql.filtered_streams) + ) + self.assertEqual( + {repost2.claim.repost.reference.claim_hash: blocking_channel.claim_hash}, + dict(self.sql.filtered_channels) + ) + + # claim in blocked channel is filtered from search and can't resolve + results, censor = censored_search(text='Claim') + self.assertEqual(1, len(results)) + self.assertEqual(claim3.claim_hash, results[0]['claim_hash']) + self.assertEqual(2, censor.total) + self.assertEqual({blocking_channel.claim_hash: 2}, censor.censored) + results, _ = reader.resolve([ + claim2.claim_name, regular_channel.claim_name # claim2 and channel don't resolve + ]) + self.assertEqual( + f"Resolve of 'claim2' was censored by channel with claim id '{blocking_channel.claim_id}'.", + results[0].args[0] + ) + self.assertEqual( + f"Resolve of '@channel1' was censored by channel with claim id '{blocking_channel.claim_id}'.", + results[1].args[0] + ) + results, _ = reader.resolve([claim3.claim_name]) # claim3 still resolved + self.assertEqual(claim3.claim_hash, results[0]['claim_hash']) + + # filtered claim is only filtered and not blocked + repost_tx3 = self.get_repost(claim3.claim_id, COIN, filtering_channel) + repost3 = repost_tx3[0].outputs[0] + self.advance(5, [repost_tx3]) + self.assertEqual( + {repost1.claim.repost.reference.claim_hash: blocking_channel.claim_hash}, + dict(self.sql.blocked_streams) + ) + self.assertEqual( + {repost2.claim.repost.reference.claim_hash: blocking_channel.claim_hash}, + dict(self.sql.blocked_channels) + ) + self.assertEqual( + {repost1.claim.repost.reference.claim_hash: blocking_channel.claim_hash, + repost3.claim.repost.reference.claim_hash: filtering_channel.claim_hash}, + dict(self.sql.filtered_streams) + ) + self.assertEqual( + {repost2.claim.repost.reference.claim_hash: blocking_channel.claim_hash}, + dict(self.sql.filtered_channels) + ) + + # filtered claim doesn't return in search but is resolveable + results, censor = censored_search(text='Claim') + self.assertEqual(0, len(results)) + self.assertEqual(3, censor.total) + self.assertEqual({blocking_channel.claim_hash: 2, filtering_channel.claim_hash: 1}, censor.censored) + results, _ = reader.resolve([claim3.claim_name]) # claim3 still resolved + self.assertEqual(claim3.claim_hash, results[0]['claim_hash']) + + # abandon unblocks content + self.advance(6, [ + self.get_abandon(repost_tx1), + self.get_abandon(repost_tx2), + self.get_abandon(repost_tx3) + ]) + self.assertEqual({}, dict(self.sql.blocked_streams)) + self.assertEqual({}, dict(self.sql.blocked_channels)) + self.assertEqual({}, dict(self.sql.filtered_streams)) + self.assertEqual({}, dict(self.sql.filtered_channels)) + results, censor = censored_search(text='Claim') + self.assertEqual(3, len(results)) + self.assertEqual(0, censor.total) + results, censor = censored_search() + self.assertEqual(6, len(results)) + self.assertEqual(0, censor.total) + results, _ = reader.resolve([ + claim1.claim_name, claim2.claim_name, + claim3.claim_name, regular_channel.claim_name + ]) + self.assertEqual(claim1.claim_hash, results[0]['claim_hash']) + self.assertEqual(claim2.claim_hash, results[1]['claim_hash']) + self.assertEqual(claim3.claim_hash, results[2]['claim_hash']) + self.assertEqual(regular_channel.claim_hash, results[3]['claim_hash']) + + def test_pagination(self): + one, two, three, four, five, six, seven, filter_channel = self.advance(1, [ + self.get_stream('One', COIN), + self.get_stream('Two', COIN), + self.get_stream('Three', COIN), + self.get_stream('Four', COIN), + self.get_stream('Five', COIN), + self.get_stream('Six', COIN), + self.get_stream('Seven', COIN), + self.get_channel('Filtering Channel', COIN, '@filter'), + ]) + self.sql.filtering_channel_hashes.add(filter_channel.claim_hash) + + # nothing filtered + results, censor = censored_search(order_by='^height', offset=1, limit=3) + self.assertEqual(3, len(results)) + self.assertEqual( + [two.claim_hash, three.claim_hash, four.claim_hash], + [r['claim_hash'] for r in results] + ) + self.assertEqual(0, censor.total) + + # content filtered + repost1, repost2 = self.advance(2, [ + self.get_repost(one.claim_id, COIN, filter_channel), + self.get_repost(two.claim_id, COIN, filter_channel), + ]) + results, censor = censored_search(order_by='^height', offset=1, limit=3) + self.assertEqual(3, len(results)) + self.assertEqual( + [four.claim_hash, five.claim_hash, six.claim_hash], + [r['claim_hash'] for r in results] + ) + self.assertEqual(2, censor.total) + self.assertEqual({filter_channel.claim_hash: 2}, censor.censored) diff --git a/tests/unit/blockchain/test_transaction.py b/tests/unit/blockchain/test_transaction.py index d32d92d4a..9d37dfb26 100644 --- a/tests/unit/blockchain/test_transaction.py +++ b/tests/unit/blockchain/test_transaction.py @@ -1,51 +1,21 @@ -import unittest +from unittest import TestCase from binascii import hexlify, unhexlify -from itertools import cycle -from lbry.testcase import AsyncioTestCase -from lbry.wallet.constants import CENT, COIN, NULL_HASH32 -from lbry.wallet import Wallet, Account, Ledger, Headers, Transaction, Output, Input -from lbry.db import Database +from lbry.blockchain.ledger import Ledger +from lbry.constants import CENT, NULL_HASH32 +from lbry.blockchain.transaction import Transaction +from lbry.testcase import ( + get_transaction, get_input, get_output, get_claim_transaction +) - -NULL_HASH = b'\x00'*32 FEE_PER_BYTE = 50 FEE_PER_CHAR = 200000 -def get_output(amount=CENT, pubkey_hash=NULL_HASH32, height=-2): - return Transaction(height=height) \ - .add_outputs([Output.pay_pubkey_hash(amount, pubkey_hash)]) \ - .outputs[0] +class TestSizeAndFeeEstimation(TestCase): - -def get_input(amount=CENT, pubkey_hash=NULL_HASH): - return Input.spend(get_output(amount, pubkey_hash)) - - -def get_transaction(txo=None): - return Transaction() \ - .add_inputs([get_input()]) \ - .add_outputs([txo or Output.pay_pubkey_hash(CENT, NULL_HASH32)]) - - -def get_claim_transaction(claim_name, claim=b''): - return get_transaction( - Output.pay_claim_name_pubkey_hash(CENT, claim_name, claim, NULL_HASH32) - ) - - -class TestSizeAndFeeEstimation(AsyncioTestCase): - - async def asyncSetUp(self): - self.ledger = Ledger({ - 'db': Database('sqlite:///:memory:'), - 'headers': Headers(':memory:') - }) - await self.ledger.db.open() - - async def asyncTearDown(self): - await self.ledger.db.close() + def setUp(self): + self.ledger = Ledger() def test_output_size_and_fee(self): txo = get_output() @@ -81,7 +51,7 @@ class TestSizeAndFeeEstimation(AsyncioTestCase): self.assertEqual(tx.get_base_fee(self.ledger), FEE_PER_BYTE * tx.base_size) -class TestAccountBalanceImpactFromTransaction(unittest.TestCase): +class TestAccountBalanceImpactFromTransaction(TestCase): def test_is_my_output_not_set(self): tx = get_transaction() @@ -97,8 +67,8 @@ class TestAccountBalanceImpactFromTransaction(unittest.TestCase): def test_paying_from_my_account_to_other_account(self): tx = Transaction() \ .add_inputs([get_input(300*CENT)]) \ - .add_outputs([get_output(190*CENT, NULL_HASH), - get_output(100*CENT, NULL_HASH)]) + .add_outputs([get_output(190*CENT, NULL_HASH32), + get_output(100*CENT, NULL_HASH32)]) tx.inputs[0].txo_ref.txo.is_my_output = True tx.outputs[0].is_my_output = False tx.outputs[1].is_my_output = True @@ -107,8 +77,8 @@ class TestAccountBalanceImpactFromTransaction(unittest.TestCase): def test_paying_from_other_account_to_my_account(self): tx = Transaction() \ .add_inputs([get_input(300*CENT)]) \ - .add_outputs([get_output(190*CENT, NULL_HASH), - get_output(100*CENT, NULL_HASH)]) + .add_outputs([get_output(190*CENT, NULL_HASH32), + get_output(100*CENT, NULL_HASH32)]) tx.inputs[0].txo_ref.txo.is_my_output = False tx.outputs[0].is_my_output = True tx.outputs[1].is_my_output = False @@ -117,15 +87,15 @@ class TestAccountBalanceImpactFromTransaction(unittest.TestCase): def test_paying_from_my_account_to_my_account(self): tx = Transaction() \ .add_inputs([get_input(300*CENT)]) \ - .add_outputs([get_output(190*CENT, NULL_HASH), - get_output(100*CENT, NULL_HASH)]) + .add_outputs([get_output(190*CENT, NULL_HASH32), + get_output(100*CENT, NULL_HASH32)]) tx.inputs[0].txo_ref.txo.is_my_output = True tx.outputs[0].is_my_output = True tx.outputs[1].is_my_output = True self.assertEqual(tx.net_account_balance, -10*CENT) # lost to fee -class TestTransactionSerialization(unittest.TestCase): +class TestTransactionSerialization(TestCase): def test_genesis_transaction(self): raw = unhexlify( @@ -259,164 +229,3 @@ class TestTransactionSerialization(unittest.TestCase): tx._reset() self.assertEqual(tx.raw, raw) - - -class TestTransactionSigning(AsyncioTestCase): - - async def asyncSetUp(self): - self.ledger = Ledger({ - 'db': Database('sqlite:///:memory:'), - 'headers': Headers(':memory:') - }) - await self.ledger.db.open() - - async def asyncTearDown(self): - await self.ledger.db.close() - - async def test_sign(self): - account = Account.from_dict( - self.ledger, Wallet(), { - "seed": - "carbon smart garage balance margin twelve chest sword toas" - "t envelope bottom stomach absent" - } - ) - - await account.ensure_address_gap() - address1, address2 = await account.receiving.get_addresses(limit=2) - pubkey_hash1 = self.ledger.address_to_hash160(address1) - pubkey_hash2 = self.ledger.address_to_hash160(address2) - - tx = Transaction() \ - .add_inputs([Input.spend(get_output(int(2*COIN), pubkey_hash1))]) \ - .add_outputs([Output.pay_pubkey_hash(int(1.9*COIN), pubkey_hash2)]) - - await tx.sign([account]) - - self.assertEqual( - hexlify(tx.inputs[0].script.values['signature']), - b'304402200dafa26ad7cf38c5a971c8a25ce7d85a076235f146126762296b1223c42ae21e022020ef9eeb8' - b'398327891008c5c0be4357683f12cb22346691ff23914f457bf679601' - ) - - -class TransactionIOBalancing(AsyncioTestCase): - - async def asyncSetUp(self): - self.ledger = Ledger({ - 'db': Database('sqlite:///:memory:'), - 'headers': Headers(':memory:') - }) - await self.ledger.db.open() - self.account = Account.from_dict( - self.ledger, Wallet(), { - "seed": "carbon smart garage balance margin twelve chest sword " - "toast envelope bottom stomach absent" - } - ) - - addresses = await self.account.ensure_address_gap() - self.pubkey_hash = [self.ledger.address_to_hash160(a) for a in addresses] - self.hash_cycler = cycle(self.pubkey_hash) - - async def asyncTearDown(self): - await self.ledger.db.close() - - def txo(self, amount, address=None): - return get_output(int(amount*COIN), address or next(self.hash_cycler)) - - def txi(self, txo): - return Input.spend(txo) - - def tx(self, inputs, outputs): - return Transaction.create(inputs, outputs, [self.account], self.account) - - async def create_utxos(self, amounts): - utxos = [self.txo(amount) for amount in amounts] - - self.funding_tx = Transaction(is_verified=True) \ - .add_inputs([self.txi(self.txo(sum(amounts)+0.1))]) \ - .add_outputs(utxos) - - await self.ledger.db.insert_transaction(self.funding_tx) - - for utxo in utxos: - await self.ledger.db.save_transaction_io( - self.funding_tx, - self.ledger.hash160_to_address(utxo.script.values['pubkey_hash']), - utxo.script.values['pubkey_hash'], '' - ) - - return utxos - - @staticmethod - def inputs(tx): - return [round(i.amount/COIN, 2) for i in tx.inputs] - - @staticmethod - def outputs(tx): - return [round(o.amount/COIN, 2) for o in tx.outputs] - - async def test_basic_use_cases(self): - self.ledger.fee_per_byte = int(.01*CENT) - - # available UTXOs for filling missing inputs - utxos = await self.create_utxos([ - 1, 1, 3, 5, 10 - ]) - - # pay 3 coins (3.02 w/ fees) - tx = await self.tx( - [], # inputs - [self.txo(3)] # outputs - ) - # best UTXO match is 5 (as UTXO 3 will be short 0.02 to cover fees) - self.assertListEqual(self.inputs(tx), [5]) - # a change of 1.98 is added to reach balance - self.assertListEqual(self.outputs(tx), [3, 1.98]) - - await self.ledger.release_outputs(utxos) - - # pay 2.98 coins (3.00 w/ fees) - tx = await self.tx( - [], # inputs - [self.txo(2.98)] # outputs - ) - # best UTXO match is 3 and no change is needed - self.assertListEqual(self.inputs(tx), [3]) - self.assertListEqual(self.outputs(tx), [2.98]) - - await self.ledger.release_outputs(utxos) - - # supplied input and output, but input is not enough to cover output - tx = await self.tx( - [self.txi(self.txo(10))], # inputs - [self.txo(11)] # outputs - ) - # additional input is chosen (UTXO 3) - self.assertListEqual([10, 3], self.inputs(tx)) - # change is now needed to consume extra input - self.assertListEqual([11, 1.96], self.outputs(tx)) - - await self.ledger.release_outputs(utxos) - - # liquidating a UTXO - tx = await self.tx( - [self.txi(self.txo(10))], # inputs - [] # outputs - ) - self.assertListEqual([10], self.inputs(tx)) - # missing change added to consume the amount - self.assertListEqual([9.98], self.outputs(tx)) - - await self.ledger.release_outputs(utxos) - - # liquidating at a loss, requires adding extra inputs - tx = await self.tx( - [self.txi(self.txo(0.01))], # inputs - [] # outputs - ) - # UTXO 1 is added to cover some of the fee - self.assertListEqual([0.01, 1], self.inputs(tx)) - # change is now needed to consume extra input - self.assertListEqual([0.97], self.outputs(tx)) diff --git a/tests/unit/blockchain/test_utils.py b/tests/unit/blockchain/test_utils.py index b0933cb88..32fbd9f89 100644 --- a/tests/unit/blockchain/test_utils.py +++ b/tests/unit/blockchain/test_utils.py @@ -1,7 +1,7 @@ import unittest -from lbry.wallet.util import ArithUint256 -from lbry.wallet.util import coins_to_satoshis as c2s, satoshis_to_coins as s2c +from lbry.blockchain.util import ArithUint256 +from lbry.blockchain.util import coins_to_satoshis as c2s, satoshis_to_coins as s2c class TestCoinValueParsing(unittest.TestCase):