wip lbry.wallet
This commit is contained in:
parent
fef09c1773
commit
533f31cc89
8 changed files with 800 additions and 746 deletions
|
@ -1,17 +0,0 @@
|
|||
__node_daemon__ = 'lbrycrdd'
|
||||
__node_cli__ = 'lbrycrd-cli'
|
||||
__node_bin__ = ''
|
||||
__node_url__ = (
|
||||
'https://github.com/lbryio/lbrycrd/releases/download/v0.17.4.4/lbrycrd-linux-1744.zip'
|
||||
)
|
||||
__spvserver__ = 'lbry.wallet.server.coin.LBCRegTest'
|
||||
|
||||
from .bip32 import PubKey
|
||||
from .wallet import Wallet, WalletStorage, TimestampedPreferences, ENCRYPT_ON_DISK
|
||||
from .manager import WalletManager
|
||||
from .network import Network
|
||||
from .ledger import Ledger, RegTestLedger, TestNetLedger, BlockHeightEvent
|
||||
from .account import Account, AddressManager, SingleKey, HierarchicalDeterministic
|
||||
from .transaction import Transaction, Output, Input
|
||||
from .script import OutputScript, InputScript
|
||||
from .header import Headers
|
|
@ -2,7 +2,6 @@ import os
|
|||
import time
|
||||
import json
|
||||
import logging
|
||||
import typing
|
||||
import asyncio
|
||||
import random
|
||||
from functools import partial
|
||||
|
@ -16,14 +15,15 @@ import ecdsa
|
|||
from lbry.error import InvalidPasswordError
|
||||
from lbry.crypto.crypt import aes_encrypt, aes_decrypt
|
||||
|
||||
from .bip32 import PrivateKey, PubKey, from_extended_key_string
|
||||
from .mnemonic import Mnemonic
|
||||
from .constants import COIN, CLAIM_TYPES, TXO_TYPES
|
||||
from .transaction import Transaction, Input, Output
|
||||
from lbry.crypto.bip32 import PrivateKey, PubKey, from_extended_key_string
|
||||
from lbry.constants import COIN
|
||||
from lbry.blockchain.transaction import Transaction, Input, Output
|
||||
from lbry.blockchain.ledger import Ledger
|
||||
from lbry.db import Database
|
||||
from lbry.db.constants import CLAIM_TYPE_CODES, TXO_TYPES
|
||||
|
||||
from .mnemonic import Mnemonic
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from .ledger import Ledger
|
||||
from .wallet import Wallet
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -71,12 +71,12 @@ class AddressManager:
|
|||
def to_dict_instance(self) -> Optional[dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _query_addresses(self, **constraints):
|
||||
return self.account.ledger.db.get_addresses(
|
||||
accounts=[self.account],
|
||||
async def _query_addresses(self, **constraints):
|
||||
return (await self.account.db.get_addresses(
|
||||
account=self.account,
|
||||
chain=self.chain_number,
|
||||
**constraints
|
||||
)
|
||||
))[0]
|
||||
|
||||
def get_private_key(self, index: int) -> PrivateKey:
|
||||
raise NotImplementedError
|
||||
|
@ -166,22 +166,22 @@ class HierarchicalDeterministic(AddressManager):
|
|||
start = addresses[0]['pubkey'].n+1 if addresses else 0
|
||||
end = start + (self.gap - existing_gap)
|
||||
new_keys = await self._generate_keys(start, end-1)
|
||||
await self.account.ledger.announce_addresses(self, new_keys)
|
||||
#await self.account.ledger.announce_addresses(self, new_keys)
|
||||
return new_keys
|
||||
|
||||
async def _generate_keys(self, start: int, end: int) -> List[str]:
|
||||
if not self.address_generator_lock.locked():
|
||||
raise RuntimeError('Should not be called outside of address_generator_lock.')
|
||||
keys = [self.public_key.child(index) for index in range(start, end+1)]
|
||||
await self.account.ledger.db.add_keys(self.account, self.chain_number, keys)
|
||||
await self.account.db.add_keys(self.account, self.chain_number, keys)
|
||||
return [key.address for key in keys]
|
||||
|
||||
def get_address_records(self, only_usable: bool = False, **constraints):
|
||||
async def get_address_records(self, only_usable: bool = False, **constraints):
|
||||
if only_usable:
|
||||
constraints['used_times__lt'] = self.maximum_uses_per_address
|
||||
if 'order_by' not in constraints:
|
||||
constraints['order_by'] = "used_times asc, n asc"
|
||||
return self._query_addresses(**constraints)
|
||||
return await self._query_addresses(**constraints)
|
||||
|
||||
|
||||
class SingleKey(AddressManager):
|
||||
|
@ -213,14 +213,14 @@ class SingleKey(AddressManager):
|
|||
async with self.address_generator_lock:
|
||||
exists = await self.get_address_records()
|
||||
if not exists:
|
||||
await self.account.ledger.db.add_keys(self.account, self.chain_number, [self.public_key])
|
||||
await self.account.db.add_keys(self.account, self.chain_number, [self.public_key])
|
||||
new_keys = [self.public_key.address]
|
||||
await self.account.ledger.announce_addresses(self, new_keys)
|
||||
#await self.account.ledger.announce_addresses(self, new_keys)
|
||||
return new_keys
|
||||
return []
|
||||
|
||||
def get_address_records(self, only_usable: bool = False, **constraints):
|
||||
return self._query_addresses(**constraints)
|
||||
async def get_address_records(self, only_usable: bool = False, **constraints):
|
||||
return await self._query_addresses(**constraints)
|
||||
|
||||
|
||||
class Account:
|
||||
|
@ -233,12 +233,12 @@ class Account:
|
|||
HierarchicalDeterministic.name: HierarchicalDeterministic,
|
||||
}
|
||||
|
||||
def __init__(self, ledger: 'Ledger', wallet: 'Wallet', name: str,
|
||||
def __init__(self, ledger: 'Ledger', db: 'Database', name: str,
|
||||
seed: str, private_key_string: str, encrypted: bool,
|
||||
private_key: Optional[PrivateKey], public_key: PubKey,
|
||||
address_generator: dict, modified_on: float, channel_keys: dict) -> None:
|
||||
self.ledger = ledger
|
||||
self.wallet = wallet
|
||||
self.db = db
|
||||
self.id = public_key.address
|
||||
self.name = name
|
||||
self.seed = seed
|
||||
|
@ -253,8 +253,6 @@ class Account:
|
|||
self.receiving, self.change = self.address_generator.from_dict(self, address_generator)
|
||||
self.address_managers = {am.chain_number: am for am in {self.receiving, self.change}}
|
||||
self.channel_keys = channel_keys
|
||||
ledger.add_account(self)
|
||||
wallet.add_account(self)
|
||||
|
||||
def get_init_vector(self, key) -> Optional[bytes]:
|
||||
init_vector = self.init_vectors.get(key, None)
|
||||
|
@ -263,9 +261,9 @@ class Account:
|
|||
return init_vector
|
||||
|
||||
@classmethod
|
||||
def generate(cls, ledger: 'Ledger', wallet: 'Wallet',
|
||||
def generate(cls, ledger: 'Ledger', db: 'Database',
|
||||
name: str = None, address_generator: dict = None):
|
||||
return cls.from_dict(ledger, wallet, {
|
||||
return cls.from_dict(ledger, db, {
|
||||
'name': name,
|
||||
'seed': cls.mnemonic_class().make_seed(),
|
||||
'address_generator': address_generator or {}
|
||||
|
@ -297,14 +295,14 @@ class Account:
|
|||
return seed, private_key, public_key
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, ledger: 'Ledger', wallet: 'Wallet', d: dict):
|
||||
def from_dict(cls, ledger: 'Ledger', db: 'Database', d: dict):
|
||||
seed, private_key, public_key = cls.keys_from_dict(ledger, d)
|
||||
name = d.get('name')
|
||||
if not name:
|
||||
name = f'Account #{public_key.address}'
|
||||
return cls(
|
||||
ledger=ledger,
|
||||
wallet=wallet,
|
||||
db=db,
|
||||
name=name,
|
||||
seed=seed,
|
||||
private_key_string=d.get('private_key', ''),
|
||||
|
@ -328,7 +326,6 @@ class Account:
|
|||
if seed:
|
||||
seed = aes_encrypt(encrypt_password, self.seed, self.get_init_vector('seed'))
|
||||
d = {
|
||||
'ledger': self.ledger.get_id(),
|
||||
'name': self.name,
|
||||
'seed': seed,
|
||||
'encrypted': bool(self.encrypted or encrypt_password),
|
||||
|
@ -365,7 +362,6 @@ class Account:
|
|||
details = {
|
||||
'id': self.id,
|
||||
'name': self.name,
|
||||
'ledger': self.ledger.get_id(),
|
||||
'coins': round(satoshis/COIN, 2),
|
||||
'satoshis': satoshis,
|
||||
'encrypted': self.encrypted,
|
||||
|
@ -436,15 +432,18 @@ class Account:
|
|||
addresses.extend(new_addresses)
|
||||
return addresses
|
||||
|
||||
async def get_address_records(self, **constraints):
|
||||
return await self.db.get_addresses(account=self, **constraints)
|
||||
|
||||
async def get_addresses(self, **constraints) -> List[str]:
|
||||
rows = await self.ledger.db.select_addresses([text('account_address.address')], accounts=[self], **constraints)
|
||||
rows, _ = await self.get_address_records(cols=['account_address.address'], **constraints)
|
||||
return [r['address'] for r in rows]
|
||||
|
||||
def get_address_records(self, **constraints):
|
||||
return self.ledger.db.get_addresses(accounts=[self], **constraints)
|
||||
|
||||
def get_address_count(self, **constraints):
|
||||
return self.ledger.db.get_address_count(accounts=[self], **constraints)
|
||||
async def get_valid_receiving_address(self, default_address: str) -> str:
|
||||
if default_address is None:
|
||||
return await self.receiving.get_or_create_usable_address()
|
||||
self.ledger.valid_address_or_error(default_address)
|
||||
return default_address
|
||||
|
||||
def get_private_key(self, chain: int, index: int) -> PrivateKey:
|
||||
assert not self.encrypted, "Cannot get private key on encrypted wallet account."
|
||||
|
@ -459,7 +458,7 @@ class Account:
|
|||
if confirmations > 0:
|
||||
height = self.ledger.headers.height - (confirmations-1)
|
||||
constraints.update({'height__lte': height, 'height__gt': 0})
|
||||
return self.ledger.db.get_balance(accounts=[self], **constraints)
|
||||
return self.db.get_balance(account=self, **constraints)
|
||||
|
||||
async def get_max_gap(self):
|
||||
change_gap = await self.change.get_max_gap()
|
||||
|
@ -469,24 +468,6 @@ class Account:
|
|||
'max_receiving_gap': receiving_gap,
|
||||
}
|
||||
|
||||
def get_txos(self, **constraints):
|
||||
return self.ledger.get_txos(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
def get_txo_count(self, **constraints):
|
||||
return self.ledger.get_txo_count(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
def get_utxos(self, **constraints):
|
||||
return self.ledger.get_utxos(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
def get_utxo_count(self, **constraints):
|
||||
return self.ledger.get_utxo_count(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
def get_transactions(self, **constraints):
|
||||
return self.ledger.get_transactions(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
def get_transaction_count(self, **constraints):
|
||||
return self.ledger.get_transaction_count(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
async def fund(self, to_account, amount=None, everything=False,
|
||||
outputs=1, broadcast=False, **constraints):
|
||||
assert self.ledger == to_account.ledger, 'Can only transfer between accounts of the same ledger.'
|
||||
|
@ -551,9 +532,9 @@ class Account:
|
|||
self.wallet.save()
|
||||
|
||||
async def save_max_gap(self):
|
||||
gap_changed = False
|
||||
if issubclass(self.address_generator, HierarchicalDeterministic):
|
||||
gap = await self.get_max_gap()
|
||||
gap_changed = False
|
||||
new_receiving_gap = max(20, gap['max_receiving_gap'] + 1)
|
||||
if self.receiving.gap != new_receiving_gap:
|
||||
self.receiving.gap = new_receiving_gap
|
||||
|
@ -562,8 +543,10 @@ class Account:
|
|||
if self.change.gap != new_change_gap:
|
||||
self.change.gap = new_change_gap
|
||||
gap_changed = True
|
||||
if gap_changed:
|
||||
self.wallet.save()
|
||||
return gap_changed
|
||||
|
||||
def get_support_summary(self):
|
||||
return self.db.get_supports_summary(account=self)
|
||||
|
||||
async def get_detailed_balance(self, confirmations=0, reserved_subtotals=False):
|
||||
tips_balance, supports_balance, claims_balance = 0, 0, 0
|
||||
|
@ -571,7 +554,7 @@ class Account:
|
|||
include_claims=True)
|
||||
total = await get_total_balance()
|
||||
if reserved_subtotals:
|
||||
claims_balance = await get_total_balance(txo_type__in=CLAIM_TYPES)
|
||||
claims_balance = await get_total_balance(txo_type__in=CLAIM_TYPE_CODES)
|
||||
for txo in await self.get_support_summary():
|
||||
if confirmations > 0 and not 0 < txo.tx_ref.height <= self.ledger.headers.height - (confirmations - 1):
|
||||
continue
|
||||
|
@ -594,49 +577,3 @@ class Account:
|
|||
'tips': tips_balance
|
||||
} if reserved_subtotals else None
|
||||
}
|
||||
|
||||
def get_transaction_history(self, **constraints):
|
||||
return self.ledger.get_transaction_history(
|
||||
wallet=self.wallet, accounts=[self], **constraints
|
||||
)
|
||||
|
||||
def get_transaction_history_count(self, **constraints):
|
||||
return self.ledger.get_transaction_history_count(
|
||||
wallet=self.wallet, accounts=[self], **constraints
|
||||
)
|
||||
|
||||
def get_claims(self, **constraints):
|
||||
return self.ledger.get_claims(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
def get_claim_count(self, **constraints):
|
||||
return self.ledger.get_claim_count(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
def get_streams(self, **constraints):
|
||||
return self.ledger.get_streams(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
def get_stream_count(self, **constraints):
|
||||
return self.ledger.get_stream_count(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
def get_channels(self, **constraints):
|
||||
return self.ledger.get_channels(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
def get_channel_count(self, **constraints):
|
||||
return self.ledger.get_channel_count(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
def get_collections(self, **constraints):
|
||||
return self.ledger.get_collections(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
def get_collection_count(self, **constraints):
|
||||
return self.ledger.get_collection_count(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
def get_supports(self, **constraints):
|
||||
return self.ledger.get_supports(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
def get_support_count(self, **constraints):
|
||||
return self.ledger.get_support_count(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
def get_support_summary(self):
|
||||
return self.ledger.db.get_supports_summary(wallet=self.wallet, accounts=[self])
|
||||
|
||||
async def release_all_outputs(self):
|
||||
await self.ledger.db.release_all_outputs(self)
|
||||
|
|
|
@ -1,18 +1,32 @@
|
|||
from random import Random
|
||||
from typing import List
|
||||
|
||||
from lbry.wallet.transaction import OutputEffectiveAmountEstimator
|
||||
from lbry.blockchain.transaction import Input, Output
|
||||
|
||||
MAXIMUM_TRIES = 100000
|
||||
|
||||
STRATEGIES = []
|
||||
COIN_SELECTION_STRATEGIES = []
|
||||
|
||||
|
||||
def strategy(method):
|
||||
STRATEGIES.append(method.__name__)
|
||||
COIN_SELECTION_STRATEGIES.append(method.__name__)
|
||||
return method
|
||||
|
||||
|
||||
class OutputEffectiveAmountEstimator:
|
||||
|
||||
__slots__ = 'txo', 'txi', 'fee', 'effective_amount'
|
||||
|
||||
def __init__(self, ledger, txo: Output) -> None:
|
||||
self.txo = txo
|
||||
self.txi = Input.spend(txo)
|
||||
self.fee: int = self.txi.get_fee(ledger)
|
||||
self.effective_amount: int = txo.amount - self.fee
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.effective_amount < other.effective_amount
|
||||
|
||||
|
||||
class CoinSelector:
|
||||
|
||||
def __init__(self, target: int, cost_of_change: int, seed: str = None) -> None:
|
||||
|
|
|
@ -10,16 +10,14 @@ from typing import List, Type, MutableSequence, MutableMapping, Optional
|
|||
from lbry.error import KeyFeeAboveMaxAllowedError
|
||||
from lbry.conf import Config
|
||||
|
||||
from .dewies import dewies_to_lbc
|
||||
from .account import Account
|
||||
from .ledger import Ledger, LedgerRegistry
|
||||
from .transaction import Transaction, Output
|
||||
from .wallet import Wallet, WalletStorage, ENCRYPT_ON_DISK
|
||||
from .rpc.jsonrpc import CodeMessageError
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from lbry.blockchain.dewies import dewies_to_lbc
|
||||
from lbry.blockchain.ledger import Ledger
|
||||
from lbry.db import Database
|
||||
from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager
|
||||
from lbry.blockchain.ledger import Ledger
|
||||
from lbry.blockchain.transaction import Transaction, Output
|
||||
|
||||
from .wallet import Wallet, WalletStorage, ENCRYPT_ON_DISK
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -27,35 +25,42 @@ log = logging.getLogger(__name__)
|
|||
|
||||
class WalletManager:
|
||||
|
||||
def __init__(self, wallets: MutableSequence[Wallet] = None,
|
||||
def __init__(self, ledger: Ledger, db: Database,
|
||||
wallets: MutableSequence[Wallet] = None,
|
||||
ledgers: MutableMapping[Type[Ledger], Ledger] = None) -> None:
|
||||
self.ledger = ledger
|
||||
self.db = db
|
||||
self.wallets = wallets or []
|
||||
self.ledgers = ledgers or {}
|
||||
self.running = False
|
||||
self.config: Optional[Config] = None
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> 'WalletManager':
|
||||
manager = cls()
|
||||
for ledger_id, ledger_config in config.get('ledgers', {}).items():
|
||||
manager.get_or_create_ledger(ledger_id, ledger_config)
|
||||
for wallet_path in config.get('wallets', []):
|
||||
wallet_storage = WalletStorage(wallet_path)
|
||||
wallet = Wallet.from_storage(wallet_storage, manager)
|
||||
manager.wallets.append(wallet)
|
||||
return manager
|
||||
async def open(self):
|
||||
conf = self.ledger.conf
|
||||
|
||||
def get_or_create_ledger(self, ledger_id, ledger_config=None):
|
||||
ledger_class = LedgerRegistry.get_ledger_class(ledger_id)
|
||||
ledger = self.ledgers.get(ledger_class)
|
||||
if ledger is None:
|
||||
ledger = ledger_class(ledger_config or {})
|
||||
self.ledgers[ledger_class] = ledger
|
||||
return ledger
|
||||
wallets_directory = os.path.join(conf.wallet_dir, 'wallets')
|
||||
if not os.path.exists(wallets_directory):
|
||||
os.mkdir(wallets_directory)
|
||||
|
||||
for wallet_file in conf.wallets:
|
||||
wallet_path = os.path.join(wallets_directory, wallet_file)
|
||||
wallet_storage = WalletStorage(wallet_path)
|
||||
wallet = Wallet.from_storage(self.ledger, self.db, wallet_storage)
|
||||
self.wallets.append(wallet)
|
||||
|
||||
self.ledger.coin_selection_strategy = self.ledger.conf.coin_selection_strategy
|
||||
default_wallet = self.default_wallet
|
||||
if default_wallet.default_account is None:
|
||||
log.info('Wallet at %s is empty, generating a default account.', default_wallet.id)
|
||||
default_wallet.generate_account()
|
||||
default_wallet.save()
|
||||
if default_wallet.is_locked and default_wallet.preferences.get(ENCRYPT_ON_DISK) is None:
|
||||
default_wallet.preferences[ENCRYPT_ON_DISK] = True
|
||||
default_wallet.save()
|
||||
|
||||
def import_wallet(self, path):
|
||||
storage = WalletStorage(path)
|
||||
wallet = Wallet.from_storage(storage, self)
|
||||
wallet = Wallet.from_storage(self.ledger, self.db, storage)
|
||||
self.wallets.append(wallet)
|
||||
return wallet
|
||||
|
||||
|
@ -104,123 +109,9 @@ class WalletManager:
|
|||
return 0
|
||||
return accounts[0].ledger.db.get_balance(wallet=wallet, accounts=accounts)
|
||||
|
||||
@property
|
||||
def ledger(self) -> Ledger:
|
||||
return self.default_account.ledger
|
||||
|
||||
@property
|
||||
def db(self) -> 'Database':
|
||||
return self.ledger.db
|
||||
|
||||
def check_locked(self):
|
||||
return self.default_wallet.is_locked
|
||||
|
||||
@staticmethod
|
||||
def migrate_lbryum_to_torba(path):
|
||||
if not os.path.exists(path):
|
||||
return None, None
|
||||
with open(path, 'r') as f:
|
||||
unmigrated_json = f.read()
|
||||
unmigrated = json.loads(unmigrated_json)
|
||||
# TODO: After several public releases of new torba based wallet, we can delete
|
||||
# this lbryum->torba conversion code and require that users who still
|
||||
# have old structured wallets install one of the earlier releases that
|
||||
# still has the below conversion code.
|
||||
if 'master_public_keys' not in unmigrated:
|
||||
return None, None
|
||||
total = unmigrated.get('addr_history')
|
||||
receiving_addresses, change_addresses = set(), set()
|
||||
for _, unmigrated_account in unmigrated.get('accounts', {}).items():
|
||||
receiving_addresses.update(map(unhexlify, unmigrated_account.get('receiving', [])))
|
||||
change_addresses.update(map(unhexlify, unmigrated_account.get('change', [])))
|
||||
log.info("Wallet migrator found %s receiving addresses and %s change addresses. %s in total on history.",
|
||||
len(receiving_addresses), len(change_addresses), len(total))
|
||||
|
||||
migrated_json = json.dumps({
|
||||
'version': 1,
|
||||
'name': 'My Wallet',
|
||||
'accounts': [{
|
||||
'version': 1,
|
||||
'name': 'Main Account',
|
||||
'ledger': 'lbc_mainnet',
|
||||
'encrypted': unmigrated['use_encryption'],
|
||||
'seed': unmigrated['seed'],
|
||||
'seed_version': unmigrated['seed_version'],
|
||||
'private_key': unmigrated['master_private_keys']['x/'],
|
||||
'public_key': unmigrated['master_public_keys']['x/'],
|
||||
'certificates': unmigrated.get('claim_certificates', {}),
|
||||
'address_generator': {
|
||||
'name': 'deterministic-chain',
|
||||
'receiving': {'gap': 20, 'maximum_uses_per_address': 1},
|
||||
'change': {'gap': 6, 'maximum_uses_per_address': 1}
|
||||
}
|
||||
}]
|
||||
}, indent=4, sort_keys=True)
|
||||
mode = os.stat(path).st_mode
|
||||
i = 1
|
||||
backup_path_template = os.path.join(os.path.dirname(path), "old_lbryum_wallet") + "_%i"
|
||||
while os.path.isfile(backup_path_template % i):
|
||||
i += 1
|
||||
os.rename(path, backup_path_template % i)
|
||||
temp_path = f"{path}.tmp.{os.getpid()}"
|
||||
with open(temp_path, "w") as f:
|
||||
f.write(migrated_json)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
os.rename(temp_path, path)
|
||||
os.chmod(path, mode)
|
||||
return receiving_addresses, change_addresses
|
||||
|
||||
@classmethod
|
||||
async def from_lbrynet_config(cls, config: Config):
|
||||
|
||||
ledger_id = {
|
||||
'lbrycrd_main': 'lbc_mainnet',
|
||||
'lbrycrd_testnet': 'lbc_testnet',
|
||||
'lbrycrd_regtest': 'lbc_regtest'
|
||||
}[config.blockchain_name]
|
||||
|
||||
ledger_config = {
|
||||
'auto_connect': True,
|
||||
'default_servers': config.lbryum_servers,
|
||||
'data_path': config.wallet_dir,
|
||||
}
|
||||
|
||||
wallets_directory = os.path.join(config.wallet_dir, 'wallets')
|
||||
if not os.path.exists(wallets_directory):
|
||||
os.mkdir(wallets_directory)
|
||||
|
||||
receiving_addresses, change_addresses = cls.migrate_lbryum_to_torba(
|
||||
os.path.join(wallets_directory, 'default_wallet')
|
||||
)
|
||||
|
||||
manager = cls.from_config({
|
||||
'ledgers': {ledger_id: ledger_config},
|
||||
'wallets': [
|
||||
os.path.join(wallets_directory, wallet_file) for wallet_file in config.wallets
|
||||
]
|
||||
})
|
||||
manager.config = config
|
||||
ledger = manager.get_or_create_ledger(ledger_id)
|
||||
ledger.coin_selection_strategy = config.coin_selection_strategy
|
||||
default_wallet = manager.default_wallet
|
||||
if default_wallet.default_account is None:
|
||||
log.info('Wallet at %s is empty, generating a default account.', default_wallet.id)
|
||||
default_wallet.generate_account(ledger)
|
||||
default_wallet.save()
|
||||
if default_wallet.is_locked and default_wallet.preferences.get(ENCRYPT_ON_DISK) is None:
|
||||
default_wallet.preferences[ENCRYPT_ON_DISK] = True
|
||||
default_wallet.save()
|
||||
if receiving_addresses or change_addresses:
|
||||
if not os.path.exists(ledger.path):
|
||||
os.mkdir(ledger.path)
|
||||
await ledger.db.open()
|
||||
try:
|
||||
await manager._migrate_addresses(receiving_addresses, change_addresses)
|
||||
finally:
|
||||
await ledger.db.close()
|
||||
return manager
|
||||
|
||||
async def reset(self):
|
||||
self.ledger.config = {
|
||||
'auto_connect': True,
|
||||
|
@ -230,24 +121,6 @@ class WalletManager:
|
|||
await self.ledger.stop()
|
||||
await self.ledger.start()
|
||||
|
||||
async def _migrate_addresses(self, receiving_addresses: set, change_addresses: set):
|
||||
async with self.default_account.receiving.address_generator_lock:
|
||||
migrated_receiving = set(await self.default_account.receiving._generate_keys(0, len(receiving_addresses)))
|
||||
async with self.default_account.change.address_generator_lock:
|
||||
migrated_change = set(await self.default_account.change._generate_keys(0, len(change_addresses)))
|
||||
receiving_addresses = set(map(self.default_account.ledger.public_key_to_address, receiving_addresses))
|
||||
change_addresses = set(map(self.default_account.ledger.public_key_to_address, change_addresses))
|
||||
if not any(change_addresses.difference(migrated_change)):
|
||||
log.info("Successfully migrated %s change addresses.", len(change_addresses))
|
||||
else:
|
||||
log.warning("Failed to migrate %s change addresses!",
|
||||
len(set(change_addresses).difference(set(migrated_change))))
|
||||
if not any(receiving_addresses.difference(migrated_receiving)):
|
||||
log.info("Successfully migrated %s receiving addresses.", len(receiving_addresses))
|
||||
else:
|
||||
log.warning("Failed to migrate %s receiving addresses!",
|
||||
len(set(receiving_addresses).difference(set(migrated_receiving))))
|
||||
|
||||
async def get_best_blockhash(self):
|
||||
if len(self.ledger.headers) <= 0:
|
||||
return self.ledger.genesis_hash
|
||||
|
@ -272,35 +145,3 @@ class WalletManager:
|
|||
await self.ledger.maybe_verify_transaction(tx, height, merkle)
|
||||
return tx
|
||||
|
||||
async def create_purchase_transaction(
|
||||
self, accounts: List[Account], txo: Output, exchange: 'ExchangeRateManager',
|
||||
override_max_key_fee=False):
|
||||
fee = txo.claim.stream.fee
|
||||
fee_amount = exchange.to_dewies(fee.currency, fee.amount)
|
||||
if not override_max_key_fee and self.config.max_key_fee:
|
||||
max_fee = self.config.max_key_fee
|
||||
max_fee_amount = exchange.to_dewies(max_fee['currency'], Decimal(max_fee['amount']))
|
||||
if max_fee_amount and fee_amount > max_fee_amount:
|
||||
error_fee = f"{dewies_to_lbc(fee_amount)} LBC"
|
||||
if fee.currency != 'LBC':
|
||||
error_fee += f" ({fee.amount} {fee.currency})"
|
||||
error_max_fee = f"{dewies_to_lbc(max_fee_amount)} LBC"
|
||||
if max_fee['currency'] != 'LBC':
|
||||
error_max_fee += f" ({max_fee['amount']} {max_fee['currency']})"
|
||||
raise KeyFeeAboveMaxAllowedError(
|
||||
f"Purchase price of {error_fee} exceeds maximum "
|
||||
f"configured price of {error_max_fee}."
|
||||
)
|
||||
fee_address = fee.address or txo.get_address(self.ledger)
|
||||
return await Transaction.purchase(
|
||||
txo.claim_id, fee_amount, fee_address, accounts, accounts[0]
|
||||
)
|
||||
|
||||
async def broadcast_or_release(self, tx, blocking=False):
|
||||
try:
|
||||
await self.ledger.broadcast(tx)
|
||||
if blocking:
|
||||
await self.ledger.wait(tx, timeout=None)
|
||||
except:
|
||||
await self.ledger.release_tx(tx)
|
||||
raise
|
||||
|
|
|
@ -1,406 +0,0 @@
|
|||
import logging
|
||||
import asyncio
|
||||
import json
|
||||
from time import perf_counter
|
||||
from operator import itemgetter
|
||||
from typing import Dict, Optional, Tuple
|
||||
from binascii import hexlify
|
||||
|
||||
from lbry import __version__
|
||||
from lbry.error import IncompatibleWalletServerError
|
||||
from lbry.wallet.rpc import RPCSession as BaseClientSession, Connector, RPCError, ProtocolError
|
||||
from lbry.wallet.stream import StreamController
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClientSession(BaseClientSession):
|
||||
def __init__(self, *args, network, server, timeout=30, on_connect_callback=None, **kwargs):
|
||||
self.network = network
|
||||
self.server = server
|
||||
super().__init__(*args, **kwargs)
|
||||
self._on_disconnect_controller = StreamController()
|
||||
self.on_disconnected = self._on_disconnect_controller.stream
|
||||
self.framer.max_size = self.max_errors = 1 << 32
|
||||
self.timeout = timeout
|
||||
self.max_seconds_idle = timeout * 2
|
||||
self.response_time: Optional[float] = None
|
||||
self.connection_latency: Optional[float] = None
|
||||
self._response_samples = 0
|
||||
self.pending_amount = 0
|
||||
self._on_connect_cb = on_connect_callback or (lambda: None)
|
||||
self.trigger_urgent_reconnect = asyncio.Event()
|
||||
|
||||
@property
|
||||
def available(self):
|
||||
return not self.is_closing() and self.response_time is not None
|
||||
|
||||
@property
|
||||
def server_address_and_port(self) -> Optional[Tuple[str, int]]:
|
||||
if not self.transport:
|
||||
return None
|
||||
return self.transport.get_extra_info('peername')
|
||||
|
||||
async def send_timed_server_version_request(self, args=(), timeout=None):
|
||||
timeout = timeout or self.timeout
|
||||
log.debug("send version request to %s:%i", *self.server)
|
||||
start = perf_counter()
|
||||
result = await asyncio.wait_for(
|
||||
super().send_request('server.version', args), timeout=timeout
|
||||
)
|
||||
current_response_time = perf_counter() - start
|
||||
response_sum = (self.response_time or 0) * self._response_samples + current_response_time
|
||||
self.response_time = response_sum / (self._response_samples + 1)
|
||||
self._response_samples += 1
|
||||
return result
|
||||
|
||||
async def send_request(self, method, args=()):
|
||||
self.pending_amount += 1
|
||||
log.debug("send %s%s to %s:%i", method, tuple(args), *self.server)
|
||||
try:
|
||||
if method == 'server.version':
|
||||
return await self.send_timed_server_version_request(args, self.timeout)
|
||||
request = asyncio.ensure_future(super().send_request(method, args))
|
||||
while not request.done():
|
||||
done, pending = await asyncio.wait([request], timeout=self.timeout)
|
||||
if pending:
|
||||
log.debug("Time since last packet: %s", perf_counter() - self.last_packet_received)
|
||||
if (perf_counter() - self.last_packet_received) < self.timeout:
|
||||
continue
|
||||
log.info("timeout sending %s to %s:%i", method, *self.server)
|
||||
raise asyncio.TimeoutError
|
||||
if done:
|
||||
try:
|
||||
return request.result()
|
||||
except ConnectionResetError:
|
||||
log.error(
|
||||
"wallet server (%s) reset connection upon our %s request, json of %i args is %i bytes",
|
||||
self.server[0], method, len(args), len(json.dumps(args))
|
||||
)
|
||||
raise
|
||||
except (RPCError, ProtocolError) as e:
|
||||
log.warning("Wallet server (%s:%i) returned an error. Code: %s Message: %s",
|
||||
*self.server, *e.args)
|
||||
raise e
|
||||
except ConnectionError:
|
||||
log.warning("connection to %s:%i lost", *self.server)
|
||||
self.synchronous_close()
|
||||
raise
|
||||
except asyncio.CancelledError:
|
||||
log.info("cancelled sending %s to %s:%i", method, *self.server)
|
||||
self.synchronous_close()
|
||||
raise
|
||||
finally:
|
||||
self.pending_amount -= 1
|
||||
|
||||
async def ensure_session(self):
|
||||
# Handles reconnecting and maintaining a session alive
|
||||
# TODO: change to 'ping' on newer protocol (above 1.2)
|
||||
retry_delay = default_delay = 1.0
|
||||
while True:
|
||||
try:
|
||||
if self.is_closing():
|
||||
await self.create_connection(self.timeout)
|
||||
await self.ensure_server_version()
|
||||
self._on_connect_cb()
|
||||
if (perf_counter() - self.last_send) > self.max_seconds_idle or self.response_time is None:
|
||||
await self.ensure_server_version()
|
||||
retry_delay = default_delay
|
||||
except RPCError as e:
|
||||
await self.close()
|
||||
log.debug("Server error, ignoring for 1h: %s:%d -- %s", *self.server, e.message)
|
||||
retry_delay = 60 * 60
|
||||
except IncompatibleWalletServerError:
|
||||
await self.close()
|
||||
retry_delay = 60 * 60
|
||||
log.debug("Wallet server has an incompatible version, retrying in 1h: %s:%d", *self.server)
|
||||
except (asyncio.TimeoutError, OSError):
|
||||
await self.close()
|
||||
retry_delay = min(60, retry_delay * 2)
|
||||
log.debug("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server)
|
||||
try:
|
||||
await asyncio.wait_for(self.trigger_urgent_reconnect.wait(), timeout=retry_delay)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
finally:
|
||||
self.trigger_urgent_reconnect.clear()
|
||||
|
||||
async def ensure_server_version(self, required=None, timeout=3):
|
||||
required = required or self.network.PROTOCOL_VERSION
|
||||
response = await asyncio.wait_for(
|
||||
self.send_request('server.version', [__version__, required]), timeout=timeout
|
||||
)
|
||||
if tuple(int(piece) for piece in response[0].split(".")) < self.network.MINIMUM_REQUIRED:
|
||||
raise IncompatibleWalletServerError(*self.server)
|
||||
return response
|
||||
|
||||
async def create_connection(self, timeout=6):
|
||||
connector = Connector(lambda: self, *self.server)
|
||||
start = perf_counter()
|
||||
await asyncio.wait_for(connector.create_connection(), timeout=timeout)
|
||||
self.connection_latency = perf_counter() - start
|
||||
|
||||
async def handle_request(self, request):
|
||||
controller = self.network.subscription_controllers[request.method]
|
||||
controller.add(request.args)
|
||||
|
||||
def connection_lost(self, exc):
|
||||
log.debug("Connection lost: %s:%d", *self.server)
|
||||
super().connection_lost(exc)
|
||||
self.response_time = None
|
||||
self.connection_latency = None
|
||||
self._response_samples = 0
|
||||
self.pending_amount = 0
|
||||
self._on_disconnect_controller.add(True)
|
||||
|
||||
|
||||
class Network:
|
||||
|
||||
PROTOCOL_VERSION = __version__
|
||||
MINIMUM_REQUIRED = (0, 65, 0)
|
||||
|
||||
def __init__(self, ledger):
|
||||
self.ledger = ledger
|
||||
self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6))
|
||||
self.client: Optional[ClientSession] = None
|
||||
self.server_features = None
|
||||
self._switch_task: Optional[asyncio.Task] = None
|
||||
self.running = False
|
||||
self.remote_height: int = 0
|
||||
self._concurrency = asyncio.Semaphore(16)
|
||||
|
||||
self._on_connected_controller = StreamController()
|
||||
self.on_connected = self._on_connected_controller.stream
|
||||
|
||||
self._on_header_controller = StreamController(merge_repeated_events=True)
|
||||
self.on_header = self._on_header_controller.stream
|
||||
|
||||
self._on_status_controller = StreamController(merge_repeated_events=True)
|
||||
self.on_status = self._on_status_controller.stream
|
||||
|
||||
self.subscription_controllers = {
|
||||
'blockchain.headers.subscribe': self._on_header_controller,
|
||||
'blockchain.address.subscribe': self._on_status_controller,
|
||||
}
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return self.ledger.config
|
||||
|
||||
async def switch_forever(self):
|
||||
while self.running:
|
||||
if self.is_connected:
|
||||
await self.client.on_disconnected.first
|
||||
self.server_features = None
|
||||
self.client = None
|
||||
continue
|
||||
self.client = await self.session_pool.wait_for_fastest_session()
|
||||
log.info("Switching to SPV wallet server: %s:%d", *self.client.server)
|
||||
try:
|
||||
self.server_features = await self.get_server_features()
|
||||
self._update_remote_height((await self.subscribe_headers(),))
|
||||
self._on_connected_controller.add(True)
|
||||
log.info("Subscribed to headers: %s:%d", *self.client.server)
|
||||
except (asyncio.TimeoutError, ConnectionError):
|
||||
log.info("Switching to %s:%d timed out, closing and retrying.", *self.client.server)
|
||||
self.client.synchronous_close()
|
||||
self.server_features = None
|
||||
self.client = None
|
||||
|
||||
async def start(self):
|
||||
self.running = True
|
||||
self._switch_task = asyncio.ensure_future(self.switch_forever())
|
||||
# this may become unnecessary when there are no more bugs found,
|
||||
# but for now it helps understanding log reports
|
||||
self._switch_task.add_done_callback(lambda _: log.info("Wallet client switching task stopped."))
|
||||
self.session_pool.start(self.config['default_servers'])
|
||||
self.on_header.listen(self._update_remote_height)
|
||||
|
||||
async def stop(self):
|
||||
if self.running:
|
||||
self.running = False
|
||||
self._switch_task.cancel()
|
||||
self.session_pool.stop()
|
||||
|
||||
@property
|
||||
def is_connected(self):
|
||||
return self.client and not self.client.is_closing()
|
||||
|
||||
def rpc(self, list_or_method, args, restricted=True):
|
||||
session = self.client if restricted else self.session_pool.fastest_session
|
||||
if session and not session.is_closing():
|
||||
return session.send_request(list_or_method, args)
|
||||
else:
|
||||
self.session_pool.trigger_nodelay_connect()
|
||||
raise ConnectionError("Attempting to send rpc request when connection is not available.")
|
||||
|
||||
async def retriable_call(self, function, *args, **kwargs):
|
||||
async with self._concurrency:
|
||||
while self.running:
|
||||
if not self.is_connected:
|
||||
log.warning("Wallet server unavailable, waiting for it to come back and retry.")
|
||||
await self.on_connected.first
|
||||
await self.session_pool.wait_for_fastest_session()
|
||||
try:
|
||||
return await function(*args, **kwargs)
|
||||
except asyncio.TimeoutError:
|
||||
log.warning("Wallet server call timed out, retrying.")
|
||||
except ConnectionError:
|
||||
pass
|
||||
raise asyncio.CancelledError() # if we got here, we are shutting down
|
||||
|
||||
def _update_remote_height(self, header_args):
|
||||
self.remote_height = header_args[0]["height"]
|
||||
|
||||
def get_transaction(self, tx_hash, known_height=None):
|
||||
# use any server if its old, otherwise restrict to who gave us the history
|
||||
restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10
|
||||
return self.rpc('blockchain.transaction.get', [hexlify(tx_hash[::-1]).decode()], restricted)
|
||||
|
||||
def get_transaction_and_merkle(self, tx_hash, known_height=None):
|
||||
# use any server if its old, otherwise restrict to who gave us the history
|
||||
restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10
|
||||
return self.rpc('blockchain.transaction.info', [hexlify(tx_hash[::-1]).decode()], restricted)
|
||||
|
||||
def get_transaction_height(self, tx_hash, known_height=None):
|
||||
restricted = not known_height or 0 > known_height > self.remote_height - 10
|
||||
return self.rpc('blockchain.transaction.get_height', [hexlify(tx_hash[::-1]).decode()], restricted)
|
||||
|
||||
def get_merkle(self, tx_hash, height):
|
||||
restricted = 0 > height > self.remote_height - 10
|
||||
return self.rpc('blockchain.transaction.get_merkle', [hexlify(tx_hash[::-1]).decode(), height], restricted)
|
||||
|
||||
def get_headers(self, height, count=10000, b64=False):
|
||||
restricted = height >= self.remote_height - 100
|
||||
return self.rpc('blockchain.block.headers', [height, count, 0, b64], restricted)
|
||||
|
||||
# --- Subscribes, history and broadcasts are always aimed towards the master client directly
|
||||
def get_history(self, address):
|
||||
return self.rpc('blockchain.address.get_history', [address], True)
|
||||
|
||||
def broadcast(self, raw_transaction):
|
||||
return self.rpc('blockchain.transaction.broadcast', [raw_transaction], True)
|
||||
|
||||
def subscribe_headers(self):
|
||||
return self.rpc('blockchain.headers.subscribe', [True], True)
|
||||
|
||||
async def subscribe_address(self, address, *addresses):
|
||||
addresses = list((address, ) + addresses)
|
||||
try:
|
||||
return await self.rpc('blockchain.address.subscribe', addresses, True)
|
||||
except asyncio.TimeoutError:
|
||||
log.warning(
|
||||
"timed out subscribing to addresses from %s:%i",
|
||||
*self.client.server_address_and_port
|
||||
)
|
||||
# abort and cancel, we can't lose a subscription, it will happen again on reconnect
|
||||
if self.client:
|
||||
self.client.abort()
|
||||
raise asyncio.CancelledError()
|
||||
|
||||
def unsubscribe_address(self, address):
|
||||
return self.rpc('blockchain.address.unsubscribe', [address], True)
|
||||
|
||||
def get_server_features(self):
|
||||
return self.rpc('server.features', (), restricted=True)
|
||||
|
||||
def get_claims_by_ids(self, claim_ids):
|
||||
return self.rpc('blockchain.claimtrie.getclaimsbyids', claim_ids)
|
||||
|
||||
def resolve(self, urls):
|
||||
return self.rpc('blockchain.claimtrie.resolve', urls)
|
||||
|
||||
def claim_search(self, **kwargs):
|
||||
return self.rpc('blockchain.claimtrie.search', kwargs)
|
||||
|
||||
|
||||
class SessionPool:
|
||||
|
||||
def __init__(self, network: Network, timeout: float):
|
||||
self.network = network
|
||||
self.sessions: Dict[ClientSession, Optional[asyncio.Task]] = dict()
|
||||
self.timeout = timeout
|
||||
self.new_connection_event = asyncio.Event()
|
||||
|
||||
@property
|
||||
def online(self):
|
||||
return any(not session.is_closing() for session in self.sessions)
|
||||
|
||||
@property
|
||||
def available_sessions(self):
|
||||
return (session for session in self.sessions if session.available)
|
||||
|
||||
@property
|
||||
def fastest_session(self):
|
||||
if not self.online:
|
||||
return None
|
||||
return min(
|
||||
[((session.response_time + session.connection_latency) * (session.pending_amount + 1), session)
|
||||
for session in self.available_sessions] or [(0, None)],
|
||||
key=itemgetter(0)
|
||||
)[1]
|
||||
|
||||
def _get_session_connect_callback(self, session: ClientSession):
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def callback():
|
||||
duplicate_connections = [
|
||||
s for s in self.sessions
|
||||
if s is not session and s.server_address_and_port == session.server_address_and_port
|
||||
]
|
||||
already_connected = None if not duplicate_connections else duplicate_connections[0]
|
||||
if already_connected:
|
||||
self.sessions.pop(session).cancel()
|
||||
session.synchronous_close()
|
||||
log.debug("wallet server %s resolves to the same server as %s, rechecking in an hour",
|
||||
session.server[0], already_connected.server[0])
|
||||
loop.call_later(3600, self._connect_session, session.server)
|
||||
return
|
||||
self.new_connection_event.set()
|
||||
log.info("connected to %s:%i", *session.server)
|
||||
|
||||
return callback
|
||||
|
||||
def _connect_session(self, server: Tuple[str, int]):
|
||||
session = None
|
||||
for s in self.sessions:
|
||||
if s.server == server:
|
||||
session = s
|
||||
break
|
||||
if not session:
|
||||
session = ClientSession(
|
||||
network=self.network, server=server
|
||||
)
|
||||
session._on_connect_cb = self._get_session_connect_callback(session)
|
||||
task = self.sessions.get(session, None)
|
||||
if not task or task.done():
|
||||
task = asyncio.create_task(session.ensure_session())
|
||||
task.add_done_callback(lambda _: self.ensure_connections())
|
||||
self.sessions[session] = task
|
||||
|
||||
def start(self, default_servers):
|
||||
for server in default_servers:
|
||||
self._connect_session(server)
|
||||
|
||||
def stop(self):
|
||||
for session, task in self.sessions.items():
|
||||
task.cancel()
|
||||
session.synchronous_close()
|
||||
self.sessions.clear()
|
||||
|
||||
def ensure_connections(self):
|
||||
for session in self.sessions:
|
||||
self._connect_session(session.server)
|
||||
|
||||
def trigger_nodelay_connect(self):
|
||||
# used when other parts of the system sees we might have internet back
|
||||
# bypasses the retry interval
|
||||
for session in self.sessions:
|
||||
session.trigger_urgent_reconnect.set()
|
||||
|
||||
async def wait_for_fastest_session(self):
|
||||
while not self.fastest_session:
|
||||
self.trigger_nodelay_connect()
|
||||
self.new_connection_event.clear()
|
||||
await self.new_connection_event.wait()
|
||||
return self.fastest_session
|
431
lbry/wallet/sync.py
Normal file
431
lbry/wallet/sync.py
Normal file
|
@ -0,0 +1,431 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from io import StringIO
|
||||
from functools import partial
|
||||
from operator import itemgetter
|
||||
from collections import defaultdict
|
||||
from binascii import hexlify, unhexlify
|
||||
from typing import List, Optional, DefaultDict, NamedTuple
|
||||
|
||||
import pylru
|
||||
from lbry.crypto.hash import double_sha256, sha256
|
||||
|
||||
from lbry.service.api import Client
|
||||
from lbry.tasks import TaskGroup
|
||||
from lbry.blockchain.transaction import Transaction
|
||||
from lbry.blockchain.ledger import Ledger
|
||||
from lbry.blockchain.block import get_block_filter
|
||||
from lbry.db import Database
|
||||
from lbry.event import EventController
|
||||
from lbry.service.base import Service, Sync
|
||||
|
||||
from .account import Account, AddressManager
|
||||
|
||||
|
||||
class TransactionEvent(NamedTuple):
|
||||
address: str
|
||||
tx: Transaction
|
||||
|
||||
|
||||
class AddressesGeneratedEvent(NamedTuple):
|
||||
address_manager: AddressManager
|
||||
addresses: List[str]
|
||||
|
||||
|
||||
class TransactionCacheItem:
|
||||
__slots__ = '_tx', 'lock', 'has_tx', 'pending_verifications'
|
||||
|
||||
def __init__(self, tx: Optional[Transaction] = None, lock: Optional[asyncio.Lock] = None):
|
||||
self.has_tx = asyncio.Event()
|
||||
self.lock = lock or asyncio.Lock()
|
||||
self._tx = self.tx = tx
|
||||
self.pending_verifications = 0
|
||||
|
||||
@property
|
||||
def tx(self) -> Optional[Transaction]:
|
||||
return self._tx
|
||||
|
||||
@tx.setter
|
||||
def tx(self, tx: Transaction):
|
||||
self._tx = tx
|
||||
if tx is not None:
|
||||
self.has_tx.set()
|
||||
|
||||
|
||||
class SPVSync(Sync):
|
||||
|
||||
def __init__(self, service: Service):
|
||||
super().__init__(service)
|
||||
return
|
||||
self.headers = headers
|
||||
self.network: Network = self.config.get('network') or Network(self)
|
||||
self.network.on_header.listen(self.receive_header)
|
||||
self.network.on_status.listen(self.process_status_update)
|
||||
self.network.on_connected.listen(self.join_network)
|
||||
|
||||
self.accounts = []
|
||||
|
||||
self.on_address = self.ledger.on_address
|
||||
|
||||
self._on_header_controller = EventController()
|
||||
self.on_header = self._on_header_controller.stream
|
||||
self.on_header.listen(
|
||||
lambda change: log.info(
|
||||
'%s: added %s header blocks, final height %s',
|
||||
self.ledger.get_id(), change, self.headers.height
|
||||
)
|
||||
)
|
||||
self._download_height = 0
|
||||
|
||||
self._on_ready_controller = EventController()
|
||||
self.on_ready = self._on_ready_controller.stream
|
||||
|
||||
self._tx_cache = pylru.lrucache(100000)
|
||||
self._update_tasks = TaskGroup()
|
||||
self._other_tasks = TaskGroup() # that we dont need to start
|
||||
self._header_processing_lock = asyncio.Lock()
|
||||
self._address_update_locks: DefaultDict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||
self._known_addresses_out_of_sync = set()
|
||||
|
||||
async def advance(self):
|
||||
address_array = [
|
||||
bytearray(a['address'].encode())
|
||||
for a in await self.service.db.get_all_addresses()
|
||||
]
|
||||
block_filters = await self.service.get_block_address_filters()
|
||||
for block_hash, block_filter in block_filters.items():
|
||||
bf = get_block_filter(block_filter)
|
||||
if bf.MatchAny(address_array):
|
||||
print(f'match: {block_hash} - {block_filter}')
|
||||
tx_filters = await self.service.get_transaction_address_filters(block_hash=block_hash)
|
||||
for txid, tx_filter in tx_filters.items():
|
||||
tf = get_block_filter(tx_filter)
|
||||
if tf.MatchAny(address_array):
|
||||
print(f' match: {txid} - {tx_filter}')
|
||||
txs = await self.service.search_transactions([txid])
|
||||
tx = Transaction(unhexlify(txs[txid]))
|
||||
await self.service.db.insert_transaction(tx)
|
||||
|
||||
async def get_local_status_and_history(self, address, history=None):
|
||||
if not history:
|
||||
address_details = await self.db.get_address(address=address)
|
||||
history = (address_details['history'] if address_details else '') or ''
|
||||
parts = history.split(':')[:-1]
|
||||
return (
|
||||
hexlify(sha256(history.encode())).decode() if history else None,
|
||||
list(zip(parts[0::2], map(int, parts[1::2])))
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_root_of_merkle_tree(branches, branch_positions, working_branch):
|
||||
for i, branch in enumerate(branches):
|
||||
other_branch = unhexlify(branch)[::-1]
|
||||
other_branch_on_left = bool((branch_positions >> i) & 1)
|
||||
if other_branch_on_left:
|
||||
combined = other_branch + working_branch
|
||||
else:
|
||||
combined = working_branch + other_branch
|
||||
working_branch = double_sha256(combined)
|
||||
return hexlify(working_branch[::-1])
|
||||
|
||||
async def start(self):
|
||||
await self.headers.open()
|
||||
fully_synced = self.on_ready.first
|
||||
asyncio.create_task(self.network.start())
|
||||
await self.network.on_connected.first
|
||||
async with self._header_processing_lock:
|
||||
await self._update_tasks.add(self.initial_headers_sync())
|
||||
await fully_synced
|
||||
|
||||
async def join_network(self, *_):
|
||||
log.info("Subscribing and updating accounts.")
|
||||
await self._update_tasks.add(self.subscribe_accounts())
|
||||
await self._update_tasks.done.wait()
|
||||
self._on_ready_controller.add(True)
|
||||
|
||||
async def stop(self):
|
||||
self._update_tasks.cancel()
|
||||
self._other_tasks.cancel()
|
||||
await self._update_tasks.done.wait()
|
||||
await self._other_tasks.done.wait()
|
||||
await self.network.stop()
|
||||
await self.headers.close()
|
||||
|
||||
@property
|
||||
def local_height_including_downloaded_height(self):
|
||||
return max(self.headers.height, self._download_height)
|
||||
|
||||
async def initial_headers_sync(self):
|
||||
get_chunk = partial(self.network.retriable_call, self.network.get_headers, count=1000, b64=True)
|
||||
self.headers.chunk_getter = get_chunk
|
||||
|
||||
async def doit():
|
||||
for height in reversed(sorted(self.headers.known_missing_checkpointed_chunks)):
|
||||
async with self._header_processing_lock:
|
||||
await self.headers.ensure_chunk_at(height)
|
||||
self._other_tasks.add(doit())
|
||||
await self.update_headers()
|
||||
|
||||
async def update_headers(self, height=None, headers=None, subscription_update=False):
|
||||
rewound = 0
|
||||
while True:
|
||||
|
||||
if height is None or height > len(self.headers):
|
||||
# sometimes header subscription updates are for a header in the future
|
||||
# which can't be connected, so we do a normal header sync instead
|
||||
height = len(self.headers)
|
||||
headers = None
|
||||
subscription_update = False
|
||||
|
||||
if not headers:
|
||||
header_response = await self.network.retriable_call(self.network.get_headers, height, 2001)
|
||||
headers = header_response['hex']
|
||||
|
||||
if not headers:
|
||||
# Nothing to do, network thinks we're already at the latest height.
|
||||
return
|
||||
|
||||
added = await self.headers.connect(height, unhexlify(headers))
|
||||
if added > 0:
|
||||
height += added
|
||||
self._on_header_controller.add(
|
||||
BlockHeightEvent(self.headers.height, added))
|
||||
|
||||
if rewound > 0:
|
||||
# we started rewinding blocks and apparently found
|
||||
# a new chain
|
||||
rewound = 0
|
||||
await self.db.rewind_blockchain(height)
|
||||
|
||||
if subscription_update:
|
||||
# subscription updates are for latest header already
|
||||
# so we don't need to check if there are newer / more
|
||||
# on another loop of update_headers(), just return instead
|
||||
return
|
||||
|
||||
elif added == 0:
|
||||
# we had headers to connect but none got connected, probably a reorganization
|
||||
height -= 1
|
||||
rewound += 1
|
||||
log.warning(
|
||||
"Blockchain Reorganization: attempting rewind to height %s from starting height %s",
|
||||
height, height+rewound
|
||||
)
|
||||
|
||||
else:
|
||||
raise IndexError(f"headers.connect() returned negative number ({added})")
|
||||
|
||||
if height < 0:
|
||||
raise IndexError(
|
||||
"Blockchain reorganization rewound all the way back to genesis hash. "
|
||||
"Something is very wrong. Maybe you are on the wrong blockchain?"
|
||||
)
|
||||
|
||||
if rewound >= 100:
|
||||
raise IndexError(
|
||||
"Blockchain reorganization dropped {} headers. This is highly unusual. "
|
||||
"Will not continue to attempt reorganizing. Please, delete the ledger "
|
||||
"synchronization directory inside your wallet directory (folder: '{}') and "
|
||||
"restart the program to synchronize from scratch."
|
||||
.format(rewound, self.ledger.get_id())
|
||||
)
|
||||
|
||||
headers = None # ready to download some more headers
|
||||
|
||||
# if we made it this far and this was a subscription_update
|
||||
# it means something went wrong and now we're doing a more
|
||||
# robust sync, turn off subscription update shortcut
|
||||
subscription_update = False
|
||||
|
||||
async def receive_header(self, response):
|
||||
async with self._header_processing_lock:
|
||||
header = response[0]
|
||||
await self.update_headers(
|
||||
height=header['height'], headers=header['hex'], subscription_update=True
|
||||
)
|
||||
|
||||
async def subscribe_accounts(self):
|
||||
if self.network.is_connected and self.accounts:
|
||||
log.info("Subscribe to %i accounts", len(self.accounts))
|
||||
await asyncio.wait([
|
||||
self.subscribe_account(a) for a in self.accounts
|
||||
])
|
||||
|
||||
async def subscribe_account(self, account: Account):
|
||||
for address_manager in account.address_managers.values():
|
||||
await self.subscribe_addresses(address_manager, await address_manager.get_addresses())
|
||||
await account.ensure_address_gap()
|
||||
|
||||
async def unsubscribe_account(self, account: Account):
|
||||
for address in await account.get_addresses():
|
||||
await self.network.unsubscribe_address(address)
|
||||
|
||||
async def subscribe_addresses(self, address_manager: AddressManager, addresses: List[str], batch_size: int = 1000):
|
||||
if self.network.is_connected and addresses:
|
||||
addresses_remaining = list(addresses)
|
||||
while addresses_remaining:
|
||||
batch = addresses_remaining[:batch_size]
|
||||
results = await self.network.subscribe_address(*batch)
|
||||
for address, remote_status in zip(batch, results):
|
||||
self._update_tasks.add(self.update_history(address, remote_status, address_manager))
|
||||
addresses_remaining = addresses_remaining[batch_size:]
|
||||
log.info("subscribed to %i/%i addresses on %s:%i", len(addresses) - len(addresses_remaining),
|
||||
len(addresses), *self.network.client.server_address_and_port)
|
||||
log.info(
|
||||
"finished subscribing to %i addresses on %s:%i", len(addresses),
|
||||
*self.network.client.server_address_and_port
|
||||
)
|
||||
|
||||
def process_status_update(self, update):
|
||||
address, remote_status = update
|
||||
self._update_tasks.add(self.update_history(address, remote_status))
|
||||
|
||||
async def update_history(self, address, remote_status, address_manager: AddressManager = None):
|
||||
async with self._address_update_locks[address]:
|
||||
self._known_addresses_out_of_sync.discard(address)
|
||||
|
||||
local_status, local_history = await self.get_local_status_and_history(address)
|
||||
|
||||
if local_status == remote_status:
|
||||
return True
|
||||
|
||||
remote_history = await self.network.retriable_call(self.network.get_history, address)
|
||||
remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history))
|
||||
we_need = set(remote_history) - set(local_history)
|
||||
if not we_need:
|
||||
return True
|
||||
|
||||
cache_tasks: List[asyncio.Task[Transaction]] = []
|
||||
synced_history = StringIO()
|
||||
loop = asyncio.get_running_loop()
|
||||
for i, (txid, remote_height) in enumerate(remote_history):
|
||||
if i < len(local_history) and local_history[i] == (txid, remote_height) and not cache_tasks:
|
||||
synced_history.write(f'{txid}:{remote_height}:')
|
||||
else:
|
||||
check_local = (txid, remote_height) not in we_need
|
||||
cache_tasks.append(loop.create_task(
|
||||
self.cache_transaction(unhexlify(txid)[::-1], remote_height, check_local=check_local)
|
||||
))
|
||||
|
||||
synced_txs = []
|
||||
for task in cache_tasks:
|
||||
tx = await task
|
||||
|
||||
check_db_for_txos = []
|
||||
for txi in tx.inputs:
|
||||
if txi.txo_ref.txo is not None:
|
||||
continue
|
||||
cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.hash)
|
||||
if cache_item is not None:
|
||||
if cache_item.tx is None:
|
||||
await cache_item.has_tx.wait()
|
||||
assert cache_item.tx is not None
|
||||
txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref
|
||||
else:
|
||||
check_db_for_txos.append(txi.txo_ref.hash)
|
||||
|
||||
referenced_txos = {} if not check_db_for_txos else {
|
||||
txo.id: txo for txo in await self.db.get_txos(
|
||||
txo_hash__in=check_db_for_txos, order_by='txo.txo_hash', no_tx=True
|
||||
)
|
||||
}
|
||||
|
||||
for txi in tx.inputs:
|
||||
if txi.txo_ref.txo is not None:
|
||||
continue
|
||||
referenced_txo = referenced_txos.get(txi.txo_ref.id)
|
||||
if referenced_txo is not None:
|
||||
txi.txo_ref = referenced_txo.ref
|
||||
|
||||
synced_history.write(f'{tx.id}:{tx.height}:')
|
||||
synced_txs.append(tx)
|
||||
|
||||
await self.db.save_transaction_io_batch(
|
||||
synced_txs, address, self.ledger.address_to_hash160(address), synced_history.getvalue()
|
||||
)
|
||||
await asyncio.wait([
|
||||
self.ledger._on_transaction_controller.add(TransactionEvent(address, tx))
|
||||
for tx in synced_txs
|
||||
])
|
||||
|
||||
if address_manager is None:
|
||||
address_manager = await self.get_address_manager_for_address(address)
|
||||
|
||||
if address_manager is not None:
|
||||
await address_manager.ensure_address_gap()
|
||||
|
||||
local_status, local_history = \
|
||||
await self.get_local_status_and_history(address, synced_history.getvalue())
|
||||
if local_status != remote_status:
|
||||
if local_history == remote_history:
|
||||
return True
|
||||
log.warning(
|
||||
"Wallet is out of sync after syncing. Remote: %s with %d items, local: %s with %d items",
|
||||
remote_status, len(remote_history), local_status, len(local_history)
|
||||
)
|
||||
log.warning("local: %s", local_history)
|
||||
log.warning("remote: %s", remote_history)
|
||||
self._known_addresses_out_of_sync.add(address)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
async def cache_transaction(self, tx_hash, remote_height, check_local=True):
|
||||
cache_item = self._tx_cache.get(tx_hash)
|
||||
if cache_item is None:
|
||||
cache_item = self._tx_cache[tx_hash] = TransactionCacheItem()
|
||||
elif cache_item.tx is not None and \
|
||||
cache_item.tx.height >= remote_height and \
|
||||
(cache_item.tx.is_verified or remote_height < 1):
|
||||
return cache_item.tx # cached tx is already up-to-date
|
||||
|
||||
try:
|
||||
cache_item.pending_verifications += 1
|
||||
return await self._update_cache_item(cache_item, tx_hash, remote_height, check_local)
|
||||
finally:
|
||||
cache_item.pending_verifications -= 1
|
||||
|
||||
async def _update_cache_item(self, cache_item, tx_hash, remote_height, check_local=True):
|
||||
|
||||
async with cache_item.lock:
|
||||
|
||||
tx = cache_item.tx
|
||||
|
||||
if tx is None and check_local:
|
||||
# check local db
|
||||
tx = cache_item.tx = await self.db.get_transaction(tx_hash=tx_hash)
|
||||
|
||||
merkle = None
|
||||
if tx is None:
|
||||
# fetch from network
|
||||
_raw, merkle = await self.network.retriable_call(
|
||||
self.network.get_transaction_and_merkle, tx_hash, remote_height
|
||||
)
|
||||
tx = Transaction(unhexlify(_raw), height=merkle.get('block_height'))
|
||||
cache_item.tx = tx # make sure it's saved before caching it
|
||||
await self.maybe_verify_transaction(tx, remote_height, merkle)
|
||||
return tx
|
||||
|
||||
async def maybe_verify_transaction(self, tx, remote_height, merkle=None):
|
||||
tx.height = remote_height
|
||||
cached = self._tx_cache.get(tx.hash)
|
||||
if not cached:
|
||||
# cache txs looked up by transaction_show too
|
||||
cached = TransactionCacheItem()
|
||||
cached.tx = tx
|
||||
self._tx_cache[tx.hash] = cached
|
||||
if 0 < remote_height < len(self.headers) and cached.pending_verifications <= 1:
|
||||
# can't be tx.pending_verifications == 1 because we have to handle the transaction_show case
|
||||
if not merkle:
|
||||
merkle = await self.network.retriable_call(self.network.get_merkle, tx.hash, remote_height)
|
||||
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
|
||||
header = await self.headers.get(remote_height)
|
||||
tx.position = merkle['pos']
|
||||
tx.is_verified = merkle_root == header['merkle_root']
|
||||
|
||||
async def get_address_manager_for_address(self, address) -> Optional[AddressManager]:
|
||||
details = await self.db.get_address(address=address)
|
||||
for account in self.accounts:
|
||||
if account.id == details['account']:
|
||||
return account.address_managers[details['chain']]
|
||||
return None
|
|
@ -6,9 +6,9 @@ from lbry.error import (
|
|||
ServerPaymentInvalidAddressError,
|
||||
ServerPaymentWalletLockedError
|
||||
)
|
||||
from lbry.wallet.dewies import lbc_to_dewies
|
||||
from lbry.wallet.stream import StreamController
|
||||
from lbry.wallet.transaction import Output, Transaction
|
||||
from lbry.blockchain.dewies import lbc_to_dewies
|
||||
from lbry.event import EventController
|
||||
from lbry.blockchain.transaction import Output, Transaction
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -22,7 +22,7 @@ class WalletServerPayer:
|
|||
self.payment_period = payment_period
|
||||
self.analytics_manager = analytics_manager
|
||||
self.max_fee = max_fee
|
||||
self._on_payment_controller = StreamController()
|
||||
self._on_payment_controller = EventController()
|
||||
self.on_payment = self._on_payment_controller.stream
|
||||
self.on_payment.listen(None, on_error=lambda e: logging.warning(e.args[0]))
|
||||
|
||||
|
|
|
@ -5,16 +5,28 @@ import json
|
|||
import zlib
|
||||
import typing
|
||||
import logging
|
||||
from typing import List, Sequence, MutableSequence, Optional
|
||||
from typing import List, Sequence, MutableSequence, Optional, Iterable
|
||||
from collections import UserDict
|
||||
from hashlib import sha256
|
||||
from operator import attrgetter
|
||||
from decimal import Decimal
|
||||
|
||||
from lbry.db import Database
|
||||
from lbry.blockchain.ledger import Ledger
|
||||
from lbry.constants import COIN, NULL_HASH32
|
||||
from lbry.blockchain.transaction import Transaction, Input, Output
|
||||
from lbry.blockchain.dewies import dewies_to_lbc
|
||||
from lbry.crypto.crypt import better_aes_encrypt, better_aes_decrypt
|
||||
from lbry.crypto.bip32 import PubKey, PrivateKey
|
||||
from lbry.schema.claim import Claim
|
||||
from lbry.schema.purchase import Purchase
|
||||
from lbry.error import InsufficientFundsError, KeyFeeAboveMaxAllowedError
|
||||
|
||||
from .account import Account
|
||||
from .coinselection import CoinSelector, OutputEffectiveAmountEstimator
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from lbry.wallet.manager import WalletManager
|
||||
from lbry.wallet.ledger import Ledger
|
||||
from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -67,8 +79,11 @@ class Wallet:
|
|||
preferences: TimestampedPreferences
|
||||
encryption_password: Optional[str]
|
||||
|
||||
def __init__(self, name: str = 'Wallet', accounts: MutableSequence['Account'] = None,
|
||||
def __init__(self, ledger: Ledger, db: Database,
|
||||
name: str = 'Wallet', accounts: MutableSequence[Account] = None,
|
||||
storage: 'WalletStorage' = None, preferences: dict = None) -> None:
|
||||
self.ledger = ledger
|
||||
self.db = db
|
||||
self.name = name
|
||||
self.accounts = accounts or []
|
||||
self.storage = storage or WalletStorage()
|
||||
|
@ -79,30 +94,34 @@ class Wallet:
|
|||
def get_id(self):
|
||||
return os.path.basename(self.storage.path) if self.storage.path else self.name
|
||||
|
||||
def add_account(self, account: 'Account'):
|
||||
def generate_account(self, name: str = None, address_generator: dict = None) -> Account:
|
||||
account = Account.generate(self.ledger, self.db, name, address_generator)
|
||||
self.accounts.append(account)
|
||||
return account
|
||||
|
||||
def generate_account(self, ledger: 'Ledger') -> 'Account':
|
||||
return Account.generate(ledger, self)
|
||||
def add_account(self, account_dict) -> Account:
|
||||
account = Account.from_dict(self.ledger, self.db, account_dict)
|
||||
self.accounts.append(account)
|
||||
return account
|
||||
|
||||
@property
|
||||
def default_account(self) -> Optional['Account']:
|
||||
def default_account(self) -> Optional[Account]:
|
||||
for account in self.accounts:
|
||||
return account
|
||||
return None
|
||||
|
||||
def get_account_or_default(self, account_id: str) -> Optional['Account']:
|
||||
def get_account_or_default(self, account_id: str) -> Optional[Account]:
|
||||
if account_id is None:
|
||||
return self.default_account
|
||||
return self.get_account_or_error(account_id)
|
||||
|
||||
def get_account_or_error(self, account_id: str) -> 'Account':
|
||||
def get_account_or_error(self, account_id: str) -> Account:
|
||||
for account in self.accounts:
|
||||
if account.id == account_id:
|
||||
return account
|
||||
raise ValueError(f"Couldn't find account: {account_id}.")
|
||||
|
||||
def get_accounts_or_all(self, account_ids: List[str]) -> Sequence['Account']:
|
||||
def get_accounts_or_all(self, account_ids: List[str]) -> Sequence[Account]:
|
||||
return [
|
||||
self.get_account_or_error(account_id)
|
||||
for account_id in account_ids
|
||||
|
@ -116,24 +135,63 @@ class Wallet:
|
|||
accounts.append(details)
|
||||
return accounts
|
||||
|
||||
async def _get_account_and_address_info_for_address(self, address):
|
||||
match = await self.db.get_address(accounts=self.accounts, address=address)
|
||||
if match:
|
||||
for account in self.accounts:
|
||||
if match['account'] == account.public_key.address:
|
||||
return account, match
|
||||
|
||||
async def get_private_key_for_address(self, address) -> Optional[PrivateKey]:
|
||||
match = await self._get_account_and_address_info_for_address(address)
|
||||
if match:
|
||||
account, address_info = match
|
||||
return account.get_private_key(address_info['chain'], address_info['pubkey'].n)
|
||||
return None
|
||||
|
||||
async def get_public_key_for_address(self, address) -> Optional[PubKey]:
|
||||
match = await self._get_account_and_address_info_for_address(address)
|
||||
if match:
|
||||
_, address_info = match
|
||||
return address_info['pubkey']
|
||||
return None
|
||||
|
||||
async def get_account_for_address(self, address):
|
||||
match = await self._get_account_and_address_info_for_address(address)
|
||||
if match:
|
||||
return match[0]
|
||||
|
||||
async def save_max_gap(self):
|
||||
gap_changed = False
|
||||
for account in self.accounts:
|
||||
if await account.save_max_gap():
|
||||
gap_changed = True
|
||||
if gap_changed:
|
||||
self.save()
|
||||
|
||||
@classmethod
|
||||
def from_storage(cls, storage: 'WalletStorage', manager: 'WalletManager') -> 'Wallet':
|
||||
def from_storage(cls, ledger: Ledger, db: Database, storage: 'WalletStorage') -> 'Wallet':
|
||||
json_dict = storage.read()
|
||||
if 'ledger' in json_dict and json_dict['ledger'] != ledger.get_id():
|
||||
raise ValueError(
|
||||
f"Using ledger {ledger.get_id()} but wallet is {json_dict['ledger']}."
|
||||
)
|
||||
wallet = cls(
|
||||
ledger, db,
|
||||
name=json_dict.get('name', 'Wallet'),
|
||||
preferences=json_dict.get('preferences', {}),
|
||||
storage=storage
|
||||
)
|
||||
account_dicts: Sequence[dict] = json_dict.get('accounts', [])
|
||||
for account_dict in account_dicts:
|
||||
ledger = manager.get_or_create_ledger(account_dict['ledger'])
|
||||
Account.from_dict(ledger, wallet, account_dict)
|
||||
wallet.add_account(account_dict)
|
||||
return wallet
|
||||
|
||||
def to_dict(self, encrypt_password: str = None):
|
||||
return {
|
||||
'version': WalletStorage.LATEST_VERSION,
|
||||
'name': self.name,
|
||||
'ledger': self.ledger.get_id(),
|
||||
'preferences': self.preferences.data,
|
||||
'accounts': [a.to_dict(encrypt_password) for a in self.accounts]
|
||||
}
|
||||
|
@ -173,15 +231,13 @@ class Wallet:
|
|||
decompressed = zlib.decompress(decrypted)
|
||||
return json.loads(decompressed)
|
||||
|
||||
def merge(self, manager: 'WalletManager',
|
||||
password: str, data: str) -> List['Account']:
|
||||
def merge(self, password: str, data: str) -> List[Account]:
|
||||
assert not self.is_locked, "Cannot sync apply on a locked wallet."
|
||||
added_accounts = []
|
||||
decrypted_data = self.unpack(password, data)
|
||||
self.preferences.merge(decrypted_data.get('preferences', {}))
|
||||
for account_dict in decrypted_data['accounts']:
|
||||
ledger = manager.get_or_create_ledger(account_dict['ledger'])
|
||||
_, _, pubkey = Account.keys_from_dict(ledger, account_dict)
|
||||
_, _, pubkey = Account.keys_from_dict(self.ledger, account_dict)
|
||||
account_id = pubkey.address
|
||||
local_match = None
|
||||
for local_account in self.accounts:
|
||||
|
@ -191,8 +247,9 @@ class Wallet:
|
|||
if local_match is not None:
|
||||
local_match.merge(account_dict)
|
||||
else:
|
||||
new_account = Account.from_dict(ledger, self, account_dict)
|
||||
added_accounts.append(new_account)
|
||||
added_accounts.append(
|
||||
self.add_account(account_dict)
|
||||
)
|
||||
return added_accounts
|
||||
|
||||
@property
|
||||
|
@ -235,6 +292,203 @@ class Wallet:
|
|||
self.save()
|
||||
return True
|
||||
|
||||
async def get_effective_amount_estimators(self, funding_accounts: Iterable[Account]):
|
||||
estimators = []
|
||||
for utxo in (await self.db.get_utxos(accounts=funding_accounts))[0]:
|
||||
estimators.append(OutputEffectiveAmountEstimator(self.ledger, utxo))
|
||||
return estimators
|
||||
|
||||
async def get_spendable_utxos(self, amount: int, funding_accounts: Iterable[Account]):
|
||||
txos = await self.get_effective_amount_estimators(funding_accounts)
|
||||
fee = Output.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(self.ledger)
|
||||
selector = CoinSelector(amount, fee)
|
||||
spendables = selector.select(txos, self.ledger.coin_selection_strategy)
|
||||
if spendables:
|
||||
await self.db.reserve_outputs(s.txo for s in spendables)
|
||||
return spendables
|
||||
|
||||
async def create_transaction(self, inputs: Iterable[Input], outputs: Iterable[Output],
|
||||
funding_accounts: Iterable[Account], change_account: Account,
|
||||
sign: bool = True):
|
||||
""" Find optimal set of inputs when only outputs are provided; add change
|
||||
outputs if only inputs are provided or if inputs are greater than outputs. """
|
||||
|
||||
tx = Transaction() \
|
||||
.add_inputs(inputs) \
|
||||
.add_outputs(outputs)
|
||||
|
||||
# value of the outputs plus associated fees
|
||||
cost = (
|
||||
tx.get_base_fee(self.ledger) +
|
||||
tx.get_total_output_sum(self.ledger)
|
||||
)
|
||||
# value of the inputs less the cost to spend those inputs
|
||||
payment = tx.get_effective_input_sum(self.ledger)
|
||||
|
||||
try:
|
||||
|
||||
for _ in range(5):
|
||||
|
||||
if payment < cost:
|
||||
deficit = cost - payment
|
||||
spendables = await self.get_spendable_utxos(deficit, funding_accounts)
|
||||
if not spendables:
|
||||
raise InsufficientFundsError()
|
||||
payment += sum(s.effective_amount for s in spendables)
|
||||
tx.add_inputs(s.txi for s in spendables)
|
||||
|
||||
cost_of_change = (
|
||||
tx.get_base_fee(self.ledger) +
|
||||
Output.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(self.ledger)
|
||||
)
|
||||
if payment > cost:
|
||||
change = payment - cost
|
||||
if change > cost_of_change:
|
||||
change_address = await change_account.change.get_or_create_usable_address()
|
||||
change_hash160 = change_account.ledger.address_to_hash160(change_address)
|
||||
change_amount = change - cost_of_change
|
||||
change_output = Output.pay_pubkey_hash(change_amount, change_hash160)
|
||||
change_output.is_internal_transfer = True
|
||||
tx.add_outputs([Output.pay_pubkey_hash(change_amount, change_hash160)])
|
||||
|
||||
if tx._outputs:
|
||||
break
|
||||
# this condition and the outer range(5) loop cover an edge case
|
||||
# whereby a single input is just enough to cover the fee and
|
||||
# has some change left over, but the change left over is less
|
||||
# than the cost_of_change: thus the input is completely
|
||||
# consumed and no output is added, which is an invalid tx.
|
||||
# to be able to spend this input we must increase the cost
|
||||
# of the TX and run through the balance algorithm a second time
|
||||
# adding an extra input and change output, making tx valid.
|
||||
# we do this 5 times in case the other UTXOs added are also
|
||||
# less than the fee, after 5 attempts we give up and go home
|
||||
cost += cost_of_change + 1
|
||||
|
||||
if sign:
|
||||
await self.sign(tx)
|
||||
|
||||
except Exception as e:
|
||||
log.exception('Failed to create transaction:')
|
||||
await self.db.release_tx(tx)
|
||||
raise e
|
||||
|
||||
return tx
|
||||
|
||||
async def sign(self, tx):
|
||||
for i, txi in enumerate(tx._inputs):
|
||||
assert txi.script is not None
|
||||
assert txi.txo_ref.txo is not None
|
||||
txo_script = txi.txo_ref.txo.script
|
||||
if txo_script.is_pay_pubkey_hash:
|
||||
address = self.ledger.hash160_to_address(txo_script.values['pubkey_hash'])
|
||||
private_key = await self.get_private_key_for_address(address)
|
||||
assert private_key is not None, 'Cannot find private key for signing output.'
|
||||
serialized = tx._serialize_for_signature(i)
|
||||
txi.script.values['signature'] = \
|
||||
private_key.sign(serialized) + bytes((tx.signature_hash_type(1),))
|
||||
txi.script.values['pubkey'] = private_key.public_key.pubkey_bytes
|
||||
txi.script.generate()
|
||||
else:
|
||||
raise NotImplementedError("Don't know how to spend this output.")
|
||||
tx._reset()
|
||||
|
||||
@classmethod
|
||||
def pay(cls, amount: int, address: bytes, funding_accounts: List['Account'], change_account: 'Account'):
|
||||
output = Output.pay_pubkey_hash(amount, ledger.address_to_hash160(address))
|
||||
return cls.create([], [output], funding_accounts, change_account)
|
||||
|
||||
def claim_create(
|
||||
self, name: str, claim: Claim, amount: int, holding_address: str,
|
||||
funding_accounts: List['Account'], change_account: 'Account', signing_channel: Output = None):
|
||||
claim_output = Output.pay_claim_name_pubkey_hash(
|
||||
amount, name, claim, self.ledger.address_to_hash160(holding_address)
|
||||
)
|
||||
if signing_channel is not None:
|
||||
claim_output.sign(signing_channel, b'placeholder txid:nout')
|
||||
return self.create_transaction(
|
||||
[], [claim_output], funding_accounts, change_account, sign=False
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def claim_update(
|
||||
cls, previous_claim: Output, claim: Claim, amount: int, holding_address: str,
|
||||
funding_accounts: List['Account'], change_account: 'Account', signing_channel: Output = None):
|
||||
updated_claim = Output.pay_update_claim_pubkey_hash(
|
||||
amount, previous_claim.claim_name, previous_claim.claim_id,
|
||||
claim, ledger.address_to_hash160(holding_address)
|
||||
)
|
||||
if signing_channel is not None:
|
||||
updated_claim.sign(signing_channel, b'placeholder txid:nout')
|
||||
else:
|
||||
updated_claim.clear_signature()
|
||||
return cls.create(
|
||||
[Input.spend(previous_claim)], [updated_claim], funding_accounts, change_account, sign=False
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def support(cls, claim_name: str, claim_id: str, amount: int, holding_address: str,
|
||||
funding_accounts: List['Account'], change_account: 'Account'):
|
||||
support_output = Output.pay_support_pubkey_hash(
|
||||
amount, claim_name, claim_id, ledger.address_to_hash160(holding_address)
|
||||
)
|
||||
return cls.create([], [support_output], funding_accounts, change_account)
|
||||
|
||||
def purchase(self, claim_id: str, amount: int, merchant_address: bytes,
|
||||
funding_accounts: List['Account'], change_account: 'Account'):
|
||||
payment = Output.pay_pubkey_hash(amount, self.ledger.address_to_hash160(merchant_address))
|
||||
data = Output.add_purchase_data(Purchase(claim_id))
|
||||
return self.create_transaction(
|
||||
[], [payment, data], funding_accounts, change_account
|
||||
)
|
||||
|
||||
async def create_purchase_transaction(
|
||||
self, accounts: List[Account], txo: Output, exchange: 'ExchangeRateManager',
|
||||
override_max_key_fee=False):
|
||||
fee = txo.claim.stream.fee
|
||||
fee_amount = exchange.to_dewies(fee.currency, fee.amount)
|
||||
if not override_max_key_fee and self.ledger.conf.max_key_fee:
|
||||
max_fee = self.ledger.conf.max_key_fee
|
||||
max_fee_amount = exchange.to_dewies(max_fee['currency'], Decimal(max_fee['amount']))
|
||||
if max_fee_amount and fee_amount > max_fee_amount:
|
||||
error_fee = f"{dewies_to_lbc(fee_amount)} LBC"
|
||||
if fee.currency != 'LBC':
|
||||
error_fee += f" ({fee.amount} {fee.currency})"
|
||||
error_max_fee = f"{dewies_to_lbc(max_fee_amount)} LBC"
|
||||
if max_fee['currency'] != 'LBC':
|
||||
error_max_fee += f" ({max_fee['amount']} {max_fee['currency']})"
|
||||
raise KeyFeeAboveMaxAllowedError(
|
||||
f"Purchase price of {error_fee} exceeds maximum "
|
||||
f"configured price of {error_max_fee}."
|
||||
)
|
||||
fee_address = fee.address or txo.get_address(self.ledger)
|
||||
return await self.purchase(
|
||||
txo.claim_id, fee_amount, fee_address, accounts, accounts[0]
|
||||
)
|
||||
|
||||
async def create_channel(
|
||||
self, name, amount, account, funding_accounts,
|
||||
claim_address, preview=False, **kwargs):
|
||||
|
||||
claim = Claim()
|
||||
claim.channel.update(**kwargs)
|
||||
tx = await self.claim_create(
|
||||
name, claim, amount, claim_address, funding_accounts, funding_accounts[0]
|
||||
)
|
||||
txo = tx.outputs[0]
|
||||
txo.generate_channel_private_key()
|
||||
|
||||
await self.sign(tx)
|
||||
|
||||
if not preview:
|
||||
account.add_channel_private_key(txo.private_key)
|
||||
self.save()
|
||||
|
||||
return tx
|
||||
|
||||
async def get_channels(self):
|
||||
return await self.db.get_channels()
|
||||
|
||||
|
||||
class WalletStorage:
|
||||
|
||||
|
|
Loading…
Reference in a new issue