From 219c7cf37d9075dc61d1cdcdbccb9257236f73f4 Mon Sep 17 00:00:00 2001 From: Lex Berezhny Date: Fri, 1 May 2020 23:22:17 -0400 Subject: [PATCH] removed lbry.wallet.server and lbry.wallet.orchstr8 --- lbry/wallet/orchstr8/__init__.py | 2 - lbry/wallet/orchstr8/cli.py | 89 - lbry/wallet/orchstr8/node.py | 438 ----- lbry/wallet/orchstr8/service.py | 137 -- lbry/wallet/server/__init__.py | 0 lbry/wallet/server/block_processor.py | 777 -------- lbry/wallet/server/cli.py | 41 - lbry/wallet/server/coin.py | 350 ---- lbry/wallet/server/daemon.py | 357 ---- lbry/wallet/server/db/__init__.py | 0 lbry/wallet/server/db/canonical.py | 26 - lbry/wallet/server/db/common.py | 221 --- lbry/wallet/server/db/trending/__init__.py | 9 - lbry/wallet/server/db/trending/ar.py | 265 --- .../server/db/trending/variable_decay.py | 431 ----- lbry/wallet/server/db/trending/zscore.py | 123 -- lbry/wallet/server/db/writer.py | 894 --------- lbry/wallet/server/env.py | 263 --- lbry/wallet/server/hash.py | 159 -- lbry/wallet/server/history.py | 324 ---- lbry/wallet/server/leveldb.py | 670 ------- lbry/wallet/server/mempool.py | 376 ---- lbry/wallet/server/merkle.py | 253 --- lbry/wallet/server/metrics.py | 135 -- lbry/wallet/server/peer.py | 302 --- lbry/wallet/server/peers.py | 506 ----- lbry/wallet/server/prometheus.py | 89 - lbry/wallet/server/script.py | 289 --- lbry/wallet/server/server.py | 146 -- lbry/wallet/server/session.py | 1641 ----------------- lbry/wallet/server/storage.py | 169 -- lbry/wallet/server/text.py | 82 - lbry/wallet/server/tx.py | 615 ------ lbry/wallet/server/util.py | 359 ---- lbry/wallet/server/version.py | 3 - lbry/wallet/server/websocket.py | 55 - 36 files changed, 10596 deletions(-) delete mode 100644 lbry/wallet/orchstr8/__init__.py delete mode 100644 lbry/wallet/orchstr8/cli.py delete mode 100644 lbry/wallet/orchstr8/node.py delete mode 100644 lbry/wallet/orchstr8/service.py delete mode 100644 lbry/wallet/server/__init__.py delete mode 100644 lbry/wallet/server/block_processor.py delete mode 100644 lbry/wallet/server/cli.py delete mode 100644 lbry/wallet/server/coin.py delete mode 100644 lbry/wallet/server/daemon.py delete mode 100644 lbry/wallet/server/db/__init__.py delete mode 100644 lbry/wallet/server/db/canonical.py delete mode 100644 lbry/wallet/server/db/common.py delete mode 100644 lbry/wallet/server/db/trending/__init__.py delete mode 100644 lbry/wallet/server/db/trending/ar.py delete mode 100644 lbry/wallet/server/db/trending/variable_decay.py delete mode 100644 lbry/wallet/server/db/trending/zscore.py delete mode 100644 lbry/wallet/server/db/writer.py delete mode 100644 lbry/wallet/server/env.py delete mode 100644 lbry/wallet/server/hash.py delete mode 100644 lbry/wallet/server/history.py delete mode 100644 lbry/wallet/server/leveldb.py delete mode 100644 lbry/wallet/server/mempool.py delete mode 100644 lbry/wallet/server/merkle.py delete mode 100644 lbry/wallet/server/metrics.py delete mode 100644 lbry/wallet/server/peer.py delete mode 100644 lbry/wallet/server/peers.py delete mode 100644 lbry/wallet/server/prometheus.py delete mode 100644 lbry/wallet/server/script.py delete mode 100644 lbry/wallet/server/server.py delete mode 100644 lbry/wallet/server/session.py delete mode 100644 lbry/wallet/server/storage.py delete mode 100644 lbry/wallet/server/text.py delete mode 100644 lbry/wallet/server/tx.py delete mode 100644 lbry/wallet/server/util.py delete mode 100644 lbry/wallet/server/version.py delete mode 100644 lbry/wallet/server/websocket.py diff --git a/lbry/wallet/orchstr8/__init__.py b/lbry/wallet/orchstr8/__init__.py deleted file mode 100644 index eea5d88be..000000000 --- a/lbry/wallet/orchstr8/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .node import Conductor -from .service import ConductorService diff --git a/lbry/wallet/orchstr8/cli.py b/lbry/wallet/orchstr8/cli.py deleted file mode 100644 index ee4ddc60c..000000000 --- a/lbry/wallet/orchstr8/cli.py +++ /dev/null @@ -1,89 +0,0 @@ -import logging -import argparse -import asyncio -import aiohttp - -from lbry import wallet -from lbry.wallet.orchstr8.node import ( - Conductor, get_blockchain_node_from_ledger -) -from lbry.wallet.orchstr8.service import ConductorService - - -def get_argument_parser(): - parser = argparse.ArgumentParser( - prog="orchstr8" - ) - subparsers = parser.add_subparsers(dest='command', help='sub-command help') - - subparsers.add_parser("download", help="Download blockchain node binary.") - - start = subparsers.add_parser("start", help="Start orchstr8 service.") - start.add_argument("--blockchain", help="Hostname to start blockchain node.") - start.add_argument("--spv", help="Hostname to start SPV server.") - start.add_argument("--wallet", help="Hostname to start wallet daemon.") - - generate = subparsers.add_parser("generate", help="Call generate method on running orchstr8 instance.") - generate.add_argument("blocks", type=int, help="Number of blocks to generate") - - subparsers.add_parser("transfer", help="Call transfer method on running orchstr8 instance.") - return parser - - -async def run_remote_command(command, **kwargs): - async with aiohttp.ClientSession() as session: - async with session.post('http://localhost:7954/'+command, data=kwargs) as resp: - print(resp.status) - print(await resp.text()) - - -def main(): - parser = get_argument_parser() - args = parser.parse_args() - command = getattr(args, 'command', 'help') - - loop = asyncio.get_event_loop() - asyncio.set_event_loop(loop) - - if command == 'download': - logging.getLogger('blockchain').setLevel(logging.INFO) - get_blockchain_node_from_ledger(wallet).ensure() - - elif command == 'generate': - loop.run_until_complete(run_remote_command( - 'generate', blocks=args.blocks - )) - - elif command == 'start': - - conductor = Conductor() - if getattr(args, 'blockchain', False): - conductor.blockchain_node.hostname = args.blockchain - loop.run_until_complete(conductor.start_blockchain()) - if getattr(args, 'spv', False): - conductor.spv_node.hostname = args.spv - loop.run_until_complete(conductor.start_spv()) - if getattr(args, 'wallet', False): - conductor.wallet_node.hostname = args.wallet - loop.run_until_complete(conductor.start_wallet()) - - service = ConductorService(conductor, loop) - loop.run_until_complete(service.start()) - - try: - print('========== Orchstr8 API Service Started ========') - loop.run_forever() - except KeyboardInterrupt: - pass - finally: - loop.run_until_complete(service.stop()) - loop.run_until_complete(conductor.stop()) - - loop.close() - - else: - parser.print_help() - - -if __name__ == "__main__": - main() diff --git a/lbry/wallet/orchstr8/node.py b/lbry/wallet/orchstr8/node.py deleted file mode 100644 index 039fffbee..000000000 --- a/lbry/wallet/orchstr8/node.py +++ /dev/null @@ -1,438 +0,0 @@ -import os -import json -import shutil -import asyncio -import zipfile -import tarfile -import logging -import tempfile -import subprocess -import importlib -from binascii import hexlify -from typing import Type, Optional -import urllib.request - -import lbry -from lbry.db import Database -from lbry.wallet.server.server import Server -from lbry.wallet.server.env import Env -from lbry.wallet import Wallet, Ledger, RegTestLedger, WalletManager, Account, BlockHeightEvent - - -log = logging.getLogger(__name__) - - -def get_spvserver_from_ledger(ledger_module): - spvserver_path, regtest_class_name = ledger_module.__spvserver__.rsplit('.', 1) - spvserver_module = importlib.import_module(spvserver_path) - return getattr(spvserver_module, regtest_class_name) - - -def get_blockchain_node_from_ledger(ledger_module): - return BlockchainNode( - ledger_module.__node_url__, - os.path.join(ledger_module.__node_bin__, ledger_module.__node_daemon__), - os.path.join(ledger_module.__node_bin__, ledger_module.__node_cli__) - ) - - -class Conductor: - - def __init__(self, seed=None): - self.manager_module = WalletManager - self.spv_module = get_spvserver_from_ledger(lbry.wallet) - - self.blockchain_node = get_blockchain_node_from_ledger(lbry.wallet) - self.spv_node = SPVNode(self.spv_module) - self.wallet_node = WalletNode( - self.manager_module, RegTestLedger, default_seed=seed - ) - - self.blockchain_started = False - self.spv_started = False - self.wallet_started = False - - self.log = log.getChild('conductor') - - async def start_blockchain(self): - if not self.blockchain_started: - asyncio.create_task(self.blockchain_node.start()) - await self.blockchain_node.running.wait() - await self.blockchain_node.generate(200) - self.blockchain_started = True - - async def stop_blockchain(self): - if self.blockchain_started: - await self.blockchain_node.stop(cleanup=True) - self.blockchain_started = False - - async def start_spv(self): - if not self.spv_started: - await self.spv_node.start(self.blockchain_node) - self.spv_started = True - - async def stop_spv(self): - if self.spv_started: - await self.spv_node.stop(cleanup=True) - self.spv_started = False - - async def start_wallet(self): - if not self.wallet_started: - await self.wallet_node.start(self.spv_node) - self.wallet_started = True - - async def stop_wallet(self): - if self.wallet_started: - await self.wallet_node.stop(cleanup=True) - self.wallet_started = False - - async def start(self): - await self.start_blockchain() - await self.start_spv() - await self.start_wallet() - - async def stop(self): - all_the_stops = [ - self.stop_wallet, - self.stop_spv, - self.stop_blockchain - ] - for stop in all_the_stops: - try: - await stop() - except Exception as e: - log.exception('Exception raised while stopping services:', exc_info=e) - - -class WalletNode: - - def __init__(self, manager_class: Type[WalletManager], ledger_class: Type[Ledger], - verbose: bool = False, port: int = 5280, default_seed: str = None) -> None: - self.manager_class = manager_class - self.ledger_class = ledger_class - self.verbose = verbose - self.manager: Optional[WalletManager] = None - self.ledger: Optional[Ledger] = None - self.wallet: Optional[Wallet] = None - self.account: Optional[Account] = None - self.data_path: Optional[str] = None - self.port = port - self.default_seed = default_seed - - async def start(self, spv_node: 'SPVNode', seed=None, connect=True): - self.data_path = tempfile.mkdtemp() - wallets_dir = os.path.join(self.data_path, 'wallets') - os.mkdir(wallets_dir) - wallet_file_name = os.path.join(wallets_dir, 'my_wallet.json') - with open(wallet_file_name, 'w') as wallet_file: - wallet_file.write('{"version": 1, "accounts": []}\n') - db_driver = os.environ.get('TEST_DB', 'sqlite') - if db_driver == 'sqlite': - db = 'sqlite:///'+os.path.join(self.data_path, self.ledger_class.get_id(), 'blockchain.db') - elif db_driver == 'postgres': - db_connection = 'postgres:postgres@localhost:5432' - db_name = f'lbry_test_{self.port}' - meta_db = Database(f'postgresql+psycopg2://{db_connection}/postgres') - await meta_db.drop(db_name) - await meta_db.create(db_name) - db = f'postgresql+psycopg2://{db_connection}/{db_name}' - else: - raise RuntimeError(f"Unsupported database driver: {db_driver}") - self.manager = self.manager_class.from_config({ - 'ledgers': { - self.ledger_class.get_id(): { - 'api_port': self.port, - 'default_servers': [(spv_node.hostname, spv_node.port)], - 'data_path': self.data_path, - 'db': Database(db) - } - }, - 'wallets': [wallet_file_name] - }) - self.ledger = self.manager.ledgers[self.ledger_class] - self.wallet = self.manager.default_wallet - if not self.wallet: - raise ValueError('Wallet is required.') - if seed or self.default_seed: - Account.from_dict( - self.ledger, self.wallet, {'seed': seed or self.default_seed} - ) - else: - self.wallet.generate_account(self.ledger) - self.account = self.wallet.default_account - if connect: - await self.manager.start() - - async def stop(self, cleanup=True): - try: - await self.manager.stop() - finally: - cleanup and self.cleanup() - - def cleanup(self): - shutil.rmtree(self.data_path, ignore_errors=True) - - -class SPVNode: - - def __init__(self, coin_class, node_number=1): - self.coin_class = coin_class - self.controller = None - self.data_path = None - self.server = None - self.hostname = 'localhost' - self.port = 50001 + node_number # avoid conflict with default daemon - self.session_timeout = 600 - self.rpc_port = '0' # disabled by default - - async def start(self, blockchain_node: 'BlockchainNode', extraconf=None): - self.data_path = tempfile.mkdtemp() - conf = { - 'DESCRIPTION': '', - 'PAYMENT_ADDRESS': '', - 'DAILY_FEE': '0', - 'DB_DIRECTORY': self.data_path, - 'DAEMON_URL': blockchain_node.rpc_url, - 'REORG_LIMIT': '100', - 'HOST': self.hostname, - 'TCP_PORT': str(self.port), - 'SESSION_TIMEOUT': str(self.session_timeout), - 'MAX_QUERY_WORKERS': '0', - 'INDIVIDUAL_TAG_INDEXES': '', - 'RPC_PORT': self.rpc_port - } - if extraconf: - conf.update(extraconf) - # TODO: don't use os.environ - os.environ.update(conf) - self.server = Server(Env(self.coin_class)) - self.server.mempool.refresh_secs = self.server.bp.prefetcher.polling_delay = 0.5 - await self.server.start() - - async def stop(self, cleanup=True): - try: - await self.server.stop() - finally: - cleanup and self.cleanup() - - def cleanup(self): - shutil.rmtree(self.data_path, ignore_errors=True) - - -class BlockchainProcess(asyncio.SubprocessProtocol): - - IGNORE_OUTPUT = [ - b'keypool keep', - b'keypool reserve', - b'keypool return', - ] - - def __init__(self): - self.ready = asyncio.Event() - self.stopped = asyncio.Event() - self.log = log.getChild('blockchain') - - def pipe_data_received(self, fd, data): - if self.log and not any(ignore in data for ignore in self.IGNORE_OUTPUT): - if b'Error:' in data: - self.log.error(data.decode()) - else: - self.log.info(data.decode()) - if b'Error:' in data: - self.ready.set() - raise SystemError(data.decode()) - if b'Done loading' in data: - self.ready.set() - - def process_exited(self): - self.stopped.set() - self.ready.set() - - -class BlockchainNode: - - P2SH_SEGWIT_ADDRESS = "p2sh-segwit" - BECH32_ADDRESS = "bech32" - - def __init__(self, url, daemon, cli): - self.latest_release_url = url - self.project_dir = os.path.dirname(os.path.dirname(__file__)) - self.bin_dir = os.path.join(self.project_dir, 'bin') - self.daemon_bin = os.path.join(self.bin_dir, daemon) - self.cli_bin = os.path.join(self.bin_dir, cli) - self.log = log.getChild('blockchain') - self.data_path = None - self.protocol = None - self.transport = None - self.block_expected = 0 - self.hostname = 'localhost' - self.peerport = 9246 + 2 # avoid conflict with default peer port - self.rpcport = 9245 + 2 # avoid conflict with default rpc port - self.rpcuser = 'rpcuser' - self.rpcpassword = 'rpcpassword' - self.stopped = False - self.restart_ready = asyncio.Event() - self.restart_ready.set() - self.running = asyncio.Event() - - @property - def rpc_url(self): - return f'http://{self.rpcuser}:{self.rpcpassword}@{self.hostname}:{self.rpcport}/' - - def is_expected_block(self, e: BlockHeightEvent): - return self.block_expected == e.height - - @property - def exists(self): - return ( - os.path.exists(self.cli_bin) and - os.path.exists(self.daemon_bin) - ) - - def download(self): - downloaded_file = os.path.join( - self.bin_dir, - self.latest_release_url[self.latest_release_url.rfind('/')+1:] - ) - - if not os.path.exists(self.bin_dir): - os.mkdir(self.bin_dir) - - if not os.path.exists(downloaded_file): - self.log.info('Downloading: %s', self.latest_release_url) - with urllib.request.urlopen(self.latest_release_url) as response: - with open(downloaded_file, 'wb') as out_file: - shutil.copyfileobj(response, out_file) - - self.log.info('Extracting: %s', downloaded_file) - - if downloaded_file.endswith('.zip'): - with zipfile.ZipFile(downloaded_file) as dotzip: - dotzip.extractall(self.bin_dir) - # zipfile bug https://bugs.python.org/issue15795 - os.chmod(self.cli_bin, 0o755) - os.chmod(self.daemon_bin, 0o755) - - elif downloaded_file.endswith('.tar.gz'): - with tarfile.open(downloaded_file) as tar: - tar.extractall(self.bin_dir) - - return self.exists - - def ensure(self): - return self.exists or self.download() - - async def start(self): - assert self.ensure() - self.data_path = tempfile.mkdtemp() - loop = asyncio.get_event_loop() - asyncio.get_child_watcher().attach_loop(loop) - command = [ - self.daemon_bin, - f'-datadir={self.data_path}', '-printtoconsole', '-regtest', '-server', '-txindex', - f'-rpcuser={self.rpcuser}', f'-rpcpassword={self.rpcpassword}', f'-rpcport={self.rpcport}', - f'-port={self.peerport}' - ] - self.log.info(' '.join(command)) - while not self.stopped: - if self.running.is_set(): - await asyncio.sleep(1) - continue - await self.restart_ready.wait() - try: - self.transport, self.protocol = await loop.subprocess_exec( - BlockchainProcess, *command - ) - await self.protocol.ready.wait() - assert not self.protocol.stopped.is_set() - self.running.set() - except asyncio.CancelledError: - self.running.clear() - raise - except Exception as e: - self.running.clear() - log.exception('failed to start lbrycrdd', exc_info=e) - - async def stop(self, cleanup=True): - self.stopped = True - try: - self.transport.terminate() - await self.protocol.stopped.wait() - self.transport.close() - finally: - if cleanup: - self.cleanup() - - async def clear_mempool(self): - self.restart_ready.clear() - self.transport.terminate() - await self.protocol.stopped.wait() - self.transport.close() - self.running.clear() - os.remove(os.path.join(self.data_path, 'regtest', 'mempool.dat')) - self.restart_ready.set() - await self.running.wait() - - def cleanup(self): - shutil.rmtree(self.data_path, ignore_errors=True) - - async def _cli_cmnd(self, *args): - cmnd_args = [ - self.cli_bin, f'-datadir={self.data_path}', '-regtest', - f'-rpcuser={self.rpcuser}', f'-rpcpassword={self.rpcpassword}', f'-rpcport={self.rpcport}' - ] + list(args) - self.log.info(' '.join(cmnd_args)) - loop = asyncio.get_event_loop() - asyncio.get_child_watcher().attach_loop(loop) - process = await asyncio.create_subprocess_exec( - *cmnd_args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT - ) - out, _ = await process.communicate() - result = out.decode().strip() - self.log.info(result) - if result.startswith('error code'): - raise Exception(result) - return result - - def generate(self, blocks): - self.block_expected += blocks - return self._cli_cmnd('generate', str(blocks)) - - def invalidate_block(self, blockhash): - return self._cli_cmnd('invalidateblock', blockhash) - - def get_block_hash(self, block): - return self._cli_cmnd('getblockhash', str(block)) - - def sendrawtransaction(self, tx): - return self._cli_cmnd('sendrawtransaction', tx) - - async def get_block(self, block_hash): - return json.loads(await self._cli_cmnd('getblock', block_hash, '1')) - - def get_raw_change_address(self): - return self._cli_cmnd('getrawchangeaddress') - - def get_new_address(self, address_type): - return self._cli_cmnd('getnewaddress', "", address_type) - - async def get_balance(self): - return float(await self._cli_cmnd('getbalance')) - - def send_to_address(self, address, amount): - return self._cli_cmnd('sendtoaddress', address, str(amount)) - - def send_raw_transaction(self, tx): - return self._cli_cmnd('sendrawtransaction', tx.decode()) - - def create_raw_transaction(self, inputs, outputs): - return self._cli_cmnd('createrawtransaction', json.dumps(inputs), json.dumps(outputs)) - - async def sign_raw_transaction_with_wallet(self, tx): - return json.loads(await self._cli_cmnd('signrawtransactionwithwallet', tx))['hex'].encode() - - def decode_raw_transaction(self, tx): - return self._cli_cmnd('decoderawtransaction', hexlify(tx.raw).decode()) - - def get_raw_transaction(self, txid): - return self._cli_cmnd('getrawtransaction', txid, '1') diff --git a/lbry/wallet/orchstr8/service.py b/lbry/wallet/orchstr8/service.py deleted file mode 100644 index 495f68a07..000000000 --- a/lbry/wallet/orchstr8/service.py +++ /dev/null @@ -1,137 +0,0 @@ -import asyncio -import logging -from aiohttp.web import Application, WebSocketResponse, json_response -from aiohttp.http_websocket import WSMsgType, WSCloseCode - -from lbry.wallet.util import satoshis_to_coins -from .node import Conductor - - -PORT = 7954 - - -class WebSocketLogHandler(logging.Handler): - - def __init__(self, send_message): - super().__init__() - self.send_message = send_message - - def emit(self, record): - try: - self.send_message({ - 'type': 'log', - 'name': record.name, - 'message': self.format(record) - }) - except Exception: - self.handleError(record) - - -class ConductorService: - - def __init__(self, stack: Conductor, loop: asyncio.AbstractEventLoop) -> None: - self.stack = stack - self.loop = loop - self.app = Application() - self.app.router.add_post('/start', self.start_stack) - self.app.router.add_post('/generate', self.generate) - self.app.router.add_post('/transfer', self.transfer) - self.app.router.add_post('/balance', self.balance) - self.app.router.add_get('/log', self.log) - self.app['websockets'] = set() - self.app.on_shutdown.append(self.on_shutdown) - self.handler = self.app.make_handler() - self.server = None - - async def start(self): - self.server = await self.loop.create_server( - self.handler, '0.0.0.0', PORT - ) - print('serving on', self.server.sockets[0].getsockname()) - - async def stop(self): - await self.stack.stop() - self.server.close() - await self.server.wait_closed() - await self.app.shutdown() - await self.handler.shutdown(60.0) - await self.app.cleanup() - - async def start_stack(self, _): - #set_logging( - # self.stack.ledger_module, logging.DEBUG, WebSocketLogHandler(self.send_message) - #) - self.stack.blockchain_started or await self.stack.start_blockchain() - self.send_message({'type': 'service', 'name': 'blockchain', 'port': self.stack.blockchain_node.port}) - self.stack.spv_started or await self.stack.start_spv() - self.send_message({'type': 'service', 'name': 'spv', 'port': self.stack.spv_node.port}) - self.stack.wallet_started or await self.stack.start_wallet() - self.send_message({'type': 'service', 'name': 'wallet', 'port': self.stack.wallet_node.port}) - self.stack.wallet_node.ledger.on_header.listen(self.on_status) - self.stack.wallet_node.ledger.on_transaction.listen(self.on_status) - return json_response({'started': True}) - - async def generate(self, request): - data = await request.post() - blocks = data.get('blocks', 1) - await self.stack.blockchain_node.generate(int(blocks)) - return json_response({'blocks': blocks}) - - async def transfer(self, request): - data = await request.post() - address = data.get('address') - if not address and self.stack.wallet_started: - address = await self.stack.wallet_node.account.receiving.get_or_create_usable_address() - if not address: - raise ValueError("No address was provided.") - amount = data.get('amount', 1) - txid = await self.stack.blockchain_node.send_to_address(address, amount) - if self.stack.wallet_started: - await self.stack.wallet_node.ledger.on_transaction.where( - lambda e: e.tx.id == txid and e.address == address - ) - return json_response({ - 'address': address, - 'amount': amount, - 'txid': txid - }) - - async def balance(self, _): - return json_response({ - 'balance': await self.stack.blockchain_node.get_balance() - }) - - async def log(self, request): - web_socket = WebSocketResponse() - await web_socket.prepare(request) - self.app['websockets'].add(web_socket) - try: - async for msg in web_socket: - if msg.type == WSMsgType.TEXT: - if msg.data == 'close': - await web_socket.close() - elif msg.type == WSMsgType.ERROR: - print('web socket connection closed with exception %s' % - web_socket.exception()) - finally: - self.app['websockets'].remove(web_socket) - return web_socket - - @staticmethod - async def on_shutdown(app): - for web_socket in app['websockets']: - await web_socket.close(code=WSCloseCode.GOING_AWAY, message='Server shutdown') - - async def on_status(self, _): - if not self.app['websockets']: - return - self.send_message({ - 'type': 'status', - 'height': self.stack.wallet_node.ledger.headers.height, - 'balance': satoshis_to_coins(await self.stack.wallet_node.account.get_balance()), - 'miner': await self.stack.blockchain_node.get_balance() - }) - - def send_message(self, msg): - for web_socket in self.app['websockets']: - self.loop.create_task(web_socket.send_json(msg)) diff --git a/lbry/wallet/server/__init__.py b/lbry/wallet/server/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/lbry/wallet/server/block_processor.py b/lbry/wallet/server/block_processor.py deleted file mode 100644 index 44eba7d1a..000000000 --- a/lbry/wallet/server/block_processor.py +++ /dev/null @@ -1,777 +0,0 @@ -import time -import asyncio -from struct import pack, unpack -from concurrent.futures.thread import ThreadPoolExecutor -from typing import Optional -import lbry -from lbry.schema.claim import Claim -from lbry.wallet.server.db.writer import SQLDB -from lbry.wallet.server.daemon import DaemonError -from lbry.wallet.server.hash import hash_to_hex_str, HASHX_LEN -from lbry.wallet.server.util import chunks, class_logger -from lbry.wallet.server.leveldb import FlushData -from lbry.wallet.server.prometheus import BLOCK_COUNT, BLOCK_UPDATE_TIMES, REORG_COUNT - - -class Prefetcher: - """Prefetches blocks (in the forward direction only).""" - - def __init__(self, daemon, coin, blocks_event): - self.logger = class_logger(__name__, self.__class__.__name__) - self.daemon = daemon - self.coin = coin - self.blocks_event = blocks_event - self.blocks = [] - self.caught_up = False - # Access to fetched_height should be protected by the semaphore - self.fetched_height = None - self.semaphore = asyncio.Semaphore() - self.refill_event = asyncio.Event() - # The prefetched block cache size. The min cache size has - # little effect on sync time. - self.cache_size = 0 - self.min_cache_size = 10 * 1024 * 1024 - # This makes the first fetch be 10 blocks - self.ave_size = self.min_cache_size // 10 - self.polling_delay = 5 - - async def main_loop(self, bp_height): - """Loop forever polling for more blocks.""" - await self.reset_height(bp_height) - while True: - try: - # Sleep a while if there is nothing to prefetch - await self.refill_event.wait() - if not await self._prefetch_blocks(): - await asyncio.sleep(self.polling_delay) - except DaemonError as e: - self.logger.info(f'ignoring daemon error: {e}') - - def get_prefetched_blocks(self): - """Called by block processor when it is processing queued blocks.""" - blocks = self.blocks - self.blocks = [] - self.cache_size = 0 - self.refill_event.set() - return blocks - - async def reset_height(self, height): - """Reset to prefetch blocks from the block processor's height. - - Used in blockchain reorganisations. This coroutine can be - called asynchronously to the _prefetch_blocks coroutine so we - must synchronize with a semaphore. - """ - async with self.semaphore: - self.blocks.clear() - self.cache_size = 0 - self.fetched_height = height - self.refill_event.set() - - daemon_height = await self.daemon.height() - behind = daemon_height - height - if behind > 0: - self.logger.info(f'catching up to daemon height {daemon_height:,d} ' - f'({behind:,d} blocks behind)') - else: - self.logger.info(f'caught up to daemon height {daemon_height:,d}') - - async def _prefetch_blocks(self): - """Prefetch some blocks and put them on the queue. - - Repeats until the queue is full or caught up. - """ - daemon = self.daemon - daemon_height = await daemon.height() - async with self.semaphore: - while self.cache_size < self.min_cache_size: - # Try and catch up all blocks but limit to room in cache. - # Constrain fetch count to between 0 and 500 regardless; - # testnet can be lumpy. - cache_room = self.min_cache_size // self.ave_size - count = min(daemon_height - self.fetched_height, cache_room) - count = min(500, max(count, 0)) - if not count: - self.caught_up = True - return False - - first = self.fetched_height + 1 - hex_hashes = await daemon.block_hex_hashes(first, count) - if self.caught_up: - self.logger.info('new block height {:,d} hash {}' - .format(first + count-1, hex_hashes[-1])) - blocks = await daemon.raw_blocks(hex_hashes) - - assert count == len(blocks) - - # Special handling for genesis block - if first == 0: - blocks[0] = self.coin.genesis_block(blocks[0]) - self.logger.info(f'verified genesis block with hash {hex_hashes[0]}') - - # Update our recent average block size estimate - size = sum(len(block) for block in blocks) - if count >= 10: - self.ave_size = size // count - else: - self.ave_size = (size + (10 - count) * self.ave_size) // 10 - - self.blocks.extend(blocks) - self.cache_size += size - self.fetched_height += count - self.blocks_event.set() - - self.refill_event.clear() - return True - - -class ChainError(Exception): - """Raised on error processing blocks.""" - - -class BlockProcessor: - """Process blocks and update the DB state to match. - - Employ a prefetcher to prefetch blocks in batches for processing. - Coordinate backing up in case of chain reorganisations. - """ - - def __init__(self, env, db, daemon, notifications): - self.env = env - self.db = db - self.daemon = daemon - self.notifications = notifications - - self.coin = env.coin - self.blocks_event = asyncio.Event() - self.prefetcher = Prefetcher(daemon, env.coin, self.blocks_event) - self.logger = class_logger(__name__, self.__class__.__name__) - self.executor = ThreadPoolExecutor(1) - - # Meta - self.next_cache_check = 0 - self.touched = set() - self.reorg_count = 0 - - # Caches of unflushed items. - self.headers = [] - self.tx_hashes = [] - self.undo_infos = [] - - # UTXO cache - self.utxo_cache = {} - self.db_deletes = [] - - # If the lock is successfully acquired, in-memory chain state - # is consistent with self.height - self.state_lock = asyncio.Lock() - - self.search_cache = {} - - async def run_in_thread_with_lock(self, func, *args): - # Run in a thread to prevent blocking. Shielded so that - # cancellations from shutdown don't lose work - when the task - # completes the data will be flushed and then we shut down. - # Take the state lock to be certain in-memory state is - # consistent and not being updated elsewhere. - async def run_in_thread_locked(): - async with self.state_lock: - return await asyncio.get_event_loop().run_in_executor(self.executor, func, *args) - return await asyncio.shield(run_in_thread_locked()) - - async def check_and_advance_blocks(self, raw_blocks): - """Process the list of raw blocks passed. Detects and handles - reorgs. - """ - if not raw_blocks: - return - first = self.height + 1 - blocks = [self.coin.block(raw_block, first + n) - for n, raw_block in enumerate(raw_blocks)] - headers = [block.header for block in blocks] - hprevs = [self.coin.header_prevhash(h) for h in headers] - chain = [self.tip] + [self.coin.header_hash(h) for h in headers[:-1]] - - if hprevs == chain: - start = time.perf_counter() - await self.run_in_thread_with_lock(self.advance_blocks, blocks) - for cache in self.search_cache.values(): - cache.clear() - await self._maybe_flush() - processed_time = time.perf_counter() - start - BLOCK_COUNT.set(self.height) - BLOCK_UPDATE_TIMES.observe(processed_time) - if not self.db.first_sync: - s = '' if len(blocks) == 1 else 's' - self.logger.info('processed {:,d} block{} in {:.1f}s'.format(len(blocks), s, processed_time)) - if self._caught_up_event.is_set(): - await self.notifications.on_block(self.touched, self.height) - self.touched = set() - elif hprevs[0] != chain[0]: - await self.reorg_chain() - else: - # It is probably possible but extremely rare that what - # bitcoind returns doesn't form a chain because it - # reorg-ed the chain as it was processing the batched - # block hash requests. Should this happen it's simplest - # just to reset the prefetcher and try again. - self.logger.warning('daemon blocks do not form a chain; ' - 'resetting the prefetcher') - await self.prefetcher.reset_height(self.height) - - async def reorg_chain(self, count: Optional[int] = None): - """Handle a chain reorganisation. - - Count is the number of blocks to simulate a reorg, or None for - a real reorg.""" - if count is None: - self.logger.info('chain reorg detected') - else: - self.logger.info(f'faking a reorg of {count:,d} blocks') - await self.flush(True) - - async def get_raw_blocks(last_height, hex_hashes): - heights = range(last_height, last_height - len(hex_hashes), -1) - try: - blocks = [self.db.read_raw_block(height) for height in heights] - self.logger.info(f'read {len(blocks)} blocks from disk') - return blocks - except FileNotFoundError: - return await self.daemon.raw_blocks(hex_hashes) - - def flush_backup(): - # self.touched can include other addresses which is - # harmless, but remove None. - self.touched.discard(None) - self.db.flush_backup(self.flush_data(), self.touched) - - start, last, hashes = await self.reorg_hashes(count) - # Reverse and convert to hex strings. - hashes = [hash_to_hex_str(hash) for hash in reversed(hashes)] - for hex_hashes in chunks(hashes, 50): - raw_blocks = await get_raw_blocks(last, hex_hashes) - await self.run_in_thread_with_lock(self.backup_blocks, raw_blocks) - await self.run_in_thread_with_lock(flush_backup) - last -= len(raw_blocks) - await self.run_in_thread_with_lock(self.db.sql.delete_claims_above_height, self.height) - await self.prefetcher.reset_height(self.height) - REORG_COUNT.inc() - - async def reorg_hashes(self, count): - """Return a pair (start, last, hashes) of blocks to back up during a - reorg. - - The hashes are returned in order of increasing height. Start - is the height of the first hash, last of the last. - """ - start, count = await self.calc_reorg_range(count) - last = start + count - 1 - s = '' if count == 1 else 's' - self.logger.info(f'chain was reorganised replacing {count:,d} ' - f'block{s} at heights {start:,d}-{last:,d}') - - return start, last, await self.db.fs_block_hashes(start, count) - - async def calc_reorg_range(self, count: Optional[int]): - """Calculate the reorg range""" - - def diff_pos(hashes1, hashes2): - """Returns the index of the first difference in the hash lists. - If both lists match returns their length.""" - for n, (hash1, hash2) in enumerate(zip(hashes1, hashes2)): - if hash1 != hash2: - return n - return len(hashes) - - if count is None: - # A real reorg - start = self.height - 1 - count = 1 - while start > 0: - hashes = await self.db.fs_block_hashes(start, count) - hex_hashes = [hash_to_hex_str(hash) for hash in hashes] - d_hex_hashes = await self.daemon.block_hex_hashes(start, count) - n = diff_pos(hex_hashes, d_hex_hashes) - if n > 0: - start += n - break - count = min(count * 2, start) - start -= count - - count = (self.height - start) + 1 - else: - start = (self.height - count) + 1 - - return start, count - - def estimate_txs_remaining(self): - # Try to estimate how many txs there are to go - daemon_height = self.daemon.cached_height() - coin = self.coin - tail_count = daemon_height - max(self.height, coin.TX_COUNT_HEIGHT) - # Damp the initial enthusiasm - realism = max(2.0 - 0.9 * self.height / coin.TX_COUNT_HEIGHT, 1.0) - return (tail_count * coin.TX_PER_BLOCK + - max(coin.TX_COUNT - self.tx_count, 0)) * realism - - # - Flushing - def flush_data(self): - """The data for a flush. The lock must be taken.""" - assert self.state_lock.locked() - return FlushData(self.height, self.tx_count, self.headers, - self.tx_hashes, self.undo_infos, self.utxo_cache, - self.db_deletes, self.tip) - - async def flush(self, flush_utxos): - def flush(): - self.db.flush_dbs(self.flush_data(), flush_utxos, - self.estimate_txs_remaining) - await self.run_in_thread_with_lock(flush) - - async def _maybe_flush(self): - # If caught up, flush everything as client queries are - # performed on the DB. - if self._caught_up_event.is_set(): - await self.flush(True) - elif time.time() > self.next_cache_check: - flush_arg = self.check_cache_size() - if flush_arg is not None: - await self.flush(flush_arg) - self.next_cache_check = time.time() + 30 - - def check_cache_size(self): - """Flush a cache if it gets too big.""" - # Good average estimates based on traversal of subobjects and - # requesting size from Python (see deep_getsizeof). - one_MB = 1000*1000 - utxo_cache_size = len(self.utxo_cache) * 205 - db_deletes_size = len(self.db_deletes) * 57 - hist_cache_size = self.db.history.unflushed_memsize() - # Roughly ntxs * 32 + nblocks * 42 - tx_hash_size = ((self.tx_count - self.db.fs_tx_count) * 32 - + (self.height - self.db.fs_height) * 42) - utxo_MB = (db_deletes_size + utxo_cache_size) // one_MB - hist_MB = (hist_cache_size + tx_hash_size) // one_MB - - self.logger.info('our height: {:,d} daemon: {:,d} ' - 'UTXOs {:,d}MB hist {:,d}MB' - .format(self.height, self.daemon.cached_height(), - utxo_MB, hist_MB)) - - # Flush history if it takes up over 20% of cache memory. - # Flush UTXOs once they take up 80% of cache memory. - cache_MB = self.env.cache_MB - if utxo_MB + hist_MB >= cache_MB or hist_MB >= cache_MB // 5: - return utxo_MB >= cache_MB * 4 // 5 - return None - - def advance_blocks(self, blocks): - """Synchronously advance the blocks. - - It is already verified they correctly connect onto our tip. - """ - min_height = self.db.min_undo_height(self.daemon.cached_height()) - height = self.height - - for block in blocks: - height += 1 - undo_info = self.advance_txs( - height, block.transactions, self.coin.electrum_header(block.header, height) - ) - if height >= min_height: - self.undo_infos.append((undo_info, height)) - self.db.write_raw_block(block.raw, height) - - headers = [block.header for block in blocks] - self.height = height - self.headers.extend(headers) - self.tip = self.coin.header_hash(headers[-1]) - - def advance_txs(self, height, txs, header): - self.tx_hashes.append(b''.join(tx_hash for tx, tx_hash in txs)) - - # Use local vars for speed in the loops - undo_info = [] - tx_num = self.tx_count - script_hashX = self.coin.hashX_from_script - s_pack = pack - put_utxo = self.utxo_cache.__setitem__ - spend_utxo = self.spend_utxo - undo_info_append = undo_info.append - update_touched = self.touched.update - hashXs_by_tx = [] - append_hashXs = hashXs_by_tx.append - - for tx, tx_hash in txs: - hashXs = [] - append_hashX = hashXs.append - tx_numb = s_pack('= len(raw_blocks) - - coin = self.coin - for raw_block in raw_blocks: - # Check and update self.tip - block = coin.block(raw_block, self.height) - header_hash = coin.header_hash(block.header) - if header_hash != self.tip: - raise ChainError('backup block {} not tip {} at height {:,d}' - .format(hash_to_hex_str(header_hash), - hash_to_hex_str(self.tip), - self.height)) - self.tip = coin.header_prevhash(block.header) - self.backup_txs(block.transactions) - self.height -= 1 - self.db.tx_counts.pop() - - self.logger.info(f'backed up to height {self.height:,d}') - - def backup_txs(self, txs): - # Prevout values, in order down the block (coinbase first if present) - # undo_info is in reverse block order - undo_info = self.db.read_undo_info(self.height) - if undo_info is None: - raise ChainError(f'no undo information found for height {self.height:,d}') - n = len(undo_info) - - # Use local vars for speed in the loops - s_pack = pack - put_utxo = self.utxo_cache.__setitem__ - spend_utxo = self.spend_utxo - script_hashX = self.coin.hashX_from_script - touched = self.touched - undo_entry_len = 12 + HASHX_LEN - - for tx, tx_hash in reversed(txs): - for idx, txout in enumerate(tx.outputs): - # Spend the TX outputs. Be careful with unspendable - # outputs - we didn't save those in the first place. - hashX = script_hashX(txout.pk_script) - if hashX: - cache_value = spend_utxo(tx_hash, idx) - touched.add(cache_value[:-12]) - - # Restore the inputs - for txin in reversed(tx.inputs): - if txin.is_generation(): - continue - n -= undo_entry_len - undo_item = undo_info[n:n + undo_entry_len] - put_utxo(txin.prev_hash + s_pack(' 1: - tx_num, = unpack('False state. - first_sync = self.db.first_sync - self.db.first_sync = False - await self.flush(True) - if first_sync: - self.logger.info(f'{lbry.__version__} synced to ' - f'height {self.height:,d}') - # Reopen for serving - await self.db.open_for_serving() - - async def _first_open_dbs(self): - await self.db.open_for_sync() - self.height = self.db.db_height - self.tip = self.db.db_tip - self.tx_count = self.db.db_tx_count - - # --- External API - - async def fetch_and_process_blocks(self, caught_up_event): - """Fetch, process and index blocks from the daemon. - - Sets caught_up_event when first caught up. Flushes to disk - and shuts down cleanly if cancelled. - - This is mainly because if, during initial sync ElectrumX is - asked to shut down when a large number of blocks have been - processed but not written to disk, it should write those to - disk before exiting, as otherwise a significant amount of work - could be lost. - """ - self._caught_up_event = caught_up_event - try: - await self._first_open_dbs() - await asyncio.wait([ - self.prefetcher.main_loop(self.height), - self._process_prefetched_blocks() - ]) - except asyncio.CancelledError: - raise - except: - self.logger.exception("Block processing failed!") - raise - finally: - # Shut down block processing - self.logger.info('flushing to DB for a clean shutdown...') - await self.flush(True) - self.db.close() - self.executor.shutdown(wait=True) - - def force_chain_reorg(self, count): - """Force a reorg of the given number of blocks. - - Returns True if a reorg is queued, false if not caught up. - """ - if self._caught_up_event.is_set(): - self.reorg_count = count - self.blocks_event.set() - return True - return False - - -class Timer: - - def __init__(self, name): - self.name = name - self.total = 0 - self.count = 0 - self.sub_timers = {} - self._last_start = None - - def add_timer(self, name): - if name not in self.sub_timers: - self.sub_timers[name] = Timer(name) - return self.sub_timers[name] - - def run(self, func, *args, forward_timer=False, timer_name=None, **kwargs): - t = self.add_timer(timer_name or func.__name__) - t.start() - try: - if forward_timer: - return func(*args, **kwargs, timer=t) - else: - return func(*args, **kwargs) - finally: - t.stop() - - def start(self): - self._last_start = time.time() - return self - - def stop(self): - self.total += (time.time() - self._last_start) - self.count += 1 - self._last_start = None - return self - - def show(self, depth=0, height=None): - if depth == 0: - print('='*100) - if height is not None: - print(f'STATISTICS AT HEIGHT {height}') - print('='*100) - else: - print( - f"{' '*depth} {self.total/60:4.2f}mins {self.name}" - # f"{self.total/self.count:.5f}sec/call, " - ) - for sub_timer in self.sub_timers.values(): - sub_timer.show(depth+1) - if depth == 0: - print('='*100) - - -class LBRYBlockProcessor(BlockProcessor): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - if self.env.coin.NET == "regtest": - self.prefetcher.polling_delay = 0.5 - self.should_validate_signatures = self.env.boolean('VALIDATE_CLAIM_SIGNATURES', False) - self.logger.info(f"LbryumX Block Processor - Validating signatures: {self.should_validate_signatures}") - self.sql: SQLDB = self.db.sql - self.timer = Timer('BlockProcessor') - - def advance_blocks(self, blocks): - self.sql.begin() - try: - self.timer.run(super().advance_blocks, blocks) - except: - self.logger.exception(f'Error while advancing transaction in new block.') - raise - finally: - self.sql.commit() - if self.db.first_sync and self.height == self.daemon.cached_height(): - self.timer.run(self.sql.execute, self.sql.SEARCH_INDEXES, timer_name='executing SEARCH_INDEXES') - if self.env.individual_tag_indexes: - self.timer.run(self.sql.execute, self.sql.TAG_INDEXES, timer_name='executing TAG_INDEXES') - - def advance_txs(self, height, txs, header): - timer = self.timer.sub_timers['advance_blocks'] - undo = timer.run(super().advance_txs, height, txs, header, timer_name='super().advance_txs') - timer.run(self.sql.advance_txs, height, txs, header, self.daemon.cached_height(), forward_timer=True) - if (height % 10000 == 0 or not self.db.first_sync) and self.logger.isEnabledFor(10): - self.timer.show(height=height) - return undo - - def _checksig(self, value, address): - try: - claim_dict = Claim.from_bytes(value) - cert_id = claim_dict.signing_channel_hash - if not self.should_validate_signatures: - return cert_id - if cert_id: - cert_claim = self.db.get_claim_info(cert_id) - if cert_claim: - certificate = Claim.from_bytes(cert_claim.value) - claim_dict.validate_signature(address, certificate) - return cert_id - except Exception: - pass diff --git a/lbry/wallet/server/cli.py b/lbry/wallet/server/cli.py deleted file mode 100644 index 5cb15fc5e..000000000 --- a/lbry/wallet/server/cli.py +++ /dev/null @@ -1,41 +0,0 @@ -import logging -import traceback -import argparse -import importlib -from lbry.wallet.server.env import Env -from lbry.wallet.server.server import Server - - -def get_argument_parser(): - parser = argparse.ArgumentParser( - prog="torba-server" - ) - parser.add_argument("spvserver", type=str, help="Python class path to SPV server implementation.", - nargs="?", default="lbry.wallet.server.coin.LBC") - return parser - - -def get_coin_class(spvserver): - spvserver_path, coin_class_name = spvserver.rsplit('.', 1) - spvserver_module = importlib.import_module(spvserver_path) - return getattr(spvserver_module, coin_class_name) - - -def main(): - parser = get_argument_parser() - args = parser.parse_args() - coin_class = get_coin_class(args.spvserver) - logging.basicConfig(level=logging.INFO) - logging.info('lbry.server starting') - try: - server = Server(Env(coin_class)) - server.run() - except Exception: - traceback.print_exc() - logging.critical('lbry.server terminated abnormally') - else: - logging.info('lbry.server terminated normally') - - -if __name__ == "__main__": - main() diff --git a/lbry/wallet/server/coin.py b/lbry/wallet/server/coin.py deleted file mode 100644 index 3b7598eb3..000000000 --- a/lbry/wallet/server/coin.py +++ /dev/null @@ -1,350 +0,0 @@ -import re -import struct -from typing import List -from hashlib import sha256 -from decimal import Decimal -from collections import namedtuple - -import lbry.wallet.server.tx as lib_tx -from lbry.wallet.script import OutputScript, OP_CLAIM_NAME, OP_UPDATE_CLAIM, OP_SUPPORT_CLAIM -from lbry.wallet.server.tx import DeserializerSegWit -from lbry.wallet.server.util import cachedproperty, subclasses -from lbry.wallet.server.hash import Base58, hash160, double_sha256, hash_to_hex_str, HASHX_LEN -from lbry.wallet.server.daemon import Daemon, LBCDaemon -from lbry.wallet.server.script import ScriptPubKey, OpCodes -from lbry.wallet.server.leveldb import LevelDB -from lbry.wallet.server.session import LBRYElectrumX, LBRYSessionManager -from lbry.wallet.server.db.writer import LBRYLevelDB -from lbry.wallet.server.block_processor import LBRYBlockProcessor - - -Block = namedtuple("Block", "raw header transactions") -OP_RETURN = OpCodes.OP_RETURN - - -class CoinError(Exception): - """Exception raised for coin-related errors.""" - - -class Coin: - """Base class of coin hierarchy.""" - - REORG_LIMIT = 200 - # Not sure if these are coin-specific - RPC_URL_REGEX = re.compile('.+@(\\[[0-9a-fA-F:]+\\]|[^:]+)(:[0-9]+)?') - VALUE_PER_COIN = 100000000 - CHUNK_SIZE = 2016 - BASIC_HEADER_SIZE = 80 - STATIC_BLOCK_HEADERS = True - SESSIONCLS = LBRYElectrumX - DESERIALIZER = lib_tx.Deserializer - DAEMON = Daemon - BLOCK_PROCESSOR = LBRYBlockProcessor - SESSION_MANAGER = LBRYSessionManager - DB = LevelDB - HEADER_VALUES = [ - 'version', 'prev_block_hash', 'merkle_root', 'timestamp', 'bits', 'nonce' - ] - HEADER_UNPACK = struct.Struct('< I 32s 32s I I I').unpack_from - MEMPOOL_HISTOGRAM_REFRESH_SECS = 500 - XPUB_VERBYTES = bytes('????', 'utf-8') - XPRV_VERBYTES = bytes('????', 'utf-8') - ENCODE_CHECK = Base58.encode_check - DECODE_CHECK = Base58.decode_check - # Peer discovery - PEER_DEFAULT_PORTS = {'t': '50001', 's': '50002'} - PEERS: List[str] = [] - - @classmethod - def lookup_coin_class(cls, name, net): - """Return a coin class given name and network. - - Raise an exception if unrecognised.""" - req_attrs = ['TX_COUNT', 'TX_COUNT_HEIGHT', 'TX_PER_BLOCK'] - for coin in subclasses(Coin): - if (coin.NAME.lower() == name.lower() and - coin.NET.lower() == net.lower()): - coin_req_attrs = req_attrs.copy() - missing = [attr for attr in coin_req_attrs - if not hasattr(coin, attr)] - if missing: - raise CoinError(f'coin {name} missing {missing} attributes') - return coin - raise CoinError(f'unknown coin {name} and network {net} combination') - - @classmethod - def sanitize_url(cls, url): - # Remove surrounding ws and trailing /s - url = url.strip().rstrip('/') - match = cls.RPC_URL_REGEX.match(url) - if not match: - raise CoinError(f'invalid daemon URL: "{url}"') - if match.groups()[1] is None: - url += f':{cls.RPC_PORT:d}' - if not url.startswith('http://') and not url.startswith('https://'): - url = 'http://' + url - return url + '/' - - @classmethod - def genesis_block(cls, block): - """Check the Genesis block is the right one for this coin. - - Return the block less its unspendable coinbase. - """ - header = cls.block_header(block, 0) - header_hex_hash = hash_to_hex_str(cls.header_hash(header)) - if header_hex_hash != cls.GENESIS_HASH: - raise CoinError(f'genesis block has hash {header_hex_hash} expected {cls.GENESIS_HASH}') - - return header + bytes(1) - - @classmethod - def hashX_from_script(cls, script): - """Returns a hashX from a script, or None if the script is provably - unspendable so the output can be dropped. - """ - if script and script[0] == OP_RETURN: - return None - return sha256(script).digest()[:HASHX_LEN] - - @staticmethod - def lookup_xverbytes(verbytes): - """Return a (is_xpub, coin_class) pair given xpub/xprv verbytes.""" - # Order means BTC testnet will override NMC testnet - for coin in subclasses(Coin): - if verbytes == coin.XPUB_VERBYTES: - return True, coin - if verbytes == coin.XPRV_VERBYTES: - return False, coin - raise CoinError('version bytes unrecognised') - - @classmethod - def address_to_hashX(cls, address): - """Return a hashX given a coin address.""" - return cls.hashX_from_script(cls.pay_to_address_script(address)) - - @classmethod - def P2PKH_address_from_hash160(cls, hash160): - """Return a P2PKH address given a public key.""" - assert len(hash160) == 20 - return cls.ENCODE_CHECK(cls.P2PKH_VERBYTE + hash160) - - @classmethod - def P2PKH_address_from_pubkey(cls, pubkey): - """Return a coin address given a public key.""" - return cls.P2PKH_address_from_hash160(hash160(pubkey)) - - @classmethod - def P2SH_address_from_hash160(cls, hash160): - """Return a coin address given a hash160.""" - assert len(hash160) == 20 - return cls.ENCODE_CHECK(cls.P2SH_VERBYTES[0] + hash160) - - @classmethod - def hash160_to_P2PKH_script(cls, hash160): - return ScriptPubKey.P2PKH_script(hash160) - - @classmethod - def hash160_to_P2PKH_hashX(cls, hash160): - return cls.hashX_from_script(cls.hash160_to_P2PKH_script(hash160)) - - @classmethod - def pay_to_address_script(cls, address): - """Return a pubkey script that pays to a pubkey hash. - - Pass the address (either P2PKH or P2SH) in base58 form. - """ - raw = cls.DECODE_CHECK(address) - - # Require version byte(s) plus hash160. - verbyte = -1 - verlen = len(raw) - 20 - if verlen > 0: - verbyte, hash160 = raw[:verlen], raw[verlen:] - - if verbyte == cls.P2PKH_VERBYTE: - return cls.hash160_to_P2PKH_script(hash160) - if verbyte in cls.P2SH_VERBYTES: - return ScriptPubKey.P2SH_script(hash160) - - raise CoinError(f'invalid address: {address}') - - @classmethod - def privkey_WIF(cls, privkey_bytes, compressed): - """Return the private key encoded in Wallet Import Format.""" - payload = bytearray(cls.WIF_BYTE) + privkey_bytes - if compressed: - payload.append(0x01) - return cls.ENCODE_CHECK(payload) - - @classmethod - def header_hash(cls, header): - """Given a header return hash""" - return double_sha256(header) - - @classmethod - def header_prevhash(cls, header): - """Given a header return previous hash""" - return header[4:36] - - @classmethod - def static_header_offset(cls, height): - """Given a header height return its offset in the headers file. - - If header sizes change at some point, this is the only code - that needs updating.""" - assert cls.STATIC_BLOCK_HEADERS - return height * cls.BASIC_HEADER_SIZE - - @classmethod - def static_header_len(cls, height): - """Given a header height return its length.""" - return (cls.static_header_offset(height + 1) - - cls.static_header_offset(height)) - - @classmethod - def block_header(cls, block, height): - """Returns the block header given a block and its height.""" - return block[:cls.static_header_len(height)] - - @classmethod - def block(cls, raw_block, height): - """Return a Block namedtuple given a raw block and its height.""" - header = cls.block_header(raw_block, height) - txs = cls.DESERIALIZER(raw_block, start=len(header)).read_tx_block() - return Block(raw_block, header, txs) - - @classmethod - def decimal_value(cls, value): - """Return the number of standard coin units as a Decimal given a - quantity of smallest units. - - For example 1 BTC is returned for 100 million satoshis. - """ - return Decimal(value) / cls.VALUE_PER_COIN - - @classmethod - def electrum_header(cls, header, height): - h = dict(zip(cls.HEADER_VALUES, cls.HEADER_UNPACK(header))) - # Add the height that is not present in the header itself - h['block_height'] = height - # Convert bytes to str - h['prev_block_hash'] = hash_to_hex_str(h['prev_block_hash']) - h['merkle_root'] = hash_to_hex_str(h['merkle_root']) - return h - - -class LBC(Coin): - DAEMON = LBCDaemon - SESSIONCLS = LBRYElectrumX - BLOCK_PROCESSOR = LBRYBlockProcessor - SESSION_MANAGER = LBRYSessionManager - DESERIALIZER = DeserializerSegWit - DB = LBRYLevelDB - NAME = "LBRY" - SHORTNAME = "LBC" - NET = "mainnet" - BASIC_HEADER_SIZE = 112 - CHUNK_SIZE = 96 - XPUB_VERBYTES = bytes.fromhex("0488b21e") - XPRV_VERBYTES = bytes.fromhex("0488ade4") - P2PKH_VERBYTE = bytes.fromhex("55") - P2SH_VERBYTES = bytes.fromhex("7A") - WIF_BYTE = bytes.fromhex("1C") - GENESIS_HASH = ('9c89283ba0f3227f6c03b70216b9f665' - 'f0118d5e0fa729cedf4fb34d6a34f463') - TX_COUNT = 2716936 - TX_COUNT_HEIGHT = 329554 - TX_PER_BLOCK = 1 - RPC_PORT = 9245 - REORG_LIMIT = 200 - PEERS = [ - ] - - @classmethod - def genesis_block(cls, block): - '''Check the Genesis block is the right one for this coin. - - Return the block less its unspendable coinbase. - ''' - header = cls.block_header(block, 0) - header_hex_hash = hash_to_hex_str(cls.header_hash(header)) - if header_hex_hash != cls.GENESIS_HASH: - raise CoinError(f'genesis block has hash {header_hex_hash} expected {cls.GENESIS_HASH}') - - return block - - @classmethod - def electrum_header(cls, header, height): - version, = struct.unpack(' 1: - self.url_index = (self.url_index + 1) % len(self.urls) - self.logger.info(f'failing over to {self.logged_url()}') - return True - return False - - def client_session(self): - """An aiohttp client session.""" - return aiohttp.ClientSession(connector=self.connector, connector_owner=False) - - async def _send_data(self, data): - if not self.connector: - raise asyncio.CancelledError('Tried to send request during shutdown.') - async with self.workqueue_semaphore: - async with self.client_session() as session: - async with session.post(self.current_url(), data=data) as resp: - kind = resp.headers.get('Content-Type', None) - if kind == 'application/json': - return await resp.json() - # bitcoind's HTTP protocol "handling" is a bad joke - text = await resp.text() - if 'Work queue depth exceeded' in text: - raise WorkQueueFullError - text = text.strip() or resp.reason - self.logger.error(text) - raise DaemonError(text) - - async def _send(self, payload, processor): - """Send a payload to be converted to JSON. - - Handles temporary connection issues. Daemon response errors - are raise through DaemonError. - """ - - def log_error(error): - nonlocal last_error_log, retry - now = time.time() - if now - last_error_log > 60: - last_error_log = now - self.logger.error(f'{error} Retrying occasionally...') - if retry == self.max_retry and self.failover(): - retry = 0 - - on_good_message = None - last_error_log = 0 - data = json.dumps(payload) - retry = self.init_retry - methods = tuple( - [payload['method']] if isinstance(payload, dict) else [request['method'] for request in payload] - ) - while True: - try: - for method in methods: - LBRYCRD_PENDING_COUNT.labels(method=method).inc() - result = await self._send_data(data) - result = processor(result) - if on_good_message: - self.logger.info(on_good_message) - return result - except asyncio.TimeoutError: - log_error('timeout error.') - except aiohttp.ServerDisconnectedError: - log_error('disconnected.') - on_good_message = 'connection restored' - except aiohttp.ClientConnectionError: - log_error('connection problem - is your daemon running?') - on_good_message = 'connection restored' - except aiohttp.ClientError as e: - log_error(f'daemon error: {e}') - on_good_message = 'running normally' - except WarmingUpError: - log_error('starting up checking blocks.') - on_good_message = 'running normally' - except WorkQueueFullError: - log_error('work queue full.') - on_good_message = 'running normally' - finally: - for method in methods: - LBRYCRD_PENDING_COUNT.labels(method=method).dec() - await asyncio.sleep(retry) - retry = max(min(self.max_retry, retry * 2), self.init_retry) - - async def _send_single(self, method, params=None): - """Send a single request to the daemon.""" - - start = time.perf_counter() - - def processor(result): - err = result['error'] - if not err: - return result['result'] - if err.get('code') == self.WARMING_UP: - raise WarmingUpError - raise DaemonError(err) - - payload = {'method': method, 'id': next(self.id_counter)} - if params: - payload['params'] = params - result = await self._send(payload, processor) - LBRYCRD_REQUEST_TIMES.labels(method=method).observe(time.perf_counter() - start) - return result - - async def _send_vector(self, method, params_iterable, replace_errs=False): - """Send several requests of the same method. - - The result will be an array of the same length as params_iterable. - If replace_errs is true, any item with an error is returned as None, - otherwise an exception is raised.""" - - start = time.perf_counter() - - def processor(result): - errs = [item['error'] for item in result if item['error']] - if any(err.get('code') == self.WARMING_UP for err in errs): - raise WarmingUpError - if not errs or replace_errs: - return [item['result'] for item in result] - raise DaemonError(errs) - - payload = [{'method': method, 'params': p, 'id': next(self.id_counter)} - for p in params_iterable] - result = [] - if payload: - result = await self._send(payload, processor) - LBRYCRD_REQUEST_TIMES.labels(method=method).observe(time.perf_counter()-start) - return result - - async def _is_rpc_available(self, method): - """Return whether given RPC method is available in the daemon. - - Results are cached and the daemon will generally not be queried with - the same method more than once.""" - available = self.available_rpcs.get(method) - if available is None: - available = True - try: - await self._send_single(method) - except DaemonError as e: - err = e.args[0] - error_code = err.get("code") - available = error_code != JSONRPC.METHOD_NOT_FOUND - self.available_rpcs[method] = available - return available - - async def block_hex_hashes(self, first, count): - """Return the hex hashes of count block starting at height first.""" - if first + count < (self.cached_height() or 0) - 200: - return await self._cached_block_hex_hashes(first, count) - params_iterable = ((h, ) for h in range(first, first + count)) - return await self._send_vector('getblockhash', params_iterable) - - async def _cached_block_hex_hashes(self, first, count): - """Return the hex hashes of count block starting at height first.""" - cached = self._block_hash_cache.get((first, count)) - if cached: - return cached - params_iterable = ((h, ) for h in range(first, first + count)) - self._block_hash_cache[(first, count)] = await self._send_vector('getblockhash', params_iterable) - return self._block_hash_cache[(first, count)] - - async def deserialised_block(self, hex_hash): - """Return the deserialised block with the given hex hash.""" - if not self._block_cache.get(hex_hash): - self._block_cache[hex_hash] = await self._send_single('getblock', (hex_hash, True)) - return self._block_cache[hex_hash] - - async def raw_blocks(self, hex_hashes): - """Return the raw binary blocks with the given hex hashes.""" - params_iterable = ((h, False) for h in hex_hashes) - blocks = await self._send_vector('getblock', params_iterable) - # Convert hex string to bytes - return [hex_to_bytes(block) for block in blocks] - - async def mempool_hashes(self): - """Update our record of the daemon's mempool hashes.""" - return await self._send_single('getrawmempool') - - async def estimatefee(self, block_count): - """Return the fee estimate for the block count. Units are whole - currency units per KB, e.g. 0.00000995, or -1 if no estimate - is available. - """ - args = (block_count, ) - if await self._is_rpc_available('estimatesmartfee'): - estimate = await self._send_single('estimatesmartfee', args) - return estimate.get('feerate', -1) - return await self._send_single('estimatefee', args) - - async def getnetworkinfo(self): - """Return the result of the 'getnetworkinfo' RPC call.""" - return await self._send_single('getnetworkinfo') - - async def relayfee(self): - """The minimum fee a low-priority tx must pay in order to be accepted - to the daemon's memory pool.""" - network_info = await self.getnetworkinfo() - return network_info['relayfee'] - - async def getrawtransaction(self, hex_hash, verbose=False): - """Return the serialized raw transaction with the given hash.""" - # Cast to int because some coin daemons are old and require it - return await self._send_single('getrawtransaction', - (hex_hash, int(verbose))) - - async def getrawtransactions(self, hex_hashes, replace_errs=True): - """Return the serialized raw transactions with the given hashes. - - Replaces errors with None by default.""" - params_iterable = ((hex_hash, 0) for hex_hash in hex_hashes) - txs = await self._send_vector('getrawtransaction', params_iterable, - replace_errs=replace_errs) - # Convert hex strings to bytes - return [hex_to_bytes(tx) if tx else None for tx in txs] - - async def broadcast_transaction(self, raw_tx): - """Broadcast a transaction to the network.""" - return await self._send_single('sendrawtransaction', (raw_tx, )) - - async def height(self): - """Query the daemon for its current height.""" - self._height = await self._send_single('getblockcount') - return self._height - - def cached_height(self): - """Return the cached daemon height. - - If the daemon has not been queried yet this returns None.""" - return self._height - - -def handles_errors(decorated_function): - @wraps(decorated_function) - async def wrapper(*args, **kwargs): - try: - return await decorated_function(*args, **kwargs) - except DaemonError as daemon_error: - raise RPCError(1, daemon_error.args[0]) - return wrapper - - -class LBCDaemon(Daemon): - @handles_errors - async def getrawtransaction(self, hex_hash, verbose=False): - return await super().getrawtransaction(hex_hash=hex_hash, verbose=verbose) - - @handles_errors - async def getclaimbyid(self, claim_id): - '''Given a claim id, retrieves claim information.''' - return await self._send_single('getclaimbyid', (claim_id,)) - - @handles_errors - async def getclaimsbyids(self, claim_ids): - '''Given a list of claim ids, batches calls to retrieve claim information.''' - return await self._send_vector('getclaimbyid', ((claim_id,) for claim_id in claim_ids)) - - @handles_errors - async def getclaimsforname(self, name): - '''Given a name, retrieves all claims matching that name.''' - return await self._send_single('getclaimsforname', (name,)) - - @handles_errors - async def getclaimsfortx(self, txid): - '''Given a txid, returns the claims it make.''' - return await self._send_single('getclaimsfortx', (txid,)) or [] - - @handles_errors - async def getnameproof(self, name, block_hash=None): - '''Given a name and optional block_hash, returns a name proof and winner, if any.''' - return await self._send_single('getnameproof', (name, block_hash,) if block_hash else (name,)) - - @handles_errors - async def getvalueforname(self, name): - '''Given a name, returns the winning claim value.''' - return await self._send_single('getvalueforname', (name,)) - - @handles_errors - async def claimname(self, name, hexvalue, amount): - '''Claim a name, used for functional tests only.''' - return await self._send_single('claimname', (name, hexvalue, float(amount))) diff --git a/lbry/wallet/server/db/__init__.py b/lbry/wallet/server/db/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/lbry/wallet/server/db/canonical.py b/lbry/wallet/server/db/canonical.py deleted file mode 100644 index a85fc8369..000000000 --- a/lbry/wallet/server/db/canonical.py +++ /dev/null @@ -1,26 +0,0 @@ -class FindShortestID: - __slots__ = 'short_id', 'new_id' - - def __init__(self): - self.short_id = '' - self.new_id = None - - def step(self, other_id, new_id): - self.new_id = new_id - for i in range(len(self.new_id)): - if other_id[i] != self.new_id[i]: - if i > len(self.short_id)-1: - self.short_id = self.new_id[:i+1] - break - - def finalize(self): - if self.short_id: - return '#'+self.short_id - - @classmethod - def factory(cls): - return cls(), cls.step, cls.finalize - - -def register_canonical_functions(connection): - connection.createaggregatefunction("shortest_id", FindShortestID.factory, 2) diff --git a/lbry/wallet/server/db/common.py b/lbry/wallet/server/db/common.py deleted file mode 100644 index 8f75737e7..000000000 --- a/lbry/wallet/server/db/common.py +++ /dev/null @@ -1,221 +0,0 @@ -CLAIM_TYPES = { - 'stream': 1, - 'channel': 2, - 'repost': 3 -} - -STREAM_TYPES = { - 'video': 1, - 'audio': 2, - 'image': 3, - 'document': 4, - 'binary': 5, - 'model': 6 -} - -MATURE_TAGS = [ - 'nsfw', 'porn', 'xxx', 'mature', 'adult', 'sex' -] - -COMMON_TAGS = { - "gaming": "gaming", - "people & blogs": "people_and_blogs", - "pop culture": "pop_culture", - "entertainment": "entertainment", - "technology": "technology", - "music": "music", - "funny": "funny", - "education": "education", - "learning": "learning", - "news": "news", - "gameplay": "gameplay", - "science & technology": "science_and_technology", - "playstation 4": "playstation_4", - "beliefs": "beliefs", - "nature": "nature", - "news & politics": "news_and_politics", - "comedy": "comedy", - "games": "games", - "sony interactive entertainment": "sony_interactive_entertainment", - "film & animation": "film_and_animation", - "game": "game", - "howto & style": "howto_and_style", - "weapons": "weapons", - "blockchain": "blockchain", - "video game": "video_game", - "sports": "sports", - "walkthrough": "walkthrough", - "ps4live": "ps4live", - "art": "art", - "pc": "pc", - "economics": "economics", - "automotive": "automotive", - "minecraft": "minecraft", - "playthrough": "playthrough", - "ps4share": "ps4share", - "tutorial": "tutorial", - "play": "play", - "twitch": "twitch", - "how to": "how_to", - "ps4": "ps4", - "bitcoin": "bitcoin", - "fortnite": "fortnite", - "commentary": "commentary", - "lets play": "lets_play", - "fun": "fun", - "politics": "politics", - "xbox": "xbox", - "autos & vehicles": "autos_and_vehicles", - "travel & events": "travel_and_events", - "food": "food", - "science": "science", - "mature": "mature", - "xbox one": "xbox_one", - "liberal": "liberal", - "democrat": "democrat", - "progressive": "progressive", - "survival": "survival", - "nonprofits & activism": "nonprofits_and_activism", - "cryptocurrency": "cryptocurrency", - "playstation": "playstation", - "nintendo": "nintendo", - "government": "government", - "steam": "steam", - "podcast": "podcast", - "horror": "horror", - "conservative": "conservative", - "reaction": "reaction", - "trailer": "trailer", - "love": "love", - "cnn": "cnn", - "republican": "republican", - "gamer": "gamer", - "political": "political", - "hangoutsonair": "hangoutsonair", - "hoa": "hoa", - "msnbc": "msnbc", - "cbs": "cbs", - "donald trump": "donald_trump", - "fiction": "fiction", - "fox news": "fox_news", - "anime": "anime", - "crypto": "crypto", - "ethereum": "ethereum", - "call of duty": "call_of_duty", - "multiplayer": "multiplayer", - "android": "android", - "epic": "epic", - "rpg": "rpg", - "adventure": "adventure", - "secular talk": "secular_talk", - "btc": "btc", - "atheist": "atheist", - "atheism": "atheism", - "ps3": "ps3", - "video games": "video_games", - "cod": "cod", - "agnostic": "agnostic", - "movie": "movie", - "online": "online", - "fps": "fps", - "mod": "mod", - "reviews": "reviews", - "sharefactory": "sharefactory", - "world": "world", - "space": "space", - "hilarious": "hilarious", - "stream": "stream", - "lol": "lol", - "sony": "sony", - "god": "god", - "lets": "lets", - "dance": "dance", - "pvp": "pvp", - "tech": "tech", - "zombies": "zombies", - "pokemon": "pokemon", - "fail": "fail", - "xbox 360": "xbox_360", - "film": "film", - "unboxing": "unboxing", - "animation": "animation", - "travel": "travel", - "money": "money", - "wwe": "wwe", - "how": "how", - "mods": "mods", - "pubg": "pubg", - "indie": "indie", - "strategy": "strategy", - "history": "history", - "rap": "rap", - "ios": "ios", - "sony computer entertainment": "sony_computer_entertainment", - "mobile": "mobile", - "trump": "trump", - "flat earth": "flat_earth", - "hack": "hack", - "trap": "trap", - "fox": "fox", - "vlogging": "vlogging", - "news radio": "news_radio", - "humor": "humor", - "facebook": "facebook", - "edm": "edm", - "fitness": "fitness", - "vaping": "vaping", - "hip hop": "hip_hop", - "secular": "secular", - "jesus": "jesus", - "vape": "vape", - "song": "song", - "remix": "remix", - "guitar": "guitar", - "daily": "daily", - "mining": "mining", - "diy": "diy", - "videogame": "videogame", - "pets & animals": "pets_and_animals", - "funny moments": "funny_moments", - "religion": "religion", - "death": "death", - "media": "media", - "nbc": "nbc", - "war": "war", - "freedom": "freedom", - "viral": "viral", - "meme": "meme", - "family": "family", - "gold": "gold", - "photography": "photography", - "chill": "chill", - "zombie": "zombie", - "computer": "computer", - "sniper": "sniper", - "bible": "bible", - "linux": "linux", - "overwatch": "overwatch", - "pro": "pro", - "dragon": "dragon", - "litecoin": "litecoin", - "gta": "gta", - "iphone": "iphone", - "house": "house", - "bass": "bass", - "bitcoin news": "bitcoin_news", - "wii": "wii", - "crash": "crash", - "league of legends": "league_of_legends", - "grand theft auto v": "grand_theft_auto_v", - "mario": "mario", - "mmorpg": "mmorpg", - "satire": "satire", - "fire": "fire", - "racing": "racing", - "apple": "apple", - "health": "health", - "instrumental": "instrumental", - "destiny": "destiny", - "truth": "truth", - "race": "race" -} diff --git a/lbry/wallet/server/db/trending/__init__.py b/lbry/wallet/server/db/trending/__init__.py deleted file mode 100644 index 86d94bdc3..000000000 --- a/lbry/wallet/server/db/trending/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from . import zscore -from . import ar -from . import variable_decay - -TRENDING_ALGORITHMS = { - 'zscore': zscore, - 'ar': ar, - 'variable_decay': variable_decay -} diff --git a/lbry/wallet/server/db/trending/ar.py b/lbry/wallet/server/db/trending/ar.py deleted file mode 100644 index e53579005..000000000 --- a/lbry/wallet/server/db/trending/ar.py +++ /dev/null @@ -1,265 +0,0 @@ -import copy -import math -import time - -# Half life in blocks -HALF_LIFE = 134 - -# Decay coefficient per block -DECAY = 0.5**(1.0/HALF_LIFE) - -# How frequently to write trending values to the db -SAVE_INTERVAL = 10 - -# Renormalisation interval -RENORM_INTERVAL = 1000 - -# Assertion -assert RENORM_INTERVAL % SAVE_INTERVAL == 0 - -# Decay coefficient per renormalisation interval -DECAY_PER_RENORM = DECAY**(RENORM_INTERVAL) - -# Log trending calculations? -TRENDING_LOG = True - - -def install(connection): - """ - Install the AR trending algorithm. - """ - check_trending_values(connection) - - if TRENDING_LOG: - f = open("trending_ar.log", "a") - f.close() - -# Stub -CREATE_TREND_TABLE = "" - - -def check_trending_values(connection): - """ - If the trending values appear to be based on the zscore algorithm, - reset them. This will allow resyncing from a standard snapshot. - """ - c = connection.cursor() - needs_reset = False - for row in c.execute("SELECT COUNT(*) num FROM claim WHERE trending_global <> 0;"): - if row[0] != 0: - needs_reset = True - break - - if needs_reset: - print("Resetting some columns. This might take a while...", flush=True, end="") - c.execute(""" BEGIN; - UPDATE claim SET trending_group = 0; - UPDATE claim SET trending_mixed = 0; - UPDATE claim SET trending_global = 0; - UPDATE claim SET trending_local = 0; - COMMIT;""") - print("done.") - - -def spike_height(trending_score, x, x_old, time_boost=1.0): - """ - Compute the size of a trending spike. - """ - - # Change in softened amount - change_in_softened_amount = x**0.25 - x_old**0.25 - - # Softened change in amount - delta = x - x_old - softened_change_in_amount = abs(delta)**0.25 - - # Softened change in amount counts more for minnows - if delta > 0.0: - if trending_score >= 0.0: - multiplier = 0.1/((trending_score/time_boost + softened_change_in_amount) + 1.0) - softened_change_in_amount *= multiplier - else: - softened_change_in_amount *= -1.0 - - return time_boost*(softened_change_in_amount + change_in_softened_amount) - - -def get_time_boost(height): - """ - Return the time boost at a given height. - """ - return 1.0/DECAY**(height % RENORM_INTERVAL) - - -def trending_log(s): - """ - Log a string. - """ - if TRENDING_LOG: - fout = open("trending_ar.log", "a") - fout.write(s) - fout.flush() - fout.close() - -class TrendingData: - """ - An object of this class holds trending data - """ - def __init__(self): - self.claims = {} - - # Have all claims been read from db yet? - self.initialised = False - - def insert_claim_from_load(self, claim_hash, trending_score, total_amount): - assert not self.initialised - self.claims[claim_hash] = {"trending_score": trending_score, - "total_amount": total_amount, - "changed": False} - - - def update_claim(self, claim_hash, total_amount, time_boost=1.0): - """ - Update trending data for a claim, given its new total amount. - """ - assert self.initialised - - # Extract existing total amount and trending score - # or use starting values if the claim is new - if claim_hash in self.claims: - old_state = copy.deepcopy(self.claims[claim_hash]) - else: - old_state = {"trending_score": 0.0, - "total_amount": 0.0, - "changed": False} - - # Calculate LBC change - change = total_amount - old_state["total_amount"] - - # Modify data if there was an LBC change - if change != 0.0: - spike = spike_height(old_state["trending_score"], - total_amount, - old_state["total_amount"], - time_boost) - trending_score = old_state["trending_score"] + spike - self.claims[claim_hash] = {"total_amount": total_amount, - "trending_score": trending_score, - "changed": True} - - - -def test_trending(): - """ - Quick trending test for something receiving 10 LBC per block - """ - data = TrendingData() - data.insert_claim_from_load("abc", 10.0, 1.0) - data.initialised = True - - for height in range(1, 5000): - - if height % RENORM_INTERVAL == 0: - data.claims["abc"]["trending_score"] *= DECAY_PER_RENORM - - time_boost = get_time_boost(height) - data.update_claim("abc", data.claims["abc"]["total_amount"] + 10.0, - time_boost=time_boost) - - - print(str(height) + " " + str(time_boost) + " " \ - + str(data.claims["abc"]["trending_score"])) - - - -# One global instance -# pylint: disable=C0103 -trending_data = TrendingData() - -def run(db, height, final_height, recalculate_claim_hashes): - - if height < final_height - 5*HALF_LIFE: - trending_log("Skipping AR trending at block {h}.\n".format(h=height)) - return - - start = time.time() - - trending_log("Calculating AR trending at block {h}.\n".format(h=height)) - trending_log(" Length of trending data = {l}.\n"\ - .format(l=len(trending_data.claims))) - - # Renormalise trending scores and mark all as having changed - if height % RENORM_INTERVAL == 0: - trending_log(" Renormalising trending scores...") - - keys = trending_data.claims.keys() - for key in keys: - if trending_data.claims[key]["trending_score"] != 0.0: - trending_data.claims[key]["trending_score"] *= DECAY_PER_RENORM - trending_data.claims[key]["changed"] = True - - # Tiny becomes zero - if abs(trending_data.claims[key]["trending_score"]) < 1E-9: - trending_data.claims[key]["trending_score"] = 0.0 - - trending_log("done.\n") - - - # Regular message. - trending_log(" Reading total_amounts from db and updating"\ - + " trending scores in RAM...") - - # Get the value of the time boost - time_boost = get_time_boost(height) - - # Update claims from db - if not trending_data.initialised: - # On fresh launch - for row in db.execute(""" - SELECT claim_hash, trending_mixed, - (amount + support_amount) - AS total_amount - FROM claim; - """): - trending_data.insert_claim_from_load(row[0], row[1], 1E-8*row[2]) - trending_data.initialised = True - else: - for row in db.execute(f""" - SELECT claim_hash, - (amount + support_amount) - AS total_amount - FROM claim - WHERE claim_hash IN - ({','.join('?' for _ in recalculate_claim_hashes)}); - """, recalculate_claim_hashes): - trending_data.update_claim(row[0], 1E-8*row[1], time_boost) - - trending_log("done.\n") - - - # Write trending scores to DB - if height % SAVE_INTERVAL == 0: - - trending_log(" Writing trending scores to db...") - - the_list = [] - keys = trending_data.claims.keys() - for key in keys: - if trending_data.claims[key]["changed"]: - the_list.append((trending_data.claims[key]["trending_score"], - key)) - trending_data.claims[key]["changed"] = False - - trending_log("{n} scores to write...".format(n=len(the_list))) - - db.executemany("UPDATE claim SET trending_mixed=? WHERE claim_hash=?;", - the_list) - - trending_log("done.\n") - - trending_log("Trending operations took {time} seconds.\n\n"\ - .format(time=time.time() - start)) - - -if __name__ == "__main__": - test_trending() diff --git a/lbry/wallet/server/db/trending/variable_decay.py b/lbry/wallet/server/db/trending/variable_decay.py deleted file mode 100644 index a64c0ff7c..000000000 --- a/lbry/wallet/server/db/trending/variable_decay.py +++ /dev/null @@ -1,431 +0,0 @@ -""" -Delayed AR with variable decay rate. - -The spike height function is also simpler. -""" - -import copy -import time -import apsw - -# Half life in blocks *for lower LBC claims* (it's shorter for whale claims) -HALF_LIFE = 200 - -# Whale threshold (higher -> less DB writing) -WHALE_THRESHOLD = 3.0 - -# Decay coefficient per block -DECAY = 0.5**(1.0/HALF_LIFE) - -# How frequently to write trending values to the db -SAVE_INTERVAL = 10 - -# Renormalisation interval -RENORM_INTERVAL = 1000 - -# Assertion -assert RENORM_INTERVAL % SAVE_INTERVAL == 0 - -# Decay coefficient per renormalisation interval -DECAY_PER_RENORM = DECAY**(RENORM_INTERVAL) - -# Log trending calculations? -TRENDING_LOG = True - - -def install(connection): - """ - Install the trending algorithm. - """ - check_trending_values(connection) - - if TRENDING_LOG: - f = open("trending_variable_decay.log", "a") - f.close() - -# Stub -CREATE_TREND_TABLE = "" - - -def check_trending_values(connection): - """ - If the trending values appear to be based on the zscore algorithm, - reset them. This will allow resyncing from a standard snapshot. - """ - c = connection.cursor() - needs_reset = False - for row in c.execute("SELECT COUNT(*) num FROM claim WHERE trending_global <> 0;"): - if row[0] != 0: - needs_reset = True - break - - if needs_reset: - print("Resetting some columns. This might take a while...", flush=True, end="") - c.execute(""" BEGIN; - UPDATE claim SET trending_group = 0; - UPDATE claim SET trending_mixed = 0; - UPDATE claim SET trending_global = 0; - UPDATE claim SET trending_local = 0; - COMMIT;""") - print("done.") - - -def spike_height(x, x_old): - """ - Compute the size of a trending spike (normed - constant units). - """ - - # Sign of trending spike - sign = 1.0 - if x < x_old: - sign = -1.0 - - # Magnitude - mag = abs(x**0.25 - x_old**0.25) - - # Minnow boost - mag *= 1.0 + 2E4/(x + 100.0)**2 - - return sign*mag - - - -def get_time_boost(height): - """ - Return the time boost at a given height. - """ - return 1.0/DECAY**(height % RENORM_INTERVAL) - - -def trending_log(s): - """ - Log a string. - """ - if TRENDING_LOG: - fout = open("trending_variable_decay.log", "a") - fout.write(s) - fout.flush() - fout.close() - -class TrendingData: - """ - An object of this class holds trending data - """ - def __init__(self): - - # Dict from claim id to some trending info. - # Units are TIME VARIABLE in here - self.claims = {} - - # Claims with >= WHALE_THRESHOLD LBC total amount - self.whales = set([]) - - # Have all claims been read from db yet? - self.initialised = False - - # List of pending spikes. - # Units are CONSTANT in here - self.pending_spikes = [] - - def insert_claim_from_load(self, height, claim_hash, trending_score, total_amount): - assert not self.initialised - self.claims[claim_hash] = {"trending_score": trending_score, - "total_amount": total_amount, - "changed": False} - - if trending_score >= WHALE_THRESHOLD*get_time_boost(height): - self.add_whale(claim_hash) - - def add_whale(self, claim_hash): - self.whales.add(claim_hash) - - def apply_spikes(self, height): - """ - Apply all pending spikes that are due at this height. - Apply with time boost ON. - """ - time_boost = get_time_boost(height) - - for spike in self.pending_spikes: - if spike["height"] > height: - # Ignore - pass - if spike["height"] == height: - # Apply - self.claims[spike["claim_hash"]]["trending_score"] += time_boost*spike["size"] - self.claims[spike["claim_hash"]]["changed"] = True - - if self.claims[spike["claim_hash"]]["trending_score"] >= WHALE_THRESHOLD*time_boost: - self.add_whale(spike["claim_hash"]) - if spike["claim_hash"] in self.whales and \ - self.claims[spike["claim_hash"]]["trending_score"] < WHALE_THRESHOLD*time_boost: - self.whales.remove(spike["claim_hash"]) - - - # Keep only future spikes - self.pending_spikes = [s for s in self.pending_spikes \ - if s["height"] > height] - - - - - def update_claim(self, height, claim_hash, total_amount): - """ - Update trending data for a claim, given its new total amount. - """ - assert self.initialised - - # Extract existing total amount and trending score - # or use starting values if the claim is new - if claim_hash in self.claims: - old_state = copy.deepcopy(self.claims[claim_hash]) - else: - old_state = {"trending_score": 0.0, - "total_amount": 0.0, - "changed": False} - - # Calculate LBC change - change = total_amount - old_state["total_amount"] - - # Modify data if there was an LBC change - if change != 0.0: - spike = spike_height(total_amount, - old_state["total_amount"]) - delay = min(int((total_amount + 1E-8)**0.4), HALF_LIFE) - - if change < 0.0: - - # How big would the spike be for the inverse movement? - reverse_spike = spike_height(old_state["total_amount"], total_amount) - - # Remove that much spike from future pending ones - for future_spike in self.pending_spikes: - if future_spike["claim_hash"] == claim_hash: - if reverse_spike >= future_spike["size"]: - reverse_spike -= future_spike["size"] - future_spike["size"] = 0.0 - elif reverse_spike > 0.0: - future_spike["size"] -= reverse_spike - reverse_spike = 0.0 - - delay = 0 - spike = -reverse_spike - - self.pending_spikes.append({"height": height + delay, - "claim_hash": claim_hash, - "size": spike}) - - self.claims[claim_hash] = {"total_amount": total_amount, - "trending_score": old_state["trending_score"], - "changed": False} - - def process_whales(self, height): - """ - Whale claims decay faster. - """ - if height % SAVE_INTERVAL != 0: - return - - for claim_hash in self.whales: - trending_normed = self.claims[claim_hash]["trending_score"]/get_time_boost(height) - - # Overall multiplication factor for decay rate - decay_rate_factor = trending_normed/WHALE_THRESHOLD - - # The -1 is because this is just the *extra* part being applied - factor = (DECAY**SAVE_INTERVAL)**(decay_rate_factor - 1.0) - # print(claim_hash, trending_normed, decay_rate_factor) - self.claims[claim_hash]["trending_score"] *= factor - self.claims[claim_hash]["changed"] = True - - -def test_trending(): - """ - Quick trending test for claims with different support patterns. - Actually use the run() function. - """ - - # Create a fake "claims.db" for testing - # pylint: disable=I1101 - dbc = apsw.Connection(":memory:") - db = dbc.cursor() - - # Create table - db.execute(""" - BEGIN; - CREATE TABLE claim (claim_hash TEXT PRIMARY KEY, - amount REAL NOT NULL DEFAULT 0.0, - support_amount REAL NOT NULL DEFAULT 0.0, - trending_mixed REAL NOT NULL DEFAULT 0.0); - COMMIT; - """) - - # Insert initial states of claims - everything = {"huge_whale": 0.01, - "huge_whale_botted": 0.01, - "medium_whale": 0.01, - "small_whale": 0.01, - "minnow": 0.01} - - def to_list_of_tuples(stuff): - l = [] - for key in stuff: - l.append((key, stuff[key])) - return l - - db.executemany(""" - INSERT INTO claim (claim_hash, amount) VALUES (?, 1E8*?); - """, to_list_of_tuples(everything)) - - height = 0 - run(db, height, height, everything.keys()) - - # Save trajectories for plotting - trajectories = {} - for key in trending_data.claims: - trajectories[key] = [trending_data.claims[key]["trending_score"]] - - # Main loop - for height in range(1, 1000): - - # One-off supports - if height == 1: - everything["huge_whale"] += 5E5 - everything["medium_whale"] += 5E4 - everything["small_whale"] += 5E3 - - # Every block - if height < 500: - everything["huge_whale_botted"] += 5E5/500 - everything["minnow"] += 1 - - # Remove supports - if height == 500: - for key in everything: - everything[key] = 0.01 - - # Whack into the db - db.executemany(""" - UPDATE claim SET amount = 1E8*? WHERE claim_hash = ?; - """, [(y, x) for (x, y) in to_list_of_tuples(everything)]) - - # Call run() - run(db, height, height, everything.keys()) - - for key in trending_data.claims: - trajectories[key].append(trending_data.claims[key]["trending_score"]\ - /get_time_boost(height)) - - dbc.close() - - # pylint: disable=C0415 - import matplotlib.pyplot as plt - for key in trending_data.claims: - plt.plot(trajectories[key], label=key) - plt.legend() - plt.show() - - -# One global instance -# pylint: disable=C0103 -trending_data = TrendingData() - -def run(db, height, final_height, recalculate_claim_hashes): - - if height < final_height - 5*HALF_LIFE: - trending_log("Skipping variable_decay trending at block {h}.\n".format(h=height)) - return - - start = time.time() - - trending_log("Calculating variable_decay trending at block {h}.\n".format(h=height)) - trending_log(" Length of trending data = {l}.\n"\ - .format(l=len(trending_data.claims))) - - # Renormalise trending scores and mark all as having changed - if height % RENORM_INTERVAL == 0: - trending_log(" Renormalising trending scores...") - - keys = trending_data.claims.keys() - trending_data.whales = set([]) - for key in keys: - if trending_data.claims[key]["trending_score"] != 0.0: - trending_data.claims[key]["trending_score"] *= DECAY_PER_RENORM - trending_data.claims[key]["changed"] = True - - # Tiny becomes zero - if abs(trending_data.claims[key]["trending_score"]) < 1E-3: - trending_data.claims[key]["trending_score"] = 0.0 - - # Re-mark whales - if trending_data.claims[key]["trending_score"] >= WHALE_THRESHOLD*get_time_boost(height): - trending_data.add_whale(key) - - trending_log("done.\n") - - - # Regular message. - trending_log(" Reading total_amounts from db and updating"\ - + " trending scores in RAM...") - - # Update claims from db - if not trending_data.initialised: - - trending_log("initial load...") - # On fresh launch - for row in db.execute(""" - SELECT claim_hash, trending_mixed, - (amount + support_amount) - AS total_amount - FROM claim; - """): - trending_data.insert_claim_from_load(height, row[0], row[1], 1E-8*row[2]) - trending_data.initialised = True - else: - for row in db.execute(f""" - SELECT claim_hash, - (amount + support_amount) - AS total_amount - FROM claim - WHERE claim_hash IN - ({','.join('?' for _ in recalculate_claim_hashes)}); - """, recalculate_claim_hashes): - trending_data.update_claim(height, row[0], 1E-8*row[1]) - - # Apply pending spikes - trending_data.apply_spikes(height) - - trending_log("done.\n") - - - # Write trending scores to DB - if height % SAVE_INTERVAL == 0: - - trending_log(" Finding and processing whales...") - trending_log(str(len(trending_data.whales)) + " whales found...") - trending_data.process_whales(height) - trending_log("done.\n") - - trending_log(" Writing trending scores to db...") - - the_list = [] - keys = trending_data.claims.keys() - - for key in keys: - if trending_data.claims[key]["changed"]: - the_list.append((trending_data.claims[key]["trending_score"], key)) - trending_data.claims[key]["changed"] = False - - trending_log("{n} scores to write...".format(n=len(the_list))) - - db.executemany("UPDATE claim SET trending_mixed=? WHERE claim_hash=?;", - the_list) - - trending_log("done.\n") - - trending_log("Trending operations took {time} seconds.\n\n"\ - .format(time=time.time() - start)) - - -if __name__ == "__main__": - test_trending() diff --git a/lbry/wallet/server/db/trending/zscore.py b/lbry/wallet/server/db/trending/zscore.py deleted file mode 100644 index bc0987d96..000000000 --- a/lbry/wallet/server/db/trending/zscore.py +++ /dev/null @@ -1,123 +0,0 @@ -from math import sqrt - -# TRENDING_WINDOW is the number of blocks in ~6hr period (21600 seconds / 161 seconds per block) -TRENDING_WINDOW = 134 - -# TRENDING_DATA_POINTS says how many samples to use for the trending algorithm -# i.e. only consider claims from the most recent (TRENDING_WINDOW * TRENDING_DATA_POINTS) blocks -TRENDING_DATA_POINTS = 28 - -CREATE_TREND_TABLE = """ - create table if not exists trend ( - claim_hash bytes not null, - height integer not null, - amount integer not null, - primary key (claim_hash, height) - ) without rowid; -""" - - -class ZScore: - __slots__ = 'count', 'total', 'power', 'last' - - def __init__(self): - self.count = 0 - self.total = 0 - self.power = 0 - self.last = None - - def step(self, value): - if self.last is not None: - self.count += 1 - self.total += self.last - self.power += self.last ** 2 - self.last = value - - @property - def mean(self): - return self.total / self.count - - @property - def standard_deviation(self): - value = (self.power / self.count) - self.mean ** 2 - return sqrt(value) if value > 0 else 0 - - def finalize(self): - if self.count == 0: - return self.last - return (self.last - self.mean) / (self.standard_deviation or 1) - - @classmethod - def factory(cls): - return cls(), cls.step, cls.finalize - - -def install(connection): - connection.createaggregatefunction("zscore", ZScore.factory, 1) - connection.cursor().execute(CREATE_TREND_TABLE) - - -def run(db, height, final_height, affected_claims): - # don't start tracking until we're at the end of initial sync - if height < (final_height - (TRENDING_WINDOW * TRENDING_DATA_POINTS)): - return - - if height % TRENDING_WINDOW != 0: - return - - db.execute(f""" - DELETE FROM trend WHERE height < {height - (TRENDING_WINDOW * TRENDING_DATA_POINTS)} - """) - - start = (height - TRENDING_WINDOW) + 1 - db.execute(f""" - INSERT OR IGNORE INTO trend (claim_hash, height, amount) - SELECT claim_hash, {start}, COALESCE( - (SELECT SUM(amount) FROM support WHERE claim_hash=claim.claim_hash - AND height >= {start}), 0 - ) AS support_sum - FROM claim WHERE support_sum > 0 - """) - - zscore = ZScore() - for global_sum in db.execute("SELECT AVG(amount) AS avg_amount FROM trend GROUP BY height"): - zscore.step(global_sum.avg_amount) - global_mean, global_deviation = 0, 1 - if zscore.count > 0: - global_mean = zscore.mean - global_deviation = zscore.standard_deviation - - db.execute(f""" - UPDATE claim SET - trending_local = COALESCE(( - SELECT zscore(amount) FROM trend - WHERE claim_hash=claim.claim_hash ORDER BY height DESC - ), 0), - trending_global = COALESCE(( - SELECT (amount - {global_mean}) / {global_deviation} FROM trend - WHERE claim_hash=claim.claim_hash AND height = {start} - ), 0), - trending_group = 0, - trending_mixed = 0 - """) - - # trending_group and trending_mixed determine how trending will show in query results - # normally the SQL will be: "ORDER BY trending_group, trending_mixed" - # changing the trending_group will have significant impact on trending results - # changing the value used for trending_mixed will only impact trending within a trending_group - db.execute(f""" - UPDATE claim SET - trending_group = CASE - WHEN trending_local > 0 AND trending_global > 0 THEN 4 - WHEN trending_local <= 0 AND trending_global > 0 THEN 3 - WHEN trending_local > 0 AND trending_global <= 0 THEN 2 - WHEN trending_local <= 0 AND trending_global <= 0 THEN 1 - END, - trending_mixed = CASE - WHEN trending_local > 0 AND trending_global > 0 THEN trending_global - WHEN trending_local <= 0 AND trending_global > 0 THEN trending_local - WHEN trending_local > 0 AND trending_global <= 0 THEN trending_local - WHEN trending_local <= 0 AND trending_global <= 0 THEN trending_global - END - WHERE trending_local <> 0 OR trending_global <> 0 - """) diff --git a/lbry/wallet/server/db/writer.py b/lbry/wallet/server/db/writer.py deleted file mode 100644 index 988b7b266..000000000 --- a/lbry/wallet/server/db/writer.py +++ /dev/null @@ -1,894 +0,0 @@ -import os -import apsw -from typing import Union, Tuple, Set, List -from itertools import chain -from decimal import Decimal -from collections import namedtuple -from multiprocessing import Manager -from binascii import unhexlify - -from lbry.wallet.server.leveldb import LevelDB -from lbry.wallet.server.util import class_logger -from lbry.wallet.database import query, constraints_to_sql - -from lbry.schema.tags import clean_tags -from lbry.schema.mime_types import guess_stream_type -from lbry.wallet import Ledger, RegTestLedger -from lbry.wallet.transaction import Transaction, Output -from lbry.wallet.server.db.canonical import register_canonical_functions -from lbry.wallet.server.db.full_text_search import update_full_text_search, CREATE_FULL_TEXT_SEARCH, first_sync_finished -from lbry.wallet.server.db.trending import TRENDING_ALGORITHMS - -from .common import CLAIM_TYPES, STREAM_TYPES, COMMON_TAGS - - -ATTRIBUTE_ARRAY_MAX_LENGTH = 100 - - -class SQLDB: - - PRAGMAS = """ - pragma journal_mode=WAL; - """ - - CREATE_CLAIM_TABLE = """ - create table if not exists claim ( - claim_hash bytes primary key, - claim_id text not null, - claim_name text not null, - normalized text not null, - txo_hash bytes not null, - tx_position integer not null, - amount integer not null, - timestamp integer not null, -- last updated timestamp - creation_timestamp integer not null, - height integer not null, -- last updated height - creation_height integer not null, - activation_height integer, - expiration_height integer not null, - release_time integer not null, - - short_url text not null, -- normalized#shortest-unique-claim_id - canonical_url text, -- channel's-short_url/normalized#shortest-unique-claim_id-within-channel - - title text, - author text, - description text, - - claim_type integer, - reposted integer default 0, - - -- streams - stream_type text, - media_type text, - fee_amount integer default 0, - fee_currency text, - duration integer, - - -- reposts - reposted_claim_hash bytes, - - -- claims which are channels - public_key_bytes bytes, - public_key_hash bytes, - claims_in_channel integer, - - -- claims which are inside channels - channel_hash bytes, - channel_join integer, -- height at which claim got valid signature / joined channel - signature bytes, - signature_digest bytes, - signature_valid bool, - - effective_amount integer not null default 0, - support_amount integer not null default 0, - trending_group integer not null default 0, - trending_mixed integer not null default 0, - trending_local integer not null default 0, - trending_global integer not null default 0 - ); - - create index if not exists claim_normalized_idx on claim (normalized, activation_height); - create index if not exists claim_channel_hash_idx on claim (channel_hash, signature, claim_hash); - create index if not exists claim_claims_in_channel_idx on claim (signature_valid, channel_hash, normalized); - create index if not exists claim_txo_hash_idx on claim (txo_hash); - create index if not exists claim_activation_height_idx on claim (activation_height, claim_hash); - create index if not exists claim_expiration_height_idx on claim (expiration_height); - create index if not exists claim_reposted_claim_hash_idx on claim (reposted_claim_hash); - """ - - CREATE_SUPPORT_TABLE = """ - create table if not exists support ( - txo_hash bytes primary key, - tx_position integer not null, - height integer not null, - claim_hash bytes not null, - amount integer not null - ); - create index if not exists support_claim_hash_idx on support (claim_hash, height); - """ - - CREATE_TAG_TABLE = """ - create table if not exists tag ( - tag text not null, - claim_hash bytes not null, - height integer not null - ); - create unique index if not exists tag_claim_hash_tag_idx on tag (claim_hash, tag); - """ - - CREATE_CLAIMTRIE_TABLE = """ - create table if not exists claimtrie ( - normalized text primary key, - claim_hash bytes not null, - last_take_over_height integer not null - ); - create index if not exists claimtrie_claim_hash_idx on claimtrie (claim_hash); - """ - - SEARCH_INDEXES = """ - -- used by any tag clouds - create index if not exists tag_tag_idx on tag (tag, claim_hash); - - -- naked order bys (no filters) - create unique index if not exists claim_release_idx on claim (release_time, claim_hash); - create unique index if not exists claim_trending_idx on claim (trending_group, trending_mixed, claim_hash); - create unique index if not exists claim_effective_amount_idx on claim (effective_amount, claim_hash); - - -- claim_type filter + order by - create unique index if not exists claim_type_release_idx on claim (release_time, claim_type, claim_hash); - create unique index if not exists claim_type_trending_idx on claim (trending_group, trending_mixed, claim_type, claim_hash); - create unique index if not exists claim_type_effective_amount_idx on claim (effective_amount, claim_type, claim_hash); - - -- stream_type filter + order by - create unique index if not exists stream_type_release_idx on claim (stream_type, release_time, claim_hash); - create unique index if not exists stream_type_trending_idx on claim (stream_type, trending_group, trending_mixed, claim_hash); - create unique index if not exists stream_type_effective_amount_idx on claim (stream_type, effective_amount, claim_hash); - - -- channel_hash filter + order by - create unique index if not exists channel_hash_release_idx on claim (channel_hash, release_time, claim_hash); - create unique index if not exists channel_hash_trending_idx on claim (channel_hash, trending_group, trending_mixed, claim_hash); - create unique index if not exists channel_hash_effective_amount_idx on claim (channel_hash, effective_amount, claim_hash); - - -- duration filter + order by - create unique index if not exists duration_release_idx on claim (duration, release_time, claim_hash); - create unique index if not exists duration_trending_idx on claim (duration, trending_group, trending_mixed, claim_hash); - create unique index if not exists duration_effective_amount_idx on claim (duration, effective_amount, claim_hash); - - -- fee_amount + order by - create unique index if not exists fee_amount_release_idx on claim (fee_amount, release_time, claim_hash); - create unique index if not exists fee_amount_trending_idx on claim (fee_amount, trending_group, trending_mixed, claim_hash); - create unique index if not exists fee_amount_effective_amount_idx on claim (fee_amount, effective_amount, claim_hash); - - -- TODO: verify that all indexes below are used - create index if not exists claim_height_normalized_idx on claim (height, normalized asc); - create index if not exists claim_resolve_idx on claim (normalized, claim_id); - create index if not exists claim_id_idx on claim (claim_id, claim_hash); - create index if not exists claim_timestamp_idx on claim (timestamp); - create index if not exists claim_public_key_hash_idx on claim (public_key_hash); - create index if not exists claim_signature_valid_idx on claim (signature_valid); - """ - - TAG_INDEXES = '\n'.join( - f"create unique index if not exists tag_{tag_key}_idx on tag (tag, claim_hash) WHERE tag='{tag_value}';" - for tag_value, tag_key in COMMON_TAGS.items() - ) - - CREATE_TABLES_QUERY = ( - CREATE_CLAIM_TABLE + - CREATE_FULL_TEXT_SEARCH + - CREATE_SUPPORT_TABLE + - CREATE_CLAIMTRIE_TABLE + - CREATE_TAG_TABLE - ) - - def __init__( - self, main, path: str, blocking_channels: list, filtering_channels: list, trending: list): - self.main = main - self._db_path = path - self.db = None - self.logger = class_logger(__name__, self.__class__.__name__) - self.ledger = Ledger if main.coin.NET == 'mainnet' else RegTestLedger - self._fts_synced = False - self.state_manager = None - self.blocked_streams = None - self.blocked_channels = None - self.blocking_channel_hashes = { - unhexlify(channel_id)[::-1] for channel_id in blocking_channels if channel_id - } - self.filtered_streams = None - self.filtered_channels = None - self.filtering_channel_hashes = { - unhexlify(channel_id)[::-1] for channel_id in filtering_channels if channel_id - } - self.trending = trending - - def open(self): - self.db = apsw.Connection( - self._db_path, - flags=( - apsw.SQLITE_OPEN_READWRITE | - apsw.SQLITE_OPEN_CREATE | - apsw.SQLITE_OPEN_URI - ) - ) - def exec_factory(cursor, statement, bindings): - tpl = namedtuple('row', (d[0] for d in cursor.getdescription())) - cursor.setrowtrace(lambda cursor, row: tpl(*row)) - return True - self.db.setexectrace(exec_factory) - self.execute(self.PRAGMAS) - self.execute(self.CREATE_TABLES_QUERY) - register_canonical_functions(self.db) - self.state_manager = Manager() - self.blocked_streams = self.state_manager.dict() - self.blocked_channels = self.state_manager.dict() - self.filtered_streams = self.state_manager.dict() - self.filtered_channels = self.state_manager.dict() - self.update_blocked_and_filtered_claims() - for algorithm in self.trending: - algorithm.install(self.db) - - def close(self): - if self.db is not None: - self.db.close() - if self.state_manager is not None: - self.state_manager.shutdown() - - def update_blocked_and_filtered_claims(self): - self.update_claims_from_channel_hashes( - self.blocked_streams, self.blocked_channels, self.blocking_channel_hashes - ) - self.update_claims_from_channel_hashes( - self.filtered_streams, self.filtered_channels, self.filtering_channel_hashes - ) - self.filtered_streams.update(self.blocked_streams) - self.filtered_channels.update(self.blocked_channels) - - def update_claims_from_channel_hashes(self, shared_streams, shared_channels, channel_hashes): - streams, channels = {}, {} - if channel_hashes: - sql = query( - "SELECT repost.channel_hash, repost.reposted_claim_hash, target.claim_type " - "FROM claim as repost JOIN claim AS target ON (target.claim_hash=repost.reposted_claim_hash)", **{ - 'repost.reposted_claim_hash__is_not_null': 1, - 'repost.channel_hash__in': channel_hashes - } - ) - for blocked_claim in self.execute(*sql): - if blocked_claim.claim_type == CLAIM_TYPES['stream']: - streams[blocked_claim.reposted_claim_hash] = blocked_claim.channel_hash - elif blocked_claim.claim_type == CLAIM_TYPES['channel']: - channels[blocked_claim.reposted_claim_hash] = blocked_claim.channel_hash - shared_streams.clear() - shared_streams.update(streams) - shared_channels.clear() - shared_channels.update(channels) - - @staticmethod - def _insert_sql(table: str, data: dict) -> Tuple[str, list]: - columns, values = [], [] - for column, value in data.items(): - columns.append(column) - values.append(value) - sql = ( - f"INSERT INTO {table} ({', '.join(columns)}) " - f"VALUES ({', '.join(['?'] * len(values))})" - ) - return sql, values - - @staticmethod - def _update_sql(table: str, data: dict, where: str, - constraints: Union[list, tuple]) -> Tuple[str, list]: - columns, values = [], [] - for column, value in data.items(): - columns.append(f"{column} = ?") - values.append(value) - values.extend(constraints) - return f"UPDATE {table} SET {', '.join(columns)} WHERE {where}", values - - @staticmethod - def _delete_sql(table: str, constraints: dict) -> Tuple[str, dict]: - where, values = constraints_to_sql(constraints) - return f"DELETE FROM {table} WHERE {where}", values - - def execute(self, *args): - return self.db.cursor().execute(*args) - - def executemany(self, *args): - return self.db.cursor().executemany(*args) - - def begin(self): - self.execute('begin;') - - def commit(self): - self.execute('commit;') - - def _upsertable_claims(self, txos: List[Output], header, clear_first=False): - claim_hashes, claims, tags = set(), [], {} - for txo in txos: - tx = txo.tx_ref.tx - - try: - assert txo.claim_name - assert txo.normalized_name - except: - #self.logger.exception(f"Could not decode claim name for {tx.id}:{txo.position}.") - continue - - claim_hash = txo.claim_hash - claim_hashes.add(claim_hash) - claim_record = { - 'claim_hash': claim_hash, - 'claim_id': txo.claim_id, - 'claim_name': txo.claim_name, - 'normalized': txo.normalized_name, - 'txo_hash': txo.ref.hash, - 'tx_position': tx.position, - 'amount': txo.amount, - 'timestamp': header['timestamp'], - 'height': tx.height, - 'title': None, - 'description': None, - 'author': None, - 'duration': None, - 'claim_type': None, - 'stream_type': None, - 'media_type': None, - 'release_time': None, - 'fee_currency': None, - 'fee_amount': 0, - 'reposted_claim_hash': None - } - claims.append(claim_record) - - try: - claim = txo.claim - except: - #self.logger.exception(f"Could not parse claim protobuf for {tx.id}:{txo.position}.") - continue - - if claim.is_stream: - claim_record['claim_type'] = CLAIM_TYPES['stream'] - claim_record['media_type'] = claim.stream.source.media_type - claim_record['stream_type'] = STREAM_TYPES[guess_stream_type(claim_record['media_type'])] - claim_record['title'] = claim.stream.title - claim_record['description'] = claim.stream.description - claim_record['author'] = claim.stream.author - if claim.stream.video and claim.stream.video.duration: - claim_record['duration'] = claim.stream.video.duration - if claim.stream.audio and claim.stream.audio.duration: - claim_record['duration'] = claim.stream.audio.duration - if claim.stream.release_time: - claim_record['release_time'] = claim.stream.release_time - if claim.stream.has_fee: - fee = claim.stream.fee - if isinstance(fee.currency, str): - claim_record['fee_currency'] = fee.currency.lower() - if isinstance(fee.amount, Decimal): - claim_record['fee_amount'] = int(fee.amount*1000) - elif claim.is_repost: - claim_record['claim_type'] = CLAIM_TYPES['repost'] - claim_record['reposted_claim_hash'] = claim.repost.reference.claim_hash - elif claim.is_channel: - claim_record['claim_type'] = CLAIM_TYPES['channel'] - - for tag in clean_tags(claim.message.tags): - tags[(tag, claim_hash)] = (tag, claim_hash, tx.height) - - if clear_first: - self._clear_claim_metadata(claim_hashes) - - if tags: - self.executemany( - "INSERT OR IGNORE INTO tag (tag, claim_hash, height) VALUES (?, ?, ?)", tags.values() - ) - - return claims - - def insert_claims(self, txos: List[Output], header): - claims = self._upsertable_claims(txos, header) - if claims: - self.executemany(""" - INSERT OR IGNORE INTO claim ( - claim_hash, claim_id, claim_name, normalized, txo_hash, tx_position, amount, - claim_type, media_type, stream_type, timestamp, creation_timestamp, - fee_currency, fee_amount, title, description, author, duration, height, reposted_claim_hash, - creation_height, release_time, activation_height, expiration_height, short_url) - VALUES ( - :claim_hash, :claim_id, :claim_name, :normalized, :txo_hash, :tx_position, :amount, - :claim_type, :media_type, :stream_type, :timestamp, :timestamp, - :fee_currency, :fee_amount, :title, :description, :author, :duration, :height, :reposted_claim_hash, :height, - CASE WHEN :release_time IS NOT NULL THEN :release_time ELSE :timestamp END, - CASE WHEN :normalized NOT IN (SELECT normalized FROM claimtrie) THEN :height END, - CASE WHEN :height >= 137181 THEN :height+2102400 ELSE :height+262974 END, - :claim_name||COALESCE( - (SELECT shortest_id(claim_id, :claim_id) FROM claim WHERE normalized = :normalized), - '#'||substr(:claim_id, 1, 1) - ) - )""", claims) - - def update_claims(self, txos: List[Output], header): - claims = self._upsertable_claims(txos, header, clear_first=True) - if claims: - self.executemany(""" - UPDATE claim SET - txo_hash=:txo_hash, tx_position=:tx_position, amount=:amount, height=:height, - claim_type=:claim_type, media_type=:media_type, stream_type=:stream_type, - timestamp=:timestamp, fee_amount=:fee_amount, fee_currency=:fee_currency, - title=:title, duration=:duration, description=:description, author=:author, reposted_claim_hash=:reposted_claim_hash, - release_time=CASE WHEN :release_time IS NOT NULL THEN :release_time ELSE release_time END - WHERE claim_hash=:claim_hash; - """, claims) - - def delete_claims(self, claim_hashes: Set[bytes]): - """ Deletes claim supports and from claimtrie in case of an abandon. """ - if claim_hashes: - affected_channels = self.execute(*query( - "SELECT channel_hash FROM claim", channel_hash__is_not_null=1, claim_hash__in=claim_hashes - )).fetchall() - for table in ('claim', 'support', 'claimtrie'): - self.execute(*self._delete_sql(table, {'claim_hash__in': claim_hashes})) - self._clear_claim_metadata(claim_hashes) - return {r.channel_hash for r in affected_channels} - return set() - - def delete_claims_above_height(self, height: int): - claim_hashes = [x[0] for x in self.execute( - "SELECT claim_hash FROM claim WHERE height>?", (height, ) - ).fetchall()] - while claim_hashes: - batch = set(claim_hashes[:500]) - claim_hashes = claim_hashes[500:] - self.delete_claims(batch) - - def _clear_claim_metadata(self, claim_hashes: Set[bytes]): - if claim_hashes: - for table in ('tag',): # 'language', 'location', etc - self.execute(*self._delete_sql(table, {'claim_hash__in': claim_hashes})) - - def split_inputs_into_claims_supports_and_other(self, txis): - txo_hashes = {txi.txo_ref.hash for txi in txis} - claims = self.execute(*query( - "SELECT txo_hash, claim_hash, normalized FROM claim", txo_hash__in=txo_hashes - )).fetchall() - txo_hashes -= {r.txo_hash for r in claims} - supports = {} - if txo_hashes: - supports = self.execute(*query( - "SELECT txo_hash, claim_hash FROM support", txo_hash__in=txo_hashes - )).fetchall() - txo_hashes -= {r.txo_hash for r in supports} - return claims, supports, txo_hashes - - def insert_supports(self, txos: List[Output]): - supports = [] - for txo in txos: - tx = txo.tx_ref.tx - supports.append(( - txo.ref.hash, tx.position, tx.height, - txo.claim_hash, txo.amount - )) - if supports: - self.executemany( - "INSERT OR IGNORE INTO support (" - " txo_hash, tx_position, height, claim_hash, amount" - ") " - "VALUES (?, ?, ?, ?, ?)", supports - ) - - def delete_supports(self, txo_hashes: Set[bytes]): - if txo_hashes: - self.execute(*self._delete_sql('support', {'txo_hash__in': txo_hashes})) - - def calculate_reposts(self, txos: List[Output]): - targets = set() - for txo in txos: - try: - claim = txo.claim - except: - continue - if claim.is_repost: - targets.add((claim.repost.reference.claim_hash,)) - if targets: - self.executemany( - """ - UPDATE claim SET reposted = ( - SELECT count(*) FROM claim AS repost WHERE repost.reposted_claim_hash = claim.claim_hash - ) - WHERE claim_hash = ? - """, targets - ) - - def validate_channel_signatures(self, height, new_claims, updated_claims, spent_claims, affected_channels, timer): - if not new_claims and not updated_claims and not spent_claims: - return - - sub_timer = timer.add_timer('segregate channels and signables') - sub_timer.start() - channels, new_channel_keys, signables = {}, {}, {} - for txo in chain(new_claims, updated_claims): - try: - claim = txo.claim - except: - continue - if claim.is_channel: - channels[txo.claim_hash] = txo - new_channel_keys[txo.claim_hash] = claim.channel.public_key_bytes - else: - signables[txo.claim_hash] = txo - sub_timer.stop() - - sub_timer = timer.add_timer('make list of channels we need to lookup') - sub_timer.start() - missing_channel_keys = set() - for txo in signables.values(): - claim = txo.claim - if claim.is_signed and claim.signing_channel_hash not in new_channel_keys: - missing_channel_keys.add(claim.signing_channel_hash) - sub_timer.stop() - - sub_timer = timer.add_timer('lookup missing channels') - sub_timer.start() - all_channel_keys = {} - if new_channel_keys or missing_channel_keys or affected_channels: - all_channel_keys = dict(self.execute(*query( - "SELECT claim_hash, public_key_bytes FROM claim", - claim_hash__in=set(new_channel_keys) | missing_channel_keys | affected_channels - ))) - sub_timer.stop() - - sub_timer = timer.add_timer('prepare for updating claims') - sub_timer.start() - changed_channel_keys = {} - for claim_hash, new_key in new_channel_keys.items(): - if claim_hash not in all_channel_keys or all_channel_keys[claim_hash] != new_key: - all_channel_keys[claim_hash] = new_key - changed_channel_keys[claim_hash] = new_key - - claim_updates = [] - - for claim_hash, txo in signables.items(): - claim = txo.claim - update = { - 'claim_hash': claim_hash, - 'channel_hash': None, - 'signature': None, - 'signature_digest': None, - 'signature_valid': None - } - if claim.is_signed: - update.update({ - 'channel_hash': claim.signing_channel_hash, - 'signature': txo.get_encoded_signature(), - 'signature_digest': txo.get_signature_digest(self.ledger), - 'signature_valid': 0 - }) - claim_updates.append(update) - sub_timer.stop() - - sub_timer = timer.add_timer('find claims affected by a change in channel key') - sub_timer.start() - if changed_channel_keys: - sql = f""" - SELECT * FROM claim WHERE - channel_hash IN ({','.join('?' for _ in changed_channel_keys)}) AND - signature IS NOT NULL - """ - for affected_claim in self.execute(sql, changed_channel_keys.keys()): - if affected_claim.claim_hash not in signables: - claim_updates.append({ - 'claim_hash': affected_claim.claim_hash, - 'channel_hash': affected_claim.channel_hash, - 'signature': affected_claim.signature, - 'signature_digest': affected_claim.signature_digest, - 'signature_valid': 0 - }) - sub_timer.stop() - - sub_timer = timer.add_timer('verify signatures') - sub_timer.start() - for update in claim_updates: - channel_pub_key = all_channel_keys.get(update['channel_hash']) - if channel_pub_key and update['signature']: - update['signature_valid'] = Output.is_signature_valid( - bytes(update['signature']), bytes(update['signature_digest']), channel_pub_key - ) - sub_timer.stop() - - sub_timer = timer.add_timer('update claims') - sub_timer.start() - if claim_updates: - self.executemany(f""" - UPDATE claim SET - channel_hash=:channel_hash, signature=:signature, signature_digest=:signature_digest, - signature_valid=:signature_valid, - channel_join=CASE - WHEN signature_valid=1 AND :signature_valid=1 AND channel_hash=:channel_hash THEN channel_join - WHEN :signature_valid=1 THEN {height} - END, - canonical_url=CASE - WHEN signature_valid=1 AND :signature_valid=1 AND channel_hash=:channel_hash THEN canonical_url - WHEN :signature_valid=1 THEN - (SELECT short_url FROM claim WHERE claim_hash=:channel_hash)||'/'|| - claim_name||COALESCE( - (SELECT shortest_id(other_claim.claim_id, claim.claim_id) FROM claim AS other_claim - WHERE other_claim.signature_valid = 1 AND - other_claim.channel_hash = :channel_hash AND - other_claim.normalized = claim.normalized), - '#'||substr(claim_id, 1, 1) - ) - END - WHERE claim_hash=:claim_hash; - """, claim_updates) - sub_timer.stop() - - sub_timer = timer.add_timer('update claims affected by spent channels') - sub_timer.start() - if spent_claims: - self.execute( - f""" - UPDATE claim SET - signature_valid=CASE WHEN signature IS NOT NULL THEN 0 END, - channel_join=NULL, canonical_url=NULL - WHERE channel_hash IN ({','.join('?' for _ in spent_claims)}) - """, spent_claims - ) - sub_timer.stop() - - sub_timer = timer.add_timer('update channels') - sub_timer.start() - if channels: - self.executemany( - """ - UPDATE claim SET - public_key_bytes=:public_key_bytes, - public_key_hash=:public_key_hash - WHERE claim_hash=:claim_hash""", [{ - 'claim_hash': claim_hash, - 'public_key_bytes': txo.claim.channel.public_key_bytes, - 'public_key_hash': self.ledger.address_to_hash160( - self.ledger.public_key_to_address(txo.claim.channel.public_key_bytes) - ) - } for claim_hash, txo in channels.items()] - ) - sub_timer.stop() - - sub_timer = timer.add_timer('update claims_in_channel counts') - sub_timer.start() - if all_channel_keys: - self.executemany(f""" - UPDATE claim SET - claims_in_channel=( - SELECT COUNT(*) FROM claim AS claim_in_channel - WHERE claim_in_channel.signature_valid=1 AND - claim_in_channel.channel_hash=claim.claim_hash - ) - WHERE claim_hash = ? - """, [(channel_hash,) for channel_hash in all_channel_keys.keys()]) - sub_timer.stop() - - sub_timer = timer.add_timer('update blocked claims list') - sub_timer.start() - if (self.blocking_channel_hashes.intersection(all_channel_keys) or - self.filtering_channel_hashes.intersection(all_channel_keys)): - self.update_blocked_and_filtered_claims() - sub_timer.stop() - - def _update_support_amount(self, claim_hashes): - if claim_hashes: - self.execute(f""" - UPDATE claim SET - support_amount = COALESCE( - (SELECT SUM(amount) FROM support WHERE support.claim_hash=claim.claim_hash), 0 - ) - WHERE claim_hash IN ({','.join('?' for _ in claim_hashes)}) - """, claim_hashes) - - def _update_effective_amount(self, height, claim_hashes=None): - self.execute( - f"UPDATE claim SET effective_amount = amount + support_amount " - f"WHERE activation_height = {height}" - ) - if claim_hashes: - self.execute( - f"UPDATE claim SET effective_amount = amount + support_amount " - f"WHERE activation_height < {height} " - f" AND claim_hash IN ({','.join('?' for _ in claim_hashes)})", - claim_hashes - ) - - def _calculate_activation_height(self, height): - last_take_over_height = f"""COALESCE( - (SELECT last_take_over_height FROM claimtrie - WHERE claimtrie.normalized=claim.normalized), - {height} - ) - """ - self.execute(f""" - UPDATE claim SET activation_height = - {height} + min(4032, cast(({height} - {last_take_over_height}) / 32 AS INT)) - WHERE activation_height IS NULL - """) - - def _perform_overtake(self, height, changed_claim_hashes, deleted_names): - deleted_names_sql = claim_hashes_sql = "" - if changed_claim_hashes: - claim_hashes_sql = f"OR claim_hash IN ({','.join('?' for _ in changed_claim_hashes)})" - if deleted_names: - deleted_names_sql = f"OR normalized IN ({','.join('?' for _ in deleted_names)})" - overtakes = self.execute(f""" - SELECT winner.normalized, winner.claim_hash, - claimtrie.claim_hash AS current_winner, - MAX(winner.effective_amount) AS max_winner_effective_amount - FROM ( - SELECT normalized, claim_hash, effective_amount FROM claim - WHERE normalized IN ( - SELECT normalized FROM claim WHERE activation_height={height} {claim_hashes_sql} - ) {deleted_names_sql} - ORDER BY effective_amount DESC, height ASC, tx_position ASC - ) AS winner LEFT JOIN claimtrie USING (normalized) - GROUP BY winner.normalized - HAVING current_winner IS NULL OR current_winner <> winner.claim_hash - """, list(changed_claim_hashes)+deleted_names) - for overtake in overtakes: - if overtake.current_winner: - self.execute( - f"UPDATE claimtrie SET claim_hash = ?, last_take_over_height = {height} " - f"WHERE normalized = ?", - (overtake.claim_hash, overtake.normalized) - ) - else: - self.execute( - f"INSERT INTO claimtrie (claim_hash, normalized, last_take_over_height) " - f"VALUES (?, ?, {height})", - (overtake.claim_hash, overtake.normalized) - ) - self.execute( - f"UPDATE claim SET activation_height = {height} WHERE normalized = ? " - f"AND (activation_height IS NULL OR activation_height > {height})", - (overtake.normalized,) - ) - - def _copy(self, height): - if height > 50: - self.execute(f"DROP TABLE claimtrie{height-50}") - self.execute(f"CREATE TABLE claimtrie{height} AS SELECT * FROM claimtrie") - - def update_claimtrie(self, height, changed_claim_hashes, deleted_names, timer): - r = timer.run - - r(self._calculate_activation_height, height) - r(self._update_support_amount, changed_claim_hashes) - - r(self._update_effective_amount, height, changed_claim_hashes) - r(self._perform_overtake, height, changed_claim_hashes, list(deleted_names)) - - r(self._update_effective_amount, height) - r(self._perform_overtake, height, [], []) - - def get_expiring(self, height): - return self.execute( - f"SELECT claim_hash, normalized FROM claim WHERE expiration_height = {height}" - ) - - def advance_txs(self, height, all_txs, header, daemon_height, timer): - insert_claims = [] - update_claims = [] - update_claim_hashes = set() - delete_claim_hashes = set() - insert_supports = [] - delete_support_txo_hashes = set() - recalculate_claim_hashes = set() # added/deleted supports, added/updated claim - deleted_claim_names = set() - delete_others = set() - body_timer = timer.add_timer('body') - for position, (etx, txid) in enumerate(all_txs): - tx = timer.run( - Transaction, etx.raw, height=height, position=position - ) - # Inputs - spent_claims, spent_supports, spent_others = timer.run( - self.split_inputs_into_claims_supports_and_other, tx.inputs - ) - body_timer.start() - delete_claim_hashes.update({r.claim_hash for r in spent_claims}) - deleted_claim_names.update({r.normalized for r in spent_claims}) - delete_support_txo_hashes.update({r.txo_hash for r in spent_supports}) - recalculate_claim_hashes.update({r.claim_hash for r in spent_supports}) - delete_others.update(spent_others) - # Outputs - for output in tx.outputs: - if output.is_support: - insert_supports.append(output) - recalculate_claim_hashes.add(output.claim_hash) - elif output.script.is_claim_name: - insert_claims.append(output) - recalculate_claim_hashes.add(output.claim_hash) - elif output.script.is_update_claim: - claim_hash = output.claim_hash - update_claims.append(output) - recalculate_claim_hashes.add(claim_hash) - body_timer.stop() - - skip_update_claim_timer = timer.add_timer('skip update of abandoned claims') - skip_update_claim_timer.start() - for updated_claim in list(update_claims): - if updated_claim.ref.hash in delete_others: - update_claims.remove(updated_claim) - for updated_claim in update_claims: - claim_hash = updated_claim.claim_hash - delete_claim_hashes.discard(claim_hash) - update_claim_hashes.add(claim_hash) - skip_update_claim_timer.stop() - - skip_insert_claim_timer = timer.add_timer('skip insertion of abandoned claims') - skip_insert_claim_timer.start() - for new_claim in list(insert_claims): - if new_claim.ref.hash in delete_others: - if new_claim.claim_hash not in update_claim_hashes: - insert_claims.remove(new_claim) - skip_insert_claim_timer.stop() - - skip_insert_support_timer = timer.add_timer('skip insertion of abandoned supports') - skip_insert_support_timer.start() - for new_support in list(insert_supports): - if new_support.ref.hash in delete_others: - insert_supports.remove(new_support) - skip_insert_support_timer.stop() - - expire_timer = timer.add_timer('recording expired claims') - expire_timer.start() - for expired in self.get_expiring(height): - delete_claim_hashes.add(expired.claim_hash) - deleted_claim_names.add(expired.normalized) - expire_timer.stop() - - r = timer.run - r(update_full_text_search, 'before-delete', - delete_claim_hashes, self.db.cursor(), self.main.first_sync) - affected_channels = r(self.delete_claims, delete_claim_hashes) - r(self.delete_supports, delete_support_txo_hashes) - r(self.insert_claims, insert_claims, header) - r(self.calculate_reposts, insert_claims) - r(update_full_text_search, 'after-insert', - [txo.claim_hash for txo in insert_claims], self.db.cursor(), self.main.first_sync) - r(update_full_text_search, 'before-update', - [txo.claim_hash for txo in update_claims], self.db.cursor(), self.main.first_sync) - r(self.update_claims, update_claims, header) - r(update_full_text_search, 'after-update', - [txo.claim_hash for txo in update_claims], self.db.cursor(), self.main.first_sync) - r(self.validate_channel_signatures, height, insert_claims, - update_claims, delete_claim_hashes, affected_channels, forward_timer=True) - r(self.insert_supports, insert_supports) - r(self.update_claimtrie, height, recalculate_claim_hashes, deleted_claim_names, forward_timer=True) - for algorithm in self.trending: - r(algorithm.run, self.db.cursor(), height, daemon_height, recalculate_claim_hashes) - if not self._fts_synced and self.main.first_sync and height == daemon_height: - r(first_sync_finished, self.db.cursor()) - self._fts_synced = True - - -class LBRYLevelDB(LevelDB): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - path = os.path.join(self.env.db_dir, 'claims.db') - trending = [] - for algorithm_name in self.env.trending_algorithms: - if algorithm_name in TRENDING_ALGORITHMS: - trending.append(TRENDING_ALGORITHMS[algorithm_name]) - self.sql = SQLDB( - self, path, - self.env.default('BLOCKING_CHANNEL_IDS', '').split(' '), - self.env.default('FILTERING_CHANNEL_IDS', '').split(' '), - trending - ) - - def close(self): - super().close() - self.sql.close() - - async def _open_dbs(self, *args, **kwargs): - await super()._open_dbs(*args, **kwargs) - self.sql.open() diff --git a/lbry/wallet/server/env.py b/lbry/wallet/server/env.py deleted file mode 100644 index 173137257..000000000 --- a/lbry/wallet/server/env.py +++ /dev/null @@ -1,263 +0,0 @@ -# Copyright (c) 2016, Neil Booth -# -# All rights reserved. -# -# See the file "LICENCE" for information about the copyright -# and warranty status of this software. - - -import re -import resource -from os import environ -from collections import namedtuple -from ipaddress import ip_address - -from lbry.wallet.server.util import class_logger -from lbry.wallet.server.coin import Coin -import lbry.wallet.server.util as lib_util - - -NetIdentity = namedtuple('NetIdentity', 'host tcp_port ssl_port nick_suffix') - - -class Env: - - # Peer discovery - PD_OFF, PD_SELF, PD_ON = range(3) - - class Error(Exception): - pass - - def __init__(self, coin=None): - self.logger = class_logger(__name__, self.__class__.__name__) - self.allow_root = self.boolean('ALLOW_ROOT', False) - self.host = self.default('HOST', 'localhost') - self.rpc_host = self.default('RPC_HOST', 'localhost') - self.loop_policy = self.set_event_loop_policy() - self.obsolete(['UTXO_MB', 'HIST_MB', 'NETWORK']) - self.db_dir = self.required('DB_DIRECTORY') - self.db_engine = self.default('DB_ENGINE', 'leveldb') - self.trending_algorithms = list(set(self.default('TRENDING_ALGORITHMS', 'zscore').split(' '))) - self.max_query_workers = self.integer('MAX_QUERY_WORKERS', None) - self.individual_tag_indexes = self.boolean('INDIVIDUAL_TAG_INDEXES', True) - self.track_metrics = self.boolean('TRACK_METRICS', False) - self.websocket_host = self.default('WEBSOCKET_HOST', self.host) - self.websocket_port = self.integer('WEBSOCKET_PORT', None) - self.daemon_url = self.required('DAEMON_URL') - if coin is not None: - assert issubclass(coin, Coin) - self.coin = coin - else: - coin_name = self.required('COIN').strip() - network = self.default('NET', 'mainnet').strip() - self.coin = Coin.lookup_coin_class(coin_name, network) - self.cache_MB = self.integer('CACHE_MB', 1200) - self.reorg_limit = self.integer('REORG_LIMIT', self.coin.REORG_LIMIT) - # Server stuff - self.tcp_port = self.integer('TCP_PORT', None) - self.ssl_port = self.integer('SSL_PORT', None) - if self.ssl_port: - self.ssl_certfile = self.required('SSL_CERTFILE') - self.ssl_keyfile = self.required('SSL_KEYFILE') - self.rpc_port = self.integer('RPC_PORT', 8000) - self.prometheus_port = self.integer('PROMETHEUS_PORT', 0) - self.max_subscriptions = self.integer('MAX_SUBSCRIPTIONS', 10000) - self.banner_file = self.default('BANNER_FILE', None) - self.tor_banner_file = self.default('TOR_BANNER_FILE', self.banner_file) - self.anon_logs = self.boolean('ANON_LOGS', False) - self.log_sessions = self.integer('LOG_SESSIONS', 3600) - # Peer discovery - self.peer_discovery = self.peer_discovery_enum() - self.peer_announce = self.boolean('PEER_ANNOUNCE', True) - self.force_proxy = self.boolean('FORCE_PROXY', False) - self.tor_proxy_host = self.default('TOR_PROXY_HOST', 'localhost') - self.tor_proxy_port = self.integer('TOR_PROXY_PORT', None) - # The electrum client takes the empty string as unspecified - self.payment_address = self.default('PAYMENT_ADDRESS', '') - self.donation_address = self.default('DONATION_ADDRESS', '') - # Server limits to help prevent DoS - self.max_send = self.integer('MAX_SEND', 1000000) - self.max_receive = self.integer('MAX_RECEIVE', 1000000) - self.max_subs = self.integer('MAX_SUBS', 250000) - self.max_sessions = self.sane_max_sessions() - self.max_session_subs = self.integer('MAX_SESSION_SUBS', 50000) - self.session_timeout = self.integer('SESSION_TIMEOUT', 600) - self.drop_client = self.custom("DROP_CLIENT", None, re.compile) - self.description = self.default('DESCRIPTION', '') - self.daily_fee = self.string_amount('DAILY_FEE', '0') - - # Identities - clearnet_identity = self.clearnet_identity() - tor_identity = self.tor_identity(clearnet_identity) - self.identities = [identity - for identity in (clearnet_identity, tor_identity) - if identity is not None] - self.database_query_timeout = float(self.integer('QUERY_TIMEOUT_MS', 250)) / 1000.0 - - @classmethod - def default(cls, envvar, default): - return environ.get(envvar, default) - - @classmethod - def boolean(cls, envvar, default): - default = 'Yes' if default else '' - return bool(cls.default(envvar, default).strip()) - - @classmethod - def required(cls, envvar): - value = environ.get(envvar) - if value is None: - raise cls.Error(f'required envvar {envvar} not set') - return value - - @classmethod - def string_amount(cls, envvar, default): - value = environ.get(envvar, default) - amount_pattern = re.compile("[0-9]{0,10}(\.[0-9]{1,8})?") - if len(value) > 0 and not amount_pattern.fullmatch(value): - raise cls.Error(f'{value} is not a valid amount for {envvar}') - return value - - @classmethod - def integer(cls, envvar, default): - value = environ.get(envvar) - if value is None: - return default - try: - return int(value) - except Exception: - raise cls.Error(f'cannot convert envvar {envvar} value {value} to an integer') - - @classmethod - def custom(cls, envvar, default, parse): - value = environ.get(envvar) - if value is None: - return default - try: - return parse(value) - except Exception as e: - raise cls.Error(f'cannot parse envvar {envvar} value {value}') from e - - @classmethod - def obsolete(cls, envvars): - bad = [envvar for envvar in envvars if environ.get(envvar)] - if bad: - raise cls.Error(f'remove obsolete environment variables {bad}') - - def set_event_loop_policy(self): - policy_name = self.default('EVENT_LOOP_POLICY', None) - if not policy_name: - import asyncio - return asyncio.get_event_loop_policy() - elif policy_name == 'uvloop': - import uvloop - import asyncio - loop_policy = uvloop.EventLoopPolicy() - asyncio.set_event_loop_policy(loop_policy) - return loop_policy - raise self.Error(f'unknown event loop policy "{policy_name}"') - - def cs_host(self, *, for_rpc): - """Returns the 'host' argument to pass to asyncio's create_server - call. The result can be a single host name string, a list of - host name strings, or an empty string to bind to all interfaces. - - If rpc is True the host to use for the RPC server is returned. - Otherwise the host to use for SSL/TCP servers is returned. - """ - host = self.rpc_host if for_rpc else self.host - result = [part.strip() for part in host.split(',')] - if len(result) == 1: - result = result[0] - # An empty result indicates all interfaces, which we do not - # permitted for an RPC server. - if for_rpc and not result: - result = 'localhost' - if result == 'localhost': - # 'localhost' resolves to ::1 (ipv6) on many systems, which fails on default setup of - # docker, using 127.0.0.1 instead forces ipv4 - result = '127.0.0.1' - return result - - def sane_max_sessions(self): - """Return the maximum number of sessions to permit. Normally this - is MAX_SESSIONS. However, to prevent open file exhaustion, ajdust - downwards if running with a small open file rlimit.""" - env_value = self.integer('MAX_SESSIONS', 1000) - nofile_limit = resource.getrlimit(resource.RLIMIT_NOFILE)[0] - # We give the DB 250 files; allow ElectrumX 100 for itself - value = max(0, min(env_value, nofile_limit - 350)) - if value < env_value: - self.logger.warning(f'lowered maximum sessions from {env_value:,d} to {value:,d} ' - f'because your open file limit is {nofile_limit:,d}') - return value - - def clearnet_identity(self): - host = self.default('REPORT_HOST', None) - if host is None: - return None - try: - ip = ip_address(host) - except ValueError: - bad = (not lib_util.is_valid_hostname(host) - or host.lower() == 'localhost') - else: - bad = (ip.is_multicast or ip.is_unspecified - or (ip.is_private and self.peer_announce)) - if bad: - raise self.Error(f'"{host}" is not a valid REPORT_HOST') - tcp_port = self.integer('REPORT_TCP_PORT', self.tcp_port) or None - ssl_port = self.integer('REPORT_SSL_PORT', self.ssl_port) or None - if tcp_port == ssl_port: - raise self.Error('REPORT_TCP_PORT and REPORT_SSL_PORT ' - f'both resolve to {tcp_port}') - return NetIdentity( - host, - tcp_port, - ssl_port, - '' - ) - - def tor_identity(self, clearnet): - host = self.default('REPORT_HOST_TOR', None) - if host is None: - return None - if not host.endswith('.onion'): - raise self.Error(f'tor host "{host}" must end with ".onion"') - - def port(port_kind): - """Returns the clearnet identity port, if any and not zero, - otherwise the listening port.""" - result = 0 - if clearnet: - result = getattr(clearnet, port_kind) - return result or getattr(self, port_kind) - - tcp_port = self.integer('REPORT_TCP_PORT_TOR', - port('tcp_port')) or None - ssl_port = self.integer('REPORT_SSL_PORT_TOR', - port('ssl_port')) or None - if tcp_port == ssl_port: - raise self.Error('REPORT_TCP_PORT_TOR and REPORT_SSL_PORT_TOR ' - f'both resolve to {tcp_port}') - - return NetIdentity( - host, - tcp_port, - ssl_port, - '_tor', - ) - - def hosts_dict(self): - return {identity.host: {'tcp_port': identity.tcp_port, - 'ssl_port': identity.ssl_port} - for identity in self.identities} - - def peer_discovery_enum(self): - pd = self.default('PEER_DISCOVERY', 'on').strip().lower() - if pd in ('off', ''): - return self.PD_OFF - elif pd == 'self': - return self.PD_SELF - else: - return self.PD_ON diff --git a/lbry/wallet/server/hash.py b/lbry/wallet/server/hash.py deleted file mode 100644 index a1a5d8068..000000000 --- a/lbry/wallet/server/hash.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright (c) 2016-2017, Neil Booth -# -# All rights reserved. -# -# The MIT License (MIT) -# -# Permission is hereby granted, free of charge, to any person obtaining -# a copy of this software and associated documentation files (the -# "Software"), to deal in the Software without restriction, including -# without limitation the rights to use, copy, modify, merge, publish, -# distribute, sublicense, and/or sell copies of the Software, and to -# permit persons to whom the Software is furnished to do so, subject to -# the following conditions: -# -# The above copyright notice and this permission notice shall be -# included in all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -"""Cryptograph hash functions and related classes.""" - - -import hashlib -import hmac - -from lbry.wallet.server.util import bytes_to_int, int_to_bytes, hex_to_bytes - -_sha256 = hashlib.sha256 -_sha512 = hashlib.sha512 -_new_hash = hashlib.new -_new_hmac = hmac.new -HASHX_LEN = 11 - - -def sha256(x): - """Simple wrapper of hashlib sha256.""" - return _sha256(x).digest() - - -def ripemd160(x): - """Simple wrapper of hashlib ripemd160.""" - h = _new_hash('ripemd160') - h.update(x) - return h.digest() - - -def double_sha256(x): - """SHA-256 of SHA-256, as used extensively in bitcoin.""" - return sha256(sha256(x)) - - -def hmac_sha512(key, msg): - """Use SHA-512 to provide an HMAC.""" - return _new_hmac(key, msg, _sha512).digest() - - -def hash160(x): - """RIPEMD-160 of SHA-256. - - Used to make bitcoin addresses from pubkeys.""" - return ripemd160(sha256(x)) - - -def hash_to_hex_str(x): - """Convert a big-endian binary hash to displayed hex string. - - Display form of a binary hash is reversed and converted to hex. - """ - return bytes(reversed(x)).hex() - - -def hex_str_to_hash(x): - """Convert a displayed hex string to a binary hash.""" - return bytes(reversed(hex_to_bytes(x))) - - -class Base58Error(Exception): - """Exception used for Base58 errors.""" - - -class Base58: - """Class providing base 58 functionality.""" - - chars = '123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz' - assert len(chars) == 58 - cmap = {c: n for n, c in enumerate(chars)} - - @staticmethod - def char_value(c): - val = Base58.cmap.get(c) - if val is None: - raise Base58Error(f'invalid base 58 character "{c}"') - return val - - @staticmethod - def decode(txt): - """Decodes txt into a big-endian bytearray.""" - if not isinstance(txt, str): - raise TypeError('a string is required') - - if not txt: - raise Base58Error('string cannot be empty') - - value = 0 - for c in txt: - value = value * 58 + Base58.char_value(c) - - result = int_to_bytes(value) - - # Prepend leading zero bytes if necessary - count = 0 - for c in txt: - if c != '1': - break - count += 1 - if count: - result = bytes(count) + result - - return result - - @staticmethod - def encode(be_bytes): - """Converts a big-endian bytearray into a base58 string.""" - value = bytes_to_int(be_bytes) - - txt = '' - while value: - value, mod = divmod(value, 58) - txt += Base58.chars[mod] - - for byte in be_bytes: - if byte != 0: - break - txt += '1' - - return txt[::-1] - - @staticmethod - def decode_check(txt, *, hash_fn=double_sha256): - """Decodes a Base58Check-encoded string to a payload. The version - prefixes it.""" - be_bytes = Base58.decode(txt) - result, check = be_bytes[:-4], be_bytes[-4:] - if check != hash_fn(result)[:4]: - raise Base58Error(f'invalid base 58 checksum for {txt}') - return result - - @staticmethod - def encode_check(payload, *, hash_fn=double_sha256): - """Encodes a payload bytearray (which includes the version byte(s)) - into a Base58Check string.""" - be_bytes = payload + hash_fn(payload)[:4] - return Base58.encode(be_bytes) diff --git a/lbry/wallet/server/history.py b/lbry/wallet/server/history.py deleted file mode 100644 index f810a1045..000000000 --- a/lbry/wallet/server/history.py +++ /dev/null @@ -1,324 +0,0 @@ -# Copyright (c) 2016-2018, Neil Booth -# Copyright (c) 2017, the ElectrumX authors -# -# All rights reserved. -# -# See the file "LICENCE" for information about the copyright -# and warranty status of this software. - -"""History by script hash (address).""" - -import array -import ast -import bisect -import time -from collections import defaultdict -from functools import partial - -from lbry.wallet.server import util -from lbry.wallet.server.util import pack_be_uint16, unpack_be_uint16_from -from lbry.wallet.server.hash import hash_to_hex_str, HASHX_LEN - - -class History: - - DB_VERSIONS = [0] - - def __init__(self): - self.logger = util.class_logger(__name__, self.__class__.__name__) - # For history compaction - self.max_hist_row_entries = 12500 - self.unflushed = defaultdict(partial(array.array, 'I')) - self.unflushed_count = 0 - self.db = None - - def open_db(self, db_class, for_sync, utxo_flush_count, compacting): - self.db = db_class('hist', for_sync) - self.read_state() - self.clear_excess(utxo_flush_count) - # An incomplete compaction needs to be cancelled otherwise - # restarting it will corrupt the history - if not compacting: - self._cancel_compaction() - return self.flush_count - - def close_db(self): - if self.db: - self.db.close() - self.db = None - - def read_state(self): - state = self.db.get(b'state\0\0') - if state: - state = ast.literal_eval(state.decode()) - if not isinstance(state, dict): - raise RuntimeError('failed reading state from history DB') - self.flush_count = state['flush_count'] - self.comp_flush_count = state.get('comp_flush_count', -1) - self.comp_cursor = state.get('comp_cursor', -1) - self.db_version = state.get('db_version', 0) - else: - self.flush_count = 0 - self.comp_flush_count = -1 - self.comp_cursor = -1 - self.db_version = max(self.DB_VERSIONS) - - self.logger.info(f'history DB version: {self.db_version}') - if self.db_version not in self.DB_VERSIONS: - msg = f'this software only handles DB versions {self.DB_VERSIONS}' - self.logger.error(msg) - raise RuntimeError(msg) - self.logger.info(f'flush count: {self.flush_count:,d}') - - def clear_excess(self, utxo_flush_count): - # < might happen at end of compaction as both DBs cannot be - # updated atomically - if self.flush_count <= utxo_flush_count: - return - - self.logger.info('DB shut down uncleanly. Scanning for ' - 'excess history flushes...') - - keys = [] - for key, hist in self.db.iterator(prefix=b''): - flush_id, = unpack_be_uint16_from(key[-2:]) - if flush_id > utxo_flush_count: - keys.append(key) - - self.logger.info(f'deleting {len(keys):,d} history entries') - - self.flush_count = utxo_flush_count - with self.db.write_batch() as batch: - for key in keys: - batch.delete(key) - self.write_state(batch) - - self.logger.info('deleted excess history entries') - - def write_state(self, batch): - """Write state to the history DB.""" - state = { - 'flush_count': self.flush_count, - 'comp_flush_count': self.comp_flush_count, - 'comp_cursor': self.comp_cursor, - 'db_version': self.db_version, - } - # History entries are not prefixed; the suffix \0\0 ensures we - # look similar to other entries and aren't interfered with - batch.put(b'state\0\0', repr(state).encode()) - - def add_unflushed(self, hashXs_by_tx, first_tx_num): - unflushed = self.unflushed - count = 0 - for tx_num, hashXs in enumerate(hashXs_by_tx, start=first_tx_num): - hashXs = set(hashXs) - for hashX in hashXs: - unflushed[hashX].append(tx_num) - count += len(hashXs) - self.unflushed_count += count - - def unflushed_memsize(self): - return len(self.unflushed) * 180 + self.unflushed_count * 4 - - def assert_flushed(self): - assert not self.unflushed - - def flush(self): - start_time = time.time() - self.flush_count += 1 - flush_id = pack_be_uint16(self.flush_count) - unflushed = self.unflushed - - with self.db.write_batch() as batch: - for hashX in sorted(unflushed): - key = hashX + flush_id - batch.put(key, unflushed[hashX].tobytes()) - self.write_state(batch) - - count = len(unflushed) - unflushed.clear() - self.unflushed_count = 0 - - if self.db.for_sync: - elapsed = time.time() - start_time - self.logger.info(f'flushed history in {elapsed:.1f}s ' - f'for {count:,d} addrs') - - def backup(self, hashXs, tx_count): - # Not certain this is needed, but it doesn't hurt - self.flush_count += 1 - nremoves = 0 - bisect_left = bisect.bisect_left - - with self.db.write_batch() as batch: - for hashX in sorted(hashXs): - deletes = [] - puts = {} - for key, hist in self.db.iterator(prefix=hashX, reverse=True): - a = array.array('I') - a.frombytes(hist) - # Remove all history entries >= tx_count - idx = bisect_left(a, tx_count) - nremoves += len(a) - idx - if idx > 0: - puts[key] = a[:idx].tobytes() - break - deletes.append(key) - - for key in deletes: - batch.delete(key) - for key, value in puts.items(): - batch.put(key, value) - self.write_state(batch) - - self.logger.info(f'backing up removed {nremoves:,d} history entries') - - def get_txnums(self, hashX, limit=1000): - """Generator that returns an unpruned, sorted list of tx_nums in the - history of a hashX. Includes both spending and receiving - transactions. By default yields at most 1000 entries. Set - limit to None to get them all. """ - limit = util.resolve_limit(limit) - for key, hist in self.db.iterator(prefix=hashX): - a = array.array('I') - a.frombytes(hist) - for tx_num in a: - if limit == 0: - return - yield tx_num - limit -= 1 - - # - # History compaction - # - - # comp_cursor is a cursor into compaction progress. - # -1: no compaction in progress - # 0-65535: Compaction in progress; all prefixes < comp_cursor have - # been compacted, and later ones have not. - # 65536: compaction complete in-memory but not flushed - # - # comp_flush_count applies during compaction, and is a flush count - # for history with prefix < comp_cursor. flush_count applies - # to still uncompacted history. It is -1 when no compaction is - # taking place. Key suffixes up to and including comp_flush_count - # are used, so a parallel history flush must first increment this - # - # When compaction is complete and the final flush takes place, - # flush_count is reset to comp_flush_count, and comp_flush_count to -1 - - def _flush_compaction(self, cursor, write_items, keys_to_delete): - """Flush a single compaction pass as a batch.""" - # Update compaction state - if cursor == 65536: - self.flush_count = self.comp_flush_count - self.comp_cursor = -1 - self.comp_flush_count = -1 - else: - self.comp_cursor = cursor - - # History DB. Flush compacted history and updated state - with self.db.write_batch() as batch: - # Important: delete first! The keyspace may overlap. - for key in keys_to_delete: - batch.delete(key) - for key, value in write_items: - batch.put(key, value) - self.write_state(batch) - - def _compact_hashX(self, hashX, hist_map, hist_list, - write_items, keys_to_delete): - """Compress history for a hashX. hist_list is an ordered list of - the histories to be compressed.""" - # History entries (tx numbers) are 4 bytes each. Distribute - # over rows of up to 50KB in size. A fixed row size means - # future compactions will not need to update the first N - 1 - # rows. - max_row_size = self.max_hist_row_entries * 4 - full_hist = b''.join(hist_list) - nrows = (len(full_hist) + max_row_size - 1) // max_row_size - if nrows > 4: - self.logger.info('hashX {} is large: {:,d} entries across ' - '{:,d} rows' - .format(hash_to_hex_str(hashX), - len(full_hist) // 4, nrows)) - - # Find what history needs to be written, and what keys need to - # be deleted. Start by assuming all keys are to be deleted, - # and then remove those that are the same on-disk as when - # compacted. - write_size = 0 - keys_to_delete.update(hist_map) - for n, chunk in enumerate(util.chunks(full_hist, max_row_size)): - key = hashX + pack_be_uint16(n) - if hist_map.get(key) == chunk: - keys_to_delete.remove(key) - else: - write_items.append((key, chunk)) - write_size += len(chunk) - - assert n + 1 == nrows - self.comp_flush_count = max(self.comp_flush_count, n) - - return write_size - - def _compact_prefix(self, prefix, write_items, keys_to_delete): - """Compact all history entries for hashXs beginning with the - given prefix. Update keys_to_delete and write.""" - prior_hashX = None - hist_map = {} - hist_list = [] - - key_len = HASHX_LEN + 2 - write_size = 0 - for key, hist in self.db.iterator(prefix=prefix): - # Ignore non-history entries - if len(key) != key_len: - continue - hashX = key[:-2] - if hashX != prior_hashX and prior_hashX: - write_size += self._compact_hashX(prior_hashX, hist_map, - hist_list, write_items, - keys_to_delete) - hist_map.clear() - hist_list.clear() - prior_hashX = hashX - hist_map[key] = hist - hist_list.append(hist) - - if prior_hashX: - write_size += self._compact_hashX(prior_hashX, hist_map, hist_list, - write_items, keys_to_delete) - return write_size - - def _compact_history(self, limit): - """Inner loop of history compaction. Loops until limit bytes have - been processed. - """ - keys_to_delete = set() - write_items = [] # A list of (key, value) pairs - write_size = 0 - - # Loop over 2-byte prefixes - cursor = self.comp_cursor - while write_size < limit and cursor < 65536: - prefix = pack_be_uint16(cursor) - write_size += self._compact_prefix(prefix, write_items, - keys_to_delete) - cursor += 1 - - max_rows = self.comp_flush_count + 1 - self._flush_compaction(cursor, write_items, keys_to_delete) - - self.logger.info('history compaction: wrote {:,d} rows ({:.1f} MB), ' - 'removed {:,d} rows, largest: {:,d}, {:.1f}% complete' - .format(len(write_items), write_size / 1000000, - len(keys_to_delete), max_rows, - 100 * cursor / 65536)) - return write_size - - def _cancel_compaction(self): - if self.comp_cursor != -1: - self.logger.warning('cancelling in-progress history compaction') - self.comp_flush_count = -1 - self.comp_cursor = -1 diff --git a/lbry/wallet/server/leveldb.py b/lbry/wallet/server/leveldb.py deleted file mode 100644 index cdcbbe50a..000000000 --- a/lbry/wallet/server/leveldb.py +++ /dev/null @@ -1,670 +0,0 @@ -# Copyright (c) 2016, Neil Booth -# Copyright (c) 2017, the ElectrumX authors -# -# All rights reserved. -# -# See the file "LICENCE" for information about the copyright -# and warranty status of this software. - -"""Interface to the blockchain database.""" - - -import asyncio -import array -import ast -import os -import time -from asyncio import sleep -from bisect import bisect_right -from collections import namedtuple -from functools import partial -from glob import glob -from struct import pack, unpack -from concurrent.futures.thread import ThreadPoolExecutor -import attr - -from lbry.wallet.server import util -from lbry.wallet.server.hash import hash_to_hex_str, HASHX_LEN -from lbry.wallet.server.merkle import Merkle, MerkleCache -from lbry.wallet.server.util import formatted_time -from lbry.wallet.server.storage import db_class -from lbry.wallet.server.history import History - - -UTXO = namedtuple("UTXO", "tx_num tx_pos tx_hash height value") - - -@attr.s(slots=True) -class FlushData: - height = attr.ib() - tx_count = attr.ib() - headers = attr.ib() - block_tx_hashes = attr.ib() - # The following are flushed to the UTXO DB if undo_infos is not None - undo_infos = attr.ib() - adds = attr.ib() - deletes = attr.ib() - tip = attr.ib() - - -class LevelDB: - """Simple wrapper of the backend database for querying. - - Performs no DB update, though the DB will be cleaned on opening if - it was shutdown uncleanly. - """ - - DB_VERSIONS = [6] - - class DBError(Exception): - """Raised on general DB errors generally indicating corruption.""" - - def __init__(self, env): - self.logger = util.class_logger(__name__, self.__class__.__name__) - self.env = env - self.coin = env.coin - self.executor = ThreadPoolExecutor(max(1, os.cpu_count() - 1)) - - # Setup block header size handlers - if self.coin.STATIC_BLOCK_HEADERS: - self.header_offset = self.coin.static_header_offset - self.header_len = self.coin.static_header_len - else: - self.header_offset = self.dynamic_header_offset - self.header_len = self.dynamic_header_len - - self.logger.info(f'switching current directory to {env.db_dir}') - - self.db_class = db_class(env.db_dir, self.env.db_engine) - self.history = History() - self.utxo_db = None - self.tx_counts = None - self.last_flush = time.time() - - self.logger.info(f'using {self.env.db_engine} for DB backend') - - # Header merkle cache - self.merkle = Merkle() - self.header_mc = MerkleCache(self.merkle, self.fs_block_hashes) - - path = partial(os.path.join, self.env.db_dir) - self.headers_file = util.LogicalFile(path('meta/headers'), 2, 16000000) - self.tx_counts_file = util.LogicalFile(path('meta/txcounts'), 2, 2000000) - self.hashes_file = util.LogicalFile(path('meta/hashes'), 4, 16000000) - if not self.coin.STATIC_BLOCK_HEADERS: - self.headers_offsets_file = util.LogicalFile( - path('meta/headers_offsets'), 2, 16000000) - - async def _read_tx_counts(self): - if self.tx_counts is not None: - return - # tx_counts[N] has the cumulative number of txs at the end of - # height N. So tx_counts[0] is 1 - the genesis coinbase - size = (self.db_height + 1) * 4 - tx_counts = self.tx_counts_file.read(0, size) - assert len(tx_counts) == size - self.tx_counts = array.array('I', tx_counts) - if self.tx_counts: - assert self.db_tx_count == self.tx_counts[-1] - else: - assert self.db_tx_count == 0 - - async def _open_dbs(self, for_sync, compacting): - assert self.utxo_db is None - - # First UTXO DB - self.utxo_db = self.db_class('utxo', for_sync) - if self.utxo_db.is_new: - self.logger.info('created new database') - self.logger.info('creating metadata directory') - os.mkdir(os.path.join(self.env.db_dir, 'meta')) - coin_path = os.path.join(self.env.db_dir, 'meta', 'COIN') - with util.open_file(coin_path, create=True) as f: - f.write(f'ElectrumX databases and metadata for ' - f'{self.coin.NAME} {self.coin.NET}'.encode()) - if not self.coin.STATIC_BLOCK_HEADERS: - self.headers_offsets_file.write(0, bytes(8)) - else: - self.logger.info(f'opened UTXO DB (for sync: {for_sync})') - self.read_utxo_state() - - # Then history DB - self.utxo_flush_count = self.history.open_db(self.db_class, for_sync, - self.utxo_flush_count, - compacting) - self.clear_excess_undo_info() - - # Read TX counts (requires meta directory) - await self._read_tx_counts() - - def close(self): - self.utxo_db.close() - self.history.close_db() - self.executor.shutdown(wait=True) - - async def open_for_compacting(self): - await self._open_dbs(True, True) - - async def open_for_sync(self): - """Open the databases to sync to the daemon. - - When syncing we want to reserve a lot of open files for the - synchronization. When serving clients we want the open files for - serving network connections. - """ - await self._open_dbs(True, False) - - async def open_for_serving(self): - """Open the databases for serving. If they are already open they are - closed first. - """ - if self.utxo_db: - self.logger.info('closing DBs to re-open for serving') - self.utxo_db.close() - self.history.close_db() - self.utxo_db = None - await self._open_dbs(False, False) - - # Header merkle cache - - async def populate_header_merkle_cache(self): - self.logger.info('populating header merkle cache...') - length = max(1, self.db_height - self.env.reorg_limit) - start = time.time() - await self.header_mc.initialize(length) - elapsed = time.time() - start - self.logger.info(f'header merkle cache populated in {elapsed:.1f}s') - - async def header_branch_and_root(self, length, height): - return await self.header_mc.branch_and_root(length, height) - - # Flushing - def assert_flushed(self, flush_data): - """Asserts state is fully flushed.""" - assert flush_data.tx_count == self.fs_tx_count == self.db_tx_count - assert flush_data.height == self.fs_height == self.db_height - assert flush_data.tip == self.db_tip - assert not flush_data.headers - assert not flush_data.block_tx_hashes - assert not flush_data.adds - assert not flush_data.deletes - assert not flush_data.undo_infos - self.history.assert_flushed() - - def flush_dbs(self, flush_data, flush_utxos, estimate_txs_remaining): - """Flush out cached state. History is always flushed; UTXOs are - flushed if flush_utxos.""" - if flush_data.height == self.db_height: - self.assert_flushed(flush_data) - return - - start_time = time.time() - prior_flush = self.last_flush - tx_delta = flush_data.tx_count - self.last_flush_tx_count - - # Flush to file system - self.flush_fs(flush_data) - - # Then history - self.flush_history() - - # Flush state last as it reads the wall time. - with self.utxo_db.write_batch() as batch: - if flush_utxos: - self.flush_utxo_db(batch, flush_data) - self.flush_state(batch) - - # Update and put the wall time again - otherwise we drop the - # time it took to commit the batch - self.flush_state(self.utxo_db) - - elapsed = self.last_flush - start_time - self.logger.info(f'flush #{self.history.flush_count:,d} took ' - f'{elapsed:.1f}s. Height {flush_data.height:,d} ' - f'txs: {flush_data.tx_count:,d} ({tx_delta:+,d})') - - # Catch-up stats - if self.utxo_db.for_sync: - flush_interval = self.last_flush - prior_flush - tx_per_sec_gen = int(flush_data.tx_count / self.wall_time) - tx_per_sec_last = 1 + int(tx_delta / flush_interval) - eta = estimate_txs_remaining() / tx_per_sec_last - self.logger.info(f'tx/sec since genesis: {tx_per_sec_gen:,d}, ' - f'since last flush: {tx_per_sec_last:,d}') - self.logger.info(f'sync time: {formatted_time(self.wall_time)} ' - f'ETA: {formatted_time(eta)}') - - def flush_fs(self, flush_data): - """Write headers, tx counts and block tx hashes to the filesystem. - - The first height to write is self.fs_height + 1. The FS - metadata is all append-only, so in a crash we just pick up - again from the height stored in the DB. - """ - prior_tx_count = (self.tx_counts[self.fs_height] - if self.fs_height >= 0 else 0) - assert len(flush_data.block_tx_hashes) == len(flush_data.headers) - assert flush_data.height == self.fs_height + len(flush_data.headers) - assert flush_data.tx_count == (self.tx_counts[-1] if self.tx_counts - else 0) - assert len(self.tx_counts) == flush_data.height + 1 - hashes = b''.join(flush_data.block_tx_hashes) - flush_data.block_tx_hashes.clear() - assert len(hashes) % 32 == 0 - assert len(hashes) // 32 == flush_data.tx_count - prior_tx_count - - # Write the headers, tx counts, and tx hashes - start_time = time.time() - height_start = self.fs_height + 1 - offset = self.header_offset(height_start) - self.headers_file.write(offset, b''.join(flush_data.headers)) - self.fs_update_header_offsets(offset, height_start, flush_data.headers) - flush_data.headers.clear() - - offset = height_start * self.tx_counts.itemsize - self.tx_counts_file.write(offset, - self.tx_counts[height_start:].tobytes()) - offset = prior_tx_count * 32 - self.hashes_file.write(offset, hashes) - - self.fs_height = flush_data.height - self.fs_tx_count = flush_data.tx_count - - if self.utxo_db.for_sync: - elapsed = time.time() - start_time - self.logger.info(f'flushed filesystem data in {elapsed:.2f}s') - - def flush_history(self): - self.history.flush() - - def flush_utxo_db(self, batch, flush_data): - """Flush the cached DB writes and UTXO set to the batch.""" - # Care is needed because the writes generated by flushing the - # UTXO state may have keys in common with our write cache or - # may be in the DB already. - start_time = time.time() - add_count = len(flush_data.adds) - spend_count = len(flush_data.deletes) // 2 - - # Spends - batch_delete = batch.delete - for key in sorted(flush_data.deletes): - batch_delete(key) - flush_data.deletes.clear() - - # New UTXOs - batch_put = batch.put - for key, value in flush_data.adds.items(): - # suffix = tx_idx + tx_num - hashX = value[:-12] - suffix = key[-2:] + value[-12:-8] - batch_put(b'h' + key[:4] + suffix, hashX) - batch_put(b'u' + hashX + suffix, value[-8:]) - flush_data.adds.clear() - - # New undo information - self.flush_undo_infos(batch_put, flush_data.undo_infos) - flush_data.undo_infos.clear() - - if self.utxo_db.for_sync: - block_count = flush_data.height - self.db_height - tx_count = flush_data.tx_count - self.db_tx_count - elapsed = time.time() - start_time - self.logger.info(f'flushed {block_count:,d} blocks with ' - f'{tx_count:,d} txs, {add_count:,d} UTXO adds, ' - f'{spend_count:,d} spends in ' - f'{elapsed:.1f}s, committing...') - - self.utxo_flush_count = self.history.flush_count - self.db_height = flush_data.height - self.db_tx_count = flush_data.tx_count - self.db_tip = flush_data.tip - - def flush_state(self, batch): - """Flush chain state to the batch.""" - now = time.time() - self.wall_time += now - self.last_flush - self.last_flush = now - self.last_flush_tx_count = self.fs_tx_count - self.write_utxo_state(batch) - - def flush_backup(self, flush_data, touched): - """Like flush_dbs() but when backing up. All UTXOs are flushed.""" - assert not flush_data.headers - assert not flush_data.block_tx_hashes - assert flush_data.height < self.db_height - self.history.assert_flushed() - - start_time = time.time() - tx_delta = flush_data.tx_count - self.last_flush_tx_count - - self.backup_fs(flush_data.height, flush_data.tx_count) - self.history.backup(touched, flush_data.tx_count) - with self.utxo_db.write_batch() as batch: - self.flush_utxo_db(batch, flush_data) - # Flush state last as it reads the wall time. - self.flush_state(batch) - - elapsed = self.last_flush - start_time - self.logger.info(f'backup flush #{self.history.flush_count:,d} took ' - f'{elapsed:.1f}s. Height {flush_data.height:,d} ' - f'txs: {flush_data.tx_count:,d} ({tx_delta:+,d})') - - def fs_update_header_offsets(self, offset_start, height_start, headers): - if self.coin.STATIC_BLOCK_HEADERS: - return - offset = offset_start - offsets = [] - for h in headers: - offset += len(h) - offsets.append(pack("= 0, count >= 0. Reads as many headers as - are available starting at start_height up to count. This - would be zero if start_height is beyond self.db_height, for - example. - - Returns a (binary, n) pair where binary is the concatenated - binary headers, and n is the count of headers returned. - """ - if start_height < 0 or count < 0: - raise self.DBError(f'{count:,d} headers starting at ' - f'{start_height:,d} not on disk') - - def read_headers(): - # Read some from disk - disk_count = max(0, min(count, self.db_height + 1 - start_height)) - if disk_count: - offset = self.header_offset(start_height) - size = self.header_offset(start_height + disk_count) - offset - return self.headers_file.read(offset, size), disk_count - return b'', 0 - - return await asyncio.get_event_loop().run_in_executor(self.executor, read_headers) - - def fs_tx_hash(self, tx_num): - """Return a par (tx_hash, tx_height) for the given tx number. - - If the tx_height is not on disk, returns (None, tx_height).""" - tx_height = bisect_right(self.tx_counts, tx_num) - if tx_height > self.db_height: - tx_hash = None - else: - tx_hash = self.hashes_file.read(tx_num * 32, 32) - return tx_hash, tx_height - - async def fs_block_hashes(self, height, count): - headers_concat, headers_count = await self.read_headers(height, count) - if headers_count != count: - raise self.DBError(f'only got {headers_count:,d} headers starting at {height:,d}, not {count:,d}') - offset = 0 - headers = [] - for n in range(count): - hlen = self.header_len(height + n) - headers.append(headers_concat[offset:offset + hlen]) - offset += hlen - - return [self.coin.header_hash(header) for header in headers] - - async def limited_history(self, hashX, *, limit=1000): - """Return an unpruned, sorted list of (tx_hash, height) tuples of - confirmed transactions that touched the address, earliest in - the blockchain first. Includes both spending and receiving - transactions. By default returns at most 1000 entries. Set - limit to None to get them all. - """ - def read_history(): - tx_nums = list(self.history.get_txnums(hashX, limit)) - fs_tx_hash = self.fs_tx_hash - return [fs_tx_hash(tx_num) for tx_num in tx_nums] - - while True: - history = await asyncio.get_event_loop().run_in_executor(self.executor, read_history) - if all(hash is not None for hash, height in history): - return history - self.logger.warning(f'limited_history: tx hash ' - f'not found (reorg?), retrying...') - await sleep(0.25) - - # -- Undo information - - def min_undo_height(self, max_height): - """Returns a height from which we should store undo info.""" - return max_height - self.env.reorg_limit + 1 - - def undo_key(self, height): - """DB key for undo information at the given height.""" - return b'U' + pack('>I', height) - - def read_undo_info(self, height): - """Read undo information from a file for the current height.""" - return self.utxo_db.get(self.undo_key(height)) - - def flush_undo_infos(self, batch_put, undo_infos): - """undo_infos is a list of (undo_info, height) pairs.""" - for undo_info, height in undo_infos: - batch_put(self.undo_key(height), b''.join(undo_info)) - - def raw_block_prefix(self): - return 'meta/block' - - def raw_block_path(self, height): - return os.path.join(self.env.db_dir, f'{self.raw_block_prefix()}{height:d}') - - def read_raw_block(self, height): - """Returns a raw block read from disk. Raises FileNotFoundError - if the block isn't on-disk.""" - with util.open_file(self.raw_block_path(height)) as f: - return f.read(-1) - - def write_raw_block(self, block, height): - """Write a raw block to disk.""" - with util.open_truncate(self.raw_block_path(height)) as f: - f.write(block) - # Delete old blocks to prevent them accumulating - try: - del_height = self.min_undo_height(height) - 1 - os.remove(self.raw_block_path(del_height)) - except FileNotFoundError: - pass - - def clear_excess_undo_info(self): - """Clear excess undo info. Only most recent N are kept.""" - prefix = b'U' - min_height = self.min_undo_height(self.db_height) - keys = [] - for key, hist in self.utxo_db.iterator(prefix=prefix): - height, = unpack('>I', key[-4:]) - if height >= min_height: - break - keys.append(key) - - if keys: - with self.utxo_db.write_batch() as batch: - for key in keys: - batch.delete(key) - self.logger.info(f'deleted {len(keys):,d} stale undo entries') - - # delete old block files - prefix = self.raw_block_prefix() - paths = [path for path in glob(f'{prefix}[0-9]*') - if len(path) > len(prefix) - and int(path[len(prefix):]) < min_height] - if paths: - for path in paths: - try: - os.remove(path) - except FileNotFoundError: - pass - self.logger.info(f'deleted {len(paths):,d} stale block files') - - # -- UTXO database - - def read_utxo_state(self): - state = self.utxo_db.get(b'state') - if not state: - self.db_height = -1 - self.db_tx_count = 0 - self.db_tip = b'\0' * 32 - self.db_version = max(self.DB_VERSIONS) - self.utxo_flush_count = 0 - self.wall_time = 0 - self.first_sync = True - else: - state = ast.literal_eval(state.decode()) - if not isinstance(state, dict): - raise self.DBError('failed reading state from DB') - self.db_version = state['db_version'] - if self.db_version not in self.DB_VERSIONS: - raise self.DBError(f'your UTXO DB version is {self.db_version} but this ' - f'software only handles versions {self.DB_VERSIONS}') - # backwards compat - genesis_hash = state['genesis'] - if isinstance(genesis_hash, bytes): - genesis_hash = genesis_hash.decode() - if genesis_hash != self.coin.GENESIS_HASH: - raise self.DBError(f'DB genesis hash {genesis_hash} does not ' - f'match coin {self.coin.GENESIS_HASH}') - self.db_height = state['height'] - self.db_tx_count = state['tx_count'] - self.db_tip = state['tip'] - self.utxo_flush_count = state['utxo_flush_count'] - self.wall_time = state['wall_time'] - self.first_sync = state['first_sync'] - - # These are our state as we move ahead of DB state - self.fs_height = self.db_height - self.fs_tx_count = self.db_tx_count - self.last_flush_tx_count = self.fs_tx_count - - # Log some stats - self.logger.info(f'DB version: {self.db_version:d}') - self.logger.info(f'coin: {self.coin.NAME}') - self.logger.info(f'network: {self.coin.NET}') - self.logger.info(f'height: {self.db_height:,d}') - self.logger.info(f'tip: {hash_to_hex_str(self.db_tip)}') - self.logger.info(f'tx count: {self.db_tx_count:,d}') - if self.utxo_db.for_sync: - self.logger.info(f'flushing DB cache at {self.env.cache_MB:,d} MB') - if self.first_sync: - self.logger.info(f'sync time so far: {util.formatted_time(self.wall_time)}') - - def write_utxo_state(self, batch): - """Write (UTXO) state to the batch.""" - state = { - 'genesis': self.coin.GENESIS_HASH, - 'height': self.db_height, - 'tx_count': self.db_tx_count, - 'tip': self.db_tip, - 'utxo_flush_count': self.utxo_flush_count, - 'wall_time': self.wall_time, - 'first_sync': self.first_sync, - 'db_version': self.db_version, - } - batch.put(b'state', repr(state).encode()) - - def set_flush_count(self, count): - self.utxo_flush_count = count - with self.utxo_db.write_batch() as batch: - self.write_utxo_state(batch) - - async def all_utxos(self, hashX): - """Return all UTXOs for an address sorted in no particular order.""" - def read_utxos(): - utxos = [] - utxos_append = utxos.append - s_unpack = unpack - # Key: b'u' + address_hashX + tx_idx + tx_num - # Value: the UTXO value as a 64-bit unsigned integer - prefix = b'u' + hashX - for db_key, db_value in self.utxo_db.iterator(prefix=prefix): - tx_pos, tx_num = s_unpack(' MemPoolTx - hashXs: hashX -> set of all hashes of txs touching the hashX - """ - - def __init__(self, coin, api, refresh_secs=5.0, log_status_secs=120.0): - assert isinstance(api, MemPoolAPI) - self.coin = coin - self.api = api - self.logger = class_logger(__name__, self.__class__.__name__) - self.txs = {} - self.hashXs = defaultdict(set) # None can be a key - self.cached_compact_histogram = [] - self.refresh_secs = refresh_secs - self.log_status_secs = log_status_secs - # Prevents mempool refreshes during fee histogram calculation - self.lock = asyncio.Lock() - self.wakeup = asyncio.Event() - self.executor = ThreadPoolExecutor(max(os.cpu_count() - 1, 1)) - - async def _logging(self, synchronized_event): - """Print regular logs of mempool stats.""" - self.logger.info('beginning processing of daemon mempool. ' - 'This can take some time...') - start = time.perf_counter() - await synchronized_event.wait() - elapsed = time.perf_counter() - start - self.logger.info(f'synced in {elapsed:.2f}s') - while True: - self.logger.info(f'{len(self.txs):,d} txs ' - f'touching {len(self.hashXs):,d} addresses') - await asyncio.sleep(self.log_status_secs) - await synchronized_event.wait() - - async def _refresh_histogram(self, synchronized_event): - try: - while True: - await synchronized_event.wait() - async with self.lock: - # Threaded as can be expensive - await asyncio.get_event_loop().run_in_executor(self.executor, self._update_histogram, 100_000) - await asyncio.sleep(self.coin.MEMPOOL_HISTOGRAM_REFRESH_SECS) - finally: - self.executor.shutdown(wait=True) - - def _update_histogram(self, bin_size): - # Build a histogram by fee rate - histogram = defaultdict(int) - for tx in self.txs.values(): - histogram[tx.fee // tx.size] += tx.size - - # Now compact it. For efficiency, get_fees returns a - # compact histogram with variable bin size. The compact - # histogram is an array of (fee_rate, vsize) values. - # vsize_n is the cumulative virtual size of mempool - # transactions with a fee rate in the interval - # [rate_(n-1), rate_n)], and rate_(n-1) > rate_n. - # Intervals are chosen to create tranches containing at - # least 100kb of transactions - compact = [] - cum_size = 0 - r = 0 # ? - for fee_rate, size in sorted(histogram.items(), reverse=True): - cum_size += size - if cum_size + r > bin_size: - compact.append((fee_rate, cum_size)) - r += cum_size - bin_size - cum_size = 0 - bin_size *= 1.1 - self.logger.info(f'compact fee histogram: {compact}') - self.cached_compact_histogram = compact - - def _accept_transactions(self, tx_map, utxo_map, touched): - """Accept transactions in tx_map to the mempool if all their inputs - can be found in the existing mempool or a utxo_map from the - DB. - - Returns an (unprocessed tx_map, unspent utxo_map) pair. - """ - hashXs = self.hashXs - txs = self.txs - - deferred = {} - unspent = set(utxo_map) - # Try to find all prevouts so we can accept the TX - for hash, tx in tx_map.items(): - in_pairs = [] - try: - for prevout in tx.prevouts: - utxo = utxo_map.get(prevout) - if not utxo: - prev_hash, prev_index = prevout - # Raises KeyError if prev_hash is not in txs - utxo = txs[prev_hash].out_pairs[prev_index] - in_pairs.append(utxo) - except KeyError: - deferred[hash] = tx - continue - - # Spend the prevouts - unspent.difference_update(tx.prevouts) - - # Save the in_pairs, compute the fee and accept the TX - tx.in_pairs = tuple(in_pairs) - # Avoid negative fees if dealing with generation-like transactions - # because some in_parts would be missing - tx.fee = max(0, (sum(v for _, v in tx.in_pairs) - - sum(v for _, v in tx.out_pairs))) - txs[hash] = tx - - for hashX, value in itertools.chain(tx.in_pairs, tx.out_pairs): - touched.add(hashX) - hashXs[hashX].add(hash) - - return deferred, {prevout: utxo_map[prevout] for prevout in unspent} - - async def _refresh_hashes(self, synchronized_event): - """Refresh our view of the daemon's mempool.""" - while True: - height = self.api.cached_height() - hex_hashes = await self.api.mempool_hashes() - if height != await self.api.height(): - continue - hashes = {hex_str_to_hash(hh) for hh in hex_hashes} - async with self.lock: - touched = await self._process_mempool(hashes) - synchronized_event.set() - synchronized_event.clear() - await self.api.on_mempool(touched, height) - try: - # we wait up to `refresh_secs` but go early if a broadcast happens (which triggers wakeup event) - await asyncio.wait_for(self.wakeup.wait(), timeout=self.refresh_secs) - except asyncio.TimeoutError: - pass - finally: - self.wakeup.clear() - - async def _process_mempool(self, all_hashes): - # Re-sync with the new set of hashes - txs = self.txs - hashXs = self.hashXs - touched = set() - - # First handle txs that have disappeared - for tx_hash in set(txs).difference(all_hashes): - tx = txs.pop(tx_hash) - tx_hashXs = {hashX for hashX, value in tx.in_pairs} - tx_hashXs.update(hashX for hashX, value in tx.out_pairs) - for hashX in tx_hashXs: - hashXs[hashX].remove(tx_hash) - if not hashXs[hashX]: - del hashXs[hashX] - touched.update(tx_hashXs) - - # Process new transactions - new_hashes = list(all_hashes.difference(txs)) - if new_hashes: - fetches = [] - for hashes in chunks(new_hashes, 200): - fetches.append(self._fetch_and_accept(hashes, all_hashes, touched)) - tx_map = {} - utxo_map = {} - for fetch in asyncio.as_completed(fetches): - deferred, unspent = await fetch - tx_map.update(deferred) - utxo_map.update(unspent) - - prior_count = 0 - # FIXME: this is not particularly efficient - while tx_map and len(tx_map) != prior_count: - prior_count = len(tx_map) - tx_map, utxo_map = self._accept_transactions(tx_map, utxo_map, - touched) - if tx_map: - self.logger.info(f'{len(tx_map)} txs dropped') - - return touched - - async def _fetch_and_accept(self, hashes, all_hashes, touched): - """Fetch a list of mempool transactions.""" - hex_hashes_iter = (hash_to_hex_str(hash) for hash in hashes) - raw_txs = await self.api.raw_transactions(hex_hashes_iter) - - def deserialize_txs(): # This function is pure - to_hashX = self.coin.hashX_from_script - deserializer = self.coin.DESERIALIZER - - txs = {} - for hash, raw_tx in zip(hashes, raw_txs): - # The daemon may have evicted the tx from its - # mempool or it may have gotten in a block - if not raw_tx: - continue - tx, tx_size = deserializer(raw_tx).read_tx_and_vsize() - # Convert the inputs and outputs into (hashX, value) pairs - # Drop generation-like inputs from MemPoolTx.prevouts - txin_pairs = tuple((txin.prev_hash, txin.prev_idx) - for txin in tx.inputs - if not txin.is_generation()) - txout_pairs = tuple((to_hashX(txout.pk_script), txout.value) - for txout in tx.outputs) - txs[hash] = MemPoolTx(txin_pairs, None, txout_pairs, - 0, tx_size) - return txs - - # Thread this potentially slow operation so as not to block - tx_map = await asyncio.get_event_loop().run_in_executor(self.executor, deserialize_txs) - - # Determine all prevouts not in the mempool, and fetch the - # UTXO information from the database. Failed prevout lookups - # return None - concurrent database updates happen - which is - # relied upon by _accept_transactions. Ignore prevouts that are - # generation-like. - prevouts = tuple(prevout for tx in tx_map.values() - for prevout in tx.prevouts - if prevout[0] not in all_hashes) - utxos = await self.api.lookup_utxos(prevouts) - utxo_map = {prevout: utxo for prevout, utxo in zip(prevouts, utxos)} - - return self._accept_transactions(tx_map, utxo_map, touched) - - # - # External interface - # - - async def keep_synchronized(self, synchronized_event): - """Keep the mempool synchronized with the daemon.""" - await asyncio.wait([ - self._refresh_hashes(synchronized_event), - self._refresh_histogram(synchronized_event), - self._logging(synchronized_event) - ]) - - async def balance_delta(self, hashX): - """Return the unconfirmed amount in the mempool for hashX. - - Can be positive or negative. - """ - value = 0 - if hashX in self.hashXs: - for hash in self.hashXs[hashX]: - tx = self.txs[hash] - value -= sum(v for h168, v in tx.in_pairs if h168 == hashX) - value += sum(v for h168, v in tx.out_pairs if h168 == hashX) - return value - - async def compact_fee_histogram(self): - """Return a compact fee histogram of the current mempool.""" - return self.cached_compact_histogram - - async def potential_spends(self, hashX): - """Return a set of (prev_hash, prev_idx) pairs from mempool - transactions that touch hashX. - - None, some or all of these may be spends of the hashX, but all - actual spends of it (in the DB or mempool) will be included. - """ - result = set() - for tx_hash in self.hashXs.get(hashX, ()): - tx = self.txs[tx_hash] - result.update(tx.prevouts) - return result - - async def transaction_summaries(self, hashX): - """Return a list of MemPoolTxSummary objects for the hashX.""" - result = [] - for tx_hash in self.hashXs.get(hashX, ()): - tx = self.txs[tx_hash] - has_ui = any(hash in self.txs for hash, idx in tx.prevouts) - result.append(MemPoolTxSummary(tx_hash, tx.fee, has_ui)) - return result - - async def unordered_UTXOs(self, hashX): - """Return an unordered list of UTXO named tuples from mempool - transactions that pay to hashX. - - This does not consider if any other mempool transactions spend - the outputs. - """ - utxos = [] - for tx_hash in self.hashXs.get(hashX, ()): - tx = self.txs.get(tx_hash) - for pos, (hX, value) in enumerate(tx.out_pairs): - if hX == hashX: - utxos.append(UTXO(-1, pos, tx_hash, 0, value)) - return utxos diff --git a/lbry/wallet/server/merkle.py b/lbry/wallet/server/merkle.py deleted file mode 100644 index 174e77b8e..000000000 --- a/lbry/wallet/server/merkle.py +++ /dev/null @@ -1,253 +0,0 @@ -# Copyright (c) 2018, Neil Booth -# -# All rights reserved. -# -# The MIT License (MIT) -# -# Permission is hereby granted, free of charge, to any person obtaining -# a copy of this software and associated documentation files (the -# "Software"), to deal in the Software without restriction, including -# without limitation the rights to use, copy, modify, merge, publish, -# distribute, sublicense, and/or sell copies of the Software, and to -# permit persons to whom the Software is furnished to do so, subject to -# the following conditions: -# -# The above copyright notice and this permission notice shall be -# included in all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -# and warranty status of this software. - -"""Merkle trees, branches, proofs and roots.""" - -from asyncio import Event -from math import ceil, log - -from lbry.wallet.server.hash import double_sha256 - - -class Merkle: - """Perform merkle tree calculations on binary hashes using a given hash - function. - - If the hash count is not even, the final hash is repeated when - calculating the next merkle layer up the tree. - """ - - def __init__(self, hash_func=double_sha256): - self.hash_func = hash_func - - def tree_depth(self, hash_count): - return self.branch_length(hash_count) + 1 - - def branch_length(self, hash_count): - """Return the length of a merkle branch given the number of hashes.""" - if not isinstance(hash_count, int): - raise TypeError('hash_count must be an integer') - if hash_count < 1: - raise ValueError('hash_count must be at least 1') - return ceil(log(hash_count, 2)) - - def branch_and_root(self, hashes, index, length=None): - """Return a (merkle branch, merkle_root) pair given hashes, and the - index of one of those hashes. - """ - hashes = list(hashes) - if not isinstance(index, int): - raise TypeError('index must be an integer') - # This also asserts hashes is not empty - if not 0 <= index < len(hashes): - raise ValueError('index out of range') - natural_length = self.branch_length(len(hashes)) - if length is None: - length = natural_length - else: - if not isinstance(length, int): - raise TypeError('length must be an integer') - if length < natural_length: - raise ValueError('length out of range') - - hash_func = self.hash_func - branch = [] - for _ in range(length): - if len(hashes) & 1: - hashes.append(hashes[-1]) - branch.append(hashes[index ^ 1]) - index >>= 1 - hashes = [hash_func(hashes[n] + hashes[n + 1]) - for n in range(0, len(hashes), 2)] - - return branch, hashes[0] - - def root(self, hashes, length=None): - """Return the merkle root of a non-empty iterable of binary hashes.""" - branch, root = self.branch_and_root(hashes, 0, length) - return root - - def root_from_proof(self, hash, branch, index): - """Return the merkle root given a hash, a merkle branch to it, and - its index in the hashes array. - - branch is an iterable sorted deepest to shallowest. If the - returned root is the expected value then the merkle proof is - verified. - - The caller should have confirmed the length of the branch with - branch_length(). Unfortunately this is not easily done for - bitcoin transactions as the number of transactions in a block - is unknown to an SPV client. - """ - hash_func = self.hash_func - for elt in branch: - if index & 1: - hash = hash_func(elt + hash) - else: - hash = hash_func(hash + elt) - index >>= 1 - if index: - raise ValueError('index out of range for branch') - return hash - - def level(self, hashes, depth_higher): - """Return a level of the merkle tree of hashes the given depth - higher than the bottom row of the original tree.""" - size = 1 << depth_higher - root = self.root - return [root(hashes[n: n + size], depth_higher) - for n in range(0, len(hashes), size)] - - def branch_and_root_from_level(self, level, leaf_hashes, index, - depth_higher): - """Return a (merkle branch, merkle_root) pair when a merkle-tree has a - level cached. - - To maximally reduce the amount of data hashed in computing a - markle branch, cache a tree of depth N at level N // 2. - - level is a list of hashes in the middle of the tree (returned - by level()) - - leaf_hashes are the leaves needed to calculate a partial branch - up to level. - - depth_higher is how much higher level is than the leaves of the tree - - index is the index in the full list of hashes of the hash whose - merkle branch we want. - """ - if not isinstance(level, list): - raise TypeError("level must be a list") - if not isinstance(leaf_hashes, list): - raise TypeError("leaf_hashes must be a list") - leaf_index = (index >> depth_higher) << depth_higher - leaf_branch, leaf_root = self.branch_and_root( - leaf_hashes, index - leaf_index, depth_higher) - index >>= depth_higher - level_branch, root = self.branch_and_root(level, index) - # Check last so that we know index is in-range - if leaf_root != level[index]: - raise ValueError('leaf hashes inconsistent with level') - return leaf_branch + level_branch, root - - -class MerkleCache: - """A cache to calculate merkle branches efficiently.""" - - def __init__(self, merkle, source_func): - """Initialise a cache hashes taken from source_func: - - async def source_func(index, count): - ... - """ - self.merkle = merkle - self.source_func = source_func - self.length = 0 - self.depth_higher = 0 - self.initialized = Event() - - def _segment_length(self): - return 1 << self.depth_higher - - def _leaf_start(self, index): - """Given a level's depth higher and a hash index, return the leaf - index and leaf hash count needed to calculate a merkle branch. - """ - depth_higher = self.depth_higher - return (index >> depth_higher) << depth_higher - - def _level(self, hashes): - return self.merkle.level(hashes, self.depth_higher) - - async def _extend_to(self, length): - """Extend the length of the cache if necessary.""" - if length <= self.length: - return - # Start from the beginning of any final partial segment. - # Retain the value of depth_higher; in practice this is fine - start = self._leaf_start(self.length) - hashes = await self.source_func(start, length - start) - self.level[start >> self.depth_higher:] = self._level(hashes) - self.length = length - - async def _level_for(self, length): - """Return a (level_length, final_hash) pair for a truncation - of the hashes to the given length.""" - if length == self.length: - return self.level - level = self.level[:length >> self.depth_higher] - leaf_start = self._leaf_start(length) - count = min(self._segment_length(), length - leaf_start) - hashes = await self.source_func(leaf_start, count) - level += self._level(hashes) - return level - - async def initialize(self, length): - """Call to initialize the cache to a source of given length.""" - self.length = length - self.depth_higher = self.merkle.tree_depth(length) // 2 - self.level = self._level(await self.source_func(0, length)) - self.initialized.set() - - def truncate(self, length): - """Truncate the cache so it covers no more than length underlying - hashes.""" - if not isinstance(length, int): - raise TypeError('length must be an integer') - if length <= 0: - raise ValueError('length must be positive') - if length >= self.length: - return - length = self._leaf_start(length) - self.length = length - self.level[length >> self.depth_higher:] = [] - - async def branch_and_root(self, length, index): - """Return a merkle branch and root. Length is the number of - hashes used to calculate the merkle root, index is the position - of the hash to calculate the branch of. - - index must be less than length, which must be at least 1.""" - if not isinstance(length, int): - raise TypeError('length must be an integer') - if not isinstance(index, int): - raise TypeError('index must be an integer') - if length <= 0: - raise ValueError('length must be positive') - if index >= length: - raise ValueError('index must be less than length') - await self.initialized.wait() - await self._extend_to(length) - leaf_start = self._leaf_start(index) - count = min(self._segment_length(), length - leaf_start) - leaf_hashes = await self.source_func(leaf_start, count) - if length < self._segment_length(): - return self.merkle.branch_and_root(leaf_hashes, index) - level = await self._level_for(length) - return self.merkle.branch_and_root_from_level( - level, leaf_hashes, index, self.depth_higher) diff --git a/lbry/wallet/server/metrics.py b/lbry/wallet/server/metrics.py deleted file mode 100644 index f1ee4e5d1..000000000 --- a/lbry/wallet/server/metrics.py +++ /dev/null @@ -1,135 +0,0 @@ -import time -import math -from typing import Tuple - - -def calculate_elapsed(start) -> int: - return int((time.perf_counter() - start) * 1000) - - -def calculate_avg_percentiles(data) -> Tuple[int, int, int, int, int, int, int, int]: - if not data: - return 0, 0, 0, 0, 0, 0, 0, 0 - data.sort() - size = len(data) - return ( - int(sum(data) / size), - data[0], - data[math.ceil(size * .05) - 1], - data[math.ceil(size * .25) - 1], - data[math.ceil(size * .50) - 1], - data[math.ceil(size * .75) - 1], - data[math.ceil(size * .95) - 1], - data[-1] - ) - - -def remove_select_list(sql) -> str: - return sql[sql.index('FROM'):] - - -class APICallMetrics: - - def __init__(self, name): - self.name = name - - # total requests received - self.receive_count = 0 - self.cache_response_count = 0 - - # millisecond timings for query based responses - self.query_response_times = [] - self.query_intrp_times = [] - self.query_error_times = [] - - self.query_python_times = [] - self.query_wait_times = [] - self.query_sql_times = [] # aggregate total of multiple SQL calls made per request - - self.individual_sql_times = [] # every SQL query run on server - - # actual queries - self.errored_queries = set() - self.interrupted_queries = set() - - def to_json(self): - return { - # total requests received - "receive_count": self.receive_count, - # sum of these is total responses made - "cache_response_count": self.cache_response_count, - "query_response_count": len(self.query_response_times), - "intrp_response_count": len(self.query_intrp_times), - "error_response_count": len(self.query_error_times), - # millisecond timings for non-cache responses - "response": calculate_avg_percentiles(self.query_response_times), - "interrupt": calculate_avg_percentiles(self.query_intrp_times), - "error": calculate_avg_percentiles(self.query_error_times), - # response, interrupt and error each also report the python, wait and sql stats: - "python": calculate_avg_percentiles(self.query_python_times), - "wait": calculate_avg_percentiles(self.query_wait_times), - "sql": calculate_avg_percentiles(self.query_sql_times), - # extended timings for individual sql executions - "individual_sql": calculate_avg_percentiles(self.individual_sql_times), - "individual_sql_count": len(self.individual_sql_times), - # actual queries - "errored_queries": list(self.errored_queries), - "interrupted_queries": list(self.interrupted_queries), - } - - def start(self): - self.receive_count += 1 - - def cache_response(self): - self.cache_response_count += 1 - - def _add_query_timings(self, request_total_time, metrics): - if metrics and 'execute_query' in metrics: - sub_process_total = metrics[self.name][0]['total'] - individual_query_times = [f['total'] for f in metrics['execute_query']] - aggregated_query_time = sum(individual_query_times) - self.individual_sql_times.extend(individual_query_times) - self.query_sql_times.append(aggregated_query_time) - self.query_python_times.append(sub_process_total - aggregated_query_time) - self.query_wait_times.append(request_total_time - sub_process_total) - - @staticmethod - def _add_queries(query_set, metrics): - if metrics and 'execute_query' in metrics: - for execute_query in metrics['execute_query']: - if 'sql' in execute_query: - query_set.add(remove_select_list(execute_query['sql'])) - - def query_response(self, start, metrics): - self.query_response_times.append(calculate_elapsed(start)) - self._add_query_timings(self.query_response_times[-1], metrics) - - def query_interrupt(self, start, metrics): - self.query_intrp_times.append(calculate_elapsed(start)) - self._add_queries(self.interrupted_queries, metrics) - self._add_query_timings(self.query_intrp_times[-1], metrics) - - def query_error(self, start, metrics): - self.query_error_times.append(calculate_elapsed(start)) - self._add_queries(self.errored_queries, metrics) - self._add_query_timings(self.query_error_times[-1], metrics) - - -class ServerLoadData: - - def __init__(self): - self._apis = {} - - def for_api(self, name) -> APICallMetrics: - if name not in self._apis: - self._apis[name] = APICallMetrics(name) - return self._apis[name] - - def to_json_and_reset(self, status): - try: - return { - 'api': {name: api.to_json() for name, api in self._apis.items()}, - 'status': status - } - finally: - self._apis = {} diff --git a/lbry/wallet/server/peer.py b/lbry/wallet/server/peer.py deleted file mode 100644 index 078fda9e4..000000000 --- a/lbry/wallet/server/peer.py +++ /dev/null @@ -1,302 +0,0 @@ -# Copyright (c) 2017, Neil Booth -# -# All rights reserved. -# -# The MIT License (MIT) -# -# Permission is hereby granted, free of charge, to any person obtaining -# a copy of this software and associated documentation files (the -# "Software"), to deal in the Software without restriction, including -# without limitation the rights to use, copy, modify, merge, publish, -# distribute, sublicense, and/or sell copies of the Software, and to -# permit persons to whom the Software is furnished to do so, subject to -# the following conditions: -# -# The above copyright notice and this permission notice shall be -# included in all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -"""Representation of a peer server.""" - -from ipaddress import ip_address - -from lbry.wallet.server import util -from lbry.wallet.server.util import cachedproperty - -from typing import Dict - - -class Peer: - - # Protocol version - ATTRS = ('host', 'features', - # metadata - 'source', 'ip_addr', - 'last_good', 'last_try', 'try_count') - FEATURES = ('pruning', 'server_version', 'protocol_min', 'protocol_max', - 'ssl_port', 'tcp_port') - # This should be set by the application - DEFAULT_PORTS: Dict[str, int] = {} - - def __init__(self, host, features, source='unknown', ip_addr=None, - last_good=0, last_try=0, try_count=0): - """Create a peer given a host name (or IP address as a string), - a dictionary of features, and a record of the source.""" - assert isinstance(host, str) - assert isinstance(features, dict) - assert host in features.get('hosts', {}) - self.host = host - self.features = features.copy() - # Canonicalize / clean-up - for feature in self.FEATURES: - self.features[feature] = getattr(self, feature) - # Metadata - self.source = source - self.ip_addr = ip_addr - # last_good represents the last connection that was - # successful *and* successfully verified, at which point - # try_count is set to 0. Failure to connect or failure to - # verify increment the try_count. - self.last_good = last_good - self.last_try = last_try - self.try_count = try_count - # Transient, non-persisted metadata - self.bad = False - self.other_port_pairs = set() - self.status = 2 - - @classmethod - def peers_from_features(cls, features, source): - peers = [] - if isinstance(features, dict): - hosts = features.get('hosts') - if isinstance(hosts, dict): - peers = [Peer(host, features, source=source) - for host in hosts if isinstance(host, str)] - return peers - - @classmethod - def deserialize(cls, item): - """Deserialize from a dictionary.""" - return cls(**item) - - def matches(self, peers): - """Return peers whose host matches our hostname or IP address. - Additionally include all peers whose IP address matches our - hostname if that is an IP address. - """ - candidates = (self.host.lower(), self.ip_addr) - return [peer for peer in peers - if peer.host.lower() in candidates - or peer.ip_addr == self.host] - - def __str__(self): - return self.host - - def update_features(self, features): - """Update features in-place.""" - try: - tmp = Peer(self.host, features) - except Exception: - pass - else: - self.update_features_from_peer(tmp) - - def update_features_from_peer(self, peer): - if peer != self: - self.features = peer.features - for feature in self.FEATURES: - setattr(self, feature, getattr(peer, feature)) - - def connection_port_pairs(self): - """Return a list of (kind, port) pairs to try when making a - connection.""" - # Use a list not a set - it's important to try the registered - # ports first. - pairs = [('SSL', self.ssl_port), ('TCP', self.tcp_port)] - while self.other_port_pairs: - pairs.append(self.other_port_pairs.pop()) - return [pair for pair in pairs if pair[1]] - - def mark_bad(self): - """Mark as bad to avoid reconnects but also to remember for a - while.""" - self.bad = True - - def check_ports(self, other): - """Remember differing ports in case server operator changed them - or removed one.""" - if other.ssl_port != self.ssl_port: - self.other_port_pairs.add(('SSL', other.ssl_port)) - if other.tcp_port != self.tcp_port: - self.other_port_pairs.add(('TCP', other.tcp_port)) - return bool(self.other_port_pairs) - - @cachedproperty - def is_tor(self): - return self.host.endswith('.onion') - - @cachedproperty - def is_valid(self): - ip = self.ip_address - if ip: - return ((ip.is_global or ip.is_private) - and not (ip.is_multicast or ip.is_unspecified)) - return util.is_valid_hostname(self.host) - - @cachedproperty - def is_public(self): - ip = self.ip_address - if ip: - return self.is_valid and not ip.is_private - else: - return self.is_valid and self.host != 'localhost' - - @cachedproperty - def ip_address(self): - """The host as a python ip_address object, or None.""" - try: - return ip_address(self.host) - except ValueError: - return None - - def bucket(self): - if self.is_tor: - return 'onion' - if not self.ip_addr: - return '' - return tuple(self.ip_addr.split('.')[:2]) - - def serialize(self): - """Serialize to a dictionary.""" - return {attr: getattr(self, attr) for attr in self.ATTRS} - - def _port(self, key): - hosts = self.features.get('hosts') - if isinstance(hosts, dict): - host = hosts.get(self.host) - port = self._integer(key, host) - if port and 0 < port < 65536: - return port - return None - - def _integer(self, key, d=None): - d = d or self.features - result = d.get(key) if isinstance(d, dict) else None - if isinstance(result, str): - try: - result = int(result) - except ValueError: - pass - return result if isinstance(result, int) else None - - def _string(self, key): - result = self.features.get(key) - return result if isinstance(result, str) else None - - @cachedproperty - def genesis_hash(self): - """Returns None if no SSL port, otherwise the port as an integer.""" - return self._string('genesis_hash') - - @cachedproperty - def ssl_port(self): - """Returns None if no SSL port, otherwise the port as an integer.""" - return self._port('ssl_port') - - @cachedproperty - def tcp_port(self): - """Returns None if no TCP port, otherwise the port as an integer.""" - return self._port('tcp_port') - - @cachedproperty - def server_version(self): - """Returns the server version as a string if known, otherwise None.""" - return self._string('server_version') - - @cachedproperty - def pruning(self): - """Returns the pruning level as an integer. None indicates no - pruning.""" - pruning = self._integer('pruning') - if pruning and pruning > 0: - return pruning - return None - - def _protocol_version_string(self, key): - version_str = self.features.get(key) - ptuple = util.protocol_tuple(version_str) - return util.version_string(ptuple) - - @cachedproperty - def protocol_min(self): - """Minimum protocol version as a string, e.g., 1.0""" - return self._protocol_version_string('protocol_min') - - @cachedproperty - def protocol_max(self): - """Maximum protocol version as a string, e.g., 1.1""" - return self._protocol_version_string('protocol_max') - - def to_tuple(self): - """The tuple ((ip, host, details) expected in response - to a peers subscription.""" - details = self.real_name().split()[1:] - return (self.ip_addr or self.host, self.host, details) - - def real_name(self): - """Real name of this peer as used on IRC.""" - def port_text(letter, port): - if port == self.DEFAULT_PORTS.get(letter): - return letter - else: - return letter + str(port) - - parts = [self.host, 'v' + self.protocol_max] - if self.pruning: - parts.append(f'p{self.pruning:d}') - for letter, port in (('s', self.ssl_port), ('t', self.tcp_port)): - if port: - parts.append(port_text(letter, port)) - return ' '.join(parts) - - @classmethod - def from_real_name(cls, real_name, source): - """Real name is a real name as on IRC, such as - - "erbium1.sytes.net v1.0 s t" - - Returns an instance of this Peer class. - """ - host = 'nohost' - features = {} - ports = {} - for n, part in enumerate(real_name.split()): - if n == 0: - host = part - continue - if part[0] in ('s', 't'): - if len(part) == 1: - port = cls.DEFAULT_PORTS[part[0]] - else: - port = part[1:] - if part[0] == 's': - ports['ssl_port'] = port - else: - ports['tcp_port'] = port - elif part[0] == 'v': - features['protocol_max'] = features['protocol_min'] = part[1:] - elif part[0] == 'p': - features['pruning'] = part[1:] - - features.update(ports) - features['hosts'] = {host: ports} - - return cls(host, features, source) diff --git a/lbry/wallet/server/peers.py b/lbry/wallet/server/peers.py deleted file mode 100644 index f1407339b..000000000 --- a/lbry/wallet/server/peers.py +++ /dev/null @@ -1,506 +0,0 @@ -# Copyright (c) 2017-2018, Neil Booth -# -# All rights reserved. -# -# See the file "LICENCE" for information about the copyright -# and warranty status of this software. - -"""Peer management.""" - -import asyncio -import random -import socket -import ssl -import time -import typing -from asyncio import Event, sleep -from collections import defaultdict, Counter - -from lbry.wallet.tasks import TaskGroup -from lbry.wallet.rpc import ( - Connector, RPCSession, SOCKSProxy, Notification, handler_invocation, - SOCKSError, RPCError -) -from lbry.wallet.server.peer import Peer -from lbry.wallet.server.util import class_logger, protocol_tuple - -PEER_GOOD, PEER_STALE, PEER_NEVER, PEER_BAD = range(4) -STALE_SECS = 24 * 3600 -WAKEUP_SECS = 300 - - -class BadPeerError(Exception): - pass - - -def assert_good(message, result, instance): - if not isinstance(result, instance): - raise BadPeerError(f'{message} returned bad result type ' - f'{type(result).__name__}') - - -class PeerSession(RPCSession): - """An outgoing session to a peer.""" - - async def handle_request(self, request): - # We subscribe so might be unlucky enough to get a notification... - if (isinstance(request, Notification) and - request.method == 'blockchain.headers.subscribe'): - pass - else: - await handler_invocation(None, request) # Raises - - -class PeerManager: - """Looks after the DB of peer network servers. - - Attempts to maintain a connection with up to 8 peers. - Issues a 'peers.subscribe' RPC to them and tells them our data. - """ - def __init__(self, env, db): - self.logger = class_logger(__name__, self.__class__.__name__) - # Initialise the Peer class - Peer.DEFAULT_PORTS = env.coin.PEER_DEFAULT_PORTS - self.env = env - self.db = db - - # Our clearnet and Tor Peers, if any - sclass = env.coin.SESSIONCLS - self.myselves = [Peer(ident.host, sclass.server_features(env), 'env') - for ident in env.identities] - self.server_version_args = sclass.server_version_args() - # Peers have one entry per hostname. Once connected, the - # ip_addr property is either None, an onion peer, or the - # IP address that was connected to. Adding a peer will evict - # any other peers with the same host name or IP address. - self.peers: typing.Set[Peer] = set() - self.permit_onion_peer_time = time.time() - self.proxy = None - self.group = TaskGroup() - - def _my_clearnet_peer(self): - """Returns the clearnet peer representing this server, if any.""" - clearnet = [peer for peer in self.myselves if not peer.is_tor] - return clearnet[0] if clearnet else None - - def _set_peer_statuses(self): - """Set peer statuses.""" - cutoff = time.time() - STALE_SECS - for peer in self.peers: - if peer.bad: - peer.status = PEER_BAD - elif peer.last_good > cutoff: - peer.status = PEER_GOOD - elif peer.last_good: - peer.status = PEER_STALE - else: - peer.status = PEER_NEVER - - def _features_to_register(self, peer, remote_peers): - """If we should register ourselves to the remote peer, which has - reported the given list of known peers, return the clearnet - identity features to register, otherwise None. - """ - # Announce ourself if not present. Don't if disabled, we - # are a non-public IP address, or to ourselves. - if not self.env.peer_announce or peer in self.myselves: - return None - my = self._my_clearnet_peer() - if not my or not my.is_public: - return None - # Register if no matches, or ports have changed - for peer in my.matches(remote_peers): - if peer.tcp_port == my.tcp_port and peer.ssl_port == my.ssl_port: - return None - return my.features - - def _permit_new_onion_peer(self): - """Accept a new onion peer only once per random time interval.""" - now = time.time() - if now < self.permit_onion_peer_time: - return False - self.permit_onion_peer_time = now + random.randrange(0, 1200) - return True - - async def _import_peers(self): - """Import hard-coded peers from a file or the coin defaults.""" - imported_peers = self.myselves.copy() - # Add the hard-coded ones unless only reporting ourself - if self.env.peer_discovery != self.env.PD_SELF: - imported_peers.extend(Peer.from_real_name(real_name, 'coins.py') - for real_name in self.env.coin.PEERS) - await self._note_peers(imported_peers, limit=None) - - async def _detect_proxy(self): - """Detect a proxy if we don't have one and some time has passed since - the last attempt. - - If found self.proxy is set to a SOCKSProxy instance, otherwise - None. - """ - host = self.env.tor_proxy_host - if self.env.tor_proxy_port is None: - ports = [9050, 9150, 1080] - else: - ports = [self.env.tor_proxy_port] - while True: - self.logger.info(f'trying to detect proxy on "{host}" ' - f'ports {ports}') - proxy = await SOCKSProxy.auto_detect_host(host, ports, None) - if proxy: - self.proxy = proxy - self.logger.info(f'detected {proxy}') - return - self.logger.info('no proxy detected, will try later') - await sleep(900) - - async def _note_peers(self, peers, limit=2, check_ports=False, - source=None): - """Add a limited number of peers that are not already present.""" - new_peers = [] - for peer in peers: - if not peer.is_public or (peer.is_tor and not self.proxy): - continue - - matches = peer.matches(self.peers) - if not matches: - new_peers.append(peer) - elif check_ports: - for match in matches: - if match.check_ports(peer): - self.logger.info(f'ports changed for {peer}') - match.retry_event.set() - - if new_peers: - source = source or new_peers[0].source - if limit: - random.shuffle(new_peers) - use_peers = new_peers[:limit] - else: - use_peers = new_peers - for peer in use_peers: - self.logger.info(f'accepted new peer {peer} from {source}') - peer.retry_event = Event() - self.peers.add(peer) - await self.group.add(self._monitor_peer(peer)) - - async def _monitor_peer(self, peer): - # Stop monitoring if we were dropped (a duplicate peer) - while peer in self.peers: - if await self._should_drop_peer(peer): - self.peers.discard(peer) - break - # Figure out how long to sleep before retrying. Retry a - # good connection when it is about to turn stale, otherwise - # exponentially back off retries. - if peer.try_count == 0: - pause = STALE_SECS - WAKEUP_SECS * 2 - else: - pause = WAKEUP_SECS * 2 ** peer.try_count - pending, done = await asyncio.wait([peer.retry_event.wait()], timeout=pause) - if done: - peer.retry_event.clear() - - async def _should_drop_peer(self, peer): - peer.try_count += 1 - is_good = False - for kind, port in peer.connection_port_pairs(): - peer.last_try = time.time() - - kwargs = {} - if kind == 'SSL': - kwargs['ssl'] = ssl.SSLContext(ssl.PROTOCOL_TLS) - - host = self.env.cs_host(for_rpc=False) - if isinstance(host, list): - host = host[0] - - if self.env.force_proxy or peer.is_tor: - if not self.proxy: - return - kwargs['proxy'] = self.proxy - kwargs['resolve'] = not peer.is_tor - elif host: - # Use our listening Host/IP for outgoing non-proxy - # connections so our peers see the correct source. - kwargs['local_addr'] = (host, None) - - peer_text = f'[{peer}:{port} {kind}]' - try: - async with Connector(PeerSession, peer.host, port, - **kwargs) as session: - await asyncio.wait_for( - self._verify_peer(session, peer), - 120 if peer.is_tor else 30 - ) - is_good = True - break - except BadPeerError as e: - self.logger.error(f'{peer_text} marking bad: ({e})') - peer.mark_bad() - break - except RPCError as e: - self.logger.error(f'{peer_text} RPC error: {e.message} ' - f'({e.code})') - except (OSError, SOCKSError, ConnectionError, asyncio.TimeoutError) as e: - self.logger.info(f'{peer_text} {e}') - - if is_good: - now = time.time() - elapsed = now - peer.last_try - self.logger.info(f'{peer_text} verified in {elapsed:.1f}s') - peer.try_count = 0 - peer.last_good = now - peer.source = 'peer' - # At most 2 matches if we're a host name, potentially - # several if we're an IP address (several instances - # can share a NAT). - matches = peer.matches(self.peers) - for match in matches: - if match.ip_address: - if len(matches) > 1: - self.peers.remove(match) - # Force the peer's monitoring task to exit - match.retry_event.set() - elif peer.host in match.features['hosts']: - match.update_features_from_peer(peer) - else: - # Forget the peer if long-term unreachable - if peer.last_good and not peer.bad: - try_limit = 10 - else: - try_limit = 3 - if peer.try_count >= try_limit: - desc = 'bad' if peer.bad else 'unreachable' - self.logger.info(f'forgetting {desc} peer: {peer}') - return True - return False - - async def _verify_peer(self, session, peer): - if not peer.is_tor: - address = session.peer_address() - if address: - peer.ip_addr = address[0] - - # server.version goes first - message = 'server.version' - result = await session.send_request(message, self.server_version_args) - assert_good(message, result, list) - - # Protocol version 1.1 returns a pair with the version first - if len(result) != 2 or not all(isinstance(x, str) for x in result): - raise BadPeerError(f'bad server.version result: {result}') - server_version, protocol_version = result - peer.server_version = server_version - peer.features['server_version'] = server_version - ptuple = protocol_tuple(protocol_version) - - await asyncio.wait([ - self._send_headers_subscribe(session, peer, ptuple), - self._send_server_features(session, peer), - self._send_peers_subscribe(session, peer) - ]) - - async def _send_headers_subscribe(self, session, peer, ptuple): - message = 'blockchain.headers.subscribe' - result = await session.send_request(message) - assert_good(message, result, dict) - - our_height = self.db.db_height - if ptuple < (1, 3): - their_height = result.get('block_height') - else: - their_height = result.get('height') - if not isinstance(their_height, int): - raise BadPeerError(f'invalid height {their_height}') - if abs(our_height - their_height) > 5: - raise BadPeerError(f'bad height {their_height:,d} ' - f'(ours: {our_height:,d})') - - # Check prior header too in case of hard fork. - check_height = min(our_height, their_height) - raw_header = await self.db.raw_header(check_height) - if ptuple >= (1, 4): - ours = raw_header.hex() - message = 'blockchain.block.header' - theirs = await session.send_request(message, [check_height]) - assert_good(message, theirs, str) - if ours != theirs: - raise BadPeerError(f'our header {ours} and ' - f'theirs {theirs} differ') - else: - ours = self.env.coin.electrum_header(raw_header, check_height) - ours = ours.get('prev_block_hash') - message = 'blockchain.block.get_header' - theirs = await session.send_request(message, [check_height]) - assert_good(message, theirs, dict) - theirs = theirs.get('prev_block_hash') - if ours != theirs: - raise BadPeerError(f'our header hash {ours} and ' - f'theirs {theirs} differ') - - async def _send_server_features(self, session, peer): - message = 'server.features' - features = await session.send_request(message) - assert_good(message, features, dict) - hosts = [host.lower() for host in features.get('hosts', {})] - if self.env.coin.GENESIS_HASH != features.get('genesis_hash'): - raise BadPeerError('incorrect genesis hash') - elif peer.host.lower() in hosts: - peer.update_features(features) - else: - raise BadPeerError(f'not listed in own hosts list {hosts}') - - async def _send_peers_subscribe(self, session, peer): - message = 'server.peers.subscribe' - raw_peers = await session.send_request(message) - assert_good(message, raw_peers, list) - - # Check the peers list we got from a remote peer. - # Each is expected to be of the form: - # [ip_addr, hostname, ['v1.0', 't51001', 's51002']] - # Call add_peer if the remote doesn't appear to know about us. - try: - real_names = [' '.join([u[1]] + u[2]) for u in raw_peers] - peers = [Peer.from_real_name(real_name, str(peer)) - for real_name in real_names] - except Exception: - raise BadPeerError('bad server.peers.subscribe response') - - await self._note_peers(peers) - features = self._features_to_register(peer, peers) - if not features: - return - self.logger.info(f'registering ourself with {peer}') - # We only care to wait for the response - await session.send_request('server.add_peer', [features]) - - # - # External interface - # - async def discover_peers(self): - """Perform peer maintenance. This includes - - 1) Forgetting unreachable peers. - 2) Verifying connectivity of new peers. - 3) Retrying old peers at regular intervals. - """ - if self.env.peer_discovery != self.env.PD_ON: - self.logger.info('peer discovery is disabled') - return - - self.logger.info(f'beginning peer discovery. Force use of ' - f'proxy: {self.env.force_proxy}') - - self.group.add(self._detect_proxy()) - self.group.add(self._import_peers()) - - def info(self) -> typing.Dict[str, int]: - """The number of peers.""" - self._set_peer_statuses() - counter = Counter(peer.status for peer in self.peers) - return { - 'bad': counter[PEER_BAD], - 'good': counter[PEER_GOOD], - 'never': counter[PEER_NEVER], - 'stale': counter[PEER_STALE], - 'total': len(self.peers), - } - - async def add_localRPC_peer(self, real_name): - """Add a peer passed by the admin over LocalRPC.""" - await self._note_peers([Peer.from_real_name(real_name, 'RPC')]) - - async def on_add_peer(self, features, source_info): - """Add a peer (but only if the peer resolves to the source).""" - if not source_info: - self.logger.info('ignored add_peer request: no source info') - return False - source = source_info[0] - peers = Peer.peers_from_features(features, source) - if not peers: - self.logger.info('ignored add_peer request: no peers given') - return False - - # Just look at the first peer, require it - peer = peers[0] - host = peer.host - if peer.is_tor: - permit = self._permit_new_onion_peer() - reason = 'rate limiting' - else: - getaddrinfo = asyncio.get_event_loop().getaddrinfo - try: - infos = await getaddrinfo(host, 80, type=socket.SOCK_STREAM) - except socket.gaierror: - permit = False - reason = 'address resolution failure' - else: - permit = any(source == info[-1][0] for info in infos) - reason = 'source-destination mismatch' - - if permit: - self.logger.info(f'accepted add_peer request from {source} ' - f'for {host}') - await self._note_peers([peer], check_ports=True) - else: - self.logger.warning(f'rejected add_peer request from {source} ' - f'for {host} ({reason})') - - return permit - - def on_peers_subscribe(self, is_tor): - """Returns the server peers as a list of (ip, host, details) tuples. - - We return all peers we've connected to in the last day. - Additionally, if we don't have onion routing, we return a few - hard-coded onion servers. - """ - cutoff = time.time() - STALE_SECS - recent = [peer for peer in self.peers - if peer.last_good > cutoff and - not peer.bad and peer.is_public] - onion_peers = [] - - # Always report ourselves if valid (even if not public) - peers = {myself for myself in self.myselves - if myself.last_good > cutoff} - - # Bucket the clearnet peers and select up to two from each - buckets = defaultdict(list) - for peer in recent: - if peer.is_tor: - onion_peers.append(peer) - else: - buckets[peer.bucket()].append(peer) - for bucket_peers in buckets.values(): - random.shuffle(bucket_peers) - peers.update(bucket_peers[:2]) - - # Add up to 20% onion peers (but up to 10 is OK anyway) - random.shuffle(onion_peers) - max_onion = 50 if is_tor else max(10, len(peers) // 4) - - peers.update(onion_peers[:max_onion]) - - return [peer.to_tuple() for peer in peers] - - def proxy_peername(self): - """Return the peername of the proxy, if there is a proxy, otherwise - None.""" - return self.proxy.peername if self.proxy else None - - def rpc_data(self): - """Peer data for the peers RPC method.""" - self._set_peer_statuses() - descs = ['good', 'stale', 'never', 'bad'] - - def peer_data(peer): - data = peer.serialize() - data['status'] = descs[peer.status] - return data - - def peer_key(peer): - return (peer.bad, -peer.last_good) - - return [peer_data(peer) for peer in sorted(self.peers, key=peer_key)] diff --git a/lbry/wallet/server/prometheus.py b/lbry/wallet/server/prometheus.py deleted file mode 100644 index e28976bf9..000000000 --- a/lbry/wallet/server/prometheus.py +++ /dev/null @@ -1,89 +0,0 @@ -import os -from aiohttp import web -from prometheus_client import Counter, Info, generate_latest as prom_generate_latest, Histogram, Gauge -from lbry import __version__ as version -from lbry.build_info import BUILD, COMMIT_HASH, DOCKER_TAG -from lbry.wallet.server import util -import lbry.wallet.server.version as wallet_server_version - -NAMESPACE = "wallet_server" -CPU_COUNT = f"{os.cpu_count()}" -VERSION_INFO = Info('build', 'Wallet server build info (e.g. version, commit hash)', namespace=NAMESPACE) -VERSION_INFO.info({ - 'build': BUILD, - "commit": COMMIT_HASH, - "docker_tag": DOCKER_TAG, - 'version': version, - "min_version": util.version_string(wallet_server_version.PROTOCOL_MIN), - "cpu_count": CPU_COUNT -}) -SESSIONS_COUNT = Gauge("session_count", "Number of connected client sessions", namespace=NAMESPACE, - labelnames=("version", )) -REQUESTS_COUNT = Counter("requests_count", "Number of requests received", namespace=NAMESPACE, - labelnames=("method", "version")) -RESPONSE_TIMES = Histogram("response_time", "Response times", namespace=NAMESPACE, labelnames=("method", "version")) -NOTIFICATION_COUNT = Counter("notification", "Number of notifications sent (for subscriptions)", - namespace=NAMESPACE, labelnames=("method", "version")) -REQUEST_ERRORS_COUNT = Counter("request_error", "Number of requests that returned errors", namespace=NAMESPACE, - labelnames=("method", "version")) -SQLITE_INTERRUPT_COUNT = Counter("interrupt", "Number of interrupted queries", namespace=NAMESPACE) -SQLITE_OPERATIONAL_ERROR_COUNT = Counter( - "operational_error", "Number of queries that raised operational errors", namespace=NAMESPACE -) -SQLITE_INTERNAL_ERROR_COUNT = Counter( - "internal_error", "Number of queries raising unexpected errors", namespace=NAMESPACE -) -SQLITE_EXECUTOR_TIMES = Histogram("executor_time", "SQLite executor times", namespace=NAMESPACE) -SQLITE_PENDING_COUNT = Gauge( - "pending_queries_count", "Number of pending and running sqlite queries", namespace=NAMESPACE -) -LBRYCRD_REQUEST_TIMES = Histogram( - "lbrycrd_request", "lbrycrd requests count", namespace=NAMESPACE, labelnames=("method",) -) -LBRYCRD_PENDING_COUNT = Gauge( - "lbrycrd_pending_count", "Number of lbrycrd rpcs that are in flight", namespace=NAMESPACE, labelnames=("method",) -) -CLIENT_VERSIONS = Counter( - "clients", "Number of connections received per client version", - namespace=NAMESPACE, labelnames=("version",) -) -BLOCK_COUNT = Gauge( - "block_count", "Number of processed blocks", namespace=NAMESPACE -) -BLOCK_UPDATE_TIMES = Histogram("block_time", "Block update times", namespace=NAMESPACE) -REORG_COUNT = Gauge( - "reorg_count", "Number of reorgs", namespace=NAMESPACE -) -RESET_CONNECTIONS = Counter( - "reset_clients", "Number of reset connections by client version", - namespace=NAMESPACE, labelnames=("version",) -) - - -class PrometheusServer: - def __init__(self): - self.logger = util.class_logger(__name__, self.__class__.__name__) - self.runner = None - - async def start(self, port: int): - prom_app = web.Application() - prom_app.router.add_get('/metrics', self.handle_metrics_get_request) - self.runner = web.AppRunner(prom_app) - await self.runner.setup() - - metrics_site = web.TCPSite(self.runner, "0.0.0.0", port, shutdown_timeout=.5) - await metrics_site.start() - self.logger.info('metrics server listening on %s:%i', *metrics_site._server.sockets[0].getsockname()[:2]) - - async def handle_metrics_get_request(self, request: web.Request): - try: - return web.Response( - text=prom_generate_latest().decode(), - content_type='text/plain; version=0.0.4' - ) - except Exception: - self.logger.exception('could not generate prometheus data') - raise - - async def stop(self): - await self.runner.cleanup() diff --git a/lbry/wallet/server/script.py b/lbry/wallet/server/script.py deleted file mode 100644 index ce8d5e5a1..000000000 --- a/lbry/wallet/server/script.py +++ /dev/null @@ -1,289 +0,0 @@ -# Copyright (c) 2016-2017, Neil Booth -# -# All rights reserved. -# -# The MIT License (MIT) -# -# Permission is hereby granted, free of charge, to any person obtaining -# a copy of this software and associated documentation files (the -# "Software"), to deal in the Software without restriction, including -# without limitation the rights to use, copy, modify, merge, publish, -# distribute, sublicense, and/or sell copies of the Software, and to -# permit persons to whom the Software is furnished to do so, subject to -# the following conditions: -# -# The above copyright notice and this permission notice shall be -# included in all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -# and warranty status of this software. - -"""Script-related classes and functions.""" - - -from collections import namedtuple - -from lbry.wallet.server.util import unpack_le_uint16_from, unpack_le_uint32_from, \ - pack_le_uint16, pack_le_uint32 - - -class EnumError(Exception): - pass - - -class Enumeration: - - def __init__(self, name, enumList): - self.__doc__ = name - - lookup = {} - reverseLookup = {} - i = 0 - uniqueNames = set() - uniqueValues = set() - for x in enumList: - if isinstance(x, tuple): - x, i = x - if not isinstance(x, str): - raise EnumError(f"enum name {x} not a string") - if not isinstance(i, int): - raise EnumError(f"enum value {i} not an integer") - if x in uniqueNames: - raise EnumError(f"enum name {x} not unique") - if i in uniqueValues: - raise EnumError(f"enum value {i} not unique") - uniqueNames.add(x) - uniqueValues.add(i) - lookup[x] = i - reverseLookup[i] = x - i = i + 1 - self.lookup = lookup - self.reverseLookup = reverseLookup - - def __getattr__(self, attr): - result = self.lookup.get(attr) - if result is None: - raise AttributeError(f'enumeration has no member {attr}') - return result - - def whatis(self, value): - return self.reverseLookup[value] - - -class ScriptError(Exception): - """Exception used for script errors.""" - - -OpCodes = Enumeration("Opcodes", [ - ("OP_0", 0), ("OP_PUSHDATA1", 76), - "OP_PUSHDATA2", "OP_PUSHDATA4", "OP_1NEGATE", - "OP_RESERVED", - "OP_1", "OP_2", "OP_3", "OP_4", "OP_5", "OP_6", "OP_7", "OP_8", - "OP_9", "OP_10", "OP_11", "OP_12", "OP_13", "OP_14", "OP_15", "OP_16", - "OP_NOP", "OP_VER", "OP_IF", "OP_NOTIF", "OP_VERIF", "OP_VERNOTIF", - "OP_ELSE", "OP_ENDIF", "OP_VERIFY", "OP_RETURN", - "OP_TOALTSTACK", "OP_FROMALTSTACK", "OP_2DROP", "OP_2DUP", "OP_3DUP", - "OP_2OVER", "OP_2ROT", "OP_2SWAP", "OP_IFDUP", "OP_DEPTH", "OP_DROP", - "OP_DUP", "OP_NIP", "OP_OVER", "OP_PICK", "OP_ROLL", "OP_ROT", - "OP_SWAP", "OP_TUCK", - "OP_CAT", "OP_SUBSTR", "OP_LEFT", "OP_RIGHT", "OP_SIZE", - "OP_INVERT", "OP_AND", "OP_OR", "OP_XOR", "OP_EQUAL", "OP_EQUALVERIFY", - "OP_RESERVED1", "OP_RESERVED2", - "OP_1ADD", "OP_1SUB", "OP_2MUL", "OP_2DIV", "OP_NEGATE", "OP_ABS", - "OP_NOT", "OP_0NOTEQUAL", "OP_ADD", "OP_SUB", "OP_MUL", "OP_DIV", "OP_MOD", - "OP_LSHIFT", "OP_RSHIFT", "OP_BOOLAND", "OP_BOOLOR", "OP_NUMEQUAL", - "OP_NUMEQUALVERIFY", "OP_NUMNOTEQUAL", "OP_LESSTHAN", "OP_GREATERTHAN", - "OP_LESSTHANOREQUAL", "OP_GREATERTHANOREQUAL", "OP_MIN", "OP_MAX", - "OP_WITHIN", - "OP_RIPEMD160", "OP_SHA1", "OP_SHA256", "OP_HASH160", "OP_HASH256", - "OP_CODESEPARATOR", "OP_CHECKSIG", "OP_CHECKSIGVERIFY", "OP_CHECKMULTISIG", - "OP_CHECKMULTISIGVERIFY", - "OP_NOP1", - "OP_CHECKLOCKTIMEVERIFY", "OP_CHECKSEQUENCEVERIFY" -]) - - -# Paranoia to make it hard to create bad scripts -assert OpCodes.OP_DUP == 0x76 -assert OpCodes.OP_HASH160 == 0xa9 -assert OpCodes.OP_EQUAL == 0x87 -assert OpCodes.OP_EQUALVERIFY == 0x88 -assert OpCodes.OP_CHECKSIG == 0xac -assert OpCodes.OP_CHECKMULTISIG == 0xae - - -def _match_ops(ops, pattern): - if len(ops) != len(pattern): - return False - for op, pop in zip(ops, pattern): - if pop != op: - # -1 means 'data push', whose op is an (op, data) tuple - if pop == -1 and isinstance(op, tuple): - continue - return False - - return True - - -class ScriptPubKey: - """A class for handling a tx output script that gives conditions - necessary for spending. - """ - - TO_ADDRESS_OPS = [OpCodes.OP_DUP, OpCodes.OP_HASH160, -1, - OpCodes.OP_EQUALVERIFY, OpCodes.OP_CHECKSIG] - TO_P2SH_OPS = [OpCodes.OP_HASH160, -1, OpCodes.OP_EQUAL] - TO_PUBKEY_OPS = [-1, OpCodes.OP_CHECKSIG] - - PayToHandlers = namedtuple('PayToHandlers', 'address script_hash pubkey ' - 'unspendable strange') - - @classmethod - def pay_to(cls, handlers, script): - """Parse a script, invoke the appropriate handler and - return the result. - - One of the following handlers is invoked: - handlers.address(hash160) - handlers.script_hash(hash160) - handlers.pubkey(pubkey) - handlers.unspendable() - handlers.strange(script) - """ - try: - ops = Script.get_ops(script) - except ScriptError: - return handlers.unspendable() - - match = _match_ops - - if match(ops, cls.TO_ADDRESS_OPS): - return handlers.address(ops[2][-1]) - if match(ops, cls.TO_P2SH_OPS): - return handlers.script_hash(ops[1][-1]) - if match(ops, cls.TO_PUBKEY_OPS): - return handlers.pubkey(ops[0][-1]) - if ops and ops[0] == OpCodes.OP_RETURN: - return handlers.unspendable() - return handlers.strange(script) - - @classmethod - def P2SH_script(cls, hash160): - return (bytes([OpCodes.OP_HASH160]) - + Script.push_data(hash160) - + bytes([OpCodes.OP_EQUAL])) - - @classmethod - def P2PKH_script(cls, hash160): - return (bytes([OpCodes.OP_DUP, OpCodes.OP_HASH160]) - + Script.push_data(hash160) - + bytes([OpCodes.OP_EQUALVERIFY, OpCodes.OP_CHECKSIG])) - - @classmethod - def validate_pubkey(cls, pubkey, req_compressed=False): - if isinstance(pubkey, (bytes, bytearray)): - if len(pubkey) == 33 and pubkey[0] in (2, 3): - return # Compressed - if len(pubkey) == 65 and pubkey[0] == 4: - if not req_compressed: - return - raise PubKeyError('uncompressed pubkeys are invalid') - raise PubKeyError(f'invalid pubkey {pubkey}') - - @classmethod - def pubkey_script(cls, pubkey): - cls.validate_pubkey(pubkey) - return Script.push_data(pubkey) + bytes([OpCodes.OP_CHECKSIG]) - - @classmethod - def multisig_script(cls, m, pubkeys): - """Returns the script for a pay-to-multisig transaction.""" - n = len(pubkeys) - if not 1 <= m <= n <= 15: - raise ScriptError(f'{m:d} of {n:d} multisig script not possible') - for pubkey in pubkeys: - cls.validate_pubkey(pubkey, req_compressed=True) - # See https://bitcoin.org/en/developer-guide - # 2 of 3 is: OP_2 pubkey1 pubkey2 pubkey3 OP_3 OP_CHECKMULTISIG - return (bytes([OP_1 + m - 1]) - + b''.join(cls.push_data(pubkey) for pubkey in pubkeys) - + bytes([OP_1 + n - 1, OP_CHECK_MULTISIG])) - - -class Script: - - @classmethod - def get_ops(cls, script): - ops = [] - - # The unpacks or script[n] below throw on truncated scripts - try: - n = 0 - while n < len(script): - op = script[n] - n += 1 - - if op <= OpCodes.OP_PUSHDATA4: - # Raw bytes follow - if op < OpCodes.OP_PUSHDATA1: - dlen = op - elif op == OpCodes.OP_PUSHDATA1: - dlen = script[n] - n += 1 - elif op == OpCodes.OP_PUSHDATA2: - dlen, = unpack_le_uint16_from(script[n: n + 2]) - n += 2 - else: - dlen, = unpack_le_uint32_from(script[n: n + 4]) - n += 4 - if n + dlen > len(script): - raise IndexError - op = (op, script[n:n + dlen]) - n += dlen - - ops.append(op) - except Exception: - # Truncated script; e.g. tx_hash - # ebc9fa1196a59e192352d76c0f6e73167046b9d37b8302b6bb6968dfd279b767 - raise ScriptError('truncated script') - - return ops - - @classmethod - def push_data(cls, data): - """Returns the opcodes to push the data on the stack.""" - assert isinstance(data, (bytes, bytearray)) - - n = len(data) - if n < OpCodes.OP_PUSHDATA1: - return bytes([n]) + data - if n < 256: - return bytes([OpCodes.OP_PUSHDATA1, n]) + data - if n < 65536: - return bytes([OpCodes.OP_PUSHDATA2]) + pack_le_uint16(n) + data - return bytes([OpCodes.OP_PUSHDATA4]) + pack_le_uint32(n) + data - - @classmethod - def opcode_name(cls, opcode): - if OpCodes.OP_0 < opcode < OpCodes.OP_PUSHDATA1: - return f'OP_{opcode:d}' - try: - return OpCodes.whatis(opcode) - except KeyError: - return f'OP_UNKNOWN:{opcode:d}' - - @classmethod - def dump(cls, script): - opcodes, datas = cls.get_ops(script) - for opcode, data in zip(opcodes, datas): - name = cls.opcode_name(opcode) - if data is None: - print(name) - else: - print(f'{name} {data.hex()} ({len(data):d} bytes)') diff --git a/lbry/wallet/server/server.py b/lbry/wallet/server/server.py deleted file mode 100644 index 4d0374ba4..000000000 --- a/lbry/wallet/server/server.py +++ /dev/null @@ -1,146 +0,0 @@ -import signal -import logging -import asyncio -from concurrent.futures.thread import ThreadPoolExecutor -import typing - -import lbry -from lbry.wallet.server.mempool import MemPool, MemPoolAPI -from lbry.wallet.server.prometheus import PrometheusServer - - -class Notifications: - # hashX notifications come from two sources: new blocks and - # mempool refreshes. - # - # A user with a pending transaction is notified after the block it - # gets in is processed. Block processing can take an extended - # time, and the prefetcher might poll the daemon after the mempool - # code in any case. In such cases the transaction will not be in - # the mempool after the mempool refresh. We want to avoid - # notifying clients twice - for the mempool refresh and when the - # block is done. This object handles that logic by deferring - # notifications appropriately. - - def __init__(self): - self._touched_mp = {} - self._touched_bp = {} - self._highest_block = -1 - - async def _maybe_notify(self): - tmp, tbp = self._touched_mp, self._touched_bp - common = set(tmp).intersection(tbp) - if common: - height = max(common) - elif tmp and max(tmp) == self._highest_block: - height = self._highest_block - else: - # Either we are processing a block and waiting for it to - # come in, or we have not yet had a mempool update for the - # new block height - return - touched = tmp.pop(height) - for old in [h for h in tmp if h <= height]: - del tmp[old] - for old in [h for h in tbp if h <= height]: - touched.update(tbp.pop(old)) - await self.notify(height, touched) - - async def notify(self, height, touched): - pass - - async def start(self, height, notify_func): - self._highest_block = height - self.notify = notify_func - await self.notify(height, set()) - - async def on_mempool(self, touched, height): - self._touched_mp[height] = touched - await self._maybe_notify() - - async def on_block(self, touched, height): - self._touched_bp[height] = touched - self._highest_block = height - await self._maybe_notify() - - -class Server: - - def __init__(self, env): - self.env = env - self.log = logging.getLogger(__name__).getChild(self.__class__.__name__) - self.shutdown_event = asyncio.Event() - self.cancellable_tasks = [] - - self.notifications = notifications = Notifications() - self.daemon = daemon = env.coin.DAEMON(env.coin, env.daemon_url) - self.db = db = env.coin.DB(env) - self.bp = bp = env.coin.BLOCK_PROCESSOR(env, db, daemon, notifications) - self.prometheus_server: typing.Optional[PrometheusServer] = None - - # Set notifications up to implement the MemPoolAPI - notifications.height = daemon.height - notifications.cached_height = daemon.cached_height - notifications.mempool_hashes = daemon.mempool_hashes - notifications.raw_transactions = daemon.getrawtransactions - notifications.lookup_utxos = db.lookup_utxos - MemPoolAPI.register(Notifications) - self.mempool = mempool = MemPool(env.coin, notifications) - - self.session_mgr = env.coin.SESSION_MANAGER( - env, db, bp, daemon, mempool, self.shutdown_event - ) - - async def start(self): - env = self.env - min_str, max_str = env.coin.SESSIONCLS.protocol_min_max_strings() - self.log.info(f'software version: {lbry.__version__}') - self.log.info(f'supported protocol versions: {min_str}-{max_str}') - self.log.info(f'event loop policy: {env.loop_policy}') - self.log.info(f'reorg limit is {env.reorg_limit:,d} blocks') - - await self.daemon.height() - - def _start_cancellable(run, *args): - _flag = asyncio.Event() - self.cancellable_tasks.append(asyncio.ensure_future(run(*args, _flag))) - return _flag.wait() - - await _start_cancellable(self.bp.fetch_and_process_blocks) - await self.db.populate_header_merkle_cache() - await _start_cancellable(self.mempool.keep_synchronized) - await _start_cancellable(self.session_mgr.serve, self.notifications) - await self.start_prometheus() - - async def stop(self): - for task in reversed(self.cancellable_tasks): - task.cancel() - await asyncio.wait(self.cancellable_tasks) - if self.prometheus_server: - await self.prometheus_server.stop() - self.prometheus_server = None - self.shutdown_event.set() - await self.daemon.close() - - def run(self): - loop = asyncio.get_event_loop() - executor = ThreadPoolExecutor(1) - loop.set_default_executor(executor) - - def __exit(): - raise SystemExit() - try: - loop.add_signal_handler(signal.SIGINT, __exit) - loop.add_signal_handler(signal.SIGTERM, __exit) - loop.run_until_complete(self.start()) - loop.run_until_complete(self.shutdown_event.wait()) - except (SystemExit, KeyboardInterrupt): - pass - finally: - loop.run_until_complete(self.stop()) - executor.shutdown(True) - - async def start_prometheus(self): - if not self.prometheus_server and self.env.prometheus_port: - self.prometheus_server = PrometheusServer() - await self.prometheus_server.start(self.env.prometheus_port) diff --git a/lbry/wallet/server/session.py b/lbry/wallet/server/session.py deleted file mode 100644 index 9a9e23558..000000000 --- a/lbry/wallet/server/session.py +++ /dev/null @@ -1,1641 +0,0 @@ -import os -import ssl -import math -import time -import json -import zlib -import pylru -import base64 -import codecs -import typing -import asyncio -import logging -import itertools -import collections - -from asyncio import Event, sleep -from collections import defaultdict -from functools import partial - -from binascii import hexlify -from pylru import lrucache -from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor - -import lbry -from lbry.wallet.server.block_processor import LBRYBlockProcessor -from lbry.wallet.server.db.writer import LBRYLevelDB -from lbry.wallet.server.db import reader -from lbry.wallet.server.websocket import AdminWebSocket -from lbry.wallet.server.metrics import ServerLoadData, APICallMetrics -from lbry.wallet.server.prometheus import REQUESTS_COUNT, SQLITE_INTERRUPT_COUNT, SQLITE_INTERNAL_ERROR_COUNT -from lbry.wallet.server.prometheus import SQLITE_OPERATIONAL_ERROR_COUNT, SQLITE_EXECUTOR_TIMES, SESSIONS_COUNT -from lbry.wallet.server.prometheus import SQLITE_PENDING_COUNT, CLIENT_VERSIONS -from lbry.wallet.rpc.framing import NewlineFramer -import lbry.wallet.server.version as VERSION - -from lbry.wallet.rpc import ( - RPCSession, JSONRPCAutoDetect, JSONRPCConnection, - handler_invocation, RPCError, Request, JSONRPC -) -from lbry.wallet.server import text -from lbry.wallet.server import util -from lbry.wallet.server.hash import sha256, hash_to_hex_str, hex_str_to_hash, HASHX_LEN, Base58Error -from lbry.wallet.server.daemon import DaemonError -from lbry.wallet.server.peers import PeerManager -if typing.TYPE_CHECKING: - from lbry.wallet.server.env import Env - from lbry.wallet.server.mempool import MemPool - from lbry.wallet.server.daemon import Daemon - -BAD_REQUEST = 1 -DAEMON_ERROR = 2 - -log = logging.getLogger(__name__) - - -def scripthash_to_hashX(scripthash: str) -> bytes: - try: - bin_hash = hex_str_to_hash(scripthash) - if len(bin_hash) == 32: - return bin_hash[:HASHX_LEN] - except Exception: - pass - raise RPCError(BAD_REQUEST, f'{scripthash} is not a valid script hash') - - -def non_negative_integer(value) -> int: - """Return param value it is or can be converted to a non-negative - integer, otherwise raise an RPCError.""" - try: - value = int(value) - if value >= 0: - return value - except ValueError: - pass - raise RPCError(BAD_REQUEST, - f'{value} should be a non-negative integer') - - -def assert_boolean(value) -> bool: - """Return param value it is boolean otherwise raise an RPCError.""" - if value in (False, True): - return value - raise RPCError(BAD_REQUEST, f'{value} should be a boolean value') - - -def assert_tx_hash(value: str) -> None: - """Raise an RPCError if the value is not a valid transaction - hash.""" - try: - if len(util.hex_to_bytes(value)) == 32: - return - except Exception: - pass - raise RPCError(BAD_REQUEST, f'{value} should be a transaction hash') - - -class Semaphores: - """For aiorpcX's semaphore handling.""" - - def __init__(self, semaphores): - self.semaphores = semaphores - self.acquired = [] - - async def __aenter__(self): - for semaphore in self.semaphores: - await semaphore.acquire() - self.acquired.append(semaphore) - - async def __aexit__(self, exc_type, exc_value, traceback): - for semaphore in self.acquired: - semaphore.release() - - -class SessionGroup: - - def __init__(self, gid: int): - self.gid = gid - # Concurrency per group - self.semaphore = asyncio.Semaphore(20) - - -class SessionManager: - """Holds global state about all sessions.""" - - def __init__(self, env: 'Env', db: LBRYLevelDB, bp: LBRYBlockProcessor, daemon: 'Daemon', mempool: 'MemPool', - shutdown_event: asyncio.Event): - env.max_send = max(350000, env.max_send) - self.env = env - self.db = db - self.bp = bp - self.daemon = daemon - self.mempool = mempool - self.peer_mgr = PeerManager(env, db) - self.shutdown_event = shutdown_event - self.logger = util.class_logger(__name__, self.__class__.__name__) - self.servers: typing.Dict[str, asyncio.AbstractServer] = {} - self.sessions: typing.Set['SessionBase'] = set() - self.cur_group = SessionGroup(0) - self.txs_sent = 0 - self.start_time = time.time() - self.history_cache = pylru.lrucache(256) - self.notified_height: typing.Optional[int] = None - # Cache some idea of room to avoid recounting on each subscription - self.subs_room = 0 - - self.session_event = Event() - - async def _start_server(self, kind, *args, **kw_args): - loop = asyncio.get_event_loop() - if kind == 'RPC': - protocol_class = LocalRPC - else: - protocol_class = self.env.coin.SESSIONCLS - protocol_factory = partial(protocol_class, self, self.db, - self.mempool, self.peer_mgr, kind) - - host, port = args[:2] - try: - self.servers[kind] = await loop.create_server(protocol_factory, *args, **kw_args) - except OSError as e: # don't suppress CancelledError - self.logger.error(f'{kind} server failed to listen on {host}:' - f'{port:d} :{e!r}') - else: - self.logger.info(f'{kind} server listening on {host}:{port:d}') - - async def _start_external_servers(self): - """Start listening on TCP and SSL ports, but only if the respective - port was given in the environment. - """ - env = self.env - host = env.cs_host(for_rpc=False) - if env.tcp_port is not None: - await self._start_server('TCP', host, env.tcp_port) - if env.ssl_port is not None: - sslc = ssl.SSLContext(ssl.PROTOCOL_TLS) - sslc.load_cert_chain(env.ssl_certfile, keyfile=env.ssl_keyfile) - await self._start_server('SSL', host, env.ssl_port, ssl=sslc) - - async def _close_servers(self, kinds): - """Close the servers of the given kinds (TCP etc.).""" - if kinds: - self.logger.info('closing down {} listening servers' - .format(', '.join(kinds))) - for kind in kinds: - server = self.servers.pop(kind, None) - if server: - server.close() - await server.wait_closed() - - async def _manage_servers(self): - paused = False - max_sessions = self.env.max_sessions - low_watermark = max_sessions * 19 // 20 - while True: - await self.session_event.wait() - self.session_event.clear() - if not paused and len(self.sessions) >= max_sessions: - self.logger.info(f'maximum sessions {max_sessions:,d} ' - f'reached, stopping new connections until ' - f'count drops to {low_watermark:,d}') - await self._close_servers(['TCP', 'SSL']) - paused = True - # Start listening for incoming connections if paused and - # session count has fallen - if paused and len(self.sessions) <= low_watermark: - self.logger.info('resuming listening for incoming connections') - await self._start_external_servers() - paused = False - - async def _log_sessions(self): - """Periodically log sessions.""" - log_interval = self.env.log_sessions - if log_interval: - while True: - await sleep(log_interval) - data = self._session_data(for_log=True) - for line in text.sessions_lines(data): - self.logger.info(line) - self.logger.info(json.dumps(self._get_info())) - - def _group_map(self): - group_map = defaultdict(list) - for session in self.sessions: - group_map[session.group].append(session) - return group_map - - def _sub_count(self) -> int: - return sum(s.sub_count() for s in self.sessions) - - def _lookup_session(self, session_id): - try: - session_id = int(session_id) - except Exception: - pass - else: - for session in self.sessions: - if session.session_id == session_id: - return session - return None - - async def _for_each_session(self, session_ids, operation): - if not isinstance(session_ids, list): - raise RPCError(BAD_REQUEST, 'expected a list of session IDs') - - result = [] - for session_id in session_ids: - session = self._lookup_session(session_id) - if session: - result.append(await operation(session)) - else: - result.append(f'unknown session: {session_id}') - return result - - async def _clear_stale_sessions(self): - """Cut off sessions that haven't done anything for 10 minutes.""" - session_timeout = self.env.session_timeout - while True: - await sleep(session_timeout // 10) - stale_cutoff = time.perf_counter() - session_timeout - stale_sessions = [session for session in self.sessions - if session.last_recv < stale_cutoff] - if stale_sessions: - text = ', '.join(str(session.session_id) - for session in stale_sessions) - self.logger.info(f'closing stale connections {text}') - # Give the sockets some time to close gracefully - if stale_sessions: - await asyncio.wait([ - session.close(force_after=session_timeout // 10) for session in stale_sessions - ]) - - # Consolidate small groups - group_map = self._group_map() - groups = [group for group, sessions in group_map.items() - if len(sessions) <= 5] # fixme: apply session cost here - if len(groups) > 1: - new_group = groups[-1] - for group in groups: - for session in group_map[group]: - session.group = new_group - - def _get_info(self): - """A summary of server state.""" - group_map = self._group_map() - method_counts = collections.defaultdict(int) - error_count = 0 - logged = 0 - paused = 0 - pending_requests = 0 - closing = 0 - - for s in self.sessions: - error_count += s.errors - if s.log_me: - logged += 1 - if not s._can_send.is_set(): - paused += 1 - pending_requests += s.count_pending_items() - if s.is_closing(): - closing += 1 - for request, _ in s.connection._requests.values(): - method_counts[request.method] += 1 - return { - 'closing': closing, - 'daemon': self.daemon.logged_url(), - 'daemon_height': self.daemon.cached_height(), - 'db_height': self.db.db_height, - 'errors': error_count, - 'groups': len(group_map), - 'logged': logged, - 'paused': paused, - 'pid': os.getpid(), - 'peers': self.peer_mgr.info(), - 'requests': pending_requests, - 'method_counts': method_counts, - 'sessions': self.session_count(), - 'subs': self._sub_count(), - 'txs_sent': self.txs_sent, - 'uptime': util.formatted_time(time.time() - self.start_time), - 'version': lbry.__version__, - } - - def _session_data(self, for_log): - """Returned to the RPC 'sessions' call.""" - now = time.time() - sessions = sorted(self.sessions, key=lambda s: s.start_time) - return [(session.session_id, - session.flags(), - session.peer_address_str(for_log=for_log), - session.client_version, - session.protocol_version_string(), - session.count_pending_items(), - session.txs_sent, - session.sub_count(), - session.recv_count, session.recv_size, - session.send_count, session.send_size, - now - session.start_time) - for session in sessions] - - def _group_data(self): - """Returned to the RPC 'groups' call.""" - result = [] - group_map = self._group_map() - for group, sessions in group_map.items(): - result.append([group.gid, - len(sessions), - sum(s.bw_charge for s in sessions), - sum(s.count_pending_items() for s in sessions), - sum(s.txs_sent for s in sessions), - sum(s.sub_count() for s in sessions), - sum(s.recv_count for s in sessions), - sum(s.recv_size for s in sessions), - sum(s.send_count for s in sessions), - sum(s.send_size for s in sessions), - ]) - return result - - async def _electrum_and_raw_headers(self, height): - raw_header = await self.raw_header(height) - electrum_header = self.env.coin.electrum_header(raw_header, height) - return electrum_header, raw_header - - async def _refresh_hsub_results(self, height): - """Refresh the cached header subscription responses to be for height, - and record that as notified_height. - """ - # Paranoia: a reorg could race and leave db_height lower - height = min(height, self.db.db_height) - electrum, raw = await self._electrum_and_raw_headers(height) - self.hsub_results = (electrum, {'hex': raw.hex(), 'height': height}) - self.notified_height = height - - # --- LocalRPC command handlers - - async def rpc_add_peer(self, real_name): - """Add a peer. - - real_name: "bch.electrumx.cash t50001 s50002" for example - """ - await self.peer_mgr.add_localRPC_peer(real_name) - return f"peer '{real_name}' added" - - async def rpc_disconnect(self, session_ids): - """Disconnect sessions. - - session_ids: array of session IDs - """ - async def close(session): - """Close the session's transport.""" - await session.close(force_after=2) - return f'disconnected {session.session_id}' - - return await self._for_each_session(session_ids, close) - - async def rpc_log(self, session_ids): - """Toggle logging of sessions. - - session_ids: array of session IDs - """ - async def toggle_logging(session): - """Toggle logging of the session.""" - session.toggle_logging() - return f'log {session.session_id}: {session.log_me}' - - return await self._for_each_session(session_ids, toggle_logging) - - async def rpc_daemon_url(self, daemon_url): - """Replace the daemon URL.""" - daemon_url = daemon_url or self.env.daemon_url - try: - self.daemon.set_url(daemon_url) - except Exception as e: - raise RPCError(BAD_REQUEST, f'an error occurred: {e!r}') - return f'now using daemon at {self.daemon.logged_url()}' - - async def rpc_stop(self): - """Shut down the server cleanly.""" - self.shutdown_event.set() - return 'stopping' - - async def rpc_getinfo(self): - """Return summary information about the server process.""" - return self._get_info() - - async def rpc_groups(self): - """Return statistics about the session groups.""" - return self._group_data() - - async def rpc_peers(self): - """Return a list of data about server peers.""" - return self.peer_mgr.rpc_data() - - async def rpc_query(self, items, limit): - """Return a list of data about server peers.""" - coin = self.env.coin - db = self.db - lines = [] - - def arg_to_hashX(arg): - try: - script = bytes.fromhex(arg) - lines.append(f'Script: {arg}') - return coin.hashX_from_script(script) - except ValueError: - pass - - try: - hashX = coin.address_to_hashX(arg) - except Base58Error as e: - lines.append(e.args[0]) - return None - lines.append(f'Address: {arg}') - return hashX - - for arg in items: - hashX = arg_to_hashX(arg) - if not hashX: - continue - n = None - history = await db.limited_history(hashX, limit=limit) - for n, (tx_hash, height) in enumerate(history): - lines.append(f'History #{n:,d}: height {height:,d} ' - f'tx_hash {hash_to_hex_str(tx_hash)}') - if n is None: - lines.append('No history found') - n = None - utxos = await db.all_utxos(hashX) - for n, utxo in enumerate(utxos, start=1): - lines.append(f'UTXO #{n:,d}: tx_hash ' - f'{hash_to_hex_str(utxo.tx_hash)} ' - f'tx_pos {utxo.tx_pos:,d} height ' - f'{utxo.height:,d} value {utxo.value:,d}') - if n == limit: - break - if n is None: - lines.append('No UTXOs found') - - balance = sum(utxo.value for utxo in utxos) - lines.append(f'Balance: {coin.decimal_value(balance):,f} ' - f'{coin.SHORTNAME}') - - return lines - - async def rpc_sessions(self): - """Return statistics about connected sessions.""" - return self._session_data(for_log=False) - - async def rpc_reorg(self, count): - """Force a reorg of the given number of blocks. - - count: number of blocks to reorg - """ - count = non_negative_integer(count) - if not self.bp.force_chain_reorg(count): - raise RPCError(BAD_REQUEST, 'still catching up with daemon') - return f'scheduled a reorg of {count:,d} blocks' - - # --- External Interface - - async def serve(self, notifications, server_listening_event): - """Start the RPC server if enabled. When the event is triggered, - start TCP and SSL servers.""" - try: - if self.env.rpc_port is not None: - await self._start_server('RPC', self.env.cs_host(for_rpc=True), - self.env.rpc_port) - self.logger.info(f'max session count: {self.env.max_sessions:,d}') - self.logger.info(f'session timeout: ' - f'{self.env.session_timeout:,d} seconds') - self.logger.info(f'max response size {self.env.max_send:,d} bytes') - if self.env.drop_client is not None: - self.logger.info(f'drop clients matching: {self.env.drop_client.pattern}') - # Start notifications; initialize hsub_results - await notifications.start(self.db.db_height, self._notify_sessions) - await self.start_other() - await self._start_external_servers() - server_listening_event.set() - # Peer discovery should start after the external servers - # because we connect to ourself - await asyncio.wait([ - self.peer_mgr.discover_peers(), - self._clear_stale_sessions(), - self._log_sessions(), - self._manage_servers() - ]) - finally: - await self._close_servers(list(self.servers.keys())) - if self.sessions: - await asyncio.wait([ - session.close(force_after=1) for session in self.sessions - ]) - await self.stop_other() - - async def start_other(self): - pass - - async def stop_other(self): - pass - - def session_count(self) -> int: - """The number of connections that we've sent something to.""" - return len(self.sessions) - - async def daemon_request(self, method, *args): - """Catch a DaemonError and convert it to an RPCError.""" - try: - return await getattr(self.daemon, method)(*args) - except DaemonError as e: - raise RPCError(DAEMON_ERROR, f'daemon error: {e!r}') from None - - async def raw_header(self, height): - """Return the binary header at the given height.""" - try: - return await self.db.raw_header(height) - except IndexError: - raise RPCError(BAD_REQUEST, f'height {height:,d} ' - 'out of range') from None - - async def electrum_header(self, height): - """Return the deserialized header at the given height.""" - electrum_header, _ = await self._electrum_and_raw_headers(height) - return electrum_header - - async def broadcast_transaction(self, raw_tx): - hex_hash = await self.daemon.broadcast_transaction(raw_tx) - self.mempool.wakeup.set() - self.txs_sent += 1 - return hex_hash - - async def limited_history(self, hashX): - """A caching layer.""" - hc = self.history_cache - if hashX not in hc: - # History DoS limit. Each element of history is about 99 - # bytes when encoded as JSON. This limits resource usage - # on bloated history requests, and uses a smaller divisor - # so large requests are logged before refusing them. - limit = self.env.max_send // 97 - hc[hashX] = await self.db.limited_history(hashX, limit=limit) - return hc[hashX] - - async def _notify_sessions(self, height, touched): - """Notify sessions about height changes and touched addresses.""" - height_changed = height != self.notified_height - if height_changed: - await self._refresh_hsub_results(height) - # Invalidate our history cache for touched hashXs - hc = self.history_cache - for hashX in set(hc).intersection(touched): - del hc[hashX] - - if self.sessions: - await asyncio.wait([ - session.notify(touched, height_changed) for session in self.sessions - ]) - - def add_session(self, session): - self.sessions.add(session) - self.session_event.set() - gid = int(session.start_time - self.start_time) // 900 - if self.cur_group.gid != gid: - self.cur_group = SessionGroup(gid) - return self.cur_group - - def remove_session(self, session): - """Remove a session from our sessions list if there.""" - self.sessions.remove(session) - self.session_event.set() - - -class SessionBase(RPCSession): - """Base class of ElectrumX JSON sessions. - - Each session runs its tasks in asynchronous parallelism with other - sessions. - """ - - MAX_CHUNK_SIZE = 40960 - session_counter = itertools.count() - request_handlers: typing.Dict[str, typing.Callable] = {} - version = '0.5.7' - - def __init__(self, session_mgr, db, mempool, peer_mgr, kind): - connection = JSONRPCConnection(JSONRPCAutoDetect) - self.env = session_mgr.env - super().__init__(connection=connection) - self.logger = util.class_logger(__name__, self.__class__.__name__) - self.session_mgr = session_mgr - self.db = db - self.mempool = mempool - self.peer_mgr = peer_mgr - self.kind = kind # 'RPC', 'TCP' etc. - self.coin = self.env.coin - self.anon_logs = self.env.anon_logs - self.txs_sent = 0 - self.log_me = False - self.daemon_request = self.session_mgr.daemon_request - # Hijack the connection so we can log messages - self._receive_message_orig = self.connection.receive_message - self.connection.receive_message = self.receive_message - - async def notify(self, touched, height_changed): - pass - - def default_framer(self): - return NewlineFramer(self.env.max_receive) - - def peer_address_str(self, *, for_log=True): - """Returns the peer's IP address and port as a human-readable - string, respecting anon logs if the output is for a log.""" - if for_log and self.anon_logs: - return 'xx.xx.xx.xx:xx' - return super().peer_address_str() - - def receive_message(self, message): - if self.log_me: - self.logger.info(f'processing {message}') - return self._receive_message_orig(message) - - def toggle_logging(self): - self.log_me = not self.log_me - - def flags(self): - """Status flags.""" - status = self.kind[0] - if self.is_closing(): - status += 'C' - if self.log_me: - status += 'L' - status += str(self._concurrency.max_concurrent) - return status - - def connection_made(self, transport): - """Handle an incoming client connection.""" - super().connection_made(transport) - self.session_id = next(self.session_counter) - context = {'conn_id': f'{self.session_id}'} - self.logger = util.ConnectionLogger(self.logger, context) - self.group = self.session_mgr.add_session(self) - SESSIONS_COUNT.labels(version=self.client_version).inc() - peer_addr_str = self.peer_address_str() - self.logger.info(f'{self.kind} {peer_addr_str}, ' - f'{self.session_mgr.session_count():,d} total') - - def connection_lost(self, exc): - """Handle client disconnection.""" - super().connection_lost(exc) - self.session_mgr.remove_session(self) - SESSIONS_COUNT.labels(version=self.client_version).dec() - msg = '' - if not self._can_send.is_set(): - msg += ' whilst paused' - if self.send_size >= 1024*1024: - msg += ('. Sent {:,d} bytes in {:,d} messages' - .format(self.send_size, self.send_count)) - if msg: - msg = 'disconnected' + msg - self.logger.info(msg) - - def count_pending_items(self): - return len(self.connection.pending_requests()) - - def semaphore(self): - return Semaphores([self.group.semaphore]) - - def sub_count(self): - return 0 - - async def handle_request(self, request): - """Handle an incoming request. ElectrumX doesn't receive - notifications from client sessions. - """ - REQUESTS_COUNT.labels(method=request.method, version=self.client_version).inc() - if isinstance(request, Request): - handler = self.request_handlers.get(request.method) - handler = partial(handler, self) - else: - handler = None - coro = handler_invocation(handler, request)() - return await coro - - -class LBRYSessionManager(SessionManager): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.query_executor = None - self.websocket = None - self.metrics = ServerLoadData() - self.metrics_loop = None - self.running = False - if self.env.websocket_host is not None and self.env.websocket_port is not None: - self.websocket = AdminWebSocket(self) - self.search_cache = self.bp.search_cache - self.search_cache['search'] = lrucache(10000) - self.search_cache['resolve'] = lrucache(10000) - - async def process_metrics(self): - while self.running: - data = self.metrics.to_json_and_reset({ - 'sessions': self.session_count(), - 'height': self.db.db_height, - }) - if self.websocket is not None: - self.websocket.send_message(data) - await asyncio.sleep(1) - - async def start_other(self): - self.running = True - path = os.path.join(self.env.db_dir, 'claims.db') - args = dict( - initializer=reader.initializer, - initargs=( - self.logger, path, self.env.coin.NET, self.env.database_query_timeout, - self.env.track_metrics, ( - self.db.sql.blocked_streams, self.db.sql.blocked_channels, - self.db.sql.filtered_streams, self.db.sql.filtered_channels - ) - ) - ) - if self.env.max_query_workers is not None and self.env.max_query_workers == 0: - self.query_executor = ThreadPoolExecutor(max_workers=1, **args) - else: - self.query_executor = ProcessPoolExecutor( - max_workers=self.env.max_query_workers or max(os.cpu_count(), 4), **args - ) - if self.websocket is not None: - await self.websocket.start() - if self.env.track_metrics: - self.metrics_loop = asyncio.create_task(self.process_metrics()) - - async def stop_other(self): - self.running = False - if self.env.track_metrics: - self.metrics_loop.cancel() - if self.websocket is not None: - await self.websocket.stop() - self.query_executor.shutdown() - - -class LBRYElectrumX(SessionBase): - """A TCP server that handles incoming Electrum connections.""" - - PROTOCOL_MIN = VERSION.PROTOCOL_MIN - PROTOCOL_MAX = VERSION.PROTOCOL_MAX - max_errors = math.inf # don't disconnect people for errors! let them happen... - session_mgr: LBRYSessionManager - version = lbry.__version__ - - @classmethod - def initialize_request_handlers(cls): - cls.request_handlers.update({ - 'blockchain.block.get_chunk': cls.block_get_chunk, - 'blockchain.block.get_header': cls.block_get_header, - 'blockchain.estimatefee': cls.estimatefee, - 'blockchain.relayfee': cls.relayfee, - 'blockchain.scripthash.get_balance': cls.scripthash_get_balance, - 'blockchain.scripthash.get_history': cls.scripthash_get_history, - 'blockchain.scripthash.get_mempool': cls.scripthash_get_mempool, - 'blockchain.scripthash.listunspent': cls.scripthash_listunspent, - 'blockchain.scripthash.subscribe': cls.scripthash_subscribe, - 'blockchain.transaction.broadcast': cls.transaction_broadcast, - 'blockchain.transaction.get': cls.transaction_get, - 'blockchain.transaction.get_batch': cls.transaction_get_batch, - 'blockchain.transaction.info': cls.transaction_info, - 'blockchain.transaction.get_merkle': cls.transaction_merkle, - 'server.add_peer': cls.add_peer, - 'server.banner': cls.banner, - 'server.payment_address': cls.payment_address, - 'server.donation_address': cls.donation_address, - 'server.features': cls.server_features_async, - 'server.peers.subscribe': cls.peers_subscribe, - 'server.version': cls.server_version, - 'blockchain.transaction.get_height': cls.transaction_get_height, - 'blockchain.claimtrie.search': cls.claimtrie_search, - 'blockchain.claimtrie.resolve': cls.claimtrie_resolve, - 'blockchain.claimtrie.getclaimsbyids': cls.claimtrie_getclaimsbyids, - 'blockchain.block.get_server_height': cls.get_server_height, - 'mempool.get_fee_histogram': cls.mempool_compact_histogram, - 'blockchain.block.headers': cls.block_headers, - 'server.ping': cls.ping, - 'blockchain.headers.subscribe': cls.headers_subscribe_False, - 'blockchain.address.get_balance': cls.address_get_balance, - 'blockchain.address.get_history': cls.address_get_history, - 'blockchain.address.get_mempool': cls.address_get_mempool, - 'blockchain.address.listunspent': cls.address_listunspent, - 'blockchain.address.subscribe': cls.address_subscribe, - 'blockchain.address.unsubscribe': cls.address_unsubscribe, - }) - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - if not LBRYElectrumX.request_handlers: - LBRYElectrumX.initialize_request_handlers() - self.subscribe_headers = False - self.subscribe_headers_raw = False - self.connection.max_response_size = self.env.max_send - self.hashX_subs = {} - self.sv_seen = False - self.mempool_statuses = {} - self.protocol_tuple = self.PROTOCOL_MIN - - self.daemon = self.session_mgr.daemon - self.bp: LBRYBlockProcessor = self.session_mgr.bp - self.db: LBRYLevelDB = self.bp.db - - @classmethod - def protocol_min_max_strings(cls): - return [util.version_string(ver) - for ver in (cls.PROTOCOL_MIN, cls.PROTOCOL_MAX)] - - @classmethod - def server_features(cls, env): - """Return the server features dictionary.""" - min_str, max_str = cls.protocol_min_max_strings() - return { - 'hosts': env.hosts_dict(), - 'pruning': None, - 'server_version': cls.version, - 'protocol_min': min_str, - 'protocol_max': max_str, - 'genesis_hash': env.coin.GENESIS_HASH, - 'description': env.description, - 'payment_address': env.payment_address, - 'donation_address': env.donation_address, - 'daily_fee': env.daily_fee, - 'hash_function': 'sha256', - 'trending_algorithm': env.trending_algorithms[0] - } - - async def server_features_async(self): - return self.server_features(self.env) - - @classmethod - def server_version_args(cls): - """The arguments to a server.version RPC call to a peer.""" - return [cls.version, cls.protocol_min_max_strings()] - - def protocol_version_string(self): - return util.version_string(self.protocol_tuple) - - def sub_count(self): - return len(self.hashX_subs) - - async def notify(self, touched, height_changed): - """Notify the client about changes to touched addresses (from mempool - updates or new blocks) and height. - """ - if height_changed and self.subscribe_headers: - args = (await self.subscribe_headers_result(), ) - try: - await self.send_notification('blockchain.headers.subscribe', args) - except asyncio.TimeoutError: - self.logger.info("timeout sending headers notification to %s", self.peer_address_str(for_log=True)) - self.abort() - return - - touched = touched.intersection(self.hashX_subs) - if touched or (height_changed and self.mempool_statuses): - changed = {} - - for hashX in touched: - alias = self.hashX_subs[hashX] - status = await self.address_status(hashX) - changed[alias] = status - - # Check mempool hashXs - the status is a function of the - # confirmed state of other transactions. Note: we cannot - # iterate over mempool_statuses as it changes size. - for hashX in tuple(self.mempool_statuses): - # Items can be evicted whilst await-ing status; False - # ensures such hashXs are notified - old_status = self.mempool_statuses.get(hashX, False) - status = await self.address_status(hashX) - if status != old_status: - alias = self.hashX_subs[hashX] - changed[alias] = status - - for alias, status in changed.items(): - if len(alias) == 64: - method = 'blockchain.scripthash.subscribe' - else: - method = 'blockchain.address.subscribe' - - try: - await self.send_notification(method, (alias, status)) - except asyncio.TimeoutError: - self.logger.info("timeout sending address notification to %s", self.peer_address_str(for_log=True)) - self.abort() - return - - if changed: - es = '' if len(changed) == 1 else 'es' - self.logger.info(f'notified of {len(changed):,d} address{es}') - - def get_metrics_or_placeholder_for_api(self, query_name): - """ Do not hold on to a reference to the metrics - returned by this method past an `await` or - you may be working with a stale metrics object. - """ - if self.env.track_metrics: - return self.session_mgr.metrics.for_api(query_name) - else: - return APICallMetrics(query_name) - - async def run_in_executor(self, query_name, func, kwargs): - start = time.perf_counter() - try: - SQLITE_PENDING_COUNT.inc() - result = await asyncio.get_running_loop().run_in_executor( - self.session_mgr.query_executor, func, kwargs - ) - except asyncio.CancelledError: - raise - except reader.SQLiteInterruptedError as error: - metrics = self.get_metrics_or_placeholder_for_api(query_name) - metrics.query_interrupt(start, error.metrics) - SQLITE_INTERRUPT_COUNT.inc() - raise RPCError(JSONRPC.QUERY_TIMEOUT, 'sqlite query timed out') - except reader.SQLiteOperationalError as error: - metrics = self.get_metrics_or_placeholder_for_api(query_name) - metrics.query_error(start, error.metrics) - SQLITE_OPERATIONAL_ERROR_COUNT.inc() - raise RPCError(JSONRPC.INTERNAL_ERROR, 'query failed to execute') - except Exception: - log.exception("dear devs, please handle this exception better") - metrics = self.get_metrics_or_placeholder_for_api(query_name) - metrics.query_error(start, {}) - SQLITE_INTERNAL_ERROR_COUNT.inc() - raise RPCError(JSONRPC.INTERNAL_ERROR, 'unknown server error') - else: - if self.env.track_metrics: - metrics = self.get_metrics_or_placeholder_for_api(query_name) - (result, metrics_data) = result - metrics.query_response(start, metrics_data) - return base64.b64encode(result).decode() - finally: - SQLITE_PENDING_COUNT.dec() - SQLITE_EXECUTOR_TIMES.observe(time.perf_counter() - start) - - async def run_and_cache_query(self, query_name, function, kwargs): - metrics = self.get_metrics_or_placeholder_for_api(query_name) - metrics.start() - cache = self.session_mgr.search_cache[query_name] - cache_key = str(kwargs) - cache_item = cache.get(cache_key) - if cache_item is None: - cache_item = cache[cache_key] = ResultCacheItem() - elif cache_item.result is not None: - metrics.cache_response() - return cache_item.result - async with cache_item.lock: - if cache_item.result is None: - cache_item.result = await self.run_in_executor( - query_name, function, kwargs - ) - else: - metrics = self.get_metrics_or_placeholder_for_api(query_name) - metrics.cache_response() - return cache_item.result - - async def mempool_compact_histogram(self): - return self.mempool.compact_fee_histogram() - - async def claimtrie_search(self, **kwargs): - if kwargs: - return await self.run_and_cache_query('search', reader.search_to_bytes, kwargs) - - async def claimtrie_resolve(self, *urls): - if urls: - return await self.run_and_cache_query('resolve', reader.resolve_to_bytes, urls) - - async def get_server_height(self): - return self.bp.height - - async def transaction_get_height(self, tx_hash): - self.assert_tx_hash(tx_hash) - transaction_info = await self.daemon.getrawtransaction(tx_hash, True) - if transaction_info and 'hex' in transaction_info and 'confirmations' in transaction_info: - # an unconfirmed transaction from lbrycrdd will not have a 'confirmations' field - return (self.db.db_height - transaction_info['confirmations']) + 1 - elif transaction_info and 'hex' in transaction_info: - return -1 - return None - - async def claimtrie_getclaimsbyids(self, *claim_ids): - claims = await self.batched_formatted_claims_from_daemon(claim_ids) - return dict(zip(claim_ids, claims)) - - async def batched_formatted_claims_from_daemon(self, claim_ids): - claims = await self.daemon.getclaimsbyids(claim_ids) - result = [] - for claim in claims: - if claim and claim.get('value'): - result.append(self.format_claim_from_daemon(claim)) - return result - - def format_claim_from_daemon(self, claim, name=None): - """Changes the returned claim data to the format expected by lbry and adds missing fields.""" - - if not claim: - return {} - - # this ISO-8859 nonsense stems from a nasty form of encoding extended characters in lbrycrd - # it will be fixed after the lbrycrd upstream merge to v17 is done - # it originated as a fear of terminals not supporting unicode. alas, they all do - - if 'name' in claim: - name = claim['name'].encode('ISO-8859-1').decode() - info = self.db.sql.get_claims(claim_id=claim['claimId']) - if not info: - # raise RPCError("Lbrycrd has {} but not lbryumx, please submit a bug report.".format(claim_id)) - return {} - address = info.address.decode() - # fixme: temporary - #supports = self.format_supports_from_daemon(claim.get('supports', [])) - supports = [] - - amount = get_from_possible_keys(claim, 'amount', 'nAmount') - height = get_from_possible_keys(claim, 'height', 'nHeight') - effective_amount = get_from_possible_keys(claim, 'effective amount', 'nEffectiveAmount') - valid_at_height = get_from_possible_keys(claim, 'valid at height', 'nValidAtHeight') - - result = { - "name": name, - "claim_id": claim['claimId'], - "txid": claim['txid'], - "nout": claim['n'], - "amount": amount, - "depth": self.db.db_height - height + 1, - "height": height, - "value": hexlify(claim['value'].encode('ISO-8859-1')).decode(), - "address": address, # from index - "supports": supports, - "effective_amount": effective_amount, - "valid_at_height": valid_at_height - } - if 'claim_sequence' in claim: - # TODO: ensure that lbrycrd #209 fills in this value - result['claim_sequence'] = claim['claim_sequence'] - else: - result['claim_sequence'] = -1 - if 'normalized_name' in claim: - result['normalized_name'] = claim['normalized_name'].encode('ISO-8859-1').decode() - return result - - def assert_tx_hash(self, value): - '''Raise an RPCError if the value is not a valid transaction - hash.''' - try: - if len(util.hex_to_bytes(value)) == 32: - return - except Exception: - pass - raise RPCError(1, f'{value} should be a transaction hash') - - def assert_claim_id(self, value): - '''Raise an RPCError if the value is not a valid claim id - hash.''' - try: - if len(util.hex_to_bytes(value)) == 20: - return - except Exception: - pass - raise RPCError(1, f'{value} should be a claim id hash') - - async def subscribe_headers_result(self): - """The result of a header subscription or notification.""" - return self.session_mgr.hsub_results[self.subscribe_headers_raw] - - async def _headers_subscribe(self, raw): - """Subscribe to get headers of new blocks.""" - self.subscribe_headers_raw = assert_boolean(raw) - self.subscribe_headers = True - return await self.subscribe_headers_result() - - async def headers_subscribe(self): - """Subscribe to get raw headers of new blocks.""" - return await self._headers_subscribe(True) - - async def headers_subscribe_True(self, raw=True): - """Subscribe to get headers of new blocks.""" - return await self._headers_subscribe(raw) - - async def headers_subscribe_False(self, raw=False): - """Subscribe to get headers of new blocks.""" - return await self._headers_subscribe(raw) - - async def add_peer(self, features): - """Add a peer (but only if the peer resolves to the source).""" - return await self.peer_mgr.on_add_peer(features, self.peer_address()) - - async def peers_subscribe(self): - """Return the server peers as a list of (ip, host, details) tuples.""" - return self.peer_mgr.on_peers_subscribe(self.is_tor()) - - async def address_status(self, hashX): - """Returns an address status. - - Status is a hex string, but must be None if there is no history. - """ - # Note history is ordered and mempool unordered in electrum-server - # For mempool, height is -1 if it has unconfirmed inputs, otherwise 0 - db_history = await self.session_mgr.limited_history(hashX) - mempool = await self.mempool.transaction_summaries(hashX) - - status = ''.join(f'{hash_to_hex_str(tx_hash)}:' - f'{height:d}:' - for tx_hash, height in db_history) - status += ''.join(f'{hash_to_hex_str(tx.hash)}:' - f'{-tx.has_unconfirmed_inputs:d}:' - for tx in mempool) - if status: - status = sha256(status.encode()).hex() - else: - status = None - - if mempool: - self.mempool_statuses[hashX] = status - else: - self.mempool_statuses.pop(hashX, None) - - return status - - async def hashX_listunspent(self, hashX): - """Return the list of UTXOs of a script hash, including mempool - effects.""" - utxos = await self.db.all_utxos(hashX) - utxos = sorted(utxos) - utxos.extend(await self.mempool.unordered_UTXOs(hashX)) - spends = await self.mempool.potential_spends(hashX) - - return [{'tx_hash': hash_to_hex_str(utxo.tx_hash), - 'tx_pos': utxo.tx_pos, - 'height': utxo.height, 'value': utxo.value} - for utxo in utxos - if (utxo.tx_hash, utxo.tx_pos) not in spends] - - async def hashX_subscribe(self, hashX, alias): - self.hashX_subs[hashX] = alias - return await self.address_status(hashX) - - async def hashX_unsubscribe(self, hashX, alias): - del self.hashX_subs[hashX] - - def address_to_hashX(self, address): - try: - return self.coin.address_to_hashX(address) - except Exception: - pass - raise RPCError(BAD_REQUEST, f'{address} is not a valid address') - - async def address_get_balance(self, address): - """Return the confirmed and unconfirmed balance of an address.""" - hashX = self.address_to_hashX(address) - return await self.get_balance(hashX) - - async def address_get_history(self, address): - """Return the confirmed and unconfirmed history of an address.""" - hashX = self.address_to_hashX(address) - return await self.confirmed_and_unconfirmed_history(hashX) - - async def address_get_mempool(self, address): - """Return the mempool transactions touching an address.""" - hashX = self.address_to_hashX(address) - return await self.unconfirmed_history(hashX) - - async def address_listunspent(self, address): - """Return the list of UTXOs of an address.""" - hashX = self.address_to_hashX(address) - return await self.hashX_listunspent(hashX) - - async def address_subscribe(self, *addresses): - """Subscribe to an address. - - address: the address to subscribe to""" - if len(addresses) > 1000: - raise RPCError(BAD_REQUEST, f'too many addresses in subscription request: {len(addresses)}') - hashXes = [ - (self.address_to_hashX(address), address) for address in addresses - ] - return await asyncio.gather(*(self.hashX_subscribe(*args) for args in hashXes)) - - async def address_unsubscribe(self, address): - """Unsubscribe an address. - - address: the address to unsubscribe""" - hashX = self.address_to_hashX(address) - return await self.hashX_unsubscribe(hashX, address) - - async def get_balance(self, hashX): - utxos = await self.db.all_utxos(hashX) - confirmed = sum(utxo.value for utxo in utxos) - unconfirmed = await self.mempool.balance_delta(hashX) - return {'confirmed': confirmed, 'unconfirmed': unconfirmed} - - async def scripthash_get_balance(self, scripthash): - """Return the confirmed and unconfirmed balance of a scripthash.""" - hashX = scripthash_to_hashX(scripthash) - return await self.get_balance(hashX) - - async def unconfirmed_history(self, hashX): - # Note unconfirmed history is unordered in electrum-server - # height is -1 if it has unconfirmed inputs, otherwise 0 - return [{'tx_hash': hash_to_hex_str(tx.hash), - 'height': -tx.has_unconfirmed_inputs, - 'fee': tx.fee} - for tx in await self.mempool.transaction_summaries(hashX)] - - async def confirmed_and_unconfirmed_history(self, hashX): - # Note history is ordered but unconfirmed is unordered in e-s - history = await self.session_mgr.limited_history(hashX) - conf = [{'tx_hash': hash_to_hex_str(tx_hash), 'height': height} - for tx_hash, height in history] - return conf + await self.unconfirmed_history(hashX) - - async def scripthash_get_history(self, scripthash): - """Return the confirmed and unconfirmed history of a scripthash.""" - hashX = scripthash_to_hashX(scripthash) - return await self.confirmed_and_unconfirmed_history(hashX) - - async def scripthash_get_mempool(self, scripthash): - """Return the mempool transactions touching a scripthash.""" - hashX = scripthash_to_hashX(scripthash) - return await self.unconfirmed_history(hashX) - - async def scripthash_listunspent(self, scripthash): - """Return the list of UTXOs of a scripthash.""" - hashX = scripthash_to_hashX(scripthash) - return await self.hashX_listunspent(hashX) - - async def scripthash_subscribe(self, scripthash): - """Subscribe to a script hash. - - scripthash: the SHA256 hash of the script to subscribe to""" - hashX = scripthash_to_hashX(scripthash) - return await self.hashX_subscribe(hashX, scripthash) - - async def _merkle_proof(self, cp_height, height): - max_height = self.db.db_height - if not height <= cp_height <= max_height: - raise RPCError(BAD_REQUEST, - f'require header height {height:,d} <= ' - f'cp_height {cp_height:,d} <= ' - f'chain height {max_height:,d}') - branch, root = await self.db.header_branch_and_root(cp_height + 1, - height) - return { - 'branch': [hash_to_hex_str(elt) for elt in branch], - 'root': hash_to_hex_str(root), - } - - async def block_header(self, height, cp_height=0): - """Return a raw block header as a hexadecimal string, or as a - dictionary with a merkle proof.""" - height = non_negative_integer(height) - cp_height = non_negative_integer(cp_height) - raw_header_hex = (await self.session_mgr.raw_header(height)).hex() - if cp_height == 0: - return raw_header_hex - result = {'header': raw_header_hex} - result.update(await self._merkle_proof(cp_height, height)) - return result - - async def block_header_13(self, height): - """Return a raw block header as a hexadecimal string. - - height: the header's height""" - return await self.block_header(height) - - async def block_headers(self, start_height, count, cp_height=0, b64=False): - """Return count concatenated block headers as hex for the main chain; - starting at start_height. - - start_height and count must be non-negative integers. At most - MAX_CHUNK_SIZE headers will be returned. - """ - start_height = non_negative_integer(start_height) - count = non_negative_integer(count) - cp_height = non_negative_integer(cp_height) - - max_size = self.MAX_CHUNK_SIZE - count = min(count, max_size) - headers, count = await self.db.read_headers(start_height, count) - compressobj = zlib.compressobj(wbits=-15, level=1, memLevel=9) - headers = base64.b64encode(compressobj.compress(headers) + compressobj.flush()).decode() if b64 else headers.hex() - result = { - 'base64' if b64 else 'hex': headers, - 'count': count, - 'max': max_size - } - if count and cp_height: - last_height = start_height + count - 1 - result.update(await self._merkle_proof(cp_height, last_height)) - return result - - async def block_get_chunk(self, index): - """Return a chunk of block headers as a hexadecimal string. - - index: the chunk index""" - index = non_negative_integer(index) - size = self.coin.CHUNK_SIZE - start_height = index * size - headers, _ = await self.db.read_headers(start_height, size) - return headers.hex() - - async def block_get_header(self, height): - """The deserialized header at a given height. - - height: the header's height""" - height = non_negative_integer(height) - return await self.session_mgr.electrum_header(height) - - def is_tor(self): - """Try to detect if the connection is to a tor hidden service we are - running.""" - peername = self.peer_mgr.proxy_peername() - if not peername: - return False - peer_address = self.peer_address() - return peer_address and peer_address[0] == peername[0] - - async def replaced_banner(self, banner): - network_info = await self.daemon_request('getnetworkinfo') - ni_version = network_info['version'] - major, minor = divmod(ni_version, 1000000) - minor, revision = divmod(minor, 10000) - revision //= 100 - daemon_version = f'{major:d}.{minor:d}.{revision:d}' - for pair in [ - ('$SERVER_VERSION', self.version), - ('$DAEMON_VERSION', daemon_version), - ('$DAEMON_SUBVERSION', network_info['subversion']), - ('$PAYMENT_ADDRESS', self.env.payment_address), - ('$DONATION_ADDRESS', self.env.donation_address), - ]: - banner = banner.replace(*pair) - return banner - - async def payment_address(self): - """Return the payment address as a string, empty if there is none.""" - return self.env.payment_address - - async def donation_address(self): - """Return the donation address as a string, empty if there is none.""" - return self.env.donation_address - - async def banner(self): - """Return the server banner text.""" - banner = f'You are connected to an {self.version} server.' - - if self.is_tor(): - banner_file = self.env.tor_banner_file - else: - banner_file = self.env.banner_file - if banner_file: - try: - with codecs.open(banner_file, 'r', 'utf-8') as f: - banner = f.read() - except Exception as e: - self.logger.error(f'reading banner file {banner_file}: {e!r}') - else: - banner = await self.replaced_banner(banner) - - return banner - - async def relayfee(self): - """The minimum fee a low-priority tx must pay in order to be accepted - to the daemon's memory pool.""" - return await self.daemon_request('relayfee') - - async def estimatefee(self, number): - """The estimated transaction fee per kilobyte to be paid for a - transaction to be included within a certain number of blocks. - - number: the number of blocks - """ - number = non_negative_integer(number) - return await self.daemon_request('estimatefee', number) - - async def ping(self): - """Serves as a connection keep-alive mechanism and for the client to - confirm the server is still responding. - """ - return None - - async def server_version(self, client_name='', protocol_version=None): - """Returns the server version as a string. - - client_name: a string identifying the client - protocol_version: the protocol version spoken by the client - """ - - if self.sv_seen and self.protocol_tuple >= (1, 4): - raise RPCError(BAD_REQUEST, f'server.version already sent') - self.sv_seen = True - - if client_name: - client_name = str(client_name) - if self.env.drop_client is not None and \ - self.env.drop_client.match(client_name): - self.close_after_send = True - raise RPCError(BAD_REQUEST, - f'unsupported client: {client_name}') - if self.client_version != client_name[:17]: - SESSIONS_COUNT.labels(version=self.client_version).dec() - self.client_version = client_name[:17] - SESSIONS_COUNT.labels(version=self.client_version).inc() - CLIENT_VERSIONS.labels(version=self.client_version).inc() - - # Find the highest common protocol version. Disconnect if - # that protocol version in unsupported. - ptuple, client_min = util.protocol_version( - protocol_version, self.PROTOCOL_MIN, self.PROTOCOL_MAX) - if ptuple is None: - # FIXME: this fills the logs - # if client_min > self.PROTOCOL_MIN: - # self.logger.info(f'client requested future protocol version ' - # f'{util.version_string(client_min)} ' - # f'- is your software out of date?') - self.close_after_send = True - raise RPCError(BAD_REQUEST, - f'unsupported protocol version: {protocol_version}') - self.protocol_tuple = ptuple - return self.version, self.protocol_version_string() - - async def transaction_broadcast(self, raw_tx): - """Broadcast a raw transaction to the network. - - raw_tx: the raw transaction as a hexadecimal string""" - # This returns errors as JSON RPC errors, as is natural - try: - hex_hash = await self.session_mgr.broadcast_transaction(raw_tx) - self.txs_sent += 1 - self.logger.info(f'sent tx: {hex_hash}') - return hex_hash - except DaemonError as e: - error, = e.args - message = error['message'] - self.logger.info(f'error sending transaction: {message}') - raise RPCError(BAD_REQUEST, 'the transaction was rejected by ' - f'network rules.\n\n{message}\n[{raw_tx}]') - - async def transaction_info(self, tx_hash: str): - assert_tx_hash(tx_hash) - tx_info = await self.daemon_request('getrawtransaction', tx_hash, True) - raw_tx = tx_info['hex'] - block_hash = tx_info.get('blockhash') - if not block_hash: - return raw_tx, {'block_height': -1} - merkle_height = (await self.daemon_request('deserialised_block', block_hash))['height'] - merkle = await self.transaction_merkle(tx_hash, merkle_height) - return raw_tx, merkle - - async def transaction_get_batch(self, *tx_hashes): - if len(tx_hashes) > 100: - raise RPCError(BAD_REQUEST, f'too many tx hashes in request: {len(tx_hashes)}') - for tx_hash in tx_hashes: - assert_tx_hash(tx_hash) - batch_result = {} - height = None - block_hash = None - block = None - for tx_hash in tx_hashes: - tx_info = await self.daemon_request('getrawtransaction', tx_hash, True) - raw_tx = tx_info['hex'] - if height is None: - if 'blockhash' in tx_info: - block_hash = tx_info['blockhash'] - block = await self.daemon_request('deserialised_block', block_hash) - height = block['height'] - else: - height = -1 - if block_hash != tx_info.get('blockhash'): - raise RPCError(BAD_REQUEST, f'request contains a mix of transaction heights') - else: - if not block_hash: - merkle = {'block_height': -1} - else: - try: - pos = block['tx'].index(tx_hash) - except ValueError: - raise RPCError(BAD_REQUEST, f'tx hash {tx_hash} not in ' - f'block {block_hash} at height {height:,d}') - merkle = { - "merkle": self._get_merkle_branch(block['tx'], pos), - "pos": pos - } - batch_result[tx_hash] = [raw_tx, merkle] - return batch_result - - async def transaction_get(self, tx_hash, verbose=False): - """Return the serialized raw transaction given its hash - - tx_hash: the transaction hash as a hexadecimal string - verbose: passed on to the daemon - """ - assert_tx_hash(tx_hash) - if verbose not in (True, False): - raise RPCError(BAD_REQUEST, f'"verbose" must be a boolean') - - return await self.daemon_request('getrawtransaction', tx_hash, verbose) - - async def _block_hash_and_tx_hashes(self, height): - """Returns a pair (block_hash, tx_hashes) for the main chain block at - the given height. - - block_hash is a hexadecimal string, and tx_hashes is an - ordered list of hexadecimal strings. - """ - height = non_negative_integer(height) - hex_hashes = await self.daemon_request('block_hex_hashes', height, 1) - block_hash = hex_hashes[0] - block = await self.daemon_request('deserialised_block', block_hash) - return block_hash, block['tx'] - - def _get_merkle_branch(self, tx_hashes, tx_pos): - """Return a merkle branch to a transaction. - - tx_hashes: ordered list of hex strings of tx hashes in a block - tx_pos: index of transaction in tx_hashes to create branch for - """ - hashes = [hex_str_to_hash(hash) for hash in tx_hashes] - branch, root = self.db.merkle.branch_and_root(hashes, tx_pos) - branch = [hash_to_hex_str(hash) for hash in branch] - return branch - - async def transaction_merkle(self, tx_hash, height): - """Return the markle branch to a confirmed transaction given its hash - and height. - - tx_hash: the transaction hash as a hexadecimal string - height: the height of the block it is in - """ - assert_tx_hash(tx_hash) - block_hash, tx_hashes = await self._block_hash_and_tx_hashes(height) - try: - pos = tx_hashes.index(tx_hash) - except ValueError: - raise RPCError(BAD_REQUEST, f'tx hash {tx_hash} not in ' - f'block {block_hash} at height {height:,d}') - branch = self._get_merkle_branch(tx_hashes, pos) - return {"block_height": height, "merkle": branch, "pos": pos} - - async def transaction_id_from_pos(self, height, tx_pos, merkle=False): - """Return the txid and optionally a merkle proof, given - a block height and position in the block. - """ - tx_pos = non_negative_integer(tx_pos) - if merkle not in (True, False): - raise RPCError(BAD_REQUEST, f'"merkle" must be a boolean') - - block_hash, tx_hashes = await self._block_hash_and_tx_hashes(height) - try: - tx_hash = tx_hashes[tx_pos] - except IndexError: - raise RPCError(BAD_REQUEST, f'no tx at position {tx_pos:,d} in ' - f'block {block_hash} at height {height:,d}') - - if merkle: - branch = self._get_merkle_branch(tx_hashes, tx_pos) - return {"tx_hash": tx_hash, "merkle": branch} - else: - return tx_hash - - -class LocalRPC(SessionBase): - """A local TCP RPC server session.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.client = 'RPC' - self.connection._max_response_size = 0 - - def protocol_version_string(self): - return 'RPC' - - -class ResultCacheItem: - __slots__ = '_result', 'lock', 'has_result' - - def __init__(self): - self.has_result = asyncio.Event() - self.lock = asyncio.Lock() - self._result = None - - @property - def result(self) -> str: - return self._result - - @result.setter - def result(self, result: str): - self._result = result - if result is not None: - self.has_result.set() - - -def get_from_possible_keys(dictionary, *keys): - for key in keys: - if key in dictionary: - return dictionary[key] diff --git a/lbry/wallet/server/storage.py b/lbry/wallet/server/storage.py deleted file mode 100644 index 5e7db97dd..000000000 --- a/lbry/wallet/server/storage.py +++ /dev/null @@ -1,169 +0,0 @@ -# Copyright (c) 2016-2017, the ElectrumX authors -# -# All rights reserved. -# -# See the file "LICENCE" for information about the copyright -# and warranty status of this software. - -"""Backend database abstraction.""" - -import os -from functools import partial - -from lbry.wallet.server import util - - -def db_class(db_dir, name): - """Returns a DB engine class.""" - for db_class in util.subclasses(Storage): - if db_class.__name__.lower() == name.lower(): - db_class.import_module() - return partial(db_class, db_dir) - raise RuntimeError(f'unrecognised DB engine "{name}"') - - -class Storage: - """Abstract base class of the DB backend abstraction.""" - - def __init__(self, db_dir, name, for_sync): - self.db_dir = db_dir - self.is_new = not os.path.exists(os.path.join(db_dir, name)) - self.for_sync = for_sync or self.is_new - self.open(name, create=self.is_new) - - @classmethod - def import_module(cls): - """Import the DB engine module.""" - raise NotImplementedError - - def open(self, name, create): - """Open an existing database or create a new one.""" - raise NotImplementedError - - def close(self): - """Close an existing database.""" - raise NotImplementedError - - def get(self, key): - raise NotImplementedError - - def put(self, key, value): - raise NotImplementedError - - def write_batch(self): - """Return a context manager that provides `put` and `delete`. - - Changes should only be committed when the context manager - closes without an exception. - """ - raise NotImplementedError - - def iterator(self, prefix=b'', reverse=False): - """Return an iterator that yields (key, value) pairs from the - database sorted by key. - - If `prefix` is set, only keys starting with `prefix` will be - included. If `reverse` is True the items are returned in - reverse order. - """ - raise NotImplementedError - - -class LevelDB(Storage): - """LevelDB database engine.""" - - @classmethod - def import_module(cls): - import plyvel - cls.module = plyvel - - def open(self, name, create): - mof = 512 if self.for_sync else 128 - path = os.path.join(self.db_dir, name) - # Use snappy compression (the default) - self.db = self.module.DB(path, create_if_missing=create, - max_open_files=mof) - self.close = self.db.close - self.get = self.db.get - self.put = self.db.put - self.iterator = self.db.iterator - self.write_batch = partial(self.db.write_batch, transaction=True, - sync=True) - - -class RocksDB(Storage): - """RocksDB database engine.""" - - @classmethod - def import_module(cls): - import rocksdb - cls.module = rocksdb - - def open(self, name, create): - mof = 512 if self.for_sync else 128 - path = os.path.join(self.db_dir, name) - # Use snappy compression (the default) - options = self.module.Options(create_if_missing=create, - use_fsync=True, - target_file_size_base=33554432, - max_open_files=mof) - self.db = self.module.DB(path, options) - self.get = self.db.get - self.put = self.db.put - - def close(self): - # PyRocksDB doesn't provide a close method; hopefully this is enough - self.db = self.get = self.put = None - import gc - gc.collect() - - def write_batch(self): - return RocksDBWriteBatch(self.db) - - def iterator(self, prefix=b'', reverse=False): - return RocksDBIterator(self.db, prefix, reverse) - - -class RocksDBWriteBatch: - """A write batch for RocksDB.""" - - def __init__(self, db): - self.batch = RocksDB.module.WriteBatch() - self.db = db - - def __enter__(self): - return self.batch - - def __exit__(self, exc_type, exc_val, exc_tb): - if not exc_val: - self.db.write(self.batch) - - -class RocksDBIterator: - """An iterator for RocksDB.""" - - def __init__(self, db, prefix, reverse): - self.prefix = prefix - if reverse: - self.iterator = reversed(db.iteritems()) - nxt_prefix = util.increment_byte_string(prefix) - if nxt_prefix: - self.iterator.seek(nxt_prefix) - try: - next(self.iterator) - except StopIteration: - self.iterator.seek(nxt_prefix) - else: - self.iterator.seek_to_last() - else: - self.iterator = db.iteritems() - self.iterator.seek(prefix) - - def __iter__(self): - return self - - def __next__(self): - k, v = next(self.iterator) - if not k.startswith(self.prefix): - raise StopIteration - return k, v diff --git a/lbry/wallet/server/text.py b/lbry/wallet/server/text.py deleted file mode 100644 index 4919b0c01..000000000 --- a/lbry/wallet/server/text.py +++ /dev/null @@ -1,82 +0,0 @@ -import time - -from lbry.wallet.server import util - - -def sessions_lines(data): - """A generator returning lines for a list of sessions. - - data is the return value of rpc_sessions().""" - fmt = ('{:<6} {:<5} {:>17} {:>5} {:>5} {:>5} ' - '{:>7} {:>7} {:>7} {:>7} {:>7} {:>9} {:>21}') - yield fmt.format('ID', 'Flags', 'Client', 'Proto', - 'Reqs', 'Txs', 'Subs', - 'Recv', 'Recv KB', 'Sent', 'Sent KB', 'Time', 'Peer') - for (id_, flags, peer, client, proto, reqs, txs_sent, subs, - recv_count, recv_size, send_count, send_size, time) in data: - yield fmt.format(id_, flags, client, proto, - f'{reqs:,d}', - f'{txs_sent:,d}', - f'{subs:,d}', - f'{recv_count:,d}', - '{:,d}'.format(recv_size // 1024), - f'{send_count:,d}', - '{:,d}'.format(send_size // 1024), - util.formatted_time(time, sep=''), peer) - - -def groups_lines(data): - """A generator returning lines for a list of groups. - - data is the return value of rpc_groups().""" - - fmt = ('{:<6} {:>9} {:>9} {:>6} {:>6} {:>8}' - '{:>7} {:>9} {:>7} {:>9}') - yield fmt.format('ID', 'Sessions', 'Bwidth KB', 'Reqs', 'Txs', 'Subs', - 'Recv', 'Recv KB', 'Sent', 'Sent KB') - for (id_, session_count, bandwidth, reqs, txs_sent, subs, - recv_count, recv_size, send_count, send_size) in data: - yield fmt.format(id_, - f'{session_count:,d}', - '{:,d}'.format(bandwidth // 1024), - f'{reqs:,d}', - f'{txs_sent:,d}', - f'{subs:,d}', - f'{recv_count:,d}', - '{:,d}'.format(recv_size // 1024), - f'{send_count:,d}', - '{:,d}'.format(send_size // 1024)) - - -def peers_lines(data): - """A generator returning lines for a list of peers. - - data is the return value of rpc_peers().""" - def time_fmt(t): - if not t: - return 'Never' - return util.formatted_time(now - t) - - now = time.time() - fmt = ('{:<30} {:<6} {:>5} {:>5} {:<17} {:>4} ' - '{:>4} {:>8} {:>11} {:>11} {:>5} {:>20} {:<15}') - yield fmt.format('Host', 'Status', 'TCP', 'SSL', 'Server', 'Min', - 'Max', 'Pruning', 'Last Good', 'Last Try', - 'Tries', 'Source', 'IP Address') - for item in data: - features = item['features'] - hostname = item['host'] - host = features['hosts'][hostname] - yield fmt.format(hostname[:30], - item['status'], - host.get('tcp_port') or '', - host.get('ssl_port') or '', - features['server_version'] or 'unknown', - features['protocol_min'], - features['protocol_max'], - features['pruning'] or '', - time_fmt(item['last_good']), - time_fmt(item['last_try']), - item['try_count'], - item['source'][:20], - item['ip_addr'] or '') diff --git a/lbry/wallet/server/tx.py b/lbry/wallet/server/tx.py deleted file mode 100644 index 411162155..000000000 --- a/lbry/wallet/server/tx.py +++ /dev/null @@ -1,615 +0,0 @@ -# Copyright (c) 2016-2017, Neil Booth -# Copyright (c) 2017, the ElectrumX authors -# -# All rights reserved. -# -# The MIT License (MIT) -# -# Permission is hereby granted, free of charge, to any person obtaining -# a copy of this software and associated documentation files (the -# "Software"), to deal in the Software without restriction, including -# without limitation the rights to use, copy, modify, merge, publish, -# distribute, sublicense, and/or sell copies of the Software, and to -# permit persons to whom the Software is furnished to do so, subject to -# the following conditions: -# -# The above copyright notice and this permission notice shall be -# included in all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -# and warranty status of this software. - -"""Transaction-related classes and functions.""" - -from collections import namedtuple - -from lbry.wallet.server.hash import sha256, double_sha256, hash_to_hex_str -from lbry.wallet.server.script import OpCodes -from lbry.wallet.server.util import ( - unpack_le_int32_from, unpack_le_int64_from, unpack_le_uint16_from, - unpack_le_uint32_from, unpack_le_uint64_from, pack_le_int32, pack_varint, - pack_le_uint32, pack_le_int64, pack_varbytes, -) - -ZERO = bytes(32) -MINUS_1 = 4294967295 - - -class Tx(namedtuple("Tx", "version inputs outputs locktime raw")): - """Class representing a transaction.""" - - -class TxInput(namedtuple("TxInput", "prev_hash prev_idx script sequence")): - """Class representing a transaction input.""" - def __str__(self): - script = self.script.hex() - prev_hash = hash_to_hex_str(self.prev_hash) - return (f"Input({prev_hash}, {self.prev_idx:d}, script={script}, sequence={self.sequence:d})") - - def is_generation(self): - """Test if an input is generation/coinbase like""" - return self.prev_idx == MINUS_1 and self.prev_hash == ZERO - - def serialize(self): - return b''.join(( - self.prev_hash, - pack_le_uint32(self.prev_idx), - pack_varbytes(self.script), - pack_le_uint32(self.sequence), - )) - - -class TxOutput(namedtuple("TxOutput", "value pk_script")): - - def serialize(self): - return b''.join(( - pack_le_int64(self.value), - pack_varbytes(self.pk_script), - )) - - -class Deserializer: - """Deserializes blocks into transactions. - - External entry points are read_tx(), read_tx_and_hash(), - read_tx_and_vsize() and read_block(). - - This code is performance sensitive as it is executed 100s of - millions of times during sync. - """ - - TX_HASH_FN = staticmethod(double_sha256) - - def __init__(self, binary, start=0): - assert isinstance(binary, bytes) - self.binary = binary - self.binary_length = len(binary) - self.cursor = start - self.flags = 0 - - def read_tx(self): - """Return a deserialized transaction.""" - start = self.cursor - return Tx( - self._read_le_int32(), # version - self._read_inputs(), # inputs - self._read_outputs(), # outputs - self._read_le_uint32(), # locktime - self.binary[start:self.cursor], - ) - - def read_tx_and_hash(self): - """Return a (deserialized TX, tx_hash) pair. - - The hash needs to be reversed for human display; for efficiency - we process it in the natural serialized order. - """ - start = self.cursor - return self.read_tx(), self.TX_HASH_FN(self.binary[start:self.cursor]) - - def read_tx_and_vsize(self): - """Return a (deserialized TX, vsize) pair.""" - return self.read_tx(), self.binary_length - - def read_tx_block(self): - """Returns a list of (deserialized_tx, tx_hash) pairs.""" - read = self.read_tx_and_hash - # Some coins have excess data beyond the end of the transactions - return [read() for _ in range(self._read_varint())] - - def _read_inputs(self): - read_input = self._read_input - return [read_input() for i in range(self._read_varint())] - - def _read_input(self): - return TxInput( - self._read_nbytes(32), # prev_hash - self._read_le_uint32(), # prev_idx - self._read_varbytes(), # script - self._read_le_uint32() # sequence - ) - - def _read_outputs(self): - read_output = self._read_output - return [read_output() for i in range(self._read_varint())] - - def _read_output(self): - return TxOutput( - self._read_le_int64(), # value - self._read_varbytes(), # pk_script - ) - - def _read_byte(self): - cursor = self.cursor - self.cursor += 1 - return self.binary[cursor] - - def _read_nbytes(self, n): - cursor = self.cursor - self.cursor = end = cursor + n - assert self.binary_length >= end - return self.binary[cursor:end] - - def _read_varbytes(self): - return self._read_nbytes(self._read_varint()) - - def _read_varint(self): - n = self.binary[self.cursor] - self.cursor += 1 - if n < 253: - return n - if n == 253: - return self._read_le_uint16() - if n == 254: - return self._read_le_uint32() - return self._read_le_uint64() - - def _read_le_int32(self): - result, = unpack_le_int32_from(self.binary, self.cursor) - self.cursor += 4 - return result - - def _read_le_int64(self): - result, = unpack_le_int64_from(self.binary, self.cursor) - self.cursor += 8 - return result - - def _read_le_uint16(self): - result, = unpack_le_uint16_from(self.binary, self.cursor) - self.cursor += 2 - return result - - def _read_le_uint32(self): - result, = unpack_le_uint32_from(self.binary, self.cursor) - self.cursor += 4 - return result - - def _read_le_uint64(self): - result, = unpack_le_uint64_from(self.binary, self.cursor) - self.cursor += 8 - return result - - -class TxSegWit(namedtuple("Tx", "version marker flag inputs outputs " - "witness locktime raw")): - """Class representing a SegWit transaction.""" - - -class DeserializerSegWit(Deserializer): - - # https://bitcoincore.org/en/segwit_wallet_dev/#transaction-serialization - - def _read_witness(self, fields): - read_witness_field = self._read_witness_field - return [read_witness_field() for i in range(fields)] - - def _read_witness_field(self): - read_varbytes = self._read_varbytes - return [read_varbytes() for i in range(self._read_varint())] - - def _read_tx_parts(self): - """Return a (deserialized TX, tx_hash, vsize) tuple.""" - start = self.cursor - marker = self.binary[self.cursor + 4] - if marker: - tx = super().read_tx() - tx_hash = self.TX_HASH_FN(self.binary[start:self.cursor]) - return tx, tx_hash, self.binary_length - - # Ugh, this is nasty. - version = self._read_le_int32() - orig_ser = self.binary[start:self.cursor] - - marker = self._read_byte() - flag = self._read_byte() - - start = self.cursor - inputs = self._read_inputs() - outputs = self._read_outputs() - orig_ser += self.binary[start:self.cursor] - - base_size = self.cursor - start - witness = self._read_witness(len(inputs)) - - start = self.cursor - locktime = self._read_le_uint32() - orig_ser += self.binary[start:self.cursor] - vsize = (3 * base_size + self.binary_length) // 4 - - return TxSegWit(version, marker, flag, inputs, outputs, witness, - locktime, orig_ser), self.TX_HASH_FN(orig_ser), vsize - - def read_tx(self): - return self._read_tx_parts()[0] - - def read_tx_and_hash(self): - tx, tx_hash, vsize = self._read_tx_parts() - return tx, tx_hash - - def read_tx_and_vsize(self): - tx, tx_hash, vsize = self._read_tx_parts() - return tx, vsize - - -class DeserializerAuxPow(Deserializer): - VERSION_AUXPOW = (1 << 8) - - def read_header(self, height, static_header_size): - """Return the AuxPow block header bytes""" - start = self.cursor - version = self._read_le_uint32() - if version & self.VERSION_AUXPOW: - # We are going to calculate the block size then read it as bytes - self.cursor = start - self.cursor += static_header_size # Block normal header - self.read_tx() # AuxPow transaction - self.cursor += 32 # Parent block hash - merkle_size = self._read_varint() - self.cursor += 32 * merkle_size # Merkle branch - self.cursor += 4 # Index - merkle_size = self._read_varint() - self.cursor += 32 * merkle_size # Chain merkle branch - self.cursor += 4 # Chain index - self.cursor += 80 # Parent block header - header_end = self.cursor - else: - header_end = static_header_size - self.cursor = start - return self._read_nbytes(header_end) - - -class DeserializerAuxPowSegWit(DeserializerSegWit, DeserializerAuxPow): - pass - - -class DeserializerEquihash(Deserializer): - def read_header(self, height, static_header_size): - """Return the block header bytes""" - start = self.cursor - # We are going to calculate the block size then read it as bytes - self.cursor += static_header_size - solution_size = self._read_varint() - self.cursor += solution_size - header_end = self.cursor - self.cursor = start - return self._read_nbytes(header_end) - - -class DeserializerEquihashSegWit(DeserializerSegWit, DeserializerEquihash): - pass - - -class TxJoinSplit(namedtuple("Tx", "version inputs outputs locktime")): - """Class representing a JoinSplit transaction.""" - - -class DeserializerZcash(DeserializerEquihash): - def read_tx(self): - header = self._read_le_uint32() - overwintered = ((header >> 31) == 1) - if overwintered: - version = header & 0x7fffffff - self.cursor += 4 # versionGroupId - else: - version = header - - is_overwinter_v3 = version == 3 - is_sapling_v4 = version == 4 - - base_tx = TxJoinSplit( - version, - self._read_inputs(), # inputs - self._read_outputs(), # outputs - self._read_le_uint32() # locktime - ) - - if is_overwinter_v3 or is_sapling_v4: - self.cursor += 4 # expiryHeight - - has_shielded = False - if is_sapling_v4: - self.cursor += 8 # valueBalance - shielded_spend_size = self._read_varint() - self.cursor += shielded_spend_size * 384 # vShieldedSpend - shielded_output_size = self._read_varint() - self.cursor += shielded_output_size * 948 # vShieldedOutput - has_shielded = shielded_spend_size > 0 or shielded_output_size > 0 - - if base_tx.version >= 2: - joinsplit_size = self._read_varint() - if joinsplit_size > 0: - joinsplit_desc_len = 1506 + (192 if is_sapling_v4 else 296) - # JSDescription - self.cursor += joinsplit_size * joinsplit_desc_len - self.cursor += 32 # joinSplitPubKey - self.cursor += 64 # joinSplitSig - - if is_sapling_v4 and has_shielded: - self.cursor += 64 # bindingSig - - return base_tx - - -class TxTime(namedtuple("Tx", "version time inputs outputs locktime")): - """Class representing transaction that has a time field.""" - - -class DeserializerTxTime(Deserializer): - def read_tx(self): - return TxTime( - self._read_le_int32(), # version - self._read_le_uint32(), # time - self._read_inputs(), # inputs - self._read_outputs(), # outputs - self._read_le_uint32(), # locktime - ) - - -class DeserializerReddcoin(Deserializer): - def read_tx(self): - version = self._read_le_int32() - inputs = self._read_inputs() - outputs = self._read_outputs() - locktime = self._read_le_uint32() - if version > 1: - time = self._read_le_uint32() - else: - time = 0 - - return TxTime(version, time, inputs, outputs, locktime) - - -class DeserializerTxTimeAuxPow(DeserializerTxTime): - VERSION_AUXPOW = (1 << 8) - - def is_merged_block(self): - start = self.cursor - self.cursor = 0 - version = self._read_le_uint32() - self.cursor = start - if version & self.VERSION_AUXPOW: - return True - return False - - def read_header(self, height, static_header_size): - """Return the AuxPow block header bytes""" - start = self.cursor - version = self._read_le_uint32() - if version & self.VERSION_AUXPOW: - # We are going to calculate the block size then read it as bytes - self.cursor = start - self.cursor += static_header_size # Block normal header - self.read_tx() # AuxPow transaction - self.cursor += 32 # Parent block hash - merkle_size = self._read_varint() - self.cursor += 32 * merkle_size # Merkle branch - self.cursor += 4 # Index - merkle_size = self._read_varint() - self.cursor += 32 * merkle_size # Chain merkle branch - self.cursor += 4 # Chain index - self.cursor += 80 # Parent block header - header_end = self.cursor - else: - header_end = static_header_size - self.cursor = start - return self._read_nbytes(header_end) - - -class DeserializerBitcoinAtom(DeserializerSegWit): - FORK_BLOCK_HEIGHT = 505888 - - def read_header(self, height, static_header_size): - """Return the block header bytes""" - header_len = static_header_size - if height >= self.FORK_BLOCK_HEIGHT: - header_len += 4 # flags - return self._read_nbytes(header_len) - - -class DeserializerGroestlcoin(DeserializerSegWit): - TX_HASH_FN = staticmethod(sha256) - - -class TxInputTokenPay(TxInput): - """Class representing a TokenPay transaction input.""" - - OP_ANON_MARKER = 0xb9 - # 2byte marker (cpubkey + sigc + sigr) - MIN_ANON_IN_SIZE = 2 + (33 + 32 + 32) - - def _is_anon_input(self): - return (len(self.script) >= self.MIN_ANON_IN_SIZE and - self.script[0] == OpCodes.OP_RETURN and - self.script[1] == self.OP_ANON_MARKER) - - def is_generation(self): - # Transactions coming in from stealth addresses are seen by - # the blockchain as newly minted coins. The reverse, where coins - # are sent TO a stealth address, are seen by the blockchain as - # a coin burn. - if self._is_anon_input(): - return True - return super().is_generation() - - -class TxInputTokenPayStealth( - namedtuple("TxInput", "keyimage ringsize script sequence")): - """Class representing a TokenPay stealth transaction input.""" - - def __str__(self): - script = self.script.hex() - keyimage = bytes(self.keyimage).hex() - return (f"Input({keyimage}, {self.ringsize[1]:d}, script={script}, sequence={self.sequence:d})") - - def is_generation(self): - return True - - def serialize(self): - return b''.join(( - self.keyimage, - self.ringsize, - pack_varbytes(self.script), - pack_le_uint32(self.sequence), - )) - - -class DeserializerTokenPay(DeserializerTxTime): - - def _read_input(self): - txin = TxInputTokenPay( - self._read_nbytes(32), # prev_hash - self._read_le_uint32(), # prev_idx - self._read_varbytes(), # script - self._read_le_uint32(), # sequence - ) - if txin._is_anon_input(): - # Not sure if this is actually needed, and seems - # extra work for no immediate benefit, but it at - # least correctly represents a stealth input - raw = txin.serialize() - deserializer = Deserializer(raw) - txin = TxInputTokenPayStealth( - deserializer._read_nbytes(33), # keyimage - deserializer._read_nbytes(3), # ringsize - deserializer._read_varbytes(), # script - deserializer._read_le_uint32() # sequence - ) - return txin - - -# Decred -class TxInputDcr(namedtuple("TxInput", "prev_hash prev_idx tree sequence")): - """Class representing a Decred transaction input.""" - - def __str__(self): - prev_hash = hash_to_hex_str(self.prev_hash) - return (f"Input({prev_hash}, {self.prev_idx:d}, tree={self.tree}, sequence={self.sequence:d})") - - def is_generation(self): - """Test if an input is generation/coinbase like""" - return self.prev_idx == MINUS_1 and self.prev_hash == ZERO - - -class TxOutputDcr(namedtuple("TxOutput", "value version pk_script")): - """Class representing a Decred transaction output.""" - pass - - -class TxDcr(namedtuple("Tx", "version inputs outputs locktime expiry " - "witness")): - """Class representing a Decred transaction.""" - - -class DeserializerDecred(Deserializer): - @staticmethod - def blake256(data): - from blake256.blake256 import blake_hash - return blake_hash(data) - - @staticmethod - def blake256d(data): - from blake256.blake256 import blake_hash - return blake_hash(blake_hash(data)) - - def read_tx(self): - return self._read_tx_parts(produce_hash=False)[0] - - def read_tx_and_hash(self): - tx, tx_hash, vsize = self._read_tx_parts() - return tx, tx_hash - - def read_tx_and_vsize(self): - tx, tx_hash, vsize = self._read_tx_parts(produce_hash=False) - return tx, vsize - - def read_tx_block(self): - """Returns a list of (deserialized_tx, tx_hash) pairs.""" - read = self.read_tx_and_hash - txs = [read() for _ in range(self._read_varint())] - stxs = [read() for _ in range(self._read_varint())] - return txs + stxs - - def read_tx_tree(self): - """Returns a list of deserialized_tx without tx hashes.""" - read_tx = self.read_tx - return [read_tx() for _ in range(self._read_varint())] - - def _read_input(self): - return TxInputDcr( - self._read_nbytes(32), # prev_hash - self._read_le_uint32(), # prev_idx - self._read_byte(), # tree - self._read_le_uint32(), # sequence - ) - - def _read_output(self): - return TxOutputDcr( - self._read_le_int64(), # value - self._read_le_uint16(), # version - self._read_varbytes(), # pk_script - ) - - def _read_witness(self, fields): - read_witness_field = self._read_witness_field - assert fields == self._read_varint() - return [read_witness_field() for _ in range(fields)] - - def _read_witness_field(self): - value_in = self._read_le_int64() - block_height = self._read_le_uint32() - block_index = self._read_le_uint32() - script = self._read_varbytes() - return value_in, block_height, block_index, script - - def _read_tx_parts(self, produce_hash=True): - start = self.cursor - version = self._read_le_int32() - inputs = self._read_inputs() - outputs = self._read_outputs() - locktime = self._read_le_uint32() - expiry = self._read_le_uint32() - end_prefix = self.cursor - witness = self._read_witness(len(inputs)) - - if produce_hash: - # TxSerializeNoWitness << 16 == 0x10000 - no_witness_header = pack_le_uint32(0x10000 | (version & 0xffff)) - prefix_tx = no_witness_header + self.binary[start+4:end_prefix] - tx_hash = self.blake256(prefix_tx) - else: - tx_hash = None - - return TxDcr( - version, - inputs, - outputs, - locktime, - expiry, - witness - ), tx_hash, self.cursor - start diff --git a/lbry/wallet/server/util.py b/lbry/wallet/server/util.py deleted file mode 100644 index 915a6975c..000000000 --- a/lbry/wallet/server/util.py +++ /dev/null @@ -1,359 +0,0 @@ -# Copyright (c) 2016-2017, Neil Booth -# -# All rights reserved. -# -# The MIT License (MIT) -# -# Permission is hereby granted, free of charge, to any person obtaining -# a copy of this software and associated documentation files (the -# "Software"), to deal in the Software without restriction, including -# without limitation the rights to use, copy, modify, merge, publish, -# distribute, sublicense, and/or sell copies of the Software, and to -# permit persons to whom the Software is furnished to do so, subject to -# the following conditions: -# -# The above copyright notice and this permission notice shall be -# included in all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -# and warranty status of this software. - -"""Miscellaneous utility classes and functions.""" - - -import array -import inspect -from ipaddress import ip_address -import logging -import re -import sys -from collections import Container, Mapping -from struct import pack, Struct - -# Logging utilities - - -class ConnectionLogger(logging.LoggerAdapter): - """Prepends a connection identifier to a logging message.""" - def process(self, msg, kwargs): - conn_id = self.extra.get('conn_id', 'unknown') - return f'[{conn_id}] {msg}', kwargs - - -class CompactFormatter(logging.Formatter): - """Strips the module from the logger name to leave the class only.""" - def format(self, record): - record.name = record.name.rpartition('.')[-1] - return super().format(record) - - -def make_logger(name, *, handler, level): - """Return the root ElectrumX logger.""" - logger = logging.getLogger(name) - logger.addHandler(handler) - logger.setLevel(logging.INFO) - logger.propagate = False - return logger - - -def class_logger(path, classname): - """Return a hierarchical logger for a class.""" - return logging.getLogger(path).getChild(classname) - - -# Method decorator. To be used for calculations that will always -# deliver the same result. The method cannot take any arguments -# and should be accessed as an attribute. -class cachedproperty: - - def __init__(self, f): - self.f = f - - def __get__(self, obj, type): - obj = obj or type - value = self.f(obj) - setattr(obj, self.f.__name__, value) - return value - - -def formatted_time(t, sep=' '): - """Return a number of seconds as a string in days, hours, mins and - maybe secs.""" - t = int(t) - fmts = (('{:d}d', 86400), ('{:02d}h', 3600), ('{:02d}m', 60)) - parts = [] - for fmt, n in fmts: - val = t // n - if parts or val: - parts.append(fmt.format(val)) - t %= n - if len(parts) < 3: - parts.append(f'{t:02d}s') - return sep.join(parts) - - -def deep_getsizeof(obj): - """Find the memory footprint of a Python object. - - Based on code from code.tutsplus.com: http://goo.gl/fZ0DXK - - This is a recursive function that drills down a Python object graph - like a dictionary holding nested dictionaries with lists of lists - and tuples and sets. - - The sys.getsizeof function does a shallow size of only. It counts each - object inside a container as pointer only regardless of how big it - really is. - """ - - ids = set() - - def size(o): - if id(o) in ids: - return 0 - - r = sys.getsizeof(o) - ids.add(id(o)) - - if isinstance(o, (str, bytes, bytearray, array.array)): - return r - - if isinstance(o, Mapping): - return r + sum(size(k) + size(v) for k, v in o.items()) - - if isinstance(o, Container): - return r + sum(size(x) for x in o) - - return r - - return size(obj) - - -def subclasses(base_class, strict=True): - """Return a list of subclasses of base_class in its module.""" - def select(obj): - return (inspect.isclass(obj) and issubclass(obj, base_class) and - (not strict or obj != base_class)) - - pairs = inspect.getmembers(sys.modules[base_class.__module__], select) - return [pair[1] for pair in pairs] - - -def chunks(items, size): - """Break up items, an iterable, into chunks of length size.""" - for i in range(0, len(items), size): - yield items[i: i + size] - - -def resolve_limit(limit): - if limit is None: - return -1 - assert isinstance(limit, int) and limit >= 0 - return limit - - -def bytes_to_int(be_bytes): - """Interprets a big-endian sequence of bytes as an integer""" - return int.from_bytes(be_bytes, 'big') - - -def int_to_bytes(value): - """Converts an integer to a big-endian sequence of bytes""" - return value.to_bytes((value.bit_length() + 7) // 8, 'big') - - -def increment_byte_string(bs): - """Return the lexicographically next byte string of the same length. - - Return None if there is none (when the input is all 0xff bytes).""" - for n in range(1, len(bs) + 1): - if bs[-n] != 0xff: - return bs[:-n] + bytes([bs[-n] + 1]) + bytes(n - 1) - return None - - -class LogicalFile: - """A logical binary file split across several separate files on disk.""" - - def __init__(self, prefix, digits, file_size): - digit_fmt = f'{{:0{digits:d}d}}' - self.filename_fmt = prefix + digit_fmt - self.file_size = file_size - - def read(self, start, size=-1): - """Read up to size bytes from the virtual file, starting at offset - start, and return them. - - If size is -1 all bytes are read.""" - parts = [] - while size != 0: - try: - with self.open_file(start, False) as f: - part = f.read(size) - if not part: - break - except FileNotFoundError: - break - parts.append(part) - start += len(part) - if size > 0: - size -= len(part) - return b''.join(parts) - - def write(self, start, b): - """Write the bytes-like object, b, to the underlying virtual file.""" - while b: - size = min(len(b), self.file_size - (start % self.file_size)) - with self.open_file(start, True) as f: - f.write(b if size == len(b) else b[:size]) - b = b[size:] - start += size - - def open_file(self, start, create): - """Open the virtual file and seek to start. Return a file handle. - Raise FileNotFoundError if the file does not exist and create - is False. - """ - file_num, offset = divmod(start, self.file_size) - filename = self.filename_fmt.format(file_num) - f = open_file(filename, create) - f.seek(offset) - return f - - -def open_file(filename, create=False): - """Open the file name. Return its handle.""" - try: - return open(filename, 'rb+') - except FileNotFoundError: - if create: - return open(filename, 'wb+') - raise - - -def open_truncate(filename): - """Open the file name. Return its handle.""" - return open(filename, 'wb+') - - -def address_string(address): - """Return an address as a correctly formatted string.""" - fmt = '{}:{:d}' - host, port = address - try: - host = ip_address(host) - except ValueError: - pass - else: - if host.version == 6: - fmt = '[{}]:{:d}' - return fmt.format(host, port) - -# See http://stackoverflow.com/questions/2532053/validate-a-hostname-string -# Note underscores are valid in domain names, but strictly invalid in host -# names. We ignore that distinction. - - -SEGMENT_REGEX = re.compile("(?!-)[A-Z_\\d-]{1,63}(? 255: - return False - # strip exactly one dot from the right, if present - if hostname and hostname[-1] == ".": - hostname = hostname[:-1] - return all(SEGMENT_REGEX.match(x) for x in hostname.split(".")) - - -def protocol_tuple(s): - """Converts a protocol version number, such as "1.0" to a tuple (1, 0). - - If the version number is bad, (0, ) indicating version 0 is returned.""" - try: - return tuple(int(part) for part in s.split('.')) - except Exception: - return (0, ) - - -def version_string(ptuple): - """Convert a version tuple such as (1, 2) to "1.2". - There is always at least one dot, so (1, ) becomes "1.0".""" - while len(ptuple) < 2: - ptuple += (0, ) - return '.'.join(str(p) for p in ptuple) - - -def protocol_version(client_req, min_tuple, max_tuple): - """Given a client's protocol version string, return a pair of - protocol tuples: - - (negotiated version, client min request) - - If the request is unsupported, the negotiated protocol tuple is - None. - """ - if client_req is None: - client_min = client_max = min_tuple - else: - if isinstance(client_req, list) and len(client_req) == 2: - client_min, client_max = client_req - else: - client_min = client_max = client_req - client_min = protocol_tuple(client_min) - client_max = protocol_tuple(client_max) - - result = min(client_max, max_tuple) - if result < max(client_min, min_tuple) or result == (0, ): - result = None - - return result, client_min - - -struct_le_i = Struct('H') -struct_be_I = Struct('>I') -structB = Struct('B') - -unpack_le_int32_from = struct_le_i.unpack_from -unpack_le_int64_from = struct_le_q.unpack_from -unpack_le_uint16_from = struct_le_H.unpack_from -unpack_le_uint32_from = struct_le_I.unpack_from -unpack_le_uint64_from = struct_le_Q.unpack_from -unpack_be_uint16_from = struct_be_H.unpack_from -unpack_be_uint32_from = struct_be_I.unpack_from - -pack_le_int32 = struct_le_i.pack -pack_le_int64 = struct_le_q.pack -pack_le_uint16 = struct_le_H.pack -pack_le_uint32 = struct_le_I.pack -pack_le_uint64 = struct_le_Q.pack -pack_be_uint16 = struct_be_H.pack -pack_be_uint32 = struct_be_I.pack -pack_byte = structB.pack - -hex_to_bytes = bytes.fromhex - - -def pack_varint(n): - if n < 253: - return pack_byte(n) - if n < 65536: - return pack_byte(253) + pack_le_uint16(n) - if n < 4294967296: - return pack_byte(254) + pack_le_uint32(n) - return pack_byte(255) + pack_le_uint64(n) - - -def pack_varbytes(data): - return pack_varint(len(data)) + data diff --git a/lbry/wallet/server/version.py b/lbry/wallet/server/version.py deleted file mode 100644 index 7d640996d..000000000 --- a/lbry/wallet/server/version.py +++ /dev/null @@ -1,3 +0,0 @@ -# need this to avoid circular import -PROTOCOL_MIN = (0, 54, 0) -PROTOCOL_MAX = (0, 99, 0) diff --git a/lbry/wallet/server/websocket.py b/lbry/wallet/server/websocket.py deleted file mode 100644 index 9620918cb..000000000 --- a/lbry/wallet/server/websocket.py +++ /dev/null @@ -1,55 +0,0 @@ -import asyncio -from weakref import WeakSet - -from aiohttp.web import Application, AppRunner, WebSocketResponse, TCPSite -from aiohttp.http_websocket import WSMsgType, WSCloseCode - - -class AdminWebSocket: - - def __init__(self, manager): - self.manager = manager - self.app = Application() - self.app['websockets'] = WeakSet() - self.app.router.add_get('/', self.on_connect) - self.app.on_shutdown.append(self.on_shutdown) - self.runner = AppRunner(self.app) - - async def on_status(self, _): - if not self.app['websockets']: - return - self.send_message({ - 'type': 'status', - 'height': self.manager.daemon.cached_height(), - }) - - def send_message(self, msg): - for web_socket in self.app['websockets']: - asyncio.create_task(web_socket.send_json(msg)) - - async def start(self): - await self.runner.setup() - await TCPSite(self.runner, self.manager.env.websocket_host, self.manager.env.websocket_port).start() - - async def stop(self): - await self.runner.cleanup() - - async def on_connect(self, request): - web_socket = WebSocketResponse() - await web_socket.prepare(request) - self.app['websockets'].add(web_socket) - try: - async for msg in web_socket: - if msg.type == WSMsgType.TEXT: - await self.on_status(None) - elif msg.type == WSMsgType.ERROR: - print('web socket connection closed with exception %s' % - web_socket.exception()) - finally: - self.app['websockets'].discard(web_socket) - return web_socket - - @staticmethod - async def on_shutdown(app): - for web_socket in set(app['websockets']): - await web_socket.close(code=WSCloseCode.GOING_AWAY, message='Server shutdown')