diff --git a/lbry/extras/daemon/daemon.py b/lbry/extras/daemon/daemon.py index fb25207f7..ea7a4da6f 100644 --- a/lbry/extras/daemon/daemon.py +++ b/lbry/extras/daemon/daemon.py @@ -53,8 +53,7 @@ from lbry.extras.daemon.security import ensure_request_allowed from lbry.file_analysis import VideoFileAnalyzer from lbry.schema.claim import Claim from lbry.schema.url import URL, normalize_name -from lbry.wallet.server.db.elasticsearch.constants import RANGE_FIELDS, REPLACEMENTS -MY_RANGE_FIELDS = RANGE_FIELDS - {"limit_claims_per_channel"} + if typing.TYPE_CHECKING: from lbry.blob.blob_manager import BlobManager @@ -67,6 +66,29 @@ if typing.TYPE_CHECKING: log = logging.getLogger(__name__) +RANGE_FIELDS = { + 'height', 'creation_height', 'activation_height', 'expiration_height', + 'timestamp', 'creation_timestamp', 'duration', 'release_time', 'fee_amount', + 'tx_position', 'repost_count', 'limit_claims_per_channel', + 'amount', 'effective_amount', 'support_amount', + 'trending_score', 'censor_type', 'tx_num' +} +MY_RANGE_FIELDS = RANGE_FIELDS - {"limit_claims_per_channel"} +REPLACEMENTS = { + 'claim_name': 'normalized_name', + 'name': 'normalized_name', + 'txid': 'tx_id', + 'nout': 'tx_nout', + 'trending_group': 'trending_score', + 'trending_mixed': 'trending_score', + 'trending_global': 'trending_score', + 'trending_local': 'trending_score', + 'reposted': 'repost_count', + 'stream_types': 'stream_type', + 'media_types': 'media_type', + 'valid_channel_signature': 'is_signature_valid' +} + def is_transactional_function(name): for action in ('create', 'update', 'abandon', 'send', 'fund'): diff --git a/lbry/schema/result.py b/lbry/schema/result.py index 4d193133d..1b9ef083c 100644 --- a/lbry/schema/result.py +++ b/lbry/schema/result.py @@ -1,13 +1,11 @@ import base64 -from typing import List, TYPE_CHECKING, Union, Optional +from typing import List, Union, Optional, NamedTuple from binascii import hexlify from itertools import chain from lbry.error import ResolveCensoredError from lbry.schema.types.v2.result_pb2 import Outputs as OutputsMessage from lbry.schema.types.v2.result_pb2 import Error as ErrorMessage -if TYPE_CHECKING: - from lbry.wallet.server.leveldb import ResolveResult INVALID = ErrorMessage.Code.Name(ErrorMessage.INVALID) NOT_FOUND = ErrorMessage.Code.Name(ErrorMessage.NOT_FOUND) @@ -24,6 +22,31 @@ def set_reference(reference, claim_hash, rows): return +class ResolveResult(NamedTuple): + name: str + normalized_name: str + claim_hash: bytes + tx_num: int + position: int + tx_hash: bytes + height: int + amount: int + short_url: str + is_controlling: bool + canonical_url: str + creation_height: int + activation_height: int + expiration_height: int + effective_amount: int + support_amount: int + reposted: int + last_takeover_height: Optional[int] + claims_in_channel: Optional[int] + channel_hash: Optional[bytes] + reposted_claim_hash: Optional[bytes] + signature_valid: Optional[bool] + + class Censor: NOT_CENSORED = 0 diff --git a/lbry/testcase.py b/lbry/testcase.py index 6214553e5..f6520dcef 100644 --- a/lbry/testcase.py +++ b/lbry/testcase.py @@ -19,7 +19,7 @@ from lbry.conf import Config from lbry.wallet.util import satoshis_to_coins from lbry.wallet.dewies import lbc_to_dewies from lbry.wallet.orchstr8 import Conductor -from lbry.wallet.orchstr8.node import BlockchainNode, WalletNode, HubNode +from lbry.wallet.orchstr8.node import LBCWalletNode, WalletNode, HubNode from lbry.schema.claim import Claim from lbry.extras.daemon.daemon import Daemon, jsonrpc_dumps_pretty @@ -236,7 +236,7 @@ class IntegrationTestCase(AsyncioTestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.conductor: Optional[Conductor] = None - self.blockchain: Optional[BlockchainNode] = None + self.blockchain: Optional[LBCWalletNode] = None self.hub: Optional[HubNode] = None self.wallet_node: Optional[WalletNode] = None self.manager: Optional[WalletManager] = None @@ -246,15 +246,17 @@ class IntegrationTestCase(AsyncioTestCase): async def asyncSetUp(self): self.conductor = Conductor(seed=self.SEED) - await self.conductor.start_blockchain() - self.addCleanup(self.conductor.stop_blockchain) + await self.conductor.start_lbcd() + self.addCleanup(self.conductor.stop_lbcd) + await self.conductor.start_lbcwallet() + self.addCleanup(self.conductor.stop_lbcwallet) await self.conductor.start_spv() self.addCleanup(self.conductor.stop_spv) await self.conductor.start_wallet() self.addCleanup(self.conductor.stop_wallet) await self.conductor.start_hub() self.addCleanup(self.conductor.stop_hub) - self.blockchain = self.conductor.blockchain_node + self.blockchain = self.conductor.lbcwallet_node self.hub = self.conductor.hub_node self.wallet_node = self.conductor.wallet_node self.manager = self.wallet_node.manager @@ -269,6 +271,13 @@ class IntegrationTestCase(AsyncioTestCase): def broadcast(self, tx): return self.ledger.broadcast(tx) + async def broadcast_and_confirm(self, tx, ledger=None): + ledger = ledger or self.ledger + notifications = asyncio.create_task(ledger.wait(tx)) + await ledger.broadcast(tx) + await notifications + await self.generate_and_wait(1, [tx.id], ledger) + async def on_header(self, height): if self.ledger.headers.height < height: await self.ledger.on_header.where( @@ -276,11 +285,36 @@ class IntegrationTestCase(AsyncioTestCase): ) return True - def on_transaction_id(self, txid, ledger=None): - return (ledger or self.ledger).on_transaction.where( - lambda e: e.tx.id == txid + async def send_to_address_and_wait(self, address, amount, blocks_to_generate=0, ledger=None): + tx_watch = [] + txid = None + done = False + watcher = (ledger or self.ledger).on_transaction.where( + lambda e: e.tx.id == txid or done or tx_watch.append(e.tx.id) ) + txid = await self.blockchain.send_to_address(address, amount) + done = txid in tx_watch + await watcher + + await self.generate_and_wait(blocks_to_generate, [txid], ledger) + return txid + + async def generate_and_wait(self, blocks_to_generate, txids, ledger=None): + if blocks_to_generate > 0: + watcher = (ledger or self.ledger).on_transaction.where( + lambda e: ((e.tx.id in txids and txids.remove(e.tx.id)), len(txids) <= 0)[-1] # multi-statement lambda + ) + self.conductor.spv_node.server.synchronized.clear() + await self.blockchain.generate(blocks_to_generate) + height = self.blockchain.block_expected + await watcher + while True: + await self.conductor.spv_node.server.synchronized.wait() + self.conductor.spv_node.server.synchronized.clear() + if self.conductor.spv_node.server.db.db_height >= height: + break + def on_address_update(self, address): return self.ledger.on_transaction.where( lambda e: e.address == address @@ -291,6 +325,19 @@ class IntegrationTestCase(AsyncioTestCase): lambda e: e.tx.id == tx.id and e.address == address ) + async def generate(self, blocks): + """ Ask lbrycrd to generate some blocks and wait until ledger has them. """ + prepare = self.ledger.on_header.where(self.blockchain.is_expected_block) + height = self.blockchain.block_expected + self.conductor.spv_node.server.synchronized.clear() + await self.blockchain.generate(blocks) + await prepare # no guarantee that it didn't happen already, so start waiting from before calling generate + while True: + await self.conductor.spv_node.server.synchronized.wait() + self.conductor.spv_node.server.synchronized.clear() + if self.conductor.spv_node.server.db.db_height >= height: + break + class FakeExchangeRateManager(ExchangeRateManager): @@ -351,20 +398,19 @@ class CommandTestCase(IntegrationTestCase): self.skip_libtorrent = True async def asyncSetUp(self): - await super().asyncSetUp() logging.getLogger('lbry.blob_exchange').setLevel(self.VERBOSITY) logging.getLogger('lbry.daemon').setLevel(self.VERBOSITY) logging.getLogger('lbry.stream').setLevel(self.VERBOSITY) logging.getLogger('lbry.wallet').setLevel(self.VERBOSITY) + await super().asyncSetUp() + self.daemon = await self.add_daemon(self.wallet_node) await self.account.ensure_address_gap() address = (await self.account.receiving.get_addresses(limit=1, only_usable=True))[0] - sendtxid = await self.blockchain.send_to_address(address, 10) - await self.confirm_tx(sendtxid) - await self.generate(5) + await self.send_to_address_and_wait(address, 10, 6) server_tmp_dir = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, server_tmp_dir) @@ -461,9 +507,14 @@ class CommandTestCase(IntegrationTestCase): async def confirm_tx(self, txid, ledger=None): """ Wait for tx to be in mempool, then generate a block, wait for tx to be in a block. """ - await self.on_transaction_id(txid, ledger) - await self.generate(1) - await self.on_transaction_id(txid, ledger) + # await (ledger or self.ledger).on_transaction.where(lambda e: e.tx.id == txid) + on_tx = (ledger or self.ledger).on_transaction.where(lambda e: e.tx.id == txid) + await asyncio.wait([self.generate(1), on_tx], timeout=5) + + # # actually, if it's in the mempool or in the block we're fine + # await self.generate_and_wait(1, [txid], ledger=ledger) + # return txid + return txid async def on_transaction_dict(self, tx): @@ -478,12 +529,6 @@ class CommandTestCase(IntegrationTestCase): addresses.add(txo['address']) return list(addresses) - async def generate(self, blocks): - """ Ask lbrycrd to generate some blocks and wait until ledger has them. """ - prepare = self.ledger.on_header.where(self.blockchain.is_expected_block) - await self.blockchain.generate(blocks) - await prepare # no guarantee that it didn't happen already, so start waiting from before calling generate - async def blockchain_claim_name(self, name: str, value: str, amount: str, confirm=True): txid = await self.blockchain._cli_cmnd('claimname', name, value, amount) if confirm: @@ -514,7 +559,7 @@ class CommandTestCase(IntegrationTestCase): return self.sout(tx) return tx - async def create_nondeterministic_channel(self, name, price, pubkey_bytes, daemon=None): + async def create_nondeterministic_channel(self, name, price, pubkey_bytes, daemon=None, blocking=False): account = (daemon or self.daemon).wallet_manager.default_account claim_address = await account.receiving.get_or_create_usable_address() claim = Claim() @@ -524,7 +569,7 @@ class CommandTestCase(IntegrationTestCase): claim_address, [self.account], self.account ) await tx.sign([self.account]) - await (daemon or self.daemon).broadcast_or_release(tx, False) + await (daemon or self.daemon).broadcast_or_release(tx, blocking) return self.sout(tx) def create_upload_file(self, data, prefix=None, suffix=None): diff --git a/lbry/utils.py b/lbry/utils.py index 6a6cdd618..dc3a6c06e 100644 --- a/lbry/utils.py +++ b/lbry/utils.py @@ -405,7 +405,7 @@ async def fallback_get_external_ip(): # used if spv servers can't be used for i async def _get_external_ip(default_servers) -> typing.Tuple[typing.Optional[str], typing.Optional[str]]: # used if upnp is disabled or non-functioning - from lbry.wallet.server.udp import SPVStatusClientProtocol # pylint: disable=C0415 + from lbry.wallet.udp import SPVStatusClientProtocol # pylint: disable=C0415 hostname_to_ip = {} ip_to_hostnames = collections.defaultdict(list) diff --git a/lbry/wallet/__init__.py b/lbry/wallet/__init__.py index 5f2fffa21..4d99646da 100644 --- a/lbry/wallet/__init__.py +++ b/lbry/wallet/__init__.py @@ -1,17 +1,23 @@ -__node_daemon__ = 'lbrycrdd' -__node_cli__ = 'lbrycrd-cli' -__node_bin__ = '' -__node_url__ = ( - 'https://github.com/lbryio/lbrycrd/releases/download/v0.17.4.6/lbrycrd-linux-1746.zip' +__lbcd__ = 'lbcd' +__lbcctl__ = 'lbcctl' +__lbcwallet__ = 'lbcwallet' +__lbcd_url__ = ( + 'https://github.com/lbryio/lbcd/releases/download/' + + 'v0.22.200-beta/lbcd_0.22.200-beta_TARGET_PLATFORM.tar.gz' +) +__lbcwallet_url__ = ( + 'https://github.com/lbryio/lbcwallet/releases/download/' + + 'v0.13.100-alpha-rc2/lbcwallet_0.13.100-alpha-rc2_TARGET_PLATFORM.tar.gz' ) __spvserver__ = 'lbry.wallet.server.coin.LBCRegTest' -from .wallet import Wallet, WalletStorage, TimestampedPreferences, ENCRYPT_ON_DISK -from .manager import WalletManager -from .network import Network -from .ledger import Ledger, RegTestLedger, TestNetLedger, BlockHeightEvent -from .account import Account, AddressManager, SingleKey, HierarchicalDeterministic, DeterministicChannelKeyManager -from .transaction import Transaction, Output, Input -from .script import OutputScript, InputScript -from .database import SQLiteMixin, Database -from .header import Headers +from lbry.wallet.wallet import Wallet, WalletStorage, TimestampedPreferences, ENCRYPT_ON_DISK +from lbry.wallet.manager import WalletManager +from lbry.wallet.network import Network +from lbry.wallet.ledger import Ledger, RegTestLedger, TestNetLedger, BlockHeightEvent +from lbry.wallet.account import Account, AddressManager, SingleKey, HierarchicalDeterministic, \ + DeterministicChannelKeyManager +from lbry.wallet.transaction import Transaction, Output, Input +from lbry.wallet.script import OutputScript, InputScript +from lbry.wallet.database import SQLiteMixin, Database +from lbry.wallet.header import Headers diff --git a/lbry/wallet/ledger.py b/lbry/wallet/ledger.py index 652c764d4..57e3ab0db 100644 --- a/lbry/wallet/ledger.py +++ b/lbry/wallet/ledger.py @@ -16,18 +16,18 @@ from lbry.crypto.hash import hash160, double_sha256, sha256 from lbry.crypto.base58 import Base58 from lbry.utils import LRUCacheWithMetrics -from .tasks import TaskGroup -from .database import Database -from .stream import StreamController -from .dewies import dewies_to_lbc -from .account import Account, AddressManager, SingleKey -from .network import Network -from .transaction import Transaction, Output -from .header import Headers, UnvalidatedHeaders -from .checkpoints import HASHES -from .constants import TXO_TYPES, CLAIM_TYPES, COIN, NULL_HASH32 -from .bip32 import PublicKey, PrivateKey -from .coinselection import CoinSelector +from lbry.wallet.tasks import TaskGroup +from lbry.wallet.database import Database +from lbry.wallet.stream import StreamController +from lbry.wallet.dewies import dewies_to_lbc +from lbry.wallet.account import Account, AddressManager, SingleKey +from lbry.wallet.network import Network +from lbry.wallet.transaction import Transaction, Output +from lbry.wallet.header import Headers, UnvalidatedHeaders +from lbry.wallet.checkpoints import HASHES +from lbry.wallet.constants import TXO_TYPES, CLAIM_TYPES, COIN, NULL_HASH32 +from lbry.wallet.bip32 import PublicKey, PrivateKey +from lbry.wallet.coinselection import CoinSelector log = logging.getLogger(__name__) @@ -365,6 +365,10 @@ class Ledger(metaclass=LedgerRegistry): await self.db.close() await self.headers.close() + async def tasks_are_done(self): + await self._update_tasks.done.wait() + await self._other_tasks.done.wait() + @property def local_height_including_downloaded_height(self): return max(self.headers.height, self._download_height) @@ -739,7 +743,7 @@ class Ledger(metaclass=LedgerRegistry): while timeout and (int(time.perf_counter()) - start) <= timeout: if await self._wait_round(tx, height, addresses): return - raise asyncio.TimeoutError('Timed out waiting for transaction.') + raise asyncio.TimeoutError(f'Timed out waiting for transaction. {tx.id}') async def _wait_round(self, tx: Transaction, height: int, addresses: Iterable[str]): records = await self.db.get_addresses(address__in=addresses) @@ -782,7 +786,7 @@ class Ledger(metaclass=LedgerRegistry): if hub_server: outputs = Outputs.from_grpc(encoded_outputs) else: - outputs = Outputs.from_base64(encoded_outputs or b'') # TODO: why is the server returning None? + outputs = Outputs.from_base64(encoded_outputs or '') # TODO: why is the server returning None? txs: List[Transaction] = [] if len(outputs.txs) > 0: async for tx in self.request_transactions(tuple(outputs.txs), cached=True): diff --git a/lbry/wallet/manager.py b/lbry/wallet/manager.py index d5b73aa21..adf6b3962 100644 --- a/lbry/wallet/manager.py +++ b/lbry/wallet/manager.py @@ -12,13 +12,13 @@ from typing import List, Type, MutableSequence, MutableMapping, Optional from lbry.error import KeyFeeAboveMaxAllowedError, WalletNotLoadedError from lbry.conf import Config, NOT_SET -from .dewies import dewies_to_lbc -from .account import Account -from .ledger import Ledger, LedgerRegistry -from .transaction import Transaction, Output -from .database import Database -from .wallet import Wallet, WalletStorage, ENCRYPT_ON_DISK -from .rpc.jsonrpc import CodeMessageError +from lbry.wallet.dewies import dewies_to_lbc +from lbry.wallet.account import Account +from lbry.wallet.ledger import Ledger, LedgerRegistry +from lbry.wallet.transaction import Transaction, Output +from lbry.wallet.database import Database +from lbry.wallet.wallet import Wallet, WalletStorage, ENCRYPT_ON_DISK +from lbry.wallet.rpc.jsonrpc import CodeMessageError if typing.TYPE_CHECKING: from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager diff --git a/lbry/wallet/network.py b/lbry/wallet/network.py index 5f796bef5..27c57dd1b 100644 --- a/lbry/wallet/network.py +++ b/lbry/wallet/network.py @@ -16,7 +16,7 @@ from lbry.utils import resolve_host from lbry.error import IncompatibleWalletServerError from lbry.wallet.rpc import RPCSession as BaseClientSession, Connector, RPCError, ProtocolError from lbry.wallet.stream import StreamController -from lbry.wallet.server.udp import SPVStatusClientProtocol, SPVPong +from lbry.wallet.udp import SPVStatusClientProtocol, SPVPong from lbry.conf import KnownHubsList log = logging.getLogger(__name__) @@ -122,7 +122,7 @@ class ClientSession(BaseClientSession): await asyncio.sleep(max(0, max_idle - (now - self.last_send))) except Exception as err: if isinstance(err, asyncio.CancelledError): - log.warning("closing connection to %s:%i", *self.server) + log.info("closing connection to %s:%i", *self.server) else: log.exception("lost connection to spv") finally: @@ -140,7 +140,7 @@ class ClientSession(BaseClientSession): controller.add(request.args) def connection_lost(self, exc): - log.warning("Connection lost: %s:%d", *self.server) + log.debug("Connection lost: %s:%d", *self.server) super().connection_lost(exc) self.response_time = None self.connection_latency = None @@ -303,7 +303,7 @@ class Network: concurrency=self.config.get('concurrent_hub_requests', 30)) try: await client.create_connection() - log.warning("Connected to spv server %s:%i", host, port) + log.info("Connected to spv server %s:%i", host, port) await client.ensure_server_version() return client except (asyncio.TimeoutError, ConnectionError, OSError, IncompatibleWalletServerError, RPCError): @@ -357,7 +357,7 @@ class Network: self._keepalive_task = None self.client = None self.server_features = None - log.warning("connection lost to %s", server_str) + log.info("connection lost to %s", server_str) log.info("network loop finished") async def stop(self): diff --git a/lbry/wallet/orchstr8/__init__.py b/lbry/wallet/orchstr8/__init__.py index 72791f2a3..5827bb95e 100644 --- a/lbry/wallet/orchstr8/__init__.py +++ b/lbry/wallet/orchstr8/__init__.py @@ -1,5 +1,5 @@ __hub_url__ = ( "https://github.com/lbryio/hub/releases/download/v0.2022.01.21.1/hub" ) -from .node import Conductor -from .service import ConductorService +from lbry.wallet.orchstr8.node import Conductor +from lbry.wallet.orchstr8.service import ConductorService diff --git a/lbry/wallet/orchstr8/cli.py b/lbry/wallet/orchstr8/cli.py index ee4ddc60c..f75216882 100644 --- a/lbry/wallet/orchstr8/cli.py +++ b/lbry/wallet/orchstr8/cli.py @@ -5,7 +5,9 @@ import aiohttp from lbry import wallet from lbry.wallet.orchstr8.node import ( - Conductor, get_blockchain_node_from_ledger + Conductor, + get_lbcd_node_from_ledger, + get_lbcwallet_node_from_ledger ) from lbry.wallet.orchstr8.service import ConductorService @@ -16,10 +18,11 @@ def get_argument_parser(): ) subparsers = parser.add_subparsers(dest='command', help='sub-command help') - subparsers.add_parser("download", help="Download blockchain node binary.") + subparsers.add_parser("download", help="Download lbcd and lbcwallet node binaries.") start = subparsers.add_parser("start", help="Start orchstr8 service.") - start.add_argument("--blockchain", help="Hostname to start blockchain node.") + start.add_argument("--lbcd", help="Hostname to start lbcd node.") + start.add_argument("--lbcwallet", help="Hostname to start lbcwallet node.") start.add_argument("--spv", help="Hostname to start SPV server.") start.add_argument("--wallet", help="Hostname to start wallet daemon.") @@ -47,7 +50,8 @@ def main(): if command == 'download': logging.getLogger('blockchain').setLevel(logging.INFO) - get_blockchain_node_from_ledger(wallet).ensure() + get_lbcd_node_from_ledger(wallet).ensure() + get_lbcwallet_node_from_ledger(wallet).ensure() elif command == 'generate': loop.run_until_complete(run_remote_command( @@ -57,9 +61,12 @@ def main(): elif command == 'start': conductor = Conductor() - if getattr(args, 'blockchain', False): - conductor.blockchain_node.hostname = args.blockchain - loop.run_until_complete(conductor.start_blockchain()) + if getattr(args, 'lbcd', False): + conductor.lbcd_node.hostname = args.lbcd + loop.run_until_complete(conductor.start_lbcd()) + if getattr(args, 'lbcwallet', False): + conductor.lbcwallet_node.hostname = args.lbcwallet + loop.run_until_complete(conductor.start_lbcwallet()) if getattr(args, 'spv', False): conductor.spv_node.hostname = args.spv loop.run_until_complete(conductor.start_spv()) diff --git a/lbry/wallet/orchstr8/node.py b/lbry/wallet/orchstr8/node.py index 572331629..640a74cfb 100644 --- a/lbry/wallet/orchstr8/node.py +++ b/lbry/wallet/orchstr8/node.py @@ -1,4 +1,6 @@ +# pylint: disable=import-error import os +import signal import json import shutil import asyncio @@ -7,7 +9,7 @@ import tarfile import logging import tempfile import subprocess -import importlib +import platform from distutils.util import strtobool from binascii import hexlify @@ -15,9 +17,15 @@ from typing import Type, Optional import urllib.request from uuid import uuid4 +try: + from scribe.env import Env + from scribe.hub.service import HubServerService + from scribe.elasticsearch.service import ElasticSyncService + from scribe.blockchain.service import BlockchainProcessorService +except ImportError: + pass + import lbry -from lbry.wallet.server.server import Server -from lbry.wallet.server.env import Env from lbry.wallet import Wallet, Ledger, RegTestLedger, WalletManager, Account, BlockHeightEvent from lbry.conf import KnownHubsList, Config from lbry.wallet.orchstr8 import __hub_url__ @@ -25,17 +33,19 @@ from lbry.wallet.orchstr8 import __hub_url__ log = logging.getLogger(__name__) -def get_spvserver_from_ledger(ledger_module): - spvserver_path, regtest_class_name = ledger_module.__spvserver__.rsplit('.', 1) - spvserver_module = importlib.import_module(spvserver_path) - return getattr(spvserver_module, regtest_class_name) +def get_lbcd_node_from_ledger(ledger_module): + return LBCDNode( + ledger_module.__lbcd_url__, + ledger_module.__lbcd__, + ledger_module.__lbcctl__ + ) -def get_blockchain_node_from_ledger(ledger_module): - return BlockchainNode( - ledger_module.__node_url__, - os.path.join(ledger_module.__node_bin__, ledger_module.__node_daemon__), - os.path.join(ledger_module.__node_bin__, ledger_module.__node_cli__) +def get_lbcwallet_node_from_ledger(ledger_module): + return LBCWalletNode( + ledger_module.__lbcwallet_url__, + ledger_module.__lbcwallet__, + ledger_module.__lbcctl__ ) @@ -43,53 +53,51 @@ class Conductor: def __init__(self, seed=None): self.manager_module = WalletManager - self.spv_module = get_spvserver_from_ledger(lbry.wallet) - - self.blockchain_node = get_blockchain_node_from_ledger(lbry.wallet) - self.spv_node = SPVNode(self.spv_module) + self.lbcd_node = get_lbcd_node_from_ledger(lbry.wallet) + self.lbcwallet_node = get_lbcwallet_node_from_ledger(lbry.wallet) + self.spv_node = SPVNode() self.wallet_node = WalletNode( self.manager_module, RegTestLedger, default_seed=seed ) self.hub_node = HubNode(__hub_url__, "hub", self.spv_node) - self.blockchain_started = False + self.lbcd_started = False + self.lbcwallet_started = False self.spv_started = False self.wallet_started = False self.hub_started = False self.log = log.getChild('conductor') - async def start_blockchain(self): - if not self.blockchain_started: - asyncio.create_task(self.blockchain_node.start()) - await self.blockchain_node.running.wait() - await self.blockchain_node.generate(200) - self.blockchain_started = True + async def start_lbcd(self): + if not self.lbcd_started: + await self.lbcd_node.start() + self.lbcd_started = True - async def stop_blockchain(self): - if self.blockchain_started: - await self.blockchain_node.stop(cleanup=True) - self.blockchain_started = False + async def stop_lbcd(self, cleanup=True): + if self.lbcd_started: + await self.lbcd_node.stop(cleanup) + self.lbcd_started = False async def start_hub(self): if not self.hub_started: - asyncio.create_task(self.hub_node.start()) - await self.blockchain_node.running.wait() + await self.hub_node.start() + await self.lbcwallet_node.running.wait() self.hub_started = True - async def stop_hub(self): + async def stop_hub(self, cleanup=True): if self.hub_started: - await self.hub_node.stop(cleanup=True) + await self.hub_node.stop(cleanup) self.hub_started = False async def start_spv(self): if not self.spv_started: - await self.spv_node.start(self.blockchain_node) + await self.spv_node.start(self.lbcwallet_node) self.spv_started = True - async def stop_spv(self): + async def stop_spv(self, cleanup=True): if self.spv_started: - await self.spv_node.stop(cleanup=True) + await self.spv_node.stop(cleanup) self.spv_started = False async def start_wallet(self): @@ -97,21 +105,41 @@ class Conductor: await self.wallet_node.start(self.spv_node) self.wallet_started = True - async def stop_wallet(self): + async def stop_wallet(self, cleanup=True): if self.wallet_started: - await self.wallet_node.stop(cleanup=True) + await self.wallet_node.stop(cleanup) self.wallet_started = False + async def start_lbcwallet(self, clean=True): + if not self.lbcwallet_started: + await self.lbcwallet_node.start() + if clean: + mining_addr = await self.lbcwallet_node.get_new_address() + self.lbcwallet_node.mining_addr = mining_addr + await self.lbcwallet_node.generate(200) + # unlock the wallet for the next 1 hour + await self.lbcwallet_node.wallet_passphrase("password", 3600) + self.lbcwallet_started = True + + async def stop_lbcwallet(self, cleanup=True): + if self.lbcwallet_started: + await self.lbcwallet_node.stop(cleanup) + self.lbcwallet_started = False + async def start(self): - await self.start_blockchain() + await self.start_lbcd() + await self.start_lbcwallet() await self.start_spv() + await self.start_hub() await self.start_wallet() async def stop(self): all_the_stops = [ self.stop_wallet, + self.stop_hub, self.stop_spv, - self.stop_blockchain + self.stop_lbcwallet, + self.stop_lbcd ] for stop in all_the_stops: try: @@ -119,6 +147,12 @@ class Conductor: except Exception as e: log.exception('Exception raised while stopping services:', exc_info=e) + async def clear_mempool(self): + await self.stop_lbcwallet(cleanup=False) + await self.stop_lbcd(cleanup=False) + await self.start_lbcd() + await self.start_lbcwallet(clean=False) + class WalletNode: @@ -139,10 +173,11 @@ class WalletNode: async def start(self, spv_node: 'SPVNode', seed=None, connect=True, config=None): wallets_dir = os.path.join(self.data_path, 'wallets') - os.mkdir(wallets_dir) wallet_file_name = os.path.join(wallets_dir, 'my_wallet.json') - with open(wallet_file_name, 'w') as wallet_file: - wallet_file.write('{"version": 1, "accounts": []}\n') + if not os.path.isdir(wallets_dir): + os.mkdir(wallets_dir) + with open(wallet_file_name, 'w') as wallet_file: + wallet_file.write('{"version": 1, "accounts": []}\n') self.manager = self.manager_class.from_config({ 'ledgers': { self.ledger_class.get_id(): { @@ -184,55 +219,72 @@ class WalletNode: class SPVNode: - - def __init__(self, coin_class, node_number=1): - self.coin_class = coin_class + def __init__(self, node_number=1): + self.node_number = node_number self.controller = None self.data_path = None - self.server = None + self.server: Optional[HubServerService] = None + self.writer: Optional[BlockchainProcessorService] = None + self.es_writer: Optional[ElasticSyncService] = None self.hostname = 'localhost' self.port = 50001 + node_number # avoid conflict with default daemon self.udp_port = self.port + self.elastic_notifier_port = 19080 + node_number self.session_timeout = 600 - self.rpc_port = '0' # disabled by default - self.stopped = False + self.stopped = True self.index_name = uuid4().hex - async def start(self, blockchain_node: 'BlockchainNode', extraconf=None): - self.data_path = tempfile.mkdtemp() - conf = { - 'DESCRIPTION': '', - 'PAYMENT_ADDRESS': '', - 'DAILY_FEE': '0', - 'DB_DIRECTORY': self.data_path, - 'DAEMON_URL': blockchain_node.rpc_url, - 'REORG_LIMIT': '100', - 'HOST': self.hostname, - 'TCP_PORT': str(self.port), - 'UDP_PORT': str(self.udp_port), - 'SESSION_TIMEOUT': str(self.session_timeout), - 'MAX_QUERY_WORKERS': '0', - 'INDIVIDUAL_TAG_INDEXES': '', - 'RPC_PORT': self.rpc_port, - 'ES_INDEX_PREFIX': self.index_name, - 'ES_MODE': 'writer', - } - if extraconf: - conf.update(extraconf) - # TODO: don't use os.environ - os.environ.update(conf) - self.server = Server(Env(self.coin_class)) - self.server.bp.mempool.refresh_secs = self.server.bp.prefetcher.polling_delay = 0.5 - await self.server.start() + async def start(self, lbcwallet_node: 'LBCWalletNode', extraconf=None): + if not self.stopped: + log.warning("spv node is already running") + return + self.stopped = False + try: + self.data_path = tempfile.mkdtemp() + conf = { + 'description': '', + 'payment_address': '', + 'daily_fee': '0', + 'db_dir': self.data_path, + 'daemon_url': lbcwallet_node.rpc_url, + 'reorg_limit': 100, + 'host': self.hostname, + 'tcp_port': self.port, + 'udp_port': self.udp_port, + 'elastic_notifier_port': self.elastic_notifier_port, + 'session_timeout': self.session_timeout, + 'max_query_workers': 0, + 'es_index_prefix': self.index_name, + 'chain': 'regtest' + } + if extraconf: + conf.update(extraconf) + env = Env(**conf) + self.writer = BlockchainProcessorService(env) + self.server = HubServerService(env) + self.es_writer = ElasticSyncService(env) + await self.writer.start() + await self.es_writer.start() + await self.server.start() + except Exception as e: + self.stopped = True + if not isinstance(e, asyncio.CancelledError): + log.exception("failed to start spv node") + raise e async def stop(self, cleanup=True): if self.stopped: + log.warning("spv node is already stopped") return try: - await self.server.db.search_index.delete_index() - await self.server.db.search_index.stop() await self.server.stop() + await self.es_writer.delete_index() + await self.es_writer.stop() + await self.writer.stop() self.stopped = True + except Exception as e: + log.exception("failed to stop spv node") + raise e finally: cleanup and self.cleanup() @@ -240,18 +292,19 @@ class SPVNode: shutil.rmtree(self.data_path, ignore_errors=True) -class BlockchainProcess(asyncio.SubprocessProtocol): +class LBCDProcess(asyncio.SubprocessProtocol): IGNORE_OUTPUT = [ b'keypool keep', b'keypool reserve', b'keypool return', + b'Block submitted', ] def __init__(self): self.ready = asyncio.Event() self.stopped = asyncio.Event() - self.log = log.getChild('blockchain') + self.log = log.getChild('lbcd') def pipe_data_received(self, fd, data): if self.log and not any(ignore in data for ignore in self.IGNORE_OUTPUT): @@ -262,7 +315,7 @@ class BlockchainProcess(asyncio.SubprocessProtocol): if b'Error:' in data: self.ready.set() raise SystemError(data.decode()) - if b'Done loading' in data: + if b'RPCS: RPC server listening on' in data: self.ready.set() def process_exited(self): @@ -270,39 +323,57 @@ class BlockchainProcess(asyncio.SubprocessProtocol): self.ready.set() -class BlockchainNode: +class WalletProcess(asyncio.SubprocessProtocol): - P2SH_SEGWIT_ADDRESS = "p2sh-segwit" - BECH32_ADDRESS = "bech32" + IGNORE_OUTPUT = [ + ] + def __init__(self): + self.ready = asyncio.Event() + self.stopped = asyncio.Event() + self.log = log.getChild('lbcwallet') + self.transport: Optional[asyncio.transports.SubprocessTransport] = None + + def pipe_data_received(self, fd, data): + if self.log and not any(ignore in data for ignore in self.IGNORE_OUTPUT): + if b'Error:' in data: + self.log.error(data.decode()) + else: + self.log.info(data.decode()) + if b'Error:' in data: + self.ready.set() + raise SystemError(data.decode()) + if b'WLLT: Finished rescan' in data: + self.ready.set() + + def process_exited(self): + self.stopped.set() + self.ready.set() + + +class LBCDNode: def __init__(self, url, daemon, cli): self.latest_release_url = url self.project_dir = os.path.dirname(os.path.dirname(__file__)) self.bin_dir = os.path.join(self.project_dir, 'bin') self.daemon_bin = os.path.join(self.bin_dir, daemon) self.cli_bin = os.path.join(self.bin_dir, cli) - self.log = log.getChild('blockchain') - self.data_path = None + self.log = log.getChild('lbcd') + self.data_path = tempfile.mkdtemp() self.protocol = None self.transport = None - self.block_expected = 0 self.hostname = 'localhost' - self.peerport = 9246 + 2 # avoid conflict with default peer port - self.rpcport = 9245 + 2 # avoid conflict with default rpc port + self.peerport = 29246 + self.rpcport = 29245 self.rpcuser = 'rpcuser' self.rpcpassword = 'rpcpassword' - self.stopped = False - self.restart_ready = asyncio.Event() - self.restart_ready.set() + self.stopped = True self.running = asyncio.Event() @property def rpc_url(self): return f'http://{self.rpcuser}:{self.rpcpassword}@{self.hostname}:{self.rpcport}/' - def is_expected_block(self, e: BlockHeightEvent): - return self.block_expected == e.height - @property def exists(self): return ( @@ -311,6 +382,12 @@ class BlockchainNode: ) def download(self): + uname = platform.uname() + target_os = str.lower(uname.system) + target_arch = str.replace(uname.machine, 'x86_64', 'amd64') + target_platform = target_os + '_' + target_arch + self.latest_release_url = str.replace(self.latest_release_url, 'TARGET_PLATFORM', target_platform) + downloaded_file = os.path.join( self.bin_dir, self.latest_release_url[self.latest_release_url.rfind('/')+1:] @@ -344,72 +421,206 @@ class BlockchainNode: return self.exists or self.download() async def start(self): - assert self.ensure() - self.data_path = tempfile.mkdtemp() - loop = asyncio.get_event_loop() - asyncio.get_child_watcher().attach_loop(loop) - command = [ - self.daemon_bin, - f'-datadir={self.data_path}', '-printtoconsole', '-regtest', '-server', '-txindex', - f'-rpcuser={self.rpcuser}', f'-rpcpassword={self.rpcpassword}', f'-rpcport={self.rpcport}', - f'-port={self.peerport}' - ] - self.log.info(' '.join(command)) - while not self.stopped: - if self.running.is_set(): - await asyncio.sleep(1) - continue - await self.restart_ready.wait() - try: - self.transport, self.protocol = await loop.subprocess_exec( - BlockchainProcess, *command - ) - await self.protocol.ready.wait() - assert not self.protocol.stopped.is_set() - self.running.set() - except asyncio.CancelledError: - self.running.clear() - raise - except Exception as e: - self.running.clear() - log.exception('failed to start lbrycrdd', exc_info=e) + if not self.stopped: + return + self.stopped = False + try: + assert self.ensure() + loop = asyncio.get_event_loop() + asyncio.get_child_watcher().attach_loop(loop) + command = [ + self.daemon_bin, + '--notls', + f'--datadir={self.data_path}', + '--regtest', f'--listen=127.0.0.1:{self.peerport}', f'--rpclisten=127.0.0.1:{self.rpcport}', + '--txindex', f'--rpcuser={self.rpcuser}', f'--rpcpass={self.rpcpassword}' + ] + self.log.info(' '.join(command)) + self.transport, self.protocol = await loop.subprocess_exec( + LBCDProcess, *command + ) + await self.protocol.ready.wait() + assert not self.protocol.stopped.is_set() + self.running.set() + except asyncio.CancelledError: + self.running.clear() + self.stopped = True + raise + except Exception as e: + self.running.clear() + self.stopped = True + log.exception('failed to start lbcd', exc_info=e) + raise async def stop(self, cleanup=True): + if self.stopped: + return + try: + if self.transport: + self.transport.terminate() + await self.protocol.stopped.wait() + self.transport.close() + except Exception as e: + log.exception('failed to stop lbcd', exc_info=e) + raise + finally: + self.log.info("Done shutting down " + self.daemon_bin) + self.stopped = True + if cleanup: + self.cleanup() + self.running.clear() + + def cleanup(self): + assert self.stopped + shutil.rmtree(self.data_path, ignore_errors=True) + + +class LBCWalletNode: + P2SH_SEGWIT_ADDRESS = "p2sh-segwit" + BECH32_ADDRESS = "bech32" + + def __init__(self, url, lbcwallet, cli): + self.latest_release_url = url + self.project_dir = os.path.dirname(os.path.dirname(__file__)) + self.bin_dir = os.path.join(self.project_dir, 'bin') + self.lbcwallet_bin = os.path.join(self.bin_dir, lbcwallet) + self.cli_bin = os.path.join(self.bin_dir, cli) + self.log = log.getChild('lbcwallet') + self.protocol = None + self.transport = None + self.hostname = 'localhost' + self.lbcd_rpcport = 29245 + self.lbcwallet_rpcport = 29244 + self.rpcuser = 'rpcuser' + self.rpcpassword = 'rpcpassword' + self.data_path = tempfile.mkdtemp() self.stopped = True + self.running = asyncio.Event() + self.block_expected = 0 + self.mining_addr = '' + + @property + def rpc_url(self): + # FIXME: somehow the hub/sdk doesn't learn the blocks through the Walet RPC port, why? + # return f'http://{self.rpcuser}:{self.rpcpassword}@{self.hostname}:{self.lbcwallet_rpcport}/' + return f'http://{self.rpcuser}:{self.rpcpassword}@{self.hostname}:{self.lbcd_rpcport}/' + + def is_expected_block(self, e: BlockHeightEvent): + return self.block_expected == e.height + + @property + def exists(self): + return ( + os.path.exists(self.lbcwallet_bin) + ) + + def download(self): + uname = platform.uname() + target_os = str.lower(uname.system) + target_arch = str.replace(uname.machine, 'x86_64', 'amd64') + target_platform = target_os + '_' + target_arch + self.latest_release_url = str.replace(self.latest_release_url, 'TARGET_PLATFORM', target_platform) + + downloaded_file = os.path.join( + self.bin_dir, + self.latest_release_url[self.latest_release_url.rfind('/')+1:] + ) + + if not os.path.exists(self.bin_dir): + os.mkdir(self.bin_dir) + + if not os.path.exists(downloaded_file): + self.log.info('Downloading: %s', self.latest_release_url) + with urllib.request.urlopen(self.latest_release_url) as response: + with open(downloaded_file, 'wb') as out_file: + shutil.copyfileobj(response, out_file) + + self.log.info('Extracting: %s', downloaded_file) + + if downloaded_file.endswith('.zip'): + with zipfile.ZipFile(downloaded_file) as dotzip: + dotzip.extractall(self.bin_dir) + # zipfile bug https://bugs.python.org/issue15795 + os.chmod(self.lbcwallet_bin, 0o755) + + elif downloaded_file.endswith('.tar.gz'): + with tarfile.open(downloaded_file) as tar: + tar.extractall(self.bin_dir) + + return self.exists + + def ensure(self): + return self.exists or self.download() + + async def start(self): + assert self.ensure() + loop = asyncio.get_event_loop() + asyncio.get_child_watcher().attach_loop(loop) + + command = [ + self.lbcwallet_bin, + '--noservertls', '--noclienttls', + '--regtest', + f'--rpcconnect=127.0.0.1:{self.lbcd_rpcport}', f'--rpclisten=127.0.0.1:{self.lbcwallet_rpcport}', + '--createtemp', f'--appdata={self.data_path}', + f'--username={self.rpcuser}', f'--password={self.rpcpassword}' + ] + self.log.info(' '.join(command)) + try: + self.transport, self.protocol = await loop.subprocess_exec( + WalletProcess, *command + ) + self.protocol.transport = self.transport + await self.protocol.ready.wait() + assert not self.protocol.stopped.is_set() + self.running.set() + self.stopped = False + except asyncio.CancelledError: + self.running.clear() + raise + except Exception as e: + self.running.clear() + log.exception('failed to start lbcwallet', exc_info=e) + + def cleanup(self): + assert self.stopped + shutil.rmtree(self.data_path, ignore_errors=True) + + async def stop(self, cleanup=True): + if self.stopped: + return try: self.transport.terminate() await self.protocol.stopped.wait() self.transport.close() + except Exception as e: + log.exception('failed to stop lbcwallet', exc_info=e) + raise finally: + self.log.info("Done shutting down " + self.lbcwallet_bin) + self.stopped = True if cleanup: self.cleanup() - - async def clear_mempool(self): - self.restart_ready.clear() - self.transport.terminate() - await self.protocol.stopped.wait() - self.transport.close() - self.running.clear() - os.remove(os.path.join(self.data_path, 'regtest', 'mempool.dat')) - self.restart_ready.set() - await self.running.wait() - - def cleanup(self): - shutil.rmtree(self.data_path, ignore_errors=True) + self.running.clear() async def _cli_cmnd(self, *args): cmnd_args = [ - self.cli_bin, f'-datadir={self.data_path}', '-regtest', - f'-rpcuser={self.rpcuser}', f'-rpcpassword={self.rpcpassword}', f'-rpcport={self.rpcport}' + self.cli_bin, + f'--rpcuser={self.rpcuser}', f'--rpcpass={self.rpcpassword}', '--notls', '--regtest', '--wallet' ] + list(args) self.log.info(' '.join(cmnd_args)) loop = asyncio.get_event_loop() asyncio.get_child_watcher().attach_loop(loop) process = await asyncio.create_subprocess_exec( - *cmnd_args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT + *cmnd_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) - out, _ = await process.communicate() + out, err = await process.communicate() result = out.decode().strip() + err = err.decode().strip() + if len(result) <= 0 and err.startswith('-'): + raise Exception(err) + if err and 'creating a default config file' not in err: + log.warning(err) self.log.info(result) if result.startswith('error code'): raise Exception(result) @@ -417,7 +628,14 @@ class BlockchainNode: def generate(self, blocks): self.block_expected += blocks - return self._cli_cmnd('generate', str(blocks)) + return self._cli_cmnd('generatetoaddress', str(blocks), self.mining_addr) + + def generate_to_address(self, blocks, addr): + self.block_expected += blocks + return self._cli_cmnd('generatetoaddress', str(blocks), addr) + + def wallet_passphrase(self, passphrase, timeout): + return self._cli_cmnd('walletpassphrase', passphrase, str(timeout)) def invalidate_block(self, blockhash): return self._cli_cmnd('invalidateblock', blockhash) @@ -434,7 +652,7 @@ class BlockchainNode: def get_raw_change_address(self): return self._cli_cmnd('getrawchangeaddress') - def get_new_address(self, address_type): + def get_new_address(self, address_type='legacy'): return self._cli_cmnd('getnewaddress', "", address_type) async def get_balance(self): @@ -450,7 +668,10 @@ class BlockchainNode: return self._cli_cmnd('createrawtransaction', json.dumps(inputs), json.dumps(outputs)) async def sign_raw_transaction_with_wallet(self, tx): - return json.loads(await self._cli_cmnd('signrawtransactionwithwallet', tx))['hex'].encode() + # the "withwallet" portion should only come into play if we are doing segwit. + # and "withwallet" doesn't exist on lbcd yet. + result = await self._cli_cmnd('signrawtransaction', tx) + return json.loads(result)['hex'].encode() def decode_raw_transaction(self, tx): return self._cli_cmnd('decoderawtransaction', hexlify(tx.raw).decode()) @@ -460,12 +681,15 @@ class BlockchainNode: class HubProcess(asyncio.SubprocessProtocol): - def __init__(self): - self.ready = asyncio.Event() - self.stopped = asyncio.Event() + def __init__(self, ready, stopped): + self.ready = ready + self.stopped = stopped self.log = log.getChild('hub') + self.transport = None def pipe_data_received(self, fd, data): + self.stopped.clear() + self.ready.set() if self.log: self.log.info(data.decode()) if b'error' in data.lower(): @@ -479,16 +703,26 @@ class HubProcess(asyncio.SubprocessProtocol): print(line) def process_exited(self): + self.ready.clear() self.stopped.set() - self.ready.set() + + async def stop(self): + t = asyncio.create_task(self.stopped.wait()) + try: + self.transport.send_signal(signal.SIGINT) + await asyncio.wait_for(t, 3) + # log.warning("stopped go hub") + except asyncio.TimeoutError: + if not t.done(): + t.cancel() + self.transport.terminate() + await self.stopped.wait() + log.warning("terminated go hub") class HubNode: - def __init__(self, url, daemon, spv_node): self.spv_node = spv_node - self.debug = False - self.latest_release_url = url self.project_dir = os.path.dirname(os.path.dirname(__file__)) self.bin_dir = os.path.join(self.project_dir, 'bin') @@ -499,11 +733,13 @@ class HubNode: self.protocol = None self.hostname = 'localhost' self.rpcport = 50051 # avoid conflict with default rpc port - self.stopped = False - self.restart_ready = asyncio.Event() - self.restart_ready.set() + self._stopped = asyncio.Event() self.running = asyncio.Event() + @property + def stopped(self): + return not self.running.is_set() + @property def exists(self): return ( @@ -554,33 +790,24 @@ class HubNode: self.daemon_bin, 'serve', '--esindex', self.spv_node.index_name + 'claims', '--debug' ] self.log.info(' '.join(command)) - while not self.stopped: - if self.running.is_set(): - await asyncio.sleep(1) - continue - await self.restart_ready.wait() - try: - if not self.debug: - self.transport, self.protocol = await loop.subprocess_exec( - HubProcess, *command - ) - await self.protocol.ready.wait() - assert not self.protocol.stopped.is_set() - self.running.set() - except asyncio.CancelledError: - self.running.clear() - raise - except Exception as e: - self.running.clear() - log.exception('failed to start hub', exc_info=e) + self.protocol = HubProcess(self.running, self._stopped) + try: + self.transport, _ = await loop.subprocess_exec( + lambda: self.protocol, *command + ) + self.protocol.transport = self.transport + except Exception as e: + log.exception('failed to start go hub', exc_info=e) + raise e + await self.protocol.ready.wait() async def stop(self, cleanup=True): - self.stopped = True try: - if not self.debug: - self.transport.terminate() - await self.protocol.stopped.wait() - self.transport.close() + if self.protocol: + await self.protocol.stop() + except Exception as e: + log.exception('failed to stop go hub', exc_info=e) + raise e finally: if cleanup: self.cleanup() diff --git a/lbry/wallet/orchstr8/service.py b/lbry/wallet/orchstr8/service.py index 495f68a07..fac3e49ea 100644 --- a/lbry/wallet/orchstr8/service.py +++ b/lbry/wallet/orchstr8/service.py @@ -61,8 +61,10 @@ class ConductorService: #set_logging( # self.stack.ledger_module, logging.DEBUG, WebSocketLogHandler(self.send_message) #) - self.stack.blockchain_started or await self.stack.start_blockchain() - self.send_message({'type': 'service', 'name': 'blockchain', 'port': self.stack.blockchain_node.port}) + self.stack.lbcd_started or await self.stack.start_lbcd() + self.send_message({'type': 'service', 'name': 'lbcd', 'port': self.stack.lbcd_node.port}) + self.stack.lbcwallet_started or await self.stack.start_lbcwallet() + self.send_message({'type': 'service', 'name': 'lbcwallet', 'port': self.stack.lbcwallet_node.port}) self.stack.spv_started or await self.stack.start_spv() self.send_message({'type': 'service', 'name': 'spv', 'port': self.stack.spv_node.port}) self.stack.wallet_started or await self.stack.start_wallet() @@ -74,7 +76,7 @@ class ConductorService: async def generate(self, request): data = await request.post() blocks = data.get('blocks', 1) - await self.stack.blockchain_node.generate(int(blocks)) + await self.stack.lbcwallet_node.generate(int(blocks)) return json_response({'blocks': blocks}) async def transfer(self, request): @@ -85,11 +87,14 @@ class ConductorService: if not address: raise ValueError("No address was provided.") amount = data.get('amount', 1) - txid = await self.stack.blockchain_node.send_to_address(address, amount) if self.stack.wallet_started: - await self.stack.wallet_node.ledger.on_transaction.where( - lambda e: e.tx.id == txid and e.address == address + watcher = self.stack.wallet_node.ledger.on_transaction.where( + lambda e: e.address == address # and e.tx.id == txid -- might stall; see send_to_address_and_wait ) + txid = await self.stack.lbcwallet_node.send_to_address(address, amount) + await watcher + else: + txid = await self.stack.lbcwallet_node.send_to_address(address, amount) return json_response({ 'address': address, 'amount': amount, @@ -98,7 +103,7 @@ class ConductorService: async def balance(self, _): return json_response({ - 'balance': await self.stack.blockchain_node.get_balance() + 'balance': await self.stack.lbcwallet_node.get_balance() }) async def log(self, request): @@ -129,7 +134,7 @@ class ConductorService: 'type': 'status', 'height': self.stack.wallet_node.ledger.headers.height, 'balance': satoshis_to_coins(await self.stack.wallet_node.account.get_balance()), - 'miner': await self.stack.blockchain_node.get_balance() + 'miner': await self.stack.lbcwallet_node.get_balance() }) def send_message(self, msg): diff --git a/lbry/wallet/server/__init__.py b/lbry/wallet/server/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/lbry/wallet/server/block_processor.py b/lbry/wallet/server/block_processor.py deleted file mode 100644 index bb233e2d5..000000000 --- a/lbry/wallet/server/block_processor.py +++ /dev/null @@ -1,1787 +0,0 @@ -import time -import asyncio -import typing -from bisect import bisect_right -from struct import pack, unpack -from concurrent.futures.thread import ThreadPoolExecutor -from typing import Optional, List, Tuple, Set, DefaultDict, Dict, NamedTuple -from prometheus_client import Gauge, Histogram -from collections import defaultdict - -import lbry -from lbry.schema.url import URL -from lbry.schema.claim import Claim -from lbry.wallet.ledger import Ledger, TestNetLedger, RegTestLedger -from lbry.utils import LRUCache -from lbry.wallet.transaction import OutputScript, Output, Transaction -from lbry.wallet.server.tx import Tx, TxOutput, TxInput -from lbry.wallet.server.daemon import DaemonError -from lbry.wallet.server.hash import hash_to_hex_str, HASHX_LEN -from lbry.wallet.server.util import chunks, class_logger -from lbry.crypto.hash import hash160 -from lbry.wallet.server.mempool import MemPool -from lbry.wallet.server.db.prefixes import ACTIVATED_SUPPORT_TXO_TYPE, ACTIVATED_CLAIM_TXO_TYPE -from lbry.wallet.server.db.prefixes import PendingActivationKey, PendingActivationValue, ClaimToTXOValue -from lbry.wallet.server.udp import StatusServer -from lbry.wallet.server.db.revertable import RevertableOpStack -if typing.TYPE_CHECKING: - from lbry.wallet.server.leveldb import LevelDB - - -class TrendingNotification(NamedTuple): - height: int - prev_amount: int - new_amount: int - - -class Prefetcher: - """Prefetches blocks (in the forward direction only).""" - - def __init__(self, daemon, coin, blocks_event): - self.logger = class_logger(__name__, self.__class__.__name__) - self.daemon = daemon - self.coin = coin - self.blocks_event = blocks_event - self.blocks = [] - self.caught_up = False - # Access to fetched_height should be protected by the semaphore - self.fetched_height = None - self.semaphore = asyncio.Semaphore() - self.refill_event = asyncio.Event() - # The prefetched block cache size. The min cache size has - # little effect on sync time. - self.cache_size = 0 - self.min_cache_size = 10 * 1024 * 1024 - # This makes the first fetch be 10 blocks - self.ave_size = self.min_cache_size // 10 - self.polling_delay = 5 - - async def main_loop(self, bp_height): - """Loop forever polling for more blocks.""" - await self.reset_height(bp_height) - while True: - try: - # Sleep a while if there is nothing to prefetch - await self.refill_event.wait() - if not await self._prefetch_blocks(): - await asyncio.sleep(self.polling_delay) - except DaemonError as e: - self.logger.info(f'ignoring daemon error: {e}') - - def get_prefetched_blocks(self): - """Called by block processor when it is processing queued blocks.""" - blocks = self.blocks - self.blocks = [] - self.cache_size = 0 - self.refill_event.set() - return blocks - - async def reset_height(self, height): - """Reset to prefetch blocks from the block processor's height. - - Used in blockchain reorganisations. This coroutine can be - called asynchronously to the _prefetch_blocks coroutine so we - must synchronize with a semaphore. - """ - async with self.semaphore: - self.blocks.clear() - self.cache_size = 0 - self.fetched_height = height - self.refill_event.set() - - daemon_height = await self.daemon.height() - behind = daemon_height - height - if behind > 0: - self.logger.info(f'catching up to daemon height {daemon_height:,d} ' - f'({behind:,d} blocks behind)') - else: - self.logger.info(f'caught up to daemon height {daemon_height:,d}') - - async def _prefetch_blocks(self): - """Prefetch some blocks and put them on the queue. - - Repeats until the queue is full or caught up. - """ - daemon = self.daemon - daemon_height = await daemon.height() - async with self.semaphore: - while self.cache_size < self.min_cache_size: - # Try and catch up all blocks but limit to room in cache. - # Constrain fetch count to between 0 and 500 regardless; - # testnet can be lumpy. - cache_room = self.min_cache_size // self.ave_size - count = min(daemon_height - self.fetched_height, cache_room) - count = min(500, max(count, 0)) - if not count: - self.caught_up = True - return False - - first = self.fetched_height + 1 - hex_hashes = await daemon.block_hex_hashes(first, count) - if self.caught_up: - self.logger.info('new block height {:,d} hash {}' - .format(first + count-1, hex_hashes[-1])) - blocks = await daemon.raw_blocks(hex_hashes) - - assert count == len(blocks) - - # Special handling for genesis block - if first == 0: - blocks[0] = self.coin.genesis_block(blocks[0]) - self.logger.info(f'verified genesis block with hash {hex_hashes[0]}') - - # Update our recent average block size estimate - size = sum(len(block) for block in blocks) - if count >= 10: - self.ave_size = size // count - else: - self.ave_size = (size + (10 - count) * self.ave_size) // 10 - - self.blocks.extend(blocks) - self.cache_size += size - self.fetched_height += count - self.blocks_event.set() - - self.refill_event.clear() - return True - - -class ChainError(Exception): - """Raised on error processing blocks.""" - - -class StagedClaimtrieItem(typing.NamedTuple): - name: str - normalized_name: str - claim_hash: bytes - amount: int - expiration_height: int - tx_num: int - position: int - root_tx_num: int - root_position: int - channel_signature_is_valid: bool - signing_hash: Optional[bytes] - reposted_claim_hash: Optional[bytes] - - @property - def is_update(self) -> bool: - return (self.tx_num, self.position) != (self.root_tx_num, self.root_position) - - def invalidate_signature(self) -> 'StagedClaimtrieItem': - return StagedClaimtrieItem( - self.name, self.normalized_name, self.claim_hash, self.amount, self.expiration_height, self.tx_num, - self.position, self.root_tx_num, self.root_position, False, None, self.reposted_claim_hash - ) - - -NAMESPACE = "wallet_server" -HISTOGRAM_BUCKETS = ( - .005, .01, .025, .05, .075, .1, .25, .5, .75, 1.0, 2.5, 5.0, 7.5, 10.0, 15.0, 20.0, 30.0, 60.0, float('inf') -) - - -class BlockProcessor: - """Process blocks and update the DB state to match. - - Employ a prefetcher to prefetch blocks in batches for processing. - Coordinate backing up in case of chain reorganisations. - """ - - block_count_metric = Gauge( - "block_count", "Number of processed blocks", namespace=NAMESPACE - ) - block_update_time_metric = Histogram( - "block_time", "Block update times", namespace=NAMESPACE, buckets=HISTOGRAM_BUCKETS - ) - reorg_count_metric = Gauge( - "reorg_count", "Number of reorgs", namespace=NAMESPACE - ) - - def __init__(self, env, db: 'LevelDB', daemon, shutdown_event: asyncio.Event): - self.state_lock = asyncio.Lock() - self.env = env - self.db = db - self.daemon = daemon - self._chain_executor = ThreadPoolExecutor(1, thread_name_prefix='block-processor') - self._sync_reader_executor = ThreadPoolExecutor(1, thread_name_prefix='hub-es-sync') - self.mempool = MemPool(env.coin, daemon, db, self.state_lock) - self.shutdown_event = shutdown_event - self.coin = env.coin - if env.coin.NET == 'mainnet': - self.ledger = Ledger - elif env.coin.NET == 'testnet': - self.ledger = TestNetLedger - else: - self.ledger = RegTestLedger - - self._caught_up_event: Optional[asyncio.Event] = None - self.height = 0 - self.tip = bytes.fromhex(self.coin.GENESIS_HASH)[::-1] - self.tx_count = 0 - - self.blocks_event = asyncio.Event() - self.prefetcher = Prefetcher(daemon, env.coin, self.blocks_event) - self.logger = class_logger(__name__, self.__class__.__name__) - - # Meta - self.touched_hashXs: Set[bytes] = set() - - # UTXO cache - self.utxo_cache: Dict[Tuple[bytes, int], Tuple[bytes, int]] = {} - - # Claimtrie cache - self.db_op_stack: Optional[RevertableOpStack] = None - - # self.search_cache = {} - self.resolve_cache = LRUCache(2**16) - self.resolve_outputs_cache = LRUCache(2 ** 16) - - self.history_cache = {} - self.status_server = StatusServer() - - ################################# - # attributes used for calculating stake activations and takeovers per block - ################################# - - self.taken_over_names: Set[str] = set() - # txo to pending claim - self.txo_to_claim: Dict[Tuple[int, int], StagedClaimtrieItem] = {} - # claim hash to pending claim txo - self.claim_hash_to_txo: Dict[bytes, Tuple[int, int]] = {} - # claim hash to lists of pending support txos - self.support_txos_by_claim: DefaultDict[bytes, List[Tuple[int, int]]] = defaultdict(list) - # support txo: (supported claim hash, support amount) - self.support_txo_to_claim: Dict[Tuple[int, int], Tuple[bytes, int]] = {} - # removed supports {name: {claim_hash: [(tx_num, nout), ...]}} - self.removed_support_txos_by_name_by_claim: DefaultDict[str, DefaultDict[bytes, List[Tuple[int, int]]]] = \ - defaultdict(lambda: defaultdict(list)) - self.abandoned_claims: Dict[bytes, StagedClaimtrieItem] = {} - self.updated_claims: Set[bytes] = set() - # removed activated support amounts by claim hash - self.removed_active_support_amount_by_claim: DefaultDict[bytes, List[int]] = defaultdict(list) - # pending activated support amounts by claim hash - self.activated_support_amount_by_claim: DefaultDict[bytes, List[int]] = defaultdict(list) - # pending activated name and claim hash to claim/update txo amount - self.activated_claim_amount_by_name_and_hash: Dict[Tuple[str, bytes], int] = {} - # pending claim and support activations per claim hash per name, - # used to process takeovers due to added activations - activation_by_claim_by_name_type = DefaultDict[str, DefaultDict[bytes, List[Tuple[PendingActivationKey, int]]]] - self.activation_by_claim_by_name: activation_by_claim_by_name_type = defaultdict(lambda: defaultdict(list)) - # these are used for detecting early takeovers by not yet activated claims/supports - self.possible_future_support_amounts_by_claim_hash: DefaultDict[bytes, List[int]] = defaultdict(list) - self.possible_future_claim_amount_by_name_and_hash: Dict[Tuple[str, bytes], int] = {} - self.possible_future_support_txos_by_claim_hash: DefaultDict[bytes, List[Tuple[int, int]]] = defaultdict(list) - - self.removed_claims_to_send_es = set() # cumulative changes across blocks to send ES - self.touched_claims_to_send_es = set() - self.activation_info_to_send_es: DefaultDict[str, List[TrendingNotification]] = defaultdict(list) - - self.removed_claim_hashes: Set[bytes] = set() # per block changes - self.touched_claim_hashes: Set[bytes] = set() - - self.signatures_changed = set() - - self.pending_reposted = set() - self.pending_channel_counts = defaultdict(lambda: 0) - self.pending_support_amount_change = defaultdict(lambda: 0) - - self.pending_channels = {} - self.amount_cache = {} - self.expired_claim_hashes: Set[bytes] = set() - - self.doesnt_have_valid_signature: Set[bytes] = set() - self.claim_channels: Dict[bytes, bytes] = {} - self.hashXs_by_tx: DefaultDict[bytes, List[int]] = defaultdict(list) - - self.pending_transaction_num_mapping: Dict[bytes, int] = {} - self.pending_transactions: Dict[int, bytes] = {} - - async def claim_producer(self): - if self.db.db_height <= 1: - return - - for claim_hash in self.removed_claims_to_send_es: - yield 'delete', claim_hash.hex() - - to_update = await asyncio.get_event_loop().run_in_executor( - self._sync_reader_executor, self.db.claims_producer, self.touched_claims_to_send_es - ) - for claim in to_update: - yield 'update', claim - - async def run_in_thread_with_lock(self, func, *args): - # Run in a thread to prevent blocking. Shielded so that - # cancellations from shutdown don't lose work - when the task - # completes the data will be flushed and then we shut down. - # Take the state lock to be certain in-memory state is - # consistent and not being updated elsewhere. - async def run_in_thread_locked(): - async with self.state_lock: - return await asyncio.get_event_loop().run_in_executor(self._chain_executor, func, *args) - return await asyncio.shield(run_in_thread_locked()) - - async def run_in_thread(self, func, *args): - async def run_in_thread(): - return await asyncio.get_event_loop().run_in_executor(self._chain_executor, func, *args) - return await asyncio.shield(run_in_thread()) - - async def check_and_advance_blocks(self, raw_blocks): - """Process the list of raw blocks passed. Detects and handles - reorgs. - """ - - if not raw_blocks: - return - first = self.height + 1 - blocks = [self.coin.block(raw_block, first + n) - for n, raw_block in enumerate(raw_blocks)] - headers = [block.header for block in blocks] - hprevs = [self.coin.header_prevhash(h) for h in headers] - chain = [self.tip] + [self.coin.header_hash(h) for h in headers[:-1]] - - if hprevs == chain: - total_start = time.perf_counter() - try: - for block in blocks: - start = time.perf_counter() - await self.run_in_thread(self.advance_block, block) - await self.flush() - - self.logger.info("advanced to %i in %0.3fs", self.height, time.perf_counter() - start) - if self.height == self.coin.nExtendedClaimExpirationForkHeight: - self.logger.warning( - "applying extended claim expiration fork on claims accepted by, %i", self.height - ) - await self.run_in_thread_with_lock(self.db.apply_expiration_extension_fork) - if self.db.first_sync: - self.db.search_index.clear_caches() - self.touched_claims_to_send_es.clear() - self.removed_claims_to_send_es.clear() - self.activation_info_to_send_es.clear() - # TODO: we shouldnt wait on the search index updating before advancing to the next block - if not self.db.first_sync: - await self.db.reload_blocking_filtering_streams() - await self.db.search_index.claim_consumer(self.claim_producer()) - await self.db.search_index.apply_filters(self.db.blocked_streams, self.db.blocked_channels, - self.db.filtered_streams, self.db.filtered_channels) - await self.db.search_index.update_trending_score(self.activation_info_to_send_es) - await self._es_caught_up() - self.db.search_index.clear_caches() - self.touched_claims_to_send_es.clear() - self.removed_claims_to_send_es.clear() - self.activation_info_to_send_es.clear() - # print("******************\n") - except: - self.logger.exception("advance blocks failed") - raise - processed_time = time.perf_counter() - total_start - self.block_count_metric.set(self.height) - self.block_update_time_metric.observe(processed_time) - self.status_server.set_height(self.db.fs_height, self.db.db_tip) - if not self.db.first_sync: - s = '' if len(blocks) == 1 else 's' - self.logger.info('processed {:,d} block{} in {:.1f}s'.format(len(blocks), s, processed_time)) - if self._caught_up_event.is_set(): - await self.mempool.on_block(self.touched_hashXs, self.height) - self.touched_hashXs.clear() - elif hprevs[0] != chain[0]: - min_start_height = max(self.height - self.coin.REORG_LIMIT, 0) - count = 1 - block_hashes_from_lbrycrd = await self.daemon.block_hex_hashes( - min_start_height, self.coin.REORG_LIMIT - ) - for height, block_hash in zip( - reversed(range(min_start_height, min_start_height + self.coin.REORG_LIMIT)), - reversed(block_hashes_from_lbrycrd)): - if self.db.get_block_hash(height)[::-1].hex() == block_hash: - break - count += 1 - self.logger.warning(f"blockchain reorg detected at {self.height}, unwinding last {count} blocks") - try: - assert count > 0, count - for _ in range(count): - await self.backup_block() - self.logger.info(f'backed up to height {self.height:,d}') - - if self.env.cache_all_claim_txos: - await self.db._read_claim_txos() # TODO: don't do this - for touched in self.touched_claims_to_send_es: - if not self.db.get_claim_txo(touched): - self.removed_claims_to_send_es.add(touched) - self.touched_claims_to_send_es.difference_update(self.removed_claims_to_send_es) - await self.db.search_index.claim_consumer(self.claim_producer()) - self.db.search_index.clear_caches() - self.touched_claims_to_send_es.clear() - self.removed_claims_to_send_es.clear() - self.activation_info_to_send_es.clear() - await self.prefetcher.reset_height(self.height) - self.reorg_count_metric.inc() - except: - self.logger.exception("reorg blocks failed") - raise - finally: - self.logger.info("backed up to block %i", self.height) - else: - # It is probably possible but extremely rare that what - # bitcoind returns doesn't form a chain because it - # reorg-ed the chain as it was processing the batched - # block hash requests. Should this happen it's simplest - # just to reset the prefetcher and try again. - self.logger.warning('daemon blocks do not form a chain; ' - 'resetting the prefetcher') - await self.prefetcher.reset_height(self.height) - - async def flush(self): - save_undo = (self.daemon.cached_height() - self.height) <= self.env.reorg_limit - - def flush(): - self.db.write_db_state() - if save_undo: - self.db.prefix_db.commit(self.height) - else: - self.db.prefix_db.unsafe_commit() - self.clear_after_advance_or_reorg() - self.db.assert_db_state() - await self.run_in_thread_with_lock(flush) - - def _add_claim_or_update(self, height: int, txo: 'Output', tx_hash: bytes, tx_num: int, nout: int, - spent_claims: typing.Dict[bytes, typing.Tuple[int, int, str]]): - try: - claim_name = txo.script.values['claim_name'].decode() - except UnicodeDecodeError: - claim_name = ''.join(chr(c) for c in txo.script.values['claim_name']) - try: - normalized_name = txo.normalized_name - except UnicodeDecodeError: - normalized_name = claim_name - if txo.script.is_claim_name: - claim_hash = hash160(tx_hash + pack('>I', nout))[::-1] - # print(f"\tnew {claim_hash.hex()} ({tx_num} {txo.amount})") - else: - claim_hash = txo.claim_hash[::-1] - # print(f"\tupdate {claim_hash.hex()} ({tx_num} {txo.amount})") - - signing_channel_hash = None - channel_signature_is_valid = False - try: - signable = txo.signable - is_repost = txo.claim.is_repost - is_channel = txo.claim.is_channel - if txo.claim.is_signed: - signing_channel_hash = txo.signable.signing_channel_hash[::-1] - except: # google.protobuf.message.DecodeError: Could not parse JSON. - signable = None - is_repost = False - is_channel = False - - reposted_claim_hash = None - - if is_repost: - reposted_claim_hash = txo.claim.repost.reference.claim_hash[::-1] - self.pending_reposted.add(reposted_claim_hash) - - if is_channel: - self.pending_channels[claim_hash] = txo.claim.channel.public_key_bytes - - self.doesnt_have_valid_signature.add(claim_hash) - raw_channel_tx = None - if signable and signable.signing_channel_hash: - signing_channel = self.db.get_claim_txo(signing_channel_hash) - - if signing_channel: - raw_channel_tx = self.db.prefix_db.tx.get( - self.db.get_tx_hash(signing_channel.tx_num), deserialize_value=False - ) - channel_pub_key_bytes = None - try: - if not signing_channel: - if txo.signable.signing_channel_hash[::-1] in self.pending_channels: - channel_pub_key_bytes = self.pending_channels[signing_channel_hash] - elif raw_channel_tx: - chan_output = self.coin.transaction(raw_channel_tx).outputs[signing_channel.position] - chan_script = OutputScript(chan_output.pk_script) - chan_script.parse() - channel_meta = Claim.from_bytes(chan_script.values['claim']) - - channel_pub_key_bytes = channel_meta.channel.public_key_bytes - if channel_pub_key_bytes: - channel_signature_is_valid = Output.is_signature_valid( - txo.signable.signature, txo.get_signature_digest(self.ledger), channel_pub_key_bytes - ) - if channel_signature_is_valid: - self.pending_channel_counts[signing_channel_hash] += 1 - self.doesnt_have_valid_signature.remove(claim_hash) - self.claim_channels[claim_hash] = signing_channel_hash - except: - self.logger.exception(f"error validating channel signature for %s:%i", tx_hash[::-1].hex(), nout) - - if txo.script.is_claim_name: # it's a root claim - root_tx_num, root_idx = tx_num, nout - previous_amount = 0 - else: # it's a claim update - if claim_hash not in spent_claims: - # print(f"\tthis is a wonky tx, contains unlinked claim update {claim_hash.hex()}") - return - if normalized_name != spent_claims[claim_hash][2]: - self.logger.warning( - f"{tx_hash[::-1].hex()} contains mismatched name for claim update {claim_hash.hex()}" - ) - return - (prev_tx_num, prev_idx, _) = spent_claims.pop(claim_hash) - # print(f"\tupdate {claim_hash.hex()} {tx_hash[::-1].hex()} {txo.amount}") - if (prev_tx_num, prev_idx) in self.txo_to_claim: - previous_claim = self.txo_to_claim.pop((prev_tx_num, prev_idx)) - self.claim_hash_to_txo.pop(claim_hash) - root_tx_num, root_idx = previous_claim.root_tx_num, previous_claim.root_position - else: - previous_claim = self._make_pending_claim_txo(claim_hash) - root_tx_num, root_idx = previous_claim.root_tx_num, previous_claim.root_position - activation = self.db.get_activation(prev_tx_num, prev_idx) - claim_name = previous_claim.name - self.get_remove_activate_ops( - ACTIVATED_CLAIM_TXO_TYPE, claim_hash, prev_tx_num, prev_idx, activation, normalized_name, - previous_claim.amount - ) - previous_amount = previous_claim.amount - self.updated_claims.add(claim_hash) - - if self.env.cache_all_claim_txos: - self.db.claim_to_txo[claim_hash] = ClaimToTXOValue( - tx_num, nout, root_tx_num, root_idx, txo.amount, channel_signature_is_valid, claim_name - ) - self.db.txo_to_claim[tx_num][nout] = claim_hash - - pending = StagedClaimtrieItem( - claim_name, normalized_name, claim_hash, txo.amount, self.coin.get_expiration_height(height), tx_num, nout, - root_tx_num, root_idx, channel_signature_is_valid, signing_channel_hash, reposted_claim_hash - ) - self.txo_to_claim[(tx_num, nout)] = pending - self.claim_hash_to_txo[claim_hash] = (tx_num, nout) - self.get_add_claim_utxo_ops(pending) - - def get_add_claim_utxo_ops(self, pending: StagedClaimtrieItem): - # claim tip by claim hash - self.db.prefix_db.claim_to_txo.stage_put( - (pending.claim_hash,), (pending.tx_num, pending.position, pending.root_tx_num, pending.root_position, - pending.amount, pending.channel_signature_is_valid, pending.name) - ) - # claim hash by txo - self.db.prefix_db.txo_to_claim.stage_put( - (pending.tx_num, pending.position), (pending.claim_hash, pending.normalized_name) - ) - - # claim expiration - self.db.prefix_db.claim_expiration.stage_put( - (pending.expiration_height, pending.tx_num, pending.position), - (pending.claim_hash, pending.normalized_name) - ) - - # short url resolution - for prefix_len in range(10): - self.db.prefix_db.claim_short_id.stage_put( - (pending.normalized_name, pending.claim_hash.hex()[:prefix_len + 1], - pending.root_tx_num, pending.root_position), - (pending.tx_num, pending.position) - ) - - if pending.signing_hash and pending.channel_signature_is_valid: - # channel by stream - self.db.prefix_db.claim_to_channel.stage_put( - (pending.claim_hash, pending.tx_num, pending.position), (pending.signing_hash,) - ) - # stream by channel - self.db.prefix_db.channel_to_claim.stage_put( - (pending.signing_hash, pending.normalized_name, pending.tx_num, pending.position), - (pending.claim_hash,) - ) - - if pending.reposted_claim_hash: - self.db.prefix_db.repost.stage_put((pending.claim_hash,), (pending.reposted_claim_hash,)) - self.db.prefix_db.reposted_claim.stage_put( - (pending.reposted_claim_hash, pending.tx_num, pending.position), (pending.claim_hash,) - ) - - def get_remove_claim_utxo_ops(self, pending: StagedClaimtrieItem): - # claim tip by claim hash - self.db.prefix_db.claim_to_txo.stage_delete( - (pending.claim_hash,), (pending.tx_num, pending.position, pending.root_tx_num, pending.root_position, - pending.amount, pending.channel_signature_is_valid, pending.name) - ) - # claim hash by txo - self.db.prefix_db.txo_to_claim.stage_delete( - (pending.tx_num, pending.position), (pending.claim_hash, pending.normalized_name) - ) - - # claim expiration - self.db.prefix_db.claim_expiration.stage_delete( - (pending.expiration_height, pending.tx_num, pending.position), - (pending.claim_hash, pending.normalized_name) - ) - - # short url resolution - for prefix_len in range(10): - self.db.prefix_db.claim_short_id.stage_delete( - (pending.normalized_name, pending.claim_hash.hex()[:prefix_len + 1], - pending.root_tx_num, pending.root_position), - (pending.tx_num, pending.position) - ) - - if pending.signing_hash and pending.channel_signature_is_valid: - # channel by stream - self.db.prefix_db.claim_to_channel.stage_delete( - (pending.claim_hash, pending.tx_num, pending.position), (pending.signing_hash,) - ) - # stream by channel - self.db.prefix_db.channel_to_claim.stage_delete( - (pending.signing_hash, pending.normalized_name, pending.tx_num, pending.position), - (pending.claim_hash,) - ) - - if pending.reposted_claim_hash: - self.db.prefix_db.repost.stage_delete((pending.claim_hash,), (pending.reposted_claim_hash,)) - self.db.prefix_db.reposted_claim.stage_delete( - (pending.reposted_claim_hash, pending.tx_num, pending.position), (pending.claim_hash,) - ) - - def _add_support(self, height: int, txo: 'Output', tx_num: int, nout: int): - supported_claim_hash = txo.claim_hash[::-1] - self.support_txos_by_claim[supported_claim_hash].append((tx_num, nout)) - self.support_txo_to_claim[(tx_num, nout)] = supported_claim_hash, txo.amount - # print(f"\tsupport claim {supported_claim_hash.hex()} +{txo.amount}") - - self.db.prefix_db.claim_to_support.stage_put((supported_claim_hash, tx_num, nout), (txo.amount,)) - self.db.prefix_db.support_to_claim.stage_put((tx_num, nout), (supported_claim_hash,)) - self.pending_support_amount_change[supported_claim_hash] += txo.amount - - def _add_claim_or_support(self, height: int, tx_hash: bytes, tx_num: int, nout: int, txo: 'Output', - spent_claims: typing.Dict[bytes, Tuple[int, int, str]]): - if txo.script.is_claim_name or txo.script.is_update_claim: - self._add_claim_or_update(height, txo, tx_hash, tx_num, nout, spent_claims) - elif txo.script.is_support_claim or txo.script.is_support_claim_data: - self._add_support(height, txo, tx_num, nout) - - def _spend_support_txo(self, height: int, txin: TxInput): - txin_num = self.get_pending_tx_num(txin.prev_hash) - activation = 0 - if (txin_num, txin.prev_idx) in self.support_txo_to_claim: - spent_support, support_amount = self.support_txo_to_claim.pop((txin_num, txin.prev_idx)) - self.support_txos_by_claim[spent_support].remove((txin_num, txin.prev_idx)) - supported_name = self._get_pending_claim_name(spent_support) - self.removed_support_txos_by_name_by_claim[supported_name][spent_support].append((txin_num, txin.prev_idx)) - else: - spent_support, support_amount = self.db.get_supported_claim_from_txo(txin_num, txin.prev_idx) - if not spent_support: # it is not a support - return - supported_name = self._get_pending_claim_name(spent_support) - if supported_name is not None: - self.removed_support_txos_by_name_by_claim[supported_name][spent_support].append( - (txin_num, txin.prev_idx)) - activation = self.db.get_activation(txin_num, txin.prev_idx, is_support=True) - if 0 < activation < self.height + 1: - self.removed_active_support_amount_by_claim[spent_support].append(support_amount) - if supported_name is not None and activation > 0: - self.get_remove_activate_ops( - ACTIVATED_SUPPORT_TXO_TYPE, spent_support, txin_num, txin.prev_idx, activation, supported_name, - support_amount - ) - # print(f"\tspent support for {spent_support.hex()} activation:{activation} {support_amount}") - self.db.prefix_db.claim_to_support.stage_delete((spent_support, txin_num, txin.prev_idx), (support_amount,)) - self.db.prefix_db.support_to_claim.stage_delete((txin_num, txin.prev_idx), (spent_support,)) - self.pending_support_amount_change[spent_support] -= support_amount - - def _spend_claim_txo(self, txin: TxInput, spent_claims: Dict[bytes, Tuple[int, int, str]]) -> bool: - txin_num = self.get_pending_tx_num(txin.prev_hash) - if (txin_num, txin.prev_idx) in self.txo_to_claim: - spent = self.txo_to_claim[(txin_num, txin.prev_idx)] - else: - if not self.db.get_cached_claim_exists(txin_num, txin.prev_idx): - # txo is not a claim - return False - spent_claim_hash_and_name = self.db.get_claim_from_txo( - txin_num, txin.prev_idx - ) - assert spent_claim_hash_and_name is not None - spent = self._make_pending_claim_txo(spent_claim_hash_and_name.claim_hash) - - if self.env.cache_all_claim_txos: - claim_hash = self.db.txo_to_claim[txin_num].pop(txin.prev_idx) - if not self.db.txo_to_claim[txin_num]: - self.db.txo_to_claim.pop(txin_num) - self.db.claim_to_txo.pop(claim_hash) - if spent.reposted_claim_hash: - self.pending_reposted.add(spent.reposted_claim_hash) - if spent.signing_hash and spent.channel_signature_is_valid and spent.signing_hash not in self.abandoned_claims: - self.pending_channel_counts[spent.signing_hash] -= 1 - spent_claims[spent.claim_hash] = (spent.tx_num, spent.position, spent.normalized_name) - # print(f"\tspend lbry://{spent.name}#{spent.claim_hash.hex()}") - self.get_remove_claim_utxo_ops(spent) - return True - - def _spend_claim_or_support_txo(self, height: int, txin: TxInput, spent_claims): - if not self._spend_claim_txo(txin, spent_claims): - self._spend_support_txo(height, txin) - - def _abandon_claim(self, claim_hash: bytes, tx_num: int, nout: int, normalized_name: str): - if (tx_num, nout) in self.txo_to_claim: - pending = self.txo_to_claim.pop((tx_num, nout)) - self.claim_hash_to_txo.pop(claim_hash) - self.abandoned_claims[pending.claim_hash] = pending - claim_root_tx_num, claim_root_idx = pending.root_tx_num, pending.root_position - prev_amount, prev_signing_hash = pending.amount, pending.signing_hash - reposted_claim_hash, name = pending.reposted_claim_hash, pending.name - expiration = self.coin.get_expiration_height(self.height) - signature_is_valid = pending.channel_signature_is_valid - else: - v = self.db.get_claim_txo( - claim_hash - ) - claim_root_tx_num, claim_root_idx, prev_amount = v.root_tx_num, v.root_position, v.amount - signature_is_valid, name = v.channel_signature_is_valid, v.name - prev_signing_hash = self.db.get_channel_for_claim(claim_hash, tx_num, nout) - reposted_claim_hash = self.db.get_repost(claim_hash) - expiration = self.coin.get_expiration_height(bisect_right(self.db.tx_counts, tx_num)) - self.abandoned_claims[claim_hash] = staged = StagedClaimtrieItem( - name, normalized_name, claim_hash, prev_amount, expiration, tx_num, nout, claim_root_tx_num, - claim_root_idx, signature_is_valid, prev_signing_hash, reposted_claim_hash - ) - for support_txo_to_clear in self.support_txos_by_claim[claim_hash]: - self.support_txo_to_claim.pop(support_txo_to_clear) - self.support_txos_by_claim[claim_hash].clear() - self.support_txos_by_claim.pop(claim_hash) - if claim_hash.hex() in self.activation_info_to_send_es: - self.activation_info_to_send_es.pop(claim_hash.hex()) - if normalized_name.startswith('@'): # abandon a channel, invalidate signatures - self._invalidate_channel_signatures(claim_hash) - - def _get_invalidate_signature_ops(self, pending: StagedClaimtrieItem): - if not pending.signing_hash: - return - self.db.prefix_db.claim_to_channel.stage_delete( - (pending.claim_hash, pending.tx_num, pending.position), (pending.signing_hash,) - ) - if pending.channel_signature_is_valid: - self.db.prefix_db.channel_to_claim.stage_delete( - (pending.signing_hash, pending.normalized_name, pending.tx_num, pending.position), - (pending.claim_hash,) - ) - self.db.prefix_db.claim_to_txo.stage_delete( - (pending.claim_hash,), - (pending.tx_num, pending.position, pending.root_tx_num, pending.root_position, pending.amount, - pending.channel_signature_is_valid, pending.name) - ) - self.db.prefix_db.claim_to_txo.stage_put( - (pending.claim_hash,), - (pending.tx_num, pending.position, pending.root_tx_num, pending.root_position, pending.amount, - False, pending.name) - ) - - def _invalidate_channel_signatures(self, claim_hash: bytes): - for (signed_claim_hash, ) in self.db.prefix_db.channel_to_claim.iterate( - prefix=(claim_hash, ), include_key=False): - if signed_claim_hash in self.abandoned_claims or signed_claim_hash in self.expired_claim_hashes: - continue - # there is no longer a signing channel for this claim as of this block - if signed_claim_hash in self.doesnt_have_valid_signature: - continue - # the signing channel changed in this block - if signed_claim_hash in self.claim_channels and signed_claim_hash != self.claim_channels[signed_claim_hash]: - continue - - # if the claim with an invalidated signature is in this block, update the StagedClaimtrieItem - # so that if we later try to spend it in this block we won't try to delete the channel info twice - if signed_claim_hash in self.claim_hash_to_txo: - signed_claim_txo = self.claim_hash_to_txo[signed_claim_hash] - claim = self.txo_to_claim[signed_claim_txo] - if claim.signing_hash != claim_hash: # claim was already invalidated this block - continue - self.txo_to_claim[signed_claim_txo] = claim.invalidate_signature() - else: - claim = self._make_pending_claim_txo(signed_claim_hash) - self.signatures_changed.add(signed_claim_hash) - self.pending_channel_counts[claim_hash] -= 1 - self._get_invalidate_signature_ops(claim) - - for staged in list(self.txo_to_claim.values()): - needs_invalidate = staged.claim_hash not in self.doesnt_have_valid_signature - if staged.signing_hash == claim_hash and needs_invalidate: - self._get_invalidate_signature_ops(staged) - self.txo_to_claim[self.claim_hash_to_txo[staged.claim_hash]] = staged.invalidate_signature() - self.signatures_changed.add(staged.claim_hash) - self.pending_channel_counts[claim_hash] -= 1 - - def _make_pending_claim_txo(self, claim_hash: bytes): - claim = self.db.get_claim_txo(claim_hash) - if claim_hash in self.doesnt_have_valid_signature: - signing_hash = None - else: - signing_hash = self.db.get_channel_for_claim(claim_hash, claim.tx_num, claim.position) - reposted_claim_hash = self.db.get_repost(claim_hash) - return StagedClaimtrieItem( - claim.name, claim.normalized_name, claim_hash, claim.amount, - self.coin.get_expiration_height( - bisect_right(self.db.tx_counts, claim.tx_num), - extended=self.height >= self.coin.nExtendedClaimExpirationForkHeight - ), - claim.tx_num, claim.position, claim.root_tx_num, claim.root_position, - claim.channel_signature_is_valid, signing_hash, reposted_claim_hash - ) - - def _expire_claims(self, height: int): - expired = self.db.get_expired_by_height(height) - self.expired_claim_hashes.update(set(expired.keys())) - spent_claims = {} - for expired_claim_hash, (tx_num, position, name, txi) in expired.items(): - if (tx_num, position) not in self.txo_to_claim: - self._spend_claim_txo(txi, spent_claims) - if expired: - # abandon the channels last to handle abandoned signed claims in the same tx, - # see test_abandon_channel_and_claims_in_same_tx - expired_channels = {} - for abandoned_claim_hash, (tx_num, nout, normalized_name) in spent_claims.items(): - self._abandon_claim(abandoned_claim_hash, tx_num, nout, normalized_name) - - if normalized_name.startswith('@'): - expired_channels[abandoned_claim_hash] = (tx_num, nout, normalized_name) - else: - # print(f"\texpire {abandoned_claim_hash.hex()} {tx_num} {nout}") - self._abandon_claim(abandoned_claim_hash, tx_num, nout, normalized_name) - - # do this to follow the same content claim removing pathway as if a claim (possible channel) was abandoned - for abandoned_claim_hash, (tx_num, nout, normalized_name) in expired_channels.items(): - # print(f"\texpire {abandoned_claim_hash.hex()} {tx_num} {nout}") - self._abandon_claim(abandoned_claim_hash, tx_num, nout, normalized_name) - - def _cached_get_active_amount(self, claim_hash: bytes, txo_type: int, height: int) -> int: - if (claim_hash, txo_type, height) in self.amount_cache: - return self.amount_cache[(claim_hash, txo_type, height)] - if txo_type == ACTIVATED_CLAIM_TXO_TYPE: - if claim_hash in self.claim_hash_to_txo: - amount = self.txo_to_claim[self.claim_hash_to_txo[claim_hash]].amount - else: - amount = self.db.get_active_amount_as_of_height( - claim_hash, height - ) - self.amount_cache[(claim_hash, txo_type, height)] = amount - else: - self.amount_cache[(claim_hash, txo_type, height)] = amount = self.db._get_active_amount( - claim_hash, txo_type, height - ) - return amount - - def _get_pending_claim_amount(self, name: str, claim_hash: bytes, height=None) -> int: - if (name, claim_hash) in self.activated_claim_amount_by_name_and_hash: - if claim_hash in self.claim_hash_to_txo: - return self.txo_to_claim[self.claim_hash_to_txo[claim_hash]].amount - return self.activated_claim_amount_by_name_and_hash[(name, claim_hash)] - if (name, claim_hash) in self.possible_future_claim_amount_by_name_and_hash: - return self.possible_future_claim_amount_by_name_and_hash[(name, claim_hash)] - return self._cached_get_active_amount(claim_hash, ACTIVATED_CLAIM_TXO_TYPE, height or (self.height + 1)) - - def _get_pending_claim_name(self, claim_hash: bytes) -> Optional[str]: - assert claim_hash is not None - if claim_hash in self.claim_hash_to_txo: - return self.txo_to_claim[self.claim_hash_to_txo[claim_hash]].normalized_name - claim_info = self.db.get_claim_txo(claim_hash) - if claim_info: - return claim_info.normalized_name - - def _get_pending_supported_amount(self, claim_hash: bytes, height: Optional[int] = None) -> int: - amount = self._cached_get_active_amount(claim_hash, ACTIVATED_SUPPORT_TXO_TYPE, height or (self.height + 1)) - if claim_hash in self.activated_support_amount_by_claim: - amount += sum(self.activated_support_amount_by_claim[claim_hash]) - if claim_hash in self.possible_future_support_amounts_by_claim_hash: - amount += sum(self.possible_future_support_amounts_by_claim_hash[claim_hash]) - if claim_hash in self.removed_active_support_amount_by_claim: - return amount - sum(self.removed_active_support_amount_by_claim[claim_hash]) - return amount - - def _get_pending_effective_amount(self, name: str, claim_hash: bytes, height: Optional[int] = None) -> int: - claim_amount = self._get_pending_claim_amount(name, claim_hash, height=height) - support_amount = self._get_pending_supported_amount(claim_hash, height=height) - return claim_amount + support_amount - - def get_activate_ops(self, txo_type: int, claim_hash: bytes, tx_num: int, position: int, - activation_height: int, name: str, amount: int): - self.db.prefix_db.activated.stage_put( - (txo_type, tx_num, position), (activation_height, claim_hash, name) - ) - self.db.prefix_db.pending_activation.stage_put( - (activation_height, txo_type, tx_num, position), (claim_hash, name) - ) - self.db.prefix_db.active_amount.stage_put( - (claim_hash, txo_type, activation_height, tx_num, position), (amount,) - ) - - def get_remove_activate_ops(self, txo_type: int, claim_hash: bytes, tx_num: int, position: int, - activation_height: int, name: str, amount: int): - self.db.prefix_db.activated.stage_delete( - (txo_type, tx_num, position), (activation_height, claim_hash, name) - ) - self.db.prefix_db.pending_activation.stage_delete( - (activation_height, txo_type, tx_num, position), (claim_hash, name) - ) - self.db.prefix_db.active_amount.stage_delete( - (claim_hash, txo_type, activation_height, tx_num, position), (amount,) - ) - - def _get_takeover_ops(self, height: int): - - # cache for controlling claims as of the previous block - controlling_claims = {} - - def get_controlling(_name): - if _name not in controlling_claims: - _controlling = self.db.get_controlling_claim(_name) - controlling_claims[_name] = _controlling - else: - _controlling = controlling_claims[_name] - return _controlling - - names_with_abandoned_or_updated_controlling_claims: List[str] = [] - - # get the claims and supports previously scheduled to be activated at this block - activated_at_height = self.db.get_activated_at_height(height) - activate_in_future = defaultdict(lambda: defaultdict(list)) - future_activations = defaultdict(dict) - - def get_delayed_activate_ops(name: str, claim_hash: bytes, is_new_claim: bool, tx_num: int, nout: int, - amount: int, is_support: bool): - controlling = get_controlling(name) - nothing_is_controlling = not controlling - staged_is_controlling = False if not controlling else claim_hash == controlling.claim_hash - controlling_is_abandoned = False if not controlling else \ - name in names_with_abandoned_or_updated_controlling_claims - - if nothing_is_controlling or staged_is_controlling or controlling_is_abandoned: - delay = 0 - elif is_new_claim: - delay = self.coin.get_delay_for_name(height - controlling.height) - else: - controlling_effective_amount = self._get_pending_effective_amount(name, controlling.claim_hash) - staged_effective_amount = self._get_pending_effective_amount(name, claim_hash) - staged_update_could_cause_takeover = staged_effective_amount > controlling_effective_amount - delay = 0 if not staged_update_could_cause_takeover else self.coin.get_delay_for_name( - height - controlling.height - ) - if delay == 0: # if delay was 0 it needs to be considered for takeovers - activated_at_height[PendingActivationValue(claim_hash, name)].append( - PendingActivationKey( - height, ACTIVATED_SUPPORT_TXO_TYPE if is_support else ACTIVATED_CLAIM_TXO_TYPE, tx_num, nout - ) - ) - else: # if the delay was higher if still needs to be considered if something else triggers a takeover - activate_in_future[name][claim_hash].append(( - PendingActivationKey( - height + delay, ACTIVATED_SUPPORT_TXO_TYPE if is_support else ACTIVATED_CLAIM_TXO_TYPE, - tx_num, nout - ), amount - )) - if is_support: - self.possible_future_support_txos_by_claim_hash[claim_hash].append((tx_num, nout)) - self.get_activate_ops( - ACTIVATED_SUPPORT_TXO_TYPE if is_support else ACTIVATED_CLAIM_TXO_TYPE, claim_hash, tx_num, nout, - height + delay, name, amount - ) - - # determine names needing takeover/deletion due to controlling claims being abandoned - # and add ops to deactivate abandoned claims - for claim_hash, staged in self.abandoned_claims.items(): - controlling = get_controlling(staged.normalized_name) - if controlling and controlling.claim_hash == claim_hash: - names_with_abandoned_or_updated_controlling_claims.append(staged.normalized_name) - # print(f"\t{staged.name} needs takeover") - activation = self.db.get_activation(staged.tx_num, staged.position) - if activation > 0: # db returns -1 for non-existent txos - # removed queued future activation from the db - self.get_remove_activate_ops( - ACTIVATED_CLAIM_TXO_TYPE, staged.claim_hash, staged.tx_num, staged.position, - activation, staged.normalized_name, staged.amount - ) - else: - # it hadn't yet been activated - pass - - # get the removed activated supports for controlling claims to determine if takeovers are possible - abandoned_support_check_need_takeover = defaultdict(list) - for claim_hash, amounts in self.removed_active_support_amount_by_claim.items(): - name = self._get_pending_claim_name(claim_hash) - if name is None: - continue - controlling = get_controlling(name) - if controlling and controlling.claim_hash == claim_hash and \ - name not in names_with_abandoned_or_updated_controlling_claims: - abandoned_support_check_need_takeover[(name, claim_hash)].extend(amounts) - - # get the controlling claims with updates to the claim to check if takeover is needed - for claim_hash in self.updated_claims: - if claim_hash in self.abandoned_claims: - continue - name = self._get_pending_claim_name(claim_hash) - if name is None: - continue - controlling = get_controlling(name) - if controlling and controlling.claim_hash == claim_hash and \ - name not in names_with_abandoned_or_updated_controlling_claims: - names_with_abandoned_or_updated_controlling_claims.append(name) - - # prepare to activate or delay activation of the pending claims being added this block - for (tx_num, nout), staged in self.txo_to_claim.items(): - is_delayed = not staged.is_update - prev_txo = self.db.get_cached_claim_txo(staged.claim_hash) - if prev_txo: - prev_activation = self.db.get_activation(prev_txo.tx_num, prev_txo.position) - if height < prev_activation or prev_activation < 0: - is_delayed = True - get_delayed_activate_ops( - staged.normalized_name, staged.claim_hash, is_delayed, tx_num, nout, staged.amount, - is_support=False - ) - - # and the supports - for (tx_num, nout), (claim_hash, amount) in self.support_txo_to_claim.items(): - if claim_hash in self.abandoned_claims: - continue - elif claim_hash in self.claim_hash_to_txo: - name = self.txo_to_claim[self.claim_hash_to_txo[claim_hash]].normalized_name - staged_is_new_claim = not self.txo_to_claim[self.claim_hash_to_txo[claim_hash]].is_update - else: - supported_claim_info = self.db.get_claim_txo(claim_hash) - if not supported_claim_info: - # the supported claim doesn't exist - continue - else: - v = supported_claim_info - name = v.normalized_name - staged_is_new_claim = (v.root_tx_num, v.root_position) == (v.tx_num, v.position) - get_delayed_activate_ops( - name, claim_hash, staged_is_new_claim, tx_num, nout, amount, is_support=True - ) - - # add the activation/delayed-activation ops - for activated, activated_txos in activated_at_height.items(): - controlling = get_controlling(activated.normalized_name) - if activated.claim_hash in self.abandoned_claims: - continue - reactivate = False - if not controlling or controlling.claim_hash == activated.claim_hash: - # there is no delay for claims to a name without a controlling value or to the controlling value - reactivate = True - for activated_txo in activated_txos: - if activated_txo.is_support and (activated_txo.tx_num, activated_txo.position) in \ - self.removed_support_txos_by_name_by_claim[activated.normalized_name][activated.claim_hash]: - # print("\tskip activate support for pending abandoned claim") - continue - if activated_txo.is_claim: - txo_type = ACTIVATED_CLAIM_TXO_TYPE - txo_tup = (activated_txo.tx_num, activated_txo.position) - if txo_tup in self.txo_to_claim: - amount = self.txo_to_claim[txo_tup].amount - else: - amount = self.db.get_claim_txo_amount( - activated.claim_hash - ) - if amount is None: - # print("\tskip activate for non existent claim") - continue - self.activated_claim_amount_by_name_and_hash[(activated.normalized_name, activated.claim_hash)] = amount - else: - txo_type = ACTIVATED_SUPPORT_TXO_TYPE - txo_tup = (activated_txo.tx_num, activated_txo.position) - if txo_tup in self.support_txo_to_claim: - amount = self.support_txo_to_claim[txo_tup][1] - else: - amount = self.db.get_support_txo_amount( - activated.claim_hash, activated_txo.tx_num, activated_txo.position - ) - if amount is None: - # print("\tskip activate support for non existent claim") - continue - self.activated_support_amount_by_claim[activated.claim_hash].append(amount) - self.activation_by_claim_by_name[activated.normalized_name][activated.claim_hash].append((activated_txo, amount)) - # print(f"\tactivate {'support' if txo_type == ACTIVATED_SUPPORT_TXO_TYPE else 'claim'} " - # f"{activated.claim_hash.hex()} @ {activated_txo.height}") - - # go through claims where the controlling claim or supports to the controlling claim have been abandoned - # check if takeovers are needed or if the name node is now empty - need_reactivate_if_takes_over = {} - for need_takeover in names_with_abandoned_or_updated_controlling_claims: - existing = self.db.get_claim_txos_for_name(need_takeover) - has_candidate = False - # add existing claims to the queue for the takeover - # track that we need to reactivate these if one of them becomes controlling - for candidate_claim_hash, (tx_num, nout) in existing.items(): - if candidate_claim_hash in self.abandoned_claims: - continue - has_candidate = True - existing_activation = self.db.get_activation(tx_num, nout) - activate_key = PendingActivationKey( - existing_activation, ACTIVATED_CLAIM_TXO_TYPE, tx_num, nout - ) - self.activation_by_claim_by_name[need_takeover][candidate_claim_hash].append(( - activate_key, self.db.get_claim_txo_amount(candidate_claim_hash) - )) - need_reactivate_if_takes_over[(need_takeover, candidate_claim_hash)] = activate_key - # print(f"\tcandidate to takeover abandoned controlling claim for " - # f"{activate_key.tx_num}:{activate_key.position} {activate_key.is_claim}") - if not has_candidate: - # remove name takeover entry, the name is now unclaimed - controlling = get_controlling(need_takeover) - self.db.prefix_db.claim_takeover.stage_delete( - (need_takeover,), (controlling.claim_hash, controlling.height) - ) - - # scan for possible takeovers out of the accumulated activations, of these make sure there - # aren't any future activations for the taken over names with yet higher amounts, if there are - # these need to get activated now and take over instead. for example: - # claim A is winning for 0.1 for long enough for a > 1 takeover delay - # claim B is made for 0.2 - # a block later, claim C is made for 0.3, it will schedule to activate 1 (or rarely 2) block(s) after B - # upon the delayed activation of B, we need to detect to activate C and make it take over early instead - - claim_exists = {} - for activated, activated_claim_txo in self.db.get_future_activated(height).items(): - # uses the pending effective amount for the future activation height, not the current height - future_amount = self._get_pending_claim_amount( - activated.normalized_name, activated.claim_hash, activated_claim_txo.height + 1 - ) - if activated.claim_hash not in claim_exists: - claim_exists[activated.claim_hash] = activated.claim_hash in self.claim_hash_to_txo or ( - self.db.get_claim_txo(activated.claim_hash) is not None) - if claim_exists[activated.claim_hash] and activated.claim_hash not in self.abandoned_claims: - v = future_amount, activated, activated_claim_txo - future_activations[activated.normalized_name][activated.claim_hash] = v - - for name, future_activated in activate_in_future.items(): - for claim_hash, activated in future_activated.items(): - if claim_hash not in claim_exists: - claim_exists[claim_hash] = claim_hash in self.claim_hash_to_txo or ( - self.db.get_claim_txo(claim_hash) is not None) - if not claim_exists[claim_hash]: - continue - if claim_hash in self.abandoned_claims: - continue - for txo in activated: - v = txo[1], PendingActivationValue(claim_hash, name), txo[0] - future_activations[name][claim_hash] = v - if txo[0].is_claim: - self.possible_future_claim_amount_by_name_and_hash[(name, claim_hash)] = txo[1] - else: - self.possible_future_support_amounts_by_claim_hash[claim_hash].append(txo[1]) - - # process takeovers - checked_names = set() - for name, activated in self.activation_by_claim_by_name.items(): - checked_names.add(name) - controlling = controlling_claims[name] - amounts = { - claim_hash: self._get_pending_effective_amount(name, claim_hash) - for claim_hash in activated.keys() if claim_hash not in self.abandoned_claims - } - # if there is a controlling claim include it in the amounts to ensure it remains the max - if controlling and controlling.claim_hash not in self.abandoned_claims: - amounts[controlling.claim_hash] = self._get_pending_effective_amount(name, controlling.claim_hash) - winning_claim_hash = max(amounts, key=lambda x: amounts[x]) - if not controlling or (winning_claim_hash != controlling.claim_hash and - name in names_with_abandoned_or_updated_controlling_claims) or \ - ((winning_claim_hash != controlling.claim_hash) and (amounts[winning_claim_hash] > amounts[controlling.claim_hash])): - amounts_with_future_activations = {claim_hash: amount for claim_hash, amount in amounts.items()} - amounts_with_future_activations.update( - { - claim_hash: self._get_pending_effective_amount( - name, claim_hash, self.height + 1 + self.coin.maxTakeoverDelay - ) for claim_hash in future_activations[name] - } - ) - winning_including_future_activations = max( - amounts_with_future_activations, key=lambda x: amounts_with_future_activations[x] - ) - future_winning_amount = amounts_with_future_activations[winning_including_future_activations] - - if winning_claim_hash != winning_including_future_activations and \ - future_winning_amount > amounts[winning_claim_hash]: - # print(f"\ttakeover by {winning_claim_hash.hex()} triggered early activation and " - # f"takeover by {winning_including_future_activations.hex()} at {height}") - # handle a pending activated claim jumping the takeover delay when another name takes over - if winning_including_future_activations not in self.claim_hash_to_txo: - claim = self.db.get_claim_txo(winning_including_future_activations) - tx_num = claim.tx_num - position = claim.position - amount = claim.amount - activation = self.db.get_activation(tx_num, position) - else: - tx_num, position = self.claim_hash_to_txo[winning_including_future_activations] - amount = self.txo_to_claim[(tx_num, position)].amount - activation = None - for (k, tx_amount) in activate_in_future[name][winning_including_future_activations]: - if (k.tx_num, k.position) == (tx_num, position): - activation = k.height - break - if activation is None: - # TODO: reproduce this in an integration test (block 604718) - _k = PendingActivationValue(winning_including_future_activations, name) - if _k in activated_at_height: - for pending_activation in activated_at_height[_k]: - if (pending_activation.tx_num, pending_activation.position) == (tx_num, position): - activation = pending_activation.height - break - assert None not in (amount, activation) - # update the claim that's activating early - self.get_remove_activate_ops( - ACTIVATED_CLAIM_TXO_TYPE, winning_including_future_activations, tx_num, - position, activation, name, amount - ) - self.get_activate_ops( - ACTIVATED_CLAIM_TXO_TYPE, winning_including_future_activations, tx_num, - position, height, name, amount - ) - - for (k, amount) in activate_in_future[name][winning_including_future_activations]: - txo = (k.tx_num, k.position) - if txo in self.possible_future_support_txos_by_claim_hash[winning_including_future_activations]: - self.get_remove_activate_ops( - ACTIVATED_SUPPORT_TXO_TYPE, winning_including_future_activations, k.tx_num, - k.position, k.height, name, amount - ) - self.get_activate_ops( - ACTIVATED_SUPPORT_TXO_TYPE, winning_including_future_activations, k.tx_num, - k.position, height, name, amount - ) - self.taken_over_names.add(name) - if controlling: - self.db.prefix_db.claim_takeover.stage_delete( - (name,), (controlling.claim_hash, controlling.height) - ) - self.db.prefix_db.claim_takeover.stage_put((name,), (winning_including_future_activations, height)) - self.touched_claim_hashes.add(winning_including_future_activations) - if controlling and controlling.claim_hash not in self.abandoned_claims: - self.touched_claim_hashes.add(controlling.claim_hash) - elif not controlling or (winning_claim_hash != controlling.claim_hash and - name in names_with_abandoned_or_updated_controlling_claims) or \ - ((winning_claim_hash != controlling.claim_hash) and (amounts[winning_claim_hash] > amounts[controlling.claim_hash])): - # print(f"\ttakeover by {winning_claim_hash.hex()} at {height}") - if (name, winning_claim_hash) in need_reactivate_if_takes_over: - previous_pending_activate = need_reactivate_if_takes_over[(name, winning_claim_hash)] - amount = self.db.get_claim_txo_amount( - winning_claim_hash - ) - if winning_claim_hash in self.claim_hash_to_txo: - tx_num, position = self.claim_hash_to_txo[winning_claim_hash] - amount = self.txo_to_claim[(tx_num, position)].amount - else: - tx_num, position = previous_pending_activate.tx_num, previous_pending_activate.position - if previous_pending_activate.height > height: - # the claim had a pending activation in the future, move it to now - if tx_num < self.tx_count: - self.get_remove_activate_ops( - ACTIVATED_CLAIM_TXO_TYPE, winning_claim_hash, tx_num, - position, previous_pending_activate.height, name, amount - ) - self.get_activate_ops( - ACTIVATED_CLAIM_TXO_TYPE, winning_claim_hash, tx_num, - position, height, name, amount - ) - self.taken_over_names.add(name) - if controlling: - self.db.prefix_db.claim_takeover.stage_delete( - (name,), (controlling.claim_hash, controlling.height) - ) - self.db.prefix_db.claim_takeover.stage_put((name,), (winning_claim_hash, height)) - if controlling and controlling.claim_hash not in self.abandoned_claims: - self.touched_claim_hashes.add(controlling.claim_hash) - self.touched_claim_hashes.add(winning_claim_hash) - elif winning_claim_hash == controlling.claim_hash: - # print("\tstill winning") - pass - else: - # print("\tno takeover") - pass - - # handle remaining takeovers from abandoned supports - for (name, claim_hash), amounts in abandoned_support_check_need_takeover.items(): - if name in checked_names: - continue - checked_names.add(name) - controlling = get_controlling(name) - amounts = { - claim_hash: self._get_pending_effective_amount(name, claim_hash) - for claim_hash in self.db.get_claims_for_name(name) if claim_hash not in self.abandoned_claims - } - if controlling and controlling.claim_hash not in self.abandoned_claims: - amounts[controlling.claim_hash] = self._get_pending_effective_amount(name, controlling.claim_hash) - winning = max(amounts, key=lambda x: amounts[x]) - - if (controlling and winning != controlling.claim_hash) or (not controlling and winning): - self.taken_over_names.add(name) - # print(f"\ttakeover from abandoned support {controlling.claim_hash.hex()} -> {winning.hex()}") - if controlling: - self.db.prefix_db.claim_takeover.stage_delete( - (name,), (controlling.claim_hash, controlling.height) - ) - self.db.prefix_db.claim_takeover.stage_put((name,), (winning, height)) - if controlling: - self.touched_claim_hashes.add(controlling.claim_hash) - self.touched_claim_hashes.add(winning) - - def _add_claim_activation_change_notification(self, claim_id: str, height: int, prev_amount: int, - new_amount: int): - self.activation_info_to_send_es[claim_id].append(TrendingNotification(height, prev_amount, new_amount)) - - def _get_cumulative_update_ops(self, height: int): - # update the last takeover height for names with takeovers - for name in self.taken_over_names: - self.touched_claim_hashes.update( - {claim_hash for claim_hash in self.db.get_claims_for_name(name) - if claim_hash not in self.abandoned_claims} - ) - - # gather cumulative removed/touched sets to update the search index - self.removed_claim_hashes.update(set(self.abandoned_claims.keys())) - self.touched_claim_hashes.difference_update(self.removed_claim_hashes) - self.touched_claim_hashes.update( - set( - map(lambda item: item[1], self.activated_claim_amount_by_name_and_hash.keys()) - ).union( - set(self.claim_hash_to_txo.keys()) - ).union( - self.removed_active_support_amount_by_claim.keys() - ).union( - self.signatures_changed - ).union( - set(self.removed_active_support_amount_by_claim.keys()) - ).union( - set(self.activated_support_amount_by_claim.keys()) - ).union( - set(self.pending_support_amount_change.keys()) - ).difference( - self.removed_claim_hashes - ) - ) - - # update support amount totals - for supported_claim, amount in self.pending_support_amount_change.items(): - existing = self.db.prefix_db.support_amount.get(supported_claim) - total = amount - if existing is not None: - total += existing.amount - self.db.prefix_db.support_amount.stage_delete((supported_claim,), existing) - self.db.prefix_db.support_amount.stage_put((supported_claim,), (total,)) - - # use the cumulative changes to update bid ordered resolve - for removed in self.removed_claim_hashes: - removed_claim = self.db.get_claim_txo(removed) - if removed_claim: - amt = self.db.get_url_effective_amount( - removed_claim.normalized_name, removed - ) - if amt: - self.db.prefix_db.effective_amount.stage_delete( - (removed_claim.normalized_name, amt.effective_amount, amt.tx_num, amt.position), (removed,) - ) - for touched in self.touched_claim_hashes: - prev_effective_amount = 0 - - if touched in self.claim_hash_to_txo: - pending = self.txo_to_claim[self.claim_hash_to_txo[touched]] - name, tx_num, position = pending.normalized_name, pending.tx_num, pending.position - claim_from_db = self.db.get_claim_txo(touched) - if claim_from_db: - claim_amount_info = self.db.get_url_effective_amount(name, touched) - if claim_amount_info: - prev_effective_amount = claim_amount_info.effective_amount - self.db.prefix_db.effective_amount.stage_delete( - (name, claim_amount_info.effective_amount, claim_amount_info.tx_num, - claim_amount_info.position), (touched,) - ) - else: - v = self.db.get_claim_txo(touched) - if not v: - continue - name, tx_num, position = v.normalized_name, v.tx_num, v.position - amt = self.db.get_url_effective_amount(name, touched) - if amt: - prev_effective_amount = amt.effective_amount - self.db.prefix_db.effective_amount.stage_delete( - (name, prev_effective_amount, amt.tx_num, amt.position), (touched,) - ) - - new_effective_amount = self._get_pending_effective_amount(name, touched) - self.db.prefix_db.effective_amount.stage_put( - (name, new_effective_amount, tx_num, position), (touched,) - ) - if touched in self.claim_hash_to_txo or touched in self.removed_claim_hashes \ - or touched in self.pending_support_amount_change: - # exclude sending notifications for claims/supports that activated but - # weren't added/spent in this block - self._add_claim_activation_change_notification( - touched.hex(), height, prev_effective_amount, new_effective_amount - ) - - for channel_hash, count in self.pending_channel_counts.items(): - if count != 0: - channel_count_val = self.db.prefix_db.channel_count.get(channel_hash) - channel_count = 0 if not channel_count_val else channel_count_val.count - if channel_count_val is not None: - self.db.prefix_db.channel_count.stage_delete((channel_hash,), (channel_count,)) - self.db.prefix_db.channel_count.stage_put((channel_hash,), (channel_count + count,)) - - self.touched_claim_hashes.update( - {k for k in self.pending_reposted if k not in self.removed_claim_hashes} - ) - self.touched_claim_hashes.update( - {k for k, v in self.pending_channel_counts.items() if v != 0 and k not in self.removed_claim_hashes} - ) - self.touched_claims_to_send_es.update(self.touched_claim_hashes) - self.touched_claims_to_send_es.difference_update(self.removed_claim_hashes) - self.removed_claims_to_send_es.update(self.removed_claim_hashes) - - def advance_block(self, block): - height = self.height + 1 - # print("advance ", height) - # Use local vars for speed in the loops - tx_count = self.tx_count - spend_utxo = self.spend_utxo - add_utxo = self.add_utxo - spend_claim_or_support_txo = self._spend_claim_or_support_txo - add_claim_or_support = self._add_claim_or_support - txs: List[Tuple[Tx, bytes]] = block.transactions - - self.db.prefix_db.block_hash.stage_put(key_args=(height,), value_args=(self.coin.header_hash(block.header),)) - self.db.prefix_db.header.stage_put(key_args=(height,), value_args=(block.header,)) - self.db.prefix_db.block_txs.stage_put(key_args=(height,), value_args=([tx_hash for tx, tx_hash in txs],)) - - for tx, tx_hash in txs: - spent_claims = {} - txos = Transaction(tx.raw).outputs - - self.db.prefix_db.tx.stage_put(key_args=(tx_hash,), value_args=(tx.raw,)) - self.db.prefix_db.tx_num.stage_put(key_args=(tx_hash,), value_args=(tx_count,)) - self.db.prefix_db.tx_hash.stage_put(key_args=(tx_count,), value_args=(tx_hash,)) - - # Spend the inputs - for txin in tx.inputs: - if txin.is_generation(): - continue - # spend utxo for address histories - hashX = spend_utxo(txin.prev_hash, txin.prev_idx) - if hashX: - if tx_count not in self.hashXs_by_tx[hashX]: - self.hashXs_by_tx[hashX].append(tx_count) - # spend claim/support txo - spend_claim_or_support_txo(height, txin, spent_claims) - - # Add the new UTXOs - for nout, txout in enumerate(tx.outputs): - # Get the hashX. Ignore unspendable outputs - hashX = add_utxo(tx_hash, tx_count, nout, txout) - if hashX: - # self._set_hashX_cache(hashX) - if tx_count not in self.hashXs_by_tx[hashX]: - self.hashXs_by_tx[hashX].append(tx_count) - # add claim/support txo - add_claim_or_support( - height, tx_hash, tx_count, nout, txos[nout], spent_claims - ) - - # Handle abandoned claims - abandoned_channels = {} - # abandon the channels last to handle abandoned signed claims in the same tx, - # see test_abandon_channel_and_claims_in_same_tx - for abandoned_claim_hash, (tx_num, nout, normalized_name) in spent_claims.items(): - if normalized_name.startswith('@'): - abandoned_channels[abandoned_claim_hash] = (tx_num, nout, normalized_name) - else: - # print(f"\tabandon {normalized_name} {abandoned_claim_hash.hex()} {tx_num} {nout}") - self._abandon_claim(abandoned_claim_hash, tx_num, nout, normalized_name) - - for abandoned_claim_hash, (tx_num, nout, normalized_name) in abandoned_channels.items(): - # print(f"\tabandon {normalized_name} {abandoned_claim_hash.hex()} {tx_num} {nout}") - self._abandon_claim(abandoned_claim_hash, tx_num, nout, normalized_name) - self.pending_transactions[tx_count] = tx_hash - self.pending_transaction_num_mapping[tx_hash] = tx_count - if self.env.cache_all_tx_hashes: - self.db.total_transactions.append(tx_hash) - self.db.tx_num_mapping[tx_hash] = tx_count - tx_count += 1 - - # handle expired claims - self._expire_claims(height) - - # activate claims and process takeovers - self._get_takeover_ops(height) - - # update effective amount and update sets of touched and deleted claims - self._get_cumulative_update_ops(height) - - self.db.prefix_db.tx_count.stage_put(key_args=(height,), value_args=(tx_count,)) - - for hashX, new_history in self.hashXs_by_tx.items(): - if not new_history: - continue - self.db.prefix_db.hashX_history.stage_put(key_args=(hashX, height), value_args=(new_history,)) - - self.tx_count = tx_count - self.db.tx_counts.append(self.tx_count) - - cached_max_reorg_depth = self.daemon.cached_height() - self.env.reorg_limit - - # if height >= cached_max_reorg_depth: - self.db.prefix_db.touched_or_deleted.stage_put( - key_args=(height,), value_args=(self.touched_claim_hashes, self.removed_claim_hashes) - ) - - self.height = height - self.db.headers.append(block.header) - self.tip = self.coin.header_hash(block.header) - - min_height = self.db.min_undo_height(self.db.db_height) - if min_height > 0: # delete undos for blocks deep enough they can't be reorged - undo_to_delete = list(self.db.prefix_db.undo.iterate(start=(0,), stop=(min_height,))) - for (k, v) in undo_to_delete: - self.db.prefix_db.undo.stage_delete((k,), (v,)) - touched_or_deleted_to_delete = list(self.db.prefix_db.touched_or_deleted.iterate( - start=(0,), stop=(min_height,)) - ) - for (k, v) in touched_or_deleted_to_delete: - self.db.prefix_db.touched_or_deleted.stage_delete(k, v) - - self.db.fs_height = self.height - self.db.fs_tx_count = self.tx_count - self.db.hist_flush_count += 1 - self.db.hist_unflushed_count = 0 - self.db.utxo_flush_count = self.db.hist_flush_count - self.db.db_height = self.height - self.db.db_tx_count = self.tx_count - self.db.db_tip = self.tip - self.db.last_flush_tx_count = self.db.fs_tx_count - now = time.time() - self.db.wall_time += now - self.db.last_flush - self.db.last_flush = now - self.db.write_db_state() - - def clear_after_advance_or_reorg(self): - self.txo_to_claim.clear() - self.claim_hash_to_txo.clear() - self.support_txos_by_claim.clear() - self.support_txo_to_claim.clear() - self.removed_support_txos_by_name_by_claim.clear() - self.abandoned_claims.clear() - self.removed_active_support_amount_by_claim.clear() - self.activated_support_amount_by_claim.clear() - self.activated_claim_amount_by_name_and_hash.clear() - self.activation_by_claim_by_name.clear() - self.possible_future_claim_amount_by_name_and_hash.clear() - self.possible_future_support_amounts_by_claim_hash.clear() - self.possible_future_support_txos_by_claim_hash.clear() - self.pending_channels.clear() - self.amount_cache.clear() - self.signatures_changed.clear() - self.expired_claim_hashes.clear() - self.doesnt_have_valid_signature.clear() - self.claim_channels.clear() - self.utxo_cache.clear() - self.hashXs_by_tx.clear() - self.history_cache.clear() - self.mempool.notified_mempool_txs.clear() - self.removed_claim_hashes.clear() - self.touched_claim_hashes.clear() - self.pending_reposted.clear() - self.pending_channel_counts.clear() - self.updated_claims.clear() - self.taken_over_names.clear() - self.pending_transaction_num_mapping.clear() - self.pending_transactions.clear() - self.pending_support_amount_change.clear() - self.resolve_cache.clear() - self.resolve_outputs_cache.clear() - - async def backup_block(self): - assert len(self.db.prefix_db._op_stack) == 0 - touched_and_deleted = self.db.prefix_db.touched_or_deleted.get(self.height) - self.touched_claims_to_send_es.update(touched_and_deleted.touched_claims) - self.removed_claims_to_send_es.difference_update(touched_and_deleted.touched_claims) - self.removed_claims_to_send_es.update(touched_and_deleted.deleted_claims) - - # self.db.assert_flushed(self.flush_data()) - self.logger.info("backup block %i", self.height) - # Check and update self.tip - - self.db.headers.pop() - self.db.tx_counts.pop() - self.tip = self.coin.header_hash(self.db.headers[-1]) - if self.env.cache_all_tx_hashes: - while len(self.db.total_transactions) > self.db.tx_counts[-1]: - self.db.tx_num_mapping.pop(self.db.total_transactions.pop()) - self.tx_count -= 1 - else: - self.tx_count = self.db.tx_counts[-1] - self.height -= 1 - - # self.touched can include other addresses which is - # harmless, but remove None. - self.touched_hashXs.discard(None) - - assert self.height < self.db.db_height - assert not self.db.hist_unflushed - - start_time = time.time() - tx_delta = self.tx_count - self.db.last_flush_tx_count - ### - self.db.fs_tx_count = self.tx_count - # Truncate header_mc: header count is 1 more than the height. - self.db.header_mc.truncate(self.height + 1) - ### - # Not certain this is needed, but it doesn't hurt - self.db.hist_flush_count += 1 - - while self.db.fs_height > self.height: - self.db.fs_height -= 1 - self.db.utxo_flush_count = self.db.hist_flush_count - self.db.db_height = self.height - self.db.db_tx_count = self.tx_count - self.db.db_tip = self.tip - # Flush state last as it reads the wall time. - now = time.time() - self.db.wall_time += now - self.db.last_flush - self.db.last_flush = now - self.db.last_flush_tx_count = self.db.fs_tx_count - - def rollback(): - self.db.prefix_db.rollback(self.height + 1) - self.db.es_sync_height = self.height - self.db.write_db_state() - self.db.prefix_db.unsafe_commit() - - await self.run_in_thread_with_lock(rollback) - self.clear_after_advance_or_reorg() - self.db.assert_db_state() - - elapsed = self.db.last_flush - start_time - self.logger.warning(f'backup flush #{self.db.hist_flush_count:,d} took {elapsed:.1f}s. ' - f'Height {self.height:,d} txs: {self.tx_count:,d} ({tx_delta:+,d})') - - def add_utxo(self, tx_hash: bytes, tx_num: int, nout: int, txout: 'TxOutput') -> Optional[bytes]: - hashX = self.coin.hashX_from_script(txout.pk_script) - if hashX: - self.touched_hashXs.add(hashX) - self.utxo_cache[(tx_hash, nout)] = (hashX, txout.value) - self.db.prefix_db.utxo.stage_put((hashX, tx_num, nout), (txout.value,)) - self.db.prefix_db.hashX_utxo.stage_put((tx_hash[:4], tx_num, nout), (hashX,)) - return hashX - - def get_pending_tx_num(self, tx_hash: bytes) -> int: - if tx_hash in self.pending_transaction_num_mapping: - return self.pending_transaction_num_mapping[tx_hash] - else: - return self.db.get_tx_num(tx_hash) - - def spend_utxo(self, tx_hash: bytes, nout: int): - hashX, amount = self.utxo_cache.pop((tx_hash, nout), (None, None)) - txin_num = self.get_pending_tx_num(tx_hash) - if not hashX: - hashX_value = self.db.prefix_db.hashX_utxo.get(tx_hash[:4], txin_num, nout) - if not hashX_value: - return - hashX = hashX_value.hashX - utxo_value = self.db.prefix_db.utxo.get(hashX, txin_num, nout) - if not utxo_value: - self.logger.warning( - "%s:%s is not found in UTXO db for %s", hash_to_hex_str(tx_hash), nout, hash_to_hex_str(hashX) - ) - raise ChainError( - f"{hash_to_hex_str(tx_hash)}:{nout} is not found in UTXO db for {hash_to_hex_str(hashX)}" - ) - self.touched_hashXs.add(hashX) - self.db.prefix_db.hashX_utxo.stage_delete((tx_hash[:4], txin_num, nout), hashX_value) - self.db.prefix_db.utxo.stage_delete((hashX, txin_num, nout), utxo_value) - return hashX - elif amount is not None: - self.db.prefix_db.hashX_utxo.stage_delete((tx_hash[:4], txin_num, nout), (hashX,)) - self.db.prefix_db.utxo.stage_delete((hashX, txin_num, nout), (amount,)) - self.touched_hashXs.add(hashX) - return hashX - - async def _process_prefetched_blocks(self): - """Loop forever processing blocks as they arrive.""" - while True: - if self.height == self.daemon.cached_height(): - if not self._caught_up_event.is_set(): - await self._first_caught_up() - self._caught_up_event.set() - await self.blocks_event.wait() - self.blocks_event.clear() - blocks = self.prefetcher.get_prefetched_blocks() - try: - await self.check_and_advance_blocks(blocks) - except Exception: - self.logger.exception("error while processing txs") - raise - - async def _es_caught_up(self): - self.db.es_sync_height = self.height - - def flush(): - assert len(self.db.prefix_db._op_stack) == 0 - self.db.write_db_state() - self.db.prefix_db.unsafe_commit() - self.db.assert_db_state() - - await self.run_in_thread_with_lock(flush) - - async def _first_caught_up(self): - self.logger.info(f'caught up to height {self.height}') - # Flush everything but with first_sync->False state. - first_sync = self.db.first_sync - self.db.first_sync = False - - def flush(): - assert len(self.db.prefix_db._op_stack) == 0 - self.db.write_db_state() - self.db.prefix_db.unsafe_commit() - self.db.assert_db_state() - - await self.run_in_thread_with_lock(flush) - - if first_sync: - self.logger.info(f'{lbry.__version__} synced to ' - f'height {self.height:,d}, halting here.') - self.shutdown_event.set() - - async def fetch_and_process_blocks(self, caught_up_event): - """Fetch, process and index blocks from the daemon. - - Sets caught_up_event when first caught up. Flushes to disk - and shuts down cleanly if cancelled. - - This is mainly because if, during initial sync ElectrumX is - asked to shut down when a large number of blocks have been - processed but not written to disk, it should write those to - disk before exiting, as otherwise a significant amount of work - could be lost. - """ - - self._caught_up_event = caught_up_event - try: - self.db.open_db() - self.height = self.db.db_height - self.tip = self.db.db_tip - self.tx_count = self.db.db_tx_count - self.status_server.set_height(self.db.fs_height, self.db.db_tip) - await self.db.initialize_caches() - await self.db.search_index.start() - await asyncio.wait([ - self.prefetcher.main_loop(self.height), - self._process_prefetched_blocks() - ]) - except asyncio.CancelledError: - raise - except: - self.logger.exception("Block processing failed!") - raise - finally: - self.status_server.stop() - # Shut down block processing - self.logger.info('closing the DB for a clean shutdown...') - self._sync_reader_executor.shutdown(wait=True) - self._chain_executor.shutdown(wait=True) - self.db.close() diff --git a/lbry/wallet/server/cli.py b/lbry/wallet/server/cli.py deleted file mode 100644 index 74a3d092a..000000000 --- a/lbry/wallet/server/cli.py +++ /dev/null @@ -1,34 +0,0 @@ -import logging -import traceback -import argparse -from lbry.wallet.server.env import Env -from lbry.wallet.server.server import Server - - -def get_argument_parser(): - parser = argparse.ArgumentParser( - prog="lbry-hub" - ) - Env.contribute_to_arg_parser(parser) - return parser - - -def main(): - parser = get_argument_parser() - args = parser.parse_args() - logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)-4s %(name)s:%(lineno)d: %(message)s") - logging.info('lbry.server starting') - logging.getLogger('aiohttp').setLevel(logging.WARNING) - logging.getLogger('elasticsearch').setLevel(logging.WARNING) - try: - server = Server(Env.from_arg_parser(args)) - server.run() - except Exception: - traceback.print_exc() - logging.critical('lbry.server terminated abnormally') - else: - logging.info('lbry.server terminated normally') - - -if __name__ == "__main__": - main() diff --git a/lbry/wallet/server/coin.py b/lbry/wallet/server/coin.py deleted file mode 100644 index bd379f112..000000000 --- a/lbry/wallet/server/coin.py +++ /dev/null @@ -1,386 +0,0 @@ -import re -import struct -from typing import List -from hashlib import sha256 -from decimal import Decimal -from collections import namedtuple - -import lbry.wallet.server.tx as lib_tx -from lbry.wallet.script import OutputScript, OP_CLAIM_NAME, OP_UPDATE_CLAIM, OP_SUPPORT_CLAIM -from lbry.wallet.server.tx import DeserializerSegWit -from lbry.wallet.server.util import cachedproperty, subclasses -from lbry.wallet.server.hash import Base58, hash160, double_sha256, hash_to_hex_str, HASHX_LEN -from lbry.wallet.server.daemon import Daemon, LBCDaemon -from lbry.wallet.server.script import ScriptPubKey, OpCodes -from lbry.wallet.server.leveldb import LevelDB -from lbry.wallet.server.session import LBRYElectrumX, LBRYSessionManager -from lbry.wallet.server.block_processor import BlockProcessor - - -Block = namedtuple("Block", "raw header transactions") -OP_RETURN = OpCodes.OP_RETURN - - -class CoinError(Exception): - """Exception raised for coin-related errors.""" - - -class Coin: - """Base class of coin hierarchy.""" - - REORG_LIMIT = 200 - # Not sure if these are coin-specific - RPC_URL_REGEX = re.compile('.+@(\\[[0-9a-fA-F:]+\\]|[^:]+)(:[0-9]+)?') - VALUE_PER_COIN = 100000000 - CHUNK_SIZE = 2016 - BASIC_HEADER_SIZE = 80 - STATIC_BLOCK_HEADERS = True - SESSIONCLS = LBRYElectrumX - DESERIALIZER = lib_tx.Deserializer - DAEMON = Daemon - BLOCK_PROCESSOR = BlockProcessor - SESSION_MANAGER = LBRYSessionManager - DB = LevelDB - HEADER_VALUES = [ - 'version', 'prev_block_hash', 'merkle_root', 'timestamp', 'bits', 'nonce' - ] - HEADER_UNPACK = struct.Struct('< I 32s 32s I I I').unpack_from - MEMPOOL_HISTOGRAM_REFRESH_SECS = 500 - XPUB_VERBYTES = bytes('????', 'utf-8') - XPRV_VERBYTES = bytes('????', 'utf-8') - ENCODE_CHECK = Base58.encode_check - DECODE_CHECK = Base58.decode_check - # Peer discovery - PEER_DEFAULT_PORTS = {'t': '50001', 's': '50002'} - PEERS: List[str] = [] - - @classmethod - def lookup_coin_class(cls, name, net): - """Return a coin class given name and network. - - Raise an exception if unrecognised.""" - req_attrs = ['TX_COUNT', 'TX_COUNT_HEIGHT', 'TX_PER_BLOCK'] - for coin in subclasses(Coin): - if (coin.NAME.lower() == name.lower() and - coin.NET.lower() == net.lower()): - coin_req_attrs = req_attrs.copy() - missing = [attr for attr in coin_req_attrs - if not hasattr(coin, attr)] - if missing: - raise CoinError(f'coin {name} missing {missing} attributes') - return coin - raise CoinError(f'unknown coin {name} and network {net} combination') - - @classmethod - def sanitize_url(cls, url): - # Remove surrounding ws and trailing /s - url = url.strip().rstrip('/') - match = cls.RPC_URL_REGEX.match(url) - if not match: - raise CoinError(f'invalid daemon URL: "{url}"') - if match.groups()[1] is None: - url += f':{cls.RPC_PORT:d}' - if not url.startswith('http://') and not url.startswith('https://'): - url = 'http://' + url - return url + '/' - - @classmethod - def genesis_block(cls, block): - """Check the Genesis block is the right one for this coin. - - Return the block less its unspendable coinbase. - """ - header = cls.block_header(block, 0) - header_hex_hash = hash_to_hex_str(cls.header_hash(header)) - if header_hex_hash != cls.GENESIS_HASH: - raise CoinError(f'genesis block has hash {header_hex_hash} expected {cls.GENESIS_HASH}') - - return header + bytes(1) - - @classmethod - def hashX_from_script(cls, script): - """Returns a hashX from a script, or None if the script is provably - unspendable so the output can be dropped. - """ - if script and script[0] == OP_RETURN: - return None - return sha256(script).digest()[:HASHX_LEN] - - @staticmethod - def lookup_xverbytes(verbytes): - """Return a (is_xpub, coin_class) pair given xpub/xprv verbytes.""" - # Order means BTC testnet will override NMC testnet - for coin in subclasses(Coin): - if verbytes == coin.XPUB_VERBYTES: - return True, coin - if verbytes == coin.XPRV_VERBYTES: - return False, coin - raise CoinError('version bytes unrecognised') - - @classmethod - def address_to_hashX(cls, address): - """Return a hashX given a coin address.""" - return cls.hashX_from_script(cls.pay_to_address_script(address)) - - @classmethod - def P2PKH_address_from_hash160(cls, hash160): - """Return a P2PKH address given a public key.""" - assert len(hash160) == 20 - return cls.ENCODE_CHECK(cls.P2PKH_VERBYTE + hash160) - - @classmethod - def P2PKH_address_from_pubkey(cls, pubkey): - """Return a coin address given a public key.""" - return cls.P2PKH_address_from_hash160(hash160(pubkey)) - - @classmethod - def P2SH_address_from_hash160(cls, hash160): - """Return a coin address given a hash160.""" - assert len(hash160) == 20 - return cls.ENCODE_CHECK(cls.P2SH_VERBYTES[0] + hash160) - - @classmethod - def hash160_to_P2PKH_script(cls, hash160): - return ScriptPubKey.P2PKH_script(hash160) - - @classmethod - def hash160_to_P2PKH_hashX(cls, hash160): - return cls.hashX_from_script(cls.hash160_to_P2PKH_script(hash160)) - - @classmethod - def pay_to_address_script(cls, address): - """Return a pubkey script that pays to a pubkey hash. - - Pass the address (either P2PKH or P2SH) in base58 form. - """ - raw = cls.DECODE_CHECK(address) - - # Require version byte(s) plus hash160. - verbyte = -1 - verlen = len(raw) - 20 - if verlen > 0: - verbyte, hash160 = raw[:verlen], raw[verlen:] - - if verbyte == cls.P2PKH_VERBYTE: - return cls.hash160_to_P2PKH_script(hash160) - if verbyte in cls.P2SH_VERBYTES: - return ScriptPubKey.P2SH_script(hash160) - - raise CoinError(f'invalid address: {address}') - - @classmethod - def privkey_WIF(cls, privkey_bytes, compressed): - """Return the private key encoded in Wallet Import Format.""" - payload = bytearray(cls.WIF_BYTE) + privkey_bytes - if compressed: - payload.append(0x01) - return cls.ENCODE_CHECK(payload) - - @classmethod - def header_hash(cls, header): - """Given a header return hash""" - return double_sha256(header) - - @classmethod - def header_prevhash(cls, header): - """Given a header return previous hash""" - return header[4:36] - - @classmethod - def static_header_offset(cls, height): - """Given a header height return its offset in the headers file. - - If header sizes change at some point, this is the only code - that needs updating.""" - assert cls.STATIC_BLOCK_HEADERS - return height * cls.BASIC_HEADER_SIZE - - @classmethod - def static_header_len(cls, height): - """Given a header height return its length.""" - return (cls.static_header_offset(height + 1) - - cls.static_header_offset(height)) - - @classmethod - def block_header(cls, block, height): - """Returns the block header given a block and its height.""" - return block[:cls.static_header_len(height)] - - @classmethod - def block(cls, raw_block, height): - """Return a Block namedtuple given a raw block and its height.""" - header = cls.block_header(raw_block, height) - txs = cls.DESERIALIZER(raw_block, start=len(header)).read_tx_block() - return Block(raw_block, header, txs) - - @classmethod - def transaction(cls, raw_tx: bytes): - """Return a Block namedtuple given a raw block and its height.""" - return cls.DESERIALIZER(raw_tx).read_tx() - - @classmethod - def decimal_value(cls, value): - """Return the number of standard coin units as a Decimal given a - quantity of smallest units. - - For example 1 BTC is returned for 100 million satoshis. - """ - return Decimal(value) / cls.VALUE_PER_COIN - - @classmethod - def electrum_header(cls, header, height): - h = dict(zip(cls.HEADER_VALUES, cls.HEADER_UNPACK(header))) - # Add the height that is not present in the header itself - h['block_height'] = height - # Convert bytes to str - h['prev_block_hash'] = hash_to_hex_str(h['prev_block_hash']) - h['merkle_root'] = hash_to_hex_str(h['merkle_root']) - return h - - -class LBC(Coin): - DAEMON = LBCDaemon - SESSIONCLS = LBRYElectrumX - SESSION_MANAGER = LBRYSessionManager - DESERIALIZER = DeserializerSegWit - DB = LevelDB - NAME = "LBRY" - SHORTNAME = "LBC" - NET = "mainnet" - BASIC_HEADER_SIZE = 112 - CHUNK_SIZE = 96 - XPUB_VERBYTES = bytes.fromhex("0488b21e") - XPRV_VERBYTES = bytes.fromhex("0488ade4") - P2PKH_VERBYTE = bytes.fromhex("55") - P2SH_VERBYTES = bytes.fromhex("7A") - WIF_BYTE = bytes.fromhex("1C") - GENESIS_HASH = ('9c89283ba0f3227f6c03b70216b9f665' - 'f0118d5e0fa729cedf4fb34d6a34f463') - TX_COUNT = 2716936 - TX_COUNT_HEIGHT = 329554 - TX_PER_BLOCK = 1 - RPC_PORT = 9245 - REORG_LIMIT = 200 - - nOriginalClaimExpirationTime = 262974 - nExtendedClaimExpirationTime = 2102400 - nExtendedClaimExpirationForkHeight = 400155 - nNormalizedNameForkHeight = 539940 # targeting 21 March 2019 - nMinTakeoverWorkaroundHeight = 496850 - nMaxTakeoverWorkaroundHeight = 658300 # targeting 30 Oct 2019 - nWitnessForkHeight = 680770 # targeting 11 Dec 2019 - nAllClaimsInMerkleForkHeight = 658310 # targeting 30 Oct 2019 - proportionalDelayFactor = 32 - maxTakeoverDelay = 4032 - - PEERS = [ - ] - - @classmethod - def genesis_block(cls, block): - '''Check the Genesis block is the right one for this coin. - - Return the block less its unspendable coinbase. - ''' - header = cls.block_header(block, 0) - header_hex_hash = hash_to_hex_str(cls.header_hash(header)) - if header_hex_hash != cls.GENESIS_HASH: - raise CoinError(f'genesis block has hash {header_hex_hash} expected {cls.GENESIS_HASH}') - - return block - - @classmethod - def electrum_header(cls, header, height): - version, = struct.unpack(' int: - if extended: - return last_updated_height + cls.nExtendedClaimExpirationTime - if last_updated_height < cls.nExtendedClaimExpirationForkHeight: - return last_updated_height + cls.nOriginalClaimExpirationTime - return last_updated_height + cls.nExtendedClaimExpirationTime - - @classmethod - def get_delay_for_name(cls, blocks_of_continuous_ownership: int) -> int: - return min(blocks_of_continuous_ownership // cls.proportionalDelayFactor, cls.maxTakeoverDelay) - - -class LBCRegTest(LBC): - NET = "regtest" - GENESIS_HASH = '6e3fcf1299d4ec5d79c3a4c91d624a4acf9e2e173d95a1a0504f677669687556' - XPUB_VERBYTES = bytes.fromhex('043587cf') - XPRV_VERBYTES = bytes.fromhex('04358394') - P2PKH_VERBYTE = bytes.fromhex("6f") - P2SH_VERBYTES = bytes.fromhex("c4") - - nOriginalClaimExpirationTime = 500 - nExtendedClaimExpirationTime = 600 - nExtendedClaimExpirationForkHeight = 800 - nNormalizedNameForkHeight = 250 - nMinTakeoverWorkaroundHeight = -1 - nMaxTakeoverWorkaroundHeight = -1 - nWitnessForkHeight = 150 - nAllClaimsInMerkleForkHeight = 350 - - -class LBCTestNet(LBCRegTest): - NET = "testnet" - GENESIS_HASH = '9c89283ba0f3227f6c03b70216b9f665f0118d5e0fa729cedf4fb34d6a34f463' diff --git a/lbry/wallet/server/daemon.py b/lbry/wallet/server/daemon.py deleted file mode 100644 index c487de0c7..000000000 --- a/lbry/wallet/server/daemon.py +++ /dev/null @@ -1,375 +0,0 @@ -import asyncio -import itertools -import json -import time -from functools import wraps - -import aiohttp -from prometheus_client import Gauge, Histogram -from lbry.utils import LRUCacheWithMetrics -from lbry.wallet.rpc.jsonrpc import RPCError -from lbry.wallet.server.util import hex_to_bytes, class_logger -from lbry.wallet.rpc import JSONRPC - - -class DaemonError(Exception): - """Raised when the daemon returns an error in its results.""" - - -class WarmingUpError(Exception): - """Internal - when the daemon is warming up.""" - - -class WorkQueueFullError(Exception): - """Internal - when the daemon's work queue is full.""" - - -NAMESPACE = "wallet_server" - - -class Daemon: - """Handles connections to a daemon at the given URL.""" - - WARMING_UP = -28 - id_counter = itertools.count() - - lbrycrd_request_time_metric = Histogram( - "lbrycrd_request", "lbrycrd requests count", namespace=NAMESPACE, labelnames=("method",) - ) - lbrycrd_pending_count_metric = Gauge( - "lbrycrd_pending_count", "Number of lbrycrd rpcs that are in flight", namespace=NAMESPACE, - labelnames=("method",) - ) - - def __init__(self, coin, url, max_workqueue=10, init_retry=0.25, - max_retry=4.0): - self.coin = coin - self.logger = class_logger(__name__, self.__class__.__name__) - self.set_url(url) - # Limit concurrent RPC calls to this number. - # See DEFAULT_HTTP_WORKQUEUE in bitcoind, which is typically 16 - self.workqueue_semaphore = asyncio.Semaphore(value=max_workqueue) - self.init_retry = init_retry - self.max_retry = max_retry - self._height = None - self.available_rpcs = {} - self.connector = aiohttp.TCPConnector() - self._block_hash_cache = LRUCacheWithMetrics(100000) - self._block_cache = LRUCacheWithMetrics(2 ** 13, metric_name='block', namespace=NAMESPACE) - - async def close(self): - if self.connector: - await self.connector.close() - self.connector = None - - def set_url(self, url): - """Set the URLS to the given list, and switch to the first one.""" - urls = url.split(',') - urls = [self.coin.sanitize_url(url) for url in urls] - for n, url in enumerate(urls): - status = '' if n else ' (current)' - logged_url = self.logged_url(url) - self.logger.info(f'daemon #{n + 1} at {logged_url}{status}') - self.url_index = 0 - self.urls = urls - - def current_url(self): - """Returns the current daemon URL.""" - return self.urls[self.url_index] - - def logged_url(self, url=None): - """The host and port part, for logging.""" - url = url or self.current_url() - return url[url.rindex('@') + 1:] - - def failover(self): - """Call to fail-over to the next daemon URL. - - Returns False if there is only one, otherwise True. - """ - if len(self.urls) > 1: - self.url_index = (self.url_index + 1) % len(self.urls) - self.logger.info(f'failing over to {self.logged_url()}') - return True - return False - - def client_session(self): - """An aiohttp client session.""" - return aiohttp.ClientSession(connector=self.connector, connector_owner=False) - - async def _send_data(self, data): - if not self.connector: - raise asyncio.CancelledError('Tried to send request during shutdown.') - async with self.workqueue_semaphore: - async with self.client_session() as session: - async with session.post(self.current_url(), data=data) as resp: - kind = resp.headers.get('Content-Type', None) - if kind == 'application/json': - return await resp.json() - # bitcoind's HTTP protocol "handling" is a bad joke - text = await resp.text() - if 'Work queue depth exceeded' in text: - raise WorkQueueFullError - text = text.strip() or resp.reason - self.logger.error(text) - raise DaemonError(text) - - async def _send(self, payload, processor): - """Send a payload to be converted to JSON. - - Handles temporary connection issues. Daemon response errors - are raise through DaemonError. - """ - - def log_error(error): - nonlocal last_error_log, retry - now = time.time() - if now - last_error_log > 60: - last_error_log = now - self.logger.error(f'{error} Retrying occasionally...') - if retry == self.max_retry and self.failover(): - retry = 0 - - on_good_message = None - last_error_log = 0 - data = json.dumps(payload) - retry = self.init_retry - methods = tuple( - [payload['method']] if isinstance(payload, dict) else [request['method'] for request in payload] - ) - while True: - try: - for method in methods: - self.lbrycrd_pending_count_metric.labels(method=method).inc() - result = await self._send_data(data) - result = processor(result) - if on_good_message: - self.logger.info(on_good_message) - return result - except asyncio.TimeoutError: - log_error('timeout error.') - except aiohttp.ServerDisconnectedError: - log_error('disconnected.') - on_good_message = 'connection restored' - except aiohttp.ClientConnectionError: - log_error('connection problem - is your daemon running?') - on_good_message = 'connection restored' - except aiohttp.ClientError as e: - log_error(f'daemon error: {e}') - on_good_message = 'running normally' - except WarmingUpError: - log_error('starting up checking blocks.') - on_good_message = 'running normally' - except WorkQueueFullError: - log_error('work queue full.') - on_good_message = 'running normally' - finally: - for method in methods: - self.lbrycrd_pending_count_metric.labels(method=method).dec() - await asyncio.sleep(retry) - retry = max(min(self.max_retry, retry * 2), self.init_retry) - - async def _send_single(self, method, params=None): - """Send a single request to the daemon.""" - - start = time.perf_counter() - - def processor(result): - err = result['error'] - if not err: - return result['result'] - if err.get('code') == self.WARMING_UP: - raise WarmingUpError - raise DaemonError(err) - - payload = {'method': method, 'id': next(self.id_counter)} - if params: - payload['params'] = params - result = await self._send(payload, processor) - self.lbrycrd_request_time_metric.labels(method=method).observe(time.perf_counter() - start) - return result - - async def _send_vector(self, method, params_iterable, replace_errs=False): - """Send several requests of the same method. - - The result will be an array of the same length as params_iterable. - If replace_errs is true, any item with an error is returned as None, - otherwise an exception is raised.""" - - start = time.perf_counter() - - def processor(result): - errs = [item['error'] for item in result if item['error']] - if any(err.get('code') == self.WARMING_UP for err in errs): - raise WarmingUpError - if not errs or replace_errs: - return [item['result'] for item in result] - raise DaemonError(errs) - - payload = [{'method': method, 'params': p, 'id': next(self.id_counter)} - for p in params_iterable] - result = [] - if payload: - result = await self._send(payload, processor) - self.lbrycrd_request_time_metric.labels(method=method).observe(time.perf_counter() - start) - return result - - async def _is_rpc_available(self, method): - """Return whether given RPC method is available in the daemon. - - Results are cached and the daemon will generally not be queried with - the same method more than once.""" - available = self.available_rpcs.get(method) - if available is None: - available = True - try: - await self._send_single(method) - except DaemonError as e: - err = e.args[0] - error_code = err.get("code") - available = error_code != JSONRPC.METHOD_NOT_FOUND - self.available_rpcs[method] = available - return available - - async def block_hex_hashes(self, first, count): - """Return the hex hashes of count block starting at height first.""" - if first + count < (self.cached_height() or 0) - 200: - return await self._cached_block_hex_hashes(first, count) - params_iterable = ((h, ) for h in range(first, first + count)) - return await self._send_vector('getblockhash', params_iterable) - - async def _cached_block_hex_hashes(self, first, count): - """Return the hex hashes of count block starting at height first.""" - cached = self._block_hash_cache.get((first, count)) - if cached: - return cached - params_iterable = ((h, ) for h in range(first, first + count)) - self._block_hash_cache[(first, count)] = await self._send_vector('getblockhash', params_iterable) - return self._block_hash_cache[(first, count)] - - async def deserialised_block(self, hex_hash): - """Return the deserialised block with the given hex hash.""" - if hex_hash not in self._block_cache: - block = await self._send_single('getblock', (hex_hash, True)) - self._block_cache[hex_hash] = block - return block - return self._block_cache[hex_hash] - - async def raw_blocks(self, hex_hashes): - """Return the raw binary blocks with the given hex hashes.""" - params_iterable = ((h, False) for h in hex_hashes) - blocks = await self._send_vector('getblock', params_iterable) - # Convert hex string to bytes - return [hex_to_bytes(block) for block in blocks] - - async def mempool_hashes(self): - """Update our record of the daemon's mempool hashes.""" - return await self._send_single('getrawmempool') - - async def estimatefee(self, block_count): - """Return the fee estimate for the block count. Units are whole - currency units per KB, e.g. 0.00000995, or -1 if no estimate - is available. - """ - args = (block_count, ) - if await self._is_rpc_available('estimatesmartfee'): - estimate = await self._send_single('estimatesmartfee', args) - return estimate.get('feerate', -1) - return await self._send_single('estimatefee', args) - - async def getnetworkinfo(self): - """Return the result of the 'getnetworkinfo' RPC call.""" - return await self._send_single('getnetworkinfo') - - async def relayfee(self): - """The minimum fee a low-priority tx must pay in order to be accepted - to the daemon's memory pool.""" - network_info = await self.getnetworkinfo() - return network_info['relayfee'] - - async def getrawtransaction(self, hex_hash, verbose=False): - """Return the serialized raw transaction with the given hash.""" - # Cast to int because some coin daemons are old and require it - return await self._send_single('getrawtransaction', - (hex_hash, int(verbose))) - - async def getrawtransactions(self, hex_hashes, replace_errs=True): - """Return the serialized raw transactions with the given hashes. - - Replaces errors with None by default.""" - params_iterable = ((hex_hash, 0) for hex_hash in hex_hashes) - txs = await self._send_vector('getrawtransaction', params_iterable, - replace_errs=replace_errs) - # Convert hex strings to bytes - return [hex_to_bytes(tx) if tx else None for tx in txs] - - async def broadcast_transaction(self, raw_tx): - """Broadcast a transaction to the network.""" - return await self._send_single('sendrawtransaction', (raw_tx, )) - - async def height(self): - """Query the daemon for its current height.""" - self._height = await self._send_single('getblockcount') - return self._height - - def cached_height(self): - """Return the cached daemon height. - - If the daemon has not been queried yet this returns None.""" - return self._height - - -def handles_errors(decorated_function): - @wraps(decorated_function) - async def wrapper(*args, **kwargs): - try: - return await decorated_function(*args, **kwargs) - except DaemonError as daemon_error: - raise RPCError(1, daemon_error.args[0]) - return wrapper - - -class LBCDaemon(Daemon): - @handles_errors - async def getrawtransaction(self, hex_hash, verbose=False): - return await super().getrawtransaction(hex_hash=hex_hash, verbose=verbose) - - @handles_errors - async def getclaimbyid(self, claim_id): - '''Given a claim id, retrieves claim information.''' - return await self._send_single('getclaimbyid', (claim_id,)) - - @handles_errors - async def getclaimsbyids(self, claim_ids): - '''Given a list of claim ids, batches calls to retrieve claim information.''' - return await self._send_vector('getclaimbyid', ((claim_id,) for claim_id in claim_ids)) - - @handles_errors - async def getclaimsforname(self, name): - '''Given a name, retrieves all claims matching that name.''' - return await self._send_single('getclaimsforname', (name,)) - - @handles_errors - async def getclaimsfortx(self, txid): - '''Given a txid, returns the claims it make.''' - return await self._send_single('getclaimsfortx', (txid,)) or [] - - @handles_errors - async def getnameproof(self, name, block_hash=None): - '''Given a name and optional block_hash, returns a name proof and winner, if any.''' - return await self._send_single('getnameproof', (name, block_hash,) if block_hash else (name,)) - - @handles_errors - async def getvalueforname(self, name): - '''Given a name, returns the winning claim value.''' - return await self._send_single('getvalueforname', (name,)) - - @handles_errors - async def getnamesintrie(self): - '''Given a name, returns the winning claim value.''' - return await self._send_single('getnamesintrie') - - @handles_errors - async def claimname(self, name, hexvalue, amount): - '''Claim a name, used for functional tests only.''' - return await self._send_single('claimname', (name, hexvalue, float(amount))) diff --git a/lbry/wallet/server/db/__init__.py b/lbry/wallet/server/db/__init__.py deleted file mode 100644 index 7da046edc..000000000 --- a/lbry/wallet/server/db/__init__.py +++ /dev/null @@ -1,42 +0,0 @@ -import enum - - -@enum.unique -class DB_PREFIXES(enum.Enum): - claim_to_support = b'K' - support_to_claim = b'L' - - claim_to_txo = b'E' - txo_to_claim = b'G' - - claim_to_channel = b'I' - channel_to_claim = b'J' - - claim_short_id_prefix = b'F' - effective_amount = b'D' - claim_expiration = b'O' - - claim_takeover = b'P' - pending_activation = b'Q' - activated_claim_and_support = b'R' - active_amount = b'S' - - repost = b'V' - reposted_claim = b'W' - - undo = b'M' - claim_diff = b'Y' - - tx = b'B' - block_hash = b'C' - header = b'H' - tx_num = b'N' - tx_count = b'T' - tx_hash = b'X' - utxo = b'u' - hashx_utxo = b'h' - hashx_history = b'x' - db_state = b's' - channel_count = b'Z' - support_amount = b'a' - block_txs = b'b' diff --git a/lbry/wallet/server/db/common.py b/lbry/wallet/server/db/common.py deleted file mode 100644 index dce98711d..000000000 --- a/lbry/wallet/server/db/common.py +++ /dev/null @@ -1,447 +0,0 @@ -import typing - -CLAIM_TYPES = { - 'stream': 1, - 'channel': 2, - 'repost': 3, - 'collection': 4, -} - -STREAM_TYPES = { - 'video': 1, - 'audio': 2, - 'image': 3, - 'document': 4, - 'binary': 5, - 'model': 6, -} - -# 9/21/2020 -MOST_USED_TAGS = { - "gaming", - "people & blogs", - "entertainment", - "music", - "pop culture", - "education", - "technology", - "blockchain", - "news", - "funny", - "science & technology", - "learning", - "gameplay", - "news & politics", - "comedy", - "bitcoin", - "beliefs", - "nature", - "art", - "economics", - "film & animation", - "lets play", - "games", - "sports", - "howto & style", - "game", - "cryptocurrency", - "playstation 4", - "automotive", - "crypto", - "mature", - "sony interactive entertainment", - "walkthrough", - "tutorial", - "video game", - "weapons", - "playthrough", - "pc", - "anime", - "how to", - "btc", - "fun", - "ethereum", - "food", - "travel & events", - "minecraft", - "science", - "autos & vehicles", - "play", - "politics", - "commentary", - "twitch", - "ps4live", - "love", - "ps4", - "nonprofits & activism", - "ps4share", - "fortnite", - "xbox", - "porn", - "video games", - "trump", - "español", - "money", - "music video", - "nintendo", - "movie", - "coronavirus", - "donald trump", - "steam", - "trailer", - "android", - "podcast", - "xbox one", - "survival", - "audio", - "linux", - "travel", - "funny moments", - "litecoin", - "animation", - "gamer", - "lets", - "playstation", - "bitcoin news", - "history", - "xxx", - "fox news", - "dance", - "god", - "adventure", - "liberal", - "2020", - "horror", - "government", - "freedom", - "reaction", - "meme", - "photography", - "truth", - "health", - "lbry", - "family", - "online", - "eth", - "crypto news", - "diy", - "trading", - "gold", - "memes", - "world", - "space", - "lol", - "covid-19", - "rpg", - "humor", - "democrat", - "film", - "call of duty", - "tech", - "religion", - "conspiracy", - "rap", - "cnn", - "hangoutsonair", - "unboxing", - "fiction", - "conservative", - "cars", - "hoa", - "epic", - "programming", - "progressive", - "cryptocurrency news", - "classical", - "jesus", - "movies", - "book", - "ps3", - "republican", - "fitness", - "books", - "multiplayer", - "animals", - "pokemon", - "bitcoin price", - "facebook", - "sharefactory", - "criptomonedas", - "cod", - "bible", - "business", - "stream", - "comics", - "how", - "fail", - "nsfw", - "new music", - "satire", - "pets & animals", - "computer", - "classical music", - "indie", - "musica", - "msnbc", - "fps", - "mod", - "sport", - "sony", - "ripple", - "auto", - "rock", - "marvel", - "complete", - "mining", - "political", - "mobile", - "pubg", - "hip hop", - "flat earth", - "xbox 360", - "reviews", - "vlogging", - "latest news", - "hack", - "tarot", - "iphone", - "media", - "cute", - "christian", - "free speech", - "trap", - "war", - "remix", - "ios", - "xrp", - "spirituality", - "song", - "league of legends", - "cat" -} - -MATURE_TAGS = [ - 'nsfw', 'porn', 'xxx', 'mature', 'adult', 'sex' -] - - -def normalize_tag(tag): - return tag.replace(" ", "_").replace("&", "and").replace("-", "_") - - -COMMON_TAGS = { - tag: normalize_tag(tag) for tag in list(MOST_USED_TAGS) -} - -INDEXED_LANGUAGES = [ - 'none', - 'en', - 'aa', - 'ab', - 'ae', - 'af', - 'ak', - 'am', - 'an', - 'ar', - 'as', - 'av', - 'ay', - 'az', - 'ba', - 'be', - 'bg', - 'bh', - 'bi', - 'bm', - 'bn', - 'bo', - 'br', - 'bs', - 'ca', - 'ce', - 'ch', - 'co', - 'cr', - 'cs', - 'cu', - 'cv', - 'cy', - 'da', - 'de', - 'dv', - 'dz', - 'ee', - 'el', - 'eo', - 'es', - 'et', - 'eu', - 'fa', - 'ff', - 'fi', - 'fj', - 'fo', - 'fr', - 'fy', - 'ga', - 'gd', - 'gl', - 'gn', - 'gu', - 'gv', - 'ha', - 'he', - 'hi', - 'ho', - 'hr', - 'ht', - 'hu', - 'hy', - 'hz', - 'ia', - 'id', - 'ie', - 'ig', - 'ii', - 'ik', - 'io', - 'is', - 'it', - 'iu', - 'ja', - 'jv', - 'ka', - 'kg', - 'ki', - 'kj', - 'kk', - 'kl', - 'km', - 'kn', - 'ko', - 'kr', - 'ks', - 'ku', - 'kv', - 'kw', - 'ky', - 'la', - 'lb', - 'lg', - 'li', - 'ln', - 'lo', - 'lt', - 'lu', - 'lv', - 'mg', - 'mh', - 'mi', - 'mk', - 'ml', - 'mn', - 'mr', - 'ms', - 'mt', - 'my', - 'na', - 'nb', - 'nd', - 'ne', - 'ng', - 'nl', - 'nn', - 'no', - 'nr', - 'nv', - 'ny', - 'oc', - 'oj', - 'om', - 'or', - 'os', - 'pa', - 'pi', - 'pl', - 'ps', - 'pt', - 'qu', - 'rm', - 'rn', - 'ro', - 'ru', - 'rw', - 'sa', - 'sc', - 'sd', - 'se', - 'sg', - 'si', - 'sk', - 'sl', - 'sm', - 'sn', - 'so', - 'sq', - 'sr', - 'ss', - 'st', - 'su', - 'sv', - 'sw', - 'ta', - 'te', - 'tg', - 'th', - 'ti', - 'tk', - 'tl', - 'tn', - 'to', - 'tr', - 'ts', - 'tt', - 'tw', - 'ty', - 'ug', - 'uk', - 'ur', - 'uz', - 've', - 'vi', - 'vo', - 'wa', - 'wo', - 'xh', - 'yi', - 'yo', - 'za', - 'zh', - 'zu' -] - - -class ResolveResult(typing.NamedTuple): - name: str - normalized_name: str - claim_hash: bytes - tx_num: int - position: int - tx_hash: bytes - height: int - amount: int - short_url: str - is_controlling: bool - canonical_url: str - creation_height: int - activation_height: int - expiration_height: int - effective_amount: int - support_amount: int - reposted: int - last_takeover_height: typing.Optional[int] - claims_in_channel: typing.Optional[int] - channel_hash: typing.Optional[bytes] - reposted_claim_hash: typing.Optional[bytes] - signature_valid: typing.Optional[bool] diff --git a/lbry/wallet/server/db/db.py b/lbry/wallet/server/db/db.py deleted file mode 100644 index 6d613df93..000000000 --- a/lbry/wallet/server/db/db.py +++ /dev/null @@ -1,119 +0,0 @@ -import struct -from typing import Optional -from lbry.wallet.server.db import DB_PREFIXES -from lbry.wallet.server.db.revertable import RevertableOpStack, RevertablePut, RevertableDelete - - -class KeyValueStorage: - def get(self, key: bytes, fill_cache: bool = True) -> Optional[bytes]: - raise NotImplemented() - - def iterator(self, reverse=False, start=None, stop=None, include_start=True, include_stop=False, prefix=None, - include_key=True, include_value=True, fill_cache=True): - raise NotImplemented() - - def write_batch(self, transaction: bool = False): - raise NotImplemented() - - def close(self): - raise NotImplemented() - - @property - def closed(self) -> bool: - raise NotImplemented() - - -class PrefixDB: - UNDO_KEY_STRUCT = struct.Struct(b'>Q') - - def __init__(self, db: KeyValueStorage, max_undo_depth: int = 200, unsafe_prefixes=None): - self._db = db - self._op_stack = RevertableOpStack(db.get, unsafe_prefixes=unsafe_prefixes) - self._max_undo_depth = max_undo_depth - - def unsafe_commit(self): - """ - Write staged changes to the database without keeping undo information - Changes written cannot be undone - """ - try: - with self._db.write_batch(transaction=True) as batch: - batch_put = batch.put - batch_delete = batch.delete - for staged_change in self._op_stack: - if staged_change.is_put: - batch_put(staged_change.key, staged_change.value) - else: - batch_delete(staged_change.key) - finally: - self._op_stack.clear() - - def commit(self, height: int): - """ - Write changes for a block height to the database and keep undo information so that the changes can be reverted - """ - undo_ops = self._op_stack.get_undo_ops() - delete_undos = [] - if height > self._max_undo_depth: - delete_undos.extend(self._db.iterator( - start=DB_PREFIXES.undo.value + self.UNDO_KEY_STRUCT.pack(0), - stop=DB_PREFIXES.undo.value + self.UNDO_KEY_STRUCT.pack(height - self._max_undo_depth), - include_value=False - )) - try: - with self._db.write_batch(transaction=True) as batch: - batch_put = batch.put - batch_delete = batch.delete - for staged_change in self._op_stack: - if staged_change.is_put: - batch_put(staged_change.key, staged_change.value) - else: - batch_delete(staged_change.key) - for undo_to_delete in delete_undos: - batch_delete(undo_to_delete) - batch_put(DB_PREFIXES.undo.value + self.UNDO_KEY_STRUCT.pack(height), undo_ops) - finally: - self._op_stack.clear() - - def rollback(self, height: int): - """ - Revert changes for a block height - """ - undo_key = DB_PREFIXES.undo.value + self.UNDO_KEY_STRUCT.pack(height) - self._op_stack.apply_packed_undo_ops(self._db.get(undo_key)) - try: - with self._db.write_batch(transaction=True) as batch: - batch_put = batch.put - batch_delete = batch.delete - for staged_change in self._op_stack: - if staged_change.is_put: - batch_put(staged_change.key, staged_change.value) - else: - batch_delete(staged_change.key) - batch_delete(undo_key) - finally: - self._op_stack.clear() - - def get(self, key: bytes, fill_cache: bool = True) -> Optional[bytes]: - return self._db.get(key, fill_cache=fill_cache) - - def iterator(self, reverse=False, start=None, stop=None, include_start=True, include_stop=False, prefix=None, - include_key=True, include_value=True, fill_cache=True): - return self._db.iterator( - reverse=reverse, start=start, stop=stop, include_start=include_start, include_stop=include_stop, - prefix=prefix, include_key=include_key, include_value=include_value, fill_cache=fill_cache - ) - - def close(self): - if not self._db.closed: - self._db.close() - - @property - def closed(self): - return self._db.closed - - def stage_raw_put(self, key: bytes, value: bytes): - self._op_stack.append_op(RevertablePut(key, value)) - - def stage_raw_delete(self, key: bytes, value: bytes): - self._op_stack.append_op(RevertableDelete(key, value)) diff --git a/lbry/wallet/server/db/elasticsearch/__init__.py b/lbry/wallet/server/db/elasticsearch/__init__.py deleted file mode 100644 index 385e96219..000000000 --- a/lbry/wallet/server/db/elasticsearch/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .search import SearchIndex \ No newline at end of file diff --git a/lbry/wallet/server/db/elasticsearch/constants.py b/lbry/wallet/server/db/elasticsearch/constants.py deleted file mode 100644 index afdfd6fe1..000000000 --- a/lbry/wallet/server/db/elasticsearch/constants.py +++ /dev/null @@ -1,100 +0,0 @@ -INDEX_DEFAULT_SETTINGS = { - "settings": - {"analysis": - {"analyzer": { - "default": {"tokenizer": "whitespace", "filter": ["lowercase", "porter_stem"]}}}, - "index": - {"refresh_interval": -1, - "number_of_shards": 1, - "number_of_replicas": 0, - "sort": { - "field": ["trending_score", "release_time"], - "order": ["desc", "desc"] - }} - }, - "mappings": { - "properties": { - "claim_id": { - "fields": { - "keyword": { - "ignore_above": 256, - "type": "keyword" - } - }, - "type": "text", - "index_prefixes": { - "min_chars": 1, - "max_chars": 10 - } - }, - "sd_hash": { - "fields": { - "keyword": { - "ignore_above": 96, - "type": "keyword" - } - }, - "type": "text", - "index_prefixes": { - "min_chars": 1, - "max_chars": 4 - } - }, - "height": {"type": "integer"}, - "claim_type": {"type": "byte"}, - "censor_type": {"type": "byte"}, - "trending_score": {"type": "double"}, - "release_time": {"type": "long"} - } - } -} - -FIELDS = { - '_id', - 'claim_id', 'claim_type', 'claim_name', 'normalized_name', - 'tx_id', 'tx_nout', 'tx_position', - 'short_url', 'canonical_url', - 'is_controlling', 'last_take_over_height', - 'public_key_bytes', 'public_key_id', 'claims_in_channel', - 'channel_id', 'signature', 'signature_digest', 'is_signature_valid', - 'amount', 'effective_amount', 'support_amount', - 'fee_amount', 'fee_currency', - 'height', 'creation_height', 'activation_height', 'expiration_height', - 'stream_type', 'media_type', 'censor_type', - 'title', 'author', 'description', - 'timestamp', 'creation_timestamp', - 'duration', 'release_time', - 'tags', 'languages', 'has_source', 'reposted_claim_type', - 'reposted_claim_id', 'repost_count', 'sd_hash', - 'trending_score', 'tx_num' -} - -TEXT_FIELDS = {'author', 'canonical_url', 'channel_id', 'description', 'claim_id', 'censoring_channel_id', - 'media_type', 'normalized_name', 'public_key_bytes', 'public_key_id', 'short_url', 'signature', - 'claim_name', 'signature_digest', 'title', 'tx_id', 'fee_currency', 'reposted_claim_id', - 'tags', 'sd_hash'} - -RANGE_FIELDS = { - 'height', 'creation_height', 'activation_height', 'expiration_height', - 'timestamp', 'creation_timestamp', 'duration', 'release_time', 'fee_amount', - 'tx_position', 'repost_count', 'limit_claims_per_channel', - 'amount', 'effective_amount', 'support_amount', - 'trending_score', 'censor_type', 'tx_num' -} - -ALL_FIELDS = RANGE_FIELDS | TEXT_FIELDS | FIELDS - -REPLACEMENTS = { - 'claim_name': 'normalized_name', - 'name': 'normalized_name', - 'txid': 'tx_id', - 'nout': 'tx_nout', - 'trending_group': 'trending_score', - 'trending_mixed': 'trending_score', - 'trending_global': 'trending_score', - 'trending_local': 'trending_score', - 'reposted': 'repost_count', - 'stream_types': 'stream_type', - 'media_types': 'media_type', - 'valid_channel_signature': 'is_signature_valid' -} diff --git a/lbry/wallet/server/db/elasticsearch/search.py b/lbry/wallet/server/db/elasticsearch/search.py deleted file mode 100644 index 3111155a9..000000000 --- a/lbry/wallet/server/db/elasticsearch/search.py +++ /dev/null @@ -1,726 +0,0 @@ -import time -import asyncio -import struct -from binascii import unhexlify -from collections import Counter, deque -from decimal import Decimal -from operator import itemgetter -from typing import Optional, List, Iterable, Union - -from elasticsearch import AsyncElasticsearch, NotFoundError, ConnectionError -from elasticsearch.helpers import async_streaming_bulk -from lbry.error import ResolveCensoredError, TooManyClaimSearchParametersError -from lbry.schema.result import Outputs, Censor -from lbry.schema.tags import clean_tags -from lbry.schema.url import URL, normalize_name -from lbry.utils import LRUCache -from lbry.wallet.server.db.common import CLAIM_TYPES, STREAM_TYPES -from lbry.wallet.server.db.elasticsearch.constants import INDEX_DEFAULT_SETTINGS, REPLACEMENTS, FIELDS, TEXT_FIELDS, \ - RANGE_FIELDS, ALL_FIELDS -from lbry.wallet.server.util import class_logger -from lbry.wallet.server.db.common import ResolveResult - - -class ChannelResolution(str): - @classmethod - def lookup_error(cls, url): - return LookupError(f'Could not find channel in "{url}".') - - -class StreamResolution(str): - @classmethod - def lookup_error(cls, url): - return LookupError(f'Could not find claim at "{url}".') - - -class IndexVersionMismatch(Exception): - def __init__(self, got_version, expected_version): - self.got_version = got_version - self.expected_version = expected_version - - -class SearchIndex: - VERSION = 1 - - def __init__(self, index_prefix: str, search_timeout=3.0, elastic_host='localhost', elastic_port=9200): - self.search_timeout = search_timeout - self.sync_timeout = 600 # wont hit that 99% of the time, but can hit on a fresh import - self.search_client: Optional[AsyncElasticsearch] = None - self.sync_client: Optional[AsyncElasticsearch] = None - self.index = index_prefix + 'claims' - self.logger = class_logger(__name__, self.__class__.__name__) - self.claim_cache = LRUCache(2 ** 15) - self.search_cache = LRUCache(2 ** 17) - self._elastic_host = elastic_host - self._elastic_port = elastic_port - - async def get_index_version(self) -> int: - try: - template = await self.sync_client.indices.get_template(self.index) - return template[self.index]['version'] - except NotFoundError: - return 0 - - async def set_index_version(self, version): - await self.sync_client.indices.put_template( - self.index, body={'version': version, 'index_patterns': ['ignored']}, ignore=400 - ) - - async def start(self) -> bool: - if self.sync_client: - return False - hosts = [{'host': self._elastic_host, 'port': self._elastic_port}] - self.sync_client = AsyncElasticsearch(hosts, timeout=self.sync_timeout) - self.search_client = AsyncElasticsearch(hosts, timeout=self.search_timeout) - while True: - try: - await self.sync_client.cluster.health(wait_for_status='yellow') - break - except ConnectionError: - self.logger.warning("Failed to connect to Elasticsearch. Waiting for it!") - await asyncio.sleep(1) - - res = await self.sync_client.indices.create(self.index, INDEX_DEFAULT_SETTINGS, ignore=400) - acked = res.get('acknowledged', False) - if acked: - await self.set_index_version(self.VERSION) - return acked - index_version = await self.get_index_version() - if index_version != self.VERSION: - self.logger.error("es search index has an incompatible version: %s vs %s", index_version, self.VERSION) - raise IndexVersionMismatch(index_version, self.VERSION) - await self.sync_client.indices.refresh(self.index) - return acked - - def stop(self): - clients = [self.sync_client, self.search_client] - self.sync_client, self.search_client = None, None - return asyncio.ensure_future(asyncio.gather(*(client.close() for client in clients))) - - def delete_index(self): - return self.sync_client.indices.delete(self.index, ignore_unavailable=True) - - async def _consume_claim_producer(self, claim_producer): - count = 0 - async for op, doc in claim_producer: - if op == 'delete': - yield { - '_index': self.index, - '_op_type': 'delete', - '_id': doc - } - else: - yield { - 'doc': {key: value for key, value in doc.items() if key in ALL_FIELDS}, - '_id': doc['claim_id'], - '_index': self.index, - '_op_type': 'update', - 'doc_as_upsert': True - } - count += 1 - if count % 100 == 0: - self.logger.info("Indexing in progress, %d claims.", count) - if count: - self.logger.info("Indexing done for %d claims.", count) - else: - self.logger.debug("Indexing done for %d claims.", count) - - async def claim_consumer(self, claim_producer): - touched = set() - async for ok, item in async_streaming_bulk(self.sync_client, self._consume_claim_producer(claim_producer), - raise_on_error=False): - if not ok: - self.logger.warning("indexing failed for an item: %s", item) - else: - item = item.popitem()[1] - touched.add(item['_id']) - await self.sync_client.indices.refresh(self.index) - self.logger.debug("Indexing done.") - - def update_filter_query(self, censor_type, blockdict, channels=False): - blockdict = {blocked.hex(): blocker.hex() for blocked, blocker in blockdict.items()} - if channels: - update = expand_query(channel_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}") - else: - update = expand_query(claim_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}") - key = 'channel_id' if channels else 'claim_id' - update['script'] = { - "source": f"ctx._source.censor_type={censor_type}; " - f"ctx._source.censoring_channel_id=params[ctx._source.{key}];", - "lang": "painless", - "params": blockdict - } - return update - - async def update_trending_score(self, params): - update_trending_score_script = """ - double softenLBC(double lbc) { return (Math.pow(lbc, 1.0 / 3.0)); } - - double logsumexp(double x, double y) - { - double top; - if(x > y) - top = x; - else - top = y; - double result = top + Math.log(Math.exp(x-top) + Math.exp(y-top)); - return(result); - } - - double logdiffexp(double big, double small) - { - return big + Math.log(1.0 - Math.exp(small - big)); - } - - double squash(double x) - { - if(x < 0.0) - return -Math.log(1.0 - x); - else - return Math.log(x + 1.0); - } - - double unsquash(double x) - { - if(x < 0.0) - return 1.0 - Math.exp(-x); - else - return Math.exp(x) - 1.0; - } - - double log_to_squash(double x) - { - return logsumexp(x, 0.0); - } - - double squash_to_log(double x) - { - //assert x > 0.0; - return logdiffexp(x, 0.0); - } - - double squashed_add(double x, double y) - { - // squash(unsquash(x) + unsquash(y)) but avoiding overflow. - // Cases where the signs are the same - if (x < 0.0 && y < 0.0) - return -logsumexp(-x, logdiffexp(-y, 0.0)); - if (x >= 0.0 && y >= 0.0) - return logsumexp(x, logdiffexp(y, 0.0)); - // Where the signs differ - if (x >= 0.0 && y < 0.0) - if (Math.abs(x) >= Math.abs(y)) - return logsumexp(0.0, logdiffexp(x, -y)); - else - return -logsumexp(0.0, logdiffexp(-y, x)); - if (x < 0.0 && y >= 0.0) - { - // Addition is commutative, hooray for new math - return squashed_add(y, x); - } - return 0.0; - } - - double squashed_multiply(double x, double y) - { - // squash(unsquash(x)*unsquash(y)) but avoiding overflow. - int sign; - if(x*y >= 0.0) - sign = 1; - else - sign = -1; - return sign*logsumexp(squash_to_log(Math.abs(x)) - + squash_to_log(Math.abs(y)), 0.0); - } - - // Squashed inflated units - double inflateUnits(int height) { - double timescale = 576.0; // Half life of 400 = e-folding time of a day - // by coincidence, so may as well go with it - return log_to_squash(height / timescale); - } - - double spikePower(double newAmount) { - if (newAmount < 50.0) { - return(0.5); - } else if (newAmount < 85.0) { - return(newAmount / 100.0); - } else { - return(0.85); - } - } - - double spikeMass(double oldAmount, double newAmount) { - double softenedChange = softenLBC(Math.abs(newAmount - oldAmount)); - double changeInSoftened = Math.abs(softenLBC(newAmount) - softenLBC(oldAmount)); - double power = spikePower(newAmount); - if (oldAmount > newAmount) { - -1.0 * Math.pow(changeInSoftened, power) * Math.pow(softenedChange, 1.0 - power) - } else { - Math.pow(changeInSoftened, power) * Math.pow(softenedChange, 1.0 - power) - } - } - for (i in params.src.changes) { - double units = inflateUnits(i.height); - if (ctx._source.trending_score == null) { - ctx._source.trending_score = 0.0; - } - double bigSpike = squashed_multiply(units, squash(spikeMass(i.prev_amount, i.new_amount))); - ctx._source.trending_score = squashed_add(ctx._source.trending_score, bigSpike); - } - """ - start = time.perf_counter() - - def producer(): - for claim_id, claim_updates in params.items(): - yield { - '_id': claim_id, - '_index': self.index, - '_op_type': 'update', - 'script': { - 'lang': 'painless', - 'source': update_trending_score_script, - 'params': {'src': { - 'changes': [ - { - 'height': p.height, - 'prev_amount': p.prev_amount / 1E8, - 'new_amount': p.new_amount / 1E8, - } for p in claim_updates - ] - }} - }, - } - if not params: - return - async for ok, item in async_streaming_bulk(self.sync_client, producer(), raise_on_error=False): - if not ok: - self.logger.warning("updating trending failed for an item: %s", item) - await self.sync_client.indices.refresh(self.index) - self.logger.info("updated trending scores in %ims", int((time.perf_counter() - start) * 1000)) - - async def apply_filters(self, blocked_streams, blocked_channels, filtered_streams, filtered_channels): - if filtered_streams: - await self.sync_client.update_by_query( - self.index, body=self.update_filter_query(Censor.SEARCH, filtered_streams), slices=4) - await self.sync_client.indices.refresh(self.index) - if filtered_channels: - await self.sync_client.update_by_query( - self.index, body=self.update_filter_query(Censor.SEARCH, filtered_channels), slices=4) - await self.sync_client.indices.refresh(self.index) - await self.sync_client.update_by_query( - self.index, body=self.update_filter_query(Censor.SEARCH, filtered_channels, True), slices=4) - await self.sync_client.indices.refresh(self.index) - if blocked_streams: - await self.sync_client.update_by_query( - self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_streams), slices=4) - await self.sync_client.indices.refresh(self.index) - if blocked_channels: - await self.sync_client.update_by_query( - self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels), slices=4) - await self.sync_client.indices.refresh(self.index) - await self.sync_client.update_by_query( - self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels, True), slices=4) - await self.sync_client.indices.refresh(self.index) - self.clear_caches() - - def clear_caches(self): - self.search_cache.clear() - self.claim_cache.clear() - - async def cached_search(self, kwargs): - total_referenced = [] - cache_item = ResultCacheItem.from_cache(str(kwargs), self.search_cache) - if cache_item.result is not None: - return cache_item.result - async with cache_item.lock: - if cache_item.result: - return cache_item.result - censor = Censor(Censor.SEARCH) - if kwargs.get('no_totals'): - response, offset, total = await self.search(**kwargs, censor_type=Censor.NOT_CENSORED) - else: - response, offset, total = await self.search(**kwargs) - censor.apply(response) - total_referenced.extend(response) - - if censor.censored: - response, _, _ = await self.search(**kwargs, censor_type=Censor.NOT_CENSORED) - total_referenced.extend(response) - response = [ - ResolveResult( - name=r['claim_name'], - normalized_name=r['normalized_name'], - claim_hash=r['claim_hash'], - tx_num=r['tx_num'], - position=r['tx_nout'], - tx_hash=r['tx_hash'], - height=r['height'], - amount=r['amount'], - short_url=r['short_url'], - is_controlling=r['is_controlling'], - canonical_url=r['canonical_url'], - creation_height=r['creation_height'], - activation_height=r['activation_height'], - expiration_height=r['expiration_height'], - effective_amount=r['effective_amount'], - support_amount=r['support_amount'], - last_takeover_height=r['last_take_over_height'], - claims_in_channel=r['claims_in_channel'], - channel_hash=r['channel_hash'], - reposted_claim_hash=r['reposted_claim_hash'], - reposted=r['reposted'], - signature_valid=r['signature_valid'] - ) for r in response - ] - extra = [ - ResolveResult( - name=r['claim_name'], - normalized_name=r['normalized_name'], - claim_hash=r['claim_hash'], - tx_num=r['tx_num'], - position=r['tx_nout'], - tx_hash=r['tx_hash'], - height=r['height'], - amount=r['amount'], - short_url=r['short_url'], - is_controlling=r['is_controlling'], - canonical_url=r['canonical_url'], - creation_height=r['creation_height'], - activation_height=r['activation_height'], - expiration_height=r['expiration_height'], - effective_amount=r['effective_amount'], - support_amount=r['support_amount'], - last_takeover_height=r['last_take_over_height'], - claims_in_channel=r['claims_in_channel'], - channel_hash=r['channel_hash'], - reposted_claim_hash=r['reposted_claim_hash'], - reposted=r['reposted'], - signature_valid=r['signature_valid'] - ) for r in await self._get_referenced_rows(total_referenced) - ] - result = Outputs.to_base64( - response, extra, offset, total, censor - ) - cache_item.result = result - return result - - async def get_many(self, *claim_ids): - await self.populate_claim_cache(*claim_ids) - return filter(None, map(self.claim_cache.get, claim_ids)) - - async def populate_claim_cache(self, *claim_ids): - missing = [claim_id for claim_id in claim_ids if self.claim_cache.get(claim_id) is None] - if missing: - results = await self.search_client.mget( - index=self.index, body={"ids": missing} - ) - for result in expand_result(filter(lambda doc: doc['found'], results["docs"])): - self.claim_cache.set(result['claim_id'], result) - - - async def search(self, **kwargs): - try: - return await self.search_ahead(**kwargs) - except NotFoundError: - return [], 0, 0 - # return expand_result(result['hits']), 0, result.get('total', {}).get('value', 0) - - async def search_ahead(self, **kwargs): - # 'limit_claims_per_channel' case. Fetch 1000 results, reorder, slice, inflate and return - per_channel_per_page = kwargs.pop('limit_claims_per_channel', 0) or 0 - remove_duplicates = kwargs.pop('remove_duplicates', False) - page_size = kwargs.pop('limit', 10) - offset = kwargs.pop('offset', 0) - kwargs['limit'] = 1000 - cache_item = ResultCacheItem.from_cache(f"ahead{per_channel_per_page}{kwargs}", self.search_cache) - if cache_item.result is not None: - reordered_hits = cache_item.result - else: - async with cache_item.lock: - if cache_item.result: - reordered_hits = cache_item.result - else: - query = expand_query(**kwargs) - search_hits = deque((await self.search_client.search( - query, index=self.index, track_total_hits=False, - _source_includes=['_id', 'channel_id', 'reposted_claim_id', 'creation_height'] - ))['hits']['hits']) - if remove_duplicates: - search_hits = self.__remove_duplicates(search_hits) - if per_channel_per_page > 0: - reordered_hits = self.__search_ahead(search_hits, page_size, per_channel_per_page) - else: - reordered_hits = [(hit['_id'], hit['_source']['channel_id']) for hit in search_hits] - cache_item.result = reordered_hits - result = list(await self.get_many(*(claim_id for claim_id, _ in reordered_hits[offset:(offset + page_size)]))) - return result, 0, len(reordered_hits) - - def __remove_duplicates(self, search_hits: deque) -> deque: - known_ids = {} # claim_id -> (creation_height, hit_id), where hit_id is either reposted claim id or original - dropped = set() - for hit in search_hits: - hit_height, hit_id = hit['_source']['creation_height'], hit['_source']['reposted_claim_id'] or hit['_id'] - if hit_id not in known_ids: - known_ids[hit_id] = (hit_height, hit['_id']) - else: - previous_height, previous_id = known_ids[hit_id] - if hit_height < previous_height: - known_ids[hit_id] = (hit_height, hit['_id']) - dropped.add(previous_id) - else: - dropped.add(hit['_id']) - return deque(hit for hit in search_hits if hit['_id'] not in dropped) - - def __search_ahead(self, search_hits: list, page_size: int, per_channel_per_page: int): - reordered_hits = [] - channel_counters = Counter() - next_page_hits_maybe_check_later = deque() - while search_hits or next_page_hits_maybe_check_later: - if reordered_hits and len(reordered_hits) % page_size == 0: - channel_counters.clear() - elif not reordered_hits: - pass - else: - break # means last page was incomplete and we are left with bad replacements - for _ in range(len(next_page_hits_maybe_check_later)): - claim_id, channel_id = next_page_hits_maybe_check_later.popleft() - if per_channel_per_page > 0 and channel_counters[channel_id] < per_channel_per_page: - reordered_hits.append((claim_id, channel_id)) - channel_counters[channel_id] += 1 - else: - next_page_hits_maybe_check_later.append((claim_id, channel_id)) - while search_hits: - hit = search_hits.popleft() - hit_id, hit_channel_id = hit['_id'], hit['_source']['channel_id'] - if hit_channel_id is None or per_channel_per_page <= 0: - reordered_hits.append((hit_id, hit_channel_id)) - elif channel_counters[hit_channel_id] < per_channel_per_page: - reordered_hits.append((hit_id, hit_channel_id)) - channel_counters[hit_channel_id] += 1 - if len(reordered_hits) % page_size == 0: - break - else: - next_page_hits_maybe_check_later.append((hit_id, hit_channel_id)) - return reordered_hits - - async def _get_referenced_rows(self, txo_rows: List[dict]): - txo_rows = [row for row in txo_rows if isinstance(row, dict)] - referenced_ids = set(filter(None, map(itemgetter('reposted_claim_id'), txo_rows))) - referenced_ids |= set(filter(None, (row['channel_id'] for row in txo_rows))) - referenced_ids |= set(filter(None, (row['censoring_channel_id'] for row in txo_rows))) - - referenced_txos = [] - if referenced_ids: - referenced_txos.extend(await self.get_many(*referenced_ids)) - referenced_ids = set(filter(None, (row['channel_id'] for row in referenced_txos))) - - if referenced_ids: - referenced_txos.extend(await self.get_many(*referenced_ids)) - - return referenced_txos - - -def expand_query(**kwargs): - if "amount_order" in kwargs: - kwargs["limit"] = 1 - kwargs["order_by"] = "effective_amount" - kwargs["offset"] = int(kwargs["amount_order"]) - 1 - if 'name' in kwargs: - kwargs['name'] = normalize_name(kwargs.pop('name')) - if kwargs.get('is_controlling') is False: - kwargs.pop('is_controlling') - query = {'must': [], 'must_not': []} - collapse = None - if 'fee_currency' in kwargs and kwargs['fee_currency'] is not None: - kwargs['fee_currency'] = kwargs['fee_currency'].upper() - for key, value in kwargs.items(): - key = key.replace('claim.', '') - many = key.endswith('__in') or isinstance(value, list) - if many and len(value) > 2048: - raise TooManyClaimSearchParametersError(key, 2048) - if many: - key = key.replace('__in', '') - value = list(filter(None, value)) - if value is None or isinstance(value, list) and len(value) == 0: - continue - key = REPLACEMENTS.get(key, key) - if key in FIELDS: - partial_id = False - if key == 'claim_type': - if isinstance(value, str): - value = CLAIM_TYPES[value] - else: - value = [CLAIM_TYPES[claim_type] for claim_type in value] - elif key == 'stream_type': - value = [STREAM_TYPES[value]] if isinstance(value, str) else list(map(STREAM_TYPES.get, value)) - if key == '_id': - if isinstance(value, Iterable): - value = [item[::-1].hex() for item in value] - else: - value = value[::-1].hex() - if not many and key in ('_id', 'claim_id', 'sd_hash') and len(value) < 20: - partial_id = True - if key in ('signature_valid', 'has_source'): - continue # handled later - if key in TEXT_FIELDS: - key += '.keyword' - ops = {'<=': 'lte', '>=': 'gte', '<': 'lt', '>': 'gt'} - if partial_id: - query['must'].append({"prefix": {key: value}}) - elif key in RANGE_FIELDS and isinstance(value, str) and value[0] in ops: - operator_length = 2 if value[:2] in ops else 1 - operator, value = value[:operator_length], value[operator_length:] - if key == 'fee_amount': - value = str(Decimal(value)*1000) - query['must'].append({"range": {key: {ops[operator]: value}}}) - elif key in RANGE_FIELDS and isinstance(value, list) and all(v[0] in ops for v in value): - range_constraints = [] - for v in value: - operator_length = 2 if v[:2] in ops else 1 - operator, stripped_op_v = v[:operator_length], v[operator_length:] - if key == 'fee_amount': - stripped_op_v = str(Decimal(stripped_op_v)*1000) - range_constraints.append((operator, stripped_op_v)) - query['must'].append({"range": {key: {ops[operator]: v for operator, v in range_constraints}}}) - elif many: - query['must'].append({"terms": {key: value}}) - else: - if key == 'fee_amount': - value = str(Decimal(value)*1000) - query['must'].append({"term": {key: {"value": value}}}) - elif key == 'not_channel_ids': - for channel_id in value: - query['must_not'].append({"term": {'channel_id.keyword': channel_id}}) - query['must_not'].append({"term": {'_id': channel_id}}) - elif key == 'channel_ids': - query['must'].append({"terms": {'channel_id.keyword': value}}) - elif key == 'claim_ids': - query['must'].append({"terms": {'claim_id.keyword': value}}) - elif key == 'media_types': - query['must'].append({"terms": {'media_type.keyword': value}}) - elif key == 'any_languages': - query['must'].append({"terms": {'languages': clean_tags(value)}}) - elif key == 'any_languages': - query['must'].append({"terms": {'languages': value}}) - elif key == 'all_languages': - query['must'].extend([{"term": {'languages': tag}} for tag in value]) - elif key == 'any_tags': - query['must'].append({"terms": {'tags.keyword': clean_tags(value)}}) - elif key == 'all_tags': - query['must'].extend([{"term": {'tags.keyword': tag}} for tag in clean_tags(value)]) - elif key == 'not_tags': - query['must_not'].extend([{"term": {'tags.keyword': tag}} for tag in clean_tags(value)]) - elif key == 'not_claim_id': - query['must_not'].extend([{"term": {'claim_id.keyword': cid}} for cid in value]) - elif key == 'limit_claims_per_channel': - collapse = ('channel_id.keyword', value) - if kwargs.get('has_channel_signature'): - query['must'].append({"exists": {"field": "signature"}}) - if 'signature_valid' in kwargs: - query['must'].append({"term": {"is_signature_valid": bool(kwargs["signature_valid"])}}) - elif 'signature_valid' in kwargs: - query.setdefault('should', []) - query["minimum_should_match"] = 1 - query['should'].append({"bool": {"must_not": {"exists": {"field": "signature"}}}}) - query['should'].append({"term": {"is_signature_valid": bool(kwargs["signature_valid"])}}) - if 'has_source' in kwargs: - query.setdefault('should', []) - query["minimum_should_match"] = 1 - is_stream_or_repost = {"terms": {"claim_type": [CLAIM_TYPES['stream'], CLAIM_TYPES['repost']]}} - query['should'].append( - {"bool": {"must": [{"match": {"has_source": kwargs['has_source']}}, is_stream_or_repost]}}) - query['should'].append({"bool": {"must_not": [is_stream_or_repost]}}) - query['should'].append({"bool": {"must": [{"term": {"reposted_claim_type": CLAIM_TYPES['channel']}}]}}) - if kwargs.get('text'): - query['must'].append( - {"simple_query_string": - {"query": kwargs["text"], "fields": [ - "claim_name^4", "channel_name^8", "title^1", "description^.5", "author^1", "tags^.5" - ]}}) - query = { - "_source": {"excludes": ["description", "title"]}, - 'query': {'bool': query}, - "sort": [], - } - if "limit" in kwargs: - query["size"] = kwargs["limit"] - if 'offset' in kwargs: - query["from"] = kwargs["offset"] - if 'order_by' in kwargs: - if isinstance(kwargs["order_by"], str): - kwargs["order_by"] = [kwargs["order_by"]] - for value in kwargs['order_by']: - if 'trending_group' in value: - # fixme: trending_mixed is 0 for all records on variable decay, making sort slow. - continue - is_asc = value.startswith('^') - value = value[1:] if is_asc else value - value = REPLACEMENTS.get(value, value) - if value in TEXT_FIELDS: - value += '.keyword' - query['sort'].append({value: "asc" if is_asc else "desc"}) - if collapse: - query["collapse"] = { - "field": collapse[0], - "inner_hits": { - "name": collapse[0], - "size": collapse[1], - "sort": query["sort"] - } - } - return query - - -def expand_result(results): - inner_hits = [] - expanded = [] - for result in results: - if result.get("inner_hits"): - for _, inner_hit in result["inner_hits"].items(): - inner_hits.extend(inner_hit["hits"]["hits"]) - continue - result = result['_source'] - result['claim_hash'] = unhexlify(result['claim_id'])[::-1] - if result['reposted_claim_id']: - result['reposted_claim_hash'] = unhexlify(result['reposted_claim_id'])[::-1] - else: - result['reposted_claim_hash'] = None - result['channel_hash'] = unhexlify(result['channel_id'])[::-1] if result['channel_id'] else None - result['txo_hash'] = unhexlify(result['tx_id'])[::-1] + struct.pack(' str: - return self._result - - @result.setter - def result(self, result: str): - self._result = result - if result is not None: - self.has_result.set() - - @classmethod - def from_cache(cls, cache_key, cache): - cache_item = cache.get(cache_key) - if cache_item is None: - cache_item = cache[cache_key] = ResultCacheItem() - return cache_item diff --git a/lbry/wallet/server/db/elasticsearch/sync.py b/lbry/wallet/server/db/elasticsearch/sync.py deleted file mode 100644 index d34c88d80..000000000 --- a/lbry/wallet/server/db/elasticsearch/sync.py +++ /dev/null @@ -1,138 +0,0 @@ -import os -import argparse -import asyncio -import logging -from elasticsearch import AsyncElasticsearch -from elasticsearch.helpers import async_streaming_bulk -from lbry.wallet.server.env import Env -from lbry.wallet.server.leveldb import LevelDB -from lbry.wallet.server.db.elasticsearch.search import SearchIndex, IndexVersionMismatch -from lbry.wallet.server.db.elasticsearch.constants import ALL_FIELDS - - -async def get_recent_claims(env, index_name='claims', db=None): - log = logging.getLogger() - need_open = db is None - db = db or LevelDB(env) - try: - if need_open: - db.open_db() - if db.es_sync_height == db.db_height or db.db_height <= 0: - return - if need_open: - await db.initialize_caches() - log.info(f"catching up ES ({db.es_sync_height}) to leveldb height: {db.db_height}") - cnt = 0 - touched_claims = set() - deleted_claims = set() - for height in range(db.es_sync_height, db.db_height + 1): - touched_or_deleted = db.prefix_db.touched_or_deleted.get(height) - touched_claims.update(touched_or_deleted.touched_claims) - deleted_claims.update(touched_or_deleted.deleted_claims) - touched_claims.difference_update(deleted_claims) - - for deleted in deleted_claims: - yield { - '_index': index_name, - '_op_type': 'delete', - '_id': deleted.hex() - } - for touched in touched_claims: - claim = db.claim_producer(touched) - if claim: - yield { - 'doc': {key: value for key, value in claim.items() if key in ALL_FIELDS}, - '_id': claim['claim_id'], - '_index': index_name, - '_op_type': 'update', - 'doc_as_upsert': True - } - cnt += 1 - else: - logging.warning("could not sync claim %s", touched.hex()) - if cnt % 10000 == 0: - logging.info("%i claims sent to ES", cnt) - - db.es_sync_height = db.db_height - db.write_db_state() - db.prefix_db.unsafe_commit() - db.assert_db_state() - - logging.info("finished sending %i claims to ES, deleted %i", cnt, len(deleted_claims)) - finally: - if need_open: - db.close() - - -async def get_all_claims(env, index_name='claims', db=None): - need_open = db is None - db = db or LevelDB(env) - if need_open: - db.open_db() - await db.initialize_caches() - logging.info("Fetching claims to send ES from leveldb") - try: - cnt = 0 - async for claim in db.all_claims_producer(): - yield { - 'doc': {key: value for key, value in claim.items() if key in ALL_FIELDS}, - '_id': claim['claim_id'], - '_index': index_name, - '_op_type': 'update', - 'doc_as_upsert': True - } - cnt += 1 - if cnt % 10000 == 0: - logging.info("sent %i claims to ES", cnt) - finally: - if need_open: - db.close() - - -async def make_es_index_and_run_sync(env: Env, clients=32, force=False, db=None, index_name='claims'): - index = SearchIndex(env.es_index_prefix, elastic_host=env.elastic_host, elastic_port=env.elastic_port) - logging.info("ES sync host: %s:%i", env.elastic_host, env.elastic_port) - try: - created = await index.start() - except IndexVersionMismatch as err: - logging.info( - "dropping ES search index (version %s) for upgrade to version %s", err.got_version, err.expected_version - ) - await index.delete_index() - await index.stop() - created = await index.start() - finally: - index.stop() - - es = AsyncElasticsearch([{'host': env.elastic_host, 'port': env.elastic_port}]) - if force or created: - claim_generator = get_all_claims(env, index_name=index_name, db=db) - else: - claim_generator = get_recent_claims(env, index_name=index_name, db=db) - try: - async for ok, item in async_streaming_bulk(es, claim_generator, request_timeout=600, raise_on_error=False): - if not ok: - logging.warning("indexing failed for an item: %s", item) - await es.indices.refresh(index=index_name) - finally: - await es.close() - - -def run_elastic_sync(): - logging.basicConfig(level=logging.INFO) - logging.getLogger('aiohttp').setLevel(logging.WARNING) - logging.getLogger('elasticsearch').setLevel(logging.WARNING) - - logging.info('lbry.server starting') - parser = argparse.ArgumentParser(prog="lbry-hub-elastic-sync") - parser.add_argument("-c", "--clients", type=int, default=32) - parser.add_argument("-f", "--force", default=False, action='store_true') - Env.contribute_to_arg_parser(parser) - args = parser.parse_args() - env = Env.from_arg_parser(args) - - if not os.path.exists(os.path.join(args.db_dir, 'lbry-leveldb')): - logging.info("DB path doesnt exist, nothing to sync to ES") - return - - asyncio.run(make_es_index_and_run_sync(env, clients=args.clients, force=args.force)) diff --git a/lbry/wallet/server/db/prefixes.py b/lbry/wallet/server/db/prefixes.py deleted file mode 100644 index 4dbfe707e..000000000 --- a/lbry/wallet/server/db/prefixes.py +++ /dev/null @@ -1,1669 +0,0 @@ -import typing -import struct -import array -import base64 -from typing import Union, Tuple, NamedTuple, Optional -from lbry.wallet.server.db import DB_PREFIXES -from lbry.wallet.server.db.db import KeyValueStorage, PrefixDB -from lbry.wallet.server.db.revertable import RevertableOpStack, RevertablePut, RevertableDelete -from lbry.schema.url import normalize_name - -ACTIVATED_CLAIM_TXO_TYPE = 1 -ACTIVATED_SUPPORT_TXO_TYPE = 2 - - -def length_encoded_name(name: str) -> bytes: - encoded = name.encode('utf-8') - return len(encoded).to_bytes(2, byteorder='big') + encoded - - -def length_prefix(key: str) -> bytes: - return len(key).to_bytes(1, byteorder='big') + key.encode() - - -ROW_TYPES = {} - - -class PrefixRowType(type): - def __new__(cls, name, bases, kwargs): - klass = super().__new__(cls, name, bases, kwargs) - if name != "PrefixRow": - ROW_TYPES[klass.prefix] = klass - return klass - - -class PrefixRow(metaclass=PrefixRowType): - prefix: bytes - key_struct: struct.Struct - value_struct: struct.Struct - key_part_lambdas = [] - - def __init__(self, db: KeyValueStorage, op_stack: RevertableOpStack): - self._db = db - self._op_stack = op_stack - - def iterate(self, prefix=None, start=None, stop=None, - reverse: bool = False, include_key: bool = True, include_value: bool = True, - fill_cache: bool = True, deserialize_key: bool = True, deserialize_value: bool = True): - if not prefix and not start and not stop: - prefix = () - if prefix is not None: - prefix = self.pack_partial_key(*prefix) - if start is not None: - start = self.pack_partial_key(*start) - if stop is not None: - stop = self.pack_partial_key(*stop) - - if deserialize_key: - key_getter = lambda k: self.unpack_key(k) - else: - key_getter = lambda k: k - if deserialize_value: - value_getter = lambda v: self.unpack_value(v) - else: - value_getter = lambda v: v - - if include_key and include_value: - for k, v in self._db.iterator(prefix=prefix, start=start, stop=stop, reverse=reverse, - fill_cache=fill_cache): - yield key_getter(k), value_getter(v) - elif include_key: - for k in self._db.iterator(prefix=prefix, start=start, stop=stop, reverse=reverse, include_value=False, - fill_cache=fill_cache): - yield key_getter(k) - elif include_value: - for v in self._db.iterator(prefix=prefix, start=start, stop=stop, reverse=reverse, include_key=False, - fill_cache=fill_cache): - yield value_getter(v) - else: - for _ in self._db.iterator(prefix=prefix, start=start, stop=stop, reverse=reverse, include_key=False, - include_value=False, fill_cache=fill_cache): - yield None - - def get(self, *key_args, fill_cache=True, deserialize_value=True): - v = self._db.get(self.pack_key(*key_args), fill_cache=fill_cache) - if v: - return v if not deserialize_value else self.unpack_value(v) - - def get_pending(self, *key_args, fill_cache=True, deserialize_value=True): - packed_key = self.pack_key(*key_args) - last_op = self._op_stack.get_last_op_for_key(packed_key) - if last_op: - if last_op.is_put: - return last_op.value if not deserialize_value else self.unpack_value(last_op.value) - else: # it's a delete - return - v = self._db.get(packed_key, fill_cache=fill_cache) - if v: - return v if not deserialize_value else self.unpack_value(v) - - def stage_put(self, key_args=(), value_args=()): - self._op_stack.append_op(RevertablePut(self.pack_key(*key_args), self.pack_value(*value_args))) - - def stage_delete(self, key_args=(), value_args=()): - self._op_stack.append_op(RevertableDelete(self.pack_key(*key_args), self.pack_value(*value_args))) - - @classmethod - def pack_partial_key(cls, *args) -> bytes: - return cls.prefix + cls.key_part_lambdas[len(args)](*args) - - @classmethod - def pack_key(cls, *args) -> bytes: - return cls.prefix + cls.key_struct.pack(*args) - - @classmethod - def pack_value(cls, *args) -> bytes: - return cls.value_struct.pack(*args) - - @classmethod - def unpack_key(cls, key: bytes): - assert key[:1] == cls.prefix - return cls.key_struct.unpack(key[1:]) - - @classmethod - def unpack_value(cls, data: bytes): - return cls.value_struct.unpack(data) - - @classmethod - def unpack_item(cls, key: bytes, value: bytes): - return cls.unpack_key(key), cls.unpack_value(value) - - -class UTXOKey(NamedTuple): - hashX: bytes - tx_num: int - nout: int - - def __str__(self): - return f"{self.__class__.__name__}(hashX={self.hashX.hex()}, tx_num={self.tx_num}, nout={self.nout})" - - -class UTXOValue(NamedTuple): - amount: int - - -class HashXUTXOKey(NamedTuple): - short_tx_hash: bytes - tx_num: int - nout: int - - def __str__(self): - return f"{self.__class__.__name__}(short_tx_hash={self.short_tx_hash.hex()}, tx_num={self.tx_num}, nout={self.nout})" - - -class HashXUTXOValue(NamedTuple): - hashX: bytes - - def __str__(self): - return f"{self.__class__.__name__}(hashX={self.hashX.hex()})" - - -class HashXHistoryKey(NamedTuple): - hashX: bytes - height: int - - def __str__(self): - return f"{self.__class__.__name__}(hashX={self.hashX.hex()}, height={self.height})" - - -class HashXHistoryValue(NamedTuple): - hashXes: typing.List[int] - - -class BlockHashKey(NamedTuple): - height: int - - -class BlockHashValue(NamedTuple): - block_hash: bytes - - def __str__(self): - return f"{self.__class__.__name__}(block_hash={self.block_hash.hex()})" - - -class BlockTxsKey(NamedTuple): - height: int - - -class BlockTxsValue(NamedTuple): - tx_hashes: typing.List[bytes] - - -class TxCountKey(NamedTuple): - height: int - - -class TxCountValue(NamedTuple): - tx_count: int - - -class TxHashKey(NamedTuple): - tx_num: int - - -class TxHashValue(NamedTuple): - tx_hash: bytes - - def __str__(self): - return f"{self.__class__.__name__}(tx_hash={self.tx_hash.hex()})" - - -class TxNumKey(NamedTuple): - tx_hash: bytes - - def __str__(self): - return f"{self.__class__.__name__}(tx_hash={self.tx_hash.hex()})" - - -class TxNumValue(NamedTuple): - tx_num: int - - -class TxKey(NamedTuple): - tx_hash: bytes - - def __str__(self): - return f"{self.__class__.__name__}(tx_hash={self.tx_hash.hex()})" - - -class TxValue(NamedTuple): - raw_tx: bytes - - def __str__(self): - return f"{self.__class__.__name__}(raw_tx={base64.b64encode(self.raw_tx)})" - - -class BlockHeaderKey(NamedTuple): - height: int - - -class BlockHeaderValue(NamedTuple): - header: bytes - - def __str__(self): - return f"{self.__class__.__name__}(header={base64.b64encode(self.header)})" - - -class ClaimToTXOKey(typing.NamedTuple): - claim_hash: bytes - - def __str__(self): - return f"{self.__class__.__name__}(claim_hash={self.claim_hash.hex()})" - - -class ClaimToTXOValue(typing.NamedTuple): - tx_num: int - position: int - root_tx_num: int - root_position: int - amount: int - # activation: int - channel_signature_is_valid: bool - name: str - - @property - def normalized_name(self) -> str: - try: - return normalize_name(self.name) - except UnicodeDecodeError: - return self.name - - -class TXOToClaimKey(typing.NamedTuple): - tx_num: int - position: int - - -class TXOToClaimValue(typing.NamedTuple): - claim_hash: bytes - name: str - - def __str__(self): - return f"{self.__class__.__name__}(claim_hash={self.claim_hash.hex()}, name={self.name})" - - -class ClaimShortIDKey(typing.NamedTuple): - normalized_name: str - partial_claim_id: str - root_tx_num: int - root_position: int - - def __str__(self): - return f"{self.__class__.__name__}(normalized_name={self.normalized_name}, " \ - f"partial_claim_id={self.partial_claim_id}, " \ - f"root_tx_num={self.root_tx_num}, root_position={self.root_position})" - - -class ClaimShortIDValue(typing.NamedTuple): - tx_num: int - position: int - - -class ClaimToChannelKey(typing.NamedTuple): - claim_hash: bytes - tx_num: int - position: int - - def __str__(self): - return f"{self.__class__.__name__}(claim_hash={self.claim_hash.hex()}, " \ - f"tx_num={self.tx_num}, position={self.position})" - - -class ClaimToChannelValue(typing.NamedTuple): - signing_hash: bytes - - def __str__(self): - return f"{self.__class__.__name__}(signing_hash={self.signing_hash.hex()})" - - -class ChannelToClaimKey(typing.NamedTuple): - signing_hash: bytes - name: str - tx_num: int - position: int - - def __str__(self): - return f"{self.__class__.__name__}(signing_hash={self.signing_hash.hex()}, name={self.name}, " \ - f"tx_num={self.tx_num}, position={self.position})" - - -class ChannelToClaimValue(typing.NamedTuple): - claim_hash: bytes - - def __str__(self): - return f"{self.__class__.__name__}(claim_hash={self.claim_hash.hex()})" - - -class ChannelCountKey(typing.NamedTuple): - channel_hash: bytes - - def __str__(self): - return f"{self.__class__.__name__}(channel_hash={self.channel_hash.hex()})" - - -class ChannelCountValue(typing.NamedTuple): - count: int - - -class SupportAmountKey(typing.NamedTuple): - claim_hash: bytes - - def __str__(self): - return f"{self.__class__.__name__}(claim_hash={self.claim_hash.hex()})" - - -class SupportAmountValue(typing.NamedTuple): - amount: int - - -class ClaimToSupportKey(typing.NamedTuple): - claim_hash: bytes - tx_num: int - position: int - - def __str__(self): - return f"{self.__class__.__name__}(claim_hash={self.claim_hash.hex()}, tx_num={self.tx_num}, " \ - f"position={self.position})" - - -class ClaimToSupportValue(typing.NamedTuple): - amount: int - - -class SupportToClaimKey(typing.NamedTuple): - tx_num: int - position: int - - -class SupportToClaimValue(typing.NamedTuple): - claim_hash: bytes - - def __str__(self): - return f"{self.__class__.__name__}(claim_hash={self.claim_hash.hex()})" - - -class ClaimExpirationKey(typing.NamedTuple): - expiration: int - tx_num: int - position: int - - -class ClaimExpirationValue(typing.NamedTuple): - claim_hash: bytes - normalized_name: str - - def __str__(self): - return f"{self.__class__.__name__}(claim_hash={self.claim_hash.hex()}, normalized_name={self.normalized_name})" - - -class ClaimTakeoverKey(typing.NamedTuple): - normalized_name: str - - -class ClaimTakeoverValue(typing.NamedTuple): - claim_hash: bytes - height: int - - def __str__(self): - return f"{self.__class__.__name__}(claim_hash={self.claim_hash.hex()}, height={self.height})" - - -class PendingActivationKey(typing.NamedTuple): - height: int - txo_type: int - tx_num: int - position: int - - @property - def is_support(self) -> bool: - return self.txo_type == ACTIVATED_SUPPORT_TXO_TYPE - - @property - def is_claim(self) -> bool: - return self.txo_type == ACTIVATED_CLAIM_TXO_TYPE - - -class PendingActivationValue(typing.NamedTuple): - claim_hash: bytes - normalized_name: str - - def __str__(self): - return f"{self.__class__.__name__}(claim_hash={self.claim_hash.hex()}, normalized_name={self.normalized_name})" - - -class ActivationKey(typing.NamedTuple): - txo_type: int - tx_num: int - position: int - - -class ActivationValue(typing.NamedTuple): - height: int - claim_hash: bytes - normalized_name: str - - def __str__(self): - return f"{self.__class__.__name__}(height={self.height}, claim_hash={self.claim_hash.hex()}, " \ - f"normalized_name={self.normalized_name})" - - -class ActiveAmountKey(typing.NamedTuple): - claim_hash: bytes - txo_type: int - activation_height: int - tx_num: int - position: int - - def __str__(self): - return f"{self.__class__.__name__}(claim_hash={self.claim_hash.hex()}, txo_type={self.txo_type}, " \ - f"activation_height={self.activation_height}, tx_num={self.tx_num}, position={self.position})" - - -class ActiveAmountValue(typing.NamedTuple): - amount: int - - -class EffectiveAmountKey(typing.NamedTuple): - normalized_name: str - effective_amount: int - tx_num: int - position: int - - -class EffectiveAmountValue(typing.NamedTuple): - claim_hash: bytes - - def __str__(self): - return f"{self.__class__.__name__}(claim_hash={self.claim_hash.hex()})" - - -class RepostKey(typing.NamedTuple): - claim_hash: bytes - - def __str__(self): - return f"{self.__class__.__name__}(claim_hash={self.claim_hash.hex()})" - - -class RepostValue(typing.NamedTuple): - reposted_claim_hash: bytes - - def __str__(self): - return f"{self.__class__.__name__}(reposted_claim_hash={self.reposted_claim_hash.hex()})" - - -class RepostedKey(typing.NamedTuple): - reposted_claim_hash: bytes - tx_num: int - position: int - - def __str__(self): - return f"{self.__class__.__name__}(reposted_claim_hash={self.reposted_claim_hash.hex()}, " \ - f"tx_num={self.tx_num}, position={self.position})" - - -class RepostedValue(typing.NamedTuple): - claim_hash: bytes - - def __str__(self): - return f"{self.__class__.__name__}(claim_hash={self.claim_hash.hex()})" - - -class TouchedOrDeletedClaimKey(typing.NamedTuple): - height: int - - -class TouchedOrDeletedClaimValue(typing.NamedTuple): - touched_claims: typing.Set[bytes] - deleted_claims: typing.Set[bytes] - - def __str__(self): - return f"{self.__class__.__name__}(" \ - f"touched_claims={','.join(map(lambda x: x.hex(), self.touched_claims))}," \ - f"deleted_claims={','.join(map(lambda x: x.hex(), self.deleted_claims))})" - - -class DBState(typing.NamedTuple): - genesis: bytes - height: int - tx_count: int - tip: bytes - utxo_flush_count: int - wall_time: int - first_sync: bool - db_version: int - hist_flush_count: int - comp_flush_count: int - comp_cursor: int - es_sync_height: int - - -class ActiveAmountPrefixRow(PrefixRow): - prefix = DB_PREFIXES.active_amount.value - key_struct = struct.Struct(b'>20sBLLH') - value_struct = struct.Struct(b'>Q') - key_part_lambdas = [ - lambda: b'', - struct.Struct(b'>20s').pack, - struct.Struct(b'>20sB').pack, - struct.Struct(b'>20sBL').pack, - struct.Struct(b'>20sBLL').pack, - struct.Struct(b'>20sBLLH').pack - ] - - @classmethod - def pack_key(cls, claim_hash: bytes, txo_type: int, activation_height: int, tx_num: int, position: int): - return super().pack_key(claim_hash, txo_type, activation_height, tx_num, position) - - @classmethod - def unpack_key(cls, key: bytes) -> ActiveAmountKey: - return ActiveAmountKey(*super().unpack_key(key)) - - @classmethod - def unpack_value(cls, data: bytes) -> ActiveAmountValue: - return ActiveAmountValue(*super().unpack_value(data)) - - @classmethod - def pack_value(cls, amount: int) -> bytes: - return cls.value_struct.pack(amount) - - @classmethod - def pack_item(cls, claim_hash: bytes, txo_type: int, activation_height: int, tx_num: int, position: int, amount: int): - return cls.pack_key(claim_hash, txo_type, activation_height, tx_num, position), cls.pack_value(amount) - - -class ClaimToTXOPrefixRow(PrefixRow): - prefix = DB_PREFIXES.claim_to_txo.value - key_struct = struct.Struct(b'>20s') - value_struct = struct.Struct(b'>LHLHQB') - key_part_lambdas = [ - lambda: b'', - struct.Struct(b'>20s').pack - ] - - @classmethod - def pack_key(cls, claim_hash: bytes): - return super().pack_key(claim_hash) - - @classmethod - def unpack_key(cls, key: bytes) -> ClaimToTXOKey: - assert key[:1] == cls.prefix and len(key) == 21 - return ClaimToTXOKey(key[1:]) - - @classmethod - def unpack_value(cls, data: bytes) -> ClaimToTXOValue: - tx_num, position, root_tx_num, root_position, amount, channel_signature_is_valid = cls.value_struct.unpack( - data[:21] - ) - name_len = int.from_bytes(data[21:23], byteorder='big') - name = data[23:23 + name_len].decode() - return ClaimToTXOValue( - tx_num, position, root_tx_num, root_position, amount, bool(channel_signature_is_valid), name - ) - - @classmethod - def pack_value(cls, tx_num: int, position: int, root_tx_num: int, root_position: int, amount: int, - channel_signature_is_valid: bool, name: str) -> bytes: - return cls.value_struct.pack( - tx_num, position, root_tx_num, root_position, amount, int(channel_signature_is_valid) - ) + length_encoded_name(name) - - @classmethod - def pack_item(cls, claim_hash: bytes, tx_num: int, position: int, root_tx_num: int, root_position: int, - amount: int, channel_signature_is_valid: bool, name: str): - return cls.pack_key(claim_hash), \ - cls.pack_value(tx_num, position, root_tx_num, root_position, amount, channel_signature_is_valid, name) - - -class TXOToClaimPrefixRow(PrefixRow): - prefix = DB_PREFIXES.txo_to_claim.value - key_struct = struct.Struct(b'>LH') - value_struct = struct.Struct(b'>20s') - - @classmethod - def pack_key(cls, tx_num: int, position: int): - return super().pack_key(tx_num, position) - - @classmethod - def unpack_key(cls, key: bytes) -> TXOToClaimKey: - return TXOToClaimKey(*super().unpack_key(key)) - - @classmethod - def unpack_value(cls, data: bytes) -> TXOToClaimValue: - claim_hash, = cls.value_struct.unpack(data[:20]) - name_len = int.from_bytes(data[20:22], byteorder='big') - name = data[22:22 + name_len].decode() - return TXOToClaimValue(claim_hash, name) - - @classmethod - def pack_value(cls, claim_hash: bytes, name: str) -> bytes: - return cls.value_struct.pack(claim_hash) + length_encoded_name(name) - - @classmethod - def pack_item(cls, tx_num: int, position: int, claim_hash: bytes, name: str): - return cls.pack_key(tx_num, position), \ - cls.pack_value(claim_hash, name) - - -def shortid_key_helper(struct_fmt): - packer = struct.Struct(struct_fmt).pack - def wrapper(name, *args): - return length_encoded_name(name) + packer(*args) - return wrapper - - -def shortid_key_partial_claim_helper(name: str, partial_claim_id: str): - assert len(partial_claim_id) < 40 - return length_encoded_name(name) + length_prefix(partial_claim_id) - - -class ClaimShortIDPrefixRow(PrefixRow): - prefix = DB_PREFIXES.claim_short_id_prefix.value - key_struct = struct.Struct(b'>LH') - value_struct = struct.Struct(b'>LH') - key_part_lambdas = [ - lambda: b'', - length_encoded_name, - shortid_key_partial_claim_helper - ] - - @classmethod - def pack_key(cls, name: str, short_claim_id: str, root_tx_num: int, root_position: int): - return cls.prefix + length_encoded_name(name) + length_prefix(short_claim_id) +\ - cls.key_struct.pack(root_tx_num, root_position) - - @classmethod - def pack_value(cls, tx_num: int, position: int): - return super().pack_value(tx_num, position) - - @classmethod - def unpack_key(cls, key: bytes) -> ClaimShortIDKey: - assert key[:1] == cls.prefix - name_len = int.from_bytes(key[1:3], byteorder='big') - name = key[3:3 + name_len].decode() - claim_id_len = int.from_bytes(key[3+name_len:4+name_len], byteorder='big') - partial_claim_id = key[4+name_len:4+name_len+claim_id_len].decode() - return ClaimShortIDKey(name, partial_claim_id, *cls.key_struct.unpack(key[4 + name_len + claim_id_len:])) - - @classmethod - def unpack_value(cls, data: bytes) -> ClaimShortIDValue: - return ClaimShortIDValue(*super().unpack_value(data)) - - @classmethod - def pack_item(cls, name: str, partial_claim_id: str, root_tx_num: int, root_position: int, - tx_num: int, position: int): - return cls.pack_key(name, partial_claim_id, root_tx_num, root_position), \ - cls.pack_value(tx_num, position) - - -class ClaimToChannelPrefixRow(PrefixRow): - prefix = DB_PREFIXES.claim_to_channel.value - key_struct = struct.Struct(b'>20sLH') - value_struct = struct.Struct(b'>20s') - - key_part_lambdas = [ - lambda: b'', - struct.Struct(b'>20s').pack, - struct.Struct(b'>20sL').pack, - struct.Struct(b'>20sLH').pack - ] - - @classmethod - def pack_key(cls, claim_hash: bytes, tx_num: int, position: int): - return super().pack_key(claim_hash, tx_num, position) - - @classmethod - def pack_value(cls, signing_hash: bytes): - return super().pack_value(signing_hash) - - @classmethod - def unpack_key(cls, key: bytes) -> ClaimToChannelKey: - return ClaimToChannelKey(*super().unpack_key(key)) - - @classmethod - def unpack_value(cls, data: bytes) -> ClaimToChannelValue: - return ClaimToChannelValue(*super().unpack_value(data)) - - @classmethod - def pack_item(cls, claim_hash: bytes, tx_num: int, position: int, signing_hash: bytes): - return cls.pack_key(claim_hash, tx_num, position), cls.pack_value(signing_hash) - - -def channel_to_claim_helper(struct_fmt): - packer = struct.Struct(struct_fmt).pack - - def wrapper(signing_hash: bytes, name: str, *args): - return signing_hash + length_encoded_name(name) + packer(*args) - - return wrapper - - -class ChannelToClaimPrefixRow(PrefixRow): - prefix = DB_PREFIXES.channel_to_claim.value - key_struct = struct.Struct(b'>LH') - value_struct = struct.Struct(b'>20s') - - key_part_lambdas = [ - lambda: b'', - struct.Struct(b'>20s').pack, - channel_to_claim_helper(b''), - channel_to_claim_helper(b'>s'), - channel_to_claim_helper(b'>L'), - channel_to_claim_helper(b'>LH'), - ] - - @classmethod - def pack_key(cls, signing_hash: bytes, name: str, tx_num: int, position: int): - return cls.prefix + signing_hash + length_encoded_name(name) + cls.key_struct.pack( - tx_num, position - ) - - @classmethod - def unpack_key(cls, key: bytes) -> ChannelToClaimKey: - assert key[:1] == cls.prefix - signing_hash = key[1:21] - name_len = int.from_bytes(key[21:23], byteorder='big') - name = key[23:23 + name_len].decode() - tx_num, position = cls.key_struct.unpack(key[23 + name_len:]) - return ChannelToClaimKey( - signing_hash, name, tx_num, position - ) - - @classmethod - def pack_value(cls, claim_hash: bytes) -> bytes: - return super().pack_value(claim_hash) - - @classmethod - def unpack_value(cls, data: bytes) -> ChannelToClaimValue: - return ChannelToClaimValue(*cls.value_struct.unpack(data)) - - @classmethod - def pack_item(cls, signing_hash: bytes, name: str, tx_num: int, position: int, - claim_hash: bytes): - return cls.pack_key(signing_hash, name, tx_num, position), \ - cls.pack_value(claim_hash) - - -class ClaimToSupportPrefixRow(PrefixRow): - prefix = DB_PREFIXES.claim_to_support.value - key_struct = struct.Struct(b'>20sLH') - value_struct = struct.Struct(b'>Q') - - key_part_lambdas = [ - lambda: b'', - struct.Struct(b'>20s').pack, - struct.Struct(b'>20sL').pack, - struct.Struct(b'>20sLH').pack - ] - - @classmethod - def pack_key(cls, claim_hash: bytes, tx_num: int, position: int): - return super().pack_key(claim_hash, tx_num, position) - - @classmethod - def unpack_key(cls, key: bytes) -> ClaimToSupportKey: - return ClaimToSupportKey(*super().unpack_key(key)) - - @classmethod - def pack_value(cls, amount: int) -> bytes: - return super().pack_value(amount) - - @classmethod - def unpack_value(cls, data: bytes) -> ClaimToSupportValue: - return ClaimToSupportValue(*super().unpack_value(data)) - - @classmethod - def pack_item(cls, claim_hash: bytes, tx_num: int, position: int, amount: int): - return cls.pack_key(claim_hash, tx_num, position), \ - cls.pack_value(amount) - - -class SupportToClaimPrefixRow(PrefixRow): - prefix = DB_PREFIXES.support_to_claim.value - key_struct = struct.Struct(b'>LH') - value_struct = struct.Struct(b'>20s') - - @classmethod - def pack_key(cls, tx_num: int, position: int): - return super().pack_key(tx_num, position) - - @classmethod - def unpack_key(cls, key: bytes) -> SupportToClaimKey: - return SupportToClaimKey(*super().unpack_key(key)) - - @classmethod - def pack_value(cls, claim_hash: bytes) -> bytes: - return super().pack_value(claim_hash) - - @classmethod - def unpack_value(cls, data: bytes) -> SupportToClaimValue: - return SupportToClaimValue(*super().unpack_value(data)) - - @classmethod - def pack_item(cls, tx_num: int, position: int, claim_hash: bytes): - return cls.pack_key(tx_num, position), \ - cls.pack_value(claim_hash) - - -class ClaimExpirationPrefixRow(PrefixRow): - prefix = DB_PREFIXES.claim_expiration.value - key_struct = struct.Struct(b'>LLH') - value_struct = struct.Struct(b'>20s') - key_part_lambdas = [ - lambda: b'', - struct.Struct(b'>L').pack, - struct.Struct(b'>LL').pack, - struct.Struct(b'>LLH').pack, - ] - - @classmethod - def pack_key(cls, expiration: int, tx_num: int, position: int) -> bytes: - return super().pack_key(expiration, tx_num, position) - - @classmethod - def pack_value(cls, claim_hash: bytes, name: str) -> bytes: - return cls.value_struct.pack(claim_hash) + length_encoded_name(name) - - @classmethod - def pack_item(cls, expiration: int, tx_num: int, position: int, claim_hash: bytes, name: str) -> typing.Tuple[bytes, bytes]: - return cls.pack_key(expiration, tx_num, position), cls.pack_value(claim_hash, name) - - @classmethod - def unpack_key(cls, key: bytes) -> ClaimExpirationKey: - return ClaimExpirationKey(*super().unpack_key(key)) - - @classmethod - def unpack_value(cls, data: bytes) -> ClaimExpirationValue: - name_len = int.from_bytes(data[20:22], byteorder='big') - name = data[22:22 + name_len].decode() - claim_id, = cls.value_struct.unpack(data[:20]) - return ClaimExpirationValue(claim_id, name) - - @classmethod - def unpack_item(cls, key: bytes, value: bytes) -> typing.Tuple[ClaimExpirationKey, ClaimExpirationValue]: - return cls.unpack_key(key), cls.unpack_value(value) - - -class ClaimTakeoverPrefixRow(PrefixRow): - prefix = DB_PREFIXES.claim_takeover.value - value_struct = struct.Struct(b'>20sL') - - key_part_lambdas = [ - lambda: b'', - length_encoded_name - ] - - @classmethod - def pack_key(cls, name: str): - return cls.prefix + length_encoded_name(name) - - @classmethod - def pack_value(cls, claim_hash: bytes, takeover_height: int): - return super().pack_value(claim_hash, takeover_height) - - @classmethod - def unpack_key(cls, key: bytes) -> ClaimTakeoverKey: - assert key[:1] == cls.prefix - name_len = int.from_bytes(key[1:3], byteorder='big') - name = key[3:3 + name_len].decode() - return ClaimTakeoverKey(name) - - @classmethod - def unpack_value(cls, data: bytes) -> ClaimTakeoverValue: - return ClaimTakeoverValue(*super().unpack_value(data)) - - @classmethod - def pack_item(cls, name: str, claim_hash: bytes, takeover_height: int): - return cls.pack_key(name), cls.pack_value(claim_hash, takeover_height) - - -class PendingActivationPrefixRow(PrefixRow): - prefix = DB_PREFIXES.pending_activation.value - key_struct = struct.Struct(b'>LBLH') - key_part_lambdas = [ - lambda: b'', - struct.Struct(b'>L').pack, - struct.Struct(b'>LB').pack, - struct.Struct(b'>LBL').pack, - struct.Struct(b'>LBLH').pack - ] - - @classmethod - def pack_key(cls, height: int, txo_type: int, tx_num: int, position: int): - return super().pack_key(height, txo_type, tx_num, position) - - @classmethod - def unpack_key(cls, key: bytes) -> PendingActivationKey: - return PendingActivationKey(*super().unpack_key(key)) - - @classmethod - def pack_value(cls, claim_hash: bytes, name: str) -> bytes: - return claim_hash + length_encoded_name(name) - - @classmethod - def unpack_value(cls, data: bytes) -> PendingActivationValue: - claim_hash = data[:20] - name_len = int.from_bytes(data[20:22], byteorder='big') - name = data[22:22 + name_len].decode() - return PendingActivationValue(claim_hash, name) - - @classmethod - def pack_item(cls, height: int, txo_type: int, tx_num: int, position: int, claim_hash: bytes, name: str): - return cls.pack_key(height, txo_type, tx_num, position), \ - cls.pack_value(claim_hash, name) - - -class ActivatedPrefixRow(PrefixRow): - prefix = DB_PREFIXES.activated_claim_and_support.value - key_struct = struct.Struct(b'>BLH') - value_struct = struct.Struct(b'>L20s') - key_part_lambdas = [ - lambda: b'', - struct.Struct(b'>B').pack, - struct.Struct(b'>BL').pack, - struct.Struct(b'>BLH').pack - ] - - @classmethod - def pack_key(cls, txo_type: int, tx_num: int, position: int): - return super().pack_key(txo_type, tx_num, position) - - @classmethod - def unpack_key(cls, key: bytes) -> ActivationKey: - return ActivationKey(*super().unpack_key(key)) - - @classmethod - def pack_value(cls, height: int, claim_hash: bytes, name: str) -> bytes: - return cls.value_struct.pack(height, claim_hash) + length_encoded_name(name) - - @classmethod - def unpack_value(cls, data: bytes) -> ActivationValue: - height, claim_hash = cls.value_struct.unpack(data[:24]) - name_len = int.from_bytes(data[24:26], byteorder='big') - name = data[26:26 + name_len].decode() - return ActivationValue(height, claim_hash, name) - - @classmethod - def pack_item(cls, txo_type: int, tx_num: int, position: int, height: int, claim_hash: bytes, name: str): - return cls.pack_key(txo_type, tx_num, position), \ - cls.pack_value(height, claim_hash, name) - - -def effective_amount_helper(struct_fmt): - packer = struct.Struct(struct_fmt).pack - - def wrapper(name, *args): - if not args: - return length_encoded_name(name) - if len(args) == 1: - return length_encoded_name(name) + packer(0xffffffffffffffff - args[0]) - return length_encoded_name(name) + packer(0xffffffffffffffff - args[0], *args[1:]) - - return wrapper - - -class EffectiveAmountPrefixRow(PrefixRow): - prefix = DB_PREFIXES.effective_amount.value - key_struct = struct.Struct(b'>QLH') - value_struct = struct.Struct(b'>20s') - key_part_lambdas = [ - lambda: b'', - length_encoded_name, - shortid_key_helper(b'>Q'), - shortid_key_helper(b'>QL'), - shortid_key_helper(b'>QLH'), - ] - - @classmethod - def pack_key(cls, name: str, effective_amount: int, tx_num: int, position: int): - return cls.prefix + length_encoded_name(name) + cls.key_struct.pack( - 0xffffffffffffffff - effective_amount, tx_num, position - ) - - @classmethod - def unpack_key(cls, key: bytes) -> EffectiveAmountKey: - assert key[:1] == cls.prefix - name_len = int.from_bytes(key[1:3], byteorder='big') - name = key[3:3 + name_len].decode() - ones_comp_effective_amount, tx_num, position = cls.key_struct.unpack(key[3 + name_len:]) - return EffectiveAmountKey(name, 0xffffffffffffffff - ones_comp_effective_amount, tx_num, position) - - @classmethod - def unpack_value(cls, data: bytes) -> EffectiveAmountValue: - return EffectiveAmountValue(*super().unpack_value(data)) - - @classmethod - def pack_value(cls, claim_hash: bytes) -> bytes: - return super().pack_value(claim_hash) - - @classmethod - def pack_item(cls, name: str, effective_amount: int, tx_num: int, position: int, claim_hash: bytes): - return cls.pack_key(name, effective_amount, tx_num, position), cls.pack_value(claim_hash) - - -class RepostPrefixRow(PrefixRow): - prefix = DB_PREFIXES.repost.value - - key_part_lambdas = [ - lambda: b'', - struct.Struct(b'>20s').pack - ] - - @classmethod - def pack_key(cls, claim_hash: bytes): - return cls.prefix + claim_hash - - @classmethod - def unpack_key(cls, key: bytes) -> RepostKey: - assert key[:1] == cls.prefix - assert len(key) == 21 - return RepostKey(key[1:]) - - @classmethod - def pack_value(cls, reposted_claim_hash: bytes) -> bytes: - return reposted_claim_hash - - @classmethod - def unpack_value(cls, data: bytes) -> RepostValue: - return RepostValue(data) - - @classmethod - def pack_item(cls, claim_hash: bytes, reposted_claim_hash: bytes): - return cls.pack_key(claim_hash), cls.pack_value(reposted_claim_hash) - - -class RepostedPrefixRow(PrefixRow): - prefix = DB_PREFIXES.reposted_claim.value - key_struct = struct.Struct(b'>20sLH') - value_struct = struct.Struct(b'>20s') - key_part_lambdas = [ - lambda: b'', - struct.Struct(b'>20s').pack, - struct.Struct(b'>20sL').pack, - struct.Struct(b'>20sLH').pack - ] - - @classmethod - def pack_key(cls, reposted_claim_hash: bytes, tx_num: int, position: int): - return super().pack_key(reposted_claim_hash, tx_num, position) - - @classmethod - def unpack_key(cls, key: bytes) -> RepostedKey: - return RepostedKey(*super().unpack_key(key)) - - @classmethod - def pack_value(cls, claim_hash: bytes) -> bytes: - return super().pack_value(claim_hash) - - @classmethod - def unpack_value(cls, data: bytes) -> RepostedValue: - return RepostedValue(*super().unpack_value(data)) - - @classmethod - def pack_item(cls, reposted_claim_hash: bytes, tx_num: int, position: int, claim_hash: bytes): - return cls.pack_key(reposted_claim_hash, tx_num, position), cls.pack_value(claim_hash) - - -class UndoPrefixRow(PrefixRow): - prefix = DB_PREFIXES.undo.value - key_struct = struct.Struct(b'>Q') - - key_part_lambdas = [ - lambda: b'', - struct.Struct(b'>Q').pack - ] - - @classmethod - def pack_key(cls, height: int): - return super().pack_key(height) - - @classmethod - def unpack_key(cls, key: bytes) -> int: - assert key[:1] == cls.prefix - height, = cls.key_struct.unpack(key[1:]) - return height - - @classmethod - def pack_value(cls, undo_ops: bytes) -> bytes: - return undo_ops - - @classmethod - def unpack_value(cls, data: bytes) -> bytes: - return data - - @classmethod - def pack_item(cls, height: int, undo_ops: bytes): - return cls.pack_key(height), cls.pack_value(undo_ops) - - -class BlockHashPrefixRow(PrefixRow): - prefix = DB_PREFIXES.block_hash.value - key_struct = struct.Struct(b'>L') - value_struct = struct.Struct(b'>32s') - - key_part_lambdas = [ - lambda: b'', - struct.Struct(b'>L').pack - ] - - @classmethod - def pack_key(cls, height: int) -> bytes: - return super().pack_key(height) - - @classmethod - def unpack_key(cls, key: bytes) -> BlockHashKey: - return BlockHashKey(*super().unpack_key(key)) - - @classmethod - def pack_value(cls, block_hash: bytes) -> bytes: - return super().pack_value(block_hash) - - @classmethod - def unpack_value(cls, data: bytes) -> BlockHashValue: - return BlockHashValue(*super().unpack_value(data)) - - @classmethod - def pack_item(cls, height: int, block_hash: bytes): - return cls.pack_key(height), cls.pack_value(block_hash) - - -class BlockHeaderPrefixRow(PrefixRow): - prefix = DB_PREFIXES.header.value - key_struct = struct.Struct(b'>L') - value_struct = struct.Struct(b'>112s') - - key_part_lambdas = [ - lambda: b'', - struct.Struct(b'>L').pack - ] - - @classmethod - def pack_key(cls, height: int) -> bytes: - return super().pack_key(height) - - @classmethod - def unpack_key(cls, key: bytes) -> BlockHeaderKey: - return BlockHeaderKey(*super().unpack_key(key)) - - @classmethod - def pack_value(cls, header: bytes) -> bytes: - return super().pack_value(header) - - @classmethod - def unpack_value(cls, data: bytes) -> BlockHeaderValue: - return BlockHeaderValue(*super().unpack_value(data)) - - @classmethod - def pack_item(cls, height: int, header: bytes): - return cls.pack_key(height), cls.pack_value(header) - - -class TXNumPrefixRow(PrefixRow): - prefix = DB_PREFIXES.tx_num.value - key_struct = struct.Struct(b'>32s') - value_struct = struct.Struct(b'>L') - - key_part_lambdas = [ - lambda: b'', - struct.Struct(b'>32s').pack - ] - - @classmethod - def pack_key(cls, tx_hash: bytes) -> bytes: - return super().pack_key(tx_hash) - - @classmethod - def unpack_key(cls, tx_hash: bytes) -> TxNumKey: - return TxNumKey(*super().unpack_key(tx_hash)) - - @classmethod - def pack_value(cls, tx_num: int) -> bytes: - return super().pack_value(tx_num) - - @classmethod - def unpack_value(cls, data: bytes) -> TxNumValue: - return TxNumValue(*super().unpack_value(data)) - - @classmethod - def pack_item(cls, tx_hash: bytes, tx_num: int): - return cls.pack_key(tx_hash), cls.pack_value(tx_num) - - -class TxCountPrefixRow(PrefixRow): - prefix = DB_PREFIXES.tx_count.value - key_struct = struct.Struct(b'>L') - value_struct = struct.Struct(b'>L') - - key_part_lambdas = [ - lambda: b'', - struct.Struct(b'>L').pack - ] - - @classmethod - def pack_key(cls, height: int) -> bytes: - return super().pack_key(height) - - @classmethod - def unpack_key(cls, key: bytes) -> TxCountKey: - return TxCountKey(*super().unpack_key(key)) - - @classmethod - def pack_value(cls, tx_count: int) -> bytes: - return super().pack_value(tx_count) - - @classmethod - def unpack_value(cls, data: bytes) -> TxCountValue: - return TxCountValue(*super().unpack_value(data)) - - @classmethod - def pack_item(cls, height: int, tx_count: int): - return cls.pack_key(height), cls.pack_value(tx_count) - - -class TXHashPrefixRow(PrefixRow): - prefix = DB_PREFIXES.tx_hash.value - key_struct = struct.Struct(b'>L') - value_struct = struct.Struct(b'>32s') - - key_part_lambdas = [ - lambda: b'', - struct.Struct(b'>L').pack - ] - - @classmethod - def pack_key(cls, tx_num: int) -> bytes: - return super().pack_key(tx_num) - - @classmethod - def unpack_key(cls, key: bytes) -> TxHashKey: - return TxHashKey(*super().unpack_key(key)) - - @classmethod - def pack_value(cls, tx_hash: bytes) -> bytes: - return super().pack_value(tx_hash) - - @classmethod - def unpack_value(cls, data: bytes) -> TxHashValue: - return TxHashValue(*super().unpack_value(data)) - - @classmethod - def pack_item(cls, tx_num: int, tx_hash: bytes): - return cls.pack_key(tx_num), cls.pack_value(tx_hash) - - -class TXPrefixRow(PrefixRow): - prefix = DB_PREFIXES.tx.value - key_struct = struct.Struct(b'>32s') - - key_part_lambdas = [ - lambda: b'', - struct.Struct(b'>32s').pack - ] - - @classmethod - def pack_key(cls, tx_hash: bytes) -> bytes: - return super().pack_key(tx_hash) - - @classmethod - def unpack_key(cls, tx_hash: bytes) -> TxKey: - return TxKey(*super().unpack_key(tx_hash)) - - @classmethod - def pack_value(cls, tx: bytes) -> bytes: - return tx - - @classmethod - def unpack_value(cls, data: bytes) -> TxValue: - return TxValue(data) - - @classmethod - def pack_item(cls, tx_hash: bytes, raw_tx: bytes): - return cls.pack_key(tx_hash), cls.pack_value(raw_tx) - - -class UTXOPrefixRow(PrefixRow): - prefix = DB_PREFIXES.utxo.value - key_struct = struct.Struct(b'>11sLH') - value_struct = struct.Struct(b'>Q') - - key_part_lambdas = [ - lambda: b'', - struct.Struct(b'>11s').pack, - struct.Struct(b'>11sL').pack, - struct.Struct(b'>11sLH').pack - ] - - @classmethod - def pack_key(cls, hashX: bytes, tx_num, nout: int): - return super().pack_key(hashX, tx_num, nout) - - @classmethod - def unpack_key(cls, key: bytes) -> UTXOKey: - return UTXOKey(*super().unpack_key(key)) - - @classmethod - def pack_value(cls, amount: int) -> bytes: - return super().pack_value(amount) - - @classmethod - def unpack_value(cls, data: bytes) -> UTXOValue: - return UTXOValue(*cls.value_struct.unpack(data)) - - @classmethod - def pack_item(cls, hashX: bytes, tx_num: int, nout: int, amount: int): - return cls.pack_key(hashX, tx_num, nout), cls.pack_value(amount) - - -class HashXUTXOPrefixRow(PrefixRow): - prefix = DB_PREFIXES.hashx_utxo.value - key_struct = struct.Struct(b'>4sLH') - value_struct = struct.Struct(b'>11s') - - key_part_lambdas = [ - lambda: b'', - struct.Struct(b'>4s').pack, - struct.Struct(b'>4sL').pack, - struct.Struct(b'>4sLH').pack - ] - - @classmethod - def pack_key(cls, short_tx_hash: bytes, tx_num, nout: int): - return super().pack_key(short_tx_hash, tx_num, nout) - - @classmethod - def unpack_key(cls, key: bytes) -> HashXUTXOKey: - return HashXUTXOKey(*super().unpack_key(key)) - - @classmethod - def pack_value(cls, hashX: bytes) -> bytes: - return super().pack_value(hashX) - - @classmethod - def unpack_value(cls, data: bytes) -> HashXUTXOValue: - return HashXUTXOValue(*cls.value_struct.unpack(data)) - - @classmethod - def pack_item(cls, short_tx_hash: bytes, tx_num: int, nout: int, hashX: bytes): - return cls.pack_key(short_tx_hash, tx_num, nout), cls.pack_value(hashX) - - -class HashXHistoryPrefixRow(PrefixRow): - prefix = DB_PREFIXES.hashx_history.value - key_struct = struct.Struct(b'>11sL') - - key_part_lambdas = [ - lambda: b'', - struct.Struct(b'>11s').pack, - struct.Struct(b'>11sL').pack - ] - - @classmethod - def pack_key(cls, hashX: bytes, height: int): - return super().pack_key(hashX, height) - - @classmethod - def unpack_key(cls, key: bytes) -> HashXHistoryKey: - return HashXHistoryKey(*super().unpack_key(key)) - - @classmethod - def pack_value(cls, history: typing.List[int]) -> bytes: - a = array.array('I') - a.fromlist(history) - return a.tobytes() - - @classmethod - def unpack_value(cls, data: bytes) -> array.array: - a = array.array('I') - a.frombytes(data) - return a - - @classmethod - def pack_item(cls, hashX: bytes, height: int, history: typing.List[int]): - return cls.pack_key(hashX, height), cls.pack_value(history) - - -class TouchedOrDeletedPrefixRow(PrefixRow): - prefix = DB_PREFIXES.claim_diff.value - key_struct = struct.Struct(b'>L') - value_struct = struct.Struct(b'>LL') - key_part_lambdas = [ - lambda: b'', - struct.Struct(b'>L').pack - ] - - @classmethod - def pack_key(cls, height: int): - return super().pack_key(height) - - @classmethod - def unpack_key(cls, key: bytes) -> TouchedOrDeletedClaimKey: - return TouchedOrDeletedClaimKey(*super().unpack_key(key)) - - @classmethod - def pack_value(cls, touched: typing.Set[bytes], deleted: typing.Set[bytes]) -> bytes: - assert True if not touched else all(len(item) == 20 for item in touched) - assert True if not deleted else all(len(item) == 20 for item in deleted) - return cls.value_struct.pack(len(touched), len(deleted)) + b''.join(sorted(touched)) + b''.join(sorted(deleted)) - - @classmethod - def unpack_value(cls, data: bytes) -> TouchedOrDeletedClaimValue: - touched_len, deleted_len = cls.value_struct.unpack(data[:8]) - data = data[8:] - assert len(data) == 20 * (touched_len + deleted_len) - touched_bytes, deleted_bytes = data[:touched_len*20], data[touched_len*20:] - return TouchedOrDeletedClaimValue( - {touched_bytes[20*i:20*(i+1)] for i in range(touched_len)}, - {deleted_bytes[20*i:20*(i+1)] for i in range(deleted_len)} - ) - - @classmethod - def pack_item(cls, height, touched, deleted): - return cls.pack_key(height), cls.pack_value(touched, deleted) - - -class ChannelCountPrefixRow(PrefixRow): - prefix = DB_PREFIXES.channel_count.value - key_struct = struct.Struct(b'>20s') - value_struct = struct.Struct(b'>L') - key_part_lambdas = [ - lambda: b'', - struct.Struct(b'>20s').pack - ] - - @classmethod - def pack_key(cls, channel_hash: bytes): - return super().pack_key(channel_hash) - - @classmethod - def unpack_key(cls, key: bytes) -> ChannelCountKey: - return ChannelCountKey(*super().unpack_key(key)) - - @classmethod - def pack_value(cls, count: int) -> bytes: - return super().pack_value(count) - - @classmethod - def unpack_value(cls, data: bytes) -> ChannelCountValue: - return ChannelCountValue(*super().unpack_value(data)) - - @classmethod - def pack_item(cls, channel_hash, count): - return cls.pack_key(channel_hash), cls.pack_value(count) - - -class SupportAmountPrefixRow(PrefixRow): - prefix = DB_PREFIXES.support_amount.value - key_struct = struct.Struct(b'>20s') - value_struct = struct.Struct(b'>Q') - key_part_lambdas = [ - lambda: b'', - struct.Struct(b'>20s').pack - ] - - @classmethod - def pack_key(cls, claim_hash: bytes): - return super().pack_key(claim_hash) - - @classmethod - def unpack_key(cls, key: bytes) -> SupportAmountKey: - return SupportAmountKey(*super().unpack_key(key)) - - @classmethod - def pack_value(cls, amount: int) -> bytes: - return super().pack_value(amount) - - @classmethod - def unpack_value(cls, data: bytes) -> SupportAmountValue: - return SupportAmountValue(*super().unpack_value(data)) - - @classmethod - def pack_item(cls, claim_hash, amount): - return cls.pack_key(claim_hash), cls.pack_value(amount) - - -class DBStatePrefixRow(PrefixRow): - prefix = DB_PREFIXES.db_state.value - value_struct = struct.Struct(b'>32sLL32sLLBBlllL') - key_struct = struct.Struct(b'') - - key_part_lambdas = [ - lambda: b'' - ] - - @classmethod - def pack_key(cls) -> bytes: - return cls.prefix - - @classmethod - def unpack_key(cls, key: bytes): - return - - @classmethod - def pack_value(cls, genesis: bytes, height: int, tx_count: int, tip: bytes, utxo_flush_count: int, wall_time: int, - first_sync: bool, db_version: int, hist_flush_count: int, comp_flush_count: int, - comp_cursor: int, es_sync_height: int) -> bytes: - return super().pack_value( - genesis, height, tx_count, tip, utxo_flush_count, - wall_time, 1 if first_sync else 0, db_version, hist_flush_count, - comp_flush_count, comp_cursor, es_sync_height - ) - - @classmethod - def unpack_value(cls, data: bytes) -> DBState: - if len(data) == 94: - # TODO: delete this after making a new snapshot - 10/20/21 - # migrate in the es_sync_height if it doesnt exist - data += data[32:36] - return DBState(*super().unpack_value(data)) - - @classmethod - def pack_item(cls, genesis: bytes, height: int, tx_count: int, tip: bytes, utxo_flush_count: int, wall_time: int, - first_sync: bool, db_version: int, hist_flush_count: int, comp_flush_count: int, - comp_cursor: int, es_sync_height: int): - return cls.pack_key(), cls.pack_value( - genesis, height, tx_count, tip, utxo_flush_count, wall_time, first_sync, db_version, hist_flush_count, - comp_flush_count, comp_cursor, es_sync_height - ) - - -class BlockTxsPrefixRow(PrefixRow): - prefix = DB_PREFIXES.block_txs.value - key_struct = struct.Struct(b'>L') - key_part_lambdas = [ - lambda: b'', - struct.Struct(b'>L').pack - ] - - @classmethod - def pack_key(cls, height: int): - return super().pack_key(height) - - @classmethod - def unpack_key(cls, key: bytes) -> BlockTxsKey: - return BlockTxsKey(*super().unpack_key(key)) - - @classmethod - def pack_value(cls, tx_hashes: typing.List[bytes]) -> bytes: - assert all(len(tx_hash) == 32 for tx_hash in tx_hashes) - return b''.join(tx_hashes) - - @classmethod - def unpack_value(cls, data: bytes) -> BlockTxsValue: - return BlockTxsValue([data[i*32:(i+1)*32] for i in range(len(data) // 32)]) - - @classmethod - def pack_item(cls, height, tx_hashes): - return cls.pack_key(height), cls.pack_value(tx_hashes) - - -class LevelDBStore(KeyValueStorage): - def __init__(self, path: str, cache_mb: int, max_open_files: int): - import plyvel - self.db = plyvel.DB( - path, create_if_missing=True, max_open_files=max_open_files, - lru_cache_size=cache_mb * 1024 * 1024, write_buffer_size=64 * 1024 * 1024, - max_file_size=1024 * 1024 * 64, bloom_filter_bits=32 - ) - - def get(self, key: bytes, fill_cache: bool = True) -> Optional[bytes]: - return self.db.get(key, fill_cache=fill_cache) - - def iterator(self, reverse=False, start=None, stop=None, include_start=True, include_stop=False, prefix=None, - include_key=True, include_value=True, fill_cache=True): - return self.db.iterator( - reverse=reverse, start=start, stop=stop, include_start=include_start, include_stop=include_stop, - prefix=prefix, include_key=include_key, include_value=include_value, fill_cache=fill_cache - ) - - def write_batch(self, transaction: bool = False, sync: bool = False): - return self.db.write_batch(transaction=transaction, sync=sync) - - def close(self): - return self.db.close() - - @property - def closed(self) -> bool: - return self.db.closed - - -class HubDB(PrefixDB): - def __init__(self, path: str, cache_mb: int = 128, reorg_limit: int = 200, max_open_files: int = 512, - unsafe_prefixes: Optional[typing.Set[bytes]] = None): - db = LevelDBStore(path, cache_mb, max_open_files) - super().__init__(db, reorg_limit, unsafe_prefixes=unsafe_prefixes) - self.claim_to_support = ClaimToSupportPrefixRow(db, self._op_stack) - self.support_to_claim = SupportToClaimPrefixRow(db, self._op_stack) - self.claim_to_txo = ClaimToTXOPrefixRow(db, self._op_stack) - self.txo_to_claim = TXOToClaimPrefixRow(db, self._op_stack) - self.claim_to_channel = ClaimToChannelPrefixRow(db, self._op_stack) - self.channel_to_claim = ChannelToClaimPrefixRow(db, self._op_stack) - self.claim_short_id = ClaimShortIDPrefixRow(db, self._op_stack) - self.claim_expiration = ClaimExpirationPrefixRow(db, self._op_stack) - self.claim_takeover = ClaimTakeoverPrefixRow(db, self._op_stack) - self.pending_activation = PendingActivationPrefixRow(db, self._op_stack) - self.activated = ActivatedPrefixRow(db, self._op_stack) - self.active_amount = ActiveAmountPrefixRow(db, self._op_stack) - self.effective_amount = EffectiveAmountPrefixRow(db, self._op_stack) - self.repost = RepostPrefixRow(db, self._op_stack) - self.reposted_claim = RepostedPrefixRow(db, self._op_stack) - self.undo = UndoPrefixRow(db, self._op_stack) - self.utxo = UTXOPrefixRow(db, self._op_stack) - self.hashX_utxo = HashXUTXOPrefixRow(db, self._op_stack) - self.hashX_history = HashXHistoryPrefixRow(db, self._op_stack) - self.block_hash = BlockHashPrefixRow(db, self._op_stack) - self.tx_count = TxCountPrefixRow(db, self._op_stack) - self.tx_hash = TXHashPrefixRow(db, self._op_stack) - self.tx_num = TXNumPrefixRow(db, self._op_stack) - self.tx = TXPrefixRow(db, self._op_stack) - self.header = BlockHeaderPrefixRow(db, self._op_stack) - self.touched_or_deleted = TouchedOrDeletedPrefixRow(db, self._op_stack) - self.channel_count = ChannelCountPrefixRow(db, self._op_stack) - self.db_state = DBStatePrefixRow(db, self._op_stack) - self.support_amount = SupportAmountPrefixRow(db, self._op_stack) - self.block_txs = BlockTxsPrefixRow(db, self._op_stack) - - -def auto_decode_item(key: bytes, value: bytes) -> Union[Tuple[NamedTuple, NamedTuple], Tuple[bytes, bytes]]: - try: - return ROW_TYPES[key[:1]].unpack_item(key, value) - except KeyError: - return key, value diff --git a/lbry/wallet/server/db/revertable.py b/lbry/wallet/server/db/revertable.py deleted file mode 100644 index e59bbcdf3..000000000 --- a/lbry/wallet/server/db/revertable.py +++ /dev/null @@ -1,175 +0,0 @@ -import struct -import logging -from string import printable -from collections import defaultdict -from typing import Tuple, Iterable, Callable, Optional -from lbry.wallet.server.db import DB_PREFIXES - -_OP_STRUCT = struct.Struct('>BLL') -log = logging.getLogger() - - -class RevertableOp: - __slots__ = [ - 'key', - 'value', - ] - is_put = 0 - - def __init__(self, key: bytes, value: bytes): - self.key = key - self.value = value - - @property - def is_delete(self) -> bool: - return not self.is_put - - def invert(self) -> 'RevertableOp': - raise NotImplementedError() - - def pack(self) -> bytes: - """ - Serialize to bytes - """ - return struct.pack( - f'>BLL{len(self.key)}s{len(self.value)}s', int(self.is_put), len(self.key), len(self.value), self.key, - self.value - ) - - @classmethod - def unpack(cls, packed: bytes) -> Tuple['RevertableOp', bytes]: - """ - Deserialize from bytes - - :param packed: bytes containing at least one packed revertable op - :return: tuple of the deserialized op (a put or a delete) and the remaining serialized bytes - """ - is_put, key_len, val_len = _OP_STRUCT.unpack(packed[:9]) - key = packed[9:9 + key_len] - value = packed[9 + key_len:9 + key_len + val_len] - if is_put == 1: - return RevertablePut(key, value), packed[9 + key_len + val_len:] - return RevertableDelete(key, value), packed[9 + key_len + val_len:] - - def __eq__(self, other: 'RevertableOp') -> bool: - return (self.is_put, self.key, self.value) == (other.is_put, other.key, other.value) - - def __repr__(self) -> str: - return str(self) - - def __str__(self) -> str: - from lbry.wallet.server.db.prefixes import auto_decode_item - k, v = auto_decode_item(self.key, self.value) - key = ''.join(c if c in printable else '.' for c in str(k)) - val = ''.join(c if c in printable else '.' for c in str(v)) - return f"{'PUT' if self.is_put else 'DELETE'} {DB_PREFIXES(self.key[:1]).name}: {key} | {val}" - - -class RevertableDelete(RevertableOp): - def invert(self): - return RevertablePut(self.key, self.value) - - -class RevertablePut(RevertableOp): - is_put = True - - def invert(self): - return RevertableDelete(self.key, self.value) - - -class OpStackIntegrity(Exception): - pass - - -class RevertableOpStack: - def __init__(self, get_fn: Callable[[bytes], Optional[bytes]], unsafe_prefixes=None): - """ - This represents a sequence of revertable puts and deletes to a key-value database that checks for integrity - violations when applying the puts and deletes. The integrity checks assure that keys that do not exist - are not deleted, and that when keys are deleted the current value is correctly known so that the delete - may be undone. When putting values, the integrity checks assure that existing values are not overwritten - without first being deleted. Updates are performed by applying a delete op for the old value and a put op - for the new value. - - :param get_fn: getter function from an object implementing `KeyValueStorage` - :param unsafe_prefixes: optional set of prefixes to ignore integrity errors for, violations are still logged - """ - self._get = get_fn - self._items = defaultdict(list) - self._unsafe_prefixes = unsafe_prefixes or set() - - def append_op(self, op: RevertableOp): - """ - Apply a put or delete op, checking that it introduces no integrity errors - """ - - inverted = op.invert() - if self._items[op.key] and inverted == self._items[op.key][-1]: - self._items[op.key].pop() # if the new op is the inverse of the last op, we can safely null both - return - elif self._items[op.key] and self._items[op.key][-1] == op: # duplicate of last op - return # raise an error? - stored_val = self._get(op.key) - has_stored_val = stored_val is not None - delete_stored_op = None if not has_stored_val else RevertableDelete(op.key, stored_val) - will_delete_existing_stored = False if delete_stored_op is None else (delete_stored_op in self._items[op.key]) - try: - if op.is_put and has_stored_val and not will_delete_existing_stored: - raise OpStackIntegrity( - f"db op tries to add on top of existing key without deleting first: {op}" - ) - elif op.is_delete and has_stored_val and stored_val != op.value and not will_delete_existing_stored: - # there is a value and we're not deleting it in this op - # check that a delete for the stored value is in the stack - raise OpStackIntegrity(f"db op tries to delete with incorrect existing value {op}") - elif op.is_delete and not has_stored_val: - raise OpStackIntegrity(f"db op tries to delete nonexistent key: {op}") - elif op.is_delete and stored_val != op.value: - raise OpStackIntegrity(f"db op tries to delete with incorrect value: {op}") - except OpStackIntegrity as err: - if op.key[:1] in self._unsafe_prefixes: - log.debug(f"skipping over integrity error: {err}") - else: - raise err - self._items[op.key].append(op) - - def extend_ops(self, ops: Iterable[RevertableOp]): - """ - Apply a sequence of put or delete ops, checking that they introduce no integrity errors - """ - for op in ops: - self.append_op(op) - - def clear(self): - self._items.clear() - - def __len__(self): - return sum(map(len, self._items.values())) - - def __iter__(self): - for key, ops in self._items.items(): - for op in ops: - yield op - - def __reversed__(self): - for key, ops in self._items.items(): - for op in reversed(ops): - yield op - - def get_undo_ops(self) -> bytes: - """ - Get the serialized bytes to undo all of the changes made by the pending ops - """ - return b''.join(op.invert().pack() for op in reversed(self)) - - def apply_packed_undo_ops(self, packed: bytes): - """ - Unpack and apply a sequence of undo ops from serialized undo bytes - """ - while packed: - op, packed = RevertableOp.unpack(packed) - self.append_op(op) - - def get_last_op_for_key(self, key: bytes) -> Optional[RevertableOp]: - if key in self._items and self._items[key]: - return self._items[key][-1] diff --git a/lbry/wallet/server/env.py b/lbry/wallet/server/env.py deleted file mode 100644 index a109abf76..000000000 --- a/lbry/wallet/server/env.py +++ /dev/null @@ -1,384 +0,0 @@ -# Copyright (c) 2016, Neil Booth -# -# All rights reserved. -# -# See the file "LICENCE" for information about the copyright -# and warranty status of this software. - -import math -import re -import resource -from os import environ -from collections import namedtuple -from ipaddress import ip_address - -from lbry.wallet.server.util import class_logger -from lbry.wallet.server.coin import Coin, LBC, LBCTestNet, LBCRegTest -import lbry.wallet.server.util as lib_util - - -NetIdentity = namedtuple('NetIdentity', 'host tcp_port ssl_port nick_suffix') - - -class Env: - - # Peer discovery - PD_OFF, PD_SELF, PD_ON = range(3) - - class Error(Exception): - pass - - def __init__(self, coin=None, db_dir=None, daemon_url=None, host=None, rpc_host=None, elastic_host=None, - elastic_port=None, loop_policy=None, max_query_workers=None, websocket_host=None, websocket_port=None, - chain=None, es_index_prefix=None, es_mode=None, cache_MB=None, reorg_limit=None, tcp_port=None, - udp_port=None, ssl_port=None, ssl_certfile=None, ssl_keyfile=None, rpc_port=None, - prometheus_port=None, max_subscriptions=None, banner_file=None, anon_logs=None, log_sessions=None, - allow_lan_udp=None, cache_all_tx_hashes=None, cache_all_claim_txos=None, country=None, - payment_address=None, donation_address=None, max_send=None, max_receive=None, max_sessions=None, - session_timeout=None, drop_client=None, description=None, daily_fee=None, - database_query_timeout=None, db_max_open_files=512): - self.logger = class_logger(__name__, self.__class__.__name__) - - self.db_dir = db_dir if db_dir is not None else self.required('DB_DIRECTORY') - self.daemon_url = daemon_url if daemon_url is not None else self.required('DAEMON_URL') - self.db_max_open_files = db_max_open_files - - self.host = host if host is not None else self.default('HOST', 'localhost') - self.rpc_host = rpc_host if rpc_host is not None else self.default('RPC_HOST', 'localhost') - self.elastic_host = elastic_host if elastic_host is not None else self.default('ELASTIC_HOST', 'localhost') - self.elastic_port = elastic_port if elastic_port is not None else self.integer('ELASTIC_PORT', 9200) - self.loop_policy = self.set_event_loop_policy( - loop_policy if loop_policy is not None else self.default('EVENT_LOOP_POLICY', None) - ) - self.obsolete(['UTXO_MB', 'HIST_MB', 'NETWORK']) - self.max_query_workers = max_query_workers if max_query_workers is not None else self.integer('MAX_QUERY_WORKERS', 4) - self.websocket_host = websocket_host if websocket_host is not None else self.default('WEBSOCKET_HOST', self.host) - self.websocket_port = websocket_port if websocket_port is not None else self.integer('WEBSOCKET_PORT', None) - if coin is not None: - assert issubclass(coin, Coin) - self.coin = coin - else: - chain = chain if chain is not None else self.default('NET', 'mainnet').strip().lower() - if chain == 'mainnet': - self.coin = LBC - elif chain == 'testnet': - self.coin = LBCTestNet - else: - self.coin = LBCRegTest - self.es_index_prefix = es_index_prefix if es_index_prefix is not None else self.default('ES_INDEX_PREFIX', '') - self.es_mode = es_mode if es_mode is not None else self.default('ES_MODE', 'writer') - self.cache_MB = cache_MB if cache_MB is not None else self.integer('CACHE_MB', 1024) - self.reorg_limit = reorg_limit if reorg_limit is not None else self.integer('REORG_LIMIT', self.coin.REORG_LIMIT) - # Server stuff - self.tcp_port = tcp_port if tcp_port is not None else self.integer('TCP_PORT', None) - self.udp_port = udp_port if udp_port is not None else self.integer('UDP_PORT', self.tcp_port) - self.ssl_port = ssl_port if ssl_port is not None else self.integer('SSL_PORT', None) - if self.ssl_port: - self.ssl_certfile = ssl_certfile if ssl_certfile is not None else self.required('SSL_CERTFILE') - self.ssl_keyfile = ssl_keyfile if ssl_keyfile is not None else self.required('SSL_KEYFILE') - self.rpc_port = rpc_port if rpc_port is not None else self.integer('RPC_PORT', 8000) - self.prometheus_port = prometheus_port if prometheus_port is not None else self.integer('PROMETHEUS_PORT', 0) - self.max_subscriptions = max_subscriptions if max_subscriptions is not None else self.integer('MAX_SUBSCRIPTIONS', 10000) - self.banner_file = banner_file if banner_file is not None else self.default('BANNER_FILE', None) - # self.tor_banner_file = self.default('TOR_BANNER_FILE', self.banner_file) - self.anon_logs = anon_logs if anon_logs is not None else self.boolean('ANON_LOGS', False) - self.log_sessions = log_sessions if log_sessions is not None else self.integer('LOG_SESSIONS', 3600) - self.allow_lan_udp = allow_lan_udp if allow_lan_udp is not None else self.boolean('ALLOW_LAN_UDP', False) - self.cache_all_tx_hashes = cache_all_tx_hashes if cache_all_tx_hashes is not None else self.boolean('CACHE_ALL_TX_HASHES', False) - self.cache_all_claim_txos = cache_all_claim_txos if cache_all_claim_txos is not None else self.boolean('CACHE_ALL_CLAIM_TXOS', False) - self.country = country if country is not None else self.default('COUNTRY', 'US') - # Peer discovery - self.peer_discovery = self.peer_discovery_enum() - self.peer_announce = self.boolean('PEER_ANNOUNCE', True) - self.peer_hubs = self.extract_peer_hubs() - # self.tor_proxy_host = self.default('TOR_PROXY_HOST', 'localhost') - # self.tor_proxy_port = self.integer('TOR_PROXY_PORT', None) - # The electrum client takes the empty string as unspecified - self.payment_address = payment_address if payment_address is not None else self.default('PAYMENT_ADDRESS', '') - self.donation_address = donation_address if donation_address is not None else self.default('DONATION_ADDRESS', '') - # Server limits to help prevent DoS - self.max_send = max_send if max_send is not None else self.integer('MAX_SEND', 1000000) - self.max_receive = max_receive if max_receive is not None else self.integer('MAX_RECEIVE', 1000000) - # self.max_subs = self.integer('MAX_SUBS', 250000) - self.max_sessions = max_sessions if max_sessions is not None else self.sane_max_sessions() - # self.max_session_subs = self.integer('MAX_SESSION_SUBS', 50000) - self.session_timeout = session_timeout if session_timeout is not None else self.integer('SESSION_TIMEOUT', 600) - self.drop_client = drop_client if drop_client is not None else self.custom("DROP_CLIENT", None, re.compile) - self.description = description if description is not None else self.default('DESCRIPTION', '') - self.daily_fee = daily_fee if daily_fee is not None else self.string_amount('DAILY_FEE', '0') - - # Identities - clearnet_identity = self.clearnet_identity() - tor_identity = self.tor_identity(clearnet_identity) - self.identities = [identity - for identity in (clearnet_identity, tor_identity) - if identity is not None] - self.database_query_timeout = database_query_timeout if database_query_timeout is not None else \ - (float(self.integer('QUERY_TIMEOUT_MS', 10000)) / 1000.0) - - @classmethod - def default(cls, envvar, default): - return environ.get(envvar, default) - - @classmethod - def boolean(cls, envvar, default): - default = 'Yes' if default else '' - return bool(cls.default(envvar, default).strip()) - - @classmethod - def required(cls, envvar): - value = environ.get(envvar) - if value is None: - raise cls.Error(f'required envvar {envvar} not set') - return value - - @classmethod - def string_amount(cls, envvar, default): - value = environ.get(envvar, default) - amount_pattern = re.compile("[0-9]{0,10}(\.[0-9]{1,8})?") - if len(value) > 0 and not amount_pattern.fullmatch(value): - raise cls.Error(f'{value} is not a valid amount for {envvar}') - return value - - @classmethod - def integer(cls, envvar, default): - value = environ.get(envvar) - if value is None: - return default - try: - return int(value) - except Exception: - raise cls.Error(f'cannot convert envvar {envvar} value {value} to an integer') - - @classmethod - def custom(cls, envvar, default, parse): - value = environ.get(envvar) - if value is None: - return default - try: - return parse(value) - except Exception as e: - raise cls.Error(f'cannot parse envvar {envvar} value {value}') from e - - @classmethod - def obsolete(cls, envvars): - bad = [envvar for envvar in envvars if environ.get(envvar)] - if bad: - raise cls.Error(f'remove obsolete environment variables {bad}') - - @classmethod - def set_event_loop_policy(cls, policy_name: str = None): - if not policy_name or policy_name == 'default': - import asyncio - return asyncio.get_event_loop_policy() - elif policy_name == 'uvloop': - import uvloop - import asyncio - loop_policy = uvloop.EventLoopPolicy() - asyncio.set_event_loop_policy(loop_policy) - return loop_policy - raise cls.Error(f'unknown event loop policy "{policy_name}"') - - def cs_host(self, *, for_rpc): - """Returns the 'host' argument to pass to asyncio's create_server - call. The result can be a single host name string, a list of - host name strings, or an empty string to bind to all interfaces. - - If rpc is True the host to use for the RPC server is returned. - Otherwise the host to use for SSL/TCP servers is returned. - """ - host = self.rpc_host if for_rpc else self.host - result = [part.strip() for part in host.split(',')] - if len(result) == 1: - result = result[0] - # An empty result indicates all interfaces, which we do not - # permitted for an RPC server. - if for_rpc and not result: - result = 'localhost' - if result == 'localhost': - # 'localhost' resolves to ::1 (ipv6) on many systems, which fails on default setup of - # docker, using 127.0.0.1 instead forces ipv4 - result = '127.0.0.1' - return result - - def sane_max_sessions(self): - """Return the maximum number of sessions to permit. Normally this - is MAX_SESSIONS. However, to prevent open file exhaustion, ajdust - downwards if running with a small open file rlimit.""" - env_value = self.integer('MAX_SESSIONS', 1000) - nofile_limit = resource.getrlimit(resource.RLIMIT_NOFILE)[0] - # We give the DB 250 files; allow ElectrumX 100 for itself - value = max(0, min(env_value, nofile_limit - 350)) - if value < env_value: - self.logger.warning(f'lowered maximum sessions from {env_value:,d} to {value:,d} ' - f'because your open file limit is {nofile_limit:,d}') - return value - - def clearnet_identity(self): - host = self.default('REPORT_HOST', None) - if host is None: - return None - try: - ip = ip_address(host) - except ValueError: - bad = (not lib_util.is_valid_hostname(host) - or host.lower() == 'localhost') - else: - bad = (ip.is_multicast or ip.is_unspecified - or (ip.is_private and self.peer_announce)) - if bad: - raise self.Error(f'"{host}" is not a valid REPORT_HOST') - tcp_port = self.integer('REPORT_TCP_PORT', self.tcp_port) or None - ssl_port = self.integer('REPORT_SSL_PORT', self.ssl_port) or None - if tcp_port == ssl_port: - raise self.Error('REPORT_TCP_PORT and REPORT_SSL_PORT ' - f'both resolve to {tcp_port}') - return NetIdentity( - host, - tcp_port, - ssl_port, - '' - ) - - def tor_identity(self, clearnet): - host = self.default('REPORT_HOST_TOR', None) - if host is None: - return None - if not host.endswith('.onion'): - raise self.Error(f'tor host "{host}" must end with ".onion"') - - def port(port_kind): - """Returns the clearnet identity port, if any and not zero, - otherwise the listening port.""" - result = 0 - if clearnet: - result = getattr(clearnet, port_kind) - return result or getattr(self, port_kind) - - tcp_port = self.integer('REPORT_TCP_PORT_TOR', - port('tcp_port')) or None - ssl_port = self.integer('REPORT_SSL_PORT_TOR', - port('ssl_port')) or None - if tcp_port == ssl_port: - raise self.Error('REPORT_TCP_PORT_TOR and REPORT_SSL_PORT_TOR ' - f'both resolve to {tcp_port}') - - return NetIdentity( - host, - tcp_port, - ssl_port, - '_tor', - ) - - def hosts_dict(self): - return {identity.host: {'tcp_port': identity.tcp_port, - 'ssl_port': identity.ssl_port} - for identity in self.identities} - - def peer_discovery_enum(self): - pd = self.default('PEER_DISCOVERY', 'on').strip().lower() - if pd in ('off', ''): - return self.PD_OFF - elif pd == 'self': - return self.PD_SELF - else: - return self.PD_ON - - def extract_peer_hubs(self): - return [hub.strip() for hub in self.default('PEER_HUBS', '').split(',') if hub.strip()] - - @classmethod - def contribute_to_arg_parser(cls, parser): - parser.add_argument('--db_dir', type=str, help='path of the directory containing lbry-leveldb', - default=cls.default('DB_DIRECTORY', None)) - parser.add_argument('--daemon_url', - help='URL for rpc from lbrycrd, :@', - default=cls.default('DAEMON_URL', None)) - parser.add_argument('--db_max_open_files', type=int, default=512, - help='number of files leveldb can have open at a time') - parser.add_argument('--host', type=str, default=cls.default('HOST', 'localhost'), - help='Interface for hub server to listen on') - parser.add_argument('--tcp_port', type=int, default=cls.integer('TCP_PORT', 50001), - help='TCP port to listen on for hub server') - parser.add_argument('--udp_port', type=int, default=cls.integer('UDP_PORT', 50001), - help='UDP port to listen on for hub server') - parser.add_argument('--rpc_host', default=cls.default('RPC_HOST', 'localhost'), type=str, - help='Listening interface for admin rpc') - parser.add_argument('--rpc_port', default=cls.integer('RPC_PORT', 8000), type=int, - help='Listening port for admin rpc') - parser.add_argument('--websocket_host', default=cls.default('WEBSOCKET_HOST', 'localhost'), type=str, - help='Listening interface for websocket') - parser.add_argument('--websocket_port', default=cls.integer('WEBSOCKET_PORT', None), type=int, - help='Listening port for websocket') - - parser.add_argument('--ssl_port', default=cls.integer('SSL_PORT', None), type=int, - help='SSL port to listen on for hub server') - parser.add_argument('--ssl_certfile', default=cls.default('SSL_CERTFILE', None), type=str, - help='Path to SSL cert file') - parser.add_argument('--ssl_keyfile', default=cls.default('SSL_KEYFILE', None), type=str, - help='Path to SSL key file') - parser.add_argument('--reorg_limit', default=cls.integer('REORG_LIMIT', 200), type=int, help='Max reorg depth') - parser.add_argument('--elastic_host', default=cls.default('ELASTIC_HOST', 'localhost'), type=str, - help='elasticsearch host') - parser.add_argument('--elastic_port', default=cls.integer('ELASTIC_PORT', 9200), type=int, - help='elasticsearch port') - parser.add_argument('--es_mode', default=cls.default('ES_MODE', 'writer'), type=str, - choices=['reader', 'writer']) - parser.add_argument('--es_index_prefix', default=cls.default('ES_INDEX_PREFIX', ''), type=str) - parser.add_argument('--loop_policy', default=cls.default('EVENT_LOOP_POLICY', 'default'), type=str, - choices=['default', 'uvloop']) - parser.add_argument('--max_query_workers', type=int, default=cls.integer('MAX_QUERY_WORKERS', 4), - help='number of threads used by the request handler to read the database') - parser.add_argument('--cache_MB', type=int, default=cls.integer('CACHE_MB', 1024), - help='size of the leveldb lru cache, in megabytes') - parser.add_argument('--cache_all_tx_hashes', type=bool, - help='Load all tx hashes into memory. This will make address subscriptions and sync, ' - 'resolve, transaction fetching, and block sync all faster at the expense of higher ' - 'memory usage') - parser.add_argument('--cache_all_claim_txos', type=bool, - help='Load all claim txos into memory. This will make address subscriptions and sync, ' - 'resolve, transaction fetching, and block sync all faster at the expense of higher ' - 'memory usage') - parser.add_argument('--prometheus_port', type=int, default=cls.integer('PROMETHEUS_PORT', 0), - help='port for hub prometheus metrics to listen on, disabled by default') - parser.add_argument('--max_subscriptions', type=int, default=cls.integer('MAX_SUBSCRIPTIONS', 10000), - help='max subscriptions per connection') - parser.add_argument('--banner_file', type=str, default=cls.default('BANNER_FILE', None), - help='path to file containing banner text') - parser.add_argument('--anon_logs', type=bool, default=cls.boolean('ANON_LOGS', False), - help="don't log ip addresses") - parser.add_argument('--allow_lan_udp', type=bool, default=cls.boolean('ALLOW_LAN_UDP', False), - help='reply to hub UDP ping messages from LAN ip addresses') - parser.add_argument('--country', type=str, default=cls.default('COUNTRY', 'US'), help='') - parser.add_argument('--max_send', type=int, default=cls.default('MAX_SEND', 1000000), help='') - parser.add_argument('--max_receive', type=int, default=cls.default('MAX_RECEIVE', 1000000), help='') - parser.add_argument('--max_sessions', type=int, default=cls.default('MAX_SESSIONS', 1000), help='') - parser.add_argument('--session_timeout', type=int, default=cls.default('SESSION_TIMEOUT', 600), help='') - parser.add_argument('--drop_client', type=str, default=cls.default('DROP_CLIENT', None), help='') - parser.add_argument('--description', type=str, default=cls.default('DESCRIPTION', ''), help='') - parser.add_argument('--daily_fee', type=float, default=cls.default('DAILY_FEE', 0.0), help='') - parser.add_argument('--payment_address', type=str, default=cls.default('PAYMENT_ADDRESS', ''), help='') - parser.add_argument('--donation_address', type=str, default=cls.default('DONATION_ADDRESS', ''), help='') - parser.add_argument('--chain', type=str, default=cls.default('NET', 'mainnet'), - help="Which chain to use, default is mainnet") - parser.add_argument('--query_timeout_ms', type=int, default=cls.integer('QUERY_TIMEOUT_MS', 10000), - help="elasticsearch query timeout") - - @classmethod - def from_arg_parser(cls, args): - return cls( - db_dir=args.db_dir, daemon_url=args.daemon_url, db_max_open_files=args.db_max_open_files, - host=args.host, rpc_host=args.rpc_host, elastic_host=args.elastic_host, elastic_port=args.elastic_port, - loop_policy=args.loop_policy, max_query_workers=args.max_query_workers, websocket_host=args.websocket_host, - websocket_port=args.websocket_port, chain=args.chain, es_index_prefix=args.es_index_prefix, - es_mode=args.es_mode, cache_MB=args.cache_MB, reorg_limit=args.reorg_limit, tcp_port=args.tcp_port, - udp_port=args.udp_port, ssl_port=args.ssl_port, ssl_certfile=args.ssl_certfile, - ssl_keyfile=args.ssl_keyfile, rpc_port=args.rpc_port, prometheus_port=args.prometheus_port, - max_subscriptions=args.max_subscriptions, banner_file=args.banner_file, anon_logs=args.anon_logs, - log_sessions=None, allow_lan_udp=args.allow_lan_udp, - cache_all_tx_hashes=args.cache_all_tx_hashes, cache_all_claim_txos=args.cache_all_claim_txos, - country=args.country, payment_address=args.payment_address, donation_address=args.donation_address, - max_send=args.max_send, max_receive=args.max_receive, max_sessions=args.max_sessions, - session_timeout=args.session_timeout, drop_client=args.drop_client, description=args.description, - daily_fee=args.daily_fee, database_query_timeout=(args.query_timeout_ms / 1000) - ) diff --git a/lbry/wallet/server/hash.py b/lbry/wallet/server/hash.py deleted file mode 100644 index e9d088684..000000000 --- a/lbry/wallet/server/hash.py +++ /dev/null @@ -1,160 +0,0 @@ -# Copyright (c) 2016-2017, Neil Booth -# -# All rights reserved. -# -# The MIT License (MIT) -# -# Permission is hereby granted, free of charge, to any person obtaining -# a copy of this software and associated documentation files (the -# "Software"), to deal in the Software without restriction, including -# without limitation the rights to use, copy, modify, merge, publish, -# distribute, sublicense, and/or sell copies of the Software, and to -# permit persons to whom the Software is furnished to do so, subject to -# the following conditions: -# -# The above copyright notice and this permission notice shall be -# included in all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -"""Cryptograph hash functions and related classes.""" - - -import hashlib -import hmac - -from lbry.wallet.server.util import bytes_to_int, int_to_bytes, hex_to_bytes - -_sha256 = hashlib.sha256 -_sha512 = hashlib.sha512 -_new_hash = hashlib.new -_new_hmac = hmac.new -HASHX_LEN = 11 -CLAIM_HASH_LEN = 20 - - -def sha256(x): - """Simple wrapper of hashlib sha256.""" - return _sha256(x).digest() - - -def ripemd160(x): - """Simple wrapper of hashlib ripemd160.""" - h = _new_hash('ripemd160') - h.update(x) - return h.digest() - - -def double_sha256(x): - """SHA-256 of SHA-256, as used extensively in bitcoin.""" - return sha256(sha256(x)) - - -def hmac_sha512(key, msg): - """Use SHA-512 to provide an HMAC.""" - return _new_hmac(key, msg, _sha512).digest() - - -def hash160(x): - """RIPEMD-160 of SHA-256. - - Used to make bitcoin addresses from pubkeys.""" - return ripemd160(sha256(x)) - - -def hash_to_hex_str(x: bytes) -> str: - """Convert a big-endian binary hash to displayed hex string. - - Display form of a binary hash is reversed and converted to hex. - """ - return x[::-1].hex() - - -def hex_str_to_hash(x: str) -> bytes: - """Convert a displayed hex string to a binary hash.""" - return hex_to_bytes(x)[::-1] - - -class Base58Error(Exception): - """Exception used for Base58 errors.""" - - -class Base58: - """Class providing base 58 functionality.""" - - chars = '123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz' - assert len(chars) == 58 - cmap = {c: n for n, c in enumerate(chars)} - - @staticmethod - def char_value(c): - val = Base58.cmap.get(c) - if val is None: - raise Base58Error(f'invalid base 58 character "{c}"') - return val - - @staticmethod - def decode(txt): - """Decodes txt into a big-endian bytearray.""" - if not isinstance(txt, str): - raise TypeError('a string is required') - - if not txt: - raise Base58Error('string cannot be empty') - - value = 0 - for c in txt: - value = value * 58 + Base58.char_value(c) - - result = int_to_bytes(value) - - # Prepend leading zero bytes if necessary - count = 0 - for c in txt: - if c != '1': - break - count += 1 - if count: - result = bytes(count) + result - - return result - - @staticmethod - def encode(be_bytes): - """Converts a big-endian bytearray into a base58 string.""" - value = bytes_to_int(be_bytes) - - txt = '' - while value: - value, mod = divmod(value, 58) - txt += Base58.chars[mod] - - for byte in be_bytes: - if byte != 0: - break - txt += '1' - - return txt[::-1] - - @staticmethod - def decode_check(txt, *, hash_fn=double_sha256): - """Decodes a Base58Check-encoded string to a payload. The version - prefixes it.""" - be_bytes = Base58.decode(txt) - result, check = be_bytes[:-4], be_bytes[-4:] - if check != hash_fn(result)[:4]: - raise Base58Error(f'invalid base 58 checksum for {txt}') - return result - - @staticmethod - def encode_check(payload, *, hash_fn=double_sha256): - """Encodes a payload bytearray (which includes the version byte(s)) - into a Base58Check string.""" - be_bytes = payload + hash_fn(payload)[:4] - return Base58.encode(be_bytes) diff --git a/lbry/wallet/server/leveldb.py b/lbry/wallet/server/leveldb.py deleted file mode 100644 index 007d7c02c..000000000 --- a/lbry/wallet/server/leveldb.py +++ /dev/null @@ -1,1151 +0,0 @@ -# Copyright (c) 2016, Neil Booth -# Copyright (c) 2017, the ElectrumX authors -# -# All rights reserved. -# -# See the file "LICENCE" for information about the copyright -# and warranty status of this software. - -"""Interface to the blockchain database.""" - -import os -import asyncio -import array -import time -import typing -import struct -import zlib -import base64 -from typing import Optional, Iterable, Tuple, DefaultDict, Set, Dict, List, TYPE_CHECKING -from functools import partial -from asyncio import sleep -from bisect import bisect_right -from collections import defaultdict - -from lbry.error import ResolveCensoredError -from lbry.schema.result import Censor -from lbry.utils import LRUCacheWithMetrics -from lbry.schema.url import URL, normalize_name -from lbry.wallet.server import util -from lbry.wallet.server.hash import hash_to_hex_str -from lbry.wallet.server.tx import TxInput -from lbry.wallet.server.merkle import Merkle, MerkleCache -from lbry.wallet.server.db.common import ResolveResult, STREAM_TYPES, CLAIM_TYPES -from lbry.wallet.server.db.prefixes import PendingActivationValue, ClaimTakeoverValue, ClaimToTXOValue, HubDB -from lbry.wallet.server.db.prefixes import ACTIVATED_CLAIM_TXO_TYPE, ACTIVATED_SUPPORT_TXO_TYPE -from lbry.wallet.server.db.prefixes import PendingActivationKey, TXOToClaimValue, DBStatePrefixRow -from lbry.wallet.transaction import OutputScript -from lbry.schema.claim import Claim, guess_stream_type -from lbry.wallet.ledger import Ledger, RegTestLedger, TestNetLedger - -from lbry.wallet.server.db.elasticsearch import SearchIndex - -if TYPE_CHECKING: - from lbry.wallet.server.db.prefixes import EffectiveAmountKey - - -class UTXO(typing.NamedTuple): - tx_num: int - tx_pos: int - tx_hash: bytes - height: int - value: int - - -TXO_STRUCT = struct.Struct(b'>LH') -TXO_STRUCT_unpack = TXO_STRUCT.unpack -TXO_STRUCT_pack = TXO_STRUCT.pack -OptionalResolveResultOrError = Optional[typing.Union[ResolveResult, ResolveCensoredError, LookupError, ValueError]] - - -class ExpandedResolveResult(typing.NamedTuple): - stream: OptionalResolveResultOrError - channel: OptionalResolveResultOrError - repost: OptionalResolveResultOrError - reposted_channel: OptionalResolveResultOrError - - -class DBError(Exception): - """Raised on general DB errors generally indicating corruption.""" - - -class LevelDB: - DB_VERSIONS = HIST_DB_VERSIONS = [7] - - def __init__(self, env): - self.logger = util.class_logger(__name__, self.__class__.__name__) - self.env = env - self.coin = env.coin - - self.logger.info(f'switching current directory to {env.db_dir}') - - self.prefix_db = None - - self.hist_unflushed = defaultdict(partial(array.array, 'I')) - self.hist_unflushed_count = 0 - self.hist_flush_count = 0 - self.hist_comp_flush_count = -1 - self.hist_comp_cursor = -1 - - self.es_sync_height = 0 - - # blocking/filtering dicts - blocking_channels = self.env.default('BLOCKING_CHANNEL_IDS', '').split(' ') - filtering_channels = self.env.default('FILTERING_CHANNEL_IDS', '').split(' ') - self.blocked_streams = {} - self.blocked_channels = {} - self.blocking_channel_hashes = { - bytes.fromhex(channel_id) for channel_id in blocking_channels if channel_id - } - self.filtered_streams = {} - self.filtered_channels = {} - self.filtering_channel_hashes = { - bytes.fromhex(channel_id) for channel_id in filtering_channels if channel_id - } - - self.tx_counts = None - self.headers = None - self.encoded_headers = LRUCacheWithMetrics(1 << 21, metric_name='encoded_headers', namespace='wallet_server') - self.last_flush = time.time() - - # Header merkle cache - self.merkle = Merkle() - self.header_mc = MerkleCache(self.merkle, self.fs_block_hashes) - - self._tx_and_merkle_cache = LRUCacheWithMetrics(2 ** 16, metric_name='tx_and_merkle', namespace="wallet_server") - - # these are only used if the cache_all_tx_hashes setting is on - self.total_transactions: List[bytes] = [] - self.tx_num_mapping: Dict[bytes, int] = {} - - # these are only used if the cache_all_claim_txos setting is on - self.claim_to_txo: Dict[bytes, ClaimToTXOValue] = {} - self.txo_to_claim: DefaultDict[int, Dict[int, bytes]] = defaultdict(dict) - - # Search index - self.search_index = SearchIndex( - self.env.es_index_prefix, self.env.database_query_timeout, - elastic_host=env.elastic_host, elastic_port=env.elastic_port - ) - - self.genesis_bytes = bytes.fromhex(self.coin.GENESIS_HASH) - - if env.coin.NET == 'mainnet': - self.ledger = Ledger - elif env.coin.NET == 'testnet': - self.ledger = TestNetLedger - else: - self.ledger = RegTestLedger - - def get_claim_from_txo(self, tx_num: int, tx_idx: int) -> Optional[TXOToClaimValue]: - claim_hash_and_name = self.prefix_db.txo_to_claim.get(tx_num, tx_idx) - if not claim_hash_and_name: - return - return claim_hash_and_name - - def get_repost(self, claim_hash) -> Optional[bytes]: - repost = self.prefix_db.repost.get(claim_hash) - if repost: - return repost.reposted_claim_hash - return - - def get_reposted_count(self, claim_hash: bytes) -> int: - return sum( - 1 for _ in self.prefix_db.reposted_claim.iterate(prefix=(claim_hash,), include_value=False) - ) - - def get_activation(self, tx_num, position, is_support=False) -> int: - activation = self.prefix_db.activated.get( - ACTIVATED_SUPPORT_TXO_TYPE if is_support else ACTIVATED_CLAIM_TXO_TYPE, tx_num, position - ) - if activation: - return activation.height - return -1 - - def get_supported_claim_from_txo(self, tx_num: int, position: int) -> typing.Tuple[Optional[bytes], Optional[int]]: - supported_claim_hash = self.prefix_db.support_to_claim.get(tx_num, position) - if supported_claim_hash: - packed_support_amount = self.prefix_db.claim_to_support.get( - supported_claim_hash.claim_hash, tx_num, position - ) - if packed_support_amount: - return supported_claim_hash.claim_hash, packed_support_amount.amount - return None, None - - def get_support_amount(self, claim_hash: bytes): - support_amount_val = self.prefix_db.support_amount.get(claim_hash) - if support_amount_val is None: - return 0 - return support_amount_val.amount - - def get_supports(self, claim_hash: bytes): - return [ - (k.tx_num, k.position, v.amount) for k, v in self.prefix_db.claim_to_support.iterate(prefix=(claim_hash,)) - ] - - def get_short_claim_id_url(self, name: str, normalized_name: str, claim_hash: bytes, - root_tx_num: int, root_position: int) -> str: - claim_id = claim_hash.hex() - for prefix_len in range(10): - for k in self.prefix_db.claim_short_id.iterate(prefix=(normalized_name, claim_id[:prefix_len+1]), - include_value=False): - if k.root_tx_num == root_tx_num and k.root_position == root_position: - return f'{name}#{k.partial_claim_id}' - break - print(f"{claim_id} has a collision") - return f'{name}#{claim_id}' - - def _prepare_resolve_result(self, tx_num: int, position: int, claim_hash: bytes, name: str, - root_tx_num: int, root_position: int, activation_height: int, - signature_valid: bool) -> ResolveResult: - try: - normalized_name = normalize_name(name) - except UnicodeDecodeError: - normalized_name = name - controlling_claim = self.get_controlling_claim(normalized_name) - - tx_hash = self.get_tx_hash(tx_num) - height = bisect_right(self.tx_counts, tx_num) - created_height = bisect_right(self.tx_counts, root_tx_num) - last_take_over_height = controlling_claim.height - - expiration_height = self.coin.get_expiration_height(height) - support_amount = self.get_support_amount(claim_hash) - claim_amount = self.get_cached_claim_txo(claim_hash).amount - - effective_amount = self.get_effective_amount(claim_hash) - channel_hash = self.get_channel_for_claim(claim_hash, tx_num, position) - reposted_claim_hash = self.get_repost(claim_hash) - short_url = self.get_short_claim_id_url(name, normalized_name, claim_hash, root_tx_num, root_position) - canonical_url = short_url - claims_in_channel = self.get_claims_in_channel_count(claim_hash) - if channel_hash: - channel_vals = self.get_cached_claim_txo(channel_hash) - if channel_vals: - channel_short_url = self.get_short_claim_id_url( - channel_vals.name, channel_vals.normalized_name, channel_hash, channel_vals.root_tx_num, - channel_vals.root_position - ) - canonical_url = f'{channel_short_url}/{short_url}' - return ResolveResult( - name, normalized_name, claim_hash, tx_num, position, tx_hash, height, claim_amount, short_url=short_url, - is_controlling=controlling_claim.claim_hash == claim_hash, canonical_url=canonical_url, - last_takeover_height=last_take_over_height, claims_in_channel=claims_in_channel, - creation_height=created_height, activation_height=activation_height, - expiration_height=expiration_height, effective_amount=effective_amount, support_amount=support_amount, - channel_hash=channel_hash, reposted_claim_hash=reposted_claim_hash, - reposted=self.get_reposted_count(claim_hash), - signature_valid=None if not channel_hash else signature_valid - ) - - def _resolve_parsed_url(self, name: str, claim_id: Optional[str] = None, - amount_order: Optional[int] = None) -> Optional[ResolveResult]: - """ - :param normalized_name: name - :param claim_id: partial or complete claim id - :param amount_order: '$' suffix to a url, defaults to 1 (winning) if no claim id modifier is provided - """ - try: - normalized_name = normalize_name(name) - except UnicodeDecodeError: - normalized_name = name - if (not amount_order and not claim_id) or amount_order == 1: - # winning resolution - controlling = self.get_controlling_claim(normalized_name) - if not controlling: - # print(f"none controlling for lbry://{normalized_name}") - return - # print(f"resolved controlling lbry://{normalized_name}#{controlling.claim_hash.hex()}") - return self._fs_get_claim_by_hash(controlling.claim_hash) - - amount_order = max(int(amount_order or 1), 1) - - if claim_id: - if len(claim_id) == 40: # a full claim id - claim_txo = self.get_claim_txo(bytes.fromhex(claim_id)) - if not claim_txo or normalized_name != claim_txo.normalized_name: - return - return self._prepare_resolve_result( - claim_txo.tx_num, claim_txo.position, bytes.fromhex(claim_id), claim_txo.name, - claim_txo.root_tx_num, claim_txo.root_position, - self.get_activation(claim_txo.tx_num, claim_txo.position), claim_txo.channel_signature_is_valid - ) - # resolve by partial/complete claim id - for key, claim_txo in self.prefix_db.claim_short_id.iterate(prefix=(normalized_name, claim_id[:10])): - full_claim_hash = self.get_cached_claim_hash(claim_txo.tx_num, claim_txo.position) - c = self.get_cached_claim_txo(full_claim_hash) - - non_normalized_name = c.name - signature_is_valid = c.channel_signature_is_valid - return self._prepare_resolve_result( - claim_txo.tx_num, claim_txo.position, full_claim_hash, non_normalized_name, key.root_tx_num, - key.root_position, self.get_activation(claim_txo.tx_num, claim_txo.position), - signature_is_valid - ) - return - - # resolve by amount ordering, 1 indexed - for idx, (key, claim_val) in enumerate(self.prefix_db.effective_amount.iterate(prefix=(normalized_name,))): - if amount_order > idx + 1: - continue - claim_txo = self.get_cached_claim_txo(claim_val.claim_hash) - activation = self.get_activation(key.tx_num, key.position) - return self._prepare_resolve_result( - key.tx_num, key.position, claim_val.claim_hash, key.normalized_name, claim_txo.root_tx_num, - claim_txo.root_position, activation, claim_txo.channel_signature_is_valid - ) - return - - def _resolve_claim_in_channel(self, channel_hash: bytes, normalized_name: str): - candidates = [] - for key, stream in self.prefix_db.channel_to_claim.iterate(prefix=(channel_hash, normalized_name)): - effective_amount = self.get_effective_amount(stream.claim_hash) - if not candidates or candidates[-1][-1] == effective_amount: - candidates.append((stream.claim_hash, key.tx_num, key.position, effective_amount)) - else: - break - if not candidates: - return - return list(sorted(candidates, key=lambda item: item[1]))[0] - - def _resolve(self, url) -> ExpandedResolveResult: - try: - parsed = URL.parse(url) - except ValueError as e: - return ExpandedResolveResult(e, None, None, None) - - stream = channel = resolved_channel = resolved_stream = None - if parsed.has_stream_in_channel: - channel = parsed.channel - stream = parsed.stream - elif parsed.has_channel: - channel = parsed.channel - elif parsed.has_stream: - stream = parsed.stream - if channel: - resolved_channel = self._resolve_parsed_url(channel.name, channel.claim_id, channel.amount_order) - if not resolved_channel: - return ExpandedResolveResult(None, LookupError(f'Could not find channel in "{url}".'), None, None) - if stream: - if resolved_channel: - stream_claim = self._resolve_claim_in_channel(resolved_channel.claim_hash, stream.normalized) - if stream_claim: - stream_claim_id, stream_tx_num, stream_tx_pos, effective_amount = stream_claim - resolved_stream = self._fs_get_claim_by_hash(stream_claim_id) - else: - resolved_stream = self._resolve_parsed_url(stream.name, stream.claim_id, stream.amount_order) - if not channel and not resolved_channel and resolved_stream and resolved_stream.channel_hash: - resolved_channel = self._fs_get_claim_by_hash(resolved_stream.channel_hash) - if not resolved_stream: - return ExpandedResolveResult(LookupError(f'Could not find claim at "{url}".'), None, None, None) - - repost = None - reposted_channel = None - if resolved_stream or resolved_channel: - claim_hash = resolved_stream.claim_hash if resolved_stream else resolved_channel.claim_hash - claim = resolved_stream if resolved_stream else resolved_channel - reposted_claim_hash = resolved_stream.reposted_claim_hash if resolved_stream else None - blocker_hash = self.blocked_streams.get(claim_hash) or self.blocked_streams.get( - reposted_claim_hash) or self.blocked_channels.get(claim_hash) or self.blocked_channels.get( - reposted_claim_hash) or self.blocked_channels.get(claim.channel_hash) - if blocker_hash: - reason_row = self._fs_get_claim_by_hash(blocker_hash) - return ExpandedResolveResult( - None, ResolveCensoredError(url, blocker_hash, censor_row=reason_row), None, None - ) - if claim.reposted_claim_hash: - repost = self._fs_get_claim_by_hash(claim.reposted_claim_hash) - if repost and repost.channel_hash and repost.signature_valid: - reposted_channel = self._fs_get_claim_by_hash(repost.channel_hash) - return ExpandedResolveResult(resolved_stream, resolved_channel, repost, reposted_channel) - - async def resolve(self, url) -> ExpandedResolveResult: - return await asyncio.get_event_loop().run_in_executor(None, self._resolve, url) - - def _fs_get_claim_by_hash(self, claim_hash): - claim = self.get_cached_claim_txo(claim_hash) - if claim: - activation = self.get_activation(claim.tx_num, claim.position) - return self._prepare_resolve_result( - claim.tx_num, claim.position, claim_hash, claim.name, claim.root_tx_num, claim.root_position, - activation, claim.channel_signature_is_valid - ) - - async def fs_getclaimbyid(self, claim_id): - return await asyncio.get_event_loop().run_in_executor( - None, self._fs_get_claim_by_hash, bytes.fromhex(claim_id) - ) - - def get_claim_txo_amount(self, claim_hash: bytes) -> Optional[int]: - claim = self.get_claim_txo(claim_hash) - if claim: - return claim.amount - - def get_block_hash(self, height: int) -> Optional[bytes]: - v = self.prefix_db.block_hash.get(height) - if v: - return v.block_hash - - def get_support_txo_amount(self, claim_hash: bytes, tx_num: int, position: int) -> Optional[int]: - v = self.prefix_db.claim_to_support.get(claim_hash, tx_num, position) - return None if not v else v.amount - - def get_claim_txo(self, claim_hash: bytes) -> Optional[ClaimToTXOValue]: - assert claim_hash - return self.prefix_db.claim_to_txo.get(claim_hash) - - def _get_active_amount(self, claim_hash: bytes, txo_type: int, height: int) -> int: - return sum( - v.amount for v in self.prefix_db.active_amount.iterate( - start=(claim_hash, txo_type, 0), stop=(claim_hash, txo_type, height), include_key=False - ) - ) - - def get_active_amount_as_of_height(self, claim_hash: bytes, height: int) -> int: - for v in self.prefix_db.active_amount.iterate( - start=(claim_hash, ACTIVATED_CLAIM_TXO_TYPE, 0), stop=(claim_hash, ACTIVATED_CLAIM_TXO_TYPE, height), - include_key=False, reverse=True): - return v.amount - return 0 - - def get_effective_amount(self, claim_hash: bytes, support_only=False) -> int: - support_amount = self._get_active_amount(claim_hash, ACTIVATED_SUPPORT_TXO_TYPE, self.db_height + 1) - if support_only: - return support_only - return support_amount + self._get_active_amount(claim_hash, ACTIVATED_CLAIM_TXO_TYPE, self.db_height + 1) - - def get_url_effective_amount(self, name: str, claim_hash: bytes) -> Optional['EffectiveAmountKey']: - for k, v in self.prefix_db.effective_amount.iterate(prefix=(name,)): - if v.claim_hash == claim_hash: - return k - - def get_claims_for_name(self, name): - claims = [] - prefix = self.prefix_db.claim_short_id.pack_partial_key(name) + bytes([1]) - for _k, _v in self.prefix_db.iterator(prefix=prefix): - v = self.prefix_db.claim_short_id.unpack_value(_v) - claim_hash = self.get_claim_from_txo(v.tx_num, v.position).claim_hash - if claim_hash not in claims: - claims.append(claim_hash) - return claims - - def get_claims_in_channel_count(self, channel_hash) -> int: - channel_count_val = self.prefix_db.channel_count.get(channel_hash) - if channel_count_val is None: - return 0 - return channel_count_val.count - - async def reload_blocking_filtering_streams(self): - def reload(): - self.blocked_streams, self.blocked_channels = self.get_streams_and_channels_reposted_by_channel_hashes( - self.blocking_channel_hashes - ) - self.filtered_streams, self.filtered_channels = self.get_streams_and_channels_reposted_by_channel_hashes( - self.filtering_channel_hashes - ) - await asyncio.get_event_loop().run_in_executor(None, reload) - - def get_streams_and_channels_reposted_by_channel_hashes(self, reposter_channel_hashes: Set[bytes]): - streams, channels = {}, {} - for reposter_channel_hash in reposter_channel_hashes: - for stream in self.prefix_db.channel_to_claim.iterate((reposter_channel_hash, ), include_key=False): - repost = self.get_repost(stream.claim_hash) - if repost: - txo = self.get_claim_txo(repost) - if txo: - if txo.normalized_name.startswith('@'): - channels[repost] = reposter_channel_hash - else: - streams[repost] = reposter_channel_hash - return streams, channels - - def get_channel_for_claim(self, claim_hash, tx_num, position) -> Optional[bytes]: - v = self.prefix_db.claim_to_channel.get(claim_hash, tx_num, position) - if v: - return v.signing_hash - - def get_expired_by_height(self, height: int) -> Dict[bytes, Tuple[int, int, str, TxInput]]: - expired = {} - for k, v in self.prefix_db.claim_expiration.iterate(prefix=(height,)): - tx_hash = self.get_tx_hash(k.tx_num) - tx = self.coin.transaction(self.prefix_db.tx.get(tx_hash, deserialize_value=False)) - # treat it like a claim spend so it will delete/abandon properly - # the _spend_claim function this result is fed to expects a txi, so make a mock one - # print(f"\texpired lbry://{v.name} {v.claim_hash.hex()}") - expired[v.claim_hash] = ( - k.tx_num, k.position, v.normalized_name, - TxInput(prev_hash=tx_hash, prev_idx=k.position, script=tx.outputs[k.position].pk_script, sequence=0) - ) - return expired - - def get_controlling_claim(self, name: str) -> Optional[ClaimTakeoverValue]: - controlling = self.prefix_db.claim_takeover.get(name) - if not controlling: - return - return controlling - - def get_claim_txos_for_name(self, name: str): - txos = {} - prefix = self.prefix_db.claim_short_id.pack_partial_key(name) + int(1).to_bytes(1, byteorder='big') - for k, v in self.prefix_db.iterator(prefix=prefix): - tx_num, nout = self.prefix_db.claim_short_id.unpack_value(v) - txos[self.get_claim_from_txo(tx_num, nout).claim_hash] = tx_num, nout - return txos - - def get_claim_metadata(self, tx_hash, nout): - raw = self.prefix_db.tx.get(tx_hash, deserialize_value=False) - try: - output = self.coin.transaction(raw).outputs[nout] - script = OutputScript(output.pk_script) - script.parse() - return Claim.from_bytes(script.values['claim']) - except: - self.logger.error("claim parsing for ES failed with tx: %s", tx_hash[::-1].hex()) - return - - def _prepare_claim_metadata(self, claim_hash: bytes, claim: ResolveResult): - metadata = self.get_claim_metadata(claim.tx_hash, claim.position) - if not metadata: - return - metadata = metadata - if not metadata.is_stream or not metadata.stream.has_fee: - fee_amount = 0 - else: - fee_amount = int(max(metadata.stream.fee.amount or 0, 0) * 1000) - if fee_amount >= 9223372036854775807: - return - reposted_claim_hash = None if not metadata.is_repost else metadata.repost.reference.claim_hash[::-1] - reposted_claim = None - reposted_metadata = None - if reposted_claim_hash: - reposted_claim = self.get_cached_claim_txo(reposted_claim_hash) - if not reposted_claim: - return - reposted_metadata = self.get_claim_metadata( - self.get_tx_hash(reposted_claim.tx_num), reposted_claim.position - ) - if not reposted_metadata: - return - reposted_tags = [] - reposted_languages = [] - reposted_has_source = False - reposted_claim_type = None - reposted_stream_type = None - reposted_media_type = None - reposted_fee_amount = None - reposted_fee_currency = None - reposted_duration = None - if reposted_claim: - reposted_tx_hash = self.get_tx_hash(reposted_claim.tx_num) - raw_reposted_claim_tx = self.prefix_db.tx.get(reposted_tx_hash, deserialize_value=False) - try: - reposted_claim_txo = self.coin.transaction( - raw_reposted_claim_tx - ).outputs[reposted_claim.position] - reposted_script = OutputScript(reposted_claim_txo.pk_script) - reposted_script.parse() - reposted_metadata = Claim.from_bytes(reposted_script.values['claim']) - except: - self.logger.error("failed to parse reposted claim in tx %s that was reposted by %s", - reposted_tx_hash[::-1].hex(), claim_hash.hex()) - return - if reposted_metadata: - if reposted_metadata.is_stream: - meta = reposted_metadata.stream - elif reposted_metadata.is_channel: - meta = reposted_metadata.channel - elif reposted_metadata.is_collection: - meta = reposted_metadata.collection - elif reposted_metadata.is_repost: - meta = reposted_metadata.repost - else: - return - reposted_tags = [tag for tag in meta.tags] - reposted_languages = [lang.language or 'none' for lang in meta.languages] or ['none'] - reposted_has_source = False if not reposted_metadata.is_stream else reposted_metadata.stream.has_source - reposted_claim_type = CLAIM_TYPES[reposted_metadata.claim_type] - reposted_stream_type = STREAM_TYPES[guess_stream_type(reposted_metadata.stream.source.media_type)] \ - if reposted_has_source else 0 - reposted_media_type = reposted_metadata.stream.source.media_type if reposted_metadata.is_stream else 0 - if not reposted_metadata.is_stream or not reposted_metadata.stream.has_fee: - reposted_fee_amount = 0 - else: - reposted_fee_amount = int(max(reposted_metadata.stream.fee.amount or 0, 0) * 1000) - if reposted_fee_amount >= 9223372036854775807: - return - reposted_fee_currency = None if not reposted_metadata.is_stream else reposted_metadata.stream.fee.currency - reposted_duration = None - if reposted_metadata.is_stream and \ - (reposted_metadata.stream.video.duration or reposted_metadata.stream.audio.duration): - reposted_duration = reposted_metadata.stream.video.duration or reposted_metadata.stream.audio.duration - if metadata.is_stream: - meta = metadata.stream - elif metadata.is_channel: - meta = metadata.channel - elif metadata.is_collection: - meta = metadata.collection - elif metadata.is_repost: - meta = metadata.repost - else: - return - claim_tags = [tag for tag in meta.tags] - claim_languages = [lang.language or 'none' for lang in meta.languages] or ['none'] - - tags = list(set(claim_tags).union(set(reposted_tags))) - languages = list(set(claim_languages).union(set(reposted_languages))) - blocked_hash = self.blocked_streams.get(claim_hash) or self.blocked_streams.get( - reposted_claim_hash) or self.blocked_channels.get(claim_hash) or self.blocked_channels.get( - reposted_claim_hash) or self.blocked_channels.get(claim.channel_hash) - filtered_hash = self.filtered_streams.get(claim_hash) or self.filtered_streams.get( - reposted_claim_hash) or self.filtered_channels.get(claim_hash) or self.filtered_channels.get( - reposted_claim_hash) or self.filtered_channels.get(claim.channel_hash) - value = { - 'claim_id': claim_hash.hex(), - 'claim_name': claim.name, - 'normalized_name': claim.normalized_name, - 'tx_id': claim.tx_hash[::-1].hex(), - 'tx_num': claim.tx_num, - 'tx_nout': claim.position, - 'amount': claim.amount, - 'timestamp': self.estimate_timestamp(claim.height), - 'creation_timestamp': self.estimate_timestamp(claim.creation_height), - 'height': claim.height, - 'creation_height': claim.creation_height, - 'activation_height': claim.activation_height, - 'expiration_height': claim.expiration_height, - 'effective_amount': claim.effective_amount, - 'support_amount': claim.support_amount, - 'is_controlling': bool(claim.is_controlling), - 'last_take_over_height': claim.last_takeover_height, - 'short_url': claim.short_url, - 'canonical_url': claim.canonical_url, - 'title': None if not metadata.is_stream else metadata.stream.title, - 'author': None if not metadata.is_stream else metadata.stream.author, - 'description': None if not metadata.is_stream else metadata.stream.description, - 'claim_type': CLAIM_TYPES[metadata.claim_type], - 'has_source': reposted_has_source if metadata.is_repost else ( - False if not metadata.is_stream else metadata.stream.has_source), - 'sd_hash': metadata.stream.source.sd_hash if metadata.is_stream and metadata.stream.has_source else None, - 'stream_type': STREAM_TYPES[guess_stream_type(metadata.stream.source.media_type)] - if metadata.is_stream and metadata.stream.has_source - else reposted_stream_type if metadata.is_repost else 0, - 'media_type': metadata.stream.source.media_type - if metadata.is_stream else reposted_media_type if metadata.is_repost else None, - 'fee_amount': fee_amount if not metadata.is_repost else reposted_fee_amount, - 'fee_currency': metadata.stream.fee.currency - if metadata.is_stream else reposted_fee_currency if metadata.is_repost else None, - 'repost_count': self.get_reposted_count(claim_hash), - 'reposted_claim_id': None if not reposted_claim_hash else reposted_claim_hash.hex(), - 'reposted_claim_type': reposted_claim_type, - 'reposted_has_source': reposted_has_source, - 'channel_id': None if not metadata.is_signed else metadata.signing_channel_hash[::-1].hex(), - 'public_key_id': None if not metadata.is_channel else - self.ledger.public_key_to_address(metadata.channel.public_key_bytes), - 'signature': (metadata.signature or b'').hex() or None, - # 'signature_digest': metadata.signature, - 'is_signature_valid': bool(claim.signature_valid), - 'tags': tags, - 'languages': languages, - 'censor_type': Censor.RESOLVE if blocked_hash else Censor.SEARCH if filtered_hash else Censor.NOT_CENSORED, - 'censoring_channel_id': (blocked_hash or filtered_hash or b'').hex() or None, - 'claims_in_channel': None if not metadata.is_channel else self.get_claims_in_channel_count(claim_hash) - } - - if metadata.is_repost and reposted_duration is not None: - value['duration'] = reposted_duration - elif metadata.is_stream and (metadata.stream.video.duration or metadata.stream.audio.duration): - value['duration'] = metadata.stream.video.duration or metadata.stream.audio.duration - if metadata.is_stream: - value['release_time'] = metadata.stream.release_time or value['creation_timestamp'] - elif metadata.is_repost or metadata.is_collection: - value['release_time'] = value['creation_timestamp'] - return value - - async def all_claims_producer(self, batch_size=500_000): - batch = [] - if self.env.cache_all_claim_txos: - claim_iterator = self.claim_to_txo.items() - else: - claim_iterator = map(lambda item: (item[0].claim_hash, item[1]), self.prefix_db.claim_to_txo.iterate()) - - for claim_hash, claim_txo in claim_iterator: - # TODO: fix the couple of claim txos that dont have controlling names - if not self.prefix_db.claim_takeover.get(claim_txo.normalized_name): - continue - activation = self.get_activation(claim_txo.tx_num, claim_txo.position) - claim = self._prepare_resolve_result( - claim_txo.tx_num, claim_txo.position, claim_hash, claim_txo.name, claim_txo.root_tx_num, - claim_txo.root_position, activation, claim_txo.channel_signature_is_valid - ) - - if claim: - batch.append(claim) - if len(batch) == batch_size: - batch.sort(key=lambda x: x.tx_hash) # sort is to improve read-ahead hits - for claim in batch: - meta = self._prepare_claim_metadata(claim.claim_hash, claim) - if meta: - yield meta - batch.clear() - batch.sort(key=lambda x: x.tx_hash) - for claim in batch: - meta = self._prepare_claim_metadata(claim.claim_hash, claim) - if meta: - yield meta - batch.clear() - - def claim_producer(self, claim_hash: bytes) -> Optional[Dict]: - claim_txo = self.get_cached_claim_txo(claim_hash) - if not claim_txo: - self.logger.warning("can't sync non existent claim to ES: %s", claim_hash.hex()) - return - if not self.prefix_db.claim_takeover.get(claim_txo.normalized_name): - self.logger.warning("can't sync non existent claim to ES: %s", claim_hash.hex()) - return - activation = self.get_activation(claim_txo.tx_num, claim_txo.position) - claim = self._prepare_resolve_result( - claim_txo.tx_num, claim_txo.position, claim_hash, claim_txo.name, claim_txo.root_tx_num, - claim_txo.root_position, activation, claim_txo.channel_signature_is_valid - ) - if not claim: - return - return self._prepare_claim_metadata(claim.claim_hash, claim) - - def claims_producer(self, claim_hashes: Set[bytes]): - batch = [] - results = [] - - for claim_hash in claim_hashes: - claim_txo = self.get_cached_claim_txo(claim_hash) - if not claim_txo: - self.logger.warning("can't sync non existent claim to ES: %s", claim_hash.hex()) - continue - if not self.prefix_db.claim_takeover.get(claim_txo.normalized_name): - self.logger.warning("can't sync non existent claim to ES: %s", claim_hash.hex()) - continue - - activation = self.get_activation(claim_txo.tx_num, claim_txo.position) - claim = self._prepare_resolve_result( - claim_txo.tx_num, claim_txo.position, claim_hash, claim_txo.name, claim_txo.root_tx_num, - claim_txo.root_position, activation, claim_txo.channel_signature_is_valid - ) - if claim: - batch.append(claim) - - batch.sort(key=lambda x: x.tx_hash) - - for claim in batch: - _meta = self._prepare_claim_metadata(claim.claim_hash, claim) - if _meta: - results.append(_meta) - return results - - def get_activated_at_height(self, height: int) -> DefaultDict[PendingActivationValue, List[PendingActivationKey]]: - activated = defaultdict(list) - for k, v in self.prefix_db.pending_activation.iterate(prefix=(height,)): - activated[v].append(k) - return activated - - def get_future_activated(self, height: int) -> typing.Dict[PendingActivationValue, PendingActivationKey]: - results = {} - for k, v in self.prefix_db.pending_activation.iterate( - start=(height + 1,), stop=(height + 1 + self.coin.maxTakeoverDelay,), reverse=True): - if v not in results: - results[v] = k - return results - - async def _read_tx_counts(self): - if self.tx_counts is not None: - return - # tx_counts[N] has the cumulative number of txs at the end of - # height N. So tx_counts[0] is 1 - the genesis coinbase - - def get_counts(): - return [ - v.tx_count for v in self.prefix_db.tx_count.iterate(include_key=False, fill_cache=False) - ] - - tx_counts = await asyncio.get_event_loop().run_in_executor(None, get_counts) - assert len(tx_counts) == self.db_height + 1, f"{len(tx_counts)} vs {self.db_height + 1}" - self.tx_counts = array.array('I', tx_counts) - - if self.tx_counts: - assert self.db_tx_count == self.tx_counts[-1], \ - f"{self.db_tx_count} vs {self.tx_counts[-1]} ({len(self.tx_counts)} counts)" - else: - assert self.db_tx_count == 0 - - async def _read_claim_txos(self): - def read_claim_txos(): - set_claim_to_txo = self.claim_to_txo.__setitem__ - for k, v in self.prefix_db.claim_to_txo.iterate(fill_cache=False): - set_claim_to_txo(k.claim_hash, v) - self.txo_to_claim[v.tx_num][v.position] = k.claim_hash - - self.claim_to_txo.clear() - self.txo_to_claim.clear() - start = time.perf_counter() - self.logger.info("loading claims") - await asyncio.get_event_loop().run_in_executor(None, read_claim_txos) - ts = time.perf_counter() - start - self.logger.info("loaded %i claim txos in %ss", len(self.claim_to_txo), round(ts, 4)) - - async def _read_headers(self): - if self.headers is not None: - return - - def get_headers(): - return [ - header for header in self.prefix_db.header.iterate( - include_key=False, fill_cache=False, deserialize_value=False - ) - ] - - headers = await asyncio.get_event_loop().run_in_executor(None, get_headers) - assert len(headers) - 1 == self.db_height, f"{len(headers)} vs {self.db_height}" - self.headers = headers - - async def _read_tx_hashes(self): - def _read_tx_hashes(): - return list(self.prefix_db.tx_hash.iterate(include_key=False, fill_cache=False, deserialize_value=False)) - - self.logger.info("loading tx hashes") - self.total_transactions.clear() - self.tx_num_mapping.clear() - start = time.perf_counter() - self.total_transactions.extend(await asyncio.get_event_loop().run_in_executor(None, _read_tx_hashes)) - self.tx_num_mapping = { - tx_hash: tx_num for tx_num, tx_hash in enumerate(self.total_transactions) - } - ts = time.perf_counter() - start - self.logger.info("loaded %i tx hashes in %ss", len(self.total_transactions), round(ts, 4)) - - def estimate_timestamp(self, height: int) -> int: - if height < len(self.headers): - return struct.unpack(' bytes: - if self.env.cache_all_tx_hashes: - return self.total_transactions[tx_num] - return self.prefix_db.tx_hash.get(tx_num, deserialize_value=False) - - def get_tx_num(self, tx_hash: bytes) -> int: - if self.env.cache_all_tx_hashes: - return self.tx_num_mapping[tx_hash] - return self.prefix_db.tx_num.get(tx_hash).tx_num - - def get_cached_claim_txo(self, claim_hash: bytes) -> Optional[ClaimToTXOValue]: - if self.env.cache_all_claim_txos: - return self.claim_to_txo.get(claim_hash) - return self.prefix_db.claim_to_txo.get_pending(claim_hash) - - def get_cached_claim_hash(self, tx_num: int, position: int) -> Optional[bytes]: - if self.env.cache_all_claim_txos: - if tx_num not in self.txo_to_claim: - return - return self.txo_to_claim[tx_num].get(position, None) - v = self.prefix_db.txo_to_claim.get_pending(tx_num, position) - return None if not v else v.claim_hash - - def get_cached_claim_exists(self, tx_num: int, position: int) -> bool: - return self.get_cached_claim_hash(tx_num, position) is not None - - # Header merkle cache - - async def populate_header_merkle_cache(self): - self.logger.info('populating header merkle cache...') - length = max(1, self.db_height - self.env.reorg_limit) - start = time.time() - await self.header_mc.initialize(length) - elapsed = time.time() - start - self.logger.info(f'header merkle cache populated in {elapsed:.1f}s') - - async def header_branch_and_root(self, length, height): - return await self.header_mc.branch_and_root(length, height) - - def raw_header(self, height): - """Return the binary header at the given height.""" - header, n = self.read_headers(height, 1) - if n != 1: - raise IndexError(f'height {height:,d} out of range') - return header - - def encode_headers(self, start_height, count, headers): - key = (start_height, count) - if not self.encoded_headers.get(key): - compressobj = zlib.compressobj(wbits=-15, level=1, memLevel=9) - headers = base64.b64encode(compressobj.compress(headers) + compressobj.flush()).decode() - if start_height % 1000 != 0: - return headers - self.encoded_headers[key] = headers - return self.encoded_headers.get(key) - - def read_headers(self, start_height, count) -> typing.Tuple[bytes, int]: - """Requires start_height >= 0, count >= 0. Reads as many headers as - are available starting at start_height up to count. This - would be zero if start_height is beyond self.db_height, for - example. - - Returns a (binary, n) pair where binary is the concatenated - binary headers, and n is the count of headers returned. - """ - - if start_height < 0 or count < 0: - raise DBError(f'{count:,d} headers starting at {start_height:,d} not on disk') - - disk_count = max(0, min(count, self.db_height + 1 - start_height)) - if disk_count: - return b''.join(self.headers[start_height:start_height + disk_count]), disk_count - return b'', 0 - - def fs_tx_hash(self, tx_num): - """Return a par (tx_hash, tx_height) for the given tx number. - - If the tx_height is not on disk, returns (None, tx_height).""" - tx_height = bisect_right(self.tx_counts, tx_num) - if tx_height > self.db_height: - return None, tx_height - try: - return self.get_tx_hash(tx_num), tx_height - except IndexError: - self.logger.exception( - "Failed to access a cached transaction, known bug #3142 " - "should be fixed in #3205" - ) - return None, tx_height - - def get_block_txs(self, height: int) -> List[bytes]: - return self.prefix_db.block_txs.get(height).tx_hashes - - async def get_transactions_and_merkles(self, tx_hashes: Iterable[str]): - tx_infos = {} - for tx_hash in tx_hashes: - tx_infos[tx_hash] = await asyncio.get_event_loop().run_in_executor( - None, self._get_transaction_and_merkle, tx_hash - ) - await asyncio.sleep(0) - return tx_infos - - def _get_transaction_and_merkle(self, tx_hash): - cached_tx = self._tx_and_merkle_cache.get(tx_hash) - if cached_tx: - tx, merkle = cached_tx - else: - tx_hash_bytes = bytes.fromhex(tx_hash)[::-1] - tx_num = self.prefix_db.tx_num.get(tx_hash_bytes) - tx = None - tx_height = -1 - tx_num = None if not tx_num else tx_num.tx_num - if tx_num is not None: - if self.env.cache_all_claim_txos: - fill_cache = tx_num in self.txo_to_claim and len(self.txo_to_claim[tx_num]) > 0 - else: - fill_cache = False - tx_height = bisect_right(self.tx_counts, tx_num) - tx = self.prefix_db.tx.get(tx_hash_bytes, fill_cache=fill_cache, deserialize_value=False) - if tx_height == -1: - merkle = { - 'block_height': -1 - } - else: - tx_pos = tx_num - self.tx_counts[tx_height - 1] - branch, root = self.merkle.branch_and_root( - self.get_block_txs(tx_height), tx_pos - ) - merkle = { - 'block_height': tx_height, - 'merkle': [ - hash_to_hex_str(hash) - for hash in branch - ], - 'pos': tx_pos - } - if tx_height + 10 < self.db_height: - self._tx_and_merkle_cache[tx_hash] = tx, merkle - return (None if not tx else tx.hex(), merkle) - - async def fs_block_hashes(self, height, count): - if height + count > len(self.headers): - raise DBError(f'only got {len(self.headers) - height:,d} headers starting at {height:,d}, not {count:,d}') - return [self.coin.header_hash(header) for header in self.headers[height:height + count]] - - def read_history(self, hashX: bytes, limit: int = 1000) -> List[Tuple[bytes, int]]: - txs = [] - txs_extend = txs.extend - for hist in self.prefix_db.hashX_history.iterate(prefix=(hashX,), include_key=False): - txs_extend(hist) - if len(txs) >= limit: - break - return [ - (self.get_tx_hash(tx_num), bisect_right(self.tx_counts, tx_num)) - for tx_num in txs - ] - - async def limited_history(self, hashX, *, limit=1000): - """Return an unpruned, sorted list of (tx_hash, height) tuples of - confirmed transactions that touched the address, earliest in - the blockchain first. Includes both spending and receiving - transactions. By default returns at most 1000 entries. Set - limit to None to get them all. - """ - return await asyncio.get_event_loop().run_in_executor(None, self.read_history, hashX, limit) - - # -- Undo information - - def min_undo_height(self, max_height): - """Returns a height from which we should store undo info.""" - return max_height - self.env.reorg_limit + 1 - - def apply_expiration_extension_fork(self): - # TODO: this can't be reorged - for k, v in self.prefix_db.claim_expiration.iterate(): - self.prefix_db.claim_expiration.stage_delete(k, v) - self.prefix_db.claim_expiration.stage_put( - (bisect_right(self.tx_counts, k.tx_num) + self.coin.nExtendedClaimExpirationTime, - k.tx_num, k.position), v - ) - self.prefix_db.unsafe_commit() - - def write_db_state(self): - """Write (UTXO) state to the batch.""" - if self.db_height > 0: - self.prefix_db.db_state.stage_delete((), self.prefix_db.db_state.get()) - self.prefix_db.db_state.stage_put((), ( - self.genesis_bytes, self.db_height, self.db_tx_count, self.db_tip, - self.utxo_flush_count, int(self.wall_time), self.first_sync, self.db_version, - self.hist_flush_count, self.hist_comp_flush_count, self.hist_comp_cursor, - self.es_sync_height - ) - ) - - def read_db_state(self): - state = self.prefix_db.db_state.get() - - if not state: - self.db_height = -1 - self.db_tx_count = 0 - self.db_tip = b'\0' * 32 - self.db_version = max(self.DB_VERSIONS) - self.utxo_flush_count = 0 - self.wall_time = 0 - self.first_sync = True - self.hist_flush_count = 0 - self.hist_comp_flush_count = -1 - self.hist_comp_cursor = -1 - self.hist_db_version = max(self.DB_VERSIONS) - self.es_sync_height = 0 - else: - self.db_version = state.db_version - if self.db_version not in self.DB_VERSIONS: - raise DBError(f'your DB version is {self.db_version} but this ' - f'software only handles versions {self.DB_VERSIONS}') - # backwards compat - genesis_hash = state.genesis - if genesis_hash.hex() != self.coin.GENESIS_HASH: - raise DBError(f'DB genesis hash {genesis_hash} does not ' - f'match coin {self.coin.GENESIS_HASH}') - self.db_height = state.height - self.db_tx_count = state.tx_count - self.db_tip = state.tip - self.utxo_flush_count = state.utxo_flush_count - self.wall_time = state.wall_time - self.first_sync = state.first_sync - self.hist_flush_count = state.hist_flush_count - self.hist_comp_flush_count = state.comp_flush_count - self.hist_comp_cursor = state.comp_cursor - self.hist_db_version = state.db_version - self.es_sync_height = state.es_sync_height - - def assert_db_state(self): - state = self.prefix_db.db_state.get() - assert self.db_version == state.db_version, f"{self.db_version} != {state.db_version}" - assert self.db_height == state.height, f"{self.db_height} != {state.height}" - assert self.db_tx_count == state.tx_count, f"{self.db_tx_count} != {state.tx_count}" - assert self.db_tip == state.tip, f"{self.db_tip} != {state.tip}" - assert self.first_sync == state.first_sync, f"{self.first_sync} != {state.first_sync}" - assert self.es_sync_height == state.es_sync_height, f"{self.es_sync_height} != {state.es_sync_height}" - - async def all_utxos(self, hashX): - """Return all UTXOs for an address sorted in no particular order.""" - def read_utxos(): - utxos = [] - utxos_append = utxos.append - fs_tx_hash = self.fs_tx_hash - for k, v in self.prefix_db.utxo.iterate(prefix=(hashX, )): - tx_hash, height = fs_tx_hash(k.tx_num) - utxos_append(UTXO(k.tx_num, k.nout, tx_hash, height, v.amount)) - return utxos - - while True: - utxos = await asyncio.get_event_loop().run_in_executor(None, read_utxos) - if all(utxo.tx_hash is not None for utxo in utxos): - return utxos - self.logger.warning(f'all_utxos: tx hash not ' - f'found (reorg?), retrying...') - await sleep(0.25) - - async def lookup_utxos(self, prevouts): - def lookup_utxos(): - utxos = [] - utxo_append = utxos.append - for (tx_hash, nout) in prevouts: - tx_num_val = self.prefix_db.tx_num.get(tx_hash) - if not tx_num_val: - continue - tx_num = tx_num_val.tx_num - hashX_val = self.prefix_db.hashX_utxo.get(tx_hash[:4], tx_num, nout) - if not hashX_val: - continue - hashX = hashX_val.hashX - utxo_value = self.prefix_db.utxo.get(hashX, tx_num, nout) - if utxo_value: - utxo_append((hashX, utxo_value.amount)) - return utxos - return await asyncio.get_event_loop().run_in_executor(None, lookup_utxos) diff --git a/lbry/wallet/server/mempool.py b/lbry/wallet/server/mempool.py deleted file mode 100644 index 27d7a2352..000000000 --- a/lbry/wallet/server/mempool.py +++ /dev/null @@ -1,361 +0,0 @@ -# Copyright (c) 2016-2018, Neil Booth -# -# All rights reserved. -# -# See the file "LICENCE" for information about the copyright -# and warranty status of this software. - -"""Mempool handling.""" -import asyncio -import itertools -import time -import attr -import typing -from typing import Set, Optional, Callable, Awaitable -from collections import defaultdict -from prometheus_client import Histogram -from lbry.wallet.server.hash import hash_to_hex_str, hex_str_to_hash -from lbry.wallet.server.util import class_logger, chunks -from lbry.wallet.server.leveldb import UTXO -if typing.TYPE_CHECKING: - from lbry.wallet.server.session import LBRYSessionManager - - -@attr.s(slots=True) -class MemPoolTx: - prevouts = attr.ib() - # A pair is a (hashX, value) tuple - in_pairs = attr.ib() - out_pairs = attr.ib() - fee = attr.ib() - size = attr.ib() - raw_tx = attr.ib() - - -@attr.s(slots=True) -class MemPoolTxSummary: - hash = attr.ib() - fee = attr.ib() - has_unconfirmed_inputs = attr.ib() - - -NAMESPACE = "wallet_server" -HISTOGRAM_BUCKETS = ( - .005, .01, .025, .05, .075, .1, .25, .5, .75, 1.0, 2.5, 5.0, 7.5, 10.0, 15.0, 20.0, 30.0, 60.0, float('inf') -) -mempool_process_time_metric = Histogram( - "processed_mempool", "Time to process mempool and notify touched addresses", - namespace=NAMESPACE, buckets=HISTOGRAM_BUCKETS -) - - -class MemPool: - def __init__(self, coin, daemon, db, state_lock: asyncio.Lock, refresh_secs=1.0, log_status_secs=120.0): - self.coin = coin - self._daemon = daemon - self._db = db - self._touched_mp = {} - self._touched_bp = {} - self._highest_block = -1 - - self.logger = class_logger(__name__, self.__class__.__name__) - self.txs = {} - self.hashXs = defaultdict(set) # None can be a key - self.cached_compact_histogram = [] - self.refresh_secs = refresh_secs - self.log_status_secs = log_status_secs - # Prevents mempool refreshes during fee histogram calculation - self.lock = state_lock - self.wakeup = asyncio.Event() - self.mempool_process_time_metric = mempool_process_time_metric - self.notified_mempool_txs = set() - self.notify_sessions: Optional[Callable[[int, Set[bytes], Set[bytes]], Awaitable[None]]] = None - - async def _logging(self, synchronized_event): - """Print regular logs of mempool stats.""" - self.logger.info('beginning processing of daemon mempool. ' - 'This can take some time...') - start = time.perf_counter() - await synchronized_event.wait() - elapsed = time.perf_counter() - start - self.logger.info(f'synced in {elapsed:.2f}s') - while True: - self.logger.info(f'{len(self.txs):,d} txs ' - f'touching {len(self.hashXs):,d} addresses') - await asyncio.sleep(self.log_status_secs) - await synchronized_event.wait() - - def _accept_transactions(self, tx_map, utxo_map, touched): - """Accept transactions in tx_map to the mempool if all their inputs - can be found in the existing mempool or a utxo_map from the - DB. - - Returns an (unprocessed tx_map, unspent utxo_map) pair. - """ - hashXs = self.hashXs - txs = self.txs - - deferred = {} - unspent = set(utxo_map) - # Try to find all prevouts so we can accept the TX - for hash, tx in tx_map.items(): - in_pairs = [] - try: - for prevout in tx.prevouts: - utxo = utxo_map.get(prevout) - if not utxo: - prev_hash, prev_index = prevout - # Raises KeyError if prev_hash is not in txs - utxo = txs[prev_hash].out_pairs[prev_index] - in_pairs.append(utxo) - except KeyError: - deferred[hash] = tx - continue - - # Spend the prevouts - unspent.difference_update(tx.prevouts) - - # Save the in_pairs, compute the fee and accept the TX - tx.in_pairs = tuple(in_pairs) - # Avoid negative fees if dealing with generation-like transactions - # because some in_parts would be missing - tx.fee = max(0, (sum(v for _, v in tx.in_pairs) - - sum(v for _, v in tx.out_pairs))) - txs[hash] = tx - - for hashX, value in itertools.chain(tx.in_pairs, tx.out_pairs): - touched.add(hashX) - hashXs[hashX].add(hash) - - return deferred, {prevout: utxo_map[prevout] for prevout in unspent} - - async def _mempool_loop(self, synchronized_event): - try: - return await self._refresh_hashes(synchronized_event) - except asyncio.CancelledError: - raise - except Exception as e: - self.logger.exception("MEMPOOL DIED") - raise e - - async def _refresh_hashes(self, synchronized_event): - """Refresh our view of the daemon's mempool.""" - while True: - start = time.perf_counter() - height = self._daemon.cached_height() - hex_hashes = await self._daemon.mempool_hashes() - if height != await self._daemon.height(): - continue - hashes = {hex_str_to_hash(hh) for hh in hex_hashes} - async with self.lock: - new_hashes = hashes.difference(self.notified_mempool_txs) - touched = await self._process_mempool(hashes) - self.notified_mempool_txs.update(new_hashes) - new_touched = { - touched_hashx for touched_hashx, txs in self.hashXs.items() if txs.intersection(new_hashes) - } - synchronized_event.set() - synchronized_event.clear() - await self.on_mempool(touched, new_touched, height) - duration = time.perf_counter() - start - self.mempool_process_time_metric.observe(duration) - try: - # we wait up to `refresh_secs` but go early if a broadcast happens (which triggers wakeup event) - await asyncio.wait_for(self.wakeup.wait(), timeout=self.refresh_secs) - except asyncio.TimeoutError: - pass - finally: - self.wakeup.clear() - - async def _process_mempool(self, all_hashes): - # Re-sync with the new set of hashes - txs = self.txs - - hashXs = self.hashXs # hashX: [tx_hash, ...] - touched = set() - - # First handle txs that have disappeared - for tx_hash in set(txs).difference(all_hashes): - tx = txs.pop(tx_hash) - tx_hashXs = {hashX for hashX, value in tx.in_pairs} - tx_hashXs.update(hashX for hashX, value in tx.out_pairs) - for hashX in tx_hashXs: - hashXs[hashX].remove(tx_hash) - if not hashXs[hashX]: - del hashXs[hashX] - touched.update(tx_hashXs) - - # Process new transactions - new_hashes = list(all_hashes.difference(txs)) - if new_hashes: - fetches = [] - for hashes in chunks(new_hashes, 200): - fetches.append(self._fetch_and_accept(hashes, all_hashes, touched)) - tx_map = {} - utxo_map = {} - for fetch in asyncio.as_completed(fetches): - deferred, unspent = await fetch - tx_map.update(deferred) - utxo_map.update(unspent) - - prior_count = 0 - # FIXME: this is not particularly efficient - while tx_map and len(tx_map) != prior_count: - prior_count = len(tx_map) - tx_map, utxo_map = self._accept_transactions(tx_map, utxo_map, touched) - - if tx_map: - self.logger.info(f'{len(tx_map)} txs dropped') - - return touched - - async def _fetch_and_accept(self, hashes, all_hashes, touched): - """Fetch a list of mempool transactions.""" - raw_txs = await self._daemon.getrawtransactions((hash_to_hex_str(hash) for hash in hashes)) - - to_hashX = self.coin.hashX_from_script - deserializer = self.coin.DESERIALIZER - - tx_map = {} - for hash, raw_tx in zip(hashes, raw_txs): - # The daemon may have evicted the tx from its - # mempool or it may have gotten in a block - if not raw_tx: - continue - tx, tx_size = deserializer(raw_tx).read_tx_and_vsize() - # Convert the inputs and outputs into (hashX, value) pairs - # Drop generation-like inputs from MemPoolTx.prevouts - txin_pairs = tuple((txin.prev_hash, txin.prev_idx) - for txin in tx.inputs - if not txin.is_generation()) - txout_pairs = tuple((to_hashX(txout.pk_script), txout.value) - for txout in tx.outputs) - tx_map[hash] = MemPoolTx(txin_pairs, None, txout_pairs, - 0, tx_size, raw_tx) - - # Determine all prevouts not in the mempool, and fetch the - # UTXO information from the database. Failed prevout lookups - # return None - concurrent database updates happen - which is - # relied upon by _accept_transactions. Ignore prevouts that are - # generation-like. - prevouts = tuple(prevout for tx in tx_map.values() - for prevout in tx.prevouts - if prevout[0] not in all_hashes) - utxos = await self._db.lookup_utxos(prevouts) - utxo_map = dict(zip(prevouts, utxos)) - - return self._accept_transactions(tx_map, utxo_map, touched) - - # - # External interface - # - - async def keep_synchronized(self, synchronized_event): - """Keep the mempool synchronized with the daemon.""" - await asyncio.wait([ - self._mempool_loop(synchronized_event), - # self._refresh_histogram(synchronized_event), - self._logging(synchronized_event) - ]) - - async def balance_delta(self, hashX): - """Return the unconfirmed amount in the mempool for hashX. - - Can be positive or negative. - """ - value = 0 - if hashX in self.hashXs: - for hash in self.hashXs[hashX]: - tx = self.txs[hash] - value -= sum(v for h168, v in tx.in_pairs if h168 == hashX) - value += sum(v for h168, v in tx.out_pairs if h168 == hashX) - return value - - def compact_fee_histogram(self): - """Return a compact fee histogram of the current mempool.""" - return self.cached_compact_histogram - - async def potential_spends(self, hashX): - """Return a set of (prev_hash, prev_idx) pairs from mempool - transactions that touch hashX. - - None, some or all of these may be spends of the hashX, but all - actual spends of it (in the DB or mempool) will be included. - """ - result = set() - for tx_hash in self.hashXs.get(hashX, ()): - tx = self.txs[tx_hash] - result.update(tx.prevouts) - return result - - def transaction_summaries(self, hashX): - """Return a list of MemPoolTxSummary objects for the hashX.""" - result = [] - for tx_hash in self.hashXs.get(hashX, ()): - tx = self.txs[tx_hash] - has_ui = any(hash in self.txs for hash, idx in tx.prevouts) - result.append(MemPoolTxSummary(tx_hash, tx.fee, has_ui)) - return result - - async def unordered_UTXOs(self, hashX): - """Return an unordered list of UTXO named tuples from mempool - transactions that pay to hashX. - - This does not consider if any other mempool transactions spend - the outputs. - """ - utxos = [] - for tx_hash in self.hashXs.get(hashX, ()): - tx = self.txs.get(tx_hash) - for pos, (hX, value) in enumerate(tx.out_pairs): - if hX == hashX: - utxos.append(UTXO(-1, pos, tx_hash, 0, value)) - return utxos - - def get_mempool_height(self, tx_hash): - # Height Progression - # -2: not broadcast - # -1: in mempool but has unconfirmed inputs - # 0: in mempool and all inputs confirmed - # +num: confirmed in a specific block (height) - if tx_hash not in self.txs: - return -2 - tx = self.txs[tx_hash] - unspent_inputs = sum(1 if hash in self.txs else 0 for hash, idx in tx.prevouts) - if unspent_inputs: - return -1 - return 0 - - async def _maybe_notify(self, new_touched): - tmp, tbp = self._touched_mp, self._touched_bp - common = set(tmp).intersection(tbp) - if common: - height = max(common) - elif tmp and max(tmp) == self._highest_block: - height = self._highest_block - else: - # Either we are processing a block and waiting for it to - # come in, or we have not yet had a mempool update for the - # new block height - return - touched = tmp.pop(height) - for old in [h for h in tmp if h <= height]: - del tmp[old] - for old in [h for h in tbp if h <= height]: - touched.update(tbp.pop(old)) - # print("notify", height, len(touched), len(new_touched)) - await self.notify_sessions(height, touched, new_touched) - - async def start(self, height, session_manager: 'LBRYSessionManager'): - self._highest_block = height - self.notify_sessions = session_manager._notify_sessions - await self.notify_sessions(height, set(), set()) - - async def on_mempool(self, touched, new_touched, height): - self._touched_mp[height] = touched - await self._maybe_notify(new_touched) - - async def on_block(self, touched, height): - self._touched_bp[height] = touched - self._highest_block = height - await self._maybe_notify(set()) diff --git a/lbry/wallet/server/merkle.py b/lbry/wallet/server/merkle.py deleted file mode 100644 index 8cf1ca08b..000000000 --- a/lbry/wallet/server/merkle.py +++ /dev/null @@ -1,258 +0,0 @@ -# Copyright (c) 2018, Neil Booth -# -# All rights reserved. -# -# The MIT License (MIT) -# -# Permission is hereby granted, free of charge, to any person obtaining -# a copy of this software and associated documentation files (the -# "Software"), to deal in the Software without restriction, including -# without limitation the rights to use, copy, modify, merge, publish, -# distribute, sublicense, and/or sell copies of the Software, and to -# permit persons to whom the Software is furnished to do so, subject to -# the following conditions: -# -# The above copyright notice and this permission notice shall be -# included in all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -# and warranty status of this software. - -"""Merkle trees, branches, proofs and roots.""" - -from asyncio import Event -from math import ceil, log - -from lbry.wallet.server.hash import double_sha256 - - -class Merkle: - """Perform merkle tree calculations on binary hashes using a given hash - function. - - If the hash count is not even, the final hash is repeated when - calculating the next merkle layer up the tree. - """ - - def __init__(self, hash_func=double_sha256): - self.hash_func = hash_func - - @staticmethod - def tree_depth(hash_count): - return Merkle.branch_length(hash_count) + 1 - - @staticmethod - def branch_length(hash_count): - """Return the length of a merkle branch given the number of hashes.""" - if not isinstance(hash_count, int): - raise TypeError('hash_count must be an integer') - if hash_count < 1: - raise ValueError('hash_count must be at least 1') - return ceil(log(hash_count, 2)) - - @staticmethod - def branch_and_root(hashes, index, length=None, hash_func=double_sha256): - """Return a (merkle branch, merkle_root) pair given hashes, and the - index of one of those hashes. - """ - hashes = list(hashes) - if not isinstance(index, int): - raise TypeError('index must be an integer') - # This also asserts hashes is not empty - if not 0 <= index < len(hashes): - raise ValueError(f"index '{index}/{len(hashes)}' out of range") - natural_length = Merkle.branch_length(len(hashes)) - if length is None: - length = natural_length - else: - if not isinstance(length, int): - raise TypeError('length must be an integer') - if length < natural_length: - raise ValueError('length out of range') - - branch = [] - for _ in range(length): - if len(hashes) & 1: - hashes.append(hashes[-1]) - branch.append(hashes[index ^ 1]) - index >>= 1 - hashes = [hash_func(hashes[n] + hashes[n + 1]) - for n in range(0, len(hashes), 2)] - - return branch, hashes[0] - - @staticmethod - def root(hashes, length=None): - """Return the merkle root of a non-empty iterable of binary hashes.""" - branch, root = Merkle.branch_and_root(hashes, 0, length) - return root - - # @staticmethod - # def root_from_proof(hash, branch, index, hash_func=double_sha256): - # """Return the merkle root given a hash, a merkle branch to it, and - # its index in the hashes array. - # - # branch is an iterable sorted deepest to shallowest. If the - # returned root is the expected value then the merkle proof is - # verified. - # - # The caller should have confirmed the length of the branch with - # branch_length(). Unfortunately this is not easily done for - # bitcoin transactions as the number of transactions in a block - # is unknown to an SPV client. - # """ - # for elt in branch: - # if index & 1: - # hash = hash_func(elt + hash) - # else: - # hash = hash_func(hash + elt) - # index >>= 1 - # if index: - # raise ValueError('index out of range for branch') - # return hash - - @staticmethod - def level(hashes, depth_higher): - """Return a level of the merkle tree of hashes the given depth - higher than the bottom row of the original tree.""" - size = 1 << depth_higher - root = Merkle.root - return [root(hashes[n: n + size], depth_higher) - for n in range(0, len(hashes), size)] - - @staticmethod - def branch_and_root_from_level(level, leaf_hashes, index, - depth_higher): - """Return a (merkle branch, merkle_root) pair when a merkle-tree has a - level cached. - - To maximally reduce the amount of data hashed in computing a - markle branch, cache a tree of depth N at level N // 2. - - level is a list of hashes in the middle of the tree (returned - by level()) - - leaf_hashes are the leaves needed to calculate a partial branch - up to level. - - depth_higher is how much higher level is than the leaves of the tree - - index is the index in the full list of hashes of the hash whose - merkle branch we want. - """ - if not isinstance(level, list): - raise TypeError("level must be a list") - if not isinstance(leaf_hashes, list): - raise TypeError("leaf_hashes must be a list") - leaf_index = (index >> depth_higher) << depth_higher - leaf_branch, leaf_root = Merkle.branch_and_root( - leaf_hashes, index - leaf_index, depth_higher) - index >>= depth_higher - level_branch, root = Merkle.branch_and_root(level, index) - # Check last so that we know index is in-range - if leaf_root != level[index]: - raise ValueError('leaf hashes inconsistent with level') - return leaf_branch + level_branch, root - - -class MerkleCache: - """A cache to calculate merkle branches efficiently.""" - - def __init__(self, merkle, source_func): - """Initialise a cache hashes taken from source_func: - - async def source_func(index, count): - ... - """ - self.merkle = merkle - self.source_func = source_func - self.length = 0 - self.depth_higher = 0 - self.initialized = Event() - - def _segment_length(self): - return 1 << self.depth_higher - - def _leaf_start(self, index): - """Given a level's depth higher and a hash index, return the leaf - index and leaf hash count needed to calculate a merkle branch. - """ - depth_higher = self.depth_higher - return (index >> depth_higher) << depth_higher - - def _level(self, hashes): - return self.merkle.level(hashes, self.depth_higher) - - async def _extend_to(self, length): - """Extend the length of the cache if necessary.""" - if length <= self.length: - return - # Start from the beginning of any final partial segment. - # Retain the value of depth_higher; in practice this is fine - start = self._leaf_start(self.length) - hashes = await self.source_func(start, length - start) - self.level[start >> self.depth_higher:] = self._level(hashes) - self.length = length - - async def _level_for(self, length): - """Return a (level_length, final_hash) pair for a truncation - of the hashes to the given length.""" - if length == self.length: - return self.level - level = self.level[:length >> self.depth_higher] - leaf_start = self._leaf_start(length) - count = min(self._segment_length(), length - leaf_start) - hashes = await self.source_func(leaf_start, count) - level += self._level(hashes) - return level - - async def initialize(self, length): - """Call to initialize the cache to a source of given length.""" - self.length = length - self.depth_higher = self.merkle.tree_depth(length) // 2 - self.level = self._level(await self.source_func(0, length)) - self.initialized.set() - - def truncate(self, length): - """Truncate the cache so it covers no more than length underlying - hashes.""" - if not isinstance(length, int): - raise TypeError('length must be an integer') - if length <= 0: - raise ValueError('length must be positive') - if length >= self.length: - return - length = self._leaf_start(length) - self.length = length - self.level[length >> self.depth_higher:] = [] - - async def branch_and_root(self, length, index): - """Return a merkle branch and root. Length is the number of - hashes used to calculate the merkle root, index is the position - of the hash to calculate the branch of. - - index must be less than length, which must be at least 1.""" - if not isinstance(length, int): - raise TypeError('length must be an integer') - if not isinstance(index, int): - raise TypeError('index must be an integer') - if length <= 0: - raise ValueError('length must be positive') - if index >= length: - raise ValueError('index must be less than length') - await self.initialized.wait() - await self._extend_to(length) - leaf_start = self._leaf_start(index) - count = min(self._segment_length(), length - leaf_start) - leaf_hashes = await self.source_func(leaf_start, count) - if length < self._segment_length(): - return self.merkle.branch_and_root(leaf_hashes, index) - level = await self._level_for(length) - return self.merkle.branch_and_root_from_level( - level, leaf_hashes, index, self.depth_higher) diff --git a/lbry/wallet/server/metrics.py b/lbry/wallet/server/metrics.py deleted file mode 100644 index f1ee4e5d1..000000000 --- a/lbry/wallet/server/metrics.py +++ /dev/null @@ -1,135 +0,0 @@ -import time -import math -from typing import Tuple - - -def calculate_elapsed(start) -> int: - return int((time.perf_counter() - start) * 1000) - - -def calculate_avg_percentiles(data) -> Tuple[int, int, int, int, int, int, int, int]: - if not data: - return 0, 0, 0, 0, 0, 0, 0, 0 - data.sort() - size = len(data) - return ( - int(sum(data) / size), - data[0], - data[math.ceil(size * .05) - 1], - data[math.ceil(size * .25) - 1], - data[math.ceil(size * .50) - 1], - data[math.ceil(size * .75) - 1], - data[math.ceil(size * .95) - 1], - data[-1] - ) - - -def remove_select_list(sql) -> str: - return sql[sql.index('FROM'):] - - -class APICallMetrics: - - def __init__(self, name): - self.name = name - - # total requests received - self.receive_count = 0 - self.cache_response_count = 0 - - # millisecond timings for query based responses - self.query_response_times = [] - self.query_intrp_times = [] - self.query_error_times = [] - - self.query_python_times = [] - self.query_wait_times = [] - self.query_sql_times = [] # aggregate total of multiple SQL calls made per request - - self.individual_sql_times = [] # every SQL query run on server - - # actual queries - self.errored_queries = set() - self.interrupted_queries = set() - - def to_json(self): - return { - # total requests received - "receive_count": self.receive_count, - # sum of these is total responses made - "cache_response_count": self.cache_response_count, - "query_response_count": len(self.query_response_times), - "intrp_response_count": len(self.query_intrp_times), - "error_response_count": len(self.query_error_times), - # millisecond timings for non-cache responses - "response": calculate_avg_percentiles(self.query_response_times), - "interrupt": calculate_avg_percentiles(self.query_intrp_times), - "error": calculate_avg_percentiles(self.query_error_times), - # response, interrupt and error each also report the python, wait and sql stats: - "python": calculate_avg_percentiles(self.query_python_times), - "wait": calculate_avg_percentiles(self.query_wait_times), - "sql": calculate_avg_percentiles(self.query_sql_times), - # extended timings for individual sql executions - "individual_sql": calculate_avg_percentiles(self.individual_sql_times), - "individual_sql_count": len(self.individual_sql_times), - # actual queries - "errored_queries": list(self.errored_queries), - "interrupted_queries": list(self.interrupted_queries), - } - - def start(self): - self.receive_count += 1 - - def cache_response(self): - self.cache_response_count += 1 - - def _add_query_timings(self, request_total_time, metrics): - if metrics and 'execute_query' in metrics: - sub_process_total = metrics[self.name][0]['total'] - individual_query_times = [f['total'] for f in metrics['execute_query']] - aggregated_query_time = sum(individual_query_times) - self.individual_sql_times.extend(individual_query_times) - self.query_sql_times.append(aggregated_query_time) - self.query_python_times.append(sub_process_total - aggregated_query_time) - self.query_wait_times.append(request_total_time - sub_process_total) - - @staticmethod - def _add_queries(query_set, metrics): - if metrics and 'execute_query' in metrics: - for execute_query in metrics['execute_query']: - if 'sql' in execute_query: - query_set.add(remove_select_list(execute_query['sql'])) - - def query_response(self, start, metrics): - self.query_response_times.append(calculate_elapsed(start)) - self._add_query_timings(self.query_response_times[-1], metrics) - - def query_interrupt(self, start, metrics): - self.query_intrp_times.append(calculate_elapsed(start)) - self._add_queries(self.interrupted_queries, metrics) - self._add_query_timings(self.query_intrp_times[-1], metrics) - - def query_error(self, start, metrics): - self.query_error_times.append(calculate_elapsed(start)) - self._add_queries(self.errored_queries, metrics) - self._add_query_timings(self.query_error_times[-1], metrics) - - -class ServerLoadData: - - def __init__(self): - self._apis = {} - - def for_api(self, name) -> APICallMetrics: - if name not in self._apis: - self._apis[name] = APICallMetrics(name) - return self._apis[name] - - def to_json_and_reset(self, status): - try: - return { - 'api': {name: api.to_json() for name, api in self._apis.items()}, - 'status': status - } - finally: - self._apis = {} diff --git a/lbry/wallet/server/script.py b/lbry/wallet/server/script.py deleted file mode 100644 index ce8d5e5a1..000000000 --- a/lbry/wallet/server/script.py +++ /dev/null @@ -1,289 +0,0 @@ -# Copyright (c) 2016-2017, Neil Booth -# -# All rights reserved. -# -# The MIT License (MIT) -# -# Permission is hereby granted, free of charge, to any person obtaining -# a copy of this software and associated documentation files (the -# "Software"), to deal in the Software without restriction, including -# without limitation the rights to use, copy, modify, merge, publish, -# distribute, sublicense, and/or sell copies of the Software, and to -# permit persons to whom the Software is furnished to do so, subject to -# the following conditions: -# -# The above copyright notice and this permission notice shall be -# included in all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -# and warranty status of this software. - -"""Script-related classes and functions.""" - - -from collections import namedtuple - -from lbry.wallet.server.util import unpack_le_uint16_from, unpack_le_uint32_from, \ - pack_le_uint16, pack_le_uint32 - - -class EnumError(Exception): - pass - - -class Enumeration: - - def __init__(self, name, enumList): - self.__doc__ = name - - lookup = {} - reverseLookup = {} - i = 0 - uniqueNames = set() - uniqueValues = set() - for x in enumList: - if isinstance(x, tuple): - x, i = x - if not isinstance(x, str): - raise EnumError(f"enum name {x} not a string") - if not isinstance(i, int): - raise EnumError(f"enum value {i} not an integer") - if x in uniqueNames: - raise EnumError(f"enum name {x} not unique") - if i in uniqueValues: - raise EnumError(f"enum value {i} not unique") - uniqueNames.add(x) - uniqueValues.add(i) - lookup[x] = i - reverseLookup[i] = x - i = i + 1 - self.lookup = lookup - self.reverseLookup = reverseLookup - - def __getattr__(self, attr): - result = self.lookup.get(attr) - if result is None: - raise AttributeError(f'enumeration has no member {attr}') - return result - - def whatis(self, value): - return self.reverseLookup[value] - - -class ScriptError(Exception): - """Exception used for script errors.""" - - -OpCodes = Enumeration("Opcodes", [ - ("OP_0", 0), ("OP_PUSHDATA1", 76), - "OP_PUSHDATA2", "OP_PUSHDATA4", "OP_1NEGATE", - "OP_RESERVED", - "OP_1", "OP_2", "OP_3", "OP_4", "OP_5", "OP_6", "OP_7", "OP_8", - "OP_9", "OP_10", "OP_11", "OP_12", "OP_13", "OP_14", "OP_15", "OP_16", - "OP_NOP", "OP_VER", "OP_IF", "OP_NOTIF", "OP_VERIF", "OP_VERNOTIF", - "OP_ELSE", "OP_ENDIF", "OP_VERIFY", "OP_RETURN", - "OP_TOALTSTACK", "OP_FROMALTSTACK", "OP_2DROP", "OP_2DUP", "OP_3DUP", - "OP_2OVER", "OP_2ROT", "OP_2SWAP", "OP_IFDUP", "OP_DEPTH", "OP_DROP", - "OP_DUP", "OP_NIP", "OP_OVER", "OP_PICK", "OP_ROLL", "OP_ROT", - "OP_SWAP", "OP_TUCK", - "OP_CAT", "OP_SUBSTR", "OP_LEFT", "OP_RIGHT", "OP_SIZE", - "OP_INVERT", "OP_AND", "OP_OR", "OP_XOR", "OP_EQUAL", "OP_EQUALVERIFY", - "OP_RESERVED1", "OP_RESERVED2", - "OP_1ADD", "OP_1SUB", "OP_2MUL", "OP_2DIV", "OP_NEGATE", "OP_ABS", - "OP_NOT", "OP_0NOTEQUAL", "OP_ADD", "OP_SUB", "OP_MUL", "OP_DIV", "OP_MOD", - "OP_LSHIFT", "OP_RSHIFT", "OP_BOOLAND", "OP_BOOLOR", "OP_NUMEQUAL", - "OP_NUMEQUALVERIFY", "OP_NUMNOTEQUAL", "OP_LESSTHAN", "OP_GREATERTHAN", - "OP_LESSTHANOREQUAL", "OP_GREATERTHANOREQUAL", "OP_MIN", "OP_MAX", - "OP_WITHIN", - "OP_RIPEMD160", "OP_SHA1", "OP_SHA256", "OP_HASH160", "OP_HASH256", - "OP_CODESEPARATOR", "OP_CHECKSIG", "OP_CHECKSIGVERIFY", "OP_CHECKMULTISIG", - "OP_CHECKMULTISIGVERIFY", - "OP_NOP1", - "OP_CHECKLOCKTIMEVERIFY", "OP_CHECKSEQUENCEVERIFY" -]) - - -# Paranoia to make it hard to create bad scripts -assert OpCodes.OP_DUP == 0x76 -assert OpCodes.OP_HASH160 == 0xa9 -assert OpCodes.OP_EQUAL == 0x87 -assert OpCodes.OP_EQUALVERIFY == 0x88 -assert OpCodes.OP_CHECKSIG == 0xac -assert OpCodes.OP_CHECKMULTISIG == 0xae - - -def _match_ops(ops, pattern): - if len(ops) != len(pattern): - return False - for op, pop in zip(ops, pattern): - if pop != op: - # -1 means 'data push', whose op is an (op, data) tuple - if pop == -1 and isinstance(op, tuple): - continue - return False - - return True - - -class ScriptPubKey: - """A class for handling a tx output script that gives conditions - necessary for spending. - """ - - TO_ADDRESS_OPS = [OpCodes.OP_DUP, OpCodes.OP_HASH160, -1, - OpCodes.OP_EQUALVERIFY, OpCodes.OP_CHECKSIG] - TO_P2SH_OPS = [OpCodes.OP_HASH160, -1, OpCodes.OP_EQUAL] - TO_PUBKEY_OPS = [-1, OpCodes.OP_CHECKSIG] - - PayToHandlers = namedtuple('PayToHandlers', 'address script_hash pubkey ' - 'unspendable strange') - - @classmethod - def pay_to(cls, handlers, script): - """Parse a script, invoke the appropriate handler and - return the result. - - One of the following handlers is invoked: - handlers.address(hash160) - handlers.script_hash(hash160) - handlers.pubkey(pubkey) - handlers.unspendable() - handlers.strange(script) - """ - try: - ops = Script.get_ops(script) - except ScriptError: - return handlers.unspendable() - - match = _match_ops - - if match(ops, cls.TO_ADDRESS_OPS): - return handlers.address(ops[2][-1]) - if match(ops, cls.TO_P2SH_OPS): - return handlers.script_hash(ops[1][-1]) - if match(ops, cls.TO_PUBKEY_OPS): - return handlers.pubkey(ops[0][-1]) - if ops and ops[0] == OpCodes.OP_RETURN: - return handlers.unspendable() - return handlers.strange(script) - - @classmethod - def P2SH_script(cls, hash160): - return (bytes([OpCodes.OP_HASH160]) - + Script.push_data(hash160) - + bytes([OpCodes.OP_EQUAL])) - - @classmethod - def P2PKH_script(cls, hash160): - return (bytes([OpCodes.OP_DUP, OpCodes.OP_HASH160]) - + Script.push_data(hash160) - + bytes([OpCodes.OP_EQUALVERIFY, OpCodes.OP_CHECKSIG])) - - @classmethod - def validate_pubkey(cls, pubkey, req_compressed=False): - if isinstance(pubkey, (bytes, bytearray)): - if len(pubkey) == 33 and pubkey[0] in (2, 3): - return # Compressed - if len(pubkey) == 65 and pubkey[0] == 4: - if not req_compressed: - return - raise PubKeyError('uncompressed pubkeys are invalid') - raise PubKeyError(f'invalid pubkey {pubkey}') - - @classmethod - def pubkey_script(cls, pubkey): - cls.validate_pubkey(pubkey) - return Script.push_data(pubkey) + bytes([OpCodes.OP_CHECKSIG]) - - @classmethod - def multisig_script(cls, m, pubkeys): - """Returns the script for a pay-to-multisig transaction.""" - n = len(pubkeys) - if not 1 <= m <= n <= 15: - raise ScriptError(f'{m:d} of {n:d} multisig script not possible') - for pubkey in pubkeys: - cls.validate_pubkey(pubkey, req_compressed=True) - # See https://bitcoin.org/en/developer-guide - # 2 of 3 is: OP_2 pubkey1 pubkey2 pubkey3 OP_3 OP_CHECKMULTISIG - return (bytes([OP_1 + m - 1]) - + b''.join(cls.push_data(pubkey) for pubkey in pubkeys) - + bytes([OP_1 + n - 1, OP_CHECK_MULTISIG])) - - -class Script: - - @classmethod - def get_ops(cls, script): - ops = [] - - # The unpacks or script[n] below throw on truncated scripts - try: - n = 0 - while n < len(script): - op = script[n] - n += 1 - - if op <= OpCodes.OP_PUSHDATA4: - # Raw bytes follow - if op < OpCodes.OP_PUSHDATA1: - dlen = op - elif op == OpCodes.OP_PUSHDATA1: - dlen = script[n] - n += 1 - elif op == OpCodes.OP_PUSHDATA2: - dlen, = unpack_le_uint16_from(script[n: n + 2]) - n += 2 - else: - dlen, = unpack_le_uint32_from(script[n: n + 4]) - n += 4 - if n + dlen > len(script): - raise IndexError - op = (op, script[n:n + dlen]) - n += dlen - - ops.append(op) - except Exception: - # Truncated script; e.g. tx_hash - # ebc9fa1196a59e192352d76c0f6e73167046b9d37b8302b6bb6968dfd279b767 - raise ScriptError('truncated script') - - return ops - - @classmethod - def push_data(cls, data): - """Returns the opcodes to push the data on the stack.""" - assert isinstance(data, (bytes, bytearray)) - - n = len(data) - if n < OpCodes.OP_PUSHDATA1: - return bytes([n]) + data - if n < 256: - return bytes([OpCodes.OP_PUSHDATA1, n]) + data - if n < 65536: - return bytes([OpCodes.OP_PUSHDATA2]) + pack_le_uint16(n) + data - return bytes([OpCodes.OP_PUSHDATA4]) + pack_le_uint32(n) + data - - @classmethod - def opcode_name(cls, opcode): - if OpCodes.OP_0 < opcode < OpCodes.OP_PUSHDATA1: - return f'OP_{opcode:d}' - try: - return OpCodes.whatis(opcode) - except KeyError: - return f'OP_UNKNOWN:{opcode:d}' - - @classmethod - def dump(cls, script): - opcodes, datas = cls.get_ops(script) - for opcode, data in zip(opcodes, datas): - name = cls.opcode_name(opcode) - if data is None: - print(name) - else: - print(f'{name} {data.hex()} ({len(data):d} bytes)') diff --git a/lbry/wallet/server/server.py b/lbry/wallet/server/server.py deleted file mode 100644 index 966d5c31e..000000000 --- a/lbry/wallet/server/server.py +++ /dev/null @@ -1,91 +0,0 @@ -import signal -import logging -import asyncio -from concurrent.futures.thread import ThreadPoolExecutor -import typing - -import lbry -from lbry.wallet.server.mempool import MemPool -from lbry.wallet.server.block_processor import BlockProcessor -from lbry.wallet.server.leveldb import LevelDB -from lbry.wallet.server.session import LBRYSessionManager -from lbry.prometheus import PrometheusServer - - -class Server: - - def __init__(self, env): - self.env = env - self.log = logging.getLogger(__name__).getChild(self.__class__.__name__) - self.shutdown_event = asyncio.Event() - self.cancellable_tasks = [] - - self.daemon = daemon = env.coin.DAEMON(env.coin, env.daemon_url) - self.db = db = LevelDB(env) - self.bp = bp = BlockProcessor(env, db, daemon, self.shutdown_event) - self.prometheus_server: typing.Optional[PrometheusServer] = None - - self.session_mgr = LBRYSessionManager( - env, db, bp, daemon, self.shutdown_event - ) - self._indexer_task = None - - async def start(self): - env = self.env - min_str, max_str = env.coin.SESSIONCLS.protocol_min_max_strings() - self.log.info(f'software version: {lbry.__version__}') - self.log.info(f'supported protocol versions: {min_str}-{max_str}') - self.log.info(f'event loop policy: {env.loop_policy}') - self.log.info(f'reorg limit is {env.reorg_limit:,d} blocks') - - await self.daemon.height() - - def _start_cancellable(run, *args): - _flag = asyncio.Event() - self.cancellable_tasks.append(asyncio.ensure_future(run(*args, _flag))) - return _flag.wait() - - await self.start_prometheus() - if self.env.udp_port: - await self.bp.status_server.start( - 0, bytes.fromhex(self.bp.coin.GENESIS_HASH)[::-1], self.env.country, - self.env.host, self.env.udp_port, self.env.allow_lan_udp - ) - await _start_cancellable(self.bp.fetch_and_process_blocks) - - await self.db.populate_header_merkle_cache() - await _start_cancellable(self.bp.mempool.keep_synchronized) - await _start_cancellable(self.session_mgr.serve, self.bp.mempool) - - async def stop(self): - for task in reversed(self.cancellable_tasks): - task.cancel() - await asyncio.wait(self.cancellable_tasks) - if self.prometheus_server: - await self.prometheus_server.stop() - self.prometheus_server = None - self.shutdown_event.set() - await self.daemon.close() - - def run(self): - loop = asyncio.get_event_loop() - executor = ThreadPoolExecutor(self.env.max_query_workers, thread_name_prefix='hub-worker') - loop.set_default_executor(executor) - - def __exit(): - raise SystemExit() - try: - loop.add_signal_handler(signal.SIGINT, __exit) - loop.add_signal_handler(signal.SIGTERM, __exit) - loop.run_until_complete(self.start()) - loop.run_until_complete(self.shutdown_event.wait()) - except (SystemExit, KeyboardInterrupt): - pass - finally: - loop.run_until_complete(self.stop()) - executor.shutdown(True) - - async def start_prometheus(self): - if not self.prometheus_server and self.env.prometheus_port: - self.prometheus_server = PrometheusServer() - await self.prometheus_server.start("0.0.0.0", self.env.prometheus_port) diff --git a/lbry/wallet/server/session.py b/lbry/wallet/server/session.py deleted file mode 100644 index 6218b3992..000000000 --- a/lbry/wallet/server/session.py +++ /dev/null @@ -1,1563 +0,0 @@ -import os -import ssl -import math -import time -import codecs -import typing -import asyncio -import logging -import itertools -import collections - -from asyncio import Event, sleep -from collections import defaultdict -from functools import partial - -from elasticsearch import ConnectionTimeout -from prometheus_client import Counter, Info, Histogram, Gauge - -import lbry -from lbry.error import ResolveCensoredError, TooManyClaimSearchParametersError -from lbry.build_info import BUILD, COMMIT_HASH, DOCKER_TAG -from lbry.schema.result import Outputs -from lbry.wallet.server.block_processor import BlockProcessor -from lbry.wallet.server.leveldb import LevelDB -from lbry.wallet.server.websocket import AdminWebSocket -from lbry.wallet.rpc.framing import NewlineFramer - -import lbry.wallet.server.version as VERSION - -from lbry.wallet.rpc import ( - RPCSession, JSONRPCAutoDetect, JSONRPCConnection, - handler_invocation, RPCError, Request, JSONRPC, Notification, Batch -) -from lbry.wallet.server import util -from lbry.wallet.server.hash import sha256, hash_to_hex_str, hex_str_to_hash, HASHX_LEN, Base58Error -from lbry.wallet.server.daemon import DaemonError -if typing.TYPE_CHECKING: - from lbry.wallet.server.env import Env - from lbry.wallet.server.daemon import Daemon - -BAD_REQUEST = 1 -DAEMON_ERROR = 2 - -log = logging.getLogger(__name__) - - -def scripthash_to_hashX(scripthash: str) -> bytes: - try: - bin_hash = hex_str_to_hash(scripthash) - if len(bin_hash) == 32: - return bin_hash[:HASHX_LEN] - except Exception: - pass - raise RPCError(BAD_REQUEST, f'{scripthash} is not a valid script hash') - - -def non_negative_integer(value) -> int: - """Return param value it is or can be converted to a non-negative - integer, otherwise raise an RPCError.""" - try: - value = int(value) - if value >= 0: - return value - except ValueError: - pass - raise RPCError(BAD_REQUEST, - f'{value} should be a non-negative integer') - - -def assert_boolean(value) -> bool: - """Return param value it is boolean otherwise raise an RPCError.""" - if value in (False, True): - return value - raise RPCError(BAD_REQUEST, f'{value} should be a boolean value') - - -def assert_tx_hash(value: str) -> None: - """Raise an RPCError if the value is not a valid transaction - hash.""" - try: - if len(util.hex_to_bytes(value)) == 32: - return - except Exception: - pass - raise RPCError(BAD_REQUEST, f'{value} should be a transaction hash') - - -class Semaphores: - """For aiorpcX's semaphore handling.""" - - def __init__(self, semaphores): - self.semaphores = semaphores - self.acquired = [] - - async def __aenter__(self): - for semaphore in self.semaphores: - await semaphore.acquire() - self.acquired.append(semaphore) - - async def __aexit__(self, exc_type, exc_value, traceback): - for semaphore in self.acquired: - semaphore.release() - - -class SessionGroup: - - def __init__(self, gid: int): - self.gid = gid - # Concurrency per group - self.semaphore = asyncio.Semaphore(20) - - -NAMESPACE = "wallet_server" -HISTOGRAM_BUCKETS = ( - .005, .01, .025, .05, .075, .1, .25, .5, .75, 1.0, 2.5, 5.0, 7.5, 10.0, 15.0, 20.0, 30.0, 60.0, float('inf') -) - - -class SessionManager: - """Holds global state about all sessions.""" - - version_info_metric = Info( - 'build', 'Wallet server build info (e.g. version, commit hash)', namespace=NAMESPACE - ) - version_info_metric.info({ - 'build': BUILD, - "commit": COMMIT_HASH, - "docker_tag": DOCKER_TAG, - 'version': lbry.__version__, - "min_version": util.version_string(VERSION.PROTOCOL_MIN), - "cpu_count": str(os.cpu_count()) - }) - session_count_metric = Gauge("session_count", "Number of connected client sessions", namespace=NAMESPACE, - labelnames=("version",)) - request_count_metric = Counter("requests_count", "Number of requests received", namespace=NAMESPACE, - labelnames=("method", "version")) - tx_request_count_metric = Counter("requested_transaction", "Number of transactions requested", namespace=NAMESPACE) - tx_replied_count_metric = Counter("replied_transaction", "Number of transactions responded", namespace=NAMESPACE) - urls_to_resolve_count_metric = Counter("urls_to_resolve", "Number of urls to resolve", namespace=NAMESPACE) - resolved_url_count_metric = Counter("resolved_url", "Number of resolved urls", namespace=NAMESPACE) - - interrupt_count_metric = Counter("interrupt", "Number of interrupted queries", namespace=NAMESPACE) - db_operational_error_metric = Counter( - "operational_error", "Number of queries that raised operational errors", namespace=NAMESPACE - ) - db_error_metric = Counter( - "internal_error", "Number of queries raising unexpected errors", namespace=NAMESPACE - ) - executor_time_metric = Histogram( - "executor_time", "SQLite executor times", namespace=NAMESPACE, buckets=HISTOGRAM_BUCKETS - ) - pending_query_metric = Gauge( - "pending_queries_count", "Number of pending and running sqlite queries", namespace=NAMESPACE - ) - - client_version_metric = Counter( - "clients", "Number of connections received per client version", - namespace=NAMESPACE, labelnames=("version",) - ) - address_history_metric = Histogram( - "address_history", "Time to fetch an address history", - namespace=NAMESPACE, buckets=HISTOGRAM_BUCKETS - ) - notifications_in_flight_metric = Gauge( - "notifications_in_flight", "Count of notifications in flight", - namespace=NAMESPACE - ) - notifications_sent_metric = Histogram( - "notifications_sent", "Time to send an address notification", - namespace=NAMESPACE, buckets=HISTOGRAM_BUCKETS - ) - - def __init__(self, env: 'Env', db: LevelDB, bp: BlockProcessor, daemon: 'Daemon', shutdown_event: asyncio.Event): - env.max_send = max(350000, env.max_send) - self.env = env - self.db = db - self.bp = bp - self.daemon = daemon - self.mempool = bp.mempool - self.shutdown_event = shutdown_event - self.logger = util.class_logger(__name__, self.__class__.__name__) - self.servers: typing.Dict[str, asyncio.AbstractServer] = {} - self.sessions: typing.Dict[int, 'SessionBase'] = {} - self.hashx_subscriptions_by_session: typing.DefaultDict[str, typing.Set[int]] = defaultdict(set) - self.mempool_statuses = {} - self.cur_group = SessionGroup(0) - self.txs_sent = 0 - self.start_time = time.time() - self.history_cache = self.bp.history_cache - self.notified_height: typing.Optional[int] = None - # Cache some idea of room to avoid recounting on each subscription - self.subs_room = 0 - - self.session_event = Event() - - async def _start_server(self, kind, *args, **kw_args): - loop = asyncio.get_event_loop() - if kind == 'RPC': - protocol_class = LocalRPC - else: - protocol_class = self.env.coin.SESSIONCLS - protocol_factory = partial(protocol_class, self, self.db, - self.mempool, kind) - - host, port = args[:2] - try: - self.servers[kind] = await loop.create_server(protocol_factory, *args, **kw_args) - except OSError as e: # don't suppress CancelledError - self.logger.error(f'{kind} server failed to listen on {host}:' - f'{port:d} :{e!r}') - else: - self.logger.info(f'{kind} server listening on {host}:{port:d}') - - async def _start_external_servers(self): - """Start listening on TCP and SSL ports, but only if the respective - port was given in the environment. - """ - env = self.env - host = env.cs_host(for_rpc=False) - if env.tcp_port is not None: - await self._start_server('TCP', host, env.tcp_port) - if env.ssl_port is not None: - sslc = ssl.SSLContext(ssl.PROTOCOL_TLS) - sslc.load_cert_chain(env.ssl_certfile, keyfile=env.ssl_keyfile) - await self._start_server('SSL', host, env.ssl_port, ssl=sslc) - - async def _close_servers(self, kinds): - """Close the servers of the given kinds (TCP etc.).""" - if kinds: - self.logger.info('closing down {} listening servers' - .format(', '.join(kinds))) - for kind in kinds: - server = self.servers.pop(kind, None) - if server: - server.close() - await server.wait_closed() - - async def _manage_servers(self): - paused = False - max_sessions = self.env.max_sessions - low_watermark = int(max_sessions * 0.95) - while True: - await self.session_event.wait() - self.session_event.clear() - if not paused and len(self.sessions) >= max_sessions: - self.bp.status_server.set_unavailable() - self.logger.info(f'maximum sessions {max_sessions:,d} ' - f'reached, stopping new connections until ' - f'count drops to {low_watermark:,d}') - await self._close_servers(['TCP', 'SSL']) - paused = True - # Start listening for incoming connections if paused and - # session count has fallen - if paused and len(self.sessions) <= low_watermark: - self.bp.status_server.set_available() - self.logger.info('resuming listening for incoming connections') - await self._start_external_servers() - paused = False - - def _group_map(self): - group_map = defaultdict(list) - for session in self.sessions.values(): - group_map[session.group].append(session) - return group_map - - def _sub_count(self) -> int: - return sum(s.sub_count() for s in self.sessions.values()) - - def _lookup_session(self, session_id): - try: - session_id = int(session_id) - except Exception: - pass - else: - for session in self.sessions.values(): - if session.session_id == session_id: - return session - return None - - async def _for_each_session(self, session_ids, operation): - if not isinstance(session_ids, list): - raise RPCError(BAD_REQUEST, 'expected a list of session IDs') - - result = [] - for session_id in session_ids: - session = self._lookup_session(session_id) - if session: - result.append(await operation(session)) - else: - result.append(f'unknown session: {session_id}') - return result - - async def _clear_stale_sessions(self): - """Cut off sessions that haven't done anything for 10 minutes.""" - session_timeout = self.env.session_timeout - while True: - await sleep(session_timeout // 10) - stale_cutoff = time.perf_counter() - session_timeout - stale_sessions = [session for session in self.sessions.values() - if session.last_recv < stale_cutoff] - if stale_sessions: - text = ', '.join(str(session.session_id) - for session in stale_sessions) - self.logger.info(f'closing stale connections {text}') - # Give the sockets some time to close gracefully - if stale_sessions: - await asyncio.wait([ - session.close(force_after=session_timeout // 10) for session in stale_sessions - ]) - - # Consolidate small groups - group_map = self._group_map() - groups = [group for group, sessions in group_map.items() - if len(sessions) <= 5] # fixme: apply session cost here - if len(groups) > 1: - new_group = groups[-1] - for group in groups: - for session in group_map[group]: - session.group = new_group - - def _get_info(self): - """A summary of server state.""" - group_map = self._group_map() - method_counts = collections.defaultdict(int) - error_count = 0 - logged = 0 - paused = 0 - pending_requests = 0 - closing = 0 - - for s in self.sessions.values(): - error_count += s.errors - if s.log_me: - logged += 1 - if not s._can_send.is_set(): - paused += 1 - pending_requests += s.count_pending_items() - if s.is_closing(): - closing += 1 - for request, _ in s.connection._requests.values(): - method_counts[request.method] += 1 - return { - 'closing': closing, - 'daemon': self.daemon.logged_url(), - 'daemon_height': self.daemon.cached_height(), - 'db_height': self.db.db_height, - 'errors': error_count, - 'groups': len(group_map), - 'logged': logged, - 'paused': paused, - 'pid': os.getpid(), - 'peers': [], - 'requests': pending_requests, - 'method_counts': method_counts, - 'sessions': self.session_count(), - 'subs': self._sub_count(), - 'txs_sent': self.txs_sent, - 'uptime': util.formatted_time(time.time() - self.start_time), - 'version': lbry.__version__, - } - - def _group_data(self): - """Returned to the RPC 'groups' call.""" - result = [] - group_map = self._group_map() - for group, sessions in group_map.items(): - result.append([group.gid, - len(sessions), - sum(s.bw_charge for s in sessions), - sum(s.count_pending_items() for s in sessions), - sum(s.txs_sent for s in sessions), - sum(s.sub_count() for s in sessions), - sum(s.recv_count for s in sessions), - sum(s.recv_size for s in sessions), - sum(s.send_count for s in sessions), - sum(s.send_size for s in sessions), - ]) - return result - - async def _electrum_and_raw_headers(self, height): - raw_header = await self.raw_header(height) - electrum_header = self.env.coin.electrum_header(raw_header, height) - return electrum_header, raw_header - - async def _refresh_hsub_results(self, height): - """Refresh the cached header subscription responses to be for height, - and record that as notified_height. - """ - # Paranoia: a reorg could race and leave db_height lower - height = min(height, self.db.db_height) - electrum, raw = await self._electrum_and_raw_headers(height) - self.hsub_results = (electrum, {'hex': raw.hex(), 'height': height}) - self.notified_height = height - - # --- LocalRPC command handlers - - async def rpc_add_peer(self, real_name): - """Add a peer. - - real_name: "bch.electrumx.cash t50001 s50002" for example - """ - await self._notify_peer(real_name) - return f"peer '{real_name}' added" - - async def rpc_disconnect(self, session_ids): - """Disconnect sessions. - - session_ids: array of session IDs - """ - async def close(session): - """Close the session's transport.""" - await session.close(force_after=2) - return f'disconnected {session.session_id}' - - return await self._for_each_session(session_ids, close) - - async def rpc_log(self, session_ids): - """Toggle logging of sessions. - - session_ids: array of session IDs - """ - async def toggle_logging(session): - """Toggle logging of the session.""" - session.toggle_logging() - return f'log {session.session_id}: {session.log_me}' - - return await self._for_each_session(session_ids, toggle_logging) - - async def rpc_daemon_url(self, daemon_url): - """Replace the daemon URL.""" - daemon_url = daemon_url or self.env.daemon_url - try: - self.daemon.set_url(daemon_url) - except Exception as e: - raise RPCError(BAD_REQUEST, f'an error occurred: {e!r}') - return f'now using daemon at {self.daemon.logged_url()}' - - async def rpc_stop(self): - """Shut down the server cleanly.""" - self.shutdown_event.set() - return 'stopping' - - async def rpc_getinfo(self): - """Return summary information about the server process.""" - return self._get_info() - - async def rpc_groups(self): - """Return statistics about the session groups.""" - return self._group_data() - - async def rpc_peers(self): - """Return a list of data about server peers.""" - return self.env.peer_hubs - - async def rpc_query(self, items, limit): - """Return a list of data about server peers.""" - coin = self.env.coin - db = self.db - lines = [] - - def arg_to_hashX(arg): - try: - script = bytes.fromhex(arg) - lines.append(f'Script: {arg}') - return coin.hashX_from_script(script) - except ValueError: - pass - - try: - hashX = coin.address_to_hashX(arg) - except Base58Error as e: - lines.append(e.args[0]) - return None - lines.append(f'Address: {arg}') - return hashX - - for arg in items: - hashX = arg_to_hashX(arg) - if not hashX: - continue - n = None - history = await db.limited_history(hashX, limit=limit) - for n, (tx_hash, height) in enumerate(history): - lines.append(f'History #{n:,d}: height {height:,d} ' - f'tx_hash {hash_to_hex_str(tx_hash)}') - if n is None: - lines.append('No history found') - n = None - utxos = await db.all_utxos(hashX) - for n, utxo in enumerate(utxos, start=1): - lines.append(f'UTXO #{n:,d}: tx_hash ' - f'{hash_to_hex_str(utxo.tx_hash)} ' - f'tx_pos {utxo.tx_pos:,d} height ' - f'{utxo.height:,d} value {utxo.value:,d}') - if n == limit: - break - if n is None: - lines.append('No UTXOs found') - - balance = sum(utxo.value for utxo in utxos) - lines.append(f'Balance: {coin.decimal_value(balance):,f} ' - f'{coin.SHORTNAME}') - - return lines - - # async def rpc_reorg(self, count): - # """Force a reorg of the given number of blocks. - # - # count: number of blocks to reorg - # """ - # count = non_negative_integer(count) - # if not self.bp.force_chain_reorg(count): - # raise RPCError(BAD_REQUEST, 'still catching up with daemon') - # return f'scheduled a reorg of {count:,d} blocks' - - # --- External Interface - - async def serve(self, mempool, server_listening_event): - """Start the RPC server if enabled. When the event is triggered, - start TCP and SSL servers.""" - try: - if self.env.rpc_port is not None: - await self._start_server('RPC', self.env.cs_host(for_rpc=True), - self.env.rpc_port) - self.logger.info(f'max session count: {self.env.max_sessions:,d}') - self.logger.info(f'session timeout: ' - f'{self.env.session_timeout:,d} seconds') - self.logger.info(f'max response size {self.env.max_send:,d} bytes') - if self.env.drop_client is not None: - self.logger.info(f'drop clients matching: {self.env.drop_client.pattern}') - # Start notifications; initialize hsub_results - await mempool.start(self.db.db_height, self) - await self.start_other() - await self._start_external_servers() - server_listening_event.set() - self.bp.status_server.set_available() - # Peer discovery should start after the external servers - # because we connect to ourself - await asyncio.wait([ - self._clear_stale_sessions(), - self._manage_servers() - ]) - except Exception as err: - if not isinstance(err, asyncio.CancelledError): - log.exception("hub server died") - raise err - finally: - await self._close_servers(list(self.servers.keys())) - log.warning("disconnect %i sessions", len(self.sessions)) - if self.sessions: - await asyncio.wait([ - session.close(force_after=1) for session in self.sessions.values() - ]) - await self.stop_other() - - async def start_other(self): - pass - - async def stop_other(self): - pass - - def session_count(self) -> int: - """The number of connections that we've sent something to.""" - return len(self.sessions) - - async def daemon_request(self, method, *args): - """Catch a DaemonError and convert it to an RPCError.""" - try: - return await getattr(self.daemon, method)(*args) - except DaemonError as e: - raise RPCError(DAEMON_ERROR, f'daemon error: {e!r}') from None - - async def raw_header(self, height): - """Return the binary header at the given height.""" - try: - return self.db.raw_header(height) - except IndexError: - raise RPCError(BAD_REQUEST, f'height {height:,d} ' - 'out of range') from None - - async def electrum_header(self, height): - """Return the deserialized header at the given height.""" - electrum_header, _ = await self._electrum_and_raw_headers(height) - return electrum_header - - async def broadcast_transaction(self, raw_tx): - hex_hash = await self.daemon.broadcast_transaction(raw_tx) - self.txs_sent += 1 - return hex_hash - - async def limited_history(self, hashX): - """A caching layer.""" - if hashX not in self.history_cache: - # History DoS limit. Each element of history is about 99 - # bytes when encoded as JSON. This limits resource usage - # on bloated history requests, and uses a smaller divisor - # so large requests are logged before refusing them. - limit = self.env.max_send // 97 - self.history_cache[hashX] = await self.db.limited_history(hashX, limit=limit) - return self.history_cache[hashX] - - def _notify_peer(self, peer): - notify_tasks = [ - session.send_notification('blockchain.peers.subscribe', [peer]) - for session in self.sessions.values() if session.subscribe_peers - ] - if notify_tasks: - self.logger.info(f'notify {len(notify_tasks)} sessions of new peers') - asyncio.create_task(asyncio.wait(notify_tasks)) - - async def _notify_sessions(self, height, touched, new_touched): - """Notify sessions about height changes and touched addresses.""" - height_changed = height != self.notified_height - if height_changed: - await self._refresh_hsub_results(height) - - if not self.sessions: - return - - if height_changed: - header_tasks = [ - session.send_notification('blockchain.headers.subscribe', (self.hsub_results[session.subscribe_headers_raw], )) - for session in self.sessions.values() if session.subscribe_headers - ] - if header_tasks: - self.logger.info(f'notify {len(header_tasks)} sessions of new header') - asyncio.create_task(asyncio.wait(header_tasks)) - for hashX in touched.intersection(self.mempool_statuses.keys()): - self.mempool_statuses.pop(hashX, None) - - await asyncio.get_event_loop().run_in_executor( - self.bp._chain_executor, touched.intersection_update, self.hashx_subscriptions_by_session.keys() - ) - - if touched or new_touched or (height_changed and self.mempool_statuses): - notified_hashxs = 0 - session_hashxes_to_notify = defaultdict(list) - to_notify = touched if height_changed else new_touched - - for hashX in to_notify: - if hashX not in self.hashx_subscriptions_by_session: - continue - for session_id in self.hashx_subscriptions_by_session[hashX]: - session_hashxes_to_notify[session_id].append(hashX) - notified_hashxs += 1 - for session_id, hashXes in session_hashxes_to_notify.items(): - asyncio.create_task(self.sessions[session_id].send_history_notifications(*hashXes)) - if session_hashxes_to_notify: - self.logger.info(f'notified {len(session_hashxes_to_notify)} sessions/{notified_hashxs:,d} touched addresses') - - def add_session(self, session): - self.sessions[id(session)] = session - self.session_event.set() - gid = int(session.start_time - self.start_time) // 900 - if self.cur_group.gid != gid: - self.cur_group = SessionGroup(gid) - return self.cur_group - - def remove_session(self, session): - """Remove a session from our sessions list if there.""" - session_id = id(session) - for hashX in session.hashX_subs: - sessions = self.hashx_subscriptions_by_session[hashX] - sessions.remove(session_id) - if not sessions: - self.hashx_subscriptions_by_session.pop(hashX) - self.sessions.pop(session_id) - self.session_event.set() - - -class SessionBase(RPCSession): - """Base class of ElectrumX JSON sessions. - - Each session runs its tasks in asynchronous parallelism with other - sessions. - """ - - MAX_CHUNK_SIZE = 40960 - session_counter = itertools.count() - request_handlers: typing.Dict[str, typing.Callable] = {} - version = '0.5.7' - - def __init__(self, session_mgr, db, mempool, kind): - connection = JSONRPCConnection(JSONRPCAutoDetect) - self.env = session_mgr.env - super().__init__(connection=connection) - self.logger = util.class_logger(__name__, self.__class__.__name__) - self.session_mgr = session_mgr - self.db = db - self.mempool = mempool - self.kind = kind # 'RPC', 'TCP' etc. - self.coin = self.env.coin - self.anon_logs = self.env.anon_logs - self.txs_sent = 0 - self.log_me = False - self.daemon_request = self.session_mgr.daemon_request - # Hijack the connection so we can log messages - self._receive_message_orig = self.connection.receive_message - self.connection.receive_message = self.receive_message - - - def default_framer(self): - return NewlineFramer(self.env.max_receive) - - def peer_address_str(self, *, for_log=True): - """Returns the peer's IP address and port as a human-readable - string, respecting anon logs if the output is for a log.""" - if for_log and self.anon_logs: - return 'xx.xx.xx.xx:xx' - return super().peer_address_str() - - def receive_message(self, message): - if self.log_me: - self.logger.info(f'processing {message}') - return self._receive_message_orig(message) - - def toggle_logging(self): - self.log_me = not self.log_me - - def connection_made(self, transport): - """Handle an incoming client connection.""" - super().connection_made(transport) - self.session_id = next(self.session_counter) - context = {'conn_id': f'{self.session_id}'} - self.logger = util.ConnectionLogger(self.logger, context) - self.group = self.session_mgr.add_session(self) - self.session_mgr.session_count_metric.labels(version=self.client_version).inc() - peer_addr_str = self.peer_address_str() - self.logger.info(f'{self.kind} {peer_addr_str}, ' - f'{self.session_mgr.session_count():,d} total') - - def connection_lost(self, exc): - """Handle client disconnection.""" - super().connection_lost(exc) - self.session_mgr.remove_session(self) - self.session_mgr.session_count_metric.labels(version=self.client_version).dec() - msg = '' - if not self._can_send.is_set(): - msg += ' whilst paused' - if self.send_size >= 1024*1024: - msg += ('. Sent {:,d} bytes in {:,d} messages' - .format(self.send_size, self.send_count)) - if msg: - msg = 'disconnected' + msg - self.logger.info(msg) - - def count_pending_items(self): - return len(self.connection.pending_requests()) - - def semaphore(self): - return Semaphores([self.group.semaphore]) - - def sub_count(self): - return 0 - - async def handle_request(self, request): - """Handle an incoming request. ElectrumX doesn't receive - notifications from client sessions. - """ - self.session_mgr.request_count_metric.labels(method=request.method, version=self.client_version).inc() - if isinstance(request, Request): - handler = self.request_handlers.get(request.method) - handler = partial(handler, self) - else: - handler = None - coro = handler_invocation(handler, request)() - return await coro - - -class LBRYSessionManager(SessionManager): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.websocket = None - # self.metrics = ServerLoadData() - # self.metrics_loop = None - self.running = False - if self.env.websocket_host is not None and self.env.websocket_port is not None: - self.websocket = AdminWebSocket(self) - - # async def process_metrics(self): - # while self.running: - # data = self.metrics.to_json_and_reset({ - # 'sessions': self.session_count(), - # 'height': self.db.db_height, - # }) - # if self.websocket is not None: - # self.websocket.send_message(data) - # await asyncio.sleep(1) - - async def start_other(self): - self.running = True - if self.websocket is not None: - await self.websocket.start() - - async def stop_other(self): - self.running = False - if self.websocket is not None: - await self.websocket.stop() - - -class LBRYElectrumX(SessionBase): - """A TCP server that handles incoming Electrum connections.""" - - PROTOCOL_MIN = VERSION.PROTOCOL_MIN - PROTOCOL_MAX = VERSION.PROTOCOL_MAX - max_errors = math.inf # don't disconnect people for errors! let them happen... - session_mgr: LBRYSessionManager - version = lbry.__version__ - cached_server_features = {} - - @classmethod - def initialize_request_handlers(cls): - cls.request_handlers.update({ - 'blockchain.block.get_chunk': cls.block_get_chunk, - 'blockchain.block.get_header': cls.block_get_header, - 'blockchain.estimatefee': cls.estimatefee, - 'blockchain.relayfee': cls.relayfee, - 'blockchain.scripthash.get_balance': cls.scripthash_get_balance, - 'blockchain.scripthash.get_history': cls.scripthash_get_history, - 'blockchain.scripthash.get_mempool': cls.scripthash_get_mempool, - 'blockchain.scripthash.listunspent': cls.scripthash_listunspent, - 'blockchain.scripthash.subscribe': cls.scripthash_subscribe, - 'blockchain.transaction.broadcast': cls.transaction_broadcast, - 'blockchain.transaction.get': cls.transaction_get, - 'blockchain.transaction.get_batch': cls.transaction_get_batch, - 'blockchain.transaction.info': cls.transaction_info, - 'blockchain.transaction.get_merkle': cls.transaction_merkle, - 'server.add_peer': cls.add_peer, - 'server.banner': cls.banner, - 'server.payment_address': cls.payment_address, - 'server.donation_address': cls.donation_address, - 'server.features': cls.server_features_async, - 'server.peers.subscribe': cls.peers_subscribe, - 'server.version': cls.server_version, - 'blockchain.transaction.get_height': cls.transaction_get_height, - 'blockchain.claimtrie.search': cls.claimtrie_search, - 'blockchain.claimtrie.resolve': cls.claimtrie_resolve, - 'blockchain.claimtrie.getclaimbyid': cls.claimtrie_getclaimbyid, - # 'blockchain.claimtrie.getclaimsbyids': cls.claimtrie_getclaimsbyids, - 'blockchain.block.get_server_height': cls.get_server_height, - 'mempool.get_fee_histogram': cls.mempool_compact_histogram, - 'blockchain.block.headers': cls.block_headers, - 'server.ping': cls.ping, - 'blockchain.headers.subscribe': cls.headers_subscribe_False, - 'blockchain.address.get_balance': cls.address_get_balance, - 'blockchain.address.get_history': cls.address_get_history, - 'blockchain.address.get_mempool': cls.address_get_mempool, - 'blockchain.address.listunspent': cls.address_listunspent, - 'blockchain.address.subscribe': cls.address_subscribe, - 'blockchain.address.unsubscribe': cls.address_unsubscribe, - }) - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - if not LBRYElectrumX.request_handlers: - LBRYElectrumX.initialize_request_handlers() - if not LBRYElectrumX.cached_server_features: - LBRYElectrumX.set_server_features(self.env) - self.subscribe_headers = False - self.subscribe_headers_raw = False - self.subscribe_peers = False - self.connection.max_response_size = self.env.max_send - self.hashX_subs = {} - self.sv_seen = False - self.protocol_tuple = self.PROTOCOL_MIN - self.protocol_string = None - self.daemon = self.session_mgr.daemon - self.bp: BlockProcessor = self.session_mgr.bp - self.db: LevelDB = self.bp.db - - @classmethod - def protocol_min_max_strings(cls): - return [util.version_string(ver) - for ver in (cls.PROTOCOL_MIN, cls.PROTOCOL_MAX)] - - @classmethod - def set_server_features(cls, env): - """Return the server features dictionary.""" - min_str, max_str = cls.protocol_min_max_strings() - cls.cached_server_features.update({ - 'hosts': env.hosts_dict(), - 'pruning': None, - 'server_version': cls.version, - 'protocol_min': min_str, - 'protocol_max': max_str, - 'genesis_hash': env.coin.GENESIS_HASH, - 'description': env.description, - 'payment_address': env.payment_address, - 'donation_address': env.donation_address, - 'daily_fee': env.daily_fee, - 'hash_function': 'sha256', - 'trending_algorithm': 'fast_ar' - }) - - async def server_features_async(self): - return self.cached_server_features - - @classmethod - def server_version_args(cls): - """The arguments to a server.version RPC call to a peer.""" - return [cls.version, cls.protocol_min_max_strings()] - - def protocol_version_string(self): - return util.version_string(self.protocol_tuple) - - def sub_count(self): - return len(self.hashX_subs) - - async def send_history_notifications(self, *hashXes: typing.Iterable[bytes]): - notifications = [] - for hashX in hashXes: - alias = self.hashX_subs[hashX] - if len(alias) == 64: - method = 'blockchain.scripthash.subscribe' - else: - method = 'blockchain.address.subscribe' - start = time.perf_counter() - db_history = await self.session_mgr.limited_history(hashX) - mempool = self.mempool.transaction_summaries(hashX) - - status = ''.join(f'{hash_to_hex_str(tx_hash)}:' - f'{height:d}:' - for tx_hash, height in db_history) - status += ''.join(f'{hash_to_hex_str(tx.hash)}:' - f'{-tx.has_unconfirmed_inputs:d}:' - for tx in mempool) - if status: - status = sha256(status.encode()).hex() - else: - status = None - if mempool: - self.session_mgr.mempool_statuses[hashX] = status - else: - self.session_mgr.mempool_statuses.pop(hashX, None) - - self.session_mgr.address_history_metric.observe(time.perf_counter() - start) - notifications.append((method, (alias, status))) - - start = time.perf_counter() - self.session_mgr.notifications_in_flight_metric.inc() - for method, args in notifications: - self.NOTIFICATION_COUNT.labels(method=method, version=self.client_version).inc() - try: - await self.send_notifications( - Batch([Notification(method, (alias, status)) for (method, (alias, status)) in notifications]) - ) - self.session_mgr.notifications_sent_metric.observe(time.perf_counter() - start) - finally: - self.session_mgr.notifications_in_flight_metric.dec() - - # def get_metrics_or_placeholder_for_api(self, query_name): - # """ Do not hold on to a reference to the metrics - # returned by this method past an `await` or - # you may be working with a stale metrics object. - # """ - # if self.env.track_metrics: - # # return self.session_mgr.metrics.for_api(query_name) - # else: - # return APICallMetrics(query_name) - - - # async def run_and_cache_query(self, query_name, kwargs): - # start = time.perf_counter() - # if isinstance(kwargs, dict): - # kwargs['release_time'] = format_release_time(kwargs.get('release_time')) - # try: - # self.session_mgr.pending_query_metric.inc() - # return await self.db.search_index.session_query(query_name, kwargs) - # except ConnectionTimeout: - # self.session_mgr.interrupt_count_metric.inc() - # raise RPCError(JSONRPC.QUERY_TIMEOUT, 'query timed out') - # finally: - # self.session_mgr.pending_query_metric.dec() - # self.session_mgr.executor_time_metric.observe(time.perf_counter() - start) - - async def mempool_compact_histogram(self): - return self.mempool.compact_fee_histogram() - - async def claimtrie_search(self, **kwargs): - start = time.perf_counter() - if 'release_time' in kwargs: - release_time = kwargs.pop('release_time') - release_times = release_time if isinstance(release_time, list) else [release_time] - try: - kwargs['release_time'] = [format_release_time(release_time) for release_time in release_times] - except ValueError: - pass - try: - self.session_mgr.pending_query_metric.inc() - if 'channel' in kwargs: - channel_url = kwargs.pop('channel') - _, channel_claim, _, _ = await self.db.resolve(channel_url) - if not channel_claim or isinstance(channel_claim, (ResolveCensoredError, LookupError, ValueError)): - return Outputs.to_base64([], [], 0, None, None) - kwargs['channel_id'] = channel_claim.claim_hash.hex() - return await self.db.search_index.cached_search(kwargs) - except ConnectionTimeout: - self.session_mgr.interrupt_count_metric.inc() - raise RPCError(JSONRPC.QUERY_TIMEOUT, 'query timed out') - except TooManyClaimSearchParametersError as err: - await asyncio.sleep(2) - self.logger.warning("Got an invalid query from %s, for %s with more than %d elements.", - self.peer_address()[0], err.key, err.limit) - return RPCError(1, str(err)) - finally: - self.session_mgr.pending_query_metric.dec() - self.session_mgr.executor_time_metric.observe(time.perf_counter() - start) - - async def _cached_resolve_url(self, url): - if url not in self.bp.resolve_cache: - self.bp.resolve_cache[url] = await self.loop.run_in_executor(None, self.db._resolve, url) - return self.bp.resolve_cache[url] - - async def claimtrie_resolve(self, *urls) -> str: - sorted_urls = tuple(sorted(urls)) - self.session_mgr.urls_to_resolve_count_metric.inc(len(sorted_urls)) - try: - if sorted_urls in self.bp.resolve_outputs_cache: - return self.bp.resolve_outputs_cache[sorted_urls] - rows, extra = [], [] - for url in urls: - if url not in self.bp.resolve_cache: - self.bp.resolve_cache[url] = await self._cached_resolve_url(url) - stream, channel, repost, reposted_channel = self.bp.resolve_cache[url] - if isinstance(channel, ResolveCensoredError): - rows.append(channel) - extra.append(channel.censor_row) - elif isinstance(stream, ResolveCensoredError): - rows.append(stream) - extra.append(stream.censor_row) - elif channel and not stream: - rows.append(channel) - # print("resolved channel", channel.name.decode()) - if repost: - extra.append(repost) - if reposted_channel: - extra.append(reposted_channel) - elif stream: - # print("resolved stream", stream.name.decode()) - rows.append(stream) - if channel: - # print("and channel", channel.name.decode()) - extra.append(channel) - if repost: - extra.append(repost) - if reposted_channel: - extra.append(reposted_channel) - await asyncio.sleep(0) - self.bp.resolve_outputs_cache[sorted_urls] = result = await self.loop.run_in_executor( - None, Outputs.to_base64, rows, extra, 0, None, None - ) - return result - finally: - self.session_mgr.resolved_url_count_metric.inc(len(sorted_urls)) - - async def get_server_height(self): - return self.bp.height - - async def transaction_get_height(self, tx_hash): - self.assert_tx_hash(tx_hash) - transaction_info = await self.daemon.getrawtransaction(tx_hash, True) - if transaction_info and 'hex' in transaction_info and 'confirmations' in transaction_info: - # an unconfirmed transaction from lbrycrdd will not have a 'confirmations' field - return (self.db.db_height - transaction_info['confirmations']) + 1 - elif transaction_info and 'hex' in transaction_info: - return -1 - return None - - async def claimtrie_getclaimbyid(self, claim_id): - rows = [] - extra = [] - stream = await self.db.fs_getclaimbyid(claim_id) - if not stream: - stream = LookupError(f"Could not find claim at {claim_id}") - rows.append(stream) - return Outputs.to_base64(rows, extra, 0, None, None) - - def assert_tx_hash(self, value): - '''Raise an RPCError if the value is not a valid transaction - hash.''' - try: - if len(util.hex_to_bytes(value)) == 32: - return - except Exception: - pass - raise RPCError(1, f'{value} should be a transaction hash') - - async def subscribe_headers_result(self): - """The result of a header subscription or notification.""" - return self.session_mgr.hsub_results[self.subscribe_headers_raw] - - async def _headers_subscribe(self, raw): - """Subscribe to get headers of new blocks.""" - self.subscribe_headers_raw = assert_boolean(raw) - self.subscribe_headers = True - return await self.subscribe_headers_result() - - async def headers_subscribe(self): - """Subscribe to get raw headers of new blocks.""" - return await self._headers_subscribe(True) - - async def headers_subscribe_True(self, raw=True): - """Subscribe to get headers of new blocks.""" - return await self._headers_subscribe(raw) - - async def headers_subscribe_False(self, raw=False): - """Subscribe to get headers of new blocks.""" - return await self._headers_subscribe(raw) - - async def add_peer(self, features): - """Add a peer (but only if the peer resolves to the source).""" - return await self.peer_mgr.on_add_peer(features, self.peer_address()) - - async def peers_subscribe(self): - """Return the server peers as a list of (ip, host, details) tuples.""" - self.subscribe_peers = True - return self.env.peer_hubs - - async def address_status(self, hashX): - """Returns an address status. - - Status is a hex string, but must be None if there is no history. - """ - # Note history is ordered and mempool unordered in electrum-server - # For mempool, height is -1 if it has unconfirmed inputs, otherwise 0 - - db_history = await self.session_mgr.limited_history(hashX) - mempool = self.mempool.transaction_summaries(hashX) - - status = ''.join(f'{hash_to_hex_str(tx_hash)}:' - f'{height:d}:' - for tx_hash, height in db_history) - status += ''.join(f'{hash_to_hex_str(tx.hash)}:' - f'{-tx.has_unconfirmed_inputs:d}:' - for tx in mempool) - if status: - status = sha256(status.encode()).hex() - else: - status = None - - if mempool: - self.session_mgr.mempool_statuses[hashX] = status - else: - self.session_mgr.mempool_statuses.pop(hashX, None) - return status - - async def hashX_listunspent(self, hashX): - """Return the list of UTXOs of a script hash, including mempool - effects.""" - utxos = await self.db.all_utxos(hashX) - utxos = sorted(utxos) - utxos.extend(await self.mempool.unordered_UTXOs(hashX)) - spends = await self.mempool.potential_spends(hashX) - - return [{'tx_hash': hash_to_hex_str(utxo.tx_hash), - 'tx_pos': utxo.tx_pos, - 'height': utxo.height, 'value': utxo.value} - for utxo in utxos - if (utxo.tx_hash, utxo.tx_pos) not in spends] - - async def hashX_subscribe(self, hashX, alias): - self.hashX_subs[hashX] = alias - self.session_mgr.hashx_subscriptions_by_session[hashX].add(id(self)) - return await self.address_status(hashX) - - async def hashX_unsubscribe(self, hashX, alias): - sessions = self.session_mgr.hashx_subscriptions_by_session[hashX] - sessions.remove(id(self)) - if not sessions: - self.hashX_subs.pop(hashX, None) - - def address_to_hashX(self, address): - try: - return self.coin.address_to_hashX(address) - except Exception: - pass - raise RPCError(BAD_REQUEST, f'{address} is not a valid address') - - async def address_get_balance(self, address): - """Return the confirmed and unconfirmed balance of an address.""" - hashX = self.address_to_hashX(address) - return await self.get_balance(hashX) - - async def address_get_history(self, address): - """Return the confirmed and unconfirmed history of an address.""" - hashX = self.address_to_hashX(address) - return await self.confirmed_and_unconfirmed_history(hashX) - - async def address_get_mempool(self, address): - """Return the mempool transactions touching an address.""" - hashX = self.address_to_hashX(address) - return self.unconfirmed_history(hashX) - - async def address_listunspent(self, address): - """Return the list of UTXOs of an address.""" - hashX = self.address_to_hashX(address) - return await self.hashX_listunspent(hashX) - - async def address_subscribe(self, *addresses): - """Subscribe to an address. - - address: the address to subscribe to""" - if len(addresses) > 1000: - raise RPCError(BAD_REQUEST, f'too many addresses in subscription request: {len(addresses)}') - results = [] - for address in addresses: - results.append(await self.hashX_subscribe(self.address_to_hashX(address), address)) - await asyncio.sleep(0) - return results - - async def address_unsubscribe(self, address): - """Unsubscribe an address. - - address: the address to unsubscribe""" - hashX = self.address_to_hashX(address) - return await self.hashX_unsubscribe(hashX, address) - - async def get_balance(self, hashX): - utxos = await self.db.all_utxos(hashX) - confirmed = sum(utxo.value for utxo in utxos) - unconfirmed = await self.mempool.balance_delta(hashX) - return {'confirmed': confirmed, 'unconfirmed': unconfirmed} - - async def scripthash_get_balance(self, scripthash): - """Return the confirmed and unconfirmed balance of a scripthash.""" - hashX = scripthash_to_hashX(scripthash) - return await self.get_balance(hashX) - - def unconfirmed_history(self, hashX): - # Note unconfirmed history is unordered in electrum-server - # height is -1 if it has unconfirmed inputs, otherwise 0 - return [{'tx_hash': hash_to_hex_str(tx.hash), - 'height': -tx.has_unconfirmed_inputs, - 'fee': tx.fee} - for tx in self.mempool.transaction_summaries(hashX)] - - async def confirmed_and_unconfirmed_history(self, hashX): - # Note history is ordered but unconfirmed is unordered in e-s - history = await self.session_mgr.limited_history(hashX) - conf = [{'tx_hash': hash_to_hex_str(tx_hash), 'height': height} - for tx_hash, height in history] - return conf + self.unconfirmed_history(hashX) - - async def scripthash_get_history(self, scripthash): - """Return the confirmed and unconfirmed history of a scripthash.""" - hashX = scripthash_to_hashX(scripthash) - return await self.confirmed_and_unconfirmed_history(hashX) - - async def scripthash_get_mempool(self, scripthash): - """Return the mempool transactions touching a scripthash.""" - hashX = scripthash_to_hashX(scripthash) - return self.unconfirmed_history(hashX) - - async def scripthash_listunspent(self, scripthash): - """Return the list of UTXOs of a scripthash.""" - hashX = scripthash_to_hashX(scripthash) - return await self.hashX_listunspent(hashX) - - async def scripthash_subscribe(self, scripthash): - """Subscribe to a script hash. - - scripthash: the SHA256 hash of the script to subscribe to""" - hashX = scripthash_to_hashX(scripthash) - return await self.hashX_subscribe(hashX, scripthash) - - async def _merkle_proof(self, cp_height, height): - max_height = self.db.db_height - if not height <= cp_height <= max_height: - raise RPCError(BAD_REQUEST, - f'require header height {height:,d} <= ' - f'cp_height {cp_height:,d} <= ' - f'chain height {max_height:,d}') - branch, root = await self.db.header_branch_and_root(cp_height + 1, height) - return { - 'branch': [hash_to_hex_str(elt) for elt in branch], - 'root': hash_to_hex_str(root), - } - - async def block_headers(self, start_height, count, cp_height=0, b64=False): - """Return count concatenated block headers as hex for the main chain; - starting at start_height. - - start_height and count must be non-negative integers. At most - MAX_CHUNK_SIZE headers will be returned. - """ - start_height = non_negative_integer(start_height) - count = non_negative_integer(count) - cp_height = non_negative_integer(cp_height) - - max_size = self.MAX_CHUNK_SIZE - count = min(count, max_size) - headers, count = self.db.read_headers(start_height, count) - - if b64: - headers = self.db.encode_headers(start_height, count, headers) - else: - headers = headers.hex() - result = { - 'base64' if b64 else 'hex': headers, - 'count': count, - 'max': max_size - } - if count and cp_height: - last_height = start_height + count - 1 - result.update(await self._merkle_proof(cp_height, last_height)) - return result - - async def block_get_chunk(self, index): - """Return a chunk of block headers as a hexadecimal string. - - index: the chunk index""" - index = non_negative_integer(index) - size = self.coin.CHUNK_SIZE - start_height = index * size - headers, _ = self.db.read_headers(start_height, size) - return headers.hex() - - async def block_get_header(self, height): - """The deserialized header at a given height. - - height: the header's height""" - height = non_negative_integer(height) - return await self.session_mgr.electrum_header(height) - - def is_tor(self): - """Try to detect if the connection is to a tor hidden service we are - running.""" - peername = self.peer_mgr.proxy_peername() - if not peername: - return False - peer_address = self.peer_address() - return peer_address and peer_address[0] == peername[0] - - async def replaced_banner(self, banner): - network_info = await self.daemon_request('getnetworkinfo') - ni_version = network_info['version'] - major, minor = divmod(ni_version, 1000000) - minor, revision = divmod(minor, 10000) - revision //= 100 - daemon_version = f'{major:d}.{minor:d}.{revision:d}' - for pair in [ - ('$SERVER_VERSION', self.version), - ('$DAEMON_VERSION', daemon_version), - ('$DAEMON_SUBVERSION', network_info['subversion']), - ('$PAYMENT_ADDRESS', self.env.payment_address), - ('$DONATION_ADDRESS', self.env.donation_address), - ]: - banner = banner.replace(*pair) - return banner - - async def payment_address(self): - """Return the payment address as a string, empty if there is none.""" - return self.env.payment_address - - async def donation_address(self): - """Return the donation address as a string, empty if there is none.""" - return self.env.donation_address - - async def banner(self): - """Return the server banner text.""" - banner = f'You are connected to an {self.version} server.' - banner_file = self.env.banner_file - if banner_file: - try: - with codecs.open(banner_file, 'r', 'utf-8') as f: - banner = f.read() - except Exception as e: - self.logger.error(f'reading banner file {banner_file}: {e!r}') - else: - banner = await self.replaced_banner(banner) - - return banner - - async def relayfee(self): - """The minimum fee a low-priority tx must pay in order to be accepted - to the daemon's memory pool.""" - return await self.daemon_request('relayfee') - - async def estimatefee(self, number): - """The estimated transaction fee per kilobyte to be paid for a - transaction to be included within a certain number of blocks. - - number: the number of blocks - """ - number = non_negative_integer(number) - return await self.daemon_request('estimatefee', number) - - async def ping(self): - """Serves as a connection keep-alive mechanism and for the client to - confirm the server is still responding. - """ - return None - - async def server_version(self, client_name='', protocol_version=None): - """Returns the server version as a string. - - client_name: a string identifying the client - protocol_version: the protocol version spoken by the client - """ - if self.protocol_string is not None: - return self.version, self.protocol_string - if self.sv_seen and self.protocol_tuple >= (1, 4): - raise RPCError(BAD_REQUEST, f'server.version already sent') - self.sv_seen = True - - if client_name: - client_name = str(client_name) - if self.env.drop_client is not None and \ - self.env.drop_client.match(client_name): - self.close_after_send = True - raise RPCError(BAD_REQUEST, f'unsupported client: {client_name}') - if self.client_version != client_name[:17]: - self.session_mgr.session_count_metric.labels(version=self.client_version).dec() - self.client_version = client_name[:17] - self.session_mgr.session_count_metric.labels(version=self.client_version).inc() - self.session_mgr.client_version_metric.labels(version=self.client_version).inc() - - # Find the highest common protocol version. Disconnect if - # that protocol version in unsupported. - ptuple, client_min = util.protocol_version(protocol_version, self.PROTOCOL_MIN, self.PROTOCOL_MAX) - if ptuple is None: - ptuple, client_min = util.protocol_version(protocol_version, (1, 1, 0), (1, 4, 0)) - if ptuple is None: - self.close_after_send = True - raise RPCError(BAD_REQUEST, f'unsupported protocol version: {protocol_version}') - - self.protocol_tuple = ptuple - self.protocol_string = util.version_string(ptuple) - return self.version, self.protocol_string - - async def transaction_broadcast(self, raw_tx): - """Broadcast a raw transaction to the network. - - raw_tx: the raw transaction as a hexadecimal string""" - # This returns errors as JSON RPC errors, as is natural - try: - hex_hash = await self.session_mgr.broadcast_transaction(raw_tx) - self.txs_sent += 1 - self.mempool.wakeup.set() - self.logger.info(f'sent tx: {hex_hash}') - return hex_hash - except DaemonError as e: - error, = e.args - message = error['message'] - self.logger.info(f'error sending transaction: {message}') - raise RPCError(BAD_REQUEST, 'the transaction was rejected by ' - f'network rules.\n\n{message}\n[{raw_tx}]') - - async def transaction_info(self, tx_hash: str): - return (await self.transaction_get_batch(tx_hash))[tx_hash] - - async def transaction_get_batch(self, *tx_hashes): - self.session_mgr.tx_request_count_metric.inc(len(tx_hashes)) - if len(tx_hashes) > 100: - raise RPCError(BAD_REQUEST, f'too many tx hashes in request: {len(tx_hashes)}') - for tx_hash in tx_hashes: - assert_tx_hash(tx_hash) - batch_result = await self.db.get_transactions_and_merkles(tx_hashes) - needed_merkles = {} - - for tx_hash in tx_hashes: - if tx_hash in batch_result and batch_result[tx_hash][0]: - continue - tx_hash_bytes = bytes.fromhex(tx_hash)[::-1] - mempool_tx = self.mempool.txs.get(tx_hash_bytes, None) - if mempool_tx: - raw_tx, block_hash = mempool_tx.raw_tx.hex(), None - else: - tx_info = await self.daemon_request('getrawtransaction', tx_hash, True) - raw_tx = tx_info['hex'] - block_hash = tx_info.get('blockhash') - if block_hash: - block = await self.daemon.deserialised_block(block_hash) - height = block['height'] - try: - pos = block['tx'].index(tx_hash) - except ValueError: - raise RPCError(BAD_REQUEST, f'tx hash {tx_hash} not in ' - f'block {block_hash} at height {height:,d}') - needed_merkles[tx_hash] = raw_tx, block['tx'], pos, height - else: - batch_result[tx_hash] = [raw_tx, {'block_height': -1}] - - if needed_merkles: - for tx_hash, (raw_tx, block_txs, pos, block_height) in needed_merkles.items(): - batch_result[tx_hash] = raw_tx, { - 'merkle': self._get_merkle_branch(block_txs, pos), - 'pos': pos, - 'block_height': block_height - } - await asyncio.sleep(0) # heavy call, give other tasks a chance - - self.session_mgr.tx_replied_count_metric.inc(len(tx_hashes)) - return batch_result - - async def transaction_get(self, tx_hash, verbose=False): - """Return the serialized raw transaction given its hash - - tx_hash: the transaction hash as a hexadecimal string - verbose: passed on to the daemon - """ - assert_tx_hash(tx_hash) - if verbose not in (True, False): - raise RPCError(BAD_REQUEST, f'"verbose" must be a boolean') - - return await self.daemon_request('getrawtransaction', tx_hash, verbose) - - def _get_merkle_branch(self, tx_hashes, tx_pos): - """Return a merkle branch to a transaction. - - tx_hashes: ordered list of hex strings of tx hashes in a block - tx_pos: index of transaction in tx_hashes to create branch for - """ - hashes = [hex_str_to_hash(hash) for hash in tx_hashes] - branch, root = self.db.merkle.branch_and_root(hashes, tx_pos) - branch = [hash_to_hex_str(hash) for hash in branch] - return branch - - async def transaction_merkle(self, tx_hash, height): - """Return the markle branch to a confirmed transaction given its hash - and height. - - tx_hash: the transaction hash as a hexadecimal string - height: the height of the block it is in - """ - assert_tx_hash(tx_hash) - result = await self.transaction_get_batch(tx_hash) - if tx_hash not in result or result[tx_hash][1]['block_height'] <= 0: - raise RPCError(BAD_REQUEST, f'tx hash {tx_hash} not in ' - f'block at height {height:,d}') - return result[tx_hash][1] - - -class LocalRPC(SessionBase): - """A local TCP RPC server session.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.client = 'RPC' - self.connection._max_response_size = 0 - - def protocol_version_string(self): - return 'RPC' - - -def get_from_possible_keys(dictionary, *keys): - for key in keys: - if key in dictionary: - return dictionary[key] - - -def format_release_time(release_time): - # round release time to 1000 so it caches better - # also set a default so we dont show claims in the future - def roundup_time(number, factor=360): - return int(1 + int(number / factor)) * factor - if isinstance(release_time, str) and len(release_time) > 0: - time_digits = ''.join(filter(str.isdigit, release_time)) - time_prefix = release_time[:-len(time_digits)] - return time_prefix + str(roundup_time(int(time_digits))) - elif isinstance(release_time, int): - return roundup_time(release_time) diff --git a/lbry/wallet/server/tx.py b/lbry/wallet/server/tx.py deleted file mode 100644 index 33cf3da3a..000000000 --- a/lbry/wallet/server/tx.py +++ /dev/null @@ -1,626 +0,0 @@ -# Copyright (c) 2016-2017, Neil Booth -# Copyright (c) 2017, the ElectrumX authors -# -# All rights reserved. -# -# The MIT License (MIT) -# -# Permission is hereby granted, free of charge, to any person obtaining -# a copy of this software and associated documentation files (the -# "Software"), to deal in the Software without restriction, including -# without limitation the rights to use, copy, modify, merge, publish, -# distribute, sublicense, and/or sell copies of the Software, and to -# permit persons to whom the Software is furnished to do so, subject to -# the following conditions: -# -# The above copyright notice and this permission notice shall be -# included in all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -# and warranty status of this software. - -"""Transaction-related classes and functions.""" -import typing -from collections import namedtuple - -from lbry.wallet.server.hash import sha256, double_sha256, hash_to_hex_str -from lbry.wallet.server.script import OpCodes -from lbry.wallet.server.util import ( - unpack_le_int32_from, unpack_le_int64_from, unpack_le_uint16_from, - unpack_le_uint32_from, unpack_le_uint64_from, pack_le_int32, pack_varint, - pack_le_uint32, pack_le_int64, pack_varbytes, -) - -ZERO = bytes(32) -MINUS_1 = 4294967295 - - -class Tx(typing.NamedTuple): - version: int - inputs: typing.List['TxInput'] - outputs: typing.List['TxOutput'] - locktime: int - raw: bytes - - -class TxInput(typing.NamedTuple): - prev_hash: bytes - prev_idx: int - script: bytes - sequence: int - - """Class representing a transaction input.""" - def __str__(self): - script = self.script.hex() - prev_hash = hash_to_hex_str(self.prev_hash) - return (f"Input({prev_hash}, {self.prev_idx:d}, script={script}, sequence={self.sequence:d})") - - def is_generation(self): - """Test if an input is generation/coinbase like""" - return self.prev_idx == MINUS_1 and self.prev_hash == ZERO - - def serialize(self): - return b''.join(( - self.prev_hash, - pack_le_uint32(self.prev_idx), - pack_varbytes(self.script), - pack_le_uint32(self.sequence), - )) - - -class TxOutput(typing.NamedTuple): - value: int - pk_script: bytes - - def serialize(self): - return b''.join(( - pack_le_int64(self.value), - pack_varbytes(self.pk_script), - )) - - -class Deserializer: - """Deserializes blocks into transactions. - - External entry points are read_tx(), read_tx_and_hash(), - read_tx_and_vsize() and read_block(). - - This code is performance sensitive as it is executed 100s of - millions of times during sync. - """ - - TX_HASH_FN = staticmethod(double_sha256) - - def __init__(self, binary, start=0): - assert isinstance(binary, bytes) - self.binary = binary - self.binary_length = len(binary) - self.cursor = start - self.flags = 0 - - def read_tx(self): - """Return a deserialized transaction.""" - start = self.cursor - return Tx( - self._read_le_int32(), # version - self._read_inputs(), # inputs - self._read_outputs(), # outputs - self._read_le_uint32(), # locktime - self.binary[start:self.cursor], - ) - - def read_tx_and_hash(self): - """Return a (deserialized TX, tx_hash) pair. - - The hash needs to be reversed for human display; for efficiency - we process it in the natural serialized order. - """ - start = self.cursor - return self.read_tx(), self.TX_HASH_FN(self.binary[start:self.cursor]) - - def read_tx_and_vsize(self): - """Return a (deserialized TX, vsize) pair.""" - return self.read_tx(), self.binary_length - - def read_tx_block(self): - """Returns a list of (deserialized_tx, tx_hash) pairs.""" - read = self.read_tx_and_hash - # Some coins have excess data beyond the end of the transactions - return [read() for _ in range(self._read_varint())] - - def _read_inputs(self): - read_input = self._read_input - return [read_input() for i in range(self._read_varint())] - - def _read_input(self): - return TxInput( - self._read_nbytes(32), # prev_hash - self._read_le_uint32(), # prev_idx - self._read_varbytes(), # script - self._read_le_uint32() # sequence - ) - - def _read_outputs(self): - read_output = self._read_output - return [read_output() for i in range(self._read_varint())] - - def _read_output(self): - return TxOutput( - self._read_le_int64(), # value - self._read_varbytes(), # pk_script - ) - - def _read_byte(self): - cursor = self.cursor - self.cursor += 1 - return self.binary[cursor] - - def _read_nbytes(self, n): - cursor = self.cursor - self.cursor = end = cursor + n - assert self.binary_length >= end - return self.binary[cursor:end] - - def _read_varbytes(self): - return self._read_nbytes(self._read_varint()) - - def _read_varint(self): - n = self.binary[self.cursor] - self.cursor += 1 - if n < 253: - return n - if n == 253: - return self._read_le_uint16() - if n == 254: - return self._read_le_uint32() - return self._read_le_uint64() - - def _read_le_int32(self): - result, = unpack_le_int32_from(self.binary, self.cursor) - self.cursor += 4 - return result - - def _read_le_int64(self): - result, = unpack_le_int64_from(self.binary, self.cursor) - self.cursor += 8 - return result - - def _read_le_uint16(self): - result, = unpack_le_uint16_from(self.binary, self.cursor) - self.cursor += 2 - return result - - def _read_le_uint32(self): - result, = unpack_le_uint32_from(self.binary, self.cursor) - self.cursor += 4 - return result - - def _read_le_uint64(self): - result, = unpack_le_uint64_from(self.binary, self.cursor) - self.cursor += 8 - return result - - -class TxSegWit(namedtuple("Tx", "version marker flag inputs outputs " - "witness locktime raw")): - """Class representing a SegWit transaction.""" - - -class DeserializerSegWit(Deserializer): - - # https://bitcoincore.org/en/segwit_wallet_dev/#transaction-serialization - - def _read_witness(self, fields): - read_witness_field = self._read_witness_field - return [read_witness_field() for i in range(fields)] - - def _read_witness_field(self): - read_varbytes = self._read_varbytes - return [read_varbytes() for i in range(self._read_varint())] - - def _read_tx_parts(self): - """Return a (deserialized TX, tx_hash, vsize) tuple.""" - start = self.cursor - marker = self.binary[self.cursor + 4] - if marker: - tx = super().read_tx() - tx_hash = self.TX_HASH_FN(self.binary[start:self.cursor]) - return tx, tx_hash, self.binary_length - - # Ugh, this is nasty. - version = self._read_le_int32() - orig_ser = self.binary[start:self.cursor] - - marker = self._read_byte() - flag = self._read_byte() - - start = self.cursor - inputs = self._read_inputs() - outputs = self._read_outputs() - orig_ser += self.binary[start:self.cursor] - - base_size = self.cursor - start - witness = self._read_witness(len(inputs)) - - start = self.cursor - locktime = self._read_le_uint32() - orig_ser += self.binary[start:self.cursor] - vsize = (3 * base_size + self.binary_length) // 4 - - return TxSegWit(version, marker, flag, inputs, outputs, witness, - locktime, orig_ser), self.TX_HASH_FN(orig_ser), vsize - - def read_tx(self): - return self._read_tx_parts()[0] - - def read_tx_and_hash(self): - tx, tx_hash, vsize = self._read_tx_parts() - return tx, tx_hash - - def read_tx_and_vsize(self): - tx, tx_hash, vsize = self._read_tx_parts() - return tx, vsize - - -class DeserializerAuxPow(Deserializer): - VERSION_AUXPOW = (1 << 8) - - def read_header(self, height, static_header_size): - """Return the AuxPow block header bytes""" - start = self.cursor - version = self._read_le_uint32() - if version & self.VERSION_AUXPOW: - # We are going to calculate the block size then read it as bytes - self.cursor = start - self.cursor += static_header_size # Block normal header - self.read_tx() # AuxPow transaction - self.cursor += 32 # Parent block hash - merkle_size = self._read_varint() - self.cursor += 32 * merkle_size # Merkle branch - self.cursor += 4 # Index - merkle_size = self._read_varint() - self.cursor += 32 * merkle_size # Chain merkle branch - self.cursor += 4 # Chain index - self.cursor += 80 # Parent block header - header_end = self.cursor - else: - header_end = static_header_size - self.cursor = start - return self._read_nbytes(header_end) - - -class DeserializerAuxPowSegWit(DeserializerSegWit, DeserializerAuxPow): - pass - - -class DeserializerEquihash(Deserializer): - def read_header(self, height, static_header_size): - """Return the block header bytes""" - start = self.cursor - # We are going to calculate the block size then read it as bytes - self.cursor += static_header_size - solution_size = self._read_varint() - self.cursor += solution_size - header_end = self.cursor - self.cursor = start - return self._read_nbytes(header_end) - - -class DeserializerEquihashSegWit(DeserializerSegWit, DeserializerEquihash): - pass - - -class TxJoinSplit(namedtuple("Tx", "version inputs outputs locktime")): - """Class representing a JoinSplit transaction.""" - - -class DeserializerZcash(DeserializerEquihash): - def read_tx(self): - header = self._read_le_uint32() - overwintered = ((header >> 31) == 1) - if overwintered: - version = header & 0x7fffffff - self.cursor += 4 # versionGroupId - else: - version = header - - is_overwinter_v3 = version == 3 - is_sapling_v4 = version == 4 - - base_tx = TxJoinSplit( - version, - self._read_inputs(), # inputs - self._read_outputs(), # outputs - self._read_le_uint32() # locktime - ) - - if is_overwinter_v3 or is_sapling_v4: - self.cursor += 4 # expiryHeight - - has_shielded = False - if is_sapling_v4: - self.cursor += 8 # valueBalance - shielded_spend_size = self._read_varint() - self.cursor += shielded_spend_size * 384 # vShieldedSpend - shielded_output_size = self._read_varint() - self.cursor += shielded_output_size * 948 # vShieldedOutput - has_shielded = shielded_spend_size > 0 or shielded_output_size > 0 - - if base_tx.version >= 2: - joinsplit_size = self._read_varint() - if joinsplit_size > 0: - joinsplit_desc_len = 1506 + (192 if is_sapling_v4 else 296) - # JSDescription - self.cursor += joinsplit_size * joinsplit_desc_len - self.cursor += 32 # joinSplitPubKey - self.cursor += 64 # joinSplitSig - - if is_sapling_v4 and has_shielded: - self.cursor += 64 # bindingSig - - return base_tx - - -class TxTime(namedtuple("Tx", "version time inputs outputs locktime")): - """Class representing transaction that has a time field.""" - - -class DeserializerTxTime(Deserializer): - def read_tx(self): - return TxTime( - self._read_le_int32(), # version - self._read_le_uint32(), # time - self._read_inputs(), # inputs - self._read_outputs(), # outputs - self._read_le_uint32(), # locktime - ) - - -class DeserializerReddcoin(Deserializer): - def read_tx(self): - version = self._read_le_int32() - inputs = self._read_inputs() - outputs = self._read_outputs() - locktime = self._read_le_uint32() - if version > 1: - time = self._read_le_uint32() - else: - time = 0 - - return TxTime(version, time, inputs, outputs, locktime) - - -class DeserializerTxTimeAuxPow(DeserializerTxTime): - VERSION_AUXPOW = (1 << 8) - - def is_merged_block(self): - start = self.cursor - self.cursor = 0 - version = self._read_le_uint32() - self.cursor = start - if version & self.VERSION_AUXPOW: - return True - return False - - def read_header(self, height, static_header_size): - """Return the AuxPow block header bytes""" - start = self.cursor - version = self._read_le_uint32() - if version & self.VERSION_AUXPOW: - # We are going to calculate the block size then read it as bytes - self.cursor = start - self.cursor += static_header_size # Block normal header - self.read_tx() # AuxPow transaction - self.cursor += 32 # Parent block hash - merkle_size = self._read_varint() - self.cursor += 32 * merkle_size # Merkle branch - self.cursor += 4 # Index - merkle_size = self._read_varint() - self.cursor += 32 * merkle_size # Chain merkle branch - self.cursor += 4 # Chain index - self.cursor += 80 # Parent block header - header_end = self.cursor - else: - header_end = static_header_size - self.cursor = start - return self._read_nbytes(header_end) - - -class DeserializerBitcoinAtom(DeserializerSegWit): - FORK_BLOCK_HEIGHT = 505888 - - def read_header(self, height, static_header_size): - """Return the block header bytes""" - header_len = static_header_size - if height >= self.FORK_BLOCK_HEIGHT: - header_len += 4 # flags - return self._read_nbytes(header_len) - - -class DeserializerGroestlcoin(DeserializerSegWit): - TX_HASH_FN = staticmethod(sha256) - - -class TxInputTokenPay(TxInput): - """Class representing a TokenPay transaction input.""" - - OP_ANON_MARKER = 0xb9 - # 2byte marker (cpubkey + sigc + sigr) - MIN_ANON_IN_SIZE = 2 + (33 + 32 + 32) - - def _is_anon_input(self): - return (len(self.script) >= self.MIN_ANON_IN_SIZE and - self.script[0] == OpCodes.OP_RETURN and - self.script[1] == self.OP_ANON_MARKER) - - def is_generation(self): - # Transactions coming in from stealth addresses are seen by - # the blockchain as newly minted coins. The reverse, where coins - # are sent TO a stealth address, are seen by the blockchain as - # a coin burn. - if self._is_anon_input(): - return True - return super().is_generation() - - -class TxInputTokenPayStealth( - namedtuple("TxInput", "keyimage ringsize script sequence")): - """Class representing a TokenPay stealth transaction input.""" - - def __str__(self): - script = self.script.hex() - keyimage = bytes(self.keyimage).hex() - return (f"Input({keyimage}, {self.ringsize[1]:d}, script={script}, sequence={self.sequence:d})") - - def is_generation(self): - return True - - def serialize(self): - return b''.join(( - self.keyimage, - self.ringsize, - pack_varbytes(self.script), - pack_le_uint32(self.sequence), - )) - - -class DeserializerTokenPay(DeserializerTxTime): - - def _read_input(self): - txin = TxInputTokenPay( - self._read_nbytes(32), # prev_hash - self._read_le_uint32(), # prev_idx - self._read_varbytes(), # script - self._read_le_uint32(), # sequence - ) - if txin._is_anon_input(): - # Not sure if this is actually needed, and seems - # extra work for no immediate benefit, but it at - # least correctly represents a stealth input - raw = txin.serialize() - deserializer = Deserializer(raw) - txin = TxInputTokenPayStealth( - deserializer._read_nbytes(33), # keyimage - deserializer._read_nbytes(3), # ringsize - deserializer._read_varbytes(), # script - deserializer._read_le_uint32() # sequence - ) - return txin - - -# Decred -class TxInputDcr(namedtuple("TxInput", "prev_hash prev_idx tree sequence")): - """Class representing a Decred transaction input.""" - - def __str__(self): - prev_hash = hash_to_hex_str(self.prev_hash) - return (f"Input({prev_hash}, {self.prev_idx:d}, tree={self.tree}, sequence={self.sequence:d})") - - def is_generation(self): - """Test if an input is generation/coinbase like""" - return self.prev_idx == MINUS_1 and self.prev_hash == ZERO - - -class TxOutputDcr(namedtuple("TxOutput", "value version pk_script")): - """Class representing a Decred transaction output.""" - pass - - -class TxDcr(namedtuple("Tx", "version inputs outputs locktime expiry " - "witness")): - """Class representing a Decred transaction.""" - - -class DeserializerDecred(Deserializer): - @staticmethod - def blake256(data): - from blake256.blake256 import blake_hash - return blake_hash(data) - - @staticmethod - def blake256d(data): - from blake256.blake256 import blake_hash - return blake_hash(blake_hash(data)) - - def read_tx(self): - return self._read_tx_parts(produce_hash=False)[0] - - def read_tx_and_hash(self): - tx, tx_hash, vsize = self._read_tx_parts() - return tx, tx_hash - - def read_tx_and_vsize(self): - tx, tx_hash, vsize = self._read_tx_parts(produce_hash=False) - return tx, vsize - - def read_tx_block(self): - """Returns a list of (deserialized_tx, tx_hash) pairs.""" - read = self.read_tx_and_hash - txs = [read() for _ in range(self._read_varint())] - stxs = [read() for _ in range(self._read_varint())] - return txs + stxs - - def read_tx_tree(self): - """Returns a list of deserialized_tx without tx hashes.""" - read_tx = self.read_tx - return [read_tx() for _ in range(self._read_varint())] - - def _read_input(self): - return TxInputDcr( - self._read_nbytes(32), # prev_hash - self._read_le_uint32(), # prev_idx - self._read_byte(), # tree - self._read_le_uint32(), # sequence - ) - - def _read_output(self): - return TxOutputDcr( - self._read_le_int64(), # value - self._read_le_uint16(), # version - self._read_varbytes(), # pk_script - ) - - def _read_witness(self, fields): - read_witness_field = self._read_witness_field - assert fields == self._read_varint() - return [read_witness_field() for _ in range(fields)] - - def _read_witness_field(self): - value_in = self._read_le_int64() - block_height = self._read_le_uint32() - block_index = self._read_le_uint32() - script = self._read_varbytes() - return value_in, block_height, block_index, script - - def _read_tx_parts(self, produce_hash=True): - start = self.cursor - version = self._read_le_int32() - inputs = self._read_inputs() - outputs = self._read_outputs() - locktime = self._read_le_uint32() - expiry = self._read_le_uint32() - end_prefix = self.cursor - witness = self._read_witness(len(inputs)) - - if produce_hash: - # TxSerializeNoWitness << 16 == 0x10000 - no_witness_header = pack_le_uint32(0x10000 | (version & 0xffff)) - prefix_tx = no_witness_header + self.binary[start+4:end_prefix] - tx_hash = self.blake256(prefix_tx) - else: - tx_hash = None - - return TxDcr( - version, - inputs, - outputs, - locktime, - expiry, - witness - ), tx_hash, self.cursor - start diff --git a/lbry/wallet/server/util.py b/lbry/wallet/server/util.py deleted file mode 100644 index d78b23bb5..000000000 --- a/lbry/wallet/server/util.py +++ /dev/null @@ -1,361 +0,0 @@ -# Copyright (c) 2016-2017, Neil Booth -# -# All rights reserved. -# -# The MIT License (MIT) -# -# Permission is hereby granted, free of charge, to any person obtaining -# a copy of this software and associated documentation files (the -# "Software"), to deal in the Software without restriction, including -# without limitation the rights to use, copy, modify, merge, publish, -# distribute, sublicense, and/or sell copies of the Software, and to -# permit persons to whom the Software is furnished to do so, subject to -# the following conditions: -# -# The above copyright notice and this permission notice shall be -# included in all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -# and warranty status of this software. - -"""Miscellaneous utility classes and functions.""" - - -import array -import inspect -from ipaddress import ip_address -import logging -import re -import sys -from collections import Container, Mapping -from struct import pack, Struct - -# Logging utilities - - -class ConnectionLogger(logging.LoggerAdapter): - """Prepends a connection identifier to a logging message.""" - def process(self, msg, kwargs): - conn_id = self.extra.get('conn_id', 'unknown') - return f'[{conn_id}] {msg}', kwargs - - -class CompactFormatter(logging.Formatter): - """Strips the module from the logger name to leave the class only.""" - def format(self, record): - record.name = record.name.rpartition('.')[-1] - return super().format(record) - - -def make_logger(name, *, handler, level): - """Return the root ElectrumX logger.""" - logger = logging.getLogger(name) - logger.addHandler(handler) - logger.setLevel(logging.INFO) - logger.propagate = False - return logger - - -def class_logger(path, classname): - """Return a hierarchical logger for a class.""" - return logging.getLogger(path).getChild(classname) - - -# Method decorator. To be used for calculations that will always -# deliver the same result. The method cannot take any arguments -# and should be accessed as an attribute. -class cachedproperty: - - def __init__(self, f): - self.f = f - - def __get__(self, obj, type): - obj = obj or type - value = self.f(obj) - setattr(obj, self.f.__name__, value) - return value - - -def formatted_time(t, sep=' '): - """Return a number of seconds as a string in days, hours, mins and - maybe secs.""" - t = int(t) - fmts = (('{:d}d', 86400), ('{:02d}h', 3600), ('{:02d}m', 60)) - parts = [] - for fmt, n in fmts: - val = t // n - if parts or val: - parts.append(fmt.format(val)) - t %= n - if len(parts) < 3: - parts.append(f'{t:02d}s') - return sep.join(parts) - - -def deep_getsizeof(obj): - """Find the memory footprint of a Python object. - - Based on code from code.tutsplus.com: http://goo.gl/fZ0DXK - - This is a recursive function that drills down a Python object graph - like a dictionary holding nested dictionaries with lists of lists - and tuples and sets. - - The sys.getsizeof function does a shallow size of only. It counts each - object inside a container as pointer only regardless of how big it - really is. - """ - - ids = set() - - def size(o): - if id(o) in ids: - return 0 - - r = sys.getsizeof(o) - ids.add(id(o)) - - if isinstance(o, (str, bytes, bytearray, array.array)): - return r - - if isinstance(o, Mapping): - return r + sum(size(k) + size(v) for k, v in o.items()) - - if isinstance(o, Container): - return r + sum(size(x) for x in o) - - return r - - return size(obj) - - -def subclasses(base_class, strict=True): - """Return a list of subclasses of base_class in its module.""" - def select(obj): - return (inspect.isclass(obj) and issubclass(obj, base_class) and - (not strict or obj != base_class)) - - pairs = inspect.getmembers(sys.modules[base_class.__module__], select) - return [pair[1] for pair in pairs] - - -def chunks(items, size): - """Break up items, an iterable, into chunks of length size.""" - for i in range(0, len(items), size): - yield items[i: i + size] - - -def resolve_limit(limit): - if limit is None: - return -1 - assert isinstance(limit, int) and limit >= 0 - return limit - - -def bytes_to_int(be_bytes): - """Interprets a big-endian sequence of bytes as an integer""" - return int.from_bytes(be_bytes, 'big') - - -def int_to_bytes(value): - """Converts an integer to a big-endian sequence of bytes""" - return value.to_bytes((value.bit_length() + 7) // 8, 'big') - - -def increment_byte_string(bs): - """Return the lexicographically next byte string of the same length. - - Return None if there is none (when the input is all 0xff bytes).""" - for n in range(1, len(bs) + 1): - if bs[-n] != 0xff: - return bs[:-n] + bytes([bs[-n] + 1]) + bytes(n - 1) - return None - - -class LogicalFile: - """A logical binary file split across several separate files on disk.""" - - def __init__(self, prefix, digits, file_size): - digit_fmt = f'{{:0{digits:d}d}}' - self.filename_fmt = prefix + digit_fmt - self.file_size = file_size - - def read(self, start, size=-1): - """Read up to size bytes from the virtual file, starting at offset - start, and return them. - - If size is -1 all bytes are read.""" - parts = [] - while size != 0: - try: - with self.open_file(start, False) as f: - part = f.read(size) - if not part: - break - except FileNotFoundError: - break - parts.append(part) - start += len(part) - if size > 0: - size -= len(part) - return b''.join(parts) - - def write(self, start, b): - """Write the bytes-like object, b, to the underlying virtual file.""" - while b: - size = min(len(b), self.file_size - (start % self.file_size)) - with self.open_file(start, True) as f: - f.write(b if size == len(b) else b[:size]) - b = b[size:] - start += size - - def open_file(self, start, create): - """Open the virtual file and seek to start. Return a file handle. - Raise FileNotFoundError if the file does not exist and create - is False. - """ - file_num, offset = divmod(start, self.file_size) - filename = self.filename_fmt.format(file_num) - f = open_file(filename, create) - f.seek(offset) - return f - - -def open_file(filename, create=False): - """Open the file name. Return its handle.""" - try: - return open(filename, 'rb+') - except FileNotFoundError: - if create: - return open(filename, 'wb+') - raise - - -def open_truncate(filename): - """Open the file name. Return its handle.""" - return open(filename, 'wb+') - - -def address_string(address): - """Return an address as a correctly formatted string.""" - fmt = '{}:{:d}' - host, port = address - try: - host = ip_address(host) - except ValueError: - pass - else: - if host.version == 6: - fmt = '[{}]:{:d}' - return fmt.format(host, port) - -# See http://stackoverflow.com/questions/2532053/validate-a-hostname-string -# Note underscores are valid in domain names, but strictly invalid in host -# names. We ignore that distinction. - - -SEGMENT_REGEX = re.compile("(?!-)[A-Z_\\d-]{1,63}(? 255: - return False - # strip exactly one dot from the right, if present - if hostname and hostname[-1] == ".": - hostname = hostname[:-1] - return all(SEGMENT_REGEX.match(x) for x in hostname.split(".")) - - -def protocol_tuple(s): - """Converts a protocol version number, such as "1.0" to a tuple (1, 0). - - If the version number is bad, (0, ) indicating version 0 is returned.""" - try: - return tuple(int(part) for part in s.split('.')) - except Exception: - return (0, ) - - -def version_string(ptuple): - """Convert a version tuple such as (1, 2) to "1.2". - There is always at least one dot, so (1, ) becomes "1.0".""" - while len(ptuple) < 2: - ptuple += (0, ) - return '.'.join(str(p) for p in ptuple) - - -def protocol_version(client_req, min_tuple, max_tuple): - """Given a client's protocol version string, return a pair of - protocol tuples: - - (negotiated version, client min request) - - If the request is unsupported, the negotiated protocol tuple is - None. - """ - if client_req is None: - client_min = client_max = min_tuple - else: - if isinstance(client_req, list) and len(client_req) == 2: - client_min, client_max = client_req - else: - client_min = client_max = client_req - client_min = protocol_tuple(client_min) - client_max = protocol_tuple(client_max) - - result = min(client_max, max_tuple) - if result < max(client_min, min_tuple) or result == (0, ): - result = None - - return result, client_min - - -struct_le_i = Struct('H') -struct_be_I = Struct('>I') -structB = Struct('B') - -unpack_le_int32_from = struct_le_i.unpack_from -unpack_le_int64_from = struct_le_q.unpack_from -unpack_le_uint16_from = struct_le_H.unpack_from -unpack_le_uint32_from = struct_le_I.unpack_from -unpack_le_uint64_from = struct_le_Q.unpack_from -unpack_be_uint16_from = struct_be_H.unpack_from -unpack_be_uint32_from = struct_be_I.unpack_from - -unpack_be_uint64 = lambda x: int.from_bytes(x, byteorder='big') - -pack_le_int32 = struct_le_i.pack -pack_le_int64 = struct_le_q.pack -pack_le_uint16 = struct_le_H.pack -pack_le_uint32 = struct_le_I.pack -pack_be_uint64 = lambda x: x.to_bytes(8, byteorder='big') -pack_be_uint16 = lambda x: x.to_bytes(2, byteorder='big') -pack_be_uint32 = struct_be_I.pack -pack_byte = structB.pack - -hex_to_bytes = bytes.fromhex - - -def pack_varint(n): - if n < 253: - return pack_byte(n) - if n < 65536: - return pack_byte(253) + pack_le_uint16(n) - if n < 4294967296: - return pack_byte(254) + pack_le_uint32(n) - return pack_byte(255) + pack_le_uint64(n) - - -def pack_varbytes(data): - return pack_varint(len(data)) + data diff --git a/lbry/wallet/server/version.py b/lbry/wallet/server/version.py deleted file mode 100644 index d6044c3a0..000000000 --- a/lbry/wallet/server/version.py +++ /dev/null @@ -1,3 +0,0 @@ -# need this to avoid circular import -PROTOCOL_MIN = (0, 54, 0) -PROTOCOL_MAX = (0, 199, 0) diff --git a/lbry/wallet/server/websocket.py b/lbry/wallet/server/websocket.py deleted file mode 100644 index 9620918cb..000000000 --- a/lbry/wallet/server/websocket.py +++ /dev/null @@ -1,55 +0,0 @@ -import asyncio -from weakref import WeakSet - -from aiohttp.web import Application, AppRunner, WebSocketResponse, TCPSite -from aiohttp.http_websocket import WSMsgType, WSCloseCode - - -class AdminWebSocket: - - def __init__(self, manager): - self.manager = manager - self.app = Application() - self.app['websockets'] = WeakSet() - self.app.router.add_get('/', self.on_connect) - self.app.on_shutdown.append(self.on_shutdown) - self.runner = AppRunner(self.app) - - async def on_status(self, _): - if not self.app['websockets']: - return - self.send_message({ - 'type': 'status', - 'height': self.manager.daemon.cached_height(), - }) - - def send_message(self, msg): - for web_socket in self.app['websockets']: - asyncio.create_task(web_socket.send_json(msg)) - - async def start(self): - await self.runner.setup() - await TCPSite(self.runner, self.manager.env.websocket_host, self.manager.env.websocket_port).start() - - async def stop(self): - await self.runner.cleanup() - - async def on_connect(self, request): - web_socket = WebSocketResponse() - await web_socket.prepare(request) - self.app['websockets'].add(web_socket) - try: - async for msg in web_socket: - if msg.type == WSMsgType.TEXT: - await self.on_status(None) - elif msg.type == WSMsgType.ERROR: - print('web socket connection closed with exception %s' % - web_socket.exception()) - finally: - self.app['websockets'].discard(web_socket) - return web_socket - - @staticmethod - async def on_shutdown(app): - for web_socket in set(app['websockets']): - await web_socket.close(code=WSCloseCode.GOING_AWAY, message='Server shutdown') diff --git a/lbry/wallet/server/udp.py b/lbry/wallet/udp.py similarity index 95% rename from lbry/wallet/server/udp.py rename to lbry/wallet/udp.py index c5520ac6b..c2064c90a 100644 --- a/lbry/wallet/server/udp.py +++ b/lbry/wallet/udp.py @@ -23,7 +23,7 @@ class SPVPing(NamedTuple): pad_bytes: bytes def encode(self): - return struct.pack(b'!lB64s', *self) + return struct.pack(b'!lB64s', *self) # pylint: disable=not-an-iterable @staticmethod def make() -> bytes: @@ -49,7 +49,7 @@ class SPVPong(NamedTuple): country: int def encode(self): - return struct.pack(PONG_ENCODING, *self) + return struct.pack(PONG_ENCODING, *self) # pylint: disable=not-an-iterable @staticmethod def encode_address(address: str): @@ -110,6 +110,7 @@ class SPVServerStatusProtocol(asyncio.DatagramProtocol): self._min_delay = 1 / throttle_reqs_per_sec self._allow_localhost = allow_localhost self._allow_lan = allow_lan + self.closed = asyncio.Event() def update_cached_response(self): self._left_cache, self._right_cache = SPVPong.make_sans_source_address( @@ -160,13 +161,16 @@ class SPVServerStatusProtocol(asyncio.DatagramProtocol): def connection_made(self, transport) -> None: self.transport = transport + self.closed.clear() def connection_lost(self, exc: Optional[Exception]) -> None: self.transport = None + self.closed.set() - def close(self): + async def close(self): if self.transport: self.transport.close() + await self.closed.wait() class StatusServer: @@ -184,9 +188,9 @@ class StatusServer: await loop.create_datagram_endpoint(lambda: self._protocol, (interface, port)) log.info("started udp status server on %s:%i", interface, port) - def stop(self): + async def stop(self): if self.is_running: - self._protocol.close() + await self._protocol.close() self._protocol = None @property diff --git a/scripts/checktrie.py b/scripts/checktrie.py index 98770963a..810267429 100644 --- a/scripts/checktrie.py +++ b/scripts/checktrie.py @@ -40,22 +40,17 @@ def checkrecord(record, expected_winner, expected_claim): async def checkcontrolling(daemon: Daemon, db: SQLDB): - records, claim_ids, names, futs = [], [], [], [] + records, names, futs = [], [], [] for record in db.get_claims('claimtrie.claim_hash as is_controlling, claim.*', is_controlling=True): records.append(record) claim_id = hex_reverted(record['claim_hash']) - claim_ids.append((claim_id,)) - names.append((record['normalized'],)) + names.append((record['normalized'], (claim_id,), "", True)) # last parameter is IncludeValues if len(names) > 50000: - futs.append(daemon._send_vector('getvalueforname', names[:])) - futs.append(daemon._send_vector('getclaimbyid', claim_ids[:])) + futs.append(daemon._send_vector('getclaimsfornamebyid', names)) names.clear() - claim_ids.clear() if names: - futs.append(daemon._send_vector('getvalueforname', names[:])) - futs.append(daemon._send_vector('getclaimbyid', claim_ids[:])) + futs.append(daemon._send_vector('getclaimsfornamebyid', names)) names.clear() - claim_ids.clear() while futs: winners, claims = futs.pop(0), futs.pop(0) diff --git a/scripts/initialize_hub_from_snapshot.sh b/scripts/initialize_hub_from_snapshot.sh index b0c09e1bc..5c7ae4e1b 100755 --- a/scripts/initialize_hub_from_snapshot.sh +++ b/scripts/initialize_hub_from_snapshot.sh @@ -1,12 +1,12 @@ #!/bin/bash -SNAPSHOT_HEIGHT="1049658" +SNAPSHOT_HEIGHT="1072108" HUB_VOLUME_PATH="/var/lib/docker/volumes/${USER}_wallet_server" ES_VOLUME_PATH="/var/lib/docker/volumes/${USER}_es01" -SNAPSHOT_TAR_NAME="wallet_server_snapshot_${SNAPSHOT_HEIGHT}.tar" -ES_SNAPSHOT_TAR_NAME="es_snapshot_${SNAPSHOT_HEIGHT}.tar" +SNAPSHOT_TAR_NAME="wallet_server_snapshot_${SNAPSHOT_HEIGHT}.tar.gz" +ES_SNAPSHOT_TAR_NAME="es_snapshot_${SNAPSHOT_HEIGHT}.tar.gz" SNAPSHOT_URL="https://snapshots.lbry.com/hub/${SNAPSHOT_TAR_NAME}" ES_SNAPSHOT_URL="https://snapshots.lbry.com/hub/${ES_SNAPSHOT_TAR_NAME}" diff --git a/setup.py b/setup.py index da749bfd9..ed59735b1 100644 --- a/setup.py +++ b/setup.py @@ -7,9 +7,11 @@ BASE = os.path.dirname(__file__) with open(os.path.join(BASE, 'README.md'), encoding='utf-8') as fh: long_description = fh.read() -PLYVEL = [] -if sys.platform.startswith('linux'): - PLYVEL.append('plyvel==1.3.0') + +ROCKSDB = [] +if sys.platform.startswith('linux') or sys.platform.startswith('darwin'): + ROCKSDB.append('lbry-rocksdb==0.8.2') + setup( name=__name__, @@ -28,16 +30,14 @@ setup( entry_points={ 'console_scripts': [ 'lbrynet=lbry.extras.cli:main', - 'lbry-hub=lbry.wallet.server.cli:main', - 'orchstr8=lbry.wallet.orchstr8.cli:main', - 'lbry-hub-elastic-sync=lbry.wallet.server.db.elasticsearch.sync:run_elastic_sync' + 'orchstr8=lbry.wallet.orchstr8.cli:main' ], }, install_requires=[ - 'aiohttp==3.5.4', + 'aiohttp==3.7.4', 'aioupnp==0.0.18', 'appdirs==1.4.3', - 'certifi>=2018.11.29', + 'certifi>=2021.10.08', 'colorama==0.3.7', 'distro==1.4.0', 'base58==1.0.0', @@ -49,7 +49,7 @@ setup( 'ecdsa==0.13.3', 'pyyaml==5.3.1', 'docopt==0.6.2', - 'hachoir', + 'hachoir==3.1.2', 'multidict==4.6.1', 'coincurve==15.0.0', 'pbkdf2==1.3', @@ -57,12 +57,13 @@ setup( 'pylru==1.1.0', 'elasticsearch==7.10.1', 'grpcio==1.38.0', - 'filetype==1.0.9' - ] + PLYVEL, + 'filetype==1.0.9', + ] + ROCKSDB, extras_require={ 'torrent': ['lbry-libtorrent'], 'lint': ['pylint==2.10.0'], 'test': ['coverage'], + 'scribe': ['scribe @ git+https://github.com/lbryio/scribe.git'], }, classifiers=[ 'Framework :: AsyncIO', diff --git a/tests/integration/blockchain/test_account_commands.py b/tests/integration/blockchain/test_account_commands.py index 52eb1bfb4..9209fe0c6 100644 --- a/tests/integration/blockchain/test_account_commands.py +++ b/tests/integration/blockchain/test_account_commands.py @@ -103,7 +103,7 @@ class AccountManagement(CommandTestCase): second_account = await self.daemon.jsonrpc_account_create('second account') tx = await self.daemon.jsonrpc_account_send( - '0.05', await self.daemon.jsonrpc_address_unused(account_id=second_account.id) + '0.05', await self.daemon.jsonrpc_address_unused(account_id=second_account.id), blocking=True ) await self.confirm_tx(tx.id) await self.assertOutputAmount(['0.05', '9.949876'], utxo_list()) diff --git a/tests/integration/blockchain/test_blockchain_reorganization.py b/tests/integration/blockchain/test_blockchain_reorganization.py index 621655add..10ae90e6e 100644 --- a/tests/integration/blockchain/test_blockchain_reorganization.py +++ b/tests/integration/blockchain/test_blockchain_reorganization.py @@ -9,7 +9,7 @@ class BlockchainReorganizationTests(CommandTestCase): VERBOSITY = logging.WARN async def assertBlockHash(self, height): - bp = self.conductor.spv_node.server.bp + bp = self.conductor.spv_node.writer def get_txids(): return [ @@ -29,15 +29,16 @@ class BlockchainReorganizationTests(CommandTestCase): self.assertListEqual(block_txs, list(txs.keys()), msg='leveldb/lbrycrd transactions are of order') async def test_reorg(self): - bp = self.conductor.spv_node.server.bp + bp = self.conductor.spv_node.writer bp.reorg_count_metric.set(0) # invalidate current block, move forward 2 height = 206 self.assertEqual(self.ledger.headers.height, height) await self.assertBlockHash(height) - await self.blockchain.invalidate_block((await self.ledger.headers.hash(206)).decode()) + block_hash = (await self.ledger.headers.hash(206)).decode() + await self.blockchain.invalidate_block(block_hash) await self.blockchain.generate(2) - await self.ledger.on_header.where(lambda e: e.height == 207) + await asyncio.wait_for(self.on_header(207), 3.0) self.assertEqual(self.ledger.headers.height, 207) await self.assertBlockHash(206) await self.assertBlockHash(207) @@ -46,14 +47,14 @@ class BlockchainReorganizationTests(CommandTestCase): # invalidate current block, move forward 3 await self.blockchain.invalidate_block((await self.ledger.headers.hash(206)).decode()) await self.blockchain.generate(3) - await self.ledger.on_header.where(lambda e: e.height == 208) + await asyncio.wait_for(self.on_header(208), 3.0) self.assertEqual(self.ledger.headers.height, 208) await self.assertBlockHash(206) await self.assertBlockHash(207) await self.assertBlockHash(208) self.assertEqual(2, bp.reorg_count_metric._samples()[0][2]) await self.blockchain.generate(3) - await self.ledger.on_header.where(lambda e: e.height == 211) + await asyncio.wait_for(self.on_header(211), 3.0) await self.assertBlockHash(209) await self.assertBlockHash(210) await self.assertBlockHash(211) @@ -62,7 +63,7 @@ class BlockchainReorganizationTests(CommandTestCase): ) await self.ledger.wait(still_valid) await self.blockchain.generate(1) - await self.ledger.on_header.where(lambda e: e.height == 212) + await asyncio.wait_for(self.on_header(212), 1.0) claim_id = still_valid.outputs[0].claim_id c1 = (await self.resolve(f'still-valid#{claim_id}'))['claim_id'] c2 = (await self.resolve(f'still-valid#{claim_id[:2]}'))['claim_id'] @@ -71,7 +72,7 @@ class BlockchainReorganizationTests(CommandTestCase): abandon_tx = await self.daemon.jsonrpc_stream_abandon(claim_id=claim_id) await self.blockchain.generate(1) - await self.ledger.on_header.where(lambda e: e.height == 213) + await asyncio.wait_for(self.on_header(213), 1.0) c1 = await self.resolve(f'still-valid#{still_valid.outputs[0].claim_id}') c2 = await self.daemon.jsonrpc_resolve([f'still-valid#{claim_id[:2]}']) c3 = await self.daemon.jsonrpc_resolve([f'still-valid']) @@ -112,11 +113,10 @@ class BlockchainReorganizationTests(CommandTestCase): # reorg the last block dropping our claim tx await self.blockchain.invalidate_block(invalidated_block_hash) - await self.blockchain.clear_mempool() + await self.conductor.clear_mempool() await self.blockchain.generate(2) - - # wait for the client to catch up and verify the reorg await asyncio.wait_for(self.on_header(209), 3.0) + await self.assertBlockHash(207) await self.assertBlockHash(208) await self.assertBlockHash(209) @@ -142,9 +142,8 @@ class BlockchainReorganizationTests(CommandTestCase): # broadcast the claim in a different block new_txid = await self.blockchain.sendrawtransaction(hexlify(broadcast_tx.raw).decode()) self.assertEqual(broadcast_tx.id, new_txid) - await self.blockchain.generate(1) - # wait for the client to catch up + await self.blockchain.generate(1) await asyncio.wait_for(self.on_header(210), 1.0) # verify the claim is in the new block and that it is returned by claim_search @@ -191,7 +190,7 @@ class BlockchainReorganizationTests(CommandTestCase): # reorg the last block dropping our claim tx await self.blockchain.invalidate_block(invalidated_block_hash) - await self.blockchain.clear_mempool() + await self.conductor.clear_mempool() await self.blockchain.generate(2) # wait for the client to catch up and verify the reorg @@ -222,8 +221,6 @@ class BlockchainReorganizationTests(CommandTestCase): new_txid = await self.blockchain.sendrawtransaction(hexlify(broadcast_tx.raw).decode()) self.assertEqual(broadcast_tx.id, new_txid) await self.blockchain.generate(1) - - # wait for the client to catch up await asyncio.wait_for(self.on_header(210), 1.0) # verify the claim is in the new block and that it is returned by claim_search diff --git a/tests/integration/blockchain/test_network.py b/tests/integration/blockchain/test_network.py index e5cc725cc..0b9fcac68 100644 --- a/tests/integration/blockchain/test_network.py +++ b/tests/integration/blockchain/test_network.py @@ -1,13 +1,16 @@ import asyncio +import scribe -import lbry from unittest.mock import Mock +from scribe.blockchain.network import LBCRegTest +from scribe.hub.udp import StatusServer +from scribe.hub.session import LBRYElectrumX + from lbry.wallet.network import Network from lbry.wallet.orchstr8 import Conductor from lbry.wallet.orchstr8.node import SPVNode from lbry.wallet.rpc import RPCSession -from lbry.wallet.server.udp import StatusServer from lbry.testcase import IntegrationTestCase, AsyncioTestCase from lbry.conf import Config @@ -22,7 +25,7 @@ class NetworkTests(IntegrationTestCase): async def test_server_features(self): self.assertDictEqual({ - 'genesis_hash': self.conductor.spv_node.coin_class.GENESIS_HASH, + 'genesis_hash': LBCRegTest.GENESIS_HASH, 'hash_function': 'sha256', 'hosts': {}, 'protocol_max': '0.199.0', @@ -32,22 +35,27 @@ class NetworkTests(IntegrationTestCase): 'payment_address': '', 'donation_address': '', 'daily_fee': '0', - 'server_version': lbry.__version__, + 'server_version': scribe.__version__, 'trending_algorithm': 'fast_ar', }, await self.ledger.network.get_server_features()) # await self.conductor.spv_node.stop() payment_address, donation_address = await self.account.get_addresses(limit=2) + + original_address = self.conductor.spv_node.server.env.payment_address + original_donation_address = self.conductor.spv_node.server.env.donation_address + original_description = self.conductor.spv_node.server.env.description + original_daily_fee = self.conductor.spv_node.server.env.daily_fee + self.conductor.spv_node.server.env.payment_address = payment_address self.conductor.spv_node.server.env.donation_address = donation_address self.conductor.spv_node.server.env.description = 'Fastest server in the west.' self.conductor.spv_node.server.env.daily_fee = '42' - from lbry.wallet.server.session import LBRYElectrumX LBRYElectrumX.set_server_features(self.conductor.spv_node.server.env) # await self.ledger.network.on_connected.first self.assertDictEqual({ - 'genesis_hash': self.conductor.spv_node.coin_class.GENESIS_HASH, + 'genesis_hash': LBCRegTest.GENESIS_HASH, 'hash_function': 'sha256', 'hosts': {}, 'protocol_max': '0.199.0', @@ -57,16 +65,23 @@ class NetworkTests(IntegrationTestCase): 'payment_address': payment_address, 'donation_address': donation_address, 'daily_fee': '42', - 'server_version': lbry.__version__, + 'server_version': scribe.__version__, 'trending_algorithm': 'fast_ar', }, await self.ledger.network.get_server_features()) + # cleanup the changes since the attributes are set on the class + self.conductor.spv_node.server.env.payment_address = original_address + self.conductor.spv_node.server.env.donation_address = original_donation_address + self.conductor.spv_node.server.env.description = original_description + self.conductor.spv_node.server.env.daily_fee = original_daily_fee + LBRYElectrumX.set_server_features(self.conductor.spv_node.server.env) + class ReconnectTests(IntegrationTestCase): async def test_multiple_servers(self): # we have a secondary node that connects later, so - node2 = SPVNode(self.conductor.spv_module, node_number=2) + node2 = SPVNode(node_number=2) await node2.start(self.blockchain) self.ledger.network.config['explicit_servers'].append((node2.hostname, node2.port)) @@ -86,7 +101,7 @@ class ReconnectTests(IntegrationTestCase): await self.ledger.stop() initial_height = self.ledger.local_height_including_downloaded_height await self.blockchain.generate(100) - while self.conductor.spv_node.server.session_mgr.notified_height < initial_height + 99: # off by 1 + while self.conductor.spv_node.server.session_manager.notified_height < initial_height + 99: # off by 1 await asyncio.sleep(0.1) self.assertEqual(initial_height, self.ledger.local_height_including_downloaded_height) await self.ledger.headers.open() @@ -101,12 +116,7 @@ class ReconnectTests(IntegrationTestCase): self.ledger.network.client.transport.close() self.assertFalse(self.ledger.network.is_connected) await self.ledger.resolve([], 'derp') - sendtxid = await self.blockchain.send_to_address(address1, 1.1337) - # await self.ledger.resolve([], 'derp') - # self.assertTrue(self.ledger.network.is_connected) - await asyncio.wait_for(self.on_transaction_id(sendtxid), 10.0) # mempool - await self.blockchain.generate(1) - await self.on_transaction_id(sendtxid) # confirmed + sendtxid = await self.send_to_address_and_wait(address1, 1.1337, 1) self.assertLess(self.ledger.network.client.response_time, 1) # response time properly set lower, we are fine await self.assertBalance(self.account, '1.1337') @@ -135,7 +145,7 @@ class ReconnectTests(IntegrationTestCase): await self.conductor.spv_node.stop() self.assertFalse(self.ledger.network.is_connected) await asyncio.sleep(0.2) # let it retry and fail once - await self.conductor.spv_node.start(self.conductor.blockchain_node) + await self.conductor.spv_node.start(self.conductor.lbcwallet_node) await self.ledger.network.on_connected.first self.assertTrue(self.ledger.network.is_connected) @@ -161,15 +171,16 @@ class ReconnectTests(IntegrationTestCase): class UDPServerFailDiscoveryTest(AsyncioTestCase): - async def test_wallet_connects_despite_lack_of_udp(self): conductor = Conductor() conductor.spv_node.udp_port = '0' - await conductor.start_blockchain() - self.addCleanup(conductor.stop_blockchain) + await conductor.start_lbcd() + self.addCleanup(conductor.stop_lbcd) + await conductor.start_lbcwallet() + self.addCleanup(conductor.stop_lbcwallet) await conductor.start_spv() self.addCleanup(conductor.stop_spv) - self.assertFalse(conductor.spv_node.server.bp.status_server.is_running) + self.assertFalse(conductor.spv_node.server.status_server.is_running) await asyncio.wait_for(conductor.start_wallet(), timeout=5) self.addCleanup(conductor.stop_wallet) self.assertTrue(conductor.wallet_node.ledger.network.is_connected) diff --git a/tests/integration/blockchain/test_purchase_command.py b/tests/integration/blockchain/test_purchase_command.py index 64e99a7ac..042ae1341 100644 --- a/tests/integration/blockchain/test_purchase_command.py +++ b/tests/integration/blockchain/test_purchase_command.py @@ -103,7 +103,7 @@ class PurchaseCommandTests(CommandTestCase): # purchase non-existent claim fails with self.assertRaisesRegex(Exception, "Could not find claim with claim_id"): - await self.daemon.jsonrpc_purchase_create('abc123') + await self.daemon.jsonrpc_purchase_create('aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa') # purchase stream with no price fails no_price_stream = await self.priced_stream('no_price_stream', price=None) @@ -174,8 +174,7 @@ class PurchaseCommandTests(CommandTestCase): self.merchant_address = await self.account.receiving.get_or_create_usable_address() daemon2 = await self.add_daemon() address2 = await daemon2.wallet_manager.default_account.receiving.get_or_create_usable_address() - sendtxid = await self.blockchain.send_to_address(address2, 2) - await self.confirm_tx(sendtxid, daemon2.ledger) + await self.send_to_address_and_wait(address2, 2, 1, ledger=daemon2.ledger) stream = await self.priced_stream('a', '1.0') await self.assertBalance(self.account, '9.987893') diff --git a/tests/integration/blockchain/test_sync.py b/tests/integration/blockchain/test_sync.py index 7af2bd1aa..9fc1f2ed2 100644 --- a/tests/integration/blockchain/test_sync.py +++ b/tests/integration/blockchain/test_sync.py @@ -63,7 +63,7 @@ class SyncTests(IntegrationTestCase): await self.assertBalance(account1, '1.0') await self.assertBalance(account2, '1.0') - await self.blockchain.generate(1) + await self.generate(1) # pay 0.01 from main node to receiving node, would have increased change addresses address0 = (await account0.receiving.get_addresses())[0] @@ -79,7 +79,7 @@ class SyncTests(IntegrationTestCase): account1.ledger.wait(tx), account2.ledger.wait(tx), ]) - await self.blockchain.generate(1) + await self.generate(1) await asyncio.wait([ account0.ledger.wait(tx), account1.ledger.wait(tx), @@ -92,7 +92,7 @@ class SyncTests(IntegrationTestCase): await self.assertBalance(account1, '0.989876') await self.assertBalance(account2, '0.989876') - await self.blockchain.generate(1) + await self.generate(1) # create a new mirror node and see if it syncs to same balance from scratch node3 = await self.make_wallet_node(account1.seed) diff --git a/tests/integration/blockchain/test_wallet_commands.py b/tests/integration/blockchain/test_wallet_commands.py index a1bc6a7cf..3de00dd49 100644 --- a/tests/integration/blockchain/test_wallet_commands.py +++ b/tests/integration/blockchain/test_wallet_commands.py @@ -11,7 +11,7 @@ from lbry.wallet.dewies import dict_values_to_lbc class WalletCommands(CommandTestCase): async def test_wallet_create_and_add_subscribe(self): - session = next(iter(self.conductor.spv_node.server.session_mgr.sessions.values())) + session = next(iter(self.conductor.spv_node.server.session_manager.sessions.values())) self.assertEqual(len(session.hashX_subs), 27) wallet = await self.daemon.jsonrpc_wallet_create('foo', create_account=True, single_key=True) self.assertEqual(len(session.hashX_subs), 28) @@ -23,7 +23,7 @@ class WalletCommands(CommandTestCase): async def test_wallet_syncing_status(self): address = await self.daemon.jsonrpc_address_unused() self.assertFalse(self.daemon.jsonrpc_wallet_status()['is_syncing']) - await self.blockchain.send_to_address(address, 1) + await self.send_to_address_and_wait(address, 1) await self.ledger._update_tasks.started.wait() self.assertTrue(self.daemon.jsonrpc_wallet_status()['is_syncing']) await self.ledger._update_tasks.done.wait() @@ -47,9 +47,9 @@ class WalletCommands(CommandTestCase): status = await self.daemon.jsonrpc_status() self.assertEqual(len(status['wallet']['servers']), 1) self.assertEqual(status['wallet']['servers'][0]['port'], 50002) - await self.conductor.spv_node.stop(True) + await self.conductor.spv_node.stop() self.conductor.spv_node.port = 54320 - await self.conductor.spv_node.start(self.conductor.blockchain_node) + await self.conductor.spv_node.start(self.conductor.lbcwallet_node) status = await self.daemon.jsonrpc_status() self.assertEqual(len(status['wallet']['servers']), 0) self.daemon.jsonrpc_settings_set('lbryum_servers', ['localhost:54320']) @@ -59,23 +59,22 @@ class WalletCommands(CommandTestCase): self.assertEqual(status['wallet']['servers'][0]['port'], 54320) async def test_sending_to_scripthash_address(self): - self.assertEqual(await self.blockchain.get_balance(), '95.99973580') + bal = await self.blockchain.get_balance() await self.assertBalance(self.account, '10.0') p2sh_address1 = await self.blockchain.get_new_address(self.blockchain.P2SH_SEGWIT_ADDRESS) tx = await self.account_send('2.0', p2sh_address1) self.assertEqual(tx['outputs'][0]['address'], p2sh_address1) - self.assertEqual(await self.blockchain.get_balance(), '98.99973580') # +1 lbc for confirm block + self.assertEqual(await self.blockchain.get_balance(), str(float(bal)+3)) # +1 lbc for confirm block await self.assertBalance(self.account, '7.999877') await self.wallet_send('3.0', p2sh_address1) - self.assertEqual(await self.blockchain.get_balance(), '102.99973580') # +1 lbc for confirm block + self.assertEqual(await self.blockchain.get_balance(), str(float(bal)+7)) # +1 lbc for confirm block await self.assertBalance(self.account, '4.999754') async def test_balance_caching(self): account2 = await self.daemon.jsonrpc_account_create("Tip-er") address2 = await self.daemon.jsonrpc_address_unused(account2.id) - sendtxid = await self.blockchain.send_to_address(address2, 10) - await self.confirm_tx(sendtxid) - await self.generate(1) + await self.send_to_address_and_wait(address2, 10, 2) + await self.ledger.tasks_are_done() # don't mess with the query count while we need it wallet_balance = self.daemon.jsonrpc_wallet_balance ledger = self.ledger @@ -90,14 +89,16 @@ class WalletCommands(CommandTestCase): self.assertIsNone(ledger._balance_cache.get(self.account.id)) query_count += 2 - self.assertEqual(await wallet_balance(), expected) + balance = await wallet_balance() self.assertEqual(self.ledger.db.db.query_count, query_count) + self.assertEqual(balance, expected) self.assertEqual(dict_values_to_lbc(ledger._balance_cache.get(self.account.id))['total'], '10.0') self.assertEqual(dict_values_to_lbc(ledger._balance_cache.get(account2.id))['total'], '10.0') # calling again uses cache - self.assertEqual(await wallet_balance(), expected) + balance = await wallet_balance() self.assertEqual(self.ledger.db.db.query_count, query_count) + self.assertEqual(balance, expected) self.assertEqual(dict_values_to_lbc(ledger._balance_cache.get(self.account.id))['total'], '10.0') self.assertEqual(dict_values_to_lbc(ledger._balance_cache.get(account2.id))['total'], '10.0') @@ -123,8 +124,7 @@ class WalletCommands(CommandTestCase): wallet2 = await self.daemon.jsonrpc_wallet_create('foo', create_account=True) account3 = wallet2.default_account address3 = await self.daemon.jsonrpc_address_unused(account3.id, wallet2.id) - await self.confirm_tx(await self.blockchain.send_to_address(address3, 1)) - await self.generate(1) + await self.send_to_address_and_wait(address3, 1, 1) account_balance = self.daemon.jsonrpc_account_balance wallet_balance = self.daemon.jsonrpc_wallet_balance @@ -154,7 +154,7 @@ class WalletCommands(CommandTestCase): address2 = await self.daemon.jsonrpc_address_unused(account2.id) # send lbc to someone else - tx = await self.daemon.jsonrpc_account_send('1.0', address2) + tx = await self.daemon.jsonrpc_account_send('1.0', address2, blocking=True) await self.confirm_tx(tx.id) self.assertEqual(await account_balance(), { 'total': '8.97741', @@ -187,7 +187,7 @@ class WalletCommands(CommandTestCase): }) # tip claimed - tx = await self.daemon.jsonrpc_support_abandon(txid=support1['txid'], nout=0) + tx = await self.daemon.jsonrpc_support_abandon(txid=support1['txid'], nout=0, blocking=True) await self.confirm_tx(tx.id) self.assertEqual(await account_balance(), { 'total': '9.277303', @@ -238,8 +238,7 @@ class WalletEncryptionAndSynchronization(CommandTestCase): "carbon smart garage balance margin twelve" ) address = (await self.daemon2.wallet_manager.default_account.receiving.get_addresses(limit=1, only_usable=True))[0] - sendtxid = await self.blockchain.send_to_address(address, 1) - await self.confirm_tx(sendtxid, self.daemon2.ledger) + await self.send_to_address_and_wait(address, 1, 1, ledger=self.daemon2.ledger) def assertWalletEncrypted(self, wallet_path, encrypted): with open(wallet_path) as opened: @@ -294,7 +293,7 @@ class WalletEncryptionAndSynchronization(CommandTestCase): '3056301006072a8648ce3d020106052b8104000a034200049ae7283f3f6723e0a1' '66b7e19e1d1167f6dc5f4af61b4a58066a0d2a8bed2b35c66bccb4ec3eba316b16' 'a97a6d6a4a8effd29d748901bb9789352519cd00b13d' - ), self.daemon2) + ), self.daemon2, blocking=True) await self.confirm_tx(channel['txid'], self.daemon2.ledger) # both daemons will have the channel but only one has the cert so far diff --git a/tests/integration/blockchain/test_wallet_server_sessions.py b/tests/integration/blockchain/test_wallet_server_sessions.py index 139a0bf0b..70e2535d9 100644 --- a/tests/integration/blockchain/test_wallet_server_sessions.py +++ b/tests/integration/blockchain/test_wallet_server_sessions.py @@ -1,12 +1,11 @@ import asyncio -import lbry -import lbry.wallet +import scribe +from scribe.hub.session import LBRYElectrumX + from lbry.error import ServerPaymentFeeAboveMaxAllowedError from lbry.wallet.network import ClientSession from lbry.wallet.rpc import RPCError -from lbry.wallet.server.db.elasticsearch.sync import make_es_index_and_run_sync -from lbry.wallet.server.session import LBRYElectrumX from lbry.testcase import IntegrationTestCase, CommandTestCase from lbry.wallet.orchstr8.node import SPVNode @@ -25,17 +24,17 @@ class TestSessions(IntegrationTestCase): ) await session.create_connection() await session.send_request('server.banner', ()) - self.assertEqual(len(self.conductor.spv_node.server.session_mgr.sessions), 1) + self.assertEqual(len(self.conductor.spv_node.server.session_manager.sessions), 1) self.assertFalse(session.is_closing()) await asyncio.sleep(1.1) with self.assertRaises(asyncio.TimeoutError): await session.send_request('server.banner', ()) self.assertTrue(session.is_closing()) - self.assertEqual(len(self.conductor.spv_node.server.session_mgr.sessions), 0) + self.assertEqual(len(self.conductor.spv_node.server.session_manager.sessions), 0) async def test_proper_version(self): info = await self.ledger.network.get_server_features() - self.assertEqual(lbry.__version__, info['server_version']) + self.assertEqual(scribe.__version__, info['server_version']) async def test_client_errors(self): # Goal is ensuring thsoe are raised and not trapped accidentally @@ -46,7 +45,7 @@ class TestSessions(IntegrationTestCase): class TestUsagePayment(CommandTestCase): - async def _test_single_server_payment(self): + async def test_single_server_payment(self): wallet_pay_service = self.daemon.component_manager.get_component('wallet_server_payments') wallet_pay_service.payment_period = 1 # only starts with a positive max key fee @@ -63,8 +62,8 @@ class TestUsagePayment(CommandTestCase): _, history = await self.ledger.get_local_status_and_history(address) self.assertEqual(history, []) - node = SPVNode(self.conductor.spv_module, node_number=2) - await node.start(self.blockchain, extraconf={"PAYMENT_ADDRESS": address, "DAILY_FEE": "1.1"}) + node = SPVNode(node_number=2) + await node.start(self.blockchain, extraconf={"payment_address": address, "daily_fee": "1.1"}) self.addCleanup(node.stop) self.daemon.jsonrpc_settings_set('lbryum_servers', [f"{node.hostname}:{node.port}"]) await self.daemon.jsonrpc_wallet_reconnect() @@ -90,56 +89,78 @@ class TestUsagePayment(CommandTestCase): class TestESSync(CommandTestCase): async def test_es_sync_utility(self): + es_writer = self.conductor.spv_node.es_writer + server_search_client = self.conductor.spv_node.server.session_manager.search_index + for i in range(10): await self.stream_create(f"stream{i}", bid='0.001') await self.generate(1) self.assertEqual(10, len(await self.claim_search(order_by=['height']))) - db = self.conductor.spv_node.server.db - env = self.conductor.spv_node.server.env - - await db.search_index.delete_index() - db.search_index.clear_caches() - self.assertEqual(0, len(await self.claim_search(order_by=['height']))) - await db.search_index.stop() - - async def resync(): - await db.search_index.start() - db.search_index.clear_caches() - await make_es_index_and_run_sync(env, db=db, index_name=db.search_index.index, force=True) - self.assertEqual(10, len(await self.claim_search(order_by=['height']))) + # delete the index and verify nothing is returned by claim search + await es_writer.delete_index() + server_search_client.clear_caches() self.assertEqual(0, len(await self.claim_search(order_by=['height']))) - await resync() - - # this time we will test a migration from unversioned to v1 - await db.search_index.sync_client.indices.delete_template(db.search_index.index) - await db.search_index.stop() - - await make_es_index_and_run_sync(env, db=db, index_name=db.search_index.index, force=True) - await db.search_index.start() - - await resync() + # reindex, 10 claims should be returned + await es_writer.reindex(force=True) self.assertEqual(10, len(await self.claim_search(order_by=['height']))) + server_search_client.clear_caches() + self.assertEqual(10, len(await self.claim_search(order_by=['height']))) + + # reindex again, this should not appear to do anything but will delete and reinsert the same 10 claims + await es_writer.reindex(force=True) + self.assertEqual(10, len(await self.claim_search(order_by=['height']))) + server_search_client.clear_caches() + self.assertEqual(10, len(await self.claim_search(order_by=['height']))) + + # delete the index again and stop the writer, upon starting it the writer should reindex automatically + await es_writer.delete_index() + await es_writer.stop() + server_search_client.clear_caches() + self.assertEqual(0, len(await self.claim_search(order_by=['height']))) + + await es_writer.start(reindex=True) + self.assertEqual(10, len(await self.claim_search(order_by=['height']))) + + # stop the es writer and advance the chain by 1, adding a new claim. upon resuming the es writer, it should + # add the new claim + await es_writer.stop() + await self.stream_create(f"stream11", bid='0.001', confirm=False) + generate_block_task = asyncio.create_task(self.generate(1)) + await es_writer.start() + await generate_block_task + self.assertEqual(11, len(await self.claim_search(order_by=['height']))) + + + # # this time we will test a migration from unversioned to v1 + # await db.search_index.sync_client.indices.delete_template(db.search_index.index) + # await db.search_index.stop() + # + # await make_es_index_and_run_sync(env, db=db, index_name=db.search_index.index, force=True) + # await db.search_index.start() + # + # await es_writer.reindex() + # self.assertEqual(10, len(await self.claim_search(order_by=['height']))) class TestHubDiscovery(CommandTestCase): async def test_hub_discovery(self): - us_final_node = SPVNode(self.conductor.spv_module, node_number=2) - await us_final_node.start(self.blockchain, extraconf={"COUNTRY": "US"}) + us_final_node = SPVNode(node_number=2) + await us_final_node.start(self.blockchain, extraconf={"country": "US"}) self.addCleanup(us_final_node.stop) final_node_host = f"{us_final_node.hostname}:{us_final_node.port}" - kp_final_node = SPVNode(self.conductor.spv_module, node_number=3) - await kp_final_node.start(self.blockchain, extraconf={"COUNTRY": "KP"}) + kp_final_node = SPVNode(node_number=3) + await kp_final_node.start(self.blockchain, extraconf={"country": "KP"}) self.addCleanup(kp_final_node.stop) kp_final_node_host = f"{kp_final_node.hostname}:{kp_final_node.port}" - relay_node = SPVNode(self.conductor.spv_module, node_number=4) + relay_node = SPVNode(node_number=4) await relay_node.start(self.blockchain, extraconf={ - "COUNTRY": "FR", - "PEER_HUBS": ",".join([kp_final_node_host, final_node_host]) + "country": "FR", + "peer_hubs": ",".join([kp_final_node_host, final_node_host]) }) relay_node_host = f"{relay_node.hostname}:{relay_node.port}" self.addCleanup(relay_node.stop) @@ -186,7 +207,7 @@ class TestHubDiscovery(CommandTestCase): self.daemon.ledger.network.client.server_address_and_port, ('127.0.0.1', kp_final_node.port) ) - kp_final_node.server.session_mgr._notify_peer('127.0.0.1:9988') + kp_final_node.server.session_manager._notify_peer('127.0.0.1:9988') await self.daemon.ledger.network.on_hub.first await asyncio.sleep(0.5) # wait for above event to be processed by other listeners self.assertEqual( diff --git a/tests/integration/claims/test_claim_commands.py b/tests/integration/claims/test_claim_commands.py index 576886376..a9d5a1a22 100644 --- a/tests/integration/claims/test_claim_commands.py +++ b/tests/integration/claims/test_claim_commands.py @@ -12,7 +12,6 @@ from lbry.error import InsufficientFundsError from lbry.extras.daemon.daemon import DEFAULT_PAGE_SIZE from lbry.testcase import CommandTestCase from lbry.wallet.orchstr8.node import SPVNode -from lbry.wallet.server.db.common import STREAM_TYPES from lbry.wallet.transaction import Transaction, Output from lbry.wallet.util import satoshis_to_coins as lbc from lbry.crypto.hash import sha256 @@ -20,6 +19,16 @@ from lbry.crypto.hash import sha256 log = logging.getLogger(__name__) +STREAM_TYPES = { + 'video': 1, + 'audio': 2, + 'image': 3, + 'document': 4, + 'binary': 5, + 'model': 6, +} + + def verify(channel, data, signature, channel_hash=None): pieces = [ signature['signing_ts'].encode(), @@ -125,18 +134,6 @@ class ClaimSearchCommand(ClaimTestCase): with self.assertRaises(ConnectionResetError): await self.claim_search(claim_ids=claim_ids) - async def test_claim_search_as_reader_server(self): - node2 = SPVNode(self.conductor.spv_module, node_number=2) - current_prefix = self.conductor.spv_node.server.bp.env.es_index_prefix - await node2.start(self.blockchain, extraconf={'ES_MODE': 'reader', 'ES_INDEX_PREFIX': current_prefix}) - self.addCleanup(node2.stop) - self.ledger.network.config['default_servers'] = [(node2.hostname, node2.port)] - await self.ledger.stop() - await self.ledger.start() - channel2 = await self.channel_create('@abc', '0.1', allow_duplicate_name=True) - await asyncio.sleep(1) # fixme: find a way to block on the writer - await self.assertFindsClaims([channel2], name='@abc') - async def test_basic_claim_search(self): await self.create_channel() channel_txo = self.channel['outputs'][0] @@ -405,6 +402,18 @@ class ClaimSearchCommand(ClaimTestCase): not_channel_ids=[chan2_id], has_channel_signature=True, valid_channel_signature=True) await match([], not_channel_ids=[chan1_id, chan2_id], has_channel_signature=True, valid_channel_signature=True) + @skip + async def test_no_source_and_valid_channel_signature_and_media_type(self): + await self.channel_create('@spam2', '1.0') + await self.stream_create('barrrrrr', '1.0', channel_name='@spam2', file_path=self.video_file_name) + paradox_no_source_claims = await self.claim_search(has_no_source=True, valid_channel_signature=True, + media_type="video/mp4") + mp4_claims = await self.claim_search(media_type="video/mp4") + no_source_claims = await self.claim_search(has_no_source=True, valid_channel_signature=True) + self.assertEqual(0, len(paradox_no_source_claims)) + self.assertEqual(1, len(no_source_claims)) + self.assertEqual(1, len(mp4_claims)) + async def test_limit_claims_per_channel(self): match = self.assertFindsClaims chan1_id = self.get_claim_id(await self.channel_create('@chan1')) @@ -494,8 +503,7 @@ class ClaimSearchCommand(ClaimTestCase): tx = await Transaction.claim_create( 'unknown', b'{"sources":{"lbry_sd_hash":""}}', 1, address, [self.account], self.account) await tx.sign([self.account]) - await self.broadcast(tx) - await self.confirm_tx(tx.id) + await self.broadcast_and_confirm(tx) octet = await self.stream_create() video = await self.stream_create('chrome', file_path=self.video_file_name) @@ -1226,7 +1234,7 @@ class ChannelCommands(CommandTestCase): data_to_sign = "CAFEBABE" # claim new name await self.channel_create('@someotherchan') - channel_tx = await self.daemon.jsonrpc_channel_create('@signer', '0.1') + channel_tx = await self.daemon.jsonrpc_channel_create('@signer', '0.1', blocking=True) await self.confirm_tx(channel_tx.id) channel = channel_tx.outputs[0] signature1 = await self.out(self.daemon.jsonrpc_channel_sign(channel_name='@signer', hexdata=data_to_sign)) @@ -1373,7 +1381,7 @@ class StreamCommands(ClaimTestCase): self.assertEqual('8.989893', (await self.daemon.jsonrpc_account_balance())['available']) result = await self.out(self.daemon.jsonrpc_account_send( - '5.0', await self.daemon.jsonrpc_address_unused(account2_id) + '5.0', await self.daemon.jsonrpc_address_unused(account2_id), blocking=True )) await self.confirm_tx(result['txid']) @@ -1514,10 +1522,13 @@ class StreamCommands(ClaimTestCase): await self.channel_create('@filtering', '0.1') ) self.conductor.spv_node.server.db.filtering_channel_hashes.add(bytes.fromhex(filtering_channel_id)) - self.assertEqual(0, len(self.conductor.spv_node.server.db.filtered_streams)) - await self.stream_repost(bad_content_id, 'filter1', '0.1', channel_name='@filtering') - self.assertEqual(1, len(self.conductor.spv_node.server.db.filtered_streams)) + self.conductor.spv_node.es_writer.db.filtering_channel_hashes.add(bytes.fromhex(filtering_channel_id)) + self.assertEqual(0, len(self.conductor.spv_node.es_writer.db.filtered_streams)) + await self.stream_repost(bad_content_id, 'filter1', '0.1', channel_name='@filtering') + self.assertEqual(1, len(self.conductor.spv_node.es_writer.db.filtered_streams)) + + self.assertEqual('0.1', (await self.out(self.daemon.jsonrpc_resolve('bad_content')))['bad_content']['amount']) # search for filtered content directly result = await self.out(self.daemon.jsonrpc_claim_search(name='bad_content')) blocked = result['blocked'] @@ -1560,14 +1571,14 @@ class StreamCommands(ClaimTestCase): ) # test setting from env vars and starting from scratch await self.conductor.spv_node.stop(False) - await self.conductor.spv_node.start(self.conductor.blockchain_node, - extraconf={'BLOCKING_CHANNEL_IDS': blocking_channel_id, - 'FILTERING_CHANNEL_IDS': filtering_channel_id}) + await self.conductor.spv_node.start(self.conductor.lbcwallet_node, + extraconf={'blocking_channel_ids': [blocking_channel_id], + 'filtering_channel_ids': [filtering_channel_id]}) await self.daemon.wallet_manager.reset() - self.assertEqual(0, len(self.conductor.spv_node.server.db.blocked_streams)) + self.assertEqual(0, len(self.conductor.spv_node.es_writer.db.blocked_streams)) await self.stream_repost(bad_content_id, 'block1', '0.1', channel_name='@blocking') - self.assertEqual(1, len(self.conductor.spv_node.server.db.blocked_streams)) + self.assertEqual(1, len(self.conductor.spv_node.es_writer.db.blocked_streams)) # blocked content is not resolveable error = (await self.resolve('lbry://@some_channel/bad_content'))['error'] @@ -1626,6 +1637,11 @@ class StreamCommands(ClaimTestCase): self.assertEqual((await self.resolve('lbry://worse_content'))['error']['name'], 'BLOCKED') self.assertEqual((await self.resolve('lbry://@bad_channel/worse_content'))['error']['name'], 'BLOCKED') + await self.stream_update(worse_content_id, channel_name='@bad_channel', tags=['bad-stuff']) + self.assertEqual((await self.resolve('lbry://@bad_channel'))['error']['name'], 'BLOCKED') + self.assertEqual((await self.resolve('lbry://worse_content'))['error']['name'], 'BLOCKED') + self.assertEqual((await self.resolve('lbry://@bad_channel/worse_content'))['error']['name'], 'BLOCKED') + async def test_publish_updates_file_list(self): tx = await self.stream_create(title='created') txo = tx['outputs'][0] @@ -1651,6 +1667,7 @@ class StreamCommands(ClaimTestCase): self.assertEqual(tx['txid'], files[0]['txid']) self.assertEqual(expected, files[0]['metadata']) + @skip async def test_setting_stream_fields(self): values = { 'title': "Cool Content", @@ -1791,22 +1808,35 @@ class StreamCommands(ClaimTestCase): self.assertItemCount(await self.daemon.jsonrpc_claim_list(account_id=self.account.id), 3) self.assertItemCount(await self.daemon.jsonrpc_claim_list(account_id=account2_id), 1) - self.assertEqual(3, len(await self.claim_search(release_time='>0', order_by=['release_time']))) - self.assertEqual(3, len(await self.claim_search(release_time='>=0', order_by=['release_time']))) + self.assertEqual(4, len(await self.claim_search(release_time='>0', order_by=['release_time']))) + self.assertEqual(3, len(await self.claim_search(release_time='>0', order_by=['release_time'], claim_type='stream'))) + + self.assertEqual(4, len(await self.claim_search(release_time='>=0', order_by=['release_time']))) self.assertEqual(4, len(await self.claim_search(order_by=['release_time']))) self.assertEqual(3, len(await self.claim_search(claim_type='stream', order_by=['release_time']))) self.assertEqual(1, len(await self.claim_search(claim_type='channel', order_by=['release_time']))) - self.assertEqual(1, len(await self.claim_search(release_time='>=123456', order_by=['release_time']))) - self.assertEqual(1, len(await self.claim_search(release_time='>123456', order_by=['release_time']))) - self.assertEqual(2, len(await self.claim_search(release_time='<123457', order_by=['release_time']))) + self.assertEqual(2, len(await self.claim_search(release_time='>=123456', order_by=['release_time']))) - self.assertEqual(2, len(await self.claim_search(release_time=['<123457'], order_by=['release_time']))) - self.assertEqual(2, len(await self.claim_search(release_time=['>0', '<123457'], order_by=['release_time']))) + self.assertEqual(1, len(await self.claim_search(release_time='>=123456', order_by=['release_time'], claim_type='stream'))) + + self.assertEqual(1, len(await self.claim_search(release_time='>123456', order_by=['release_time'], claim_type='stream'))) + self.assertEqual(2, len(await self.claim_search(release_time='>123456', order_by=['release_time']))) + + self.assertEqual(3, len(await self.claim_search(release_time='<123457', order_by=['release_time']))) + self.assertEqual(2, len(await self.claim_search(release_time='<123457', order_by=['release_time'], claim_type='stream'))) + + self.assertEqual(2, len(await self.claim_search(release_time=['<123457'], order_by=['release_time'], claim_type='stream'))) + self.assertEqual(3, len(await self.claim_search(release_time=['<123457'], order_by=['release_time']))) + self.assertEqual(3, len(await self.claim_search(release_time=['>0', '<123457'], order_by=['release_time']))) + self.assertEqual(2, len(await self.claim_search(release_time=['>0', '<123457'], order_by=['release_time'], claim_type='stream'))) + self.assertEqual(3, len(await self.claim_search(release_time=['<123457'], order_by=['release_time'], height=['>0']))) + self.assertEqual(4, len(await self.claim_search(order_by=['release_time'], height=['>0']))) + self.assertEqual(4, len(await self.claim_search(order_by=['release_time'], height=['>0'], claim_type=['stream', 'channel']))) self.assertEqual( - 2, len(await self.claim_search(release_time=['>=123097', '<123457'], order_by=['release_time'])) + 3, len(await self.claim_search(release_time=['>=123097', '<123457'], order_by=['release_time'])) ) self.assertEqual( - 2, len(await self.claim_search(release_time=['<123457', '>0'], order_by=['release_time'])) + 3, len(await self.claim_search(release_time=['<123457', '>0'], order_by=['release_time'])) ) async def test_setting_fee_fields(self): @@ -2177,7 +2207,7 @@ class SupportCommands(CommandTestCase): tip = await self.out( self.daemon.jsonrpc_support_create( claim_id, '1.0', True, account_id=account2.id, wallet_id='wallet2', - funding_account_ids=[account2.id]) + funding_account_ids=[account2.id], blocking=True) ) await self.confirm_tx(tip['txid']) @@ -2209,7 +2239,7 @@ class SupportCommands(CommandTestCase): support = await self.out( self.daemon.jsonrpc_support_create( claim_id, '2.0', False, account_id=account2.id, wallet_id='wallet2', - funding_account_ids=[account2.id]) + funding_account_ids=[account2.id], blocking=True) ) await self.confirm_tx(support['txid']) diff --git a/tests/integration/datanetwork/test_file_commands.py b/tests/integration/datanetwork/test_file_commands.py index 6772a1ae5..08cf070c8 100644 --- a/tests/integration/datanetwork/test_file_commands.py +++ b/tests/integration/datanetwork/test_file_commands.py @@ -1,3 +1,4 @@ +import unittest from unittest import skipIf import asyncio import os @@ -36,8 +37,7 @@ class FileCommands(CommandTestCase): tx_to_update.outputs[0], claim, 1, address, [self.account], self.account ) await tx.sign([self.account]) - await self.broadcast(tx) - await self.confirm_tx(tx.id) + await self.broadcast_and_confirm(tx) self.client_session = self.daemon.file_manager.source_managers['torrent'].torrent_session self.client_session._session.add_dht_node(('localhost', 4040)) self.client_session.wait_start = False # fixme: this is super slow on tests @@ -216,6 +216,7 @@ class FileCommands(CommandTestCase): await self.wait_files_to_complete() self.assertNotEqual(first_path, second_path) + @unittest.SkipTest # FIXME: claimname/updateclaim is gone. #3480 wip, unblock #3479" async def test_file_list_updated_metadata_on_resolve(self): await self.stream_create('foo', '0.01') txo = (await self.daemon.resolve(self.wallet.accounts, ['lbry://foo']))['lbry://foo'] @@ -504,8 +505,7 @@ class FileCommands(CommandTestCase): tx.outputs[0].claim.stream.fee.address_bytes = b'' tx.outputs[0].script.generate() await tx.sign([self.account]) - await self.broadcast(tx) - await self.confirm_tx(tx.id) + await self.broadcast_and_confirm(tx) async def __raw_value_update_no_fee_amount(self, tx, claim_address): tx = await self.daemon.jsonrpc_stream_update( @@ -515,8 +515,7 @@ class FileCommands(CommandTestCase): tx.outputs[0].claim.stream.fee.message.ClearField('amount') tx.outputs[0].script.generate() await tx.sign([self.account]) - await self.broadcast(tx) - await self.confirm_tx(tx.id) + await self.broadcast_and_confirm(tx) class DiskSpaceManagement(CommandTestCase): diff --git a/tests/integration/other/test_chris45.py b/tests/integration/other/test_chris45.py index bcdbc290b..0e3f35614 100644 --- a/tests/integration/other/test_chris45.py +++ b/tests/integration/other/test_chris45.py @@ -80,7 +80,7 @@ class EpicAdventuresOfChris45(CommandTestCase): # After some soul searching Chris decides that his story needs more # heart and a better ending. He takes down the story and begins the rewrite. - abandon = await self.out(self.daemon.jsonrpc_stream_abandon(claim_id, blocking=False)) + abandon = await self.out(self.daemon.jsonrpc_stream_abandon(claim_id, blocking=True)) self.assertEqual(abandon['inputs'][0]['claim_id'], claim_id) await self.confirm_tx(abandon['txid']) @@ -103,7 +103,7 @@ class EpicAdventuresOfChris45(CommandTestCase): # 1 LBC to which Chris readily obliges ramsey_account_id = (await self.out(self.daemon.jsonrpc_account_create("Ramsey")))['id'] ramsey_address = await self.daemon.jsonrpc_address_unused(ramsey_account_id) - result = await self.out(self.daemon.jsonrpc_account_send('1.0', ramsey_address)) + result = await self.out(self.daemon.jsonrpc_account_send('1.0', ramsey_address, blocking=True)) self.assertIn("txid", result) await self.confirm_tx(result['txid']) @@ -133,7 +133,7 @@ class EpicAdventuresOfChris45(CommandTestCase): # And voila, and bravo and encore! His Best Friend Ramsey read the story and immediately knew this was a hit # Now to keep this claim winning on the lbry blockchain he immediately supports the claim tx = await self.out(self.daemon.jsonrpc_support_create( - claim_id2, '0.2', account_id=ramsey_account_id + claim_id2, '0.2', account_id=ramsey_account_id, blocking=True )) await self.confirm_tx(tx['txid']) @@ -147,7 +147,7 @@ class EpicAdventuresOfChris45(CommandTestCase): # Now he also wanted to support the original creator of the Award Winning Novel # So he quickly decides to send a tip to him tx = await self.out( - self.daemon.jsonrpc_support_create(claim_id2, '0.3', tip=True, account_id=ramsey_account_id) + self.daemon.jsonrpc_support_create(claim_id2, '0.3', tip=True, account_id=ramsey_account_id, blocking=True) ) await self.confirm_tx(tx['txid']) @@ -158,7 +158,7 @@ class EpicAdventuresOfChris45(CommandTestCase): await self.generate(5) # Seeing the ravishing success of his novel Chris adds support to his claim too - tx = await self.out(self.daemon.jsonrpc_support_create(claim_id2, '0.4')) + tx = await self.out(self.daemon.jsonrpc_support_create(claim_id2, '0.4', blocking=True)) await self.confirm_tx(tx['txid']) # And check if his support showed up @@ -183,7 +183,7 @@ class EpicAdventuresOfChris45(CommandTestCase): # But sadly Ramsey wasn't so pleased. It was hard for him to tell Chris... # Chris, though a bit heartbroken, abandoned the claim for now, but instantly started working on new hit lyrics - abandon = await self.out(self.daemon.jsonrpc_stream_abandon(txid=tx['txid'], nout=0, blocking=False)) + abandon = await self.out(self.daemon.jsonrpc_stream_abandon(txid=tx['txid'], nout=0, blocking=True)) self.assertTrue(abandon['inputs'][0]['txid'], tx['txid']) await self.confirm_tx(abandon['txid']) diff --git a/tests/integration/takeovers/test_resolve_command.py b/tests/integration/takeovers/test_resolve_command.py index 7856270ed..ebef0f917 100644 --- a/tests/integration/takeovers/test_resolve_command.py +++ b/tests/integration/takeovers/test_resolve_command.py @@ -1,6 +1,7 @@ import asyncio import json import hashlib +import sys from bisect import bisect_right from binascii import hexlify, unhexlify from collections import defaultdict @@ -23,7 +24,7 @@ class BaseResolveTestCase(CommandTestCase): def assertMatchESClaim(self, claim_from_es, claim_from_db): self.assertEqual(claim_from_es['claim_hash'][::-1].hex(), claim_from_db.claim_hash.hex()) self.assertEqual(claim_from_es['claim_id'], claim_from_db.claim_hash.hex()) - self.assertEqual(claim_from_es['activation_height'], claim_from_db.activation_height) + self.assertEqual(claim_from_es['activation_height'], claim_from_db.activation_height, f"es height: {claim_from_es['activation_height']}, rocksdb height: {claim_from_db.activation_height}") self.assertEqual(claim_from_es['last_take_over_height'], claim_from_db.last_takeover_height) self.assertEqual(claim_from_es['tx_id'], claim_from_db.tx_hash[::-1].hex()) self.assertEqual(claim_from_es['tx_nout'], claim_from_db.position) @@ -31,125 +32,151 @@ class BaseResolveTestCase(CommandTestCase): self.assertEqual(claim_from_es['effective_amount'], claim_from_db.effective_amount) def assertMatchDBClaim(self, expected, claim): - self.assertEqual(expected['claimId'], claim.claim_hash.hex()) - self.assertEqual(expected['validAtHeight'], claim.activation_height) - self.assertEqual(expected['lastTakeoverHeight'], claim.last_takeover_height) - self.assertEqual(expected['txId'], claim.tx_hash[::-1].hex()) + self.assertEqual(expected['claimid'], claim.claim_hash.hex()) + self.assertEqual(expected['validatheight'], claim.activation_height) + self.assertEqual(expected['lasttakeoverheight'], claim.last_takeover_height) + self.assertEqual(expected['txid'], claim.tx_hash[::-1].hex()) self.assertEqual(expected['n'], claim.position) self.assertEqual(expected['amount'], claim.amount) - self.assertEqual(expected['effectiveAmount'], claim.effective_amount) + self.assertEqual(expected['effectiveamount'], claim.effective_amount) async def assertResolvesToClaimId(self, name, claim_id): other = await self.resolve(name) if claim_id is None: self.assertIn('error', other) self.assertEqual(other['error']['name'], 'NOT_FOUND') - claims_from_es = (await self.conductor.spv_node.server.bp.db.search_index.search(name=name))[0] + claims_from_es = (await self.conductor.spv_node.server.session_manager.search_index.search(name=name))[0] claims_from_es = [c['claim_hash'][::-1].hex() for c in claims_from_es] self.assertNotIn(claim_id, claims_from_es) else: - claim_from_es = await self.conductor.spv_node.server.bp.db.search_index.search(claim_id=claim_id) + claim_from_es = await self.conductor.spv_node.server.session_manager.search_index.search(claim_id=claim_id) self.assertEqual(claim_id, other['claim_id']) self.assertEqual(claim_id, claim_from_es[0][0]['claim_hash'][::-1].hex()) async def assertNoClaimForName(self, name: str): - lbrycrd_winning = json.loads(await self.blockchain._cli_cmnd('getvalueforname', name)) - stream, channel, _, _ = await self.conductor.spv_node.server.bp.db.resolve(name) - self.assertNotIn('claimId', lbrycrd_winning) + lbrycrd_winning = json.loads(await self.blockchain._cli_cmnd('getclaimsforname', name)) + stream, channel, _, _ = await self.conductor.spv_node.server.db.resolve(name) + if 'claims' in lbrycrd_winning and lbrycrd_winning['claims'] is not None: + self.assertEqual(len(lbrycrd_winning['claims']), 0) if stream is not None: self.assertIsInstance(stream, LookupError) else: self.assertIsInstance(channel, LookupError) - claim_from_es = await self.conductor.spv_node.server.bp.db.search_index.search(name=name) + claim_from_es = await self.conductor.spv_node.server.session_manager.search_index.search(name=name) self.assertListEqual([], claim_from_es[0]) - async def assertNoClaim(self, claim_id: str): - self.assertDictEqual( - {}, json.loads(await self.blockchain._cli_cmnd('getclaimbyid', claim_id)) - ) - claim_from_es = await self.conductor.spv_node.server.bp.db.search_index.search(claim_id=claim_id) + async def assertNoClaim(self, name: str, claim_id: str): + expected = json.loads(await self.blockchain._cli_cmnd('getclaimsfornamebyid', name, '["' + claim_id + '"]')) + if 'claims' in expected and expected['claims'] is not None: + # ensure that if we do have the matching claim that it is not active + self.assertEqual(expected['claims'][0]['effectiveamount'], 0) + + claim_from_es = await self.conductor.spv_node.server.session_manager.search_index.search(claim_id=claim_id) self.assertListEqual([], claim_from_es[0]) - claim = await self.conductor.spv_node.server.bp.db.fs_getclaimbyid(claim_id) + claim = await self.conductor.spv_node.server.db.fs_getclaimbyid(claim_id) self.assertIsNone(claim) async def assertMatchWinningClaim(self, name): - expected = json.loads(await self.blockchain._cli_cmnd('getvalueforname', name)) - stream, channel, _, _ = await self.conductor.spv_node.server.bp.db.resolve(name) + expected = json.loads(await self.blockchain._cli_cmnd('getclaimsfornamebybid', name, "[0]")) + stream, channel, _, _ = await self.conductor.spv_node.server.db.resolve(name) claim = stream if stream else channel - await self._assertMatchClaim(expected, claim) + expected['claims'][0]['lasttakeoverheight'] = expected['lasttakeoverheight'] + await self._assertMatchClaim(expected['claims'][0], claim) return claim async def _assertMatchClaim(self, expected, claim): self.assertMatchDBClaim(expected, claim) - claim_from_es = await self.conductor.spv_node.server.bp.db.search_index.search( + claim_from_es = await self.conductor.spv_node.server.session_manager.search_index.search( claim_id=claim.claim_hash.hex() ) self.assertEqual(len(claim_from_es[0]), 1) self.assertMatchESClaim(claim_from_es[0][0], claim) - self._check_supports(claim.claim_hash.hex(), expected['supports'], claim_from_es[0][0]['support_amount']) + self._check_supports(claim.claim_hash.hex(), expected.get('supports', []), + claim_from_es[0][0]['support_amount']) - async def assertMatchClaim(self, claim_id, is_active_in_lbrycrd=True): - expected = json.loads(await self.blockchain._cli_cmnd('getclaimbyid', claim_id)) - claim = await self.conductor.spv_node.server.bp.db.fs_getclaimbyid(claim_id) - if is_active_in_lbrycrd: - if not expected: - self.assertIsNone(claim) - return - self.assertMatchDBClaim(expected, claim) - else: - self.assertDictEqual({}, expected) - claim_from_es = await self.conductor.spv_node.server.bp.db.search_index.search( + async def assertMatchClaim(self, name, claim_id, is_active_in_lbrycrd=True): + claim = await self.conductor.spv_node.server.db.fs_getclaimbyid(claim_id) + claim_from_es = await self.conductor.spv_node.server.session_manager.search_index.search( claim_id=claim.claim_hash.hex() ) self.assertEqual(len(claim_from_es[0]), 1) self.assertEqual(claim_from_es[0][0]['claim_hash'][::-1].hex(), claim.claim_hash.hex()) self.assertMatchESClaim(claim_from_es[0][0], claim) - self._check_supports( - claim.claim_hash.hex(), expected.get('supports', []), claim_from_es[0][0]['support_amount'], - is_active_in_lbrycrd - ) + + expected = json.loads(await self.blockchain._cli_cmnd('getclaimsfornamebyid', name, '["' + claim_id + '"]')) + if is_active_in_lbrycrd: + if not expected: + self.assertIsNone(claim) + return + expected['claims'][0]['lasttakeoverheight'] = expected['lasttakeoverheight'] + self.assertMatchDBClaim(expected['claims'][0], claim) + self._check_supports(claim.claim_hash.hex(), expected['claims'][0].get('supports', []), + claim_from_es[0][0]['support_amount']) + else: + if 'claims' in expected and expected['claims'] is not None: + # ensure that if we do have the matching claim that it is not active + self.assertEqual(expected['claims'][0]['effectiveamount'], 0) return claim async def assertMatchClaimIsWinning(self, name, claim_id): self.assertEqual(claim_id, (await self.assertMatchWinningClaim(name)).claim_hash.hex()) await self.assertMatchClaimsForName(name) - def _check_supports(self, claim_id, lbrycrd_supports, es_support_amount, is_active_in_lbrycrd=True): - total_amount = 0 - db = self.conductor.spv_node.server.bp.db + def _check_supports(self, claim_id, lbrycrd_supports, es_support_amount): + total_lbrycrd_amount = 0.0 + total_es_amount = 0.0 + active_es_amount = 0.0 + db = self.conductor.spv_node.server.db + es_supports = db.get_supports(bytes.fromhex(claim_id)) - for i, (tx_num, position, amount) in enumerate(db.get_supports(bytes.fromhex(claim_id))): - total_amount += amount - if is_active_in_lbrycrd: - support = lbrycrd_supports[i] - self.assertEqual(support['txId'], db.prefix_db.tx_hash.get(tx_num, deserialize_value=False)[::-1].hex()) - self.assertEqual(support['n'], position) - self.assertEqual(support['height'], bisect_right(db.tx_counts, tx_num)) - self.assertEqual(support['validAtHeight'], db.get_activation(tx_num, position, is_support=True)) - self.assertEqual(total_amount, es_support_amount, f"lbrycrd support amount: {total_amount} vs es: {es_support_amount}") + # we're only concerned about active supports here, and they should match + self.assertTrue(len(es_supports) >= len(lbrycrd_supports)) + + for i, (tx_num, position, amount) in enumerate(es_supports): + total_es_amount += amount + valid_height = db.get_activation(tx_num, position, is_support=True) + if valid_height > db.db_height: + continue + active_es_amount += amount + txid = db.prefix_db.tx_hash.get(tx_num, deserialize_value=False)[::-1].hex() + support = next(filter(lambda s: s['txid'] == txid and s['n'] == position, lbrycrd_supports)) + total_lbrycrd_amount += support['amount'] + self.assertEqual(support['height'], bisect_right(db.tx_counts, tx_num)) + self.assertEqual(support['validatheight'], valid_height) + + self.assertEqual(total_es_amount, es_support_amount) + self.assertEqual(active_es_amount, total_lbrycrd_amount) async def assertMatchClaimsForName(self, name): - expected = json.loads(await self.blockchain._cli_cmnd('getclaimsforname', name)) - - db = self.conductor.spv_node.server.bp.db - # self.assertEqual(len(expected['claims']), len(db_claims.claims)) - # self.assertEqual(expected['lastTakeoverHeight'], db_claims.lastTakeoverHeight) - last_takeover = json.loads(await self.blockchain._cli_cmnd('getvalueforname', name))['lastTakeoverHeight'] + expected = json.loads(await self.blockchain._cli_cmnd('getclaimsforname', name, "", "true")) + db = self.conductor.spv_node.server.db for c in expected['claims']: - c['lastTakeoverHeight'] = last_takeover - claim_id = c['claimId'] + c['lasttakeoverheight'] = expected['lasttakeoverheight'] + claim_id = c['claimid'] claim_hash = bytes.fromhex(claim_id) claim = db._fs_get_claim_by_hash(claim_hash) self.assertMatchDBClaim(c, claim) - claim_from_es = await self.conductor.spv_node.server.bp.db.search_index.search( - claim_id=c['claimId'] + claim_from_es = await self.conductor.spv_node.server.session_manager.search_index.search( + claim_id=claim_id ) self.assertEqual(len(claim_from_es[0]), 1) - self.assertEqual(claim_from_es[0][0]['claim_hash'][::-1].hex(), c['claimId']) + self.assertEqual(claim_from_es[0][0]['claim_hash'][::-1].hex(), claim_id) self.assertMatchESClaim(claim_from_es[0][0], claim) - self._check_supports(c['claimId'], c['supports'], claim_from_es[0][0]['support_amount']) + self._check_supports(claim_id, c.get('supports', []), + claim_from_es[0][0]['support_amount']) + + async def assertNameState(self, height: int, name: str, winning_claim_id: str, last_takeover_height: int, + non_winning_claims: List[ClaimStateValue]): + self.assertEqual(height, self.conductor.spv_node.server.db.db_height) + await self.assertMatchClaimIsWinning(name, winning_claim_id) + for non_winning in non_winning_claims: + claim = await self.assertMatchClaim( + name, non_winning.claim_id, is_active_in_lbrycrd=non_winning.active_in_lbrycrd + ) + self.assertEqual(non_winning.activation_height, claim.activation_height) + self.assertEqual(last_takeover_height, claim.last_takeover_height) class ResolveCommand(BaseResolveTestCase): @@ -261,19 +288,20 @@ class ResolveCommand(BaseResolveTestCase): tx_details = await self.blockchain.get_raw_transaction(claim['txid']) self.assertEqual(claim['confirmations'], json.loads(tx_details)['confirmations']) + # FIXME : claimname/updateclaim is gone. #3480 wip, unblock #3479" # resolve handles invalid data - await self.blockchain_claim_name("gibberish", hexlify(b"{'invalid':'json'}").decode(), "0.1") - await self.generate(1) - response = await self.out(self.daemon.jsonrpc_resolve("lbry://gibberish")) - self.assertSetEqual({'lbry://gibberish'}, set(response)) - claim = response['lbry://gibberish'] - self.assertEqual(claim['name'], 'gibberish') - self.assertNotIn('value', claim) + # await self.blockchain_claim_name("gibberish", hexlify(b"{'invalid':'json'}").decode(), "0.1") + # await self.generate(1) + # response = await self.out(self.daemon.jsonrpc_resolve("lbry://gibberish")) + # self.assertSetEqual({'lbry://gibberish'}, set(response)) + # claim = response['lbry://gibberish'] + # self.assertEqual(claim['name'], 'gibberish') + # self.assertNotIn('value', claim) # resolve retries await self.conductor.spv_node.stop() resolve_task = asyncio.create_task(self.resolve('foo')) - await self.conductor.spv_node.start(self.conductor.blockchain_node) + await self.conductor.spv_node.start(self.conductor.lbcwallet_node) self.assertIsNotNone((await resolve_task)['claim_id']) async def test_winning_by_effective_amount(self): @@ -443,16 +471,16 @@ class ResolveCommand(BaseResolveTestCase): self.assertEqual(one, claim6['name']) async def test_resolve_old_claim(self): - channel = await self.daemon.jsonrpc_channel_create('@olds', '1.0') + channel = await self.daemon.jsonrpc_channel_create('@olds', '1.0', blocking=True) await self.confirm_tx(channel.id) address = channel.outputs[0].get_address(self.account.ledger) claim = generate_signed_legacy(address, channel.outputs[0]) tx = await Transaction.claim_create('example', claim.SerializeToString(), 1, address, [self.account], self.account) await tx.sign([self.account]) - await self.broadcast(tx) - await self.confirm_tx(tx.id) + await self.broadcast_and_confirm(tx) response = await self.resolve('@olds/example') + self.assertTrue('is_channel_signature_valid' in response, str(response)) self.assertTrue(response['is_channel_signature_valid']) claim.publisherSignature.signature = bytes(reversed(claim.publisherSignature.signature)) @@ -460,8 +488,7 @@ class ResolveCommand(BaseResolveTestCase): 'bad_example', claim.SerializeToString(), 1, address, [self.account], self.account ) await tx.sign([self.account]) - await self.broadcast(tx) - await self.confirm_tx(tx.id) + await self.broadcast_and_confirm(tx) response = await self.resolve('bad_example') self.assertFalse(response['is_channel_signature_valid']) @@ -606,6 +633,12 @@ class ResolveClaimTakeovers(BaseResolveTestCase): self.assertDictEqual(await self.resolve('@other/signed4'), await self.resolve('signed4')) + self.assertEqual(2, len(await self.claim_search(channel_ids=[channel_id2]))) + + await self.channel_update(channel_id2) + await make_claim('third_signed', '0.01', channel_id=channel_id2) + self.assertEqual(3, len(await self.claim_search(channel_ids=[channel_id2]))) + async def _test_activation_delay(self): name = 'derp' # initially claim the name @@ -643,10 +676,10 @@ class ResolveClaimTakeovers(BaseResolveTestCase): async def assertNameState(self, height: int, name: str, winning_claim_id: str, last_takeover_height: int, non_winning_claims: List[ClaimStateValue]): - self.assertEqual(height, self.conductor.spv_node.server.bp.db.db_height) + self.assertEqual(height, self.conductor.spv_node.server.db.db_height) await self.assertMatchClaimIsWinning(name, winning_claim_id) for non_winning in non_winning_claims: - claim = await self.assertMatchClaim( + claim = await self.assertMatchClaim(name, non_winning.claim_id, is_active_in_lbrycrd=non_winning.active_in_lbrycrd ) self.assertEqual(non_winning.activation_height, claim.activation_height) @@ -961,7 +994,7 @@ class ResolveClaimTakeovers(BaseResolveTestCase): ) greater_than_or_equal_to_zero = [ claim['claim_id'] for claim in ( - await self.conductor.spv_node.server.bp.db.search_index.search( + await self.conductor.spv_node.server.session_manager.search_index.search( channel_id=channel_id, fee_amount=">=0" ))[0] ] @@ -969,7 +1002,7 @@ class ResolveClaimTakeovers(BaseResolveTestCase): self.assertSetEqual(set(greater_than_or_equal_to_zero), {stream_with_no_fee, stream_with_fee}) greater_than_zero = [ claim['claim_id'] for claim in ( - await self.conductor.spv_node.server.bp.db.search_index.search( + await self.conductor.spv_node.server.session_manager.search_index.search( channel_id=channel_id, fee_amount=">0" ))[0] ] @@ -977,7 +1010,7 @@ class ResolveClaimTakeovers(BaseResolveTestCase): self.assertSetEqual(set(greater_than_zero), {stream_with_fee}) equal_to_zero = [ claim['claim_id'] for claim in ( - await self.conductor.spv_node.server.bp.db.search_index.search( + await self.conductor.spv_node.server.session_manager.search_index.search( channel_id=channel_id, fee_amount="<=0" ))[0] ] @@ -992,10 +1025,10 @@ class ResolveClaimTakeovers(BaseResolveTestCase): name = 'test' await self.generate(494) address = (await self.account.receiving.get_addresses(True))[0] - await self.blockchain.send_to_address(address, 400.0) + await self.send_to_address_and_wait(address, 400.0) await self.account.ledger.on_address.first await self.generate(100) - self.assertEqual(800, self.conductor.spv_node.server.bp.db.db_height) + self.assertEqual(800, self.conductor.spv_node.server.db.db_height) # Block 801: Claim A for 10 LBC is accepted. # It is the first claim, so it immediately becomes active and controlling. @@ -1007,10 +1040,10 @@ class ResolveClaimTakeovers(BaseResolveTestCase): # Its activation height is 1121 + min(4032, floor((1121-801) / 32)) = 1121 + 10 = 1131. # State: A(10) is controlling, B(20) is accepted. await self.generate(32 * 10 - 1) - self.assertEqual(1120, self.conductor.spv_node.server.bp.db.db_height) + self.assertEqual(1120, self.conductor.spv_node.server.db.db_height) claim_id_B = (await self.stream_create(name, '20.0', allow_duplicate_name=True))['outputs'][0]['claim_id'] - claim_B, _, _, _ = await self.conductor.spv_node.server.bp.db.resolve(f"{name}:{claim_id_B}") - self.assertEqual(1121, self.conductor.spv_node.server.bp.db.db_height) + claim_B, _, _, _ = await self.conductor.spv_node.server.db.resolve(f"{name}:{claim_id_B}") + self.assertEqual(1121, self.conductor.spv_node.server.db.db_height) self.assertEqual(1131, claim_B.activation_height) await self.assertMatchClaimIsWinning(name, claim_id_A) @@ -1018,33 +1051,33 @@ class ResolveClaimTakeovers(BaseResolveTestCase): # Since it is a support for the controlling claim, it activates immediately. # State: A(10+14) is controlling, B(20) is accepted. await self.support_create(claim_id_A, bid='14.0') - self.assertEqual(1122, self.conductor.spv_node.server.bp.db.db_height) + self.assertEqual(1122, self.conductor.spv_node.server.db.db_height) await self.assertMatchClaimIsWinning(name, claim_id_A) # Block 1123: Claim C for 50 LBC is accepted. # The activation height is 1123 + min(4032, floor((1123-801) / 32)) = 1123 + 10 = 1133. # State: A(10+14) is controlling, B(20) is accepted, C(50) is accepted. claim_id_C = (await self.stream_create(name, '50.0', allow_duplicate_name=True))['outputs'][0]['claim_id'] - self.assertEqual(1123, self.conductor.spv_node.server.bp.db.db_height) - claim_C, _, _, _ = await self.conductor.spv_node.server.bp.db.resolve(f"{name}:{claim_id_C}") + self.assertEqual(1123, self.conductor.spv_node.server.db.db_height) + claim_C, _, _, _ = await self.conductor.spv_node.server.db.resolve(f"{name}:{claim_id_C}") self.assertEqual(1133, claim_C.activation_height) await self.assertMatchClaimIsWinning(name, claim_id_A) await self.generate(7) - self.assertEqual(1130, self.conductor.spv_node.server.bp.db.db_height) + self.assertEqual(1130, self.conductor.spv_node.server.db.db_height) await self.assertMatchClaimIsWinning(name, claim_id_A) await self.generate(1) # Block 1131: Claim B activates. It has 20 LBC, while claim A has 24 LBC (10 original + 14 from support X). There is no takeover, and claim A remains controlling. # State: A(10+14) is controlling, B(20) is active, C(50) is accepted. - self.assertEqual(1131, self.conductor.spv_node.server.bp.db.db_height) + self.assertEqual(1131, self.conductor.spv_node.server.db.db_height) await self.assertMatchClaimIsWinning(name, claim_id_A) # Block 1132: Claim D for 300 LBC is accepted. The activation height is 1132 + min(4032, floor((1132-801) / 32)) = 1132 + 10 = 1142. # State: A(10+14) is controlling, B(20) is active, C(50) is accepted, D(300) is accepted. claim_id_D = (await self.stream_create(name, '300.0', allow_duplicate_name=True))['outputs'][0]['claim_id'] - self.assertEqual(1132, self.conductor.spv_node.server.bp.db.db_height) - claim_D, _, _, _ = await self.conductor.spv_node.server.bp.db.resolve(f"{name}:{claim_id_D}") + self.assertEqual(1132, self.conductor.spv_node.server.db.db_height) + claim_D, _, _, _ = await self.conductor.spv_node.server.db.resolve(f"{name}:{claim_id_D}") self.assertEqual(False, claim_D.is_controlling) self.assertEqual(801, claim_D.last_takeover_height) self.assertEqual(1142, claim_D.activation_height) @@ -1053,8 +1086,8 @@ class ResolveClaimTakeovers(BaseResolveTestCase): # Block 1133: Claim C activates. It has 50 LBC, while claim A has 24 LBC, so a takeover is initiated. The takeover height for this name is set to 1133, and therefore the activation delay for all the claims becomes min(4032, floor((1133-1133) / 32)) = 0. All the claims become active. The totals for each claim are recalculated, and claim D becomes controlling because it has the highest total. # State: A(10+14) is active, B(20) is active, C(50) is active, D(300) is controlling await self.generate(1) - self.assertEqual(1133, self.conductor.spv_node.server.bp.db.db_height) - claim_D, _, _, _ = await self.conductor.spv_node.server.bp.db.resolve(f"{name}:{claim_id_D}") + self.assertEqual(1133, self.conductor.spv_node.server.db.db_height) + claim_D, _, _, _ = await self.conductor.spv_node.server.db.resolve(f"{name}:{claim_id_D}") self.assertEqual(True, claim_D.is_controlling) self.assertEqual(1133, claim_D.last_takeover_height) self.assertEqual(1133, claim_D.activation_height) @@ -1327,15 +1360,15 @@ class ResolveClaimTakeovers(BaseResolveTestCase): await self.generate(8) await self.assertMatchClaimIsWinning(name, first_claim_id) # abandon the support that causes the winning claim to have the highest staked - tx = await self.daemon.jsonrpc_txo_spend(type='support', txid=controlling_support_tx.id) + tx = await self.daemon.jsonrpc_txo_spend(type='support', txid=controlling_support_tx.id, blocking=True) await self.generate(1) - await self.assertMatchClaimIsWinning(name, first_claim_id) - # await self.assertMatchClaim(second_claim_id) - + await self.assertNameState(538, name, first_claim_id, last_takeover_height=207, non_winning_claims=[ + ClaimStateValue(second_claim_id, activation_height=539, active_in_lbrycrd=False) + ]) await self.generate(1) - - await self.assertMatchClaim(first_claim_id) - await self.assertMatchClaimIsWinning(name, second_claim_id) + await self.assertNameState(539, name, second_claim_id, last_takeover_height=539, non_winning_claims=[ + ClaimStateValue(first_claim_id, activation_height=207, active_in_lbrycrd=True) + ]) async def test_remove_controlling_support(self): name = 'derp' @@ -1405,14 +1438,14 @@ class ResolveClaimTakeovers(BaseResolveTestCase): await self.generate(32) second_claim_id = (await self.stream_create(name, '0.01', allow_duplicate_name=True))['outputs'][0]['claim_id'] - await self.assertNoClaim(second_claim_id) + await self.assertNoClaim(name, second_claim_id) self.assertEqual( - len((await self.conductor.spv_node.server.bp.db.search_index.search(claim_name=name))[0]), 1 + len((await self.conductor.spv_node.server.session_manager.search_index.search(claim_name=name))[0]), 1 ) await self.generate(1) - await self.assertMatchClaim(second_claim_id) + await self.assertMatchClaim(name, second_claim_id) self.assertEqual( - len((await self.conductor.spv_node.server.bp.db.search_index.search(claim_name=name))[0]), 2 + len((await self.conductor.spv_node.server.session_manager.search_index.search(claim_name=name))[0]), 2 ) async def test_abandon_controlling_same_block_as_new_claim(self): @@ -1428,35 +1461,47 @@ class ResolveClaimTakeovers(BaseResolveTestCase): async def test_trending(self): async def get_trending_score(claim_id): - return (await self.conductor.spv_node.server.bp.db.search_index.search( + return (await self.conductor.spv_node.server.session_manager.search_index.search( claim_id=claim_id ))[0][0]['trending_score'] claim_id1 = (await self.stream_create('derp', '1.0'))['outputs'][0]['claim_id'] - COIN = 1E8 + COIN = int(1E8) - height = 99000 - self.conductor.spv_node.server.bp._add_claim_activation_change_notification( - claim_id1, height, 0, 10 * COIN + self.assertEqual(self.conductor.spv_node.writer.height, 207) + self.conductor.spv_node.writer.db.prefix_db.trending_notification.stage_put( + (208, bytes.fromhex(claim_id1)), (0, 10 * COIN) ) await self.generate(1) - self.assertEqual(172.64252836433135, await get_trending_score(claim_id1)) - self.conductor.spv_node.server.bp._add_claim_activation_change_notification( - claim_id1, height + 1, 10 * COIN, 100 * COIN + self.assertEqual(self.conductor.spv_node.writer.height, 208) + + self.assertEqual(1.7090807854206793, await get_trending_score(claim_id1)) + self.conductor.spv_node.writer.db.prefix_db.trending_notification.stage_put( + (209, bytes.fromhex(claim_id1)), (10 * COIN, 100 * COIN) ) await self.generate(1) - self.assertEqual(173.45931832928875, await get_trending_score(claim_id1)) - self.conductor.spv_node.server.bp._add_claim_activation_change_notification( - claim_id1, height + 100, 100 * COIN, 1000000 * COIN + self.assertEqual(self.conductor.spv_node.writer.height, 209) + self.assertEqual(2.2437974397778886, await get_trending_score(claim_id1)) + self.conductor.spv_node.writer.db.prefix_db.trending_notification.stage_put( + (309, bytes.fromhex(claim_id1)), (100 * COIN, 1000000 * COIN) ) - await self.generate(1) - self.assertEqual(176.65517070393514, await get_trending_score(claim_id1)) - self.conductor.spv_node.server.bp._add_claim_activation_change_notification( - claim_id1, height + 200, 1000000 * COIN, 1 * COIN + await self.generate(100) + self.assertEqual(self.conductor.spv_node.writer.height, 309) + self.assertEqual(5.157053472135866, await get_trending_score(claim_id1)) + + self.conductor.spv_node.writer.db.prefix_db.trending_notification.stage_put( + (409, bytes.fromhex(claim_id1)), (1000000 * COIN, 1 * COIN) ) + + await self.generate(99) + self.assertEqual(self.conductor.spv_node.writer.height, 408) + self.assertEqual(5.157053472135866, await get_trending_score(claim_id1)) + await self.generate(1) - self.assertEqual(-174.951347102643, await get_trending_score(claim_id1)) - search_results = (await self.conductor.spv_node.server.bp.db.search_index.search(claim_name="derp"))[0] + self.assertEqual(self.conductor.spv_node.writer.height, 409) + + self.assertEqual(-3.4256156592205627, await get_trending_score(claim_id1)) + search_results = (await self.conductor.spv_node.server.session_manager.search_index.search(claim_name="derp"))[0] self.assertEqual(1, len(search_results)) self.assertListEqual([claim_id1], [c['claim_id'] for c in search_results]) @@ -1465,22 +1510,31 @@ class ResolveAfterReorg(BaseResolveTestCase): async def reorg(self, start): blocks = self.ledger.headers.height - start self.blockchain.block_expected = start - 1 + + + prepare = self.ledger.on_header.where(self.blockchain.is_expected_block) + self.conductor.spv_node.server.synchronized.clear() + # go back to start await self.blockchain.invalidate_block((await self.ledger.headers.hash(start)).decode()) # go to previous + 1 - await self.generate(blocks + 2) + await self.blockchain.generate(blocks + 2) + + await prepare # no guarantee that it didn't happen already, so start waiting from before calling generate + await self.conductor.spv_node.server.synchronized.wait() + # await asyncio.wait_for(self.on_header(self.blockchain.block_expected), 30.0) async def assertBlockHash(self, height): - bp = self.conductor.spv_node.server.bp + reader_db = self.conductor.spv_node.server.db block_hash = await self.blockchain.get_block_hash(height) self.assertEqual(block_hash, (await self.ledger.headers.hash(height)).decode()) - self.assertEqual(block_hash, (await bp.db.fs_block_hashes(height, 1))[0][::-1].hex()) + self.assertEqual(block_hash, (await reader_db.fs_block_hashes(height, 1))[0][::-1].hex()) txids = [ - tx_hash[::-1].hex() for tx_hash in bp.db.get_block_txs(height) + tx_hash[::-1].hex() for tx_hash in reader_db.get_block_txs(height) ] - txs = await bp.db.get_transactions_and_merkles(txids) - block_txs = (await bp.daemon.deserialised_block(block_hash))['tx'] + txs = await reader_db.get_transactions_and_merkles(txids) + block_txs = (await self.conductor.spv_node.server.daemon.deserialised_block(block_hash))['tx'] self.assertSetEqual(set(block_txs), set(txs.keys()), msg='leveldb/lbrycrd is missing transactions') self.assertListEqual(block_txs, list(txs.keys()), msg='leveldb/lbrycrd transactions are of order') @@ -1491,9 +1545,18 @@ class ResolveAfterReorg(BaseResolveTestCase): channel_id = self.get_claim_id( await self.channel_create(channel_name, '0.01') ) - self.assertEqual(channel_id, (await self.assertMatchWinningClaim(channel_name)).claim_hash.hex()) + + await self.assertNameState( + height=207, name='@abc', winning_claim_id=channel_id, last_takeover_height=207, + non_winning_claims=[] + ) + await self.reorg(206) - self.assertEqual(channel_id, (await self.assertMatchWinningClaim(channel_name)).claim_hash.hex()) + + await self.assertNameState( + height=208, name='@abc', winning_claim_id=channel_id, last_takeover_height=207, + non_winning_claims=[] + ) # await self.assertNoClaimForName(channel_name) # self.assertNotIn('error', await self.resolve(channel_name)) @@ -1502,16 +1565,29 @@ class ResolveAfterReorg(BaseResolveTestCase): stream_id = self.get_claim_id( await self.stream_create(stream_name, '0.01', channel_id=channel_id) ) - self.assertEqual(stream_id, (await self.assertMatchWinningClaim(stream_name)).claim_hash.hex()) + + await self.assertNameState( + height=209, name=stream_name, winning_claim_id=stream_id, last_takeover_height=209, + non_winning_claims=[] + ) await self.reorg(206) - self.assertEqual(stream_id, (await self.assertMatchWinningClaim(stream_name)).claim_hash.hex()) + await self.assertNameState( + height=210, name=stream_name, winning_claim_id=stream_id, last_takeover_height=209, + non_winning_claims=[] + ) await self.support_create(stream_id, '0.01') - self.assertNotIn('error', await self.resolve(stream_name)) - self.assertEqual(stream_id, (await self.assertMatchWinningClaim(stream_name)).claim_hash.hex()) + + await self.assertNameState( + height=211, name=stream_name, winning_claim_id=stream_id, last_takeover_height=209, + non_winning_claims=[] + ) await self.reorg(206) # self.assertNotIn('error', await self.resolve(stream_name)) - self.assertEqual(stream_id, (await self.assertMatchWinningClaim(stream_name)).claim_hash.hex()) + await self.assertNameState( + height=212, name=stream_name, winning_claim_id=stream_id, last_takeover_height=209, + non_winning_claims=[] + ) await self.stream_abandon(stream_id) self.assertNotIn('error', await self.resolve(channel_name)) @@ -1553,7 +1629,6 @@ class ResolveAfterReorg(BaseResolveTestCase): await self.ledger.wait(broadcast_tx) await self.support_create(still_valid.outputs[0].claim_id, '0.01') - # await self.generate(1) await self.ledger.wait(broadcast_tx, self.blockchain.block_expected) self.assertEqual(self.ledger.headers.height, 208) await self.assertBlockHash(208) @@ -1570,7 +1645,7 @@ class ResolveAfterReorg(BaseResolveTestCase): # reorg the last block dropping our claim tx await self.blockchain.invalidate_block(invalidated_block_hash) - await self.blockchain.clear_mempool() + await self.conductor.clear_mempool() await self.blockchain.generate(2) # wait for the client to catch up and verify the reorg @@ -1603,7 +1678,7 @@ class ResolveAfterReorg(BaseResolveTestCase): await self.blockchain.generate(1) # wait for the client to catch up - await asyncio.wait_for(self.on_header(210), 1.0) + await asyncio.wait_for(self.on_header(210), 3.0) # verify the claim is in the new block and that it is returned by claim_search republished = await self.resolve('hovercraft') @@ -1649,11 +1724,11 @@ class ResolveAfterReorg(BaseResolveTestCase): # reorg the last block dropping our claim tx await self.blockchain.invalidate_block(invalidated_block_hash) - await self.blockchain.clear_mempool() + await self.conductor.clear_mempool() await self.blockchain.generate(2) # wait for the client to catch up and verify the reorg - await asyncio.wait_for(self.on_header(209), 3.0) + await asyncio.wait_for(self.on_header(209), 30.0) await self.assertBlockHash(207) await self.assertBlockHash(208) await self.assertBlockHash(209) diff --git a/tests/integration/transactions/test_internal_transaction_api.py b/tests/integration/transactions/test_internal_transaction_api.py index 142009ba4..04817b2be 100644 --- a/tests/integration/transactions/test_internal_transaction_api.py +++ b/tests/integration/transactions/test_internal_transaction_api.py @@ -21,9 +21,8 @@ class BasicTransactionTest(IntegrationTestCase): [asyncio.ensure_future(self.on_address_update(address1)), asyncio.ensure_future(self.on_address_update(address2))] )) - sendtxid1 = await self.blockchain.send_to_address(address1, 5) - sendtxid2 = await self.blockchain.send_to_address(address2, 5) - await self.blockchain.generate(1) + await self.send_to_address_and_wait(address1, 5) + await self.send_to_address_and_wait(address2, 5, 1) await notifications self.assertEqual(d2l(await self.account.get_balance()), '10.0') @@ -57,7 +56,7 @@ class BasicTransactionTest(IntegrationTestCase): notifications = asyncio.create_task(asyncio.wait( [asyncio.ensure_future(self.ledger.wait(channel_tx)), asyncio.ensure_future(self.ledger.wait(stream_tx))] )) - await self.blockchain.generate(1) + await self.generate(1) await notifications self.assertEqual(d2l(await self.account.get_balance()), '7.985786') self.assertEqual(d2l(await self.account.get_balance(include_claims=True)), '9.985786') @@ -70,7 +69,7 @@ class BasicTransactionTest(IntegrationTestCase): await self.broadcast(abandon_tx) await notify notify = asyncio.create_task(self.ledger.wait(abandon_tx)) - await self.blockchain.generate(1) + await self.generate(1) await notify response = await self.ledger.resolve([], ['lbry://@bar/foo']) diff --git a/tests/integration/transactions/test_transaction_commands.py b/tests/integration/transactions/test_transaction_commands.py index d85641847..8252d4bb1 100644 --- a/tests/integration/transactions/test_transaction_commands.py +++ b/tests/integration/transactions/test_transaction_commands.py @@ -1,3 +1,5 @@ +import unittest + from lbry.testcase import CommandTestCase @@ -17,7 +19,7 @@ class TransactionCommandsTestCase(CommandTestCase): async def test_transaction_show(self): # local tx result = await self.out(self.daemon.jsonrpc_account_send( - '5.0', await self.daemon.jsonrpc_address_unused(self.account.id) + '5.0', await self.daemon.jsonrpc_address_unused(self.account.id), blocking=True )) await self.confirm_tx(result['txid']) tx = await self.daemon.jsonrpc_transaction_show(result['txid']) @@ -38,10 +40,9 @@ class TransactionCommandsTestCase(CommandTestCase): self.assertFalse(result['success']) async def test_utxo_release(self): - sendtxid = await self.blockchain.send_to_address( - await self.account.receiving.get_or_create_usable_address(), 1 + await self.send_to_address_and_wait( + await self.account.receiving.get_or_create_usable_address(), 1, 1 ) - await self.confirm_tx(sendtxid) await self.assertBalance(self.account, '11.0') await self.ledger.reserve_outputs(await self.account.get_utxos()) await self.assertBalance(self.account, '0.0') @@ -51,6 +52,7 @@ class TransactionCommandsTestCase(CommandTestCase): class TestSegwit(CommandTestCase): + @unittest.SkipTest async def test_segwit(self): p2sh_address1 = await self.blockchain.get_new_address(self.blockchain.P2SH_SEGWIT_ADDRESS) p2sh_address2 = await self.blockchain.get_new_address(self.blockchain.P2SH_SEGWIT_ADDRESS) @@ -64,14 +66,13 @@ class TestSegwit(CommandTestCase): p2sh_txid2 = await self.blockchain.send_to_address(p2sh_address2, '1.0') bech32_txid1 = await self.blockchain.send_to_address(bech32_address1, '1.0') bech32_txid2 = await self.blockchain.send_to_address(bech32_address2, '1.0') - await self.generate(1) # P2SH & BECH32 can pay to P2SH address tx = await self.blockchain.create_raw_transaction([ {"txid": p2sh_txid1, "vout": 0}, {"txid": bech32_txid1, "vout": 0}, - ], [{p2sh_address3: '1.9'}] + ], {p2sh_address3: 1.9} ) tx = await self.blockchain.sign_raw_transaction_with_wallet(tx) p2sh_txid3 = await self.blockchain.send_raw_transaction(tx) @@ -82,7 +83,7 @@ class TestSegwit(CommandTestCase): tx = await self.blockchain.create_raw_transaction([ {"txid": p2sh_txid2, "vout": 0}, {"txid": bech32_txid2, "vout": 0}, - ], [{bech32_address3: '1.9'}] + ], {bech32_address3: 1.9} ) tx = await self.blockchain.sign_raw_transaction_with_wallet(tx) bech32_txid3 = await self.blockchain.send_raw_transaction(tx) @@ -94,12 +95,9 @@ class TestSegwit(CommandTestCase): tx = await self.blockchain.create_raw_transaction([ {"txid": p2sh_txid3, "vout": 0}, {"txid": bech32_txid3, "vout": 0}, - ], [{address: '3.5'}] + ], {address: 3.5} ) tx = await self.blockchain.sign_raw_transaction_with_wallet(tx) txid = await self.blockchain.send_raw_transaction(tx) - await self.on_transaction_id(txid) - await self.generate(1) - await self.on_transaction_id(txid) - + await self.generate_and_wait(1, [txid]) await self.assertBalance(self.account, '13.5') diff --git a/tests/integration/transactions/test_transactions.py b/tests/integration/transactions/test_transactions.py index fea0b18fb..9e3b49056 100644 --- a/tests/integration/transactions/test_transactions.py +++ b/tests/integration/transactions/test_transactions.py @@ -1,7 +1,7 @@ import asyncio import random -from itertools import chain +import lbry.wallet.rpc.jsonrpc from lbry.wallet.transaction import Transaction, Output, Input from lbry.testcase import IntegrationTestCase from lbry.wallet.util import satoshis_to_coins, coins_to_satoshis @@ -9,9 +9,8 @@ from lbry.wallet.manager import WalletManager class BasicTransactionTests(IntegrationTestCase): - async def test_variety_of_transactions_and_longish_history(self): - await self.blockchain.generate(300) + await self.generate(300) await self.assertBalance(self.account, '0.0') addresses = await self.account.receiving.get_addresses() @@ -19,10 +18,10 @@ class BasicTransactionTests(IntegrationTestCase): # to the 10th receiving address for a total of 30 UTXOs on the entire account for i in range(10): notification = asyncio.ensure_future(self.on_address_update(addresses[i])) - txid = await self.blockchain.send_to_address(addresses[i], 10) + _ = await self.send_to_address_and_wait(addresses[i], 10) await notification notification = asyncio.ensure_future(self.on_address_update(addresses[9])) - txid = await self.blockchain.send_to_address(addresses[9], 10) + _ = await self.send_to_address_and_wait(addresses[9], 10) await notification # use batching to reduce issues with send_to_address on cli @@ -57,7 +56,7 @@ class BasicTransactionTests(IntegrationTestCase): for tx in await self.ledger.db.get_transactions(txid__in=[tx.id for tx in txs]) ])) - await self.blockchain.generate(1) + await self.generate(1) await asyncio.wait([self.ledger.wait(tx) for tx in txs]) await self.assertBalance(self.account, '199.99876') @@ -74,7 +73,7 @@ class BasicTransactionTests(IntegrationTestCase): ) await self.broadcast(tx) await self.ledger.wait(tx) - await self.blockchain.generate(1) + await self.generate(1) await self.ledger.wait(tx) self.assertEqual(2, await self.account.get_utxo_count()) # 199 + change @@ -88,12 +87,10 @@ class BasicTransactionTests(IntegrationTestCase): await self.assertBalance(account2, '0.0') addresses = await account1.receiving.get_addresses() - txids = await asyncio.gather(*( - self.blockchain.send_to_address(address, 1.1) for address in addresses[:5] - )) - await asyncio.wait([self.on_transaction_id(txid) for txid in txids]) # mempool - await self.blockchain.generate(1) - await asyncio.wait([self.on_transaction_id(txid) for txid in txids]) # confirmed + txids = [] + for address in addresses[:5]: + txids.append(await self.send_to_address_and_wait(address, 1.1)) + await self.generate_and_wait(1, txids) await self.assertBalance(account1, '5.5') await self.assertBalance(account2, '0.0') @@ -107,7 +104,7 @@ class BasicTransactionTests(IntegrationTestCase): ) await self.broadcast(tx) await self.ledger.wait(tx) # mempool - await self.blockchain.generate(1) + await self.generate(1) await self.ledger.wait(tx) # confirmed await self.assertBalance(account1, '3.499802') @@ -121,7 +118,7 @@ class BasicTransactionTests(IntegrationTestCase): ) await self.broadcast(tx) await self.ledger.wait(tx) # mempool - await self.blockchain.generate(1) + await self.generate(1) await self.ledger.wait(tx) # confirmed tx = (await account1.get_transactions(include_is_my_input=True, include_is_my_output=True))[1] @@ -133,11 +130,11 @@ class BasicTransactionTests(IntegrationTestCase): self.assertTrue(tx.outputs[1].is_internal_transfer) async def test_history_edge_cases(self): - await self.blockchain.generate(300) + await self.generate(300) await self.assertBalance(self.account, '0.0') address = await self.account.receiving.get_or_create_usable_address() # evil trick: mempool is unsorted on real life, but same order between python instances. reproduce it - original_summary = self.conductor.spv_node.server.bp.mempool.transaction_summaries + original_summary = self.conductor.spv_node.server.mempool.transaction_summaries def random_summary(*args, **kwargs): summary = original_summary(*args, **kwargs) @@ -146,13 +143,10 @@ class BasicTransactionTests(IntegrationTestCase): while summary == ordered: random.shuffle(summary) return summary - self.conductor.spv_node.server.bp.mempool.transaction_summaries = random_summary + self.conductor.spv_node.server.mempool.transaction_summaries = random_summary # 10 unconfirmed txs, all from blockchain wallet - sends = [self.blockchain.send_to_address(address, 10) for _ in range(10)] - # use batching to reduce issues with send_to_address on cli - for batch in range(0, len(sends), 10): - txids = await asyncio.gather(*sends[batch:batch + 10]) - await asyncio.wait([self.on_transaction_id(txid) for txid in txids]) + for i in range(10): + await self.send_to_address_and_wait(address, 10) remote_status = await self.ledger.network.subscribe_address(address) self.assertTrue(await self.ledger.update_history(address, remote_status)) # 20 unconfirmed txs, 10 from blockchain, 10 from local to local @@ -170,8 +164,7 @@ class BasicTransactionTests(IntegrationTestCase): remote_status = await self.ledger.network.subscribe_address(address) self.assertTrue(await self.ledger.update_history(address, remote_status)) # server history grows unordered - txid = await self.blockchain.send_to_address(address, 1) - await self.on_transaction_id(txid) + await self.send_to_address_and_wait(address, 1) self.assertTrue(await self.ledger.update_history(address, remote_status)) self.assertEqual(21, len((await self.ledger.get_local_status_and_history(address))[1])) self.assertEqual(0, len(self.ledger._known_addresses_out_of_sync)) @@ -195,37 +188,37 @@ class BasicTransactionTests(IntegrationTestCase): self.ledger, 2000000000000, [self.account], set_reserved=False, return_insufficient_funds=True ) got_amounts = [estimator.effective_amount for estimator in spendable] - self.assertListEqual(amounts, got_amounts) + self.assertListEqual(sorted(amounts), sorted(got_amounts)) async def test_sqlite_coin_chooser(self): wallet_manager = WalletManager([self.wallet], {self.ledger.get_id(): self.ledger}) - await self.blockchain.generate(300) + await self.generate(300) await self.assertBalance(self.account, '0.0') address = await self.account.receiving.get_or_create_usable_address() other_account = self.wallet.generate_account(self.ledger) other_address = await other_account.receiving.get_or_create_usable_address() self.ledger.coin_selection_strategy = 'sqlite' - await self.ledger.subscribe_account(self.account) + await self.ledger.subscribe_account(other_account) accepted = asyncio.ensure_future(self.on_address_update(address)) - txid = await self.blockchain.send_to_address(address, 1.0) + _ = await self.send_to_address_and_wait(address, 1.0) await accepted accepted = asyncio.ensure_future(self.on_address_update(address)) - txid = await self.blockchain.send_to_address(address, 1.0) + _ = await self.send_to_address_and_wait(address, 1.0) await accepted accepted = asyncio.ensure_future(self.on_address_update(address)) - txid = await self.blockchain.send_to_address(address, 3.0) + _ = await self.send_to_address_and_wait(address, 3.0) await accepted accepted = asyncio.ensure_future(self.on_address_update(address)) - txid = await self.blockchain.send_to_address(address, 5.0) + _ = await self.send_to_address_and_wait(address, 5.0) await accepted accepted = asyncio.ensure_future(self.on_address_update(address)) - txid = await self.blockchain.send_to_address(address, 10.0) + _ = await self.send_to_address_and_wait(address, 10.0) await accepted await self.assertBalance(self.account, '20.0') @@ -266,6 +259,12 @@ class BasicTransactionTests(IntegrationTestCase): async def broadcast(tx): try: return await real_broadcast(tx) + except lbry.wallet.rpc.jsonrpc.RPCError as err: + # this is expected in tests where we try to double spend. + if 'the transaction was rejected by network rules.' in str(err): + pass + else: + raise err finally: e.set() diff --git a/tests/unit/wallet/server/test_metrics.py b/tests/unit/wallet/server/test_metrics.py deleted file mode 100644 index c66d136fd..000000000 --- a/tests/unit/wallet/server/test_metrics.py +++ /dev/null @@ -1,60 +0,0 @@ -import time -import unittest -from lbry.wallet.server.metrics import ServerLoadData, calculate_avg_percentiles - - -class TestPercentileCalculation(unittest.TestCase): - - def test_calculate_percentiles(self): - self.assertEqual(calculate_avg_percentiles([]), (0, 0, 0, 0, 0, 0, 0, 0)) - self.assertEqual(calculate_avg_percentiles([1]), (1, 1, 1, 1, 1, 1, 1, 1)) - self.assertEqual(calculate_avg_percentiles([1, 2]), (1, 1, 1, 1, 1, 2, 2, 2)) - self.assertEqual(calculate_avg_percentiles([1, 2, 3]), (2, 1, 1, 1, 2, 3, 3, 3)) - self.assertEqual(calculate_avg_percentiles([4, 1, 2, 3]), (2, 1, 1, 1, 2, 3, 4, 4)) - self.assertEqual(calculate_avg_percentiles([1, 2, 3, 4, 5, 6]), (3, 1, 1, 2, 3, 5, 6, 6)) - self.assertEqual(calculate_avg_percentiles( - list(range(1, 101))), (50, 1, 5, 25, 50, 75, 95, 100)) - - -class TestCollectingMetrics(unittest.TestCase): - - def test_happy_path(self): - self.maxDiff = None - load = ServerLoadData() - search = load.for_api('search') - self.assertEqual(search.name, 'search') - search.start() - search.cache_response() - search.cache_response() - metrics = { - 'search': [{'total': 40}], - 'execute_query': [ - {'total': 20}, - {'total': 10} - ] - } - for x in range(5): - search.query_response(time.perf_counter() - 0.055 + 0.001*x, metrics) - metrics['execute_query'][0]['total'] = 10 - metrics['execute_query'][0]['sql'] = "select lots, of, stuff FROM claim where something=1" - search.query_interrupt(time.perf_counter() - 0.050, metrics) - search.query_error(time.perf_counter() - 0.050, metrics) - search.query_error(time.perf_counter() - 0.052, {}) - self.assertEqual(load.to_json_and_reset({}), {'status': {}, 'api': {'search': { - "receive_count": 1, - "cache_response_count": 2, - "query_response_count": 5, - "intrp_response_count": 1, - "error_response_count": 2, - "response": (53, 51, 51, 52, 53, 54, 55, 55), - "interrupt": (50, 50, 50, 50, 50, 50, 50, 50), - "error": (51, 50, 50, 50, 50, 52, 52, 52), - "python": (12, 10, 10, 10, 10, 20, 20, 20), - "wait": (12, 10, 10, 10, 12, 14, 15, 15), - "sql": (27, 20, 20, 20, 30, 30, 30, 30), - "individual_sql": (13, 10, 10, 10, 10, 20, 20, 20), - "individual_sql_count": 14, - "errored_queries": ['FROM claim where something=1'], - "interrupted_queries": ['FROM claim where something=1'], - }}}) - self.assertEqual(load.to_json_and_reset({}), {'status': {}, 'api': {}}) diff --git a/tests/unit/wallet/server/test_revertable.py b/tests/unit/wallet/server/test_revertable.py deleted file mode 100644 index 79b4cdb0c..000000000 --- a/tests/unit/wallet/server/test_revertable.py +++ /dev/null @@ -1,153 +0,0 @@ -import unittest -import tempfile -import shutil -from lbry.wallet.server.db.revertable import RevertableOpStack, RevertableDelete, RevertablePut, OpStackIntegrity -from lbry.wallet.server.db.prefixes import ClaimToTXOPrefixRow, HubDB - - -class TestRevertableOpStack(unittest.TestCase): - def setUp(self): - self.fake_db = {} - self.stack = RevertableOpStack(self.fake_db.get) - - def tearDown(self) -> None: - self.stack.clear() - self.fake_db.clear() - - def process_stack(self): - for op in self.stack: - if op.is_put: - self.fake_db[op.key] = op.value - else: - self.fake_db.pop(op.key) - self.stack.clear() - - def update(self, key1: bytes, value1: bytes, key2: bytes, value2: bytes): - self.stack.append_op(RevertableDelete(key1, value1)) - self.stack.append_op(RevertablePut(key2, value2)) - - def test_simplify(self): - key1 = ClaimToTXOPrefixRow.pack_key(b'\x01' * 20) - key2 = ClaimToTXOPrefixRow.pack_key(b'\x02' * 20) - key3 = ClaimToTXOPrefixRow.pack_key(b'\x03' * 20) - key4 = ClaimToTXOPrefixRow.pack_key(b'\x04' * 20) - - val1 = ClaimToTXOPrefixRow.pack_value(1, 0, 1, 0, 1, False, 'derp') - val2 = ClaimToTXOPrefixRow.pack_value(1, 0, 1, 0, 1, False, 'oops') - val3 = ClaimToTXOPrefixRow.pack_value(1, 0, 1, 0, 1, False, 'other') - - # check that we can't delete a non existent value - with self.assertRaises(OpStackIntegrity): - self.stack.append_op(RevertableDelete(key1, val1)) - - self.stack.append_op(RevertablePut(key1, val1)) - self.assertEqual(1, len(self.stack)) - self.stack.append_op(RevertableDelete(key1, val1)) - self.assertEqual(0, len(self.stack)) - - self.stack.append_op(RevertablePut(key1, val1)) - self.assertEqual(1, len(self.stack)) - # try to delete the wrong value - with self.assertRaises(OpStackIntegrity): - self.stack.append_op(RevertableDelete(key2, val2)) - - self.stack.append_op(RevertableDelete(key1, val1)) - self.assertEqual(0, len(self.stack)) - self.stack.append_op(RevertablePut(key2, val3)) - self.assertEqual(1, len(self.stack)) - - self.process_stack() - - self.assertDictEqual({key2: val3}, self.fake_db) - - # check that we can't put on top of the existing stored value - with self.assertRaises(OpStackIntegrity): - self.stack.append_op(RevertablePut(key2, val1)) - - self.assertEqual(0, len(self.stack)) - self.stack.append_op(RevertableDelete(key2, val3)) - self.assertEqual(1, len(self.stack)) - self.stack.append_op(RevertablePut(key2, val3)) - self.assertEqual(0, len(self.stack)) - - self.update(key2, val3, key2, val1) - self.assertEqual(2, len(self.stack)) - - self.process_stack() - self.assertDictEqual({key2: val1}, self.fake_db) - - self.update(key2, val1, key2, val2) - self.assertEqual(2, len(self.stack)) - self.update(key2, val2, key2, val3) - self.update(key2, val3, key2, val2) - self.update(key2, val2, key2, val3) - self.update(key2, val3, key2, val2) - with self.assertRaises(OpStackIntegrity): - self.update(key2, val3, key2, val2) - self.update(key2, val2, key2, val3) - self.assertEqual(2, len(self.stack)) - self.stack.append_op(RevertableDelete(key2, val3)) - self.process_stack() - self.assertDictEqual({}, self.fake_db) - - self.stack.append_op(RevertablePut(key2, val3)) - self.process_stack() - with self.assertRaises(OpStackIntegrity): - self.update(key2, val2, key2, val2) - self.update(key2, val3, key2, val2) - self.assertDictEqual({key2: val3}, self.fake_db) - undo = self.stack.get_undo_ops() - self.process_stack() - self.assertDictEqual({key2: val2}, self.fake_db) - self.stack.apply_packed_undo_ops(undo) - self.process_stack() - self.assertDictEqual({key2: val3}, self.fake_db) - - -class TestRevertablePrefixDB(unittest.TestCase): - def setUp(self): - self.tmp_dir = tempfile.mkdtemp() - self.db = HubDB(self.tmp_dir, cache_mb=1, max_open_files=32) - - def tearDown(self) -> None: - self.db.close() - shutil.rmtree(self.tmp_dir) - - def test_rollback(self): - name = 'derp' - claim_hash1 = 20 * b'\x00' - claim_hash2 = 20 * b'\x01' - claim_hash3 = 20 * b'\x02' - - takeover_height = 10000000 - - self.assertIsNone(self.db.claim_takeover.get(name)) - self.db.claim_takeover.stage_put((name,), (claim_hash1, takeover_height)) - self.assertIsNone(self.db.claim_takeover.get(name)) - self.assertEqual(10000000, self.db.claim_takeover.get_pending(name).height) - - self.db.commit(10000000) - self.assertEqual(10000000, self.db.claim_takeover.get(name).height) - - self.db.claim_takeover.stage_delete((name,), (claim_hash1, takeover_height)) - self.db.claim_takeover.stage_put((name,), (claim_hash2, takeover_height + 1)) - self.db.claim_takeover.stage_delete((name,), (claim_hash2, takeover_height + 1)) - self.db.commit(10000001) - self.assertIsNone(self.db.claim_takeover.get(name)) - self.db.claim_takeover.stage_put((name,), (claim_hash3, takeover_height + 2)) - self.db.commit(10000002) - self.assertEqual(10000002, self.db.claim_takeover.get(name).height) - - self.db.claim_takeover.stage_delete((name,), (claim_hash3, takeover_height + 2)) - self.db.claim_takeover.stage_put((name,), (claim_hash2, takeover_height + 3)) - self.db.commit(10000003) - self.assertEqual(10000003, self.db.claim_takeover.get(name).height) - - self.db.rollback(10000003) - self.assertEqual(10000002, self.db.claim_takeover.get(name).height) - self.db.rollback(10000002) - self.assertIsNone(self.db.claim_takeover.get(name)) - self.db.rollback(10000001) - self.assertEqual(10000000, self.db.claim_takeover.get(name).height) - self.db.rollback(10000000) - self.assertIsNone(self.db.claim_takeover.get(name)) diff --git a/tox.ini b/tox.ini index 8ad5e37a9..fe35b583a 100644 --- a/tox.ini +++ b/tox.ini @@ -3,6 +3,7 @@ deps = coverage extras = test + scribe torrent changedir = {toxinidir}/tests setenv =