refactored lbry.wallet

This commit is contained in:
Lex Berezhny 2020-05-06 10:53:31 -04:00
parent db89607e4e
commit 8c91777e5d
10 changed files with 883 additions and 710 deletions

View file

@ -54,7 +54,8 @@ class API:
def __init__(self, service: Service): def __init__(self, service: Service):
self.service = service self.service = service
self.wallet_manager = service.wallet_manager self.wallets = service.wallets
self.ledger = service.ledger
async def stop(self): async def stop(self):
""" """
@ -291,7 +292,7 @@ class API:
if isinstance(urls, str): if isinstance(urls, str):
urls = [urls] urls = [urls]
return await self.service.resolve( return await self.service.resolve(
urls, wallet=self.wallet_manager.get_wallet_or_default(wallet_id), **kwargs urls, wallet=self.wallets.get_or_default(wallet_id), **kwargs
) )
async def get( async def get(
@ -317,7 +318,7 @@ class API:
""" """
return await self.service.get( return await self.service.get(
uri, file_name=file_name, download_directory=download_directory, timeout=timeout, save_file=save_file, uri, file_name=file_name, download_directory=download_directory, timeout=timeout, save_file=save_file,
wallet=self.wallet_manager.get_wallet_or_default(wallet_id) wallet=self.wallets.get_or_default(wallet_id)
) )
SETTINGS_DOC = """ SETTINGS_DOC = """
@ -396,7 +397,7 @@ class API:
Returns: Returns:
(dict) Dictionary of preference(s) (dict) Dictionary of preference(s)
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
if key: if key:
if key in wallet.preferences: if key in wallet.preferences:
return {key: wallet.preferences[key]} return {key: wallet.preferences[key]}
@ -418,7 +419,7 @@ class API:
Returns: Returns:
(dict) Dictionary with key/value of new preference (dict) Dictionary with key/value of new preference
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
if value and isinstance(value, str) and value[0] in ('[', '{'): if value and isinstance(value, str) and value[0] in ('[', '{'):
value = json.loads(value) value = json.loads(value)
wallet.preferences[key] = value wallet.preferences[key] = value
@ -444,8 +445,8 @@ class API:
Returns: {Paginated[Wallet]} Returns: {Paginated[Wallet]}
""" """
if wallet_id: if wallet_id:
return paginate_list([self.wallet_manager.get_wallet_or_error(wallet_id)], 1, 1) return paginate_list([self.wallets.get_wallet_or_error(wallet_id)], 1, 1)
return paginate_list(self.wallet_manager.wallets, page, page_size) return paginate_list(self.wallets.wallets, page, page_size)
async def wallet_reconnect(self): async def wallet_reconnect(self):
""" """
@ -458,7 +459,7 @@ class API:
Returns: None Returns: None
""" """
return self.wallet_manager.reset() return self.wallets.reset()
async def wallet_create( async def wallet_create(
self, wallet_id, skip_on_startup=False, create_account=False, single_key=False): self, wallet_id, skip_on_startup=False, create_account=False, single_key=False):
@ -478,13 +479,13 @@ class API:
Returns: {Wallet} Returns: {Wallet}
""" """
wallet_path = os.path.join(self.conf.wallet_dir, 'wallets', wallet_id) wallet_path = os.path.join(self.conf.wallet_dir, 'wallets', wallet_id)
for wallet in self.wallet_manager.wallets: for wallet in self.wallets.wallets:
if wallet.id == wallet_id: if wallet.id == wallet_id:
raise Exception(f"Wallet at path '{wallet_path}' already exists and is loaded.") raise Exception(f"Wallet at path '{wallet_path}' already exists and is loaded.")
if os.path.exists(wallet_path): if os.path.exists(wallet_path):
raise Exception(f"Wallet at path '{wallet_path}' already exists, use 'wallet_add' to load wallet.") raise Exception(f"Wallet at path '{wallet_path}' already exists, use 'wallet_add' to load wallet.")
wallet = self.wallet_manager.import_wallet(wallet_path) wallet = self.wallets.import_wallet(wallet_path)
if not wallet.accounts and create_account: if not wallet.accounts and create_account:
account = Account.generate( account = Account.generate(
self.ledger, wallet, address_generator={ self.ledger, wallet, address_generator={
@ -512,12 +513,12 @@ class API:
Returns: {Wallet} Returns: {Wallet}
""" """
wallet_path = os.path.join(self.conf.wallet_dir, 'wallets', wallet_id) wallet_path = os.path.join(self.conf.wallet_dir, 'wallets', wallet_id)
for wallet in self.wallet_manager.wallets: for wallet in self.wallets.wallets:
if wallet.id == wallet_id: if wallet.id == wallet_id:
raise Exception(f"Wallet at path '{wallet_path}' is already loaded.") raise Exception(f"Wallet at path '{wallet_path}' is already loaded.")
if not os.path.exists(wallet_path): if not os.path.exists(wallet_path):
raise Exception(f"Wallet at path '{wallet_path}' was not found.") raise Exception(f"Wallet at path '{wallet_path}' was not found.")
wallet = self.wallet_manager.import_wallet(wallet_path) wallet = self.wallets.import_wallet(wallet_path)
if self.ledger.sync.network.is_connected: if self.ledger.sync.network.is_connected:
for account in wallet.accounts: for account in wallet.accounts:
await self.ledger.subscribe_account(account) await self.ledger.subscribe_account(account)
@ -535,8 +536,8 @@ class API:
Returns: {Wallet} Returns: {Wallet}
""" """
wallet = self.wallet_manager.get_wallet_or_error(wallet_id) wallet = self.wallets.get_wallet_or_error(wallet_id)
self.wallet_manager.wallets.remove(wallet) self.wallets.wallets.remove(wallet)
for account in wallet.accounts: for account in wallet.accounts:
await self.ledger.unsubscribe_account(account) await self.ledger.unsubscribe_account(account)
return wallet return wallet
@ -556,7 +557,7 @@ class API:
Returns: Returns:
(decimal) amount of lbry credits in wallet (decimal) amount of lbry credits in wallet
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
balance = await self.ledger.get_detailed_balance( balance = await self.ledger.get_detailed_balance(
accounts=wallet.accounts, confirmations=confirmations accounts=wallet.accounts, confirmations=confirmations
) )
@ -575,9 +576,9 @@ class API:
Returns: Returns:
Dictionary of wallet status information. Dictionary of wallet status information.
""" """
if self.wallet_manager is None: if self.wallets is None:
return {'is_encrypted': None, 'is_syncing': None, 'is_locked': None} return {'is_encrypted': None, 'is_syncing': None, 'is_locked': None}
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
return { return {
'is_encrypted': wallet.is_encrypted, 'is_encrypted': wallet.is_encrypted,
'is_syncing': len(self.ledger._update_tasks) > 0, 'is_syncing': len(self.ledger._update_tasks) > 0,
@ -598,7 +599,7 @@ class API:
Returns: Returns:
(bool) true if wallet is unlocked, otherwise false (bool) true if wallet is unlocked, otherwise false
""" """
return self.wallet_manager.get_wallet_or_default(wallet_id).unlock(password) return self.wallets.get_or_default(wallet_id).unlock(password)
async def wallet_lock(self, wallet_id=None): async def wallet_lock(self, wallet_id=None):
""" """
@ -613,7 +614,7 @@ class API:
Returns: Returns:
(bool) true if wallet is locked, otherwise false (bool) true if wallet is locked, otherwise false
""" """
return self.wallet_manager.get_wallet_or_default(wallet_id).lock() return self.wallets.get_or_default(wallet_id).lock()
async def wallet_decrypt(self, wallet_id=None): async def wallet_decrypt(self, wallet_id=None):
""" """
@ -628,7 +629,7 @@ class API:
Returns: Returns:
(bool) true if wallet is decrypted, otherwise false (bool) true if wallet is decrypted, otherwise false
""" """
return self.wallet_manager.get_wallet_or_default(wallet_id).decrypt() return self.wallets.get_or_default(wallet_id).decrypt()
async def wallet_encrypt(self, new_password, wallet_id=None): async def wallet_encrypt(self, new_password, wallet_id=None):
""" """
@ -645,7 +646,7 @@ class API:
Returns: Returns:
(bool) true if wallet is decrypted, otherwise false (bool) true if wallet is decrypted, otherwise false
""" """
return self.wallet_manager.get_wallet_or_default(wallet_id).encrypt(new_password) return self.wallets.get_or_default(wallet_id).encrypt(new_password)
async def wallet_send( async def wallet_send(
self, amount, addresses, wallet_id=None, self, amount, addresses, wallet_id=None,
@ -666,10 +667,10 @@ class API:
Returns: {Transaction} Returns: {Transaction}
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
assert not wallet.is_locked, "Cannot spend funds with locked wallet, unlock first." assert not wallet.is_locked, "Cannot spend funds with locked wallet, unlock first."
account = wallet.get_account_or_default(change_account_id) account = wallet.accounts.get_or_default(change_account_id)
accounts = wallet.get_accounts_or_all(funding_account_ids) accounts = wallet.accounts.get_or_all(funding_account_ids)
amount = self.get_dewies_or_error("amount", amount) amount = self.get_dewies_or_error("amount", amount)
@ -730,7 +731,7 @@ class API:
'confirmations': confirmations, 'confirmations': confirmations,
'show_seed': show_seed 'show_seed': show_seed
} }
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
if account_id: if account_id:
return paginate_list([await wallet.get_account_or_error(account_id).get_details(**kwargs)], 1, 1) return paginate_list([await wallet.get_account_or_error(account_id).get_details(**kwargs)], 1, 1)
else: else:
@ -754,8 +755,8 @@ class API:
Returns: Returns:
(decimal) amount of lbry credits in wallet (decimal) amount of lbry credits in wallet
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
account = wallet.get_account_or_default(account_id) account = wallet.accounts.get_or_default(account_id)
balance = await account.get_detailed_balance( balance = await account.get_detailed_balance(
confirmations=confirmations, reserved_subtotals=True, confirmations=confirmations, reserved_subtotals=True,
) )
@ -783,7 +784,7 @@ class API:
Returns: {Account} Returns: {Account}
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
account = Account.from_dict( account = Account.from_dict(
self.ledger, wallet, { self.ledger, wallet, {
'name': account_name, 'name': account_name,
@ -816,7 +817,7 @@ class API:
Returns: {Account} Returns: {Account}
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
account = Account.generate( account = Account.generate(
self.ledger.ledger, wallet, account_name, { self.ledger.ledger, wallet, account_name, {
'name': SingleKey.name if single_key else HierarchicalDeterministic.name 'name': SingleKey.name if single_key else HierarchicalDeterministic.name
@ -840,7 +841,7 @@ class API:
Returns: {Account} Returns: {Account}
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
account = wallet.get_account_or_error(account_id) account = wallet.get_account_or_error(account_id)
wallet.accounts.remove(account) wallet.accounts.remove(account)
wallet.save() wallet.save()
@ -872,7 +873,7 @@ class API:
Returns: {Account} Returns: {Account}
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
account = wallet.get_account_or_error(account_id) account = wallet.get_account_or_error(account_id)
change_made = False change_made = False
@ -921,7 +922,7 @@ class API:
Returns: Returns:
(map) maximum gap for change and receiving addresses (map) maximum gap for change and receiving addresses
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
return wallet.get_account_or_error(account_id).get_max_gap() return wallet.get_account_or_error(account_id).get_max_gap()
async def account_fund(self, to_account=None, from_account=None, amount='0.0', async def account_fund(self, to_account=None, from_account=None, amount='0.0',
@ -950,9 +951,9 @@ class API:
Returns: {Transaction} Returns: {Transaction}
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
to_account = wallet.get_account_or_default(to_account) to_account = wallet.accounts.get_or_default(to_account)
from_account = wallet.get_account_or_default(from_account) from_account = wallet.accounts.get_or_default(from_account)
amount = self.get_dewies_or_error('amount', amount) if amount else None amount = self.get_dewies_or_error('amount', amount) if amount else None
if not isinstance(outputs, int): if not isinstance(outputs, int):
raise ValueError("--outputs must be an integer.") raise ValueError("--outputs must be an integer.")
@ -1000,7 +1001,7 @@ class API:
Returns: Returns:
(str) sha256 hash of wallet (str) sha256 hash of wallet
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
return hexlify(wallet.hash).decode() return hexlify(wallet.hash).decode()
async def sync_apply(self, password, data=None, wallet_id=None, blocking=False): async def sync_apply(self, password, data=None, wallet_id=None, blocking=False):
@ -1026,10 +1027,10 @@ class API:
(map) sync hash and data (map) sync hash and data
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
wallet_changed = False wallet_changed = False
if data is not None: if data is not None:
added_accounts = wallet.merge(self.wallet_manager, password, data) added_accounts = wallet.merge(self.wallets, password, data)
if added_accounts and self.ledger.sync.network.is_connected: if added_accounts and self.ledger.sync.network.is_connected:
if blocking: if blocking:
await asyncio.wait([ await asyncio.wait([
@ -1070,8 +1071,8 @@ class API:
Returns: Returns:
(bool) true, if address is associated with current wallet (bool) true, if address is associated with current wallet
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
account = wallet.get_account_or_default(account_id) account = wallet.accounts.get_or_default(account_id)
match = await self.ledger.db.get_address(address=address, accounts=[account]) match = await self.ledger.db.get_address(address=address, accounts=[account])
if match is not None: if match is not None:
return True return True
@ -1094,7 +1095,7 @@ class API:
Returns: {Paginated[Address]} Returns: {Paginated[Address]}
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
constraints = {} constraints = {}
if address: if address:
constraints['address'] = address constraints['address'] = address
@ -1122,8 +1123,8 @@ class API:
Returns: {Address} Returns: {Address}
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
return wallet.get_account_or_default(account_id).receiving.get_or_create_usable_address() return wallet.accounts.get_or_default(account_id).receiving.get_or_create_usable_address()
async def address_block_filters(self): async def address_block_filters(self):
return await self.service.get_block_address_filters() return await self.service.get_block_address_filters()
@ -1175,7 +1176,7 @@ class API:
Returns: {Paginated[File]} Returns: {Paginated[File]}
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
sort = sort or 'rowid' sort = sort or 'rowid'
comparison = comparison or 'eq' comparison = comparison or 'eq'
paginated = paginate_list( paginated = paginate_list(
@ -1348,7 +1349,7 @@ class API:
Returns: {Paginated[Output]} Returns: {Paginated[Output]}
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
constraints = { constraints = {
"wallet": wallet, "wallet": wallet,
"accounts": [wallet.get_account_or_error(account_id)] if account_id else wallet.accounts, "accounts": [wallet.get_account_or_error(account_id)] if account_id else wallet.accounts,
@ -1385,9 +1386,9 @@ class API:
Returns: {Transaction} Returns: {Transaction}
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
assert not wallet.is_locked, "Cannot spend funds with locked wallet, unlock first." assert not wallet.is_locked, "Cannot spend funds with locked wallet, unlock first."
accounts = wallet.get_accounts_or_all(funding_account_ids) accounts = wallet.accounts.get_or_all(funding_account_ids)
txo = None txo = None
if claim_id: if claim_id:
txo = await self.ledger.get_claim_by_claim_id(accounts, claim_id, include_purchase_receipt=True) txo = await self.ledger.get_claim_by_claim_id(accounts, claim_id, include_purchase_receipt=True)
@ -1407,7 +1408,7 @@ class API:
claim = txo.claim claim = txo.claim
if not claim.is_stream or not claim.stream.has_fee: if not claim.is_stream or not claim.stream.has_fee:
raise Exception(f"Claim '{claim_id}' does not have a purchase price.") raise Exception(f"Claim '{claim_id}' does not have a purchase price.")
tx = await self.wallet_manager.create_purchase_transaction( tx = await self.wallets.create_purchase_transaction(
accounts, txo, self.exchange_rate_manager, override_max_key_fee accounts, txo, self.exchange_rate_manager, override_max_key_fee
) )
if not preview: if not preview:
@ -1594,7 +1595,7 @@ class API:
Returns: {Paginated[Output]} Returns: {Paginated[Output]}
""" """
wallet = self.wallet_manager.get_wallet_or_default(kwargs.pop('wallet_id', None)) wallet = self.wallets.get_or_default(kwargs.pop('wallet_id', None))
if {'claim_id', 'claim_ids'}.issubset(kwargs): if {'claim_id', 'claim_ids'}.issubset(kwargs):
raise ValueError("Only 'claim_id' or 'claim_ids' is allowed, not both.") raise ValueError("Only 'claim_id' or 'claim_ids' is allowed, not both.")
if kwargs.pop('valid_channel_signature', False): if kwargs.pop('valid_channel_signature', False):
@ -1697,16 +1698,15 @@ class API:
Returns: {Transaction} Returns: {Transaction}
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id)
assert not wallet.is_locked, "Cannot spend funds with locked wallet, unlock first."
account = wallet.get_account_or_default(account_id)
funding_accounts = wallet.get_accounts_or_all(funding_account_ids)
self.service.ledger.valid_channel_name_or_error(name) self.service.ledger.valid_channel_name_or_error(name)
#amount = self.get_dewies_or_error('bid', bid, positive_value=True) wallet = self.wallets.get_or_default(wallet_id)
amount = lbc_to_dewies(bid) assert not wallet.is_locked, "Cannot spend funds with locked wallet, unlock first."
account = wallet.accounts.get_or_default(account_id)
funding_accounts = wallet.accounts.get_or_all(funding_account_ids)
amount = self.ledger.get_dewies_or_error('bid', bid, positive_value=True)
claim_address = await account.get_valid_receiving_address(claim_address) claim_address = await account.get_valid_receiving_address(claim_address)
existing_channels, _ = await self.service.get_channels(wallet=wallet, claim_name=name) existing_channels, _ = await wallet.channels.list(claim_name=name)
if len(existing_channels) > 0: if len(existing_channels) > 0:
if not allow_duplicate_name: if not allow_duplicate_name:
raise Exception( raise Exception(
@ -1714,14 +1714,14 @@ class API:
f"Use --allow-duplicate-name flag to override." f"Use --allow-duplicate-name flag to override."
) )
tx = await wallet.create_channel( tx = await wallet.channels.create(
name, amount, account, funding_accounts, claim_address, preview, **kwargs name, amount, account, funding_accounts, claim_address, preview, **kwargs
) )
if not preview: if not preview:
await self.service.broadcast_or_release(tx, blocking) await self.service.broadcast_or_release(tx, blocking)
else: else:
await account.ledger.release_tx(tx) await self.service.release_tx(tx)
return tx return tx
@ -1813,9 +1813,9 @@ class API:
Returns: {Transaction} Returns: {Transaction}
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
assert not wallet.is_locked, "Cannot spend funds with locked wallet, unlock first." assert not wallet.is_locked, "Cannot spend funds with locked wallet, unlock first."
funding_accounts = wallet.get_accounts_or_all(funding_account_ids) funding_accounts = wallet.accounts.get_or_all(funding_account_ids)
if account_id: if account_id:
account = wallet.get_account_or_error(account_id) account = wallet.get_account_or_error(account_id)
accounts = [account] accounts = [account]
@ -1903,7 +1903,7 @@ class API:
Returns: {Transaction} Returns: {Transaction}
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
assert not wallet.is_locked, "Cannot spend funds with locked wallet, unlock first." assert not wallet.is_locked, "Cannot spend funds with locked wallet, unlock first."
if account_id: if account_id:
account = wallet.get_account_or_error(account_id) account = wallet.get_account_or_error(account_id)
@ -1985,7 +1985,7 @@ class API:
Returns: Returns:
(str) serialized channel private key (str) serialized channel private key
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
channel = await self.get_channel_or_error(wallet, account_id, channel_id, channel_name, for_signing=True) channel = await self.get_channel_or_error(wallet, account_id, channel_id, channel_name, for_signing=True)
address = channel.get_address(self.ledger) address = channel.get_address(self.ledger)
public_key = await self.ledger.get_public_key_for_address(wallet, address) public_key = await self.ledger.get_public_key_for_address(wallet, address)
@ -2014,7 +2014,7 @@ class API:
Returns: Returns:
(dict) Result dictionary (dict) Result dictionary
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
decoded = base58.b58decode(channel_data) decoded = base58.b58decode(channel_data)
data = json.loads(decoded) data = json.loads(decoded)
@ -2161,7 +2161,7 @@ class API:
Returns: {Transaction} Returns: {Transaction}
""" """
self.valid_stream_name_or_error(name) self.valid_stream_name_or_error(name)
wallet = self.wallet_manager.get_wallet_or_default(kwargs.get('wallet_id')) wallet = self.wallets.get_or_default(kwargs.get('wallet_id'))
if kwargs.get('account_id'): if kwargs.get('account_id'):
accounts = [wallet.get_account_or_error(kwargs.get('account_id'))] accounts = [wallet.get_account_or_error(kwargs.get('account_id'))]
else: else:
@ -2218,10 +2218,10 @@ class API:
Returns: {Transaction} Returns: {Transaction}
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
self.valid_stream_name_or_error(name) self.valid_stream_name_or_error(name)
account = wallet.get_account_or_default(account_id) account = wallet.accounts.get_or_default(account_id)
funding_accounts = wallet.get_accounts_or_all(funding_account_ids) funding_accounts = wallet.accounts.get_or_all(funding_account_ids)
channel = await self.get_channel_or_none(wallet, channel_account_id, channel_id, channel_name, for_signing=True) channel = await self.get_channel_or_none(wallet, channel_account_id, channel_id, channel_name, for_signing=True)
amount = self.get_dewies_or_error('bid', bid, positive_value=True) amount = self.get_dewies_or_error('bid', bid, positive_value=True)
claim_address = await self.get_receiving_address(claim_address, account) claim_address = await self.get_receiving_address(claim_address, account)
@ -2359,17 +2359,17 @@ class API:
Returns: {Transaction} Returns: {Transaction}
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) self.ledger.valid_stream_name_or_error(name)
wallet = self.wallets.get_or_default(wallet_id)
assert not wallet.is_locked, "Cannot spend funds with locked wallet, unlock first." assert not wallet.is_locked, "Cannot spend funds with locked wallet, unlock first."
self.valid_stream_name_or_error(name) account = wallet.accounts.get_or_default(account_id)
account = wallet.get_account_or_default(account_id) funding_accounts = wallet.accounts.get_or_all(funding_account_ids)
funding_accounts = wallet.get_accounts_or_all(funding_account_ids) channel = await wallet.channels.get_for_signing_or_none(claim_id=channel_id, claim_name=channel_name)
channel = await self.get_channel_or_none(wallet, channel_account_id, channel_id, channel_name, for_signing=True) amount = self.ledger.get_dewies_or_error('bid', bid, positive_value=True)
amount = self.get_dewies_or_error('bid', bid, positive_value=True) claim_address = await account.get_valid_receiving_address(claim_address)
claim_address = await self.get_receiving_address(claim_address, account) kwargs['fee_address'] = self.ledger.get_fee_address(kwargs, claim_address)
kwargs['fee_address'] = self.get_fee_address(kwargs, claim_address)
claims = await account.get_claims(claim_name=name) claims = await wallet.streams.list(claim_name=name)
if len(claims) > 0: if len(claims) > 0:
if not allow_duplicate_name: if not allow_duplicate_name:
raise Exception( raise Exception(
@ -2377,11 +2377,15 @@ class API:
f"Use --allow-duplicate-name flag to override." f"Use --allow-duplicate-name flag to override."
) )
file_path, spec = await self._video_file_analyzer.verify_or_repair( # TODO: fix
validate_file, optimize_file, file_path, ignore_non_video=True #file_path, spec = await self._video_file_analyzer.verify_or_repair(
) # validate_file, optimize_file, file_path, ignore_non_video=True
kwargs.update(spec) #)
#kwargs.update(spec)
wallet.streams.create(
)
claim = Claim() claim = Claim()
claim.stream.update(file_path=file_path, sd_hash='0' * 96, **kwargs) claim.stream.update(file_path=file_path, sd_hash='0' * 96, **kwargs)
tx = await Transaction.claim_create( tx = await Transaction.claim_create(
@ -2529,9 +2533,9 @@ class API:
Returns: {Transaction} Returns: {Transaction}
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
assert not wallet.is_locked, "Cannot spend funds with locked wallet, unlock first." assert not wallet.is_locked, "Cannot spend funds with locked wallet, unlock first."
funding_accounts = wallet.get_accounts_or_all(funding_account_ids) funding_accounts = wallet.accounts.get_or_all(funding_account_ids)
if account_id: if account_id:
account = wallet.get_account_or_error(account_id) account = wallet.get_account_or_error(account_id)
accounts = [account] accounts = [account]
@ -2651,7 +2655,7 @@ class API:
Returns: {Transaction} Returns: {Transaction}
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
assert not wallet.is_locked, "Cannot spend funds with locked wallet, unlock first." assert not wallet.is_locked, "Cannot spend funds with locked wallet, unlock first."
if account_id: if account_id:
account = wallet.get_account_or_error(account_id) account = wallet.get_account_or_error(account_id)
@ -2813,9 +2817,9 @@ class API:
Returns: {Transaction} Returns: {Transaction}
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
account = wallet.get_account_or_default(account_id) account = wallet.accounts.get_or_default(account_id)
funding_accounts = wallet.get_accounts_or_all(funding_account_ids) funding_accounts = wallet.accounts.get_or_all(funding_account_ids)
self.valid_collection_name_or_error(name) self.valid_collection_name_or_error(name)
channel = await self.get_channel_or_none(wallet, channel_account_id, channel_id, channel_name, for_signing=True) channel = await self.get_channel_or_none(wallet, channel_account_id, channel_id, channel_name, for_signing=True)
amount = self.get_dewies_or_error('bid', bid, positive_value=True) amount = self.get_dewies_or_error('bid', bid, positive_value=True)
@ -2933,8 +2937,8 @@ class API:
Returns: {Transaction} Returns: {Transaction}
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
funding_accounts = wallet.get_accounts_or_all(funding_account_ids) funding_accounts = wallet.accounts.get_or_all(funding_account_ids)
if account_id: if account_id:
account = wallet.get_account_or_error(account_id) account = wallet.get_account_or_error(account_id)
accounts = [account] accounts = [account]
@ -3041,7 +3045,7 @@ class API:
Returns: {Paginated[Output]} Returns: {Paginated[Output]}
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
if account_id: if account_id:
account = wallet.get_account_or_error(account_id) account = wallet.get_account_or_error(account_id)
collections = account.get_collections collections = account.get_collections
@ -3069,7 +3073,7 @@ class API:
Returns: {Paginated[Output]} Returns: {Paginated[Output]}
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
if claim_id: if claim_id:
txo = await self.ledger.get_claim_by_claim_id(wallet.accounts, claim_id) txo = await self.ledger.get_claim_by_claim_id(wallet.accounts, claim_id)
@ -3121,14 +3125,14 @@ class API:
Returns: {Transaction} Returns: {Transaction}
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
assert not wallet.is_locked, "Cannot spend funds with locked wallet, unlock first." assert not wallet.is_locked, "Cannot spend funds with locked wallet, unlock first."
funding_accounts = wallet.get_accounts_or_all(funding_account_ids) funding_accounts = wallet.accounts.get_or_all(funding_account_ids)
amount = self.get_dewies_or_error("amount", amount) amount = self.ledger.get_dewies_or_error("amount", amount)
claim = await self.ledger.get_claim_by_claim_id(wallet.accounts, claim_id) claim = await self.ledger.get_claim_by_claim_id(wallet.accounts, claim_id)
claim_address = claim.get_address(self.ledger.ledger) claim_address = claim.get_address(self.ledger.ledger)
if not tip: if not tip:
account = wallet.get_account_or_default(account_id) account = wallet.accounts.get_or_default(account_id)
claim_address = await account.receiving.get_or_create_usable_address() claim_address = await account.receiving.get_or_create_usable_address()
tx = await Transaction.support( tx = await Transaction.support(
@ -3217,7 +3221,7 @@ class API:
Returns: {Transaction} Returns: {Transaction}
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
assert not wallet.is_locked, "Cannot spend funds with locked wallet, unlock first." assert not wallet.is_locked, "Cannot spend funds with locked wallet, unlock first."
if account_id: if account_id:
account = wallet.get_account_or_error(account_id) account = wallet.get_account_or_error(account_id)
@ -3329,7 +3333,7 @@ class API:
} }
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
if account_id: if account_id:
account = wallet.get_account_or_error(account_id) account = wallet.get_account_or_error(account_id)
transactions = account.get_transaction_history transactions = account.get_transaction_history
@ -3441,7 +3445,7 @@ class API:
Returns: {Paginated[Output]} Returns: {Paginated[Output]}
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
if account_id: if account_id:
account = wallet.get_account_or_error(account_id) account = wallet.get_account_or_error(account_id)
claims = account.get_txos claims = account.get_txos
@ -3501,7 +3505,7 @@ class API:
Returns: {List[Transaction]} Returns: {List[Transaction]}
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
accounts = [wallet.get_account_or_error(account_id)] if account_id else wallet.accounts accounts = [wallet.get_account_or_error(account_id)] if account_id else wallet.accounts
txos = await self.ledger.get_txos( txos = await self.ledger.get_txos(
wallet=wallet, accounts=accounts, wallet=wallet, accounts=accounts,
@ -3559,7 +3563,7 @@ class API:
Returns: int Returns: int
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
return self.ledger.get_txo_sum( return self.ledger.get_txo_sum(
wallet=wallet, accounts=[wallet.get_account_or_error(account_id)] if account_id else wallet.accounts, wallet=wallet, accounts=[wallet.get_account_or_error(account_id)] if account_id else wallet.accounts,
**self._constrain_txo_from_kwargs({}, **kwargs) **self._constrain_txo_from_kwargs({}, **kwargs)
@ -3614,7 +3618,7 @@ class API:
Returns: List[Dict] Returns: List[Dict]
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
plot = await self.ledger.get_txo_plot( plot = await self.ledger.get_txo_plot(
wallet=wallet, accounts=[wallet.get_account_or_error(account_id)] if account_id else wallet.accounts, wallet=wallet, accounts=[wallet.get_account_or_error(account_id)] if account_id else wallet.accounts,
days_back=days_back, start_day=start_day, days_after=days_after, end_day=end_day, days_back=days_back, start_day=start_day, days_after=days_after, end_day=end_day,
@ -3666,7 +3670,7 @@ class API:
Returns: Returns:
None None
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
if account_id is not None: if account_id is not None:
await wallet.get_account_or_error(account_id).release_all_outputs() await wallet.get_account_or_error(account_id).release_all_outputs()
else: else:
@ -4160,7 +4164,7 @@ class API:
"timestamp": (int) The time at which comment was entered into the server at, in nanoseconds. "timestamp": (int) The time at which comment was entered into the server at, in nanoseconds.
} }
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
channel = await self.get_channel_or_error( channel = await self.get_channel_or_error(
wallet, channel_account_id, channel_id, channel_name, for_signing=True wallet, channel_account_id, channel_id, channel_name, for_signing=True
) )
@ -4216,7 +4220,7 @@ class API:
if 'error' in channel: if 'error' in channel:
raise ValueError(channel['error']) raise ValueError(channel['error'])
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
# channel = await self.get_channel_or_none(wallet, None, **channel) # channel = await self.get_channel_or_none(wallet, None, **channel)
channel_claim = await self.get_channel_or_error(wallet, [], **channel) channel_claim = await self.get_channel_or_error(wallet, [], **channel)
edited_comment = { edited_comment = {
@ -4249,7 +4253,7 @@ class API:
} }
} }
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
abandon_comment_body = {'comment_id': comment_id} abandon_comment_body = {'comment_id': comment_id}
channel = await comment_client.post( channel = await comment_client.post(
self.conf.comment_server, 'get_channel_from_comment_id', comment_id=comment_id self.conf.comment_server, 'get_channel_from_comment_id', comment_id=comment_id
@ -4281,7 +4285,7 @@ class API:
"hidden": (bool) flag indicating if comment_id was hidden "hidden": (bool) flag indicating if comment_id was hidden
} }
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallets.get_or_default(wallet_id)
if isinstance(comment_ids, str): if isinstance(comment_ids, str):
comment_ids = [comment_ids] comment_ids = [comment_ids]

View file

@ -1,18 +1,14 @@
import os import os
import asyncio import asyncio
import logging import logging
from datetime import datetime from typing import List, Optional, Tuple, NamedTuple
from typing import Iterable, List, Optional, Tuple, NamedTuple
from lbry.db import Database from lbry.db import Database
from lbry.db.constants import TXO_TYPES from lbry.db.constants import TXO_TYPES
from lbry.schema.result import Censor from lbry.schema.result import Censor
from lbry.blockchain.dewies import dewies_to_lbc
from lbry.blockchain.transaction import Transaction, Output from lbry.blockchain.transaction import Transaction, Output
from lbry.blockchain.ledger import Ledger from lbry.blockchain.ledger import Ledger
from lbry.crypto.bip32 import PubKey, PrivateKey from lbry.wallet import WalletManager, AddressManager
from lbry.wallet.account import Account, AddressManager, SingleKey
from lbry.wallet.manager import WalletManager
from lbry.event import EventController from lbry.event import EventController
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -62,7 +58,7 @@ class Service:
def __init__(self, ledger: Ledger, db_url: str): def __init__(self, ledger: Ledger, db_url: str):
self.ledger, self.conf = ledger, ledger.conf self.ledger, self.conf = ledger, ledger.conf
self.db = Database(ledger, db_url) self.db = Database(ledger, db_url)
self.wallet_manager = WalletManager(ledger, self.db) self.wallets = WalletManager(ledger, self.db)
#self.on_address = sync.on_address #self.on_address = sync.on_address
#self.accounts = sync.accounts #self.accounts = sync.accounts
@ -77,7 +73,8 @@ class Service:
async def start(self): async def start(self):
await self.db.open() await self.db.open()
await self.wallet_manager.open() await self.wallets.ensure_path_exists()
await self.wallets.load()
await self.sync.start() await self.sync.start()
async def stop(self): async def stop(self):
@ -106,30 +103,9 @@ class Service:
path = os.path.join(self.conf.wallet_dir, file_name) path = os.path.join(self.conf.wallet_dir, file_name)
return self.wallet_manager.import_wallet(path) return self.wallet_manager.import_wallet(path)
def add_account(self, account: Account):
self.ledger.add_account(account)
async def get_private_key_for_address(self, wallet, address) -> Optional[PrivateKey]:
return await self.ledger.get_private_key_for_address(wallet, address)
async def get_public_key_for_address(self, wallet, address) -> Optional[PubKey]:
return await self.ledger.get_public_key_for_address(wallet, address)
async def get_account_for_address(self, wallet, address):
return await self.ledger.get_account_for_address(wallet, address)
async def get_effective_amount_estimators(self, funding_accounts: Iterable[Account]):
return await self.ledger.get_effective_amount_estimators(funding_accounts)
async def get_addresses(self, **constraints): async def get_addresses(self, **constraints):
return await self.db.get_addresses(**constraints) return await self.db.get_addresses(**constraints)
def get_address_count(self, **constraints):
return self.db.get_address_count(**constraints)
async def get_spendable_utxos(self, amount: int, funding_accounts):
return await self.ledger.get_spendable_utxos(amount, funding_accounts)
def reserve_outputs(self, txos): def reserve_outputs(self, txos):
return self.db.reserve_outputs(txos) return self.db.reserve_outputs(txos)
@ -143,19 +119,12 @@ class Service:
self.constraint_spending_utxos(constraints) self.constraint_spending_utxos(constraints)
return self.db.get_utxos(**constraints) return self.db.get_utxos(**constraints)
def get_utxo_count(self, **constraints):
self.constraint_spending_utxos(constraints)
return self.db.get_utxo_count(**constraints)
async def get_txos(self, resolve=False, **constraints) -> List[Output]: async def get_txos(self, resolve=False, **constraints) -> List[Output]:
txos = await self.db.get_txos(**constraints) txos = await self.db.get_txos(**constraints)
if resolve: if resolve:
return await self._resolve_for_local_results(constraints.get('accounts', []), txos) return await self._resolve_for_local_results(constraints.get('accounts', []), txos)
return txos return txos
def get_txo_count(self, **constraints):
return self.db.get_txo_count(**constraints)
def get_txo_sum(self, **constraints): def get_txo_sum(self, **constraints):
return self.db.get_txo_sum(**constraints) return self.db.get_txo_sum(**constraints)
@ -165,8 +134,21 @@ class Service:
def get_transactions(self, **constraints): def get_transactions(self, **constraints):
return self.db.get_transactions(**constraints) return self.db.get_transactions(**constraints)
def get_transaction_count(self, **constraints): async def get_transaction(self, tx_hash: bytes):
return self.db.get_transaction_count(**constraints) tx = await self.db.get_transaction(tx_hash=tx_hash)
if tx:
return tx
try:
raw, merkle = await self.ledger.network.get_transaction_and_merkle(tx_hash)
except CodeMessageError as e:
if 'No such mempool or blockchain transaction.' in e.message:
return {'success': False, 'code': 404, 'message': 'transaction not found'}
return {'success': False, 'code': e.code, 'message': e.message}
height = merkle.get('block_height')
tx = Transaction(unhexlify(raw), height=height)
if height and height > 0:
await self.ledger.maybe_verify_transaction(tx, height, merkle)
return tx
async def search_transactions(self, txids): async def search_transactions(self, txids):
raise NotImplementedError raise NotImplementedError
@ -181,6 +163,20 @@ class Service:
return account.address_managers[details['chain']] return account.address_managers[details['chain']]
return None return None
async def reset(self):
self.ledger.config = {
'auto_connect': True,
'default_servers': self.config.lbryum_servers,
'data_path': self.config.wallet_dir,
}
await self.ledger.stop()
await self.ledger.start()
async def get_best_blockhash(self):
if len(self.ledger.headers) <= 0:
return self.ledger.genesis_hash
return (await self.ledger.headers.hash(self.ledger.headers.height)).decode()
async def broadcast_or_release(self, tx, blocking=False): async def broadcast_or_release(self, tx, blocking=False):
try: try:
await self.broadcast(tx) await self.broadcast(tx)
@ -206,39 +202,12 @@ class Service:
for claim in (await self.search_claims(accounts, claim_id=claim_id, **kwargs))[0]: for claim in (await self.search_claims(accounts, claim_id=claim_id, **kwargs))[0]:
return claim return claim
async def _report_state(self):
try:
for account in self.accounts:
balance = dewies_to_lbc(await account.get_balance(include_claims=True))
channel_count = await account.get_channel_count()
claim_count = await account.get_claim_count()
if isinstance(account.receiving, SingleKey):
log.info("Loaded single key account %s with %s LBC. "
"%d channels, %d certificates and %d claims",
account.id, balance, channel_count, len(account.channel_keys), claim_count)
else:
total_receiving = len(await account.receiving.get_addresses())
total_change = len(await account.change.get_addresses())
log.info("Loaded account %s with %s LBC, %d receiving addresses (gap: %d), "
"%d change addresses (gap: %d), %d channels, %d certificates and %d claims. ",
account.id, balance, total_receiving, account.receiving.gap, total_change,
account.change.gap, channel_count, len(account.channel_keys), claim_count)
except Exception as err:
if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8
raise
log.exception(
'Failed to display wallet state, please file issue '
'for this bug along with the traceback you see below:')
async def _reset_balance_cache(self, e):
return await self.ledger._reset_balance_cache(e)
@staticmethod @staticmethod
def constraint_spending_utxos(constraints): def constraint_spending_utxos(constraints):
constraints['txo_type__in'] = (0, TXO_TYPES['purchase']) constraints['txo_type__in'] = (0, TXO_TYPES['purchase'])
async def get_purchases(self, resolve=False, **constraints): async def get_purchases(self, wallet, resolve=False, **constraints):
purchases = await self.db.get_purchases(**constraints) purchases = await wallet.get_purchases(**constraints)
if resolve: if resolve:
claim_ids = [p.purchased_claim_id for p in purchases] claim_ids = [p.purchased_claim_id for p in purchases]
try: try:
@ -253,9 +222,6 @@ class Service:
purchase.purchased_claim = lookup.get(purchase.purchased_claim_id) purchase.purchased_claim = lookup.get(purchase.purchased_claim_id)
return purchases return purchases
def get_purchase_count(self, resolve=False, **constraints):
return self.db.get_purchase_count(**constraints)
async def _resolve_for_local_results(self, accounts, txos): async def _resolve_for_local_results(self, accounts, txos):
results = [] results = []
response = await self.resolve( response = await self.resolve(
@ -272,33 +238,6 @@ class Service:
results.append(txo) results.append(txo)
return results return results
async def get_claims(self, resolve=False, **constraints):
claims = await self.db.get_claims(**constraints)
if resolve:
return await self._resolve_for_local_results(constraints.get('accounts', []), claims)
return claims
def get_claim_count(self, **constraints):
return self.db.get_claim_count(**constraints)
async def get_streams(self, resolve=False, **constraints):
streams = await self.db.get_streams(**constraints)
if resolve:
return await self._resolve_for_local_results(constraints.get('accounts', []), streams)
return streams
def get_stream_count(self, **constraints):
return self.db.get_stream_count(**constraints)
async def get_channels(self, resolve=False, **constraints):
channels = await self.db.get_channels(**constraints)
if resolve:
return await self._resolve_for_local_results(constraints.get('accounts', []), channels)
return channels
def get_channel_count(self, **constraints):
return self.db.get_channel_count(**constraints)
async def resolve_collection(self, collection, offset=0, page_size=1): async def resolve_collection(self, collection, offset=0, page_size=1):
claim_ids = collection.claim.collection.claims.ids[offset:page_size+offset] claim_ids = collection.claim.collection.claims.ids[offset:page_size+offset]
try: try:
@ -319,138 +258,3 @@ class Service:
if not found: if not found:
claims.append(None) claims.append(None)
return claims return claims
async def get_collections(self, resolve_claims=0, **constraints):
collections = await self.db.get_collections(**constraints)
if resolve_claims > 0:
for collection in collections:
collection.claims = await self.resolve_collection(collection, page_size=resolve_claims)
return collections
def get_collection_count(self, resolve_claims=0, **constraints):
return self.db.get_collection_count(**constraints)
def get_supports(self, **constraints):
return self.db.get_supports(**constraints)
def get_support_count(self, **constraints):
return self.db.get_support_count(**constraints)
async def get_transaction_history(self, **constraints):
txs: List[Transaction] = await self.db.get_transactions(
include_is_my_output=True, include_is_spent=True,
**constraints
)
headers = self.headers
history = []
for tx in txs: # pylint: disable=too-many-nested-blocks
ts = headers.estimated_timestamp(tx.height)
item = {
'txid': tx.id,
'timestamp': ts,
'date': datetime.fromtimestamp(ts).isoformat(' ')[:-3] if tx.height > 0 else None,
'confirmations': (headers.height+1) - tx.height if tx.height > 0 else 0,
'claim_info': [],
'update_info': [],
'support_info': [],
'abandon_info': [],
'purchase_info': []
}
is_my_inputs = all([txi.is_my_input for txi in tx.inputs])
if is_my_inputs:
# fees only matter if we are the ones paying them
item['value'] = dewies_to_lbc(tx.net_account_balance+tx.fee)
item['fee'] = dewies_to_lbc(-tx.fee)
else:
# someone else paid the fees
item['value'] = dewies_to_lbc(tx.net_account_balance)
item['fee'] = '0.0'
for txo in tx.my_claim_outputs:
item['claim_info'].append({
'address': txo.get_address(self.ledger),
'balance_delta': dewies_to_lbc(-txo.amount),
'amount': dewies_to_lbc(txo.amount),
'claim_id': txo.claim_id,
'claim_name': txo.claim_name,
'nout': txo.position,
'is_spent': txo.is_spent,
})
for txo in tx.my_update_outputs:
if is_my_inputs: # updating my own claim
previous = None
for txi in tx.inputs:
if txi.txo_ref.txo is not None:
other_txo = txi.txo_ref.txo
if (other_txo.is_claim or other_txo.script.is_support_claim) \
and other_txo.claim_id == txo.claim_id:
previous = other_txo
break
if previous is not None:
item['update_info'].append({
'address': txo.get_address(self),
'balance_delta': dewies_to_lbc(previous.amount-txo.amount),
'amount': dewies_to_lbc(txo.amount),
'claim_id': txo.claim_id,
'claim_name': txo.claim_name,
'nout': txo.position,
'is_spent': txo.is_spent,
})
else: # someone sent us their claim
item['update_info'].append({
'address': txo.get_address(self),
'balance_delta': dewies_to_lbc(0),
'amount': dewies_to_lbc(txo.amount),
'claim_id': txo.claim_id,
'claim_name': txo.claim_name,
'nout': txo.position,
'is_spent': txo.is_spent,
})
for txo in tx.my_support_outputs:
item['support_info'].append({
'address': txo.get_address(self.ledger),
'balance_delta': dewies_to_lbc(txo.amount if not is_my_inputs else -txo.amount),
'amount': dewies_to_lbc(txo.amount),
'claim_id': txo.claim_id,
'claim_name': txo.claim_name,
'is_tip': not is_my_inputs,
'nout': txo.position,
'is_spent': txo.is_spent,
})
if is_my_inputs:
for txo in tx.other_support_outputs:
item['support_info'].append({
'address': txo.get_address(self.ledger),
'balance_delta': dewies_to_lbc(-txo.amount),
'amount': dewies_to_lbc(txo.amount),
'claim_id': txo.claim_id,
'claim_name': txo.claim_name,
'is_tip': is_my_inputs,
'nout': txo.position,
'is_spent': txo.is_spent,
})
for txo in tx.my_abandon_outputs:
item['abandon_info'].append({
'address': txo.get_address(self.ledger),
'balance_delta': dewies_to_lbc(txo.amount),
'amount': dewies_to_lbc(txo.amount),
'claim_id': txo.claim_id,
'claim_name': txo.claim_name,
'nout': txo.position
})
for txo in tx.any_purchase_outputs:
item['purchase_info'].append({
'address': txo.get_address(self.ledger),
'balance_delta': dewies_to_lbc(txo.amount if not is_my_inputs else -txo.amount),
'amount': dewies_to_lbc(txo.amount),
'claim_id': txo.purchased_claim_id,
'nout': txo.position,
'is_spent': txo.is_spent,
})
history.append(item)
return history
def get_transaction_history_count(self, **constraints):
return self.db.get_transaction_count(**constraints)
async def get_detailed_balance(self, accounts, confirmations=0):
return self.ledger.get_detailed_balance(accounts, confirmations)

View file

@ -14,16 +14,15 @@ from time import time
from binascii import unhexlify from binascii import unhexlify
from functools import partial from functools import partial
from lbry.wallet import WalletManager, Wallet, Account
from lbry.blockchain.ledger import Ledger from lbry.blockchain.ledger import Ledger
from lbry.blockchain.transaction import Transaction, Input, Output from lbry.blockchain.transaction import Transaction, Input, Output
from lbry.blockchain.util import satoshis_to_coins from lbry.blockchain.util import satoshis_to_coins
from lbry.constants import CENT, NULL_HASH32
from lbry.wallet.wallet import Wallet, Account
from lbry.wallet.manager import WalletManager
from lbry.conf import Config
from lbry.blockchain.lbrycrd import Lbrycrd from lbry.blockchain.lbrycrd import Lbrycrd
from lbry.constants import CENT, NULL_HASH32
from lbry.service.full_node import FullNode from lbry.service.full_node import FullNode
from lbry.service.daemon import Daemon from lbry.service.daemon import Daemon
from lbry.conf import Config
from lbry.extras.daemon.daemon import jsonrpc_dumps_pretty from lbry.extras.daemon.daemon import jsonrpc_dumps_pretty
from lbry.extras.daemon.components import Component, WalletComponent from lbry.extras.daemon.components import Component, WalletComponent
@ -363,7 +362,7 @@ class CommandTestCase(IntegrationTestCase):
self.block_expected = 0 self.block_expected = 0
await self.generate(200, wait=False) await self.generate(200, wait=False)
self.chain.ledger.conf.spv_address_filters = False self.ledger.conf.spv_address_filters = False
self.service = FullNode( self.service = FullNode(
self.ledger, f'sqlite:///{self.chain.data_dir}/full_node.db', Lbrycrd(self.ledger) self.ledger, f'sqlite:///{self.chain.data_dir}/full_node.db', Lbrycrd(self.ledger)
) )
@ -372,10 +371,13 @@ class CommandTestCase(IntegrationTestCase):
self.addCleanup(self.daemon.stop) self.addCleanup(self.daemon.stop)
await self.daemon.start() await self.daemon.start()
self.wallet = self.service.wallet_manager.default_wallet self.wallet = self.service.wallets.default
self.account = self.wallet.accounts[0] self.account = self.wallet.accounts.default
addresses = await self.account.ensure_address_gap() addresses = await self.account.ensure_address_gap()
self.ledger.conf.upload_dir = os.path.join(self.ledger.conf.data_dir, 'uploads')
os.mkdir(self.ledger.conf.upload_dir)
await self.chain.send_to_address(addresses[0], '10.0') await self.chain.send_to_address(addresses[0], '10.0')
await self.generate(5) await self.generate(5)
@ -527,7 +529,9 @@ class CommandTestCase(IntegrationTestCase):
return self.sout(tx) return self.sout(tx)
def create_upload_file(self, data, prefix=None, suffix=None): def create_upload_file(self, data, prefix=None, suffix=None):
file_path = tempfile.mktemp(prefix=prefix or "tmp", suffix=suffix or "", dir=self.daemon.conf.upload_dir) file_path = tempfile.mktemp(
prefix=prefix or "tmp", suffix=suffix or "", dir=self.ledger.conf.upload_dir
)
with open(file_path, 'w+b') as file: with open(file_path, 'w+b') as file:
file.write(data) file.write(data)
file.flush() file.flush()
@ -539,7 +543,7 @@ class CommandTestCase(IntegrationTestCase):
if file_path is None: if file_path is None:
file_path = self.create_upload_file(data=data, prefix=prefix, suffix=suffix) file_path = self.create_upload_file(data=data, prefix=prefix, suffix=suffix)
return await self.confirm_and_render( return await self.confirm_and_render(
self.daemon.jsonrpc_stream_create(name, bid, file_path=file_path, **kwargs), confirm self.api.stream_create(name, bid, file_path=file_path, **kwargs), confirm
) )
async def stream_update( async def stream_update(

View file

@ -0,0 +1,3 @@
from .account import Account, AddressManager, SingleKey
from .wallet import Wallet
from .manager import WalletManager

View file

@ -9,12 +9,10 @@ from hashlib import sha256
from string import hexdigits from string import hexdigits
from typing import Type, Dict, Tuple, Optional, Any, List from typing import Type, Dict, Tuple, Optional, Any, List
from sqlalchemy import text
import ecdsa import ecdsa
from lbry.error import InvalidPasswordError from lbry.error import InvalidPasswordError
from lbry.crypto.crypt import aes_encrypt, aes_decrypt from lbry.crypto.crypt import aes_encrypt, aes_decrypt
from lbry.crypto.bip32 import PrivateKey, PubKey, from_extended_key_string from lbry.crypto.bip32 import PrivateKey, PubKey, from_extended_key_string
from lbry.constants import COIN from lbry.constants import COIN
from lbry.blockchain.transaction import Transaction, Input, Output from lbry.blockchain.transaction import Transaction, Input, Output

View file

@ -1,147 +1,102 @@
import os import os
import json
import typing
import logging
import asyncio import asyncio
from binascii import unhexlify import logging
from decimal import Decimal from typing import Optional, Dict
from typing import List, Type, MutableSequence, MutableMapping, Optional
from lbry.error import KeyFeeAboveMaxAllowedError
from lbry.conf import Config
from .account import Account
from lbry.blockchain.dewies import dewies_to_lbc
from lbry.blockchain.ledger import Ledger
from lbry.db import Database from lbry.db import Database
from lbry.blockchain.ledger import Ledger from lbry.blockchain.ledger import Ledger
from lbry.blockchain.transaction import Transaction, Output
from .wallet import Wallet, WalletStorage, ENCRYPT_ON_DISK
from .wallet import Wallet
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class WalletManager: class WalletManager:
def __init__(self, ledger: Ledger, db: Database, def __init__(self, ledger: Ledger, db: Database):
wallets: MutableSequence[Wallet] = None,
ledgers: MutableMapping[Type[Ledger], Ledger] = None) -> None:
self.ledger = ledger self.ledger = ledger
self.db = db self.db = db
self.wallets = wallets or [] self.wallets: Dict[str, Wallet] = {}
self.ledgers = ledgers or {}
self.running = False
self.config: Optional[Config] = None
async def open(self): def __getitem__(self, wallet_id: str) -> Wallet:
conf = self.ledger.conf try:
return self.wallets[wallet_id]
wallets_directory = os.path.join(conf.wallet_dir, 'wallets') except KeyError:
if not os.path.exists(wallets_directory):
os.mkdir(wallets_directory)
for wallet_file in conf.wallets:
wallet_path = os.path.join(wallets_directory, wallet_file)
wallet_storage = WalletStorage(wallet_path)
wallet = Wallet.from_storage(self.ledger, self.db, wallet_storage)
self.wallets.append(wallet)
self.ledger.coin_selection_strategy = self.ledger.conf.coin_selection_strategy
default_wallet = self.default_wallet
if default_wallet.default_account is None:
log.info('Wallet at %s is empty, generating a default account.', default_wallet.id)
default_wallet.generate_account()
default_wallet.save()
if default_wallet.is_locked and default_wallet.preferences.get(ENCRYPT_ON_DISK) is None:
default_wallet.preferences[ENCRYPT_ON_DISK] = True
default_wallet.save()
def import_wallet(self, path):
storage = WalletStorage(path)
wallet = Wallet.from_storage(self.ledger, self.db, storage)
self.wallets.append(wallet)
return wallet
@property
def default_wallet(self):
for wallet in self.wallets:
return wallet
@property
def default_account(self):
for wallet in self.wallets:
return wallet.default_account
@property
def accounts(self):
for wallet in self.wallets:
yield from wallet.accounts
async def start(self):
self.running = True
await asyncio.gather(*(
l.start() for l in self.ledgers.values()
))
async def stop(self):
await asyncio.gather(*(
l.stop() for l in self.ledgers.values()
))
self.running = False
def get_wallet_or_default(self, wallet_id: Optional[str]) -> Wallet:
if wallet_id is None:
return self.default_wallet
return self.get_wallet_or_error(wallet_id)
def get_wallet_or_error(self, wallet_id: str) -> Wallet:
for wallet in self.wallets:
if wallet.id == wallet_id:
return wallet
raise ValueError(f"Couldn't find wallet: {wallet_id}.") raise ValueError(f"Couldn't find wallet: {wallet_id}.")
@staticmethod @property
def get_balance(wallet): def default(self) -> Optional[Wallet]:
accounts = wallet.accounts for wallet in self.wallets.values():
if not accounts: return wallet
return 0
return accounts[0].ledger.db.get_balance(wallet=wallet, accounts=accounts)
def check_locked(self): def get_or_default(self, wallet_id: Optional[str]) -> Optional[Wallet]:
return self.default_wallet.is_locked if wallet_id:
return self[wallet_id]
return self.default
async def reset(self): @property
self.ledger.config = { def path(self):
'auto_connect': True, return os.path.join(self.ledger.conf.wallet_dir, 'wallets')
'default_servers': self.config.lbryum_servers,
'data_path': self.config.wallet_dir,
}
await self.ledger.stop()
await self.ledger.start()
async def get_best_blockhash(self): def sync_ensure_path_exists(self):
if len(self.ledger.headers) <= 0: if not os.path.exists(self.path):
return self.ledger.genesis_hash os.mkdir(self.path)
return (await self.ledger.headers.hash(self.ledger.headers.height)).decode()
def get_unused_address(self): async def ensure_path_exists(self):
return self.default_account.receiving.get_or_create_usable_address() await asyncio.get_running_loop().run_in_executor(
None, self.sync_ensure_path_exists
)
async def get_transaction(self, tx_hash: bytes): async def load(self):
tx = await self.db.get_transaction(tx_hash=tx_hash) wallets_directory = self.path
if tx: for wallet_id in self.ledger.conf.wallets:
return tx if wallet_id in self.wallets:
try: log.warning(f"Ignoring duplicate wallet_id in config: {wallet_id}")
raw, merkle = await self.ledger.network.get_transaction_and_merkle(tx_hash) continue
except CodeMessageError as e: wallet_path = os.path.join(wallets_directory, wallet_id)
if 'No such mempool or blockchain transaction.' in e.message: if not os.path.exists(wallet_path):
return {'success': False, 'code': 404, 'message': 'transaction not found'} if not wallet_id == "default_wallet": # we'll probably generate this wallet, don't show error
return {'success': False, 'code': e.code, 'message': e.message} log.error(f"Could not load wallet, file does not exist: {wallet_path}")
height = merkle.get('block_height') continue
tx = Transaction(unhexlify(raw), height=height) wallet = await Wallet.from_path(self.ledger, self.db, wallet_path)
if height and height > 0: self.add(wallet)
await self.ledger.maybe_verify_transaction(tx, height, merkle) default_wallet = self.default
return tx if default_wallet is None:
if self.ledger.conf.create_default_wallet:
assert self.ledger.conf.wallets[0] == "default_wallet", (
"Requesting to generate the default wallet but the 'wallets' "
"config setting does not include 'default_wallet' as the first wallet."
)
await self.create(
self.ledger.conf.wallets[0], 'Wallet',
create_account=self.ledger.conf.create_default_account
)
elif not default_wallet.has_accounts and self.ledger.conf.create_default_account:
default_wallet.accounts.generate()
def add(self, wallet: Wallet) -> Wallet:
self.wallets[wallet.id] = wallet
return wallet
async def add_from_path(self, wallet_path) -> Wallet:
wallet_id = os.path.basename(wallet_path)
if wallet_id in self.wallets:
existing = self.wallets.get(wallet_id)
if existing.storage.path == wallet_path:
raise Exception(f"Wallet '{wallet_id}' is already loaded.")
raise Exception(
f"Wallet '{wallet_id}' is already loaded from '{existing.storage.path}'"
f" and cannot be loaded from '{wallet_path}'. Consider changing the wallet"
f" filename to be unique in order to avoid conflicts."
)
wallet = await Wallet.from_path(self.ledger, self.db, wallet_path)
return self.add(wallet)
async def create(self, wallet_id: str, name: str, create_account=False, single_key=False) -> Wallet:
if wallet_id in self.wallets:
raise Exception(f"Wallet with id '{wallet_id}' is already loaded and cannot be created.")
wallet_path = os.path.join(self.path, wallet_id)
if os.path.exists(wallet_path):
raise Exception(f"Wallet at path '{wallet_path}' already exists, use 'wallet_add' to load wallet.")
wallet = await Wallet.create(self.ledger, self.db, wallet_path, name, create_account, single_key)
return self.add(wallet)

View file

@ -0,0 +1,39 @@
import time
import json
from hashlib import sha256
from collections import UserDict
class TimestampedPreferences(UserDict):
def __init__(self, d: dict = None):
super().__init__()
if d is not None:
self.data = d.copy()
def __getitem__(self, key):
return self.data[key]['value']
def __setitem__(self, key, value):
self.data[key] = {
'value': value,
'ts': time.time()
}
def __repr__(self):
return repr(self.to_dict_without_ts())
def to_dict_without_ts(self):
return {
key: value['value'] for key, value in self.data.items()
}
@property
def hash(self):
return sha256(json.dumps(self.data).encode()).digest()
def merge(self, other: dict):
for key, value in other.items():
if key in self.data and value['ts'] < self.data[key]['ts']:
continue
self.data[key] = value

60
lbry/wallet/storage.py Normal file
View file

@ -0,0 +1,60 @@
import os
import stat
import json
import asyncio
class WalletStorage:
VERSION = 1
def __init__(self, path=None):
self.path = path
def sync_read(self):
with open(self.path, 'r') as f:
json_data = f.read()
json_dict = json.loads(json_data)
if json_dict.get('version') == self.VERSION:
return json_dict
else:
return self.upgrade(json_dict)
async def read(self):
return await asyncio.get_running_loop().run_in_executor(
None, self.sync_read
)
def upgrade(self, json_dict):
version = json_dict.pop('version', -1)
if version == -1:
pass
json_dict['version'] = self.VERSION
return json_dict
def sync_write(self, json_dict):
json_data = json.dumps(json_dict, indent=4, sort_keys=True)
if self.path is None:
return json_data
temp_path = "{}.tmp.{}".format(self.path, os.getpid())
with open(temp_path, "w") as f:
f.write(json_data)
f.flush()
os.fsync(f.fileno())
if os.path.exists(self.path):
mode = os.stat(self.path).st_mode
else:
mode = stat.S_IREAD | stat.S_IWRITE
try:
os.rename(temp_path, self.path)
except Exception: # pylint: disable=broad-except
os.remove(self.path)
os.rename(temp_path, self.path)
os.chmod(self.path, mode)
async def write(self, json_dict):
return await asyncio.get_running_loop().run_in_executor(
None, self.sync_write, json_dict
)

View file

@ -1,16 +1,14 @@
import os import os
import time
import stat
import json import json
import zlib import zlib
import typing import asyncio
import logging import logging
from typing import List, Sequence, MutableSequence, Optional, Iterable from typing import List, Sequence, Tuple, Optional, Iterable
from collections import UserDict
from hashlib import sha256 from hashlib import sha256
from operator import attrgetter from operator import attrgetter
from decimal import Decimal from decimal import Decimal
from lbry.db import Database, SPENDABLE_TYPE_CODES from lbry.db import Database, SPENDABLE_TYPE_CODES
from lbry.blockchain.ledger import Ledger from lbry.blockchain.ledger import Ledger
from lbry.constants import COIN, NULL_HASH32 from lbry.constants import COIN, NULL_HASH32
@ -22,11 +20,10 @@ from lbry.schema.claim import Claim
from lbry.schema.purchase import Purchase from lbry.schema.purchase import Purchase
from lbry.error import InsufficientFundsError, KeyFeeAboveMaxAllowedError from lbry.error import InsufficientFundsError, KeyFeeAboveMaxAllowedError
from .account import Account from .account import Account, SingleKey, HierarchicalDeterministic
from .coinselection import CoinSelector, OutputEffectiveAmountEstimator from .coinselection import CoinSelector, OutputEffectiveAmountEstimator
from .storage import WalletStorage
if typing.TYPE_CHECKING: from .preferences import TimestampedPreferences
from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -34,41 +31,6 @@ log = logging.getLogger(__name__)
ENCRYPT_ON_DISK = 'encrypt-on-disk' ENCRYPT_ON_DISK = 'encrypt-on-disk'
class TimestampedPreferences(UserDict):
def __init__(self, d: dict = None):
super().__init__()
if d is not None:
self.data = d.copy()
def __getitem__(self, key):
return self.data[key]['value']
def __setitem__(self, key, value):
self.data[key] = {
'value': value,
'ts': time.time()
}
def __repr__(self):
return repr(self.to_dict_without_ts())
def to_dict_without_ts(self):
return {
key: value['value'] for key, value in self.data.items()
}
@property
def hash(self):
return sha256(json.dumps(self.data).encode()).digest()
def merge(self, other: dict):
for key, value in other.items():
if key in self.data and value['ts'] < self.data[key]['ts']:
continue
self.data[key] = value
class Wallet: class Wallet:
""" The primary role of Wallet is to encapsulate a collection """ The primary role of Wallet is to encapsulate a collection
of accounts (seed/private keys) and the spending rules / settings of accounts (seed/private keys) and the spending rules / settings
@ -76,102 +38,45 @@ class Wallet:
by physical files on the filesystem. by physical files on the filesystem.
""" """
preferences: TimestampedPreferences def __init__(self, ledger: Ledger, db: Database, name: str, storage: WalletStorage, preferences: dict):
encryption_password: Optional[str]
def __init__(self, ledger: Ledger, db: Database,
name: str = 'Wallet', accounts: MutableSequence[Account] = None,
storage: 'WalletStorage' = None, preferences: dict = None) -> None:
self.ledger = ledger self.ledger = ledger
self.db = db self.db = db
self.name = name self.name = name
self.accounts = accounts or [] self.storage = storage
self.storage = storage or WalletStorage()
self.preferences = TimestampedPreferences(preferences or {}) self.preferences = TimestampedPreferences(preferences or {})
self.encryption_password = None self.encryption_password: Optional[str] = None
self.id = self.get_id() self.id = self.get_id()
self.utxo_lock = asyncio.Lock()
self.accounts = AccountListManager(self)
self.claims = ClaimListManager(self)
self.streams = StreamListManager(self)
self.channels = ChannelListManager(self)
self.collections = CollectionListManager(self)
self.purchases = PurchaseListManager(self)
self.supports = SupportListManager(self)
def get_id(self): def get_id(self):
return os.path.basename(self.storage.path) if self.storage.path else self.name return os.path.basename(self.storage.path) if self.storage.path else self.name
def generate_account(self, name: str = None, address_generator: dict = None) -> Account: @classmethod
account = Account.generate(self.ledger, self.db, name, address_generator) async def create(cls, ledger: Ledger, db: Database, path: str, name: str, create_account=False, single_key=False):
self.accounts.append(account) wallet = cls(ledger, db, name, WalletStorage(path), {})
return account if create_account:
wallet.accounts.generate(address_generator={
def add_account(self, account_dict) -> Account: 'name': SingleKey.name if single_key else HierarchicalDeterministic.name
account = Account.from_dict(self.ledger, self.db, account_dict) })
self.accounts.append(account) await wallet.save()
return account return wallet
@property
def default_account(self) -> Optional[Account]:
for account in self.accounts:
return account
return None
def get_account_or_default(self, account_id: str) -> Optional[Account]:
if account_id is None:
return self.default_account
return self.get_account_or_error(account_id)
def get_account_or_error(self, account_id: str) -> Account:
for account in self.accounts:
if account.id == account_id:
return account
raise ValueError(f"Couldn't find account: {account_id}.")
def get_accounts_or_all(self, account_ids: List[str]) -> Sequence[Account]:
return [
self.get_account_or_error(account_id)
for account_id in account_ids
] if account_ids else self.accounts
async def get_detailed_accounts(self, **kwargs):
accounts = []
for i, account in enumerate(self.accounts):
details = await account.get_details(**kwargs)
details['is_default'] = i == 0
accounts.append(details)
return accounts
async def _get_account_and_address_info_for_address(self, address):
match = await self.db.get_address(accounts=self.accounts, address=address)
if match:
for account in self.accounts:
if match['account'] == account.public_key.address:
return account, match
async def get_private_key_for_address(self, address) -> Optional[PrivateKey]:
match = await self._get_account_and_address_info_for_address(address)
if match:
account, address_info = match
return account.get_private_key(address_info['chain'], address_info['pubkey'].n)
return None
async def get_public_key_for_address(self, address) -> Optional[PubKey]:
match = await self._get_account_and_address_info_for_address(address)
if match:
_, address_info = match
return address_info['pubkey']
return None
async def get_account_for_address(self, address):
match = await self._get_account_and_address_info_for_address(address)
if match:
return match[0]
async def save_max_gap(self):
gap_changed = False
for account in self.accounts:
if await account.save_max_gap():
gap_changed = True
if gap_changed:
self.save()
@classmethod @classmethod
def from_storage(cls, ledger: Ledger, db: Database, storage: 'WalletStorage') -> 'Wallet': async def from_path(cls, ledger: Ledger, db: Database, path: str):
json_dict = storage.read() return await cls.from_storage(ledger, db, WalletStorage(path))
@classmethod
async def from_storage(cls, ledger: Ledger, db: Database, storage: WalletStorage) -> 'Wallet':
json_dict = await storage.read()
if 'ledger' in json_dict and json_dict['ledger'] != ledger.get_id(): if 'ledger' in json_dict and json_dict['ledger'] != ledger.get_id():
raise ValueError( raise ValueError(
f"Using ledger {ledger.get_id()} but wallet is {json_dict['ledger']}." f"Using ledger {ledger.get_id()} but wallet is {json_dict['ledger']}."
@ -179,33 +84,32 @@ class Wallet:
wallet = cls( wallet = cls(
ledger, db, ledger, db,
name=json_dict.get('name', 'Wallet'), name=json_dict.get('name', 'Wallet'),
storage=storage,
preferences=json_dict.get('preferences', {}), preferences=json_dict.get('preferences', {}),
storage=storage
) )
account_dicts: Sequence[dict] = json_dict.get('accounts', []) for account_dict in json_dict.get('accounts', []):
for account_dict in account_dicts: wallet.accounts.add_from_dict(account_dict)
wallet.add_account(account_dict)
return wallet return wallet
def to_dict(self, encrypt_password: str = None): def to_dict(self, encrypt_password: str = None):
return { return {
'version': WalletStorage.LATEST_VERSION, 'version': WalletStorage.VERSION,
'name': self.name,
'ledger': self.ledger.get_id(), 'ledger': self.ledger.get_id(),
'name': self.name,
'preferences': self.preferences.data, 'preferences': self.preferences.data,
'accounts': [a.to_dict(encrypt_password) for a in self.accounts] 'accounts': [a.to_dict(encrypt_password) for a in self.accounts]
} }
def save(self): async def save(self):
if self.preferences.get(ENCRYPT_ON_DISK, False): if self.preferences.get(ENCRYPT_ON_DISK, False):
if self.encryption_password is not None: if self.encryption_password is not None:
return self.storage.write(self.to_dict(encrypt_password=self.encryption_password)) return await self.storage.write(self.to_dict(encrypt_password=self.encryption_password))
elif not self.is_locked: elif not self.is_locked:
log.warning( log.warning(
"Disk encryption requested but no password available for encryption. " "Disk encryption requested but no password available for encryption. "
"Saving wallet in an unencrypted state." "Saving wallet in an unencrypted state."
) )
return self.storage.write(self.to_dict()) return await self.storage.write(self.to_dict())
@property @property
def hash(self) -> bytes: def hash(self) -> bytes:
@ -248,7 +152,7 @@ class Wallet:
local_match.merge(account_dict) local_match.merge(account_dict)
else: else:
added_accounts.append( added_accounts.append(
self.add_account(account_dict) self.accounts.add_from_dict(account_dict)
) )
return added_accounts return added_accounts
@ -292,6 +196,44 @@ class Wallet:
self.save() self.save()
return True return True
@property
def has_accounts(self):
return len(self.accounts) > 0
async def _get_account_and_address_info_for_address(self, address):
match = await self.db.get_address(accounts=self.accounts, address=address)
if match:
for account in self.accounts:
if match['account'] == account.public_key.address:
return account, match
async def get_private_key_for_address(self, address) -> Optional[PrivateKey]:
match = await self._get_account_and_address_info_for_address(address)
if match:
account, address_info = match
return account.get_private_key(address_info['chain'], address_info['pubkey'].n)
return None
async def get_public_key_for_address(self, address) -> Optional[PubKey]:
match = await self._get_account_and_address_info_for_address(address)
if match:
_, address_info = match
return address_info['pubkey']
return None
async def get_account_for_address(self, address):
match = await self._get_account_and_address_info_for_address(address)
if match:
return match[0]
async def save_max_gap(self):
gap_changed = False
for account in self.accounts:
if await account.save_max_gap():
gap_changed = True
if gap_changed:
self.save()
async def get_effective_amount_estimators(self, funding_accounts: Iterable[Account]): async def get_effective_amount_estimators(self, funding_accounts: Iterable[Account]):
estimators = [] estimators = []
utxos = await self.db.get_utxos( utxos = await self.db.get_utxos(
@ -303,6 +245,7 @@ class Wallet:
return estimators return estimators
async def get_spendable_utxos(self, amount: int, funding_accounts: Iterable[Account]): async def get_spendable_utxos(self, amount: int, funding_accounts: Iterable[Account]):
async with self.utxo_lock:
txos = await self.get_effective_amount_estimators(funding_accounts) txos = await self.get_effective_amount_estimators(funding_accounts)
fee = Output.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(self.ledger) fee = Output.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(self.ledger)
selector = CoinSelector(amount, fee) selector = CoinSelector(amount, fee)
@ -311,6 +254,12 @@ class Wallet:
await self.db.reserve_outputs(s.txo for s in spendables) await self.db.reserve_outputs(s.txo for s in spendables)
return spendables return spendables
async def list_transactions(self, **constraints):
return txs_to_dict(await self.db.get_transactions(
include_is_my_output=True, include_is_spent=True,
**constraints
), self.ledger)
async def create_transaction(self, inputs: Iterable[Input], outputs: Iterable[Output], async def create_transaction(self, inputs: Iterable[Input], outputs: Iterable[Output],
funding_accounts: Iterable[Account], change_account: Account, funding_accounts: Iterable[Account], change_account: Account,
sign: bool = True): sign: bool = True):
@ -397,46 +346,271 @@ class Wallet:
raise NotImplementedError("Don't know how to spend this output.") raise NotImplementedError("Don't know how to spend this output.")
tx._reset() tx._reset()
@classmethod async def pay(self, amount: int, address: bytes, funding_accounts: List[Account], change_account: Account):
def pay(cls, amount: int, address: bytes, funding_accounts: List['Account'], change_account: 'Account'): output = Output.pay_pubkey_hash(amount, self.ledger.address_to_hash160(address))
output = Output.pay_pubkey_hash(amount, ledger.address_to_hash160(address)) return await self.create_transaction([], [output], funding_accounts, change_account)
return cls.create([], [output], funding_accounts, change_account)
def claim_create( async def _report_state(self):
try:
for account in self.accounts:
balance = dewies_to_lbc(await account.get_balance(include_claims=True))
_, channel_count = await account.get_channels(limit=1)
claim_count = await account.get_claim_count()
if isinstance(account.receiving, SingleKey):
log.info("Loaded single key account %s with %s LBC. "
"%d channels, %d certificates and %d claims",
account.id, balance, channel_count, len(account.channel_keys), claim_count)
else:
total_receiving = len(await account.receiving.get_addresses())
total_change = len(await account.change.get_addresses())
log.info("Loaded account %s with %s LBC, %d receiving addresses (gap: %d), "
"%d change addresses (gap: %d), %d channels, %d certificates and %d claims. ",
account.id, balance, total_receiving, account.receiving.gap, total_change,
account.change.gap, channel_count, len(account.channel_keys), claim_count)
except Exception as err:
if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8
raise
log.exception(
'Failed to display wallet state, please file issue '
'for this bug along with the traceback you see below:')
class AccountListManager:
__slots__ = 'wallet', '_accounts'
def __init__(self, wallet: Wallet):
self.wallet = wallet
self._accounts: List[Account] = []
def __len__(self):
return self._accounts.__len__()
def __iter__(self):
return self._accounts.__iter__()
def __getitem__(self, account_id: str) -> Account:
for account in self:
if account.id == account_id:
return account
raise ValueError(f"Couldn't find account: {account_id}.")
@property
def default(self) -> Optional[Account]:
for account in self:
return account
def generate(self, name: str = None, address_generator: dict = None) -> Account:
account = Account.generate(self.wallet.ledger, self.wallet.db, name, address_generator)
self._accounts.append(account)
return account
def add_from_dict(self, account_dict: dict) -> Account:
account = Account.from_dict(self.wallet.ledger, self.wallet.db, account_dict)
self._accounts.append(account)
return account
async def remove(self, account_id: str) -> Account:
account = self[account_id]
self._accounts.remove(account)
await self.wallet.save()
return account
def get_or_none(self, account_id: str) -> Optional[Account]:
if account_id is not None:
return self[account_id]
def get_or_default(self, account_id: str) -> Optional[Account]:
if account_id is None:
return self.default
return self[account_id]
def get_or_all(self, account_ids: List[str]) -> List[Account]:
return [self[account_id] for account_id in account_ids] if account_ids else self._accounts
async def get_account_details(self, **kwargs):
accounts = []
for i, account in enumerate(self._accounts):
details = await account.get_details(**kwargs)
details['is_default'] = i == 0
accounts.append(details)
return accounts
class BaseListManager:
__slots__ = 'wallet', 'db'
def __init__(self, wallet: Wallet):
self.wallet = wallet
self.db = wallet.db
async def create(self, **kwargs) -> Transaction:
raise NotImplementedError
async def delete(self, **constraints):
raise NotImplementedError
async def list(self, **constraints) -> Tuple[List[Output], Optional[int]]:
raise NotImplementedError
async def get(self, **constraints) -> Output:
raise NotImplementedError
async def get_or_none(self, **constraints) -> Optional[Output]:
raise NotImplementedError
class ClaimListManager(BaseListManager):
name = 'claim'
__slots__ = ()
async def create(
self, name: str, claim: Claim, amount: int, holding_address: str, self, name: str, claim: Claim, amount: int, holding_address: str,
funding_accounts: List['Account'], change_account: 'Account', signing_channel: Output = None): funding_accounts: List[Account], change_account: Account, signing_channel: Output = None):
claim_output = Output.pay_claim_name_pubkey_hash( claim_output = Output.pay_claim_name_pubkey_hash(
amount, name, claim, self.ledger.address_to_hash160(holding_address) amount, name, claim, self.wallet.ledger.address_to_hash160(holding_address)
) )
if signing_channel is not None: if signing_channel is not None:
claim_output.sign(signing_channel, b'placeholder txid:nout') claim_output.sign(signing_channel, b'placeholder txid:nout')
return self.create_transaction( return await self.wallet.create_transaction(
[], [claim_output], funding_accounts, change_account, sign=False [], [claim_output], funding_accounts, change_account, sign=False
) )
@classmethod async def update(
def claim_update( self, previous_claim: Output, claim: Claim, amount: int, holding_address: str,
cls, previous_claim: Output, claim: Claim, amount: int, holding_address: str, funding_accounts: List[Account], change_account: Account, signing_channel: Output = None):
funding_accounts: List['Account'], change_account: 'Account', signing_channel: Output = None):
updated_claim = Output.pay_update_claim_pubkey_hash( updated_claim = Output.pay_update_claim_pubkey_hash(
amount, previous_claim.claim_name, previous_claim.claim_id, amount, previous_claim.claim_name, previous_claim.claim_id,
claim, ledger.address_to_hash160(holding_address) claim, self.wallet.ledger.address_to_hash160(holding_address)
) )
if signing_channel is not None: if signing_channel is not None:
updated_claim.sign(signing_channel, b'placeholder txid:nout') updated_claim.sign(signing_channel, b'placeholder txid:nout')
else: else:
updated_claim.clear_signature() updated_claim.clear_signature()
return cls.create( return await self.wallet.create_transaction(
[Input.spend(previous_claim)], [updated_claim], funding_accounts, change_account, sign=False [Input.spend(previous_claim)], [updated_claim], funding_accounts, change_account, sign=False
) )
@classmethod async def delete(self, claim_id=None, txid=None, nout=None):
def support(cls, claim_name: str, claim_id: str, amount: int, holding_address: str, claim = await self.get(claim_id=claim_id, txid=txid, nout=nout)
funding_accounts: List['Account'], change_account: 'Account'): return await self.wallet.create_transaction(
support_output = Output.pay_support_pubkey_hash( [Input.spend(claim)], [], self.wallet._accounts, self.wallet._accounts[0]
amount, claim_name, claim_id, ledger.address_to_hash160(holding_address) )
async def list(self, **constraints) -> Tuple[List[Output], Optional[int]]:
return await self.db.get_claims(wallet=self.wallet, **constraints)
async def get(self, claim_id=None, claim_name=None, txid=None, nout=None) -> Output:
if txid is not None and nout is not None:
key, value, constraints = 'txid:nout', f'{txid}:{nout}', {'tx_hash': '', 'position': nout}
elif claim_id is not None:
key, value, constraints = 'id', claim_id, {'claim_id': claim_id}
elif claim_name is not None:
key, value, constraints = 'name', claim_name, {'claim_name': claim_name}
else:
raise ValueError(f"Couldn't find {self.name} because an {self.name}_id or name was not provided.")
claims, _ = await self.list(**constraints)
if len(claims) == 1:
return claims[0]
elif len(claims) > 1:
raise ValueError(
f"Multiple {self.name}s found with {key} '{value}', "
f"pass a {self.name}_id to narrow it down."
)
raise ValueError(f"Couldn't find {self.name} with {key} '{value}'.")
async def get_or_none(self, claim_id=None, claim_name=None, txid=None, nout=None) -> Optional[Output]:
if any((claim_id, claim_name, all((txid, nout)))):
return await self.get(claim_id, claim_name, txid, nout)
class StreamListManager(ClaimListManager):
__slots__ = ()
async def create(self, *args, **kwargs):
return await super().create(*args, **kwargs)
async def list(self, **constraints) -> Tuple[List[Output], Optional[int]]:
return await self.db.get_streams(wallet=self.wallet, **constraints)
class CollectionListManager(ClaimListManager):
__slots__ = ()
async def create(self, *args, **kwargs):
return await super().create(*args, **kwargs)
async def list(self, **constraints) -> Tuple[List[Output], Optional[int]]:
return await self.db.get_collections(wallet=self.wallet, **constraints)
class ChannelListManager(ClaimListManager):
name = 'channel'
__slots__ = ()
async def create(self, name: str, amount: int, account: Account, funding_accounts: List[Account],
claim_address: str, preview=False, **kwargs):
claim = Claim()
claim.channel.update(**kwargs)
tx = await super().create(
name, claim, amount, claim_address, funding_accounts, funding_accounts[0]
)
txo = tx.outputs[0]
txo.generate_channel_private_key()
await self.wallet.sign(tx)
if not preview:
account.add_channel_private_key(txo.private_key)
await self.wallet.save()
return tx
async def list(self, **constraints) -> Tuple[List[Output], Optional[int]]:
return await self.db.get_channels(wallet=self.wallet, **constraints)
async def get_for_signing(self, **kwargs) -> Output:
channel = await self.get(**kwargs)
if not channel.has_private_key:
raise Exception(
f"Couldn't find private key for channel '{channel.claim_name}', can't use channel for signing. "
)
return channel
async def get_for_signing_or_none(self, **kwargs) -> Optional[Output]:
if any(kwargs.values()):
return await self.get_for_signing(**kwargs)
class SupportListManager(BaseListManager):
__slots__ = ()
async def create(self, name: str, claim_id: str, amount: int, holding_address: str,
funding_accounts: List[Account], change_account: Account) -> Transaction:
support_output = Output.pay_support_pubkey_hash(
amount, name, claim_id, self.wallet.ledger.address_to_hash160(holding_address)
)
return await self.wallet.create_transaction(
[], [support_output], funding_accounts, change_account
)
async def list(self, **constraints) -> Tuple[List[Output], Optional[int]]:
return await self.db.get_supports(**constraints)
async def get(self, **constraints) -> Output:
raise NotImplementedError
async def get_or_none(self, **constraints) -> Optional[Output]:
raise NotImplementedError
class PurchaseListManager(BaseListManager):
__slots__ = ()
async def create(self, name: str, claim_id: str, amount: int, holding_address: str,
funding_accounts: List[Account], change_account: Account) -> Transaction:
support_output = Output.pay_support_pubkey_hash(
amount, name, claim_id, self.wallet.ledger.address_to_hash160(holding_address)
)
return await self.wallet.create_transaction(
[], [support_output], funding_accounts, change_account
) )
return cls.create([], [support_output], funding_accounts, change_account)
def purchase(self, claim_id: str, amount: int, merchant_address: bytes, def purchase(self, claim_id: str, amount: int, merchant_address: bytes,
funding_accounts: List['Account'], change_account: 'Account'): funding_accounts: List['Account'], change_account: 'Account'):
@ -470,84 +644,120 @@ class Wallet:
txo.claim_id, fee_amount, fee_address, accounts, accounts[0] txo.claim_id, fee_amount, fee_address, accounts, accounts[0]
) )
async def create_channel( async def list(self, **constraints) -> Tuple[List[Output], Optional[int]]:
self, name, amount, account, funding_accounts, return await self.db.get_purchases(**constraints)
claim_address, preview=False, **kwargs):
claim = Claim() async def get(self, **constraints) -> Output:
claim.channel.update(**kwargs) raise NotImplementedError
tx = await self.claim_create(
name, claim, amount, claim_address, funding_accounts, funding_accounts[0]
)
txo = tx.outputs[0]
txo.generate_channel_private_key()
await self.sign(tx) async def get_or_none(self, **constraints) -> Optional[Output]:
raise NotImplementedError
if not preview:
account.add_channel_private_key(txo.private_key)
self.save()
return tx
async def get_channels(self):
return await self.db.get_channels()
class WalletStorage: def txs_to_dict(txs, ledger):
history = []
LATEST_VERSION = 1 for tx in txs: # pylint: disable=too-many-nested-blocks
ts = headers.estimated_timestamp(tx.height)
def __init__(self, path=None, default=None): item = {
self.path = path 'txid': tx.id,
self._default = default or { 'timestamp': ts,
'version': self.LATEST_VERSION, 'date': datetime.fromtimestamp(ts).isoformat(' ')[:-3] if tx.height > 0 else None,
'name': 'My Wallet', 'confirmations': (headers.height + 1) - tx.height if tx.height > 0 else 0,
'preferences': {}, 'claim_info': [],
'accounts': [] 'update_info': [],
'support_info': [],
'abandon_info': [],
'purchase_info': []
} }
is_my_inputs = all([txi.is_my_input for txi in tx.inputs])
def read(self): if is_my_inputs:
if self.path and os.path.exists(self.path): # fees only matter if we are the ones paying them
with open(self.path, 'r') as f: item['value'] = dewies_to_lbc(tx.net_account_balance + tx.fee)
json_data = f.read() item['fee'] = dewies_to_lbc(-tx.fee)
json_dict = json.loads(json_data)
if json_dict.get('version') == self.LATEST_VERSION and \
set(json_dict) == set(self._default):
return json_dict
else: else:
return self.upgrade(json_dict) # someone else paid the fees
else: item['value'] = dewies_to_lbc(tx.net_account_balance)
return self._default.copy() item['fee'] = '0.0'
for txo in tx.my_claim_outputs:
def upgrade(self, json_dict): item['claim_info'].append({
json_dict = json_dict.copy() 'address': txo.get_address(self.ledger),
version = json_dict.pop('version', -1) 'balance_delta': dewies_to_lbc(-txo.amount),
if version == -1: 'amount': dewies_to_lbc(txo.amount),
pass 'claim_id': txo.claim_id,
upgraded = self._default.copy() 'claim_name': txo.claim_name,
upgraded.update(json_dict) 'nout': txo.position,
return json_dict 'is_spent': txo.is_spent,
})
def write(self, json_dict): for txo in tx.my_update_outputs:
if is_my_inputs: # updating my own claim
json_data = json.dumps(json_dict, indent=4, sort_keys=True) previous = None
if self.path is None: for txi in tx.inputs:
return json_data if txi.txo_ref.txo is not None:
other_txo = txi.txo_ref.txo
temp_path = "{}.tmp.{}".format(self.path, os.getpid()) if (other_txo.is_claim or other_txo.script.is_support_claim) \
with open(temp_path, "w") as f: and other_txo.claim_id == txo.claim_id:
f.write(json_data) previous = other_txo
f.flush() break
os.fsync(f.fileno()) if previous is not None:
item['update_info'].append({
if os.path.exists(self.path): 'address': txo.get_address(self),
mode = os.stat(self.path).st_mode 'balance_delta': dewies_to_lbc(previous.amount - txo.amount),
else: 'amount': dewies_to_lbc(txo.amount),
mode = stat.S_IREAD | stat.S_IWRITE 'claim_id': txo.claim_id,
try: 'claim_name': txo.claim_name,
os.rename(temp_path, self.path) 'nout': txo.position,
except Exception: # pylint: disable=broad-except 'is_spent': txo.is_spent,
os.remove(self.path) })
os.rename(temp_path, self.path) else: # someone sent us their claim
os.chmod(self.path, mode) item['update_info'].append({
'address': txo.get_address(self),
'balance_delta': dewies_to_lbc(0),
'amount': dewies_to_lbc(txo.amount),
'claim_id': txo.claim_id,
'claim_name': txo.claim_name,
'nout': txo.position,
'is_spent': txo.is_spent,
})
for txo in tx.my_support_outputs:
item['support_info'].append({
'address': txo.get_address(self.ledger),
'balance_delta': dewies_to_lbc(txo.amount if not is_my_inputs else -txo.amount),
'amount': dewies_to_lbc(txo.amount),
'claim_id': txo.claim_id,
'claim_name': txo.claim_name,
'is_tip': not is_my_inputs,
'nout': txo.position,
'is_spent': txo.is_spent,
})
if is_my_inputs:
for txo in tx.other_support_outputs:
item['support_info'].append({
'address': txo.get_address(self.ledger),
'balance_delta': dewies_to_lbc(-txo.amount),
'amount': dewies_to_lbc(txo.amount),
'claim_id': txo.claim_id,
'claim_name': txo.claim_name,
'is_tip': is_my_inputs,
'nout': txo.position,
'is_spent': txo.is_spent,
})
for txo in tx.my_abandon_outputs:
item['abandon_info'].append({
'address': txo.get_address(self.ledger),
'balance_delta': dewies_to_lbc(txo.amount),
'amount': dewies_to_lbc(txo.amount),
'claim_id': txo.claim_id,
'claim_name': txo.claim_name,
'nout': txo.position
})
for txo in tx.any_purchase_outputs:
item['purchase_info'].append({
'address': txo.get_address(self.ledger),
'balance_delta': dewies_to_lbc(txo.amount if not is_my_inputs else -txo.amount),
'amount': dewies_to_lbc(txo.amount),
'claim_id': txo.purchased_claim_id,
'nout': txo.position,
'is_spent': txo.is_spent,
})
history.append(item)
return history

View file

@ -0,0 +1,96 @@
import os
import shutil
import tempfile
from lbry.testcase import AsyncioTestCase
from lbry.blockchain.ledger import Ledger
from lbry.wallet import WalletManager, Wallet, Account
from lbry.db import Database
from lbry.conf import Config
class TestWalletManager(AsyncioTestCase):
async def asyncSetUp(self):
self.temp_dir = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, self.temp_dir)
self.ledger = Ledger(Config(
wallet_dir=self.temp_dir
))
self.db = Database.from_memory(self.ledger)
async def test_ensure_path_exists(self):
wm = WalletManager(self.ledger, self.db)
self.assertFalse(os.path.exists(wm.path))
await wm.ensure_path_exists()
self.assertTrue(os.path.exists(wm.path))
async def test_load_with_default_wallet_account_progression(self):
wm = WalletManager(self.ledger, self.db)
await wm.ensure_path_exists()
# first, no defaults
self.ledger.conf.create_default_wallet = False
self.ledger.conf.create_default_account = False
await wm.load()
self.assertIsNone(wm.default)
# then, yes to default wallet but no to default account
self.ledger.conf.create_default_wallet = True
self.ledger.conf.create_default_account = False
await wm.load()
self.assertIsInstance(wm.default, Wallet)
self.assertTrue(os.path.exists(wm.default.storage.path))
self.assertIsNone(wm.default.accounts.default)
# finally, yes to all the things
self.ledger.conf.create_default_wallet = True
self.ledger.conf.create_default_account = True
await wm.load()
self.assertIsInstance(wm.default, Wallet)
self.assertIsInstance(wm.default.accounts.default, Account)
async def test_load_with_create_default_everything_upfront(self):
wm = WalletManager(self.ledger, self.db)
await wm.ensure_path_exists()
self.ledger.conf.create_default_wallet = True
self.ledger.conf.create_default_account = True
await wm.load()
self.assertIsInstance(wm.default, Wallet)
self.assertIsInstance(wm.default.accounts.default, Account)
self.assertTrue(os.path.exists(wm.default.storage.path))
async def test_load_errors(self):
_wm = WalletManager(self.ledger, self.db)
await _wm.ensure_path_exists()
await _wm.create('bar', '')
await _wm.create('foo', '')
wm = WalletManager(self.ledger, self.db)
self.ledger.conf.wallets = ['bar', 'foo', 'foo']
with self.assertLogs(level='WARN') as cm:
await wm.load()
self.assertEqual(
cm.output, [
'WARNING:lbry.wallet.manager:Ignoring duplicate wallet_id in config: foo',
]
)
self.assertEqual({'bar', 'foo'}, set(wm.wallets))
async def test_creating_and_accessing_wallets(self):
wm = WalletManager(self.ledger, self.db)
await wm.ensure_path_exists()
await wm.load()
default_wallet = wm.default
self.assertEqual(default_wallet, wm['default_wallet'])
self.assertEqual(default_wallet, wm.get_or_default(None))
new_wallet = await wm.create('second', 'Second Wallet')
self.assertEqual(default_wallet, wm.default)
self.assertEqual(new_wallet, wm['second'])
self.assertEqual(new_wallet, wm.get_or_default('second'))
self.assertEqual(default_wallet, wm.get_or_default(None))
with self.assertRaisesRegex(ValueError, "Couldn't find wallet: invalid"):
_ = wm['invalid']
with self.assertRaisesRegex(ValueError, "Couldn't find wallet: invalid"):
wm.get_or_default('invalid')