update imports and more merging
This commit is contained in:
parent
c9e410a6f4
commit
fb1af9e3d2
44 changed files with 3667 additions and 470 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -10,7 +10,7 @@ lbry.egg-info
|
|||
__pycache__
|
||||
_trial_temp/
|
||||
|
||||
/tests/integration/files
|
||||
/tests/integration/blockchain/files
|
||||
/tests/.coverage.*
|
||||
|
||||
/lbry/wallet/bin
|
||||
|
|
|
@ -9,7 +9,7 @@ from contextlib import contextmanager
|
|||
from appdirs import user_data_dir, user_config_dir
|
||||
from lbry.error import InvalidCurrencyError
|
||||
from lbry.dht import constants
|
||||
from lbry.wallet.client.coinselection import STRATEGIES
|
||||
from lbry.wallet.coinselection import STRATEGIES
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ from lbry.stream.stream_manager import StreamManager
|
|||
from lbry.extras.daemon.Component import Component
|
||||
from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager
|
||||
from lbry.extras.daemon.storage import SQLiteStorage
|
||||
from lbry.wallet import LbryWalletManager
|
||||
from lbry.wallet import WalletManager
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -17,8 +17,11 @@ from traceback import format_exc
|
|||
from aiohttp import web
|
||||
from functools import wraps, partial
|
||||
from google.protobuf.message import DecodeError
|
||||
from lbry.wallet.client.wallet import Wallet, ENCRYPT_ON_DISK
|
||||
from lbry.wallet.client.baseaccount import SingleKey, HierarchicalDeterministic
|
||||
from lbry.wallet import (
|
||||
Wallet, WalletManager, ENCRYPT_ON_DISK, SingleKey, HierarchicalDeterministic,
|
||||
Ledger, Transaction, Output, Input, Account
|
||||
)
|
||||
from lbry.wallet.dewies import dewies_to_lbc, lbc_to_dewies, dict_values_to_lbc
|
||||
|
||||
from lbry import utils
|
||||
from lbry.conf import Config, Setting, NOT_SET
|
||||
|
@ -39,9 +42,6 @@ from lbry.extras.daemon.ComponentManager import ComponentManager
|
|||
from lbry.extras.daemon.json_response_encoder import JSONResponseEncoder
|
||||
from lbry.extras.daemon import comment_client
|
||||
from lbry.extras.daemon.undecorated import undecorated
|
||||
from lbry.wallet.transaction import Transaction, Output, Input
|
||||
from lbry.wallet.account import Account as LBCAccount
|
||||
from lbry.wallet.dewies import dewies_to_lbc, lbc_to_dewies, dict_values_to_lbc
|
||||
from lbry.schema.claim import Claim
|
||||
from lbry.schema.url import URL
|
||||
|
||||
|
@ -51,8 +51,6 @@ if typing.TYPE_CHECKING:
|
|||
from lbry.extras.daemon.Components import UPnPComponent
|
||||
from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager
|
||||
from lbry.extras.daemon.storage import SQLiteStorage
|
||||
from lbry.wallet.manager import LbryWalletManager
|
||||
from lbry.wallet.ledger import MainNetLedger
|
||||
from lbry.stream.stream_manager import StreamManager
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -322,7 +320,7 @@ class Daemon(metaclass=JSONRPCServerType):
|
|||
return self.component_manager.get_component(DHT_COMPONENT)
|
||||
|
||||
@property
|
||||
def wallet_manager(self) -> typing.Optional['LbryWalletManager']:
|
||||
def wallet_manager(self) -> typing.Optional['WalletManager']:
|
||||
return self.component_manager.get_component(WALLET_COMPONENT)
|
||||
|
||||
@property
|
||||
|
@ -676,7 +674,7 @@ class Daemon(metaclass=JSONRPCServerType):
|
|||
return None, None
|
||||
|
||||
@property
|
||||
def ledger(self) -> Optional['MainNetLedger']:
|
||||
def ledger(self) -> Optional['Ledger']:
|
||||
try:
|
||||
return self.wallet_manager.default_account.ledger
|
||||
except AttributeError:
|
||||
|
@ -1161,7 +1159,7 @@ class Daemon(metaclass=JSONRPCServerType):
|
|||
|
||||
wallet = self.wallet_manager.import_wallet(wallet_path)
|
||||
if not wallet.accounts and create_account:
|
||||
account = LBCAccount.generate(
|
||||
account = Account.generate(
|
||||
self.ledger, wallet, address_generator={
|
||||
'name': SingleKey.name if single_key else HierarchicalDeterministic.name
|
||||
}
|
||||
|
@ -1464,7 +1462,7 @@ class Daemon(metaclass=JSONRPCServerType):
|
|||
Returns: {Account}
|
||||
"""
|
||||
wallet = self.wallet_manager.get_wallet_or_default(wallet_id)
|
||||
account = LBCAccount.from_dict(
|
||||
account = Account.from_dict(
|
||||
self.ledger, wallet, {
|
||||
'name': account_name,
|
||||
'seed': seed,
|
||||
|
@ -1498,7 +1496,7 @@ class Daemon(metaclass=JSONRPCServerType):
|
|||
Returns: {Account}
|
||||
"""
|
||||
wallet = self.wallet_manager.get_wallet_or_default(wallet_id)
|
||||
account = LBCAccount.generate(
|
||||
account = Account.generate(
|
||||
self.ledger, wallet, account_name, {
|
||||
'name': SingleKey.name if single_key else HierarchicalDeterministic.name
|
||||
}
|
||||
|
@ -2134,7 +2132,7 @@ class Daemon(metaclass=JSONRPCServerType):
|
|||
"""
|
||||
wallet = self.wallet_manager.get_wallet_or_default(wallet_id)
|
||||
if account_id:
|
||||
account: LBCAccount = wallet.get_account_or_error(account_id)
|
||||
account = wallet.get_account_or_error(account_id)
|
||||
claims = account.get_claims
|
||||
claim_count = account.get_claim_count
|
||||
else:
|
||||
|
@ -2657,7 +2655,7 @@ class Daemon(metaclass=JSONRPCServerType):
|
|||
"""
|
||||
wallet = self.wallet_manager.get_wallet_or_default(wallet_id)
|
||||
if account_id:
|
||||
account: LBCAccount = wallet.get_account_or_error(account_id)
|
||||
account = wallet.get_account_or_error(account_id)
|
||||
channels = account.get_channels
|
||||
channel_count = account.get_channel_count
|
||||
else:
|
||||
|
@ -2732,7 +2730,7 @@ class Daemon(metaclass=JSONRPCServerType):
|
|||
if channels and channels[0].get_address(self.ledger) != holding_address:
|
||||
holding_address = channels[0].get_address(self.ledger)
|
||||
|
||||
account: LBCAccount = await self.ledger.get_account_for_address(wallet, holding_address)
|
||||
account = await self.ledger.get_account_for_address(wallet, holding_address)
|
||||
if account:
|
||||
# Case 1: channel holding address is in one of the accounts we already have
|
||||
# simply add the certificate to existing account
|
||||
|
@ -2741,7 +2739,7 @@ class Daemon(metaclass=JSONRPCServerType):
|
|||
# Case 2: channel holding address hasn't changed and thus is in the bundled read-only account
|
||||
# create a single-address holding account to manage the channel
|
||||
if holding_address == data['holding_address']:
|
||||
account = LBCAccount.from_dict(self.ledger, wallet, {
|
||||
account = Account.from_dict(self.ledger, wallet, {
|
||||
'name': f"Holding Account For Channel {data['name']}",
|
||||
'public_key': data['holding_public_key'],
|
||||
'address_generator': {'name': 'single-address'}
|
||||
|
@ -3384,7 +3382,7 @@ class Daemon(metaclass=JSONRPCServerType):
|
|||
"""
|
||||
wallet = self.wallet_manager.get_wallet_or_default(wallet_id)
|
||||
if account_id:
|
||||
account: LBCAccount = wallet.get_account_or_error(account_id)
|
||||
account = wallet.get_account_or_error(account_id)
|
||||
streams = account.get_streams
|
||||
stream_count = account.get_stream_count
|
||||
else:
|
||||
|
@ -3727,7 +3725,7 @@ class Daemon(metaclass=JSONRPCServerType):
|
|||
"""
|
||||
wallet = self.wallet_manager.get_wallet_or_default(wallet_id)
|
||||
if account_id:
|
||||
account: LBCAccount = wallet.get_account_or_error(account_id)
|
||||
account = wallet.get_account_or_error(account_id)
|
||||
collections = account.get_collections
|
||||
collection_count = account.get_collection_count
|
||||
else:
|
||||
|
@ -3854,7 +3852,7 @@ class Daemon(metaclass=JSONRPCServerType):
|
|||
"""
|
||||
wallet = self.wallet_manager.get_wallet_or_default(wallet_id)
|
||||
if account_id:
|
||||
account: LBCAccount = wallet.get_account_or_error(account_id)
|
||||
account = wallet.get_account_or_error(account_id)
|
||||
supports = account.get_supports
|
||||
support_count = account.get_support_count
|
||||
else:
|
||||
|
@ -4002,7 +4000,7 @@ class Daemon(metaclass=JSONRPCServerType):
|
|||
"""
|
||||
wallet = self.wallet_manager.get_wallet_or_default(wallet_id)
|
||||
if account_id:
|
||||
account: LBCAccount = wallet.get_account_or_error(account_id)
|
||||
account = wallet.get_account_or_error(account_id)
|
||||
transactions = account.get_transaction_history
|
||||
transaction_count = account.get_transaction_history_count
|
||||
else:
|
||||
|
@ -4696,7 +4694,7 @@ class Daemon(metaclass=JSONRPCServerType):
|
|||
if 'fee_currency' in kwargs or 'fee_amount' in kwargs:
|
||||
return claim_address
|
||||
|
||||
async def get_receiving_address(self, address: str, account: Optional[LBCAccount]) -> str:
|
||||
async def get_receiving_address(self, address: str, account: Optional[Account]) -> str:
|
||||
if address is None and account is not None:
|
||||
return await account.receiving.get_or_create_usable_address()
|
||||
self.valid_address_or_error(address)
|
||||
|
|
|
@ -6,11 +6,9 @@ from json import JSONEncoder
|
|||
|
||||
from google.protobuf.message import DecodeError
|
||||
|
||||
from lbry.wallet.client.wallet import Wallet
|
||||
from lbry.wallet.client.bip32 import PubKey
|
||||
from lbry.schema.claim import Claim
|
||||
from lbry.wallet.ledger import MainNetLedger, Account
|
||||
from lbry.wallet.transaction import Transaction, Output
|
||||
from lbry.wallet import Wallet, Ledger, Account, Transaction, Output
|
||||
from lbry.wallet.bip32 import PubKey
|
||||
from lbry.wallet.dewies import dewies_to_lbc
|
||||
from lbry.stream.managed_stream import ManagedStream
|
||||
|
||||
|
@ -114,7 +112,7 @@ def encode_file_doc():
|
|||
|
||||
class JSONResponseEncoder(JSONEncoder):
|
||||
|
||||
def __init__(self, *args, ledger: MainNetLedger, include_protobuf=False, **kwargs):
|
||||
def __init__(self, *args, ledger: Ledger, include_protobuf=False, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ledger = ledger
|
||||
self.include_protobuf = include_protobuf
|
||||
|
|
|
@ -5,7 +5,7 @@ import typing
|
|||
import asyncio
|
||||
import binascii
|
||||
import time
|
||||
from lbry.wallet.client.basedatabase import SQLiteMixin
|
||||
from lbry.wallet import SQLiteMixin
|
||||
from lbry.conf import Config
|
||||
from lbry.wallet.dewies import dewies_to_lbc, lbc_to_dewies
|
||||
from lbry.wallet.transaction import Transaction
|
||||
|
|
|
@ -14,17 +14,15 @@ from lbry.stream.managed_stream import ManagedStream
|
|||
from lbry.schema.claim import Claim
|
||||
from lbry.schema.url import URL
|
||||
from lbry.wallet.dewies import dewies_to_lbc
|
||||
from lbry.wallet.transaction import Output
|
||||
from lbry.wallet import WalletManager, Wallet, Transaction, Output
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from lbry.conf import Config
|
||||
from lbry.blob.blob_manager import BlobManager
|
||||
from lbry.dht.node import Node
|
||||
from lbry.extras.daemon.analytics import AnalyticsManager
|
||||
from lbry.extras.daemon.storage import SQLiteStorage, StoredContentClaim
|
||||
from lbry.wallet import LbryWalletManager
|
||||
from lbry.wallet.transaction import Transaction
|
||||
from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager
|
||||
from lbry.wallet.client.wallet import Wallet
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -66,7 +64,7 @@ def path_or_none(p) -> Optional[str]:
|
|||
|
||||
class StreamManager:
|
||||
def __init__(self, loop: asyncio.AbstractEventLoop, config: 'Config', blob_manager: 'BlobManager',
|
||||
wallet_manager: 'LbryWalletManager', storage: 'SQLiteStorage', node: Optional['Node'],
|
||||
wallet_manager: 'WalletManager', storage: 'SQLiteStorage', node: Optional['Node'],
|
||||
analytics_manager: Optional['AnalyticsManager'] = None):
|
||||
self.loop = loop
|
||||
self.config = config
|
||||
|
|
|
@ -14,18 +14,11 @@ from time import time
|
|||
from binascii import unhexlify
|
||||
from functools import partial
|
||||
|
||||
import lbry.wallet
|
||||
from lbry.wallet import WalletManager, Wallet, Ledger, Account, Transaction
|
||||
from lbry.conf import Config
|
||||
from lbry.wallet import LbryWalletManager
|
||||
from lbry.wallet.account import Account
|
||||
from lbry.wallet.util import satoshis_to_coins
|
||||
from lbry.wallet.orchstr8 import Conductor
|
||||
from lbry.wallet.transaction import Transaction
|
||||
from lbry.wallet.client.wallet import Wallet
|
||||
from lbry.wallet.client.util import satoshis_to_coins
|
||||
from lbry.wallet.orchstr8.node import BlockchainNode, WalletNode
|
||||
from lbry.wallet.client.baseledger import BaseLedger
|
||||
from lbry.wallet.client.baseaccount import BaseAccount
|
||||
from lbry.wallet.client.basemanager import BaseWalletManager
|
||||
|
||||
from lbry.extras.daemon.Daemon import Daemon, jsonrpc_dumps_pretty
|
||||
from lbry.extras.daemon.Components import Component, WalletComponent
|
||||
|
@ -215,25 +208,19 @@ class AdvanceTimeTestCase(AsyncioTestCase):
|
|||
class IntegrationTestCase(AsyncioTestCase):
|
||||
|
||||
SEED = None
|
||||
LEDGER = lbry.wallet
|
||||
MANAGER = LbryWalletManager
|
||||
ENABLE_SEGWIT = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.conductor: Optional[Conductor] = None
|
||||
self.blockchain: Optional[BlockchainNode] = None
|
||||
self.wallet_node: Optional[WalletNode] = None
|
||||
self.manager: Optional[BaseWalletManager] = None
|
||||
self.ledger: Optional[BaseLedger] = None
|
||||
self.manager: Optional[WalletManager] = None
|
||||
self.ledger: Optional[Ledger] = None
|
||||
self.wallet: Optional[Wallet] = None
|
||||
self.account: Optional[BaseAccount] = None
|
||||
self.account: Optional[Account] = None
|
||||
|
||||
async def asyncSetUp(self):
|
||||
self.conductor = Conductor(
|
||||
ledger_module=self.LEDGER, manager_module=self.MANAGER,
|
||||
enable_segwit=self.ENABLE_SEGWIT, seed=self.SEED
|
||||
)
|
||||
self.conductor = Conductor(seed=self.SEED)
|
||||
await self.conductor.start_blockchain()
|
||||
self.addCleanup(self.conductor.stop_blockchain)
|
||||
await self.conductor.start_spv()
|
||||
|
@ -317,14 +304,13 @@ class CommandTestCase(IntegrationTestCase):
|
|||
VERBOSITY = logging.WARN
|
||||
blob_lru_cache_size = 0
|
||||
|
||||
account: Account
|
||||
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
|
||||
logging.getLogger('lbry.blob_exchange').setLevel(self.VERBOSITY)
|
||||
logging.getLogger('lbry.daemon').setLevel(self.VERBOSITY)
|
||||
logging.getLogger('lbry.stream').setLevel(self.VERBOSITY)
|
||||
logging.getLogger('lbry.wallet').setLevel(self.VERBOSITY)
|
||||
|
||||
self.daemons = []
|
||||
self.extra_wallet_nodes = []
|
||||
|
@ -419,9 +405,7 @@ class CommandTestCase(IntegrationTestCase):
|
|||
return txid
|
||||
|
||||
async def on_transaction_dict(self, tx):
|
||||
await self.ledger.wait(
|
||||
self.ledger.transaction_class(unhexlify(tx['hex']))
|
||||
)
|
||||
await self.ledger.wait(Transaction(unhexlify(tx['hex'])))
|
||||
|
||||
@staticmethod
|
||||
def get_all_addresses(tx):
|
||||
|
|
|
@ -6,6 +6,12 @@ __node_url__ = (
|
|||
)
|
||||
__spvserver__ = 'lbry.wallet.server.coin.LBCRegTest'
|
||||
|
||||
from lbry.wallet.manager import LbryWalletManager
|
||||
from lbry.wallet.network import Network
|
||||
from lbry.wallet.ledger import MainNetLedger, RegTestLedger, TestNetLedger
|
||||
from .wallet import Wallet, WalletStorage, TimestampedPreferences, ENCRYPT_ON_DISK
|
||||
from .manager import WalletManager
|
||||
from .network import Network
|
||||
from .ledger import Ledger, RegTestLedger, TestNetLedger, BlockHeightEvent
|
||||
from .account import Account, AddressManager, SingleKey, HierarchicalDeterministic
|
||||
from .transaction import Transaction, Output, Input
|
||||
from .script import OutputScript, InputScript
|
||||
from .database import SQLiteMixin, Database
|
||||
from .header import Headers
|
||||
|
|
|
@ -1,14 +1,28 @@
|
|||
import os
|
||||
import time
|
||||
import json
|
||||
import ecdsa
|
||||
import logging
|
||||
import typing
|
||||
import asyncio
|
||||
import random
|
||||
|
||||
from functools import partial
|
||||
from hashlib import sha256
|
||||
from string import hexdigits
|
||||
from typing import Type, Dict, Tuple, Optional, Any, List
|
||||
|
||||
import ecdsa
|
||||
from lbry.wallet.constants import CLAIM_TYPES, TXO_TYPES
|
||||
from lbry.error import InvalidPasswordError
|
||||
from lbry.crypto.crypt import aes_encrypt, aes_decrypt
|
||||
|
||||
from lbry.wallet.client.baseaccount import BaseAccount, HierarchicalDeterministic
|
||||
from .bip32 import PrivateKey, PubKey, from_extended_key_string
|
||||
from .mnemonic import Mnemonic
|
||||
from .constants import COIN, CLAIM_TYPES, TXO_TYPES
|
||||
from .transaction import Transaction, Input, Output
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from .ledger import Ledger
|
||||
from .wallet import Wallet
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -22,22 +36,483 @@ def validate_claim_id(claim_id):
|
|||
raise Exception("Claim id is not hex encoded")
|
||||
|
||||
|
||||
class Account(BaseAccount):
|
||||
class AddressManager:
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.channel_keys = {}
|
||||
name: str
|
||||
|
||||
__slots__ = 'account', 'public_key', 'chain_number', 'address_generator_lock'
|
||||
|
||||
def __init__(self, account, public_key, chain_number):
|
||||
self.account = account
|
||||
self.public_key = public_key
|
||||
self.chain_number = chain_number
|
||||
self.address_generator_lock = asyncio.Lock()
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, account: 'Account', d: dict) \
|
||||
-> Tuple['AddressManager', 'AddressManager']:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def to_dict(cls, receiving: 'AddressManager', change: 'AddressManager') -> Dict:
|
||||
d: Dict[str, Any] = {'name': cls.name}
|
||||
receiving_dict = receiving.to_dict_instance()
|
||||
if receiving_dict:
|
||||
d['receiving'] = receiving_dict
|
||||
change_dict = change.to_dict_instance()
|
||||
if change_dict:
|
||||
d['change'] = change_dict
|
||||
return d
|
||||
|
||||
def merge(self, d: dict):
|
||||
pass
|
||||
|
||||
def to_dict_instance(self) -> Optional[dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _query_addresses(self, **constraints):
|
||||
return self.account.ledger.db.get_addresses(
|
||||
accounts=[self.account],
|
||||
chain=self.chain_number,
|
||||
**constraints
|
||||
)
|
||||
|
||||
def get_private_key(self, index: int) -> PrivateKey:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_public_key(self, index: int) -> PubKey:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_max_gap(self):
|
||||
raise NotImplementedError
|
||||
|
||||
async def ensure_address_gap(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_address_records(self, only_usable: bool = False, **constraints):
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_addresses(self, only_usable: bool = False, **constraints) -> List[str]:
|
||||
records = await self.get_address_records(only_usable=only_usable, **constraints)
|
||||
return [r['address'] for r in records]
|
||||
|
||||
async def get_or_create_usable_address(self) -> str:
|
||||
addresses = await self.get_addresses(only_usable=True, limit=10)
|
||||
if addresses:
|
||||
return random.choice(addresses)
|
||||
addresses = await self.ensure_address_gap()
|
||||
return addresses[0]
|
||||
|
||||
|
||||
class HierarchicalDeterministic(AddressManager):
|
||||
""" Implements simple version of Bitcoin Hierarchical Deterministic key management. """
|
||||
|
||||
name: str = "deterministic-chain"
|
||||
|
||||
__slots__ = 'gap', 'maximum_uses_per_address'
|
||||
|
||||
def __init__(self, account: 'Account', chain: int, gap: int, maximum_uses_per_address: int) -> None:
|
||||
super().__init__(account, account.public_key.child(chain), chain)
|
||||
self.gap = gap
|
||||
self.maximum_uses_per_address = maximum_uses_per_address
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, account: 'Account', d: dict) -> Tuple[AddressManager, AddressManager]:
|
||||
return (
|
||||
cls(account, 0, **d.get('receiving', {'gap': 20, 'maximum_uses_per_address': 1})),
|
||||
cls(account, 1, **d.get('change', {'gap': 6, 'maximum_uses_per_address': 1}))
|
||||
)
|
||||
|
||||
def merge(self, d: dict):
|
||||
self.gap = d.get('gap', self.gap)
|
||||
self.maximum_uses_per_address = d.get('maximum_uses_per_address', self.maximum_uses_per_address)
|
||||
|
||||
def to_dict_instance(self):
|
||||
return {'gap': self.gap, 'maximum_uses_per_address': self.maximum_uses_per_address}
|
||||
|
||||
def get_private_key(self, index: int) -> PrivateKey:
|
||||
return self.account.private_key.child(self.chain_number).child(index)
|
||||
|
||||
def get_public_key(self, index: int) -> PubKey:
|
||||
return self.account.public_key.child(self.chain_number).child(index)
|
||||
|
||||
async def get_max_gap(self) -> int:
|
||||
addresses = await self._query_addresses(order_by="n asc")
|
||||
max_gap = 0
|
||||
current_gap = 0
|
||||
for address in addresses:
|
||||
if address['used_times'] == 0:
|
||||
current_gap += 1
|
||||
else:
|
||||
max_gap = max(max_gap, current_gap)
|
||||
current_gap = 0
|
||||
return max_gap
|
||||
|
||||
async def ensure_address_gap(self) -> List[str]:
|
||||
async with self.address_generator_lock:
|
||||
addresses = await self._query_addresses(limit=self.gap, order_by="n desc")
|
||||
|
||||
existing_gap = 0
|
||||
for address in addresses:
|
||||
if address['used_times'] == 0:
|
||||
existing_gap += 1
|
||||
else:
|
||||
break
|
||||
|
||||
if existing_gap == self.gap:
|
||||
return []
|
||||
|
||||
start = addresses[0]['pubkey'].n+1 if addresses else 0
|
||||
end = start + (self.gap - existing_gap)
|
||||
new_keys = await self._generate_keys(start, end-1)
|
||||
await self.account.ledger.announce_addresses(self, new_keys)
|
||||
return new_keys
|
||||
|
||||
async def _generate_keys(self, start: int, end: int) -> List[str]:
|
||||
if not self.address_generator_lock.locked():
|
||||
raise RuntimeError('Should not be called outside of address_generator_lock.')
|
||||
keys = [self.public_key.child(index) for index in range(start, end+1)]
|
||||
await self.account.ledger.db.add_keys(self.account, self.chain_number, keys)
|
||||
return [key.address for key in keys]
|
||||
|
||||
def get_address_records(self, only_usable: bool = False, **constraints):
|
||||
if only_usable:
|
||||
constraints['used_times__lt'] = self.maximum_uses_per_address
|
||||
if 'order_by' not in constraints:
|
||||
constraints['order_by'] = "used_times asc, n asc"
|
||||
return self._query_addresses(**constraints)
|
||||
|
||||
|
||||
class SingleKey(AddressManager):
|
||||
""" Single Key address manager always returns the same address for all operations. """
|
||||
|
||||
name: str = "single-address"
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, account: 'Account', d: dict) \
|
||||
-> Tuple[AddressManager, AddressManager]:
|
||||
same_address_manager = cls(account, account.public_key, 0)
|
||||
return same_address_manager, same_address_manager
|
||||
|
||||
def to_dict_instance(self):
|
||||
return None
|
||||
|
||||
def get_private_key(self, index: int) -> PrivateKey:
|
||||
return self.account.private_key
|
||||
|
||||
def get_public_key(self, index: int) -> PubKey:
|
||||
return self.account.public_key
|
||||
|
||||
async def get_max_gap(self) -> int:
|
||||
return 0
|
||||
|
||||
async def ensure_address_gap(self) -> List[str]:
|
||||
async with self.address_generator_lock:
|
||||
exists = await self.get_address_records()
|
||||
if not exists:
|
||||
await self.account.ledger.db.add_keys(self.account, self.chain_number, [self.public_key])
|
||||
new_keys = [self.public_key.address]
|
||||
await self.account.ledger.announce_addresses(self, new_keys)
|
||||
return new_keys
|
||||
return []
|
||||
|
||||
def get_address_records(self, only_usable: bool = False, **constraints):
|
||||
return self._query_addresses(**constraints)
|
||||
|
||||
|
||||
class Account:
|
||||
|
||||
mnemonic_class = Mnemonic
|
||||
private_key_class = PrivateKey
|
||||
public_key_class = PubKey
|
||||
address_generators: Dict[str, Type[AddressManager]] = {
|
||||
SingleKey.name: SingleKey,
|
||||
HierarchicalDeterministic.name: HierarchicalDeterministic,
|
||||
}
|
||||
|
||||
def __init__(self, ledger: 'Ledger', wallet: 'Wallet', name: str,
|
||||
seed: str, private_key_string: str, encrypted: bool,
|
||||
private_key: Optional[PrivateKey], public_key: PubKey,
|
||||
address_generator: dict, modified_on: float, channel_keys: dict) -> None:
|
||||
self.ledger = ledger
|
||||
self.wallet = wallet
|
||||
self.id = public_key.address
|
||||
self.name = name
|
||||
self.seed = seed
|
||||
self.modified_on = modified_on
|
||||
self.private_key_string = private_key_string
|
||||
self.init_vectors: Dict[str, bytes] = {}
|
||||
self.encrypted = encrypted
|
||||
self.private_key = private_key
|
||||
self.public_key = public_key
|
||||
generator_name = address_generator.get('name', HierarchicalDeterministic.name)
|
||||
self.address_generator = self.address_generators[generator_name]
|
||||
self.receiving, self.change = self.address_generator.from_dict(self, address_generator)
|
||||
self.address_managers = {am.chain_number: am for am in {self.receiving, self.change}}
|
||||
self.channel_keys = channel_keys
|
||||
ledger.add_account(self)
|
||||
wallet.add_account(self)
|
||||
|
||||
def get_init_vector(self, key) -> Optional[bytes]:
|
||||
init_vector = self.init_vectors.get(key, None)
|
||||
if init_vector is None:
|
||||
init_vector = self.init_vectors[key] = os.urandom(16)
|
||||
return init_vector
|
||||
|
||||
@classmethod
|
||||
def generate(cls, ledger: 'Ledger', wallet: 'Wallet',
|
||||
name: str = None, address_generator: dict = None):
|
||||
return cls.from_dict(ledger, wallet, {
|
||||
'name': name,
|
||||
'seed': cls.mnemonic_class().make_seed(),
|
||||
'address_generator': address_generator or {}
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def get_private_key_from_seed(cls, ledger: 'Ledger', seed: str, password: str):
|
||||
return cls.private_key_class.from_seed(
|
||||
ledger, cls.mnemonic_class.mnemonic_to_seed(seed, password or 'lbryum')
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def keys_from_dict(cls, ledger: 'Ledger', d: dict) \
|
||||
-> Tuple[str, Optional[PrivateKey], PubKey]:
|
||||
seed = d.get('seed', '')
|
||||
private_key_string = d.get('private_key', '')
|
||||
private_key = None
|
||||
public_key = None
|
||||
encrypted = d.get('encrypted', False)
|
||||
if not encrypted:
|
||||
if seed:
|
||||
private_key = cls.get_private_key_from_seed(ledger, seed, '')
|
||||
public_key = private_key.public_key
|
||||
elif private_key_string:
|
||||
private_key = from_extended_key_string(ledger, private_key_string)
|
||||
public_key = private_key.public_key
|
||||
if public_key is None:
|
||||
public_key = from_extended_key_string(ledger, d['public_key'])
|
||||
return seed, private_key, public_key
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, ledger: 'Ledger', wallet: 'Wallet', d: dict):
|
||||
seed, private_key, public_key = cls.keys_from_dict(ledger, d)
|
||||
name = d.get('name')
|
||||
if not name:
|
||||
name = f'Account #{public_key.address}'
|
||||
return cls(
|
||||
ledger=ledger,
|
||||
wallet=wallet,
|
||||
name=name,
|
||||
seed=seed,
|
||||
private_key_string=d.get('private_key', ''),
|
||||
encrypted=d.get('encrypted', False),
|
||||
private_key=private_key,
|
||||
public_key=public_key,
|
||||
address_generator=d.get('address_generator', {}),
|
||||
modified_on=d.get('modified_on', time.time()),
|
||||
channel_keys=d.get('certificates', {})
|
||||
)
|
||||
|
||||
def to_dict(self, encrypt_password: str = None, include_channel_keys: bool = True):
|
||||
private_key_string, seed = self.private_key_string, self.seed
|
||||
if not self.encrypted and self.private_key:
|
||||
private_key_string = self.private_key.extended_key_string()
|
||||
if not self.encrypted and encrypt_password:
|
||||
if private_key_string:
|
||||
private_key_string = aes_encrypt(
|
||||
encrypt_password, private_key_string, self.get_init_vector('private_key')
|
||||
)
|
||||
if seed:
|
||||
seed = aes_encrypt(encrypt_password, self.seed, self.get_init_vector('seed'))
|
||||
d = {
|
||||
'ledger': self.ledger.get_id(),
|
||||
'name': self.name,
|
||||
'seed': seed,
|
||||
'encrypted': bool(self.encrypted or encrypt_password),
|
||||
'private_key': private_key_string,
|
||||
'public_key': self.public_key.extended_key_string(),
|
||||
'address_generator': self.address_generator.to_dict(self.receiving, self.change),
|
||||
'modified_on': self.modified_on
|
||||
}
|
||||
if include_channel_keys:
|
||||
d['certificates'] = self.channel_keys
|
||||
return d
|
||||
|
||||
def merge(self, d: dict):
|
||||
if d.get('modified_on', 0) > self.modified_on:
|
||||
self.name = d['name']
|
||||
self.modified_on = d.get('modified_on', time.time())
|
||||
assert self.address_generator.name == d['address_generator']['name']
|
||||
for chain_name in ('change', 'receiving'):
|
||||
if chain_name in d['address_generator']:
|
||||
chain_object = getattr(self, chain_name)
|
||||
chain_object.merge(d['address_generator'][chain_name])
|
||||
self.channel_keys.update(d.get('certificates', {}))
|
||||
|
||||
@property
|
||||
def hash(self) -> bytes:
|
||||
assert not self.encrypted, "Cannot hash an encrypted account."
|
||||
h = sha256(json.dumps(self.to_dict(include_channel_keys=False)).encode())
|
||||
for cert in sorted(self.channel_keys.keys()):
|
||||
h.update(cert.encode())
|
||||
return h.digest()
|
||||
|
||||
def merge(self, d: dict):
|
||||
super().merge(d)
|
||||
self.channel_keys.update(d.get('certificates', {}))
|
||||
async def get_details(self, show_seed=False, **kwargs):
|
||||
satoshis = await self.get_balance(**kwargs)
|
||||
details = {
|
||||
'id': self.id,
|
||||
'name': self.name,
|
||||
'ledger': self.ledger.get_id(),
|
||||
'coins': round(satoshis/COIN, 2),
|
||||
'satoshis': satoshis,
|
||||
'encrypted': self.encrypted,
|
||||
'public_key': self.public_key.extended_key_string(),
|
||||
'address_generator': self.address_generator.to_dict(self.receiving, self.change)
|
||||
}
|
||||
if show_seed:
|
||||
details['seed'] = self.seed
|
||||
details['certificates'] = len(self.channel_keys)
|
||||
return details
|
||||
|
||||
def decrypt(self, password: str) -> bool:
|
||||
assert self.encrypted, "Key is not encrypted."
|
||||
try:
|
||||
seed = self._decrypt_seed(password)
|
||||
except (ValueError, InvalidPasswordError):
|
||||
return False
|
||||
try:
|
||||
private_key = self._decrypt_private_key_string(password)
|
||||
except (TypeError, ValueError, InvalidPasswordError):
|
||||
return False
|
||||
self.seed = seed
|
||||
self.private_key = private_key
|
||||
self.private_key_string = ""
|
||||
self.encrypted = False
|
||||
return True
|
||||
|
||||
def _decrypt_private_key_string(self, password: str) -> Optional[PrivateKey]:
|
||||
if not self.private_key_string:
|
||||
return None
|
||||
private_key_string, self.init_vectors['private_key'] = aes_decrypt(password, self.private_key_string)
|
||||
if not private_key_string:
|
||||
return None
|
||||
return from_extended_key_string(
|
||||
self.ledger, private_key_string
|
||||
)
|
||||
|
||||
def _decrypt_seed(self, password: str) -> str:
|
||||
if not self.seed:
|
||||
return ""
|
||||
seed, self.init_vectors['seed'] = aes_decrypt(password, self.seed)
|
||||
if not seed:
|
||||
return ""
|
||||
try:
|
||||
Mnemonic().mnemonic_decode(seed)
|
||||
except IndexError:
|
||||
# failed to decode the seed, this either means it decrypted and is invalid
|
||||
# or that we hit an edge case where an incorrect password gave valid padding
|
||||
raise ValueError("Failed to decode seed.")
|
||||
return seed
|
||||
|
||||
def encrypt(self, password: str) -> bool:
|
||||
assert not self.encrypted, "Key is already encrypted."
|
||||
if self.seed:
|
||||
self.seed = aes_encrypt(password, self.seed, self.get_init_vector('seed'))
|
||||
if isinstance(self.private_key, PrivateKey):
|
||||
self.private_key_string = aes_encrypt(
|
||||
password, self.private_key.extended_key_string(), self.get_init_vector('private_key')
|
||||
)
|
||||
self.private_key = None
|
||||
self.encrypted = True
|
||||
return True
|
||||
|
||||
async def ensure_address_gap(self):
|
||||
addresses = []
|
||||
for address_manager in self.address_managers.values():
|
||||
new_addresses = await address_manager.ensure_address_gap()
|
||||
addresses.extend(new_addresses)
|
||||
return addresses
|
||||
|
||||
async def get_addresses(self, **constraints) -> List[str]:
|
||||
rows = await self.ledger.db.select_addresses('address', accounts=[self], **constraints)
|
||||
return [r[0] for r in rows]
|
||||
|
||||
def get_address_records(self, **constraints):
|
||||
return self.ledger.db.get_addresses(accounts=[self], **constraints)
|
||||
|
||||
def get_address_count(self, **constraints):
|
||||
return self.ledger.db.get_address_count(accounts=[self], **constraints)
|
||||
|
||||
def get_private_key(self, chain: int, index: int) -> PrivateKey:
|
||||
assert not self.encrypted, "Cannot get private key on encrypted wallet account."
|
||||
return self.address_managers[chain].get_private_key(index)
|
||||
|
||||
def get_public_key(self, chain: int, index: int) -> PubKey:
|
||||
return self.address_managers[chain].get_public_key(index)
|
||||
|
||||
def get_balance(self, confirmations: int = 0, include_claims=False, **constraints):
|
||||
if not include_claims:
|
||||
constraints.update({'txo_type__in': (0, TXO_TYPES['purchase'])})
|
||||
if confirmations > 0:
|
||||
height = self.ledger.headers.height - (confirmations-1)
|
||||
constraints.update({'height__lte': height, 'height__gt': 0})
|
||||
return self.ledger.db.get_balance(accounts=[self], **constraints)
|
||||
|
||||
async def get_max_gap(self):
|
||||
change_gap = await self.change.get_max_gap()
|
||||
receiving_gap = await self.receiving.get_max_gap()
|
||||
return {
|
||||
'max_change_gap': change_gap,
|
||||
'max_receiving_gap': receiving_gap,
|
||||
}
|
||||
|
||||
def get_utxos(self, **constraints):
|
||||
return self.ledger.get_utxos(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
def get_utxo_count(self, **constraints):
|
||||
return self.ledger.get_utxo_count(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
def get_transactions(self, **constraints):
|
||||
return self.ledger.get_transactions(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
def get_transaction_count(self, **constraints):
|
||||
return self.ledger.get_transaction_count(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
async def fund(self, to_account, amount=None, everything=False,
|
||||
outputs=1, broadcast=False, **constraints):
|
||||
assert self.ledger == to_account.ledger, 'Can only transfer between accounts of the same ledger.'
|
||||
if everything:
|
||||
utxos = await self.get_utxos(**constraints)
|
||||
await self.ledger.reserve_outputs(utxos)
|
||||
tx = await Transaction.create(
|
||||
inputs=[Input.spend(txo) for txo in utxos],
|
||||
outputs=[],
|
||||
funding_accounts=[self],
|
||||
change_account=to_account
|
||||
)
|
||||
elif amount > 0:
|
||||
to_address = await to_account.change.get_or_create_usable_address()
|
||||
to_hash160 = to_account.ledger.address_to_hash160(to_address)
|
||||
tx = await Transaction.create(
|
||||
inputs=[],
|
||||
outputs=[
|
||||
Output.pay_pubkey_hash(amount//outputs, to_hash160)
|
||||
for _ in range(outputs)
|
||||
],
|
||||
funding_accounts=[self],
|
||||
change_account=self
|
||||
)
|
||||
else:
|
||||
raise ValueError('An amount is required.')
|
||||
|
||||
if broadcast:
|
||||
await self.ledger.broadcast(tx)
|
||||
else:
|
||||
await self.ledger.release_tx(tx)
|
||||
|
||||
return tx
|
||||
|
||||
def add_channel_private_key(self, private_key):
|
||||
public_key_bytes = private_key.get_verifying_key().to_der()
|
||||
|
@ -81,11 +556,6 @@ class Account(BaseAccount):
|
|||
if gap_changed:
|
||||
self.wallet.save()
|
||||
|
||||
def get_balance(self, confirmations=0, include_claims=False, **constraints):
|
||||
if not include_claims:
|
||||
constraints.update({'txo_type__in': (0, TXO_TYPES['purchase'])})
|
||||
return super().get_balance(confirmations, **constraints)
|
||||
|
||||
async def get_detailed_balance(self, confirmations=0, reserved_subtotals=False):
|
||||
tips_balance, supports_balance, claims_balance = 0, 0, 0
|
||||
get_total_balance = partial(self.get_balance, confirmations=confirmations, include_claims=True)
|
||||
|
@ -116,29 +586,6 @@ class Account(BaseAccount):
|
|||
} if reserved_subtotals else None
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_private_key_from_seed(cls, ledger, seed: str, password: str):
|
||||
return super().get_private_key_from_seed(
|
||||
ledger, seed, password or 'lbryum'
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, ledger, wallet, d: dict) -> 'Account':
|
||||
account = super().from_dict(ledger, wallet, d)
|
||||
account.channel_keys = d.get('certificates', {})
|
||||
return account
|
||||
|
||||
def to_dict(self, encrypt_password: str = None, include_channel_keys: bool = True):
|
||||
d = super().to_dict(encrypt_password)
|
||||
if include_channel_keys:
|
||||
d['certificates'] = self.channel_keys
|
||||
return d
|
||||
|
||||
async def get_details(self, **kwargs):
|
||||
details = await super().get_details(**kwargs)
|
||||
details['certificates'] = len(self.channel_keys)
|
||||
return details
|
||||
|
||||
def get_transaction_history(self, **constraints):
|
||||
return self.ledger.get_transaction_history(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ from coincurve import PublicKey, PrivateKey as _PrivateKey
|
|||
|
||||
from lbry.crypto.hash import hmac_sha512, hash160, double_sha256
|
||||
from lbry.crypto.base58 import Base58
|
||||
from lbry.wallet.client.util import cachedproperty
|
||||
from .util import cachedproperty
|
||||
|
||||
|
||||
class DerivationError(Exception):
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from random import Random
|
||||
from typing import List
|
||||
|
||||
from lbry.wallet.client import basetransaction
|
||||
from lbry.wallet.transaction import OutputEffectiveAmountEstimator
|
||||
|
||||
MAXIMUM_TRIES = 100000
|
||||
|
||||
|
@ -25,8 +25,8 @@ class CoinSelector:
|
|||
self.random.seed(seed, version=1)
|
||||
|
||||
def select(
|
||||
self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator],
|
||||
strategy_name: str = None) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
|
||||
self, txos: List[OutputEffectiveAmountEstimator],
|
||||
strategy_name: str = None) -> List[OutputEffectiveAmountEstimator]:
|
||||
if not txos:
|
||||
return []
|
||||
available = sum(c.effective_amount for c in txos)
|
||||
|
@ -35,16 +35,16 @@ class CoinSelector:
|
|||
return getattr(self, strategy_name or "standard")(txos, available)
|
||||
|
||||
@strategy
|
||||
def prefer_confirmed(self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator],
|
||||
available: int) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
|
||||
def prefer_confirmed(self, txos: List[OutputEffectiveAmountEstimator],
|
||||
available: int) -> List[OutputEffectiveAmountEstimator]:
|
||||
return (
|
||||
self.only_confirmed(txos, available) or
|
||||
self.standard(txos, available)
|
||||
)
|
||||
|
||||
@strategy
|
||||
def only_confirmed(self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator],
|
||||
_) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
|
||||
def only_confirmed(self, txos: List[OutputEffectiveAmountEstimator],
|
||||
_) -> List[OutputEffectiveAmountEstimator]:
|
||||
confirmed = [t for t in txos if t.txo.tx_ref and t.txo.tx_ref.height > 0]
|
||||
if not confirmed:
|
||||
return []
|
||||
|
@ -54,8 +54,8 @@ class CoinSelector:
|
|||
return self.standard(confirmed, confirmed_available)
|
||||
|
||||
@strategy
|
||||
def standard(self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator],
|
||||
available: int) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
|
||||
def standard(self, txos: List[OutputEffectiveAmountEstimator],
|
||||
available: int) -> List[OutputEffectiveAmountEstimator]:
|
||||
return (
|
||||
self.branch_and_bound(txos, available) or
|
||||
self.closest_match(txos, available) or
|
||||
|
@ -63,8 +63,8 @@ class CoinSelector:
|
|||
)
|
||||
|
||||
@strategy
|
||||
def branch_and_bound(self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator],
|
||||
available: int) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
|
||||
def branch_and_bound(self, txos: List[OutputEffectiveAmountEstimator],
|
||||
available: int) -> List[OutputEffectiveAmountEstimator]:
|
||||
# see bitcoin implementation for more info:
|
||||
# https://github.com/bitcoin/bitcoin/blob/master/src/wallet/coinselection.cpp
|
||||
|
||||
|
@ -123,8 +123,8 @@ class CoinSelector:
|
|||
return []
|
||||
|
||||
@strategy
|
||||
def closest_match(self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator],
|
||||
_) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
|
||||
def closest_match(self, txos: List[OutputEffectiveAmountEstimator],
|
||||
_) -> List[OutputEffectiveAmountEstimator]:
|
||||
""" Pick one UTXOs that is larger than the target but with the smallest change. """
|
||||
target = self.target + self.cost_of_change
|
||||
smallest_change = None
|
||||
|
@ -137,8 +137,8 @@ class CoinSelector:
|
|||
return [best_match] if best_match else []
|
||||
|
||||
@strategy
|
||||
def random_draw(self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator],
|
||||
_) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
|
||||
def random_draw(self, txos: List[OutputEffectiveAmountEstimator],
|
||||
_) -> List[OutputEffectiveAmountEstimator]:
|
||||
""" Accumulate UTXOs at random until there is enough to cover the target. """
|
||||
target = self.target + self.cost_of_change
|
||||
self.random.shuffle(txos, self.random.random)
|
||||
|
|
|
@ -1,3 +1,10 @@
|
|||
NULL_HASH32 = b'\x00'*32
|
||||
|
||||
CENT = 1000000
|
||||
COIN = 100*CENT
|
||||
|
||||
TIMEOUT = 30.0
|
||||
|
||||
TXO_TYPES = {
|
||||
"stream": 1,
|
||||
"channel": 2,
|
||||
|
|
|
@ -1,14 +1,321 @@
|
|||
from typing import List
|
||||
import logging
|
||||
import asyncio
|
||||
import sqlite3
|
||||
|
||||
from lbry.wallet.client.basedatabase import BaseDatabase
|
||||
from binascii import hexlify
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from typing import Tuple, List, Union, Callable, Any, Awaitable, Iterable, Dict, Optional
|
||||
|
||||
from lbry.wallet.transaction import Output
|
||||
from lbry.wallet.constants import TXO_TYPES, CLAIM_TYPES
|
||||
from .bip32 import PubKey
|
||||
from .transaction import Transaction, Output, OutputScript, TXRefImmutable
|
||||
from .constants import TXO_TYPES, CLAIM_TYPES
|
||||
|
||||
|
||||
class WalletDatabase(BaseDatabase):
|
||||
log = logging.getLogger(__name__)
|
||||
sqlite3.enable_callback_tracebacks(True)
|
||||
|
||||
SCHEMA_VERSION = f"{BaseDatabase.SCHEMA_VERSION}+1"
|
||||
|
||||
class AIOSQLite:
|
||||
|
||||
def __init__(self):
|
||||
# has to be single threaded as there is no mapping of thread:connection
|
||||
self.executor = ThreadPoolExecutor(max_workers=1)
|
||||
self.connection: sqlite3.Connection = None
|
||||
self._closing = False
|
||||
self.query_count = 0
|
||||
|
||||
@classmethod
|
||||
async def connect(cls, path: Union[bytes, str], *args, **kwargs):
|
||||
sqlite3.enable_callback_tracebacks(True)
|
||||
def _connect():
|
||||
return sqlite3.connect(path, *args, **kwargs)
|
||||
db = cls()
|
||||
db.connection = await asyncio.get_event_loop().run_in_executor(db.executor, _connect)
|
||||
return db
|
||||
|
||||
async def close(self):
|
||||
if self._closing:
|
||||
return
|
||||
self._closing = True
|
||||
await asyncio.get_event_loop().run_in_executor(self.executor, self.connection.close)
|
||||
self.executor.shutdown(wait=True)
|
||||
self.connection = None
|
||||
|
||||
def executemany(self, sql: str, params: Iterable):
|
||||
params = params if params is not None else []
|
||||
# this fetchall is needed to prevent SQLITE_MISUSE
|
||||
return self.run(lambda conn: conn.executemany(sql, params).fetchall())
|
||||
|
||||
def executescript(self, script: str) -> Awaitable:
|
||||
return self.run(lambda conn: conn.executescript(script))
|
||||
|
||||
def execute_fetchall(self, sql: str, parameters: Iterable = None) -> Awaitable[Iterable[sqlite3.Row]]:
|
||||
parameters = parameters if parameters is not None else []
|
||||
return self.run(lambda conn: conn.execute(sql, parameters).fetchall())
|
||||
|
||||
def execute_fetchone(self, sql: str, parameters: Iterable = None) -> Awaitable[Iterable[sqlite3.Row]]:
|
||||
parameters = parameters if parameters is not None else []
|
||||
return self.run(lambda conn: conn.execute(sql, parameters).fetchone())
|
||||
|
||||
def execute(self, sql: str, parameters: Iterable = None) -> Awaitable[sqlite3.Cursor]:
|
||||
parameters = parameters if parameters is not None else []
|
||||
return self.run(lambda conn: conn.execute(sql, parameters))
|
||||
|
||||
def run(self, fun, *args, **kwargs) -> Awaitable:
|
||||
return asyncio.get_event_loop().run_in_executor(
|
||||
self.executor, lambda: self.__run_transaction(fun, *args, **kwargs)
|
||||
)
|
||||
|
||||
def __run_transaction(self, fun: Callable[[sqlite3.Connection, Any, Any], Any], *args, **kwargs):
|
||||
self.connection.execute('begin')
|
||||
try:
|
||||
self.query_count += 1
|
||||
result = fun(self.connection, *args, **kwargs) # type: ignore
|
||||
self.connection.commit()
|
||||
return result
|
||||
except (Exception, OSError) as e:
|
||||
log.exception('Error running transaction:', exc_info=e)
|
||||
self.connection.rollback()
|
||||
log.warning("rolled back")
|
||||
raise
|
||||
|
||||
def run_with_foreign_keys_disabled(self, fun, *args, **kwargs) -> Awaitable:
|
||||
return asyncio.get_event_loop().run_in_executor(
|
||||
self.executor, self.__run_transaction_with_foreign_keys_disabled, fun, args, kwargs
|
||||
)
|
||||
|
||||
def __run_transaction_with_foreign_keys_disabled(self,
|
||||
fun: Callable[[sqlite3.Connection, Any, Any], Any],
|
||||
args, kwargs):
|
||||
foreign_keys_enabled, = self.connection.execute("pragma foreign_keys").fetchone()
|
||||
if not foreign_keys_enabled:
|
||||
raise sqlite3.IntegrityError("foreign keys are disabled, use `AIOSQLite.run` instead")
|
||||
try:
|
||||
self.connection.execute('pragma foreign_keys=off').fetchone()
|
||||
return self.__run_transaction(fun, *args, **kwargs)
|
||||
finally:
|
||||
self.connection.execute('pragma foreign_keys=on').fetchone()
|
||||
|
||||
|
||||
def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''):
|
||||
sql, values = [], {}
|
||||
for key, constraint in constraints.items():
|
||||
tag = '0'
|
||||
if '#' in key:
|
||||
key, tag = key[:key.index('#')], key[key.index('#')+1:]
|
||||
col, op, key = key, '=', key.replace('.', '_')
|
||||
if not key:
|
||||
sql.append(constraint)
|
||||
continue
|
||||
if key.startswith('$'):
|
||||
values[key] = constraint
|
||||
continue
|
||||
if key.endswith('__not'):
|
||||
col, op = col[:-len('__not')], '!='
|
||||
elif key.endswith('__is_null'):
|
||||
col = col[:-len('__is_null')]
|
||||
sql.append(f'{col} IS NULL')
|
||||
continue
|
||||
if key.endswith('__is_not_null'):
|
||||
col = col[:-len('__is_not_null')]
|
||||
sql.append(f'{col} IS NOT NULL')
|
||||
continue
|
||||
if key.endswith('__lt'):
|
||||
col, op = col[:-len('__lt')], '<'
|
||||
elif key.endswith('__lte'):
|
||||
col, op = col[:-len('__lte')], '<='
|
||||
elif key.endswith('__gt'):
|
||||
col, op = col[:-len('__gt')], '>'
|
||||
elif key.endswith('__gte'):
|
||||
col, op = col[:-len('__gte')], '>='
|
||||
elif key.endswith('__like'):
|
||||
col, op = col[:-len('__like')], 'LIKE'
|
||||
elif key.endswith('__not_like'):
|
||||
col, op = col[:-len('__not_like')], 'NOT LIKE'
|
||||
elif key.endswith('__in') or key.endswith('__not_in'):
|
||||
if key.endswith('__in'):
|
||||
col, op = col[:-len('__in')], 'IN'
|
||||
else:
|
||||
col, op = col[:-len('__not_in')], 'NOT IN'
|
||||
if constraint:
|
||||
if isinstance(constraint, (list, set, tuple)):
|
||||
keys = []
|
||||
for i, val in enumerate(constraint):
|
||||
keys.append(f':{key}{tag}_{i}')
|
||||
values[f'{key}{tag}_{i}'] = val
|
||||
sql.append(f'{col} {op} ({", ".join(keys)})')
|
||||
elif isinstance(constraint, str):
|
||||
sql.append(f'{col} {op} ({constraint})')
|
||||
else:
|
||||
raise ValueError(f"{col} requires a list, set or string as constraint value.")
|
||||
continue
|
||||
elif key.endswith('__any') or key.endswith('__or'):
|
||||
where, subvalues = constraints_to_sql(constraint, ' OR ', key+tag+'_')
|
||||
sql.append(f'({where})')
|
||||
values.update(subvalues)
|
||||
continue
|
||||
if key.endswith('__and'):
|
||||
where, subvalues = constraints_to_sql(constraint, ' AND ', key+tag+'_')
|
||||
sql.append(f'({where})')
|
||||
values.update(subvalues)
|
||||
continue
|
||||
sql.append(f'{col} {op} :{prepend_key}{key}{tag}')
|
||||
values[prepend_key+key+tag] = constraint
|
||||
return joiner.join(sql) if sql else '', values
|
||||
|
||||
|
||||
def query(select, **constraints) -> Tuple[str, Dict[str, Any]]:
|
||||
sql = [select]
|
||||
limit = constraints.pop('limit', None)
|
||||
offset = constraints.pop('offset', None)
|
||||
order_by = constraints.pop('order_by', None)
|
||||
|
||||
accounts = constraints.pop('accounts', [])
|
||||
if accounts:
|
||||
constraints['account__in'] = [a.public_key.address for a in accounts]
|
||||
|
||||
where, values = constraints_to_sql(constraints)
|
||||
if where:
|
||||
sql.append('WHERE')
|
||||
sql.append(where)
|
||||
|
||||
if order_by:
|
||||
sql.append('ORDER BY')
|
||||
if isinstance(order_by, str):
|
||||
sql.append(order_by)
|
||||
elif isinstance(order_by, list):
|
||||
sql.append(', '.join(order_by))
|
||||
else:
|
||||
raise ValueError("order_by must be string or list")
|
||||
|
||||
if limit is not None:
|
||||
sql.append(f'LIMIT {limit}')
|
||||
|
||||
if offset is not None:
|
||||
sql.append(f'OFFSET {offset}')
|
||||
|
||||
return ' '.join(sql), values
|
||||
|
||||
|
||||
def interpolate(sql, values):
|
||||
for k in sorted(values.keys(), reverse=True):
|
||||
value = values[k]
|
||||
if isinstance(value, bytes):
|
||||
value = f"X'{hexlify(value).decode()}'"
|
||||
elif isinstance(value, str):
|
||||
value = f"'{value}'"
|
||||
else:
|
||||
value = str(value)
|
||||
sql = sql.replace(f":{k}", value)
|
||||
return sql
|
||||
|
||||
|
||||
def rows_to_dict(rows, fields):
|
||||
if rows:
|
||||
return [dict(zip(fields, r)) for r in rows]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
class SQLiteMixin:
|
||||
|
||||
SCHEMA_VERSION: Optional[str] = None
|
||||
CREATE_TABLES_QUERY: str
|
||||
MAX_QUERY_VARIABLES = 900
|
||||
|
||||
CREATE_VERSION_TABLE = """
|
||||
create table if not exists version (
|
||||
version text
|
||||
);
|
||||
"""
|
||||
|
||||
def __init__(self, path):
|
||||
self._db_path = path
|
||||
self.db: AIOSQLite = None
|
||||
self.ledger = None
|
||||
|
||||
async def open(self):
|
||||
log.info("connecting to database: %s", self._db_path)
|
||||
self.db = await AIOSQLite.connect(self._db_path, isolation_level=None)
|
||||
if self.SCHEMA_VERSION:
|
||||
tables = [t[0] for t in await self.db.execute_fetchall(
|
||||
"SELECT name FROM sqlite_master WHERE type='table';"
|
||||
)]
|
||||
if tables:
|
||||
if 'version' in tables:
|
||||
version = await self.db.execute_fetchone("SELECT version FROM version LIMIT 1;")
|
||||
if version == (self.SCHEMA_VERSION,):
|
||||
return
|
||||
await self.db.executescript('\n'.join(
|
||||
f"DROP TABLE {table};" for table in tables
|
||||
))
|
||||
await self.db.execute(self.CREATE_VERSION_TABLE)
|
||||
await self.db.execute("INSERT INTO version VALUES (?)", (self.SCHEMA_VERSION,))
|
||||
await self.db.executescript(self.CREATE_TABLES_QUERY)
|
||||
|
||||
async def close(self):
|
||||
await self.db.close()
|
||||
|
||||
@staticmethod
|
||||
def _insert_sql(table: str, data: dict, ignore_duplicate: bool = False,
|
||||
replace: bool = False) -> Tuple[str, List]:
|
||||
columns, values = [], []
|
||||
for column, value in data.items():
|
||||
columns.append(column)
|
||||
values.append(value)
|
||||
policy = ""
|
||||
if ignore_duplicate:
|
||||
policy = " OR IGNORE"
|
||||
if replace:
|
||||
policy = " OR REPLACE"
|
||||
sql = "INSERT{} INTO {} ({}) VALUES ({})".format(
|
||||
policy, table, ', '.join(columns), ', '.join(['?'] * len(values))
|
||||
)
|
||||
return sql, values
|
||||
|
||||
@staticmethod
|
||||
def _update_sql(table: str, data: dict, where: str,
|
||||
constraints: Union[list, tuple]) -> Tuple[str, list]:
|
||||
columns, values = [], []
|
||||
for column, value in data.items():
|
||||
columns.append(f"{column} = ?")
|
||||
values.append(value)
|
||||
values.extend(constraints)
|
||||
sql = "UPDATE {} SET {} WHERE {}".format(
|
||||
table, ', '.join(columns), where
|
||||
)
|
||||
return sql, values
|
||||
|
||||
|
||||
class Database(SQLiteMixin):
|
||||
|
||||
SCHEMA_VERSION = "1.1"
|
||||
|
||||
PRAGMAS = """
|
||||
pragma journal_mode=WAL;
|
||||
"""
|
||||
|
||||
CREATE_ACCOUNT_TABLE = """
|
||||
create table if not exists account_address (
|
||||
account text not null,
|
||||
address text not null,
|
||||
chain integer not null,
|
||||
pubkey blob not null,
|
||||
chain_code blob not null,
|
||||
n integer not null,
|
||||
depth integer not null,
|
||||
primary key (account, address)
|
||||
);
|
||||
create index if not exists address_account_idx on account_address (address, account);
|
||||
"""
|
||||
|
||||
CREATE_PUBKEY_ADDRESS_TABLE = """
|
||||
create table if not exists pubkey_address (
|
||||
address text primary key,
|
||||
history text,
|
||||
used_times integer not null default 0
|
||||
);
|
||||
"""
|
||||
|
||||
CREATE_TX_TABLE = """
|
||||
create table if not exists tx (
|
||||
|
@ -42,25 +349,35 @@ class WalletDatabase(BaseDatabase):
|
|||
create index if not exists txo_txo_type_idx on txo (txo_type);
|
||||
"""
|
||||
|
||||
CREATE_TXI_TABLE = """
|
||||
create table if not exists txi (
|
||||
txid text references tx,
|
||||
txoid text references txo,
|
||||
address text references pubkey_address
|
||||
);
|
||||
create index if not exists txi_address_idx on txi (address);
|
||||
create index if not exists txi_txoid_idx on txi (txoid);
|
||||
"""
|
||||
|
||||
CREATE_TABLES_QUERY = (
|
||||
BaseDatabase.PRAGMAS +
|
||||
BaseDatabase.CREATE_ACCOUNT_TABLE +
|
||||
BaseDatabase.CREATE_PUBKEY_ADDRESS_TABLE +
|
||||
PRAGMAS +
|
||||
CREATE_ACCOUNT_TABLE +
|
||||
CREATE_PUBKEY_ADDRESS_TABLE +
|
||||
CREATE_TX_TABLE +
|
||||
CREATE_TXO_TABLE +
|
||||
BaseDatabase.CREATE_TXI_TABLE
|
||||
CREATE_TXI_TABLE
|
||||
)
|
||||
|
||||
def tx_to_row(self, tx):
|
||||
row = super().tx_to_row(tx)
|
||||
txos = tx.outputs
|
||||
if len(txos) >= 2 and txos[1].can_decode_purchase_data:
|
||||
txos[0].purchase = txos[1]
|
||||
row['purchased_claim_id'] = txos[1].purchase_data.claim_id
|
||||
return row
|
||||
|
||||
def txo_to_row(self, tx, address, txo):
|
||||
row = super().txo_to_row(tx, address, txo)
|
||||
@staticmethod
|
||||
def txo_to_row(tx, address, txo):
|
||||
row = {
|
||||
'txid': tx.id,
|
||||
'txoid': txo.id,
|
||||
'address': address,
|
||||
'position': txo.position,
|
||||
'amount': txo.amount,
|
||||
'script': sqlite3.Binary(txo.script.source)
|
||||
}
|
||||
if txo.is_claim:
|
||||
if txo.can_decode_claim:
|
||||
row['txo_type'] = TXO_TYPES.get(txo.claim.claim_type, TXO_TYPES['stream'])
|
||||
|
@ -76,39 +393,212 @@ class WalletDatabase(BaseDatabase):
|
|||
row['claim_name'] = txo.claim_name
|
||||
return row
|
||||
|
||||
async def get_transactions(self, **constraints):
|
||||
txs = await super().get_transactions(**constraints)
|
||||
for tx in txs:
|
||||
@staticmethod
|
||||
def tx_to_row(tx):
|
||||
row = {
|
||||
'txid': tx.id,
|
||||
'raw': sqlite3.Binary(tx.raw),
|
||||
'height': tx.height,
|
||||
'position': tx.position,
|
||||
'is_verified': tx.is_verified
|
||||
}
|
||||
txos = tx.outputs
|
||||
if len(txos) >= 2 and txos[1].can_decode_purchase_data:
|
||||
txos[0].purchase = txos[1]
|
||||
return txs
|
||||
row['purchased_claim_id'] = txos[1].purchase_data.claim_id
|
||||
return row
|
||||
|
||||
@staticmethod
|
||||
def constrain_purchases(constraints):
|
||||
accounts = constraints.pop('accounts', None)
|
||||
assert accounts, "'accounts' argument required to find purchases"
|
||||
if not {'purchased_claim_id', 'purchased_claim_id__in'}.intersection(constraints):
|
||||
constraints['purchased_claim_id__is_not_null'] = True
|
||||
async def insert_transaction(self, tx):
|
||||
await self.db.execute_fetchall(*self._insert_sql('tx', self.tx_to_row(tx)))
|
||||
|
||||
async def update_transaction(self, tx):
|
||||
await self.db.execute_fetchall(*self._update_sql("tx", {
|
||||
'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified
|
||||
}, 'txid = ?', (tx.id,)))
|
||||
|
||||
def _transaction_io(self, conn: sqlite3.Connection, tx: Transaction, address, txhash, history):
|
||||
conn.execute(*self._insert_sql('tx', self.tx_to_row(tx), replace=True))
|
||||
|
||||
for txo in tx.outputs:
|
||||
if txo.script.is_pay_pubkey_hash and txo.pubkey_hash == txhash:
|
||||
conn.execute(*self._insert_sql(
|
||||
"txo", self.txo_to_row(tx, address, txo), ignore_duplicate=True
|
||||
)).fetchall()
|
||||
elif txo.script.is_pay_script_hash:
|
||||
# TODO: implement script hash payments
|
||||
log.warning('Database.save_transaction_io: pay script hash is not implemented!')
|
||||
|
||||
for txi in tx.inputs:
|
||||
if txi.txo_ref.txo is not None:
|
||||
txo = txi.txo_ref.txo
|
||||
if txo.has_address and txo.get_address(self.ledger) == address:
|
||||
conn.execute(*self._insert_sql("txi", {
|
||||
'txid': tx.id,
|
||||
'txoid': txo.id,
|
||||
'address': address,
|
||||
}, ignore_duplicate=True)).fetchall()
|
||||
|
||||
conn.execute(
|
||||
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
|
||||
(history, history.count(':') // 2, address)
|
||||
)
|
||||
|
||||
def save_transaction_io(self, tx: Transaction, address, txhash, history):
|
||||
return self.db.run(self._transaction_io, tx, address, txhash, history)
|
||||
|
||||
def save_transaction_io_batch(self, txs: Iterable[Transaction], address, txhash, history):
|
||||
def __many(conn):
|
||||
for tx in txs:
|
||||
self._transaction_io(conn, tx, address, txhash, history)
|
||||
return self.db.run(__many)
|
||||
|
||||
async def reserve_outputs(self, txos, is_reserved=True):
|
||||
txoids = ((is_reserved, txo.id) for txo in txos)
|
||||
await self.db.executemany("UPDATE txo SET is_reserved = ? WHERE txoid = ?", txoids)
|
||||
|
||||
async def release_outputs(self, txos):
|
||||
await self.reserve_outputs(txos, is_reserved=False)
|
||||
|
||||
async def rewind_blockchain(self, above_height): # pylint: disable=no-self-use
|
||||
# TODO:
|
||||
# 1. delete transactions above_height
|
||||
# 2. update address histories removing deleted TXs
|
||||
return True
|
||||
|
||||
async def select_transactions(self, cols, accounts=None, **constraints):
|
||||
if not {'txid', 'txid__in'}.intersection(constraints):
|
||||
assert accounts, "'accounts' argument required when no 'txid' constraint is present"
|
||||
constraints.update({
|
||||
f'$account{i}': a.public_key.address for i, a in enumerate(accounts)
|
||||
})
|
||||
account_values = ', '.join([f':$account{i}' for i in range(len(accounts))])
|
||||
where = f" WHERE account_address.account IN ({account_values})"
|
||||
constraints['txid__in'] = f"""
|
||||
SELECT txid FROM txi JOIN account_address USING (address)
|
||||
WHERE account_address.account IN ({account_values})
|
||||
SELECT txo.txid FROM txo JOIN account_address USING (address) {where}
|
||||
UNION
|
||||
SELECT txi.txid FROM txi JOIN account_address USING (address) {where}
|
||||
"""
|
||||
return await self.db.execute_fetchall(
|
||||
*query(f"SELECT {cols} FROM tx", **constraints)
|
||||
)
|
||||
|
||||
async def get_purchases(self, **constraints):
|
||||
self.constrain_purchases(constraints)
|
||||
return [tx.outputs[0] for tx in await self.get_transactions(**constraints)]
|
||||
async def get_transactions(self, wallet=None, **constraints):
|
||||
tx_rows = await self.select_transactions(
|
||||
'txid, raw, height, position, is_verified',
|
||||
order_by=constraints.pop('order_by', ["height=0 DESC", "height DESC", "position DESC"]),
|
||||
**constraints
|
||||
)
|
||||
|
||||
def get_purchase_count(self, **constraints):
|
||||
self.constrain_purchases(constraints)
|
||||
return self.get_transaction_count(**constraints)
|
||||
if not tx_rows:
|
||||
return []
|
||||
|
||||
async def get_txos(self, wallet=None, no_tx=False, **constraints) -> List[Output]:
|
||||
txos = await super().get_txos(wallet=wallet, no_tx=no_tx, **constraints)
|
||||
txids, txs, txi_txoids = [], [], []
|
||||
for row in tx_rows:
|
||||
txids.append(row[0])
|
||||
txs.append(Transaction(
|
||||
raw=row[1], height=row[2], position=row[3], is_verified=bool(row[4])
|
||||
))
|
||||
for txi in txs[-1].inputs:
|
||||
txi_txoids.append(txi.txo_ref.id)
|
||||
|
||||
step = self.MAX_QUERY_VARIABLES
|
||||
annotated_txos = {}
|
||||
for offset in range(0, len(txids), step):
|
||||
annotated_txos.update({
|
||||
txo.id: txo for txo in
|
||||
(await self.get_txos(
|
||||
wallet=wallet,
|
||||
txid__in=txids[offset:offset+step],
|
||||
))
|
||||
})
|
||||
|
||||
referenced_txos = {}
|
||||
for offset in range(0, len(txi_txoids), step):
|
||||
referenced_txos.update({
|
||||
txo.id: txo for txo in
|
||||
(await self.get_txos(
|
||||
wallet=wallet,
|
||||
txoid__in=txi_txoids[offset:offset+step],
|
||||
))
|
||||
})
|
||||
|
||||
for tx in txs:
|
||||
for txi in tx.inputs:
|
||||
txo = referenced_txos.get(txi.txo_ref.id)
|
||||
if txo:
|
||||
txi.txo_ref = txo.ref
|
||||
for txo in tx.outputs:
|
||||
_txo = annotated_txos.get(txo.id)
|
||||
if _txo:
|
||||
txo.update_annotations(_txo)
|
||||
else:
|
||||
txo.update_annotations(None)
|
||||
|
||||
for tx in txs:
|
||||
txos = tx.outputs
|
||||
if len(txos) >= 2 and txos[1].can_decode_purchase_data:
|
||||
txos[0].purchase = txos[1]
|
||||
|
||||
return txs
|
||||
|
||||
async def get_transaction_count(self, **constraints):
|
||||
constraints.pop('wallet', None)
|
||||
constraints.pop('offset', None)
|
||||
constraints.pop('limit', None)
|
||||
constraints.pop('order_by', None)
|
||||
count = await self.select_transactions('count(*)', **constraints)
|
||||
return count[0][0]
|
||||
|
||||
async def get_transaction(self, **constraints):
|
||||
txs = await self.get_transactions(limit=1, **constraints)
|
||||
if txs:
|
||||
return txs[0]
|
||||
|
||||
async def select_txos(self, cols, **constraints):
|
||||
sql = f"SELECT {cols} FROM txo JOIN tx USING (txid)"
|
||||
if 'accounts' in constraints:
|
||||
sql += " JOIN account_address USING (address)"
|
||||
return await self.db.execute_fetchall(*query(sql, **constraints))
|
||||
|
||||
async def get_txos(self, wallet=None, no_tx=False, **constraints):
|
||||
my_accounts = {a.public_key.address for a in wallet.accounts} if wallet else set()
|
||||
if 'order_by' not in constraints:
|
||||
constraints['order_by'] = [
|
||||
"tx.height=0 DESC", "tx.height DESC", "tx.position DESC", "txo.position"
|
||||
]
|
||||
rows = await self.select_txos(
|
||||
"""
|
||||
tx.txid, raw, tx.height, tx.position, tx.is_verified, txo.position, amount, script, (
|
||||
select group_concat(account||"|"||chain) from account_address
|
||||
where account_address.address=txo.address
|
||||
)
|
||||
""",
|
||||
**constraints
|
||||
)
|
||||
txos = []
|
||||
txs = {}
|
||||
for row in rows:
|
||||
if no_tx:
|
||||
txo = Output(
|
||||
amount=row[6],
|
||||
script=OutputScript(row[7]),
|
||||
tx_ref=TXRefImmutable.from_id(row[0], row[2]),
|
||||
position=row[5]
|
||||
)
|
||||
else:
|
||||
if row[0] not in txs:
|
||||
txs[row[0]] = Transaction(
|
||||
row[1], height=row[2], position=row[3], is_verified=row[4]
|
||||
)
|
||||
txo = txs[row[0]].outputs[row[5]]
|
||||
row_accounts = dict(a.split('|') for a in row[8].split(','))
|
||||
account_match = set(row_accounts) & my_accounts
|
||||
if account_match:
|
||||
txo.is_my_account = True
|
||||
txo.is_change = row_accounts[account_match.pop()] == '1'
|
||||
else:
|
||||
txo.is_change = txo.is_my_account = False
|
||||
txos.append(txo)
|
||||
|
||||
channel_ids = set()
|
||||
for txo in txos:
|
||||
|
@ -138,6 +628,112 @@ class WalletDatabase(BaseDatabase):
|
|||
|
||||
return txos
|
||||
|
||||
async def get_txo_count(self, **constraints):
|
||||
constraints.pop('wallet', None)
|
||||
constraints.pop('offset', None)
|
||||
constraints.pop('limit', None)
|
||||
constraints.pop('order_by', None)
|
||||
count = await self.select_txos('count(*)', **constraints)
|
||||
return count[0][0]
|
||||
|
||||
@staticmethod
|
||||
def constrain_utxo(constraints):
|
||||
constraints['is_reserved'] = False
|
||||
constraints['txoid__not_in'] = "SELECT txoid FROM txi"
|
||||
|
||||
def get_utxos(self, **constraints):
|
||||
self.constrain_utxo(constraints)
|
||||
return self.get_txos(**constraints)
|
||||
|
||||
def get_utxo_count(self, **constraints):
|
||||
self.constrain_utxo(constraints)
|
||||
return self.get_txo_count(**constraints)
|
||||
|
||||
async def get_balance(self, wallet=None, accounts=None, **constraints):
|
||||
assert wallet or accounts, \
|
||||
"'wallet' or 'accounts' constraints required to calculate balance"
|
||||
constraints['accounts'] = accounts or wallet.accounts
|
||||
self.constrain_utxo(constraints)
|
||||
balance = await self.select_txos('SUM(amount)', **constraints)
|
||||
return balance[0][0] or 0
|
||||
|
||||
async def select_addresses(self, cols, **constraints):
|
||||
return await self.db.execute_fetchall(*query(
|
||||
f"SELECT {cols} FROM pubkey_address JOIN account_address USING (address)",
|
||||
**constraints
|
||||
))
|
||||
|
||||
async def get_addresses(self, cols=None, **constraints):
|
||||
cols = cols or (
|
||||
'address', 'account', 'chain', 'history', 'used_times',
|
||||
'pubkey', 'chain_code', 'n', 'depth'
|
||||
)
|
||||
addresses = rows_to_dict(await self.select_addresses(', '.join(cols), **constraints), cols)
|
||||
if 'pubkey' in cols:
|
||||
for address in addresses:
|
||||
address['pubkey'] = PubKey(
|
||||
self.ledger, address.pop('pubkey'), address.pop('chain_code'),
|
||||
address.pop('n'), address.pop('depth')
|
||||
)
|
||||
return addresses
|
||||
|
||||
async def get_address_count(self, cols=None, **constraints):
|
||||
count = await self.select_addresses('count(*)', **constraints)
|
||||
return count[0][0]
|
||||
|
||||
async def get_address(self, **constraints):
|
||||
addresses = await self.get_addresses(limit=1, **constraints)
|
||||
if addresses:
|
||||
return addresses[0]
|
||||
|
||||
async def add_keys(self, account, chain, pubkeys):
|
||||
await self.db.executemany(
|
||||
"insert or ignore into account_address "
|
||||
"(account, address, chain, pubkey, chain_code, n, depth) values "
|
||||
"(?, ?, ?, ?, ?, ?, ?)", ((
|
||||
account.id, k.address, chain,
|
||||
sqlite3.Binary(k.pubkey_bytes),
|
||||
sqlite3.Binary(k.chain_code),
|
||||
k.n, k.depth
|
||||
) for k in pubkeys)
|
||||
)
|
||||
await self.db.executemany(
|
||||
"insert or ignore into pubkey_address (address) values (?)",
|
||||
((pubkey.address,) for pubkey in pubkeys)
|
||||
)
|
||||
|
||||
async def _set_address_history(self, address, history):
|
||||
await self.db.execute_fetchall(
|
||||
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
|
||||
(history, history.count(':')//2, address)
|
||||
)
|
||||
|
||||
async def set_address_history(self, address, history):
|
||||
await self._set_address_history(address, history)
|
||||
|
||||
@staticmethod
|
||||
def constrain_purchases(constraints):
|
||||
accounts = constraints.pop('accounts', None)
|
||||
assert accounts, "'accounts' argument required to find purchases"
|
||||
if not {'purchased_claim_id', 'purchased_claim_id__in'}.intersection(constraints):
|
||||
constraints['purchased_claim_id__is_not_null'] = True
|
||||
constraints.update({
|
||||
f'$account{i}': a.public_key.address for i, a in enumerate(accounts)
|
||||
})
|
||||
account_values = ', '.join([f':$account{i}' for i in range(len(accounts))])
|
||||
constraints['txid__in'] = f"""
|
||||
SELECT txid FROM txi JOIN account_address USING (address)
|
||||
WHERE account_address.account IN ({account_values})
|
||||
"""
|
||||
|
||||
async def get_purchases(self, **constraints):
|
||||
self.constrain_purchases(constraints)
|
||||
return [tx.outputs[0] for tx in await self.get_transactions(**constraints)]
|
||||
|
||||
def get_purchase_count(self, **constraints):
|
||||
self.constrain_purchases(constraints)
|
||||
return self.get_transaction_count(**constraints)
|
||||
|
||||
@staticmethod
|
||||
def constrain_claims(constraints):
|
||||
claim_type = constraints.pop('claim_type', None)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import textwrap
|
||||
from lbry.wallet.client.util import coins_to_satoshis, satoshis_to_coins
|
||||
from .util import coins_to_satoshis, satoshis_to_coins
|
||||
|
||||
|
||||
def lbc_to_dewies(lbc: str) -> int:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from binascii import hexlify, unhexlify
|
||||
from lbry.wallet.client.constants import NULL_HASH32
|
||||
from .constants import NULL_HASH32
|
||||
|
||||
|
||||
class TXRef:
|
||||
|
|
|
@ -1,10 +1,252 @@
|
|||
import os
|
||||
import struct
|
||||
from typing import Optional
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
from io import BytesIO
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional, Iterator, Tuple
|
||||
from binascii import hexlify, unhexlify
|
||||
|
||||
from lbry.crypto.hash import sha512, double_sha256, ripemd160
|
||||
from lbry.wallet.client.baseheader import BaseHeaders
|
||||
from lbry.wallet.client.util import ArithUint256
|
||||
from lbry.wallet.util import ArithUint256
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InvalidHeader(Exception):
|
||||
|
||||
def __init__(self, height, message):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.height = height
|
||||
|
||||
|
||||
class BaseHeaders:
|
||||
|
||||
header_size: int
|
||||
chunk_size: int
|
||||
|
||||
max_target: int
|
||||
genesis_hash: Optional[bytes]
|
||||
target_timespan: int
|
||||
|
||||
validate_difficulty: bool = True
|
||||
checkpoint = None
|
||||
|
||||
def __init__(self, path) -> None:
|
||||
if path == ':memory:':
|
||||
self.io = BytesIO()
|
||||
self.path = path
|
||||
self._size: Optional[int] = None
|
||||
|
||||
async def open(self):
|
||||
if self.path != ':memory:':
|
||||
if not os.path.exists(self.path):
|
||||
self.io = open(self.path, 'w+b')
|
||||
else:
|
||||
self.io = open(self.path, 'r+b')
|
||||
|
||||
async def close(self):
|
||||
self.io.close()
|
||||
|
||||
@staticmethod
|
||||
def serialize(header: dict) -> bytes:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def deserialize(height, header):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_next_chunk_target(self, chunk: int) -> ArithUint256:
|
||||
return ArithUint256(self.max_target)
|
||||
|
||||
@staticmethod
|
||||
def get_next_block_target(chunk_target: ArithUint256, previous: Optional[dict],
|
||||
current: Optional[dict]) -> ArithUint256:
|
||||
return chunk_target
|
||||
|
||||
def __len__(self) -> int:
|
||||
if self._size is None:
|
||||
self._size = self.io.seek(0, os.SEEK_END) // self.header_size
|
||||
return self._size
|
||||
|
||||
def __bool__(self):
|
||||
return True
|
||||
|
||||
def __getitem__(self, height) -> dict:
|
||||
if isinstance(height, slice):
|
||||
raise NotImplementedError("Slicing of header chain has not been implemented yet.")
|
||||
if not 0 <= height <= self.height:
|
||||
raise IndexError(f"{height} is out of bounds, current height: {self.height}")
|
||||
return self.deserialize(height, self.get_raw_header(height))
|
||||
|
||||
def get_raw_header(self, height) -> bytes:
|
||||
self.io.seek(height * self.header_size, os.SEEK_SET)
|
||||
return self.io.read(self.header_size)
|
||||
|
||||
@property
|
||||
def height(self) -> int:
|
||||
return len(self)-1
|
||||
|
||||
@property
|
||||
def bytes_size(self):
|
||||
return len(self) * self.header_size
|
||||
|
||||
def hash(self, height=None) -> bytes:
|
||||
return self.hash_header(
|
||||
self.get_raw_header(height if height is not None else self.height)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def hash_header(header: bytes) -> bytes:
|
||||
if header is None:
|
||||
return b'0' * 64
|
||||
return hexlify(double_sha256(header)[::-1])
|
||||
|
||||
@asynccontextmanager
|
||||
async def checkpointed_connector(self):
|
||||
buf = BytesIO()
|
||||
try:
|
||||
yield buf
|
||||
finally:
|
||||
await asyncio.sleep(0)
|
||||
final_height = len(self) + buf.tell() // self.header_size
|
||||
verifiable_bytes = (self.checkpoint[0] - len(self)) * self.header_size if self.checkpoint else 0
|
||||
if verifiable_bytes > 0 and final_height >= self.checkpoint[0]:
|
||||
buf.seek(0)
|
||||
self.io.seek(0)
|
||||
h = hashlib.sha256()
|
||||
h.update(self.io.read())
|
||||
h.update(buf.read(verifiable_bytes))
|
||||
if h.hexdigest().encode() == self.checkpoint[1]:
|
||||
buf.seek(0)
|
||||
self._write(len(self), buf.read(verifiable_bytes))
|
||||
remaining = buf.read()
|
||||
buf.seek(0)
|
||||
buf.write(remaining)
|
||||
buf.truncate()
|
||||
else:
|
||||
log.warning("Checkpoint mismatch, connecting headers through slow method.")
|
||||
if buf.tell() > 0:
|
||||
await self.connect(len(self), buf.getvalue())
|
||||
|
||||
async def connect(self, start: int, headers: bytes) -> int:
|
||||
added = 0
|
||||
bail = False
|
||||
for height, chunk in self._iterate_chunks(start, headers):
|
||||
try:
|
||||
# validate_chunk() is CPU bound and reads previous chunks from file system
|
||||
self.validate_chunk(height, chunk)
|
||||
except InvalidHeader as e:
|
||||
bail = True
|
||||
chunk = chunk[:(height-e.height)*self.header_size]
|
||||
added += self._write(height, chunk) if chunk else 0
|
||||
if bail:
|
||||
break
|
||||
return added
|
||||
|
||||
def _write(self, height, verified_chunk):
|
||||
self.io.seek(height * self.header_size, os.SEEK_SET)
|
||||
written = self.io.write(verified_chunk) // self.header_size
|
||||
self.io.truncate()
|
||||
# .seek()/.write()/.truncate() might also .flush() when needed
|
||||
# the goal here is mainly to ensure we're definitely flush()'ing
|
||||
self.io.flush()
|
||||
self._size = self.io.tell() // self.header_size
|
||||
return written
|
||||
|
||||
def validate_chunk(self, height, chunk):
|
||||
previous_hash, previous_header, previous_previous_header = None, None, None
|
||||
if height > 0:
|
||||
previous_header = self[height-1]
|
||||
previous_hash = self.hash(height-1)
|
||||
if height > 1:
|
||||
previous_previous_header = self[height-2]
|
||||
chunk_target = self.get_next_chunk_target(height // 2016 - 1)
|
||||
for current_hash, current_header in self._iterate_headers(height, chunk):
|
||||
block_target = self.get_next_block_target(chunk_target, previous_previous_header, previous_header)
|
||||
self.validate_header(height, current_hash, current_header, previous_hash, block_target)
|
||||
previous_previous_header = previous_header
|
||||
previous_header = current_header
|
||||
previous_hash = current_hash
|
||||
|
||||
def validate_header(self, height: int, current_hash: bytes,
|
||||
header: dict, previous_hash: bytes, target: ArithUint256):
|
||||
|
||||
if previous_hash is None:
|
||||
if self.genesis_hash is not None and self.genesis_hash != current_hash:
|
||||
raise InvalidHeader(
|
||||
height, f"genesis header doesn't match: {current_hash.decode()} "
|
||||
f"vs expected {self.genesis_hash.decode()}")
|
||||
return
|
||||
|
||||
if header['prev_block_hash'] != previous_hash:
|
||||
raise InvalidHeader(
|
||||
height, "previous hash mismatch: {} vs expected {}".format(
|
||||
header['prev_block_hash'].decode(), previous_hash.decode())
|
||||
)
|
||||
|
||||
if self.validate_difficulty:
|
||||
|
||||
if header['bits'] != target.compact:
|
||||
raise InvalidHeader(
|
||||
height, "bits mismatch: {} vs expected {}".format(
|
||||
header['bits'], target.compact)
|
||||
)
|
||||
|
||||
proof_of_work = self.get_proof_of_work(current_hash)
|
||||
if proof_of_work > target:
|
||||
raise InvalidHeader(
|
||||
height, f"insufficient proof of work: {proof_of_work.value} vs target {target.value}"
|
||||
)
|
||||
|
||||
async def repair(self):
|
||||
previous_header_hash = fail = None
|
||||
batch_size = 36
|
||||
for start_height in range(0, self.height, batch_size):
|
||||
self.io.seek(self.header_size * start_height)
|
||||
headers = self.io.read(self.header_size*batch_size)
|
||||
if len(headers) % self.header_size != 0:
|
||||
headers = headers[:(len(headers) // self.header_size) * self.header_size]
|
||||
for header_hash, header in self._iterate_headers(start_height, headers):
|
||||
height = header['block_height']
|
||||
if height:
|
||||
if header['prev_block_hash'] != previous_header_hash:
|
||||
fail = True
|
||||
else:
|
||||
if header_hash != self.genesis_hash:
|
||||
fail = True
|
||||
if fail:
|
||||
log.warning("Header file corrupted at height %s, truncating it.", height - 1)
|
||||
self.io.seek(max(0, (height - 1)) * self.header_size, os.SEEK_SET)
|
||||
self.io.truncate()
|
||||
self.io.flush()
|
||||
self._size = None
|
||||
return
|
||||
previous_header_hash = header_hash
|
||||
|
||||
@staticmethod
|
||||
def get_proof_of_work(header_hash: bytes) -> ArithUint256:
|
||||
return ArithUint256(int(b'0x' + header_hash, 16))
|
||||
|
||||
def _iterate_chunks(self, height: int, headers: bytes) -> Iterator[Tuple[int, bytes]]:
|
||||
assert len(headers) % self.header_size == 0, f"{len(headers)} {len(headers)%self.header_size}"
|
||||
start = 0
|
||||
end = (self.chunk_size - height % self.chunk_size) * self.header_size
|
||||
while start < end:
|
||||
yield height + (start // self.header_size), headers[start:end]
|
||||
start = end
|
||||
end = min(len(headers), end + self.chunk_size * self.header_size)
|
||||
|
||||
def _iterate_headers(self, height: int, headers: bytes) -> Iterator[Tuple[bytes, dict]]:
|
||||
assert len(headers) % self.header_size == 0, len(headers)
|
||||
for idx in range(len(headers) // self.header_size):
|
||||
start, end = idx * self.header_size, (idx + 1) * self.header_size
|
||||
header = headers[start:end]
|
||||
yield self.hash_header(header), self.deserialize(height+idx, header)
|
||||
|
||||
|
||||
class Headers(BaseHeaders):
|
||||
|
|
|
@ -1,41 +1,96 @@
|
|||
import os
|
||||
import zlib
|
||||
import pylru
|
||||
import base64
|
||||
import asyncio
|
||||
import logging
|
||||
from binascii import unhexlify
|
||||
from functools import partial
|
||||
from typing import Tuple, List
|
||||
from datetime import datetime
|
||||
|
||||
import pylru
|
||||
from lbry.wallet.client.baseledger import BaseLedger, TransactionEvent
|
||||
from lbry.wallet.client.baseaccount import SingleKey
|
||||
from io import StringIO
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from operator import itemgetter
|
||||
from collections import namedtuple
|
||||
from binascii import hexlify, unhexlify
|
||||
from typing import Dict, Tuple, Type, Iterable, List, Optional
|
||||
|
||||
from lbry.schema.result import Outputs
|
||||
from lbry.schema.url import URL
|
||||
from lbry.wallet.dewies import dewies_to_lbc
|
||||
from lbry.wallet.account import Account
|
||||
from lbry.wallet.network import Network
|
||||
from lbry.wallet.database import WalletDatabase
|
||||
from lbry.wallet.transaction import Transaction, Output
|
||||
from lbry.wallet.header import Headers, UnvalidatedHeaders
|
||||
from lbry.wallet.constants import TXO_TYPES
|
||||
from lbry.crypto.hash import hash160, double_sha256, sha256
|
||||
from lbry.crypto.base58 import Base58
|
||||
|
||||
from .tasks import TaskGroup
|
||||
from .database import Database
|
||||
from .stream import StreamController
|
||||
from .dewies import dewies_to_lbc
|
||||
from .account import Account, AddressManager, SingleKey
|
||||
from .network import Network
|
||||
from .transaction import Transaction, Output
|
||||
from .header import Headers, UnvalidatedHeaders
|
||||
from .constants import TXO_TYPES, COIN, NULL_HASH32
|
||||
from .bip32 import PubKey, PrivateKey
|
||||
from .coinselection import CoinSelector
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
LedgerType = Type['BaseLedger']
|
||||
|
||||
class MainNetLedger(BaseLedger):
|
||||
|
||||
class LedgerRegistry(type):
|
||||
|
||||
ledgers: Dict[str, LedgerType] = {}
|
||||
|
||||
def __new__(mcs, name, bases, attrs):
|
||||
cls: LedgerType = super().__new__(mcs, name, bases, attrs)
|
||||
if not (name == 'BaseLedger' and not bases):
|
||||
ledger_id = cls.get_id()
|
||||
assert ledger_id not in mcs.ledgers, \
|
||||
f'Ledger with id "{ledger_id}" already registered.'
|
||||
mcs.ledgers[ledger_id] = cls
|
||||
return cls
|
||||
|
||||
@classmethod
|
||||
def get_ledger_class(mcs, ledger_id: str) -> LedgerType:
|
||||
return mcs.ledgers[ledger_id]
|
||||
|
||||
|
||||
class TransactionEvent(namedtuple('TransactionEvent', ('address', 'tx'))):
|
||||
pass
|
||||
|
||||
|
||||
class AddressesGeneratedEvent(namedtuple('AddressesGeneratedEvent', ('address_manager', 'addresses'))):
|
||||
pass
|
||||
|
||||
|
||||
class BlockHeightEvent(namedtuple('BlockHeightEvent', ('height', 'change'))):
|
||||
pass
|
||||
|
||||
|
||||
class TransactionCacheItem:
|
||||
__slots__ = '_tx', 'lock', 'has_tx'
|
||||
|
||||
def __init__(self, tx: Optional[Transaction] = None, lock: Optional[asyncio.Lock] = None):
|
||||
self.has_tx = asyncio.Event()
|
||||
self.lock = lock or asyncio.Lock()
|
||||
self._tx = self.tx = tx
|
||||
|
||||
@property
|
||||
def tx(self) -> Optional[Transaction]:
|
||||
return self._tx
|
||||
|
||||
@tx.setter
|
||||
def tx(self, tx: Transaction):
|
||||
self._tx = tx
|
||||
if tx is not None:
|
||||
self.has_tx.set()
|
||||
|
||||
|
||||
class Ledger(metaclass=LedgerRegistry):
|
||||
name = 'LBRY Credits'
|
||||
symbol = 'LBC'
|
||||
network_name = 'mainnet'
|
||||
|
||||
headers: Headers
|
||||
|
||||
account_class = Account
|
||||
database_class = WalletDatabase
|
||||
headers_class = Headers
|
||||
network_class = Network
|
||||
transaction_class = Transaction
|
||||
|
||||
db: WalletDatabase
|
||||
|
||||
secret_prefix = bytes((0x1c,))
|
||||
pubkey_address_prefix = bytes((0x55,))
|
||||
|
@ -51,11 +106,522 @@ class MainNetLedger(BaseLedger):
|
|||
default_fee_per_byte = 50
|
||||
default_fee_per_name_char = 200000
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
def __init__(self, config=None):
|
||||
self.config = config or {}
|
||||
self.db: Database = self.config.get('db') or Database(
|
||||
os.path.join(self.path, "blockchain.db")
|
||||
)
|
||||
self.db.ledger = self
|
||||
self.headers: Headers = self.config.get('headers') or self.headers_class(
|
||||
os.path.join(self.path, "headers")
|
||||
)
|
||||
self.network: Network = self.config.get('network') or Network(self)
|
||||
self.network.on_header.listen(self.receive_header)
|
||||
self.network.on_status.listen(self.process_status_update)
|
||||
self.network.on_connected.listen(self.join_network)
|
||||
|
||||
self.accounts = []
|
||||
self.fee_per_byte: int = self.config.get('fee_per_byte', self.default_fee_per_byte)
|
||||
|
||||
self._on_transaction_controller = StreamController()
|
||||
self.on_transaction = self._on_transaction_controller.stream
|
||||
self.on_transaction.listen(
|
||||
lambda e: log.info(
|
||||
'(%s) on_transaction: address=%s, height=%s, is_verified=%s, tx.id=%s',
|
||||
self.get_id(), e.address, e.tx.height, e.tx.is_verified, e.tx.id
|
||||
)
|
||||
)
|
||||
|
||||
self._on_address_controller = StreamController()
|
||||
self.on_address = self._on_address_controller.stream
|
||||
self.on_address.listen(
|
||||
lambda e: log.info('(%s) on_address: %s', self.get_id(), e.addresses)
|
||||
)
|
||||
|
||||
self._on_header_controller = StreamController()
|
||||
self.on_header = self._on_header_controller.stream
|
||||
self.on_header.listen(
|
||||
lambda change: log.info(
|
||||
'%s: added %s header blocks, final height %s',
|
||||
self.get_id(), change, self.headers.height
|
||||
)
|
||||
)
|
||||
self._download_height = 0
|
||||
|
||||
self._on_ready_controller = StreamController()
|
||||
self.on_ready = self._on_ready_controller.stream
|
||||
|
||||
self._tx_cache = pylru.lrucache(100000)
|
||||
self._update_tasks = TaskGroup()
|
||||
self._utxo_reservation_lock = asyncio.Lock()
|
||||
self._header_processing_lock = asyncio.Lock()
|
||||
self._address_update_locks: Dict[str, asyncio.Lock] = {}
|
||||
|
||||
self.coin_selection_strategy = None
|
||||
self._known_addresses_out_of_sync = set()
|
||||
|
||||
self.fee_per_name_char = self.config.get('fee_per_name_char', self.default_fee_per_name_char)
|
||||
self._balance_cache = pylru.lrucache(100000)
|
||||
|
||||
@classmethod
|
||||
def get_id(cls):
|
||||
return '{}_{}'.format(cls.symbol.lower(), cls.network_name.lower())
|
||||
|
||||
@classmethod
|
||||
def hash160_to_address(cls, h160):
|
||||
raw_address = cls.pubkey_address_prefix + h160
|
||||
return Base58.encode(bytearray(raw_address + double_sha256(raw_address)[0:4]))
|
||||
|
||||
@staticmethod
|
||||
def address_to_hash160(address):
|
||||
return Base58.decode(address)[1:21]
|
||||
|
||||
@classmethod
|
||||
def is_valid_address(cls, address):
|
||||
decoded = Base58.decode_check(address)
|
||||
return decoded[0] == cls.pubkey_address_prefix[0]
|
||||
|
||||
@classmethod
|
||||
def public_key_to_address(cls, public_key):
|
||||
return cls.hash160_to_address(hash160(public_key))
|
||||
|
||||
@staticmethod
|
||||
def private_key_to_wif(private_key):
|
||||
return b'\x1c' + private_key + b'\x01'
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return os.path.join(self.config['data_path'], self.get_id())
|
||||
|
||||
def add_account(self, account: Account):
|
||||
self.accounts.append(account)
|
||||
|
||||
async def _get_account_and_address_info_for_address(self, wallet, address):
|
||||
match = await self.db.get_address(accounts=wallet.accounts, address=address)
|
||||
if match:
|
||||
for account in wallet.accounts:
|
||||
if match['account'] == account.public_key.address:
|
||||
return account, match
|
||||
|
||||
async def get_private_key_for_address(self, wallet, address) -> Optional[PrivateKey]:
|
||||
match = await self._get_account_and_address_info_for_address(wallet, address)
|
||||
if match:
|
||||
account, address_info = match
|
||||
return account.get_private_key(address_info['chain'], address_info['pubkey'].n)
|
||||
return None
|
||||
|
||||
async def get_public_key_for_address(self, wallet, address) -> Optional[PubKey]:
|
||||
match = await self._get_account_and_address_info_for_address(wallet, address)
|
||||
if match:
|
||||
_, address_info = match
|
||||
return address_info['pubkey']
|
||||
return None
|
||||
|
||||
async def get_account_for_address(self, wallet, address):
|
||||
match = await self._get_account_and_address_info_for_address(wallet, address)
|
||||
if match:
|
||||
return match[0]
|
||||
|
||||
async def get_effective_amount_estimators(self, funding_accounts: Iterable[Account]):
|
||||
estimators = []
|
||||
for account in funding_accounts:
|
||||
utxos = await account.get_utxos()
|
||||
for utxo in utxos:
|
||||
estimators.append(utxo.get_estimator(self))
|
||||
return estimators
|
||||
|
||||
async def get_addresses(self, **constraints):
|
||||
return await self.db.get_addresses(**constraints)
|
||||
|
||||
def get_address_count(self, **constraints):
|
||||
return self.db.get_address_count(**constraints)
|
||||
|
||||
async def get_spendable_utxos(self, amount: int, funding_accounts):
|
||||
async with self._utxo_reservation_lock:
|
||||
txos = await self.get_effective_amount_estimators(funding_accounts)
|
||||
fee = Output.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(self)
|
||||
selector = CoinSelector(amount, fee)
|
||||
spendables = selector.select(txos, self.coin_selection_strategy)
|
||||
if spendables:
|
||||
await self.reserve_outputs(s.txo for s in spendables)
|
||||
return spendables
|
||||
|
||||
def reserve_outputs(self, txos):
|
||||
return self.db.reserve_outputs(txos)
|
||||
|
||||
def release_outputs(self, txos):
|
||||
return self.db.release_outputs(txos)
|
||||
|
||||
def release_tx(self, tx):
|
||||
return self.release_outputs([txi.txo_ref.txo for txi in tx.inputs])
|
||||
|
||||
def get_utxos(self, **constraints):
|
||||
self.constraint_spending_utxos(constraints)
|
||||
return self.db.get_utxos(**constraints)
|
||||
|
||||
def get_utxo_count(self, **constraints):
|
||||
self.constraint_spending_utxos(constraints)
|
||||
return self.db.get_utxo_count(**constraints)
|
||||
|
||||
def get_transactions(self, **constraints):
|
||||
return self.db.get_transactions(**constraints)
|
||||
|
||||
def get_transaction_count(self, **constraints):
|
||||
return self.db.get_transaction_count(**constraints)
|
||||
|
||||
async def get_local_status_and_history(self, address, history=None):
|
||||
if not history:
|
||||
address_details = await self.db.get_address(address=address)
|
||||
history = address_details['history'] or ''
|
||||
parts = history.split(':')[:-1]
|
||||
return (
|
||||
hexlify(sha256(history.encode())).decode() if history else None,
|
||||
list(zip(parts[0::2], map(int, parts[1::2])))
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_root_of_merkle_tree(branches, branch_positions, working_branch):
|
||||
for i, branch in enumerate(branches):
|
||||
other_branch = unhexlify(branch)[::-1]
|
||||
other_branch_on_left = bool((branch_positions >> i) & 1)
|
||||
if other_branch_on_left:
|
||||
combined = other_branch + working_branch
|
||||
else:
|
||||
combined = working_branch + other_branch
|
||||
working_branch = double_sha256(combined)
|
||||
return hexlify(working_branch[::-1])
|
||||
|
||||
async def start(self):
|
||||
if not os.path.exists(self.path):
|
||||
os.mkdir(self.path)
|
||||
await asyncio.wait([
|
||||
self.db.open(),
|
||||
self.headers.open()
|
||||
])
|
||||
first_connection = self.network.on_connected.first
|
||||
asyncio.ensure_future(self.network.start())
|
||||
await first_connection
|
||||
async with self._header_processing_lock:
|
||||
await self._update_tasks.add(self.initial_headers_sync())
|
||||
await self._on_ready_controller.stream.first
|
||||
await asyncio.gather(*(a.maybe_migrate_certificates() for a in self.accounts))
|
||||
await asyncio.gather(*(a.save_max_gap() for a in self.accounts))
|
||||
if len(self.accounts) > 10:
|
||||
log.info("Loaded %i accounts", len(self.accounts))
|
||||
else:
|
||||
await self._report_state()
|
||||
self.on_transaction.listen(self._reset_balance_cache)
|
||||
|
||||
async def join_network(self, *_):
|
||||
log.info("Subscribing and updating accounts.")
|
||||
async with self._header_processing_lock:
|
||||
await self.update_headers()
|
||||
await self.subscribe_accounts()
|
||||
await self._update_tasks.done.wait()
|
||||
self._on_ready_controller.add(True)
|
||||
|
||||
async def stop(self):
|
||||
self._update_tasks.cancel()
|
||||
await self._update_tasks.done.wait()
|
||||
await self.network.stop()
|
||||
await self.db.close()
|
||||
await self.headers.close()
|
||||
|
||||
@property
|
||||
def local_height_including_downloaded_height(self):
|
||||
return max(self.headers.height, self._download_height)
|
||||
|
||||
async def initial_headers_sync(self):
|
||||
target = self.network.remote_height + 1
|
||||
current = len(self.headers)
|
||||
get_chunk = partial(self.network.retriable_call, self.network.get_headers, count=4096, b64=True)
|
||||
chunks = [asyncio.create_task(get_chunk(height)) for height in range(current, target, 4096)]
|
||||
total = 0
|
||||
async with self.headers.checkpointed_connector() as buffer:
|
||||
for chunk in chunks:
|
||||
headers = await chunk
|
||||
total += buffer.write(
|
||||
zlib.decompress(base64.b64decode(headers['base64']), wbits=-15, bufsize=600_000)
|
||||
)
|
||||
self._download_height = current + total // self.headers.header_size
|
||||
log.info("Headers sync: %s / %s", self._download_height, target)
|
||||
|
||||
async def update_headers(self, height=None, headers=None, subscription_update=False):
|
||||
rewound = 0
|
||||
while True:
|
||||
|
||||
if height is None or height > len(self.headers):
|
||||
# sometimes header subscription updates are for a header in the future
|
||||
# which can't be connected, so we do a normal header sync instead
|
||||
height = len(self.headers)
|
||||
headers = None
|
||||
subscription_update = False
|
||||
|
||||
if not headers:
|
||||
header_response = await self.network.retriable_call(self.network.get_headers, height, 2001)
|
||||
headers = header_response['hex']
|
||||
|
||||
if not headers:
|
||||
# Nothing to do, network thinks we're already at the latest height.
|
||||
return
|
||||
|
||||
added = await self.headers.connect(height, unhexlify(headers))
|
||||
if added > 0:
|
||||
height += added
|
||||
self._on_header_controller.add(
|
||||
BlockHeightEvent(self.headers.height, added))
|
||||
|
||||
if rewound > 0:
|
||||
# we started rewinding blocks and apparently found
|
||||
# a new chain
|
||||
rewound = 0
|
||||
await self.db.rewind_blockchain(height)
|
||||
|
||||
if subscription_update:
|
||||
# subscription updates are for latest header already
|
||||
# so we don't need to check if there are newer / more
|
||||
# on another loop of update_headers(), just return instead
|
||||
return
|
||||
|
||||
elif added == 0:
|
||||
# we had headers to connect but none got connected, probably a reorganization
|
||||
height -= 1
|
||||
rewound += 1
|
||||
log.warning(
|
||||
"Blockchain Reorganization: attempting rewind to height %s from starting height %s",
|
||||
height, height+rewound
|
||||
)
|
||||
|
||||
else:
|
||||
raise IndexError(f"headers.connect() returned negative number ({added})")
|
||||
|
||||
if height < 0:
|
||||
raise IndexError(
|
||||
"Blockchain reorganization rewound all the way back to genesis hash. "
|
||||
"Something is very wrong. Maybe you are on the wrong blockchain?"
|
||||
)
|
||||
|
||||
if rewound >= 100:
|
||||
raise IndexError(
|
||||
"Blockchain reorganization dropped {} headers. This is highly unusual. "
|
||||
"Will not continue to attempt reorganizing. Please, delete the ledger "
|
||||
"synchronization directory inside your wallet directory (folder: '{}') and "
|
||||
"restart the program to synchronize from scratch."
|
||||
.format(rewound, self.get_id())
|
||||
)
|
||||
|
||||
headers = None # ready to download some more headers
|
||||
|
||||
# if we made it this far and this was a subscription_update
|
||||
# it means something went wrong and now we're doing a more
|
||||
# robust sync, turn off subscription update shortcut
|
||||
subscription_update = False
|
||||
|
||||
async def receive_header(self, response):
|
||||
async with self._header_processing_lock:
|
||||
header = response[0]
|
||||
await self.update_headers(
|
||||
height=header['height'], headers=header['hex'], subscription_update=True
|
||||
)
|
||||
|
||||
async def subscribe_accounts(self):
|
||||
if self.network.is_connected and self.accounts:
|
||||
await asyncio.wait([
|
||||
self.subscribe_account(a) for a in self.accounts
|
||||
])
|
||||
|
||||
async def subscribe_account(self, account: Account):
|
||||
for address_manager in account.address_managers.values():
|
||||
await self.subscribe_addresses(address_manager, await address_manager.get_addresses())
|
||||
await account.ensure_address_gap()
|
||||
|
||||
async def unsubscribe_account(self, account: Account):
|
||||
for address in await account.get_addresses():
|
||||
await self.network.unsubscribe_address(address)
|
||||
|
||||
async def announce_addresses(self, address_manager: AddressManager, addresses: List[str]):
|
||||
await self.subscribe_addresses(address_manager, addresses)
|
||||
await self._on_address_controller.add(
|
||||
AddressesGeneratedEvent(address_manager, addresses)
|
||||
)
|
||||
|
||||
async def subscribe_addresses(self, address_manager: AddressManager, addresses: List[str]):
|
||||
if self.network.is_connected and addresses:
|
||||
await asyncio.wait([
|
||||
self.subscribe_address(address_manager, address) for address in addresses
|
||||
])
|
||||
|
||||
async def subscribe_address(self, address_manager: AddressManager, address: str):
|
||||
remote_status = await self.network.subscribe_address(address)
|
||||
self._update_tasks.add(self.update_history(address, remote_status, address_manager))
|
||||
|
||||
def process_status_update(self, update):
|
||||
address, remote_status = update
|
||||
self._update_tasks.add(self.update_history(address, remote_status))
|
||||
|
||||
async def update_history(self, address, remote_status,
|
||||
address_manager: AddressManager = None):
|
||||
|
||||
async with self._address_update_locks.setdefault(address, asyncio.Lock()):
|
||||
self._known_addresses_out_of_sync.discard(address)
|
||||
|
||||
local_status, local_history = await self.get_local_status_and_history(address)
|
||||
|
||||
if local_status == remote_status:
|
||||
return True
|
||||
|
||||
remote_history = await self.network.retriable_call(self.network.get_history, address)
|
||||
remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history))
|
||||
we_need = set(remote_history) - set(local_history)
|
||||
if not we_need:
|
||||
return True
|
||||
|
||||
cache_tasks: List[asyncio.Future[Transaction]] = []
|
||||
synced_history = StringIO()
|
||||
for i, (txid, remote_height) in enumerate(remote_history):
|
||||
if i < len(local_history) and local_history[i] == (txid, remote_height) and not cache_tasks:
|
||||
synced_history.write(f'{txid}:{remote_height}:')
|
||||
else:
|
||||
check_local = (txid, remote_height) not in we_need
|
||||
cache_tasks.append(asyncio.ensure_future(
|
||||
self.cache_transaction(txid, remote_height, check_local=check_local)
|
||||
))
|
||||
|
||||
synced_txs = []
|
||||
for task in cache_tasks:
|
||||
tx = await task
|
||||
|
||||
check_db_for_txos = []
|
||||
for txi in tx.inputs:
|
||||
if txi.txo_ref.txo is not None:
|
||||
continue
|
||||
cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.id)
|
||||
if cache_item is not None:
|
||||
if cache_item.tx is None:
|
||||
await cache_item.has_tx.wait()
|
||||
assert cache_item.tx is not None
|
||||
txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref
|
||||
else:
|
||||
check_db_for_txos.append(txi.txo_ref.id)
|
||||
|
||||
referenced_txos = {} if not check_db_for_txos else {
|
||||
txo.id: txo for txo in await self.db.get_txos(txoid__in=check_db_for_txos, no_tx=True)
|
||||
}
|
||||
|
||||
for txi in tx.inputs:
|
||||
if txi.txo_ref.txo is not None:
|
||||
continue
|
||||
referenced_txo = referenced_txos.get(txi.txo_ref.id)
|
||||
if referenced_txo is not None:
|
||||
txi.txo_ref = referenced_txo.ref
|
||||
|
||||
synced_history.write(f'{tx.id}:{tx.height}:')
|
||||
synced_txs.append(tx)
|
||||
|
||||
await self.db.save_transaction_io_batch(
|
||||
synced_txs, address, self.address_to_hash160(address), synced_history.getvalue()
|
||||
)
|
||||
await asyncio.wait([
|
||||
self._on_transaction_controller.add(TransactionEvent(address, tx))
|
||||
for tx in synced_txs
|
||||
])
|
||||
|
||||
if address_manager is None:
|
||||
address_manager = await self.get_address_manager_for_address(address)
|
||||
|
||||
if address_manager is not None:
|
||||
await address_manager.ensure_address_gap()
|
||||
|
||||
local_status, local_history = \
|
||||
await self.get_local_status_and_history(address, synced_history.getvalue())
|
||||
if local_status != remote_status:
|
||||
if local_history == remote_history:
|
||||
return True
|
||||
log.warning(
|
||||
"Wallet is out of sync after syncing. Remote: %s with %d items, local: %s with %d items",
|
||||
remote_status, len(remote_history), local_status, len(local_history)
|
||||
)
|
||||
log.warning("local: %s", local_history)
|
||||
log.warning("remote: %s", remote_history)
|
||||
self._known_addresses_out_of_sync.add(address)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
async def cache_transaction(self, txid, remote_height, check_local=True):
|
||||
cache_item = self._tx_cache.get(txid)
|
||||
if cache_item is None:
|
||||
cache_item = self._tx_cache[txid] = TransactionCacheItem()
|
||||
elif cache_item.tx is not None and \
|
||||
cache_item.tx.height >= remote_height and \
|
||||
(cache_item.tx.is_verified or remote_height < 1):
|
||||
return cache_item.tx # cached tx is already up-to-date
|
||||
|
||||
async with cache_item.lock:
|
||||
|
||||
tx = cache_item.tx
|
||||
|
||||
if tx is None and check_local:
|
||||
# check local db
|
||||
tx = cache_item.tx = await self.db.get_transaction(txid=txid)
|
||||
|
||||
if tx is None:
|
||||
# fetch from network
|
||||
_raw = await self.network.retriable_call(self.network.get_transaction, txid, remote_height)
|
||||
tx = Transaction(unhexlify(_raw))
|
||||
cache_item.tx = tx # make sure it's saved before caching it
|
||||
|
||||
await self.maybe_verify_transaction(tx, remote_height)
|
||||
return tx
|
||||
|
||||
async def maybe_verify_transaction(self, tx, remote_height):
|
||||
tx.height = remote_height
|
||||
if 0 < remote_height < len(self.headers):
|
||||
merkle = await self.network.retriable_call(self.network.get_merkle, tx.id, remote_height)
|
||||
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
|
||||
header = self.headers[remote_height]
|
||||
tx.position = merkle['pos']
|
||||
tx.is_verified = merkle_root == header['merkle_root']
|
||||
|
||||
async def get_address_manager_for_address(self, address) -> Optional[AddressManager]:
|
||||
details = await self.db.get_address(address=address)
|
||||
for account in self.accounts:
|
||||
if account.id == details['account']:
|
||||
return account.address_managers[details['chain']]
|
||||
return None
|
||||
|
||||
def broadcast(self, tx):
|
||||
# broadcast can't be a retriable call yet
|
||||
return self.network.broadcast(hexlify(tx.raw).decode())
|
||||
|
||||
async def wait(self, tx: Transaction, height=-1, timeout=1):
|
||||
addresses = set()
|
||||
for txi in tx.inputs:
|
||||
if txi.txo_ref.txo is not None:
|
||||
addresses.add(
|
||||
self.hash160_to_address(txi.txo_ref.txo.pubkey_hash)
|
||||
)
|
||||
for txo in tx.outputs:
|
||||
if txo.has_address:
|
||||
addresses.add(self.hash160_to_address(txo.pubkey_hash))
|
||||
records = await self.db.get_addresses(address__in=addresses)
|
||||
_, pending = await asyncio.wait([
|
||||
self.on_transaction.where(partial(
|
||||
lambda a, e: a == e.address and e.tx.height >= height and e.tx.id == tx.id,
|
||||
address_record['address']
|
||||
)) for address_record in records
|
||||
], timeout=timeout)
|
||||
if pending:
|
||||
for record in records:
|
||||
found = False
|
||||
_, local_history = await self.get_local_status_and_history(None, history=record['history'])
|
||||
for txid, local_height in local_history:
|
||||
if txid == tx.id and local_height >= height:
|
||||
found = True
|
||||
if not found:
|
||||
print(record['history'], addresses, tx.id)
|
||||
raise asyncio.TimeoutError('Timed out waiting for transaction.')
|
||||
|
||||
async def _inflate_outputs(self, query, accounts):
|
||||
outputs = Outputs.from_base64(await query)
|
||||
txs = []
|
||||
|
@ -103,16 +669,6 @@ class MainNetLedger(BaseLedger):
|
|||
for claim in (await self.claim_search(accounts, claim_id=claim_id))[0]:
|
||||
return claim
|
||||
|
||||
async def start(self):
|
||||
await super().start()
|
||||
await asyncio.gather(*(a.maybe_migrate_certificates() for a in self.accounts))
|
||||
await asyncio.gather(*(a.save_max_gap() for a in self.accounts))
|
||||
if len(self.accounts) > 10:
|
||||
log.info("Loaded %i accounts", len(self.accounts))
|
||||
else:
|
||||
await self._report_state()
|
||||
self.on_transaction.listen(self._reset_balance_cache)
|
||||
|
||||
async def _report_state(self):
|
||||
try:
|
||||
for account in self.accounts:
|
||||
|
@ -147,14 +703,6 @@ class MainNetLedger(BaseLedger):
|
|||
def constraint_spending_utxos(constraints):
|
||||
constraints['txo_type__in'] = (0, TXO_TYPES['purchase'])
|
||||
|
||||
def get_utxos(self, **constraints):
|
||||
self.constraint_spending_utxos(constraints)
|
||||
return super().get_utxos(**constraints)
|
||||
|
||||
def get_utxo_count(self, **constraints):
|
||||
self.constraint_spending_utxos(constraints)
|
||||
return super().get_utxo_count(**constraints)
|
||||
|
||||
async def get_purchases(self, resolve=False, **constraints):
|
||||
purchases = await self.db.get_purchases(**constraints)
|
||||
if resolve:
|
||||
|
@ -357,7 +905,7 @@ class MainNetLedger(BaseLedger):
|
|||
return result
|
||||
|
||||
|
||||
class TestNetLedger(MainNetLedger):
|
||||
class TestNetLedger(Ledger):
|
||||
network_name = 'testnet'
|
||||
pubkey_address_prefix = bytes((111,))
|
||||
script_address_prefix = bytes((196,))
|
||||
|
@ -365,7 +913,7 @@ class TestNetLedger(MainNetLedger):
|
|||
extended_private_key_prefix = unhexlify('04358394')
|
||||
|
||||
|
||||
class RegTestLedger(MainNetLedger):
|
||||
class RegTestLedger(Ledger):
|
||||
network_name = 'regtest'
|
||||
headers_class = UnvalidatedHeaders
|
||||
pubkey_address_prefix = bytes((111,))
|
||||
|
|
|
@ -1,43 +1,112 @@
|
|||
import os
|
||||
import json
|
||||
import typing
|
||||
import logging
|
||||
import asyncio
|
||||
from binascii import unhexlify
|
||||
from typing import Optional, List
|
||||
from decimal import Decimal
|
||||
|
||||
from lbry.wallet.client.basemanager import BaseWalletManager
|
||||
from lbry.wallet.client.wallet import ENCRYPT_ON_DISK
|
||||
from lbry.wallet.rpc.jsonrpc import CodeMessageError
|
||||
from typing import List, Type, MutableSequence, MutableMapping, Optional
|
||||
|
||||
from lbry.error import KeyFeeAboveMaxAllowedError
|
||||
from lbry.wallet.dewies import dewies_to_lbc
|
||||
from lbry.wallet.account import Account
|
||||
from lbry.wallet.ledger import MainNetLedger
|
||||
from lbry.wallet.transaction import Transaction, Output
|
||||
from lbry.wallet.database import WalletDatabase
|
||||
from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager
|
||||
from lbry.conf import Config
|
||||
|
||||
from .dewies import dewies_to_lbc
|
||||
from .account import Account
|
||||
from .ledger import Ledger, LedgerRegistry
|
||||
from .transaction import Transaction, Output
|
||||
from .database import Database
|
||||
from .wallet import Wallet, WalletStorage, ENCRYPT_ON_DISK
|
||||
from .rpc.jsonrpc import CodeMessageError
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager
|
||||
class WalletManager:
|
||||
|
||||
|
||||
class LbryWalletManager(BaseWalletManager):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
def __init__(self, wallets: MutableSequence[Wallet] = None,
|
||||
ledgers: MutableMapping[Type[Ledger], Ledger] = None) -> None:
|
||||
self.wallets = wallets or []
|
||||
self.ledgers = ledgers or {}
|
||||
self.running = False
|
||||
self.config: Optional[Config] = None
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> 'WalletManager':
|
||||
manager = cls()
|
||||
for ledger_id, ledger_config in config.get('ledgers', {}).items():
|
||||
manager.get_or_create_ledger(ledger_id, ledger_config)
|
||||
for wallet_path in config.get('wallets', []):
|
||||
wallet_storage = WalletStorage(wallet_path)
|
||||
wallet = Wallet.from_storage(wallet_storage, manager)
|
||||
manager.wallets.append(wallet)
|
||||
return manager
|
||||
|
||||
def get_or_create_ledger(self, ledger_id, ledger_config=None):
|
||||
ledger_class = LedgerRegistry.get_ledger_class(ledger_id)
|
||||
ledger = self.ledgers.get(ledger_class)
|
||||
if ledger is None:
|
||||
ledger = ledger_class(ledger_config or {})
|
||||
self.ledgers[ledger_class] = ledger
|
||||
return ledger
|
||||
|
||||
def import_wallet(self, path):
|
||||
storage = WalletStorage(path)
|
||||
wallet = Wallet.from_storage(storage, self)
|
||||
self.wallets.append(wallet)
|
||||
return wallet
|
||||
|
||||
@property
|
||||
def ledger(self) -> MainNetLedger:
|
||||
def default_wallet(self):
|
||||
for wallet in self.wallets:
|
||||
return wallet
|
||||
|
||||
@property
|
||||
def default_account(self):
|
||||
for wallet in self.wallets:
|
||||
return wallet.default_account
|
||||
|
||||
@property
|
||||
def accounts(self):
|
||||
for wallet in self.wallets:
|
||||
yield from wallet.accounts
|
||||
|
||||
async def start(self):
|
||||
self.running = True
|
||||
await asyncio.gather(*(
|
||||
l.start() for l in self.ledgers.values()
|
||||
))
|
||||
|
||||
async def stop(self):
|
||||
await asyncio.gather(*(
|
||||
l.stop() for l in self.ledgers.values()
|
||||
))
|
||||
self.running = False
|
||||
|
||||
def get_wallet_or_default(self, wallet_id: Optional[str]) -> Wallet:
|
||||
if wallet_id is None:
|
||||
return self.default_wallet
|
||||
return self.get_wallet_or_error(wallet_id)
|
||||
|
||||
def get_wallet_or_error(self, wallet_id: str) -> Wallet:
|
||||
for wallet in self.wallets:
|
||||
if wallet.id == wallet_id:
|
||||
return wallet
|
||||
raise ValueError(f"Couldn't find wallet: {wallet_id}.")
|
||||
|
||||
@staticmethod
|
||||
def get_balance(wallet):
|
||||
accounts = wallet.accounts
|
||||
if not accounts:
|
||||
return 0
|
||||
return accounts[0].ledger.db.get_balance(wallet=wallet, accounts=accounts)
|
||||
|
||||
@property
|
||||
def ledger(self) -> Ledger:
|
||||
return self.default_account.ledger
|
||||
|
||||
@property
|
||||
def db(self) -> WalletDatabase:
|
||||
def db(self) -> Database:
|
||||
return self.ledger.db
|
||||
|
||||
def check_locked(self):
|
||||
|
@ -194,7 +263,7 @@ class LbryWalletManager(BaseWalletManager):
|
|||
if 'No such mempool or blockchain transaction.' in e.message:
|
||||
return {'success': False, 'code': 404, 'message': 'transaction not found'}
|
||||
return {'success': False, 'code': e.code, 'message': e.message}
|
||||
tx = self.ledger.transaction_class(unhexlify(raw))
|
||||
tx = Transaction(unhexlify(raw))
|
||||
await self.ledger.maybe_verify_transaction(tx, height)
|
||||
return tx
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ from secrets import randbelow
|
|||
import pbkdf2
|
||||
|
||||
from lbry.crypto.hash import hmac_sha512
|
||||
from lbry.wallet.client.words import english
|
||||
from .words import english
|
||||
|
||||
# The hash of the mnemonic seed must begin with this
|
||||
SEED_PREFIX = b'01' # Standard wallet
|
||||
|
|
|
@ -1,9 +1,276 @@
|
|||
import lbry
|
||||
from lbry.wallet.client.basenetwork import BaseNetwork
|
||||
import logging
|
||||
import asyncio
|
||||
from time import perf_counter
|
||||
from operator import itemgetter
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from lbry import __version__
|
||||
from lbry.wallet.rpc import RPCSession as BaseClientSession, Connector, RPCError, ProtocolError
|
||||
from lbry.wallet.stream import StreamController
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Network(BaseNetwork):
|
||||
PROTOCOL_VERSION = lbry.__version__
|
||||
class ClientSession(BaseClientSession):
|
||||
def __init__(self, *args, network, server, timeout=30, on_connect_callback=None, **kwargs):
|
||||
self.network = network
|
||||
self.server = server
|
||||
super().__init__(*args, **kwargs)
|
||||
self._on_disconnect_controller = StreamController()
|
||||
self.on_disconnected = self._on_disconnect_controller.stream
|
||||
self.framer.max_size = self.max_errors = 1 << 32
|
||||
self.bw_limit = -1
|
||||
self.timeout = timeout
|
||||
self.max_seconds_idle = timeout * 2
|
||||
self.response_time: Optional[float] = None
|
||||
self.connection_latency: Optional[float] = None
|
||||
self._response_samples = 0
|
||||
self.pending_amount = 0
|
||||
self._on_connect_cb = on_connect_callback or (lambda: None)
|
||||
self.trigger_urgent_reconnect = asyncio.Event()
|
||||
|
||||
@property
|
||||
def available(self):
|
||||
return not self.is_closing() and self.response_time is not None
|
||||
|
||||
@property
|
||||
def server_address_and_port(self) -> Optional[Tuple[str, int]]:
|
||||
if not self.transport:
|
||||
return None
|
||||
return self.transport.get_extra_info('peername')
|
||||
|
||||
async def send_timed_server_version_request(self, args=(), timeout=None):
|
||||
timeout = timeout or self.timeout
|
||||
log.debug("send version request to %s:%i", *self.server)
|
||||
start = perf_counter()
|
||||
result = await asyncio.wait_for(
|
||||
super().send_request('server.version', args), timeout=timeout
|
||||
)
|
||||
current_response_time = perf_counter() - start
|
||||
response_sum = (self.response_time or 0) * self._response_samples + current_response_time
|
||||
self.response_time = response_sum / (self._response_samples + 1)
|
||||
self._response_samples += 1
|
||||
return result
|
||||
|
||||
async def send_request(self, method, args=()):
|
||||
self.pending_amount += 1
|
||||
log.debug("send %s to %s:%i", method, *self.server)
|
||||
try:
|
||||
if method == 'server.version':
|
||||
return await self.send_timed_server_version_request(args, self.timeout)
|
||||
request = asyncio.ensure_future(super().send_request(method, args))
|
||||
while not request.done():
|
||||
done, pending = await asyncio.wait([request], timeout=self.timeout)
|
||||
if pending:
|
||||
log.debug("Time since last packet: %s", perf_counter() - self.last_packet_received)
|
||||
if (perf_counter() - self.last_packet_received) < self.timeout:
|
||||
continue
|
||||
log.info("timeout sending %s to %s:%i", method, *self.server)
|
||||
raise asyncio.TimeoutError
|
||||
if done:
|
||||
return request.result()
|
||||
except (RPCError, ProtocolError) as e:
|
||||
log.warning("Wallet server (%s:%i) returned an error. Code: %s Message: %s",
|
||||
*self.server, *e.args)
|
||||
raise e
|
||||
except ConnectionError:
|
||||
log.warning("connection to %s:%i lost", *self.server)
|
||||
self.synchronous_close()
|
||||
raise
|
||||
except asyncio.CancelledError:
|
||||
log.info("cancelled sending %s to %s:%i", method, *self.server)
|
||||
self.synchronous_close()
|
||||
raise
|
||||
finally:
|
||||
self.pending_amount -= 1
|
||||
|
||||
async def ensure_session(self):
|
||||
# Handles reconnecting and maintaining a session alive
|
||||
# TODO: change to 'ping' on newer protocol (above 1.2)
|
||||
retry_delay = default_delay = 1.0
|
||||
while True:
|
||||
try:
|
||||
if self.is_closing():
|
||||
await self.create_connection(self.timeout)
|
||||
await self.ensure_server_version()
|
||||
self._on_connect_cb()
|
||||
if (perf_counter() - self.last_send) > self.max_seconds_idle or self.response_time is None:
|
||||
await self.ensure_server_version()
|
||||
retry_delay = default_delay
|
||||
except RPCError as e:
|
||||
log.warning("Server error, ignoring for 1h: %s:%d -- %s", *self.server, e.message)
|
||||
retry_delay = 60 * 60
|
||||
except (asyncio.TimeoutError, OSError):
|
||||
await self.close()
|
||||
retry_delay = min(60, retry_delay * 2)
|
||||
log.debug("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server)
|
||||
try:
|
||||
await asyncio.wait_for(self.trigger_urgent_reconnect.wait(), timeout=retry_delay)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
finally:
|
||||
self.trigger_urgent_reconnect.clear()
|
||||
|
||||
async def ensure_server_version(self, required=None, timeout=3):
|
||||
required = required or self.network.PROTOCOL_VERSION
|
||||
return await asyncio.wait_for(
|
||||
self.send_request('server.version', [__version__, required]), timeout=timeout
|
||||
)
|
||||
|
||||
async def create_connection(self, timeout=6):
|
||||
connector = Connector(lambda: self, *self.server)
|
||||
start = perf_counter()
|
||||
await asyncio.wait_for(connector.create_connection(), timeout=timeout)
|
||||
self.connection_latency = perf_counter() - start
|
||||
|
||||
async def handle_request(self, request):
|
||||
controller = self.network.subscription_controllers[request.method]
|
||||
controller.add(request.args)
|
||||
|
||||
def connection_lost(self, exc):
|
||||
log.debug("Connection lost: %s:%d", *self.server)
|
||||
super().connection_lost(exc)
|
||||
self.response_time = None
|
||||
self.connection_latency = None
|
||||
self._response_samples = 0
|
||||
self.pending_amount = 0
|
||||
self._on_disconnect_controller.add(True)
|
||||
|
||||
|
||||
class Network:
|
||||
|
||||
PROTOCOL_VERSION = __version__
|
||||
|
||||
def __init__(self, ledger):
|
||||
self.ledger = ledger
|
||||
self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6))
|
||||
self.client: Optional[ClientSession] = None
|
||||
self._switch_task: Optional[asyncio.Task] = None
|
||||
self.running = False
|
||||
self.remote_height: int = 0
|
||||
self._concurrency = asyncio.Semaphore(16)
|
||||
|
||||
self._on_connected_controller = StreamController()
|
||||
self.on_connected = self._on_connected_controller.stream
|
||||
|
||||
self._on_header_controller = StreamController(merge_repeated_events=True)
|
||||
self.on_header = self._on_header_controller.stream
|
||||
|
||||
self._on_status_controller = StreamController(merge_repeated_events=True)
|
||||
self.on_status = self._on_status_controller.stream
|
||||
|
||||
self.subscription_controllers = {
|
||||
'blockchain.headers.subscribe': self._on_header_controller,
|
||||
'blockchain.address.subscribe': self._on_status_controller,
|
||||
}
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return self.ledger.config
|
||||
|
||||
async def switch_forever(self):
|
||||
while self.running:
|
||||
if self.is_connected:
|
||||
await self.client.on_disconnected.first
|
||||
self.client = None
|
||||
continue
|
||||
self.client = await self.session_pool.wait_for_fastest_session()
|
||||
log.info("Switching to SPV wallet server: %s:%d", *self.client.server)
|
||||
try:
|
||||
self._update_remote_height((await self.subscribe_headers(),))
|
||||
self._on_connected_controller.add(True)
|
||||
log.info("Subscribed to headers: %s:%d", *self.client.server)
|
||||
except (asyncio.TimeoutError, ConnectionError):
|
||||
log.info("Switching to %s:%d timed out, closing and retrying.", *self.client.server)
|
||||
self.client.synchronous_close()
|
||||
self.client = None
|
||||
|
||||
async def start(self):
|
||||
self.running = True
|
||||
self._switch_task = asyncio.ensure_future(self.switch_forever())
|
||||
# this may become unnecessary when there are no more bugs found,
|
||||
# but for now it helps understanding log reports
|
||||
self._switch_task.add_done_callback(lambda _: log.info("Wallet client switching task stopped."))
|
||||
self.session_pool.start(self.config['default_servers'])
|
||||
self.on_header.listen(self._update_remote_height)
|
||||
|
||||
async def stop(self):
|
||||
if self.running:
|
||||
self.running = False
|
||||
self._switch_task.cancel()
|
||||
self.session_pool.stop()
|
||||
|
||||
@property
|
||||
def is_connected(self):
|
||||
return self.client and not self.client.is_closing()
|
||||
|
||||
def rpc(self, list_or_method, args, restricted=True):
|
||||
session = self.client if restricted else self.session_pool.fastest_session
|
||||
if session and not session.is_closing():
|
||||
return session.send_request(list_or_method, args)
|
||||
else:
|
||||
self.session_pool.trigger_nodelay_connect()
|
||||
raise ConnectionError("Attempting to send rpc request when connection is not available.")
|
||||
|
||||
async def retriable_call(self, function, *args, **kwargs):
|
||||
async with self._concurrency:
|
||||
while self.running:
|
||||
if not self.is_connected:
|
||||
log.warning("Wallet server unavailable, waiting for it to come back and retry.")
|
||||
await self.on_connected.first
|
||||
await self.session_pool.wait_for_fastest_session()
|
||||
try:
|
||||
return await function(*args, **kwargs)
|
||||
except asyncio.TimeoutError:
|
||||
log.warning("Wallet server call timed out, retrying.")
|
||||
except ConnectionError:
|
||||
pass
|
||||
raise asyncio.CancelledError() # if we got here, we are shutting down
|
||||
|
||||
def _update_remote_height(self, header_args):
|
||||
self.remote_height = header_args[0]["height"]
|
||||
|
||||
def get_transaction(self, tx_hash, known_height=None):
|
||||
# use any server if its old, otherwise restrict to who gave us the history
|
||||
restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10
|
||||
return self.rpc('blockchain.transaction.get', [tx_hash], restricted)
|
||||
|
||||
def get_transaction_height(self, tx_hash, known_height=None):
|
||||
restricted = not known_height or 0 > known_height > self.remote_height - 10
|
||||
return self.rpc('blockchain.transaction.get_height', [tx_hash], restricted)
|
||||
|
||||
def get_merkle(self, tx_hash, height):
|
||||
restricted = 0 > height > self.remote_height - 10
|
||||
return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height], restricted)
|
||||
|
||||
def get_headers(self, height, count=10000, b64=False):
|
||||
restricted = height >= self.remote_height - 100
|
||||
return self.rpc('blockchain.block.headers', [height, count, 0, b64], restricted)
|
||||
|
||||
# --- Subscribes, history and broadcasts are always aimed towards the master client directly
|
||||
def get_history(self, address):
|
||||
return self.rpc('blockchain.address.get_history', [address], True)
|
||||
|
||||
def broadcast(self, raw_transaction):
|
||||
return self.rpc('blockchain.transaction.broadcast', [raw_transaction], True)
|
||||
|
||||
def subscribe_headers(self):
|
||||
return self.rpc('blockchain.headers.subscribe', [True], True)
|
||||
|
||||
async def subscribe_address(self, address):
|
||||
try:
|
||||
return await self.rpc('blockchain.address.subscribe', [address], True)
|
||||
except asyncio.TimeoutError:
|
||||
# abort and cancel, we can't lose a subscription, it will happen again on reconnect
|
||||
if self.client:
|
||||
self.client.abort()
|
||||
raise asyncio.CancelledError()
|
||||
|
||||
def unsubscribe_address(self, address):
|
||||
return self.rpc('blockchain.address.unsubscribe', [address], True)
|
||||
|
||||
def get_server_features(self):
|
||||
return self.rpc('server.features', (), restricted=True)
|
||||
|
||||
def get_claims_by_ids(self, claim_ids):
|
||||
return self.rpc('blockchain.claimtrie.getclaimsbyids', claim_ids)
|
||||
|
@ -13,3 +280,95 @@ class Network(BaseNetwork):
|
|||
|
||||
def claim_search(self, **kwargs):
|
||||
return self.rpc('blockchain.claimtrie.search', kwargs)
|
||||
|
||||
|
||||
class SessionPool:
|
||||
|
||||
def __init__(self, network: Network, timeout: float):
|
||||
self.network = network
|
||||
self.sessions: Dict[ClientSession, Optional[asyncio.Task]] = dict()
|
||||
self.timeout = timeout
|
||||
self.new_connection_event = asyncio.Event()
|
||||
|
||||
@property
|
||||
def online(self):
|
||||
return any(not session.is_closing() for session in self.sessions)
|
||||
|
||||
@property
|
||||
def available_sessions(self):
|
||||
return (session for session in self.sessions if session.available)
|
||||
|
||||
@property
|
||||
def fastest_session(self):
|
||||
if not self.online:
|
||||
return None
|
||||
return min(
|
||||
[((session.response_time + session.connection_latency) * (session.pending_amount + 1), session)
|
||||
for session in self.available_sessions] or [(0, None)],
|
||||
key=itemgetter(0)
|
||||
)[1]
|
||||
|
||||
def _get_session_connect_callback(self, session: ClientSession):
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def callback():
|
||||
duplicate_connections = [
|
||||
s for s in self.sessions
|
||||
if s is not session and s.server_address_and_port == session.server_address_and_port
|
||||
]
|
||||
already_connected = None if not duplicate_connections else duplicate_connections[0]
|
||||
if already_connected:
|
||||
self.sessions.pop(session).cancel()
|
||||
session.synchronous_close()
|
||||
log.debug("wallet server %s resolves to the same server as %s, rechecking in an hour",
|
||||
session.server[0], already_connected.server[0])
|
||||
loop.call_later(3600, self._connect_session, session.server)
|
||||
return
|
||||
self.new_connection_event.set()
|
||||
log.info("connected to %s:%i", *session.server)
|
||||
|
||||
return callback
|
||||
|
||||
def _connect_session(self, server: Tuple[str, int]):
|
||||
session = None
|
||||
for s in self.sessions:
|
||||
if s.server == server:
|
||||
session = s
|
||||
break
|
||||
if not session:
|
||||
session = ClientSession(
|
||||
network=self.network, server=server
|
||||
)
|
||||
session._on_connect_cb = self._get_session_connect_callback(session)
|
||||
task = self.sessions.get(session, None)
|
||||
if not task or task.done():
|
||||
task = asyncio.create_task(session.ensure_session())
|
||||
task.add_done_callback(lambda _: self.ensure_connections())
|
||||
self.sessions[session] = task
|
||||
|
||||
def start(self, default_servers):
|
||||
for server in default_servers:
|
||||
self._connect_session(server)
|
||||
|
||||
def stop(self):
|
||||
for session, task in self.sessions.items():
|
||||
task.cancel()
|
||||
session.synchronous_close()
|
||||
self.sessions.clear()
|
||||
|
||||
def ensure_connections(self):
|
||||
for session in self.sessions:
|
||||
self._connect_session(session.server)
|
||||
|
||||
def trigger_nodelay_connect(self):
|
||||
# used when other parts of the system sees we might have internet back
|
||||
# bypasses the retry interval
|
||||
for session in self.sessions:
|
||||
session.trigger_urgent_reconnect.set()
|
||||
|
||||
async def wait_for_fastest_session(self):
|
||||
while not self.fastest_session:
|
||||
self.trigger_nodelay_connect()
|
||||
self.new_connection_event.clear()
|
||||
await self.new_connection_event.wait()
|
||||
return self.fastest_session
|
||||
|
|
|
@ -12,28 +12,15 @@ from binascii import hexlify
|
|||
from typing import Type, Optional
|
||||
import urllib.request
|
||||
|
||||
import lbry
|
||||
from lbry.wallet.server.server import Server
|
||||
from lbry.wallet.server.env import Env
|
||||
from lbry.wallet.client.wallet import Wallet
|
||||
from lbry.wallet.client.baseledger import BaseLedger, BlockHeightEvent
|
||||
from lbry.wallet.client.basemanager import BaseWalletManager
|
||||
from lbry.wallet.client.baseaccount import BaseAccount
|
||||
from lbry.wallet import Wallet, Ledger, RegTestLedger, WalletManager, Account, BlockHeightEvent
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_manager_from_environment(default_manager=BaseWalletManager):
|
||||
if 'TORBA_MANAGER' not in os.environ:
|
||||
return default_manager
|
||||
module_name = os.environ['TORBA_MANAGER'].split('-')[-1] # tox support
|
||||
return importlib.import_module(module_name)
|
||||
|
||||
|
||||
def get_ledger_from_environment():
|
||||
return importlib.import_module('lbry.wallet')
|
||||
|
||||
|
||||
def get_spvserver_from_ledger(ledger_module):
|
||||
spvserver_path, regtest_class_name = ledger_module.__spvserver__.rsplit('.', 1)
|
||||
spvserver_module = importlib.import_module(spvserver_path)
|
||||
|
@ -50,16 +37,14 @@ def get_blockchain_node_from_ledger(ledger_module):
|
|||
|
||||
class Conductor:
|
||||
|
||||
def __init__(self, ledger_module=None, manager_module=None, enable_segwit=False, seed=None):
|
||||
self.ledger_module = ledger_module or get_ledger_from_environment()
|
||||
self.manager_module = manager_module or get_manager_from_environment()
|
||||
self.spv_module = get_spvserver_from_ledger(self.ledger_module)
|
||||
def __init__(self, seed=None):
|
||||
self.manager_module = WalletManager
|
||||
self.spv_module = get_spvserver_from_ledger(lbry.wallet)
|
||||
|
||||
self.blockchain_node = get_blockchain_node_from_ledger(self.ledger_module)
|
||||
self.blockchain_node.segwit_enabled = enable_segwit
|
||||
self.blockchain_node = get_blockchain_node_from_ledger(lbry.wallet)
|
||||
self.spv_node = SPVNode(self.spv_module)
|
||||
self.wallet_node = WalletNode(
|
||||
self.manager_module, self.ledger_module.RegTestLedger, default_seed=seed
|
||||
self.manager_module, RegTestLedger, default_seed=seed
|
||||
)
|
||||
|
||||
self.blockchain_started = False
|
||||
|
@ -119,15 +104,15 @@ class Conductor:
|
|||
|
||||
class WalletNode:
|
||||
|
||||
def __init__(self, manager_class: Type[BaseWalletManager], ledger_class: Type[BaseLedger],
|
||||
def __init__(self, manager_class: Type[WalletManager], ledger_class: Type[Ledger],
|
||||
verbose: bool = False, port: int = 5280, default_seed: str = None) -> None:
|
||||
self.manager_class = manager_class
|
||||
self.ledger_class = ledger_class
|
||||
self.verbose = verbose
|
||||
self.manager: Optional[BaseWalletManager] = None
|
||||
self.ledger: Optional[BaseLedger] = None
|
||||
self.manager: Optional[WalletManager] = None
|
||||
self.ledger: Optional[Ledger] = None
|
||||
self.wallet: Optional[Wallet] = None
|
||||
self.account: Optional[BaseAccount] = None
|
||||
self.account: Optional[Account] = None
|
||||
self.data_path: Optional[str] = None
|
||||
self.port = port
|
||||
self.default_seed = default_seed
|
||||
|
@ -154,7 +139,7 @@ class WalletNode:
|
|||
if not self.wallet:
|
||||
raise ValueError('Wallet is required.')
|
||||
if seed or self.default_seed:
|
||||
self.ledger.account_class.from_dict(
|
||||
Account.from_dict(
|
||||
self.ledger, self.wallet, {'seed': seed or self.default_seed}
|
||||
)
|
||||
else:
|
||||
|
@ -250,7 +235,7 @@ class BlockchainNode:
|
|||
P2SH_SEGWIT_ADDRESS = "p2sh-segwit"
|
||||
BECH32_ADDRESS = "bech32"
|
||||
|
||||
def __init__(self, url, daemon, cli, segwit_enabled=False):
|
||||
def __init__(self, url, daemon, cli):
|
||||
self.latest_release_url = url
|
||||
self.project_dir = os.path.dirname(os.path.dirname(__file__))
|
||||
self.bin_dir = os.path.join(self.project_dir, 'bin')
|
||||
|
@ -266,7 +251,6 @@ class BlockchainNode:
|
|||
self.rpcport = 9245 + 2 # avoid conflict with default rpc port
|
||||
self.rpcuser = 'rpcuser'
|
||||
self.rpcpassword = 'rpcpassword'
|
||||
self.segwit_enabled = segwit_enabled
|
||||
|
||||
@property
|
||||
def rpc_url(self):
|
||||
|
@ -326,8 +310,6 @@ class BlockchainNode:
|
|||
f'-rpcuser={self.rpcuser}', f'-rpcpassword={self.rpcpassword}', f'-rpcport={self.rpcport}',
|
||||
f'-port={self.peerport}'
|
||||
]
|
||||
if not self.segwit_enabled:
|
||||
command.extend(['-addresstype=legacy', '-vbparams=segwit:0:999999999999'])
|
||||
self.log.info(' '.join(command))
|
||||
self.transport, self.protocol = await loop.subprocess_exec(
|
||||
BlockchainProcess, *command
|
||||
|
|
|
@ -3,7 +3,7 @@ import logging
|
|||
from aiohttp.web import Application, WebSocketResponse, json_response
|
||||
from aiohttp.http_websocket import WSMsgType, WSCloseCode
|
||||
|
||||
from lbry.wallet.client.util import satoshis_to_coins
|
||||
from lbry.wallet.util import satoshis_to_coins
|
||||
from .node import Conductor
|
||||
|
||||
|
||||
|
|
|
@ -1,34 +1,430 @@
|
|||
from lbry.wallet.client.basescript import BaseInputScript, BaseOutputScript, Template
|
||||
from lbry.wallet.client.basescript import PUSH_SINGLE, PUSH_INTEGER, OP_DROP, OP_2DROP, PUSH_SUBSCRIPT, OP_VERIFY
|
||||
from typing import List
|
||||
from itertools import chain
|
||||
from binascii import hexlify
|
||||
from collections import namedtuple
|
||||
|
||||
from .bcd_data_stream import BCDataStream
|
||||
from .util import subclass_tuple
|
||||
|
||||
|
||||
class InputScript(BaseInputScript):
|
||||
# bitcoin opcodes
|
||||
OP_0 = 0x00
|
||||
OP_1 = 0x51
|
||||
OP_16 = 0x60
|
||||
OP_VERIFY = 0x69
|
||||
OP_DUP = 0x76
|
||||
OP_HASH160 = 0xa9
|
||||
OP_EQUALVERIFY = 0x88
|
||||
OP_CHECKSIG = 0xac
|
||||
OP_CHECKMULTISIG = 0xae
|
||||
OP_EQUAL = 0x87
|
||||
OP_PUSHDATA1 = 0x4c
|
||||
OP_PUSHDATA2 = 0x4d
|
||||
OP_PUSHDATA4 = 0x4e
|
||||
OP_RETURN = 0x6a
|
||||
OP_2DROP = 0x6d
|
||||
OP_DROP = 0x75
|
||||
|
||||
# lbry custom opcodes
|
||||
# checks
|
||||
OP_PRICECHECK = 0xb0 # checks that the BUY output is >= SELL price
|
||||
# tx types
|
||||
OP_CLAIM_NAME = 0xb5
|
||||
OP_SUPPORT_CLAIM = 0xb6
|
||||
OP_UPDATE_CLAIM = 0xb7
|
||||
OP_SELL_CLAIM = 0xb8
|
||||
OP_BUY_CLAIM = 0xb9
|
||||
|
||||
# template matching opcodes (not real opcodes)
|
||||
# base class for PUSH_DATA related opcodes
|
||||
# pylint: disable=invalid-name
|
||||
PUSH_DATA_OP = namedtuple('PUSH_DATA_OP', 'name')
|
||||
# opcode for variable length strings
|
||||
# pylint: disable=invalid-name
|
||||
PUSH_SINGLE = subclass_tuple('PUSH_SINGLE', PUSH_DATA_OP)
|
||||
# opcode for variable size integers
|
||||
# pylint: disable=invalid-name
|
||||
PUSH_INTEGER = subclass_tuple('PUSH_INTEGER', PUSH_DATA_OP)
|
||||
# opcode for variable number of variable length strings
|
||||
# pylint: disable=invalid-name
|
||||
PUSH_MANY = subclass_tuple('PUSH_MANY', PUSH_DATA_OP)
|
||||
# opcode with embedded subscript parsing
|
||||
# pylint: disable=invalid-name
|
||||
PUSH_SUBSCRIPT = namedtuple('PUSH_SUBSCRIPT', 'name template')
|
||||
|
||||
|
||||
def is_push_data_opcode(opcode):
|
||||
return isinstance(opcode, (PUSH_DATA_OP, PUSH_SUBSCRIPT))
|
||||
|
||||
|
||||
def is_push_data_token(token):
|
||||
return 1 <= token <= OP_PUSHDATA4
|
||||
|
||||
|
||||
def push_data(data):
|
||||
size = len(data)
|
||||
if size < OP_PUSHDATA1:
|
||||
yield BCDataStream.uint8.pack(size)
|
||||
elif size <= 0xFF:
|
||||
yield BCDataStream.uint8.pack(OP_PUSHDATA1)
|
||||
yield BCDataStream.uint8.pack(size)
|
||||
elif size <= 0xFFFF:
|
||||
yield BCDataStream.uint8.pack(OP_PUSHDATA2)
|
||||
yield BCDataStream.uint16.pack(size)
|
||||
else:
|
||||
yield BCDataStream.uint8.pack(OP_PUSHDATA4)
|
||||
yield BCDataStream.uint32.pack(size)
|
||||
yield bytes(data)
|
||||
|
||||
|
||||
def read_data(token, stream):
|
||||
if token < OP_PUSHDATA1:
|
||||
return stream.read(token)
|
||||
if token == OP_PUSHDATA1:
|
||||
return stream.read(stream.read_uint8())
|
||||
if token == OP_PUSHDATA2:
|
||||
return stream.read(stream.read_uint16())
|
||||
return stream.read(stream.read_uint32())
|
||||
|
||||
|
||||
# opcode for OP_1 - OP_16
|
||||
# pylint: disable=invalid-name
|
||||
SMALL_INTEGER = namedtuple('SMALL_INTEGER', 'name')
|
||||
|
||||
|
||||
def is_small_integer(token):
|
||||
return OP_1 <= token <= OP_16
|
||||
|
||||
|
||||
def push_small_integer(num):
|
||||
assert 1 <= num <= 16
|
||||
yield BCDataStream.uint8.pack(OP_1 + (num - 1))
|
||||
|
||||
|
||||
def read_small_integer(token):
|
||||
return (token - OP_1) + 1
|
||||
|
||||
|
||||
class Token(namedtuple('Token', 'value')):
|
||||
__slots__ = ()
|
||||
|
||||
def __repr__(self):
|
||||
name = None
|
||||
for var_name, var_value in globals().items():
|
||||
if var_name.startswith('OP_') and var_value == self.value:
|
||||
name = var_name
|
||||
break
|
||||
return name or self.value
|
||||
|
||||
|
||||
class DataToken(Token):
|
||||
__slots__ = ()
|
||||
|
||||
def __repr__(self):
|
||||
return f'"{hexlify(self.value)}"'
|
||||
|
||||
|
||||
class SmallIntegerToken(Token):
|
||||
__slots__ = ()
|
||||
|
||||
def __repr__(self):
|
||||
return f'SmallIntegerToken({self.value})'
|
||||
|
||||
|
||||
def token_producer(source):
|
||||
token = source.read_uint8()
|
||||
while token is not None:
|
||||
if is_push_data_token(token):
|
||||
yield DataToken(read_data(token, source))
|
||||
elif is_small_integer(token):
|
||||
yield SmallIntegerToken(read_small_integer(token))
|
||||
else:
|
||||
yield Token(token)
|
||||
token = source.read_uint8()
|
||||
|
||||
|
||||
def tokenize(source):
|
||||
return list(token_producer(source))
|
||||
|
||||
|
||||
class ScriptError(Exception):
|
||||
""" General script handling error. """
|
||||
|
||||
|
||||
class ParseError(ScriptError):
|
||||
""" Script parsing error. """
|
||||
|
||||
|
||||
class Parser:
|
||||
|
||||
def __init__(self, opcodes, tokens):
|
||||
self.opcodes = opcodes
|
||||
self.tokens = tokens
|
||||
self.values = {}
|
||||
self.token_index = 0
|
||||
self.opcode_index = 0
|
||||
|
||||
def parse(self):
|
||||
while self.token_index < len(self.tokens) and self.opcode_index < len(self.opcodes):
|
||||
token = self.tokens[self.token_index]
|
||||
opcode = self.opcodes[self.opcode_index]
|
||||
if token.value == 0 and isinstance(opcode, PUSH_SINGLE):
|
||||
token = DataToken(b'')
|
||||
if isinstance(token, DataToken):
|
||||
if isinstance(opcode, (PUSH_SINGLE, PUSH_INTEGER, PUSH_SUBSCRIPT)):
|
||||
self.push_single(opcode, token.value)
|
||||
elif isinstance(opcode, PUSH_MANY):
|
||||
self.consume_many_non_greedy()
|
||||
else:
|
||||
raise ParseError(f"DataToken found but opcode was '{opcode}'.")
|
||||
elif isinstance(token, SmallIntegerToken):
|
||||
if isinstance(opcode, SMALL_INTEGER):
|
||||
self.values[opcode.name] = token.value
|
||||
else:
|
||||
raise ParseError(f"SmallIntegerToken found but opcode was '{opcode}'.")
|
||||
elif token.value == opcode:
|
||||
pass
|
||||
else:
|
||||
raise ParseError(f"Token is '{token.value}' and opcode is '{opcode}'.")
|
||||
self.token_index += 1
|
||||
self.opcode_index += 1
|
||||
|
||||
if self.token_index < len(self.tokens):
|
||||
raise ParseError("Parse completed without all tokens being consumed.")
|
||||
|
||||
if self.opcode_index < len(self.opcodes):
|
||||
raise ParseError("Parse completed without all opcodes being consumed.")
|
||||
|
||||
return self
|
||||
|
||||
def consume_many_non_greedy(self):
|
||||
""" Allows PUSH_MANY to consume data without being greedy
|
||||
in cases when one or more PUSH_SINGLEs follow a PUSH_MANY. This will
|
||||
prioritize giving all PUSH_SINGLEs some data and only after that
|
||||
subsume the rest into PUSH_MANY.
|
||||
"""
|
||||
|
||||
token_values = []
|
||||
while self.token_index < len(self.tokens):
|
||||
token = self.tokens[self.token_index]
|
||||
if not isinstance(token, DataToken):
|
||||
self.token_index -= 1
|
||||
break
|
||||
token_values.append(token.value)
|
||||
self.token_index += 1
|
||||
|
||||
push_opcodes = []
|
||||
push_many_count = 0
|
||||
while self.opcode_index < len(self.opcodes):
|
||||
opcode = self.opcodes[self.opcode_index]
|
||||
if not is_push_data_opcode(opcode):
|
||||
self.opcode_index -= 1
|
||||
break
|
||||
if isinstance(opcode, PUSH_MANY):
|
||||
push_many_count += 1
|
||||
push_opcodes.append(opcode)
|
||||
self.opcode_index += 1
|
||||
|
||||
if push_many_count > 1:
|
||||
raise ParseError(
|
||||
"Cannot have more than one consecutive PUSH_MANY, as there is no way to tell which"
|
||||
" token value should go into which PUSH_MANY."
|
||||
)
|
||||
|
||||
if len(push_opcodes) > len(token_values):
|
||||
raise ParseError(
|
||||
"Not enough token values to match all of the PUSH_MANY and PUSH_SINGLE opcodes."
|
||||
)
|
||||
|
||||
many_opcode = push_opcodes.pop(0)
|
||||
|
||||
# consume data into PUSH_SINGLE opcodes, working backwards
|
||||
for opcode in reversed(push_opcodes):
|
||||
self.push_single(opcode, token_values.pop())
|
||||
|
||||
# finally PUSH_MANY gets everything that's left
|
||||
self.values[many_opcode.name] = token_values
|
||||
|
||||
def push_single(self, opcode, value):
|
||||
if isinstance(opcode, PUSH_SINGLE):
|
||||
self.values[opcode.name] = value
|
||||
elif isinstance(opcode, PUSH_INTEGER):
|
||||
self.values[opcode.name] = int.from_bytes(value, 'little')
|
||||
elif isinstance(opcode, PUSH_SUBSCRIPT):
|
||||
self.values[opcode.name] = Script.from_source_with_template(value, opcode.template)
|
||||
else:
|
||||
raise ParseError(f"Not a push single or subscript: {opcode}")
|
||||
|
||||
|
||||
class OutputScript(BaseOutputScript):
|
||||
class Template:
|
||||
|
||||
# lbry custom opcodes
|
||||
__slots__ = 'name', 'opcodes'
|
||||
|
||||
# checks
|
||||
OP_PRICECHECK = 0xb0 # checks that the BUY output is >= SELL price
|
||||
def __init__(self, name, opcodes):
|
||||
self.name = name
|
||||
self.opcodes = opcodes
|
||||
|
||||
# tx types
|
||||
OP_CLAIM_NAME = 0xb5
|
||||
OP_SUPPORT_CLAIM = 0xb6
|
||||
OP_UPDATE_CLAIM = 0xb7
|
||||
OP_SELL_CLAIM = 0xb8
|
||||
OP_BUY_CLAIM = 0xb9
|
||||
def parse(self, tokens):
|
||||
return Parser(self.opcodes, tokens).parse().values if self.opcodes else {}
|
||||
|
||||
def generate(self, values):
|
||||
source = BCDataStream()
|
||||
for opcode in self.opcodes:
|
||||
if isinstance(opcode, PUSH_SINGLE):
|
||||
data = values[opcode.name]
|
||||
source.write_many(push_data(data))
|
||||
elif isinstance(opcode, PUSH_INTEGER):
|
||||
data = values[opcode.name]
|
||||
source.write_many(push_data(
|
||||
data.to_bytes((data.bit_length() + 7) // 8, byteorder='little')
|
||||
))
|
||||
elif isinstance(opcode, PUSH_SUBSCRIPT):
|
||||
data = values[opcode.name]
|
||||
source.write_many(push_data(data.source))
|
||||
elif isinstance(opcode, PUSH_MANY):
|
||||
for data in values[opcode.name]:
|
||||
source.write_many(push_data(data))
|
||||
elif isinstance(opcode, SMALL_INTEGER):
|
||||
data = values[opcode.name]
|
||||
source.write_many(push_small_integer(data))
|
||||
else:
|
||||
source.write_uint8(opcode)
|
||||
return source.get_bytes()
|
||||
|
||||
|
||||
class Script:
|
||||
|
||||
__slots__ = 'source', '_template', '_values', '_template_hint'
|
||||
|
||||
templates: List[Template] = []
|
||||
|
||||
NO_SCRIPT = Template('no_script', None) # special case
|
||||
|
||||
def __init__(self, source=None, template=None, values=None, template_hint=None):
|
||||
self.source = source
|
||||
self._template = template
|
||||
self._values = values
|
||||
self._template_hint = template_hint
|
||||
if source is None and template and values:
|
||||
self.generate()
|
||||
|
||||
@property
|
||||
def template(self):
|
||||
if self._template is None:
|
||||
self.parse(self._template_hint)
|
||||
return self._template
|
||||
|
||||
@property
|
||||
def values(self):
|
||||
if self._values is None:
|
||||
self.parse(self._template_hint)
|
||||
return self._values
|
||||
|
||||
@property
|
||||
def tokens(self):
|
||||
return tokenize(BCDataStream(self.source))
|
||||
|
||||
@classmethod
|
||||
def from_source_with_template(cls, source, template):
|
||||
return cls(source, template_hint=template)
|
||||
|
||||
def parse(self, template_hint=None):
|
||||
tokens = self.tokens
|
||||
if not tokens and not template_hint:
|
||||
template_hint = self.NO_SCRIPT
|
||||
for template in chain((template_hint,), self.templates):
|
||||
if not template:
|
||||
continue
|
||||
try:
|
||||
self._values = template.parse(tokens)
|
||||
self._template = template
|
||||
return
|
||||
except ParseError:
|
||||
continue
|
||||
raise ValueError(f'No matching templates for source: {hexlify(self.source)}')
|
||||
|
||||
def generate(self):
|
||||
self.source = self.template.generate(self._values)
|
||||
|
||||
|
||||
class InputScript(Script):
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
REDEEM_PUBKEY = Template('pubkey', (
|
||||
PUSH_SINGLE('signature'),
|
||||
))
|
||||
REDEEM_PUBKEY_HASH = Template('pubkey_hash', (
|
||||
PUSH_SINGLE('signature'), PUSH_SINGLE('pubkey')
|
||||
))
|
||||
REDEEM_SCRIPT = Template('script', (
|
||||
SMALL_INTEGER('signatures_count'), PUSH_MANY('pubkeys'), SMALL_INTEGER('pubkeys_count'),
|
||||
OP_CHECKMULTISIG
|
||||
))
|
||||
REDEEM_SCRIPT_HASH = Template('script_hash', (
|
||||
OP_0, PUSH_MANY('signatures'), PUSH_SUBSCRIPT('script', REDEEM_SCRIPT)
|
||||
))
|
||||
|
||||
templates = [
|
||||
REDEEM_PUBKEY,
|
||||
REDEEM_PUBKEY_HASH,
|
||||
REDEEM_SCRIPT_HASH,
|
||||
REDEEM_SCRIPT
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def redeem_pubkey_hash(cls, signature, pubkey):
|
||||
return cls(template=cls.REDEEM_PUBKEY_HASH, values={
|
||||
'signature': signature,
|
||||
'pubkey': pubkey
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def redeem_script_hash(cls, signatures, pubkeys):
|
||||
return cls(template=cls.REDEEM_SCRIPT_HASH, values={
|
||||
'signatures': signatures,
|
||||
'script': cls.redeem_script(signatures, pubkeys)
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def redeem_script(cls, signatures, pubkeys):
|
||||
return cls(template=cls.REDEEM_SCRIPT, values={
|
||||
'signatures_count': len(signatures),
|
||||
'pubkeys': pubkeys,
|
||||
'pubkeys_count': len(pubkeys)
|
||||
})
|
||||
|
||||
|
||||
class OutputScript(Script):
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
# output / payment script templates (aka scriptPubKey)
|
||||
PAY_PUBKEY_FULL = Template('pay_pubkey_full', (
|
||||
PUSH_SINGLE('pubkey'), OP_CHECKSIG
|
||||
))
|
||||
PAY_PUBKEY_HASH = Template('pay_pubkey_hash', (
|
||||
OP_DUP, OP_HASH160, PUSH_SINGLE('pubkey_hash'), OP_EQUALVERIFY, OP_CHECKSIG
|
||||
))
|
||||
PAY_SCRIPT_HASH = Template('pay_script_hash', (
|
||||
OP_HASH160, PUSH_SINGLE('script_hash'), OP_EQUAL
|
||||
))
|
||||
PAY_SEGWIT = Template('pay_script_hash+segwit', (
|
||||
OP_0, PUSH_SINGLE('script_hash')
|
||||
))
|
||||
RETURN_DATA = Template('return_data', (
|
||||
OP_RETURN, PUSH_SINGLE('data')
|
||||
))
|
||||
|
||||
CLAIM_NAME_OPCODES = (
|
||||
OP_CLAIM_NAME, PUSH_SINGLE('claim_name'), PUSH_SINGLE('claim'),
|
||||
OP_2DROP, OP_DROP
|
||||
)
|
||||
CLAIM_NAME_PUBKEY = Template('claim_name+pay_pubkey_hash', (
|
||||
CLAIM_NAME_OPCODES + BaseOutputScript.PAY_PUBKEY_HASH.opcodes
|
||||
CLAIM_NAME_OPCODES + PAY_PUBKEY_HASH.opcodes
|
||||
))
|
||||
CLAIM_NAME_SCRIPT = Template('claim_name+pay_script_hash', (
|
||||
CLAIM_NAME_OPCODES + BaseOutputScript.PAY_SCRIPT_HASH.opcodes
|
||||
CLAIM_NAME_OPCODES + PAY_SCRIPT_HASH.opcodes
|
||||
))
|
||||
|
||||
SUPPORT_CLAIM_OPCODES = (
|
||||
|
@ -36,10 +432,10 @@ class OutputScript(BaseOutputScript):
|
|||
OP_2DROP, OP_DROP
|
||||
)
|
||||
SUPPORT_CLAIM_PUBKEY = Template('support_claim+pay_pubkey_hash', (
|
||||
SUPPORT_CLAIM_OPCODES + BaseOutputScript.PAY_PUBKEY_HASH.opcodes
|
||||
SUPPORT_CLAIM_OPCODES + PAY_PUBKEY_HASH.opcodes
|
||||
))
|
||||
SUPPORT_CLAIM_SCRIPT = Template('support_claim+pay_script_hash', (
|
||||
SUPPORT_CLAIM_OPCODES + BaseOutputScript.PAY_SCRIPT_HASH.opcodes
|
||||
SUPPORT_CLAIM_OPCODES + PAY_SCRIPT_HASH.opcodes
|
||||
))
|
||||
|
||||
UPDATE_CLAIM_OPCODES = (
|
||||
|
@ -47,10 +443,10 @@ class OutputScript(BaseOutputScript):
|
|||
OP_2DROP, OP_2DROP
|
||||
)
|
||||
UPDATE_CLAIM_PUBKEY = Template('update_claim+pay_pubkey_hash', (
|
||||
UPDATE_CLAIM_OPCODES + BaseOutputScript.PAY_PUBKEY_HASH.opcodes
|
||||
UPDATE_CLAIM_OPCODES + PAY_PUBKEY_HASH.opcodes
|
||||
))
|
||||
UPDATE_CLAIM_SCRIPT = Template('update_claim+pay_script_hash', (
|
||||
UPDATE_CLAIM_OPCODES + BaseOutputScript.PAY_SCRIPT_HASH.opcodes
|
||||
UPDATE_CLAIM_OPCODES + PAY_SCRIPT_HASH.opcodes
|
||||
))
|
||||
|
||||
SELL_SCRIPT = Template('sell_script', (
|
||||
|
@ -58,17 +454,22 @@ class OutputScript(BaseOutputScript):
|
|||
))
|
||||
SELL_CLAIM = Template('sell_claim+pay_script_hash', (
|
||||
OP_SELL_CLAIM, PUSH_SINGLE('claim_id'), PUSH_SUBSCRIPT('sell_script', SELL_SCRIPT),
|
||||
PUSH_SUBSCRIPT('receive_script', BaseInputScript.REDEEM_SCRIPT), OP_2DROP, OP_2DROP
|
||||
) + BaseOutputScript.PAY_SCRIPT_HASH.opcodes)
|
||||
PUSH_SUBSCRIPT('receive_script', InputScript.REDEEM_SCRIPT), OP_2DROP, OP_2DROP
|
||||
) + PAY_SCRIPT_HASH.opcodes)
|
||||
|
||||
BUY_CLAIM = Template('buy_claim+pay_script_hash', (
|
||||
OP_BUY_CLAIM, PUSH_SINGLE('sell_id'),
|
||||
PUSH_SINGLE('claim_id'), PUSH_SINGLE('claim_version'),
|
||||
PUSH_SINGLE('owner_pubkey_hash'), PUSH_SINGLE('negotiation_signature'),
|
||||
OP_2DROP, OP_2DROP, OP_2DROP,
|
||||
) + BaseOutputScript.PAY_SCRIPT_HASH.opcodes)
|
||||
) + PAY_SCRIPT_HASH.opcodes)
|
||||
|
||||
templates = BaseOutputScript.templates + [
|
||||
templates = [
|
||||
PAY_PUBKEY_FULL,
|
||||
PAY_PUBKEY_HASH,
|
||||
PAY_SCRIPT_HASH,
|
||||
PAY_SEGWIT,
|
||||
RETURN_DATA,
|
||||
CLAIM_NAME_PUBKEY,
|
||||
CLAIM_NAME_SCRIPT,
|
||||
SUPPORT_CLAIM_PUBKEY,
|
||||
|
@ -79,6 +480,28 @@ class OutputScript(BaseOutputScript):
|
|||
BUY_CLAIM,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def pay_pubkey_hash(cls, pubkey_hash):
|
||||
return cls(template=cls.PAY_PUBKEY_HASH, values={
|
||||
'pubkey_hash': pubkey_hash
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def pay_script_hash(cls, script_hash):
|
||||
return cls(template=cls.PAY_SCRIPT_HASH, values={
|
||||
'script_hash': script_hash
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def return_data(cls, data):
|
||||
return cls(template=cls.RETURN_DATA, values={
|
||||
'data': data
|
||||
})
|
||||
|
||||
@property
|
||||
def is_pay_pubkey(self):
|
||||
return self.template.name.endswith('pay_pubkey_full')
|
||||
|
||||
@classmethod
|
||||
def pay_claim_name_pubkey_hash(cls, claim_name, claim, pubkey_hash):
|
||||
return cls(template=cls.CLAIM_NAME_PUBKEY, values={
|
||||
|
@ -128,6 +551,18 @@ class OutputScript(BaseOutputScript):
|
|||
'negotiation_signature': negotiation_signature,
|
||||
})
|
||||
|
||||
@property
|
||||
def is_pay_pubkey_hash(self):
|
||||
return self.template.name.endswith('pay_pubkey_hash')
|
||||
|
||||
@property
|
||||
def is_pay_script_hash(self):
|
||||
return self.template.name.endswith('pay_script_hash')
|
||||
|
||||
@property
|
||||
def is_return_data(self):
|
||||
return self.template.name.endswith('return_data')
|
||||
|
||||
@property
|
||||
def is_claim_name(self):
|
||||
return self.template.name.startswith('claim_name+')
|
||||
|
|
|
@ -6,7 +6,7 @@ from decimal import Decimal
|
|||
from collections import namedtuple
|
||||
|
||||
import lbry.wallet.server.tx as lib_tx
|
||||
from lbry.wallet.script import OutputScript
|
||||
from lbry.wallet.script import OutputScript, OP_CLAIM_NAME, OP_UPDATE_CLAIM, OP_SUPPORT_CLAIM
|
||||
from lbry.wallet.server.tx import DeserializerSegWit
|
||||
from lbry.wallet.server.util import cachedproperty, subclasses
|
||||
from lbry.wallet.server.hash import Base58, hash160, double_sha256, hash_to_hex_str, HASHX_LEN
|
||||
|
@ -327,9 +327,9 @@ class LBC(Coin):
|
|||
if script and script[0] == OpCodes.OP_RETURN or not script:
|
||||
return None
|
||||
if script[0] in [
|
||||
OutputScript.OP_CLAIM_NAME,
|
||||
OutputScript.OP_UPDATE_CLAIM,
|
||||
OutputScript.OP_SUPPORT_CLAIM,
|
||||
OP_CLAIM_NAME,
|
||||
OP_UPDATE_CLAIM,
|
||||
OP_SUPPORT_CLAIM,
|
||||
]:
|
||||
return cls.address_to_hashX(cls.claim_address_handler(script))
|
||||
else:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from lbry.wallet.client.basedatabase import constraints_to_sql
|
||||
from lbry.wallet.database import constraints_to_sql
|
||||
|
||||
CREATE_FULL_TEXT_SEARCH = """
|
||||
create virtual table if not exists search using fts5(
|
||||
|
|
|
@ -10,12 +10,12 @@ from contextvars import ContextVar
|
|||
from functools import wraps
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lbry.wallet.client.basedatabase import query, interpolate
|
||||
from lbry.wallet.database import query, interpolate
|
||||
|
||||
from lbry.schema.url import URL, normalize_name
|
||||
from lbry.schema.tags import clean_tags
|
||||
from lbry.schema.result import Outputs
|
||||
from lbry.wallet.ledger import BaseLedger, MainNetLedger, RegTestLedger
|
||||
from lbry.wallet import Ledger, RegTestLedger
|
||||
|
||||
from .common import CLAIM_TYPES, STREAM_TYPES, COMMON_TAGS
|
||||
from .full_text_search import FTS_ORDER_BY
|
||||
|
@ -67,7 +67,7 @@ class ReaderState:
|
|||
stack: List[List]
|
||||
metrics: Dict
|
||||
is_tracking_metrics: bool
|
||||
ledger: Type[BaseLedger]
|
||||
ledger: Type[Ledger]
|
||||
query_timeout: float
|
||||
log: logging.Logger
|
||||
|
||||
|
@ -100,7 +100,7 @@ def initializer(log, _path, _ledger_name, query_timeout, _measure=False):
|
|||
ctx.set(
|
||||
ReaderState(
|
||||
db=db, stack=[], metrics={}, is_tracking_metrics=_measure,
|
||||
ledger=MainNetLedger if _ledger_name == 'mainnet' else RegTestLedger,
|
||||
ledger=Ledger if _ledger_name == 'mainnet' else RegTestLedger,
|
||||
query_timeout=query_timeout, log=log
|
||||
)
|
||||
)
|
||||
|
|
|
@ -7,11 +7,11 @@ from collections import namedtuple
|
|||
|
||||
from lbry.wallet.server.leveldb import DB
|
||||
from lbry.wallet.server.util import class_logger
|
||||
from lbry.wallet.client.basedatabase import query, constraints_to_sql
|
||||
from lbry.wallet.database import query, constraints_to_sql
|
||||
|
||||
from lbry.schema.tags import clean_tags
|
||||
from lbry.schema.mime_types import guess_stream_type
|
||||
from lbry.wallet.ledger import MainNetLedger, RegTestLedger
|
||||
from lbry.wallet import Ledger, RegTestLedger
|
||||
from lbry.wallet.transaction import Transaction, Output
|
||||
from lbry.wallet.server.db.canonical import register_canonical_functions
|
||||
from lbry.wallet.server.db.full_text_search import update_full_text_search, CREATE_FULL_TEXT_SEARCH, first_sync_finished
|
||||
|
@ -171,7 +171,7 @@ class SQLDB:
|
|||
self._db_path = path
|
||||
self.db = None
|
||||
self.logger = class_logger(__name__, self.__class__.__name__)
|
||||
self.ledger = MainNetLedger if self.main.coin.NET == 'mainnet' else RegTestLedger
|
||||
self.ledger = Ledger if self.main.coin.NET == 'mainnet' else RegTestLedger
|
||||
self._fts_synced = False
|
||||
|
||||
def open(self):
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
import ecdsa
|
||||
import struct
|
||||
import hashlib
|
||||
from binascii import hexlify, unhexlify
|
||||
from typing import List, Optional
|
||||
import logging
|
||||
import typing
|
||||
|
||||
from binascii import hexlify, unhexlify
|
||||
from typing import List, Iterable, Optional, Tuple
|
||||
|
||||
import ecdsa
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.serialization import load_der_public_key
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
|
@ -11,34 +14,216 @@ from cryptography.hazmat.primitives.asymmetric import ec
|
|||
from cryptography.hazmat.primitives.asymmetric.utils import Prehashed
|
||||
from cryptography.exceptions import InvalidSignature
|
||||
|
||||
from lbry.crypto.base58 import Base58
|
||||
from lbry.error import InsufficientFundsError
|
||||
from lbry.crypto.hash import hash160, sha256
|
||||
from lbry.wallet.client.basetransaction import BaseTransaction, BaseInput, BaseOutput, ReadOnlyList
|
||||
from lbry.crypto.base58 import Base58
|
||||
from lbry.schema.url import normalize_name
|
||||
from lbry.schema.claim import Claim
|
||||
from lbry.schema.purchase import Purchase
|
||||
from lbry.schema.url import normalize_name
|
||||
from lbry.wallet.account import Account
|
||||
from lbry.wallet.script import InputScript, OutputScript
|
||||
|
||||
from .script import InputScript, OutputScript
|
||||
from .constants import COIN, NULL_HASH32
|
||||
from .bcd_data_stream import BCDataStream
|
||||
from .hash import TXRef, TXRefImmutable
|
||||
from .util import ReadOnlyList
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from lbry.wallet.account import Account
|
||||
from lbry.wallet.ledger import Ledger
|
||||
from lbry.wallet.wallet import Wallet
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
|
||||
class Input(BaseInput):
|
||||
script: InputScript
|
||||
script_class = InputScript
|
||||
class TXRefMutable(TXRef):
|
||||
|
||||
__slots__ = ('tx',)
|
||||
|
||||
def __init__(self, tx: 'Transaction') -> None:
|
||||
super().__init__()
|
||||
self.tx = tx
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
if self._id is None:
|
||||
self._id = hexlify(self.hash[::-1]).decode()
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def hash(self):
|
||||
if self._hash is None:
|
||||
self._hash = sha256(sha256(self.tx.raw_sans_segwit))
|
||||
return self._hash
|
||||
|
||||
@property
|
||||
def height(self):
|
||||
return self.tx.height
|
||||
|
||||
def reset(self):
|
||||
self._id = None
|
||||
self._hash = None
|
||||
|
||||
|
||||
class Output(BaseOutput):
|
||||
script: OutputScript
|
||||
script_class = OutputScript
|
||||
class TXORef:
|
||||
|
||||
__slots__ = 'tx_ref', 'position'
|
||||
|
||||
def __init__(self, tx_ref: TXRef, position: int) -> None:
|
||||
self.tx_ref = tx_ref
|
||||
self.position = position
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
return f'{self.tx_ref.id}:{self.position}'
|
||||
|
||||
@property
|
||||
def hash(self):
|
||||
return self.tx_ref.hash + BCDataStream.uint32.pack(self.position)
|
||||
|
||||
@property
|
||||
def is_null(self):
|
||||
return self.tx_ref.is_null
|
||||
|
||||
@property
|
||||
def txo(self) -> Optional['Output']:
|
||||
return None
|
||||
|
||||
|
||||
class TXORefResolvable(TXORef):
|
||||
|
||||
__slots__ = ('_txo',)
|
||||
|
||||
def __init__(self, txo: 'Output') -> None:
|
||||
assert txo.tx_ref is not None
|
||||
assert txo.position is not None
|
||||
super().__init__(txo.tx_ref, txo.position)
|
||||
self._txo = txo
|
||||
|
||||
@property
|
||||
def txo(self):
|
||||
return self._txo
|
||||
|
||||
|
||||
class InputOutput:
|
||||
|
||||
__slots__ = 'tx_ref', 'position'
|
||||
|
||||
def __init__(self, tx_ref: TXRef = None, position: int = None) -> None:
|
||||
self.tx_ref = tx_ref
|
||||
self.position = position
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
""" Size of this input / output in bytes. """
|
||||
stream = BCDataStream()
|
||||
self.serialize_to(stream)
|
||||
return len(stream.get_bytes())
|
||||
|
||||
def get_fee(self, ledger):
|
||||
return self.size * ledger.fee_per_byte
|
||||
|
||||
def serialize_to(self, stream, alternate_script=None):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Input(InputOutput):
|
||||
|
||||
NULL_SIGNATURE = b'\x00'*72
|
||||
NULL_PUBLIC_KEY = b'\x00'*33
|
||||
|
||||
__slots__ = 'txo_ref', 'sequence', 'coinbase', 'script'
|
||||
|
||||
def __init__(self, txo_ref: TXORef, script: InputScript, sequence: int = 0xFFFFFFFF,
|
||||
tx_ref: TXRef = None, position: int = None) -> None:
|
||||
super().__init__(tx_ref, position)
|
||||
self.txo_ref = txo_ref
|
||||
self.sequence = sequence
|
||||
self.coinbase = script if txo_ref.is_null else None
|
||||
self.script = script if not txo_ref.is_null else None
|
||||
|
||||
@property
|
||||
def is_coinbase(self):
|
||||
return self.coinbase is not None
|
||||
|
||||
@classmethod
|
||||
def spend(cls, txo: 'Output') -> 'Input':
|
||||
""" Create an input to spend the output."""
|
||||
assert txo.script.is_pay_pubkey_hash, 'Attempting to spend unsupported output.'
|
||||
script = InputScript.redeem_pubkey_hash(cls.NULL_SIGNATURE, cls.NULL_PUBLIC_KEY)
|
||||
return cls(txo.ref, script)
|
||||
|
||||
@property
|
||||
def amount(self) -> int:
|
||||
""" Amount this input adds to the transaction. """
|
||||
if self.txo_ref.txo is None:
|
||||
raise ValueError('Cannot resolve output to get amount.')
|
||||
return self.txo_ref.txo.amount
|
||||
|
||||
@property
|
||||
def is_my_account(self) -> Optional[bool]:
|
||||
""" True if the output this input spends is yours. """
|
||||
if self.txo_ref.txo is None:
|
||||
return False
|
||||
return self.txo_ref.txo.is_my_account
|
||||
|
||||
@classmethod
|
||||
def deserialize_from(cls, stream):
|
||||
tx_ref = TXRefImmutable.from_hash(stream.read(32), -1)
|
||||
position = stream.read_uint32()
|
||||
script = stream.read_string()
|
||||
sequence = stream.read_uint32()
|
||||
return cls(
|
||||
TXORef(tx_ref, position),
|
||||
InputScript(script) if not tx_ref.is_null else script,
|
||||
sequence
|
||||
)
|
||||
|
||||
def serialize_to(self, stream, alternate_script=None):
|
||||
stream.write(self.txo_ref.tx_ref.hash)
|
||||
stream.write_uint32(self.txo_ref.position)
|
||||
if alternate_script is not None:
|
||||
stream.write_string(alternate_script)
|
||||
else:
|
||||
if self.is_coinbase:
|
||||
stream.write_string(self.coinbase)
|
||||
else:
|
||||
stream.write_string(self.script.source)
|
||||
stream.write_uint32(self.sequence)
|
||||
|
||||
|
||||
class OutputEffectiveAmountEstimator:
|
||||
|
||||
__slots__ = 'txo', 'txi', 'fee', 'effective_amount'
|
||||
|
||||
def __init__(self, ledger: 'Ledger', txo: 'Output') -> None:
|
||||
self.txo = txo
|
||||
self.txi = Input.spend(txo)
|
||||
self.fee: int = self.txi.get_fee(ledger)
|
||||
self.effective_amount: int = txo.amount - self.fee
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.effective_amount < other.effective_amount
|
||||
|
||||
|
||||
class Output(InputOutput):
|
||||
|
||||
__slots__ = (
|
||||
'amount', 'script', 'is_change', 'is_my_account',
|
||||
'channel', 'private_key', 'meta',
|
||||
'purchase', 'purchased_claim', 'purchase_receipt',
|
||||
'reposted_claim', 'claims',
|
||||
)
|
||||
|
||||
def __init__(self, *args, channel: Optional['Output'] = None,
|
||||
private_key: Optional[str] = None, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
def __init__(self, amount: int, script: OutputScript,
|
||||
tx_ref: TXRef = None, position: int = None,
|
||||
is_change: Optional[bool] = None, is_my_account: Optional[bool] = None,
|
||||
channel: Optional['Output'] = None, private_key: Optional[str] = None
|
||||
) -> None:
|
||||
super().__init__(tx_ref, position)
|
||||
self.amount = amount
|
||||
self.script = script
|
||||
self.is_change = is_change
|
||||
self.is_my_account = is_my_account
|
||||
self.channel = channel
|
||||
self.private_key = private_key
|
||||
self.purchase: 'Output' = None # txo containing purchase metadata
|
||||
|
@ -49,10 +234,52 @@ class Output(BaseOutput):
|
|||
self.meta = {}
|
||||
|
||||
def update_annotations(self, annotated):
|
||||
super().update_annotations(annotated)
|
||||
if annotated is None:
|
||||
self.is_change = False
|
||||
self.is_my_account = False
|
||||
else:
|
||||
self.is_change = annotated.is_change
|
||||
self.is_my_account = annotated.is_my_account
|
||||
self.channel = annotated.channel if annotated else None
|
||||
self.private_key = annotated.private_key if annotated else None
|
||||
|
||||
@property
|
||||
def ref(self):
|
||||
return TXORefResolvable(self)
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
return self.ref.id
|
||||
|
||||
@property
|
||||
def pubkey_hash(self):
|
||||
return self.script.values['pubkey_hash']
|
||||
|
||||
@property
|
||||
def has_address(self):
|
||||
return 'pubkey_hash' in self.script.values
|
||||
|
||||
def get_address(self, ledger):
|
||||
return ledger.hash160_to_address(self.pubkey_hash)
|
||||
|
||||
def get_estimator(self, ledger):
|
||||
return OutputEffectiveAmountEstimator(ledger, self)
|
||||
|
||||
@classmethod
|
||||
def pay_pubkey_hash(cls, amount, pubkey_hash):
|
||||
return cls(amount, OutputScript.pay_pubkey_hash(pubkey_hash))
|
||||
|
||||
@classmethod
|
||||
def deserialize_from(cls, stream):
|
||||
return cls(
|
||||
amount=stream.read_uint64(),
|
||||
script=OutputScript(stream.read_string())
|
||||
)
|
||||
|
||||
def serialize_to(self, stream, alternate_script=None):
|
||||
stream.write_uint64(self.amount)
|
||||
stream.write_string(self.script.source)
|
||||
|
||||
def get_fee(self, ledger):
|
||||
name_fee = 0
|
||||
if self.script.is_claim_name:
|
||||
|
@ -180,27 +407,28 @@ class Output(BaseOutput):
|
|||
@classmethod
|
||||
def pay_claim_name_pubkey_hash(
|
||||
cls, amount: int, claim_name: str, claim: Claim, pubkey_hash: bytes) -> 'Output':
|
||||
script = cls.script_class.pay_claim_name_pubkey_hash(
|
||||
script = OutputScript.pay_claim_name_pubkey_hash(
|
||||
claim_name.encode(), claim, pubkey_hash)
|
||||
txo = cls(amount, script)
|
||||
return txo
|
||||
return cls(amount, script)
|
||||
|
||||
@classmethod
|
||||
def pay_update_claim_pubkey_hash(
|
||||
cls, amount: int, claim_name: str, claim_id: str, claim: Claim, pubkey_hash: bytes) -> 'Output':
|
||||
script = cls.script_class.pay_update_claim_pubkey_hash(
|
||||
claim_name.encode(), unhexlify(claim_id)[::-1], claim, pubkey_hash)
|
||||
txo = cls(amount, script)
|
||||
return txo
|
||||
script = OutputScript.pay_update_claim_pubkey_hash(
|
||||
claim_name.encode(), unhexlify(claim_id)[::-1], claim, pubkey_hash
|
||||
)
|
||||
return cls(amount, script)
|
||||
|
||||
@classmethod
|
||||
def pay_support_pubkey_hash(cls, amount: int, claim_name: str, claim_id: str, pubkey_hash: bytes) -> 'Output':
|
||||
script = cls.script_class.pay_support_pubkey_hash(claim_name.encode(), unhexlify(claim_id)[::-1], pubkey_hash)
|
||||
script = OutputScript.pay_support_pubkey_hash(
|
||||
claim_name.encode(), unhexlify(claim_id)[::-1], pubkey_hash
|
||||
)
|
||||
return cls(amount, script)
|
||||
|
||||
@classmethod
|
||||
def add_purchase_data(cls, purchase: Purchase) -> 'Output':
|
||||
script = cls.script_class.return_data(purchase)
|
||||
script = OutputScript.return_data(purchase)
|
||||
return cls(0, script)
|
||||
|
||||
@property
|
||||
|
@ -246,16 +474,331 @@ class Output(BaseOutput):
|
|||
return self.claim.stream.fee
|
||||
|
||||
|
||||
class Transaction(BaseTransaction):
|
||||
class Transaction:
|
||||
|
||||
input_class = Input
|
||||
output_class = Output
|
||||
def __init__(self, raw=None, version: int = 1, locktime: int = 0, is_verified: bool = False,
|
||||
height: int = -2, position: int = -1) -> None:
|
||||
self._raw = raw
|
||||
self._raw_sans_segwit = None
|
||||
self.is_segwit_flag = 0
|
||||
self.witnesses: List[bytes] = []
|
||||
self.ref = TXRefMutable(self)
|
||||
self.version = version
|
||||
self.locktime = locktime
|
||||
self._inputs: List[Input] = []
|
||||
self._outputs: List[Output] = []
|
||||
self.is_verified = is_verified
|
||||
# Height Progression
|
||||
# -2: not broadcast
|
||||
# -1: in mempool but has unconfirmed inputs
|
||||
# 0: in mempool and all inputs confirmed
|
||||
# +num: confirmed in a specific block (height)
|
||||
self.height = height
|
||||
self.position = position
|
||||
if raw is not None:
|
||||
self._deserialize()
|
||||
|
||||
outputs: ReadOnlyList[Output]
|
||||
inputs: ReadOnlyList[Input]
|
||||
@property
|
||||
def is_broadcast(self):
|
||||
return self.height > -2
|
||||
|
||||
@property
|
||||
def is_mempool(self):
|
||||
return self.height in (-1, 0)
|
||||
|
||||
@property
|
||||
def is_confirmed(self):
|
||||
return self.height > 0
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
return self.ref.id
|
||||
|
||||
@property
|
||||
def hash(self):
|
||||
return self.ref.hash
|
||||
|
||||
@property
|
||||
def raw(self):
|
||||
if self._raw is None:
|
||||
self._raw = self._serialize()
|
||||
return self._raw
|
||||
|
||||
@property
|
||||
def raw_sans_segwit(self):
|
||||
if self.is_segwit_flag:
|
||||
if self._raw_sans_segwit is None:
|
||||
self._raw_sans_segwit = self._serialize(sans_segwit=True)
|
||||
return self._raw_sans_segwit
|
||||
return self.raw
|
||||
|
||||
def _reset(self):
|
||||
self._raw = None
|
||||
self._raw_sans_segwit = None
|
||||
self.ref.reset()
|
||||
|
||||
@property
|
||||
def inputs(self) -> ReadOnlyList[Input]:
|
||||
return ReadOnlyList(self._inputs)
|
||||
|
||||
@property
|
||||
def outputs(self) -> ReadOnlyList[Output]:
|
||||
return ReadOnlyList(self._outputs)
|
||||
|
||||
def _add(self, existing_ios: List, new_ios: Iterable[InputOutput], reset=False) -> 'Transaction':
|
||||
for txio in new_ios:
|
||||
txio.tx_ref = self.ref
|
||||
txio.position = len(existing_ios)
|
||||
existing_ios.append(txio)
|
||||
if reset:
|
||||
self._reset()
|
||||
return self
|
||||
|
||||
def add_inputs(self, inputs: Iterable[Input]) -> 'Transaction':
|
||||
return self._add(self._inputs, inputs, True)
|
||||
|
||||
def add_outputs(self, outputs: Iterable[Output]) -> 'Transaction':
|
||||
return self._add(self._outputs, outputs, True)
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
""" Size in bytes of the entire transaction. """
|
||||
return len(self.raw)
|
||||
|
||||
@property
|
||||
def base_size(self) -> int:
|
||||
""" Size of transaction without inputs or outputs in bytes. """
|
||||
return (
|
||||
self.size
|
||||
- sum(txi.size for txi in self._inputs)
|
||||
- sum(txo.size for txo in self._outputs)
|
||||
)
|
||||
|
||||
@property
|
||||
def input_sum(self):
|
||||
return sum(i.amount for i in self.inputs if i.txo_ref.txo is not None)
|
||||
|
||||
@property
|
||||
def output_sum(self):
|
||||
return sum(o.amount for o in self.outputs)
|
||||
|
||||
@property
|
||||
def net_account_balance(self) -> int:
|
||||
balance = 0
|
||||
for txi in self.inputs:
|
||||
if txi.txo_ref.txo is None:
|
||||
continue
|
||||
if txi.is_my_account is None:
|
||||
raise ValueError(
|
||||
"Cannot access net_account_balance if inputs/outputs do not "
|
||||
"have is_my_account set properly."
|
||||
)
|
||||
if txi.is_my_account:
|
||||
balance -= txi.amount
|
||||
for txo in self.outputs:
|
||||
if txo.is_my_account is None:
|
||||
raise ValueError(
|
||||
"Cannot access net_account_balance if inputs/outputs do not "
|
||||
"have is_my_account set properly."
|
||||
)
|
||||
if txo.is_my_account:
|
||||
balance += txo.amount
|
||||
return balance
|
||||
|
||||
@property
|
||||
def fee(self) -> int:
|
||||
return self.input_sum - self.output_sum
|
||||
|
||||
def get_base_fee(self, ledger) -> int:
|
||||
""" Fee for base tx excluding inputs and outputs. """
|
||||
return self.base_size * ledger.fee_per_byte
|
||||
|
||||
def get_effective_input_sum(self, ledger) -> int:
|
||||
""" Sum of input values *minus* the cost involved to spend them. """
|
||||
return sum(txi.amount - txi.get_fee(ledger) for txi in self._inputs)
|
||||
|
||||
def get_total_output_sum(self, ledger) -> int:
|
||||
""" Sum of output values *plus* the cost involved to spend them. """
|
||||
return sum(txo.amount + txo.get_fee(ledger) for txo in self._outputs)
|
||||
|
||||
def _serialize(self, with_inputs: bool = True, sans_segwit: bool = False) -> bytes:
|
||||
stream = BCDataStream()
|
||||
stream.write_uint32(self.version)
|
||||
if with_inputs:
|
||||
stream.write_compact_size(len(self._inputs))
|
||||
for txin in self._inputs:
|
||||
txin.serialize_to(stream)
|
||||
stream.write_compact_size(len(self._outputs))
|
||||
for txout in self._outputs:
|
||||
txout.serialize_to(stream)
|
||||
stream.write_uint32(self.locktime)
|
||||
return stream.get_bytes()
|
||||
|
||||
def _serialize_for_signature(self, signing_input: int) -> bytes:
|
||||
stream = BCDataStream()
|
||||
stream.write_uint32(self.version)
|
||||
stream.write_compact_size(len(self._inputs))
|
||||
for i, txin in enumerate(self._inputs):
|
||||
if signing_input == i:
|
||||
assert txin.txo_ref.txo is not None
|
||||
txin.serialize_to(stream, txin.txo_ref.txo.script.source)
|
||||
else:
|
||||
txin.serialize_to(stream, b'')
|
||||
stream.write_compact_size(len(self._outputs))
|
||||
for txout in self._outputs:
|
||||
txout.serialize_to(stream)
|
||||
stream.write_uint32(self.locktime)
|
||||
stream.write_uint32(self.signature_hash_type(1)) # signature hash type: SIGHASH_ALL
|
||||
return stream.get_bytes()
|
||||
|
||||
def _deserialize(self):
|
||||
if self._raw is not None:
|
||||
stream = BCDataStream(self._raw)
|
||||
self.version = stream.read_uint32()
|
||||
input_count = stream.read_compact_size()
|
||||
if input_count == 0:
|
||||
self.is_segwit_flag = stream.read_uint8()
|
||||
input_count = stream.read_compact_size()
|
||||
self._add(self._inputs, [
|
||||
Input.deserialize_from(stream) for _ in range(input_count)
|
||||
])
|
||||
output_count = stream.read_compact_size()
|
||||
self._add(self._outputs, [
|
||||
Output.deserialize_from(stream) for _ in range(output_count)
|
||||
])
|
||||
if self.is_segwit_flag:
|
||||
# drain witness portion of transaction
|
||||
# too many witnesses for no crime
|
||||
self.witnesses = []
|
||||
for _ in range(input_count):
|
||||
for _ in range(stream.read_compact_size()):
|
||||
self.witnesses.append(stream.read(stream.read_compact_size()))
|
||||
self.locktime = stream.read_uint32()
|
||||
|
||||
@classmethod
|
||||
def pay(cls, amount: int, address: bytes, funding_accounts: List[Account], change_account: Account):
|
||||
def ensure_all_have_same_ledger_and_wallet(
|
||||
cls, funding_accounts: Iterable['Account'],
|
||||
change_account: 'Account' = None) -> Tuple['Ledger', 'Wallet']:
|
||||
ledger = wallet = None
|
||||
for account in funding_accounts:
|
||||
if ledger is None:
|
||||
ledger = account.ledger
|
||||
wallet = account.wallet
|
||||
if ledger != account.ledger:
|
||||
raise ValueError(
|
||||
'All funding accounts used to create a transaction must be on the same ledger.'
|
||||
)
|
||||
if wallet != account.wallet:
|
||||
raise ValueError(
|
||||
'All funding accounts used to create a transaction must be from the same wallet.'
|
||||
)
|
||||
if change_account is not None:
|
||||
if change_account.ledger != ledger:
|
||||
raise ValueError('Change account must use same ledger as funding accounts.')
|
||||
if change_account.wallet != wallet:
|
||||
raise ValueError('Change account must use same wallet as funding accounts.')
|
||||
if ledger is None:
|
||||
raise ValueError('No ledger found.')
|
||||
if wallet is None:
|
||||
raise ValueError('No wallet found.')
|
||||
return ledger, wallet
|
||||
|
||||
@classmethod
|
||||
async def create(cls, inputs: Iterable[Input], outputs: Iterable[Output],
|
||||
funding_accounts: Iterable['Account'], change_account: 'Account',
|
||||
sign: bool = True):
|
||||
""" Find optimal set of inputs when only outputs are provided; add change
|
||||
outputs if only inputs are provided or if inputs are greater than outputs. """
|
||||
|
||||
tx = cls() \
|
||||
.add_inputs(inputs) \
|
||||
.add_outputs(outputs)
|
||||
|
||||
ledger, _ = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account)
|
||||
|
||||
# value of the outputs plus associated fees
|
||||
cost = (
|
||||
tx.get_base_fee(ledger) +
|
||||
tx.get_total_output_sum(ledger)
|
||||
)
|
||||
# value of the inputs less the cost to spend those inputs
|
||||
payment = tx.get_effective_input_sum(ledger)
|
||||
|
||||
try:
|
||||
|
||||
for _ in range(5):
|
||||
|
||||
if payment < cost:
|
||||
deficit = cost - payment
|
||||
spendables = await ledger.get_spendable_utxos(deficit, funding_accounts)
|
||||
if not spendables:
|
||||
raise InsufficientFundsError()
|
||||
payment += sum(s.effective_amount for s in spendables)
|
||||
tx.add_inputs(s.txi for s in spendables)
|
||||
|
||||
cost_of_change = (
|
||||
tx.get_base_fee(ledger) +
|
||||
Output.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(ledger)
|
||||
)
|
||||
if payment > cost:
|
||||
change = payment - cost
|
||||
if change > cost_of_change:
|
||||
change_address = await change_account.change.get_or_create_usable_address()
|
||||
change_hash160 = change_account.ledger.address_to_hash160(change_address)
|
||||
change_amount = change - cost_of_change
|
||||
change_output = Output.pay_pubkey_hash(change_amount, change_hash160)
|
||||
change_output.is_change = True
|
||||
tx.add_outputs([Output.pay_pubkey_hash(change_amount, change_hash160)])
|
||||
|
||||
if tx._outputs:
|
||||
break
|
||||
# this condition and the outer range(5) loop cover an edge case
|
||||
# whereby a single input is just enough to cover the fee and
|
||||
# has some change left over, but the change left over is less
|
||||
# than the cost_of_change: thus the input is completely
|
||||
# consumed and no output is added, which is an invalid tx.
|
||||
# to be able to spend this input we must increase the cost
|
||||
# of the TX and run through the balance algorithm a second time
|
||||
# adding an extra input and change output, making tx valid.
|
||||
# we do this 5 times in case the other UTXOs added are also
|
||||
# less than the fee, after 5 attempts we give up and go home
|
||||
cost += cost_of_change + 1
|
||||
|
||||
if sign:
|
||||
await tx.sign(funding_accounts)
|
||||
|
||||
except Exception as e:
|
||||
log.exception('Failed to create transaction:')
|
||||
await ledger.release_tx(tx)
|
||||
raise e
|
||||
|
||||
return tx
|
||||
|
||||
@staticmethod
|
||||
def signature_hash_type(hash_type):
|
||||
return hash_type
|
||||
|
||||
async def sign(self, funding_accounts: Iterable['Account']):
|
||||
ledger, wallet = self.ensure_all_have_same_ledger_and_wallet(funding_accounts)
|
||||
for i, txi in enumerate(self._inputs):
|
||||
assert txi.script is not None
|
||||
assert txi.txo_ref.txo is not None
|
||||
txo_script = txi.txo_ref.txo.script
|
||||
if txo_script.is_pay_pubkey_hash:
|
||||
address = ledger.hash160_to_address(txo_script.values['pubkey_hash'])
|
||||
private_key = await ledger.get_private_key_for_address(wallet, address)
|
||||
assert private_key is not None, 'Cannot find private key for signing output.'
|
||||
tx = self._serialize_for_signature(i)
|
||||
txi.script.values['signature'] = \
|
||||
private_key.sign(tx) + bytes((self.signature_hash_type(1),))
|
||||
txi.script.values['pubkey'] = private_key.public_key.pubkey_bytes
|
||||
txi.script.generate()
|
||||
else:
|
||||
raise NotImplementedError("Don't know how to spend this output.")
|
||||
self._reset()
|
||||
|
||||
@classmethod
|
||||
def pay(cls, amount: int, address: bytes, funding_accounts: List['Account'], change_account: 'Account'):
|
||||
ledger, wallet = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account)
|
||||
output = Output.pay_pubkey_hash(amount, ledger.address_to_hash160(address))
|
||||
return cls.create([], [output], funding_accounts, change_account)
|
||||
|
@ -263,7 +806,7 @@ class Transaction(BaseTransaction):
|
|||
@classmethod
|
||||
def claim_create(
|
||||
cls, name: str, claim: Claim, amount: int, holding_address: str,
|
||||
funding_accounts: List[Account], change_account: Account, signing_channel: Output = None):
|
||||
funding_accounts: List['Account'], change_account: 'Account', signing_channel: Output = None):
|
||||
ledger, wallet = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account)
|
||||
claim_output = Output.pay_claim_name_pubkey_hash(
|
||||
amount, name, claim, ledger.address_to_hash160(holding_address)
|
||||
|
@ -275,7 +818,7 @@ class Transaction(BaseTransaction):
|
|||
@classmethod
|
||||
def claim_update(
|
||||
cls, previous_claim: Output, claim: Claim, amount: int, holding_address: str,
|
||||
funding_accounts: List[Account], change_account: Account, signing_channel: Output = None):
|
||||
funding_accounts: List['Account'], change_account: 'Account', signing_channel: Output = None):
|
||||
ledger, wallet = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account)
|
||||
updated_claim = Output.pay_update_claim_pubkey_hash(
|
||||
amount, previous_claim.claim_name, previous_claim.claim_id,
|
||||
|
@ -291,7 +834,7 @@ class Transaction(BaseTransaction):
|
|||
|
||||
@classmethod
|
||||
def support(cls, claim_name: str, claim_id: str, amount: int, holding_address: str,
|
||||
funding_accounts: List[Account], change_account: Account):
|
||||
funding_accounts: List['Account'], change_account: 'Account'):
|
||||
ledger, wallet = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account)
|
||||
support_output = Output.pay_support_pubkey_hash(
|
||||
amount, claim_name, claim_id, ledger.address_to_hash160(holding_address)
|
||||
|
@ -300,7 +843,7 @@ class Transaction(BaseTransaction):
|
|||
|
||||
@classmethod
|
||||
def purchase(cls, claim_id: str, amount: int, merchant_address: bytes,
|
||||
funding_accounts: List[Account], change_account: Account):
|
||||
funding_accounts: List['Account'], change_account: 'Account'):
|
||||
ledger, wallet = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account)
|
||||
payment = Output.pay_pubkey_hash(amount, ledger.address_to_hash160(merchant_address))
|
||||
data = Output.add_purchase_data(Purchase(claim_id))
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import re
|
||||
from typing import TypeVar, Sequence, Optional
|
||||
from lbry.wallet.client.constants import COIN
|
||||
from .constants import COIN
|
||||
|
||||
|
||||
def coins_to_satoshis(coins):
|
||||
|
|
|
@ -10,9 +10,11 @@ from collections import UserDict
|
|||
from hashlib import sha256
|
||||
from operator import attrgetter
|
||||
from lbry.crypto.crypt import better_aes_encrypt, better_aes_decrypt
|
||||
from .account import Account
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from lbry.wallet.client import basemanager, baseaccount, baseledger
|
||||
from lbry.wallet.manager import WalletManager
|
||||
from lbry.wallet.ledger import Ledger
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -65,7 +67,7 @@ class Wallet:
|
|||
preferences: TimestampedPreferences
|
||||
encryption_password: Optional[str]
|
||||
|
||||
def __init__(self, name: str = 'Wallet', accounts: MutableSequence['baseaccount.BaseAccount'] = None,
|
||||
def __init__(self, name: str = 'Wallet', accounts: MutableSequence['Account'] = None,
|
||||
storage: 'WalletStorage' = None, preferences: dict = None) -> None:
|
||||
self.name = name
|
||||
self.accounts = accounts or []
|
||||
|
@ -79,30 +81,30 @@ class Wallet:
|
|||
return os.path.basename(self.storage.path)
|
||||
return self.name
|
||||
|
||||
def add_account(self, account: 'baseaccount.BaseAccount'):
|
||||
def add_account(self, account: 'Account'):
|
||||
self.accounts.append(account)
|
||||
|
||||
def generate_account(self, ledger: 'baseledger.BaseLedger') -> 'baseaccount.BaseAccount':
|
||||
return ledger.account_class.generate(ledger, self)
|
||||
def generate_account(self, ledger: 'Ledger') -> 'Account':
|
||||
return Account.generate(ledger, self)
|
||||
|
||||
@property
|
||||
def default_account(self) -> Optional['baseaccount.BaseAccount']:
|
||||
def default_account(self) -> Optional['Account']:
|
||||
for account in self.accounts:
|
||||
return account
|
||||
return None
|
||||
|
||||
def get_account_or_default(self, account_id: str) -> Optional['baseaccount.BaseAccount']:
|
||||
def get_account_or_default(self, account_id: str) -> Optional['Account']:
|
||||
if account_id is None:
|
||||
return self.default_account
|
||||
return self.get_account_or_error(account_id)
|
||||
|
||||
def get_account_or_error(self, account_id: str) -> 'baseaccount.BaseAccount':
|
||||
def get_account_or_error(self, account_id: str) -> 'Account':
|
||||
for account in self.accounts:
|
||||
if account.id == account_id:
|
||||
return account
|
||||
raise ValueError(f"Couldn't find account: {account_id}.")
|
||||
|
||||
def get_accounts_or_all(self, account_ids: List[str]) -> Sequence['baseaccount.BaseAccount']:
|
||||
def get_accounts_or_all(self, account_ids: List[str]) -> Sequence['Account']:
|
||||
return [
|
||||
self.get_account_or_error(account_id)
|
||||
for account_id in account_ids
|
||||
|
@ -117,7 +119,7 @@ class Wallet:
|
|||
return accounts
|
||||
|
||||
@classmethod
|
||||
def from_storage(cls, storage: 'WalletStorage', manager: 'basemanager.BaseWalletManager') -> 'Wallet':
|
||||
def from_storage(cls, storage: 'WalletStorage', manager: 'WalletManager') -> 'Wallet':
|
||||
json_dict = storage.read()
|
||||
wallet = cls(
|
||||
name=json_dict.get('name', 'Wallet'),
|
||||
|
@ -127,7 +129,7 @@ class Wallet:
|
|||
account_dicts: Sequence[dict] = json_dict.get('accounts', [])
|
||||
for account_dict in account_dicts:
|
||||
ledger = manager.get_or_create_ledger(account_dict['ledger'])
|
||||
ledger.account_class.from_dict(ledger, wallet, account_dict)
|
||||
Account.from_dict(ledger, wallet, account_dict)
|
||||
return wallet
|
||||
|
||||
def to_dict(self, encrypt_password: str = None):
|
||||
|
@ -173,15 +175,15 @@ class Wallet:
|
|||
decompressed = zlib.decompress(decrypted)
|
||||
return json.loads(decompressed)
|
||||
|
||||
def merge(self, manager: 'basemanager.BaseWalletManager',
|
||||
password: str, data: str) -> List['baseaccount.BaseAccount']:
|
||||
def merge(self, manager: 'WalletManager',
|
||||
password: str, data: str) -> List['Account']:
|
||||
assert not self.is_locked, "Cannot sync apply on a locked wallet."
|
||||
added_accounts = []
|
||||
decrypted_data = self.unpack(password, data)
|
||||
self.preferences.merge(decrypted_data.get('preferences', {}))
|
||||
for account_dict in decrypted_data['accounts']:
|
||||
ledger = manager.get_or_create_ledger(account_dict['ledger'])
|
||||
_, _, pubkey = ledger.account_class.keys_from_dict(ledger, account_dict)
|
||||
_, _, pubkey = Account.keys_from_dict(ledger, account_dict)
|
||||
account_id = pubkey.address
|
||||
local_match = None
|
||||
for local_account in self.accounts:
|
||||
|
@ -191,7 +193,7 @@ class Wallet:
|
|||
if local_match is not None:
|
||||
local_match.merge(account_dict)
|
||||
else:
|
||||
new_account = ledger.account_class.from_dict(ledger, self, account_dict)
|
||||
new_account = Account.from_dict(ledger, self, account_dict)
|
||||
added_accounts.append(new_account)
|
||||
return added_accounts
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import asyncio
|
||||
import json
|
||||
|
||||
from lbry.wallet.client.wallet import ENCRYPT_ON_DISK
|
||||
from lbry.wallet import ENCRYPT_ON_DISK
|
||||
from lbry.error import InvalidPasswordError
|
||||
from lbry.testcase import CommandTestCase
|
||||
from lbry.wallet.dewies import dict_values_to_lbc
|
||||
|
|
|
@ -2,7 +2,6 @@ import unittest
|
|||
from unittest import mock
|
||||
import json
|
||||
|
||||
import lbry.wallet
|
||||
from lbry.conf import Config
|
||||
from lbry.extras.daemon.storage import SQLiteStorage
|
||||
from lbry.extras.daemon.ComponentManager import ComponentManager
|
||||
|
@ -11,8 +10,7 @@ from lbry.extras.daemon.Components import HASH_ANNOUNCER_COMPONENT
|
|||
from lbry.extras.daemon.Components import UPNP_COMPONENT, BLOB_COMPONENT
|
||||
from lbry.extras.daemon.Components import PEER_PROTOCOL_SERVER_COMPONENT, EXCHANGE_RATE_MANAGER_COMPONENT
|
||||
from lbry.extras.daemon.Daemon import Daemon as LBRYDaemon
|
||||
from lbry.wallet import LbryWalletManager
|
||||
from lbry.wallet.client.wallet import Wallet
|
||||
from lbry.wallet import WalletManager, Wallet
|
||||
|
||||
from tests import test_utils
|
||||
# from tests.mocks import mock_conf_settings, FakeNetwork, FakeFileManager
|
||||
|
@ -37,7 +35,7 @@ def get_test_daemon(conf: Config, with_fee=False):
|
|||
)
|
||||
daemon = LBRYDaemon(conf, component_manager=component_manager)
|
||||
daemon.payment_rate_manager = OnlyFreePaymentsManager()
|
||||
daemon.wallet_manager = mock.Mock(spec=LbryWalletManager)
|
||||
daemon.wallet_manager = mock.Mock(spec=WalletManager)
|
||||
daemon.wallet_manager.wallet = mock.Mock(spec=Wallet)
|
||||
daemon.wallet_manager.use_encryption = False
|
||||
daemon.wallet_manager.network = FakeNetwork()
|
||||
|
|
|
@ -10,13 +10,10 @@ from lbry.testcase import get_fake_exchange_rate_manager
|
|||
from lbry.utils import generate_id
|
||||
from lbry.error import InsufficientFundsError
|
||||
from lbry.error import KeyFeeAboveMaxAllowedError, ResolveError, DownloadSDTimeoutError, DownloadDataTimeoutError
|
||||
from lbry.wallet.client.wallet import Wallet
|
||||
from lbry.wallet import WalletManager, Wallet, Ledger, Transaction, Input, Output, Database
|
||||
from lbry.wallet.client.constants import CENT, NULL_HASH32
|
||||
from lbry.wallet.client.basenetwork import ClientSession
|
||||
from lbry.conf import Config
|
||||
from lbry.wallet.ledger import MainNetLedger
|
||||
from lbry.wallet.transaction import Transaction, Input, Output
|
||||
from lbry.wallet.manager import LbryWalletManager
|
||||
from lbry.extras.daemon.analytics import AnalyticsManager
|
||||
from lbry.stream.stream_manager import StreamManager
|
||||
from lbry.stream.descriptor import StreamDescriptor
|
||||
|
@ -94,16 +91,16 @@ async def get_mock_wallet(sd_hash, storage, balance=10.0, fee=None):
|
|||
return {'timestamp': 1984}
|
||||
|
||||
wallet = Wallet()
|
||||
ledger = MainNetLedger({
|
||||
'db': MainNetLedger.database_class(':memory:'),
|
||||
ledger = Ledger({
|
||||
'db': Database(':memory:'),
|
||||
'headers': FakeHeaders(514082)
|
||||
})
|
||||
await ledger.db.open()
|
||||
wallet.generate_account(ledger)
|
||||
manager = LbryWalletManager()
|
||||
manager = WalletManager()
|
||||
manager.config = Config()
|
||||
manager.wallets.append(wallet)
|
||||
manager.ledgers[MainNetLedger] = ledger
|
||||
manager.ledgers[Ledger] = ledger
|
||||
manager.ledger.network.client = ClientSession(
|
||||
network=manager.ledger.network, server=('fakespv.lbry.com', 50001)
|
||||
)
|
||||
|
|
|
@ -1,17 +1,13 @@
|
|||
from binascii import hexlify
|
||||
from lbry.testcase import AsyncioTestCase
|
||||
from lbry.wallet.client.wallet import Wallet
|
||||
from lbry.wallet.ledger import MainNetLedger, WalletDatabase
|
||||
from lbry.wallet.header import Headers
|
||||
from lbry.wallet.account import Account
|
||||
from lbry.wallet.client.baseaccount import SingleKey, HierarchicalDeterministic
|
||||
from lbry.wallet import Wallet, Ledger, Database, Headers, Account, SingleKey, HierarchicalDeterministic
|
||||
|
||||
|
||||
class TestAccount(AsyncioTestCase):
|
||||
|
||||
async def asyncSetUp(self):
|
||||
self.ledger = MainNetLedger({
|
||||
'db': WalletDatabase(':memory:'),
|
||||
self.ledger = Ledger({
|
||||
'db': Database(':memory:'),
|
||||
'headers': Headers(':memory:')
|
||||
})
|
||||
await self.ledger.db.open()
|
||||
|
@ -236,8 +232,8 @@ class TestAccount(AsyncioTestCase):
|
|||
class TestSingleKeyAccount(AsyncioTestCase):
|
||||
|
||||
async def asyncSetUp(self):
|
||||
self.ledger = MainNetLedger({
|
||||
'db': WalletDatabase(':memory:'),
|
||||
self.ledger = Ledger({
|
||||
'db': Database(':memory:'),
|
||||
'headers': Headers(':memory:')
|
||||
})
|
||||
await self.ledger.db.open()
|
||||
|
@ -327,7 +323,7 @@ class TestSingleKeyAccount(AsyncioTestCase):
|
|||
self.assertEqual(len(keys), 1)
|
||||
|
||||
async def test_generate_account_from_seed(self):
|
||||
account = self.ledger.account_class.from_dict(
|
||||
account = Account.from_dict(
|
||||
self.ledger, Wallet(), {
|
||||
"seed":
|
||||
"carbon smart garage balance margin twelve chest sword toas"
|
||||
|
@ -432,8 +428,8 @@ class AccountEncryptionTests(AsyncioTestCase):
|
|||
}
|
||||
|
||||
async def asyncSetUp(self):
|
||||
self.ledger = MainNetLedger({
|
||||
'db': WalletDatabase(':memory:'),
|
||||
self.ledger = Ledger({
|
||||
'db': Database(':memory:'),
|
||||
'headers': Headers(':memory:')
|
||||
})
|
||||
|
||||
|
@ -489,7 +485,7 @@ class AccountEncryptionTests(AsyncioTestCase):
|
|||
account_data = self.unencrypted_account.copy()
|
||||
del account_data['seed']
|
||||
del account_data['private_key']
|
||||
account = self.ledger.account_class.from_dict(self.ledger, Wallet(), account_data)
|
||||
account = Account.from_dict(self.ledger, Wallet(), account_data)
|
||||
encrypted = account.to_dict('password')
|
||||
self.assertFalse(encrypted['seed'])
|
||||
self.assertFalse(encrypted['private_key'])
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import unittest
|
||||
|
||||
from lbry.wallet.client.bcd_data_stream import BCDataStream
|
||||
from lbry.wallet.bcd_data_stream import BCDataStream
|
||||
|
||||
|
||||
class TestBCDataStream(unittest.TestCase):
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
from binascii import unhexlify, hexlify
|
||||
|
||||
from lbry.testcase import AsyncioTestCase
|
||||
from lbry.wallet.client.bip32 import PubKey, PrivateKey, from_extended_key_string
|
||||
from lbry.wallet import Ledger, Database, Headers
|
||||
|
||||
from tests.unit.wallet.key_fixtures import expected_ids, expected_privkeys, expected_hardened_privkeys
|
||||
from lbry.wallet.client.bip32 import PubKey, PrivateKey, from_extended_key_string
|
||||
from lbry.wallet import MainNetLedger as ledger_class
|
||||
|
||||
|
||||
class BIP32Tests(AsyncioTestCase):
|
||||
|
@ -46,9 +46,9 @@ class BIP32Tests(AsyncioTestCase):
|
|||
with self.assertRaisesRegex(ValueError, 'private key must be 32 bytes'):
|
||||
PrivateKey(None, b'abcd', b'abcd'*8, 0, 255)
|
||||
private_key = PrivateKey(
|
||||
ledger_class({
|
||||
'db': ledger_class.database_class(':memory:'),
|
||||
'headers': ledger_class.headers_class(':memory:'),
|
||||
Ledger({
|
||||
'db': Database(':memory:'),
|
||||
'headers': Headers(':memory:'),
|
||||
}),
|
||||
unhexlify('2423f3dc6087d9683f73a684935abc0ccd8bc26370588f56653128c6a6f0bf7c'),
|
||||
b'abcd'*8, 0, 1
|
||||
|
@ -67,9 +67,9 @@ class BIP32Tests(AsyncioTestCase):
|
|||
|
||||
async def test_private_key_derivation(self):
|
||||
private_key = PrivateKey(
|
||||
ledger_class({
|
||||
'db': ledger_class.database_class(':memory:'),
|
||||
'headers': ledger_class.headers_class(':memory:'),
|
||||
Ledger({
|
||||
'db': Database(':memory:'),
|
||||
'headers': Headers(':memory:'),
|
||||
}),
|
||||
unhexlify('2423f3dc6087d9683f73a684935abc0ccd8bc26370588f56653128c6a6f0bf7c'),
|
||||
b'abcd'*8, 0, 1
|
||||
|
@ -84,9 +84,9 @@ class BIP32Tests(AsyncioTestCase):
|
|||
self.assertEqual(hexlify(new_privkey.private_key_bytes), expected_hardened_privkeys[i - 1 - PrivateKey.HARDENED])
|
||||
|
||||
async def test_from_extended_keys(self):
|
||||
ledger = ledger_class({
|
||||
'db': ledger_class.database_class(':memory:'),
|
||||
'headers': ledger_class.headers_class(':memory:'),
|
||||
ledger = Ledger({
|
||||
'db': Database(':memory:'),
|
||||
'headers': Headers(':memory:'),
|
||||
})
|
||||
self.assertIsInstance(
|
||||
from_extended_key_string(
|
||||
|
|
|
@ -2,9 +2,9 @@ from types import GeneratorType
|
|||
|
||||
from lbry.testcase import AsyncioTestCase
|
||||
|
||||
from lbry.wallet import MainNetLedger as ledger_class
|
||||
from lbry.wallet import Ledger, Database, Headers
|
||||
from lbry.wallet.client.coinselection import CoinSelector, MAXIMUM_TRIES
|
||||
from lbry.wallet.client.constants import CENT
|
||||
from lbry.constants import CENT
|
||||
|
||||
from tests.unit.wallet.test_transaction import get_output as utxo
|
||||
|
||||
|
@ -20,9 +20,9 @@ def search(*args, **kwargs):
|
|||
class BaseSelectionTestCase(AsyncioTestCase):
|
||||
|
||||
async def asyncSetUp(self):
|
||||
self.ledger = ledger_class({
|
||||
'db': ledger_class.database_class(':memory:'),
|
||||
'headers': ledger_class.headers_class(':memory:'),
|
||||
self.ledger = Ledger({
|
||||
'db': Database(':memory:'),
|
||||
'headers': Headers(':memory:'),
|
||||
})
|
||||
await self.ledger.db.open()
|
||||
|
||||
|
|
|
@ -6,11 +6,11 @@ import tempfile
|
|||
import asyncio
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
|
||||
from lbry.wallet import MainNetLedger
|
||||
from lbry.wallet.transaction import Transaction
|
||||
from lbry.wallet.client.wallet import Wallet
|
||||
from lbry.wallet import (
|
||||
Wallet, Account, Ledger, Database, Headers, Transaction, Input
|
||||
)
|
||||
from lbry.wallet.client.constants import COIN
|
||||
from lbry.wallet.client.basedatabase import query, interpolate, constraints_to_sql, AIOSQLite
|
||||
from lbry.wallet.database import query, interpolate, constraints_to_sql, AIOSQLite
|
||||
from lbry.crypto.hash import sha256
|
||||
from lbry.testcase import AsyncioTestCase
|
||||
|
||||
|
@ -195,9 +195,9 @@ class TestQueryBuilder(unittest.TestCase):
|
|||
class TestQueries(AsyncioTestCase):
|
||||
|
||||
async def asyncSetUp(self):
|
||||
self.ledger = MainNetLedger({
|
||||
'db': MainNetLedger.database_class(':memory:'),
|
||||
'headers': MainNetLedger.headers_class(':memory:')
|
||||
self.ledger = Ledger({
|
||||
'db': Database(':memory:'),
|
||||
'headers': Headers(':memory:')
|
||||
})
|
||||
self.wallet = Wallet()
|
||||
await self.ledger.db.open()
|
||||
|
@ -206,13 +206,13 @@ class TestQueries(AsyncioTestCase):
|
|||
await self.ledger.db.close()
|
||||
|
||||
async def create_account(self, wallet=None):
|
||||
account = self.ledger.account_class.generate(self.ledger, wallet or self.wallet)
|
||||
account = Account.generate(self.ledger, wallet or self.wallet)
|
||||
await account.ensure_address_gap()
|
||||
return account
|
||||
|
||||
async def create_tx_from_nothing(self, my_account, height):
|
||||
to_address = await my_account.receiving.get_or_create_usable_address()
|
||||
to_hash = MainNetLedger.address_to_hash160(to_address)
|
||||
to_hash = Ledger.address_to_hash160(to_address)
|
||||
tx = Transaction(height=height, is_verified=True) \
|
||||
.add_inputs([self.txi(self.txo(1, sha256(str(height).encode())))]) \
|
||||
.add_outputs([self.txo(1, to_hash)])
|
||||
|
@ -224,7 +224,7 @@ class TestQueries(AsyncioTestCase):
|
|||
from_hash = txo.script.values['pubkey_hash']
|
||||
from_address = self.ledger.hash160_to_address(from_hash)
|
||||
to_address = await to_account.receiving.get_or_create_usable_address()
|
||||
to_hash = MainNetLedger.address_to_hash160(to_address)
|
||||
to_hash = Ledger.address_to_hash160(to_address)
|
||||
tx = Transaction(height=height, is_verified=True) \
|
||||
.add_inputs([self.txi(txo)]) \
|
||||
.add_outputs([self.txo(1, to_hash)])
|
||||
|
@ -248,7 +248,7 @@ class TestQueries(AsyncioTestCase):
|
|||
return get_output(int(amount*COIN), address)
|
||||
|
||||
def txi(self, txo):
|
||||
return Transaction.input_class.spend(txo)
|
||||
return Input.spend(txo)
|
||||
|
||||
async def test_large_tx_doesnt_hit_variable_limits(self):
|
||||
# SQLite is usually compiled with 999 variables limit: https://www.sqlite.org/limits.html
|
||||
|
@ -408,9 +408,9 @@ class TestUpgrade(AsyncioTestCase):
|
|||
return [col[0] for col in conn.execute(sql).fetchall()]
|
||||
|
||||
async def test_reset_on_version_change(self):
|
||||
self.ledger = MainNetLedger({
|
||||
'db': MainNetLedger.database_class(self.path),
|
||||
'headers': MainNetLedger.headers_class(':memory:')
|
||||
self.ledger = Ledger({
|
||||
'db': Database(self.path),
|
||||
'headers': Headers(':memory:')
|
||||
})
|
||||
|
||||
# initial open, pre-version enabled db
|
||||
|
|
|
@ -2,10 +2,7 @@ import os
|
|||
from binascii import hexlify
|
||||
|
||||
from lbry.testcase import AsyncioTestCase
|
||||
from lbry.wallet.client.wallet import Wallet
|
||||
from lbry.wallet.account import Account
|
||||
from lbry.wallet.transaction import Transaction, Output, Input
|
||||
from lbry.wallet.ledger import MainNetLedger
|
||||
from lbry.wallet import Wallet, Account, Transaction, Output, Input, Ledger, Database, Headers
|
||||
|
||||
from tests.unit.wallet.test_transaction import get_transaction, get_output
|
||||
from tests.unit.wallet.test_headers import HEADERS, block_bytes
|
||||
|
@ -40,9 +37,9 @@ class MockNetwork:
|
|||
class LedgerTestCase(AsyncioTestCase):
|
||||
|
||||
async def asyncSetUp(self):
|
||||
self.ledger = MainNetLedger({
|
||||
'db': MainNetLedger.database_class(':memory:'),
|
||||
'headers': MainNetLedger.headers_class(':memory:')
|
||||
self.ledger = Ledger({
|
||||
'db': Database(':memory:'),
|
||||
'headers': Headers(':memory:')
|
||||
})
|
||||
self.account = Account.generate(self.ledger, Wallet(), "lbryum")
|
||||
await self.ledger.db.open()
|
||||
|
@ -76,7 +73,7 @@ class LedgerTestCase(AsyncioTestCase):
|
|||
class TestSynchronization(LedgerTestCase):
|
||||
|
||||
async def test_update_history(self):
|
||||
account = self.ledger.account_class.generate(self.ledger, Wallet(), "torba")
|
||||
account = Account.generate(self.ledger, Wallet(), "torba")
|
||||
address = await account.receiving.get_or_create_usable_address()
|
||||
address_details = await self.ledger.db.get_address(address=address)
|
||||
self.assertIsNone(address_details['history'])
|
||||
|
|
|
@ -3,9 +3,7 @@ from binascii import unhexlify
|
|||
from lbry.testcase import AsyncioTestCase
|
||||
from lbry.wallet.client.constants import CENT, NULL_HASH32
|
||||
|
||||
from lbry.wallet.ledger import MainNetLedger
|
||||
from lbry.wallet.transaction import Transaction, Input, Output
|
||||
|
||||
from lbry.wallet import Ledger, Database, Headers, Transaction, Input, Output
|
||||
from lbry.schema.claim import Claim
|
||||
|
||||
|
||||
|
@ -110,9 +108,9 @@ class TestValidatingOldSignatures(AsyncioTestCase):
|
|||
))
|
||||
channel = channel_tx.outputs[0]
|
||||
|
||||
ledger = MainNetLedger({
|
||||
'db': MainNetLedger.database_class(':memory:'),
|
||||
'headers': MainNetLedger.headers_class(':memory:')
|
||||
ledger = Ledger({
|
||||
'db': Database(':memory:'),
|
||||
'headers': Headers(':memory:')
|
||||
})
|
||||
|
||||
self.assertTrue(stream.is_signed_by(channel, ledger))
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
from lbry.wallet.script import OutputScript
|
||||
import unittest
|
||||
from binascii import hexlify, unhexlify
|
||||
|
||||
from lbry.wallet.client.bcd_data_stream import BCDataStream
|
||||
from lbry.wallet.client.basescript import Template, ParseError, tokenize, push_data
|
||||
from lbry.wallet.client.basescript import PUSH_SINGLE, PUSH_INTEGER, PUSH_MANY, OP_HASH160, OP_EQUAL
|
||||
from lbry.wallet.client.basescript import BaseInputScript, BaseOutputScript
|
||||
from lbry.wallet.bcd_data_stream import BCDataStream
|
||||
from lbry.wallet.script import (
|
||||
InputScript, OutputScript, Template, ParseError, tokenize, push_data,
|
||||
PUSH_SINGLE, PUSH_INTEGER, PUSH_MANY, OP_HASH160, OP_EQUAL
|
||||
)
|
||||
|
||||
|
||||
def parse(opcodes, source):
|
||||
|
@ -102,12 +102,12 @@ class TestRedeemPubKeyHash(unittest.TestCase):
|
|||
|
||||
def redeem_pubkey_hash(self, sig, pubkey):
|
||||
# this checks that factory function correctly sets up the script
|
||||
src1 = BaseInputScript.redeem_pubkey_hash(unhexlify(sig), unhexlify(pubkey))
|
||||
src1 = InputScript.redeem_pubkey_hash(unhexlify(sig), unhexlify(pubkey))
|
||||
self.assertEqual(src1.template.name, 'pubkey_hash')
|
||||
self.assertEqual(hexlify(src1.values['signature']), sig)
|
||||
self.assertEqual(hexlify(src1.values['pubkey']), pubkey)
|
||||
# now we test that it will round trip
|
||||
src2 = BaseInputScript(src1.source)
|
||||
src2 = InputScript(src1.source)
|
||||
self.assertEqual(src2.template.name, 'pubkey_hash')
|
||||
self.assertEqual(hexlify(src2.values['signature']), sig)
|
||||
self.assertEqual(hexlify(src2.values['pubkey']), pubkey)
|
||||
|
@ -130,7 +130,7 @@ class TestRedeemScriptHash(unittest.TestCase):
|
|||
|
||||
def redeem_script_hash(self, sigs, pubkeys):
|
||||
# this checks that factory function correctly sets up the script
|
||||
src1 = BaseInputScript.redeem_script_hash(
|
||||
src1 = InputScript.redeem_script_hash(
|
||||
[unhexlify(sig) for sig in sigs],
|
||||
[unhexlify(pubkey) for pubkey in pubkeys]
|
||||
)
|
||||
|
@ -141,7 +141,7 @@ class TestRedeemScriptHash(unittest.TestCase):
|
|||
self.assertEqual(subscript1.values['signatures_count'], len(sigs))
|
||||
self.assertEqual(subscript1.values['pubkeys_count'], len(pubkeys))
|
||||
# now we test that it will round trip
|
||||
src2 = BaseInputScript(src1.source)
|
||||
src2 = InputScript(src1.source)
|
||||
subscript2 = src2.values['script']
|
||||
self.assertEqual(src2.template.name, 'script_hash')
|
||||
self.assertListEqual([hexlify(v) for v in src2.values['signatures']], sigs)
|
||||
|
@ -183,11 +183,11 @@ class TestPayPubKeyHash(unittest.TestCase):
|
|||
|
||||
def pay_pubkey_hash(self, pubkey_hash):
|
||||
# this checks that factory function correctly sets up the script
|
||||
src1 = BaseOutputScript.pay_pubkey_hash(unhexlify(pubkey_hash))
|
||||
src1 = OutputScript.pay_pubkey_hash(unhexlify(pubkey_hash))
|
||||
self.assertEqual(src1.template.name, 'pay_pubkey_hash')
|
||||
self.assertEqual(hexlify(src1.values['pubkey_hash']), pubkey_hash)
|
||||
# now we test that it will round trip
|
||||
src2 = BaseOutputScript(src1.source)
|
||||
src2 = OutputScript(src1.source)
|
||||
self.assertEqual(src2.template.name, 'pay_pubkey_hash')
|
||||
self.assertEqual(hexlify(src2.values['pubkey_hash']), pubkey_hash)
|
||||
return hexlify(src1.source)
|
||||
|
@ -203,11 +203,11 @@ class TestPayScriptHash(unittest.TestCase):
|
|||
|
||||
def pay_script_hash(self, script_hash):
|
||||
# this checks that factory function correctly sets up the script
|
||||
src1 = BaseOutputScript.pay_script_hash(unhexlify(script_hash))
|
||||
src1 = OutputScript.pay_script_hash(unhexlify(script_hash))
|
||||
self.assertEqual(src1.template.name, 'pay_script_hash')
|
||||
self.assertEqual(hexlify(src1.values['script_hash']), script_hash)
|
||||
# now we test that it will round trip
|
||||
src2 = BaseOutputScript(src1.source)
|
||||
src2 = OutputScript(src1.source)
|
||||
self.assertEqual(src2.template.name, 'pay_script_hash')
|
||||
self.assertEqual(hexlify(src2.values['script_hash']), script_hash)
|
||||
return hexlify(src1.source)
|
||||
|
|
|
@ -4,10 +4,7 @@ from itertools import cycle
|
|||
|
||||
from lbry.testcase import AsyncioTestCase
|
||||
from lbry.wallet.client.constants import CENT, COIN, NULL_HASH32
|
||||
from lbry.wallet.client.wallet import Wallet
|
||||
|
||||
from lbry.wallet.ledger import MainNetLedger
|
||||
from lbry.wallet.transaction import Transaction, Output, Input
|
||||
from lbry.wallet import Wallet, Account, Ledger, Database, Headers, Transaction, Output, Input
|
||||
|
||||
|
||||
NULL_HASH = b'\x00'*32
|
||||
|
@ -40,9 +37,9 @@ def get_claim_transaction(claim_name, claim=b''):
|
|||
class TestSizeAndFeeEstimation(AsyncioTestCase):
|
||||
|
||||
async def asyncSetUp(self):
|
||||
self.ledger = MainNetLedger({
|
||||
'db': MainNetLedger.database_class(':memory:'),
|
||||
'headers': MainNetLedger.headers_class(':memory:')
|
||||
self.ledger = Ledger({
|
||||
'db': Database(':memory:'),
|
||||
'headers': Headers(':memory:')
|
||||
})
|
||||
await self.ledger.db.open()
|
||||
|
||||
|
@ -266,9 +263,9 @@ class TestTransactionSerialization(unittest.TestCase):
|
|||
class TestTransactionSigning(AsyncioTestCase):
|
||||
|
||||
async def asyncSetUp(self):
|
||||
self.ledger = MainNetLedger({
|
||||
'db': MainNetLedger.database_class(':memory:'),
|
||||
'headers': MainNetLedger.headers_class(':memory:')
|
||||
self.ledger = Ledger({
|
||||
'db': Database(':memory:'),
|
||||
'headers': Headers(':memory:')
|
||||
})
|
||||
await self.ledger.db.open()
|
||||
|
||||
|
@ -276,7 +273,7 @@ class TestTransactionSigning(AsyncioTestCase):
|
|||
await self.ledger.db.close()
|
||||
|
||||
async def test_sign(self):
|
||||
account = self.ledger.account_class.from_dict(
|
||||
account = Account.from_dict(
|
||||
self.ledger, Wallet(), {
|
||||
"seed":
|
||||
"carbon smart garage balance margin twelve chest sword toas"
|
||||
|
@ -305,12 +302,12 @@ class TestTransactionSigning(AsyncioTestCase):
|
|||
class TransactionIOBalancing(AsyncioTestCase):
|
||||
|
||||
async def asyncSetUp(self):
|
||||
self.ledger = MainNetLedger({
|
||||
'db': MainNetLedger.database_class(':memory:'),
|
||||
'headers': MainNetLedger.headers_class(':memory:')
|
||||
self.ledger = Ledger({
|
||||
'db': Database(':memory:'),
|
||||
'headers': Headers(':memory:')
|
||||
})
|
||||
await self.ledger.db.open()
|
||||
self.account = self.ledger.account_class.from_dict(
|
||||
self.account = Account.from_dict(
|
||||
self.ledger, Wallet(), {
|
||||
"seed": "carbon smart garage balance margin twelve chest sword "
|
||||
"toast envelope bottom stomach absent"
|
||||
|
@ -328,7 +325,7 @@ class TransactionIOBalancing(AsyncioTestCase):
|
|||
return get_output(int(amount*COIN), address or next(self.hash_cycler))
|
||||
|
||||
def txi(self, txo):
|
||||
return Transaction.input_class.spend(txo)
|
||||
return Input.spend(txo)
|
||||
|
||||
def tx(self, inputs, outputs):
|
||||
return Transaction.create(inputs, outputs, [self.account], self.account)
|
||||
|
|
|
@ -3,18 +3,18 @@ from binascii import hexlify
|
|||
|
||||
from unittest import TestCase, mock
|
||||
from lbry.testcase import AsyncioTestCase
|
||||
|
||||
from lbry.wallet.ledger import MainNetLedger, RegTestLedger
|
||||
from lbry.wallet.client.basemanager import BaseWalletManager
|
||||
from lbry.wallet.client.wallet import Wallet, WalletStorage, TimestampedPreferences
|
||||
from lbry.wallet import (
|
||||
Ledger, RegTestLedger, WalletManager, Account,
|
||||
Wallet, WalletStorage, TimestampedPreferences
|
||||
)
|
||||
|
||||
|
||||
class TestWalletCreation(AsyncioTestCase):
|
||||
|
||||
async def asyncSetUp(self):
|
||||
self.manager = BaseWalletManager()
|
||||
self.manager = WalletManager()
|
||||
config = {'data_path': '/tmp/wallet'}
|
||||
self.main_ledger = self.manager.get_or_create_ledger(MainNetLedger.get_id(), config)
|
||||
self.main_ledger = self.manager.get_or_create_ledger(Ledger.get_id(), config)
|
||||
self.test_ledger = self.manager.get_or_create_ledger(RegTestLedger.get_id(), config)
|
||||
|
||||
def test_create_wallet_and_accounts(self):
|
||||
|
@ -66,7 +66,7 @@ class TestWalletCreation(AsyncioTestCase):
|
|||
)
|
||||
self.assertEqual(len(wallet.accounts), 1)
|
||||
account = wallet.default_account
|
||||
self.assertIsInstance(account, MainNetLedger.account_class)
|
||||
self.assertIsInstance(account, Account)
|
||||
self.maxDiff = None
|
||||
self.assertDictEqual(wallet_dict, wallet.to_dict())
|
||||
|
||||
|
@ -75,9 +75,9 @@ class TestWalletCreation(AsyncioTestCase):
|
|||
self.assertEqual(decrypted['accounts'][0]['name'], 'An Account')
|
||||
|
||||
def test_read_write(self):
|
||||
manager = BaseWalletManager()
|
||||
manager = WalletManager()
|
||||
config = {'data_path': '/tmp/wallet'}
|
||||
ledger = manager.get_or_create_ledger(MainNetLedger.get_id(), config)
|
||||
ledger = manager.get_or_create_ledger(Ledger.get_id(), config)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix='.json') as wallet_file:
|
||||
wallet_file.write(b'{"version": 1}')
|
||||
|
|
Loading…
Reference in a new issue