From 8c91777e5d05ab7dd3003fb7ca9150b9ff6a137c Mon Sep 17 00:00:00 2001 From: Lex Berezhny Date: Wed, 6 May 2020 10:53:31 -0400 Subject: [PATCH] refactored lbry.wallet --- lbry/service/api.py | 214 ++++----- lbry/service/base.py | 268 ++---------- lbry/testcase.py | 22 +- lbry/wallet/__init__.py | 3 + lbry/wallet/account.py | 4 +- lbry/wallet/manager.py | 195 ++++----- lbry/wallet/preferences.py | 39 ++ lbry/wallet/storage.py | 60 +++ lbry/wallet/wallet.py | 692 +++++++++++++++++++----------- tests/unit/wallet/test_manager.py | 96 +++++ 10 files changed, 883 insertions(+), 710 deletions(-) create mode 100644 lbry/wallet/preferences.py create mode 100644 lbry/wallet/storage.py create mode 100644 tests/unit/wallet/test_manager.py diff --git a/lbry/service/api.py b/lbry/service/api.py index 65e5a08a0..9b468914f 100644 --- a/lbry/service/api.py +++ b/lbry/service/api.py @@ -54,7 +54,8 @@ class API: def __init__(self, service: Service): self.service = service - self.wallet_manager = service.wallet_manager + self.wallets = service.wallets + self.ledger = service.ledger async def stop(self): """ @@ -291,7 +292,7 @@ class API: if isinstance(urls, str): urls = [urls] return await self.service.resolve( - urls, wallet=self.wallet_manager.get_wallet_or_default(wallet_id), **kwargs + urls, wallet=self.wallets.get_or_default(wallet_id), **kwargs ) async def get( @@ -317,7 +318,7 @@ class API: """ return await self.service.get( uri, file_name=file_name, download_directory=download_directory, timeout=timeout, save_file=save_file, - wallet=self.wallet_manager.get_wallet_or_default(wallet_id) + wallet=self.wallets.get_or_default(wallet_id) ) SETTINGS_DOC = """ @@ -396,7 +397,7 @@ class API: Returns: (dict) Dictionary of preference(s) """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) if key: if key in wallet.preferences: return {key: wallet.preferences[key]} @@ -418,7 +419,7 @@ class API: Returns: (dict) Dictionary with key/value of new preference """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) if value and isinstance(value, str) and value[0] in ('[', '{'): value = json.loads(value) wallet.preferences[key] = value @@ -444,8 +445,8 @@ class API: Returns: {Paginated[Wallet]} """ if wallet_id: - return paginate_list([self.wallet_manager.get_wallet_or_error(wallet_id)], 1, 1) - return paginate_list(self.wallet_manager.wallets, page, page_size) + return paginate_list([self.wallets.get_wallet_or_error(wallet_id)], 1, 1) + return paginate_list(self.wallets.wallets, page, page_size) async def wallet_reconnect(self): """ @@ -458,7 +459,7 @@ class API: Returns: None """ - return self.wallet_manager.reset() + return self.wallets.reset() async def wallet_create( self, wallet_id, skip_on_startup=False, create_account=False, single_key=False): @@ -478,13 +479,13 @@ class API: Returns: {Wallet} """ wallet_path = os.path.join(self.conf.wallet_dir, 'wallets', wallet_id) - for wallet in self.wallet_manager.wallets: + for wallet in self.wallets.wallets: if wallet.id == wallet_id: raise Exception(f"Wallet at path '{wallet_path}' already exists and is loaded.") if os.path.exists(wallet_path): raise Exception(f"Wallet at path '{wallet_path}' already exists, use 'wallet_add' to load wallet.") - wallet = self.wallet_manager.import_wallet(wallet_path) + wallet = self.wallets.import_wallet(wallet_path) if not wallet.accounts and create_account: account = Account.generate( self.ledger, wallet, address_generator={ @@ -512,12 +513,12 @@ class API: Returns: {Wallet} """ wallet_path = os.path.join(self.conf.wallet_dir, 'wallets', wallet_id) - for wallet in self.wallet_manager.wallets: + for wallet in self.wallets.wallets: if wallet.id == wallet_id: raise Exception(f"Wallet at path '{wallet_path}' is already loaded.") if not os.path.exists(wallet_path): raise Exception(f"Wallet at path '{wallet_path}' was not found.") - wallet = self.wallet_manager.import_wallet(wallet_path) + wallet = self.wallets.import_wallet(wallet_path) if self.ledger.sync.network.is_connected: for account in wallet.accounts: await self.ledger.subscribe_account(account) @@ -535,8 +536,8 @@ class API: Returns: {Wallet} """ - wallet = self.wallet_manager.get_wallet_or_error(wallet_id) - self.wallet_manager.wallets.remove(wallet) + wallet = self.wallets.get_wallet_or_error(wallet_id) + self.wallets.wallets.remove(wallet) for account in wallet.accounts: await self.ledger.unsubscribe_account(account) return wallet @@ -556,7 +557,7 @@ class API: Returns: (decimal) amount of lbry credits in wallet """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) balance = await self.ledger.get_detailed_balance( accounts=wallet.accounts, confirmations=confirmations ) @@ -575,9 +576,9 @@ class API: Returns: Dictionary of wallet status information. """ - if self.wallet_manager is None: + if self.wallets is None: return {'is_encrypted': None, 'is_syncing': None, 'is_locked': None} - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) return { 'is_encrypted': wallet.is_encrypted, 'is_syncing': len(self.ledger._update_tasks) > 0, @@ -598,7 +599,7 @@ class API: Returns: (bool) true if wallet is unlocked, otherwise false """ - return self.wallet_manager.get_wallet_or_default(wallet_id).unlock(password) + return self.wallets.get_or_default(wallet_id).unlock(password) async def wallet_lock(self, wallet_id=None): """ @@ -613,7 +614,7 @@ class API: Returns: (bool) true if wallet is locked, otherwise false """ - return self.wallet_manager.get_wallet_or_default(wallet_id).lock() + return self.wallets.get_or_default(wallet_id).lock() async def wallet_decrypt(self, wallet_id=None): """ @@ -628,7 +629,7 @@ class API: Returns: (bool) true if wallet is decrypted, otherwise false """ - return self.wallet_manager.get_wallet_or_default(wallet_id).decrypt() + return self.wallets.get_or_default(wallet_id).decrypt() async def wallet_encrypt(self, new_password, wallet_id=None): """ @@ -645,7 +646,7 @@ class API: Returns: (bool) true if wallet is decrypted, otherwise false """ - return self.wallet_manager.get_wallet_or_default(wallet_id).encrypt(new_password) + return self.wallets.get_or_default(wallet_id).encrypt(new_password) async def wallet_send( self, amount, addresses, wallet_id=None, @@ -666,10 +667,10 @@ class API: Returns: {Transaction} """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) assert not wallet.is_locked, "Cannot spend funds with locked wallet, unlock first." - account = wallet.get_account_or_default(change_account_id) - accounts = wallet.get_accounts_or_all(funding_account_ids) + account = wallet.accounts.get_or_default(change_account_id) + accounts = wallet.accounts.get_or_all(funding_account_ids) amount = self.get_dewies_or_error("amount", amount) @@ -730,7 +731,7 @@ class API: 'confirmations': confirmations, 'show_seed': show_seed } - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) if account_id: return paginate_list([await wallet.get_account_or_error(account_id).get_details(**kwargs)], 1, 1) else: @@ -754,8 +755,8 @@ class API: Returns: (decimal) amount of lbry credits in wallet """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) - account = wallet.get_account_or_default(account_id) + wallet = self.wallets.get_or_default(wallet_id) + account = wallet.accounts.get_or_default(account_id) balance = await account.get_detailed_balance( confirmations=confirmations, reserved_subtotals=True, ) @@ -783,7 +784,7 @@ class API: Returns: {Account} """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) account = Account.from_dict( self.ledger, wallet, { 'name': account_name, @@ -816,7 +817,7 @@ class API: Returns: {Account} """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) account = Account.generate( self.ledger.ledger, wallet, account_name, { 'name': SingleKey.name if single_key else HierarchicalDeterministic.name @@ -840,7 +841,7 @@ class API: Returns: {Account} """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) account = wallet.get_account_or_error(account_id) wallet.accounts.remove(account) wallet.save() @@ -872,7 +873,7 @@ class API: Returns: {Account} """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) account = wallet.get_account_or_error(account_id) change_made = False @@ -921,7 +922,7 @@ class API: Returns: (map) maximum gap for change and receiving addresses """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) return wallet.get_account_or_error(account_id).get_max_gap() async def account_fund(self, to_account=None, from_account=None, amount='0.0', @@ -950,9 +951,9 @@ class API: Returns: {Transaction} """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) - to_account = wallet.get_account_or_default(to_account) - from_account = wallet.get_account_or_default(from_account) + wallet = self.wallets.get_or_default(wallet_id) + to_account = wallet.accounts.get_or_default(to_account) + from_account = wallet.accounts.get_or_default(from_account) amount = self.get_dewies_or_error('amount', amount) if amount else None if not isinstance(outputs, int): raise ValueError("--outputs must be an integer.") @@ -1000,7 +1001,7 @@ class API: Returns: (str) sha256 hash of wallet """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) return hexlify(wallet.hash).decode() async def sync_apply(self, password, data=None, wallet_id=None, blocking=False): @@ -1026,10 +1027,10 @@ class API: (map) sync hash and data """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) wallet_changed = False if data is not None: - added_accounts = wallet.merge(self.wallet_manager, password, data) + added_accounts = wallet.merge(self.wallets, password, data) if added_accounts and self.ledger.sync.network.is_connected: if blocking: await asyncio.wait([ @@ -1070,8 +1071,8 @@ class API: Returns: (bool) true, if address is associated with current wallet """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) - account = wallet.get_account_or_default(account_id) + wallet = self.wallets.get_or_default(wallet_id) + account = wallet.accounts.get_or_default(account_id) match = await self.ledger.db.get_address(address=address, accounts=[account]) if match is not None: return True @@ -1094,7 +1095,7 @@ class API: Returns: {Paginated[Address]} """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) constraints = {} if address: constraints['address'] = address @@ -1122,8 +1123,8 @@ class API: Returns: {Address} """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) - return wallet.get_account_or_default(account_id).receiving.get_or_create_usable_address() + wallet = self.wallets.get_or_default(wallet_id) + return wallet.accounts.get_or_default(account_id).receiving.get_or_create_usable_address() async def address_block_filters(self): return await self.service.get_block_address_filters() @@ -1175,7 +1176,7 @@ class API: Returns: {Paginated[File]} """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) sort = sort or 'rowid' comparison = comparison or 'eq' paginated = paginate_list( @@ -1348,7 +1349,7 @@ class API: Returns: {Paginated[Output]} """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) constraints = { "wallet": wallet, "accounts": [wallet.get_account_or_error(account_id)] if account_id else wallet.accounts, @@ -1385,9 +1386,9 @@ class API: Returns: {Transaction} """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) assert not wallet.is_locked, "Cannot spend funds with locked wallet, unlock first." - accounts = wallet.get_accounts_or_all(funding_account_ids) + accounts = wallet.accounts.get_or_all(funding_account_ids) txo = None if claim_id: txo = await self.ledger.get_claim_by_claim_id(accounts, claim_id, include_purchase_receipt=True) @@ -1407,7 +1408,7 @@ class API: claim = txo.claim if not claim.is_stream or not claim.stream.has_fee: raise Exception(f"Claim '{claim_id}' does not have a purchase price.") - tx = await self.wallet_manager.create_purchase_transaction( + tx = await self.wallets.create_purchase_transaction( accounts, txo, self.exchange_rate_manager, override_max_key_fee ) if not preview: @@ -1594,7 +1595,7 @@ class API: Returns: {Paginated[Output]} """ - wallet = self.wallet_manager.get_wallet_or_default(kwargs.pop('wallet_id', None)) + wallet = self.wallets.get_or_default(kwargs.pop('wallet_id', None)) if {'claim_id', 'claim_ids'}.issubset(kwargs): raise ValueError("Only 'claim_id' or 'claim_ids' is allowed, not both.") if kwargs.pop('valid_channel_signature', False): @@ -1697,16 +1698,15 @@ class API: Returns: {Transaction} """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) - assert not wallet.is_locked, "Cannot spend funds with locked wallet, unlock first." - account = wallet.get_account_or_default(account_id) - funding_accounts = wallet.get_accounts_or_all(funding_account_ids) self.service.ledger.valid_channel_name_or_error(name) - #amount = self.get_dewies_or_error('bid', bid, positive_value=True) - amount = lbc_to_dewies(bid) + wallet = self.wallets.get_or_default(wallet_id) + assert not wallet.is_locked, "Cannot spend funds with locked wallet, unlock first." + account = wallet.accounts.get_or_default(account_id) + funding_accounts = wallet.accounts.get_or_all(funding_account_ids) + amount = self.ledger.get_dewies_or_error('bid', bid, positive_value=True) claim_address = await account.get_valid_receiving_address(claim_address) - existing_channels, _ = await self.service.get_channels(wallet=wallet, claim_name=name) + existing_channels, _ = await wallet.channels.list(claim_name=name) if len(existing_channels) > 0: if not allow_duplicate_name: raise Exception( @@ -1714,14 +1714,14 @@ class API: f"Use --allow-duplicate-name flag to override." ) - tx = await wallet.create_channel( + tx = await wallet.channels.create( name, amount, account, funding_accounts, claim_address, preview, **kwargs ) if not preview: await self.service.broadcast_or_release(tx, blocking) else: - await account.ledger.release_tx(tx) + await self.service.release_tx(tx) return tx @@ -1813,9 +1813,9 @@ class API: Returns: {Transaction} """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) assert not wallet.is_locked, "Cannot spend funds with locked wallet, unlock first." - funding_accounts = wallet.get_accounts_or_all(funding_account_ids) + funding_accounts = wallet.accounts.get_or_all(funding_account_ids) if account_id: account = wallet.get_account_or_error(account_id) accounts = [account] @@ -1903,7 +1903,7 @@ class API: Returns: {Transaction} """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) assert not wallet.is_locked, "Cannot spend funds with locked wallet, unlock first." if account_id: account = wallet.get_account_or_error(account_id) @@ -1985,7 +1985,7 @@ class API: Returns: (str) serialized channel private key """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) channel = await self.get_channel_or_error(wallet, account_id, channel_id, channel_name, for_signing=True) address = channel.get_address(self.ledger) public_key = await self.ledger.get_public_key_for_address(wallet, address) @@ -2014,7 +2014,7 @@ class API: Returns: (dict) Result dictionary """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) decoded = base58.b58decode(channel_data) data = json.loads(decoded) @@ -2161,7 +2161,7 @@ class API: Returns: {Transaction} """ self.valid_stream_name_or_error(name) - wallet = self.wallet_manager.get_wallet_or_default(kwargs.get('wallet_id')) + wallet = self.wallets.get_or_default(kwargs.get('wallet_id')) if kwargs.get('account_id'): accounts = [wallet.get_account_or_error(kwargs.get('account_id'))] else: @@ -2218,10 +2218,10 @@ class API: Returns: {Transaction} """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) self.valid_stream_name_or_error(name) - account = wallet.get_account_or_default(account_id) - funding_accounts = wallet.get_accounts_or_all(funding_account_ids) + account = wallet.accounts.get_or_default(account_id) + funding_accounts = wallet.accounts.get_or_all(funding_account_ids) channel = await self.get_channel_or_none(wallet, channel_account_id, channel_id, channel_name, for_signing=True) amount = self.get_dewies_or_error('bid', bid, positive_value=True) claim_address = await self.get_receiving_address(claim_address, account) @@ -2359,17 +2359,17 @@ class API: Returns: {Transaction} """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + self.ledger.valid_stream_name_or_error(name) + wallet = self.wallets.get_or_default(wallet_id) assert not wallet.is_locked, "Cannot spend funds with locked wallet, unlock first." - self.valid_stream_name_or_error(name) - account = wallet.get_account_or_default(account_id) - funding_accounts = wallet.get_accounts_or_all(funding_account_ids) - channel = await self.get_channel_or_none(wallet, channel_account_id, channel_id, channel_name, for_signing=True) - amount = self.get_dewies_or_error('bid', bid, positive_value=True) - claim_address = await self.get_receiving_address(claim_address, account) - kwargs['fee_address'] = self.get_fee_address(kwargs, claim_address) + account = wallet.accounts.get_or_default(account_id) + funding_accounts = wallet.accounts.get_or_all(funding_account_ids) + channel = await wallet.channels.get_for_signing_or_none(claim_id=channel_id, claim_name=channel_name) + amount = self.ledger.get_dewies_or_error('bid', bid, positive_value=True) + claim_address = await account.get_valid_receiving_address(claim_address) + kwargs['fee_address'] = self.ledger.get_fee_address(kwargs, claim_address) - claims = await account.get_claims(claim_name=name) + claims = await wallet.streams.list(claim_name=name) if len(claims) > 0: if not allow_duplicate_name: raise Exception( @@ -2377,11 +2377,15 @@ class API: f"Use --allow-duplicate-name flag to override." ) - file_path, spec = await self._video_file_analyzer.verify_or_repair( - validate_file, optimize_file, file_path, ignore_non_video=True - ) - kwargs.update(spec) + # TODO: fix + #file_path, spec = await self._video_file_analyzer.verify_or_repair( + # validate_file, optimize_file, file_path, ignore_non_video=True + #) + #kwargs.update(spec) + wallet.streams.create( + + ) claim = Claim() claim.stream.update(file_path=file_path, sd_hash='0' * 96, **kwargs) tx = await Transaction.claim_create( @@ -2529,9 +2533,9 @@ class API: Returns: {Transaction} """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) assert not wallet.is_locked, "Cannot spend funds with locked wallet, unlock first." - funding_accounts = wallet.get_accounts_or_all(funding_account_ids) + funding_accounts = wallet.accounts.get_or_all(funding_account_ids) if account_id: account = wallet.get_account_or_error(account_id) accounts = [account] @@ -2651,7 +2655,7 @@ class API: Returns: {Transaction} """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) assert not wallet.is_locked, "Cannot spend funds with locked wallet, unlock first." if account_id: account = wallet.get_account_or_error(account_id) @@ -2813,9 +2817,9 @@ class API: Returns: {Transaction} """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) - account = wallet.get_account_or_default(account_id) - funding_accounts = wallet.get_accounts_or_all(funding_account_ids) + wallet = self.wallets.get_or_default(wallet_id) + account = wallet.accounts.get_or_default(account_id) + funding_accounts = wallet.accounts.get_or_all(funding_account_ids) self.valid_collection_name_or_error(name) channel = await self.get_channel_or_none(wallet, channel_account_id, channel_id, channel_name, for_signing=True) amount = self.get_dewies_or_error('bid', bid, positive_value=True) @@ -2933,8 +2937,8 @@ class API: Returns: {Transaction} """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) - funding_accounts = wallet.get_accounts_or_all(funding_account_ids) + wallet = self.wallets.get_or_default(wallet_id) + funding_accounts = wallet.accounts.get_or_all(funding_account_ids) if account_id: account = wallet.get_account_or_error(account_id) accounts = [account] @@ -3041,7 +3045,7 @@ class API: Returns: {Paginated[Output]} """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) if account_id: account = wallet.get_account_or_error(account_id) collections = account.get_collections @@ -3069,7 +3073,7 @@ class API: Returns: {Paginated[Output]} """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) if claim_id: txo = await self.ledger.get_claim_by_claim_id(wallet.accounts, claim_id) @@ -3121,14 +3125,14 @@ class API: Returns: {Transaction} """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) assert not wallet.is_locked, "Cannot spend funds with locked wallet, unlock first." - funding_accounts = wallet.get_accounts_or_all(funding_account_ids) - amount = self.get_dewies_or_error("amount", amount) + funding_accounts = wallet.accounts.get_or_all(funding_account_ids) + amount = self.ledger.get_dewies_or_error("amount", amount) claim = await self.ledger.get_claim_by_claim_id(wallet.accounts, claim_id) claim_address = claim.get_address(self.ledger.ledger) if not tip: - account = wallet.get_account_or_default(account_id) + account = wallet.accounts.get_or_default(account_id) claim_address = await account.receiving.get_or_create_usable_address() tx = await Transaction.support( @@ -3217,7 +3221,7 @@ class API: Returns: {Transaction} """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) assert not wallet.is_locked, "Cannot spend funds with locked wallet, unlock first." if account_id: account = wallet.get_account_or_error(account_id) @@ -3329,7 +3333,7 @@ class API: } """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) if account_id: account = wallet.get_account_or_error(account_id) transactions = account.get_transaction_history @@ -3441,7 +3445,7 @@ class API: Returns: {Paginated[Output]} """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) if account_id: account = wallet.get_account_or_error(account_id) claims = account.get_txos @@ -3501,7 +3505,7 @@ class API: Returns: {List[Transaction]} """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) accounts = [wallet.get_account_or_error(account_id)] if account_id else wallet.accounts txos = await self.ledger.get_txos( wallet=wallet, accounts=accounts, @@ -3559,7 +3563,7 @@ class API: Returns: int """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) return self.ledger.get_txo_sum( wallet=wallet, accounts=[wallet.get_account_or_error(account_id)] if account_id else wallet.accounts, **self._constrain_txo_from_kwargs({}, **kwargs) @@ -3614,7 +3618,7 @@ class API: Returns: List[Dict] """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) plot = await self.ledger.get_txo_plot( wallet=wallet, accounts=[wallet.get_account_or_error(account_id)] if account_id else wallet.accounts, days_back=days_back, start_day=start_day, days_after=days_after, end_day=end_day, @@ -3666,7 +3670,7 @@ class API: Returns: None """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) if account_id is not None: await wallet.get_account_or_error(account_id).release_all_outputs() else: @@ -4160,7 +4164,7 @@ class API: "timestamp": (int) The time at which comment was entered into the server at, in nanoseconds. } """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) channel = await self.get_channel_or_error( wallet, channel_account_id, channel_id, channel_name, for_signing=True ) @@ -4216,7 +4220,7 @@ class API: if 'error' in channel: raise ValueError(channel['error']) - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) # channel = await self.get_channel_or_none(wallet, None, **channel) channel_claim = await self.get_channel_or_error(wallet, [], **channel) edited_comment = { @@ -4249,7 +4253,7 @@ class API: } } """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) abandon_comment_body = {'comment_id': comment_id} channel = await comment_client.post( self.conf.comment_server, 'get_channel_from_comment_id', comment_id=comment_id @@ -4281,7 +4285,7 @@ class API: "hidden": (bool) flag indicating if comment_id was hidden } """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) + wallet = self.wallets.get_or_default(wallet_id) if isinstance(comment_ids, str): comment_ids = [comment_ids] diff --git a/lbry/service/base.py b/lbry/service/base.py index 81288a92d..8e58baa38 100644 --- a/lbry/service/base.py +++ b/lbry/service/base.py @@ -1,18 +1,14 @@ import os import asyncio import logging -from datetime import datetime -from typing import Iterable, List, Optional, Tuple, NamedTuple +from typing import List, Optional, Tuple, NamedTuple from lbry.db import Database from lbry.db.constants import TXO_TYPES from lbry.schema.result import Censor -from lbry.blockchain.dewies import dewies_to_lbc from lbry.blockchain.transaction import Transaction, Output from lbry.blockchain.ledger import Ledger -from lbry.crypto.bip32 import PubKey, PrivateKey -from lbry.wallet.account import Account, AddressManager, SingleKey -from lbry.wallet.manager import WalletManager +from lbry.wallet import WalletManager, AddressManager from lbry.event import EventController log = logging.getLogger(__name__) @@ -62,7 +58,7 @@ class Service: def __init__(self, ledger: Ledger, db_url: str): self.ledger, self.conf = ledger, ledger.conf self.db = Database(ledger, db_url) - self.wallet_manager = WalletManager(ledger, self.db) + self.wallets = WalletManager(ledger, self.db) #self.on_address = sync.on_address #self.accounts = sync.accounts @@ -77,7 +73,8 @@ class Service: async def start(self): await self.db.open() - await self.wallet_manager.open() + await self.wallets.ensure_path_exists() + await self.wallets.load() await self.sync.start() async def stop(self): @@ -106,30 +103,9 @@ class Service: path = os.path.join(self.conf.wallet_dir, file_name) return self.wallet_manager.import_wallet(path) - def add_account(self, account: Account): - self.ledger.add_account(account) - - async def get_private_key_for_address(self, wallet, address) -> Optional[PrivateKey]: - return await self.ledger.get_private_key_for_address(wallet, address) - - async def get_public_key_for_address(self, wallet, address) -> Optional[PubKey]: - return await self.ledger.get_public_key_for_address(wallet, address) - - async def get_account_for_address(self, wallet, address): - return await self.ledger.get_account_for_address(wallet, address) - - async def get_effective_amount_estimators(self, funding_accounts: Iterable[Account]): - return await self.ledger.get_effective_amount_estimators(funding_accounts) - async def get_addresses(self, **constraints): return await self.db.get_addresses(**constraints) - def get_address_count(self, **constraints): - return self.db.get_address_count(**constraints) - - async def get_spendable_utxos(self, amount: int, funding_accounts): - return await self.ledger.get_spendable_utxos(amount, funding_accounts) - def reserve_outputs(self, txos): return self.db.reserve_outputs(txos) @@ -143,19 +119,12 @@ class Service: self.constraint_spending_utxos(constraints) return self.db.get_utxos(**constraints) - def get_utxo_count(self, **constraints): - self.constraint_spending_utxos(constraints) - return self.db.get_utxo_count(**constraints) - async def get_txos(self, resolve=False, **constraints) -> List[Output]: txos = await self.db.get_txos(**constraints) if resolve: return await self._resolve_for_local_results(constraints.get('accounts', []), txos) return txos - def get_txo_count(self, **constraints): - return self.db.get_txo_count(**constraints) - def get_txo_sum(self, **constraints): return self.db.get_txo_sum(**constraints) @@ -165,8 +134,21 @@ class Service: def get_transactions(self, **constraints): return self.db.get_transactions(**constraints) - def get_transaction_count(self, **constraints): - return self.db.get_transaction_count(**constraints) + async def get_transaction(self, tx_hash: bytes): + tx = await self.db.get_transaction(tx_hash=tx_hash) + if tx: + return tx + try: + raw, merkle = await self.ledger.network.get_transaction_and_merkle(tx_hash) + except CodeMessageError as e: + if 'No such mempool or blockchain transaction.' in e.message: + return {'success': False, 'code': 404, 'message': 'transaction not found'} + return {'success': False, 'code': e.code, 'message': e.message} + height = merkle.get('block_height') + tx = Transaction(unhexlify(raw), height=height) + if height and height > 0: + await self.ledger.maybe_verify_transaction(tx, height, merkle) + return tx async def search_transactions(self, txids): raise NotImplementedError @@ -181,6 +163,20 @@ class Service: return account.address_managers[details['chain']] return None + async def reset(self): + self.ledger.config = { + 'auto_connect': True, + 'default_servers': self.config.lbryum_servers, + 'data_path': self.config.wallet_dir, + } + await self.ledger.stop() + await self.ledger.start() + + async def get_best_blockhash(self): + if len(self.ledger.headers) <= 0: + return self.ledger.genesis_hash + return (await self.ledger.headers.hash(self.ledger.headers.height)).decode() + async def broadcast_or_release(self, tx, blocking=False): try: await self.broadcast(tx) @@ -206,39 +202,12 @@ class Service: for claim in (await self.search_claims(accounts, claim_id=claim_id, **kwargs))[0]: return claim - async def _report_state(self): - try: - for account in self.accounts: - balance = dewies_to_lbc(await account.get_balance(include_claims=True)) - channel_count = await account.get_channel_count() - claim_count = await account.get_claim_count() - if isinstance(account.receiving, SingleKey): - log.info("Loaded single key account %s with %s LBC. " - "%d channels, %d certificates and %d claims", - account.id, balance, channel_count, len(account.channel_keys), claim_count) - else: - total_receiving = len(await account.receiving.get_addresses()) - total_change = len(await account.change.get_addresses()) - log.info("Loaded account %s with %s LBC, %d receiving addresses (gap: %d), " - "%d change addresses (gap: %d), %d channels, %d certificates and %d claims. ", - account.id, balance, total_receiving, account.receiving.gap, total_change, - account.change.gap, channel_count, len(account.channel_keys), claim_count) - except Exception as err: - if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8 - raise - log.exception( - 'Failed to display wallet state, please file issue ' - 'for this bug along with the traceback you see below:') - - async def _reset_balance_cache(self, e): - return await self.ledger._reset_balance_cache(e) - @staticmethod def constraint_spending_utxos(constraints): constraints['txo_type__in'] = (0, TXO_TYPES['purchase']) - async def get_purchases(self, resolve=False, **constraints): - purchases = await self.db.get_purchases(**constraints) + async def get_purchases(self, wallet, resolve=False, **constraints): + purchases = await wallet.get_purchases(**constraints) if resolve: claim_ids = [p.purchased_claim_id for p in purchases] try: @@ -253,9 +222,6 @@ class Service: purchase.purchased_claim = lookup.get(purchase.purchased_claim_id) return purchases - def get_purchase_count(self, resolve=False, **constraints): - return self.db.get_purchase_count(**constraints) - async def _resolve_for_local_results(self, accounts, txos): results = [] response = await self.resolve( @@ -272,33 +238,6 @@ class Service: results.append(txo) return results - async def get_claims(self, resolve=False, **constraints): - claims = await self.db.get_claims(**constraints) - if resolve: - return await self._resolve_for_local_results(constraints.get('accounts', []), claims) - return claims - - def get_claim_count(self, **constraints): - return self.db.get_claim_count(**constraints) - - async def get_streams(self, resolve=False, **constraints): - streams = await self.db.get_streams(**constraints) - if resolve: - return await self._resolve_for_local_results(constraints.get('accounts', []), streams) - return streams - - def get_stream_count(self, **constraints): - return self.db.get_stream_count(**constraints) - - async def get_channels(self, resolve=False, **constraints): - channels = await self.db.get_channels(**constraints) - if resolve: - return await self._resolve_for_local_results(constraints.get('accounts', []), channels) - return channels - - def get_channel_count(self, **constraints): - return self.db.get_channel_count(**constraints) - async def resolve_collection(self, collection, offset=0, page_size=1): claim_ids = collection.claim.collection.claims.ids[offset:page_size+offset] try: @@ -319,138 +258,3 @@ class Service: if not found: claims.append(None) return claims - - async def get_collections(self, resolve_claims=0, **constraints): - collections = await self.db.get_collections(**constraints) - if resolve_claims > 0: - for collection in collections: - collection.claims = await self.resolve_collection(collection, page_size=resolve_claims) - return collections - - def get_collection_count(self, resolve_claims=0, **constraints): - return self.db.get_collection_count(**constraints) - - def get_supports(self, **constraints): - return self.db.get_supports(**constraints) - - def get_support_count(self, **constraints): - return self.db.get_support_count(**constraints) - - async def get_transaction_history(self, **constraints): - txs: List[Transaction] = await self.db.get_transactions( - include_is_my_output=True, include_is_spent=True, - **constraints - ) - headers = self.headers - history = [] - for tx in txs: # pylint: disable=too-many-nested-blocks - ts = headers.estimated_timestamp(tx.height) - item = { - 'txid': tx.id, - 'timestamp': ts, - 'date': datetime.fromtimestamp(ts).isoformat(' ')[:-3] if tx.height > 0 else None, - 'confirmations': (headers.height+1) - tx.height if tx.height > 0 else 0, - 'claim_info': [], - 'update_info': [], - 'support_info': [], - 'abandon_info': [], - 'purchase_info': [] - } - is_my_inputs = all([txi.is_my_input for txi in tx.inputs]) - if is_my_inputs: - # fees only matter if we are the ones paying them - item['value'] = dewies_to_lbc(tx.net_account_balance+tx.fee) - item['fee'] = dewies_to_lbc(-tx.fee) - else: - # someone else paid the fees - item['value'] = dewies_to_lbc(tx.net_account_balance) - item['fee'] = '0.0' - for txo in tx.my_claim_outputs: - item['claim_info'].append({ - 'address': txo.get_address(self.ledger), - 'balance_delta': dewies_to_lbc(-txo.amount), - 'amount': dewies_to_lbc(txo.amount), - 'claim_id': txo.claim_id, - 'claim_name': txo.claim_name, - 'nout': txo.position, - 'is_spent': txo.is_spent, - }) - for txo in tx.my_update_outputs: - if is_my_inputs: # updating my own claim - previous = None - for txi in tx.inputs: - if txi.txo_ref.txo is not None: - other_txo = txi.txo_ref.txo - if (other_txo.is_claim or other_txo.script.is_support_claim) \ - and other_txo.claim_id == txo.claim_id: - previous = other_txo - break - if previous is not None: - item['update_info'].append({ - 'address': txo.get_address(self), - 'balance_delta': dewies_to_lbc(previous.amount-txo.amount), - 'amount': dewies_to_lbc(txo.amount), - 'claim_id': txo.claim_id, - 'claim_name': txo.claim_name, - 'nout': txo.position, - 'is_spent': txo.is_spent, - }) - else: # someone sent us their claim - item['update_info'].append({ - 'address': txo.get_address(self), - 'balance_delta': dewies_to_lbc(0), - 'amount': dewies_to_lbc(txo.amount), - 'claim_id': txo.claim_id, - 'claim_name': txo.claim_name, - 'nout': txo.position, - 'is_spent': txo.is_spent, - }) - for txo in tx.my_support_outputs: - item['support_info'].append({ - 'address': txo.get_address(self.ledger), - 'balance_delta': dewies_to_lbc(txo.amount if not is_my_inputs else -txo.amount), - 'amount': dewies_to_lbc(txo.amount), - 'claim_id': txo.claim_id, - 'claim_name': txo.claim_name, - 'is_tip': not is_my_inputs, - 'nout': txo.position, - 'is_spent': txo.is_spent, - }) - if is_my_inputs: - for txo in tx.other_support_outputs: - item['support_info'].append({ - 'address': txo.get_address(self.ledger), - 'balance_delta': dewies_to_lbc(-txo.amount), - 'amount': dewies_to_lbc(txo.amount), - 'claim_id': txo.claim_id, - 'claim_name': txo.claim_name, - 'is_tip': is_my_inputs, - 'nout': txo.position, - 'is_spent': txo.is_spent, - }) - for txo in tx.my_abandon_outputs: - item['abandon_info'].append({ - 'address': txo.get_address(self.ledger), - 'balance_delta': dewies_to_lbc(txo.amount), - 'amount': dewies_to_lbc(txo.amount), - 'claim_id': txo.claim_id, - 'claim_name': txo.claim_name, - 'nout': txo.position - }) - for txo in tx.any_purchase_outputs: - item['purchase_info'].append({ - 'address': txo.get_address(self.ledger), - 'balance_delta': dewies_to_lbc(txo.amount if not is_my_inputs else -txo.amount), - 'amount': dewies_to_lbc(txo.amount), - 'claim_id': txo.purchased_claim_id, - 'nout': txo.position, - 'is_spent': txo.is_spent, - }) - history.append(item) - return history - - def get_transaction_history_count(self, **constraints): - return self.db.get_transaction_count(**constraints) - - async def get_detailed_balance(self, accounts, confirmations=0): - return self.ledger.get_detailed_balance(accounts, confirmations) diff --git a/lbry/testcase.py b/lbry/testcase.py index 9cf394dc9..8bb7e252e 100644 --- a/lbry/testcase.py +++ b/lbry/testcase.py @@ -14,16 +14,15 @@ from time import time from binascii import unhexlify from functools import partial +from lbry.wallet import WalletManager, Wallet, Account from lbry.blockchain.ledger import Ledger from lbry.blockchain.transaction import Transaction, Input, Output from lbry.blockchain.util import satoshis_to_coins -from lbry.constants import CENT, NULL_HASH32 -from lbry.wallet.wallet import Wallet, Account -from lbry.wallet.manager import WalletManager -from lbry.conf import Config from lbry.blockchain.lbrycrd import Lbrycrd +from lbry.constants import CENT, NULL_HASH32 from lbry.service.full_node import FullNode from lbry.service.daemon import Daemon +from lbry.conf import Config from lbry.extras.daemon.daemon import jsonrpc_dumps_pretty from lbry.extras.daemon.components import Component, WalletComponent @@ -363,7 +362,7 @@ class CommandTestCase(IntegrationTestCase): self.block_expected = 0 await self.generate(200, wait=False) - self.chain.ledger.conf.spv_address_filters = False + self.ledger.conf.spv_address_filters = False self.service = FullNode( self.ledger, f'sqlite:///{self.chain.data_dir}/full_node.db', Lbrycrd(self.ledger) ) @@ -372,10 +371,13 @@ class CommandTestCase(IntegrationTestCase): self.addCleanup(self.daemon.stop) await self.daemon.start() - self.wallet = self.service.wallet_manager.default_wallet - self.account = self.wallet.accounts[0] + self.wallet = self.service.wallets.default + self.account = self.wallet.accounts.default addresses = await self.account.ensure_address_gap() + self.ledger.conf.upload_dir = os.path.join(self.ledger.conf.data_dir, 'uploads') + os.mkdir(self.ledger.conf.upload_dir) + await self.chain.send_to_address(addresses[0], '10.0') await self.generate(5) @@ -527,7 +529,9 @@ class CommandTestCase(IntegrationTestCase): return self.sout(tx) def create_upload_file(self, data, prefix=None, suffix=None): - file_path = tempfile.mktemp(prefix=prefix or "tmp", suffix=suffix or "", dir=self.daemon.conf.upload_dir) + file_path = tempfile.mktemp( + prefix=prefix or "tmp", suffix=suffix or "", dir=self.ledger.conf.upload_dir + ) with open(file_path, 'w+b') as file: file.write(data) file.flush() @@ -539,7 +543,7 @@ class CommandTestCase(IntegrationTestCase): if file_path is None: file_path = self.create_upload_file(data=data, prefix=prefix, suffix=suffix) return await self.confirm_and_render( - self.daemon.jsonrpc_stream_create(name, bid, file_path=file_path, **kwargs), confirm + self.api.stream_create(name, bid, file_path=file_path, **kwargs), confirm ) async def stream_update( diff --git a/lbry/wallet/__init__.py b/lbry/wallet/__init__.py index e69de29bb..266bec4ad 100644 --- a/lbry/wallet/__init__.py +++ b/lbry/wallet/__init__.py @@ -0,0 +1,3 @@ +from .account import Account, AddressManager, SingleKey +from .wallet import Wallet +from .manager import WalletManager diff --git a/lbry/wallet/account.py b/lbry/wallet/account.py index 31fa614de..30bbba51e 100644 --- a/lbry/wallet/account.py +++ b/lbry/wallet/account.py @@ -9,12 +9,10 @@ from hashlib import sha256 from string import hexdigits from typing import Type, Dict, Tuple, Optional, Any, List -from sqlalchemy import text - import ecdsa + from lbry.error import InvalidPasswordError from lbry.crypto.crypt import aes_encrypt, aes_decrypt - from lbry.crypto.bip32 import PrivateKey, PubKey, from_extended_key_string from lbry.constants import COIN from lbry.blockchain.transaction import Transaction, Input, Output diff --git a/lbry/wallet/manager.py b/lbry/wallet/manager.py index 57a2b2d62..cd40150da 100644 --- a/lbry/wallet/manager.py +++ b/lbry/wallet/manager.py @@ -1,147 +1,102 @@ import os -import json -import typing -import logging import asyncio -from binascii import unhexlify -from decimal import Decimal -from typing import List, Type, MutableSequence, MutableMapping, Optional +import logging +from typing import Optional, Dict -from lbry.error import KeyFeeAboveMaxAllowedError -from lbry.conf import Config - -from .account import Account -from lbry.blockchain.dewies import dewies_to_lbc -from lbry.blockchain.ledger import Ledger from lbry.db import Database from lbry.blockchain.ledger import Ledger -from lbry.blockchain.transaction import Transaction, Output - -from .wallet import Wallet, WalletStorage, ENCRYPT_ON_DISK +from .wallet import Wallet log = logging.getLogger(__name__) class WalletManager: - def __init__(self, ledger: Ledger, db: Database, - wallets: MutableSequence[Wallet] = None, - ledgers: MutableMapping[Type[Ledger], Ledger] = None) -> None: + def __init__(self, ledger: Ledger, db: Database): self.ledger = ledger self.db = db - self.wallets = wallets or [] - self.ledgers = ledgers or {} - self.running = False - self.config: Optional[Config] = None + self.wallets: Dict[str, Wallet] = {} - async def open(self): - conf = self.ledger.conf - - wallets_directory = os.path.join(conf.wallet_dir, 'wallets') - if not os.path.exists(wallets_directory): - os.mkdir(wallets_directory) - - for wallet_file in conf.wallets: - wallet_path = os.path.join(wallets_directory, wallet_file) - wallet_storage = WalletStorage(wallet_path) - wallet = Wallet.from_storage(self.ledger, self.db, wallet_storage) - self.wallets.append(wallet) - - self.ledger.coin_selection_strategy = self.ledger.conf.coin_selection_strategy - default_wallet = self.default_wallet - if default_wallet.default_account is None: - log.info('Wallet at %s is empty, generating a default account.', default_wallet.id) - default_wallet.generate_account() - default_wallet.save() - if default_wallet.is_locked and default_wallet.preferences.get(ENCRYPT_ON_DISK) is None: - default_wallet.preferences[ENCRYPT_ON_DISK] = True - default_wallet.save() - - def import_wallet(self, path): - storage = WalletStorage(path) - wallet = Wallet.from_storage(self.ledger, self.db, storage) - self.wallets.append(wallet) - return wallet + def __getitem__(self, wallet_id: str) -> Wallet: + try: + return self.wallets[wallet_id] + except KeyError: + raise ValueError(f"Couldn't find wallet: {wallet_id}.") @property - def default_wallet(self): - for wallet in self.wallets: + def default(self) -> Optional[Wallet]: + for wallet in self.wallets.values(): return wallet - @property - def default_account(self): - for wallet in self.wallets: - return wallet.default_account + def get_or_default(self, wallet_id: Optional[str]) -> Optional[Wallet]: + if wallet_id: + return self[wallet_id] + return self.default @property - def accounts(self): - for wallet in self.wallets: - yield from wallet.accounts + def path(self): + return os.path.join(self.ledger.conf.wallet_dir, 'wallets') - async def start(self): - self.running = True - await asyncio.gather(*( - l.start() for l in self.ledgers.values() - )) + def sync_ensure_path_exists(self): + if not os.path.exists(self.path): + os.mkdir(self.path) - async def stop(self): - await asyncio.gather(*( - l.stop() for l in self.ledgers.values() - )) - self.running = False + async def ensure_path_exists(self): + await asyncio.get_running_loop().run_in_executor( + None, self.sync_ensure_path_exists + ) - def get_wallet_or_default(self, wallet_id: Optional[str]) -> Wallet: - if wallet_id is None: - return self.default_wallet - return self.get_wallet_or_error(wallet_id) + async def load(self): + wallets_directory = self.path + for wallet_id in self.ledger.conf.wallets: + if wallet_id in self.wallets: + log.warning(f"Ignoring duplicate wallet_id in config: {wallet_id}") + continue + wallet_path = os.path.join(wallets_directory, wallet_id) + if not os.path.exists(wallet_path): + if not wallet_id == "default_wallet": # we'll probably generate this wallet, don't show error + log.error(f"Could not load wallet, file does not exist: {wallet_path}") + continue + wallet = await Wallet.from_path(self.ledger, self.db, wallet_path) + self.add(wallet) + default_wallet = self.default + if default_wallet is None: + if self.ledger.conf.create_default_wallet: + assert self.ledger.conf.wallets[0] == "default_wallet", ( + "Requesting to generate the default wallet but the 'wallets' " + "config setting does not include 'default_wallet' as the first wallet." + ) + await self.create( + self.ledger.conf.wallets[0], 'Wallet', + create_account=self.ledger.conf.create_default_account + ) + elif not default_wallet.has_accounts and self.ledger.conf.create_default_account: + default_wallet.accounts.generate() - def get_wallet_or_error(self, wallet_id: str) -> Wallet: - for wallet in self.wallets: - if wallet.id == wallet_id: - return wallet - raise ValueError(f"Couldn't find wallet: {wallet_id}.") + def add(self, wallet: Wallet) -> Wallet: + self.wallets[wallet.id] = wallet + return wallet - @staticmethod - def get_balance(wallet): - accounts = wallet.accounts - if not accounts: - return 0 - return accounts[0].ledger.db.get_balance(wallet=wallet, accounts=accounts) - - def check_locked(self): - return self.default_wallet.is_locked - - async def reset(self): - self.ledger.config = { - 'auto_connect': True, - 'default_servers': self.config.lbryum_servers, - 'data_path': self.config.wallet_dir, - } - await self.ledger.stop() - await self.ledger.start() - - async def get_best_blockhash(self): - if len(self.ledger.headers) <= 0: - return self.ledger.genesis_hash - return (await self.ledger.headers.hash(self.ledger.headers.height)).decode() - - def get_unused_address(self): - return self.default_account.receiving.get_or_create_usable_address() - - async def get_transaction(self, tx_hash: bytes): - tx = await self.db.get_transaction(tx_hash=tx_hash) - if tx: - return tx - try: - raw, merkle = await self.ledger.network.get_transaction_and_merkle(tx_hash) - except CodeMessageError as e: - if 'No such mempool or blockchain transaction.' in e.message: - return {'success': False, 'code': 404, 'message': 'transaction not found'} - return {'success': False, 'code': e.code, 'message': e.message} - height = merkle.get('block_height') - tx = Transaction(unhexlify(raw), height=height) - if height and height > 0: - await self.ledger.maybe_verify_transaction(tx, height, merkle) - return tx + async def add_from_path(self, wallet_path) -> Wallet: + wallet_id = os.path.basename(wallet_path) + if wallet_id in self.wallets: + existing = self.wallets.get(wallet_id) + if existing.storage.path == wallet_path: + raise Exception(f"Wallet '{wallet_id}' is already loaded.") + raise Exception( + f"Wallet '{wallet_id}' is already loaded from '{existing.storage.path}'" + f" and cannot be loaded from '{wallet_path}'. Consider changing the wallet" + f" filename to be unique in order to avoid conflicts." + ) + wallet = await Wallet.from_path(self.ledger, self.db, wallet_path) + return self.add(wallet) + async def create(self, wallet_id: str, name: str, create_account=False, single_key=False) -> Wallet: + if wallet_id in self.wallets: + raise Exception(f"Wallet with id '{wallet_id}' is already loaded and cannot be created.") + wallet_path = os.path.join(self.path, wallet_id) + if os.path.exists(wallet_path): + raise Exception(f"Wallet at path '{wallet_path}' already exists, use 'wallet_add' to load wallet.") + wallet = await Wallet.create(self.ledger, self.db, wallet_path, name, create_account, single_key) + return self.add(wallet) diff --git a/lbry/wallet/preferences.py b/lbry/wallet/preferences.py new file mode 100644 index 000000000..8e526cb18 --- /dev/null +++ b/lbry/wallet/preferences.py @@ -0,0 +1,39 @@ +import time +import json +from hashlib import sha256 +from collections import UserDict + + +class TimestampedPreferences(UserDict): + + def __init__(self, d: dict = None): + super().__init__() + if d is not None: + self.data = d.copy() + + def __getitem__(self, key): + return self.data[key]['value'] + + def __setitem__(self, key, value): + self.data[key] = { + 'value': value, + 'ts': time.time() + } + + def __repr__(self): + return repr(self.to_dict_without_ts()) + + def to_dict_without_ts(self): + return { + key: value['value'] for key, value in self.data.items() + } + + @property + def hash(self): + return sha256(json.dumps(self.data).encode()).digest() + + def merge(self, other: dict): + for key, value in other.items(): + if key in self.data and value['ts'] < self.data[key]['ts']: + continue + self.data[key] = value diff --git a/lbry/wallet/storage.py b/lbry/wallet/storage.py new file mode 100644 index 000000000..3859bb204 --- /dev/null +++ b/lbry/wallet/storage.py @@ -0,0 +1,60 @@ +import os +import stat +import json +import asyncio + + +class WalletStorage: + VERSION = 1 + + def __init__(self, path=None): + self.path = path + + def sync_read(self): + with open(self.path, 'r') as f: + json_data = f.read() + json_dict = json.loads(json_data) + if json_dict.get('version') == self.VERSION: + return json_dict + else: + return self.upgrade(json_dict) + + async def read(self): + return await asyncio.get_running_loop().run_in_executor( + None, self.sync_read + ) + + def upgrade(self, json_dict): + version = json_dict.pop('version', -1) + if version == -1: + pass + json_dict['version'] = self.VERSION + return json_dict + + def sync_write(self, json_dict): + + json_data = json.dumps(json_dict, indent=4, sort_keys=True) + if self.path is None: + return json_data + + temp_path = "{}.tmp.{}".format(self.path, os.getpid()) + with open(temp_path, "w") as f: + f.write(json_data) + f.flush() + os.fsync(f.fileno()) + + if os.path.exists(self.path): + mode = os.stat(self.path).st_mode + else: + mode = stat.S_IREAD | stat.S_IWRITE + try: + os.rename(temp_path, self.path) + except Exception: # pylint: disable=broad-except + os.remove(self.path) + os.rename(temp_path, self.path) + os.chmod(self.path, mode) + + async def write(self, json_dict): + return await asyncio.get_running_loop().run_in_executor( + None, self.sync_write, json_dict + ) diff --git a/lbry/wallet/wallet.py b/lbry/wallet/wallet.py index a065db846..c20228e24 100644 --- a/lbry/wallet/wallet.py +++ b/lbry/wallet/wallet.py @@ -1,16 +1,14 @@ import os -import time -import stat import json import zlib -import typing +import asyncio import logging -from typing import List, Sequence, MutableSequence, Optional, Iterable -from collections import UserDict +from typing import List, Sequence, Tuple, Optional, Iterable from hashlib import sha256 from operator import attrgetter from decimal import Decimal + from lbry.db import Database, SPENDABLE_TYPE_CODES from lbry.blockchain.ledger import Ledger from lbry.constants import COIN, NULL_HASH32 @@ -22,11 +20,10 @@ from lbry.schema.claim import Claim from lbry.schema.purchase import Purchase from lbry.error import InsufficientFundsError, KeyFeeAboveMaxAllowedError -from .account import Account +from .account import Account, SingleKey, HierarchicalDeterministic from .coinselection import CoinSelector, OutputEffectiveAmountEstimator - -if typing.TYPE_CHECKING: - from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager +from .storage import WalletStorage +from .preferences import TimestampedPreferences log = logging.getLogger(__name__) @@ -34,41 +31,6 @@ log = logging.getLogger(__name__) ENCRYPT_ON_DISK = 'encrypt-on-disk' -class TimestampedPreferences(UserDict): - - def __init__(self, d: dict = None): - super().__init__() - if d is not None: - self.data = d.copy() - - def __getitem__(self, key): - return self.data[key]['value'] - - def __setitem__(self, key, value): - self.data[key] = { - 'value': value, - 'ts': time.time() - } - - def __repr__(self): - return repr(self.to_dict_without_ts()) - - def to_dict_without_ts(self): - return { - key: value['value'] for key, value in self.data.items() - } - - @property - def hash(self): - return sha256(json.dumps(self.data).encode()).digest() - - def merge(self, other: dict): - for key, value in other.items(): - if key in self.data and value['ts'] < self.data[key]['ts']: - continue - self.data[key] = value - - class Wallet: """ The primary role of Wallet is to encapsulate a collection of accounts (seed/private keys) and the spending rules / settings @@ -76,102 +38,45 @@ class Wallet: by physical files on the filesystem. """ - preferences: TimestampedPreferences - encryption_password: Optional[str] - - def __init__(self, ledger: Ledger, db: Database, - name: str = 'Wallet', accounts: MutableSequence[Account] = None, - storage: 'WalletStorage' = None, preferences: dict = None) -> None: + def __init__(self, ledger: Ledger, db: Database, name: str, storage: WalletStorage, preferences: dict): self.ledger = ledger self.db = db self.name = name - self.accounts = accounts or [] - self.storage = storage or WalletStorage() + self.storage = storage self.preferences = TimestampedPreferences(preferences or {}) - self.encryption_password = None + self.encryption_password: Optional[str] = None self.id = self.get_id() + self.utxo_lock = asyncio.Lock() + + self.accounts = AccountListManager(self) + self.claims = ClaimListManager(self) + self.streams = StreamListManager(self) + self.channels = ChannelListManager(self) + self.collections = CollectionListManager(self) + self.purchases = PurchaseListManager(self) + self.supports = SupportListManager(self) + def get_id(self): return os.path.basename(self.storage.path) if self.storage.path else self.name - def generate_account(self, name: str = None, address_generator: dict = None) -> Account: - account = Account.generate(self.ledger, self.db, name, address_generator) - self.accounts.append(account) - return account - - def add_account(self, account_dict) -> Account: - account = Account.from_dict(self.ledger, self.db, account_dict) - self.accounts.append(account) - return account - - @property - def default_account(self) -> Optional[Account]: - for account in self.accounts: - return account - return None - - def get_account_or_default(self, account_id: str) -> Optional[Account]: - if account_id is None: - return self.default_account - return self.get_account_or_error(account_id) - - def get_account_or_error(self, account_id: str) -> Account: - for account in self.accounts: - if account.id == account_id: - return account - raise ValueError(f"Couldn't find account: {account_id}.") - - def get_accounts_or_all(self, account_ids: List[str]) -> Sequence[Account]: - return [ - self.get_account_or_error(account_id) - for account_id in account_ids - ] if account_ids else self.accounts - - async def get_detailed_accounts(self, **kwargs): - accounts = [] - for i, account in enumerate(self.accounts): - details = await account.get_details(**kwargs) - details['is_default'] = i == 0 - accounts.append(details) - return accounts - - async def _get_account_and_address_info_for_address(self, address): - match = await self.db.get_address(accounts=self.accounts, address=address) - if match: - for account in self.accounts: - if match['account'] == account.public_key.address: - return account, match - - async def get_private_key_for_address(self, address) -> Optional[PrivateKey]: - match = await self._get_account_and_address_info_for_address(address) - if match: - account, address_info = match - return account.get_private_key(address_info['chain'], address_info['pubkey'].n) - return None - - async def get_public_key_for_address(self, address) -> Optional[PubKey]: - match = await self._get_account_and_address_info_for_address(address) - if match: - _, address_info = match - return address_info['pubkey'] - return None - - async def get_account_for_address(self, address): - match = await self._get_account_and_address_info_for_address(address) - if match: - return match[0] - - async def save_max_gap(self): - gap_changed = False - for account in self.accounts: - if await account.save_max_gap(): - gap_changed = True - if gap_changed: - self.save() + @classmethod + async def create(cls, ledger: Ledger, db: Database, path: str, name: str, create_account=False, single_key=False): + wallet = cls(ledger, db, name, WalletStorage(path), {}) + if create_account: + wallet.accounts.generate(address_generator={ + 'name': SingleKey.name if single_key else HierarchicalDeterministic.name + }) + await wallet.save() + return wallet @classmethod - def from_storage(cls, ledger: Ledger, db: Database, storage: 'WalletStorage') -> 'Wallet': - json_dict = storage.read() + async def from_path(cls, ledger: Ledger, db: Database, path: str): + return await cls.from_storage(ledger, db, WalletStorage(path)) + + @classmethod + async def from_storage(cls, ledger: Ledger, db: Database, storage: WalletStorage) -> 'Wallet': + json_dict = await storage.read() if 'ledger' in json_dict and json_dict['ledger'] != ledger.get_id(): raise ValueError( f"Using ledger {ledger.get_id()} but wallet is {json_dict['ledger']}." @@ -179,33 +84,32 @@ class Wallet: wallet = cls( ledger, db, name=json_dict.get('name', 'Wallet'), + storage=storage, preferences=json_dict.get('preferences', {}), - storage=storage ) - account_dicts: Sequence[dict] = json_dict.get('accounts', []) - for account_dict in account_dicts: - wallet.add_account(account_dict) + for account_dict in json_dict.get('accounts', []): + wallet.accounts.add_from_dict(account_dict) return wallet def to_dict(self, encrypt_password: str = None): return { - 'version': WalletStorage.LATEST_VERSION, - 'name': self.name, + 'version': WalletStorage.VERSION, 'ledger': self.ledger.get_id(), + 'name': self.name, 'preferences': self.preferences.data, 'accounts': [a.to_dict(encrypt_password) for a in self.accounts] } - def save(self): + async def save(self): if self.preferences.get(ENCRYPT_ON_DISK, False): if self.encryption_password is not None: - return self.storage.write(self.to_dict(encrypt_password=self.encryption_password)) + return await self.storage.write(self.to_dict(encrypt_password=self.encryption_password)) elif not self.is_locked: log.warning( "Disk encryption requested but no password available for encryption. " "Saving wallet in an unencrypted state." ) - return self.storage.write(self.to_dict()) + return await self.storage.write(self.to_dict()) @property def hash(self) -> bytes: @@ -248,7 +152,7 @@ class Wallet: local_match.merge(account_dict) else: added_accounts.append( - self.add_account(account_dict) + self.accounts.add_from_dict(account_dict) ) return added_accounts @@ -292,6 +196,44 @@ class Wallet: self.save() return True + @property + def has_accounts(self): + return len(self.accounts) > 0 + + async def _get_account_and_address_info_for_address(self, address): + match = await self.db.get_address(accounts=self.accounts, address=address) + if match: + for account in self.accounts: + if match['account'] == account.public_key.address: + return account, match + + async def get_private_key_for_address(self, address) -> Optional[PrivateKey]: + match = await self._get_account_and_address_info_for_address(address) + if match: + account, address_info = match + return account.get_private_key(address_info['chain'], address_info['pubkey'].n) + return None + + async def get_public_key_for_address(self, address) -> Optional[PubKey]: + match = await self._get_account_and_address_info_for_address(address) + if match: + _, address_info = match + return address_info['pubkey'] + return None + + async def get_account_for_address(self, address): + match = await self._get_account_and_address_info_for_address(address) + if match: + return match[0] + + async def save_max_gap(self): + gap_changed = False + for account in self.accounts: + if await account.save_max_gap(): + gap_changed = True + if gap_changed: + self.save() + async def get_effective_amount_estimators(self, funding_accounts: Iterable[Account]): estimators = [] utxos = await self.db.get_utxos( @@ -303,13 +245,20 @@ class Wallet: return estimators async def get_spendable_utxos(self, amount: int, funding_accounts: Iterable[Account]): - txos = await self.get_effective_amount_estimators(funding_accounts) - fee = Output.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(self.ledger) - selector = CoinSelector(amount, fee) - spendables = selector.select(txos, self.ledger.coin_selection_strategy) - if spendables: - await self.db.reserve_outputs(s.txo for s in spendables) - return spendables + async with self.utxo_lock: + txos = await self.get_effective_amount_estimators(funding_accounts) + fee = Output.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(self.ledger) + selector = CoinSelector(amount, fee) + spendables = selector.select(txos, self.ledger.coin_selection_strategy) + if spendables: + await self.db.reserve_outputs(s.txo for s in spendables) + return spendables + + async def list_transactions(self, **constraints): + return txs_to_dict(await self.db.get_transactions( + include_is_my_output=True, include_is_spent=True, + **constraints + ), self.ledger) async def create_transaction(self, inputs: Iterable[Input], outputs: Iterable[Output], funding_accounts: Iterable[Account], change_account: Account, @@ -397,46 +346,271 @@ class Wallet: raise NotImplementedError("Don't know how to spend this output.") tx._reset() - @classmethod - def pay(cls, amount: int, address: bytes, funding_accounts: List['Account'], change_account: 'Account'): - output = Output.pay_pubkey_hash(amount, ledger.address_to_hash160(address)) - return cls.create([], [output], funding_accounts, change_account) + async def pay(self, amount: int, address: bytes, funding_accounts: List[Account], change_account: Account): + output = Output.pay_pubkey_hash(amount, self.ledger.address_to_hash160(address)) + return await self.create_transaction([], [output], funding_accounts, change_account) - def claim_create( + async def _report_state(self): + try: + for account in self.accounts: + balance = dewies_to_lbc(await account.get_balance(include_claims=True)) + _, channel_count = await account.get_channels(limit=1) + claim_count = await account.get_claim_count() + if isinstance(account.receiving, SingleKey): + log.info("Loaded single key account %s with %s LBC. " + "%d channels, %d certificates and %d claims", + account.id, balance, channel_count, len(account.channel_keys), claim_count) + else: + total_receiving = len(await account.receiving.get_addresses()) + total_change = len(await account.change.get_addresses()) + log.info("Loaded account %s with %s LBC, %d receiving addresses (gap: %d), " + "%d change addresses (gap: %d), %d channels, %d certificates and %d claims. ", + account.id, balance, total_receiving, account.receiving.gap, total_change, + account.change.gap, channel_count, len(account.channel_keys), claim_count) + except Exception as err: + if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8 + raise + log.exception( + 'Failed to display wallet state, please file issue ' + 'for this bug along with the traceback you see below:') + + +class AccountListManager: + __slots__ = 'wallet', '_accounts' + + def __init__(self, wallet: Wallet): + self.wallet = wallet + self._accounts: List[Account] = [] + + def __len__(self): + return self._accounts.__len__() + + def __iter__(self): + return self._accounts.__iter__() + + def __getitem__(self, account_id: str) -> Account: + for account in self: + if account.id == account_id: + return account + raise ValueError(f"Couldn't find account: {account_id}.") + + @property + def default(self) -> Optional[Account]: + for account in self: + return account + + def generate(self, name: str = None, address_generator: dict = None) -> Account: + account = Account.generate(self.wallet.ledger, self.wallet.db, name, address_generator) + self._accounts.append(account) + return account + + def add_from_dict(self, account_dict: dict) -> Account: + account = Account.from_dict(self.wallet.ledger, self.wallet.db, account_dict) + self._accounts.append(account) + return account + + async def remove(self, account_id: str) -> Account: + account = self[account_id] + self._accounts.remove(account) + await self.wallet.save() + return account + + def get_or_none(self, account_id: str) -> Optional[Account]: + if account_id is not None: + return self[account_id] + + def get_or_default(self, account_id: str) -> Optional[Account]: + if account_id is None: + return self.default + return self[account_id] + + def get_or_all(self, account_ids: List[str]) -> List[Account]: + return [self[account_id] for account_id in account_ids] if account_ids else self._accounts + + async def get_account_details(self, **kwargs): + accounts = [] + for i, account in enumerate(self._accounts): + details = await account.get_details(**kwargs) + details['is_default'] = i == 0 + accounts.append(details) + return accounts + + +class BaseListManager: + __slots__ = 'wallet', 'db' + + def __init__(self, wallet: Wallet): + self.wallet = wallet + self.db = wallet.db + + async def create(self, **kwargs) -> Transaction: + raise NotImplementedError + + async def delete(self, **constraints): + raise NotImplementedError + + async def list(self, **constraints) -> Tuple[List[Output], Optional[int]]: + raise NotImplementedError + + async def get(self, **constraints) -> Output: + raise NotImplementedError + + async def get_or_none(self, **constraints) -> Optional[Output]: + raise NotImplementedError + + +class ClaimListManager(BaseListManager): + name = 'claim' + __slots__ = () + + async def create( self, name: str, claim: Claim, amount: int, holding_address: str, - funding_accounts: List['Account'], change_account: 'Account', signing_channel: Output = None): + funding_accounts: List[Account], change_account: Account, signing_channel: Output = None): claim_output = Output.pay_claim_name_pubkey_hash( - amount, name, claim, self.ledger.address_to_hash160(holding_address) + amount, name, claim, self.wallet.ledger.address_to_hash160(holding_address) ) if signing_channel is not None: claim_output.sign(signing_channel, b'placeholder txid:nout') - return self.create_transaction( + return await self.wallet.create_transaction( [], [claim_output], funding_accounts, change_account, sign=False ) - @classmethod - def claim_update( - cls, previous_claim: Output, claim: Claim, amount: int, holding_address: str, - funding_accounts: List['Account'], change_account: 'Account', signing_channel: Output = None): + async def update( + self, previous_claim: Output, claim: Claim, amount: int, holding_address: str, + funding_accounts: List[Account], change_account: Account, signing_channel: Output = None): updated_claim = Output.pay_update_claim_pubkey_hash( amount, previous_claim.claim_name, previous_claim.claim_id, - claim, ledger.address_to_hash160(holding_address) + claim, self.wallet.ledger.address_to_hash160(holding_address) ) if signing_channel is not None: updated_claim.sign(signing_channel, b'placeholder txid:nout') else: updated_claim.clear_signature() - return cls.create( + return await self.wallet.create_transaction( [Input.spend(previous_claim)], [updated_claim], funding_accounts, change_account, sign=False ) - @classmethod - def support(cls, claim_name: str, claim_id: str, amount: int, holding_address: str, - funding_accounts: List['Account'], change_account: 'Account'): - support_output = Output.pay_support_pubkey_hash( - amount, claim_name, claim_id, ledger.address_to_hash160(holding_address) + async def delete(self, claim_id=None, txid=None, nout=None): + claim = await self.get(claim_id=claim_id, txid=txid, nout=nout) + return await self.wallet.create_transaction( + [Input.spend(claim)], [], self.wallet._accounts, self.wallet._accounts[0] + ) + + async def list(self, **constraints) -> Tuple[List[Output], Optional[int]]: + return await self.db.get_claims(wallet=self.wallet, **constraints) + + async def get(self, claim_id=None, claim_name=None, txid=None, nout=None) -> Output: + if txid is not None and nout is not None: + key, value, constraints = 'txid:nout', f'{txid}:{nout}', {'tx_hash': '', 'position': nout} + elif claim_id is not None: + key, value, constraints = 'id', claim_id, {'claim_id': claim_id} + elif claim_name is not None: + key, value, constraints = 'name', claim_name, {'claim_name': claim_name} + else: + raise ValueError(f"Couldn't find {self.name} because an {self.name}_id or name was not provided.") + claims, _ = await self.list(**constraints) + if len(claims) == 1: + return claims[0] + elif len(claims) > 1: + raise ValueError( + f"Multiple {self.name}s found with {key} '{value}', " + f"pass a {self.name}_id to narrow it down." + ) + raise ValueError(f"Couldn't find {self.name} with {key} '{value}'.") + + async def get_or_none(self, claim_id=None, claim_name=None, txid=None, nout=None) -> Optional[Output]: + if any((claim_id, claim_name, all((txid, nout)))): + return await self.get(claim_id, claim_name, txid, nout) + + +class StreamListManager(ClaimListManager): + __slots__ = () + + async def create(self, *args, **kwargs): + return await super().create(*args, **kwargs) + + async def list(self, **constraints) -> Tuple[List[Output], Optional[int]]: + return await self.db.get_streams(wallet=self.wallet, **constraints) + + +class CollectionListManager(ClaimListManager): + __slots__ = () + + async def create(self, *args, **kwargs): + return await super().create(*args, **kwargs) + + async def list(self, **constraints) -> Tuple[List[Output], Optional[int]]: + return await self.db.get_collections(wallet=self.wallet, **constraints) + + +class ChannelListManager(ClaimListManager): + name = 'channel' + __slots__ = () + + async def create(self, name: str, amount: int, account: Account, funding_accounts: List[Account], + claim_address: str, preview=False, **kwargs): + claim = Claim() + claim.channel.update(**kwargs) + tx = await super().create( + name, claim, amount, claim_address, funding_accounts, funding_accounts[0] + ) + txo = tx.outputs[0] + txo.generate_channel_private_key() + await self.wallet.sign(tx) + if not preview: + account.add_channel_private_key(txo.private_key) + await self.wallet.save() + return tx + + async def list(self, **constraints) -> Tuple[List[Output], Optional[int]]: + return await self.db.get_channels(wallet=self.wallet, **constraints) + + async def get_for_signing(self, **kwargs) -> Output: + channel = await self.get(**kwargs) + if not channel.has_private_key: + raise Exception( + f"Couldn't find private key for channel '{channel.claim_name}', can't use channel for signing. " + ) + return channel + + async def get_for_signing_or_none(self, **kwargs) -> Optional[Output]: + if any(kwargs.values()): + return await self.get_for_signing(**kwargs) + + +class SupportListManager(BaseListManager): + __slots__ = () + + async def create(self, name: str, claim_id: str, amount: int, holding_address: str, + funding_accounts: List[Account], change_account: Account) -> Transaction: + support_output = Output.pay_support_pubkey_hash( + amount, name, claim_id, self.wallet.ledger.address_to_hash160(holding_address) + ) + return await self.wallet.create_transaction( + [], [support_output], funding_accounts, change_account + ) + + async def list(self, **constraints) -> Tuple[List[Output], Optional[int]]: + return await self.db.get_supports(**constraints) + + async def get(self, **constraints) -> Output: + raise NotImplementedError + + async def get_or_none(self, **constraints) -> Optional[Output]: + raise NotImplementedError + + +class PurchaseListManager(BaseListManager): + __slots__ = () + + async def create(self, name: str, claim_id: str, amount: int, holding_address: str, + funding_accounts: List[Account], change_account: Account) -> Transaction: + support_output = Output.pay_support_pubkey_hash( + amount, name, claim_id, self.wallet.ledger.address_to_hash160(holding_address) + ) + return await self.wallet.create_transaction( + [], [support_output], funding_accounts, change_account ) - return cls.create([], [support_output], funding_accounts, change_account) def purchase(self, claim_id: str, amount: int, merchant_address: bytes, funding_accounts: List['Account'], change_account: 'Account'): @@ -470,84 +644,120 @@ class Wallet: txo.claim_id, fee_amount, fee_address, accounts, accounts[0] ) - async def create_channel( - self, name, amount, account, funding_accounts, - claim_address, preview=False, **kwargs): + async def list(self, **constraints) -> Tuple[List[Output], Optional[int]]: + return await self.db.get_purchases(**constraints) - claim = Claim() - claim.channel.update(**kwargs) - tx = await self.claim_create( - name, claim, amount, claim_address, funding_accounts, funding_accounts[0] - ) - txo = tx.outputs[0] - txo.generate_channel_private_key() + async def get(self, **constraints) -> Output: + raise NotImplementedError - await self.sign(tx) - - if not preview: - account.add_channel_private_key(txo.private_key) - self.save() - - return tx - - async def get_channels(self): - return await self.db.get_channels() + async def get_or_none(self, **constraints) -> Optional[Output]: + raise NotImplementedError -class WalletStorage: - - LATEST_VERSION = 1 - - def __init__(self, path=None, default=None): - self.path = path - self._default = default or { - 'version': self.LATEST_VERSION, - 'name': 'My Wallet', - 'preferences': {}, - 'accounts': [] +def txs_to_dict(txs, ledger): + history = [] + for tx in txs: # pylint: disable=too-many-nested-blocks + ts = headers.estimated_timestamp(tx.height) + item = { + 'txid': tx.id, + 'timestamp': ts, + 'date': datetime.fromtimestamp(ts).isoformat(' ')[:-3] if tx.height > 0 else None, + 'confirmations': (headers.height + 1) - tx.height if tx.height > 0 else 0, + 'claim_info': [], + 'update_info': [], + 'support_info': [], + 'abandon_info': [], + 'purchase_info': [] } - - def read(self): - if self.path and os.path.exists(self.path): - with open(self.path, 'r') as f: - json_data = f.read() - json_dict = json.loads(json_data) - if json_dict.get('version') == self.LATEST_VERSION and \ - set(json_dict) == set(self._default): - return json_dict - else: - return self.upgrade(json_dict) + is_my_inputs = all([txi.is_my_input for txi in tx.inputs]) + if is_my_inputs: + # fees only matter if we are the ones paying them + item['value'] = dewies_to_lbc(tx.net_account_balance + tx.fee) + item['fee'] = dewies_to_lbc(-tx.fee) else: - return self._default.copy() - - def upgrade(self, json_dict): - json_dict = json_dict.copy() - version = json_dict.pop('version', -1) - if version == -1: - pass - upgraded = self._default.copy() - upgraded.update(json_dict) - return json_dict - - def write(self, json_dict): - - json_data = json.dumps(json_dict, indent=4, sort_keys=True) - if self.path is None: - return json_data - - temp_path = "{}.tmp.{}".format(self.path, os.getpid()) - with open(temp_path, "w") as f: - f.write(json_data) - f.flush() - os.fsync(f.fileno()) - - if os.path.exists(self.path): - mode = os.stat(self.path).st_mode - else: - mode = stat.S_IREAD | stat.S_IWRITE - try: - os.rename(temp_path, self.path) - except Exception: # pylint: disable=broad-except - os.remove(self.path) - os.rename(temp_path, self.path) - os.chmod(self.path, mode) + # someone else paid the fees + item['value'] = dewies_to_lbc(tx.net_account_balance) + item['fee'] = '0.0' + for txo in tx.my_claim_outputs: + item['claim_info'].append({ + 'address': txo.get_address(self.ledger), + 'balance_delta': dewies_to_lbc(-txo.amount), + 'amount': dewies_to_lbc(txo.amount), + 'claim_id': txo.claim_id, + 'claim_name': txo.claim_name, + 'nout': txo.position, + 'is_spent': txo.is_spent, + }) + for txo in tx.my_update_outputs: + if is_my_inputs: # updating my own claim + previous = None + for txi in tx.inputs: + if txi.txo_ref.txo is not None: + other_txo = txi.txo_ref.txo + if (other_txo.is_claim or other_txo.script.is_support_claim) \ + and other_txo.claim_id == txo.claim_id: + previous = other_txo + break + if previous is not None: + item['update_info'].append({ + 'address': txo.get_address(self), + 'balance_delta': dewies_to_lbc(previous.amount - txo.amount), + 'amount': dewies_to_lbc(txo.amount), + 'claim_id': txo.claim_id, + 'claim_name': txo.claim_name, + 'nout': txo.position, + 'is_spent': txo.is_spent, + }) + else: # someone sent us their claim + item['update_info'].append({ + 'address': txo.get_address(self), + 'balance_delta': dewies_to_lbc(0), + 'amount': dewies_to_lbc(txo.amount), + 'claim_id': txo.claim_id, + 'claim_name': txo.claim_name, + 'nout': txo.position, + 'is_spent': txo.is_spent, + }) + for txo in tx.my_support_outputs: + item['support_info'].append({ + 'address': txo.get_address(self.ledger), + 'balance_delta': dewies_to_lbc(txo.amount if not is_my_inputs else -txo.amount), + 'amount': dewies_to_lbc(txo.amount), + 'claim_id': txo.claim_id, + 'claim_name': txo.claim_name, + 'is_tip': not is_my_inputs, + 'nout': txo.position, + 'is_spent': txo.is_spent, + }) + if is_my_inputs: + for txo in tx.other_support_outputs: + item['support_info'].append({ + 'address': txo.get_address(self.ledger), + 'balance_delta': dewies_to_lbc(-txo.amount), + 'amount': dewies_to_lbc(txo.amount), + 'claim_id': txo.claim_id, + 'claim_name': txo.claim_name, + 'is_tip': is_my_inputs, + 'nout': txo.position, + 'is_spent': txo.is_spent, + }) + for txo in tx.my_abandon_outputs: + item['abandon_info'].append({ + 'address': txo.get_address(self.ledger), + 'balance_delta': dewies_to_lbc(txo.amount), + 'amount': dewies_to_lbc(txo.amount), + 'claim_id': txo.claim_id, + 'claim_name': txo.claim_name, + 'nout': txo.position + }) + for txo in tx.any_purchase_outputs: + item['purchase_info'].append({ + 'address': txo.get_address(self.ledger), + 'balance_delta': dewies_to_lbc(txo.amount if not is_my_inputs else -txo.amount), + 'amount': dewies_to_lbc(txo.amount), + 'claim_id': txo.purchased_claim_id, + 'nout': txo.position, + 'is_spent': txo.is_spent, + }) + history.append(item) + return history diff --git a/tests/unit/wallet/test_manager.py b/tests/unit/wallet/test_manager.py new file mode 100644 index 000000000..a2e33d2aa --- /dev/null +++ b/tests/unit/wallet/test_manager.py @@ -0,0 +1,96 @@ +import os +import shutil +import tempfile + + +from lbry.testcase import AsyncioTestCase +from lbry.blockchain.ledger import Ledger +from lbry.wallet import WalletManager, Wallet, Account +from lbry.db import Database +from lbry.conf import Config + + +class TestWalletManager(AsyncioTestCase): + + async def asyncSetUp(self): + self.temp_dir = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.temp_dir) + self.ledger = Ledger(Config( + wallet_dir=self.temp_dir + )) + self.db = Database.from_memory(self.ledger) + + async def test_ensure_path_exists(self): + wm = WalletManager(self.ledger, self.db) + self.assertFalse(os.path.exists(wm.path)) + await wm.ensure_path_exists() + self.assertTrue(os.path.exists(wm.path)) + + async def test_load_with_default_wallet_account_progression(self): + wm = WalletManager(self.ledger, self.db) + await wm.ensure_path_exists() + + # first, no defaults + self.ledger.conf.create_default_wallet = False + self.ledger.conf.create_default_account = False + await wm.load() + self.assertIsNone(wm.default) + + # then, yes to default wallet but no to default account + self.ledger.conf.create_default_wallet = True + self.ledger.conf.create_default_account = False + await wm.load() + self.assertIsInstance(wm.default, Wallet) + self.assertTrue(os.path.exists(wm.default.storage.path)) + self.assertIsNone(wm.default.accounts.default) + + # finally, yes to all the things + self.ledger.conf.create_default_wallet = True + self.ledger.conf.create_default_account = True + await wm.load() + self.assertIsInstance(wm.default, Wallet) + self.assertIsInstance(wm.default.accounts.default, Account) + + async def test_load_with_create_default_everything_upfront(self): + wm = WalletManager(self.ledger, self.db) + await wm.ensure_path_exists() + self.ledger.conf.create_default_wallet = True + self.ledger.conf.create_default_account = True + await wm.load() + self.assertIsInstance(wm.default, Wallet) + self.assertIsInstance(wm.default.accounts.default, Account) + self.assertTrue(os.path.exists(wm.default.storage.path)) + + async def test_load_errors(self): + _wm = WalletManager(self.ledger, self.db) + await _wm.ensure_path_exists() + await _wm.create('bar', '') + await _wm.create('foo', '') + + wm = WalletManager(self.ledger, self.db) + self.ledger.conf.wallets = ['bar', 'foo', 'foo'] + with self.assertLogs(level='WARN') as cm: + await wm.load() + self.assertEqual( + cm.output, [ + 'WARNING:lbry.wallet.manager:Ignoring duplicate wallet_id in config: foo', + ] + ) + self.assertEqual({'bar', 'foo'}, set(wm.wallets)) + + async def test_creating_and_accessing_wallets(self): + wm = WalletManager(self.ledger, self.db) + await wm.ensure_path_exists() + await wm.load() + default_wallet = wm.default + self.assertEqual(default_wallet, wm['default_wallet']) + self.assertEqual(default_wallet, wm.get_or_default(None)) + new_wallet = await wm.create('second', 'Second Wallet') + self.assertEqual(default_wallet, wm.default) + self.assertEqual(new_wallet, wm['second']) + self.assertEqual(new_wallet, wm.get_or_default('second')) + self.assertEqual(default_wallet, wm.get_or_default(None)) + with self.assertRaisesRegex(ValueError, "Couldn't find wallet: invalid"): + _ = wm['invalid'] + with self.assertRaisesRegex(ValueError, "Couldn't find wallet: invalid"): + wm.get_or_default('invalid')