update imports and more merging

This commit is contained in:
Lex Berezhny 2020-01-02 22:18:49 -05:00
parent c9e410a6f4
commit fb1af9e3d2
44 changed files with 3667 additions and 470 deletions

2
.gitignore vendored
View file

@ -10,7 +10,7 @@ lbry.egg-info
__pycache__ __pycache__
_trial_temp/ _trial_temp/
/tests/integration/files /tests/integration/blockchain/files
/tests/.coverage.* /tests/.coverage.*
/lbry/wallet/bin /lbry/wallet/bin

View file

@ -9,7 +9,7 @@ from contextlib import contextmanager
from appdirs import user_data_dir, user_config_dir from appdirs import user_data_dir, user_config_dir
from lbry.error import InvalidCurrencyError from lbry.error import InvalidCurrencyError
from lbry.dht import constants from lbry.dht import constants
from lbry.wallet.client.coinselection import STRATEGIES from lbry.wallet.coinselection import STRATEGIES
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View file

@ -20,7 +20,7 @@ from lbry.stream.stream_manager import StreamManager
from lbry.extras.daemon.Component import Component from lbry.extras.daemon.Component import Component
from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager
from lbry.extras.daemon.storage import SQLiteStorage from lbry.extras.daemon.storage import SQLiteStorage
from lbry.wallet import LbryWalletManager from lbry.wallet import WalletManager
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View file

@ -17,8 +17,11 @@ from traceback import format_exc
from aiohttp import web from aiohttp import web
from functools import wraps, partial from functools import wraps, partial
from google.protobuf.message import DecodeError from google.protobuf.message import DecodeError
from lbry.wallet.client.wallet import Wallet, ENCRYPT_ON_DISK from lbry.wallet import (
from lbry.wallet.client.baseaccount import SingleKey, HierarchicalDeterministic Wallet, WalletManager, ENCRYPT_ON_DISK, SingleKey, HierarchicalDeterministic,
Ledger, Transaction, Output, Input, Account
)
from lbry.wallet.dewies import dewies_to_lbc, lbc_to_dewies, dict_values_to_lbc
from lbry import utils from lbry import utils
from lbry.conf import Config, Setting, NOT_SET from lbry.conf import Config, Setting, NOT_SET
@ -39,9 +42,6 @@ from lbry.extras.daemon.ComponentManager import ComponentManager
from lbry.extras.daemon.json_response_encoder import JSONResponseEncoder from lbry.extras.daemon.json_response_encoder import JSONResponseEncoder
from lbry.extras.daemon import comment_client from lbry.extras.daemon import comment_client
from lbry.extras.daemon.undecorated import undecorated from lbry.extras.daemon.undecorated import undecorated
from lbry.wallet.transaction import Transaction, Output, Input
from lbry.wallet.account import Account as LBCAccount
from lbry.wallet.dewies import dewies_to_lbc, lbc_to_dewies, dict_values_to_lbc
from lbry.schema.claim import Claim from lbry.schema.claim import Claim
from lbry.schema.url import URL from lbry.schema.url import URL
@ -51,8 +51,6 @@ if typing.TYPE_CHECKING:
from lbry.extras.daemon.Components import UPnPComponent from lbry.extras.daemon.Components import UPnPComponent
from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager
from lbry.extras.daemon.storage import SQLiteStorage from lbry.extras.daemon.storage import SQLiteStorage
from lbry.wallet.manager import LbryWalletManager
from lbry.wallet.ledger import MainNetLedger
from lbry.stream.stream_manager import StreamManager from lbry.stream.stream_manager import StreamManager
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -322,7 +320,7 @@ class Daemon(metaclass=JSONRPCServerType):
return self.component_manager.get_component(DHT_COMPONENT) return self.component_manager.get_component(DHT_COMPONENT)
@property @property
def wallet_manager(self) -> typing.Optional['LbryWalletManager']: def wallet_manager(self) -> typing.Optional['WalletManager']:
return self.component_manager.get_component(WALLET_COMPONENT) return self.component_manager.get_component(WALLET_COMPONENT)
@property @property
@ -676,7 +674,7 @@ class Daemon(metaclass=JSONRPCServerType):
return None, None return None, None
@property @property
def ledger(self) -> Optional['MainNetLedger']: def ledger(self) -> Optional['Ledger']:
try: try:
return self.wallet_manager.default_account.ledger return self.wallet_manager.default_account.ledger
except AttributeError: except AttributeError:
@ -1161,7 +1159,7 @@ class Daemon(metaclass=JSONRPCServerType):
wallet = self.wallet_manager.import_wallet(wallet_path) wallet = self.wallet_manager.import_wallet(wallet_path)
if not wallet.accounts and create_account: if not wallet.accounts and create_account:
account = LBCAccount.generate( account = Account.generate(
self.ledger, wallet, address_generator={ self.ledger, wallet, address_generator={
'name': SingleKey.name if single_key else HierarchicalDeterministic.name 'name': SingleKey.name if single_key else HierarchicalDeterministic.name
} }
@ -1464,7 +1462,7 @@ class Daemon(metaclass=JSONRPCServerType):
Returns: {Account} Returns: {Account}
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallet_manager.get_wallet_or_default(wallet_id)
account = LBCAccount.from_dict( account = Account.from_dict(
self.ledger, wallet, { self.ledger, wallet, {
'name': account_name, 'name': account_name,
'seed': seed, 'seed': seed,
@ -1498,7 +1496,7 @@ class Daemon(metaclass=JSONRPCServerType):
Returns: {Account} Returns: {Account}
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallet_manager.get_wallet_or_default(wallet_id)
account = LBCAccount.generate( account = Account.generate(
self.ledger, wallet, account_name, { self.ledger, wallet, account_name, {
'name': SingleKey.name if single_key else HierarchicalDeterministic.name 'name': SingleKey.name if single_key else HierarchicalDeterministic.name
} }
@ -2134,7 +2132,7 @@ class Daemon(metaclass=JSONRPCServerType):
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallet_manager.get_wallet_or_default(wallet_id)
if account_id: if account_id:
account: LBCAccount = wallet.get_account_or_error(account_id) account = wallet.get_account_or_error(account_id)
claims = account.get_claims claims = account.get_claims
claim_count = account.get_claim_count claim_count = account.get_claim_count
else: else:
@ -2657,7 +2655,7 @@ class Daemon(metaclass=JSONRPCServerType):
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallet_manager.get_wallet_or_default(wallet_id)
if account_id: if account_id:
account: LBCAccount = wallet.get_account_or_error(account_id) account = wallet.get_account_or_error(account_id)
channels = account.get_channels channels = account.get_channels
channel_count = account.get_channel_count channel_count = account.get_channel_count
else: else:
@ -2732,7 +2730,7 @@ class Daemon(metaclass=JSONRPCServerType):
if channels and channels[0].get_address(self.ledger) != holding_address: if channels and channels[0].get_address(self.ledger) != holding_address:
holding_address = channels[0].get_address(self.ledger) holding_address = channels[0].get_address(self.ledger)
account: LBCAccount = await self.ledger.get_account_for_address(wallet, holding_address) account = await self.ledger.get_account_for_address(wallet, holding_address)
if account: if account:
# Case 1: channel holding address is in one of the accounts we already have # Case 1: channel holding address is in one of the accounts we already have
# simply add the certificate to existing account # simply add the certificate to existing account
@ -2741,7 +2739,7 @@ class Daemon(metaclass=JSONRPCServerType):
# Case 2: channel holding address hasn't changed and thus is in the bundled read-only account # Case 2: channel holding address hasn't changed and thus is in the bundled read-only account
# create a single-address holding account to manage the channel # create a single-address holding account to manage the channel
if holding_address == data['holding_address']: if holding_address == data['holding_address']:
account = LBCAccount.from_dict(self.ledger, wallet, { account = Account.from_dict(self.ledger, wallet, {
'name': f"Holding Account For Channel {data['name']}", 'name': f"Holding Account For Channel {data['name']}",
'public_key': data['holding_public_key'], 'public_key': data['holding_public_key'],
'address_generator': {'name': 'single-address'} 'address_generator': {'name': 'single-address'}
@ -3384,7 +3382,7 @@ class Daemon(metaclass=JSONRPCServerType):
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallet_manager.get_wallet_or_default(wallet_id)
if account_id: if account_id:
account: LBCAccount = wallet.get_account_or_error(account_id) account = wallet.get_account_or_error(account_id)
streams = account.get_streams streams = account.get_streams
stream_count = account.get_stream_count stream_count = account.get_stream_count
else: else:
@ -3727,7 +3725,7 @@ class Daemon(metaclass=JSONRPCServerType):
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallet_manager.get_wallet_or_default(wallet_id)
if account_id: if account_id:
account: LBCAccount = wallet.get_account_or_error(account_id) account = wallet.get_account_or_error(account_id)
collections = account.get_collections collections = account.get_collections
collection_count = account.get_collection_count collection_count = account.get_collection_count
else: else:
@ -3854,7 +3852,7 @@ class Daemon(metaclass=JSONRPCServerType):
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallet_manager.get_wallet_or_default(wallet_id)
if account_id: if account_id:
account: LBCAccount = wallet.get_account_or_error(account_id) account = wallet.get_account_or_error(account_id)
supports = account.get_supports supports = account.get_supports
support_count = account.get_support_count support_count = account.get_support_count
else: else:
@ -4002,7 +4000,7 @@ class Daemon(metaclass=JSONRPCServerType):
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallet_manager.get_wallet_or_default(wallet_id)
if account_id: if account_id:
account: LBCAccount = wallet.get_account_or_error(account_id) account = wallet.get_account_or_error(account_id)
transactions = account.get_transaction_history transactions = account.get_transaction_history
transaction_count = account.get_transaction_history_count transaction_count = account.get_transaction_history_count
else: else:
@ -4696,7 +4694,7 @@ class Daemon(metaclass=JSONRPCServerType):
if 'fee_currency' in kwargs or 'fee_amount' in kwargs: if 'fee_currency' in kwargs or 'fee_amount' in kwargs:
return claim_address return claim_address
async def get_receiving_address(self, address: str, account: Optional[LBCAccount]) -> str: async def get_receiving_address(self, address: str, account: Optional[Account]) -> str:
if address is None and account is not None: if address is None and account is not None:
return await account.receiving.get_or_create_usable_address() return await account.receiving.get_or_create_usable_address()
self.valid_address_or_error(address) self.valid_address_or_error(address)

View file

@ -6,11 +6,9 @@ from json import JSONEncoder
from google.protobuf.message import DecodeError from google.protobuf.message import DecodeError
from lbry.wallet.client.wallet import Wallet
from lbry.wallet.client.bip32 import PubKey
from lbry.schema.claim import Claim from lbry.schema.claim import Claim
from lbry.wallet.ledger import MainNetLedger, Account from lbry.wallet import Wallet, Ledger, Account, Transaction, Output
from lbry.wallet.transaction import Transaction, Output from lbry.wallet.bip32 import PubKey
from lbry.wallet.dewies import dewies_to_lbc from lbry.wallet.dewies import dewies_to_lbc
from lbry.stream.managed_stream import ManagedStream from lbry.stream.managed_stream import ManagedStream
@ -114,7 +112,7 @@ def encode_file_doc():
class JSONResponseEncoder(JSONEncoder): class JSONResponseEncoder(JSONEncoder):
def __init__(self, *args, ledger: MainNetLedger, include_protobuf=False, **kwargs): def __init__(self, *args, ledger: Ledger, include_protobuf=False, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.ledger = ledger self.ledger = ledger
self.include_protobuf = include_protobuf self.include_protobuf = include_protobuf

View file

@ -5,7 +5,7 @@ import typing
import asyncio import asyncio
import binascii import binascii
import time import time
from lbry.wallet.client.basedatabase import SQLiteMixin from lbry.wallet import SQLiteMixin
from lbry.conf import Config from lbry.conf import Config
from lbry.wallet.dewies import dewies_to_lbc, lbc_to_dewies from lbry.wallet.dewies import dewies_to_lbc, lbc_to_dewies
from lbry.wallet.transaction import Transaction from lbry.wallet.transaction import Transaction

View file

@ -14,17 +14,15 @@ from lbry.stream.managed_stream import ManagedStream
from lbry.schema.claim import Claim from lbry.schema.claim import Claim
from lbry.schema.url import URL from lbry.schema.url import URL
from lbry.wallet.dewies import dewies_to_lbc from lbry.wallet.dewies import dewies_to_lbc
from lbry.wallet.transaction import Output from lbry.wallet import WalletManager, Wallet, Transaction, Output
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from lbry.conf import Config from lbry.conf import Config
from lbry.blob.blob_manager import BlobManager from lbry.blob.blob_manager import BlobManager
from lbry.dht.node import Node from lbry.dht.node import Node
from lbry.extras.daemon.analytics import AnalyticsManager from lbry.extras.daemon.analytics import AnalyticsManager
from lbry.extras.daemon.storage import SQLiteStorage, StoredContentClaim from lbry.extras.daemon.storage import SQLiteStorage, StoredContentClaim
from lbry.wallet import LbryWalletManager
from lbry.wallet.transaction import Transaction
from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager
from lbry.wallet.client.wallet import Wallet
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -66,7 +64,7 @@ def path_or_none(p) -> Optional[str]:
class StreamManager: class StreamManager:
def __init__(self, loop: asyncio.AbstractEventLoop, config: 'Config', blob_manager: 'BlobManager', def __init__(self, loop: asyncio.AbstractEventLoop, config: 'Config', blob_manager: 'BlobManager',
wallet_manager: 'LbryWalletManager', storage: 'SQLiteStorage', node: Optional['Node'], wallet_manager: 'WalletManager', storage: 'SQLiteStorage', node: Optional['Node'],
analytics_manager: Optional['AnalyticsManager'] = None): analytics_manager: Optional['AnalyticsManager'] = None):
self.loop = loop self.loop = loop
self.config = config self.config = config

View file

@ -14,18 +14,11 @@ from time import time
from binascii import unhexlify from binascii import unhexlify
from functools import partial from functools import partial
import lbry.wallet from lbry.wallet import WalletManager, Wallet, Ledger, Account, Transaction
from lbry.conf import Config from lbry.conf import Config
from lbry.wallet import LbryWalletManager from lbry.wallet.util import satoshis_to_coins
from lbry.wallet.account import Account
from lbry.wallet.orchstr8 import Conductor from lbry.wallet.orchstr8 import Conductor
from lbry.wallet.transaction import Transaction
from lbry.wallet.client.wallet import Wallet
from lbry.wallet.client.util import satoshis_to_coins
from lbry.wallet.orchstr8.node import BlockchainNode, WalletNode from lbry.wallet.orchstr8.node import BlockchainNode, WalletNode
from lbry.wallet.client.baseledger import BaseLedger
from lbry.wallet.client.baseaccount import BaseAccount
from lbry.wallet.client.basemanager import BaseWalletManager
from lbry.extras.daemon.Daemon import Daemon, jsonrpc_dumps_pretty from lbry.extras.daemon.Daemon import Daemon, jsonrpc_dumps_pretty
from lbry.extras.daemon.Components import Component, WalletComponent from lbry.extras.daemon.Components import Component, WalletComponent
@ -215,25 +208,19 @@ class AdvanceTimeTestCase(AsyncioTestCase):
class IntegrationTestCase(AsyncioTestCase): class IntegrationTestCase(AsyncioTestCase):
SEED = None SEED = None
LEDGER = lbry.wallet
MANAGER = LbryWalletManager
ENABLE_SEGWIT = False
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.conductor: Optional[Conductor] = None self.conductor: Optional[Conductor] = None
self.blockchain: Optional[BlockchainNode] = None self.blockchain: Optional[BlockchainNode] = None
self.wallet_node: Optional[WalletNode] = None self.wallet_node: Optional[WalletNode] = None
self.manager: Optional[BaseWalletManager] = None self.manager: Optional[WalletManager] = None
self.ledger: Optional[BaseLedger] = None self.ledger: Optional[Ledger] = None
self.wallet: Optional[Wallet] = None self.wallet: Optional[Wallet] = None
self.account: Optional[BaseAccount] = None self.account: Optional[Account] = None
async def asyncSetUp(self): async def asyncSetUp(self):
self.conductor = Conductor( self.conductor = Conductor(seed=self.SEED)
ledger_module=self.LEDGER, manager_module=self.MANAGER,
enable_segwit=self.ENABLE_SEGWIT, seed=self.SEED
)
await self.conductor.start_blockchain() await self.conductor.start_blockchain()
self.addCleanup(self.conductor.stop_blockchain) self.addCleanup(self.conductor.stop_blockchain)
await self.conductor.start_spv() await self.conductor.start_spv()
@ -317,14 +304,13 @@ class CommandTestCase(IntegrationTestCase):
VERBOSITY = logging.WARN VERBOSITY = logging.WARN
blob_lru_cache_size = 0 blob_lru_cache_size = 0
account: Account
async def asyncSetUp(self): async def asyncSetUp(self):
await super().asyncSetUp() await super().asyncSetUp()
logging.getLogger('lbry.blob_exchange').setLevel(self.VERBOSITY) logging.getLogger('lbry.blob_exchange').setLevel(self.VERBOSITY)
logging.getLogger('lbry.daemon').setLevel(self.VERBOSITY) logging.getLogger('lbry.daemon').setLevel(self.VERBOSITY)
logging.getLogger('lbry.stream').setLevel(self.VERBOSITY) logging.getLogger('lbry.stream').setLevel(self.VERBOSITY)
logging.getLogger('lbry.wallet').setLevel(self.VERBOSITY)
self.daemons = [] self.daemons = []
self.extra_wallet_nodes = [] self.extra_wallet_nodes = []
@ -419,9 +405,7 @@ class CommandTestCase(IntegrationTestCase):
return txid return txid
async def on_transaction_dict(self, tx): async def on_transaction_dict(self, tx):
await self.ledger.wait( await self.ledger.wait(Transaction(unhexlify(tx['hex'])))
self.ledger.transaction_class(unhexlify(tx['hex']))
)
@staticmethod @staticmethod
def get_all_addresses(tx): def get_all_addresses(tx):

View file

@ -6,6 +6,12 @@ __node_url__ = (
) )
__spvserver__ = 'lbry.wallet.server.coin.LBCRegTest' __spvserver__ = 'lbry.wallet.server.coin.LBCRegTest'
from lbry.wallet.manager import LbryWalletManager from .wallet import Wallet, WalletStorage, TimestampedPreferences, ENCRYPT_ON_DISK
from lbry.wallet.network import Network from .manager import WalletManager
from lbry.wallet.ledger import MainNetLedger, RegTestLedger, TestNetLedger from .network import Network
from .ledger import Ledger, RegTestLedger, TestNetLedger, BlockHeightEvent
from .account import Account, AddressManager, SingleKey, HierarchicalDeterministic
from .transaction import Transaction, Output, Input
from .script import OutputScript, InputScript
from .database import SQLiteMixin, Database
from .header import Headers

View file

@ -1,14 +1,28 @@
import os
import time
import json import json
import ecdsa
import logging import logging
import typing
import asyncio
import random
from functools import partial from functools import partial
from hashlib import sha256 from hashlib import sha256
from string import hexdigits from string import hexdigits
from typing import Type, Dict, Tuple, Optional, Any, List
import ecdsa from lbry.error import InvalidPasswordError
from lbry.wallet.constants import CLAIM_TYPES, TXO_TYPES from lbry.crypto.crypt import aes_encrypt, aes_decrypt
from lbry.wallet.client.baseaccount import BaseAccount, HierarchicalDeterministic from .bip32 import PrivateKey, PubKey, from_extended_key_string
from .mnemonic import Mnemonic
from .constants import COIN, CLAIM_TYPES, TXO_TYPES
from .transaction import Transaction, Input, Output
if typing.TYPE_CHECKING:
from .ledger import Ledger
from .wallet import Wallet
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -22,22 +36,483 @@ def validate_claim_id(claim_id):
raise Exception("Claim id is not hex encoded") raise Exception("Claim id is not hex encoded")
class Account(BaseAccount): class AddressManager:
def __init__(self, *args, **kwargs): name: str
super().__init__(*args, **kwargs)
self.channel_keys = {} __slots__ = 'account', 'public_key', 'chain_number', 'address_generator_lock'
def __init__(self, account, public_key, chain_number):
self.account = account
self.public_key = public_key
self.chain_number = chain_number
self.address_generator_lock = asyncio.Lock()
@classmethod
def from_dict(cls, account: 'Account', d: dict) \
-> Tuple['AddressManager', 'AddressManager']:
raise NotImplementedError
@classmethod
def to_dict(cls, receiving: 'AddressManager', change: 'AddressManager') -> Dict:
d: Dict[str, Any] = {'name': cls.name}
receiving_dict = receiving.to_dict_instance()
if receiving_dict:
d['receiving'] = receiving_dict
change_dict = change.to_dict_instance()
if change_dict:
d['change'] = change_dict
return d
def merge(self, d: dict):
pass
def to_dict_instance(self) -> Optional[dict]:
raise NotImplementedError
def _query_addresses(self, **constraints):
return self.account.ledger.db.get_addresses(
accounts=[self.account],
chain=self.chain_number,
**constraints
)
def get_private_key(self, index: int) -> PrivateKey:
raise NotImplementedError
def get_public_key(self, index: int) -> PubKey:
raise NotImplementedError
async def get_max_gap(self):
raise NotImplementedError
async def ensure_address_gap(self):
raise NotImplementedError
def get_address_records(self, only_usable: bool = False, **constraints):
raise NotImplementedError
async def get_addresses(self, only_usable: bool = False, **constraints) -> List[str]:
records = await self.get_address_records(only_usable=only_usable, **constraints)
return [r['address'] for r in records]
async def get_or_create_usable_address(self) -> str:
addresses = await self.get_addresses(only_usable=True, limit=10)
if addresses:
return random.choice(addresses)
addresses = await self.ensure_address_gap()
return addresses[0]
class HierarchicalDeterministic(AddressManager):
""" Implements simple version of Bitcoin Hierarchical Deterministic key management. """
name: str = "deterministic-chain"
__slots__ = 'gap', 'maximum_uses_per_address'
def __init__(self, account: 'Account', chain: int, gap: int, maximum_uses_per_address: int) -> None:
super().__init__(account, account.public_key.child(chain), chain)
self.gap = gap
self.maximum_uses_per_address = maximum_uses_per_address
@classmethod
def from_dict(cls, account: 'Account', d: dict) -> Tuple[AddressManager, AddressManager]:
return (
cls(account, 0, **d.get('receiving', {'gap': 20, 'maximum_uses_per_address': 1})),
cls(account, 1, **d.get('change', {'gap': 6, 'maximum_uses_per_address': 1}))
)
def merge(self, d: dict):
self.gap = d.get('gap', self.gap)
self.maximum_uses_per_address = d.get('maximum_uses_per_address', self.maximum_uses_per_address)
def to_dict_instance(self):
return {'gap': self.gap, 'maximum_uses_per_address': self.maximum_uses_per_address}
def get_private_key(self, index: int) -> PrivateKey:
return self.account.private_key.child(self.chain_number).child(index)
def get_public_key(self, index: int) -> PubKey:
return self.account.public_key.child(self.chain_number).child(index)
async def get_max_gap(self) -> int:
addresses = await self._query_addresses(order_by="n asc")
max_gap = 0
current_gap = 0
for address in addresses:
if address['used_times'] == 0:
current_gap += 1
else:
max_gap = max(max_gap, current_gap)
current_gap = 0
return max_gap
async def ensure_address_gap(self) -> List[str]:
async with self.address_generator_lock:
addresses = await self._query_addresses(limit=self.gap, order_by="n desc")
existing_gap = 0
for address in addresses:
if address['used_times'] == 0:
existing_gap += 1
else:
break
if existing_gap == self.gap:
return []
start = addresses[0]['pubkey'].n+1 if addresses else 0
end = start + (self.gap - existing_gap)
new_keys = await self._generate_keys(start, end-1)
await self.account.ledger.announce_addresses(self, new_keys)
return new_keys
async def _generate_keys(self, start: int, end: int) -> List[str]:
if not self.address_generator_lock.locked():
raise RuntimeError('Should not be called outside of address_generator_lock.')
keys = [self.public_key.child(index) for index in range(start, end+1)]
await self.account.ledger.db.add_keys(self.account, self.chain_number, keys)
return [key.address for key in keys]
def get_address_records(self, only_usable: bool = False, **constraints):
if only_usable:
constraints['used_times__lt'] = self.maximum_uses_per_address
if 'order_by' not in constraints:
constraints['order_by'] = "used_times asc, n asc"
return self._query_addresses(**constraints)
class SingleKey(AddressManager):
""" Single Key address manager always returns the same address for all operations. """
name: str = "single-address"
__slots__ = ()
@classmethod
def from_dict(cls, account: 'Account', d: dict) \
-> Tuple[AddressManager, AddressManager]:
same_address_manager = cls(account, account.public_key, 0)
return same_address_manager, same_address_manager
def to_dict_instance(self):
return None
def get_private_key(self, index: int) -> PrivateKey:
return self.account.private_key
def get_public_key(self, index: int) -> PubKey:
return self.account.public_key
async def get_max_gap(self) -> int:
return 0
async def ensure_address_gap(self) -> List[str]:
async with self.address_generator_lock:
exists = await self.get_address_records()
if not exists:
await self.account.ledger.db.add_keys(self.account, self.chain_number, [self.public_key])
new_keys = [self.public_key.address]
await self.account.ledger.announce_addresses(self, new_keys)
return new_keys
return []
def get_address_records(self, only_usable: bool = False, **constraints):
return self._query_addresses(**constraints)
class Account:
mnemonic_class = Mnemonic
private_key_class = PrivateKey
public_key_class = PubKey
address_generators: Dict[str, Type[AddressManager]] = {
SingleKey.name: SingleKey,
HierarchicalDeterministic.name: HierarchicalDeterministic,
}
def __init__(self, ledger: 'Ledger', wallet: 'Wallet', name: str,
seed: str, private_key_string: str, encrypted: bool,
private_key: Optional[PrivateKey], public_key: PubKey,
address_generator: dict, modified_on: float, channel_keys: dict) -> None:
self.ledger = ledger
self.wallet = wallet
self.id = public_key.address
self.name = name
self.seed = seed
self.modified_on = modified_on
self.private_key_string = private_key_string
self.init_vectors: Dict[str, bytes] = {}
self.encrypted = encrypted
self.private_key = private_key
self.public_key = public_key
generator_name = address_generator.get('name', HierarchicalDeterministic.name)
self.address_generator = self.address_generators[generator_name]
self.receiving, self.change = self.address_generator.from_dict(self, address_generator)
self.address_managers = {am.chain_number: am for am in {self.receiving, self.change}}
self.channel_keys = channel_keys
ledger.add_account(self)
wallet.add_account(self)
def get_init_vector(self, key) -> Optional[bytes]:
init_vector = self.init_vectors.get(key, None)
if init_vector is None:
init_vector = self.init_vectors[key] = os.urandom(16)
return init_vector
@classmethod
def generate(cls, ledger: 'Ledger', wallet: 'Wallet',
name: str = None, address_generator: dict = None):
return cls.from_dict(ledger, wallet, {
'name': name,
'seed': cls.mnemonic_class().make_seed(),
'address_generator': address_generator or {}
})
@classmethod
def get_private_key_from_seed(cls, ledger: 'Ledger', seed: str, password: str):
return cls.private_key_class.from_seed(
ledger, cls.mnemonic_class.mnemonic_to_seed(seed, password or 'lbryum')
)
@classmethod
def keys_from_dict(cls, ledger: 'Ledger', d: dict) \
-> Tuple[str, Optional[PrivateKey], PubKey]:
seed = d.get('seed', '')
private_key_string = d.get('private_key', '')
private_key = None
public_key = None
encrypted = d.get('encrypted', False)
if not encrypted:
if seed:
private_key = cls.get_private_key_from_seed(ledger, seed, '')
public_key = private_key.public_key
elif private_key_string:
private_key = from_extended_key_string(ledger, private_key_string)
public_key = private_key.public_key
if public_key is None:
public_key = from_extended_key_string(ledger, d['public_key'])
return seed, private_key, public_key
@classmethod
def from_dict(cls, ledger: 'Ledger', wallet: 'Wallet', d: dict):
seed, private_key, public_key = cls.keys_from_dict(ledger, d)
name = d.get('name')
if not name:
name = f'Account #{public_key.address}'
return cls(
ledger=ledger,
wallet=wallet,
name=name,
seed=seed,
private_key_string=d.get('private_key', ''),
encrypted=d.get('encrypted', False),
private_key=private_key,
public_key=public_key,
address_generator=d.get('address_generator', {}),
modified_on=d.get('modified_on', time.time()),
channel_keys=d.get('certificates', {})
)
def to_dict(self, encrypt_password: str = None, include_channel_keys: bool = True):
private_key_string, seed = self.private_key_string, self.seed
if not self.encrypted and self.private_key:
private_key_string = self.private_key.extended_key_string()
if not self.encrypted and encrypt_password:
if private_key_string:
private_key_string = aes_encrypt(
encrypt_password, private_key_string, self.get_init_vector('private_key')
)
if seed:
seed = aes_encrypt(encrypt_password, self.seed, self.get_init_vector('seed'))
d = {
'ledger': self.ledger.get_id(),
'name': self.name,
'seed': seed,
'encrypted': bool(self.encrypted or encrypt_password),
'private_key': private_key_string,
'public_key': self.public_key.extended_key_string(),
'address_generator': self.address_generator.to_dict(self.receiving, self.change),
'modified_on': self.modified_on
}
if include_channel_keys:
d['certificates'] = self.channel_keys
return d
def merge(self, d: dict):
if d.get('modified_on', 0) > self.modified_on:
self.name = d['name']
self.modified_on = d.get('modified_on', time.time())
assert self.address_generator.name == d['address_generator']['name']
for chain_name in ('change', 'receiving'):
if chain_name in d['address_generator']:
chain_object = getattr(self, chain_name)
chain_object.merge(d['address_generator'][chain_name])
self.channel_keys.update(d.get('certificates', {}))
@property @property
def hash(self) -> bytes: def hash(self) -> bytes:
assert not self.encrypted, "Cannot hash an encrypted account."
h = sha256(json.dumps(self.to_dict(include_channel_keys=False)).encode()) h = sha256(json.dumps(self.to_dict(include_channel_keys=False)).encode())
for cert in sorted(self.channel_keys.keys()): for cert in sorted(self.channel_keys.keys()):
h.update(cert.encode()) h.update(cert.encode())
return h.digest() return h.digest()
def merge(self, d: dict): async def get_details(self, show_seed=False, **kwargs):
super().merge(d) satoshis = await self.get_balance(**kwargs)
self.channel_keys.update(d.get('certificates', {})) details = {
'id': self.id,
'name': self.name,
'ledger': self.ledger.get_id(),
'coins': round(satoshis/COIN, 2),
'satoshis': satoshis,
'encrypted': self.encrypted,
'public_key': self.public_key.extended_key_string(),
'address_generator': self.address_generator.to_dict(self.receiving, self.change)
}
if show_seed:
details['seed'] = self.seed
details['certificates'] = len(self.channel_keys)
return details
def decrypt(self, password: str) -> bool:
assert self.encrypted, "Key is not encrypted."
try:
seed = self._decrypt_seed(password)
except (ValueError, InvalidPasswordError):
return False
try:
private_key = self._decrypt_private_key_string(password)
except (TypeError, ValueError, InvalidPasswordError):
return False
self.seed = seed
self.private_key = private_key
self.private_key_string = ""
self.encrypted = False
return True
def _decrypt_private_key_string(self, password: str) -> Optional[PrivateKey]:
if not self.private_key_string:
return None
private_key_string, self.init_vectors['private_key'] = aes_decrypt(password, self.private_key_string)
if not private_key_string:
return None
return from_extended_key_string(
self.ledger, private_key_string
)
def _decrypt_seed(self, password: str) -> str:
if not self.seed:
return ""
seed, self.init_vectors['seed'] = aes_decrypt(password, self.seed)
if not seed:
return ""
try:
Mnemonic().mnemonic_decode(seed)
except IndexError:
# failed to decode the seed, this either means it decrypted and is invalid
# or that we hit an edge case where an incorrect password gave valid padding
raise ValueError("Failed to decode seed.")
return seed
def encrypt(self, password: str) -> bool:
assert not self.encrypted, "Key is already encrypted."
if self.seed:
self.seed = aes_encrypt(password, self.seed, self.get_init_vector('seed'))
if isinstance(self.private_key, PrivateKey):
self.private_key_string = aes_encrypt(
password, self.private_key.extended_key_string(), self.get_init_vector('private_key')
)
self.private_key = None
self.encrypted = True
return True
async def ensure_address_gap(self):
addresses = []
for address_manager in self.address_managers.values():
new_addresses = await address_manager.ensure_address_gap()
addresses.extend(new_addresses)
return addresses
async def get_addresses(self, **constraints) -> List[str]:
rows = await self.ledger.db.select_addresses('address', accounts=[self], **constraints)
return [r[0] for r in rows]
def get_address_records(self, **constraints):
return self.ledger.db.get_addresses(accounts=[self], **constraints)
def get_address_count(self, **constraints):
return self.ledger.db.get_address_count(accounts=[self], **constraints)
def get_private_key(self, chain: int, index: int) -> PrivateKey:
assert not self.encrypted, "Cannot get private key on encrypted wallet account."
return self.address_managers[chain].get_private_key(index)
def get_public_key(self, chain: int, index: int) -> PubKey:
return self.address_managers[chain].get_public_key(index)
def get_balance(self, confirmations: int = 0, include_claims=False, **constraints):
if not include_claims:
constraints.update({'txo_type__in': (0, TXO_TYPES['purchase'])})
if confirmations > 0:
height = self.ledger.headers.height - (confirmations-1)
constraints.update({'height__lte': height, 'height__gt': 0})
return self.ledger.db.get_balance(accounts=[self], **constraints)
async def get_max_gap(self):
change_gap = await self.change.get_max_gap()
receiving_gap = await self.receiving.get_max_gap()
return {
'max_change_gap': change_gap,
'max_receiving_gap': receiving_gap,
}
def get_utxos(self, **constraints):
return self.ledger.get_utxos(wallet=self.wallet, accounts=[self], **constraints)
def get_utxo_count(self, **constraints):
return self.ledger.get_utxo_count(wallet=self.wallet, accounts=[self], **constraints)
def get_transactions(self, **constraints):
return self.ledger.get_transactions(wallet=self.wallet, accounts=[self], **constraints)
def get_transaction_count(self, **constraints):
return self.ledger.get_transaction_count(wallet=self.wallet, accounts=[self], **constraints)
async def fund(self, to_account, amount=None, everything=False,
outputs=1, broadcast=False, **constraints):
assert self.ledger == to_account.ledger, 'Can only transfer between accounts of the same ledger.'
if everything:
utxos = await self.get_utxos(**constraints)
await self.ledger.reserve_outputs(utxos)
tx = await Transaction.create(
inputs=[Input.spend(txo) for txo in utxos],
outputs=[],
funding_accounts=[self],
change_account=to_account
)
elif amount > 0:
to_address = await to_account.change.get_or_create_usable_address()
to_hash160 = to_account.ledger.address_to_hash160(to_address)
tx = await Transaction.create(
inputs=[],
outputs=[
Output.pay_pubkey_hash(amount//outputs, to_hash160)
for _ in range(outputs)
],
funding_accounts=[self],
change_account=self
)
else:
raise ValueError('An amount is required.')
if broadcast:
await self.ledger.broadcast(tx)
else:
await self.ledger.release_tx(tx)
return tx
def add_channel_private_key(self, private_key): def add_channel_private_key(self, private_key):
public_key_bytes = private_key.get_verifying_key().to_der() public_key_bytes = private_key.get_verifying_key().to_der()
@ -81,11 +556,6 @@ class Account(BaseAccount):
if gap_changed: if gap_changed:
self.wallet.save() self.wallet.save()
def get_balance(self, confirmations=0, include_claims=False, **constraints):
if not include_claims:
constraints.update({'txo_type__in': (0, TXO_TYPES['purchase'])})
return super().get_balance(confirmations, **constraints)
async def get_detailed_balance(self, confirmations=0, reserved_subtotals=False): async def get_detailed_balance(self, confirmations=0, reserved_subtotals=False):
tips_balance, supports_balance, claims_balance = 0, 0, 0 tips_balance, supports_balance, claims_balance = 0, 0, 0
get_total_balance = partial(self.get_balance, confirmations=confirmations, include_claims=True) get_total_balance = partial(self.get_balance, confirmations=confirmations, include_claims=True)
@ -116,29 +586,6 @@ class Account(BaseAccount):
} if reserved_subtotals else None } if reserved_subtotals else None
} }
@classmethod
def get_private_key_from_seed(cls, ledger, seed: str, password: str):
return super().get_private_key_from_seed(
ledger, seed, password or 'lbryum'
)
@classmethod
def from_dict(cls, ledger, wallet, d: dict) -> 'Account':
account = super().from_dict(ledger, wallet, d)
account.channel_keys = d.get('certificates', {})
return account
def to_dict(self, encrypt_password: str = None, include_channel_keys: bool = True):
d = super().to_dict(encrypt_password)
if include_channel_keys:
d['certificates'] = self.channel_keys
return d
async def get_details(self, **kwargs):
details = await super().get_details(**kwargs)
details['certificates'] = len(self.channel_keys)
return details
def get_transaction_history(self, **constraints): def get_transaction_history(self, **constraints):
return self.ledger.get_transaction_history(wallet=self.wallet, accounts=[self], **constraints) return self.ledger.get_transaction_history(wallet=self.wallet, accounts=[self], **constraints)

View file

@ -2,7 +2,7 @@ from coincurve import PublicKey, PrivateKey as _PrivateKey
from lbry.crypto.hash import hmac_sha512, hash160, double_sha256 from lbry.crypto.hash import hmac_sha512, hash160, double_sha256
from lbry.crypto.base58 import Base58 from lbry.crypto.base58 import Base58
from lbry.wallet.client.util import cachedproperty from .util import cachedproperty
class DerivationError(Exception): class DerivationError(Exception):

View file

@ -1,7 +1,7 @@
from random import Random from random import Random
from typing import List from typing import List
from lbry.wallet.client import basetransaction from lbry.wallet.transaction import OutputEffectiveAmountEstimator
MAXIMUM_TRIES = 100000 MAXIMUM_TRIES = 100000
@ -25,8 +25,8 @@ class CoinSelector:
self.random.seed(seed, version=1) self.random.seed(seed, version=1)
def select( def select(
self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], self, txos: List[OutputEffectiveAmountEstimator],
strategy_name: str = None) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]: strategy_name: str = None) -> List[OutputEffectiveAmountEstimator]:
if not txos: if not txos:
return [] return []
available = sum(c.effective_amount for c in txos) available = sum(c.effective_amount for c in txos)
@ -35,16 +35,16 @@ class CoinSelector:
return getattr(self, strategy_name or "standard")(txos, available) return getattr(self, strategy_name or "standard")(txos, available)
@strategy @strategy
def prefer_confirmed(self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], def prefer_confirmed(self, txos: List[OutputEffectiveAmountEstimator],
available: int) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]: available: int) -> List[OutputEffectiveAmountEstimator]:
return ( return (
self.only_confirmed(txos, available) or self.only_confirmed(txos, available) or
self.standard(txos, available) self.standard(txos, available)
) )
@strategy @strategy
def only_confirmed(self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], def only_confirmed(self, txos: List[OutputEffectiveAmountEstimator],
_) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]: _) -> List[OutputEffectiveAmountEstimator]:
confirmed = [t for t in txos if t.txo.tx_ref and t.txo.tx_ref.height > 0] confirmed = [t for t in txos if t.txo.tx_ref and t.txo.tx_ref.height > 0]
if not confirmed: if not confirmed:
return [] return []
@ -54,8 +54,8 @@ class CoinSelector:
return self.standard(confirmed, confirmed_available) return self.standard(confirmed, confirmed_available)
@strategy @strategy
def standard(self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], def standard(self, txos: List[OutputEffectiveAmountEstimator],
available: int) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]: available: int) -> List[OutputEffectiveAmountEstimator]:
return ( return (
self.branch_and_bound(txos, available) or self.branch_and_bound(txos, available) or
self.closest_match(txos, available) or self.closest_match(txos, available) or
@ -63,8 +63,8 @@ class CoinSelector:
) )
@strategy @strategy
def branch_and_bound(self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], def branch_and_bound(self, txos: List[OutputEffectiveAmountEstimator],
available: int) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]: available: int) -> List[OutputEffectiveAmountEstimator]:
# see bitcoin implementation for more info: # see bitcoin implementation for more info:
# https://github.com/bitcoin/bitcoin/blob/master/src/wallet/coinselection.cpp # https://github.com/bitcoin/bitcoin/blob/master/src/wallet/coinselection.cpp
@ -123,8 +123,8 @@ class CoinSelector:
return [] return []
@strategy @strategy
def closest_match(self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], def closest_match(self, txos: List[OutputEffectiveAmountEstimator],
_) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]: _) -> List[OutputEffectiveAmountEstimator]:
""" Pick one UTXOs that is larger than the target but with the smallest change. """ """ Pick one UTXOs that is larger than the target but with the smallest change. """
target = self.target + self.cost_of_change target = self.target + self.cost_of_change
smallest_change = None smallest_change = None
@ -137,8 +137,8 @@ class CoinSelector:
return [best_match] if best_match else [] return [best_match] if best_match else []
@strategy @strategy
def random_draw(self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], def random_draw(self, txos: List[OutputEffectiveAmountEstimator],
_) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]: _) -> List[OutputEffectiveAmountEstimator]:
""" Accumulate UTXOs at random until there is enough to cover the target. """ """ Accumulate UTXOs at random until there is enough to cover the target. """
target = self.target + self.cost_of_change target = self.target + self.cost_of_change
self.random.shuffle(txos, self.random.random) self.random.shuffle(txos, self.random.random)

