diff --git a/.gitignore b/.gitignore index 51b7f31fc..338420313 100644 --- a/.gitignore +++ b/.gitignore @@ -10,7 +10,7 @@ lbry.egg-info __pycache__ _trial_temp/ -/tests/integration/files +/tests/integration/blockchain/files /tests/.coverage.* /lbry/wallet/bin diff --git a/lbry/conf.py b/lbry/conf.py index e98f3262f..40eaabce2 100644 --- a/lbry/conf.py +++ b/lbry/conf.py @@ -9,7 +9,7 @@ from contextlib import contextmanager from appdirs import user_data_dir, user_config_dir from lbry.error import InvalidCurrencyError from lbry.dht import constants -from lbry.wallet.client.coinselection import STRATEGIES +from lbry.wallet.coinselection import STRATEGIES log = logging.getLogger(__name__) diff --git a/lbry/extras/daemon/Components.py b/lbry/extras/daemon/Components.py index 73dd2cd79..ba3ca5b9e 100644 --- a/lbry/extras/daemon/Components.py +++ b/lbry/extras/daemon/Components.py @@ -20,7 +20,7 @@ from lbry.stream.stream_manager import StreamManager from lbry.extras.daemon.Component import Component from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager from lbry.extras.daemon.storage import SQLiteStorage -from lbry.wallet import LbryWalletManager +from lbry.wallet import WalletManager log = logging.getLogger(__name__) diff --git a/lbry/extras/daemon/Daemon.py b/lbry/extras/daemon/Daemon.py index 7d911405d..1cd086bda 100644 --- a/lbry/extras/daemon/Daemon.py +++ b/lbry/extras/daemon/Daemon.py @@ -17,8 +17,11 @@ from traceback import format_exc from aiohttp import web from functools import wraps, partial from google.protobuf.message import DecodeError -from lbry.wallet.client.wallet import Wallet, ENCRYPT_ON_DISK -from lbry.wallet.client.baseaccount import SingleKey, HierarchicalDeterministic +from lbry.wallet import ( + Wallet, WalletManager, ENCRYPT_ON_DISK, SingleKey, HierarchicalDeterministic, + Ledger, Transaction, Output, Input, Account +) +from lbry.wallet.dewies import dewies_to_lbc, lbc_to_dewies, dict_values_to_lbc from lbry import utils from lbry.conf import Config, Setting, NOT_SET @@ -39,9 +42,6 @@ from lbry.extras.daemon.ComponentManager import ComponentManager from lbry.extras.daemon.json_response_encoder import JSONResponseEncoder from lbry.extras.daemon import comment_client from lbry.extras.daemon.undecorated import undecorated -from lbry.wallet.transaction import Transaction, Output, Input -from lbry.wallet.account import Account as LBCAccount -from lbry.wallet.dewies import dewies_to_lbc, lbc_to_dewies, dict_values_to_lbc from lbry.schema.claim import Claim from lbry.schema.url import URL @@ -51,8 +51,6 @@ if typing.TYPE_CHECKING: from lbry.extras.daemon.Components import UPnPComponent from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager from lbry.extras.daemon.storage import SQLiteStorage - from lbry.wallet.manager import LbryWalletManager - from lbry.wallet.ledger import MainNetLedger from lbry.stream.stream_manager import StreamManager log = logging.getLogger(__name__) @@ -322,7 +320,7 @@ class Daemon(metaclass=JSONRPCServerType): return self.component_manager.get_component(DHT_COMPONENT) @property - def wallet_manager(self) -> typing.Optional['LbryWalletManager']: + def wallet_manager(self) -> typing.Optional['WalletManager']: return self.component_manager.get_component(WALLET_COMPONENT) @property @@ -676,7 +674,7 @@ class Daemon(metaclass=JSONRPCServerType): return None, None @property - def ledger(self) -> Optional['MainNetLedger']: + def ledger(self) -> Optional['Ledger']: try: return self.wallet_manager.default_account.ledger except AttributeError: @@ -1161,7 +1159,7 @@ class Daemon(metaclass=JSONRPCServerType): wallet = self.wallet_manager.import_wallet(wallet_path) if not wallet.accounts and create_account: - account = LBCAccount.generate( + account = Account.generate( self.ledger, wallet, address_generator={ 'name': SingleKey.name if single_key else HierarchicalDeterministic.name } @@ -1464,7 +1462,7 @@ class Daemon(metaclass=JSONRPCServerType): Returns: {Account} """ wallet = self.wallet_manager.get_wallet_or_default(wallet_id) - account = LBCAccount.from_dict( + account = Account.from_dict( self.ledger, wallet, { 'name': account_name, 'seed': seed, @@ -1498,7 +1496,7 @@ class Daemon(metaclass=JSONRPCServerType): Returns: {Account} """ wallet = self.wallet_manager.get_wallet_or_default(wallet_id) - account = LBCAccount.generate( + account = Account.generate( self.ledger, wallet, account_name, { 'name': SingleKey.name if single_key else HierarchicalDeterministic.name } @@ -2134,7 +2132,7 @@ class Daemon(metaclass=JSONRPCServerType): """ wallet = self.wallet_manager.get_wallet_or_default(wallet_id) if account_id: - account: LBCAccount = wallet.get_account_or_error(account_id) + account = wallet.get_account_or_error(account_id) claims = account.get_claims claim_count = account.get_claim_count else: @@ -2657,7 +2655,7 @@ class Daemon(metaclass=JSONRPCServerType): """ wallet = self.wallet_manager.get_wallet_or_default(wallet_id) if account_id: - account: LBCAccount = wallet.get_account_or_error(account_id) + account = wallet.get_account_or_error(account_id) channels = account.get_channels channel_count = account.get_channel_count else: @@ -2732,7 +2730,7 @@ class Daemon(metaclass=JSONRPCServerType): if channels and channels[0].get_address(self.ledger) != holding_address: holding_address = channels[0].get_address(self.ledger) - account: LBCAccount = await self.ledger.get_account_for_address(wallet, holding_address) + account = await self.ledger.get_account_for_address(wallet, holding_address) if account: # Case 1: channel holding address is in one of the accounts we already have # simply add the certificate to existing account @@ -2741,7 +2739,7 @@ class Daemon(metaclass=JSONRPCServerType): # Case 2: channel holding address hasn't changed and thus is in the bundled read-only account # create a single-address holding account to manage the channel if holding_address == data['holding_address']: - account = LBCAccount.from_dict(self.ledger, wallet, { + account = Account.from_dict(self.ledger, wallet, { 'name': f"Holding Account For Channel {data['name']}", 'public_key': data['holding_public_key'], 'address_generator': {'name': 'single-address'} @@ -3384,7 +3382,7 @@ class Daemon(metaclass=JSONRPCServerType): """ wallet = self.wallet_manager.get_wallet_or_default(wallet_id) if account_id: - account: LBCAccount = wallet.get_account_or_error(account_id) + account = wallet.get_account_or_error(account_id) streams = account.get_streams stream_count = account.get_stream_count else: @@ -3727,7 +3725,7 @@ class Daemon(metaclass=JSONRPCServerType): """ wallet = self.wallet_manager.get_wallet_or_default(wallet_id) if account_id: - account: LBCAccount = wallet.get_account_or_error(account_id) + account = wallet.get_account_or_error(account_id) collections = account.get_collections collection_count = account.get_collection_count else: @@ -3854,7 +3852,7 @@ class Daemon(metaclass=JSONRPCServerType): """ wallet = self.wallet_manager.get_wallet_or_default(wallet_id) if account_id: - account: LBCAccount = wallet.get_account_or_error(account_id) + account = wallet.get_account_or_error(account_id) supports = account.get_supports support_count = account.get_support_count else: @@ -4002,7 +4000,7 @@ class Daemon(metaclass=JSONRPCServerType): """ wallet = self.wallet_manager.get_wallet_or_default(wallet_id) if account_id: - account: LBCAccount = wallet.get_account_or_error(account_id) + account = wallet.get_account_or_error(account_id) transactions = account.get_transaction_history transaction_count = account.get_transaction_history_count else: @@ -4696,7 +4694,7 @@ class Daemon(metaclass=JSONRPCServerType): if 'fee_currency' in kwargs or 'fee_amount' in kwargs: return claim_address - async def get_receiving_address(self, address: str, account: Optional[LBCAccount]) -> str: + async def get_receiving_address(self, address: str, account: Optional[Account]) -> str: if address is None and account is not None: return await account.receiving.get_or_create_usable_address() self.valid_address_or_error(address) diff --git a/lbry/extras/daemon/json_response_encoder.py b/lbry/extras/daemon/json_response_encoder.py index c6b944acd..864ce16b2 100644 --- a/lbry/extras/daemon/json_response_encoder.py +++ b/lbry/extras/daemon/json_response_encoder.py @@ -6,11 +6,9 @@ from json import JSONEncoder from google.protobuf.message import DecodeError -from lbry.wallet.client.wallet import Wallet -from lbry.wallet.client.bip32 import PubKey from lbry.schema.claim import Claim -from lbry.wallet.ledger import MainNetLedger, Account -from lbry.wallet.transaction import Transaction, Output +from lbry.wallet import Wallet, Ledger, Account, Transaction, Output +from lbry.wallet.bip32 import PubKey from lbry.wallet.dewies import dewies_to_lbc from lbry.stream.managed_stream import ManagedStream @@ -114,7 +112,7 @@ def encode_file_doc(): class JSONResponseEncoder(JSONEncoder): - def __init__(self, *args, ledger: MainNetLedger, include_protobuf=False, **kwargs): + def __init__(self, *args, ledger: Ledger, include_protobuf=False, **kwargs): super().__init__(*args, **kwargs) self.ledger = ledger self.include_protobuf = include_protobuf diff --git a/lbry/extras/daemon/storage.py b/lbry/extras/daemon/storage.py index 86be3b67f..268f249f6 100644 --- a/lbry/extras/daemon/storage.py +++ b/lbry/extras/daemon/storage.py @@ -5,7 +5,7 @@ import typing import asyncio import binascii import time -from lbry.wallet.client.basedatabase import SQLiteMixin +from lbry.wallet import SQLiteMixin from lbry.conf import Config from lbry.wallet.dewies import dewies_to_lbc, lbc_to_dewies from lbry.wallet.transaction import Transaction diff --git a/lbry/stream/stream_manager.py b/lbry/stream/stream_manager.py index d019502e9..fc79242a8 100644 --- a/lbry/stream/stream_manager.py +++ b/lbry/stream/stream_manager.py @@ -14,17 +14,15 @@ from lbry.stream.managed_stream import ManagedStream from lbry.schema.claim import Claim from lbry.schema.url import URL from lbry.wallet.dewies import dewies_to_lbc -from lbry.wallet.transaction import Output +from lbry.wallet import WalletManager, Wallet, Transaction, Output + if typing.TYPE_CHECKING: from lbry.conf import Config from lbry.blob.blob_manager import BlobManager from lbry.dht.node import Node from lbry.extras.daemon.analytics import AnalyticsManager from lbry.extras.daemon.storage import SQLiteStorage, StoredContentClaim - from lbry.wallet import LbryWalletManager - from lbry.wallet.transaction import Transaction from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager - from lbry.wallet.client.wallet import Wallet log = logging.getLogger(__name__) @@ -66,7 +64,7 @@ def path_or_none(p) -> Optional[str]: class StreamManager: def __init__(self, loop: asyncio.AbstractEventLoop, config: 'Config', blob_manager: 'BlobManager', - wallet_manager: 'LbryWalletManager', storage: 'SQLiteStorage', node: Optional['Node'], + wallet_manager: 'WalletManager', storage: 'SQLiteStorage', node: Optional['Node'], analytics_manager: Optional['AnalyticsManager'] = None): self.loop = loop self.config = config diff --git a/lbry/testcase.py b/lbry/testcase.py index d66c945cc..216e26002 100644 --- a/lbry/testcase.py +++ b/lbry/testcase.py @@ -14,18 +14,11 @@ from time import time from binascii import unhexlify from functools import partial -import lbry.wallet +from lbry.wallet import WalletManager, Wallet, Ledger, Account, Transaction from lbry.conf import Config -from lbry.wallet import LbryWalletManager -from lbry.wallet.account import Account +from lbry.wallet.util import satoshis_to_coins from lbry.wallet.orchstr8 import Conductor -from lbry.wallet.transaction import Transaction -from lbry.wallet.client.wallet import Wallet -from lbry.wallet.client.util import satoshis_to_coins from lbry.wallet.orchstr8.node import BlockchainNode, WalletNode -from lbry.wallet.client.baseledger import BaseLedger -from lbry.wallet.client.baseaccount import BaseAccount -from lbry.wallet.client.basemanager import BaseWalletManager from lbry.extras.daemon.Daemon import Daemon, jsonrpc_dumps_pretty from lbry.extras.daemon.Components import Component, WalletComponent @@ -215,25 +208,19 @@ class AdvanceTimeTestCase(AsyncioTestCase): class IntegrationTestCase(AsyncioTestCase): SEED = None - LEDGER = lbry.wallet - MANAGER = LbryWalletManager - ENABLE_SEGWIT = False def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.conductor: Optional[Conductor] = None self.blockchain: Optional[BlockchainNode] = None self.wallet_node: Optional[WalletNode] = None - self.manager: Optional[BaseWalletManager] = None - self.ledger: Optional[BaseLedger] = None + self.manager: Optional[WalletManager] = None + self.ledger: Optional[Ledger] = None self.wallet: Optional[Wallet] = None - self.account: Optional[BaseAccount] = None + self.account: Optional[Account] = None async def asyncSetUp(self): - self.conductor = Conductor( - ledger_module=self.LEDGER, manager_module=self.MANAGER, - enable_segwit=self.ENABLE_SEGWIT, seed=self.SEED - ) + self.conductor = Conductor(seed=self.SEED) await self.conductor.start_blockchain() self.addCleanup(self.conductor.stop_blockchain) await self.conductor.start_spv() @@ -317,14 +304,13 @@ class CommandTestCase(IntegrationTestCase): VERBOSITY = logging.WARN blob_lru_cache_size = 0 - account: Account - async def asyncSetUp(self): await super().asyncSetUp() logging.getLogger('lbry.blob_exchange').setLevel(self.VERBOSITY) logging.getLogger('lbry.daemon').setLevel(self.VERBOSITY) logging.getLogger('lbry.stream').setLevel(self.VERBOSITY) + logging.getLogger('lbry.wallet').setLevel(self.VERBOSITY) self.daemons = [] self.extra_wallet_nodes = [] @@ -419,9 +405,7 @@ class CommandTestCase(IntegrationTestCase): return txid async def on_transaction_dict(self, tx): - await self.ledger.wait( - self.ledger.transaction_class(unhexlify(tx['hex'])) - ) + await self.ledger.wait(Transaction(unhexlify(tx['hex']))) @staticmethod def get_all_addresses(tx): diff --git a/lbry/wallet/__init__.py b/lbry/wallet/__init__.py index 6da86b5e7..05d070f26 100644 --- a/lbry/wallet/__init__.py +++ b/lbry/wallet/__init__.py @@ -6,6 +6,12 @@ __node_url__ = ( ) __spvserver__ = 'lbry.wallet.server.coin.LBCRegTest' -from lbry.wallet.manager import LbryWalletManager -from lbry.wallet.network import Network -from lbry.wallet.ledger import MainNetLedger, RegTestLedger, TestNetLedger +from .wallet import Wallet, WalletStorage, TimestampedPreferences, ENCRYPT_ON_DISK +from .manager import WalletManager +from .network import Network +from .ledger import Ledger, RegTestLedger, TestNetLedger, BlockHeightEvent +from .account import Account, AddressManager, SingleKey, HierarchicalDeterministic +from .transaction import Transaction, Output, Input +from .script import OutputScript, InputScript +from .database import SQLiteMixin, Database +from .header import Headers diff --git a/lbry/wallet/account.py b/lbry/wallet/account.py index 4bed8eaa2..6991eb1d7 100644 --- a/lbry/wallet/account.py +++ b/lbry/wallet/account.py @@ -1,14 +1,28 @@ +import os +import time import json +import ecdsa import logging +import typing +import asyncio +import random + from functools import partial from hashlib import sha256 from string import hexdigits +from typing import Type, Dict, Tuple, Optional, Any, List -import ecdsa -from lbry.wallet.constants import CLAIM_TYPES, TXO_TYPES +from lbry.error import InvalidPasswordError +from lbry.crypto.crypt import aes_encrypt, aes_decrypt -from lbry.wallet.client.baseaccount import BaseAccount, HierarchicalDeterministic +from .bip32 import PrivateKey, PubKey, from_extended_key_string +from .mnemonic import Mnemonic +from .constants import COIN, CLAIM_TYPES, TXO_TYPES +from .transaction import Transaction, Input, Output +if typing.TYPE_CHECKING: + from .ledger import Ledger + from .wallet import Wallet log = logging.getLogger(__name__) @@ -22,22 +36,483 @@ def validate_claim_id(claim_id): raise Exception("Claim id is not hex encoded") -class Account(BaseAccount): +class AddressManager: - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.channel_keys = {} + name: str + + __slots__ = 'account', 'public_key', 'chain_number', 'address_generator_lock' + + def __init__(self, account, public_key, chain_number): + self.account = account + self.public_key = public_key + self.chain_number = chain_number + self.address_generator_lock = asyncio.Lock() + + @classmethod + def from_dict(cls, account: 'Account', d: dict) \ + -> Tuple['AddressManager', 'AddressManager']: + raise NotImplementedError + + @classmethod + def to_dict(cls, receiving: 'AddressManager', change: 'AddressManager') -> Dict: + d: Dict[str, Any] = {'name': cls.name} + receiving_dict = receiving.to_dict_instance() + if receiving_dict: + d['receiving'] = receiving_dict + change_dict = change.to_dict_instance() + if change_dict: + d['change'] = change_dict + return d + + def merge(self, d: dict): + pass + + def to_dict_instance(self) -> Optional[dict]: + raise NotImplementedError + + def _query_addresses(self, **constraints): + return self.account.ledger.db.get_addresses( + accounts=[self.account], + chain=self.chain_number, + **constraints + ) + + def get_private_key(self, index: int) -> PrivateKey: + raise NotImplementedError + + def get_public_key(self, index: int) -> PubKey: + raise NotImplementedError + + async def get_max_gap(self): + raise NotImplementedError + + async def ensure_address_gap(self): + raise NotImplementedError + + def get_address_records(self, only_usable: bool = False, **constraints): + raise NotImplementedError + + async def get_addresses(self, only_usable: bool = False, **constraints) -> List[str]: + records = await self.get_address_records(only_usable=only_usable, **constraints) + return [r['address'] for r in records] + + async def get_or_create_usable_address(self) -> str: + addresses = await self.get_addresses(only_usable=True, limit=10) + if addresses: + return random.choice(addresses) + addresses = await self.ensure_address_gap() + return addresses[0] + + +class HierarchicalDeterministic(AddressManager): + """ Implements simple version of Bitcoin Hierarchical Deterministic key management. """ + + name: str = "deterministic-chain" + + __slots__ = 'gap', 'maximum_uses_per_address' + + def __init__(self, account: 'Account', chain: int, gap: int, maximum_uses_per_address: int) -> None: + super().__init__(account, account.public_key.child(chain), chain) + self.gap = gap + self.maximum_uses_per_address = maximum_uses_per_address + + @classmethod + def from_dict(cls, account: 'Account', d: dict) -> Tuple[AddressManager, AddressManager]: + return ( + cls(account, 0, **d.get('receiving', {'gap': 20, 'maximum_uses_per_address': 1})), + cls(account, 1, **d.get('change', {'gap': 6, 'maximum_uses_per_address': 1})) + ) + + def merge(self, d: dict): + self.gap = d.get('gap', self.gap) + self.maximum_uses_per_address = d.get('maximum_uses_per_address', self.maximum_uses_per_address) + + def to_dict_instance(self): + return {'gap': self.gap, 'maximum_uses_per_address': self.maximum_uses_per_address} + + def get_private_key(self, index: int) -> PrivateKey: + return self.account.private_key.child(self.chain_number).child(index) + + def get_public_key(self, index: int) -> PubKey: + return self.account.public_key.child(self.chain_number).child(index) + + async def get_max_gap(self) -> int: + addresses = await self._query_addresses(order_by="n asc") + max_gap = 0 + current_gap = 0 + for address in addresses: + if address['used_times'] == 0: + current_gap += 1 + else: + max_gap = max(max_gap, current_gap) + current_gap = 0 + return max_gap + + async def ensure_address_gap(self) -> List[str]: + async with self.address_generator_lock: + addresses = await self._query_addresses(limit=self.gap, order_by="n desc") + + existing_gap = 0 + for address in addresses: + if address['used_times'] == 0: + existing_gap += 1 + else: + break + + if existing_gap == self.gap: + return [] + + start = addresses[0]['pubkey'].n+1 if addresses else 0 + end = start + (self.gap - existing_gap) + new_keys = await self._generate_keys(start, end-1) + await self.account.ledger.announce_addresses(self, new_keys) + return new_keys + + async def _generate_keys(self, start: int, end: int) -> List[str]: + if not self.address_generator_lock.locked(): + raise RuntimeError('Should not be called outside of address_generator_lock.') + keys = [self.public_key.child(index) for index in range(start, end+1)] + await self.account.ledger.db.add_keys(self.account, self.chain_number, keys) + return [key.address for key in keys] + + def get_address_records(self, only_usable: bool = False, **constraints): + if only_usable: + constraints['used_times__lt'] = self.maximum_uses_per_address + if 'order_by' not in constraints: + constraints['order_by'] = "used_times asc, n asc" + return self._query_addresses(**constraints) + + +class SingleKey(AddressManager): + """ Single Key address manager always returns the same address for all operations. """ + + name: str = "single-address" + + __slots__ = () + + @classmethod + def from_dict(cls, account: 'Account', d: dict) \ + -> Tuple[AddressManager, AddressManager]: + same_address_manager = cls(account, account.public_key, 0) + return same_address_manager, same_address_manager + + def to_dict_instance(self): + return None + + def get_private_key(self, index: int) -> PrivateKey: + return self.account.private_key + + def get_public_key(self, index: int) -> PubKey: + return self.account.public_key + + async def get_max_gap(self) -> int: + return 0 + + async def ensure_address_gap(self) -> List[str]: + async with self.address_generator_lock: + exists = await self.get_address_records() + if not exists: + await self.account.ledger.db.add_keys(self.account, self.chain_number, [self.public_key]) + new_keys = [self.public_key.address] + await self.account.ledger.announce_addresses(self, new_keys) + return new_keys + return [] + + def get_address_records(self, only_usable: bool = False, **constraints): + return self._query_addresses(**constraints) + + +class Account: + + mnemonic_class = Mnemonic + private_key_class = PrivateKey + public_key_class = PubKey + address_generators: Dict[str, Type[AddressManager]] = { + SingleKey.name: SingleKey, + HierarchicalDeterministic.name: HierarchicalDeterministic, + } + + def __init__(self, ledger: 'Ledger', wallet: 'Wallet', name: str, + seed: str, private_key_string: str, encrypted: bool, + private_key: Optional[PrivateKey], public_key: PubKey, + address_generator: dict, modified_on: float, channel_keys: dict) -> None: + self.ledger = ledger + self.wallet = wallet + self.id = public_key.address + self.name = name + self.seed = seed + self.modified_on = modified_on + self.private_key_string = private_key_string + self.init_vectors: Dict[str, bytes] = {} + self.encrypted = encrypted + self.private_key = private_key + self.public_key = public_key + generator_name = address_generator.get('name', HierarchicalDeterministic.name) + self.address_generator = self.address_generators[generator_name] + self.receiving, self.change = self.address_generator.from_dict(self, address_generator) + self.address_managers = {am.chain_number: am for am in {self.receiving, self.change}} + self.channel_keys = channel_keys + ledger.add_account(self) + wallet.add_account(self) + + def get_init_vector(self, key) -> Optional[bytes]: + init_vector = self.init_vectors.get(key, None) + if init_vector is None: + init_vector = self.init_vectors[key] = os.urandom(16) + return init_vector + + @classmethod + def generate(cls, ledger: 'Ledger', wallet: 'Wallet', + name: str = None, address_generator: dict = None): + return cls.from_dict(ledger, wallet, { + 'name': name, + 'seed': cls.mnemonic_class().make_seed(), + 'address_generator': address_generator or {} + }) + + @classmethod + def get_private_key_from_seed(cls, ledger: 'Ledger', seed: str, password: str): + return cls.private_key_class.from_seed( + ledger, cls.mnemonic_class.mnemonic_to_seed(seed, password or 'lbryum') + ) + + @classmethod + def keys_from_dict(cls, ledger: 'Ledger', d: dict) \ + -> Tuple[str, Optional[PrivateKey], PubKey]: + seed = d.get('seed', '') + private_key_string = d.get('private_key', '') + private_key = None + public_key = None + encrypted = d.get('encrypted', False) + if not encrypted: + if seed: + private_key = cls.get_private_key_from_seed(ledger, seed, '') + public_key = private_key.public_key + elif private_key_string: + private_key = from_extended_key_string(ledger, private_key_string) + public_key = private_key.public_key + if public_key is None: + public_key = from_extended_key_string(ledger, d['public_key']) + return seed, private_key, public_key + + @classmethod + def from_dict(cls, ledger: 'Ledger', wallet: 'Wallet', d: dict): + seed, private_key, public_key = cls.keys_from_dict(ledger, d) + name = d.get('name') + if not name: + name = f'Account #{public_key.address}' + return cls( + ledger=ledger, + wallet=wallet, + name=name, + seed=seed, + private_key_string=d.get('private_key', ''), + encrypted=d.get('encrypted', False), + private_key=private_key, + public_key=public_key, + address_generator=d.get('address_generator', {}), + modified_on=d.get('modified_on', time.time()), + channel_keys=d.get('certificates', {}) + ) + + def to_dict(self, encrypt_password: str = None, include_channel_keys: bool = True): + private_key_string, seed = self.private_key_string, self.seed + if not self.encrypted and self.private_key: + private_key_string = self.private_key.extended_key_string() + if not self.encrypted and encrypt_password: + if private_key_string: + private_key_string = aes_encrypt( + encrypt_password, private_key_string, self.get_init_vector('private_key') + ) + if seed: + seed = aes_encrypt(encrypt_password, self.seed, self.get_init_vector('seed')) + d = { + 'ledger': self.ledger.get_id(), + 'name': self.name, + 'seed': seed, + 'encrypted': bool(self.encrypted or encrypt_password), + 'private_key': private_key_string, + 'public_key': self.public_key.extended_key_string(), + 'address_generator': self.address_generator.to_dict(self.receiving, self.change), + 'modified_on': self.modified_on + } + if include_channel_keys: + d['certificates'] = self.channel_keys + return d + + def merge(self, d: dict): + if d.get('modified_on', 0) > self.modified_on: + self.name = d['name'] + self.modified_on = d.get('modified_on', time.time()) + assert self.address_generator.name == d['address_generator']['name'] + for chain_name in ('change', 'receiving'): + if chain_name in d['address_generator']: + chain_object = getattr(self, chain_name) + chain_object.merge(d['address_generator'][chain_name]) + self.channel_keys.update(d.get('certificates', {})) @property def hash(self) -> bytes: + assert not self.encrypted, "Cannot hash an encrypted account." h = sha256(json.dumps(self.to_dict(include_channel_keys=False)).encode()) for cert in sorted(self.channel_keys.keys()): h.update(cert.encode()) return h.digest() - def merge(self, d: dict): - super().merge(d) - self.channel_keys.update(d.get('certificates', {})) + async def get_details(self, show_seed=False, **kwargs): + satoshis = await self.get_balance(**kwargs) + details = { + 'id': self.id, + 'name': self.name, + 'ledger': self.ledger.get_id(), + 'coins': round(satoshis/COIN, 2), + 'satoshis': satoshis, + 'encrypted': self.encrypted, + 'public_key': self.public_key.extended_key_string(), + 'address_generator': self.address_generator.to_dict(self.receiving, self.change) + } + if show_seed: + details['seed'] = self.seed + details['certificates'] = len(self.channel_keys) + return details + + def decrypt(self, password: str) -> bool: + assert self.encrypted, "Key is not encrypted." + try: + seed = self._decrypt_seed(password) + except (ValueError, InvalidPasswordError): + return False + try: + private_key = self._decrypt_private_key_string(password) + except (TypeError, ValueError, InvalidPasswordError): + return False + self.seed = seed + self.private_key = private_key + self.private_key_string = "" + self.encrypted = False + return True + + def _decrypt_private_key_string(self, password: str) -> Optional[PrivateKey]: + if not self.private_key_string: + return None + private_key_string, self.init_vectors['private_key'] = aes_decrypt(password, self.private_key_string) + if not private_key_string: + return None + return from_extended_key_string( + self.ledger, private_key_string + ) + + def _decrypt_seed(self, password: str) -> str: + if not self.seed: + return "" + seed, self.init_vectors['seed'] = aes_decrypt(password, self.seed) + if not seed: + return "" + try: + Mnemonic().mnemonic_decode(seed) + except IndexError: + # failed to decode the seed, this either means it decrypted and is invalid + # or that we hit an edge case where an incorrect password gave valid padding + raise ValueError("Failed to decode seed.") + return seed + + def encrypt(self, password: str) -> bool: + assert not self.encrypted, "Key is already encrypted." + if self.seed: + self.seed = aes_encrypt(password, self.seed, self.get_init_vector('seed')) + if isinstance(self.private_key, PrivateKey): + self.private_key_string = aes_encrypt( + password, self.private_key.extended_key_string(), self.get_init_vector('private_key') + ) + self.private_key = None + self.encrypted = True + return True + + async def ensure_address_gap(self): + addresses = [] + for address_manager in self.address_managers.values(): + new_addresses = await address_manager.ensure_address_gap() + addresses.extend(new_addresses) + return addresses + + async def get_addresses(self, **constraints) -> List[str]: + rows = await self.ledger.db.select_addresses('address', accounts=[self], **constraints) + return [r[0] for r in rows] + + def get_address_records(self, **constraints): + return self.ledger.db.get_addresses(accounts=[self], **constraints) + + def get_address_count(self, **constraints): + return self.ledger.db.get_address_count(accounts=[self], **constraints) + + def get_private_key(self, chain: int, index: int) -> PrivateKey: + assert not self.encrypted, "Cannot get private key on encrypted wallet account." + return self.address_managers[chain].get_private_key(index) + + def get_public_key(self, chain: int, index: int) -> PubKey: + return self.address_managers[chain].get_public_key(index) + + def get_balance(self, confirmations: int = 0, include_claims=False, **constraints): + if not include_claims: + constraints.update({'txo_type__in': (0, TXO_TYPES['purchase'])}) + if confirmations > 0: + height = self.ledger.headers.height - (confirmations-1) + constraints.update({'height__lte': height, 'height__gt': 0}) + return self.ledger.db.get_balance(accounts=[self], **constraints) + + async def get_max_gap(self): + change_gap = await self.change.get_max_gap() + receiving_gap = await self.receiving.get_max_gap() + return { + 'max_change_gap': change_gap, + 'max_receiving_gap': receiving_gap, + } + + def get_utxos(self, **constraints): + return self.ledger.get_utxos(wallet=self.wallet, accounts=[self], **constraints) + + def get_utxo_count(self, **constraints): + return self.ledger.get_utxo_count(wallet=self.wallet, accounts=[self], **constraints) + + def get_transactions(self, **constraints): + return self.ledger.get_transactions(wallet=self.wallet, accounts=[self], **constraints) + + def get_transaction_count(self, **constraints): + return self.ledger.get_transaction_count(wallet=self.wallet, accounts=[self], **constraints) + + async def fund(self, to_account, amount=None, everything=False, + outputs=1, broadcast=False, **constraints): + assert self.ledger == to_account.ledger, 'Can only transfer between accounts of the same ledger.' + if everything: + utxos = await self.get_utxos(**constraints) + await self.ledger.reserve_outputs(utxos) + tx = await Transaction.create( + inputs=[Input.spend(txo) for txo in utxos], + outputs=[], + funding_accounts=[self], + change_account=to_account + ) + elif amount > 0: + to_address = await to_account.change.get_or_create_usable_address() + to_hash160 = to_account.ledger.address_to_hash160(to_address) + tx = await Transaction.create( + inputs=[], + outputs=[ + Output.pay_pubkey_hash(amount//outputs, to_hash160) + for _ in range(outputs) + ], + funding_accounts=[self], + change_account=self + ) + else: + raise ValueError('An amount is required.') + + if broadcast: + await self.ledger.broadcast(tx) + else: + await self.ledger.release_tx(tx) + + return tx def add_channel_private_key(self, private_key): public_key_bytes = private_key.get_verifying_key().to_der() @@ -81,11 +556,6 @@ class Account(BaseAccount): if gap_changed: self.wallet.save() - def get_balance(self, confirmations=0, include_claims=False, **constraints): - if not include_claims: - constraints.update({'txo_type__in': (0, TXO_TYPES['purchase'])}) - return super().get_balance(confirmations, **constraints) - async def get_detailed_balance(self, confirmations=0, reserved_subtotals=False): tips_balance, supports_balance, claims_balance = 0, 0, 0 get_total_balance = partial(self.get_balance, confirmations=confirmations, include_claims=True) @@ -116,29 +586,6 @@ class Account(BaseAccount): } if reserved_subtotals else None } - @classmethod - def get_private_key_from_seed(cls, ledger, seed: str, password: str): - return super().get_private_key_from_seed( - ledger, seed, password or 'lbryum' - ) - - @classmethod - def from_dict(cls, ledger, wallet, d: dict) -> 'Account': - account = super().from_dict(ledger, wallet, d) - account.channel_keys = d.get('certificates', {}) - return account - - def to_dict(self, encrypt_password: str = None, include_channel_keys: bool = True): - d = super().to_dict(encrypt_password) - if include_channel_keys: - d['certificates'] = self.channel_keys - return d - - async def get_details(self, **kwargs): - details = await super().get_details(**kwargs) - details['certificates'] = len(self.channel_keys) - return details - def get_transaction_history(self, **constraints): return self.ledger.get_transaction_history(wallet=self.wallet, accounts=[self], **constraints) diff --git a/lbry/wallet/bip32.py b/lbry/wallet/bip32.py index fc9a590d1..3e6bc3a7f 100644 --- a/lbry/wallet/bip32.py +++ b/lbry/wallet/bip32.py @@ -2,7 +2,7 @@ from coincurve import PublicKey, PrivateKey as _PrivateKey from lbry.crypto.hash import hmac_sha512, hash160, double_sha256 from lbry.crypto.base58 import Base58 -from lbry.wallet.client.util import cachedproperty +from .util import cachedproperty class DerivationError(Exception): diff --git a/lbry/wallet/coinselection.py b/lbry/wallet/coinselection.py index 253d31a21..63bbb6977 100644 --- a/lbry/wallet/coinselection.py +++ b/lbry/wallet/coinselection.py @@ -1,7 +1,7 @@ from random import Random from typing import List -from lbry.wallet.client import basetransaction +from lbry.wallet.transaction import OutputEffectiveAmountEstimator MAXIMUM_TRIES = 100000 @@ -25,8 +25,8 @@ class CoinSelector: self.random.seed(seed, version=1) def select( - self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], - strategy_name: str = None) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]: + self, txos: List[OutputEffectiveAmountEstimator], + strategy_name: str = None) -> List[OutputEffectiveAmountEstimator]: if not txos: return [] available = sum(c.effective_amount for c in txos) @@ -35,16 +35,16 @@ class CoinSelector: return getattr(self, strategy_name or "standard")(txos, available) @strategy - def prefer_confirmed(self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], - available: int) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]: + def prefer_confirmed(self, txos: List[OutputEffectiveAmountEstimator], + available: int) -> List[OutputEffectiveAmountEstimator]: return ( self.only_confirmed(txos, available) or self.standard(txos, available) ) @strategy - def only_confirmed(self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], - _) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]: + def only_confirmed(self, txos: List[OutputEffectiveAmountEstimator], + _) -> List[OutputEffectiveAmountEstimator]: confirmed = [t for t in txos if t.txo.tx_ref and t.txo.tx_ref.height > 0] if not confirmed: return [] @@ -54,8 +54,8 @@ class CoinSelector: return self.standard(confirmed, confirmed_available) @strategy - def standard(self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], - available: int) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]: + def standard(self, txos: List[OutputEffectiveAmountEstimator], + available: int) -> List[OutputEffectiveAmountEstimator]: return ( self.branch_and_bound(txos, available) or self.closest_match(txos, available) or @@ -63,8 +63,8 @@ class CoinSelector: ) @strategy - def branch_and_bound(self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], - available: int) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]: + def branch_and_bound(self, txos: List[OutputEffectiveAmountEstimator], + available: int) -> List[OutputEffectiveAmountEstimator]: # see bitcoin implementation for more info: # https://github.com/bitcoin/bitcoin/blob/master/src/wallet/coinselection.cpp @@ -123,8 +123,8 @@ class CoinSelector: return [] @strategy - def closest_match(self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], - _) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]: + def closest_match(self, txos: List[OutputEffectiveAmountEstimator], + _) -> List[OutputEffectiveAmountEstimator]: """ Pick one UTXOs that is larger than the target but with the smallest change. """ target = self.target + self.cost_of_change smallest_change = None @@ -137,8 +137,8 @@ class CoinSelector: return [best_match] if best_match else [] @strategy - def random_draw(self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], - _) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]: + def random_draw(self, txos: List[OutputEffectiveAmountEstimator], + _) -> List[OutputEffectiveAmountEstimator]: """ Accumulate UTXOs at random until there is enough to cover the target. """ target = self.target + self.cost_of_change self.random.shuffle(txos, self.random.random) diff --git a/lbry/wallet/constants.py b/lbry/wallet/constants.py index 52f3454b6..5935a6717 100644 --- a/lbry/wallet/constants.py +++ b/lbry/wallet/constants.py @@ -1,3 +1,10 @@ +NULL_HASH32 = b'\x00'*32 + +CENT = 1000000 +COIN = 100*CENT + +TIMEOUT = 30.0 + TXO_TYPES = { "stream": 1, "channel": 2, diff --git a/lbry/wallet/database.py b/lbry/wallet/database.py index 948e063bb..bcfaa8ce4 100644 --- a/lbry/wallet/database.py +++ b/lbry/wallet/database.py @@ -1,14 +1,321 @@ -from typing import List +import logging +import asyncio +import sqlite3 -from lbry.wallet.client.basedatabase import BaseDatabase +from binascii import hexlify +from concurrent.futures.thread import ThreadPoolExecutor +from typing import Tuple, List, Union, Callable, Any, Awaitable, Iterable, Dict, Optional -from lbry.wallet.transaction import Output -from lbry.wallet.constants import TXO_TYPES, CLAIM_TYPES +from .bip32 import PubKey +from .transaction import Transaction, Output, OutputScript, TXRefImmutable +from .constants import TXO_TYPES, CLAIM_TYPES -class WalletDatabase(BaseDatabase): +log = logging.getLogger(__name__) +sqlite3.enable_callback_tracebacks(True) - SCHEMA_VERSION = f"{BaseDatabase.SCHEMA_VERSION}+1" + +class AIOSQLite: + + def __init__(self): + # has to be single threaded as there is no mapping of thread:connection + self.executor = ThreadPoolExecutor(max_workers=1) + self.connection: sqlite3.Connection = None + self._closing = False + self.query_count = 0 + + @classmethod + async def connect(cls, path: Union[bytes, str], *args, **kwargs): + sqlite3.enable_callback_tracebacks(True) + def _connect(): + return sqlite3.connect(path, *args, **kwargs) + db = cls() + db.connection = await asyncio.get_event_loop().run_in_executor(db.executor, _connect) + return db + + async def close(self): + if self._closing: + return + self._closing = True + await asyncio.get_event_loop().run_in_executor(self.executor, self.connection.close) + self.executor.shutdown(wait=True) + self.connection = None + + def executemany(self, sql: str, params: Iterable): + params = params if params is not None else [] + # this fetchall is needed to prevent SQLITE_MISUSE + return self.run(lambda conn: conn.executemany(sql, params).fetchall()) + + def executescript(self, script: str) -> Awaitable: + return self.run(lambda conn: conn.executescript(script)) + + def execute_fetchall(self, sql: str, parameters: Iterable = None) -> Awaitable[Iterable[sqlite3.Row]]: + parameters = parameters if parameters is not None else [] + return self.run(lambda conn: conn.execute(sql, parameters).fetchall()) + + def execute_fetchone(self, sql: str, parameters: Iterable = None) -> Awaitable[Iterable[sqlite3.Row]]: + parameters = parameters if parameters is not None else [] + return self.run(lambda conn: conn.execute(sql, parameters).fetchone()) + + def execute(self, sql: str, parameters: Iterable = None) -> Awaitable[sqlite3.Cursor]: + parameters = parameters if parameters is not None else [] + return self.run(lambda conn: conn.execute(sql, parameters)) + + def run(self, fun, *args, **kwargs) -> Awaitable: + return asyncio.get_event_loop().run_in_executor( + self.executor, lambda: self.__run_transaction(fun, *args, **kwargs) + ) + + def __run_transaction(self, fun: Callable[[sqlite3.Connection, Any, Any], Any], *args, **kwargs): + self.connection.execute('begin') + try: + self.query_count += 1 + result = fun(self.connection, *args, **kwargs) # type: ignore + self.connection.commit() + return result + except (Exception, OSError) as e: + log.exception('Error running transaction:', exc_info=e) + self.connection.rollback() + log.warning("rolled back") + raise + + def run_with_foreign_keys_disabled(self, fun, *args, **kwargs) -> Awaitable: + return asyncio.get_event_loop().run_in_executor( + self.executor, self.__run_transaction_with_foreign_keys_disabled, fun, args, kwargs + ) + + def __run_transaction_with_foreign_keys_disabled(self, + fun: Callable[[sqlite3.Connection, Any, Any], Any], + args, kwargs): + foreign_keys_enabled, = self.connection.execute("pragma foreign_keys").fetchone() + if not foreign_keys_enabled: + raise sqlite3.IntegrityError("foreign keys are disabled, use `AIOSQLite.run` instead") + try: + self.connection.execute('pragma foreign_keys=off').fetchone() + return self.__run_transaction(fun, *args, **kwargs) + finally: + self.connection.execute('pragma foreign_keys=on').fetchone() + + +def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''): + sql, values = [], {} + for key, constraint in constraints.items(): + tag = '0' + if '#' in key: + key, tag = key[:key.index('#')], key[key.index('#')+1:] + col, op, key = key, '=', key.replace('.', '_') + if not key: + sql.append(constraint) + continue + if key.startswith('$'): + values[key] = constraint + continue + if key.endswith('__not'): + col, op = col[:-len('__not')], '!=' + elif key.endswith('__is_null'): + col = col[:-len('__is_null')] + sql.append(f'{col} IS NULL') + continue + if key.endswith('__is_not_null'): + col = col[:-len('__is_not_null')] + sql.append(f'{col} IS NOT NULL') + continue + if key.endswith('__lt'): + col, op = col[:-len('__lt')], '<' + elif key.endswith('__lte'): + col, op = col[:-len('__lte')], '<=' + elif key.endswith('__gt'): + col, op = col[:-len('__gt')], '>' + elif key.endswith('__gte'): + col, op = col[:-len('__gte')], '>=' + elif key.endswith('__like'): + col, op = col[:-len('__like')], 'LIKE' + elif key.endswith('__not_like'): + col, op = col[:-len('__not_like')], 'NOT LIKE' + elif key.endswith('__in') or key.endswith('__not_in'): + if key.endswith('__in'): + col, op = col[:-len('__in')], 'IN' + else: + col, op = col[:-len('__not_in')], 'NOT IN' + if constraint: + if isinstance(constraint, (list, set, tuple)): + keys = [] + for i, val in enumerate(constraint): + keys.append(f':{key}{tag}_{i}') + values[f'{key}{tag}_{i}'] = val + sql.append(f'{col} {op} ({", ".join(keys)})') + elif isinstance(constraint, str): + sql.append(f'{col} {op} ({constraint})') + else: + raise ValueError(f"{col} requires a list, set or string as constraint value.") + continue + elif key.endswith('__any') or key.endswith('__or'): + where, subvalues = constraints_to_sql(constraint, ' OR ', key+tag+'_') + sql.append(f'({where})') + values.update(subvalues) + continue + if key.endswith('__and'): + where, subvalues = constraints_to_sql(constraint, ' AND ', key+tag+'_') + sql.append(f'({where})') + values.update(subvalues) + continue + sql.append(f'{col} {op} :{prepend_key}{key}{tag}') + values[prepend_key+key+tag] = constraint + return joiner.join(sql) if sql else '', values + + +def query(select, **constraints) -> Tuple[str, Dict[str, Any]]: + sql = [select] + limit = constraints.pop('limit', None) + offset = constraints.pop('offset', None) + order_by = constraints.pop('order_by', None) + + accounts = constraints.pop('accounts', []) + if accounts: + constraints['account__in'] = [a.public_key.address for a in accounts] + + where, values = constraints_to_sql(constraints) + if where: + sql.append('WHERE') + sql.append(where) + + if order_by: + sql.append('ORDER BY') + if isinstance(order_by, str): + sql.append(order_by) + elif isinstance(order_by, list): + sql.append(', '.join(order_by)) + else: + raise ValueError("order_by must be string or list") + + if limit is not None: + sql.append(f'LIMIT {limit}') + + if offset is not None: + sql.append(f'OFFSET {offset}') + + return ' '.join(sql), values + + +def interpolate(sql, values): + for k in sorted(values.keys(), reverse=True): + value = values[k] + if isinstance(value, bytes): + value = f"X'{hexlify(value).decode()}'" + elif isinstance(value, str): + value = f"'{value}'" + else: + value = str(value) + sql = sql.replace(f":{k}", value) + return sql + + +def rows_to_dict(rows, fields): + if rows: + return [dict(zip(fields, r)) for r in rows] + else: + return [] + + +class SQLiteMixin: + + SCHEMA_VERSION: Optional[str] = None + CREATE_TABLES_QUERY: str + MAX_QUERY_VARIABLES = 900 + + CREATE_VERSION_TABLE = """ + create table if not exists version ( + version text + ); + """ + + def __init__(self, path): + self._db_path = path + self.db: AIOSQLite = None + self.ledger = None + + async def open(self): + log.info("connecting to database: %s", self._db_path) + self.db = await AIOSQLite.connect(self._db_path, isolation_level=None) + if self.SCHEMA_VERSION: + tables = [t[0] for t in await self.db.execute_fetchall( + "SELECT name FROM sqlite_master WHERE type='table';" + )] + if tables: + if 'version' in tables: + version = await self.db.execute_fetchone("SELECT version FROM version LIMIT 1;") + if version == (self.SCHEMA_VERSION,): + return + await self.db.executescript('\n'.join( + f"DROP TABLE {table};" for table in tables + )) + await self.db.execute(self.CREATE_VERSION_TABLE) + await self.db.execute("INSERT INTO version VALUES (?)", (self.SCHEMA_VERSION,)) + await self.db.executescript(self.CREATE_TABLES_QUERY) + + async def close(self): + await self.db.close() + + @staticmethod + def _insert_sql(table: str, data: dict, ignore_duplicate: bool = False, + replace: bool = False) -> Tuple[str, List]: + columns, values = [], [] + for column, value in data.items(): + columns.append(column) + values.append(value) + policy = "" + if ignore_duplicate: + policy = " OR IGNORE" + if replace: + policy = " OR REPLACE" + sql = "INSERT{} INTO {} ({}) VALUES ({})".format( + policy, table, ', '.join(columns), ', '.join(['?'] * len(values)) + ) + return sql, values + + @staticmethod + def _update_sql(table: str, data: dict, where: str, + constraints: Union[list, tuple]) -> Tuple[str, list]: + columns, values = [], [] + for column, value in data.items(): + columns.append(f"{column} = ?") + values.append(value) + values.extend(constraints) + sql = "UPDATE {} SET {} WHERE {}".format( + table, ', '.join(columns), where + ) + return sql, values + + +class Database(SQLiteMixin): + + SCHEMA_VERSION = "1.1" + + PRAGMAS = """ + pragma journal_mode=WAL; + """ + + CREATE_ACCOUNT_TABLE = """ + create table if not exists account_address ( + account text not null, + address text not null, + chain integer not null, + pubkey blob not null, + chain_code blob not null, + n integer not null, + depth integer not null, + primary key (account, address) + ); + create index if not exists address_account_idx on account_address (address, account); + """ + + CREATE_PUBKEY_ADDRESS_TABLE = """ + create table if not exists pubkey_address ( + address text primary key, + history text, + used_times integer not null default 0 + ); + """ CREATE_TX_TABLE = """ create table if not exists tx ( @@ -42,25 +349,35 @@ class WalletDatabase(BaseDatabase): create index if not exists txo_txo_type_idx on txo (txo_type); """ + CREATE_TXI_TABLE = """ + create table if not exists txi ( + txid text references tx, + txoid text references txo, + address text references pubkey_address + ); + create index if not exists txi_address_idx on txi (address); + create index if not exists txi_txoid_idx on txi (txoid); + """ + CREATE_TABLES_QUERY = ( - BaseDatabase.PRAGMAS + - BaseDatabase.CREATE_ACCOUNT_TABLE + - BaseDatabase.CREATE_PUBKEY_ADDRESS_TABLE + - CREATE_TX_TABLE + - CREATE_TXO_TABLE + - BaseDatabase.CREATE_TXI_TABLE + PRAGMAS + + CREATE_ACCOUNT_TABLE + + CREATE_PUBKEY_ADDRESS_TABLE + + CREATE_TX_TABLE + + CREATE_TXO_TABLE + + CREATE_TXI_TABLE ) - def tx_to_row(self, tx): - row = super().tx_to_row(tx) - txos = tx.outputs - if len(txos) >= 2 and txos[1].can_decode_purchase_data: - txos[0].purchase = txos[1] - row['purchased_claim_id'] = txos[1].purchase_data.claim_id - return row - - def txo_to_row(self, tx, address, txo): - row = super().txo_to_row(tx, address, txo) + @staticmethod + def txo_to_row(tx, address, txo): + row = { + 'txid': tx.id, + 'txoid': txo.id, + 'address': address, + 'position': txo.position, + 'amount': txo.amount, + 'script': sqlite3.Binary(txo.script.source) + } if txo.is_claim: if txo.can_decode_claim: row['txo_type'] = TXO_TYPES.get(txo.claim.claim_type, TXO_TYPES['stream']) @@ -76,39 +393,212 @@ class WalletDatabase(BaseDatabase): row['claim_name'] = txo.claim_name return row - async def get_transactions(self, **constraints): - txs = await super().get_transactions(**constraints) + @staticmethod + def tx_to_row(tx): + row = { + 'txid': tx.id, + 'raw': sqlite3.Binary(tx.raw), + 'height': tx.height, + 'position': tx.position, + 'is_verified': tx.is_verified + } + txos = tx.outputs + if len(txos) >= 2 and txos[1].can_decode_purchase_data: + txos[0].purchase = txos[1] + row['purchased_claim_id'] = txos[1].purchase_data.claim_id + return row + + async def insert_transaction(self, tx): + await self.db.execute_fetchall(*self._insert_sql('tx', self.tx_to_row(tx))) + + async def update_transaction(self, tx): + await self.db.execute_fetchall(*self._update_sql("tx", { + 'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified + }, 'txid = ?', (tx.id,))) + + def _transaction_io(self, conn: sqlite3.Connection, tx: Transaction, address, txhash, history): + conn.execute(*self._insert_sql('tx', self.tx_to_row(tx), replace=True)) + + for txo in tx.outputs: + if txo.script.is_pay_pubkey_hash and txo.pubkey_hash == txhash: + conn.execute(*self._insert_sql( + "txo", self.txo_to_row(tx, address, txo), ignore_duplicate=True + )).fetchall() + elif txo.script.is_pay_script_hash: + # TODO: implement script hash payments + log.warning('Database.save_transaction_io: pay script hash is not implemented!') + + for txi in tx.inputs: + if txi.txo_ref.txo is not None: + txo = txi.txo_ref.txo + if txo.has_address and txo.get_address(self.ledger) == address: + conn.execute(*self._insert_sql("txi", { + 'txid': tx.id, + 'txoid': txo.id, + 'address': address, + }, ignore_duplicate=True)).fetchall() + + conn.execute( + "UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?", + (history, history.count(':') // 2, address) + ) + + def save_transaction_io(self, tx: Transaction, address, txhash, history): + return self.db.run(self._transaction_io, tx, address, txhash, history) + + def save_transaction_io_batch(self, txs: Iterable[Transaction], address, txhash, history): + def __many(conn): + for tx in txs: + self._transaction_io(conn, tx, address, txhash, history) + return self.db.run(__many) + + async def reserve_outputs(self, txos, is_reserved=True): + txoids = ((is_reserved, txo.id) for txo in txos) + await self.db.executemany("UPDATE txo SET is_reserved = ? WHERE txoid = ?", txoids) + + async def release_outputs(self, txos): + await self.reserve_outputs(txos, is_reserved=False) + + async def rewind_blockchain(self, above_height): # pylint: disable=no-self-use + # TODO: + # 1. delete transactions above_height + # 2. update address histories removing deleted TXs + return True + + async def select_transactions(self, cols, accounts=None, **constraints): + if not {'txid', 'txid__in'}.intersection(constraints): + assert accounts, "'accounts' argument required when no 'txid' constraint is present" + constraints.update({ + f'$account{i}': a.public_key.address for i, a in enumerate(accounts) + }) + account_values = ', '.join([f':$account{i}' for i in range(len(accounts))]) + where = f" WHERE account_address.account IN ({account_values})" + constraints['txid__in'] = f""" + SELECT txo.txid FROM txo JOIN account_address USING (address) {where} + UNION + SELECT txi.txid FROM txi JOIN account_address USING (address) {where} + """ + return await self.db.execute_fetchall( + *query(f"SELECT {cols} FROM tx", **constraints) + ) + + async def get_transactions(self, wallet=None, **constraints): + tx_rows = await self.select_transactions( + 'txid, raw, height, position, is_verified', + order_by=constraints.pop('order_by', ["height=0 DESC", "height DESC", "position DESC"]), + **constraints + ) + + if not tx_rows: + return [] + + txids, txs, txi_txoids = [], [], [] + for row in tx_rows: + txids.append(row[0]) + txs.append(Transaction( + raw=row[1], height=row[2], position=row[3], is_verified=bool(row[4]) + )) + for txi in txs[-1].inputs: + txi_txoids.append(txi.txo_ref.id) + + step = self.MAX_QUERY_VARIABLES + annotated_txos = {} + for offset in range(0, len(txids), step): + annotated_txos.update({ + txo.id: txo for txo in + (await self.get_txos( + wallet=wallet, + txid__in=txids[offset:offset+step], + )) + }) + + referenced_txos = {} + for offset in range(0, len(txi_txoids), step): + referenced_txos.update({ + txo.id: txo for txo in + (await self.get_txos( + wallet=wallet, + txoid__in=txi_txoids[offset:offset+step], + )) + }) + + for tx in txs: + for txi in tx.inputs: + txo = referenced_txos.get(txi.txo_ref.id) + if txo: + txi.txo_ref = txo.ref + for txo in tx.outputs: + _txo = annotated_txos.get(txo.id) + if _txo: + txo.update_annotations(_txo) + else: + txo.update_annotations(None) + for tx in txs: txos = tx.outputs if len(txos) >= 2 and txos[1].can_decode_purchase_data: txos[0].purchase = txos[1] + return txs - @staticmethod - def constrain_purchases(constraints): - accounts = constraints.pop('accounts', None) - assert accounts, "'accounts' argument required to find purchases" - if not {'purchased_claim_id', 'purchased_claim_id__in'}.intersection(constraints): - constraints['purchased_claim_id__is_not_null'] = True - constraints.update({ - f'$account{i}': a.public_key.address for i, a in enumerate(accounts) - }) - account_values = ', '.join([f':$account{i}' for i in range(len(accounts))]) - constraints['txid__in'] = f""" - SELECT txid FROM txi JOIN account_address USING (address) - WHERE account_address.account IN ({account_values}) - """ + async def get_transaction_count(self, **constraints): + constraints.pop('wallet', None) + constraints.pop('offset', None) + constraints.pop('limit', None) + constraints.pop('order_by', None) + count = await self.select_transactions('count(*)', **constraints) + return count[0][0] - async def get_purchases(self, **constraints): - self.constrain_purchases(constraints) - return [tx.outputs[0] for tx in await self.get_transactions(**constraints)] + async def get_transaction(self, **constraints): + txs = await self.get_transactions(limit=1, **constraints) + if txs: + return txs[0] - def get_purchase_count(self, **constraints): - self.constrain_purchases(constraints) - return self.get_transaction_count(**constraints) + async def select_txos(self, cols, **constraints): + sql = f"SELECT {cols} FROM txo JOIN tx USING (txid)" + if 'accounts' in constraints: + sql += " JOIN account_address USING (address)" + return await self.db.execute_fetchall(*query(sql, **constraints)) - async def get_txos(self, wallet=None, no_tx=False, **constraints) -> List[Output]: - txos = await super().get_txos(wallet=wallet, no_tx=no_tx, **constraints) + async def get_txos(self, wallet=None, no_tx=False, **constraints): + my_accounts = {a.public_key.address for a in wallet.accounts} if wallet else set() + if 'order_by' not in constraints: + constraints['order_by'] = [ + "tx.height=0 DESC", "tx.height DESC", "tx.position DESC", "txo.position" + ] + rows = await self.select_txos( + """ + tx.txid, raw, tx.height, tx.position, tx.is_verified, txo.position, amount, script, ( + select group_concat(account||"|"||chain) from account_address + where account_address.address=txo.address + ) + """, + **constraints + ) + txos = [] + txs = {} + for row in rows: + if no_tx: + txo = Output( + amount=row[6], + script=OutputScript(row[7]), + tx_ref=TXRefImmutable.from_id(row[0], row[2]), + position=row[5] + ) + else: + if row[0] not in txs: + txs[row[0]] = Transaction( + row[1], height=row[2], position=row[3], is_verified=row[4] + ) + txo = txs[row[0]].outputs[row[5]] + row_accounts = dict(a.split('|') for a in row[8].split(',')) + account_match = set(row_accounts) & my_accounts + if account_match: + txo.is_my_account = True + txo.is_change = row_accounts[account_match.pop()] == '1' + else: + txo.is_change = txo.is_my_account = False + txos.append(txo) channel_ids = set() for txo in txos: @@ -138,6 +628,112 @@ class WalletDatabase(BaseDatabase): return txos + async def get_txo_count(self, **constraints): + constraints.pop('wallet', None) + constraints.pop('offset', None) + constraints.pop('limit', None) + constraints.pop('order_by', None) + count = await self.select_txos('count(*)', **constraints) + return count[0][0] + + @staticmethod + def constrain_utxo(constraints): + constraints['is_reserved'] = False + constraints['txoid__not_in'] = "SELECT txoid FROM txi" + + def get_utxos(self, **constraints): + self.constrain_utxo(constraints) + return self.get_txos(**constraints) + + def get_utxo_count(self, **constraints): + self.constrain_utxo(constraints) + return self.get_txo_count(**constraints) + + async def get_balance(self, wallet=None, accounts=None, **constraints): + assert wallet or accounts, \ + "'wallet' or 'accounts' constraints required to calculate balance" + constraints['accounts'] = accounts or wallet.accounts + self.constrain_utxo(constraints) + balance = await self.select_txos('SUM(amount)', **constraints) + return balance[0][0] or 0 + + async def select_addresses(self, cols, **constraints): + return await self.db.execute_fetchall(*query( + f"SELECT {cols} FROM pubkey_address JOIN account_address USING (address)", + **constraints + )) + + async def get_addresses(self, cols=None, **constraints): + cols = cols or ( + 'address', 'account', 'chain', 'history', 'used_times', + 'pubkey', 'chain_code', 'n', 'depth' + ) + addresses = rows_to_dict(await self.select_addresses(', '.join(cols), **constraints), cols) + if 'pubkey' in cols: + for address in addresses: + address['pubkey'] = PubKey( + self.ledger, address.pop('pubkey'), address.pop('chain_code'), + address.pop('n'), address.pop('depth') + ) + return addresses + + async def get_address_count(self, cols=None, **constraints): + count = await self.select_addresses('count(*)', **constraints) + return count[0][0] + + async def get_address(self, **constraints): + addresses = await self.get_addresses(limit=1, **constraints) + if addresses: + return addresses[0] + + async def add_keys(self, account, chain, pubkeys): + await self.db.executemany( + "insert or ignore into account_address " + "(account, address, chain, pubkey, chain_code, n, depth) values " + "(?, ?, ?, ?, ?, ?, ?)", (( + account.id, k.address, chain, + sqlite3.Binary(k.pubkey_bytes), + sqlite3.Binary(k.chain_code), + k.n, k.depth + ) for k in pubkeys) + ) + await self.db.executemany( + "insert or ignore into pubkey_address (address) values (?)", + ((pubkey.address,) for pubkey in pubkeys) + ) + + async def _set_address_history(self, address, history): + await self.db.execute_fetchall( + "UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?", + (history, history.count(':')//2, address) + ) + + async def set_address_history(self, address, history): + await self._set_address_history(address, history) + + @staticmethod + def constrain_purchases(constraints): + accounts = constraints.pop('accounts', None) + assert accounts, "'accounts' argument required to find purchases" + if not {'purchased_claim_id', 'purchased_claim_id__in'}.intersection(constraints): + constraints['purchased_claim_id__is_not_null'] = True + constraints.update({ + f'$account{i}': a.public_key.address for i, a in enumerate(accounts) + }) + account_values = ', '.join([f':$account{i}' for i in range(len(accounts))]) + constraints['txid__in'] = f""" + SELECT txid FROM txi JOIN account_address USING (address) + WHERE account_address.account IN ({account_values}) + """ + + async def get_purchases(self, **constraints): + self.constrain_purchases(constraints) + return [tx.outputs[0] for tx in await self.get_transactions(**constraints)] + + def get_purchase_count(self, **constraints): + self.constrain_purchases(constraints) + return self.get_transaction_count(**constraints) + @staticmethod def constrain_claims(constraints): claim_type = constraints.pop('claim_type', None) diff --git a/lbry/wallet/dewies.py b/lbry/wallet/dewies.py index 9562d0ee5..8244712b5 100644 --- a/lbry/wallet/dewies.py +++ b/lbry/wallet/dewies.py @@ -1,5 +1,5 @@ import textwrap -from lbry.wallet.client.util import coins_to_satoshis, satoshis_to_coins +from .util import coins_to_satoshis, satoshis_to_coins def lbc_to_dewies(lbc: str) -> int: diff --git a/lbry/wallet/hash.py b/lbry/wallet/hash.py index 4fdbf306a..08a0aee82 100644 --- a/lbry/wallet/hash.py +++ b/lbry/wallet/hash.py @@ -1,5 +1,5 @@ from binascii import hexlify, unhexlify -from lbry.wallet.client.constants import NULL_HASH32 +from .constants import NULL_HASH32 class TXRef: diff --git a/lbry/wallet/header.py b/lbry/wallet/header.py index bba22db3f..afae36cfc 100644 --- a/lbry/wallet/header.py +++ b/lbry/wallet/header.py @@ -1,10 +1,252 @@ +import os import struct -from typing import Optional +import asyncio +import hashlib +import logging + +from io import BytesIO +from contextlib import asynccontextmanager +from typing import Optional, Iterator, Tuple from binascii import hexlify, unhexlify from lbry.crypto.hash import sha512, double_sha256, ripemd160 -from lbry.wallet.client.baseheader import BaseHeaders -from lbry.wallet.client.util import ArithUint256 +from lbry.wallet.util import ArithUint256 + + +log = logging.getLogger(__name__) + + +class InvalidHeader(Exception): + + def __init__(self, height, message): + super().__init__(message) + self.message = message + self.height = height + + +class BaseHeaders: + + header_size: int + chunk_size: int + + max_target: int + genesis_hash: Optional[bytes] + target_timespan: int + + validate_difficulty: bool = True + checkpoint = None + + def __init__(self, path) -> None: + if path == ':memory:': + self.io = BytesIO() + self.path = path + self._size: Optional[int] = None + + async def open(self): + if self.path != ':memory:': + if not os.path.exists(self.path): + self.io = open(self.path, 'w+b') + else: + self.io = open(self.path, 'r+b') + + async def close(self): + self.io.close() + + @staticmethod + def serialize(header: dict) -> bytes: + raise NotImplementedError + + @staticmethod + def deserialize(height, header): + raise NotImplementedError + + def get_next_chunk_target(self, chunk: int) -> ArithUint256: + return ArithUint256(self.max_target) + + @staticmethod + def get_next_block_target(chunk_target: ArithUint256, previous: Optional[dict], + current: Optional[dict]) -> ArithUint256: + return chunk_target + + def __len__(self) -> int: + if self._size is None: + self._size = self.io.seek(0, os.SEEK_END) // self.header_size + return self._size + + def __bool__(self): + return True + + def __getitem__(self, height) -> dict: + if isinstance(height, slice): + raise NotImplementedError("Slicing of header chain has not been implemented yet.") + if not 0 <= height <= self.height: + raise IndexError(f"{height} is out of bounds, current height: {self.height}") + return self.deserialize(height, self.get_raw_header(height)) + + def get_raw_header(self, height) -> bytes: + self.io.seek(height * self.header_size, os.SEEK_SET) + return self.io.read(self.header_size) + + @property + def height(self) -> int: + return len(self)-1 + + @property + def bytes_size(self): + return len(self) * self.header_size + + def hash(self, height=None) -> bytes: + return self.hash_header( + self.get_raw_header(height if height is not None else self.height) + ) + + @staticmethod + def hash_header(header: bytes) -> bytes: + if header is None: + return b'0' * 64 + return hexlify(double_sha256(header)[::-1]) + + @asynccontextmanager + async def checkpointed_connector(self): + buf = BytesIO() + try: + yield buf + finally: + await asyncio.sleep(0) + final_height = len(self) + buf.tell() // self.header_size + verifiable_bytes = (self.checkpoint[0] - len(self)) * self.header_size if self.checkpoint else 0 + if verifiable_bytes > 0 and final_height >= self.checkpoint[0]: + buf.seek(0) + self.io.seek(0) + h = hashlib.sha256() + h.update(self.io.read()) + h.update(buf.read(verifiable_bytes)) + if h.hexdigest().encode() == self.checkpoint[1]: + buf.seek(0) + self._write(len(self), buf.read(verifiable_bytes)) + remaining = buf.read() + buf.seek(0) + buf.write(remaining) + buf.truncate() + else: + log.warning("Checkpoint mismatch, connecting headers through slow method.") + if buf.tell() > 0: + await self.connect(len(self), buf.getvalue()) + + async def connect(self, start: int, headers: bytes) -> int: + added = 0 + bail = False + for height, chunk in self._iterate_chunks(start, headers): + try: + # validate_chunk() is CPU bound and reads previous chunks from file system + self.validate_chunk(height, chunk) + except InvalidHeader as e: + bail = True + chunk = chunk[:(height-e.height)*self.header_size] + added += self._write(height, chunk) if chunk else 0 + if bail: + break + return added + + def _write(self, height, verified_chunk): + self.io.seek(height * self.header_size, os.SEEK_SET) + written = self.io.write(verified_chunk) // self.header_size + self.io.truncate() + # .seek()/.write()/.truncate() might also .flush() when needed + # the goal here is mainly to ensure we're definitely flush()'ing + self.io.flush() + self._size = self.io.tell() // self.header_size + return written + + def validate_chunk(self, height, chunk): + previous_hash, previous_header, previous_previous_header = None, None, None + if height > 0: + previous_header = self[height-1] + previous_hash = self.hash(height-1) + if height > 1: + previous_previous_header = self[height-2] + chunk_target = self.get_next_chunk_target(height // 2016 - 1) + for current_hash, current_header in self._iterate_headers(height, chunk): + block_target = self.get_next_block_target(chunk_target, previous_previous_header, previous_header) + self.validate_header(height, current_hash, current_header, previous_hash, block_target) + previous_previous_header = previous_header + previous_header = current_header + previous_hash = current_hash + + def validate_header(self, height: int, current_hash: bytes, + header: dict, previous_hash: bytes, target: ArithUint256): + + if previous_hash is None: + if self.genesis_hash is not None and self.genesis_hash != current_hash: + raise InvalidHeader( + height, f"genesis header doesn't match: {current_hash.decode()} " + f"vs expected {self.genesis_hash.decode()}") + return + + if header['prev_block_hash'] != previous_hash: + raise InvalidHeader( + height, "previous hash mismatch: {} vs expected {}".format( + header['prev_block_hash'].decode(), previous_hash.decode()) + ) + + if self.validate_difficulty: + + if header['bits'] != target.compact: + raise InvalidHeader( + height, "bits mismatch: {} vs expected {}".format( + header['bits'], target.compact) + ) + + proof_of_work = self.get_proof_of_work(current_hash) + if proof_of_work > target: + raise InvalidHeader( + height, f"insufficient proof of work: {proof_of_work.value} vs target {target.value}" + ) + + async def repair(self): + previous_header_hash = fail = None + batch_size = 36 + for start_height in range(0, self.height, batch_size): + self.io.seek(self.header_size * start_height) + headers = self.io.read(self.header_size*batch_size) + if len(headers) % self.header_size != 0: + headers = headers[:(len(headers) // self.header_size) * self.header_size] + for header_hash, header in self._iterate_headers(start_height, headers): + height = header['block_height'] + if height: + if header['prev_block_hash'] != previous_header_hash: + fail = True + else: + if header_hash != self.genesis_hash: + fail = True + if fail: + log.warning("Header file corrupted at height %s, truncating it.", height - 1) + self.io.seek(max(0, (height - 1)) * self.header_size, os.SEEK_SET) + self.io.truncate() + self.io.flush() + self._size = None + return + previous_header_hash = header_hash + + @staticmethod + def get_proof_of_work(header_hash: bytes) -> ArithUint256: + return ArithUint256(int(b'0x' + header_hash, 16)) + + def _iterate_chunks(self, height: int, headers: bytes) -> Iterator[Tuple[int, bytes]]: + assert len(headers) % self.header_size == 0, f"{len(headers)} {len(headers)%self.header_size}" + start = 0 + end = (self.chunk_size - height % self.chunk_size) * self.header_size + while start < end: + yield height + (start // self.header_size), headers[start:end] + start = end + end = min(len(headers), end + self.chunk_size * self.header_size) + + def _iterate_headers(self, height: int, headers: bytes) -> Iterator[Tuple[bytes, dict]]: + assert len(headers) % self.header_size == 0, len(headers) + for idx in range(len(headers) // self.header_size): + start, end = idx * self.header_size, (idx + 1) * self.header_size + header = headers[start:end] + yield self.hash_header(header), self.deserialize(height+idx, header) class Headers(BaseHeaders): diff --git a/lbry/wallet/ledger.py b/lbry/wallet/ledger.py index 1e11d2b16..37afe857e 100644 --- a/lbry/wallet/ledger.py +++ b/lbry/wallet/ledger.py @@ -1,41 +1,96 @@ +import os +import zlib +import pylru +import base64 import asyncio import logging -from binascii import unhexlify -from functools import partial -from typing import Tuple, List -from datetime import datetime -import pylru -from lbry.wallet.client.baseledger import BaseLedger, TransactionEvent -from lbry.wallet.client.baseaccount import SingleKey +from io import StringIO +from datetime import datetime +from functools import partial +from operator import itemgetter +from collections import namedtuple +from binascii import hexlify, unhexlify +from typing import Dict, Tuple, Type, Iterable, List, Optional + from lbry.schema.result import Outputs from lbry.schema.url import URL -from lbry.wallet.dewies import dewies_to_lbc -from lbry.wallet.account import Account -from lbry.wallet.network import Network -from lbry.wallet.database import WalletDatabase -from lbry.wallet.transaction import Transaction, Output -from lbry.wallet.header import Headers, UnvalidatedHeaders -from lbry.wallet.constants import TXO_TYPES +from lbry.crypto.hash import hash160, double_sha256, sha256 +from lbry.crypto.base58 import Base58 + +from .tasks import TaskGroup +from .database import Database +from .stream import StreamController +from .dewies import dewies_to_lbc +from .account import Account, AddressManager, SingleKey +from .network import Network +from .transaction import Transaction, Output +from .header import Headers, UnvalidatedHeaders +from .constants import TXO_TYPES, COIN, NULL_HASH32 +from .bip32 import PubKey, PrivateKey +from .coinselection import CoinSelector log = logging.getLogger(__name__) +LedgerType = Type['BaseLedger'] -class MainNetLedger(BaseLedger): + +class LedgerRegistry(type): + + ledgers: Dict[str, LedgerType] = {} + + def __new__(mcs, name, bases, attrs): + cls: LedgerType = super().__new__(mcs, name, bases, attrs) + if not (name == 'BaseLedger' and not bases): + ledger_id = cls.get_id() + assert ledger_id not in mcs.ledgers, \ + f'Ledger with id "{ledger_id}" already registered.' + mcs.ledgers[ledger_id] = cls + return cls + + @classmethod + def get_ledger_class(mcs, ledger_id: str) -> LedgerType: + return mcs.ledgers[ledger_id] + + +class TransactionEvent(namedtuple('TransactionEvent', ('address', 'tx'))): + pass + + +class AddressesGeneratedEvent(namedtuple('AddressesGeneratedEvent', ('address_manager', 'addresses'))): + pass + + +class BlockHeightEvent(namedtuple('BlockHeightEvent', ('height', 'change'))): + pass + + +class TransactionCacheItem: + __slots__ = '_tx', 'lock', 'has_tx' + + def __init__(self, tx: Optional[Transaction] = None, lock: Optional[asyncio.Lock] = None): + self.has_tx = asyncio.Event() + self.lock = lock or asyncio.Lock() + self._tx = self.tx = tx + + @property + def tx(self) -> Optional[Transaction]: + return self._tx + + @tx.setter + def tx(self, tx: Transaction): + self._tx = tx + if tx is not None: + self.has_tx.set() + + +class Ledger(metaclass=LedgerRegistry): name = 'LBRY Credits' symbol = 'LBC' network_name = 'mainnet' - headers: Headers - - account_class = Account - database_class = WalletDatabase headers_class = Headers - network_class = Network - transaction_class = Transaction - - db: WalletDatabase secret_prefix = bytes((0x1c,)) pubkey_address_prefix = bytes((0x55,)) @@ -51,11 +106,522 @@ class MainNetLedger(BaseLedger): default_fee_per_byte = 50 default_fee_per_name_char = 200000 - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, config=None): + self.config = config or {} + self.db: Database = self.config.get('db') or Database( + os.path.join(self.path, "blockchain.db") + ) + self.db.ledger = self + self.headers: Headers = self.config.get('headers') or self.headers_class( + os.path.join(self.path, "headers") + ) + self.network: Network = self.config.get('network') or Network(self) + self.network.on_header.listen(self.receive_header) + self.network.on_status.listen(self.process_status_update) + self.network.on_connected.listen(self.join_network) + + self.accounts = [] + self.fee_per_byte: int = self.config.get('fee_per_byte', self.default_fee_per_byte) + + self._on_transaction_controller = StreamController() + self.on_transaction = self._on_transaction_controller.stream + self.on_transaction.listen( + lambda e: log.info( + '(%s) on_transaction: address=%s, height=%s, is_verified=%s, tx.id=%s', + self.get_id(), e.address, e.tx.height, e.tx.is_verified, e.tx.id + ) + ) + + self._on_address_controller = StreamController() + self.on_address = self._on_address_controller.stream + self.on_address.listen( + lambda e: log.info('(%s) on_address: %s', self.get_id(), e.addresses) + ) + + self._on_header_controller = StreamController() + self.on_header = self._on_header_controller.stream + self.on_header.listen( + lambda change: log.info( + '%s: added %s header blocks, final height %s', + self.get_id(), change, self.headers.height + ) + ) + self._download_height = 0 + + self._on_ready_controller = StreamController() + self.on_ready = self._on_ready_controller.stream + + self._tx_cache = pylru.lrucache(100000) + self._update_tasks = TaskGroup() + self._utxo_reservation_lock = asyncio.Lock() + self._header_processing_lock = asyncio.Lock() + self._address_update_locks: Dict[str, asyncio.Lock] = {} + + self.coin_selection_strategy = None + self._known_addresses_out_of_sync = set() + self.fee_per_name_char = self.config.get('fee_per_name_char', self.default_fee_per_name_char) self._balance_cache = pylru.lrucache(100000) + @classmethod + def get_id(cls): + return '{}_{}'.format(cls.symbol.lower(), cls.network_name.lower()) + + @classmethod + def hash160_to_address(cls, h160): + raw_address = cls.pubkey_address_prefix + h160 + return Base58.encode(bytearray(raw_address + double_sha256(raw_address)[0:4])) + + @staticmethod + def address_to_hash160(address): + return Base58.decode(address)[1:21] + + @classmethod + def is_valid_address(cls, address): + decoded = Base58.decode_check(address) + return decoded[0] == cls.pubkey_address_prefix[0] + + @classmethod + def public_key_to_address(cls, public_key): + return cls.hash160_to_address(hash160(public_key)) + + @staticmethod + def private_key_to_wif(private_key): + return b'\x1c' + private_key + b'\x01' + + @property + def path(self): + return os.path.join(self.config['data_path'], self.get_id()) + + def add_account(self, account: Account): + self.accounts.append(account) + + async def _get_account_and_address_info_for_address(self, wallet, address): + match = await self.db.get_address(accounts=wallet.accounts, address=address) + if match: + for account in wallet.accounts: + if match['account'] == account.public_key.address: + return account, match + + async def get_private_key_for_address(self, wallet, address) -> Optional[PrivateKey]: + match = await self._get_account_and_address_info_for_address(wallet, address) + if match: + account, address_info = match + return account.get_private_key(address_info['chain'], address_info['pubkey'].n) + return None + + async def get_public_key_for_address(self, wallet, address) -> Optional[PubKey]: + match = await self._get_account_and_address_info_for_address(wallet, address) + if match: + _, address_info = match + return address_info['pubkey'] + return None + + async def get_account_for_address(self, wallet, address): + match = await self._get_account_and_address_info_for_address(wallet, address) + if match: + return match[0] + + async def get_effective_amount_estimators(self, funding_accounts: Iterable[Account]): + estimators = [] + for account in funding_accounts: + utxos = await account.get_utxos() + for utxo in utxos: + estimators.append(utxo.get_estimator(self)) + return estimators + + async def get_addresses(self, **constraints): + return await self.db.get_addresses(**constraints) + + def get_address_count(self, **constraints): + return self.db.get_address_count(**constraints) + + async def get_spendable_utxos(self, amount: int, funding_accounts): + async with self._utxo_reservation_lock: + txos = await self.get_effective_amount_estimators(funding_accounts) + fee = Output.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(self) + selector = CoinSelector(amount, fee) + spendables = selector.select(txos, self.coin_selection_strategy) + if spendables: + await self.reserve_outputs(s.txo for s in spendables) + return spendables + + def reserve_outputs(self, txos): + return self.db.reserve_outputs(txos) + + def release_outputs(self, txos): + return self.db.release_outputs(txos) + + def release_tx(self, tx): + return self.release_outputs([txi.txo_ref.txo for txi in tx.inputs]) + + def get_utxos(self, **constraints): + self.constraint_spending_utxos(constraints) + return self.db.get_utxos(**constraints) + + def get_utxo_count(self, **constraints): + self.constraint_spending_utxos(constraints) + return self.db.get_utxo_count(**constraints) + + def get_transactions(self, **constraints): + return self.db.get_transactions(**constraints) + + def get_transaction_count(self, **constraints): + return self.db.get_transaction_count(**constraints) + + async def get_local_status_and_history(self, address, history=None): + if not history: + address_details = await self.db.get_address(address=address) + history = address_details['history'] or '' + parts = history.split(':')[:-1] + return ( + hexlify(sha256(history.encode())).decode() if history else None, + list(zip(parts[0::2], map(int, parts[1::2]))) + ) + + @staticmethod + def get_root_of_merkle_tree(branches, branch_positions, working_branch): + for i, branch in enumerate(branches): + other_branch = unhexlify(branch)[::-1] + other_branch_on_left = bool((branch_positions >> i) & 1) + if other_branch_on_left: + combined = other_branch + working_branch + else: + combined = working_branch + other_branch + working_branch = double_sha256(combined) + return hexlify(working_branch[::-1]) + + async def start(self): + if not os.path.exists(self.path): + os.mkdir(self.path) + await asyncio.wait([ + self.db.open(), + self.headers.open() + ]) + first_connection = self.network.on_connected.first + asyncio.ensure_future(self.network.start()) + await first_connection + async with self._header_processing_lock: + await self._update_tasks.add(self.initial_headers_sync()) + await self._on_ready_controller.stream.first + await asyncio.gather(*(a.maybe_migrate_certificates() for a in self.accounts)) + await asyncio.gather(*(a.save_max_gap() for a in self.accounts)) + if len(self.accounts) > 10: + log.info("Loaded %i accounts", len(self.accounts)) + else: + await self._report_state() + self.on_transaction.listen(self._reset_balance_cache) + + async def join_network(self, *_): + log.info("Subscribing and updating accounts.") + async with self._header_processing_lock: + await self.update_headers() + await self.subscribe_accounts() + await self._update_tasks.done.wait() + self._on_ready_controller.add(True) + + async def stop(self): + self._update_tasks.cancel() + await self._update_tasks.done.wait() + await self.network.stop() + await self.db.close() + await self.headers.close() + + @property + def local_height_including_downloaded_height(self): + return max(self.headers.height, self._download_height) + + async def initial_headers_sync(self): + target = self.network.remote_height + 1 + current = len(self.headers) + get_chunk = partial(self.network.retriable_call, self.network.get_headers, count=4096, b64=True) + chunks = [asyncio.create_task(get_chunk(height)) for height in range(current, target, 4096)] + total = 0 + async with self.headers.checkpointed_connector() as buffer: + for chunk in chunks: + headers = await chunk + total += buffer.write( + zlib.decompress(base64.b64decode(headers['base64']), wbits=-15, bufsize=600_000) + ) + self._download_height = current + total // self.headers.header_size + log.info("Headers sync: %s / %s", self._download_height, target) + + async def update_headers(self, height=None, headers=None, subscription_update=False): + rewound = 0 + while True: + + if height is None or height > len(self.headers): + # sometimes header subscription updates are for a header in the future + # which can't be connected, so we do a normal header sync instead + height = len(self.headers) + headers = None + subscription_update = False + + if not headers: + header_response = await self.network.retriable_call(self.network.get_headers, height, 2001) + headers = header_response['hex'] + + if not headers: + # Nothing to do, network thinks we're already at the latest height. + return + + added = await self.headers.connect(height, unhexlify(headers)) + if added > 0: + height += added + self._on_header_controller.add( + BlockHeightEvent(self.headers.height, added)) + + if rewound > 0: + # we started rewinding blocks and apparently found + # a new chain + rewound = 0 + await self.db.rewind_blockchain(height) + + if subscription_update: + # subscription updates are for latest header already + # so we don't need to check if there are newer / more + # on another loop of update_headers(), just return instead + return + + elif added == 0: + # we had headers to connect but none got connected, probably a reorganization + height -= 1 + rewound += 1 + log.warning( + "Blockchain Reorganization: attempting rewind to height %s from starting height %s", + height, height+rewound + ) + + else: + raise IndexError(f"headers.connect() returned negative number ({added})") + + if height < 0: + raise IndexError( + "Blockchain reorganization rewound all the way back to genesis hash. " + "Something is very wrong. Maybe you are on the wrong blockchain?" + ) + + if rewound >= 100: + raise IndexError( + "Blockchain reorganization dropped {} headers. This is highly unusual. " + "Will not continue to attempt reorganizing. Please, delete the ledger " + "synchronization directory inside your wallet directory (folder: '{}') and " + "restart the program to synchronize from scratch." + .format(rewound, self.get_id()) + ) + + headers = None # ready to download some more headers + + # if we made it this far and this was a subscription_update + # it means something went wrong and now we're doing a more + # robust sync, turn off subscription update shortcut + subscription_update = False + + async def receive_header(self, response): + async with self._header_processing_lock: + header = response[0] + await self.update_headers( + height=header['height'], headers=header['hex'], subscription_update=True + ) + + async def subscribe_accounts(self): + if self.network.is_connected and self.accounts: + await asyncio.wait([ + self.subscribe_account(a) for a in self.accounts + ]) + + async def subscribe_account(self, account: Account): + for address_manager in account.address_managers.values(): + await self.subscribe_addresses(address_manager, await address_manager.get_addresses()) + await account.ensure_address_gap() + + async def unsubscribe_account(self, account: Account): + for address in await account.get_addresses(): + await self.network.unsubscribe_address(address) + + async def announce_addresses(self, address_manager: AddressManager, addresses: List[str]): + await self.subscribe_addresses(address_manager, addresses) + await self._on_address_controller.add( + AddressesGeneratedEvent(address_manager, addresses) + ) + + async def subscribe_addresses(self, address_manager: AddressManager, addresses: List[str]): + if self.network.is_connected and addresses: + await asyncio.wait([ + self.subscribe_address(address_manager, address) for address in addresses + ]) + + async def subscribe_address(self, address_manager: AddressManager, address: str): + remote_status = await self.network.subscribe_address(address) + self._update_tasks.add(self.update_history(address, remote_status, address_manager)) + + def process_status_update(self, update): + address, remote_status = update + self._update_tasks.add(self.update_history(address, remote_status)) + + async def update_history(self, address, remote_status, + address_manager: AddressManager = None): + + async with self._address_update_locks.setdefault(address, asyncio.Lock()): + self._known_addresses_out_of_sync.discard(address) + + local_status, local_history = await self.get_local_status_and_history(address) + + if local_status == remote_status: + return True + + remote_history = await self.network.retriable_call(self.network.get_history, address) + remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history)) + we_need = set(remote_history) - set(local_history) + if not we_need: + return True + + cache_tasks: List[asyncio.Future[Transaction]] = [] + synced_history = StringIO() + for i, (txid, remote_height) in enumerate(remote_history): + if i < len(local_history) and local_history[i] == (txid, remote_height) and not cache_tasks: + synced_history.write(f'{txid}:{remote_height}:') + else: + check_local = (txid, remote_height) not in we_need + cache_tasks.append(asyncio.ensure_future( + self.cache_transaction(txid, remote_height, check_local=check_local) + )) + + synced_txs = [] + for task in cache_tasks: + tx = await task + + check_db_for_txos = [] + for txi in tx.inputs: + if txi.txo_ref.txo is not None: + continue + cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.id) + if cache_item is not None: + if cache_item.tx is None: + await cache_item.has_tx.wait() + assert cache_item.tx is not None + txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref + else: + check_db_for_txos.append(txi.txo_ref.id) + + referenced_txos = {} if not check_db_for_txos else { + txo.id: txo for txo in await self.db.get_txos(txoid__in=check_db_for_txos, no_tx=True) + } + + for txi in tx.inputs: + if txi.txo_ref.txo is not None: + continue + referenced_txo = referenced_txos.get(txi.txo_ref.id) + if referenced_txo is not None: + txi.txo_ref = referenced_txo.ref + + synced_history.write(f'{tx.id}:{tx.height}:') + synced_txs.append(tx) + + await self.db.save_transaction_io_batch( + synced_txs, address, self.address_to_hash160(address), synced_history.getvalue() + ) + await asyncio.wait([ + self._on_transaction_controller.add(TransactionEvent(address, tx)) + for tx in synced_txs + ]) + + if address_manager is None: + address_manager = await self.get_address_manager_for_address(address) + + if address_manager is not None: + await address_manager.ensure_address_gap() + + local_status, local_history = \ + await self.get_local_status_and_history(address, synced_history.getvalue()) + if local_status != remote_status: + if local_history == remote_history: + return True + log.warning( + "Wallet is out of sync after syncing. Remote: %s with %d items, local: %s with %d items", + remote_status, len(remote_history), local_status, len(local_history) + ) + log.warning("local: %s", local_history) + log.warning("remote: %s", remote_history) + self._known_addresses_out_of_sync.add(address) + return False + else: + return True + + async def cache_transaction(self, txid, remote_height, check_local=True): + cache_item = self._tx_cache.get(txid) + if cache_item is None: + cache_item = self._tx_cache[txid] = TransactionCacheItem() + elif cache_item.tx is not None and \ + cache_item.tx.height >= remote_height and \ + (cache_item.tx.is_verified or remote_height < 1): + return cache_item.tx # cached tx is already up-to-date + + async with cache_item.lock: + + tx = cache_item.tx + + if tx is None and check_local: + # check local db + tx = cache_item.tx = await self.db.get_transaction(txid=txid) + + if tx is None: + # fetch from network + _raw = await self.network.retriable_call(self.network.get_transaction, txid, remote_height) + tx = Transaction(unhexlify(_raw)) + cache_item.tx = tx # make sure it's saved before caching it + + await self.maybe_verify_transaction(tx, remote_height) + return tx + + async def maybe_verify_transaction(self, tx, remote_height): + tx.height = remote_height + if 0 < remote_height < len(self.headers): + merkle = await self.network.retriable_call(self.network.get_merkle, tx.id, remote_height) + merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) + header = self.headers[remote_height] + tx.position = merkle['pos'] + tx.is_verified = merkle_root == header['merkle_root'] + + async def get_address_manager_for_address(self, address) -> Optional[AddressManager]: + details = await self.db.get_address(address=address) + for account in self.accounts: + if account.id == details['account']: + return account.address_managers[details['chain']] + return None + + def broadcast(self, tx): + # broadcast can't be a retriable call yet + return self.network.broadcast(hexlify(tx.raw).decode()) + + async def wait(self, tx: Transaction, height=-1, timeout=1): + addresses = set() + for txi in tx.inputs: + if txi.txo_ref.txo is not None: + addresses.add( + self.hash160_to_address(txi.txo_ref.txo.pubkey_hash) + ) + for txo in tx.outputs: + if txo.has_address: + addresses.add(self.hash160_to_address(txo.pubkey_hash)) + records = await self.db.get_addresses(address__in=addresses) + _, pending = await asyncio.wait([ + self.on_transaction.where(partial( + lambda a, e: a == e.address and e.tx.height >= height and e.tx.id == tx.id, + address_record['address'] + )) for address_record in records + ], timeout=timeout) + if pending: + for record in records: + found = False + _, local_history = await self.get_local_status_and_history(None, history=record['history']) + for txid, local_height in local_history: + if txid == tx.id and local_height >= height: + found = True + if not found: + print(record['history'], addresses, tx.id) + raise asyncio.TimeoutError('Timed out waiting for transaction.') + async def _inflate_outputs(self, query, accounts): outputs = Outputs.from_base64(await query) txs = [] @@ -103,16 +669,6 @@ class MainNetLedger(BaseLedger): for claim in (await self.claim_search(accounts, claim_id=claim_id))[0]: return claim - async def start(self): - await super().start() - await asyncio.gather(*(a.maybe_migrate_certificates() for a in self.accounts)) - await asyncio.gather(*(a.save_max_gap() for a in self.accounts)) - if len(self.accounts) > 10: - log.info("Loaded %i accounts", len(self.accounts)) - else: - await self._report_state() - self.on_transaction.listen(self._reset_balance_cache) - async def _report_state(self): try: for account in self.accounts: @@ -147,14 +703,6 @@ class MainNetLedger(BaseLedger): def constraint_spending_utxos(constraints): constraints['txo_type__in'] = (0, TXO_TYPES['purchase']) - def get_utxos(self, **constraints): - self.constraint_spending_utxos(constraints) - return super().get_utxos(**constraints) - - def get_utxo_count(self, **constraints): - self.constraint_spending_utxos(constraints) - return super().get_utxo_count(**constraints) - async def get_purchases(self, resolve=False, **constraints): purchases = await self.db.get_purchases(**constraints) if resolve: @@ -357,7 +905,7 @@ class MainNetLedger(BaseLedger): return result -class TestNetLedger(MainNetLedger): +class TestNetLedger(Ledger): network_name = 'testnet' pubkey_address_prefix = bytes((111,)) script_address_prefix = bytes((196,)) @@ -365,7 +913,7 @@ class TestNetLedger(MainNetLedger): extended_private_key_prefix = unhexlify('04358394') -class RegTestLedger(MainNetLedger): +class RegTestLedger(Ledger): network_name = 'regtest' headers_class = UnvalidatedHeaders pubkey_address_prefix = bytes((111,)) diff --git a/lbry/wallet/manager.py b/lbry/wallet/manager.py index 883a7d9a9..1aa72e2ac 100644 --- a/lbry/wallet/manager.py +++ b/lbry/wallet/manager.py @@ -1,43 +1,112 @@ import os import json -import typing import logging +import asyncio from binascii import unhexlify -from typing import Optional, List from decimal import Decimal - -from lbry.wallet.client.basemanager import BaseWalletManager -from lbry.wallet.client.wallet import ENCRYPT_ON_DISK -from lbry.wallet.rpc.jsonrpc import CodeMessageError +from typing import List, Type, MutableSequence, MutableMapping, Optional from lbry.error import KeyFeeAboveMaxAllowedError -from lbry.wallet.dewies import dewies_to_lbc -from lbry.wallet.account import Account -from lbry.wallet.ledger import MainNetLedger -from lbry.wallet.transaction import Transaction, Output -from lbry.wallet.database import WalletDatabase +from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager from lbry.conf import Config +from .dewies import dewies_to_lbc +from .account import Account +from .ledger import Ledger, LedgerRegistry +from .transaction import Transaction, Output +from .database import Database +from .wallet import Wallet, WalletStorage, ENCRYPT_ON_DISK +from .rpc.jsonrpc import CodeMessageError + log = logging.getLogger(__name__) -if typing.TYPE_CHECKING: - from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager +class WalletManager: - -class LbryWalletManager(BaseWalletManager): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, wallets: MutableSequence[Wallet] = None, + ledgers: MutableMapping[Type[Ledger], Ledger] = None) -> None: + self.wallets = wallets or [] + self.ledgers = ledgers or {} + self.running = False self.config: Optional[Config] = None + @classmethod + def from_config(cls, config: dict) -> 'WalletManager': + manager = cls() + for ledger_id, ledger_config in config.get('ledgers', {}).items(): + manager.get_or_create_ledger(ledger_id, ledger_config) + for wallet_path in config.get('wallets', []): + wallet_storage = WalletStorage(wallet_path) + wallet = Wallet.from_storage(wallet_storage, manager) + manager.wallets.append(wallet) + return manager + + def get_or_create_ledger(self, ledger_id, ledger_config=None): + ledger_class = LedgerRegistry.get_ledger_class(ledger_id) + ledger = self.ledgers.get(ledger_class) + if ledger is None: + ledger = ledger_class(ledger_config or {}) + self.ledgers[ledger_class] = ledger + return ledger + + def import_wallet(self, path): + storage = WalletStorage(path) + wallet = Wallet.from_storage(storage, self) + self.wallets.append(wallet) + return wallet + @property - def ledger(self) -> MainNetLedger: + def default_wallet(self): + for wallet in self.wallets: + return wallet + + @property + def default_account(self): + for wallet in self.wallets: + return wallet.default_account + + @property + def accounts(self): + for wallet in self.wallets: + yield from wallet.accounts + + async def start(self): + self.running = True + await asyncio.gather(*( + l.start() for l in self.ledgers.values() + )) + + async def stop(self): + await asyncio.gather(*( + l.stop() for l in self.ledgers.values() + )) + self.running = False + + def get_wallet_or_default(self, wallet_id: Optional[str]) -> Wallet: + if wallet_id is None: + return self.default_wallet + return self.get_wallet_or_error(wallet_id) + + def get_wallet_or_error(self, wallet_id: str) -> Wallet: + for wallet in self.wallets: + if wallet.id == wallet_id: + return wallet + raise ValueError(f"Couldn't find wallet: {wallet_id}.") + + @staticmethod + def get_balance(wallet): + accounts = wallet.accounts + if not accounts: + return 0 + return accounts[0].ledger.db.get_balance(wallet=wallet, accounts=accounts) + + @property + def ledger(self) -> Ledger: return self.default_account.ledger @property - def db(self) -> WalletDatabase: + def db(self) -> Database: return self.ledger.db def check_locked(self): @@ -194,7 +263,7 @@ class LbryWalletManager(BaseWalletManager): if 'No such mempool or blockchain transaction.' in e.message: return {'success': False, 'code': 404, 'message': 'transaction not found'} return {'success': False, 'code': e.code, 'message': e.message} - tx = self.ledger.transaction_class(unhexlify(raw)) + tx = Transaction(unhexlify(raw)) await self.ledger.maybe_verify_transaction(tx, height) return tx diff --git a/lbry/wallet/mnemonic.py b/lbry/wallet/mnemonic.py index 0fcfe5b57..885d4649c 100644 --- a/lbry/wallet/mnemonic.py +++ b/lbry/wallet/mnemonic.py @@ -13,7 +13,7 @@ from secrets import randbelow import pbkdf2 from lbry.crypto.hash import hmac_sha512 -from lbry.wallet.client.words import english +from .words import english # The hash of the mnemonic seed must begin with this SEED_PREFIX = b'01' # Standard wallet diff --git a/lbry/wallet/network.py b/lbry/wallet/network.py index 170885e4f..f2b8deba9 100644 --- a/lbry/wallet/network.py +++ b/lbry/wallet/network.py @@ -1,9 +1,276 @@ -import lbry -from lbry.wallet.client.basenetwork import BaseNetwork +import logging +import asyncio +from time import perf_counter +from operator import itemgetter +from typing import Dict, Optional, Tuple + +from lbry import __version__ +from lbry.wallet.rpc import RPCSession as BaseClientSession, Connector, RPCError, ProtocolError +from lbry.wallet.stream import StreamController + +log = logging.getLogger(__name__) -class Network(BaseNetwork): - PROTOCOL_VERSION = lbry.__version__ +class ClientSession(BaseClientSession): + def __init__(self, *args, network, server, timeout=30, on_connect_callback=None, **kwargs): + self.network = network + self.server = server + super().__init__(*args, **kwargs) + self._on_disconnect_controller = StreamController() + self.on_disconnected = self._on_disconnect_controller.stream + self.framer.max_size = self.max_errors = 1 << 32 + self.bw_limit = -1 + self.timeout = timeout + self.max_seconds_idle = timeout * 2 + self.response_time: Optional[float] = None + self.connection_latency: Optional[float] = None + self._response_samples = 0 + self.pending_amount = 0 + self._on_connect_cb = on_connect_callback or (lambda: None) + self.trigger_urgent_reconnect = asyncio.Event() + + @property + def available(self): + return not self.is_closing() and self.response_time is not None + + @property + def server_address_and_port(self) -> Optional[Tuple[str, int]]: + if not self.transport: + return None + return self.transport.get_extra_info('peername') + + async def send_timed_server_version_request(self, args=(), timeout=None): + timeout = timeout or self.timeout + log.debug("send version request to %s:%i", *self.server) + start = perf_counter() + result = await asyncio.wait_for( + super().send_request('server.version', args), timeout=timeout + ) + current_response_time = perf_counter() - start + response_sum = (self.response_time or 0) * self._response_samples + current_response_time + self.response_time = response_sum / (self._response_samples + 1) + self._response_samples += 1 + return result + + async def send_request(self, method, args=()): + self.pending_amount += 1 + log.debug("send %s to %s:%i", method, *self.server) + try: + if method == 'server.version': + return await self.send_timed_server_version_request(args, self.timeout) + request = asyncio.ensure_future(super().send_request(method, args)) + while not request.done(): + done, pending = await asyncio.wait([request], timeout=self.timeout) + if pending: + log.debug("Time since last packet: %s", perf_counter() - self.last_packet_received) + if (perf_counter() - self.last_packet_received) < self.timeout: + continue + log.info("timeout sending %s to %s:%i", method, *self.server) + raise asyncio.TimeoutError + if done: + return request.result() + except (RPCError, ProtocolError) as e: + log.warning("Wallet server (%s:%i) returned an error. Code: %s Message: %s", + *self.server, *e.args) + raise e + except ConnectionError: + log.warning("connection to %s:%i lost", *self.server) + self.synchronous_close() + raise + except asyncio.CancelledError: + log.info("cancelled sending %s to %s:%i", method, *self.server) + self.synchronous_close() + raise + finally: + self.pending_amount -= 1 + + async def ensure_session(self): + # Handles reconnecting and maintaining a session alive + # TODO: change to 'ping' on newer protocol (above 1.2) + retry_delay = default_delay = 1.0 + while True: + try: + if self.is_closing(): + await self.create_connection(self.timeout) + await self.ensure_server_version() + self._on_connect_cb() + if (perf_counter() - self.last_send) > self.max_seconds_idle or self.response_time is None: + await self.ensure_server_version() + retry_delay = default_delay + except RPCError as e: + log.warning("Server error, ignoring for 1h: %s:%d -- %s", *self.server, e.message) + retry_delay = 60 * 60 + except (asyncio.TimeoutError, OSError): + await self.close() + retry_delay = min(60, retry_delay * 2) + log.debug("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server) + try: + await asyncio.wait_for(self.trigger_urgent_reconnect.wait(), timeout=retry_delay) + except asyncio.TimeoutError: + pass + finally: + self.trigger_urgent_reconnect.clear() + + async def ensure_server_version(self, required=None, timeout=3): + required = required or self.network.PROTOCOL_VERSION + return await asyncio.wait_for( + self.send_request('server.version', [__version__, required]), timeout=timeout + ) + + async def create_connection(self, timeout=6): + connector = Connector(lambda: self, *self.server) + start = perf_counter() + await asyncio.wait_for(connector.create_connection(), timeout=timeout) + self.connection_latency = perf_counter() - start + + async def handle_request(self, request): + controller = self.network.subscription_controllers[request.method] + controller.add(request.args) + + def connection_lost(self, exc): + log.debug("Connection lost: %s:%d", *self.server) + super().connection_lost(exc) + self.response_time = None + self.connection_latency = None + self._response_samples = 0 + self.pending_amount = 0 + self._on_disconnect_controller.add(True) + + +class Network: + + PROTOCOL_VERSION = __version__ + + def __init__(self, ledger): + self.ledger = ledger + self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6)) + self.client: Optional[ClientSession] = None + self._switch_task: Optional[asyncio.Task] = None + self.running = False + self.remote_height: int = 0 + self._concurrency = asyncio.Semaphore(16) + + self._on_connected_controller = StreamController() + self.on_connected = self._on_connected_controller.stream + + self._on_header_controller = StreamController(merge_repeated_events=True) + self.on_header = self._on_header_controller.stream + + self._on_status_controller = StreamController(merge_repeated_events=True) + self.on_status = self._on_status_controller.stream + + self.subscription_controllers = { + 'blockchain.headers.subscribe': self._on_header_controller, + 'blockchain.address.subscribe': self._on_status_controller, + } + + @property + def config(self): + return self.ledger.config + + async def switch_forever(self): + while self.running: + if self.is_connected: + await self.client.on_disconnected.first + self.client = None + continue + self.client = await self.session_pool.wait_for_fastest_session() + log.info("Switching to SPV wallet server: %s:%d", *self.client.server) + try: + self._update_remote_height((await self.subscribe_headers(),)) + self._on_connected_controller.add(True) + log.info("Subscribed to headers: %s:%d", *self.client.server) + except (asyncio.TimeoutError, ConnectionError): + log.info("Switching to %s:%d timed out, closing and retrying.", *self.client.server) + self.client.synchronous_close() + self.client = None + + async def start(self): + self.running = True + self._switch_task = asyncio.ensure_future(self.switch_forever()) + # this may become unnecessary when there are no more bugs found, + # but for now it helps understanding log reports + self._switch_task.add_done_callback(lambda _: log.info("Wallet client switching task stopped.")) + self.session_pool.start(self.config['default_servers']) + self.on_header.listen(self._update_remote_height) + + async def stop(self): + if self.running: + self.running = False + self._switch_task.cancel() + self.session_pool.stop() + + @property + def is_connected(self): + return self.client and not self.client.is_closing() + + def rpc(self, list_or_method, args, restricted=True): + session = self.client if restricted else self.session_pool.fastest_session + if session and not session.is_closing(): + return session.send_request(list_or_method, args) + else: + self.session_pool.trigger_nodelay_connect() + raise ConnectionError("Attempting to send rpc request when connection is not available.") + + async def retriable_call(self, function, *args, **kwargs): + async with self._concurrency: + while self.running: + if not self.is_connected: + log.warning("Wallet server unavailable, waiting for it to come back and retry.") + await self.on_connected.first + await self.session_pool.wait_for_fastest_session() + try: + return await function(*args, **kwargs) + except asyncio.TimeoutError: + log.warning("Wallet server call timed out, retrying.") + except ConnectionError: + pass + raise asyncio.CancelledError() # if we got here, we are shutting down + + def _update_remote_height(self, header_args): + self.remote_height = header_args[0]["height"] + + def get_transaction(self, tx_hash, known_height=None): + # use any server if its old, otherwise restrict to who gave us the history + restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10 + return self.rpc('blockchain.transaction.get', [tx_hash], restricted) + + def get_transaction_height(self, tx_hash, known_height=None): + restricted = not known_height or 0 > known_height > self.remote_height - 10 + return self.rpc('blockchain.transaction.get_height', [tx_hash], restricted) + + def get_merkle(self, tx_hash, height): + restricted = 0 > height > self.remote_height - 10 + return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height], restricted) + + def get_headers(self, height, count=10000, b64=False): + restricted = height >= self.remote_height - 100 + return self.rpc('blockchain.block.headers', [height, count, 0, b64], restricted) + + # --- Subscribes, history and broadcasts are always aimed towards the master client directly + def get_history(self, address): + return self.rpc('blockchain.address.get_history', [address], True) + + def broadcast(self, raw_transaction): + return self.rpc('blockchain.transaction.broadcast', [raw_transaction], True) + + def subscribe_headers(self): + return self.rpc('blockchain.headers.subscribe', [True], True) + + async def subscribe_address(self, address): + try: + return await self.rpc('blockchain.address.subscribe', [address], True) + except asyncio.TimeoutError: + # abort and cancel, we can't lose a subscription, it will happen again on reconnect + if self.client: + self.client.abort() + raise asyncio.CancelledError() + + def unsubscribe_address(self, address): + return self.rpc('blockchain.address.unsubscribe', [address], True) + + def get_server_features(self): + return self.rpc('server.features', (), restricted=True) def get_claims_by_ids(self, claim_ids): return self.rpc('blockchain.claimtrie.getclaimsbyids', claim_ids) @@ -13,3 +280,95 @@ class Network(BaseNetwork): def claim_search(self, **kwargs): return self.rpc('blockchain.claimtrie.search', kwargs) + + +class SessionPool: + + def __init__(self, network: Network, timeout: float): + self.network = network + self.sessions: Dict[ClientSession, Optional[asyncio.Task]] = dict() + self.timeout = timeout + self.new_connection_event = asyncio.Event() + + @property + def online(self): + return any(not session.is_closing() for session in self.sessions) + + @property + def available_sessions(self): + return (session for session in self.sessions if session.available) + + @property + def fastest_session(self): + if not self.online: + return None + return min( + [((session.response_time + session.connection_latency) * (session.pending_amount + 1), session) + for session in self.available_sessions] or [(0, None)], + key=itemgetter(0) + )[1] + + def _get_session_connect_callback(self, session: ClientSession): + loop = asyncio.get_event_loop() + + def callback(): + duplicate_connections = [ + s for s in self.sessions + if s is not session and s.server_address_and_port == session.server_address_and_port + ] + already_connected = None if not duplicate_connections else duplicate_connections[0] + if already_connected: + self.sessions.pop(session).cancel() + session.synchronous_close() + log.debug("wallet server %s resolves to the same server as %s, rechecking in an hour", + session.server[0], already_connected.server[0]) + loop.call_later(3600, self._connect_session, session.server) + return + self.new_connection_event.set() + log.info("connected to %s:%i", *session.server) + + return callback + + def _connect_session(self, server: Tuple[str, int]): + session = None + for s in self.sessions: + if s.server == server: + session = s + break + if not session: + session = ClientSession( + network=self.network, server=server + ) + session._on_connect_cb = self._get_session_connect_callback(session) + task = self.sessions.get(session, None) + if not task or task.done(): + task = asyncio.create_task(session.ensure_session()) + task.add_done_callback(lambda _: self.ensure_connections()) + self.sessions[session] = task + + def start(self, default_servers): + for server in default_servers: + self._connect_session(server) + + def stop(self): + for session, task in self.sessions.items(): + task.cancel() + session.synchronous_close() + self.sessions.clear() + + def ensure_connections(self): + for session in self.sessions: + self._connect_session(session.server) + + def trigger_nodelay_connect(self): + # used when other parts of the system sees we might have internet back + # bypasses the retry interval + for session in self.sessions: + session.trigger_urgent_reconnect.set() + + async def wait_for_fastest_session(self): + while not self.fastest_session: + self.trigger_nodelay_connect() + self.new_connection_event.clear() + await self.new_connection_event.wait() + return self.fastest_session diff --git a/lbry/wallet/orchstr8/node.py b/lbry/wallet/orchstr8/node.py index b496ba9db..24dfb6def 100644 --- a/lbry/wallet/orchstr8/node.py +++ b/lbry/wallet/orchstr8/node.py @@ -12,28 +12,15 @@ from binascii import hexlify from typing import Type, Optional import urllib.request +import lbry from lbry.wallet.server.server import Server from lbry.wallet.server.env import Env -from lbry.wallet.client.wallet import Wallet -from lbry.wallet.client.baseledger import BaseLedger, BlockHeightEvent -from lbry.wallet.client.basemanager import BaseWalletManager -from lbry.wallet.client.baseaccount import BaseAccount +from lbry.wallet import Wallet, Ledger, RegTestLedger, WalletManager, Account, BlockHeightEvent log = logging.getLogger(__name__) -def get_manager_from_environment(default_manager=BaseWalletManager): - if 'TORBA_MANAGER' not in os.environ: - return default_manager - module_name = os.environ['TORBA_MANAGER'].split('-')[-1] # tox support - return importlib.import_module(module_name) - - -def get_ledger_from_environment(): - return importlib.import_module('lbry.wallet') - - def get_spvserver_from_ledger(ledger_module): spvserver_path, regtest_class_name = ledger_module.__spvserver__.rsplit('.', 1) spvserver_module = importlib.import_module(spvserver_path) @@ -50,16 +37,14 @@ def get_blockchain_node_from_ledger(ledger_module): class Conductor: - def __init__(self, ledger_module=None, manager_module=None, enable_segwit=False, seed=None): - self.ledger_module = ledger_module or get_ledger_from_environment() - self.manager_module = manager_module or get_manager_from_environment() - self.spv_module = get_spvserver_from_ledger(self.ledger_module) + def __init__(self, seed=None): + self.manager_module = WalletManager + self.spv_module = get_spvserver_from_ledger(lbry.wallet) - self.blockchain_node = get_blockchain_node_from_ledger(self.ledger_module) - self.blockchain_node.segwit_enabled = enable_segwit + self.blockchain_node = get_blockchain_node_from_ledger(lbry.wallet) self.spv_node = SPVNode(self.spv_module) self.wallet_node = WalletNode( - self.manager_module, self.ledger_module.RegTestLedger, default_seed=seed + self.manager_module, RegTestLedger, default_seed=seed ) self.blockchain_started = False @@ -119,15 +104,15 @@ class Conductor: class WalletNode: - def __init__(self, manager_class: Type[BaseWalletManager], ledger_class: Type[BaseLedger], + def __init__(self, manager_class: Type[WalletManager], ledger_class: Type[Ledger], verbose: bool = False, port: int = 5280, default_seed: str = None) -> None: self.manager_class = manager_class self.ledger_class = ledger_class self.verbose = verbose - self.manager: Optional[BaseWalletManager] = None - self.ledger: Optional[BaseLedger] = None + self.manager: Optional[WalletManager] = None + self.ledger: Optional[Ledger] = None self.wallet: Optional[Wallet] = None - self.account: Optional[BaseAccount] = None + self.account: Optional[Account] = None self.data_path: Optional[str] = None self.port = port self.default_seed = default_seed @@ -154,7 +139,7 @@ class WalletNode: if not self.wallet: raise ValueError('Wallet is required.') if seed or self.default_seed: - self.ledger.account_class.from_dict( + Account.from_dict( self.ledger, self.wallet, {'seed': seed or self.default_seed} ) else: @@ -250,7 +235,7 @@ class BlockchainNode: P2SH_SEGWIT_ADDRESS = "p2sh-segwit" BECH32_ADDRESS = "bech32" - def __init__(self, url, daemon, cli, segwit_enabled=False): + def __init__(self, url, daemon, cli): self.latest_release_url = url self.project_dir = os.path.dirname(os.path.dirname(__file__)) self.bin_dir = os.path.join(self.project_dir, 'bin') @@ -266,7 +251,6 @@ class BlockchainNode: self.rpcport = 9245 + 2 # avoid conflict with default rpc port self.rpcuser = 'rpcuser' self.rpcpassword = 'rpcpassword' - self.segwit_enabled = segwit_enabled @property def rpc_url(self): @@ -326,8 +310,6 @@ class BlockchainNode: f'-rpcuser={self.rpcuser}', f'-rpcpassword={self.rpcpassword}', f'-rpcport={self.rpcport}', f'-port={self.peerport}' ] - if not self.segwit_enabled: - command.extend(['-addresstype=legacy', '-vbparams=segwit:0:999999999999']) self.log.info(' '.join(command)) self.transport, self.protocol = await loop.subprocess_exec( BlockchainProcess, *command diff --git a/lbry/wallet/orchstr8/service.py b/lbry/wallet/orchstr8/service.py index 25a62081d..495f68a07 100644 --- a/lbry/wallet/orchstr8/service.py +++ b/lbry/wallet/orchstr8/service.py @@ -3,7 +3,7 @@ import logging from aiohttp.web import Application, WebSocketResponse, json_response from aiohttp.http_websocket import WSMsgType, WSCloseCode -from lbry.wallet.client.util import satoshis_to_coins +from lbry.wallet.util import satoshis_to_coins from .node import Conductor diff --git a/lbry/wallet/script.py b/lbry/wallet/script.py index 4c0ca9509..f49b193cd 100644 --- a/lbry/wallet/script.py +++ b/lbry/wallet/script.py @@ -1,34 +1,430 @@ -from lbry.wallet.client.basescript import BaseInputScript, BaseOutputScript, Template -from lbry.wallet.client.basescript import PUSH_SINGLE, PUSH_INTEGER, OP_DROP, OP_2DROP, PUSH_SUBSCRIPT, OP_VERIFY +from typing import List +from itertools import chain +from binascii import hexlify +from collections import namedtuple + +from .bcd_data_stream import BCDataStream +from .util import subclass_tuple -class InputScript(BaseInputScript): - pass +# bitcoin opcodes +OP_0 = 0x00 +OP_1 = 0x51 +OP_16 = 0x60 +OP_VERIFY = 0x69 +OP_DUP = 0x76 +OP_HASH160 = 0xa9 +OP_EQUALVERIFY = 0x88 +OP_CHECKSIG = 0xac +OP_CHECKMULTISIG = 0xae +OP_EQUAL = 0x87 +OP_PUSHDATA1 = 0x4c +OP_PUSHDATA2 = 0x4d +OP_PUSHDATA4 = 0x4e +OP_RETURN = 0x6a +OP_2DROP = 0x6d +OP_DROP = 0x75 + +# lbry custom opcodes +# checks +OP_PRICECHECK = 0xb0 # checks that the BUY output is >= SELL price +# tx types +OP_CLAIM_NAME = 0xb5 +OP_SUPPORT_CLAIM = 0xb6 +OP_UPDATE_CLAIM = 0xb7 +OP_SELL_CLAIM = 0xb8 +OP_BUY_CLAIM = 0xb9 + +# template matching opcodes (not real opcodes) +# base class for PUSH_DATA related opcodes +# pylint: disable=invalid-name +PUSH_DATA_OP = namedtuple('PUSH_DATA_OP', 'name') +# opcode for variable length strings +# pylint: disable=invalid-name +PUSH_SINGLE = subclass_tuple('PUSH_SINGLE', PUSH_DATA_OP) +# opcode for variable size integers +# pylint: disable=invalid-name +PUSH_INTEGER = subclass_tuple('PUSH_INTEGER', PUSH_DATA_OP) +# opcode for variable number of variable length strings +# pylint: disable=invalid-name +PUSH_MANY = subclass_tuple('PUSH_MANY', PUSH_DATA_OP) +# opcode with embedded subscript parsing +# pylint: disable=invalid-name +PUSH_SUBSCRIPT = namedtuple('PUSH_SUBSCRIPT', 'name template') -class OutputScript(BaseOutputScript): +def is_push_data_opcode(opcode): + return isinstance(opcode, (PUSH_DATA_OP, PUSH_SUBSCRIPT)) - # lbry custom opcodes - # checks - OP_PRICECHECK = 0xb0 # checks that the BUY output is >= SELL price +def is_push_data_token(token): + return 1 <= token <= OP_PUSHDATA4 - # tx types - OP_CLAIM_NAME = 0xb5 - OP_SUPPORT_CLAIM = 0xb6 - OP_UPDATE_CLAIM = 0xb7 - OP_SELL_CLAIM = 0xb8 - OP_BUY_CLAIM = 0xb9 + +def push_data(data): + size = len(data) + if size < OP_PUSHDATA1: + yield BCDataStream.uint8.pack(size) + elif size <= 0xFF: + yield BCDataStream.uint8.pack(OP_PUSHDATA1) + yield BCDataStream.uint8.pack(size) + elif size <= 0xFFFF: + yield BCDataStream.uint8.pack(OP_PUSHDATA2) + yield BCDataStream.uint16.pack(size) + else: + yield BCDataStream.uint8.pack(OP_PUSHDATA4) + yield BCDataStream.uint32.pack(size) + yield bytes(data) + + +def read_data(token, stream): + if token < OP_PUSHDATA1: + return stream.read(token) + if token == OP_PUSHDATA1: + return stream.read(stream.read_uint8()) + if token == OP_PUSHDATA2: + return stream.read(stream.read_uint16()) + return stream.read(stream.read_uint32()) + + +# opcode for OP_1 - OP_16 +# pylint: disable=invalid-name +SMALL_INTEGER = namedtuple('SMALL_INTEGER', 'name') + + +def is_small_integer(token): + return OP_1 <= token <= OP_16 + + +def push_small_integer(num): + assert 1 <= num <= 16 + yield BCDataStream.uint8.pack(OP_1 + (num - 1)) + + +def read_small_integer(token): + return (token - OP_1) + 1 + + +class Token(namedtuple('Token', 'value')): + __slots__ = () + + def __repr__(self): + name = None + for var_name, var_value in globals().items(): + if var_name.startswith('OP_') and var_value == self.value: + name = var_name + break + return name or self.value + + +class DataToken(Token): + __slots__ = () + + def __repr__(self): + return f'"{hexlify(self.value)}"' + + +class SmallIntegerToken(Token): + __slots__ = () + + def __repr__(self): + return f'SmallIntegerToken({self.value})' + + +def token_producer(source): + token = source.read_uint8() + while token is not None: + if is_push_data_token(token): + yield DataToken(read_data(token, source)) + elif is_small_integer(token): + yield SmallIntegerToken(read_small_integer(token)) + else: + yield Token(token) + token = source.read_uint8() + + +def tokenize(source): + return list(token_producer(source)) + + +class ScriptError(Exception): + """ General script handling error. """ + + +class ParseError(ScriptError): + """ Script parsing error. """ + + +class Parser: + + def __init__(self, opcodes, tokens): + self.opcodes = opcodes + self.tokens = tokens + self.values = {} + self.token_index = 0 + self.opcode_index = 0 + + def parse(self): + while self.token_index < len(self.tokens) and self.opcode_index < len(self.opcodes): + token = self.tokens[self.token_index] + opcode = self.opcodes[self.opcode_index] + if token.value == 0 and isinstance(opcode, PUSH_SINGLE): + token = DataToken(b'') + if isinstance(token, DataToken): + if isinstance(opcode, (PUSH_SINGLE, PUSH_INTEGER, PUSH_SUBSCRIPT)): + self.push_single(opcode, token.value) + elif isinstance(opcode, PUSH_MANY): + self.consume_many_non_greedy() + else: + raise ParseError(f"DataToken found but opcode was '{opcode}'.") + elif isinstance(token, SmallIntegerToken): + if isinstance(opcode, SMALL_INTEGER): + self.values[opcode.name] = token.value + else: + raise ParseError(f"SmallIntegerToken found but opcode was '{opcode}'.") + elif token.value == opcode: + pass + else: + raise ParseError(f"Token is '{token.value}' and opcode is '{opcode}'.") + self.token_index += 1 + self.opcode_index += 1 + + if self.token_index < len(self.tokens): + raise ParseError("Parse completed without all tokens being consumed.") + + if self.opcode_index < len(self.opcodes): + raise ParseError("Parse completed without all opcodes being consumed.") + + return self + + def consume_many_non_greedy(self): + """ Allows PUSH_MANY to consume data without being greedy + in cases when one or more PUSH_SINGLEs follow a PUSH_MANY. This will + prioritize giving all PUSH_SINGLEs some data and only after that + subsume the rest into PUSH_MANY. + """ + + token_values = [] + while self.token_index < len(self.tokens): + token = self.tokens[self.token_index] + if not isinstance(token, DataToken): + self.token_index -= 1 + break + token_values.append(token.value) + self.token_index += 1 + + push_opcodes = [] + push_many_count = 0 + while self.opcode_index < len(self.opcodes): + opcode = self.opcodes[self.opcode_index] + if not is_push_data_opcode(opcode): + self.opcode_index -= 1 + break + if isinstance(opcode, PUSH_MANY): + push_many_count += 1 + push_opcodes.append(opcode) + self.opcode_index += 1 + + if push_many_count > 1: + raise ParseError( + "Cannot have more than one consecutive PUSH_MANY, as there is no way to tell which" + " token value should go into which PUSH_MANY." + ) + + if len(push_opcodes) > len(token_values): + raise ParseError( + "Not enough token values to match all of the PUSH_MANY and PUSH_SINGLE opcodes." + ) + + many_opcode = push_opcodes.pop(0) + + # consume data into PUSH_SINGLE opcodes, working backwards + for opcode in reversed(push_opcodes): + self.push_single(opcode, token_values.pop()) + + # finally PUSH_MANY gets everything that's left + self.values[many_opcode.name] = token_values + + def push_single(self, opcode, value): + if isinstance(opcode, PUSH_SINGLE): + self.values[opcode.name] = value + elif isinstance(opcode, PUSH_INTEGER): + self.values[opcode.name] = int.from_bytes(value, 'little') + elif isinstance(opcode, PUSH_SUBSCRIPT): + self.values[opcode.name] = Script.from_source_with_template(value, opcode.template) + else: + raise ParseError(f"Not a push single or subscript: {opcode}") + + +class Template: + + __slots__ = 'name', 'opcodes' + + def __init__(self, name, opcodes): + self.name = name + self.opcodes = opcodes + + def parse(self, tokens): + return Parser(self.opcodes, tokens).parse().values if self.opcodes else {} + + def generate(self, values): + source = BCDataStream() + for opcode in self.opcodes: + if isinstance(opcode, PUSH_SINGLE): + data = values[opcode.name] + source.write_many(push_data(data)) + elif isinstance(opcode, PUSH_INTEGER): + data = values[opcode.name] + source.write_many(push_data( + data.to_bytes((data.bit_length() + 7) // 8, byteorder='little') + )) + elif isinstance(opcode, PUSH_SUBSCRIPT): + data = values[opcode.name] + source.write_many(push_data(data.source)) + elif isinstance(opcode, PUSH_MANY): + for data in values[opcode.name]: + source.write_many(push_data(data)) + elif isinstance(opcode, SMALL_INTEGER): + data = values[opcode.name] + source.write_many(push_small_integer(data)) + else: + source.write_uint8(opcode) + return source.get_bytes() + + +class Script: + + __slots__ = 'source', '_template', '_values', '_template_hint' + + templates: List[Template] = [] + + NO_SCRIPT = Template('no_script', None) # special case + + def __init__(self, source=None, template=None, values=None, template_hint=None): + self.source = source + self._template = template + self._values = values + self._template_hint = template_hint + if source is None and template and values: + self.generate() + + @property + def template(self): + if self._template is None: + self.parse(self._template_hint) + return self._template + + @property + def values(self): + if self._values is None: + self.parse(self._template_hint) + return self._values + + @property + def tokens(self): + return tokenize(BCDataStream(self.source)) + + @classmethod + def from_source_with_template(cls, source, template): + return cls(source, template_hint=template) + + def parse(self, template_hint=None): + tokens = self.tokens + if not tokens and not template_hint: + template_hint = self.NO_SCRIPT + for template in chain((template_hint,), self.templates): + if not template: + continue + try: + self._values = template.parse(tokens) + self._template = template + return + except ParseError: + continue + raise ValueError(f'No matching templates for source: {hexlify(self.source)}') + + def generate(self): + self.source = self.template.generate(self._values) + + +class InputScript(Script): + + __slots__ = () + + REDEEM_PUBKEY = Template('pubkey', ( + PUSH_SINGLE('signature'), + )) + REDEEM_PUBKEY_HASH = Template('pubkey_hash', ( + PUSH_SINGLE('signature'), PUSH_SINGLE('pubkey') + )) + REDEEM_SCRIPT = Template('script', ( + SMALL_INTEGER('signatures_count'), PUSH_MANY('pubkeys'), SMALL_INTEGER('pubkeys_count'), + OP_CHECKMULTISIG + )) + REDEEM_SCRIPT_HASH = Template('script_hash', ( + OP_0, PUSH_MANY('signatures'), PUSH_SUBSCRIPT('script', REDEEM_SCRIPT) + )) + + templates = [ + REDEEM_PUBKEY, + REDEEM_PUBKEY_HASH, + REDEEM_SCRIPT_HASH, + REDEEM_SCRIPT + ] + + @classmethod + def redeem_pubkey_hash(cls, signature, pubkey): + return cls(template=cls.REDEEM_PUBKEY_HASH, values={ + 'signature': signature, + 'pubkey': pubkey + }) + + @classmethod + def redeem_script_hash(cls, signatures, pubkeys): + return cls(template=cls.REDEEM_SCRIPT_HASH, values={ + 'signatures': signatures, + 'script': cls.redeem_script(signatures, pubkeys) + }) + + @classmethod + def redeem_script(cls, signatures, pubkeys): + return cls(template=cls.REDEEM_SCRIPT, values={ + 'signatures_count': len(signatures), + 'pubkeys': pubkeys, + 'pubkeys_count': len(pubkeys) + }) + + +class OutputScript(Script): + + __slots__ = () + + # output / payment script templates (aka scriptPubKey) + PAY_PUBKEY_FULL = Template('pay_pubkey_full', ( + PUSH_SINGLE('pubkey'), OP_CHECKSIG + )) + PAY_PUBKEY_HASH = Template('pay_pubkey_hash', ( + OP_DUP, OP_HASH160, PUSH_SINGLE('pubkey_hash'), OP_EQUALVERIFY, OP_CHECKSIG + )) + PAY_SCRIPT_HASH = Template('pay_script_hash', ( + OP_HASH160, PUSH_SINGLE('script_hash'), OP_EQUAL + )) + PAY_SEGWIT = Template('pay_script_hash+segwit', ( + OP_0, PUSH_SINGLE('script_hash') + )) + RETURN_DATA = Template('return_data', ( + OP_RETURN, PUSH_SINGLE('data') + )) CLAIM_NAME_OPCODES = ( OP_CLAIM_NAME, PUSH_SINGLE('claim_name'), PUSH_SINGLE('claim'), OP_2DROP, OP_DROP ) CLAIM_NAME_PUBKEY = Template('claim_name+pay_pubkey_hash', ( - CLAIM_NAME_OPCODES + BaseOutputScript.PAY_PUBKEY_HASH.opcodes + CLAIM_NAME_OPCODES + PAY_PUBKEY_HASH.opcodes )) CLAIM_NAME_SCRIPT = Template('claim_name+pay_script_hash', ( - CLAIM_NAME_OPCODES + BaseOutputScript.PAY_SCRIPT_HASH.opcodes + CLAIM_NAME_OPCODES + PAY_SCRIPT_HASH.opcodes )) SUPPORT_CLAIM_OPCODES = ( @@ -36,10 +432,10 @@ class OutputScript(BaseOutputScript): OP_2DROP, OP_DROP ) SUPPORT_CLAIM_PUBKEY = Template('support_claim+pay_pubkey_hash', ( - SUPPORT_CLAIM_OPCODES + BaseOutputScript.PAY_PUBKEY_HASH.opcodes + SUPPORT_CLAIM_OPCODES + PAY_PUBKEY_HASH.opcodes )) SUPPORT_CLAIM_SCRIPT = Template('support_claim+pay_script_hash', ( - SUPPORT_CLAIM_OPCODES + BaseOutputScript.PAY_SCRIPT_HASH.opcodes + SUPPORT_CLAIM_OPCODES + PAY_SCRIPT_HASH.opcodes )) UPDATE_CLAIM_OPCODES = ( @@ -47,10 +443,10 @@ class OutputScript(BaseOutputScript): OP_2DROP, OP_2DROP ) UPDATE_CLAIM_PUBKEY = Template('update_claim+pay_pubkey_hash', ( - UPDATE_CLAIM_OPCODES + BaseOutputScript.PAY_PUBKEY_HASH.opcodes + UPDATE_CLAIM_OPCODES + PAY_PUBKEY_HASH.opcodes )) UPDATE_CLAIM_SCRIPT = Template('update_claim+pay_script_hash', ( - UPDATE_CLAIM_OPCODES + BaseOutputScript.PAY_SCRIPT_HASH.opcodes + UPDATE_CLAIM_OPCODES + PAY_SCRIPT_HASH.opcodes )) SELL_SCRIPT = Template('sell_script', ( @@ -58,17 +454,22 @@ class OutputScript(BaseOutputScript): )) SELL_CLAIM = Template('sell_claim+pay_script_hash', ( OP_SELL_CLAIM, PUSH_SINGLE('claim_id'), PUSH_SUBSCRIPT('sell_script', SELL_SCRIPT), - PUSH_SUBSCRIPT('receive_script', BaseInputScript.REDEEM_SCRIPT), OP_2DROP, OP_2DROP - ) + BaseOutputScript.PAY_SCRIPT_HASH.opcodes) + PUSH_SUBSCRIPT('receive_script', InputScript.REDEEM_SCRIPT), OP_2DROP, OP_2DROP + ) + PAY_SCRIPT_HASH.opcodes) BUY_CLAIM = Template('buy_claim+pay_script_hash', ( OP_BUY_CLAIM, PUSH_SINGLE('sell_id'), PUSH_SINGLE('claim_id'), PUSH_SINGLE('claim_version'), PUSH_SINGLE('owner_pubkey_hash'), PUSH_SINGLE('negotiation_signature'), OP_2DROP, OP_2DROP, OP_2DROP, - ) + BaseOutputScript.PAY_SCRIPT_HASH.opcodes) + ) + PAY_SCRIPT_HASH.opcodes) - templates = BaseOutputScript.templates + [ + templates = [ + PAY_PUBKEY_FULL, + PAY_PUBKEY_HASH, + PAY_SCRIPT_HASH, + PAY_SEGWIT, + RETURN_DATA, CLAIM_NAME_PUBKEY, CLAIM_NAME_SCRIPT, SUPPORT_CLAIM_PUBKEY, @@ -79,6 +480,28 @@ class OutputScript(BaseOutputScript): BUY_CLAIM, ] + @classmethod + def pay_pubkey_hash(cls, pubkey_hash): + return cls(template=cls.PAY_PUBKEY_HASH, values={ + 'pubkey_hash': pubkey_hash + }) + + @classmethod + def pay_script_hash(cls, script_hash): + return cls(template=cls.PAY_SCRIPT_HASH, values={ + 'script_hash': script_hash + }) + + @classmethod + def return_data(cls, data): + return cls(template=cls.RETURN_DATA, values={ + 'data': data + }) + + @property + def is_pay_pubkey(self): + return self.template.name.endswith('pay_pubkey_full') + @classmethod def pay_claim_name_pubkey_hash(cls, claim_name, claim, pubkey_hash): return cls(template=cls.CLAIM_NAME_PUBKEY, values={ @@ -128,6 +551,18 @@ class OutputScript(BaseOutputScript): 'negotiation_signature': negotiation_signature, }) + @property + def is_pay_pubkey_hash(self): + return self.template.name.endswith('pay_pubkey_hash') + + @property + def is_pay_script_hash(self): + return self.template.name.endswith('pay_script_hash') + + @property + def is_return_data(self): + return self.template.name.endswith('return_data') + @property def is_claim_name(self): return self.template.name.startswith('claim_name+') diff --git a/lbry/wallet/server/coin.py b/lbry/wallet/server/coin.py index f1663769b..146d5b646 100644 --- a/lbry/wallet/server/coin.py +++ b/lbry/wallet/server/coin.py @@ -6,7 +6,7 @@ from decimal import Decimal from collections import namedtuple import lbry.wallet.server.tx as lib_tx -from lbry.wallet.script import OutputScript +from lbry.wallet.script import OutputScript, OP_CLAIM_NAME, OP_UPDATE_CLAIM, OP_SUPPORT_CLAIM from lbry.wallet.server.tx import DeserializerSegWit from lbry.wallet.server.util import cachedproperty, subclasses from lbry.wallet.server.hash import Base58, hash160, double_sha256, hash_to_hex_str, HASHX_LEN @@ -327,9 +327,9 @@ class LBC(Coin): if script and script[0] == OpCodes.OP_RETURN or not script: return None if script[0] in [ - OutputScript.OP_CLAIM_NAME, - OutputScript.OP_UPDATE_CLAIM, - OutputScript.OP_SUPPORT_CLAIM, + OP_CLAIM_NAME, + OP_UPDATE_CLAIM, + OP_SUPPORT_CLAIM, ]: return cls.address_to_hashX(cls.claim_address_handler(script)) else: diff --git a/lbry/wallet/server/db/full_text_search.py b/lbry/wallet/server/db/full_text_search.py index 3a88775c7..3f82fbf6d 100644 --- a/lbry/wallet/server/db/full_text_search.py +++ b/lbry/wallet/server/db/full_text_search.py @@ -1,4 +1,4 @@ -from lbry.wallet.client.basedatabase import constraints_to_sql +from lbry.wallet.database import constraints_to_sql CREATE_FULL_TEXT_SEARCH = """ create virtual table if not exists search using fts5( diff --git a/lbry/wallet/server/db/reader.py b/lbry/wallet/server/db/reader.py index 0d8858d80..dd5495e67 100644 --- a/lbry/wallet/server/db/reader.py +++ b/lbry/wallet/server/db/reader.py @@ -10,12 +10,12 @@ from contextvars import ContextVar from functools import wraps from dataclasses import dataclass -from lbry.wallet.client.basedatabase import query, interpolate +from lbry.wallet.database import query, interpolate from lbry.schema.url import URL, normalize_name from lbry.schema.tags import clean_tags from lbry.schema.result import Outputs -from lbry.wallet.ledger import BaseLedger, MainNetLedger, RegTestLedger +from lbry.wallet import Ledger, RegTestLedger from .common import CLAIM_TYPES, STREAM_TYPES, COMMON_TAGS from .full_text_search import FTS_ORDER_BY @@ -67,7 +67,7 @@ class ReaderState: stack: List[List] metrics: Dict is_tracking_metrics: bool - ledger: Type[BaseLedger] + ledger: Type[Ledger] query_timeout: float log: logging.Logger @@ -100,7 +100,7 @@ def initializer(log, _path, _ledger_name, query_timeout, _measure=False): ctx.set( ReaderState( db=db, stack=[], metrics={}, is_tracking_metrics=_measure, - ledger=MainNetLedger if _ledger_name == 'mainnet' else RegTestLedger, + ledger=Ledger if _ledger_name == 'mainnet' else RegTestLedger, query_timeout=query_timeout, log=log ) ) diff --git a/lbry/wallet/server/db/writer.py b/lbry/wallet/server/db/writer.py index 6dca73c22..094694f34 100644 --- a/lbry/wallet/server/db/writer.py +++ b/lbry/wallet/server/db/writer.py @@ -7,11 +7,11 @@ from collections import namedtuple from lbry.wallet.server.leveldb import DB from lbry.wallet.server.util import class_logger -from lbry.wallet.client.basedatabase import query, constraints_to_sql +from lbry.wallet.database import query, constraints_to_sql from lbry.schema.tags import clean_tags from lbry.schema.mime_types import guess_stream_type -from lbry.wallet.ledger import MainNetLedger, RegTestLedger +from lbry.wallet import Ledger, RegTestLedger from lbry.wallet.transaction import Transaction, Output from lbry.wallet.server.db.canonical import register_canonical_functions from lbry.wallet.server.db.full_text_search import update_full_text_search, CREATE_FULL_TEXT_SEARCH, first_sync_finished @@ -171,7 +171,7 @@ class SQLDB: self._db_path = path self.db = None self.logger = class_logger(__name__, self.__class__.__name__) - self.ledger = MainNetLedger if self.main.coin.NET == 'mainnet' else RegTestLedger + self.ledger = Ledger if self.main.coin.NET == 'mainnet' else RegTestLedger self._fts_synced = False def open(self): diff --git a/lbry/wallet/transaction.py b/lbry/wallet/transaction.py index 74ec9bdd2..5e800fdd3 100644 --- a/lbry/wallet/transaction.py +++ b/lbry/wallet/transaction.py @@ -1,9 +1,12 @@ +import ecdsa import struct import hashlib -from binascii import hexlify, unhexlify -from typing import List, Optional +import logging +import typing + +from binascii import hexlify, unhexlify +from typing import List, Iterable, Optional, Tuple -import ecdsa from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.serialization import load_der_public_key from cryptography.hazmat.primitives import hashes @@ -11,34 +14,216 @@ from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.primitives.asymmetric.utils import Prehashed from cryptography.exceptions import InvalidSignature -from lbry.crypto.base58 import Base58 +from lbry.error import InsufficientFundsError from lbry.crypto.hash import hash160, sha256 -from lbry.wallet.client.basetransaction import BaseTransaction, BaseInput, BaseOutput, ReadOnlyList +from lbry.crypto.base58 import Base58 +from lbry.schema.url import normalize_name from lbry.schema.claim import Claim from lbry.schema.purchase import Purchase -from lbry.schema.url import normalize_name -from lbry.wallet.account import Account -from lbry.wallet.script import InputScript, OutputScript + +from .script import InputScript, OutputScript +from .constants import COIN, NULL_HASH32 +from .bcd_data_stream import BCDataStream +from .hash import TXRef, TXRefImmutable +from .util import ReadOnlyList + +if typing.TYPE_CHECKING: + from lbry.wallet.account import Account + from lbry.wallet.ledger import Ledger + from lbry.wallet.wallet import Wallet + +log = logging.getLogger() -class Input(BaseInput): - script: InputScript - script_class = InputScript +class TXRefMutable(TXRef): + + __slots__ = ('tx',) + + def __init__(self, tx: 'Transaction') -> None: + super().__init__() + self.tx = tx + + @property + def id(self): + if self._id is None: + self._id = hexlify(self.hash[::-1]).decode() + return self._id + + @property + def hash(self): + if self._hash is None: + self._hash = sha256(sha256(self.tx.raw_sans_segwit)) + return self._hash + + @property + def height(self): + return self.tx.height + + def reset(self): + self._id = None + self._hash = None -class Output(BaseOutput): - script: OutputScript - script_class = OutputScript +class TXORef: + + __slots__ = 'tx_ref', 'position' + + def __init__(self, tx_ref: TXRef, position: int) -> None: + self.tx_ref = tx_ref + self.position = position + + @property + def id(self): + return f'{self.tx_ref.id}:{self.position}' + + @property + def hash(self): + return self.tx_ref.hash + BCDataStream.uint32.pack(self.position) + + @property + def is_null(self): + return self.tx_ref.is_null + + @property + def txo(self) -> Optional['Output']: + return None + + +class TXORefResolvable(TXORef): + + __slots__ = ('_txo',) + + def __init__(self, txo: 'Output') -> None: + assert txo.tx_ref is not None + assert txo.position is not None + super().__init__(txo.tx_ref, txo.position) + self._txo = txo + + @property + def txo(self): + return self._txo + + +class InputOutput: + + __slots__ = 'tx_ref', 'position' + + def __init__(self, tx_ref: TXRef = None, position: int = None) -> None: + self.tx_ref = tx_ref + self.position = position + + @property + def size(self) -> int: + """ Size of this input / output in bytes. """ + stream = BCDataStream() + self.serialize_to(stream) + return len(stream.get_bytes()) + + def get_fee(self, ledger): + return self.size * ledger.fee_per_byte + + def serialize_to(self, stream, alternate_script=None): + raise NotImplementedError + + +class Input(InputOutput): + + NULL_SIGNATURE = b'\x00'*72 + NULL_PUBLIC_KEY = b'\x00'*33 + + __slots__ = 'txo_ref', 'sequence', 'coinbase', 'script' + + def __init__(self, txo_ref: TXORef, script: InputScript, sequence: int = 0xFFFFFFFF, + tx_ref: TXRef = None, position: int = None) -> None: + super().__init__(tx_ref, position) + self.txo_ref = txo_ref + self.sequence = sequence + self.coinbase = script if txo_ref.is_null else None + self.script = script if not txo_ref.is_null else None + + @property + def is_coinbase(self): + return self.coinbase is not None + + @classmethod + def spend(cls, txo: 'Output') -> 'Input': + """ Create an input to spend the output.""" + assert txo.script.is_pay_pubkey_hash, 'Attempting to spend unsupported output.' + script = InputScript.redeem_pubkey_hash(cls.NULL_SIGNATURE, cls.NULL_PUBLIC_KEY) + return cls(txo.ref, script) + + @property + def amount(self) -> int: + """ Amount this input adds to the transaction. """ + if self.txo_ref.txo is None: + raise ValueError('Cannot resolve output to get amount.') + return self.txo_ref.txo.amount + + @property + def is_my_account(self) -> Optional[bool]: + """ True if the output this input spends is yours. """ + if self.txo_ref.txo is None: + return False + return self.txo_ref.txo.is_my_account + + @classmethod + def deserialize_from(cls, stream): + tx_ref = TXRefImmutable.from_hash(stream.read(32), -1) + position = stream.read_uint32() + script = stream.read_string() + sequence = stream.read_uint32() + return cls( + TXORef(tx_ref, position), + InputScript(script) if not tx_ref.is_null else script, + sequence + ) + + def serialize_to(self, stream, alternate_script=None): + stream.write(self.txo_ref.tx_ref.hash) + stream.write_uint32(self.txo_ref.position) + if alternate_script is not None: + stream.write_string(alternate_script) + else: + if self.is_coinbase: + stream.write_string(self.coinbase) + else: + stream.write_string(self.script.source) + stream.write_uint32(self.sequence) + + +class OutputEffectiveAmountEstimator: + + __slots__ = 'txo', 'txi', 'fee', 'effective_amount' + + def __init__(self, ledger: 'Ledger', txo: 'Output') -> None: + self.txo = txo + self.txi = Input.spend(txo) + self.fee: int = self.txi.get_fee(ledger) + self.effective_amount: int = txo.amount - self.fee + + def __lt__(self, other): + return self.effective_amount < other.effective_amount + + +class Output(InputOutput): __slots__ = ( + 'amount', 'script', 'is_change', 'is_my_account', 'channel', 'private_key', 'meta', 'purchase', 'purchased_claim', 'purchase_receipt', 'reposted_claim', 'claims', ) - def __init__(self, *args, channel: Optional['Output'] = None, - private_key: Optional[str] = None, **kwargs) -> None: - super().__init__(*args, **kwargs) + def __init__(self, amount: int, script: OutputScript, + tx_ref: TXRef = None, position: int = None, + is_change: Optional[bool] = None, is_my_account: Optional[bool] = None, + channel: Optional['Output'] = None, private_key: Optional[str] = None + ) -> None: + super().__init__(tx_ref, position) + self.amount = amount + self.script = script + self.is_change = is_change + self.is_my_account = is_my_account self.channel = channel self.private_key = private_key self.purchase: 'Output' = None # txo containing purchase metadata @@ -49,10 +234,52 @@ class Output(BaseOutput): self.meta = {} def update_annotations(self, annotated): - super().update_annotations(annotated) + if annotated is None: + self.is_change = False + self.is_my_account = False + else: + self.is_change = annotated.is_change + self.is_my_account = annotated.is_my_account self.channel = annotated.channel if annotated else None self.private_key = annotated.private_key if annotated else None + @property + def ref(self): + return TXORefResolvable(self) + + @property + def id(self): + return self.ref.id + + @property + def pubkey_hash(self): + return self.script.values['pubkey_hash'] + + @property + def has_address(self): + return 'pubkey_hash' in self.script.values + + def get_address(self, ledger): + return ledger.hash160_to_address(self.pubkey_hash) + + def get_estimator(self, ledger): + return OutputEffectiveAmountEstimator(ledger, self) + + @classmethod + def pay_pubkey_hash(cls, amount, pubkey_hash): + return cls(amount, OutputScript.pay_pubkey_hash(pubkey_hash)) + + @classmethod + def deserialize_from(cls, stream): + return cls( + amount=stream.read_uint64(), + script=OutputScript(stream.read_string()) + ) + + def serialize_to(self, stream, alternate_script=None): + stream.write_uint64(self.amount) + stream.write_string(self.script.source) + def get_fee(self, ledger): name_fee = 0 if self.script.is_claim_name: @@ -180,34 +407,35 @@ class Output(BaseOutput): @classmethod def pay_claim_name_pubkey_hash( cls, amount: int, claim_name: str, claim: Claim, pubkey_hash: bytes) -> 'Output': - script = cls.script_class.pay_claim_name_pubkey_hash( + script = OutputScript.pay_claim_name_pubkey_hash( claim_name.encode(), claim, pubkey_hash) - txo = cls(amount, script) - return txo + return cls(amount, script) @classmethod def pay_update_claim_pubkey_hash( cls, amount: int, claim_name: str, claim_id: str, claim: Claim, pubkey_hash: bytes) -> 'Output': - script = cls.script_class.pay_update_claim_pubkey_hash( - claim_name.encode(), unhexlify(claim_id)[::-1], claim, pubkey_hash) - txo = cls(amount, script) - return txo + script = OutputScript.pay_update_claim_pubkey_hash( + claim_name.encode(), unhexlify(claim_id)[::-1], claim, pubkey_hash + ) + return cls(amount, script) @classmethod def pay_support_pubkey_hash(cls, amount: int, claim_name: str, claim_id: str, pubkey_hash: bytes) -> 'Output': - script = cls.script_class.pay_support_pubkey_hash(claim_name.encode(), unhexlify(claim_id)[::-1], pubkey_hash) + script = OutputScript.pay_support_pubkey_hash( + claim_name.encode(), unhexlify(claim_id)[::-1], pubkey_hash + ) return cls(amount, script) @classmethod def add_purchase_data(cls, purchase: Purchase) -> 'Output': - script = cls.script_class.return_data(purchase) + script = OutputScript.return_data(purchase) return cls(0, script) @property def is_purchase_data(self) -> bool: return self.script.is_return_data and ( - isinstance(self.script.values['data'], Purchase) or - Purchase.has_start_byte(self.script.values['data']) + isinstance(self.script.values['data'], Purchase) or + Purchase.has_start_byte(self.script.values['data']) ) @property @@ -246,16 +474,331 @@ class Output(BaseOutput): return self.claim.stream.fee -class Transaction(BaseTransaction): +class Transaction: - input_class = Input - output_class = Output + def __init__(self, raw=None, version: int = 1, locktime: int = 0, is_verified: bool = False, + height: int = -2, position: int = -1) -> None: + self._raw = raw + self._raw_sans_segwit = None + self.is_segwit_flag = 0 + self.witnesses: List[bytes] = [] + self.ref = TXRefMutable(self) + self.version = version + self.locktime = locktime + self._inputs: List[Input] = [] + self._outputs: List[Output] = [] + self.is_verified = is_verified + # Height Progression + # -2: not broadcast + # -1: in mempool but has unconfirmed inputs + # 0: in mempool and all inputs confirmed + # +num: confirmed in a specific block (height) + self.height = height + self.position = position + if raw is not None: + self._deserialize() - outputs: ReadOnlyList[Output] - inputs: ReadOnlyList[Input] + @property + def is_broadcast(self): + return self.height > -2 + + @property + def is_mempool(self): + return self.height in (-1, 0) + + @property + def is_confirmed(self): + return self.height > 0 + + @property + def id(self): + return self.ref.id + + @property + def hash(self): + return self.ref.hash + + @property + def raw(self): + if self._raw is None: + self._raw = self._serialize() + return self._raw + + @property + def raw_sans_segwit(self): + if self.is_segwit_flag: + if self._raw_sans_segwit is None: + self._raw_sans_segwit = self._serialize(sans_segwit=True) + return self._raw_sans_segwit + return self.raw + + def _reset(self): + self._raw = None + self._raw_sans_segwit = None + self.ref.reset() + + @property + def inputs(self) -> ReadOnlyList[Input]: + return ReadOnlyList(self._inputs) + + @property + def outputs(self) -> ReadOnlyList[Output]: + return ReadOnlyList(self._outputs) + + def _add(self, existing_ios: List, new_ios: Iterable[InputOutput], reset=False) -> 'Transaction': + for txio in new_ios: + txio.tx_ref = self.ref + txio.position = len(existing_ios) + existing_ios.append(txio) + if reset: + self._reset() + return self + + def add_inputs(self, inputs: Iterable[Input]) -> 'Transaction': + return self._add(self._inputs, inputs, True) + + def add_outputs(self, outputs: Iterable[Output]) -> 'Transaction': + return self._add(self._outputs, outputs, True) + + @property + def size(self) -> int: + """ Size in bytes of the entire transaction. """ + return len(self.raw) + + @property + def base_size(self) -> int: + """ Size of transaction without inputs or outputs in bytes. """ + return ( + self.size + - sum(txi.size for txi in self._inputs) + - sum(txo.size for txo in self._outputs) + ) + + @property + def input_sum(self): + return sum(i.amount for i in self.inputs if i.txo_ref.txo is not None) + + @property + def output_sum(self): + return sum(o.amount for o in self.outputs) + + @property + def net_account_balance(self) -> int: + balance = 0 + for txi in self.inputs: + if txi.txo_ref.txo is None: + continue + if txi.is_my_account is None: + raise ValueError( + "Cannot access net_account_balance if inputs/outputs do not " + "have is_my_account set properly." + ) + if txi.is_my_account: + balance -= txi.amount + for txo in self.outputs: + if txo.is_my_account is None: + raise ValueError( + "Cannot access net_account_balance if inputs/outputs do not " + "have is_my_account set properly." + ) + if txo.is_my_account: + balance += txo.amount + return balance + + @property + def fee(self) -> int: + return self.input_sum - self.output_sum + + def get_base_fee(self, ledger) -> int: + """ Fee for base tx excluding inputs and outputs. """ + return self.base_size * ledger.fee_per_byte + + def get_effective_input_sum(self, ledger) -> int: + """ Sum of input values *minus* the cost involved to spend them. """ + return sum(txi.amount - txi.get_fee(ledger) for txi in self._inputs) + + def get_total_output_sum(self, ledger) -> int: + """ Sum of output values *plus* the cost involved to spend them. """ + return sum(txo.amount + txo.get_fee(ledger) for txo in self._outputs) + + def _serialize(self, with_inputs: bool = True, sans_segwit: bool = False) -> bytes: + stream = BCDataStream() + stream.write_uint32(self.version) + if with_inputs: + stream.write_compact_size(len(self._inputs)) + for txin in self._inputs: + txin.serialize_to(stream) + stream.write_compact_size(len(self._outputs)) + for txout in self._outputs: + txout.serialize_to(stream) + stream.write_uint32(self.locktime) + return stream.get_bytes() + + def _serialize_for_signature(self, signing_input: int) -> bytes: + stream = BCDataStream() + stream.write_uint32(self.version) + stream.write_compact_size(len(self._inputs)) + for i, txin in enumerate(self._inputs): + if signing_input == i: + assert txin.txo_ref.txo is not None + txin.serialize_to(stream, txin.txo_ref.txo.script.source) + else: + txin.serialize_to(stream, b'') + stream.write_compact_size(len(self._outputs)) + for txout in self._outputs: + txout.serialize_to(stream) + stream.write_uint32(self.locktime) + stream.write_uint32(self.signature_hash_type(1)) # signature hash type: SIGHASH_ALL + return stream.get_bytes() + + def _deserialize(self): + if self._raw is not None: + stream = BCDataStream(self._raw) + self.version = stream.read_uint32() + input_count = stream.read_compact_size() + if input_count == 0: + self.is_segwit_flag = stream.read_uint8() + input_count = stream.read_compact_size() + self._add(self._inputs, [ + Input.deserialize_from(stream) for _ in range(input_count) + ]) + output_count = stream.read_compact_size() + self._add(self._outputs, [ + Output.deserialize_from(stream) for _ in range(output_count) + ]) + if self.is_segwit_flag: + # drain witness portion of transaction + # too many witnesses for no crime + self.witnesses = [] + for _ in range(input_count): + for _ in range(stream.read_compact_size()): + self.witnesses.append(stream.read(stream.read_compact_size())) + self.locktime = stream.read_uint32() @classmethod - def pay(cls, amount: int, address: bytes, funding_accounts: List[Account], change_account: Account): + def ensure_all_have_same_ledger_and_wallet( + cls, funding_accounts: Iterable['Account'], + change_account: 'Account' = None) -> Tuple['Ledger', 'Wallet']: + ledger = wallet = None + for account in funding_accounts: + if ledger is None: + ledger = account.ledger + wallet = account.wallet + if ledger != account.ledger: + raise ValueError( + 'All funding accounts used to create a transaction must be on the same ledger.' + ) + if wallet != account.wallet: + raise ValueError( + 'All funding accounts used to create a transaction must be from the same wallet.' + ) + if change_account is not None: + if change_account.ledger != ledger: + raise ValueError('Change account must use same ledger as funding accounts.') + if change_account.wallet != wallet: + raise ValueError('Change account must use same wallet as funding accounts.') + if ledger is None: + raise ValueError('No ledger found.') + if wallet is None: + raise ValueError('No wallet found.') + return ledger, wallet + + @classmethod + async def create(cls, inputs: Iterable[Input], outputs: Iterable[Output], + funding_accounts: Iterable['Account'], change_account: 'Account', + sign: bool = True): + """ Find optimal set of inputs when only outputs are provided; add change + outputs if only inputs are provided or if inputs are greater than outputs. """ + + tx = cls() \ + .add_inputs(inputs) \ + .add_outputs(outputs) + + ledger, _ = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account) + + # value of the outputs plus associated fees + cost = ( + tx.get_base_fee(ledger) + + tx.get_total_output_sum(ledger) + ) + # value of the inputs less the cost to spend those inputs + payment = tx.get_effective_input_sum(ledger) + + try: + + for _ in range(5): + + if payment < cost: + deficit = cost - payment + spendables = await ledger.get_spendable_utxos(deficit, funding_accounts) + if not spendables: + raise InsufficientFundsError() + payment += sum(s.effective_amount for s in spendables) + tx.add_inputs(s.txi for s in spendables) + + cost_of_change = ( + tx.get_base_fee(ledger) + + Output.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(ledger) + ) + if payment > cost: + change = payment - cost + if change > cost_of_change: + change_address = await change_account.change.get_or_create_usable_address() + change_hash160 = change_account.ledger.address_to_hash160(change_address) + change_amount = change - cost_of_change + change_output = Output.pay_pubkey_hash(change_amount, change_hash160) + change_output.is_change = True + tx.add_outputs([Output.pay_pubkey_hash(change_amount, change_hash160)]) + + if tx._outputs: + break + # this condition and the outer range(5) loop cover an edge case + # whereby a single input is just enough to cover the fee and + # has some change left over, but the change left over is less + # than the cost_of_change: thus the input is completely + # consumed and no output is added, which is an invalid tx. + # to be able to spend this input we must increase the cost + # of the TX and run through the balance algorithm a second time + # adding an extra input and change output, making tx valid. + # we do this 5 times in case the other UTXOs added are also + # less than the fee, after 5 attempts we give up and go home + cost += cost_of_change + 1 + + if sign: + await tx.sign(funding_accounts) + + except Exception as e: + log.exception('Failed to create transaction:') + await ledger.release_tx(tx) + raise e + + return tx + + @staticmethod + def signature_hash_type(hash_type): + return hash_type + + async def sign(self, funding_accounts: Iterable['Account']): + ledger, wallet = self.ensure_all_have_same_ledger_and_wallet(funding_accounts) + for i, txi in enumerate(self._inputs): + assert txi.script is not None + assert txi.txo_ref.txo is not None + txo_script = txi.txo_ref.txo.script + if txo_script.is_pay_pubkey_hash: + address = ledger.hash160_to_address(txo_script.values['pubkey_hash']) + private_key = await ledger.get_private_key_for_address(wallet, address) + assert private_key is not None, 'Cannot find private key for signing output.' + tx = self._serialize_for_signature(i) + txi.script.values['signature'] = \ + private_key.sign(tx) + bytes((self.signature_hash_type(1),)) + txi.script.values['pubkey'] = private_key.public_key.pubkey_bytes + txi.script.generate() + else: + raise NotImplementedError("Don't know how to spend this output.") + self._reset() + + @classmethod + def pay(cls, amount: int, address: bytes, funding_accounts: List['Account'], change_account: 'Account'): ledger, wallet = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account) output = Output.pay_pubkey_hash(amount, ledger.address_to_hash160(address)) return cls.create([], [output], funding_accounts, change_account) @@ -263,7 +806,7 @@ class Transaction(BaseTransaction): @classmethod def claim_create( cls, name: str, claim: Claim, amount: int, holding_address: str, - funding_accounts: List[Account], change_account: Account, signing_channel: Output = None): + funding_accounts: List['Account'], change_account: 'Account', signing_channel: Output = None): ledger, wallet = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account) claim_output = Output.pay_claim_name_pubkey_hash( amount, name, claim, ledger.address_to_hash160(holding_address) @@ -275,7 +818,7 @@ class Transaction(BaseTransaction): @classmethod def claim_update( cls, previous_claim: Output, claim: Claim, amount: int, holding_address: str, - funding_accounts: List[Account], change_account: Account, signing_channel: Output = None): + funding_accounts: List['Account'], change_account: 'Account', signing_channel: Output = None): ledger, wallet = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account) updated_claim = Output.pay_update_claim_pubkey_hash( amount, previous_claim.claim_name, previous_claim.claim_id, @@ -291,7 +834,7 @@ class Transaction(BaseTransaction): @classmethod def support(cls, claim_name: str, claim_id: str, amount: int, holding_address: str, - funding_accounts: List[Account], change_account: Account): + funding_accounts: List['Account'], change_account: 'Account'): ledger, wallet = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account) support_output = Output.pay_support_pubkey_hash( amount, claim_name, claim_id, ledger.address_to_hash160(holding_address) @@ -300,7 +843,7 @@ class Transaction(BaseTransaction): @classmethod def purchase(cls, claim_id: str, amount: int, merchant_address: bytes, - funding_accounts: List[Account], change_account: Account): + funding_accounts: List['Account'], change_account: 'Account'): ledger, wallet = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account) payment = Output.pay_pubkey_hash(amount, ledger.address_to_hash160(merchant_address)) data = Output.add_purchase_data(Purchase(claim_id)) diff --git a/lbry/wallet/util.py b/lbry/wallet/util.py index f30b4f532..a9504bff1 100644 --- a/lbry/wallet/util.py +++ b/lbry/wallet/util.py @@ -1,6 +1,6 @@ import re from typing import TypeVar, Sequence, Optional -from lbry.wallet.client.constants import COIN +from .constants import COIN def coins_to_satoshis(coins): diff --git a/lbry/wallet/wallet.py b/lbry/wallet/wallet.py index 18e9a72d0..72ba6f9d6 100644 --- a/lbry/wallet/wallet.py +++ b/lbry/wallet/wallet.py @@ -10,9 +10,11 @@ from collections import UserDict from hashlib import sha256 from operator import attrgetter from lbry.crypto.crypt import better_aes_encrypt, better_aes_decrypt +from .account import Account if typing.TYPE_CHECKING: - from lbry.wallet.client import basemanager, baseaccount, baseledger + from lbry.wallet.manager import WalletManager + from lbry.wallet.ledger import Ledger log = logging.getLogger(__name__) @@ -65,7 +67,7 @@ class Wallet: preferences: TimestampedPreferences encryption_password: Optional[str] - def __init__(self, name: str = 'Wallet', accounts: MutableSequence['baseaccount.BaseAccount'] = None, + def __init__(self, name: str = 'Wallet', accounts: MutableSequence['Account'] = None, storage: 'WalletStorage' = None, preferences: dict = None) -> None: self.name = name self.accounts = accounts or [] @@ -79,30 +81,30 @@ class Wallet: return os.path.basename(self.storage.path) return self.name - def add_account(self, account: 'baseaccount.BaseAccount'): + def add_account(self, account: 'Account'): self.accounts.append(account) - def generate_account(self, ledger: 'baseledger.BaseLedger') -> 'baseaccount.BaseAccount': - return ledger.account_class.generate(ledger, self) + def generate_account(self, ledger: 'Ledger') -> 'Account': + return Account.generate(ledger, self) @property - def default_account(self) -> Optional['baseaccount.BaseAccount']: + def default_account(self) -> Optional['Account']: for account in self.accounts: return account return None - def get_account_or_default(self, account_id: str) -> Optional['baseaccount.BaseAccount']: + def get_account_or_default(self, account_id: str) -> Optional['Account']: if account_id is None: return self.default_account return self.get_account_or_error(account_id) - def get_account_or_error(self, account_id: str) -> 'baseaccount.BaseAccount': + def get_account_or_error(self, account_id: str) -> 'Account': for account in self.accounts: if account.id == account_id: return account raise ValueError(f"Couldn't find account: {account_id}.") - def get_accounts_or_all(self, account_ids: List[str]) -> Sequence['baseaccount.BaseAccount']: + def get_accounts_or_all(self, account_ids: List[str]) -> Sequence['Account']: return [ self.get_account_or_error(account_id) for account_id in account_ids @@ -117,7 +119,7 @@ class Wallet: return accounts @classmethod - def from_storage(cls, storage: 'WalletStorage', manager: 'basemanager.BaseWalletManager') -> 'Wallet': + def from_storage(cls, storage: 'WalletStorage', manager: 'WalletManager') -> 'Wallet': json_dict = storage.read() wallet = cls( name=json_dict.get('name', 'Wallet'), @@ -127,7 +129,7 @@ class Wallet: account_dicts: Sequence[dict] = json_dict.get('accounts', []) for account_dict in account_dicts: ledger = manager.get_or_create_ledger(account_dict['ledger']) - ledger.account_class.from_dict(ledger, wallet, account_dict) + Account.from_dict(ledger, wallet, account_dict) return wallet def to_dict(self, encrypt_password: str = None): @@ -173,15 +175,15 @@ class Wallet: decompressed = zlib.decompress(decrypted) return json.loads(decompressed) - def merge(self, manager: 'basemanager.BaseWalletManager', - password: str, data: str) -> List['baseaccount.BaseAccount']: + def merge(self, manager: 'WalletManager', + password: str, data: str) -> List['Account']: assert not self.is_locked, "Cannot sync apply on a locked wallet." added_accounts = [] decrypted_data = self.unpack(password, data) self.preferences.merge(decrypted_data.get('preferences', {})) for account_dict in decrypted_data['accounts']: ledger = manager.get_or_create_ledger(account_dict['ledger']) - _, _, pubkey = ledger.account_class.keys_from_dict(ledger, account_dict) + _, _, pubkey = Account.keys_from_dict(ledger, account_dict) account_id = pubkey.address local_match = None for local_account in self.accounts: @@ -191,7 +193,7 @@ class Wallet: if local_match is not None: local_match.merge(account_dict) else: - new_account = ledger.account_class.from_dict(ledger, self, account_dict) + new_account = Account.from_dict(ledger, self, account_dict) added_accounts.append(new_account) return added_accounts diff --git a/tests/integration/blockchain/test_wallet_commands.py b/tests/integration/blockchain/test_wallet_commands.py index a34197d84..126f3fd9a 100644 --- a/tests/integration/blockchain/test_wallet_commands.py +++ b/tests/integration/blockchain/test_wallet_commands.py @@ -1,7 +1,7 @@ import asyncio import json -from lbry.wallet.client.wallet import ENCRYPT_ON_DISK +from lbry.wallet import ENCRYPT_ON_DISK from lbry.error import InvalidPasswordError from lbry.testcase import CommandTestCase from lbry.wallet.dewies import dict_values_to_lbc diff --git a/tests/unit/lbrynet_daemon/test_Daemon.py b/tests/unit/lbrynet_daemon/test_Daemon.py index d247f6900..f22deeeec 100644 --- a/tests/unit/lbrynet_daemon/test_Daemon.py +++ b/tests/unit/lbrynet_daemon/test_Daemon.py @@ -2,7 +2,6 @@ import unittest from unittest import mock import json -import lbry.wallet from lbry.conf import Config from lbry.extras.daemon.storage import SQLiteStorage from lbry.extras.daemon.ComponentManager import ComponentManager @@ -11,8 +10,7 @@ from lbry.extras.daemon.Components import HASH_ANNOUNCER_COMPONENT from lbry.extras.daemon.Components import UPNP_COMPONENT, BLOB_COMPONENT from lbry.extras.daemon.Components import PEER_PROTOCOL_SERVER_COMPONENT, EXCHANGE_RATE_MANAGER_COMPONENT from lbry.extras.daemon.Daemon import Daemon as LBRYDaemon -from lbry.wallet import LbryWalletManager -from lbry.wallet.client.wallet import Wallet +from lbry.wallet import WalletManager, Wallet from tests import test_utils # from tests.mocks import mock_conf_settings, FakeNetwork, FakeFileManager @@ -37,7 +35,7 @@ def get_test_daemon(conf: Config, with_fee=False): ) daemon = LBRYDaemon(conf, component_manager=component_manager) daemon.payment_rate_manager = OnlyFreePaymentsManager() - daemon.wallet_manager = mock.Mock(spec=LbryWalletManager) + daemon.wallet_manager = mock.Mock(spec=WalletManager) daemon.wallet_manager.wallet = mock.Mock(spec=Wallet) daemon.wallet_manager.use_encryption = False daemon.wallet_manager.network = FakeNetwork() diff --git a/tests/unit/stream/test_stream_manager.py b/tests/unit/stream/test_stream_manager.py index d34fa8159..ae4f83782 100644 --- a/tests/unit/stream/test_stream_manager.py +++ b/tests/unit/stream/test_stream_manager.py @@ -10,13 +10,10 @@ from lbry.testcase import get_fake_exchange_rate_manager from lbry.utils import generate_id from lbry.error import InsufficientFundsError from lbry.error import KeyFeeAboveMaxAllowedError, ResolveError, DownloadSDTimeoutError, DownloadDataTimeoutError -from lbry.wallet.client.wallet import Wallet +from lbry.wallet import WalletManager, Wallet, Ledger, Transaction, Input, Output, Database from lbry.wallet.client.constants import CENT, NULL_HASH32 from lbry.wallet.client.basenetwork import ClientSession from lbry.conf import Config -from lbry.wallet.ledger import MainNetLedger -from lbry.wallet.transaction import Transaction, Input, Output -from lbry.wallet.manager import LbryWalletManager from lbry.extras.daemon.analytics import AnalyticsManager from lbry.stream.stream_manager import StreamManager from lbry.stream.descriptor import StreamDescriptor @@ -94,16 +91,16 @@ async def get_mock_wallet(sd_hash, storage, balance=10.0, fee=None): return {'timestamp': 1984} wallet = Wallet() - ledger = MainNetLedger({ - 'db': MainNetLedger.database_class(':memory:'), + ledger = Ledger({ + 'db': Database(':memory:'), 'headers': FakeHeaders(514082) }) await ledger.db.open() wallet.generate_account(ledger) - manager = LbryWalletManager() + manager = WalletManager() manager.config = Config() manager.wallets.append(wallet) - manager.ledgers[MainNetLedger] = ledger + manager.ledgers[Ledger] = ledger manager.ledger.network.client = ClientSession( network=manager.ledger.network, server=('fakespv.lbry.com', 50001) ) diff --git a/tests/unit/wallet/test_account.py b/tests/unit/wallet/test_account.py index b07e56dec..a54962ed6 100644 --- a/tests/unit/wallet/test_account.py +++ b/tests/unit/wallet/test_account.py @@ -1,17 +1,13 @@ from binascii import hexlify from lbry.testcase import AsyncioTestCase -from lbry.wallet.client.wallet import Wallet -from lbry.wallet.ledger import MainNetLedger, WalletDatabase -from lbry.wallet.header import Headers -from lbry.wallet.account import Account -from lbry.wallet.client.baseaccount import SingleKey, HierarchicalDeterministic +from lbry.wallet import Wallet, Ledger, Database, Headers, Account, SingleKey, HierarchicalDeterministic class TestAccount(AsyncioTestCase): async def asyncSetUp(self): - self.ledger = MainNetLedger({ - 'db': WalletDatabase(':memory:'), + self.ledger = Ledger({ + 'db': Database(':memory:'), 'headers': Headers(':memory:') }) await self.ledger.db.open() @@ -236,8 +232,8 @@ class TestAccount(AsyncioTestCase): class TestSingleKeyAccount(AsyncioTestCase): async def asyncSetUp(self): - self.ledger = MainNetLedger({ - 'db': WalletDatabase(':memory:'), + self.ledger = Ledger({ + 'db': Database(':memory:'), 'headers': Headers(':memory:') }) await self.ledger.db.open() @@ -327,7 +323,7 @@ class TestSingleKeyAccount(AsyncioTestCase): self.assertEqual(len(keys), 1) async def test_generate_account_from_seed(self): - account = self.ledger.account_class.from_dict( + account = Account.from_dict( self.ledger, Wallet(), { "seed": "carbon smart garage balance margin twelve chest sword toas" @@ -432,8 +428,8 @@ class AccountEncryptionTests(AsyncioTestCase): } async def asyncSetUp(self): - self.ledger = MainNetLedger({ - 'db': WalletDatabase(':memory:'), + self.ledger = Ledger({ + 'db': Database(':memory:'), 'headers': Headers(':memory:') }) @@ -489,7 +485,7 @@ class AccountEncryptionTests(AsyncioTestCase): account_data = self.unencrypted_account.copy() del account_data['seed'] del account_data['private_key'] - account = self.ledger.account_class.from_dict(self.ledger, Wallet(), account_data) + account = Account.from_dict(self.ledger, Wallet(), account_data) encrypted = account.to_dict('password') self.assertFalse(encrypted['seed']) self.assertFalse(encrypted['private_key']) diff --git a/tests/unit/wallet/test_bcd_data_stream.py b/tests/unit/wallet/test_bcd_data_stream.py index 5e38c0b3d..ab2095e45 100644 --- a/tests/unit/wallet/test_bcd_data_stream.py +++ b/tests/unit/wallet/test_bcd_data_stream.py @@ -1,6 +1,6 @@ import unittest -from lbry.wallet.client.bcd_data_stream import BCDataStream +from lbry.wallet.bcd_data_stream import BCDataStream class TestBCDataStream(unittest.TestCase): diff --git a/tests/unit/wallet/test_bip32.py b/tests/unit/wallet/test_bip32.py index 87c804b01..92d99325f 100644 --- a/tests/unit/wallet/test_bip32.py +++ b/tests/unit/wallet/test_bip32.py @@ -1,10 +1,10 @@ from binascii import unhexlify, hexlify from lbry.testcase import AsyncioTestCase +from lbry.wallet.client.bip32 import PubKey, PrivateKey, from_extended_key_string +from lbry.wallet import Ledger, Database, Headers from tests.unit.wallet.key_fixtures import expected_ids, expected_privkeys, expected_hardened_privkeys -from lbry.wallet.client.bip32 import PubKey, PrivateKey, from_extended_key_string -from lbry.wallet import MainNetLedger as ledger_class class BIP32Tests(AsyncioTestCase): @@ -46,9 +46,9 @@ class BIP32Tests(AsyncioTestCase): with self.assertRaisesRegex(ValueError, 'private key must be 32 bytes'): PrivateKey(None, b'abcd', b'abcd'*8, 0, 255) private_key = PrivateKey( - ledger_class({ - 'db': ledger_class.database_class(':memory:'), - 'headers': ledger_class.headers_class(':memory:'), + Ledger({ + 'db': Database(':memory:'), + 'headers': Headers(':memory:'), }), unhexlify('2423f3dc6087d9683f73a684935abc0ccd8bc26370588f56653128c6a6f0bf7c'), b'abcd'*8, 0, 1 @@ -67,9 +67,9 @@ class BIP32Tests(AsyncioTestCase): async def test_private_key_derivation(self): private_key = PrivateKey( - ledger_class({ - 'db': ledger_class.database_class(':memory:'), - 'headers': ledger_class.headers_class(':memory:'), + Ledger({ + 'db': Database(':memory:'), + 'headers': Headers(':memory:'), }), unhexlify('2423f3dc6087d9683f73a684935abc0ccd8bc26370588f56653128c6a6f0bf7c'), b'abcd'*8, 0, 1 @@ -84,9 +84,9 @@ class BIP32Tests(AsyncioTestCase): self.assertEqual(hexlify(new_privkey.private_key_bytes), expected_hardened_privkeys[i - 1 - PrivateKey.HARDENED]) async def test_from_extended_keys(self): - ledger = ledger_class({ - 'db': ledger_class.database_class(':memory:'), - 'headers': ledger_class.headers_class(':memory:'), + ledger = Ledger({ + 'db': Database(':memory:'), + 'headers': Headers(':memory:'), }) self.assertIsInstance( from_extended_key_string( diff --git a/tests/unit/wallet/test_coinselection.py b/tests/unit/wallet/test_coinselection.py index 237a407be..d6dec1d4d 100644 --- a/tests/unit/wallet/test_coinselection.py +++ b/tests/unit/wallet/test_coinselection.py @@ -2,9 +2,9 @@ from types import GeneratorType from lbry.testcase import AsyncioTestCase -from lbry.wallet import MainNetLedger as ledger_class +from lbry.wallet import Ledger, Database, Headers from lbry.wallet.client.coinselection import CoinSelector, MAXIMUM_TRIES -from lbry.wallet.client.constants import CENT +from lbry.constants import CENT from tests.unit.wallet.test_transaction import get_output as utxo @@ -20,9 +20,9 @@ def search(*args, **kwargs): class BaseSelectionTestCase(AsyncioTestCase): async def asyncSetUp(self): - self.ledger = ledger_class({ - 'db': ledger_class.database_class(':memory:'), - 'headers': ledger_class.headers_class(':memory:'), + self.ledger = Ledger({ + 'db': Database(':memory:'), + 'headers': Headers(':memory:'), }) await self.ledger.db.open() diff --git a/tests/unit/wallet/test_database.py b/tests/unit/wallet/test_database.py index ba4809dd4..ddb1a4003 100644 --- a/tests/unit/wallet/test_database.py +++ b/tests/unit/wallet/test_database.py @@ -6,11 +6,11 @@ import tempfile import asyncio from concurrent.futures.thread import ThreadPoolExecutor -from lbry.wallet import MainNetLedger -from lbry.wallet.transaction import Transaction -from lbry.wallet.client.wallet import Wallet +from lbry.wallet import ( + Wallet, Account, Ledger, Database, Headers, Transaction, Input +) from lbry.wallet.client.constants import COIN -from lbry.wallet.client.basedatabase import query, interpolate, constraints_to_sql, AIOSQLite +from lbry.wallet.database import query, interpolate, constraints_to_sql, AIOSQLite from lbry.crypto.hash import sha256 from lbry.testcase import AsyncioTestCase @@ -195,9 +195,9 @@ class TestQueryBuilder(unittest.TestCase): class TestQueries(AsyncioTestCase): async def asyncSetUp(self): - self.ledger = MainNetLedger({ - 'db': MainNetLedger.database_class(':memory:'), - 'headers': MainNetLedger.headers_class(':memory:') + self.ledger = Ledger({ + 'db': Database(':memory:'), + 'headers': Headers(':memory:') }) self.wallet = Wallet() await self.ledger.db.open() @@ -206,13 +206,13 @@ class TestQueries(AsyncioTestCase): await self.ledger.db.close() async def create_account(self, wallet=None): - account = self.ledger.account_class.generate(self.ledger, wallet or self.wallet) + account = Account.generate(self.ledger, wallet or self.wallet) await account.ensure_address_gap() return account async def create_tx_from_nothing(self, my_account, height): to_address = await my_account.receiving.get_or_create_usable_address() - to_hash = MainNetLedger.address_to_hash160(to_address) + to_hash = Ledger.address_to_hash160(to_address) tx = Transaction(height=height, is_verified=True) \ .add_inputs([self.txi(self.txo(1, sha256(str(height).encode())))]) \ .add_outputs([self.txo(1, to_hash)]) @@ -224,7 +224,7 @@ class TestQueries(AsyncioTestCase): from_hash = txo.script.values['pubkey_hash'] from_address = self.ledger.hash160_to_address(from_hash) to_address = await to_account.receiving.get_or_create_usable_address() - to_hash = MainNetLedger.address_to_hash160(to_address) + to_hash = Ledger.address_to_hash160(to_address) tx = Transaction(height=height, is_verified=True) \ .add_inputs([self.txi(txo)]) \ .add_outputs([self.txo(1, to_hash)]) @@ -248,7 +248,7 @@ class TestQueries(AsyncioTestCase): return get_output(int(amount*COIN), address) def txi(self, txo): - return Transaction.input_class.spend(txo) + return Input.spend(txo) async def test_large_tx_doesnt_hit_variable_limits(self): # SQLite is usually compiled with 999 variables limit: https://www.sqlite.org/limits.html @@ -408,9 +408,9 @@ class TestUpgrade(AsyncioTestCase): return [col[0] for col in conn.execute(sql).fetchall()] async def test_reset_on_version_change(self): - self.ledger = MainNetLedger({ - 'db': MainNetLedger.database_class(self.path), - 'headers': MainNetLedger.headers_class(':memory:') + self.ledger = Ledger({ + 'db': Database(self.path), + 'headers': Headers(':memory:') }) # initial open, pre-version enabled db diff --git a/tests/unit/wallet/test_ledger.py b/tests/unit/wallet/test_ledger.py index ac9e1e206..13957d459 100644 --- a/tests/unit/wallet/test_ledger.py +++ b/tests/unit/wallet/test_ledger.py @@ -2,10 +2,7 @@ import os from binascii import hexlify from lbry.testcase import AsyncioTestCase -from lbry.wallet.client.wallet import Wallet -from lbry.wallet.account import Account -from lbry.wallet.transaction import Transaction, Output, Input -from lbry.wallet.ledger import MainNetLedger +from lbry.wallet import Wallet, Account, Transaction, Output, Input, Ledger, Database, Headers from tests.unit.wallet.test_transaction import get_transaction, get_output from tests.unit.wallet.test_headers import HEADERS, block_bytes @@ -40,9 +37,9 @@ class MockNetwork: class LedgerTestCase(AsyncioTestCase): async def asyncSetUp(self): - self.ledger = MainNetLedger({ - 'db': MainNetLedger.database_class(':memory:'), - 'headers': MainNetLedger.headers_class(':memory:') + self.ledger = Ledger({ + 'db': Database(':memory:'), + 'headers': Headers(':memory:') }) self.account = Account.generate(self.ledger, Wallet(), "lbryum") await self.ledger.db.open() @@ -76,7 +73,7 @@ class LedgerTestCase(AsyncioTestCase): class TestSynchronization(LedgerTestCase): async def test_update_history(self): - account = self.ledger.account_class.generate(self.ledger, Wallet(), "torba") + account = Account.generate(self.ledger, Wallet(), "torba") address = await account.receiving.get_or_create_usable_address() address_details = await self.ledger.db.get_address(address=address) self.assertIsNone(address_details['history']) diff --git a/tests/unit/wallet/test_schema_signing.py b/tests/unit/wallet/test_schema_signing.py index 5fc72f6ff..2a3acc1c0 100644 --- a/tests/unit/wallet/test_schema_signing.py +++ b/tests/unit/wallet/test_schema_signing.py @@ -3,9 +3,7 @@ from binascii import unhexlify from lbry.testcase import AsyncioTestCase from lbry.wallet.client.constants import CENT, NULL_HASH32 -from lbry.wallet.ledger import MainNetLedger -from lbry.wallet.transaction import Transaction, Input, Output - +from lbry.wallet import Ledger, Database, Headers, Transaction, Input, Output from lbry.schema.claim import Claim @@ -110,9 +108,9 @@ class TestValidatingOldSignatures(AsyncioTestCase): )) channel = channel_tx.outputs[0] - ledger = MainNetLedger({ - 'db': MainNetLedger.database_class(':memory:'), - 'headers': MainNetLedger.headers_class(':memory:') + ledger = Ledger({ + 'db': Database(':memory:'), + 'headers': Headers(':memory:') }) self.assertTrue(stream.is_signed_by(channel, ledger)) diff --git a/tests/unit/wallet/test_script.py b/tests/unit/wallet/test_script.py index c0956ae0f..7333e1133 100644 --- a/tests/unit/wallet/test_script.py +++ b/tests/unit/wallet/test_script.py @@ -1,11 +1,11 @@ -from lbry.wallet.script import OutputScript import unittest from binascii import hexlify, unhexlify -from lbry.wallet.client.bcd_data_stream import BCDataStream -from lbry.wallet.client.basescript import Template, ParseError, tokenize, push_data -from lbry.wallet.client.basescript import PUSH_SINGLE, PUSH_INTEGER, PUSH_MANY, OP_HASH160, OP_EQUAL -from lbry.wallet.client.basescript import BaseInputScript, BaseOutputScript +from lbry.wallet.bcd_data_stream import BCDataStream +from lbry.wallet.script import ( + InputScript, OutputScript, Template, ParseError, tokenize, push_data, + PUSH_SINGLE, PUSH_INTEGER, PUSH_MANY, OP_HASH160, OP_EQUAL +) def parse(opcodes, source): @@ -102,12 +102,12 @@ class TestRedeemPubKeyHash(unittest.TestCase): def redeem_pubkey_hash(self, sig, pubkey): # this checks that factory function correctly sets up the script - src1 = BaseInputScript.redeem_pubkey_hash(unhexlify(sig), unhexlify(pubkey)) + src1 = InputScript.redeem_pubkey_hash(unhexlify(sig), unhexlify(pubkey)) self.assertEqual(src1.template.name, 'pubkey_hash') self.assertEqual(hexlify(src1.values['signature']), sig) self.assertEqual(hexlify(src1.values['pubkey']), pubkey) # now we test that it will round trip - src2 = BaseInputScript(src1.source) + src2 = InputScript(src1.source) self.assertEqual(src2.template.name, 'pubkey_hash') self.assertEqual(hexlify(src2.values['signature']), sig) self.assertEqual(hexlify(src2.values['pubkey']), pubkey) @@ -130,7 +130,7 @@ class TestRedeemScriptHash(unittest.TestCase): def redeem_script_hash(self, sigs, pubkeys): # this checks that factory function correctly sets up the script - src1 = BaseInputScript.redeem_script_hash( + src1 = InputScript.redeem_script_hash( [unhexlify(sig) for sig in sigs], [unhexlify(pubkey) for pubkey in pubkeys] ) @@ -141,7 +141,7 @@ class TestRedeemScriptHash(unittest.TestCase): self.assertEqual(subscript1.values['signatures_count'], len(sigs)) self.assertEqual(subscript1.values['pubkeys_count'], len(pubkeys)) # now we test that it will round trip - src2 = BaseInputScript(src1.source) + src2 = InputScript(src1.source) subscript2 = src2.values['script'] self.assertEqual(src2.template.name, 'script_hash') self.assertListEqual([hexlify(v) for v in src2.values['signatures']], sigs) @@ -183,11 +183,11 @@ class TestPayPubKeyHash(unittest.TestCase): def pay_pubkey_hash(self, pubkey_hash): # this checks that factory function correctly sets up the script - src1 = BaseOutputScript.pay_pubkey_hash(unhexlify(pubkey_hash)) + src1 = OutputScript.pay_pubkey_hash(unhexlify(pubkey_hash)) self.assertEqual(src1.template.name, 'pay_pubkey_hash') self.assertEqual(hexlify(src1.values['pubkey_hash']), pubkey_hash) # now we test that it will round trip - src2 = BaseOutputScript(src1.source) + src2 = OutputScript(src1.source) self.assertEqual(src2.template.name, 'pay_pubkey_hash') self.assertEqual(hexlify(src2.values['pubkey_hash']), pubkey_hash) return hexlify(src1.source) @@ -203,11 +203,11 @@ class TestPayScriptHash(unittest.TestCase): def pay_script_hash(self, script_hash): # this checks that factory function correctly sets up the script - src1 = BaseOutputScript.pay_script_hash(unhexlify(script_hash)) + src1 = OutputScript.pay_script_hash(unhexlify(script_hash)) self.assertEqual(src1.template.name, 'pay_script_hash') self.assertEqual(hexlify(src1.values['script_hash']), script_hash) # now we test that it will round trip - src2 = BaseOutputScript(src1.source) + src2 = OutputScript(src1.source) self.assertEqual(src2.template.name, 'pay_script_hash') self.assertEqual(hexlify(src2.values['script_hash']), script_hash) return hexlify(src1.source) diff --git a/tests/unit/wallet/test_transaction.py b/tests/unit/wallet/test_transaction.py index 67768aec5..54e4e361f 100644 --- a/tests/unit/wallet/test_transaction.py +++ b/tests/unit/wallet/test_transaction.py @@ -4,10 +4,7 @@ from itertools import cycle from lbry.testcase import AsyncioTestCase from lbry.wallet.client.constants import CENT, COIN, NULL_HASH32 -from lbry.wallet.client.wallet import Wallet - -from lbry.wallet.ledger import MainNetLedger -from lbry.wallet.transaction import Transaction, Output, Input +from lbry.wallet import Wallet, Account, Ledger, Database, Headers, Transaction, Output, Input NULL_HASH = b'\x00'*32 @@ -40,9 +37,9 @@ def get_claim_transaction(claim_name, claim=b''): class TestSizeAndFeeEstimation(AsyncioTestCase): async def asyncSetUp(self): - self.ledger = MainNetLedger({ - 'db': MainNetLedger.database_class(':memory:'), - 'headers': MainNetLedger.headers_class(':memory:') + self.ledger = Ledger({ + 'db': Database(':memory:'), + 'headers': Headers(':memory:') }) await self.ledger.db.open() @@ -266,9 +263,9 @@ class TestTransactionSerialization(unittest.TestCase): class TestTransactionSigning(AsyncioTestCase): async def asyncSetUp(self): - self.ledger = MainNetLedger({ - 'db': MainNetLedger.database_class(':memory:'), - 'headers': MainNetLedger.headers_class(':memory:') + self.ledger = Ledger({ + 'db': Database(':memory:'), + 'headers': Headers(':memory:') }) await self.ledger.db.open() @@ -276,7 +273,7 @@ class TestTransactionSigning(AsyncioTestCase): await self.ledger.db.close() async def test_sign(self): - account = self.ledger.account_class.from_dict( + account = Account.from_dict( self.ledger, Wallet(), { "seed": "carbon smart garage balance margin twelve chest sword toas" @@ -305,12 +302,12 @@ class TestTransactionSigning(AsyncioTestCase): class TransactionIOBalancing(AsyncioTestCase): async def asyncSetUp(self): - self.ledger = MainNetLedger({ - 'db': MainNetLedger.database_class(':memory:'), - 'headers': MainNetLedger.headers_class(':memory:') + self.ledger = Ledger({ + 'db': Database(':memory:'), + 'headers': Headers(':memory:') }) await self.ledger.db.open() - self.account = self.ledger.account_class.from_dict( + self.account = Account.from_dict( self.ledger, Wallet(), { "seed": "carbon smart garage balance margin twelve chest sword " "toast envelope bottom stomach absent" @@ -328,7 +325,7 @@ class TransactionIOBalancing(AsyncioTestCase): return get_output(int(amount*COIN), address or next(self.hash_cycler)) def txi(self, txo): - return Transaction.input_class.spend(txo) + return Input.spend(txo) def tx(self, inputs, outputs): return Transaction.create(inputs, outputs, [self.account], self.account) diff --git a/tests/unit/wallet/test_wallet.py b/tests/unit/wallet/test_wallet.py index 9f1b18da0..dc6e4d3e8 100644 --- a/tests/unit/wallet/test_wallet.py +++ b/tests/unit/wallet/test_wallet.py @@ -3,18 +3,18 @@ from binascii import hexlify from unittest import TestCase, mock from lbry.testcase import AsyncioTestCase - -from lbry.wallet.ledger import MainNetLedger, RegTestLedger -from lbry.wallet.client.basemanager import BaseWalletManager -from lbry.wallet.client.wallet import Wallet, WalletStorage, TimestampedPreferences +from lbry.wallet import ( + Ledger, RegTestLedger, WalletManager, Account, + Wallet, WalletStorage, TimestampedPreferences +) class TestWalletCreation(AsyncioTestCase): async def asyncSetUp(self): - self.manager = BaseWalletManager() + self.manager = WalletManager() config = {'data_path': '/tmp/wallet'} - self.main_ledger = self.manager.get_or_create_ledger(MainNetLedger.get_id(), config) + self.main_ledger = self.manager.get_or_create_ledger(Ledger.get_id(), config) self.test_ledger = self.manager.get_or_create_ledger(RegTestLedger.get_id(), config) def test_create_wallet_and_accounts(self): @@ -66,7 +66,7 @@ class TestWalletCreation(AsyncioTestCase): ) self.assertEqual(len(wallet.accounts), 1) account = wallet.default_account - self.assertIsInstance(account, MainNetLedger.account_class) + self.assertIsInstance(account, Account) self.maxDiff = None self.assertDictEqual(wallet_dict, wallet.to_dict()) @@ -75,9 +75,9 @@ class TestWalletCreation(AsyncioTestCase): self.assertEqual(decrypted['accounts'][0]['name'], 'An Account') def test_read_write(self): - manager = BaseWalletManager() + manager = WalletManager() config = {'data_path': '/tmp/wallet'} - ledger = manager.get_or_create_ledger(MainNetLedger.get_id(), config) + ledger = manager.get_or_create_ledger(Ledger.get_id(), config) with tempfile.NamedTemporaryFile(suffix='.json') as wallet_file: wallet_file.write(b'{"version": 1}')