diff --git a/lbry/wallet/client/bcd_data_stream.py b/lbry/wallet/bcd_data_stream.py similarity index 100% rename from lbry/wallet/client/bcd_data_stream.py rename to lbry/wallet/bcd_data_stream.py diff --git a/lbry/wallet/client/bip32.py b/lbry/wallet/bip32.py similarity index 100% rename from lbry/wallet/client/bip32.py rename to lbry/wallet/bip32.py diff --git a/lbry/wallet/client/baseaccount.py b/lbry/wallet/client/baseaccount.py deleted file mode 100644 index 26c3451b7..000000000 --- a/lbry/wallet/client/baseaccount.py +++ /dev/null @@ -1,485 +0,0 @@ -import os -import json -import time -import asyncio -import random -import typing -from typing import Dict, Tuple, Type, Optional, Any, List - -from lbry.crypto.hash import sha256 -from lbry.crypto.crypt import aes_encrypt, aes_decrypt -from lbry.wallet.client.bip32 import PrivateKey, PubKey, from_extended_key_string -from lbry.wallet.client.mnemonic import Mnemonic -from lbry.wallet.client.constants import COIN -from lbry.error import InvalidPasswordError - -if typing.TYPE_CHECKING: - from lbry.wallet.client import baseledger, wallet as basewallet - - -class AddressManager: - - name: str - - __slots__ = 'account', 'public_key', 'chain_number', 'address_generator_lock' - - def __init__(self, account, public_key, chain_number): - self.account = account - self.public_key = public_key - self.chain_number = chain_number - self.address_generator_lock = asyncio.Lock() - - @classmethod - def from_dict(cls, account: 'BaseAccount', d: dict) \ - -> Tuple['AddressManager', 'AddressManager']: - raise NotImplementedError - - @classmethod - def to_dict(cls, receiving: 'AddressManager', change: 'AddressManager') -> Dict: - d: Dict[str, Any] = {'name': cls.name} - receiving_dict = receiving.to_dict_instance() - if receiving_dict: - d['receiving'] = receiving_dict - change_dict = change.to_dict_instance() - if change_dict: - d['change'] = change_dict - return d - - def merge(self, d: dict): - pass - - def to_dict_instance(self) -> Optional[dict]: - raise NotImplementedError - - def _query_addresses(self, **constraints): - return self.account.ledger.db.get_addresses( - accounts=[self.account], - chain=self.chain_number, - **constraints - ) - - def get_private_key(self, index: int) -> PrivateKey: - raise NotImplementedError - - def get_public_key(self, index: int) -> PubKey: - raise NotImplementedError - - async def get_max_gap(self): - raise NotImplementedError - - async def ensure_address_gap(self): - raise NotImplementedError - - def get_address_records(self, only_usable: bool = False, **constraints): - raise NotImplementedError - - async def get_addresses(self, only_usable: bool = False, **constraints) -> List[str]: - records = await self.get_address_records(only_usable=only_usable, **constraints) - return [r['address'] for r in records] - - async def get_or_create_usable_address(self) -> str: - addresses = await self.get_addresses(only_usable=True, limit=10) - if addresses: - return random.choice(addresses) - addresses = await self.ensure_address_gap() - return addresses[0] - - -class HierarchicalDeterministic(AddressManager): - """ Implements simple version of Bitcoin Hierarchical Deterministic key management. """ - - name: str = "deterministic-chain" - - __slots__ = 'gap', 'maximum_uses_per_address' - - def __init__(self, account: 'BaseAccount', chain: int, gap: int, maximum_uses_per_address: int) -> None: - super().__init__(account, account.public_key.child(chain), chain) - self.gap = gap - self.maximum_uses_per_address = maximum_uses_per_address - - @classmethod - def from_dict(cls, account: 'BaseAccount', d: dict) -> Tuple[AddressManager, AddressManager]: - return ( - cls(account, 0, **d.get('receiving', {'gap': 20, 'maximum_uses_per_address': 1})), - cls(account, 1, **d.get('change', {'gap': 6, 'maximum_uses_per_address': 1})) - ) - - def merge(self, d: dict): - self.gap = d.get('gap', self.gap) - self.maximum_uses_per_address = d.get('maximum_uses_per_address', self.maximum_uses_per_address) - - def to_dict_instance(self): - return {'gap': self.gap, 'maximum_uses_per_address': self.maximum_uses_per_address} - - def get_private_key(self, index: int) -> PrivateKey: - return self.account.private_key.child(self.chain_number).child(index) - - def get_public_key(self, index: int) -> PubKey: - return self.account.public_key.child(self.chain_number).child(index) - - async def get_max_gap(self) -> int: - addresses = await self._query_addresses(order_by="n asc") - max_gap = 0 - current_gap = 0 - for address in addresses: - if address['used_times'] == 0: - current_gap += 1 - else: - max_gap = max(max_gap, current_gap) - current_gap = 0 - return max_gap - - async def ensure_address_gap(self) -> List[str]: - async with self.address_generator_lock: - addresses = await self._query_addresses(limit=self.gap, order_by="n desc") - - existing_gap = 0 - for address in addresses: - if address['used_times'] == 0: - existing_gap += 1 - else: - break - - if existing_gap == self.gap: - return [] - - start = addresses[0]['pubkey'].n+1 if addresses else 0 - end = start + (self.gap - existing_gap) - new_keys = await self._generate_keys(start, end-1) - await self.account.ledger.announce_addresses(self, new_keys) - return new_keys - - async def _generate_keys(self, start: int, end: int) -> List[str]: - if not self.address_generator_lock.locked(): - raise RuntimeError('Should not be called outside of address_generator_lock.') - keys = [self.public_key.child(index) for index in range(start, end+1)] - await self.account.ledger.db.add_keys(self.account, self.chain_number, keys) - return [key.address for key in keys] - - def get_address_records(self, only_usable: bool = False, **constraints): - if only_usable: - constraints['used_times__lt'] = self.maximum_uses_per_address - if 'order_by' not in constraints: - constraints['order_by'] = "used_times asc, n asc" - return self._query_addresses(**constraints) - - -class SingleKey(AddressManager): - """ Single Key address manager always returns the same address for all operations. """ - - name: str = "single-address" - - __slots__ = () - - @classmethod - def from_dict(cls, account: 'BaseAccount', d: dict)\ - -> Tuple[AddressManager, AddressManager]: - same_address_manager = cls(account, account.public_key, 0) - return same_address_manager, same_address_manager - - def to_dict_instance(self): - return None - - def get_private_key(self, index: int) -> PrivateKey: - return self.account.private_key - - def get_public_key(self, index: int) -> PubKey: - return self.account.public_key - - async def get_max_gap(self) -> int: - return 0 - - async def ensure_address_gap(self) -> List[str]: - async with self.address_generator_lock: - exists = await self.get_address_records() - if not exists: - await self.account.ledger.db.add_keys(self.account, self.chain_number, [self.public_key]) - new_keys = [self.public_key.address] - await self.account.ledger.announce_addresses(self, new_keys) - return new_keys - return [] - - def get_address_records(self, only_usable: bool = False, **constraints): - return self._query_addresses(**constraints) - - -class BaseAccount: - - mnemonic_class = Mnemonic - private_key_class = PrivateKey - public_key_class = PubKey - address_generators: Dict[str, Type[AddressManager]] = { - SingleKey.name: SingleKey, - HierarchicalDeterministic.name: HierarchicalDeterministic, - } - - def __init__(self, ledger: 'baseledger.BaseLedger', wallet: 'basewallet.Wallet', name: str, - seed: str, private_key_string: str, encrypted: bool, - private_key: Optional[PrivateKey], public_key: PubKey, - address_generator: dict, modified_on: float) -> None: - self.ledger = ledger - self.wallet = wallet - self.id = public_key.address - self.name = name - self.seed = seed - self.modified_on = modified_on - self.private_key_string = private_key_string - self.init_vectors: Dict[str, bytes] = {} - self.encrypted = encrypted - self.private_key = private_key - self.public_key = public_key - generator_name = address_generator.get('name', HierarchicalDeterministic.name) - self.address_generator = self.address_generators[generator_name] - self.receiving, self.change = self.address_generator.from_dict(self, address_generator) - self.address_managers = {am.chain_number: am for am in {self.receiving, self.change}} - ledger.add_account(self) - wallet.add_account(self) - - def get_init_vector(self, key) -> Optional[bytes]: - init_vector = self.init_vectors.get(key, None) - if init_vector is None: - init_vector = self.init_vectors[key] = os.urandom(16) - return init_vector - - @classmethod - def generate(cls, ledger: 'baseledger.BaseLedger', wallet: 'basewallet.Wallet', - name: str = None, address_generator: dict = None): - return cls.from_dict(ledger, wallet, { - 'name': name, - 'seed': cls.mnemonic_class().make_seed(), - 'address_generator': address_generator or {} - }) - - @classmethod - def get_private_key_from_seed(cls, ledger: 'baseledger.BaseLedger', seed: str, password: str): - return cls.private_key_class.from_seed( - ledger, cls.mnemonic_class.mnemonic_to_seed(seed, password) - ) - - @classmethod - def keys_from_dict(cls, ledger: 'baseledger.BaseLedger', d: dict) \ - -> Tuple[str, Optional[PrivateKey], PubKey]: - seed = d.get('seed', '') - private_key_string = d.get('private_key', '') - private_key = None - public_key = None - encrypted = d.get('encrypted', False) - if not encrypted: - if seed: - private_key = cls.get_private_key_from_seed(ledger, seed, '') - public_key = private_key.public_key - elif private_key_string: - private_key = from_extended_key_string(ledger, private_key_string) - public_key = private_key.public_key - if public_key is None: - public_key = from_extended_key_string(ledger, d['public_key']) - return seed, private_key, public_key - - @classmethod - def from_dict(cls, ledger: 'baseledger.BaseLedger', wallet: 'basewallet.Wallet', d: dict): - seed, private_key, public_key = cls.keys_from_dict(ledger, d) - name = d.get('name') - if not name: - name = f'Account #{public_key.address}' - return cls( - ledger=ledger, - wallet=wallet, - name=name, - seed=seed, - private_key_string=d.get('private_key', ''), - encrypted=d.get('encrypted', False), - private_key=private_key, - public_key=public_key, - address_generator=d.get('address_generator', {}), - modified_on=d.get('modified_on', time.time()) - ) - - def to_dict(self, encrypt_password: str = None): - private_key_string, seed = self.private_key_string, self.seed - if not self.encrypted and self.private_key: - private_key_string = self.private_key.extended_key_string() - if not self.encrypted and encrypt_password: - if private_key_string: - private_key_string = aes_encrypt( - encrypt_password, private_key_string, self.get_init_vector('private_key') - ) - if seed: - seed = aes_encrypt(encrypt_password, self.seed, self.get_init_vector('seed')) - return { - 'ledger': self.ledger.get_id(), - 'name': self.name, - 'seed': seed, - 'encrypted': bool(self.encrypted or encrypt_password), - 'private_key': private_key_string, - 'public_key': self.public_key.extended_key_string(), - 'address_generator': self.address_generator.to_dict(self.receiving, self.change), - 'modified_on': self.modified_on - } - - def merge(self, d: dict): - if d.get('modified_on', 0) > self.modified_on: - self.name = d['name'] - self.modified_on = d.get('modified_on', time.time()) - assert self.address_generator.name == d['address_generator']['name'] - for chain_name in ('change', 'receiving'): - if chain_name in d['address_generator']: - chain_object = getattr(self, chain_name) - chain_object.merge(d['address_generator'][chain_name]) - - @property - def hash(self) -> bytes: - assert not self.encrypted, "Cannot hash an encrypted account." - return sha256(json.dumps(self.to_dict()).encode()) - - async def get_details(self, show_seed=False, **kwargs): - satoshis = await self.get_balance(**kwargs) - details = { - 'id': self.id, - 'name': self.name, - 'ledger': self.ledger.get_id(), - 'coins': round(satoshis/COIN, 2), - 'satoshis': satoshis, - 'encrypted': self.encrypted, - 'public_key': self.public_key.extended_key_string(), - 'address_generator': self.address_generator.to_dict(self.receiving, self.change) - } - if show_seed: - details['seed'] = self.seed - return details - - def decrypt(self, password: str) -> bool: - assert self.encrypted, "Key is not encrypted." - try: - seed = self._decrypt_seed(password) - except (ValueError, InvalidPasswordError): - return False - try: - private_key = self._decrypt_private_key_string(password) - except (TypeError, ValueError, InvalidPasswordError): - return False - self.seed = seed - self.private_key = private_key - self.private_key_string = "" - self.encrypted = False - return True - - def _decrypt_private_key_string(self, password: str) -> Optional[PrivateKey]: - if not self.private_key_string: - return None - private_key_string, self.init_vectors['private_key'] = aes_decrypt(password, self.private_key_string) - if not private_key_string: - return None - return from_extended_key_string( - self.ledger, private_key_string - ) - - def _decrypt_seed(self, password: str) -> str: - if not self.seed: - return "" - seed, self.init_vectors['seed'] = aes_decrypt(password, self.seed) - if not seed: - return "" - try: - Mnemonic().mnemonic_decode(seed) - except IndexError: - # failed to decode the seed, this either means it decrypted and is invalid - # or that we hit an edge case where an incorrect password gave valid padding - raise ValueError("Failed to decode seed.") - return seed - - def encrypt(self, password: str) -> bool: - assert not self.encrypted, "Key is already encrypted." - if self.seed: - self.seed = aes_encrypt(password, self.seed, self.get_init_vector('seed')) - if isinstance(self.private_key, PrivateKey): - self.private_key_string = aes_encrypt( - password, self.private_key.extended_key_string(), self.get_init_vector('private_key') - ) - self.private_key = None - self.encrypted = True - return True - - async def ensure_address_gap(self): - addresses = [] - for address_manager in self.address_managers.values(): - new_addresses = await address_manager.ensure_address_gap() - addresses.extend(new_addresses) - return addresses - - async def get_addresses(self, **constraints) -> List[str]: - rows = await self.ledger.db.select_addresses('address', accounts=[self], **constraints) - return [r[0] for r in rows] - - def get_address_records(self, **constraints): - return self.ledger.db.get_addresses(accounts=[self], **constraints) - - def get_address_count(self, **constraints): - return self.ledger.db.get_address_count(accounts=[self], **constraints) - - def get_private_key(self, chain: int, index: int) -> PrivateKey: - assert not self.encrypted, "Cannot get private key on encrypted wallet account." - return self.address_managers[chain].get_private_key(index) - - def get_public_key(self, chain: int, index: int) -> PubKey: - return self.address_managers[chain].get_public_key(index) - - def get_balance(self, confirmations: int = 0, **constraints): - if confirmations > 0: - height = self.ledger.headers.height - (confirmations-1) - constraints.update({'height__lte': height, 'height__gt': 0}) - return self.ledger.db.get_balance(accounts=[self], **constraints) - - async def get_max_gap(self): - change_gap = await self.change.get_max_gap() - receiving_gap = await self.receiving.get_max_gap() - return { - 'max_change_gap': change_gap, - 'max_receiving_gap': receiving_gap, - } - - def get_utxos(self, **constraints): - return self.ledger.get_utxos(wallet=self.wallet, accounts=[self], **constraints) - - def get_utxo_count(self, **constraints): - return self.ledger.get_utxo_count(wallet=self.wallet, accounts=[self], **constraints) - - def get_transactions(self, **constraints): - return self.ledger.get_transactions(wallet=self.wallet, accounts=[self], **constraints) - - def get_transaction_count(self, **constraints): - return self.ledger.get_transaction_count(wallet=self.wallet, accounts=[self], **constraints) - - async def fund(self, to_account, amount=None, everything=False, - outputs=1, broadcast=False, **constraints): - assert self.ledger == to_account.ledger, 'Can only transfer between accounts of the same ledger.' - tx_class = self.ledger.transaction_class - if everything: - utxos = await self.get_utxos(**constraints) - await self.ledger.reserve_outputs(utxos) - tx = await tx_class.create( - inputs=[tx_class.input_class.spend(txo) for txo in utxos], - outputs=[], - funding_accounts=[self], - change_account=to_account - ) - elif amount > 0: - to_address = await to_account.change.get_or_create_usable_address() - to_hash160 = to_account.ledger.address_to_hash160(to_address) - tx = await tx_class.create( - inputs=[], - outputs=[ - tx_class.output_class.pay_pubkey_hash(amount//outputs, to_hash160) - for _ in range(outputs) - ], - funding_accounts=[self], - change_account=self - ) - else: - raise ValueError('An amount is required.') - - if broadcast: - await self.ledger.broadcast(tx) - else: - await self.ledger.release_tx(tx) - - return tx diff --git a/lbry/wallet/client/basedatabase.py b/lbry/wallet/client/basedatabase.py deleted file mode 100644 index 7b607c47b..000000000 --- a/lbry/wallet/client/basedatabase.py +++ /dev/null @@ -1,652 +0,0 @@ -import logging -import asyncio -from binascii import hexlify -from concurrent.futures.thread import ThreadPoolExecutor - -from typing import Tuple, List, Union, Callable, Any, Awaitable, Iterable, Dict, Optional - -import sqlite3 - -from lbry.wallet.client.basetransaction import BaseTransaction, TXRefImmutable -from lbry.wallet.client.bip32 import PubKey - -log = logging.getLogger(__name__) -sqlite3.enable_callback_tracebacks(True) - - -class AIOSQLite: - - def __init__(self): - # has to be single threaded as there is no mapping of thread:connection - self.executor = ThreadPoolExecutor(max_workers=1) - self.connection: sqlite3.Connection = None - self._closing = False - self.query_count = 0 - - @classmethod - async def connect(cls, path: Union[bytes, str], *args, **kwargs): - sqlite3.enable_callback_tracebacks(True) - def _connect(): - return sqlite3.connect(path, *args, **kwargs) - db = cls() - db.connection = await asyncio.get_event_loop().run_in_executor(db.executor, _connect) - return db - - async def close(self): - if self._closing: - return - self._closing = True - await asyncio.get_event_loop().run_in_executor(self.executor, self.connection.close) - self.executor.shutdown(wait=True) - self.connection = None - - def executemany(self, sql: str, params: Iterable): - params = params if params is not None else [] - # this fetchall is needed to prevent SQLITE_MISUSE - return self.run(lambda conn: conn.executemany(sql, params).fetchall()) - - def executescript(self, script: str) -> Awaitable: - return self.run(lambda conn: conn.executescript(script)) - - def execute_fetchall(self, sql: str, parameters: Iterable = None) -> Awaitable[Iterable[sqlite3.Row]]: - parameters = parameters if parameters is not None else [] - return self.run(lambda conn: conn.execute(sql, parameters).fetchall()) - - def execute_fetchone(self, sql: str, parameters: Iterable = None) -> Awaitable[Iterable[sqlite3.Row]]: - parameters = parameters if parameters is not None else [] - return self.run(lambda conn: conn.execute(sql, parameters).fetchone()) - - def execute(self, sql: str, parameters: Iterable = None) -> Awaitable[sqlite3.Cursor]: - parameters = parameters if parameters is not None else [] - return self.run(lambda conn: conn.execute(sql, parameters)) - - def run(self, fun, *args, **kwargs) -> Awaitable: - return asyncio.get_event_loop().run_in_executor( - self.executor, lambda: self.__run_transaction(fun, *args, **kwargs) - ) - - def __run_transaction(self, fun: Callable[[sqlite3.Connection, Any, Any], Any], *args, **kwargs): - self.connection.execute('begin') - try: - self.query_count += 1 - result = fun(self.connection, *args, **kwargs) # type: ignore - self.connection.commit() - return result - except (Exception, OSError) as e: - log.exception('Error running transaction:', exc_info=e) - self.connection.rollback() - log.warning("rolled back") - raise - - def run_with_foreign_keys_disabled(self, fun, *args, **kwargs) -> Awaitable: - return asyncio.get_event_loop().run_in_executor( - self.executor, self.__run_transaction_with_foreign_keys_disabled, fun, args, kwargs - ) - - def __run_transaction_with_foreign_keys_disabled(self, - fun: Callable[[sqlite3.Connection, Any, Any], Any], - args, kwargs): - foreign_keys_enabled, = self.connection.execute("pragma foreign_keys").fetchone() - if not foreign_keys_enabled: - raise sqlite3.IntegrityError("foreign keys are disabled, use `AIOSQLite.run` instead") - try: - self.connection.execute('pragma foreign_keys=off').fetchone() - return self.__run_transaction(fun, *args, **kwargs) - finally: - self.connection.execute('pragma foreign_keys=on').fetchone() - - -def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''): - sql, values = [], {} - for key, constraint in constraints.items(): - tag = '0' - if '#' in key: - key, tag = key[:key.index('#')], key[key.index('#')+1:] - col, op, key = key, '=', key.replace('.', '_') - if not key: - sql.append(constraint) - continue - if key.startswith('$'): - values[key] = constraint - continue - if key.endswith('__not'): - col, op = col[:-len('__not')], '!=' - elif key.endswith('__is_null'): - col = col[:-len('__is_null')] - sql.append(f'{col} IS NULL') - continue - if key.endswith('__is_not_null'): - col = col[:-len('__is_not_null')] - sql.append(f'{col} IS NOT NULL') - continue - if key.endswith('__lt'): - col, op = col[:-len('__lt')], '<' - elif key.endswith('__lte'): - col, op = col[:-len('__lte')], '<=' - elif key.endswith('__gt'): - col, op = col[:-len('__gt')], '>' - elif key.endswith('__gte'): - col, op = col[:-len('__gte')], '>=' - elif key.endswith('__like'): - col, op = col[:-len('__like')], 'LIKE' - elif key.endswith('__not_like'): - col, op = col[:-len('__not_like')], 'NOT LIKE' - elif key.endswith('__in') or key.endswith('__not_in'): - if key.endswith('__in'): - col, op = col[:-len('__in')], 'IN' - else: - col, op = col[:-len('__not_in')], 'NOT IN' - if constraint: - if isinstance(constraint, (list, set, tuple)): - keys = [] - for i, val in enumerate(constraint): - keys.append(f':{key}{tag}_{i}') - values[f'{key}{tag}_{i}'] = val - sql.append(f'{col} {op} ({", ".join(keys)})') - elif isinstance(constraint, str): - sql.append(f'{col} {op} ({constraint})') - else: - raise ValueError(f"{col} requires a list, set or string as constraint value.") - continue - elif key.endswith('__any') or key.endswith('__or'): - where, subvalues = constraints_to_sql(constraint, ' OR ', key+tag+'_') - sql.append(f'({where})') - values.update(subvalues) - continue - if key.endswith('__and'): - where, subvalues = constraints_to_sql(constraint, ' AND ', key+tag+'_') - sql.append(f'({where})') - values.update(subvalues) - continue - sql.append(f'{col} {op} :{prepend_key}{key}{tag}') - values[prepend_key+key+tag] = constraint - return joiner.join(sql) if sql else '', values - - -def query(select, **constraints) -> Tuple[str, Dict[str, Any]]: - sql = [select] - limit = constraints.pop('limit', None) - offset = constraints.pop('offset', None) - order_by = constraints.pop('order_by', None) - - accounts = constraints.pop('accounts', []) - if accounts: - constraints['account__in'] = [a.public_key.address for a in accounts] - - where, values = constraints_to_sql(constraints) - if where: - sql.append('WHERE') - sql.append(where) - - if order_by: - sql.append('ORDER BY') - if isinstance(order_by, str): - sql.append(order_by) - elif isinstance(order_by, list): - sql.append(', '.join(order_by)) - else: - raise ValueError("order_by must be string or list") - - if limit is not None: - sql.append(f'LIMIT {limit}') - - if offset is not None: - sql.append(f'OFFSET {offset}') - - return ' '.join(sql), values - - -def interpolate(sql, values): - for k in sorted(values.keys(), reverse=True): - value = values[k] - if isinstance(value, bytes): - value = f"X'{hexlify(value).decode()}'" - elif isinstance(value, str): - value = f"'{value}'" - else: - value = str(value) - sql = sql.replace(f":{k}", value) - return sql - - -def rows_to_dict(rows, fields): - if rows: - return [dict(zip(fields, r)) for r in rows] - else: - return [] - - -class SQLiteMixin: - - SCHEMA_VERSION: Optional[str] = None - CREATE_TABLES_QUERY: str - MAX_QUERY_VARIABLES = 900 - - CREATE_VERSION_TABLE = """ - create table if not exists version ( - version text - ); - """ - - def __init__(self, path): - self._db_path = path - self.db: AIOSQLite = None - self.ledger = None - - async def open(self): - log.info("connecting to database: %s", self._db_path) - self.db = await AIOSQLite.connect(self._db_path, isolation_level=None) - if self.SCHEMA_VERSION: - tables = [t[0] for t in await self.db.execute_fetchall( - "SELECT name FROM sqlite_master WHERE type='table';" - )] - if tables: - if 'version' in tables: - version = await self.db.execute_fetchone("SELECT version FROM version LIMIT 1;") - if version == (self.SCHEMA_VERSION,): - return - await self.db.executescript('\n'.join( - f"DROP TABLE {table};" for table in tables - )) - await self.db.execute(self.CREATE_VERSION_TABLE) - await self.db.execute("INSERT INTO version VALUES (?)", (self.SCHEMA_VERSION,)) - await self.db.executescript(self.CREATE_TABLES_QUERY) - - async def close(self): - await self.db.close() - - @staticmethod - def _insert_sql(table: str, data: dict, ignore_duplicate: bool = False, - replace: bool = False) -> Tuple[str, List]: - columns, values = [], [] - for column, value in data.items(): - columns.append(column) - values.append(value) - policy = "" - if ignore_duplicate: - policy = " OR IGNORE" - if replace: - policy = " OR REPLACE" - sql = "INSERT{} INTO {} ({}) VALUES ({})".format( - policy, table, ', '.join(columns), ', '.join(['?'] * len(values)) - ) - return sql, values - - @staticmethod - def _update_sql(table: str, data: dict, where: str, - constraints: Union[list, tuple]) -> Tuple[str, list]: - columns, values = [], [] - for column, value in data.items(): - columns.append(f"{column} = ?") - values.append(value) - values.extend(constraints) - sql = "UPDATE {} SET {} WHERE {}".format( - table, ', '.join(columns), where - ) - return sql, values - - -class BaseDatabase(SQLiteMixin): - - SCHEMA_VERSION = "1.1" - - PRAGMAS = """ - pragma journal_mode=WAL; - """ - - CREATE_ACCOUNT_TABLE = """ - create table if not exists account_address ( - account text not null, - address text not null, - chain integer not null, - pubkey blob not null, - chain_code blob not null, - n integer not null, - depth integer not null, - primary key (account, address) - ); - create index if not exists address_account_idx on account_address (address, account); - """ - - CREATE_PUBKEY_ADDRESS_TABLE = """ - create table if not exists pubkey_address ( - address text primary key, - history text, - used_times integer not null default 0 - ); - """ - - CREATE_TX_TABLE = """ - create table if not exists tx ( - txid text primary key, - raw blob not null, - height integer not null, - position integer not null, - is_verified boolean not null default 0 - ); - """ - - CREATE_TXO_TABLE = """ - create table if not exists txo ( - txid text references tx, - txoid text primary key, - address text references pubkey_address, - position integer not null, - amount integer not null, - script blob not null, - is_reserved boolean not null default 0 - ); - create index if not exists txo_address_idx on txo (address); - """ - - CREATE_TXI_TABLE = """ - create table if not exists txi ( - txid text references tx, - txoid text references txo, - address text references pubkey_address - ); - create index if not exists txi_address_idx on txi (address); - create index if not exists txi_txoid_idx on txi (txoid); - """ - - CREATE_TABLES_QUERY = ( - PRAGMAS + - CREATE_ACCOUNT_TABLE + - CREATE_PUBKEY_ADDRESS_TABLE + - CREATE_TX_TABLE + - CREATE_TXO_TABLE + - CREATE_TXI_TABLE - ) - - @staticmethod - def txo_to_row(tx, address, txo): - return { - 'txid': tx.id, - 'txoid': txo.id, - 'address': address, - 'position': txo.position, - 'amount': txo.amount, - 'script': sqlite3.Binary(txo.script.source) - } - - @staticmethod - def tx_to_row(tx): - return { - 'txid': tx.id, - 'raw': sqlite3.Binary(tx.raw), - 'height': tx.height, - 'position': tx.position, - 'is_verified': tx.is_verified - } - - async def insert_transaction(self, tx): - await self.db.execute_fetchall(*self._insert_sql('tx', self.tx_to_row(tx))) - - async def update_transaction(self, tx): - await self.db.execute_fetchall(*self._update_sql("tx", { - 'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified - }, 'txid = ?', (tx.id,))) - - def _transaction_io(self, conn: sqlite3.Connection, tx: BaseTransaction, address, txhash, history): - conn.execute(*self._insert_sql('tx', self.tx_to_row(tx), replace=True)) - - for txo in tx.outputs: - if txo.script.is_pay_pubkey_hash and txo.pubkey_hash == txhash: - conn.execute(*self._insert_sql( - "txo", self.txo_to_row(tx, address, txo), ignore_duplicate=True - )).fetchall() - elif txo.script.is_pay_script_hash: - # TODO: implement script hash payments - log.warning('Database.save_transaction_io: pay script hash is not implemented!') - - for txi in tx.inputs: - if txi.txo_ref.txo is not None: - txo = txi.txo_ref.txo - if txo.has_address and txo.get_address(self.ledger) == address: - conn.execute(*self._insert_sql("txi", { - 'txid': tx.id, - 'txoid': txo.id, - 'address': address, - }, ignore_duplicate=True)).fetchall() - - conn.execute( - "UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?", - (history, history.count(':') // 2, address) - ) - - def save_transaction_io(self, tx: BaseTransaction, address, txhash, history): - return self.db.run(self._transaction_io, tx, address, txhash, history) - - def save_transaction_io_batch(self, txs: Iterable[BaseTransaction], address, txhash, history): - def __many(conn): - for tx in txs: - self._transaction_io(conn, tx, address, txhash, history) - return self.db.run(__many) - - async def reserve_outputs(self, txos, is_reserved=True): - txoids = ((is_reserved, txo.id) for txo in txos) - await self.db.executemany("UPDATE txo SET is_reserved = ? WHERE txoid = ?", txoids) - - async def release_outputs(self, txos): - await self.reserve_outputs(txos, is_reserved=False) - - async def rewind_blockchain(self, above_height): # pylint: disable=no-self-use - # TODO: - # 1. delete transactions above_height - # 2. update address histories removing deleted TXs - return True - - async def select_transactions(self, cols, accounts=None, **constraints): - if not {'txid', 'txid__in'}.intersection(constraints): - assert accounts, "'accounts' argument required when no 'txid' constraint is present" - constraints.update({ - f'$account{i}': a.public_key.address for i, a in enumerate(accounts) - }) - account_values = ', '.join([f':$account{i}' for i in range(len(accounts))]) - where = f" WHERE account_address.account IN ({account_values})" - constraints['txid__in'] = f""" - SELECT txo.txid FROM txo JOIN account_address USING (address) {where} - UNION - SELECT txi.txid FROM txi JOIN account_address USING (address) {where} - """ - return await self.db.execute_fetchall( - *query(f"SELECT {cols} FROM tx", **constraints) - ) - - async def get_transactions(self, wallet=None, **constraints): - tx_rows = await self.select_transactions( - 'txid, raw, height, position, is_verified', - order_by=constraints.pop('order_by', ["height=0 DESC", "height DESC", "position DESC"]), - **constraints - ) - - if not tx_rows: - return [] - - txids, txs, txi_txoids = [], [], [] - for row in tx_rows: - txids.append(row[0]) - txs.append(self.ledger.transaction_class( - raw=row[1], height=row[2], position=row[3], is_verified=bool(row[4]) - )) - for txi in txs[-1].inputs: - txi_txoids.append(txi.txo_ref.id) - - step = self.MAX_QUERY_VARIABLES - annotated_txos = {} - for offset in range(0, len(txids), step): - annotated_txos.update({ - txo.id: txo for txo in - (await self.get_txos( - wallet=wallet, - txid__in=txids[offset:offset+step], - )) - }) - - referenced_txos = {} - for offset in range(0, len(txi_txoids), step): - referenced_txos.update({ - txo.id: txo for txo in - (await self.get_txos( - wallet=wallet, - txoid__in=txi_txoids[offset:offset+step], - )) - }) - - for tx in txs: - for txi in tx.inputs: - txo = referenced_txos.get(txi.txo_ref.id) - if txo: - txi.txo_ref = txo.ref - for txo in tx.outputs: - _txo = annotated_txos.get(txo.id) - if _txo: - txo.update_annotations(_txo) - else: - txo.update_annotations(None) - - return txs - - async def get_transaction_count(self, **constraints): - constraints.pop('wallet', None) - constraints.pop('offset', None) - constraints.pop('limit', None) - constraints.pop('order_by', None) - count = await self.select_transactions('count(*)', **constraints) - return count[0][0] - - async def get_transaction(self, **constraints): - txs = await self.get_transactions(limit=1, **constraints) - if txs: - return txs[0] - - async def select_txos(self, cols, **constraints): - sql = f"SELECT {cols} FROM txo JOIN tx USING (txid)" - if 'accounts' in constraints: - sql += " JOIN account_address USING (address)" - return await self.db.execute_fetchall(*query(sql, **constraints)) - - async def get_txos(self, wallet=None, no_tx=False, **constraints): - my_accounts = {a.public_key.address for a in wallet.accounts} if wallet else set() - if 'order_by' not in constraints: - constraints['order_by'] = [ - "tx.height=0 DESC", "tx.height DESC", "tx.position DESC", "txo.position" - ] - rows = await self.select_txos( - """ - tx.txid, raw, tx.height, tx.position, tx.is_verified, txo.position, amount, script, ( - select group_concat(account||"|"||chain) from account_address - where account_address.address=txo.address - ) - """, - **constraints - ) - txos = [] - txs = {} - output_class = self.ledger.transaction_class.output_class - for row in rows: - if no_tx: - txo = output_class( - amount=row[6], - script=output_class.script_class(row[7]), - tx_ref=TXRefImmutable.from_id(row[0], row[2]), - position=row[5] - ) - else: - if row[0] not in txs: - txs[row[0]] = self.ledger.transaction_class( - row[1], height=row[2], position=row[3], is_verified=row[4] - ) - txo = txs[row[0]].outputs[row[5]] - row_accounts = dict(a.split('|') for a in row[8].split(',')) - account_match = set(row_accounts) & my_accounts - if account_match: - txo.is_my_account = True - txo.is_change = row_accounts[account_match.pop()] == '1' - else: - txo.is_change = txo.is_my_account = False - txos.append(txo) - return txos - - async def get_txo_count(self, **constraints): - constraints.pop('wallet', None) - constraints.pop('offset', None) - constraints.pop('limit', None) - constraints.pop('order_by', None) - count = await self.select_txos('count(*)', **constraints) - return count[0][0] - - @staticmethod - def constrain_utxo(constraints): - constraints['is_reserved'] = False - constraints['txoid__not_in'] = "SELECT txoid FROM txi" - - def get_utxos(self, **constraints): - self.constrain_utxo(constraints) - return self.get_txos(**constraints) - - def get_utxo_count(self, **constraints): - self.constrain_utxo(constraints) - return self.get_txo_count(**constraints) - - async def get_balance(self, wallet=None, accounts=None, **constraints): - assert wallet or accounts, \ - "'wallet' or 'accounts' constraints required to calculate balance" - constraints['accounts'] = accounts or wallet.accounts - self.constrain_utxo(constraints) - balance = await self.select_txos('SUM(amount)', **constraints) - return balance[0][0] or 0 - - async def select_addresses(self, cols, **constraints): - return await self.db.execute_fetchall(*query( - f"SELECT {cols} FROM pubkey_address JOIN account_address USING (address)", - **constraints - )) - - async def get_addresses(self, cols=None, **constraints): - cols = cols or ( - 'address', 'account', 'chain', 'history', 'used_times', - 'pubkey', 'chain_code', 'n', 'depth' - ) - addresses = rows_to_dict(await self.select_addresses(', '.join(cols), **constraints), cols) - if 'pubkey' in cols: - for address in addresses: - address['pubkey'] = PubKey( - self.ledger, address.pop('pubkey'), address.pop('chain_code'), - address.pop('n'), address.pop('depth') - ) - return addresses - - async def get_address_count(self, cols=None, **constraints): - count = await self.select_addresses('count(*)', **constraints) - return count[0][0] - - async def get_address(self, **constraints): - addresses = await self.get_addresses(limit=1, **constraints) - if addresses: - return addresses[0] - - async def add_keys(self, account, chain, pubkeys): - await self.db.executemany( - "insert or ignore into account_address " - "(account, address, chain, pubkey, chain_code, n, depth) values " - "(?, ?, ?, ?, ?, ?, ?)", (( - account.id, k.address, chain, - sqlite3.Binary(k.pubkey_bytes), - sqlite3.Binary(k.chain_code), - k.n, k.depth - ) for k in pubkeys) - ) - await self.db.executemany( - "insert or ignore into pubkey_address (address) values (?)", - ((pubkey.address,) for pubkey in pubkeys) - ) - - async def _set_address_history(self, address, history): - await self.db.execute_fetchall( - "UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?", - (history, history.count(':')//2, address) - ) - - async def set_address_history(self, address, history): - await self._set_address_history(address, history) diff --git a/lbry/wallet/client/baseheader.py b/lbry/wallet/client/baseheader.py deleted file mode 100644 index 6450804e6..000000000 --- a/lbry/wallet/client/baseheader.py +++ /dev/null @@ -1,246 +0,0 @@ -import asyncio -import hashlib -import os -import logging -from contextlib import asynccontextmanager -from io import BytesIO -from typing import Optional, Iterator, Tuple -from binascii import hexlify - -from lbry.wallet.client.util import ArithUint256 -from lbry.crypto.hash import double_sha256 - -log = logging.getLogger(__name__) - - -class InvalidHeader(Exception): - - def __init__(self, height, message): - super().__init__(message) - self.message = message - self.height = height - - -class BaseHeaders: - - header_size: int - chunk_size: int - - max_target: int - genesis_hash: Optional[bytes] - target_timespan: int - - validate_difficulty: bool = True - checkpoint = None - - def __init__(self, path) -> None: - if path == ':memory:': - self.io = BytesIO() - self.path = path - self._size: Optional[int] = None - - async def open(self): - if self.path != ':memory:': - if not os.path.exists(self.path): - self.io = open(self.path, 'w+b') - else: - self.io = open(self.path, 'r+b') - - async def close(self): - self.io.close() - - @staticmethod - def serialize(header: dict) -> bytes: - raise NotImplementedError - - @staticmethod - def deserialize(height, header): - raise NotImplementedError - - def get_next_chunk_target(self, chunk: int) -> ArithUint256: - return ArithUint256(self.max_target) - - @staticmethod - def get_next_block_target(chunk_target: ArithUint256, previous: Optional[dict], - current: Optional[dict]) -> ArithUint256: - return chunk_target - - def __len__(self) -> int: - if self._size is None: - self._size = self.io.seek(0, os.SEEK_END) // self.header_size - return self._size - - def __bool__(self): - return True - - def __getitem__(self, height) -> dict: - if isinstance(height, slice): - raise NotImplementedError("Slicing of header chain has not been implemented yet.") - if not 0 <= height <= self.height: - raise IndexError(f"{height} is out of bounds, current height: {self.height}") - return self.deserialize(height, self.get_raw_header(height)) - - def get_raw_header(self, height) -> bytes: - self.io.seek(height * self.header_size, os.SEEK_SET) - return self.io.read(self.header_size) - - @property - def height(self) -> int: - return len(self)-1 - - @property - def bytes_size(self): - return len(self) * self.header_size - - def hash(self, height=None) -> bytes: - return self.hash_header( - self.get_raw_header(height if height is not None else self.height) - ) - - @staticmethod - def hash_header(header: bytes) -> bytes: - if header is None: - return b'0' * 64 - return hexlify(double_sha256(header)[::-1]) - - @asynccontextmanager - async def checkpointed_connector(self): - buf = BytesIO() - try: - yield buf - finally: - await asyncio.sleep(0) - final_height = len(self) + buf.tell() // self.header_size - verifiable_bytes = (self.checkpoint[0] - len(self)) * self.header_size if self.checkpoint else 0 - if verifiable_bytes > 0 and final_height >= self.checkpoint[0]: - buf.seek(0) - self.io.seek(0) - h = hashlib.sha256() - h.update(self.io.read()) - h.update(buf.read(verifiable_bytes)) - if h.hexdigest().encode() == self.checkpoint[1]: - buf.seek(0) - self._write(len(self), buf.read(verifiable_bytes)) - remaining = buf.read() - buf.seek(0) - buf.write(remaining) - buf.truncate() - else: - log.warning("Checkpoint mismatch, connecting headers through slow method.") - if buf.tell() > 0: - await self.connect(len(self), buf.getvalue()) - - async def connect(self, start: int, headers: bytes) -> int: - added = 0 - bail = False - for height, chunk in self._iterate_chunks(start, headers): - try: - # validate_chunk() is CPU bound and reads previous chunks from file system - self.validate_chunk(height, chunk) - except InvalidHeader as e: - bail = True - chunk = chunk[:(height-e.height)*self.header_size] - added += self._write(height, chunk) if chunk else 0 - if bail: - break - return added - - def _write(self, height, verified_chunk): - self.io.seek(height * self.header_size, os.SEEK_SET) - written = self.io.write(verified_chunk) // self.header_size - self.io.truncate() - # .seek()/.write()/.truncate() might also .flush() when needed - # the goal here is mainly to ensure we're definitely flush()'ing - self.io.flush() - self._size = self.io.tell() // self.header_size - return written - - def validate_chunk(self, height, chunk): - previous_hash, previous_header, previous_previous_header = None, None, None - if height > 0: - previous_header = self[height-1] - previous_hash = self.hash(height-1) - if height > 1: - previous_previous_header = self[height-2] - chunk_target = self.get_next_chunk_target(height // 2016 - 1) - for current_hash, current_header in self._iterate_headers(height, chunk): - block_target = self.get_next_block_target(chunk_target, previous_previous_header, previous_header) - self.validate_header(height, current_hash, current_header, previous_hash, block_target) - previous_previous_header = previous_header - previous_header = current_header - previous_hash = current_hash - - def validate_header(self, height: int, current_hash: bytes, - header: dict, previous_hash: bytes, target: ArithUint256): - - if previous_hash is None: - if self.genesis_hash is not None and self.genesis_hash != current_hash: - raise InvalidHeader( - height, f"genesis header doesn't match: {current_hash.decode()} " - f"vs expected {self.genesis_hash.decode()}") - return - - if header['prev_block_hash'] != previous_hash: - raise InvalidHeader( - height, "previous hash mismatch: {} vs expected {}".format( - header['prev_block_hash'].decode(), previous_hash.decode()) - ) - - if self.validate_difficulty: - - if header['bits'] != target.compact: - raise InvalidHeader( - height, "bits mismatch: {} vs expected {}".format( - header['bits'], target.compact) - ) - - proof_of_work = self.get_proof_of_work(current_hash) - if proof_of_work > target: - raise InvalidHeader( - height, f"insufficient proof of work: {proof_of_work.value} vs target {target.value}" - ) - - async def repair(self): - previous_header_hash = fail = None - batch_size = 36 - for start_height in range(0, self.height, batch_size): - self.io.seek(self.header_size * start_height) - headers = self.io.read(self.header_size*batch_size) - if len(headers) % self.header_size != 0: - headers = headers[:(len(headers) // self.header_size) * self.header_size] - for header_hash, header in self._iterate_headers(start_height, headers): - height = header['block_height'] - if height: - if header['prev_block_hash'] != previous_header_hash: - fail = True - else: - if header_hash != self.genesis_hash: - fail = True - if fail: - log.warning("Header file corrupted at height %s, truncating it.", height - 1) - self.io.seek(max(0, (height - 1)) * self.header_size, os.SEEK_SET) - self.io.truncate() - self.io.flush() - self._size = None - return - previous_header_hash = header_hash - - @staticmethod - def get_proof_of_work(header_hash: bytes) -> ArithUint256: - return ArithUint256(int(b'0x' + header_hash, 16)) - - def _iterate_chunks(self, height: int, headers: bytes) -> Iterator[Tuple[int, bytes]]: - assert len(headers) % self.header_size == 0, f"{len(headers)} {len(headers)%self.header_size}" - start = 0 - end = (self.chunk_size - height % self.chunk_size) * self.header_size - while start < end: - yield height + (start // self.header_size), headers[start:end] - start = end - end = min(len(headers), end + self.chunk_size * self.header_size) - - def _iterate_headers(self, height: int, headers: bytes) -> Iterator[Tuple[bytes, dict]]: - assert len(headers) % self.header_size == 0, len(headers) - for idx in range(len(headers) // self.header_size): - start, end = idx * self.header_size, (idx + 1) * self.header_size - header = headers[start:end] - yield self.hash_header(header), self.deserialize(height+idx, header) diff --git a/lbry/wallet/client/baseledger.py b/lbry/wallet/client/baseledger.py deleted file mode 100644 index 1f8e993e0..000000000 --- a/lbry/wallet/client/baseledger.py +++ /dev/null @@ -1,605 +0,0 @@ -import base64 -import os -import asyncio -import logging -import zlib -from functools import partial -from binascii import hexlify, unhexlify -from io import StringIO - -from typing import Dict, Type, Iterable, List, Optional -from operator import itemgetter -from collections import namedtuple - -import pylru -from lbry.wallet.client.basetransaction import BaseTransaction -from lbry.wallet.tasks import TaskGroup -from lbry.wallet.client import baseaccount, basenetwork, basetransaction -from lbry.wallet.client.basedatabase import BaseDatabase -from lbry.wallet.client.baseheader import BaseHeaders -from lbry.wallet.client.coinselection import CoinSelector -from lbry.wallet.client.constants import COIN, NULL_HASH32 -from lbry.wallet.stream import StreamController -from lbry.crypto.hash import hash160, double_sha256, sha256 -from lbry.crypto.base58 import Base58 -from lbry.wallet.client.bip32 import PubKey, PrivateKey - -log = logging.getLogger(__name__) - -LedgerType = Type['BaseLedger'] - - -class LedgerRegistry(type): - - ledgers: Dict[str, LedgerType] = {} - - def __new__(mcs, name, bases, attrs): - cls: LedgerType = super().__new__(mcs, name, bases, attrs) - if not (name == 'BaseLedger' and not bases): - ledger_id = cls.get_id() - assert ledger_id not in mcs.ledgers,\ - f'Ledger with id "{ledger_id}" already registered.' - mcs.ledgers[ledger_id] = cls - return cls - - @classmethod - def get_ledger_class(mcs, ledger_id: str) -> LedgerType: - return mcs.ledgers[ledger_id] - - -class TransactionEvent(namedtuple('TransactionEvent', ('address', 'tx'))): - pass - - -class AddressesGeneratedEvent(namedtuple('AddressesGeneratedEvent', ('address_manager', 'addresses'))): - pass - - -class BlockHeightEvent(namedtuple('BlockHeightEvent', ('height', 'change'))): - pass - - -class TransactionCacheItem: - __slots__ = '_tx', 'lock', 'has_tx' - - def __init__(self, - tx: Optional[basetransaction.BaseTransaction] = None, - lock: Optional[asyncio.Lock] = None): - self.has_tx = asyncio.Event() - self.lock = lock or asyncio.Lock() - self._tx = self.tx = tx - - @property - def tx(self) -> Optional[basetransaction.BaseTransaction]: - return self._tx - - @tx.setter - def tx(self, tx: basetransaction.BaseTransaction): - self._tx = tx - if tx is not None: - self.has_tx.set() - - -class BaseLedger(metaclass=LedgerRegistry): - - name: str - symbol: str - network_name: str - - database_class = BaseDatabase - account_class = baseaccount.BaseAccount - network_class = basenetwork.BaseNetwork - transaction_class = basetransaction.BaseTransaction - - headers_class: Type[BaseHeaders] - - pubkey_address_prefix: bytes - script_address_prefix: bytes - extended_public_key_prefix: bytes - extended_private_key_prefix: bytes - - default_fee_per_byte = 10 - - def __init__(self, config=None): - self.config = config or {} - self.db: BaseDatabase = self.config.get('db') or self.database_class( - os.path.join(self.path, "blockchain.db") - ) - self.db.ledger = self - self.headers: BaseHeaders = self.config.get('headers') or self.headers_class( - os.path.join(self.path, "headers") - ) - self.network = self.config.get('network') or self.network_class(self) - self.network.on_header.listen(self.receive_header) - self.network.on_status.listen(self.process_status_update) - self.network.on_connected.listen(self.join_network) - - self.accounts = [] - self.fee_per_byte: int = self.config.get('fee_per_byte', self.default_fee_per_byte) - - self._on_transaction_controller = StreamController() - self.on_transaction = self._on_transaction_controller.stream - self.on_transaction.listen( - lambda e: log.info( - '(%s) on_transaction: address=%s, height=%s, is_verified=%s, tx.id=%s', - self.get_id(), e.address, e.tx.height, e.tx.is_verified, e.tx.id - ) - ) - - self._on_address_controller = StreamController() - self.on_address = self._on_address_controller.stream - self.on_address.listen( - lambda e: log.info('(%s) on_address: %s', self.get_id(), e.addresses) - ) - - self._on_header_controller = StreamController() - self.on_header = self._on_header_controller.stream - self.on_header.listen( - lambda change: log.info( - '%s: added %s header blocks, final height %s', - self.get_id(), change, self.headers.height - ) - ) - self._download_height = 0 - - self._on_ready_controller = StreamController() - self.on_ready = self._on_ready_controller.stream - - self._tx_cache = pylru.lrucache(100000) - self._update_tasks = TaskGroup() - self._utxo_reservation_lock = asyncio.Lock() - self._header_processing_lock = asyncio.Lock() - self._address_update_locks: Dict[str, asyncio.Lock] = {} - - self.coin_selection_strategy = None - self._known_addresses_out_of_sync = set() - - @classmethod - def get_id(cls): - return '{}_{}'.format(cls.symbol.lower(), cls.network_name.lower()) - - @classmethod - def hash160_to_address(cls, h160): - raw_address = cls.pubkey_address_prefix + h160 - return Base58.encode(bytearray(raw_address + double_sha256(raw_address)[0:4])) - - @staticmethod - def address_to_hash160(address): - return Base58.decode(address)[1:21] - - @classmethod - def is_valid_address(cls, address): - decoded = Base58.decode_check(address) - return decoded[0] == cls.pubkey_address_prefix[0] - - @classmethod - def public_key_to_address(cls, public_key): - return cls.hash160_to_address(hash160(public_key)) - - @staticmethod - def private_key_to_wif(private_key): - return b'\x1c' + private_key + b'\x01' - - @property - def path(self): - return os.path.join(self.config['data_path'], self.get_id()) - - def add_account(self, account: baseaccount.BaseAccount): - self.accounts.append(account) - - async def _get_account_and_address_info_for_address(self, wallet, address): - match = await self.db.get_address(accounts=wallet.accounts, address=address) - if match: - for account in wallet.accounts: - if match['account'] == account.public_key.address: - return account, match - - async def get_private_key_for_address(self, wallet, address) -> Optional[PrivateKey]: - match = await self._get_account_and_address_info_for_address(wallet, address) - if match: - account, address_info = match - return account.get_private_key(address_info['chain'], address_info['pubkey'].n) - return None - - async def get_public_key_for_address(self, wallet, address) -> Optional[PubKey]: - match = await self._get_account_and_address_info_for_address(wallet, address) - if match: - _, address_info = match - return address_info['pubkey'] - return None - - async def get_account_for_address(self, wallet, address): - match = await self._get_account_and_address_info_for_address(wallet, address) - if match: - return match[0] - - async def get_effective_amount_estimators(self, funding_accounts: Iterable[baseaccount.BaseAccount]): - estimators = [] - for account in funding_accounts: - utxos = await account.get_utxos() - for utxo in utxos: - estimators.append(utxo.get_estimator(self)) - return estimators - - async def get_addresses(self, **constraints): - return await self.db.get_addresses(**constraints) - - def get_address_count(self, **constraints): - return self.db.get_address_count(**constraints) - - async def get_spendable_utxos(self, amount: int, funding_accounts): - async with self._utxo_reservation_lock: - txos = await self.get_effective_amount_estimators(funding_accounts) - fee = self.transaction_class.output_class.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(self) - selector = CoinSelector(amount, fee) - spendables = selector.select(txos, self.coin_selection_strategy) - if spendables: - await self.reserve_outputs(s.txo for s in spendables) - return spendables - - def reserve_outputs(self, txos): - return self.db.reserve_outputs(txos) - - def release_outputs(self, txos): - return self.db.release_outputs(txos) - - def release_tx(self, tx): - return self.release_outputs([txi.txo_ref.txo for txi in tx.inputs]) - - def get_utxos(self, **constraints): - return self.db.get_utxos(**constraints) - - def get_utxo_count(self, **constraints): - return self.db.get_utxo_count(**constraints) - - def get_transactions(self, **constraints): - return self.db.get_transactions(**constraints) - - def get_transaction_count(self, **constraints): - return self.db.get_transaction_count(**constraints) - - async def get_local_status_and_history(self, address, history=None): - if not history: - address_details = await self.db.get_address(address=address) - history = address_details['history'] or '' - parts = history.split(':')[:-1] - return ( - hexlify(sha256(history.encode())).decode() if history else None, - list(zip(parts[0::2], map(int, parts[1::2]))) - ) - - @staticmethod - def get_root_of_merkle_tree(branches, branch_positions, working_branch): - for i, branch in enumerate(branches): - other_branch = unhexlify(branch)[::-1] - other_branch_on_left = bool((branch_positions >> i) & 1) - if other_branch_on_left: - combined = other_branch + working_branch - else: - combined = working_branch + other_branch - working_branch = double_sha256(combined) - return hexlify(working_branch[::-1]) - - async def start(self): - if not os.path.exists(self.path): - os.mkdir(self.path) - await asyncio.wait([ - self.db.open(), - self.headers.open() - ]) - first_connection = self.network.on_connected.first - asyncio.ensure_future(self.network.start()) - await first_connection - async with self._header_processing_lock: - await self._update_tasks.add(self.initial_headers_sync()) - await self._on_ready_controller.stream.first - - async def join_network(self, *_): - log.info("Subscribing and updating accounts.") - async with self._header_processing_lock: - await self.update_headers() - await self.subscribe_accounts() - await self._update_tasks.done.wait() - self._on_ready_controller.add(True) - - async def stop(self): - self._update_tasks.cancel() - await self._update_tasks.done.wait() - await self.network.stop() - await self.db.close() - await self.headers.close() - - @property - def local_height_including_downloaded_height(self): - return max(self.headers.height, self._download_height) - - async def initial_headers_sync(self): - target = self.network.remote_height + 1 - current = len(self.headers) - get_chunk = partial(self.network.retriable_call, self.network.get_headers, count=4096, b64=True) - chunks = [asyncio.create_task(get_chunk(height)) for height in range(current, target, 4096)] - total = 0 - async with self.headers.checkpointed_connector() as buffer: - for chunk in chunks: - headers = await chunk - total += buffer.write( - zlib.decompress(base64.b64decode(headers['base64']), wbits=-15, bufsize=600_000) - ) - self._download_height = current + total // self.headers.header_size - log.info("Headers sync: %s / %s", self._download_height, target) - - async def update_headers(self, height=None, headers=None, subscription_update=False): - rewound = 0 - while True: - - if height is None or height > len(self.headers): - # sometimes header subscription updates are for a header in the future - # which can't be connected, so we do a normal header sync instead - height = len(self.headers) - headers = None - subscription_update = False - - if not headers: - header_response = await self.network.retriable_call(self.network.get_headers, height, 2001) - headers = header_response['hex'] - - if not headers: - # Nothing to do, network thinks we're already at the latest height. - return - - added = await self.headers.connect(height, unhexlify(headers)) - if added > 0: - height += added - self._on_header_controller.add( - BlockHeightEvent(self.headers.height, added)) - - if rewound > 0: - # we started rewinding blocks and apparently found - # a new chain - rewound = 0 - await self.db.rewind_blockchain(height) - - if subscription_update: - # subscription updates are for latest header already - # so we don't need to check if there are newer / more - # on another loop of update_headers(), just return instead - return - - elif added == 0: - # we had headers to connect but none got connected, probably a reorganization - height -= 1 - rewound += 1 - log.warning( - "Blockchain Reorganization: attempting rewind to height %s from starting height %s", - height, height+rewound - ) - - else: - raise IndexError(f"headers.connect() returned negative number ({added})") - - if height < 0: - raise IndexError( - "Blockchain reorganization rewound all the way back to genesis hash. " - "Something is very wrong. Maybe you are on the wrong blockchain?" - ) - - if rewound >= 100: - raise IndexError( - "Blockchain reorganization dropped {} headers. This is highly unusual. " - "Will not continue to attempt reorganizing. Please, delete the ledger " - "synchronization directory inside your wallet directory (folder: '{}') and " - "restart the program to synchronize from scratch." - .format(rewound, self.get_id()) - ) - - headers = None # ready to download some more headers - - # if we made it this far and this was a subscription_update - # it means something went wrong and now we're doing a more - # robust sync, turn off subscription update shortcut - subscription_update = False - - async def receive_header(self, response): - async with self._header_processing_lock: - header = response[0] - await self.update_headers( - height=header['height'], headers=header['hex'], subscription_update=True - ) - - async def subscribe_accounts(self): - if self.network.is_connected and self.accounts: - await asyncio.wait([ - self.subscribe_account(a) for a in self.accounts - ]) - - async def subscribe_account(self, account: baseaccount.BaseAccount): - for address_manager in account.address_managers.values(): - await self.subscribe_addresses(address_manager, await address_manager.get_addresses()) - await account.ensure_address_gap() - - async def unsubscribe_account(self, account: baseaccount.BaseAccount): - for address in await account.get_addresses(): - await self.network.unsubscribe_address(address) - - async def announce_addresses(self, address_manager: baseaccount.AddressManager, addresses: List[str]): - await self.subscribe_addresses(address_manager, addresses) - await self._on_address_controller.add( - AddressesGeneratedEvent(address_manager, addresses) - ) - - async def subscribe_addresses(self, address_manager: baseaccount.AddressManager, addresses: List[str]): - if self.network.is_connected and addresses: - await asyncio.wait([ - self.subscribe_address(address_manager, address) for address in addresses - ]) - - async def subscribe_address(self, address_manager: baseaccount.AddressManager, address: str): - remote_status = await self.network.subscribe_address(address) - self._update_tasks.add(self.update_history(address, remote_status, address_manager)) - - def process_status_update(self, update): - address, remote_status = update - self._update_tasks.add(self.update_history(address, remote_status)) - - async def update_history(self, address, remote_status, - address_manager: baseaccount.AddressManager = None): - - async with self._address_update_locks.setdefault(address, asyncio.Lock()): - self._known_addresses_out_of_sync.discard(address) - - local_status, local_history = await self.get_local_status_and_history(address) - - if local_status == remote_status: - return True - - remote_history = await self.network.retriable_call(self.network.get_history, address) - remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history)) - we_need = set(remote_history) - set(local_history) - if not we_need: - return True - - cache_tasks: List[asyncio.Future[BaseTransaction]] = [] - synced_history = StringIO() - for i, (txid, remote_height) in enumerate(remote_history): - if i < len(local_history) and local_history[i] == (txid, remote_height) and not cache_tasks: - synced_history.write(f'{txid}:{remote_height}:') - else: - check_local = (txid, remote_height) not in we_need - cache_tasks.append(asyncio.ensure_future( - self.cache_transaction(txid, remote_height, check_local=check_local) - )) - - synced_txs = [] - for task in cache_tasks: - tx = await task - - check_db_for_txos = [] - for txi in tx.inputs: - if txi.txo_ref.txo is not None: - continue - cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.id) - if cache_item is not None: - if cache_item.tx is None: - await cache_item.has_tx.wait() - assert cache_item.tx is not None - txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref - else: - check_db_for_txos.append(txi.txo_ref.id) - - referenced_txos = {} if not check_db_for_txos else { - txo.id: txo for txo in await self.db.get_txos(txoid__in=check_db_for_txos, no_tx=True) - } - - for txi in tx.inputs: - if txi.txo_ref.txo is not None: - continue - referenced_txo = referenced_txos.get(txi.txo_ref.id) - if referenced_txo is not None: - txi.txo_ref = referenced_txo.ref - - synced_history.write(f'{tx.id}:{tx.height}:') - synced_txs.append(tx) - - await self.db.save_transaction_io_batch( - synced_txs, address, self.address_to_hash160(address), synced_history.getvalue() - ) - await asyncio.wait([ - self._on_transaction_controller.add(TransactionEvent(address, tx)) - for tx in synced_txs - ]) - - if address_manager is None: - address_manager = await self.get_address_manager_for_address(address) - - if address_manager is not None: - await address_manager.ensure_address_gap() - - local_status, local_history = \ - await self.get_local_status_and_history(address, synced_history.getvalue()) - if local_status != remote_status: - if local_history == remote_history: - return True - log.warning( - "Wallet is out of sync after syncing. Remote: %s with %d items, local: %s with %d items", - remote_status, len(remote_history), local_status, len(local_history) - ) - log.warning("local: %s", local_history) - log.warning("remote: %s", remote_history) - self._known_addresses_out_of_sync.add(address) - return False - else: - return True - - async def cache_transaction(self, txid, remote_height, check_local=True): - cache_item = self._tx_cache.get(txid) - if cache_item is None: - cache_item = self._tx_cache[txid] = TransactionCacheItem() - elif cache_item.tx is not None and \ - cache_item.tx.height >= remote_height and \ - (cache_item.tx.is_verified or remote_height < 1): - return cache_item.tx # cached tx is already up-to-date - - async with cache_item.lock: - - tx = cache_item.tx - - if tx is None and check_local: - # check local db - tx = cache_item.tx = await self.db.get_transaction(txid=txid) - - if tx is None: - # fetch from network - _raw = await self.network.retriable_call(self.network.get_transaction, txid, remote_height) - tx = self.transaction_class(unhexlify(_raw)) - cache_item.tx = tx # make sure it's saved before caching it - - await self.maybe_verify_transaction(tx, remote_height) - return tx - - async def maybe_verify_transaction(self, tx, remote_height): - tx.height = remote_height - if 0 < remote_height < len(self.headers): - merkle = await self.network.retriable_call(self.network.get_merkle, tx.id, remote_height) - merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) - header = self.headers[remote_height] - tx.position = merkle['pos'] - tx.is_verified = merkle_root == header['merkle_root'] - - async def get_address_manager_for_address(self, address) -> Optional[baseaccount.AddressManager]: - details = await self.db.get_address(address=address) - for account in self.accounts: - if account.id == details['account']: - return account.address_managers[details['chain']] - return None - - def broadcast(self, tx): - # broadcast can't be a retriable call yet - return self.network.broadcast(hexlify(tx.raw).decode()) - - async def wait(self, tx: basetransaction.BaseTransaction, height=-1, timeout=1): - addresses = set() - for txi in tx.inputs: - if txi.txo_ref.txo is not None: - addresses.add( - self.hash160_to_address(txi.txo_ref.txo.pubkey_hash) - ) - for txo in tx.outputs: - if txo.has_address: - addresses.add(self.hash160_to_address(txo.pubkey_hash)) - records = await self.db.get_addresses(address__in=addresses) - _, pending = await asyncio.wait([ - self.on_transaction.where(partial( - lambda a, e: a == e.address and e.tx.height >= height and e.tx.id == tx.id, - address_record['address'] - )) for address_record in records - ], timeout=timeout) - if pending: - for record in records: - found = False - _, local_history = await self.get_local_status_and_history(None, history=record['history']) - for txid, local_height in local_history: - if txid == tx.id and local_height >= height: - found = True - if not found: - print(record['history'], addresses, tx.id) - raise asyncio.TimeoutError('Timed out waiting for transaction.') diff --git a/lbry/wallet/client/basemanager.py b/lbry/wallet/client/basemanager.py deleted file mode 100644 index d4a4a6ef7..000000000 --- a/lbry/wallet/client/basemanager.py +++ /dev/null @@ -1,87 +0,0 @@ -import asyncio -import logging -from typing import Type, MutableSequence, MutableMapping, Optional - -from lbry.wallet.client.baseledger import BaseLedger, LedgerRegistry -from lbry.wallet.client.wallet import Wallet, WalletStorage - -log = logging.getLogger(__name__) - - -class BaseWalletManager: - - def __init__(self, wallets: MutableSequence[Wallet] = None, - ledgers: MutableMapping[Type[BaseLedger], BaseLedger] = None) -> None: - self.wallets = wallets or [] - self.ledgers = ledgers or {} - self.running = False - - @classmethod - def from_config(cls, config: dict) -> 'BaseWalletManager': - manager = cls() - for ledger_id, ledger_config in config.get('ledgers', {}).items(): - manager.get_or_create_ledger(ledger_id, ledger_config) - for wallet_path in config.get('wallets', []): - wallet_storage = WalletStorage(wallet_path) - wallet = Wallet.from_storage(wallet_storage, manager) - manager.wallets.append(wallet) - return manager - - def get_or_create_ledger(self, ledger_id, ledger_config=None): - ledger_class = LedgerRegistry.get_ledger_class(ledger_id) - ledger = self.ledgers.get(ledger_class) - if ledger is None: - ledger = ledger_class(ledger_config or {}) - self.ledgers[ledger_class] = ledger - return ledger - - def import_wallet(self, path): - storage = WalletStorage(path) - wallet = Wallet.from_storage(storage, self) - self.wallets.append(wallet) - return wallet - - @property - def default_wallet(self): - for wallet in self.wallets: - return wallet - - @property - def default_account(self): - for wallet in self.wallets: - return wallet.default_account - - @property - def accounts(self): - for wallet in self.wallets: - yield from wallet.accounts - - async def start(self): - self.running = True - await asyncio.gather(*( - l.start() for l in self.ledgers.values() - )) - - async def stop(self): - await asyncio.gather(*( - l.stop() for l in self.ledgers.values() - )) - self.running = False - - def get_wallet_or_default(self, wallet_id: Optional[str]) -> Wallet: - if wallet_id is None: - return self.default_wallet - return self.get_wallet_or_error(wallet_id) - - def get_wallet_or_error(self, wallet_id: str) -> Wallet: - for wallet in self.wallets: - if wallet.id == wallet_id: - return wallet - raise ValueError(f"Couldn't find wallet: {wallet_id}.") - - @staticmethod - def get_balance(wallet): - accounts = wallet.accounts - if not accounts: - return 0 - return accounts[0].ledger.db.get_balance(wallet=wallet, accounts=accounts) diff --git a/lbry/wallet/client/basenetwork.py b/lbry/wallet/client/basenetwork.py deleted file mode 100644 index 7fa4d0251..000000000 --- a/lbry/wallet/client/basenetwork.py +++ /dev/null @@ -1,364 +0,0 @@ -import logging -import asyncio -from operator import itemgetter -from typing import Dict, Optional, Tuple -from time import perf_counter - -import lbry -from lbry.wallet.rpc import RPCSession as BaseClientSession, Connector, RPCError, ProtocolError -from lbry.wallet.stream import StreamController - -log = logging.getLogger(__name__) - - -class ClientSession(BaseClientSession): - def __init__(self, *args, network, server, timeout=30, on_connect_callback=None, **kwargs): - self.network = network - self.server = server - super().__init__(*args, **kwargs) - self._on_disconnect_controller = StreamController() - self.on_disconnected = self._on_disconnect_controller.stream - self.framer.max_size = self.max_errors = 1 << 32 - self.bw_limit = -1 - self.timeout = timeout - self.max_seconds_idle = timeout * 2 - self.response_time: Optional[float] = None - self.connection_latency: Optional[float] = None - self._response_samples = 0 - self.pending_amount = 0 - self._on_connect_cb = on_connect_callback or (lambda: None) - self.trigger_urgent_reconnect = asyncio.Event() - - @property - def available(self): - return not self.is_closing() and self.response_time is not None - - @property - def server_address_and_port(self) -> Optional[Tuple[str, int]]: - if not self.transport: - return None - return self.transport.get_extra_info('peername') - - async def send_timed_server_version_request(self, args=(), timeout=None): - timeout = timeout or self.timeout - log.debug("send version request to %s:%i", *self.server) - start = perf_counter() - result = await asyncio.wait_for( - super().send_request('server.version', args), timeout=timeout - ) - current_response_time = perf_counter() - start - response_sum = (self.response_time or 0) * self._response_samples + current_response_time - self.response_time = response_sum / (self._response_samples + 1) - self._response_samples += 1 - return result - - async def send_request(self, method, args=()): - self.pending_amount += 1 - log.debug("send %s to %s:%i", method, *self.server) - try: - if method == 'server.version': - return await self.send_timed_server_version_request(args, self.timeout) - request = asyncio.ensure_future(super().send_request(method, args)) - while not request.done(): - done, pending = await asyncio.wait([request], timeout=self.timeout) - if pending: - log.debug("Time since last packet: %s", perf_counter() - self.last_packet_received) - if (perf_counter() - self.last_packet_received) < self.timeout: - continue - log.info("timeout sending %s to %s:%i", method, *self.server) - raise asyncio.TimeoutError - if done: - return request.result() - except (RPCError, ProtocolError) as e: - log.warning("Wallet server (%s:%i) returned an error. Code: %s Message: %s", - *self.server, *e.args) - raise e - except ConnectionError: - log.warning("connection to %s:%i lost", *self.server) - self.synchronous_close() - raise - except asyncio.CancelledError: - log.info("cancelled sending %s to %s:%i", method, *self.server) - self.synchronous_close() - raise - finally: - self.pending_amount -= 1 - - async def ensure_session(self): - # Handles reconnecting and maintaining a session alive - # TODO: change to 'ping' on newer protocol (above 1.2) - retry_delay = default_delay = 1.0 - while True: - try: - if self.is_closing(): - await self.create_connection(self.timeout) - await self.ensure_server_version() - self._on_connect_cb() - if (perf_counter() - self.last_send) > self.max_seconds_idle or self.response_time is None: - await self.ensure_server_version() - retry_delay = default_delay - except RPCError as e: - log.warning("Server error, ignoring for 1h: %s:%d -- %s", *self.server, e.message) - retry_delay = 60 * 60 - except (asyncio.TimeoutError, OSError): - await self.close() - retry_delay = min(60, retry_delay * 2) - log.debug("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server) - try: - await asyncio.wait_for(self.trigger_urgent_reconnect.wait(), timeout=retry_delay) - except asyncio.TimeoutError: - pass - finally: - self.trigger_urgent_reconnect.clear() - - async def ensure_server_version(self, required=None, timeout=3): - required = required or self.network.PROTOCOL_VERSION - return await asyncio.wait_for( - self.send_request('server.version', [lbry.__version__, required]), timeout=timeout - ) - - async def create_connection(self, timeout=6): - connector = Connector(lambda: self, *self.server) - start = perf_counter() - await asyncio.wait_for(connector.create_connection(), timeout=timeout) - self.connection_latency = perf_counter() - start - - async def handle_request(self, request): - controller = self.network.subscription_controllers[request.method] - controller.add(request.args) - - def connection_lost(self, exc): - log.debug("Connection lost: %s:%d", *self.server) - super().connection_lost(exc) - self.response_time = None - self.connection_latency = None - self._response_samples = 0 - self.pending_amount = 0 - self._on_disconnect_controller.add(True) - - -class BaseNetwork: - PROTOCOL_VERSION = '1.2' - - def __init__(self, ledger): - self.ledger = ledger - self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6)) - self.client: Optional[ClientSession] = None - self._switch_task: Optional[asyncio.Task] = None - self.running = False - self.remote_height: int = 0 - self._concurrency = asyncio.Semaphore(16) - - self._on_connected_controller = StreamController() - self.on_connected = self._on_connected_controller.stream - - self._on_header_controller = StreamController(merge_repeated_events=True) - self.on_header = self._on_header_controller.stream - - self._on_status_controller = StreamController(merge_repeated_events=True) - self.on_status = self._on_status_controller.stream - - self.subscription_controllers = { - 'blockchain.headers.subscribe': self._on_header_controller, - 'blockchain.address.subscribe': self._on_status_controller, - } - - @property - def config(self): - return self.ledger.config - - async def switch_forever(self): - while self.running: - if self.is_connected: - await self.client.on_disconnected.first - self.client = None - continue - self.client = await self.session_pool.wait_for_fastest_session() - log.info("Switching to SPV wallet server: %s:%d", *self.client.server) - try: - self._update_remote_height((await self.subscribe_headers(),)) - self._on_connected_controller.add(True) - log.info("Subscribed to headers: %s:%d", *self.client.server) - except (asyncio.TimeoutError, ConnectionError): - log.info("Switching to %s:%d timed out, closing and retrying.", *self.client.server) - self.client.synchronous_close() - self.client = None - - async def start(self): - self.running = True - self._switch_task = asyncio.ensure_future(self.switch_forever()) - # this may become unnecessary when there are no more bugs found, - # but for now it helps understanding log reports - self._switch_task.add_done_callback(lambda _: log.info("Wallet client switching task stopped.")) - self.session_pool.start(self.config['default_servers']) - self.on_header.listen(self._update_remote_height) - - async def stop(self): - if self.running: - self.running = False - self._switch_task.cancel() - self.session_pool.stop() - - @property - def is_connected(self): - return self.client and not self.client.is_closing() - - def rpc(self, list_or_method, args, restricted=True): - session = self.client if restricted else self.session_pool.fastest_session - if session and not session.is_closing(): - return session.send_request(list_or_method, args) - else: - self.session_pool.trigger_nodelay_connect() - raise ConnectionError("Attempting to send rpc request when connection is not available.") - - async def retriable_call(self, function, *args, **kwargs): - async with self._concurrency: - while self.running: - if not self.is_connected: - log.warning("Wallet server unavailable, waiting for it to come back and retry.") - await self.on_connected.first - await self.session_pool.wait_for_fastest_session() - try: - return await function(*args, **kwargs) - except asyncio.TimeoutError: - log.warning("Wallet server call timed out, retrying.") - except ConnectionError: - pass - raise asyncio.CancelledError() # if we got here, we are shutting down - - def _update_remote_height(self, header_args): - self.remote_height = header_args[0]["height"] - - def get_transaction(self, tx_hash, known_height=None): - # use any server if its old, otherwise restrict to who gave us the history - restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10 - return self.rpc('blockchain.transaction.get', [tx_hash], restricted) - - def get_transaction_height(self, tx_hash, known_height=None): - restricted = not known_height or 0 > known_height > self.remote_height - 10 - return self.rpc('blockchain.transaction.get_height', [tx_hash], restricted) - - def get_merkle(self, tx_hash, height): - restricted = 0 > height > self.remote_height - 10 - return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height], restricted) - - def get_headers(self, height, count=10000, b64=False): - restricted = height >= self.remote_height - 100 - return self.rpc('blockchain.block.headers', [height, count, 0, b64], restricted) - - # --- Subscribes, history and broadcasts are always aimed towards the master client directly - def get_history(self, address): - return self.rpc('blockchain.address.get_history', [address], True) - - def broadcast(self, raw_transaction): - return self.rpc('blockchain.transaction.broadcast', [raw_transaction], True) - - def subscribe_headers(self): - return self.rpc('blockchain.headers.subscribe', [True], True) - - async def subscribe_address(self, address): - try: - return await self.rpc('blockchain.address.subscribe', [address], True) - except asyncio.TimeoutError: - # abort and cancel, we can't lose a subscription, it will happen again on reconnect - if self.client: - self.client.abort() - raise asyncio.CancelledError() - - def unsubscribe_address(self, address): - return self.rpc('blockchain.address.unsubscribe', [address], True) - - def get_server_features(self): - return self.rpc('server.features', (), restricted=True) - - -class SessionPool: - - def __init__(self, network: BaseNetwork, timeout: float): - self.network = network - self.sessions: Dict[ClientSession, Optional[asyncio.Task]] = dict() - self.timeout = timeout - self.new_connection_event = asyncio.Event() - - @property - def online(self): - return any(not session.is_closing() for session in self.sessions) - - @property - def available_sessions(self): - return (session for session in self.sessions if session.available) - - @property - def fastest_session(self): - if not self.online: - return None - return min( - [((session.response_time + session.connection_latency) * (session.pending_amount + 1), session) - for session in self.available_sessions] or [(0, None)], - key=itemgetter(0) - )[1] - - def _get_session_connect_callback(self, session: ClientSession): - loop = asyncio.get_event_loop() - - def callback(): - duplicate_connections = [ - s for s in self.sessions - if s is not session and s.server_address_and_port == session.server_address_and_port - ] - already_connected = None if not duplicate_connections else duplicate_connections[0] - if already_connected: - self.sessions.pop(session).cancel() - session.synchronous_close() - log.debug("wallet server %s resolves to the same server as %s, rechecking in an hour", - session.server[0], already_connected.server[0]) - loop.call_later(3600, self._connect_session, session.server) - return - self.new_connection_event.set() - log.info("connected to %s:%i", *session.server) - - return callback - - def _connect_session(self, server: Tuple[str, int]): - session = None - for s in self.sessions: - if s.server == server: - session = s - break - if not session: - session = ClientSession( - network=self.network, server=server - ) - session._on_connect_cb = self._get_session_connect_callback(session) - task = self.sessions.get(session, None) - if not task or task.done(): - task = asyncio.create_task(session.ensure_session()) - task.add_done_callback(lambda _: self.ensure_connections()) - self.sessions[session] = task - - def start(self, default_servers): - for server in default_servers: - self._connect_session(server) - - def stop(self): - for session, task in self.sessions.items(): - task.cancel() - session.synchronous_close() - self.sessions.clear() - - def ensure_connections(self): - for session in self.sessions: - self._connect_session(session.server) - - def trigger_nodelay_connect(self): - # used when other parts of the system sees we might have internet back - # bypasses the retry interval - for session in self.sessions: - session.trigger_urgent_reconnect.set() - - async def wait_for_fastest_session(self): - while not self.fastest_session: - self.trigger_nodelay_connect() - self.new_connection_event.clear() - await self.new_connection_event.wait() - return self.fastest_session diff --git a/lbry/wallet/client/basescript.py b/lbry/wallet/client/basescript.py deleted file mode 100644 index 6299bd909..000000000 --- a/lbry/wallet/client/basescript.py +++ /dev/null @@ -1,450 +0,0 @@ -from itertools import chain -from binascii import hexlify -from collections import namedtuple -from typing import List - -from lbry.wallet.client.bcd_data_stream import BCDataStream -from lbry.wallet.client.util import subclass_tuple - -# bitcoin opcodes -OP_0 = 0x00 -OP_1 = 0x51 -OP_16 = 0x60 -OP_VERIFY = 0x69 -OP_DUP = 0x76 -OP_HASH160 = 0xa9 -OP_EQUALVERIFY = 0x88 -OP_CHECKSIG = 0xac -OP_CHECKMULTISIG = 0xae -OP_EQUAL = 0x87 -OP_PUSHDATA1 = 0x4c -OP_PUSHDATA2 = 0x4d -OP_PUSHDATA4 = 0x4e -OP_RETURN = 0x6a -OP_2DROP = 0x6d -OP_DROP = 0x75 - - -# template matching opcodes (not real opcodes) -# base class for PUSH_DATA related opcodes -# pylint: disable=invalid-name -PUSH_DATA_OP = namedtuple('PUSH_DATA_OP', 'name') -# opcode for variable length strings -# pylint: disable=invalid-name -PUSH_SINGLE = subclass_tuple('PUSH_SINGLE', PUSH_DATA_OP) -# opcode for variable size integers -# pylint: disable=invalid-name -PUSH_INTEGER = subclass_tuple('PUSH_INTEGER', PUSH_DATA_OP) -# opcode for variable number of variable length strings -# pylint: disable=invalid-name -PUSH_MANY = subclass_tuple('PUSH_MANY', PUSH_DATA_OP) -# opcode with embedded subscript parsing -# pylint: disable=invalid-name -PUSH_SUBSCRIPT = namedtuple('PUSH_SUBSCRIPT', 'name template') - - -def is_push_data_opcode(opcode): - return isinstance(opcode, (PUSH_DATA_OP, PUSH_SUBSCRIPT)) - - -def is_push_data_token(token): - return 1 <= token <= OP_PUSHDATA4 - - -def push_data(data): - size = len(data) - if size < OP_PUSHDATA1: - yield BCDataStream.uint8.pack(size) - elif size <= 0xFF: - yield BCDataStream.uint8.pack(OP_PUSHDATA1) - yield BCDataStream.uint8.pack(size) - elif size <= 0xFFFF: - yield BCDataStream.uint8.pack(OP_PUSHDATA2) - yield BCDataStream.uint16.pack(size) - else: - yield BCDataStream.uint8.pack(OP_PUSHDATA4) - yield BCDataStream.uint32.pack(size) - yield bytes(data) - - -def read_data(token, stream): - if token < OP_PUSHDATA1: - return stream.read(token) - if token == OP_PUSHDATA1: - return stream.read(stream.read_uint8()) - if token == OP_PUSHDATA2: - return stream.read(stream.read_uint16()) - return stream.read(stream.read_uint32()) - - -# opcode for OP_1 - OP_16 -# pylint: disable=invalid-name -SMALL_INTEGER = namedtuple('SMALL_INTEGER', 'name') - - -def is_small_integer(token): - return OP_1 <= token <= OP_16 - - -def push_small_integer(num): - assert 1 <= num <= 16 - yield BCDataStream.uint8.pack(OP_1 + (num - 1)) - - -def read_small_integer(token): - return (token - OP_1) + 1 - - -class Token(namedtuple('Token', 'value')): - __slots__ = () - - def __repr__(self): - name = None - for var_name, var_value in globals().items(): - if var_name.startswith('OP_') and var_value == self.value: - name = var_name - break - return name or self.value - - -class DataToken(Token): - __slots__ = () - - def __repr__(self): - return f'"{hexlify(self.value)}"' - - -class SmallIntegerToken(Token): - __slots__ = () - - def __repr__(self): - return f'SmallIntegerToken({self.value})' - - -def token_producer(source): - token = source.read_uint8() - while token is not None: - if is_push_data_token(token): - yield DataToken(read_data(token, source)) - elif is_small_integer(token): - yield SmallIntegerToken(read_small_integer(token)) - else: - yield Token(token) - token = source.read_uint8() - - -def tokenize(source): - return list(token_producer(source)) - - -class ScriptError(Exception): - """ General script handling error. """ - - -class ParseError(ScriptError): - """ Script parsing error. """ - - -class Parser: - - def __init__(self, opcodes, tokens): - self.opcodes = opcodes - self.tokens = tokens - self.values = {} - self.token_index = 0 - self.opcode_index = 0 - - def parse(self): - while self.token_index < len(self.tokens) and self.opcode_index < len(self.opcodes): - token = self.tokens[self.token_index] - opcode = self.opcodes[self.opcode_index] - if token.value == 0 and isinstance(opcode, PUSH_SINGLE): - token = DataToken(b'') - if isinstance(token, DataToken): - if isinstance(opcode, (PUSH_SINGLE, PUSH_INTEGER, PUSH_SUBSCRIPT)): - self.push_single(opcode, token.value) - elif isinstance(opcode, PUSH_MANY): - self.consume_many_non_greedy() - else: - raise ParseError(f"DataToken found but opcode was '{opcode}'.") - elif isinstance(token, SmallIntegerToken): - if isinstance(opcode, SMALL_INTEGER): - self.values[opcode.name] = token.value - else: - raise ParseError(f"SmallIntegerToken found but opcode was '{opcode}'.") - elif token.value == opcode: - pass - else: - raise ParseError(f"Token is '{token.value}' and opcode is '{opcode}'.") - self.token_index += 1 - self.opcode_index += 1 - - if self.token_index < len(self.tokens): - raise ParseError("Parse completed without all tokens being consumed.") - - if self.opcode_index < len(self.opcodes): - raise ParseError("Parse completed without all opcodes being consumed.") - - return self - - def consume_many_non_greedy(self): - """ Allows PUSH_MANY to consume data without being greedy - in cases when one or more PUSH_SINGLEs follow a PUSH_MANY. This will - prioritize giving all PUSH_SINGLEs some data and only after that - subsume the rest into PUSH_MANY. - """ - - token_values = [] - while self.token_index < len(self.tokens): - token = self.tokens[self.token_index] - if not isinstance(token, DataToken): - self.token_index -= 1 - break - token_values.append(token.value) - self.token_index += 1 - - push_opcodes = [] - push_many_count = 0 - while self.opcode_index < len(self.opcodes): - opcode = self.opcodes[self.opcode_index] - if not is_push_data_opcode(opcode): - self.opcode_index -= 1 - break - if isinstance(opcode, PUSH_MANY): - push_many_count += 1 - push_opcodes.append(opcode) - self.opcode_index += 1 - - if push_many_count > 1: - raise ParseError( - "Cannot have more than one consecutive PUSH_MANY, as there is no way to tell which" - " token value should go into which PUSH_MANY." - ) - - if len(push_opcodes) > len(token_values): - raise ParseError( - "Not enough token values to match all of the PUSH_MANY and PUSH_SINGLE opcodes." - ) - - many_opcode = push_opcodes.pop(0) - - # consume data into PUSH_SINGLE opcodes, working backwards - for opcode in reversed(push_opcodes): - self.push_single(opcode, token_values.pop()) - - # finally PUSH_MANY gets everything that's left - self.values[many_opcode.name] = token_values - - def push_single(self, opcode, value): - if isinstance(opcode, PUSH_SINGLE): - self.values[opcode.name] = value - elif isinstance(opcode, PUSH_INTEGER): - self.values[opcode.name] = int.from_bytes(value, 'little') - elif isinstance(opcode, PUSH_SUBSCRIPT): - self.values[opcode.name] = Script.from_source_with_template(value, opcode.template) - else: - raise ParseError(f"Not a push single or subscript: {opcode}") - - -class Template: - - __slots__ = 'name', 'opcodes' - - def __init__(self, name, opcodes): - self.name = name - self.opcodes = opcodes - - def parse(self, tokens): - return Parser(self.opcodes, tokens).parse().values if self.opcodes else {} - - def generate(self, values): - source = BCDataStream() - for opcode in self.opcodes: - if isinstance(opcode, PUSH_SINGLE): - data = values[opcode.name] - source.write_many(push_data(data)) - elif isinstance(opcode, PUSH_INTEGER): - data = values[opcode.name] - source.write_many(push_data( - data.to_bytes((data.bit_length() + 7) // 8, byteorder='little') - )) - elif isinstance(opcode, PUSH_SUBSCRIPT): - data = values[opcode.name] - source.write_many(push_data(data.source)) - elif isinstance(opcode, PUSH_MANY): - for data in values[opcode.name]: - source.write_many(push_data(data)) - elif isinstance(opcode, SMALL_INTEGER): - data = values[opcode.name] - source.write_many(push_small_integer(data)) - else: - source.write_uint8(opcode) - return source.get_bytes() - - -class Script: - - __slots__ = 'source', '_template', '_values', '_template_hint' - - templates: List[Template] = [] - - NO_SCRIPT = Template('no_script', None) # special case - - def __init__(self, source=None, template=None, values=None, template_hint=None): - self.source = source - self._template = template - self._values = values - self._template_hint = template_hint - if source is None and template and values: - self.generate() - - @property - def template(self): - if self._template is None: - self.parse(self._template_hint) - return self._template - - @property - def values(self): - if self._values is None: - self.parse(self._template_hint) - return self._values - - @property - def tokens(self): - return tokenize(BCDataStream(self.source)) - - @classmethod - def from_source_with_template(cls, source, template): - return cls(source, template_hint=template) - - def parse(self, template_hint=None): - tokens = self.tokens - if not tokens and not template_hint: - template_hint = self.NO_SCRIPT - for template in chain((template_hint,), self.templates): - if not template: - continue - try: - self._values = template.parse(tokens) - self._template = template - return - except ParseError: - continue - raise ValueError(f'No matching templates for source: {hexlify(self.source)}') - - def generate(self): - self.source = self.template.generate(self._values) - - -class BaseInputScript(Script): - """ Input / redeem script templates (aka scriptSig) """ - - __slots__ = () - - REDEEM_PUBKEY = Template('pubkey', ( - PUSH_SINGLE('signature'), - )) - REDEEM_PUBKEY_HASH = Template('pubkey_hash', ( - PUSH_SINGLE('signature'), PUSH_SINGLE('pubkey') - )) - REDEEM_SCRIPT = Template('script', ( - SMALL_INTEGER('signatures_count'), PUSH_MANY('pubkeys'), SMALL_INTEGER('pubkeys_count'), - OP_CHECKMULTISIG - )) - REDEEM_SCRIPT_HASH = Template('script_hash', ( - OP_0, PUSH_MANY('signatures'), PUSH_SUBSCRIPT('script', REDEEM_SCRIPT) - )) - - templates = [ - REDEEM_PUBKEY, - REDEEM_PUBKEY_HASH, - REDEEM_SCRIPT_HASH, - REDEEM_SCRIPT - ] - - @classmethod - def redeem_pubkey_hash(cls, signature, pubkey): - return cls(template=cls.REDEEM_PUBKEY_HASH, values={ - 'signature': signature, - 'pubkey': pubkey - }) - - @classmethod - def redeem_script_hash(cls, signatures, pubkeys): - return cls(template=cls.REDEEM_SCRIPT_HASH, values={ - 'signatures': signatures, - 'script': cls.redeem_script(signatures, pubkeys) - }) - - @classmethod - def redeem_script(cls, signatures, pubkeys): - return cls(template=cls.REDEEM_SCRIPT, values={ - 'signatures_count': len(signatures), - 'pubkeys': pubkeys, - 'pubkeys_count': len(pubkeys) - }) - - -class BaseOutputScript(Script): - - __slots__ = () - - # output / payment script templates (aka scriptPubKey) - PAY_PUBKEY_FULL = Template('pay_pubkey_full', ( - PUSH_SINGLE('pubkey'), OP_CHECKSIG - )) - PAY_PUBKEY_HASH = Template('pay_pubkey_hash', ( - OP_DUP, OP_HASH160, PUSH_SINGLE('pubkey_hash'), OP_EQUALVERIFY, OP_CHECKSIG - )) - PAY_SCRIPT_HASH = Template('pay_script_hash', ( - OP_HASH160, PUSH_SINGLE('script_hash'), OP_EQUAL - )) - PAY_SEGWIT = Template('pay_script_hash+segwit', ( - OP_0, PUSH_SINGLE('script_hash') - )) - RETURN_DATA = Template('return_data', ( - OP_RETURN, PUSH_SINGLE('data') - )) - - templates = [ - PAY_PUBKEY_FULL, - PAY_PUBKEY_HASH, - PAY_SCRIPT_HASH, - PAY_SEGWIT, - RETURN_DATA - ] - - @classmethod - def pay_pubkey_hash(cls, pubkey_hash): - return cls(template=cls.PAY_PUBKEY_HASH, values={ - 'pubkey_hash': pubkey_hash - }) - - @classmethod - def pay_script_hash(cls, script_hash): - return cls(template=cls.PAY_SCRIPT_HASH, values={ - 'script_hash': script_hash - }) - - @classmethod - def return_data(cls, data): - return cls(template=cls.RETURN_DATA, values={ - 'data': data - }) - - @property - def is_pay_pubkey(self): - return self.template.name.endswith('pay_pubkey_full') - - @property - def is_pay_pubkey_hash(self): - return self.template.name.endswith('pay_pubkey_hash') - - @property - def is_pay_script_hash(self): - return self.template.name.endswith('pay_script_hash') - - @property - def is_return_data(self): - return self.template.name.endswith('return_data') diff --git a/lbry/wallet/client/basetransaction.py b/lbry/wallet/client/basetransaction.py deleted file mode 100644 index 8429d641a..000000000 --- a/lbry/wallet/client/basetransaction.py +++ /dev/null @@ -1,580 +0,0 @@ -import logging -import typing -from typing import List, Iterable, Optional, Tuple -from binascii import hexlify - -from lbry.crypto.hash import sha256 -from lbry.wallet.client.basescript import BaseInputScript, BaseOutputScript -from lbry.wallet.client.baseaccount import BaseAccount -from lbry.wallet.client.constants import COIN, NULL_HASH32 -from lbry.wallet.client.bcd_data_stream import BCDataStream -from lbry.wallet.client.hash import TXRef, TXRefImmutable -from lbry.wallet.client.util import ReadOnlyList -from lbry.error import InsufficientFundsError - -if typing.TYPE_CHECKING: - from lbry.wallet.client import baseledger, wallet as basewallet - -log = logging.getLogger() - - -class TXRefMutable(TXRef): - - __slots__ = ('tx',) - - def __init__(self, tx: 'BaseTransaction') -> None: - super().__init__() - self.tx = tx - - @property - def id(self): - if self._id is None: - self._id = hexlify(self.hash[::-1]).decode() - return self._id - - @property - def hash(self): - if self._hash is None: - self._hash = sha256(sha256(self.tx.raw_sans_segwit)) - return self._hash - - @property - def height(self): - return self.tx.height - - def reset(self): - self._id = None - self._hash = None - - -class TXORef: - - __slots__ = 'tx_ref', 'position' - - def __init__(self, tx_ref: TXRef, position: int) -> None: - self.tx_ref = tx_ref - self.position = position - - @property - def id(self): - return f'{self.tx_ref.id}:{self.position}' - - @property - def hash(self): - return self.tx_ref.hash + BCDataStream.uint32.pack(self.position) - - @property - def is_null(self): - return self.tx_ref.is_null - - @property - def txo(self) -> Optional['BaseOutput']: - return None - - -class TXORefResolvable(TXORef): - - __slots__ = ('_txo',) - - def __init__(self, txo: 'BaseOutput') -> None: - assert txo.tx_ref is not None - assert txo.position is not None - super().__init__(txo.tx_ref, txo.position) - self._txo = txo - - @property - def txo(self): - return self._txo - - -class InputOutput: - - __slots__ = 'tx_ref', 'position' - - def __init__(self, tx_ref: TXRef = None, position: int = None) -> None: - self.tx_ref = tx_ref - self.position = position - - @property - def size(self) -> int: - """ Size of this input / output in bytes. """ - stream = BCDataStream() - self.serialize_to(stream) - return len(stream.get_bytes()) - - def get_fee(self, ledger): - return self.size * ledger.fee_per_byte - - def serialize_to(self, stream, alternate_script=None): - raise NotImplementedError - - -class BaseInput(InputOutput): - - script_class = BaseInputScript - - NULL_SIGNATURE = b'\x00'*72 - NULL_PUBLIC_KEY = b'\x00'*33 - - __slots__ = 'txo_ref', 'sequence', 'coinbase', 'script' - - def __init__(self, txo_ref: TXORef, script: BaseInputScript, sequence: int = 0xFFFFFFFF, - tx_ref: TXRef = None, position: int = None) -> None: - super().__init__(tx_ref, position) - self.txo_ref = txo_ref - self.sequence = sequence - self.coinbase = script if txo_ref.is_null else None - self.script = script if not txo_ref.is_null else None - - @property - def is_coinbase(self): - return self.coinbase is not None - - @classmethod - def spend(cls, txo: 'BaseOutput') -> 'BaseInput': - """ Create an input to spend the output.""" - assert txo.script.is_pay_pubkey_hash, 'Attempting to spend unsupported output.' - script = cls.script_class.redeem_pubkey_hash(cls.NULL_SIGNATURE, cls.NULL_PUBLIC_KEY) - return cls(txo.ref, script) - - @property - def amount(self) -> int: - """ Amount this input adds to the transaction. """ - if self.txo_ref.txo is None: - raise ValueError('Cannot resolve output to get amount.') - return self.txo_ref.txo.amount - - @property - def is_my_account(self) -> Optional[bool]: - """ True if the output this input spends is yours. """ - if self.txo_ref.txo is None: - return False - return self.txo_ref.txo.is_my_account - - @classmethod - def deserialize_from(cls, stream): - tx_ref = TXRefImmutable.from_hash(stream.read(32), -1) - position = stream.read_uint32() - script = stream.read_string() - sequence = stream.read_uint32() - return cls( - TXORef(tx_ref, position), - cls.script_class(script) if not tx_ref.is_null else script, - sequence - ) - - def serialize_to(self, stream, alternate_script=None): - stream.write(self.txo_ref.tx_ref.hash) - stream.write_uint32(self.txo_ref.position) - if alternate_script is not None: - stream.write_string(alternate_script) - else: - if self.is_coinbase: - stream.write_string(self.coinbase) - else: - stream.write_string(self.script.source) - stream.write_uint32(self.sequence) - - -class BaseOutputEffectiveAmountEstimator: - - __slots__ = 'txo', 'txi', 'fee', 'effective_amount' - - def __init__(self, ledger: 'baseledger.BaseLedger', txo: 'BaseOutput') -> None: - self.txo = txo - self.txi = ledger.transaction_class.input_class.spend(txo) - self.fee: int = self.txi.get_fee(ledger) - self.effective_amount: int = txo.amount - self.fee - - def __lt__(self, other): - return self.effective_amount < other.effective_amount - - -class BaseOutput(InputOutput): - - script_class = BaseOutputScript - estimator_class = BaseOutputEffectiveAmountEstimator - - __slots__ = 'amount', 'script', 'is_change', 'is_my_account' - - def __init__(self, amount: int, script: BaseOutputScript, - tx_ref: TXRef = None, position: int = None, - is_change: Optional[bool] = None, is_my_account: Optional[bool] = None - ) -> None: - super().__init__(tx_ref, position) - self.amount = amount - self.script = script - self.is_change = is_change - self.is_my_account = is_my_account - - def update_annotations(self, annotated): - if annotated is None: - self.is_change = False - self.is_my_account = False - else: - self.is_change = annotated.is_change - self.is_my_account = annotated.is_my_account - - @property - def ref(self): - return TXORefResolvable(self) - - @property - def id(self): - return self.ref.id - - @property - def pubkey_hash(self): - return self.script.values['pubkey_hash'] - - @property - def has_address(self): - return 'pubkey_hash' in self.script.values - - def get_address(self, ledger): - return ledger.hash160_to_address(self.pubkey_hash) - - def get_estimator(self, ledger): - return self.estimator_class(ledger, self) - - @classmethod - def pay_pubkey_hash(cls, amount, pubkey_hash): - return cls(amount, cls.script_class.pay_pubkey_hash(pubkey_hash)) - - @classmethod - def deserialize_from(cls, stream): - return cls( - amount=stream.read_uint64(), - script=cls.script_class(stream.read_string()) - ) - - def serialize_to(self, stream, alternate_script=None): - stream.write_uint64(self.amount) - stream.write_string(self.script.source) - - -class BaseTransaction: - - input_class = BaseInput - output_class = BaseOutput - - def __init__(self, raw=None, version: int = 1, locktime: int = 0, is_verified: bool = False, - height: int = -2, position: int = -1) -> None: - self._raw = raw - self._raw_sans_segwit = None - self.is_segwit_flag = 0 - self.witnesses: List[bytes] = [] - self.ref = TXRefMutable(self) - self.version = version - self.locktime = locktime - self._inputs: List[BaseInput] = [] - self._outputs: List[BaseOutput] = [] - self.is_verified = is_verified - # Height Progression - # -2: not broadcast - # -1: in mempool but has unconfirmed inputs - # 0: in mempool and all inputs confirmed - # +num: confirmed in a specific block (height) - self.height = height - self.position = position - if raw is not None: - self._deserialize() - - @property - def is_broadcast(self): - return self.height > -2 - - @property - def is_mempool(self): - return self.height in (-1, 0) - - @property - def is_confirmed(self): - return self.height > 0 - - @property - def id(self): - return self.ref.id - - @property - def hash(self): - return self.ref.hash - - @property - def raw(self): - if self._raw is None: - self._raw = self._serialize() - return self._raw - - @property - def raw_sans_segwit(self): - if self.is_segwit_flag: - if self._raw_sans_segwit is None: - self._raw_sans_segwit = self._serialize(sans_segwit=True) - return self._raw_sans_segwit - return self.raw - - def _reset(self): - self._raw = None - self._raw_sans_segwit = None - self.ref.reset() - - @property - def inputs(self) -> ReadOnlyList[BaseInput]: - return ReadOnlyList(self._inputs) - - @property - def outputs(self) -> ReadOnlyList[BaseOutput]: - return ReadOnlyList(self._outputs) - - def _add(self, existing_ios: List, new_ios: Iterable[InputOutput], reset=False) -> 'BaseTransaction': - for txio in new_ios: - txio.tx_ref = self.ref - txio.position = len(existing_ios) - existing_ios.append(txio) - if reset: - self._reset() - return self - - def add_inputs(self, inputs: Iterable[BaseInput]) -> 'BaseTransaction': - return self._add(self._inputs, inputs, True) - - def add_outputs(self, outputs: Iterable[BaseOutput]) -> 'BaseTransaction': - return self._add(self._outputs, outputs, True) - - @property - def size(self) -> int: - """ Size in bytes of the entire transaction. """ - return len(self.raw) - - @property - def base_size(self) -> int: - """ Size of transaction without inputs or outputs in bytes. """ - return ( - self.size - - sum(txi.size for txi in self._inputs) - - sum(txo.size for txo in self._outputs) - ) - - @property - def input_sum(self): - return sum(i.amount for i in self.inputs if i.txo_ref.txo is not None) - - @property - def output_sum(self): - return sum(o.amount for o in self.outputs) - - @property - def net_account_balance(self) -> int: - balance = 0 - for txi in self.inputs: - if txi.txo_ref.txo is None: - continue - if txi.is_my_account is None: - raise ValueError( - "Cannot access net_account_balance if inputs/outputs do not " - "have is_my_account set properly." - ) - if txi.is_my_account: - balance -= txi.amount - for txo in self.outputs: - if txo.is_my_account is None: - raise ValueError( - "Cannot access net_account_balance if inputs/outputs do not " - "have is_my_account set properly." - ) - if txo.is_my_account: - balance += txo.amount - return balance - - @property - def fee(self) -> int: - return self.input_sum - self.output_sum - - def get_base_fee(self, ledger) -> int: - """ Fee for base tx excluding inputs and outputs. """ - return self.base_size * ledger.fee_per_byte - - def get_effective_input_sum(self, ledger) -> int: - """ Sum of input values *minus* the cost involved to spend them. """ - return sum(txi.amount - txi.get_fee(ledger) for txi in self._inputs) - - def get_total_output_sum(self, ledger) -> int: - """ Sum of output values *plus* the cost involved to spend them. """ - return sum(txo.amount + txo.get_fee(ledger) for txo in self._outputs) - - def _serialize(self, with_inputs: bool = True, sans_segwit: bool = False) -> bytes: - stream = BCDataStream() - stream.write_uint32(self.version) - if with_inputs: - stream.write_compact_size(len(self._inputs)) - for txin in self._inputs: - txin.serialize_to(stream) - stream.write_compact_size(len(self._outputs)) - for txout in self._outputs: - txout.serialize_to(stream) - stream.write_uint32(self.locktime) - return stream.get_bytes() - - def _serialize_for_signature(self, signing_input: int) -> bytes: - stream = BCDataStream() - stream.write_uint32(self.version) - stream.write_compact_size(len(self._inputs)) - for i, txin in enumerate(self._inputs): - if signing_input == i: - assert txin.txo_ref.txo is not None - txin.serialize_to(stream, txin.txo_ref.txo.script.source) - else: - txin.serialize_to(stream, b'') - stream.write_compact_size(len(self._outputs)) - for txout in self._outputs: - txout.serialize_to(stream) - stream.write_uint32(self.locktime) - stream.write_uint32(self.signature_hash_type(1)) # signature hash type: SIGHASH_ALL - return stream.get_bytes() - - def _deserialize(self): - if self._raw is not None: - stream = BCDataStream(self._raw) - self.version = stream.read_uint32() - input_count = stream.read_compact_size() - if input_count == 0: - self.is_segwit_flag = stream.read_uint8() - input_count = stream.read_compact_size() - self._add(self._inputs, [ - self.input_class.deserialize_from(stream) for _ in range(input_count) - ]) - output_count = stream.read_compact_size() - self._add(self._outputs, [ - self.output_class.deserialize_from(stream) for _ in range(output_count) - ]) - if self.is_segwit_flag: - # drain witness portion of transaction - # too many witnesses for no crime - self.witnesses = [] - for _ in range(input_count): - for _ in range(stream.read_compact_size()): - self.witnesses.append(stream.read(stream.read_compact_size())) - self.locktime = stream.read_uint32() - - @classmethod - def ensure_all_have_same_ledger_and_wallet( - cls, funding_accounts: Iterable[BaseAccount], - change_account: BaseAccount = None) -> Tuple['baseledger.BaseLedger', 'basewallet.Wallet']: - ledger = wallet = None - for account in funding_accounts: - if ledger is None: - ledger = account.ledger - wallet = account.wallet - if ledger != account.ledger: - raise ValueError( - 'All funding accounts used to create a transaction must be on the same ledger.' - ) - if wallet != account.wallet: - raise ValueError( - 'All funding accounts used to create a transaction must be from the same wallet.' - ) - if change_account is not None: - if change_account.ledger != ledger: - raise ValueError('Change account must use same ledger as funding accounts.') - if change_account.wallet != wallet: - raise ValueError('Change account must use same wallet as funding accounts.') - if ledger is None: - raise ValueError('No ledger found.') - if wallet is None: - raise ValueError('No wallet found.') - return ledger, wallet - - @classmethod - async def create(cls, inputs: Iterable[BaseInput], outputs: Iterable[BaseOutput], - funding_accounts: Iterable[BaseAccount], change_account: BaseAccount, - sign: bool = True): - """ Find optimal set of inputs when only outputs are provided; add change - outputs if only inputs are provided or if inputs are greater than outputs. """ - - tx = cls() \ - .add_inputs(inputs) \ - .add_outputs(outputs) - - ledger, _ = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account) - - # value of the outputs plus associated fees - cost = ( - tx.get_base_fee(ledger) + - tx.get_total_output_sum(ledger) - ) - # value of the inputs less the cost to spend those inputs - payment = tx.get_effective_input_sum(ledger) - - try: - - for _ in range(5): - - if payment < cost: - deficit = cost - payment - spendables = await ledger.get_spendable_utxos(deficit, funding_accounts) - if not spendables: - raise InsufficientFundsError() - payment += sum(s.effective_amount for s in spendables) - tx.add_inputs(s.txi for s in spendables) - - cost_of_change = ( - tx.get_base_fee(ledger) + - cls.output_class.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(ledger) - ) - if payment > cost: - change = payment - cost - if change > cost_of_change: - change_address = await change_account.change.get_or_create_usable_address() - change_hash160 = change_account.ledger.address_to_hash160(change_address) - change_amount = change - cost_of_change - change_output = cls.output_class.pay_pubkey_hash(change_amount, change_hash160) - change_output.is_change = True - tx.add_outputs([cls.output_class.pay_pubkey_hash(change_amount, change_hash160)]) - - if tx._outputs: - break - # this condition and the outer range(5) loop cover an edge case - # whereby a single input is just enough to cover the fee and - # has some change left over, but the change left over is less - # than the cost_of_change: thus the input is completely - # consumed and no output is added, which is an invalid tx. - # to be able to spend this input we must increase the cost - # of the TX and run through the balance algorithm a second time - # adding an extra input and change output, making tx valid. - # we do this 5 times in case the other UTXOs added are also - # less than the fee, after 5 attempts we give up and go home - cost += cost_of_change + 1 - - if sign: - await tx.sign(funding_accounts) - - except Exception as e: - log.exception('Failed to create transaction:') - await ledger.release_tx(tx) - raise e - - return tx - - @staticmethod - def signature_hash_type(hash_type): - return hash_type - - async def sign(self, funding_accounts: Iterable[BaseAccount]): - ledger, wallet = self.ensure_all_have_same_ledger_and_wallet(funding_accounts) - for i, txi in enumerate(self._inputs): - assert txi.script is not None - assert txi.txo_ref.txo is not None - txo_script = txi.txo_ref.txo.script - if txo_script.is_pay_pubkey_hash: - address = ledger.hash160_to_address(txo_script.values['pubkey_hash']) - private_key = await ledger.get_private_key_for_address(wallet, address) - assert private_key is not None, 'Cannot find private key for signing output.' - tx = self._serialize_for_signature(i) - txi.script.values['signature'] = \ - private_key.sign(tx) + bytes((self.signature_hash_type(1),)) - txi.script.values['pubkey'] = private_key.public_key.pubkey_bytes - txi.script.generate() - else: - raise NotImplementedError("Don't know how to spend this output.") - self._reset() diff --git a/lbry/wallet/client/constants.py b/lbry/wallet/client/constants.py deleted file mode 100644 index 6c790f02c..000000000 --- a/lbry/wallet/client/constants.py +++ /dev/null @@ -1,6 +0,0 @@ -NULL_HASH32 = b'\x00'*32 - -CENT = 1000000 -COIN = 100*CENT - -TIMEOUT = 30.0 diff --git a/lbry/wallet/client/words/__init__.py b/lbry/wallet/client/words/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/lbry/wallet/client/coinselection.py b/lbry/wallet/coinselection.py similarity index 100% rename from lbry/wallet/client/coinselection.py rename to lbry/wallet/coinselection.py diff --git a/lbry/wallet/client/hash.py b/lbry/wallet/hash.py similarity index 100% rename from lbry/wallet/client/hash.py rename to lbry/wallet/hash.py diff --git a/lbry/wallet/client/mnemonic.py b/lbry/wallet/mnemonic.py similarity index 100% rename from lbry/wallet/client/mnemonic.py rename to lbry/wallet/mnemonic.py diff --git a/lbry/wallet/client/util.py b/lbry/wallet/util.py similarity index 100% rename from lbry/wallet/client/util.py rename to lbry/wallet/util.py diff --git a/lbry/wallet/client/wallet.py b/lbry/wallet/wallet.py similarity index 100% rename from lbry/wallet/client/wallet.py rename to lbry/wallet/wallet.py diff --git a/lbry/wallet/client/__init__.py b/lbry/wallet/words/__init__.py similarity index 100% rename from lbry/wallet/client/__init__.py rename to lbry/wallet/words/__init__.py diff --git a/lbry/wallet/client/words/chinese_simplified.py b/lbry/wallet/words/chinese_simplified.py similarity index 100% rename from lbry/wallet/client/words/chinese_simplified.py rename to lbry/wallet/words/chinese_simplified.py diff --git a/lbry/wallet/client/words/english.py b/lbry/wallet/words/english.py similarity index 100% rename from lbry/wallet/client/words/english.py rename to lbry/wallet/words/english.py diff --git a/lbry/wallet/client/words/japanese.py b/lbry/wallet/words/japanese.py similarity index 100% rename from lbry/wallet/client/words/japanese.py rename to lbry/wallet/words/japanese.py diff --git a/lbry/wallet/client/words/portuguese.py b/lbry/wallet/words/portuguese.py similarity index 100% rename from lbry/wallet/client/words/portuguese.py rename to lbry/wallet/words/portuguese.py diff --git a/lbry/wallet/client/words/spanish.py b/lbry/wallet/words/spanish.py similarity index 100% rename from lbry/wallet/client/words/spanish.py rename to lbry/wallet/words/spanish.py