View file

@ -1,3 +1,10 @@
NULL_HASH32 = b'\x00'*32
CENT = 1000000
COIN = 100*CENT
TIMEOUT = 30.0
TXO_TYPES = { TXO_TYPES = {
"stream": 1, "stream": 1,
"channel": 2, "channel": 2,

View file

@ -1,14 +1,321 @@
from typing import List import logging
import asyncio
import sqlite3
from lbry.wallet.client.basedatabase import BaseDatabase from binascii import hexlify
from concurrent.futures.thread import ThreadPoolExecutor
from typing import Tuple, List, Union, Callable, Any, Awaitable, Iterable, Dict, Optional
from lbry.wallet.transaction import Output from .bip32 import PubKey
from lbry.wallet.constants import TXO_TYPES, CLAIM_TYPES from .transaction import Transaction, Output, OutputScript, TXRefImmutable
from .constants import TXO_TYPES, CLAIM_TYPES
class WalletDatabase(BaseDatabase): log = logging.getLogger(__name__)
sqlite3.enable_callback_tracebacks(True)
SCHEMA_VERSION = f"{BaseDatabase.SCHEMA_VERSION}+1"
class AIOSQLite:
def __init__(self):
# has to be single threaded as there is no mapping of thread:connection
self.executor = ThreadPoolExecutor(max_workers=1)
self.connection: sqlite3.Connection = None
self._closing = False
self.query_count = 0
@classmethod
async def connect(cls, path: Union[bytes, str], *args, **kwargs):
sqlite3.enable_callback_tracebacks(True)
def _connect():
return sqlite3.connect(path, *args, **kwargs)
db = cls()
db.connection = await asyncio.get_event_loop().run_in_executor(db.executor, _connect)
return db
async def close(self):
if self._closing:
return
self._closing = True
await asyncio.get_event_loop().run_in_executor(self.executor, self.connection.close)
self.executor.shutdown(wait=True)
self.connection = None
def executemany(self, sql: str, params: Iterable):
params = params if params is not None else []
# this fetchall is needed to prevent SQLITE_MISUSE
return self.run(lambda conn: conn.executemany(sql, params).fetchall())
def executescript(self, script: str) -> Awaitable:
return self.run(lambda conn: conn.executescript(script))
def execute_fetchall(self, sql: str, parameters: Iterable = None) -> Awaitable[Iterable[sqlite3.Row]]:
parameters = parameters if parameters is not None else []
return self.run(lambda conn: conn.execute(sql, parameters).fetchall())
def execute_fetchone(self, sql: str, parameters: Iterable = None) -> Awaitable[Iterable[sqlite3.Row]]:
parameters = parameters if parameters is not None else []
return self.run(lambda conn: conn.execute(sql, parameters).fetchone())
def execute(self, sql: str, parameters: Iterable = None) -> Awaitable[sqlite3.Cursor]:
parameters = parameters if parameters is not None else []
return self.run(lambda conn: conn.execute(sql, parameters))
def run(self, fun, *args, **kwargs) -> Awaitable:
return asyncio.get_event_loop().run_in_executor(
self.executor, lambda: self.__run_transaction(fun, *args, **kwargs)
)
def __run_transaction(self, fun: Callable[[sqlite3.Connection, Any, Any], Any], *args, **kwargs):
self.connection.execute('begin')
try:
self.query_count += 1
result = fun(self.connection, *args, **kwargs) # type: ignore
self.connection.commit()
return result
except (Exception, OSError) as e:
log.exception('Error running transaction:', exc_info=e)
self.connection.rollback()
log.warning("rolled back")
raise
def run_with_foreign_keys_disabled(self, fun, *args, **kwargs) -> Awaitable:
return asyncio.get_event_loop().run_in_executor(
self.executor, self.__run_transaction_with_foreign_keys_disabled, fun, args, kwargs
)
def __run_transaction_with_foreign_keys_disabled(self,
fun: Callable[[sqlite3.Connection, Any, Any], Any],
args, kwargs):
foreign_keys_enabled, = self.connection.execute("pragma foreign_keys").fetchone()
if not foreign_keys_enabled:
raise sqlite3.IntegrityError("foreign keys are disabled, use `AIOSQLite.run` instead")
try:
self.connection.execute('pragma foreign_keys=off').fetchone()
return self.__run_transaction(fun, *args, **kwargs)
finally:
self.connection.execute('pragma foreign_keys=on').fetchone()
def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''):
sql, values = [], {}
for key, constraint in constraints.items():
tag = '0'
if '#' in key:
key, tag = key[:key.index('#')], key[key.index('#')+1:]
col, op, key = key, '=', key.replace('.', '_')
if not key:
sql.append(constraint)
continue
if key.startswith('$'):
values[key] = constraint
continue
if key.endswith('__not'):
col, op = col[:-len('__not')], '!='
elif key.endswith('__is_null'):
col = col[:-len('__is_null')]
sql.append(f'{col} IS NULL')
continue
if key.endswith('__is_not_null'):
col = col[:-len('__is_not_null')]
sql.append(f'{col} IS NOT NULL')
continue
if key.endswith('__lt'):
col, op = col[:-len('__lt')], '<'
elif key.endswith('__lte'):
col, op = col[:-len('__lte')], '<='
elif key.endswith('__gt'):
col, op = col[:-len('__gt')], '>'
elif key.endswith('__gte'):
col, op = col[:-len('__gte')], '>='
elif key.endswith('__like'):
col, op = col[:-len('__like')], 'LIKE'
elif key.endswith('__not_like'):
col, op = col[:-len('__not_like')], 'NOT LIKE'
elif key.endswith('__in') or key.endswith('__not_in'):
if key.endswith('__in'):
col, op = col[:-len('__in')], 'IN'
else:
col, op = col[:-len('__not_in')], 'NOT IN'
if constraint:
if isinstance(constraint, (list, set, tuple)):
keys = []
for i, val in enumerate(constraint):
keys.append(f':{key}{tag}_{i}')
values[f'{key}{tag}_{i}'] = val
sql.append(f'{col} {op} ({", ".join(keys)})')
elif isinstance(constraint, str):
sql.append(f'{col} {op} ({constraint})')
else:
raise ValueError(f"{col} requires a list, set or string as constraint value.")
continue
elif key.endswith('__any') or key.endswith('__or'):
where, subvalues = constraints_to_sql(constraint, ' OR ', key+tag+'_')
sql.append(f'({where})')
values.update(subvalues)
continue
if key.endswith('__and'):
where, subvalues = constraints_to_sql(constraint, ' AND ', key+tag+'_')
sql.append(f'({where})')
values.update(subvalues)
continue
sql.append(f'{col} {op} :{prepend_key}{key}{tag}')
values[prepend_key+key+tag] = constraint
return joiner.join(sql) if sql else '', values
def query(select, **constraints) -> Tuple[str, Dict[str, Any]]:
sql = [select]
limit = constraints.pop('limit', None)
offset = constraints.pop('offset', None)
order_by = constraints.pop('order_by', None)
accounts = constraints.pop('accounts', [])
if accounts:
constraints['account__in'] = [a.public_key.address for a in accounts]
where, values = constraints_to_sql(constraints)
if where:
sql.append('WHERE')
sql.append(where)
if order_by:
sql.append('ORDER BY')
if isinstance(order_by, str):
sql.append(order_by)
elif isinstance(order_by, list):
sql.append(', '.join(order_by))
else:
raise ValueError("order_by must be string or list")
if limit is not None:
sql.append(f'LIMIT {limit}')
if offset is not None:
sql.append(f'OFFSET {offset}')
return ' '.join(sql), values
def interpolate(sql, values):
for k in sorted(values.keys(), reverse=True):
value = values[k]
if isinstance(value, bytes):
value = f"X'{hexlify(value).decode()}'"
elif isinstance(value, str):
value = f"'{value}'"
else:
value = str(value)
sql = sql.replace(f":{k}", value)
return sql
def rows_to_dict(rows, fields):
if rows:
return [dict(zip(fields, r)) for r in rows]
else:
return []
class SQLiteMixin:
SCHEMA_VERSION: Optional[str] = None
CREATE_TABLES_QUERY: str
MAX_QUERY_VARIABLES = 900
CREATE_VERSION_TABLE = """
create table if not exists version (
version text
);
"""
def __init__(self, path):
self._db_path = path
self.db: AIOSQLite = None
self.ledger = None
async def open(self):
log.info("connecting to database: %s", self._db_path)
self.db = await AIOSQLite.connect(self._db_path, isolation_level=None)
if self.SCHEMA_VERSION:
tables = [t[0] for t in await self.db.execute_fetchall(
"SELECT name FROM sqlite_master WHERE type='table';"
)]
if tables:
if 'version' in tables:
version = await self.db.execute_fetchone("SELECT version FROM version LIMIT 1;")
if version == (self.SCHEMA_VERSION,):
return
await self.db.executescript('\n'.join(
f"DROP TABLE {table};" for table in tables
))
await self.db.execute(self.CREATE_VERSION_TABLE)
await self.db.execute("INSERT INTO version VALUES (?)", (self.SCHEMA_VERSION,))
await self.db.executescript(self.CREATE_TABLES_QUERY)
async def close(self):
await self.db.close()
@staticmethod
def _insert_sql(table: str, data: dict, ignore_duplicate: bool = False,
replace: bool = False) -> Tuple[str, List]:
columns, values = [], []
for column, value in data.items():
columns.append(column)
values.append(value)
policy = ""
if ignore_duplicate:
policy = " OR IGNORE"
if replace:
policy = " OR REPLACE"
sql = "INSERT{} INTO {} ({}) VALUES ({})".format(
policy, table, ', '.join(columns), ', '.join(['?'] * len(values))
)
return sql, values
@staticmethod
def _update_sql(table: str, data: dict, where: str,
constraints: Union[list, tuple]) -> Tuple[str, list]:
columns, values = [], []
for column, value in data.items():
columns.append(f"{column} = ?")
values.append(value)
values.extend(constraints)
sql = "UPDATE {} SET {} WHERE {}".format(
table, ', '.join(columns), where
)
return sql, values
class Database(SQLiteMixin):
SCHEMA_VERSION = "1.1"
PRAGMAS = """
pragma journal_mode=WAL;
"""
CREATE_ACCOUNT_TABLE = """
create table if not exists account_address (
account text not null,
address text not null,
chain integer not null,
pubkey blob not null,
chain_code blob not null,
n integer not null,
depth integer not null,
primary key (account, address)
);
create index if not exists address_account_idx on account_address (address, account);
"""
CREATE_PUBKEY_ADDRESS_TABLE = """
create table if not exists pubkey_address (
address text primary key,
history text,
used_times integer not null default 0
);
"""
CREATE_TX_TABLE = """ CREATE_TX_TABLE = """
create table if not exists tx ( create table if not exists tx (
@ -42,25 +349,35 @@ class WalletDatabase(BaseDatabase):
create index if not exists txo_txo_type_idx on txo (txo_type); create index if not exists txo_txo_type_idx on txo (txo_type);
""" """
CREATE_TXI_TABLE = """
create table if not exists txi (
txid text references tx,
txoid text references txo,
address text references pubkey_address
);
create index if not exists txi_address_idx on txi (address);
create index if not exists txi_txoid_idx on txi (txoid);
"""
CREATE_TABLES_QUERY = ( CREATE_TABLES_QUERY = (
BaseDatabase.PRAGMAS + PRAGMAS +
BaseDatabase.CREATE_ACCOUNT_TABLE + CREATE_ACCOUNT_TABLE +
BaseDatabase.CREATE_PUBKEY_ADDRESS_TABLE + CREATE_PUBKEY_ADDRESS_TABLE +
CREATE_TX_TABLE + CREATE_TX_TABLE +
CREATE_TXO_TABLE + CREATE_TXO_TABLE +
BaseDatabase.CREATE_TXI_TABLE CREATE_TXI_TABLE
) )
def tx_to_row(self, tx): @staticmethod
row = super().tx_to_row(tx) def txo_to_row(tx, address, txo):
txos = tx.outputs row = {
if len(txos) >= 2 and txos[1].can_decode_purchase_data: 'txid': tx.id,
txos[0].purchase = txos[1] 'txoid': txo.id,
row['purchased_claim_id'] = txos[1].purchase_data.claim_id 'address': address,
return row 'position': txo.position,
'amount': txo.amount,
def txo_to_row(self, tx, address, txo): 'script': sqlite3.Binary(txo.script.source)
row = super().txo_to_row(tx, address, txo) }
if txo.is_claim: if txo.is_claim:
if txo.can_decode_claim: if txo.can_decode_claim:
row['txo_type'] = TXO_TYPES.get(txo.claim.claim_type, TXO_TYPES['stream']) row['txo_type'] = TXO_TYPES.get(txo.claim.claim_type, TXO_TYPES['stream'])
@ -76,39 +393,212 @@ class WalletDatabase(BaseDatabase):
row['claim_name'] = txo.claim_name row['claim_name'] = txo.claim_name
return row return row
async def get_transactions(self, **constraints): @staticmethod
txs = await super().get_transactions(**constraints) def tx_to_row(tx):
row = {
'txid': tx.id,
'raw': sqlite3.Binary(tx.raw),
'height': tx.height,
'position': tx.position,
'is_verified': tx.is_verified
}
txos = tx.outputs
if len(txos) >= 2 and txos[1].can_decode_purchase_data:
txos[0].purchase = txos[1]
row['purchased_claim_id'] = txos[1].purchase_data.claim_id
return row
async def insert_transaction(self, tx):
await self.db.execute_fetchall(*self._insert_sql('tx', self.tx_to_row(tx)))
async def update_transaction(self, tx):
await self.db.execute_fetchall(*self._update_sql("tx", {
'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified
}, 'txid = ?', (tx.id,)))
def _transaction_io(self, conn: sqlite3.Connection, tx: Transaction, address, txhash, history):
conn.execute(*self._insert_sql('tx', self.tx_to_row(tx), replace=True))
for txo in tx.outputs:
if txo.script.is_pay_pubkey_hash and txo.pubkey_hash == txhash:
conn.execute(*self._insert_sql(
"txo", self.txo_to_row(tx, address, txo), ignore_duplicate=True
)).fetchall()
elif txo.script.is_pay_script_hash:
# TODO: implement script hash payments
log.warning('Database.save_transaction_io: pay script hash is not implemented!')
for txi in tx.inputs:
if txi.txo_ref.txo is not None:
txo = txi.txo_ref.txo
if txo.has_address and txo.get_address(self.ledger) == address:
conn.execute(*self._insert_sql("txi", {
'txid': tx.id,
'txoid': txo.id,
'address': address,
}, ignore_duplicate=True)).fetchall()
conn.execute(
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
(history, history.count(':') // 2, address)
)
def save_transaction_io(self, tx: Transaction, address, txhash, history):
return self.db.run(self._transaction_io, tx, address, txhash, history)
def save_transaction_io_batch(self, txs: Iterable[Transaction], address, txhash, history):
def __many(conn):
for tx in txs:
self._transaction_io(conn, tx, address, txhash, history)
return self.db.run(__many)
async def reserve_outputs(self, txos, is_reserved=True):
txoids = ((is_reserved, txo.id) for txo in txos)
await self.db.executemany("UPDATE txo SET is_reserved = ? WHERE txoid = ?", txoids)
async def release_outputs(self, txos):
await self.reserve_outputs(txos, is_reserved=False)
async def rewind_blockchain(self, above_height): # pylint: disable=no-self-use
# TODO:
# 1. delete transactions above_height
# 2. update address histories removing deleted TXs
return True
async def select_transactions(self, cols, accounts=None, **constraints):
if not {'txid', 'txid__in'}.intersection(constraints):
assert accounts, "'accounts' argument required when no 'txid' constraint is present"
constraints.update({
f'$account{i}': a.public_key.address for i, a in enumerate(accounts)
})
account_values = ', '.join([f':$account{i}' for i in range(len(accounts))])
where = f" WHERE account_address.account IN ({account_values})"
constraints['txid__in'] = f"""
SELECT txo.txid FROM txo JOIN account_address USING (address) {where}
UNION
SELECT txi.txid FROM txi JOIN account_address USING (address) {where}
"""
return await self.db.execute_fetchall(
*query(f"SELECT {cols} FROM tx", **constraints)
)
async def get_transactions(self, wallet=None, **constraints):
tx_rows = await self.select_transactions(
'txid, raw, height, position, is_verified',
order_by=constraints.pop('order_by', ["height=0 DESC", "height DESC", "position DESC"]),
**constraints
)
if not tx_rows:
return []
txids, txs, txi_txoids = [], [], []
for row in tx_rows:
txids.append(row[0])
txs.append(Transaction(
raw=row[1], height=row[2], position=row[3], is_verified=bool(row[4])
))
for txi in txs[-1].inputs:
txi_txoids.append(txi.txo_ref.id)
step = self.MAX_QUERY_VARIABLES
annotated_txos = {}
for offset in range(0, len(txids), step):
annotated_txos.update({
txo.id: txo for txo in
(await self.get_txos(
wallet=wallet,
txid__in=txids[offset:offset+step],
))
})
referenced_txos = {}
for offset in range(0, len(txi_txoids), step):
referenced_txos.update({
txo.id: txo for txo in
(await self.get_txos(
wallet=wallet,
txoid__in=txi_txoids[offset:offset+step],
))
})
for tx in txs:
for txi in tx.inputs:
txo = referenced_txos.get(txi.txo_ref.id)
if txo:
txi.txo_ref = txo.ref
for txo in tx.outputs:
_txo = annotated_txos.get(txo.id)
if _txo:
txo.update_annotations(_txo)
else:
txo.update_annotations(None)
for tx in txs: for tx in txs:
txos = tx.outputs txos = tx.outputs
if len(txos) >= 2 and txos[1].can_decode_purchase_data: if len(txos) >= 2 and txos[1].can_decode_purchase_data:
txos[0].purchase = txos[1] txos[0].purchase = txos[1]
return txs return txs
@staticmethod async def get_transaction_count(self, **constraints):
def constrain_purchases(constraints): constraints.pop('wallet', None)
accounts = constraints.pop('accounts', None) constraints.pop('offset', None)
assert accounts, "'accounts' argument required to find purchases" constraints.pop('limit', None)
if not {'purchased_claim_id', 'purchased_claim_id__in'}.intersection(constraints): constraints.pop('order_by', None)
constraints['purchased_claim_id__is_not_null'] = True count = await self.select_transactions('count(*)', **constraints)
constraints.update({ return count[0][0]
f'$account{i}': a.public_key.address for i, a in enumerate(accounts)
})
account_values = ', '.join([f':$account{i}' for i in range(len(accounts))])
constraints['txid__in'] = f"""
SELECT txid FROM txi JOIN account_address USING (address)
WHERE account_address.account IN ({account_values})
"""
async def get_purchases(self, **constraints): async def get_transaction(self, **constraints):
self.constrain_purchases(constraints) txs = await self.get_transactions(limit=1, **constraints)
return [tx.outputs[0] for tx in await self.get_transactions(**constraints)] if txs:
return txs[0]
def get_purchase_count(self, **constraints): async def select_txos(self, cols, **constraints):
self.constrain_purchases(constraints) sql = f"SELECT {cols} FROM txo JOIN tx USING (txid)"
return self.get_transaction_count(**constraints) if 'accounts' in constraints:
sql += " JOIN account_address USING (address)"
return await self.db.execute_fetchall(*query(sql, **constraints))
async def get_txos(self, wallet=None, no_tx=False, **constraints) -> List[Output]: async def get_txos(self, wallet=None, no_tx=False, **constraints):
txos = await super().get_txos(wallet=wallet, no_tx=no_tx, **constraints) my_accounts = {a.public_key.address for a in wallet.accounts} if wallet else set()
if 'order_by' not in constraints:
constraints['order_by'] = [
"tx.height=0 DESC", "tx.height DESC", "tx.position DESC", "txo.position"
]
rows = await self.select_txos(
"""
tx.txid, raw, tx.height, tx.position, tx.is_verified, txo.position, amount, script, (
select group_concat(account||"|"||chain) from account_address
where account_address.address=txo.address
)
""",
**constraints
)
txos = []
txs = {}
for row in rows:
if no_tx:
txo = Output(
amount=row[6],
script=OutputScript(row[7]),
tx_ref=TXRefImmutable.from_id(row[0], row[2]),
position=row[5]
)
else:
if row[0] not in txs:
txs[row[0]] = Transaction(
row[1], height=row[2], position=row[3], is_verified=row[4]
)
txo = txs[row[0]].outputs[row[5]]
row_accounts = dict(a.split('|') for a in row[8].split(','))
account_match = set(row_accounts) & my_accounts
if account_match:
txo.is_my_account = True
txo.is_change = row_accounts[account_match.pop()] == '1'
else:
txo.is_change = txo.is_my_account = False
txos.append(txo)
channel_ids = set() channel_ids = set()
for txo in txos: for txo in txos:
@ -138,6 +628,112 @@ class WalletDatabase(BaseDatabase):
return txos return txos
async def get_txo_count(self, **constraints):
constraints.pop('wallet', None)
constraints.pop('offset', None)
constraints.pop('limit', None)
constraints.pop('order_by', None)
count = await self.select_txos('count(*)', **constraints)
return count[0][0]
@staticmethod
def constrain_utxo(constraints):
constraints['is_reserved'] = False
constraints['txoid__not_in'] = "SELECT txoid FROM txi"
def get_utxos(self, **constraints):
self.constrain_utxo(constraints)
return self.get_txos(**constraints)
def get_utxo_count(self, **constraints):
self.constrain_utxo(constraints)
return self.get_txo_count(**constraints)
async def get_balance(self, wallet=None, accounts=None, **constraints):
assert wallet or accounts, \
"'wallet' or 'accounts' constraints required to calculate balance"
constraints['accounts'] = accounts or wallet.accounts
self.constrain_utxo(constraints)
balance = await self.select_txos('SUM(amount)', **constraints)
return balance[0][0] or 0
async def select_addresses(self, cols, **constraints):
return await self.db.execute_fetchall(*query(
f"SELECT {cols} FROM pubkey_address JOIN account_address USING (address)",
**constraints
))
async def get_addresses(self, cols=None, **constraints):
cols = cols or (
'address', 'account', 'chain', 'history', 'used_times',
'pubkey', 'chain_code', 'n', 'depth'
)
addresses = rows_to_dict(await self.select_addresses(', '.join(cols), **constraints), cols)
if 'pubkey' in cols:
for address in addresses:
address['pubkey'] = PubKey(
self.ledger, address.pop('pubkey'), address.pop('chain_code'),
address.pop('n'), address.pop('depth')
)
return addresses
async def get_address_count(self, cols=None, **constraints):
count = await self.select_addresses('count(*)', **constraints)
return count[0][0]
async def get_address(self, **constraints):
addresses = await self.get_addresses(limit=1, **constraints)
if addresses:
return addresses[0]
async def add_keys(self, account, chain, pubkeys):
await self.db.executemany(
"insert or ignore into account_address "
"(account, address, chain, pubkey, chain_code, n, depth) values "
"(?, ?, ?, ?, ?, ?, ?)", ((
account.id, k.address, chain,
sqlite3.Binary(k.pubkey_bytes),
sqlite3.Binary(k.chain_code),
k.n, k.depth
) for k in pubkeys)
)
await self.db.executemany(
"insert or ignore into pubkey_address (address) values (?)",
((pubkey.address,) for pubkey in pubkeys)
)
async def _set_address_history(self, address, history):
await self.db.execute_fetchall(
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
(history, history.count(':')//2, address)
)
async def set_address_history(self, address, history):
await self._set_address_history(address, history)
@staticmethod
def constrain_purchases(constraints):
accounts = constraints.pop('accounts', None)
assert accounts, "'accounts' argument required to find purchases"
if not {'purchased_claim_id', 'purchased_claim_id__in'}.intersection(constraints):
constraints['purchased_claim_id__is_not_null'] = True
constraints.update({
f'$account{i}': a.public_key.address for i, a in enumerate(accounts)
})
account_values = ', '.join([f':$account{i}' for i in range(len(accounts))])
constraints['txid__in'] = f"""
SELECT txid FROM txi JOIN account_address USING (address)
WHERE account_address.account IN ({account_values})
"""
async def get_purchases(self, **constraints):
self.constrain_purchases(constraints)
return [tx.outputs[0] for tx in await self.get_transactions(**constraints)]
def get_purchase_count(self, **constraints):
self.constrain_purchases(constraints)
return self.get_transaction_count(**constraints)
@staticmethod @staticmethod
def constrain_claims(constraints): def constrain_claims(constraints):
claim_type = constraints.pop('claim_type', None) claim_type = constraints.pop('claim_type', None)

View file

@ -1,5 +1,5 @@
import textwrap import textwrap
from lbry.wallet.client.util import coins_to_satoshis, satoshis_to_coins from .util import coins_to_satoshis, satoshis_to_coins
def lbc_to_dewies(lbc: str) -> int: def lbc_to_dewies(lbc: str) -> int:

View file

@ -1,5 +1,5 @@
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from lbry.wallet.client.constants import NULL_HASH32 from .constants import NULL_HASH32
class TXRef: class TXRef:

View file

@ -1,10 +1,252 @@
import os
import struct import struct
from typing import Optional import asyncio
import hashlib
import logging
from io import BytesIO
from contextlib import asynccontextmanager
from typing import Optional, Iterator, Tuple
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from lbry.crypto.hash import sha512, double_sha256, ripemd160 from lbry.crypto.hash import sha512, double_sha256, ripemd160
from lbry.wallet.client.baseheader import BaseHeaders from lbry.wallet.util import ArithUint256
from lbry.wallet.client.util import ArithUint256
log = logging.getLogger(__name__)
class InvalidHeader(Exception):
def __init__(self, height, message):
super().__init__(message)
self.message = message
self.height = height
class BaseHeaders:
header_size: int
chunk_size: int
max_target: int
genesis_hash: Optional[bytes]
target_timespan: int
validate_difficulty: bool = True
checkpoint = None
def __init__(self, path) -> None:
if path == ':memory:':
self.io = BytesIO()
self.path = path
self._size: Optional[int] = None
async def open(self):
if self.path != ':memory:':
if not os.path.exists(self.path):
self.io = open(self.path, 'w+b')
else:
self.io = open(self.path, 'r+b')
async def close(self):
self.io.close()
@staticmethod
def serialize(header: dict) -> bytes:
raise NotImplementedError
@staticmethod
def deserialize(height, header):
raise NotImplementedError
def get_next_chunk_target(self, chunk: int) -> ArithUint256:
return ArithUint256(self.max_target)
@staticmethod
def get_next_block_target(chunk_target: ArithUint256, previous: Optional[dict],
current: Optional[dict]) -> ArithUint256:
return chunk_target
def __len__(self) -> int:
if self._size is None:
self._size = self.io.seek(0, os.SEEK_END) // self.header_size
return self._size
def __bool__(self):
return True
def __getitem__(self, height) -> dict:
if isinstance(height, slice):
raise NotImplementedError("Slicing of header chain has not been implemented yet.")
if not 0 <= height <= self.height:
raise IndexError(f"{height} is out of bounds, current height: {self.height}")
return self.deserialize(height, self.get_raw_header(height))
def get_raw_header(self, height) -> bytes:
self.io.seek(height * self.header_size, os.SEEK_SET)
return self.io.read(self.header_size)
@property
def height(self) -> int:
return len(self)-1
@property
def bytes_size(self):
return len(self) * self.header_size
def hash(self, height=None) -> bytes:
return self.hash_header(
self.get_raw_header(height if height is not None else self.height)
)
@staticmethod
def hash_header(header: bytes) -> bytes:
if header is None:
return b'0' * 64
return hexlify(double_sha256(header)[::-1])
@asynccontextmanager
async def checkpointed_connector(self):
buf = BytesIO()
try:
yield buf
finally:
await asyncio.sleep(0)
final_height = len(self) + buf.tell() // self.header_size
verifiable_bytes = (self.checkpoint[0] - len(self)) * self.header_size if self.checkpoint else 0
if verifiable_bytes > 0 and final_height >= self.checkpoint[0]:
buf.seek(0)
self.io.seek(0)
h = hashlib.sha256()
h.update(self.io.read())
h.update(buf.read(verifiable_bytes))
if h.hexdigest().encode() == self.checkpoint[1]:
buf.seek(0)
self._write(len(self), buf.read(verifiable_bytes))
remaining = buf.read()
buf.seek(0)
buf.write(remaining)
buf.truncate()
else:
log.warning("Checkpoint mismatch, connecting headers through slow method.")
if buf.tell() > 0:
await self.connect(len(self), buf.getvalue())
async def connect(self, start: int, headers: bytes) -> int:
added = 0
bail = False
for height, chunk in self._iterate_chunks(start, headers):
try:
# validate_chunk() is CPU bound and reads previous chunks from file system
self.validate_chunk(height, chunk)
except InvalidHeader as e:
bail = True
chunk = chunk[:(height-e.height)*self.header_size]
added += self._write(height, chunk) if chunk else 0
if bail:
break
return added
def _write(self, height, verified_chunk):
self.io.seek(height * self.header_size, os.SEEK_SET)
written = self.io.write(verified_chunk) // self.header_size
self.io.truncate()
# .seek()/.write()/.truncate() might also .flush() when needed
# the goal here is mainly to ensure we're definitely flush()'ing
self.io.flush()
self._size = self.io.tell() // self.header_size
return written
def validate_chunk(self, height, chunk):
previous_hash, previous_header, previous_previous_header = None, None, None
if height > 0:
previous_header = self[height-1]
previous_hash = self.hash(height-1)
if height > 1:
previous_previous_header = self[height-2]
chunk_target = self.get_next_chunk_target(height // 2016 - 1)
for current_hash, current_header in self._iterate_headers(height, chunk):
block_target = self.get_next_block_target(chunk_target, previous_previous_header, previous_header)
self.validate_header(height, current_hash, current_header, previous_hash, block_target)
previous_previous_header = previous_header
previous_header = current_header
previous_hash = current_hash
def validate_header(self, height: int, current_hash: bytes,
header: dict, previous_hash: bytes, target: ArithUint256):
if previous_hash is None:
if self.genesis_hash is not None and self.genesis_hash != current_hash:
raise InvalidHeader(
height, f"genesis header doesn't match: {current_hash.decode()} "
f"vs expected {self.genesis_hash.decode()}")
return
if header['prev_block_hash'] != previous_hash:
raise InvalidHeader(
height, "previous hash mismatch: {} vs expected {}".format(
header['prev_block_hash'].decode(), previous_hash.decode())
)
if self.validate_difficulty:
if header['bits'] != target.compact:
raise InvalidHeader(
height, "bits mismatch: {} vs expected {}".format(
header['bits'], target.compact)
)
proof_of_work = self.get_proof_of_work(current_hash)
if proof_of_work > target:
raise InvalidHeader(
height, f"insufficient proof of work: {proof_of_work.value} vs target {target.value}"
)
async def repair(self):
previous_header_hash = fail = None
batch_size = 36
for start_height in range(0, self.height, batch_size):
self.io.seek(self.header_size * start_height)
headers = self.io.read(self.header_size*batch_size)
if len(headers) % self.header_size != 0:
headers = headers[:(len(headers) // self.header_size) * self.header_size]
for header_hash, header in self._iterate_headers(start_height, headers):
height = header['block_height']
if height:
if header['prev_block_hash'] != previous_header_hash:
fail = True
else:
if header_hash != self.genesis_hash:
fail = True
if fail:
log.warning("Header file corrupted at height %s, truncating it.", height - 1)
self.io.seek(max(0, (height - 1)) * self.header_size, os.SEEK_SET)
self.io.truncate()
self.io.flush()
self._size = None
return
previous_header_hash = header_hash
@staticmethod
def get_proof_of_work(header_hash: bytes) -> ArithUint256:
return ArithUint256(int(b'0x' + header_hash, 16))
def _iterate_chunks(self, height: int, headers: bytes) -> Iterator[Tuple[int, bytes]]:
assert len(headers) % self.header_size == 0, f"{len(headers)} {len(headers)%self.header_size}"
start = 0
end = (self.chunk_size - height % self.chunk_size) * self.header_size
while start < end:
yield height + (start // self.header_size), headers[start:end]
start = end
end = min(len(headers), end + self.chunk_size * self.header_size)
def _iterate_headers(self, height: int, headers: bytes) -> Iterator[Tuple[bytes, dict]]:
assert len(headers) % self.header_size == 0, len(headers)
for idx in range(len(headers) // self.header_size):
start, end = idx * self.header_size, (idx + 1) * self.header_size
header = headers[start:end]
yield self.hash_header(header), self.deserialize(height+idx, header)
class Headers(BaseHeaders): class Headers(BaseHeaders):

View file

@ -1,41 +1,96 @@
import os
import zlib
import pylru
import base64
import asyncio import asyncio
import logging import logging
from binascii import unhexlify
from functools import partial
from typing import Tuple, List
from datetime import datetime
import pylru from io import StringIO
from lbry.wallet.client.baseledger import BaseLedger, TransactionEvent from datetime import datetime
from lbry.wallet.client.baseaccount import SingleKey from functools import partial
from operator import itemgetter
from collections import namedtuple
from binascii import hexlify, unhexlify
from typing import Dict, Tuple, Type, Iterable, List, Optional
from lbry.schema.result import Outputs from lbry.schema.result import Outputs
from lbry.schema.url import URL from lbry.schema.url import URL
from lbry.wallet.dewies import dewies_to_lbc from lbry.crypto.hash import hash160, double_sha256, sha256
from lbry.wallet.account import Account from lbry.crypto.base58 import Base58
from lbry.wallet.network import Network
from lbry.wallet.database import WalletDatabase from .tasks import TaskGroup
from lbry.wallet.transaction import Transaction, Output from .database import Database
from lbry.wallet.header import Headers, UnvalidatedHeaders from .stream import StreamController
from lbry.wallet.constants import TXO_TYPES from .dewies import dewies_to_lbc
from .account import Account, AddressManager, SingleKey
from .network import Network
from .transaction import Transaction, Output
from .header import Headers, UnvalidatedHeaders
from .constants import TXO_TYPES, COIN, NULL_HASH32
from .bip32 import PubKey, PrivateKey
from .coinselection import CoinSelector
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
LedgerType = Type['BaseLedger']
class MainNetLedger(BaseLedger):
class LedgerRegistry(type):
ledgers: Dict[str, LedgerType] = {}
def __new__(mcs, name, bases, attrs):
cls: LedgerType = super().__new__(mcs, name, bases, attrs)
if not (name == 'BaseLedger' and not bases):
ledger_id = cls.get_id()
assert ledger_id not in mcs.ledgers, \
f'Ledger with id "{ledger_id}" already registered.'
mcs.ledgers[ledger_id] = cls
return cls
@classmethod
def get_ledger_class(mcs, ledger_id: str) -> LedgerType:
return mcs.ledgers[ledger_id]
class TransactionEvent(namedtuple('TransactionEvent', ('address', 'tx'))):
pass
class AddressesGeneratedEvent(namedtuple('AddressesGeneratedEvent', ('address_manager', 'addresses'))):
pass
class BlockHeightEvent(namedtuple('BlockHeightEvent', ('height', 'change'))):
pass
class TransactionCacheItem:
__slots__ = '_tx', 'lock', 'has_tx'
def __init__(self, tx: Optional[Transaction] = None, lock: Optional[asyncio.Lock] = None):
self.has_tx = asyncio.Event()
self.lock = lock or asyncio.Lock()
self._tx = self.tx = tx
@property
def tx(self) -> Optional[Transaction]:
return self._tx
@tx.setter
def tx(self, tx: Transaction):
self._tx = tx
if tx is not None:
self.has_tx.set()
class Ledger(metaclass=LedgerRegistry):
name = 'LBRY Credits' name = 'LBRY Credits'
symbol = 'LBC' symbol = 'LBC'
network_name = 'mainnet' network_name = 'mainnet'
headers: Headers
account_class = Account
database_class = WalletDatabase
headers_class = Headers headers_class = Headers
network_class = Network
transaction_class = Transaction
db: WalletDatabase
secret_prefix = bytes((0x1c,)) secret_prefix = bytes((0x1c,))
pubkey_address_prefix = bytes((0x55,)) pubkey_address_prefix = bytes((0x55,))
@ -51,11 +106,522 @@ class MainNetLedger(BaseLedger):
default_fee_per_byte = 50 default_fee_per_byte = 50
default_fee_per_name_char = 200000 default_fee_per_name_char = 200000
def __init__(self, *args, **kwargs): def __init__(self, config=None):
super().__init__(*args, **kwargs) self.config = config or {}
self.db: Database = self.config.get('db') or Database(
os.path.join(self.path, "blockchain.db")
)
self.db.ledger = self
self.headers: Headers = self.config.get('headers') or self.headers_class(
os.path.join(self.path, "headers")
)
self.network: Network = self.config.get('network') or Network(self)
self.network.on_header.listen(self.receive_header)
self.network.on_status.listen(self.process_status_update)
self.network.on_connected.listen(self.join_network)
self.accounts = []
self.fee_per_byte: int = self.config.get('fee_per_byte', self.default_fee_per_byte)
self._on_transaction_controller = StreamController()
self.on_transaction = self._on_transaction_controller.stream
self.on_transaction.listen(
lambda e: log.info(
'(%s) on_transaction: address=%s, height=%s, is_verified=%s, tx.id=%s',
self.get_id(), e.address, e.tx.height, e.tx.is_verified, e.tx.id
)
)
self._on_address_controller = StreamController()
self.on_address = self._on_address_controller.stream
self.on_address.listen(
lambda e: log.info('(%s) on_address: %s', self.get_id(), e.addresses)
)
self._on_header_controller = StreamController()
self.on_header = self._on_header_controller.stream
self.on_header.listen(
lambda change: log.info(
'%s: added %s header blocks, final height %s',
self.get_id(), change, self.headers.height
)
)
self._download_height = 0
self._on_ready_controller = StreamController()
self.on_ready = self._on_ready_controller.stream
self._tx_cache = pylru.lrucache(100000)
self._update_tasks = TaskGroup()
self._utxo_reservation_lock = asyncio.Lock()
self._header_processing_lock = asyncio.Lock()
self._address_update_locks: Dict[str, asyncio.Lock] = {}
self.coin_selection_strategy = None
self._known_addresses_out_of_sync = set()
self.fee_per_name_char = self.config.get('fee_per_name_char', self.default_fee_per_name_char) self.fee_per_name_char = self.config.get('fee_per_name_char', self.default_fee_per_name_char)
self._balance_cache = pylru.lrucache(100000) self._balance_cache = pylru.lrucache(100000)
@classmethod
def get_id(cls):
return '{}_{}'.format(cls.symbol.lower(), cls.network_name.lower())
@classmethod
def hash160_to_address(cls, h160):
raw_address = cls.pubkey_address_prefix + h160
return Base58.encode(bytearray(raw_address + double_sha256(raw_address)[0:4]))
@staticmethod
def address_to_hash160(address):
return Base58.decode(address)[1:21]
@classmethod
def is_valid_address(cls, address):
decoded = Base58.decode_check(address)
return decoded[0] == cls.pubkey_address_prefix[0]
@classmethod
def public_key_to_address(cls, public_key):
return cls.hash160_to_address(hash160(public_key))
@staticmethod
def private_key_to_wif(private_key):
return b'\x1c' + private_key + b'\x01'
@property
def path(self):
return os.path.join(self.config['data_path'], self.get_id())
def add_account(self, account: Account):
self.accounts.append(account)
async def _get_account_and_address_info_for_address(self, wallet, address):
match = await self.db.get_address(accounts=wallet.accounts, address=address)
if match:
for account in wallet.accounts:
if match['account'] == account.public_key.address:
return account, match
async def get_private_key_for_address(self, wallet, address) -> Optional[PrivateKey]:
match = await self._get_account_and_address_info_for_address(wallet, address)
if match:
account, address_info = match
return account.get_private_key(address_info['chain'], address_info['pubkey'].n)
return None
async def get_public_key_for_address(self, wallet, address) -> Optional[PubKey]:
match = await self._get_account_and_address_info_for_address(wallet, address)
if match:
_, address_info = match
return address_info['pubkey']
return None
async def get_account_for_address(self, wallet, address):
match = await self._get_account_and_address_info_for_address(wallet, address)
if match:
return match[0]
async def get_effective_amount_estimators(self, funding_accounts: Iterable[Account]):
estimators = []
for account in funding_accounts:
utxos = await account.get_utxos()
for utxo in utxos:
estimators.append(utxo.get_estimator(self))
return estimators
async def get_addresses(self, **constraints):
return await self.db.get_addresses(**constraints)
def get_address_count(self, **constraints):
return self.db.get_address_count(**constraints)
async def get_spendable_utxos(self, amount: int, funding_accounts):
async with self._utxo_reservation_lock:
txos = await self.get_effective_amount_estimators(funding_accounts)
fee = Output.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(self)
selector = CoinSelector(amount, fee)
spendables = selector.select(txos, self.coin_selection_strategy)
if spendables:
await self.reserve_outputs(s.txo for s in spendables)
return spendables
def reserve_outputs(self, txos):
return self.db.reserve_outputs(txos)
def release_outputs(self, txos):
return self.db.release_outputs(txos)
def release_tx(self, tx):
return self.release_outputs([txi.txo_ref.txo for txi in tx.inputs])
def get_utxos(self, **constraints):
self.constraint_spending_utxos(constraints)
return self.db.get_utxos(**constraints)
def get_utxo_count(self, **constraints):
self.constraint_spending_utxos(constraints)
return self.db.get_utxo_count(**constraints)
def get_transactions(self, **constraints):
return self.db.get_transactions(**constraints)
def get_transaction_count(self, **constraints):
return self.db.get_transaction_count(**constraints)
async def get_local_status_and_history(self, address, history=None):
if not history:
address_details = await self.db.get_address(address=address)
history = address_details['history'] or ''
parts = history.split(':')[:-1]
return (
hexlify(sha256(history.encode())).decode() if history else None,
list(zip(parts[0::2], map(int, parts[1::2])))
)
@staticmethod
def get_root_of_merkle_tree(branches, branch_positions, working_branch):
for i, branch in enumerate(branches):
other_branch = unhexlify(branch)[::-1]
other_branch_on_left = bool((branch_positions >> i) & 1)
if other_branch_on_left:
combined = other_branch + working_branch
else:
combined = working_branch + other_branch
working_branch = double_sha256(combined)
return hexlify(working_branch[::-1])
async def start(self):
if not os.path.exists(self.path):
os.mkdir(self.path)
await asyncio.wait([
self.db.open(),
self.headers.open()
])
first_connection = self.network.on_connected.first
asyncio.ensure_future(self.network.start())
await first_connection
async with self._header_processing_lock:
await self._update_tasks.add(self.initial_headers_sync())
await self._on_ready_controller.stream.first
await asyncio.gather(*(a.maybe_migrate_certificates() for a in self.accounts))
await asyncio.gather(*(a.save_max_gap() for a in self.accounts))
if len(self.accounts) > 10:
log.info("Loaded %i accounts", len(self.accounts))
else:
await self._report_state()
self.on_transaction.listen(self._reset_balance_cache)
async def join_network(self, *_):
log.info("Subscribing and updating accounts.")
async with self._header_processing_lock:
await self.update_headers()
await self.subscribe_accounts()
await self._update_tasks.done.wait()
self._on_ready_controller.add(True)
async def stop(self):
self._update_tasks.cancel()
await self._update_tasks.done.wait()
await self.network.stop()
await self.db.close()
await self.headers.close()
@property
def local_height_including_downloaded_height(self):
return max(self.headers.height, self._download_height)
async def initial_headers_sync(self):
target = self.network.remote_height + 1
current = len(self.headers)
get_chunk = partial(self.network.retriable_call, self.network.get_headers, count=4096, b64=True)
chunks = [asyncio.create_task(get_chunk(height)) for height in range(current, target, 4096)]
total = 0
async with self.headers.checkpointed_connector() as buffer:
for chunk in chunks:
headers = await chunk
total += buffer.write(
zlib.decompress(base64.b64decode(headers['base64']), wbits=-15, bufsize=600_000)
)
self._download_height = current + total // self.headers.header_size
log.info("Headers sync: %s / %s", self._download_height, target)
async def update_headers(self, height=None, headers=None, subscription_update=False):
rewound = 0
while True:
if height is None or height > len(self.headers):
# sometimes header subscription updates are for a header in the future
# which can't be connected, so we do a normal header sync instead
height = len(self.headers)
headers = None
subscription_update = False
if not headers:
header_response = await self.network.retriable_call(self.network.get_headers, height, 2001)
headers = header_response['hex']
if not headers:
# Nothing to do, network thinks we're already at the latest height.
return
added = await self.headers.connect(height, unhexlify(headers))
if added > 0:
height += added
self._on_header_controller.add(
BlockHeightEvent(self.headers.height, added))
if rewound > 0:
# we started rewinding blocks and apparently found
# a new chain
rewound = 0
await self.db.rewind_blockchain(height)
if subscription_update:
# subscription updates are for latest header already
# so we don't need to check if there are newer / more
# on another loop of update_headers(), just return instead
return
elif added == 0:
# we had headers to connect but none got connected, probably a reorganization
height -= 1
rewound += 1
log.warning(
"Blockchain Reorganization: attempting rewind to height %s from starting height %s",
height, height+rewound
)
else:
raise IndexError(f"headers.connect() returned negative number ({added})")
if height < 0:
raise IndexError(
"Blockchain reorganization rewound all the way back to genesis hash. "
"Something is very wrong. Maybe you are on the wrong blockchain?"
)
if rewound >= 100:
raise IndexError(
"Blockchain reorganization dropped {} headers. This is highly unusual. "
"Will not continue to attempt reorganizing. Please, delete the ledger "
"synchronization directory inside your wallet directory (folder: '{}') and "
"restart the program to synchronize from scratch."
.format(rewound, self.get_id())
)
headers = None # ready to download some more headers
# if we made it this far and this was a subscription_update
# it means something went wrong and now we're doing a more
# robust sync, turn off subscription update shortcut
subscription_update = False
async def receive_header(self, response):
async with self._header_processing_lock:
header = response[0]
await self.update_headers(
height=header['height'], headers=header['hex'], subscription_update=True
)
async def subscribe_accounts(self):
if self.network.is_connected and self.accounts:
await asyncio.wait([
self.subscribe_account(a) for a in self.accounts
])
async def subscribe_account(self, account: Account):
for address_manager in account.address_managers.values():
await self.subscribe_addresses(address_manager, await address_manager.get_addresses())
await account.ensure_address_gap()
async def unsubscribe_account(self, account: Account):
for address in await account.get_addresses():
await self.network.unsubscribe_address(address)
async def announce_addresses(self, address_manager: AddressManager, addresses: List[str]):
await self.subscribe_addresses(address_manager, addresses)
await self._on_address_controller.add(
AddressesGeneratedEvent(address_manager, addresses)
)
async def subscribe_addresses(self, address_manager: AddressManager, addresses: List[str]):
if self.network.is_connected and addresses:
await asyncio.wait([
self.subscribe_address(address_manager, address) for address in addresses
])
async def subscribe_address(self, address_manager: AddressManager, address: str):
remote_status = await self.network.subscribe_address(address)
self._update_tasks.add(self.update_history(address, remote_status, address_manager))
def process_status_update(self, update):
address, remote_status = update
self._update_tasks.add(self.update_history(address, remote_status))
async def update_history(self, address, remote_status,
address_manager: AddressManager = None):
async with self._address_update_locks.setdefault(address, asyncio.Lock()):
self._known_addresses_out_of_sync.discard(address)
local_status, local_history = await self.get_local_status_and_history(address)
if local_status == remote_status:
return True
remote_history = await self.network.retriable_call(self.network.get_history, address)
remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history))
we_need = set(remote_history) - set(local_history)
if not we_need:
return True
cache_tasks: List[asyncio.Future[Transaction]] = []
synced_history = StringIO()
for i, (txid, remote_height) in enumerate(remote_history):
if i < len(local_history) and local_history[i] == (txid, remote_height) and not cache_tasks:
synced_history.write(f'{txid}:{remote_height}:')
else:
check_local = (txid, remote_height) not in we_need
cache_tasks.append(asyncio.ensure_future(
self.cache_transaction(txid, remote_height, check_local=check_local)
))
synced_txs = []
for task in cache_tasks:
tx = await task
check_db_for_txos = []
for txi in tx.inputs:
if txi.txo_ref.txo is not None:
continue
cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.id)
if cache_item is not None:
if cache_item.tx is None:
await cache_item.has_tx.wait()
assert cache_item.tx is not None
txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref
else:
check_db_for_txos.append(txi.txo_ref.id)
referenced_txos = {} if not check_db_for_txos else {
txo.id: txo for txo in await self.db.get_txos(txoid__in=check_db_for_txos, no_tx=True)
}
for txi in tx.inputs:
if txi.txo_ref.txo is not None:
continue
referenced_txo = referenced_txos.get(txi.txo_ref.id)
if referenced_txo is not None:
txi.txo_ref = referenced_txo.ref
synced_history.write(f'{tx.id}:{tx.height}:')
synced_txs.append(tx)
await self.db.save_transaction_io_batch(
synced_txs, address, self.address_to_hash160(address), synced_history.getvalue()
)
await asyncio.wait([
self._on_transaction_controller.add(TransactionEvent(address, tx))
for tx in synced_txs
])
if address_manager is None:
address_manager = await self.get_address_manager_for_address(address)
if address_manager is not None:
await address_manager.ensure_address_gap()
local_status, local_history = \
await self.get_local_status_and_history(address, synced_history.getvalue())
if local_status != remote_status:
if local_history == remote_history:
return True
log.warning(
"Wallet is out of sync after syncing. Remote: %s with %d items, local: %s with %d items",
remote_status, len(remote_history), local_status, len(local_history)
)
log.warning("local: %s", local_history)
log.warning("remote: %s", remote_history)
self._known_addresses_out_of_sync.add(address)
return False
else:
return True
async def cache_transaction(self, txid, remote_height, check_local=True):
cache_item = self._tx_cache.get(txid)
if cache_item is None:
cache_item = self._tx_cache[txid] = TransactionCacheItem()
elif cache_item.tx is not None and \
cache_item.tx.height >= remote_height and \
(cache_item.tx.is_verified or remote_height < 1):
return cache_item.tx # cached tx is already up-to-date
async with cache_item.lock:
tx = cache_item.tx
if tx is None and check_local:
# check local db
tx = cache_item.tx = await self.db.get_transaction(txid=txid)
if tx is None:
# fetch from network
_raw = await self.network.retriable_call(self.network.get_transaction, txid, remote_height)
tx = Transaction(unhexlify(_raw))
cache_item.tx = tx # make sure it's saved before caching it
await self.maybe_verify_transaction(tx, remote_height)
return tx
async def maybe_verify_transaction(self, tx, remote_height):
tx.height = remote_height
if 0 < remote_height < len(self.headers):
merkle = await self.network.retriable_call(self.network.get_merkle, tx.id, remote_height)
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
header = self.headers[remote_height]
tx.position = merkle['pos']
tx.is_verified = merkle_root == header['merkle_root']
async def get_address_manager_for_address(self, address) -> Optional[AddressManager]:
details = await self.db.get_address(address=address)
for account in self.accounts:
if account.id == details['account']:
return account.address_managers[details['chain']]
return None
def broadcast(self, tx):
# broadcast can't be a retriable call yet
return self.network.broadcast(hexlify(tx.raw).decode())
async def wait(self, tx: Transaction, height=-1, timeout=1):
addresses = set()
for txi in tx.inputs:
if txi.txo_ref.txo is not None:
addresses.add(
self.hash160_to_address(txi.txo_ref.txo.pubkey_hash)
)
for txo in tx.outputs:
if txo.has_address:
addresses.add(self.hash160_to_address(txo.pubkey_hash))
records = await self.db.get_addresses(address__in=addresses)
_, pending = await asyncio.wait([
self.on_transaction.where(partial(
lambda a, e: a == e.address and e.tx.height >= height and e.tx.id == tx.id,
address_record['address']
)) for address_record in records
], timeout=timeout)
if pending:
for record in records:
found = False
_, local_history = await self.get_local_status_and_history(None, history=record['history'])
for txid, local_height in local_history:
if txid == tx.id and local_height >= height:
found = True
if not found:
print(record['history'], addresses, tx.id)
raise asyncio.TimeoutError('Timed out waiting for transaction.')
async def _inflate_outputs(self, query, accounts): async def _inflate_outputs(self, query, accounts):
outputs = Outputs.from_base64(await query) outputs = Outputs.from_base64(await query)
txs = [] txs = []
@ -103,16 +669,6 @@ class MainNetLedger(BaseLedger):
for claim in (await self.claim_search(accounts, claim_id=claim_id))[0]: for claim in (await self.claim_search(accounts, claim_id=claim_id))[0]:
return claim return claim
async def start(self):
await super().start()
await asyncio.gather(*(a.maybe_migrate_certificates() for a in self.accounts))
await asyncio.gather(*(a.save_max_gap() for a in self.accounts))
if len(self.accounts) > 10:
log.info("Loaded %i accounts", len(self.accounts))
else:
await self._report_state()
self.on_transaction.listen(self._reset_balance_cache)
async def _report_state(self): async def _report_state(self):
try: try:
for account in self.accounts: for account in self.accounts:
@ -147,14 +703,6 @@ class MainNetLedger(BaseLedger):
def constraint_spending_utxos(constraints): def constraint_spending_utxos(constraints):
constraints['txo_type__in'] = (0, TXO_TYPES['purchase']) constraints['txo_type__in'] = (0, TXO_TYPES['purchase'])
def get_utxos(self, **constraints):
self.constraint_spending_utxos(constraints)
return super().get_utxos(**constraints)
def get_utxo_count(self, **constraints):
self.constraint_spending_utxos(constraints)
return super().get_utxo_count(**constraints)
async def get_purchases(self, resolve=False, **constraints): async def get_purchases(self, resolve=False, **constraints):
purchases = await self.db.get_purchases(**constraints) purchases = await self.db.get_purchases(**constraints)
if resolve: if resolve:
@ -357,7 +905,7 @@ class MainNetLedger(BaseLedger):
return result return result
class TestNetLedger(MainNetLedger): class TestNetLedger(Ledger):
network_name = 'testnet' network_name = 'testnet'
pubkey_address_prefix = bytes((111,)) pubkey_address_prefix = bytes((111,))
script_address_prefix = bytes((196,)) script_address_prefix = bytes((196,))
@ -365,7 +913,7 @@ class TestNetLedger(MainNetLedger):
extended_private_key_prefix = unhexlify('04358394') extended_private_key_prefix = unhexlify('04358394')
class RegTestLedger(MainNetLedger): class RegTestLedger(Ledger):
network_name = 'regtest' network_name = 'regtest'
headers_class = UnvalidatedHeaders headers_class = UnvalidatedHeaders
pubkey_address_prefix = bytes((111,)) pubkey_address_prefix = bytes((111,))

View file

@ -1,43 +1,112 @@
import os import os
import json import json
import typing
import logging import logging
import asyncio
from binascii import unhexlify from binascii import unhexlify
from typing import Optional, List
from decimal import Decimal from decimal import Decimal
from typing import List, Type, MutableSequence, MutableMapping, Optional
from lbry.wallet.client.basemanager import BaseWalletManager
from lbry.wallet.client.wallet import ENCRYPT_ON_DISK
from lbry.wallet.rpc.jsonrpc import CodeMessageError
from lbry.error import KeyFeeAboveMaxAllowedError from lbry.error import KeyFeeAboveMaxAllowedError
from lbry.wallet.dewies import dewies_to_lbc from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager
from lbry.wallet.account import Account
from lbry.wallet.ledger import MainNetLedger
from lbry.wallet.transaction import Transaction, Output
from lbry.wallet.database import WalletDatabase
from lbry.conf import Config from lbry.conf import Config
from .dewies import dewies_to_lbc
from .account import Account
from .ledger import Ledger, LedgerRegistry
from .transaction import Transaction, Output
from .database import Database
from .wallet import Wallet, WalletStorage, ENCRYPT_ON_DISK
from .rpc.jsonrpc import CodeMessageError
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
if typing.TYPE_CHECKING: class WalletManager:
from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager
def __init__(self, wallets: MutableSequence[Wallet] = None,
class LbryWalletManager(BaseWalletManager): ledgers: MutableMapping[Type[Ledger], Ledger] = None) -> None:
self.wallets = wallets or []
def __init__(self, *args, **kwargs): self.ledgers = ledgers or {}
super().__init__(*args, **kwargs) self.running = False
self.config: Optional[Config] = None self.config: Optional[Config] = None
@classmethod
def from_config(cls, config: dict) -> 'WalletManager':
manager = cls()
for ledger_id, ledger_config in config.get('ledgers', {}).items():
manager.get_or_create_ledger(ledger_id, ledger_config)
for wallet_path in config.get('wallets', []):
wallet_storage = WalletStorage(wallet_path)
wallet = Wallet.from_storage(wallet_storage, manager)
manager.wallets.append(wallet)
return manager
def get_or_create_ledger(self, ledger_id, ledger_config=None):
ledger_class = LedgerRegistry.get_ledger_class(ledger_id)
ledger = self.ledgers.get(ledger_class)
if ledger is None:
ledger = ledger_class(ledger_config or {})
self.ledgers[ledger_class] = ledger
return ledger
def import_wallet(self, path):
storage = WalletStorage(path)
wallet = Wallet.from_storage(storage, self)
self.wallets.append(wallet)
return wallet
@property @property
def ledger(self) -> MainNetLedger: def default_wallet(self):
for wallet in self.wallets:
return wallet
@property
def default_account(self):
for wallet in self.wallets:
return wallet.default_account
@property
def accounts(self):
for wallet in self.wallets:
yield from wallet.accounts
async def start(self):
self.running = True
await asyncio.gather(*(
l.start() for l in self.ledgers.values()
))
async def stop(self):
await asyncio.gather(*(
l.stop() for l in self.ledgers.values()
))
self.running = False
def get_wallet_or_default(self, wallet_id: Optional[str]) -> Wallet:
if wallet_id is None:
return self.default_wallet
return self.get_wallet_or_error(wallet_id)
def get_wallet_or_error(self, wallet_id: str) -> Wallet:
for wallet in self.wallets:
if wallet.id == wallet_id:
return wallet
raise ValueError(f"Couldn't find wallet: {wallet_id}.")
@staticmethod
def get_balance(wallet):
accounts = wallet.accounts
if not accounts:
return 0
return accounts[0].ledger.db.get_balance(wallet=wallet, accounts=accounts)
@property
def ledger(self) -> Ledger:
return self.default_account.ledger return self.default_account.ledger
@property @property
def db(self) -> WalletDatabase: def db(self) -> Database:
return self.ledger.db return self.ledger.db
def check_locked(self): def check_locked(self):
@ -194,7 +263,7 @@ class LbryWalletManager(BaseWalletManager):
if 'No such mempool or blockchain transaction.' in e.message: if 'No such mempool or blockchain transaction.' in e.message:
return {'success': False, 'code': 404, 'message': 'transaction not found'} return {'success': False, 'code': 404, 'message': 'transaction not found'}
return {'success': False, 'code': e.code, 'message': e.message} return {'success': False, 'code': e.code, 'message': e.message}
tx = self.ledger.transaction_class(unhexlify(raw)) tx = Transaction(unhexlify(raw))
await self.ledger.maybe_verify_transaction(tx, height) await self.ledger.maybe_verify_transaction(tx, height)
return tx return tx

View file

@ -13,7 +13,7 @@ from secrets import randbelow
import pbkdf2 import pbkdf2
from lbry.crypto.hash import hmac_sha512 from lbry.crypto.hash import hmac_sha512
from lbry.wallet.client.words import english from .words import english
# The hash of the mnemonic seed must begin with this # The hash of the mnemonic seed must begin with this
SEED_PREFIX = b'01' # Standard wallet SEED_PREFIX = b'01' # Standard wallet

View file

@ -1,9 +1,276 @@
import lbry import logging
from lbry.wallet.client.basenetwork import BaseNetwork import asyncio
from time import perf_counter
from operator import itemgetter
from typing import Dict, Optional, Tuple
from lbry import __version__
from lbry.wallet.rpc import RPCSession as BaseClientSession, Connector, RPCError, ProtocolError
from lbry.wallet.stream import StreamController
log = logging.getLogger(__name__)
class Network(BaseNetwork): class ClientSession(BaseClientSession):
PROTOCOL_VERSION = lbry.__version__ def __init__(self, *args, network, server, timeout=30, on_connect_callback=None, **kwargs):
self.network = network
self.server = server
super().__init__(*args, **kwargs)
self._on_disconnect_controller = StreamController()
self.on_disconnected = self._on_disconnect_controller.stream
self.framer.max_size = self.max_errors = 1 << 32
self.bw_limit = -1
self.timeout = timeout
self.max_seconds_idle = timeout * 2
self.response_time: Optional[float] = None
self.connection_latency: Optional[float] = None
self._response_samples = 0
self.pending_amount = 0
self._on_connect_cb = on_connect_callback or (lambda: None)
self.trigger_urgent_reconnect = asyncio.Event()
@property
def available(self):
return not self.is_closing() and self.response_time is not None
@property
def server_address_and_port(self) -> Optional[Tuple[str, int]]:
if not self.transport:
return None
return self.transport.get_extra_info('peername')
async def send_timed_server_version_request(self, args=(), timeout=None):
timeout = timeout or self.timeout
log.debug("send version request to %s:%i", *self.server)
start = perf_counter()
result = await asyncio.wait_for(
super().send_request('server.version', args), timeout=timeout
)
current_response_time = perf_counter() - start
response_sum = (self.response_time or 0) * self._response_samples + current_response_time
self.response_time = response_sum / (self._response_samples + 1)
self._response_samples += 1
return result
async def send_request(self, method, args=()):
self.pending_amount += 1
log.debug("send %s to %s:%i", method, *self.server)
try:
if method == 'server.version':
return await self.send_timed_server_version_request(args, self.timeout)
request = asyncio.ensure_future(super().send_request(method, args))
while not request.done():
done, pending = await asyncio.wait([request], timeout=self.timeout)
if pending:
log.debug("Time since last packet: %s", perf_counter() - self.last_packet_received)
if (perf_counter() - self.last_packet_received) < self.timeout:
continue
log.info("timeout sending %s to %s:%i", method, *self.server)
raise asyncio.TimeoutError
if done:
return request.result()
except (RPCError, ProtocolError) as e:
log.warning("Wallet server (%s:%i) returned an error. Code: %s Message: %s",
*self.server, *e.args)
raise e
except ConnectionError:
log.warning("connection to %s:%i lost", *self.server)
self.synchronous_close()
raise
except asyncio.CancelledError:
log.info("cancelled sending %s to %s:%i", method, *self.server)
self.synchronous_close()
raise
finally:
self.pending_amount -= 1
async def ensure_session(self):
# Handles reconnecting and maintaining a session alive
# TODO: change to 'ping' on newer protocol (above 1.2)
retry_delay = default_delay = 1.0
while True:
try:
if self.is_closing():
await self.create_connection(self.timeout)
await self.ensure_server_version()
self._on_connect_cb()
if (perf_counter() - self.last_send) > self.max_seconds_idle or self.response_time is None:
await self.ensure_server_version()
retry_delay = default_delay
except RPCError as e:
log.warning("Server error, ignoring for 1h: %s:%d -- %s", *self.server, e.message)
retry_delay = 60 * 60
except (asyncio.TimeoutError, OSError):
await self.close()
retry_delay = min(60, retry_delay * 2)
log.debug("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server)
try:
await asyncio.wait_for(self.trigger_urgent_reconnect.wait(), timeout=retry_delay)
except asyncio.TimeoutError:
pass
finally:
self.trigger_urgent_reconnect.clear()
async def ensure_server_version(self, required=None, timeout=3):
required = required or self.network.PROTOCOL_VERSION
return await asyncio.wait_for(
self.send_request('server.version', [__version__, required]), timeout=timeout
)
async def create_connection(self, timeout=6):
connector = Connector(lambda: self, *self.server)
start = perf_counter()
await asyncio.wait_for(connector.create_connection(), timeout=timeout)
self.connection_latency = perf_counter() - start
async def handle_request(self, request):
controller = self.network.subscription_controllers[request.method]
controller.add(request.args)
def connection_lost(self, exc):
log.debug("Connection lost: %s:%d", *self.server)
super().connection_lost(exc)
self.response_time = None
self.connection_latency = None
self._response_samples = 0
self.pending_amount = 0
self._on_disconnect_controller.add(True)
class Network:
PROTOCOL_VERSION = __version__
def __init__(self, ledger):
self.ledger = ledger
self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6))
self.client: Optional[ClientSession] = None
self._switch_task: Optional[asyncio.Task] = None
self.running = False
self.remote_height: int = 0
self._concurrency = asyncio.Semaphore(16)
self._on_connected_controller = StreamController()
self.on_connected = self._on_connected_controller.stream
self._on_header_controller = StreamController(merge_repeated_events=True)
self.on_header = self._on_header_controller.stream
self._on_status_controller = StreamController(merge_repeated_events=True)
self.on_status = self._on_status_controller.stream
self.subscription_controllers = {
'blockchain.headers.subscribe': self._on_header_controller,
'blockchain.address.subscribe': self._on_status_controller,
}
@property
def config(self):
return self.ledger.config
async def switch_forever(self):
while self.running:
if self.is_connected:
await self.client.on_disconnected.first
self.client = None
continue
self.client = await self.session_pool.wait_for_fastest_session()
log.info("Switching to SPV wallet server: %s:%d", *self.client.server)
try:
self._update_remote_height((await self.subscribe_headers(),))
self._on_connected_controller.add(True)
log.info("Subscribed to headers: %s:%d", *self.client.server)
except (asyncio.TimeoutError, ConnectionError):
log.info("Switching to %s:%d timed out, closing and retrying.", *self.client.server)
self.client.synchronous_close()
self.client = None
async def start(self):
self.running = True
self._switch_task = asyncio.ensure_future(self.switch_forever())
# this may become unnecessary when there are no more bugs found,
# but for now it helps understanding log reports
self._switch_task.add_done_callback(lambda _: log.info("Wallet client switching task stopped."))
self.session_pool.start(self.config['default_servers'])
self.on_header.listen(self._update_remote_height)
async def stop(self):
if self.running:
self.running = False
self._switch_task.cancel()
self.session_pool.stop()
@property
def is_connected(self):
return self.client and not self.client.is_closing()
def rpc(self, list_or_method, args, restricted=True):
session = self.client if restricted else self.session_pool.fastest_session
if session and not session.is_closing():
return session.send_request(list_or_method, args)
else:
self.session_pool.trigger_nodelay_connect()
raise ConnectionError("Attempting to send rpc request when connection is not available.")
async def retriable_call(self, function, *args, **kwargs):
async with self._concurrency:
while self.running:
if not self.is_connected:
log.warning("Wallet server unavailable, waiting for it to come back and retry.")
await self.on_connected.first
await self.session_pool.wait_for_fastest_session()
try:
return await function(*args, **kwargs)
except asyncio.TimeoutError:
log.warning("Wallet server call timed out, retrying.")
except ConnectionError:
pass
raise asyncio.CancelledError() # if we got here, we are shutting down
def _update_remote_height(self, header_args):
self.remote_height = header_args[0]["height"]
def get_transaction(self, tx_hash, known_height=None):
# use any server if its old, otherwise restrict to who gave us the history
restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10
return self.rpc('blockchain.transaction.get', [tx_hash], restricted)
def get_transaction_height(self, tx_hash, known_height=None):
restricted = not known_height or 0 > known_height > self.remote_height - 10
return self.rpc('blockchain.transaction.get_height', [tx_hash], restricted)
def get_merkle(self, tx_hash, height):
restricted = 0 > height > self.remote_height - 10
return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height], restricted)
def get_headers(self, height, count=10000, b64=False):
restricted = height >= self.remote_height - 100
return self.rpc('blockchain.block.headers', [height, count, 0, b64], restricted)
# --- Subscribes, history and broadcasts are always aimed towards the master client directly
def get_history(self, address):
return self.rpc('blockchain.address.get_history', [address], True)
def broadcast(self, raw_transaction):
return self.rpc('blockchain.transaction.broadcast', [raw_transaction], True)
def subscribe_headers(self):
return self.rpc('blockchain.headers.subscribe', [True], True)
async def subscribe_address(self, address):
try:
return await self.rpc('blockchain.address.subscribe', [address], True)
except asyncio.TimeoutError:
# abort and cancel, we can't lose a subscription, it will happen again on reconnect
if self.client:
self.client.abort()
raise asyncio.CancelledError()
def unsubscribe_address(self, address):
return self.rpc('blockchain.address.unsubscribe', [address], True)
def get_server_features(self):
return self.rpc('server.features', (), restricted=True)
def get_claims_by_ids(self, claim_ids): def get_claims_by_ids(self, claim_ids):
return self.rpc('blockchain.claimtrie.getclaimsbyids', claim_ids) return self.rpc('blockchain.claimtrie.getclaimsbyids', claim_ids)
@ -13,3 +280,95 @@ class Network(BaseNetwork):
def claim_search(self, **kwargs): def claim_search(self, **kwargs):
return self.rpc('blockchain.claimtrie.search', kwargs) return self.rpc('blockchain.claimtrie.search', kwargs)
class SessionPool:
def __init__(self, network: Network, timeout: float):
self.network = network
self.sessions: Dict[ClientSession, Optional[asyncio.Task]] = dict()
self.timeout = timeout
self.new_connection_event = asyncio.Event()
@property
def online(self):
return any(not session.is_closing() for session in self.sessions)
@property
def available_sessions(self):
return (session for session in self.sessions if session.available)
@property
def fastest_session(self):
if not self.online:
return None
return min(
[((session.response_time + session.connection_latency) * (session.pending_amount + 1), session)
for session in self.available_sessions] or [(0, None)],
key=itemgetter(0)
)[1]
def _get_session_connect_callback(self, session: ClientSession):
loop = asyncio.get_event_loop()
def callback():
duplicate_connections = [
s for s in self.sessions
if s is not session and s.server_address_and_port == session.server_address_and_port
]
already_connected = None if not duplicate_connections else duplicate_connections[0]
if already_connected:
self.sessions.pop(session).cancel()
session.synchronous_close()
log.debug("wallet server %s resolves to the same server as %s, rechecking in an hour",
session.server[0], already_connected.server[0])
loop.call_later(3600, self._connect_session, session.server)
return
self.new_connection_event.set()
log.info("connected to %s:%i", *session.server)
return callback
def _connect_session(self, server: Tuple[str, int]):
session = None
for s in self.sessions:
if s.server == server:
session = s
break
if not session:
session = ClientSession(
network=self.network, server=server
)
session._on_connect_cb = self._get_session_connect_callback(session)
task = self.sessions.get(session, None)
if not task or task.done():
task = asyncio.create_task(session.ensure_session())
task.add_done_callback(lambda _: self.ensure_connections())
self.sessions[session] = task
def start(self, default_servers):
for server in default_servers:
self._connect_session(server)
def stop(self):
for session, task in self.sessions.items():
task.cancel()
session.synchronous_close()
self.sessions.clear()
def ensure_connections(self):
for session in self.sessions:
self._connect_session(session.server)
def trigger_nodelay_connect(self):
# used when other parts of the system sees we might have internet back
# bypasses the retry interval
for session in self.sessions:
session.trigger_urgent_reconnect.set()
async def wait_for_fastest_session(self):
while not self.fastest_session:
self.trigger_nodelay_connect()
self.new_connection_event.clear()
await self.new_connection_event.wait()
return self.fastest_session

View file

@ -12,28 +12,15 @@ from binascii import hexlify
from typing import Type, Optional from typing import Type, Optional
import urllib.request import urllib.request
import lbry
from lbry.wallet.server.server import Server from lbry.wallet.server.server import Server
from lbry.wallet.server.env import Env from lbry.wallet.server.env import Env
from lbry.wallet.client.wallet import Wallet from lbry.wallet import Wallet, Ledger, RegTestLedger, WalletManager, Account, BlockHeightEvent
from lbry.wallet.client.baseledger import BaseLedger, BlockHeightEvent
from lbry.wallet.client.basemanager import BaseWalletManager
from lbry.wallet.client.baseaccount import BaseAccount
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def get_manager_from_environment(default_manager=BaseWalletManager):
if 'TORBA_MANAGER' not in os.environ:
return default_manager
module_name = os.environ['TORBA_MANAGER'].split('-')[-1] # tox support
return importlib.import_module(module_name)
def get_ledger_from_environment():
return importlib.import_module('lbry.wallet')
def get_spvserver_from_ledger(ledger_module): def get_spvserver_from_ledger(ledger_module):
spvserver_path, regtest_class_name = ledger_module.__spvserver__.rsplit('.', 1) spvserver_path, regtest_class_name = ledger_module.__spvserver__.rsplit('.', 1)
spvserver_module = importlib.import_module(spvserver_path) spvserver_module = importlib.import_module(spvserver_path)
@ -50,16 +37,14 @@ def get_blockchain_node_from_ledger(ledger_module):
class Conductor: class Conductor:
def __init__(self, ledger_module=None, manager_module=None, enable_segwit=False, seed=None): def __init__(self, seed=None):
self.ledger_module = ledger_module or get_ledger_from_environment() self.manager_module = WalletManager
self.manager_module = manager_module or get_manager_from_environment() self.spv_module = get_spvserver_from_ledger(lbry.wallet)
self.spv_module = get_spvserver_from_ledger(self.ledger_module)
self.blockchain_node = get_blockchain_node_from_ledger(self.ledger_module) self.blockchain_node = get_blockchain_node_from_ledger(lbry.wallet)
self.blockchain_node.segwit_enabled = enable_segwit
self.spv_node = SPVNode(self.spv_module) self.spv_node = SPVNode(self.spv_module)
self.wallet_node = WalletNode( self.wallet_node = WalletNode(
self.manager_module, self.ledger_module.RegTestLedger, default_seed=seed self.manager_module, RegTestLedger, default_seed=seed
) )
self.blockchain_started = False self.blockchain_started = False
@ -119,15 +104,15 @@ class Conductor:
class WalletNode: class WalletNode:
def __init__(self, manager_class: Type[BaseWalletManager], ledger_class: Type[BaseLedger], def __init__(self, manager_class: Type[WalletManager], ledger_class: Type[Ledger],
verbose: bool = False, port: int = 5280, default_seed: str = None) -> None: verbose: bool = False, port: int = 5280, default_seed: str = None) -> None:
self.manager_class = manager_class self.manager_class = manager_class
self.ledger_class = ledger_class self.ledger_class = ledger_class
self.verbose = verbose self.verbose = verbose
self.manager: Optional[BaseWalletManager] = None self.manager: Optional[WalletManager] = None
self.ledger: Optional[BaseLedger] = None self.ledger: Optional[Ledger] = None
self.wallet: Optional[Wallet] = None self.wallet: Optional[Wallet] = None
self.account: Optional[BaseAccount] = None self.account: Optional[Account] = None
self.data_path: Optional[str] = None self.data_path: Optional[str] = None
self.port = port self.port = port
self.default_seed = default_seed self.default_seed = default_seed
@ -154,7 +139,7 @@ class WalletNode:
if not self.wallet: if not self.wallet:
raise ValueError('Wallet is required.') raise ValueError('Wallet is required.')
if seed or self.default_seed: if seed or self.default_seed:
self.ledger.account_class.from_dict( Account.from_dict(
self.ledger, self.wallet, {'seed': seed or self.default_seed} self.ledger, self.wallet, {'seed': seed or self.default_seed}
) )
else: else:
@ -250,7 +235,7 @@ class BlockchainNode:
P2SH_SEGWIT_ADDRESS = "p2sh-segwit" P2SH_SEGWIT_ADDRESS = "p2sh-segwit"
BECH32_ADDRESS = "bech32" BECH32_ADDRESS = "bech32"
def __init__(self, url, daemon, cli, segwit_enabled=False): def __init__(self, url, daemon, cli):
self.latest_release_url = url self.latest_release_url = url
self.project_dir = os.path.dirname(os.path.dirname(__file__)) self.project_dir = os.path.dirname(os.path.dirname(__file__))
self.bin_dir = os.path.join(self.project_dir, 'bin') self.bin_dir = os.path.join(self.project_dir, 'bin')
@ -266,7 +251,6 @@ class BlockchainNode:
self.rpcport = 9245 + 2 # avoid conflict with default rpc port self.rpcport = 9245 + 2 # avoid conflict with default rpc port
self.rpcuser = 'rpcuser' self.rpcuser = 'rpcuser'
self.rpcpassword = 'rpcpassword' self.rpcpassword = 'rpcpassword'
self.segwit_enabled = segwit_enabled
@property @property
def rpc_url(self): def rpc_url(self):
@ -326,8 +310,6 @@ class BlockchainNode:
f'-rpcuser={self.rpcuser}', f'-rpcpassword={self.rpcpassword}', f'-rpcport={self.rpcport}', f'-rpcuser={self.rpcuser}', f'-rpcpassword={self.rpcpassword}', f'-rpcport={self.rpcport}',
f'-port={self.peerport}' f'-port={self.peerport}'
] ]
if not self.segwit_enabled:
command.extend(['-addresstype=legacy', '-vbparams=segwit:0:999999999999'])
self.log.info(' '.join(command)) self.log.info(' '.join(command))
self.transport, self.protocol = await loop.subprocess_exec( self.transport, self.protocol = await loop.subprocess_exec(
BlockchainProcess, *command BlockchainProcess, *command

View file

@ -3,7 +3,7 @@ import logging
from aiohttp.web import Application, WebSocketResponse, json_response from aiohttp.web import Application, WebSocketResponse, json_response
from aiohttp.http_websocket import WSMsgType, WSCloseCode from aiohttp.http_websocket import WSMsgType, WSCloseCode
from lbry.wallet.client.util import satoshis_to_coins from lbry.wallet.util import satoshis_to_coins
from .node import Conductor from .node import Conductor

View file

@ -1,34 +1,430 @@
from lbry.wallet.client.basescript import BaseInputScript, BaseOutputScript, Template from typing import List
from lbry.wallet.client.basescript import PUSH_SINGLE, PUSH_INTEGER, OP_DROP, OP_2DROP, PUSH_SUBSCRIPT, OP_VERIFY from itertools import chain
from binascii import hexlify
from collections import namedtuple
from .bcd_data_stream import BCDataStream
from .util import subclass_tuple
class InputScript(BaseInputScript): # bitcoin opcodes
pass OP_0 = 0x00
OP_1 = 0x51
OP_16 = 0x60
OP_VERIFY = 0x69
OP_DUP = 0x76
OP_HASH160 = 0xa9
OP_EQUALVERIFY = 0x88
OP_CHECKSIG = 0xac
OP_CHECKMULTISIG = 0xae
OP_EQUAL = 0x87
OP_PUSHDATA1 = 0x4c
OP_PUSHDATA2 = 0x4d
OP_PUSHDATA4 = 0x4e
OP_RETURN = 0x6a
OP_2DROP = 0x6d
OP_DROP = 0x75
# lbry custom opcodes
# checks
OP_PRICECHECK = 0xb0 # checks that the BUY output is >= SELL price
# tx types
OP_CLAIM_NAME = 0xb5
OP_SUPPORT_CLAIM = 0xb6
OP_UPDATE_CLAIM = 0xb7
OP_SELL_CLAIM = 0xb8
OP_BUY_CLAIM = 0xb9
# template matching opcodes (not real opcodes)
# base class for PUSH_DATA related opcodes
# pylint: disable=invalid-name
PUSH_DATA_OP = namedtuple('PUSH_DATA_OP', 'name')
# opcode for variable length strings
# pylint: disable=invalid-name
PUSH_SINGLE = subclass_tuple('PUSH_SINGLE', PUSH_DATA_OP)
# opcode for variable size integers
# pylint: disable=invalid-name
PUSH_INTEGER = subclass_tuple('PUSH_INTEGER', PUSH_DATA_OP)
# opcode for variable number of variable length strings
# pylint: disable=invalid-name
PUSH_MANY = subclass_tuple('PUSH_MANY', PUSH_DATA_OP)
# opcode with embedded subscript parsing
# pylint: disable=invalid-name
PUSH_SUBSCRIPT = namedtuple('PUSH_SUBSCRIPT', 'name template')
class OutputScript(BaseOutputScript): def is_push_data_opcode(opcode):
return isinstance(opcode, (PUSH_DATA_OP, PUSH_SUBSCRIPT))
# lbry custom opcodes
# checks def is_push_data_token(token):
OP_PRICECHECK = 0xb0 # checks that the BUY output is >= SELL price return 1 <= token <= OP_PUSHDATA4
# tx types
OP_CLAIM_NAME = 0xb5 def push_data(data):
OP_SUPPORT_CLAIM = 0xb6 size = len(data)
OP_UPDATE_CLAIM = 0xb7 if size < OP_PUSHDATA1:
OP_SELL_CLAIM = 0xb8 yield BCDataStream.uint8.pack(size)
OP_BUY_CLAIM = 0xb9 elif size <= 0xFF:
yield BCDataStream.uint8.pack(OP_PUSHDATA1)
yield BCDataStream.uint8.pack(size)
elif size <= 0xFFFF:
yield BCDataStream.uint8.pack(OP_PUSHDATA2)
yield BCDataStream.uint16.pack(size)
else:
yield BCDataStream.uint8.pack(OP_PUSHDATA4)
yield BCDataStream.uint32.pack(size)
yield bytes(data)
def read_data(token, stream):
if token < OP_PUSHDATA1:
return stream.read(token)
if token == OP_PUSHDATA1:
return stream.read(stream.read_uint8())
if token == OP_PUSHDATA2:
return stream.read(stream.read_uint16())
return stream.read(stream.read_uint32())
# opcode for OP_1 - OP_16
# pylint: disable=invalid-name
SMALL_INTEGER = namedtuple('SMALL_INTEGER', 'name')
def is_small_integer(token):
return OP_1 <= token <= OP_16
def push_small_integer(num):
assert 1 <= num <= 16
yield BCDataStream.uint8.pack(OP_1 + (num - 1))
def read_small_integer(token):
return (token - OP_1) + 1
class Token(namedtuple('Token', 'value')):
__slots__ = ()
def __repr__(self):
name = None
for var_name, var_value in globals().items():
if var_name.startswith('OP_') and var_value == self.value:
name = var_name
break
return name or self.value
class DataToken(Token):
__slots__ = ()
def __repr__(self):
return f'"{hexlify(self.value)}"'
class SmallIntegerToken(Token):
__slots__ = ()
def __repr__(self):
return f'SmallIntegerToken({self.value})'
def token_producer(source):
token = source.read_uint8()
while token is not None:
if is_push_data_token(token):
yield DataToken(read_data(token, source))
elif is_small_integer(token):
yield SmallIntegerToken(read_small_integer(token))
else:
yield Token(token)
token = source.read_uint8()
def tokenize(source):
return list(token_producer(source))
class ScriptError(Exception):
""" General script handling error. """
class ParseError(ScriptError):
""" Script parsing error. """
class Parser:
def __init__(self, opcodes, tokens):
self.opcodes = opcodes
self.tokens = tokens
self.values = {}
self.token_index = 0
self.opcode_index = 0
def parse(self):
while self.token_index < len(self.tokens) and self.opcode_index < len(self.opcodes):
token = self.tokens[self.token_index]
opcode = self.opcodes[self.opcode_index]
if token.value == 0 and isinstance(opcode, PUSH_SINGLE):
token = DataToken(b'')
if isinstance(token, DataToken):
if isinstance(opcode, (PUSH_SINGLE, PUSH_INTEGER, PUSH_SUBSCRIPT)):
self.push_single(opcode, token.value)
elif isinstance(opcode, PUSH_MANY):
self.consume_many_non_greedy()
else:
raise ParseError(f"DataToken found but opcode was '{opcode}'.")
elif isinstance(token, SmallIntegerToken):
if isinstance(opcode, SMALL_INTEGER):
self.values[opcode.name] = token.value
else:
raise ParseError(f"SmallIntegerToken found but opcode was '{opcode}'.")
elif token.value == opcode:
pass
else:
raise ParseError(f"Token is '{token.value}' and opcode is '{opcode}'.")
self.token_index += 1
self.opcode_index += 1
if self.token_index < len(self.tokens):
raise ParseError("Parse completed without all tokens being consumed.")
if self.opcode_index < len(self.opcodes):
raise ParseError("Parse completed without all opcodes being consumed.")
return self
def consume_many_non_greedy(self):
""" Allows PUSH_MANY to consume data without being greedy
in cases when one or more PUSH_SINGLEs follow a PUSH_MANY. This will
prioritize giving all PUSH_SINGLEs some data and only after that
subsume the rest into PUSH_MANY.
"""
token_values = []
while self.token_index < len(self.tokens):
token = self.tokens[self.token_index]
if not isinstance(token, DataToken):
self.token_index -= 1
break
token_values.append(token.value)
self.token_index += 1
push_opcodes = []
push_many_count = 0
while self.opcode_index < len(self.opcodes):
opcode = self.opcodes[self.opcode_index]
if not is_push_data_opcode(opcode):
self.opcode_index -= 1
break
if isinstance(opcode, PUSH_MANY):
push_many_count += 1
push_opcodes.append(opcode)
self.opcode_index += 1
if push_many_count > 1:
raise ParseError(
"Cannot have more than one consecutive PUSH_MANY, as there is no way to tell which"
" token value should go into which PUSH_MANY."
)
if len(push_opcodes) > len(token_values):
raise ParseError(
"Not enough token values to match all of the PUSH_MANY and PUSH_SINGLE opcodes."
)
many_opcode = push_opcodes.pop(0)
# consume data into PUSH_SINGLE opcodes, working backwards
for opcode in reversed(push_opcodes):
self.push_single(opcode, token_values.pop())
# finally PUSH_MANY gets everything that's left
self.values[many_opcode.name] = token_values
def push_single(self, opcode, value):
if isinstance(opcode, PUSH_SINGLE):
self.values[opcode.name] = value
elif isinstance(opcode, PUSH_INTEGER):
self.values[opcode.name] = int.from_bytes(value, 'little')
elif isinstance(opcode, PUSH_SUBSCRIPT):
self.values[opcode.name] = Script.from_source_with_template(value, opcode.template)
else:
raise ParseError(f"Not a push single or subscript: {opcode}")
class Template:
__slots__ = 'name', 'opcodes'
def __init__(self, name, opcodes):
self.name = name
self.opcodes = opcodes
def parse(self, tokens):
return Parser(self.opcodes, tokens).parse().values if self.opcodes else {}
def generate(self, values):
source = BCDataStream()
for opcode in self.opcodes:
if isinstance(opcode, PUSH_SINGLE):
data = values[opcode.name]
source.write_many(push_data(data))
elif isinstance(opcode, PUSH_INTEGER):
data = values[opcode.name]
source.write_many(push_data(
data.to_bytes((data.bit_length() + 7) // 8, byteorder='little')
))
elif isinstance(opcode, PUSH_SUBSCRIPT):
data = values[opcode.name]
source.write_many(push_data(data.source))
elif isinstance(opcode, PUSH_MANY):
for data in values[opcode.name]:
source.write_many(push_data(data))
elif isinstance(opcode, SMALL_INTEGER):
data = values[opcode.name]
source.write_many(push_small_integer(data))
else:
source.write_uint8(opcode)
return source.get_bytes()
class Script:
__slots__ = 'source', '_template', '_values', '_template_hint'
templates: List[Template] = []
NO_SCRIPT = Template('no_script', None) # special case
def __init__(self, source=None, template=None, values=None, template_hint=None):
self.source = source
self._template = template
self._values = values
self._template_hint = template_hint
if source is None and template and values:
self.generate()
@property
def template(self):
if self._template is None:
self.parse(self._template_hint)
return self._template
@property
def values(self):
if self._values is None:
self.parse(self._template_hint)
return self._values
@property
def tokens(self):
return tokenize(BCDataStream(self.source))
@classmethod
def from_source_with_template(cls, source, template):
return cls(source, template_hint=template)
def parse(self, template_hint=None):
tokens = self.tokens
if not tokens and not template_hint:
template_hint = self.NO_SCRIPT
for template in chain((template_hint,), self.templates):
if not template:
continue
try:
self._values = template.parse(tokens)
self._template = template
return
except ParseError:
continue
raise ValueError(f'No matching templates for source: {hexlify(self.source)}')
def generate(self):
self.source = self.template.generate(self._values)
class InputScript(Script):
__slots__ = ()
REDEEM_PUBKEY = Template('pubkey', (
PUSH_SINGLE('signature'),
))
REDEEM_PUBKEY_HASH = Template('pubkey_hash', (
PUSH_SINGLE('signature'), PUSH_SINGLE('pubkey')
))
REDEEM_SCRIPT = Template('script', (
SMALL_INTEGER('signatures_count'), PUSH_MANY('pubkeys'), SMALL_INTEGER('pubkeys_count'),
OP_CHECKMULTISIG
))
REDEEM_SCRIPT_HASH = Template('script_hash', (
OP_0, PUSH_MANY('signatures'), PUSH_SUBSCRIPT('script', REDEEM_SCRIPT)
))
templates = [
REDEEM_PUBKEY,
REDEEM_PUBKEY_HASH,
REDEEM_SCRIPT_HASH,
REDEEM_SCRIPT
]
@classmethod
def redeem_pubkey_hash(cls, signature, pubkey):
return cls(template=cls.REDEEM_PUBKEY_HASH, values={
'signature': signature,
'pubkey': pubkey
})
@classmethod
def redeem_script_hash(cls, signatures, pubkeys):
return cls(template=cls.REDEEM_SCRIPT_HASH, values={
'signatures': signatures,
'script': cls.redeem_script(signatures, pubkeys)
})
@classmethod
def redeem_script(cls, signatures, pubkeys):
return cls(template=cls.REDEEM_SCRIPT, values={
'signatures_count': len(signatures),
'pubkeys': pubkeys,
'pubkeys_count': len(pubkeys)
})
class OutputScript(Script):
__slots__ = ()
# output / payment script templates (aka scriptPubKey)
PAY_PUBKEY_FULL = Template('pay_pubkey_full', (
PUSH_SINGLE('pubkey'), OP_CHECKSIG
))
PAY_PUBKEY_HASH = Template('pay_pubkey_hash', (
OP_DUP, OP_HASH160, PUSH_SINGLE('pubkey_hash'), OP_EQUALVERIFY, OP_CHECKSIG
))
PAY_SCRIPT_HASH = Template('pay_script_hash', (
OP_HASH160, PUSH_SINGLE('script_hash'), OP_EQUAL
))
PAY_SEGWIT = Template('pay_script_hash+segwit', (
OP_0, PUSH_SINGLE('script_hash')
))
RETURN_DATA = Template('return_data', (
OP_RETURN, PUSH_SINGLE('data')
))
CLAIM_NAME_OPCODES = ( CLAIM_NAME_OPCODES = (
OP_CLAIM_NAME, PUSH_SINGLE('claim_name'), PUSH_SINGLE('claim'), OP_CLAIM_NAME, PUSH_SINGLE('claim_name'), PUSH_SINGLE('claim'),
OP_2DROP, OP_DROP OP_2DROP, OP_DROP
) )
CLAIM_NAME_PUBKEY = Template('claim_name+pay_pubkey_hash', ( CLAIM_NAME_PUBKEY = Template('claim_name+pay_pubkey_hash', (
CLAIM_NAME_OPCODES + BaseOutputScript.PAY_PUBKEY_HASH.opcodes CLAIM_NAME_OPCODES + PAY_PUBKEY_HASH.opcodes
)) ))
CLAIM_NAME_SCRIPT = Template('claim_name+pay_script_hash', ( CLAIM_NAME_SCRIPT = Template('claim_name+pay_script_hash', (
CLAIM_NAME_OPCODES + BaseOutputScript.PAY_SCRIPT_HASH.opcodes CLAIM_NAME_OPCODES + PAY_SCRIPT_HASH.opcodes
)) ))
SUPPORT_CLAIM_OPCODES = ( SUPPORT_CLAIM_OPCODES = (
@ -36,10 +432,10 @@ class OutputScript(BaseOutputScript):
OP_2DROP, OP_DROP OP_2DROP, OP_DROP
) )
SUPPORT_CLAIM_PUBKEY = Template('support_claim+pay_pubkey_hash', ( SUPPORT_CLAIM_PUBKEY = Template('support_claim+pay_pubkey_hash', (
SUPPORT_CLAIM_OPCODES + BaseOutputScript.PAY_PUBKEY_HASH.opcodes SUPPORT_CLAIM_OPCODES + PAY_PUBKEY_HASH.opcodes
)) ))
SUPPORT_CLAIM_SCRIPT = Template('support_claim+pay_script_hash', ( SUPPORT_CLAIM_SCRIPT = Template('support_claim+pay_script_hash', (
SUPPORT_CLAIM_OPCODES + BaseOutputScript.PAY_SCRIPT_HASH.opcodes SUPPORT_CLAIM_OPCODES + PAY_SCRIPT_HASH.opcodes
)) ))
UPDATE_CLAIM_OPCODES = ( UPDATE_CLAIM_OPCODES = (
@ -47,10 +443,10 @@ class OutputScript(BaseOutputScript):
OP_2DROP, OP_2DROP OP_2DROP, OP_2DROP
) )
UPDATE_CLAIM_PUBKEY = Template('update_claim+pay_pubkey_hash', ( UPDATE_CLAIM_PUBKEY = Template('update_claim+pay_pubkey_hash', (
UPDATE_CLAIM_OPCODES + BaseOutputScript.PAY_PUBKEY_HASH.opcodes UPDATE_CLAIM_OPCODES + PAY_PUBKEY_HASH.opcodes
)) ))
UPDATE_CLAIM_SCRIPT = Template('update_claim+pay_script_hash', ( UPDATE_CLAIM_SCRIPT = Template('update_claim+pay_script_hash', (
UPDATE_CLAIM_OPCODES + BaseOutputScript.PAY_SCRIPT_HASH.opcodes UPDATE_CLAIM_OPCODES + PAY_SCRIPT_HASH.opcodes
)) ))
SELL_SCRIPT = Template('sell_script', ( SELL_SCRIPT = Template('sell_script', (
@ -58,17 +454,22 @@ class OutputScript(BaseOutputScript):
)) ))
SELL_CLAIM = Template('sell_claim+pay_script_hash', ( SELL_CLAIM = Template('sell_claim+pay_script_hash', (
OP_SELL_CLAIM, PUSH_SINGLE('claim_id'), PUSH_SUBSCRIPT('sell_script', SELL_SCRIPT), OP_SELL_CLAIM, PUSH_SINGLE('claim_id'), PUSH_SUBSCRIPT('sell_script', SELL_SCRIPT),
PUSH_SUBSCRIPT('receive_script', BaseInputScript.REDEEM_SCRIPT), OP_2DROP, OP_2DROP PUSH_SUBSCRIPT('receive_script', InputScript.REDEEM_SCRIPT), OP_2DROP, OP_2DROP
) + BaseOutputScript.PAY_SCRIPT_HASH.opcodes) ) + PAY_SCRIPT_HASH.opcodes)
BUY_CLAIM = Template('buy_claim+pay_script_hash', ( BUY_CLAIM = Template('buy_claim+pay_script_hash', (
OP_BUY_CLAIM, PUSH_SINGLE('sell_id'), OP_BUY_CLAIM, PUSH_SINGLE('sell_id'),
PUSH_SINGLE('claim_id'), PUSH_SINGLE('claim_version'), PUSH_SINGLE('claim_id'), PUSH_SINGLE('claim_version'),
PUSH_SINGLE('owner_pubkey_hash'), PUSH_SINGLE('negotiation_signature'), PUSH_SINGLE('owner_pubkey_hash'), PUSH_SINGLE('negotiation_signature'),
OP_2DROP, OP_2DROP, OP_2DROP, OP_2DROP, OP_2DROP, OP_2DROP,
) + BaseOutputScript.PAY_SCRIPT_HASH.opcodes) ) + PAY_SCRIPT_HASH.opcodes)
templates = BaseOutputScript.templates + [ templates = [
PAY_PUBKEY_FULL,
PAY_PUBKEY_HASH,
PAY_SCRIPT_HASH,
PAY_SEGWIT,
RETURN_DATA,
CLAIM_NAME_PUBKEY, CLAIM_NAME_PUBKEY,
CLAIM_NAME_SCRIPT, CLAIM_NAME_SCRIPT,
SUPPORT_CLAIM_PUBKEY, SUPPORT_CLAIM_PUBKEY,
@ -79,6 +480,28 @@ class OutputScript(BaseOutputScript):
BUY_CLAIM, BUY_CLAIM,
] ]
@classmethod
def pay_pubkey_hash(cls, pubkey_hash):
return cls(template=cls.PAY_PUBKEY_HASH, values={
'pubkey_hash': pubkey_hash
})
@classmethod
def pay_script_hash(cls, script_hash):
return cls(template=cls.PAY_SCRIPT_HASH, values={
'script_hash': script_hash
})
@classmethod
def return_data(cls, data):
return cls(template=cls.RETURN_DATA, values={
'data': data
})
@property
def is_pay_pubkey(self):
return self.template.name.endswith('pay_pubkey_full')
@classmethod @classmethod
def pay_claim_name_pubkey_hash(cls, claim_name, claim, pubkey_hash): def pay_claim_name_pubkey_hash(cls, claim_name, claim, pubkey_hash):
return cls(template=cls.CLAIM_NAME_PUBKEY, values={ return cls(template=cls.CLAIM_NAME_PUBKEY, values={
@ -128,6 +551,18 @@ class OutputScript(BaseOutputScript):
'negotiation_signature': negotiation_signature, 'negotiation_signature': negotiation_signature,
}) })
@property
def is_pay_pubkey_hash(self):
return self.template.name.endswith('pay_pubkey_hash')
@property
def is_pay_script_hash(self):
return self.template.name.endswith('pay_script_hash')
@property
def is_return_data(self):
return self.template.name.endswith('return_data')
@property @property
def is_claim_name(self): def is_claim_name(self):
return self.template.name.startswith('claim_name+') return self.template.name.startswith('claim_name+')

View file

@ -6,7 +6,7 @@ from decimal import Decimal
from collections import namedtuple from collections import namedtuple
import lbry.wallet.server.tx as lib_tx import lbry.wallet.server.tx as lib_tx
from lbry.wallet.script import OutputScript from lbry.wallet.script import OutputScript, OP_CLAIM_NAME, OP_UPDATE_CLAIM, OP_SUPPORT_CLAIM
from lbry.wallet.server.tx import DeserializerSegWit from lbry.wallet.server.tx import DeserializerSegWit
from lbry.wallet.server.util import cachedproperty, subclasses from lbry.wallet.server.util import cachedproperty, subclasses
from lbry.wallet.server.hash import Base58, hash160, double_sha256, hash_to_hex_str, HASHX_LEN from lbry.wallet.server.hash import Base58, hash160, double_sha256, hash_to_hex_str, HASHX_LEN
@ -327,9 +327,9 @@ class LBC(Coin):
if script and script[0] == OpCodes.OP_RETURN or not script: if script and script[0] == OpCodes.OP_RETURN or not script:
return None return None
if script[0] in [ if script[0] in [
OutputScript.OP_CLAIM_NAME, OP_CLAIM_NAME,
OutputScript.OP_UPDATE_CLAIM, OP_UPDATE_CLAIM,
OutputScript.OP_SUPPORT_CLAIM, OP_SUPPORT_CLAIM,
]: ]:
return cls.address_to_hashX(cls.claim_address_handler(script)) return cls.address_to_hashX(cls.claim_address_handler(script))
else: else:

View file

@ -1,4 +1,4 @@
from lbry.wallet.client.basedatabase import constraints_to_sql from lbry.wallet.database import constraints_to_sql
CREATE_FULL_TEXT_SEARCH = """ CREATE_FULL_TEXT_SEARCH = """
create virtual table if not exists search using fts5( create virtual table if not exists search using fts5(

View file

@ -10,12 +10,12 @@ from contextvars import ContextVar
from functools import wraps from functools import wraps
from dataclasses import dataclass from dataclasses import dataclass
from lbry.wallet.client.basedatabase import query, interpolate from lbry.wallet.database import query, interpolate
from lbry.schema.url import URL, normalize_name from lbry.schema.url import URL, normalize_name
from lbry.schema.tags import clean_tags from lbry.schema.tags import clean_tags
from lbry.schema.result import Outputs from lbry.schema.result import Outputs
from lbry.wallet.ledger import BaseLedger, MainNetLedger, RegTestLedger from lbry.wallet import Ledger, RegTestLedger
from .common import CLAIM_TYPES, STREAM_TYPES, COMMON_TAGS from .common import CLAIM_TYPES, STREAM_TYPES, COMMON_TAGS
from .full_text_search import FTS_ORDER_BY from .full_text_search import FTS_ORDER_BY
@ -67,7 +67,7 @@ class ReaderState:
stack: List[List] stack: List[List]
metrics: Dict metrics: Dict
is_tracking_metrics: bool is_tracking_metrics: bool
ledger: Type[BaseLedger] ledger: Type[Ledger]
query_timeout: float query_timeout: float
log: logging.Logger log: logging.Logger
@ -100,7 +100,7 @@ def initializer(log, _path, _ledger_name, query_timeout, _measure=False):
ctx.set( ctx.set(
ReaderState( ReaderState(
db=db, stack=[], metrics={}, is_tracking_metrics=_measure, db=db, stack=[], metrics={}, is_tracking_metrics=_measure,
ledger=MainNetLedger if _ledger_name == 'mainnet' else RegTestLedger, ledger=Ledger if _ledger_name == 'mainnet' else RegTestLedger,
query_timeout=query_timeout, log=log query_timeout=query_timeout, log=log
) )
) )

View file

@ -7,11 +7,11 @@ from collections import namedtuple
from lbry.wallet.server.leveldb import DB from lbry.wallet.server.leveldb import DB
from lbry.wallet.server.util import class_logger from lbry.wallet.server.util import class_logger
from lbry.wallet.client.basedatabase import query, constraints_to_sql from lbry.wallet.database import query, constraints_to_sql
from lbry.schema.tags import clean_tags from lbry.schema.tags import clean_tags
from lbry.schema.mime_types import guess_stream_type from lbry.schema.mime_types import guess_stream_type
from lbry.wallet.ledger import MainNetLedger, RegTestLedger from lbry.wallet import Ledger, RegTestLedger
from lbry.wallet.transaction import Transaction, Output from lbry.wallet.transaction import Transaction, Output
from lbry.wallet.server.db.canonical import register_canonical_functions from lbry.wallet.server.db.canonical import register_canonical_functions
from lbry.wallet.server.db.full_text_search import update_full_text_search, CREATE_FULL_TEXT_SEARCH, first_sync_finished from lbry.wallet.server.db.full_text_search import update_full_text_search, CREATE_FULL_TEXT_SEARCH, first_sync_finished
@ -171,7 +171,7 @@ class SQLDB:
self._db_path = path self._db_path = path
self.db = None self.db = None
self.logger = class_logger(__name__, self.__class__.__name__) self.logger = class_logger(__name__, self.__class__.__name__)
self.ledger = MainNetLedger if self.main.coin.NET == 'mainnet' else RegTestLedger self.ledger = Ledger if self.main.coin.NET == 'mainnet' else RegTestLedger
self._fts_synced = False self._fts_synced = False
def open(self): def open(self):

View file

@ -1,9 +1,12 @@
import ecdsa
import struct import struct
import hashlib import hashlib
from binascii import hexlify, unhexlify import logging
from typing import List, Optional import typing
from binascii import hexlify, unhexlify
from typing import List, Iterable, Optional, Tuple
import ecdsa
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.serialization import load_der_public_key from cryptography.hazmat.primitives.serialization import load_der_public_key
from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import hashes
@ -11,34 +14,216 @@ from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.asymmetric.utils import Prehashed from cryptography.hazmat.primitives.asymmetric.utils import Prehashed
from cryptography.exceptions import InvalidSignature from cryptography.exceptions import InvalidSignature
from lbry.crypto.base58 import Base58 from lbry.error import InsufficientFundsError
from lbry.crypto.hash import hash160, sha256 from lbry.crypto.hash import hash160, sha256
from lbry.wallet.client.basetransaction import BaseTransaction, BaseInput, BaseOutput, ReadOnlyList from lbry.crypto.base58 import Base58
from lbry.schema.url import normalize_name
from lbry.schema.claim import Claim from lbry.schema.claim import Claim
from lbry.schema.purchase import Purchase from lbry.schema.purchase import Purchase
from lbry.schema.url import normalize_name
from lbry.wallet.account import Account from .script import InputScript, OutputScript
from lbry.wallet.script import InputScript, OutputScript from .constants import COIN, NULL_HASH32
from .bcd_data_stream import BCDataStream
from .hash import TXRef, TXRefImmutable
from .util import ReadOnlyList
if typing.TYPE_CHECKING:
from lbry.wallet.account import Account
from lbry.wallet.ledger import Ledger
from lbry.wallet.wallet import Wallet
log = logging.getLogger()
class Input(BaseInput): class TXRefMutable(TXRef):
script: InputScript
script_class = InputScript __slots__ = ('tx',)
def __init__(self, tx: 'Transaction') -> None:
super().__init__()
self.tx = tx
@property
def id(self):
if self._id is None:
self._id = hexlify(self.hash[::-1]).decode()
return self._id
@property
def hash(self):
if self._hash is None:
self._hash = sha256(sha256(self.tx.raw_sans_segwit))
return self._hash
@property
def height(self):
return self.tx.height
def reset(self):
self._id = None
self._hash = None
class Output(BaseOutput): class TXORef:
script: OutputScript
script_class = OutputScript __slots__ = 'tx_ref', 'position'
def __init__(self, tx_ref: TXRef, position: int) -> None:
self.tx_ref = tx_ref
self.position = position
@property
def id(self):
return f'{self.tx_ref.id}:{self.position}'
@property
def hash(self):
return self.tx_ref.hash + BCDataStream.uint32.pack(self.position)
@property
def is_null(self):
return self.tx_ref.is_null
@property
def txo(self) -> Optional['Output']:
return None
class TXORefResolvable(TXORef):
__slots__ = ('_txo',)
def __init__(self, txo: 'Output') -> None:
assert txo.tx_ref is not None
assert txo.position is not None
super().__init__(txo.tx_ref, txo.position)
self._txo = txo
@property
def txo(self):
return self._txo
class InputOutput:
__slots__ = 'tx_ref', 'position'
def __init__(self, tx_ref: TXRef = None, position: int = None) -> None:
self.tx_ref = tx_ref
self.position = position
@property
def size(self) -> int:
""" Size of this input / output in bytes. """
stream = BCDataStream()
self.serialize_to(stream)
return len(stream.get_bytes())
def get_fee(self, ledger):
return self.size * ledger.fee_per_byte
def serialize_to(self, stream, alternate_script=None):
raise NotImplementedError
class Input(InputOutput):
NULL_SIGNATURE = b'\x00'*72
NULL_PUBLIC_KEY = b'\x00'*33
__slots__ = 'txo_ref', 'sequence', 'coinbase', 'script'
def __init__(self, txo_ref: TXORef, script: InputScript, sequence: int = 0xFFFFFFFF,
tx_ref: TXRef = None, position: int = None) -> None:
super().__init__(tx_ref, position)
self.txo_ref = txo_ref
self.sequence = sequence
self.coinbase = script if txo_ref.is_null else None
self.script = script if not txo_ref.is_null else None
@property
def is_coinbase(self):
return self.coinbase is not None
@classmethod
def spend(cls, txo: 'Output') -> 'Input':
""" Create an input to spend the output."""
assert txo.script.is_pay_pubkey_hash, 'Attempting to spend unsupported output.'
script = InputScript.redeem_pubkey_hash(cls.NULL_SIGNATURE, cls.NULL_PUBLIC_KEY)
return cls(txo.ref, script)
@property
def amount(self) -> int:
""" Amount this input adds to the transaction. """
if self.txo_ref.txo is None:
raise ValueError('Cannot resolve output to get amount.')
return self.txo_ref.txo.amount
@property
def is_my_account(self) -> Optional[bool]:
""" True if the output this input spends is yours. """
if self.txo_ref.txo is None:
return False
return self.txo_ref.txo.is_my_account
@classmethod
def deserialize_from(cls, stream):
tx_ref = TXRefImmutable.from_hash(stream.read(32), -1)
position = stream.read_uint32()
script = stream.read_string()
sequence = stream.read_uint32()
return cls(
TXORef(tx_ref, position),
InputScript(script) if not tx_ref.is_null else script,
sequence
)
def serialize_to(self, stream, alternate_script=None):
stream.write(self.txo_ref.tx_ref.hash)
stream.write_uint32(self.txo_ref.position)
if alternate_script is not None:
stream.write_string(alternate_script)
else:
if self.is_coinbase:
stream.write_string(self.coinbase)
else:
stream.write_string(self.script.source)
stream.write_uint32(self.sequence)
class OutputEffectiveAmountEstimator:
__slots__ = 'txo', 'txi', 'fee', 'effective_amount'
def __init__(self, ledger: 'Ledger', txo: 'Output') -> None:
self.txo = txo
self.txi = Input.spend(txo)
self.fee: int = self.txi.get_fee(ledger)
self.effective_amount: int = txo.amount - self.fee
def __lt__(self, other):
return self.effective_amount < other.effective_amount
class Output(InputOutput):
__slots__ = ( __slots__ = (
'amount', 'script', 'is_change', 'is_my_account',
'channel', 'private_key', 'meta', 'channel', 'private_key', 'meta',
'purchase', 'purchased_claim', 'purchase_receipt', 'purchase', 'purchased_claim', 'purchase_receipt',
'reposted_claim', 'claims', 'reposted_claim', 'claims',
) )
def __init__(self, *args, channel: Optional['Output'] = None, def __init__(self, amount: int, script: OutputScript,
private_key: Optional[str] = None, **kwargs) -> None: tx_ref: TXRef = None, position: int = None,
super().__init__(*args, **kwargs) is_change: Optional[bool] = None, is_my_account: Optional[bool] = None,
channel: Optional['Output'] = None, private_key: Optional[str] = None
) -> None:
super().__init__(tx_ref, position)
self.amount = amount
self.script = script
self.is_change = is_change
self.is_my_account = is_my_account
self.channel = channel self.channel = channel
self.private_key = private_key self.private_key = private_key
self.purchase: 'Output' = None # txo containing purchase metadata self.purchase: 'Output' = None # txo containing purchase metadata
@ -49,10 +234,52 @@ class Output(BaseOutput):
self.meta = {} self.meta = {}
def update_annotations(self, annotated): def update_annotations(self, annotated):
super().update_annotations(annotated) if annotated is None:
self.is_change = False
self.is_my_account = False
else:
self.is_change = annotated.is_change
self.is_my_account = annotated.is_my_account
self.channel = annotated.channel if annotated else None self.channel = annotated.channel if annotated else None
self.private_key = annotated.private_key if annotated else None self.private_key = annotated.private_key if annotated else None
@property
def ref(self):
return TXORefResolvable(self)
@property
def id(self):
return self.ref.id
@property
def pubkey_hash(self):
return self.script.values['pubkey_hash']
@property
def has_address(self):
return 'pubkey_hash' in self.script.values
def get_address(self, ledger):
return ledger.hash160_to_address(self.pubkey_hash)
def get_estimator(self, ledger):
return OutputEffectiveAmountEstimator(ledger, self)
@classmethod
def pay_pubkey_hash(cls, amount, pubkey_hash):
return cls(amount, OutputScript.pay_pubkey_hash(pubkey_hash))
@classmethod
def deserialize_from(cls, stream):
return cls(
amount=stream.read_uint64(),
script=OutputScript(stream.read_string())
)
def serialize_to(self, stream, alternate_script=None):
stream.write_uint64(self.amount)
stream.write_string(self.script.source)
def get_fee(self, ledger): def get_fee(self, ledger):
name_fee = 0 name_fee = 0
if self.script.is_claim_name: if self.script.is_claim_name:
@ -180,34 +407,35 @@ class Output(BaseOutput):
@classmethod @classmethod
def pay_claim_name_pubkey_hash( def pay_claim_name_pubkey_hash(
cls, amount: int, claim_name: str, claim: Claim, pubkey_hash: bytes) -> 'Output': cls, amount: int, claim_name: str, claim: Claim, pubkey_hash: bytes) -> 'Output':
script = cls.script_class.pay_claim_name_pubkey_hash( script = OutputScript.pay_claim_name_pubkey_hash(
claim_name.encode(), claim, pubkey_hash) claim_name.encode(), claim, pubkey_hash)
txo = cls(amount, script) return cls(amount, script)
return txo
@classmethod @classmethod
def pay_update_claim_pubkey_hash( def pay_update_claim_pubkey_hash(
cls, amount: int, claim_name: str, claim_id: str, claim: Claim, pubkey_hash: bytes) -> 'Output': cls, amount: int, claim_name: str, claim_id: str, claim: Claim, pubkey_hash: bytes) -> 'Output':
script = cls.script_class.pay_update_claim_pubkey_hash( script = OutputScript.pay_update_claim_pubkey_hash(
claim_name.encode(), unhexlify(claim_id)[::-1], claim, pubkey_hash) claim_name.encode(), unhexlify(claim_id)[::-1], claim, pubkey_hash
txo = cls(amount, script) )
return txo return cls(amount, script)
@classmethod @classmethod
def pay_support_pubkey_hash(cls, amount: int, claim_name: str, claim_id: str, pubkey_hash: bytes) -> 'Output': def pay_support_pubkey_hash(cls, amount: int, claim_name: str, claim_id: str, pubkey_hash: bytes) -> 'Output':
script = cls.script_class.pay_support_pubkey_hash(claim_name.encode(), unhexlify(claim_id)[::-1], pubkey_hash) script = OutputScript.pay_support_pubkey_hash(
claim_name.encode(), unhexlify(claim_id)[::-1], pubkey_hash
)
return cls(amount, script) return cls(amount, script)
@classmethod @classmethod
def add_purchase_data(cls, purchase: Purchase) -> 'Output': def add_purchase_data(cls, purchase: Purchase) -> 'Output':
script = cls.script_class.return_data(purchase) script = OutputScript.return_data(purchase)
return cls(0, script) return cls(0, script)
@property @property
def is_purchase_data(self) -> bool: def is_purchase_data(self) -> bool:
return self.script.is_return_data and ( return self.script.is_return_data and (
isinstance(self.script.values['data'], Purchase) or isinstance(self.script.values['data'], Purchase) or
Purchase.has_start_byte(self.script.values['data']) Purchase.has_start_byte(self.script.values['data'])
) )
@property @property
@ -246,16 +474,331 @@ class Output(BaseOutput):
return self.claim.stream.fee return self.claim.stream.fee
class Transaction(BaseTransaction): class Transaction:
input_class = Input def __init__(self, raw=None, version: int = 1, locktime: int = 0, is_verified: bool = False,
output_class = Output height: int = -2, position: int = -1) -> None:
self._raw = raw
self._raw_sans_segwit = None
self.is_segwit_flag = 0
self.witnesses: List[bytes] = []
self.ref = TXRefMutable(self)
self.version = version
self.locktime = locktime
self._inputs: List[Input] = []
self._outputs: List[Output] = []
self.is_verified = is_verified
# Height Progression
# -2: not broadcast
# -1: in mempool but has unconfirmed inputs
# 0: in mempool and all inputs confirmed
# +num: confirmed in a specific block (height)
self.height = height
self.position = position
if raw is not None:
self._deserialize()
outputs: ReadOnlyList[Output] @property
inputs: ReadOnlyList[Input] def is_broadcast(self):
return self.height > -2
@property
def is_mempool(self):
return self.height in (-1, 0)
@property
def is_confirmed(self):
return self.height > 0
@property
def id(self):
return self.ref.id
@property
def hash(self):
return self.ref.hash
@property
def raw(self):
if self._raw is None:
self._raw = self._serialize()
return self._raw
@property
def raw_sans_segwit(self):
if self.is_segwit_flag:
if self._raw_sans_segwit is None:
self._raw_sans_segwit = self._serialize(sans_segwit=True)
return self._raw_sans_segwit
return self.raw
def _reset(self):
self._raw = None
self._raw_sans_segwit = None
self.ref.reset()
@property
def inputs(self) -> ReadOnlyList[Input]:
return ReadOnlyList(self._inputs)
@property
def outputs(self) -> ReadOnlyList[Output]:
return ReadOnlyList(self._outputs)
def _add(self, existing_ios: List, new_ios: Iterable[InputOutput], reset=False) -> 'Transaction':
for txio in new_ios:
txio.tx_ref = self.ref
txio.position = len(existing_ios)
existing_ios.append(txio)
if reset:
self._reset()
return self
def add_inputs(self, inputs: Iterable[Input]) -> 'Transaction':
return self._add(self._inputs, inputs, True)
def add_outputs(self, outputs: Iterable[Output]) -> 'Transaction':
return self._add(self._outputs, outputs, True)
@property
def size(self) -> int:
""" Size in bytes of the entire transaction. """
return len(self.raw)
@property
def base_size(self) -> int:
""" Size of transaction without inputs or outputs in bytes. """
return (
self.size
- sum(txi.size for txi in self._inputs)
- sum(txo.size for txo in self._outputs)
)
@property
def input_sum(self):
return sum(i.amount for i in self.inputs if i.txo_ref.txo is not None)
@property
def output_sum(self):
return sum(o.amount for o in self.outputs)
@property
def net_account_balance(self) -> int:
balance = 0
for txi in self.inputs:
if txi.txo_ref.txo is None:
continue
if txi.is_my_account is None:
raise ValueError(
"Cannot access net_account_balance if inputs/outputs do not "
"have is_my_account set properly."
)
if txi.is_my_account:
balance -= txi.amount
for txo in self.outputs:
if txo.is_my_account is None:
raise ValueError(
"Cannot access net_account_balance if inputs/outputs do not "
"have is_my_account set properly."
)
if txo.is_my_account:
balance += txo.amount
return balance
@property
def fee(self) -> int:
return self.input_sum - self.output_sum
def get_base_fee(self, ledger) -> int:
""" Fee for base tx excluding inputs and outputs. """
return self.base_size * ledger.fee_per_byte
def get_effective_input_sum(self, ledger) -> int:
""" Sum of input values *minus* the cost involved to spend them. """
return sum(txi.amount - txi.get_fee(ledger) for txi in self._inputs)
def get_total_output_sum(self, ledger) -> int:
""" Sum of output values *plus* the cost involved to spend them. """
return sum(txo.amount + txo.get_fee(ledger) for txo in self._outputs)
def _serialize(self, with_inputs: bool = True, sans_segwit: bool = False) -> bytes:
stream = BCDataStream()
stream.write_uint32(self.version)
if with_inputs:
stream.write_compact_size(len(self._inputs))
for txin in self._inputs:
txin.serialize_to(stream)
stream.write_compact_size(len(self._outputs))
for txout in self._outputs:
txout.serialize_to(stream)
stream.write_uint32(self.locktime)
return stream.get_bytes()
def _serialize_for_signature(self, signing_input: int) -> bytes:
stream = BCDataStream()
stream.write_uint32(self.version)
stream.write_compact_size(len(self._inputs))
for i, txin in enumerate(self._inputs):
if signing_input == i:
assert txin.txo_ref.txo is not None
txin.serialize_to(stream, txin.txo_ref.txo.script.source)
else:
txin.serialize_to(stream, b'')
stream.write_compact_size(len(self._outputs))
for txout in self._outputs:
txout.serialize_to(stream)
stream.write_uint32(self.locktime)
stream.write_uint32(self.signature_hash_type(1)) # signature hash type: SIGHASH_ALL
return stream.get_bytes()
def _deserialize(self):
if self._raw is not None:
stream = BCDataStream(self._raw)
self.version = stream.read_uint32()
input_count = stream.read_compact_size()
if input_count == 0:
self.is_segwit_flag = stream.read_uint8()
input_count = stream.read_compact_size()
self._add(self._inputs, [
Input.deserialize_from(stream) for _ in range(input_count)
])
output_count = stream.read_compact_size()
self._add(self._outputs, [
Output.deserialize_from(stream) for _ in range(output_count)
])
if self.is_segwit_flag:
# drain witness portion of transaction
# too many witnesses for no crime
self.witnesses = []
for _ in range(input_count):
for _ in range(stream.read_compact_size()):
self.witnesses.append(stream.read(stream.read_compact_size()))
self.locktime = stream.read_uint32()
@classmethod @classmethod
def pay(cls, amount: int, address: bytes, funding_accounts: List[Account], change_account: Account): def ensure_all_have_same_ledger_and_wallet(
cls, funding_accounts: Iterable['Account'],
change_account: 'Account' = None) -> Tuple['Ledger', 'Wallet']:
ledger = wallet = None
for account in funding_accounts:
if ledger is None:
ledger = account.ledger
wallet = account.wallet
if ledger != account.ledger:
raise ValueError(
'All funding accounts used to create a transaction must be on the same ledger.'
)
if wallet != account.wallet:
raise ValueError(
'All funding accounts used to create a transaction must be from the same wallet.'
)
if change_account is not None:
if change_account.ledger != ledger:
raise ValueError('Change account must use same ledger as funding accounts.')
if change_account.wallet != wallet:
raise ValueError('Change account must use same wallet as funding accounts.')
if ledger is None:
raise ValueError('No ledger found.')
if wallet is None:
raise ValueError('No wallet found.')
return ledger, wallet
@classmethod
async def create(cls, inputs: Iterable[Input], outputs: Iterable[Output],
funding_accounts: Iterable['Account'], change_account: 'Account',
sign: bool = True):
""" Find optimal set of inputs when only outputs are provided; add change
outputs if only inputs are provided or if inputs are greater than outputs. """
tx = cls() \
.add_inputs(inputs) \
.add_outputs(outputs)
ledger, _ = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account)
# value of the outputs plus associated fees
cost = (
tx.get_base_fee(ledger) +
tx.get_total_output_sum(ledger)
)
# value of the inputs less the cost to spend those inputs
payment = tx.get_effective_input_sum(ledger)
try:
for _ in range(5):
if payment < cost:
deficit = cost - payment
spendables = await ledger.get_spendable_utxos(deficit, funding_accounts)
if not spendables:
raise InsufficientFundsError()
payment += sum(s.effective_amount for s in spendables)
tx.add_inputs(s.txi for s in spendables)
cost_of_change = (
tx.get_base_fee(ledger) +
Output.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(ledger)
)
if payment > cost:
change = payment - cost
if change > cost_of_change:
change_address = await change_account.change.get_or_create_usable_address()
change_hash160 = change_account.ledger.address_to_hash160(change_address)
change_amount = change - cost_of_change
change_output = Output.pay_pubkey_hash(change_amount, change_hash160)
change_output.is_change = True
tx.add_outputs([Output.pay_pubkey_hash(change_amount, change_hash160)])
if tx._outputs:
break
# this condition and the outer range(5) loop cover an edge case
# whereby a single input is just enough to cover the fee and
# has some change left over, but the change left over is less
# than the cost_of_change: thus the input is completely
# consumed and no output is added, which is an invalid tx.
# to be able to spend this input we must increase the cost
# of the TX and run through the balance algorithm a second time
# adding an extra input and change output, making tx valid.
# we do this 5 times in case the other UTXOs added are also
# less than the fee, after 5 attempts we give up and go home
cost += cost_of_change + 1
if sign:
await tx.sign(funding_accounts)
except Exception as e:
log.exception('Failed to create transaction:')
await ledger.release_tx(tx)
raise e
return tx
@staticmethod
def signature_hash_type(hash_type):
return hash_type
async def sign(self, funding_accounts: Iterable['Account']):
ledger, wallet = self.ensure_all_have_same_ledger_and_wallet(funding_accounts)
for i, txi in enumerate(self._inputs):
assert txi.script is not None
assert txi.txo_ref.txo is not None
txo_script = txi.txo_ref.txo.script
if txo_script.is_pay_pubkey_hash:
address = ledger.hash160_to_address(txo_script.values['pubkey_hash'])
private_key = await ledger.get_private_key_for_address(wallet, address)
assert private_key is not None, 'Cannot find private key for signing output.'
tx = self._serialize_for_signature(i)
txi.script.values['signature'] = \
private_key.sign(tx) + bytes((self.signature_hash_type(1),))
txi.script.values['pubkey'] = private_key.public_key.pubkey_bytes
txi.script.generate()
else:
raise NotImplementedError("Don't know how to spend this output.")
self._reset()
@classmethod
def pay(cls, amount: int, address: bytes, funding_accounts: List['Account'], change_account: 'Account'):
ledger, wallet = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account) ledger, wallet = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account)
output = Output.pay_pubkey_hash(amount, ledger.address_to_hash160(address)) output = Output.pay_pubkey_hash(amount, ledger.address_to_hash160(address))
return cls.create([], [output], funding_accounts, change_account) return cls.create([], [output], funding_accounts, change_account)
@ -263,7 +806,7 @@ class Transaction(BaseTransaction):
@classmethod @classmethod
def claim_create( def claim_create(
cls, name: str, claim: Claim, amount: int, holding_address: str, cls, name: str, claim: Claim, amount: int, holding_address: str,
funding_accounts: List[Account], change_account: Account, signing_channel: Output = None): funding_accounts: List['Account'], change_account: 'Account', signing_channel: Output = None):
ledger, wallet = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account) ledger, wallet = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account)
claim_output = Output.pay_claim_name_pubkey_hash( claim_output = Output.pay_claim_name_pubkey_hash(
amount, name, claim, ledger.address_to_hash160(holding_address) amount, name, claim, ledger.address_to_hash160(holding_address)
@ -275,7 +818,7 @@ class Transaction(BaseTransaction):
@classmethod @classmethod
def claim_update( def claim_update(
cls, previous_claim: Output, claim: Claim, amount: int, holding_address: str, cls, previous_claim: Output, claim: Claim, amount: int, holding_address: str,
funding_accounts: List[Account], change_account: Account, signing_channel: Output = None): funding_accounts: List['Account'], change_account: 'Account', signing_channel: Output = None):
ledger, wallet = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account) ledger, wallet = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account)
updated_claim = Output.pay_update_claim_pubkey_hash( updated_claim = Output.pay_update_claim_pubkey_hash(
amount, previous_claim.claim_name, previous_claim.claim_id, amount, previous_claim.claim_name, previous_claim.claim_id,
@ -291,7 +834,7 @@ class Transaction(BaseTransaction):
@classmethod @classmethod
def support(cls, claim_name: str, claim_id: str, amount: int, holding_address: str, def support(cls, claim_name: str, claim_id: str, amount: int, holding_address: str,
funding_accounts: List[Account], change_account: Account): funding_accounts: List['Account'], change_account: 'Account'):
ledger, wallet = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account) ledger, wallet = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account)
support_output = Output.pay_support_pubkey_hash( support_output = Output.pay_support_pubkey_hash(
amount, claim_name, claim_id, ledger.address_to_hash160(holding_address) amount, claim_name, claim_id, ledger.address_to_hash160(holding_address)
@ -300,7 +843,7 @@ class Transaction(BaseTransaction):
@classmethod @classmethod
def purchase(cls, claim_id: str, amount: int, merchant_address: bytes, def purchase(cls, claim_id: str, amount: int, merchant_address: bytes,
funding_accounts: List[Account], change_account: Account): funding_accounts: List['Account'], change_account: 'Account'):
ledger, wallet = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account) ledger, wallet = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account)
payment = Output.pay_pubkey_hash(amount, ledger.address_to_hash160(merchant_address)) payment = Output.pay_pubkey_hash(amount, ledger.address_to_hash160(merchant_address))
data = Output.add_purchase_data(Purchase(claim_id)) data = Output.add_purchase_data(Purchase(claim_id))

View file

@ -1,6 +1,6 @@
import re import re
from typing import TypeVar, Sequence, Optional from typing import TypeVar, Sequence, Optional
from lbry.wallet.client.constants import COIN from .constants import COIN
def coins_to_satoshis(coins): def coins_to_satoshis(coins):

View file

@ -10,9 +10,11 @@ from collections import UserDict
from hashlib import sha256 from hashlib import sha256
from operator import attrgetter from operator import attrgetter
from lbry.crypto.crypt import better_aes_encrypt, better_aes_decrypt from lbry.crypto.crypt import better_aes_encrypt, better_aes_decrypt
from .account import Account
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from lbry.wallet.client import basemanager, baseaccount, baseledger from lbry.wallet.manager import WalletManager
from lbry.wallet.ledger import Ledger
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -65,7 +67,7 @@ class Wallet:
preferences: TimestampedPreferences preferences: TimestampedPreferences
encryption_password: Optional[str] encryption_password: Optional[str]
def __init__(self, name: str = 'Wallet', accounts: MutableSequence['baseaccount.BaseAccount'] = None, def __init__(self, name: str = 'Wallet', accounts: MutableSequence['Account'] = None,
storage: 'WalletStorage' = None, preferences: dict = None) -> None: storage: 'WalletStorage' = None, preferences: dict = None) -> None:
self.name = name self.name = name
self.accounts = accounts or [] self.accounts = accounts or []
@ -79,30 +81,30 @@ class Wallet:
return os.path.basename(self.storage.path) return os.path.basename(self.storage.path)
return self.name return self.name
def add_account(self, account: 'baseaccount.BaseAccount'): def add_account(self, account: 'Account'):
self.accounts.append(account) self.accounts.append(account)
def generate_account(self, ledger: 'baseledger.BaseLedger') -> 'baseaccount.BaseAccount': def generate_account(self, ledger: 'Ledger') -> 'Account':
return ledger.account_class.generate(ledger, self) return Account.generate(ledger, self)
@property @property
def default_account(self) -> Optional['baseaccount.BaseAccount']: def default_account(self) -> Optional['Account']:
for account in self.accounts: for account in self.accounts:
return account return account
return None return None
def get_account_or_default(self, account_id: str) -> Optional['baseaccount.BaseAccount']: def get_account_or_default(self, account_id: str) -> Optional['Account']:
if account_id is None: if account_id is None:
return self.default_account return self.default_account
return self.get_account_or_error(account_id) return self.get_account_or_error(account_id)
def get_account_or_error(self, account_id: str) -> 'baseaccount.BaseAccount': def get_account_or_error(self, account_id: str) -> 'Account':
for account in self.accounts: for account in self.accounts:
if account.id == account_id: if account.id == account_id:
return account return account
raise ValueError(f"Couldn't find account: {account_id}.") raise ValueError(f"Couldn't find account: {account_id}.")
def get_accounts_or_all(self, account_ids: List[str]) -> Sequence['baseaccount.BaseAccount']: def get_accounts_or_all(self, account_ids: List[str]) -> Sequence['Account']:
return [ return [
self.get_account_or_error(account_id) self.get_account_or_error(account_id)
for account_id in account_ids for account_id in account_ids
@ -117,7 +119,7 @@ class Wallet:
return accounts return accounts
@classmethod @classmethod
def from_storage(cls, storage: 'WalletStorage', manager: 'basemanager.BaseWalletManager') -> 'Wallet': def from_storage(cls, storage: 'WalletStorage', manager: 'WalletManager') -> 'Wallet':
json_dict = storage.read() json_dict = storage.read()
wallet = cls( wallet = cls(
name=json_dict.get('name', 'Wallet'), name=json_dict.get('name', 'Wallet'),
@ -127,7 +129,7 @@ class Wallet:
account_dicts: Sequence[dict] = json_dict.get('accounts', []) account_dicts: Sequence[dict] = json_dict.get('accounts', [])
for account_dict in account_dicts: for account_dict in account_dicts:
ledger = manager.get_or_create_ledger(account_dict['ledger']) ledger = manager.get_or_create_ledger(account_dict['ledger'])
ledger.account_class.from_dict(ledger, wallet, account_dict) Account.from_dict(ledger, wallet, account_dict)
return wallet return wallet
def to_dict(self, encrypt_password: str = None): def to_dict(self, encrypt_password: str = None):
@ -173,15 +175,15 @@ class Wallet:
decompressed = zlib.decompress(decrypted) decompressed = zlib.decompress(decrypted)
return json.loads(decompressed) return json.loads(decompressed)
def merge(self, manager: 'basemanager.BaseWalletManager', def merge(self, manager: 'WalletManager',
password: str, data: str) -> List['baseaccount.BaseAccount']: password: str, data: str) -> List['Account']:
assert not self.is_locked, "Cannot sync apply on a locked wallet." assert not self.is_locked, "Cannot sync apply on a locked wallet."
added_accounts = [] added_accounts = []
decrypted_data = self.unpack(password, data) decrypted_data = self.unpack(password, data)
self.preferences.merge(decrypted_data.get('preferences', {})) self.preferences.merge(decrypted_data.get('preferences', {}))
for account_dict in decrypted_data['accounts']: for account_dict in decrypted_data['accounts']:
ledger = manager.get_or_create_ledger(account_dict['ledger']) ledger = manager.get_or_create_ledger(account_dict['ledger'])
_, _, pubkey = ledger.account_class.keys_from_dict(ledger, account_dict) _, _, pubkey = Account.keys_from_dict(ledger, account_dict)
account_id = pubkey.address account_id = pubkey.address
local_match = None local_match = None
for local_account in self.accounts: for local_account in self.accounts:
@ -191,7 +193,7 @@ class Wallet:
if local_match is not None: if local_match is not None:
local_match.merge(account_dict) local_match.merge(account_dict)
else: else:
new_account = ledger.account_class.from_dict(ledger, self, account_dict) new_account = Account.from_dict(ledger, self, account_dict)
added_accounts.append(new_account) added_accounts.append(new_account)
return added_accounts return added_accounts

View file

@ -1,7 +1,7 @@
import asyncio import asyncio
import json import json
from lbry.wallet.client.wallet import ENCRYPT_ON_DISK from lbry.wallet import ENCRYPT_ON_DISK
from lbry.error import InvalidPasswordError from lbry.error import InvalidPasswordError
from lbry.testcase import CommandTestCase from lbry.testcase import CommandTestCase
from lbry.wallet.dewies import dict_values_to_lbc from lbry.wallet.dewies import dict_values_to_lbc

View file

@ -2,7 +2,6 @@ import unittest
from unittest import mock from unittest import mock
import json import json
import lbry.wallet
from lbry.conf import Config from lbry.conf import Config
from lbry.extras.daemon.storage import SQLiteStorage from lbry.extras.daemon.storage import SQLiteStorage
from lbry.extras.daemon.ComponentManager import ComponentManager from lbry.extras.daemon.ComponentManager import ComponentManager
@ -11,8 +10,7 @@ from lbry.extras.daemon.Components import HASH_ANNOUNCER_COMPONENT
from lbry.extras.daemon.Components import UPNP_COMPONENT, BLOB_COMPONENT from lbry.extras.daemon.Components import UPNP_COMPONENT, BLOB_COMPONENT
from lbry.extras.daemon.Components import PEER_PROTOCOL_SERVER_COMPONENT, EXCHANGE_RATE_MANAGER_COMPONENT from lbry.extras.daemon.Components import PEER_PROTOCOL_SERVER_COMPONENT, EXCHANGE_RATE_MANAGER_COMPONENT
from lbry.extras.daemon.Daemon import Daemon as LBRYDaemon from lbry.extras.daemon.Daemon import Daemon as LBRYDaemon
from lbry.wallet import LbryWalletManager from lbry.wallet import WalletManager, Wallet
from lbry.wallet.client.wallet import Wallet
from tests import test_utils from tests import test_utils
# from tests.mocks import mock_conf_settings, FakeNetwork, FakeFileManager # from tests.mocks import mock_conf_settings, FakeNetwork, FakeFileManager
@ -37,7 +35,7 @@ def get_test_daemon(conf: Config, with_fee=False):
) )
daemon = LBRYDaemon(conf, component_manager=component_manager) daemon = LBRYDaemon(conf, component_manager=component_manager)
daemon.payment_rate_manager = OnlyFreePaymentsManager() daemon.payment_rate_manager = OnlyFreePaymentsManager()
daemon.wallet_manager = mock.Mock(spec=LbryWalletManager) daemon.wallet_manager = mock.Mock(spec=WalletManager)
daemon.wallet_manager.wallet = mock.Mock(spec=Wallet) daemon.wallet_manager.wallet = mock.Mock(spec=Wallet)
daemon.wallet_manager.use_encryption = False daemon.wallet_manager.use_encryption = False
daemon.wallet_manager.network = FakeNetwork() daemon.wallet_manager.network = FakeNetwork()

View file

@ -10,13 +10,10 @@ from lbry.testcase import get_fake_exchange_rate_manager
from lbry.utils import generate_id from lbry.utils import generate_id
from lbry.error import InsufficientFundsError from lbry.error import InsufficientFundsError
from lbry.error import KeyFeeAboveMaxAllowedError, ResolveError, DownloadSDTimeoutError, DownloadDataTimeoutError from lbry.error import KeyFeeAboveMaxAllowedError, ResolveError, DownloadSDTimeoutError, DownloadDataTimeoutError
from lbry.wallet.client.wallet import Wallet from lbry.wallet import WalletManager, Wallet, Ledger, Transaction, Input, Output, Database
from lbry.wallet.client.constants import CENT, NULL_HASH32 from lbry.wallet.client.constants import CENT, NULL_HASH32
from lbry.wallet.client.basenetwork import ClientSession from lbry.wallet.client.basenetwork import ClientSession
from lbry.conf import Config from lbry.conf import Config
from lbry.wallet.ledger import MainNetLedger
from lbry.wallet.transaction import Transaction, Input, Output
from lbry.wallet.manager import LbryWalletManager
from lbry.extras.daemon.analytics import AnalyticsManager from lbry.extras.daemon.analytics import AnalyticsManager
from lbry.stream.stream_manager import StreamManager from lbry.stream.stream_manager import StreamManager
from lbry.stream.descriptor import StreamDescriptor from lbry.stream.descriptor import StreamDescriptor
@ -94,16 +91,16 @@ async def get_mock_wallet(sd_hash, storage, balance=10.0, fee=None):
return {'timestamp': 1984} return {'timestamp': 1984}
wallet = Wallet() wallet = Wallet()
ledger = MainNetLedger({ ledger = Ledger({
'db': MainNetLedger.database_class(':memory:'), 'db': Database(':memory:'),
'headers': FakeHeaders(514082) 'headers': FakeHeaders(514082)
}) })
await ledger.db.open() await ledger.db.open()
wallet.generate_account(ledger) wallet.generate_account(ledger)
manager = LbryWalletManager() manager = WalletManager()
manager.config = Config() manager.config = Config()
manager.wallets.append(wallet) manager.wallets.append(wallet)
manager.ledgers[MainNetLedger] = ledger manager.ledgers[Ledger] = ledger
manager.ledger.network.client = ClientSession( manager.ledger.network.client = ClientSession(
network=manager.ledger.network, server=('fakespv.lbry.com', 50001) network=manager.ledger.network, server=('fakespv.lbry.com', 50001)
) )

View file

@ -1,17 +1,13 @@
from binascii import hexlify from binascii import hexlify
from lbry.testcase import AsyncioTestCase from lbry.testcase import AsyncioTestCase
from lbry.wallet.client.wallet import Wallet from lbry.wallet import Wallet, Ledger, Database, Headers, Account, SingleKey, HierarchicalDeterministic
from lbry.wallet.ledger import MainNetLedger, WalletDatabase
from lbry.wallet.header import Headers
from lbry.wallet.account import Account
from lbry.wallet.client.baseaccount import SingleKey, HierarchicalDeterministic
class TestAccount(AsyncioTestCase): class TestAccount(AsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):
self.ledger = MainNetLedger({ self.ledger = Ledger({
'db': WalletDatabase(':memory:'), 'db': Database(':memory:'),
'headers': Headers(':memory:') 'headers': Headers(':memory:')
}) })
await self.ledger.db.open() await self.ledger.db.open()
@ -236,8 +232,8 @@ class TestAccount(AsyncioTestCase):
class TestSingleKeyAccount(AsyncioTestCase): class TestSingleKeyAccount(AsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):
self.ledger = MainNetLedger({ self.ledger = Ledger({
'db': WalletDatabase(':memory:'), 'db': Database(':memory:'),
'headers': Headers(':memory:') 'headers': Headers(':memory:')
}) })
await self.ledger.db.open() await self.ledger.db.open()
@ -327,7 +323,7 @@ class TestSingleKeyAccount(AsyncioTestCase):
self.assertEqual(len(keys), 1) self.assertEqual(len(keys), 1)
async def test_generate_account_from_seed(self): async def test_generate_account_from_seed(self):
account = self.ledger.account_class.from_dict( account = Account.from_dict(
self.ledger, Wallet(), { self.ledger, Wallet(), {
"seed": "seed":
"carbon smart garage balance margin twelve chest sword toas" "carbon smart garage balance margin twelve chest sword toas"
@ -432,8 +428,8 @@ class AccountEncryptionTests(AsyncioTestCase):
} }
async def asyncSetUp(self): async def asyncSetUp(self):
self.ledger = MainNetLedger({ self.ledger = Ledger({
'db': WalletDatabase(':memory:'), 'db': Database(':memory:'),
'headers': Headers(':memory:') 'headers': Headers(':memory:')
}) })
@ -489,7 +485,7 @@ class AccountEncryptionTests(AsyncioTestCase):
account_data = self.unencrypted_account.copy() account_data = self.unencrypted_account.copy()
del account_data['seed'] del account_data['seed']
del account_data['private_key'] del account_data['private_key']
account = self.ledger.account_class.from_dict(self.ledger, Wallet(), account_data) account = Account.from_dict(self.ledger, Wallet(), account_data)
encrypted = account.to_dict('password') encrypted = account.to_dict('password')
self.assertFalse(encrypted['seed']) self.assertFalse(encrypted['seed'])
self.assertFalse(encrypted['private_key']) self.assertFalse(encrypted['private_key'])

View file

@ -1,6 +1,6 @@
import unittest import unittest
from lbry.wallet.client.bcd_data_stream import BCDataStream from lbry.wallet.bcd_data_stream import BCDataStream
class TestBCDataStream(unittest.TestCase): class TestBCDataStream(unittest.TestCase):

View file

@ -1,10 +1,10 @@
from binascii import unhexlify, hexlify from binascii import unhexlify, hexlify
from lbry.testcase import AsyncioTestCase from lbry.testcase import AsyncioTestCase
from lbry.wallet.client.bip32 import PubKey, PrivateKey, from_extended_key_string
from lbry.wallet import Ledger, Database, Headers
from tests.unit.wallet.key_fixtures import expected_ids, expected_privkeys, expected_hardened_privkeys from tests.unit.wallet.key_fixtures import expected_ids, expected_privkeys, expected_hardened_privkeys
from lbry.wallet.client.bip32 import PubKey, PrivateKey, from_extended_key_string
from lbry.wallet import MainNetLedger as ledger_class
class BIP32Tests(AsyncioTestCase): class BIP32Tests(AsyncioTestCase):
@ -46,9 +46,9 @@ class BIP32Tests(AsyncioTestCase):
with self.assertRaisesRegex(ValueError, 'private key must be 32 bytes'): with self.assertRaisesRegex(ValueError, 'private key must be 32 bytes'):
PrivateKey(None, b'abcd', b'abcd'*8, 0, 255) PrivateKey(None, b'abcd', b'abcd'*8, 0, 255)
private_key = PrivateKey( private_key = PrivateKey(
ledger_class({ Ledger({
'db': ledger_class.database_class(':memory:'), 'db': Database(':memory:'),
'headers': ledger_class.headers_class(':memory:'), 'headers': Headers(':memory:'),
}), }),
unhexlify('2423f3dc6087d9683f73a684935abc0ccd8bc26370588f56653128c6a6f0bf7c'), unhexlify('2423f3dc6087d9683f73a684935abc0ccd8bc26370588f56653128c6a6f0bf7c'),
b'abcd'*8, 0, 1 b'abcd'*8, 0, 1
@ -67,9 +67,9 @@ class BIP32Tests(AsyncioTestCase):
async def test_private_key_derivation(self): async def test_private_key_derivation(self):
private_key = PrivateKey( private_key = PrivateKey(
ledger_class({ Ledger({
'db': ledger_class.database_class(':memory:'), 'db': Database(':memory:'),
'headers': ledger_class.headers_class(':memory:'), 'headers': Headers(':memory:'),
}), }),
unhexlify('2423f3dc6087d9683f73a684935abc0ccd8bc26370588f56653128c6a6f0bf7c'), unhexlify('2423f3dc6087d9683f73a684935abc0ccd8bc26370588f56653128c6a6f0bf7c'),
b'abcd'*8, 0, 1 b'abcd'*8, 0, 1
@ -84,9 +84,9 @@ class BIP32Tests(AsyncioTestCase):
self.assertEqual(hexlify(new_privkey.private_key_bytes), expected_hardened_privkeys[i - 1 - PrivateKey.HARDENED]) self.assertEqual(hexlify(new_privkey.private_key_bytes), expected_hardened_privkeys[i - 1 - PrivateKey.HARDENED])
async def test_from_extended_keys(self): async def test_from_extended_keys(self):
ledger = ledger_class({ ledger = Ledger({
'db': ledger_class.database_class(':memory:'), 'db': Database(':memory:'),
'headers': ledger_class.headers_class(':memory:'), 'headers': Headers(':memory:'),
}) })
self.assertIsInstance( self.assertIsInstance(
from_extended_key_string( from_extended_key_string(

View file

@ -2,9 +2,9 @@ from types import GeneratorType
from lbry.testcase import AsyncioTestCase from lbry.testcase import AsyncioTestCase
from lbry.wallet import MainNetLedger as ledger_class from lbry.wallet import Ledger, Database, Headers
from lbry.wallet.client.coinselection import CoinSelector, MAXIMUM_TRIES from lbry.wallet.client.coinselection import CoinSelector, MAXIMUM_TRIES
from lbry.wallet.client.constants import CENT from lbry.constants import CENT
from tests.unit.wallet.test_transaction import get_output as utxo from tests.unit.wallet.test_transaction import get_output as utxo
@ -20,9 +20,9 @@ def search(*args, **kwargs):
class BaseSelectionTestCase(AsyncioTestCase): class BaseSelectionTestCase(AsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):
self.ledger = ledger_class({ self.ledger = Ledger({
'db': ledger_class.database_class(':memory:'), 'db': Database(':memory:'),
'headers': ledger_class.headers_class(':memory:'), 'headers': Headers(':memory:'),
}) })
await self.ledger.db.open() await self.ledger.db.open()

View file

@ -6,11 +6,11 @@ import tempfile
import asyncio import asyncio
from concurrent.futures.thread import ThreadPoolExecutor from concurrent.futures.thread import ThreadPoolExecutor
from lbry.wallet import MainNetLedger from lbry.wallet import (
from lbry.wallet.transaction import Transaction Wallet, Account, Ledger, Database, Headers, Transaction, Input
from lbry.wallet.client.wallet import Wallet )
from lbry.wallet.client.constants import COIN from lbry.wallet.client.constants import COIN
from lbry.wallet.client.basedatabase import query, interpolate, constraints_to_sql, AIOSQLite from lbry.wallet.database import query, interpolate, constraints_to_sql, AIOSQLite
from lbry.crypto.hash import sha256 from lbry.crypto.hash import sha256
from lbry.testcase import AsyncioTestCase from lbry.testcase import AsyncioTestCase
@ -195,9 +195,9 @@ class TestQueryBuilder(unittest.TestCase):
class TestQueries(AsyncioTestCase): class TestQueries(AsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):
self.ledger = MainNetLedger({ self.ledger = Ledger({
'db': MainNetLedger.database_class(':memory:'), 'db': Database(':memory:'),
'headers': MainNetLedger.headers_class(':memory:') 'headers': Headers(':memory:')
}) })
self.wallet = Wallet() self.wallet = Wallet()
await self.ledger.db.open() await self.ledger.db.open()
@ -206,13 +206,13 @@ class TestQueries(AsyncioTestCase):
await self.ledger.db.close() await self.ledger.db.close()
async def create_account(self, wallet=None): async def create_account(self, wallet=None):
account = self.ledger.account_class.generate(self.ledger, wallet or self.wallet) account = Account.generate(self.ledger, wallet or self.wallet)
await account.ensure_address_gap() await account.ensure_address_gap()
return account return account
async def create_tx_from_nothing(self, my_account, height): async def create_tx_from_nothing(self, my_account, height):
to_address = await my_account.receiving.get_or_create_usable_address() to_address = await my_account.receiving.get_or_create_usable_address()
to_hash = MainNetLedger.address_to_hash160(to_address) to_hash = Ledger.address_to_hash160(to_address)
tx = Transaction(height=height, is_verified=True) \ tx = Transaction(height=height, is_verified=True) \
.add_inputs([self.txi(self.txo(1, sha256(str(height).encode())))]) \ .add_inputs([self.txi(self.txo(1, sha256(str(height).encode())))]) \
.add_outputs([self.txo(1, to_hash)]) .add_outputs([self.txo(1, to_hash)])
@ -224,7 +224,7 @@ class TestQueries(AsyncioTestCase):
from_hash = txo.script.values['pubkey_hash'] from_hash = txo.script.values['pubkey_hash']
from_address = self.ledger.hash160_to_address(from_hash) from_address = self.ledger.hash160_to_address(from_hash)
to_address = await to_account.receiving.get_or_create_usable_address() to_address = await to_account.receiving.get_or_create_usable_address()
to_hash = MainNetLedger.address_to_hash160(to_address) to_hash = Ledger.address_to_hash160(to_address)
tx = Transaction(height=height, is_verified=True) \ tx = Transaction(height=height, is_verified=True) \
.add_inputs([self.txi(txo)]) \ .add_inputs([self.txi(txo)]) \
.add_outputs([self.txo(1, to_hash)]) .add_outputs([self.txo(1, to_hash)])
@ -248,7 +248,7 @@ class TestQueries(AsyncioTestCase):
return get_output(int(amount*COIN), address) return get_output(int(amount*COIN), address)
def txi(self, txo): def txi(self, txo):
return Transaction.input_class.spend(txo) return Input.spend(txo)
async def test_large_tx_doesnt_hit_variable_limits(self): async def test_large_tx_doesnt_hit_variable_limits(self):
# SQLite is usually compiled with 999 variables limit: https://www.sqlite.org/limits.html # SQLite is usually compiled with 999 variables limit: https://www.sqlite.org/limits.html
@ -408,9 +408,9 @@ class TestUpgrade(AsyncioTestCase):
return [col[0] for col in conn.execute(sql).fetchall()] return [col[0] for col in conn.execute(sql).fetchall()]
async def test_reset_on_version_change(self): async def test_reset_on_version_change(self):
self.ledger = MainNetLedger({ self.ledger = Ledger({
'db': MainNetLedger.database_class(self.path), 'db': Database(self.path),
'headers': MainNetLedger.headers_class(':memory:') 'headers': Headers(':memory:')
}) })
# initial open, pre-version enabled db # initial open, pre-version enabled db

View file

@ -2,10 +2,7 @@ import os
from binascii import hexlify from binascii import hexlify
from lbry.testcase import AsyncioTestCase from lbry.testcase import AsyncioTestCase
from lbry.wallet.client.wallet import Wallet from lbry.wallet import Wallet, Account, Transaction, Output, Input, Ledger, Database, Headers
from lbry.wallet.account import Account
from lbry.wallet.transaction import Transaction, Output, Input
from lbry.wallet.ledger import MainNetLedger
from tests.unit.wallet.test_transaction import get_transaction, get_output from tests.unit.wallet.test_transaction import get_transaction, get_output
from tests.unit.wallet.test_headers import HEADERS, block_bytes from tests.unit.wallet.test_headers import HEADERS, block_bytes
@ -40,9 +37,9 @@ class MockNetwork:
class LedgerTestCase(AsyncioTestCase): class LedgerTestCase(AsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):
self.ledger = MainNetLedger({ self.ledger = Ledger({
'db': MainNetLedger.database_class(':memory:'), 'db': Database(':memory:'),
'headers': MainNetLedger.headers_class(':memory:') 'headers': Headers(':memory:')
}) })
self.account = Account.generate(self.ledger, Wallet(), "lbryum") self.account = Account.generate(self.ledger, Wallet(), "lbryum")
await self.ledger.db.open() await self.ledger.db.open()
@ -76,7 +73,7 @@ class LedgerTestCase(AsyncioTestCase):
class TestSynchronization(LedgerTestCase): class TestSynchronization(LedgerTestCase):
async def test_update_history(self): async def test_update_history(self):
account = self.ledger.account_class.generate(self.ledger, Wallet(), "torba") account = Account.generate(self.ledger, Wallet(), "torba")
address = await account.receiving.get_or_create_usable_address() address = await account.receiving.get_or_create_usable_address()
address_details = await self.ledger.db.get_address(address=address) address_details = await self.ledger.db.get_address(address=address)
self.assertIsNone(address_details['history']) self.assertIsNone(address_details['history'])

View file

@ -3,9 +3,7 @@ from binascii import unhexlify
from lbry.testcase import AsyncioTestCase from lbry.testcase import AsyncioTestCase
from lbry.wallet.client.constants import CENT, NULL_HASH32 from lbry.wallet.client.constants import CENT, NULL_HASH32
from lbry.wallet.ledger import MainNetLedger from lbry.wallet import Ledger, Database, Headers, Transaction, Input, Output
from lbry.wallet.transaction import Transaction, Input, Output
from lbry.schema.claim import Claim from lbry.schema.claim import Claim
@ -110,9 +108,9 @@ class TestValidatingOldSignatures(AsyncioTestCase):
)) ))
channel = channel_tx.outputs[0] channel = channel_tx.outputs[0]
ledger = MainNetLedger({ ledger = Ledger({
'db': MainNetLedger.database_class(':memory:'), 'db': Database(':memory:'),
'headers': MainNetLedger.headers_class(':memory:') 'headers': Headers(':memory:')
}) })
self.assertTrue(stream.is_signed_by(channel, ledger)) self.assertTrue(stream.is_signed_by(channel, ledger))

View file

@ -1,11 +1,11 @@
from lbry.wallet.script import OutputScript
import unittest import unittest
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from lbry.wallet.client.bcd_data_stream import BCDataStream from lbry.wallet.bcd_data_stream import BCDataStream
from lbry.wallet.client.basescript import Template, ParseError, tokenize, push_data from lbry.wallet.script import (
from lbry.wallet.client.basescript import PUSH_SINGLE, PUSH_INTEGER, PUSH_MANY, OP_HASH160, OP_EQUAL InputScript, OutputScript, Template, ParseError, tokenize, push_data,
from lbry.wallet.client.basescript import BaseInputScript, BaseOutputScript PUSH_SINGLE, PUSH_INTEGER, PUSH_MANY, OP_HASH160, OP_EQUAL
)
def parse(opcodes, source): def parse(opcodes, source):
@ -102,12 +102,12 @@ class TestRedeemPubKeyHash(unittest.TestCase):
def redeem_pubkey_hash(self, sig, pubkey): def redeem_pubkey_hash(self, sig, pubkey):
# this checks that factory function correctly sets up the script # this checks that factory function correctly sets up the script
src1 = BaseInputScript.redeem_pubkey_hash(unhexlify(sig), unhexlify(pubkey)) src1 = InputScript.redeem_pubkey_hash(unhexlify(sig), unhexlify(pubkey))
self.assertEqual(src1.template.name, 'pubkey_hash') self.assertEqual(src1.template.name, 'pubkey_hash')
self.assertEqual(hexlify(src1.values['signature']), sig) self.assertEqual(hexlify(src1.values['signature']), sig)
self.assertEqual(hexlify(src1.values['pubkey']), pubkey) self.assertEqual(hexlify(src1.values['pubkey']), pubkey)
# now we test that it will round trip # now we test that it will round trip
src2 = BaseInputScript(src1.source) src2 = InputScript(src1.source)
self.assertEqual(src2.template.name, 'pubkey_hash') self.assertEqual(src2.template.name, 'pubkey_hash')
self.assertEqual(hexlify(src2.values['signature']), sig) self.assertEqual(hexlify(src2.values['signature']), sig)
self.assertEqual(hexlify(src2.values['pubkey']), pubkey) self.assertEqual(hexlify(src2.values['pubkey']), pubkey)
@ -130,7 +130,7 @@ class TestRedeemScriptHash(unittest.TestCase):
def redeem_script_hash(self, sigs, pubkeys): def redeem_script_hash(self, sigs, pubkeys):
# this checks that factory function correctly sets up the script # this checks that factory function correctly sets up the script
src1 = BaseInputScript.redeem_script_hash( src1 = InputScript.redeem_script_hash(
[unhexlify(sig) for sig in sigs], [unhexlify(sig) for sig in sigs],
[unhexlify(pubkey) for pubkey in pubkeys] [unhexlify(pubkey) for pubkey in pubkeys]
) )
@ -141,7 +141,7 @@ class TestRedeemScriptHash(unittest.TestCase):
self.assertEqual(subscript1.values['signatures_count'], len(sigs)) self.assertEqual(subscript1.values['signatures_count'], len(sigs))
self.assertEqual(subscript1.values['pubkeys_count'], len(pubkeys)) self.assertEqual(subscript1.values['pubkeys_count'], len(pubkeys))
# now we test that it will round trip # now we test that it will round trip
src2 = BaseInputScript(src1.source) src2 = InputScript(src1.source)
subscript2 = src2.values['script'] subscript2 = src2.values['script']
self.assertEqual(src2.template.name, 'script_hash') self.assertEqual(src2.template.name, 'script_hash')
self.assertListEqual([hexlify(v) for v in src2.values['signatures']], sigs) self.assertListEqual([hexlify(v) for v in src2.values['signatures']], sigs)
@ -183,11 +183,11 @@ class TestPayPubKeyHash(unittest.TestCase):
def pay_pubkey_hash(self, pubkey_hash): def pay_pubkey_hash(self, pubkey_hash):
# this checks that factory function correctly sets up the script # this checks that factory function correctly sets up the script
src1 = BaseOutputScript.pay_pubkey_hash(unhexlify(pubkey_hash)) src1 = OutputScript.pay_pubkey_hash(unhexlify(pubkey_hash))
self.assertEqual(src1.template.name, 'pay_pubkey_hash') self.assertEqual(src1.template.name, 'pay_pubkey_hash')
self.assertEqual(hexlify(src1.values['pubkey_hash']), pubkey_hash) self.assertEqual(hexlify(src1.values['pubkey_hash']), pubkey_hash)
# now we test that it will round trip # now we test that it will round trip
src2 = BaseOutputScript(src1.source) src2 = OutputScript(src1.source)
self.assertEqual(src2.template.name, 'pay_pubkey_hash') self.assertEqual(src2.template.name, 'pay_pubkey_hash')
self.assertEqual(hexlify(src2.values['pubkey_hash']), pubkey_hash) self.assertEqual(hexlify(src2.values['pubkey_hash']), pubkey_hash)
return hexlify(src1.source) return hexlify(src1.source)
@ -203,11 +203,11 @@ class TestPayScriptHash(unittest.TestCase):
def pay_script_hash(self, script_hash): def pay_script_hash(self, script_hash):
# this checks that factory function correctly sets up the script # this checks that factory function correctly sets up the script
src1 = BaseOutputScript.pay_script_hash(unhexlify(script_hash)) src1 = OutputScript.pay_script_hash(unhexlify(script_hash))
self.assertEqual(src1.template.name, 'pay_script_hash') self.assertEqual(src1.template.name, 'pay_script_hash')
self.assertEqual(hexlify(src1.values['script_hash']), script_hash) self.assertEqual(hexlify(src1.values['script_hash']), script_hash)
# now we test that it will round trip # now we test that it will round trip
src2 = BaseOutputScript(src1.source) src2 = OutputScript(src1.source)
self.assertEqual(src2.template.name, 'pay_script_hash') self.assertEqual(src2.template.name, 'pay_script_hash')
self.assertEqual(hexlify(src2.values['script_hash']), script_hash) self.assertEqual(hexlify(src2.values['script_hash']), script_hash)
return hexlify(src1.source) return hexlify(src1.source)

View file

@ -4,10 +4,7 @@ from itertools import cycle
from lbry.testcase import AsyncioTestCase from lbry.testcase import AsyncioTestCase
from lbry.wallet.client.constants import CENT, COIN, NULL_HASH32 from lbry.wallet.client.constants import CENT, COIN, NULL_HASH32
from lbry.wallet.client.wallet import Wallet from lbry.wallet import Wallet, Account, Ledger, Database, Headers, Transaction, Output, Input
from lbry.wallet.ledger import MainNetLedger
from lbry.wallet.transaction import Transaction, Output, Input
NULL_HASH = b'\x00'*32 NULL_HASH = b'\x00'*32
@ -40,9 +37,9 @@ def get_claim_transaction(claim_name, claim=b''):
class TestSizeAndFeeEstimation(AsyncioTestCase): class TestSizeAndFeeEstimation(AsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):
self.ledger = MainNetLedger({ self.ledger = Ledger({
'db': MainNetLedger.database_class(':memory:'), 'db': Database(':memory:'),
'headers': MainNetLedger.headers_class(':memory:') 'headers': Headers(':memory:')
}) })
await self.ledger.db.open() await self.ledger.db.open()
@ -266,9 +263,9 @@ class TestTransactionSerialization(unittest.TestCase):
class TestTransactionSigning(AsyncioTestCase): class TestTransactionSigning(AsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):
self.ledger = MainNetLedger({ self.ledger = Ledger({
'db': MainNetLedger.database_class(':memory:'), 'db': Database(':memory:'),
'headers': MainNetLedger.headers_class(':memory:') 'headers': Headers(':memory:')
}) })
await self.ledger.db.open() await self.ledger.db.open()
@ -276,7 +273,7 @@ class TestTransactionSigning(AsyncioTestCase):
await self.ledger.db.close() await self.ledger.db.close()
async def test_sign(self): async def test_sign(self):
account = self.ledger.account_class.from_dict( account = Account.from_dict(
self.ledger, Wallet(), { self.ledger, Wallet(), {
"seed": "seed":
"carbon smart garage balance margin twelve chest sword toas" "carbon smart garage balance margin twelve chest sword toas"
@ -305,12 +302,12 @@ class TestTransactionSigning(AsyncioTestCase):
class TransactionIOBalancing(AsyncioTestCase): class TransactionIOBalancing(AsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):
self.ledger = MainNetLedger({ self.ledger = Ledger({
'db': MainNetLedger.database_class(':memory:'), 'db': Database(':memory:'),
'headers': MainNetLedger.headers_class(':memory:') 'headers': Headers(':memory:')
}) })
await self.ledger.db.open() await self.ledger.db.open()
self.account = self.ledger.account_class.from_dict( self.account = Account.from_dict(
self.ledger, Wallet(), { self.ledger, Wallet(), {
"seed": "carbon smart garage balance margin twelve chest sword " "seed": "carbon smart garage balance margin twelve chest sword "
"toast envelope bottom stomach absent" "toast envelope bottom stomach absent"
@ -328,7 +325,7 @@ class TransactionIOBalancing(AsyncioTestCase):
return get_output(int(amount*COIN), address or next(self.hash_cycler)) return get_output(int(amount*COIN), address or next(self.hash_cycler))
def txi(self, txo): def txi(self, txo):
return Transaction.input_class.spend(txo) return Input.spend(txo)
def tx(self, inputs, outputs): def tx(self, inputs, outputs):
return Transaction.create(inputs, outputs, [self.account], self.account) return Transaction.create(inputs, outputs, [self.account], self.account)

View file

@ -3,18 +3,18 @@ from binascii import hexlify
from unittest import TestCase, mock from unittest import TestCase, mock
from lbry.testcase import AsyncioTestCase from lbry.testcase import AsyncioTestCase
from lbry.wallet import (
from lbry.wallet.ledger import MainNetLedger, RegTestLedger Ledger, RegTestLedger, WalletManager, Account,
from lbry.wallet.client.basemanager import BaseWalletManager Wallet, WalletStorage, TimestampedPreferences
from lbry.wallet.client.wallet import Wallet, WalletStorage, TimestampedPreferences )
class TestWalletCreation(AsyncioTestCase): class TestWalletCreation(AsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):
self.manager = BaseWalletManager() self.manager = WalletManager()
config = {'data_path': '/tmp/wallet'} config = {'data_path': '/tmp/wallet'}
self.main_ledger = self.manager.get_or_create_ledger(MainNetLedger.get_id(), config) self.main_ledger = self.manager.get_or_create_ledger(Ledger.get_id(), config)
self.test_ledger = self.manager.get_or_create_ledger(RegTestLedger.get_id(), config) self.test_ledger = self.manager.get_or_create_ledger(RegTestLedger.get_id(), config)
def test_create_wallet_and_accounts(self): def test_create_wallet_and_accounts(self):
@ -66,7 +66,7 @@ class TestWalletCreation(AsyncioTestCase):
) )
self.assertEqual(len(wallet.accounts), 1) self.assertEqual(len(wallet.accounts), 1)
account = wallet.default_account account = wallet.default_account
self.assertIsInstance(account, MainNetLedger.account_class) self.assertIsInstance(account, Account)
self.maxDiff = None self.maxDiff = None
self.assertDictEqual(wallet_dict, wallet.to_dict()) self.assertDictEqual(wallet_dict, wallet.to_dict())
@ -75,9 +75,9 @@ class TestWalletCreation(AsyncioTestCase):
self.assertEqual(decrypted['accounts'][0]['name'], 'An Account') self.assertEqual(decrypted['accounts'][0]['name'], 'An Account')
def test_read_write(self): def test_read_write(self):
manager = BaseWalletManager() manager = WalletManager()
config = {'data_path': '/tmp/wallet'} config = {'data_path': '/tmp/wallet'}
ledger = manager.get_or_create_ledger(MainNetLedger.get_id(), config) ledger = manager.get_or_create_ledger(Ledger.get_id(), config)
with tempfile.NamedTemporaryFile(suffix='.json') as wallet_file: with tempfile.NamedTemporaryFile(suffix='.json') as wallet_file:
wallet_file.write(b'{"version": 1}') wallet_file.write(b'{"version": 1}')