wallet commands and wallet_id argument to other commands

This commit is contained in:
Lex Berezhny 2019-09-20 00:05:37 -04:00
parent 9fbc83fe9d
commit 84587ac232
16 changed files with 608 additions and 406 deletions

File diff suppressed because it is too large Load diff

View file

@ -139,34 +139,34 @@ class Account(BaseAccount):
return details
def get_transaction_history(self, **constraints):
return self.ledger.get_transaction_history(account=self, **constraints)
return self.ledger.get_transaction_history(wallet=self.wallet, accounts=[self], **constraints)
def get_transaction_history_count(self, **constraints):
return self.ledger.get_transaction_history_count(account=self, **constraints)
return self.ledger.get_transaction_history_count(wallet=self.wallet, accounts=[self], **constraints)
def get_claims(self, **constraints):
return self.ledger.get_claims(account=self, **constraints)
return self.ledger.get_claims(wallet=self.wallet, accounts=[self], **constraints)
def get_claim_count(self, **constraints):
return self.ledger.get_claim_count(account=self, **constraints)
return self.ledger.get_claim_count(wallet=self.wallet, accounts=[self], **constraints)
def get_streams(self, **constraints):
return self.ledger.get_streams(account=self, **constraints)
return self.ledger.get_streams(wallet=self.wallet, accounts=[self], **constraints)
def get_stream_count(self, **constraints):
return self.ledger.get_stream_count(account=self, **constraints)
return self.ledger.get_stream_count(wallet=self.wallet, accounts=[self], **constraints)
def get_channels(self, **constraints):
return self.ledger.get_channels(account=self, **constraints)
return self.ledger.get_channels(wallet=self.wallet, accounts=[self], **constraints)
def get_channel_count(self, **constraints):
return self.ledger.get_channel_count(account=self, **constraints)
return self.ledger.get_channel_count(wallet=self.wallet, accounts=[self], **constraints)
def get_supports(self, **constraints):
return self.ledger.get_supports(account=self, **constraints)
return self.ledger.get_supports(wallet=self.wallet, accounts=[self], **constraints)
def get_support_count(self, **constraints):
return self.ledger.get_support_count(account=self, **constraints)
return self.ledger.get_support_count(wallet=self.wallet, accounts=[self], **constraints)
def get_support_summary(self):
return self.ledger.db.get_supports_summary(account_id=self.id)

View file

@ -22,18 +22,18 @@ class WalletDatabase(BaseDatabase):
claim_id text,
claim_name text
);
create index if not exists txo_address_idx on txo (address);
create index if not exists txo_claim_id_idx on txo (claim_id);
create index if not exists txo_txo_type_idx on txo (txo_type);
"""
CREATE_TABLES_QUERY = (
BaseDatabase.CREATE_TX_TABLE +
BaseDatabase.CREATE_PUBKEY_ADDRESS_TABLE +
BaseDatabase.CREATE_PUBKEY_ADDRESS_INDEX +
CREATE_TXO_TABLE +
BaseDatabase.CREATE_TXO_INDEX +
BaseDatabase.CREATE_TXI_TABLE +
BaseDatabase.CREATE_TXI_INDEX
BaseDatabase.PRAGMAS +
BaseDatabase.CREATE_ACCOUNT_TABLE +
BaseDatabase.CREATE_PUBKEY_ADDRESS_TABLE +
BaseDatabase.CREATE_TX_TABLE +
CREATE_TXO_TABLE +
BaseDatabase.CREATE_TXI_TABLE
)
def txo_to_row(self, tx, address, txo):
@ -50,18 +50,16 @@ class WalletDatabase(BaseDatabase):
row['claim_name'] = txo.claim_name
return row
async def get_txos(self, **constraints) -> List[Output]:
my_accounts = constraints.get('my_accounts', constraints.get('accounts', []))
txos = await super().get_txos(**constraints)
async def get_txos(self, wallet=None, no_tx=False, **constraints) -> List[Output]:
txos = await super().get_txos(wallet=wallet, no_tx=no_tx, **constraints)
channel_ids = set()
for txo in txos:
if txo.is_claim and txo.can_decode_claim:
if txo.claim.is_signed:
channel_ids.add(txo.claim.signing_channel_id)
if txo.claim.is_channel and my_accounts:
for account in my_accounts:
if txo.claim.is_channel and wallet:
for account in wallet.accounts:
private_key = account.get_channel_private_key(
txo.claim.channel.public_key_bytes
)
@ -73,7 +71,7 @@ class WalletDatabase(BaseDatabase):
channels = {
txo.claim_id: txo for txo in
(await self.get_claims(
my_accounts=my_accounts,
wallet=wallet,
claim_id__in=channel_ids
))
}

View file

@ -121,39 +121,30 @@ class MainNetLedger(BaseLedger):
return super().get_utxo_count(**constraints)
def get_claims(self, **constraints):
self.constraint_account_or_all(constraints)
return self.db.get_claims(**constraints)
def get_claim_count(self, **constraints):
self.constraint_account_or_all(constraints)
return self.db.get_claim_count(**constraints)
def get_streams(self, **constraints):
self.constraint_account_or_all(constraints)
return self.db.get_streams(**constraints)
def get_stream_count(self, **constraints):
self.constraint_account_or_all(constraints)
return self.db.get_stream_count(**constraints)
def get_channels(self, **constraints):
self.constraint_account_or_all(constraints)
return self.db.get_channels(**constraints)
def get_channel_count(self, **constraints):
self.constraint_account_or_all(constraints)
return self.db.get_channel_count(**constraints)
def get_supports(self, **constraints):
self.constraint_account_or_all(constraints)
return self.db.get_supports(**constraints)
def get_support_count(self, **constraints):
self.constraint_account_or_all(constraints)
return self.db.get_support_count(**constraints)
async def get_transaction_history(self, **constraints):
self.constraint_account_or_all(constraints)
txs = await self.db.get_transactions(**constraints)
headers = self.headers
history = []
@ -249,7 +240,6 @@ class MainNetLedger(BaseLedger):
return history
def get_transaction_history_count(self, **constraints):
self.constraint_account_or_all(constraints)
return self.db.get_transaction_count(**constraints)

View file

@ -203,7 +203,7 @@ class Transaction(BaseTransaction):
@classmethod
def pay(cls, amount: int, address: bytes, funding_accounts: List[Account], change_account: Account):
ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account)
ledger, wallet = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account)
output = Output.pay_pubkey_hash(amount, ledger.address_to_hash160(address))
return cls.create([], [output], funding_accounts, change_account)
@ -211,7 +211,7 @@ class Transaction(BaseTransaction):
def claim_create(
cls, name: str, claim: Claim, amount: int, holding_address: str,
funding_accounts: List[Account], change_account: Account, signing_channel: Output = None):
ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account)
ledger, wallet = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account)
claim_output = Output.pay_claim_name_pubkey_hash(
amount, name, claim, ledger.address_to_hash160(holding_address)
)
@ -223,7 +223,7 @@ class Transaction(BaseTransaction):
def claim_update(
cls, previous_claim: Output, claim: Claim, amount: int, holding_address: str,
funding_accounts: List[Account], change_account: Account, signing_channel: Output = None):
ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account)
ledger, wallet = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account)
updated_claim = Output.pay_update_claim_pubkey_hash(
amount, previous_claim.claim_name, previous_claim.claim_id,
claim, ledger.address_to_hash160(holding_address)
@ -239,7 +239,7 @@ class Transaction(BaseTransaction):
@classmethod
def support(cls, claim_name: str, claim_id: str, amount: int, holding_address: str,
funding_accounts: List[Account], change_account: Account):
ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account)
ledger, wallet = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account)
support_output = Output.pay_support_pubkey_hash(
amount, claim_name, claim_id, ledger.address_to_hash160(holding_address)
)
@ -248,7 +248,7 @@ class Transaction(BaseTransaction):
@classmethod
def purchase(cls, claim: Output, amount: int, merchant_address: bytes,
funding_accounts: List[Account], change_account: Account):
ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account)
ledger, wallet = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account)
claim_output = Output.purchase_claim_pubkey_hash(
amount, claim.claim_id, ledger.address_to_hash160(merchant_address)
)

View file

@ -466,45 +466,47 @@ class ChannelCommands(CommandTestCase):
'title': 'foo', 'email': 'new@email.com'}
)
# send channel to someone else
# move channel to another account
new_account = await self.out(self.daemon.jsonrpc_account_create('second account'))
account2_id, account2 = new_account['id'], self.daemon.get_account_or_error(new_account['id'])
account2_id, account2 = new_account['id'], self.wallet.get_account_or_error(new_account['id'])
# before sending
# before moving
self.assertEqual(len(await self.daemon.jsonrpc_channel_list()), 3)
self.assertEqual(len(await self.daemon.jsonrpc_channel_list(account_id=account2_id)), 0)
other_address = await account2.receiving.get_or_create_usable_address()
tx = await self.out(self.channel_update(claim_id, claim_address=other_address))
# after sending
# after moving
self.assertEqual(len(await self.daemon.jsonrpc_channel_list()), 3)
self.assertEqual(len(await self.daemon.jsonrpc_channel_list(account_id=self.account.id)), 2)
self.assertEqual(len(await self.daemon.jsonrpc_channel_list(account_id=account2_id)), 1)
# shoud not have private key
txo = (await account2.get_channels())[0]
self.assertIsNone(txo.private_key)
# send the private key too
private_key = self.account.get_channel_private_key(unhexlify(channel['public_key']))
account2.add_channel_private_key(private_key)
# now should have private key
txo = (await account2.get_channels())[0]
self.assertIsNotNone(txo.private_key)
async def test_channel_export_import_into_new_account(self):
async def test_channel_export_import_before_sending_channel(self):
# export
tx = await self.channel_create('@foo', '1.0')
claim_id = self.get_claim_id(tx)
channel_private_key = (await self.account.get_channels())[0].private_key
exported_data = await self.out(self.daemon.jsonrpc_channel_export(claim_id))
# import
daemon2 = await self.add_daemon()
self.assertEqual(0, len(await daemon2.jsonrpc_channel_list()))
await daemon2.jsonrpc_channel_import(exported_data)
channels = await daemon2.jsonrpc_channel_list()
self.assertEqual(1, len(channels))
self.assertEqual(channel_private_key.to_string(), channels[0].private_key.to_string())
# second wallet can't update until channel is sent to it
with self.assertRaisesRegex(AssertionError, 'Cannot find private key for signing output.'):
await daemon2.jsonrpc_channel_update(claim_id, bid='0.5')
# now send the channel as well
await self.channel_update(claim_id, claim_address=await daemon2.jsonrpc_address_unused())
# second wallet should be able to update now
await daemon2.jsonrpc_channel_update(claim_id, bid='0.5')
class StreamCommands(ClaimTestCase):
@ -565,7 +567,7 @@ class StreamCommands(ClaimTestCase):
async def test_publishing_checks_all_accounts_for_channel(self):
account1_id, account1 = self.account.id, self.account
new_account = await self.out(self.daemon.jsonrpc_account_create('second account'))
account2_id, account2 = new_account['id'], self.daemon.get_account_or_error(new_account['id'])
account2_id, account2 = new_account['id'], self.wallet.get_account_or_error(new_account['id'])
await self.out(self.channel_create('@spam', '1.0'))
self.assertEqual('8.989893', (await self.daemon.jsonrpc_account_balance())['available'])
@ -782,7 +784,7 @@ class StreamCommands(ClaimTestCase):
# send claim to someone else
new_account = await self.out(self.daemon.jsonrpc_account_create('second account'))
account2_id, account2 = new_account['id'], self.daemon.get_account_or_error(new_account['id'])
account2_id, account2 = new_account['id'], self.wallet.get_account_or_error(new_account['id'])
# before sending
self.assertEqual(len(await self.daemon.jsonrpc_claim_list()), 4)
@ -1079,13 +1081,12 @@ class StreamCommands(ClaimTestCase):
class SupportCommands(CommandTestCase):
async def test_regular_supports_and_tip_supports(self):
# account2 will be used to send tips and supports to account1
account2_id = (await self.out(self.daemon.jsonrpc_account_create('second account')))['id']
account2 = self.daemon.get_account_or_error(account2_id)
wallet2 = await self.daemon.jsonrpc_wallet_add('wallet2', create_wallet=True, create_account=True)
account2 = wallet2.accounts[0]
# send account2 5 LBC out of the 10 LBC in account1
result = await self.out(self.daemon.jsonrpc_account_send(
'5.0', await self.daemon.jsonrpc_address_unused(account2_id)
'5.0', await self.daemon.jsonrpc_address_unused(wallet_id='wallet2')
))
await self.on_transaction_dict(result)
@ -1103,7 +1104,7 @@ class SupportCommands(CommandTestCase):
# send a tip to the claim using account2
tip = await self.out(
self.daemon.jsonrpc_support_create(
claim_id, '1.0', True, account2_id, funding_account_ids=[account2_id])
claim_id, '1.0', True, account2.id, 'wallet2', funding_account_ids=[account2.id])
)
await self.confirm_tx(tip['txid'])
@ -1122,7 +1123,7 @@ class SupportCommands(CommandTestCase):
# verify that the outgoing tip is marked correctly as is_tip=True in account2
txs2 = await self.out(
self.daemon.jsonrpc_transaction_list(account2_id)
self.daemon.jsonrpc_transaction_list(wallet_id='wallet2', account_id=account2.id)
)
self.assertEqual(len(txs2[0]['support_info']), 1)
self.assertEqual(txs2[0]['support_info'][0]['balance_delta'], '-1.0')
@ -1134,7 +1135,7 @@ class SupportCommands(CommandTestCase):
# send a support to the claim using account2
support = await self.out(
self.daemon.jsonrpc_support_create(
claim_id, '2.0', False, account2_id, funding_account_ids=[account2_id])
claim_id, '2.0', False, account2.id, 'wallet2', funding_account_ids=[account2.id])
)
await self.confirm_tx(support['txid'])
@ -1143,7 +1144,7 @@ class SupportCommands(CommandTestCase):
await self.assertBalance(account2, '1.999717')
# verify that the outgoing support is marked correctly as is_tip=False in account2
txs2 = await self.out(self.daemon.jsonrpc_transaction_list(account2_id))
txs2 = await self.out(self.daemon.jsonrpc_transaction_list(wallet_id='wallet2'))
self.assertEqual(len(txs2[0]['support_info']), 1)
self.assertEqual(txs2[0]['support_info'][0]['balance_delta'], '-2.0')
self.assertEqual(txs2[0]['support_info'][0]['claim_id'], claim_id)

View file

@ -61,13 +61,17 @@ class TestAccount(AsyncioTestCase):
address = await account.receiving.ensure_address_gap()
self.assertEqual(address[0], 'bCqJrLHdoiRqEZ1whFZ3WHNb33bP34SuGx')
private_key = await self.ledger.get_private_key_for_address('bCqJrLHdoiRqEZ1whFZ3WHNb33bP34SuGx')
private_key = await self.ledger.get_private_key_for_address(
account.wallet, 'bCqJrLHdoiRqEZ1whFZ3WHNb33bP34SuGx'
)
self.assertEqual(
private_key.extended_key_string(),
'xprv9vwXVierUTT4hmoe3dtTeBfbNv1ph2mm8RWXARU6HsZjBaAoFaS2FRQu4fptR'
'AyJWhJW42dmsEaC1nKnVKKTMhq3TVEHsNj1ca3ciZMKktT'
)
private_key = await self.ledger.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX')
private_key = await self.ledger.get_private_key_for_address(
account.wallet, 'BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX'
)
self.assertIsNone(private_key)
def test_load_and_save_account(self):

View file

@ -60,7 +60,7 @@ class TestHierarchicalDeterministicAccount(AsyncioTestCase):
await account.receiving._generate_keys(8, 11)
records = await account.receiving.get_address_records()
self.assertEqual(
[r['position'] for r in records],
[r['pubkey'].n for r in records],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
)
@ -69,7 +69,7 @@ class TestHierarchicalDeterministicAccount(AsyncioTestCase):
self.assertEqual(len(new_keys), 8)
records = await account.receiving.get_address_records()
self.assertEqual(
[r['position'] for r in records],
[r['pubkey'].n for r in records],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
)
@ -125,14 +125,18 @@ class TestHierarchicalDeterministicAccount(AsyncioTestCase):
address = await account.receiving.ensure_address_gap()
self.assertEqual(address[0], '1CDLuMfwmPqJiNk5C2Bvew6tpgjAGgUk8J')
private_key = await self.ledger.get_private_key_for_address('1CDLuMfwmPqJiNk5C2Bvew6tpgjAGgUk8J')
private_key = await self.ledger.get_private_key_for_address(
account.wallet, '1CDLuMfwmPqJiNk5C2Bvew6tpgjAGgUk8J'
)
self.assertEqual(
private_key.extended_key_string(),
'xprv9xV7rhbg6M4yWrdTeLorz3Q1GrQb4aQzzGWboP3du7W7UUztzNTUrEYTnDfz7o'
'ptBygDxXYRppyiuenJpoBTgYP2C26E1Ah5FEALM24CsWi'
)
invalid_key = await self.ledger.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX')
invalid_key = await self.ledger.get_private_key_for_address(
account.wallet, 'BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX'
)
self.assertIsNone(invalid_key)
self.assertEqual(
@ -275,12 +279,18 @@ class TestSingleKeyAccount(AsyncioTestCase):
self.assertEqual(len(new_keys), 1)
self.assertEqual(new_keys[0], account.public_key.address)
records = await account.receiving.get_address_records()
pubkey = records[0].pop('pubkey')
self.assertEqual(records, [{
'position': 0, 'chain': 0,
'chain': 0,
'account': account.public_key.address,
'address': account.public_key.address,
'history': None,
'used_times': 0
}])
self.assertEqual(
pubkey.extended_key_string(),
account.public_key.extended_key_string()
)
# case #1: no new addresses needed
empty = await account.receiving.ensure_address_gap()
@ -333,14 +343,18 @@ class TestSingleKeyAccount(AsyncioTestCase):
address = await account.receiving.ensure_address_gap()
self.assertEqual(address[0], account.public_key.address)
private_key = await self.ledger.get_private_key_for_address(address[0])
private_key = await self.ledger.get_private_key_for_address(
account.wallet, address[0]
)
self.assertEqual(
private_key.extended_key_string(),
'xprv9s21ZrQH143K3TsAz5efNV8K93g3Ms3FXcjaWB9fVUsMwAoE3ZT4vYymkp'
'5BxKKfnpz8J6sHDFriX1SnpvjNkzcks8XBnxjGLS83BTyfpna',
)
invalid_key = await self.ledger.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX')
invalid_key = await self.ledger.get_private_key_for_address(
account.wallet, 'BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX'
)
self.assertIsNone(invalid_key)
self.assertEqual(

View file

@ -199,13 +199,14 @@ class TestQueries(AsyncioTestCase):
'db': ledger_class.database_class(':memory:'),
'headers': ledger_class.headers_class(':memory:'),
})
self.wallet = Wallet()
await self.ledger.db.open()
async def asyncTearDown(self):
await self.ledger.db.close()
async def create_account(self):
account = self.ledger.account_class.generate(self.ledger, Wallet())
async def create_account(self, wallet=None):
account = self.ledger.account_class.generate(self.ledger, wallet or self.wallet)
await account.ensure_address_gap()
return account
@ -264,7 +265,8 @@ class TestQueries(AsyncioTestCase):
tx = await self.create_tx_from_txo(tx.outputs[0], account, height=height)
variable_limit = self.ledger.db.MAX_QUERY_VARIABLES
for limit in range(variable_limit-2, variable_limit+2):
txs = await self.ledger.get_transactions(limit=limit, order_by='height asc')
txs = await self.ledger.get_transactions(
accounts=self.wallet.accounts, limit=limit, order_by='height asc')
self.assertEqual(len(txs), limit)
inputs, outputs, last_tx = set(), set(), txs[0]
for tx in txs[1:]:
@ -274,19 +276,23 @@ class TestQueries(AsyncioTestCase):
last_tx = tx
async def test_queries(self):
self.assertEqual(0, await self.ledger.db.get_address_count())
account1 = await self.create_account()
self.assertEqual(26, await self.ledger.db.get_address_count())
account2 = await self.create_account()
self.assertEqual(52, await self.ledger.db.get_address_count())
wallet1 = Wallet()
account1 = await self.create_account(wallet1)
self.assertEqual(26, await self.ledger.db.get_address_count(accounts=[account1]))
wallet2 = Wallet()
account2 = await self.create_account(wallet2)
account3 = await self.create_account(wallet2)
self.assertEqual(26, await self.ledger.db.get_address_count(accounts=[account2]))
self.assertEqual(0, await self.ledger.db.get_transaction_count(accounts=[account1, account2]))
self.assertEqual(0, await self.ledger.db.get_transaction_count(accounts=[account1, account2, account3]))
self.assertEqual(0, await self.ledger.db.get_utxo_count())
self.assertEqual([], await self.ledger.db.get_utxos())
self.assertEqual(0, await self.ledger.db.get_txo_count())
self.assertEqual(0, await self.ledger.db.get_balance())
self.assertEqual(0, await self.ledger.db.get_balance(wallet=wallet1))
self.assertEqual(0, await self.ledger.db.get_balance(wallet=wallet2))
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account1]))
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account2]))
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account3]))
tx1 = await self.create_tx_from_nothing(account1, 1)
self.assertEqual(1, await self.ledger.db.get_transaction_count(accounts=[account1]))
@ -294,20 +300,28 @@ class TestQueries(AsyncioTestCase):
self.assertEqual(1, await self.ledger.db.get_utxo_count(accounts=[account1]))
self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account1]))
self.assertEqual(0, await self.ledger.db.get_txo_count(accounts=[account2]))
self.assertEqual(10**8, await self.ledger.db.get_balance())
self.assertEqual(10**8, await self.ledger.db.get_balance(wallet=wallet1))
self.assertEqual(0, await self.ledger.db.get_balance(wallet=wallet2))
self.assertEqual(10**8, await self.ledger.db.get_balance(accounts=[account1]))
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account2]))
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account3]))
tx2 = await self.create_tx_from_txo(tx1.outputs[0], account2, 2)
tx2b = await self.create_tx_from_nothing(account3, 2)
self.assertEqual(2, await self.ledger.db.get_transaction_count(accounts=[account1]))
self.assertEqual(1, await self.ledger.db.get_transaction_count(accounts=[account2]))
self.assertEqual(1, await self.ledger.db.get_transaction_count(accounts=[account3]))
self.assertEqual(0, await self.ledger.db.get_utxo_count(accounts=[account1]))
self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account1]))
self.assertEqual(1, await self.ledger.db.get_utxo_count(accounts=[account2]))
self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account2]))
self.assertEqual(10**8, await self.ledger.db.get_balance())
self.assertEqual(1, await self.ledger.db.get_utxo_count(accounts=[account3]))
self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account3]))
self.assertEqual(0, await self.ledger.db.get_balance(wallet=wallet1))
self.assertEqual(10**8+10**8, await self.ledger.db.get_balance(wallet=wallet2))
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account1]))
self.assertEqual(10**8, await self.ledger.db.get_balance(accounts=[account2]))
self.assertEqual(10**8, await self.ledger.db.get_balance(accounts=[account3]))
tx3 = await self.create_tx_to_nowhere(tx2.outputs[0], 3)
self.assertEqual(2, await self.ledger.db.get_transaction_count(accounts=[account1]))
@ -316,22 +330,24 @@ class TestQueries(AsyncioTestCase):
self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account1]))
self.assertEqual(0, await self.ledger.db.get_utxo_count(accounts=[account2]))
self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account2]))
self.assertEqual(0, await self.ledger.db.get_balance())
self.assertEqual(0, await self.ledger.db.get_balance(wallet=wallet1))
self.assertEqual(10**8, await self.ledger.db.get_balance(wallet=wallet2))
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account1]))
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account2]))
self.assertEqual(10**8, await self.ledger.db.get_balance(accounts=[account3]))
txs = await self.ledger.db.get_transactions(accounts=[account1, account2])
self.assertEqual([tx3.id, tx2.id, tx1.id], [tx.id for tx in txs])
self.assertEqual([3, 2, 1], [tx.height for tx in txs])
txs = await self.ledger.db.get_transactions(accounts=[account1])
txs = await self.ledger.db.get_transactions(wallet=wallet1, accounts=wallet1.accounts)
self.assertEqual([tx2.id, tx1.id], [tx.id for tx in txs])
self.assertEqual(txs[0].inputs[0].is_my_account, True)
self.assertEqual(txs[0].outputs[0].is_my_account, False)
self.assertEqual(txs[1].inputs[0].is_my_account, False)
self.assertEqual(txs[1].outputs[0].is_my_account, True)
txs = await self.ledger.db.get_transactions(accounts=[account2])
txs = await self.ledger.db.get_transactions(wallet=wallet2, accounts=[account2])
self.assertEqual([tx3.id, tx2.id], [tx.id for tx in txs])
self.assertEqual(txs[0].inputs[0].is_my_account, True)
self.assertEqual(txs[0].outputs[0].is_my_account, False)
@ -343,18 +359,18 @@ class TestQueries(AsyncioTestCase):
self.assertEqual(tx.id, tx2.id)
self.assertEqual(tx.inputs[0].is_my_account, False)
self.assertEqual(tx.outputs[0].is_my_account, False)
tx = await self.ledger.db.get_transaction(txid=tx2.id, accounts=[account1])
tx = await self.ledger.db.get_transaction(wallet=wallet1, txid=tx2.id)
self.assertEqual(tx.inputs[0].is_my_account, True)
self.assertEqual(tx.outputs[0].is_my_account, False)
tx = await self.ledger.db.get_transaction(txid=tx2.id, accounts=[account2])
tx = await self.ledger.db.get_transaction(wallet=wallet2, txid=tx2.id)
self.assertEqual(tx.inputs[0].is_my_account, False)
self.assertEqual(tx.outputs[0].is_my_account, True)
# height 0 sorted to the top with the rest in descending order
tx4 = await self.create_tx_from_nothing(account1, 0)
txos = await self.ledger.db.get_txos()
self.assertEqual([0, 2, 1], [txo.tx_ref.height for txo in txos])
self.assertEqual([tx4.id, tx2.id, tx1.id], [txo.tx_ref.id for txo in txos])
self.assertEqual([0, 2, 2, 1], [txo.tx_ref.height for txo in txos])
self.assertEqual([tx4.id, tx2.id, tx2b.id, tx1.id], [txo.tx_ref.id for txo in txos])
txs = await self.ledger.db.get_transactions(accounts=[account1, account2])
self.assertEqual([0, 3, 2, 1], [tx.height for tx in txs])
self.assertEqual([tx4.id, tx3.id, tx2.id, tx1.id], [tx.id for tx in txs])
@ -382,13 +398,13 @@ class TestUpgrade(AsyncioTestCase):
def add_address(self, address):
with sqlite3.connect(self.path) as conn:
conn.execute("""
INSERT INTO pubkey_address (address, account, chain, position, pubkey)
VALUES (?, 'account1', 0, 0, 'pubkey blob')
INSERT INTO account_address (address, account, chain, n, pubkey, chain_code, depth)
VALUES (?, 'account1', 0, 0, 'pubkey', 'chain_code', 0)
""", (address,))
def get_addresses(self):
with sqlite3.connect(self.path) as conn:
sql = "SELECT address FROM pubkey_address ORDER BY address;"
sql = "SELECT address FROM account_address ORDER BY address;"
return [col[0] for col in conn.execute(sql).fetchall()]
async def test_reset_on_version_change(self):
@ -401,7 +417,7 @@ class TestUpgrade(AsyncioTestCase):
self.ledger.db.SCHEMA_VERSION = None
self.assertEqual(self.get_tables(), [])
await self.ledger.db.open()
self.assertEqual(self.get_tables(), ['pubkey_address', 'tx', 'txi', 'txo'])
self.assertEqual(self.get_tables(), ['account_address', 'pubkey_address', 'tx', 'txi', 'txo'])
self.assertEqual(self.get_addresses(), [])
self.add_address('address1')
await self.ledger.db.close()
@ -410,17 +426,17 @@ class TestUpgrade(AsyncioTestCase):
self.ledger.db.SCHEMA_VERSION = '1.0'
await self.ledger.db.open()
self.assertEqual(self.get_version(), '1.0')
self.assertEqual(self.get_tables(), ['pubkey_address', 'tx', 'txi', 'txo', 'version'])
self.assertEqual(self.get_tables(), ['account_address', 'pubkey_address', 'tx', 'txi', 'txo', 'version'])
self.assertEqual(self.get_addresses(), []) # address1 deleted during version upgrade
self.add_address('address2')
await self.ledger.db.close()
# nothing changes
self.assertEqual(self.get_version(), '1.0')
self.assertEqual(self.get_tables(), ['pubkey_address', 'tx', 'txi', 'txo', 'version'])
self.assertEqual(self.get_tables(), ['account_address', 'pubkey_address', 'tx', 'txi', 'txo', 'version'])
await self.ledger.db.open()
self.assertEqual(self.get_version(), '1.0')
self.assertEqual(self.get_tables(), ['pubkey_address', 'tx', 'txi', 'txo', 'version'])
self.assertEqual(self.get_tables(), ['account_address', 'pubkey_address', 'tx', 'txi', 'txo', 'version'])
self.assertEqual(self.get_addresses(), ['address2'])
await self.ledger.db.close()
@ -431,7 +447,7 @@ class TestUpgrade(AsyncioTestCase):
"""
await self.ledger.db.open()
self.assertEqual(self.get_version(), '1.1')
self.assertEqual(self.get_tables(), ['foo', 'pubkey_address', 'tx', 'txi', 'txo', 'version'])
self.assertEqual(self.get_tables(), ['account_address', 'foo', 'pubkey_address', 'tx', 'txi', 'txo', 'version'])
self.assertEqual(self.get_addresses(), []) # all tables got reset
await self.ledger.db.close()

View file

@ -115,7 +115,7 @@ class HierarchicalDeterministic(AddressManager):
return self.account.public_key.child(self.chain_number).child(index)
async def get_max_gap(self) -> int:
addresses = await self._query_addresses(order_by="position ASC")
addresses = await self._query_addresses(order_by="n asc")
max_gap = 0
current_gap = 0
for address in addresses:
@ -128,7 +128,7 @@ class HierarchicalDeterministic(AddressManager):
async def ensure_address_gap(self) -> List[str]:
async with self.address_generator_lock:
addresses = await self._query_addresses(limit=self.gap, order_by="position DESC")
addresses = await self._query_addresses(limit=self.gap, order_by="n desc")
existing_gap = 0
for address in addresses:
@ -140,7 +140,7 @@ class HierarchicalDeterministic(AddressManager):
if existing_gap == self.gap:
return []
start = addresses[0]['position']+1 if addresses else 0
start = addresses[0]['pubkey'].n+1 if addresses else 0
end = start + (self.gap - existing_gap)
new_keys = await self._generate_keys(start, end-1)
await self.account.ledger.announce_addresses(self, new_keys)
@ -149,15 +149,15 @@ class HierarchicalDeterministic(AddressManager):
async def _generate_keys(self, start: int, end: int) -> List[str]:
if not self.address_generator_lock.locked():
raise RuntimeError('Should not be called outside of address_generator_lock.')
keys = [(index, self.public_key.child(index)) for index in range(start, end+1)]
keys = [self.public_key.child(index) for index in range(start, end+1)]
await self.account.ledger.db.add_keys(self.account, self.chain_number, keys)
return [key[1].address for key in keys]
return [key.address for key in keys]
def get_address_records(self, only_usable: bool = False, **constraints):
if only_usable:
constraints['used_times__lt'] = self.maximum_uses_per_address
if 'order_by' not in constraints:
constraints['order_by'] = "used_times ASC, position ASC"
constraints['order_by'] = "used_times asc, n asc"
return self._query_addresses(**constraints)
@ -190,9 +190,7 @@ class SingleKey(AddressManager):
async with self.address_generator_lock:
exists = await self.get_address_records()
if not exists:
await self.account.ledger.db.add_keys(
self.account, self.chain_number, [(0, self.public_key)]
)
await self.account.ledger.db.add_keys(self.account, self.chain_number, [self.public_key])
new_keys = [self.public_key.address]
await self.account.ledger.announce_addresses(self, new_keys)
return new_keys
@ -417,16 +415,16 @@ class BaseAccount:
}
def get_utxos(self, **constraints):
return self.ledger.get_utxos(account=self, **constraints)
return self.ledger.get_utxos(wallet=self.wallet, accounts=[self], **constraints)
def get_utxo_count(self, **constraints):
return self.ledger.get_utxo_count(account=self, **constraints)
return self.ledger.get_utxo_count(wallet=self.wallet, accounts=[self], **constraints)
def get_transactions(self, **constraints):
return self.ledger.get_transactions(account=self, **constraints)
return self.ledger.get_transactions(wallet=self.wallet, accounts=[self], **constraints)
def get_transaction_count(self, **constraints):
return self.ledger.get_transaction_count(account=self, **constraints)
return self.ledger.get_transaction_count(wallet=self.wallet, accounts=[self], **constraints)
async def fund(self, to_account, amount=None, everything=False,
outputs=1, broadcast=False, **constraints):

View file

@ -1,5 +1,6 @@
import logging
import asyncio
import json
from binascii import hexlify
from concurrent.futures.thread import ThreadPoolExecutor
@ -8,7 +9,7 @@ from typing import Tuple, List, Union, Callable, Any, Awaitable, Iterable, Dict,
import sqlite3
from torba.client.basetransaction import BaseTransaction, TXRefImmutable
from torba.client.baseaccount import BaseAccount
from torba.client.bip32 import PubKey
log = logging.getLogger(__name__)
@ -165,9 +166,8 @@ def query(select, **constraints) -> Tuple[str, Dict[str, Any]]:
offset = constraints.pop('offset', None)
order_by = constraints.pop('order_by', None)
constraints.pop('my_accounts', None)
accounts = constraints.pop('accounts', None)
if accounts is not None:
accounts = constraints.pop('accounts', [])
if accounts:
constraints['account__in'] = [a.public_key.address for a in accounts]
where, values = constraints_to_sql(constraints)
@ -285,26 +285,33 @@ class SQLiteMixin:
class BaseDatabase(SQLiteMixin):
SCHEMA_VERSION = "1.0"
SCHEMA_VERSION = "1.1"
PRAGMAS = """
pragma journal_mode=WAL;
"""
CREATE_ACCOUNT_TABLE = """
create table if not exists account_address (
account text not null,
address text not null,
chain integer not null,
pubkey blob not null,
chain_code blob not null,
n integer not null,
depth integer not null,
primary key (account, address)
);
create index if not exists address_account_idx on account_address (address, account);
"""
CREATE_PUBKEY_ADDRESS_TABLE = """
create table if not exists pubkey_address (
address text primary key,
account text not null,
chain integer not null,
position integer not null,
pubkey blob not null,
history text,
used_times integer not null default 0
);
"""
CREATE_PUBKEY_ADDRESS_INDEX = """
create index if not exists pubkey_address_account_idx on pubkey_address (account);
"""
CREATE_TX_TABLE = """
create table if not exists tx (
@ -326,8 +333,6 @@ class BaseDatabase(SQLiteMixin):
script blob not null,
is_reserved boolean not null default 0
);
"""
CREATE_TXO_INDEX = """
create index if not exists txo_address_idx on txo (address);
"""
@ -337,21 +342,17 @@ class BaseDatabase(SQLiteMixin):
txoid text references txo,
address text references pubkey_address
);
"""
CREATE_TXI_INDEX = """
create index if not exists txi_address_idx on txi (address);
create index if not exists txi_txoid_idx on txi (txoid);
"""
CREATE_TABLES_QUERY = (
PRAGMAS +
CREATE_TX_TABLE +
CREATE_ACCOUNT_TABLE +
CREATE_PUBKEY_ADDRESS_TABLE +
CREATE_PUBKEY_ADDRESS_INDEX +
CREATE_TX_TABLE +
CREATE_TXO_TABLE +
CREATE_TXO_INDEX +
CREATE_TXI_TABLE +
CREATE_TXI_INDEX
CREATE_TXI_TABLE
)
@staticmethod
@ -434,25 +435,22 @@ class BaseDatabase(SQLiteMixin):
async def select_transactions(self, cols, accounts=None, **constraints):
if not set(constraints) & {'txid', 'txid__in'}:
assert accounts is not None, "'accounts' argument required when no 'txid' constraint"
assert accounts, "'accounts' argument required when no 'txid' constraint is present"
constraints.update({
f'$account{i}': a.public_key.address for i, a in enumerate(accounts)
})
account_values = ', '.join([f':$account{i}' for i in range(len(accounts))])
where = f" WHERE account_address.account IN ({account_values})"
constraints['txid__in'] = f"""
SELECT txo.txid FROM txo
INNER JOIN pubkey_address USING (address) WHERE pubkey_address.account IN ({account_values})
SELECT txo.txid FROM txo JOIN account_address USING (address) {where}
UNION
SELECT txi.txid FROM txi
INNER JOIN pubkey_address USING (address) WHERE pubkey_address.account IN ({account_values})
SELECT txi.txid FROM txi JOIN account_address USING (address) {where}
"""
return await self.db.execute_fetchall(
*query("SELECT {} FROM tx".format(cols), **constraints)
)
async def get_transactions(self, **constraints):
accounts = constraints.get('accounts', None)
async def get_transactions(self, wallet=None, **constraints):
tx_rows = await self.select_transactions(
'txid, raw, height, position, is_verified',
order_by=constraints.pop('order_by', ["height=0 DESC", "height DESC", "position DESC"]),
@ -477,7 +475,7 @@ class BaseDatabase(SQLiteMixin):
annotated_txos.update({
txo.id: txo for txo in
(await self.get_txos(
my_accounts=accounts,
wallet=wallet,
txid__in=txids[offset:offset+step],
))
})
@ -487,7 +485,7 @@ class BaseDatabase(SQLiteMixin):
referenced_txos.update({
txo.id: txo for txo in
(await self.get_txos(
my_accounts=accounts,
wallet=wallet,
txoid__in=txi_txoids[offset:offset+step],
))
})
@ -507,6 +505,7 @@ class BaseDatabase(SQLiteMixin):
return txs
async def get_transaction_count(self, **constraints):
constraints.pop('wallet', None)
constraints.pop('offset', None)
constraints.pop('limit', None)
constraints.pop('order_by', None)
@ -519,23 +518,24 @@ class BaseDatabase(SQLiteMixin):
return txs[0]
async def select_txos(self, cols, **constraints):
return await self.db.execute_fetchall(*query(
"SELECT {} FROM txo"
" JOIN pubkey_address USING (address)"
" JOIN tx USING (txid)".format(cols), **constraints
))
sql = "SELECT {} FROM txo JOIN tx USING (txid)"
if 'accounts' in constraints:
sql += " JOIN account_address USING (address)"
return await self.db.execute_fetchall(*query(sql.format(cols), **constraints))
async def get_txos(self, my_accounts=None, no_tx=False, **constraints):
my_accounts = [
(a.public_key.address if isinstance(a, BaseAccount) else a)
for a in (my_accounts or constraints.get('accounts', []))
]
async def get_txos(self, wallet=None, no_tx=False, **constraints):
my_accounts = set(a.public_key.address for a in wallet.accounts) if wallet else set()
if 'order_by' not in constraints:
constraints['order_by'] = [
"tx.height=0 DESC", "tx.height DESC", "tx.position DESC", "txo.position"]
"tx.height=0 DESC", "tx.height DESC", "tx.position DESC", "txo.position"
]
rows = await self.select_txos(
"tx.txid, raw, tx.height, tx.position, tx.is_verified, "
"txo.position, chain, account, amount, script",
"""
tx.txid, raw, tx.height, tx.position, tx.is_verified, txo.position, amount, script, (
select json_group_object(account, chain) from account_address
where account_address.address=txo.address
)
""",
**constraints
)
txos = []
@ -544,8 +544,8 @@ class BaseDatabase(SQLiteMixin):
for row in rows:
if no_tx:
txo = output_class(
amount=row[8],
script=output_class.script_class(row[9]),
amount=row[6],
script=output_class.script_class(row[7]),
tx_ref=TXRefImmutable.from_id(row[0], row[2]),
position=row[5]
)
@ -555,12 +555,18 @@ class BaseDatabase(SQLiteMixin):
row[1], height=row[2], position=row[3], is_verified=row[4]
)
txo = txs[row[0]].outputs[row[5]]
txo.is_change = row[6] == 1
txo.is_my_account = row[7] in my_accounts
row_accounts = json.loads(row[8])
account_match = set(row_accounts) & my_accounts
if account_match:
txo.is_my_account = True
txo.is_change = row_accounts[account_match.pop()] == 1
else:
txo.is_change = txo.is_my_account = False
txos.append(txo)
return txos
async def get_txo_count(self, **constraints):
constraints.pop('wallet', None)
constraints.pop('offset', None)
constraints.pop('limit', None)
constraints.pop('order_by', None)
@ -580,41 +586,56 @@ class BaseDatabase(SQLiteMixin):
self.constrain_utxo(constraints)
return self.get_txo_count(**constraints)
async def get_balance(self, **constraints):
async def get_balance(self, wallet=None, accounts=None, **constraints):
assert wallet or accounts, \
"'wallet' or 'accounts' constraints required to calculate balance"
constraints['accounts'] = accounts or wallet.accounts
self.constrain_utxo(constraints)
balance = await self.select_txos('SUM(amount)', **constraints)
return balance[0][0] or 0
async def select_addresses(self, cols, **constraints):
return await self.db.execute_fetchall(*query(
"SELECT {} FROM pubkey_address".format(cols), **constraints
"SELECT {} FROM pubkey_address JOIN account_address USING (address)".format(cols),
**constraints
))
async def get_addresses(self, cols=('address', 'account', 'chain', 'position', 'used_times'),
**constraints):
addresses = await self.select_addresses(', '.join(cols), **constraints)
return rows_to_dict(addresses, cols)
async def get_addresses(self, **constraints):
cols = (
'address', 'account', 'chain', 'history', 'used_times',
'pubkey', 'chain_code', 'n', 'depth'
)
addresses = rows_to_dict(await self.select_addresses(', '.join(cols), **constraints), cols)
for address in addresses:
address['pubkey'] = PubKey(
self.ledger, address.pop('pubkey'), address.pop('chain_code'),
address.pop('n'), address.pop('depth')
)
return addresses
async def get_address_count(self, **constraints):
count = await self.select_addresses('count(*)', **constraints)
return count[0][0]
async def get_address(self, **constraints):
addresses = await self.get_addresses(
cols=('address', 'account', 'chain', 'position', 'pubkey', 'history', 'used_times'),
limit=1, **constraints
)
addresses = await self.get_addresses(limit=1, **constraints)
if addresses:
return addresses[0]
async def add_keys(self, account, chain, keys):
async def add_keys(self, account, chain, pubkeys):
await self.db.executemany(
"insert into pubkey_address (address, account, chain, position, pubkey) values (?, ?, ?, ?, ?)",
(
(pubkey.address, account.public_key.address, chain, position,
sqlite3.Binary(pubkey.pubkey_bytes))
for position, pubkey in keys
)
"insert or ignore into account_address "
"(account, address, chain, pubkey, chain_code, n, depth) values "
"(?, ?, ?, ?, ?, ?, ?)", ((
account.id, k.address, chain,
sqlite3.Binary(k.pubkey_bytes),
sqlite3.Binary(k.chain_code),
k.n, k.depth
) for k in pubkeys)
)
await self.db.executemany(
"insert or ignore into pubkey_address (address) values (?)",
((pubkey.address,) for pubkey in pubkeys)
)
async def _set_address_history(self, address, history):

View file

@ -179,29 +179,29 @@ class BaseLedger(metaclass=LedgerRegistry):
def add_account(self, account: baseaccount.BaseAccount):
self.accounts.append(account)
async def _get_account_and_address_info_for_address(self, address):
match = await self.db.get_address(address=address)
async def _get_account_and_address_info_for_address(self, wallet, address):
match = await self.db.get_address(accounts=wallet.accounts, address=address)
if match:
for account in self.accounts:
for account in wallet.accounts:
if match['account'] == account.public_key.address:
return account, match
async def get_private_key_for_address(self, address) -> Optional[PrivateKey]:
match = await self._get_account_and_address_info_for_address(address)
async def get_private_key_for_address(self, wallet, address) -> Optional[PrivateKey]:
match = await self._get_account_and_address_info_for_address(wallet, address)
if match:
account, address_info = match
return account.get_private_key(address_info['chain'], address_info['position'])
return account.get_private_key(address_info['chain'], address_info['pubkey'].n)
return None
async def get_public_key_for_address(self, address) -> Optional[PubKey]:
match = await self._get_account_and_address_info_for_address(address)
async def get_public_key_for_address(self, wallet, address) -> Optional[PubKey]:
match = await self._get_account_and_address_info_for_address(wallet, address)
if match:
account, address_info = match
return account.get_public_key(address_info['chain'], address_info['position'])
_, address_info = match
return address_info['pubkey']
return None
async def get_account_for_address(self, address):
match = await self._get_account_and_address_info_for_address(address)
async def get_account_for_address(self, wallet, address):
match = await self._get_account_and_address_info_for_address(wallet, address)
if match:
return match[0]
@ -214,15 +214,9 @@ class BaseLedger(metaclass=LedgerRegistry):
return estimators
async def get_addresses(self, **constraints):
self.constraint_account_or_all(constraints)
addresses = await self.db.get_addresses(**constraints)
for address in addresses:
public_key = await self.get_public_key_for_address(address['address'])
address['public_key'] = public_key.extended_key_string()
return addresses
return await self.db.get_addresses(**constraints)
def get_address_count(self, **constraints):
self.constraint_account_or_all(constraints)
return self.db.get_address_count(**constraints)
async def get_spendable_utxos(self, amount: int, funding_accounts):
@ -244,29 +238,16 @@ class BaseLedger(metaclass=LedgerRegistry):
def release_tx(self, tx):
return self.release_outputs([txi.txo_ref.txo for txi in tx.inputs])
def constraint_account_or_all(self, constraints):
if 'accounts' in constraints:
return
account = constraints.pop('account', None)
if account:
constraints['accounts'] = [account]
else:
constraints['accounts'] = self.accounts
def get_utxos(self, **constraints):
self.constraint_account_or_all(constraints)
return self.db.get_utxos(**constraints)
def get_utxo_count(self, **constraints):
self.constraint_account_or_all(constraints)
return self.db.get_utxo_count(**constraints)
def get_transactions(self, **constraints):
self.constraint_account_or_all(constraints)
return self.db.get_transactions(**constraints)
def get_transaction_count(self, **constraints):
self.constraint_account_or_all(constraints)
return self.db.get_transaction_count(**constraints)
async def get_local_status_and_history(self, address, history=None):
@ -577,7 +558,7 @@ class BaseLedger(metaclass=LedgerRegistry):
addresses.add(
self.hash160_to_address(txo.script.values['pubkey_hash'])
)
records = await self.db.get_addresses(cols=('address',), address__in=addresses)
records = await self.db.get_addresses(address__in=addresses)
_, pending = await asyncio.wait([
self.on_transaction.where(partial(
lambda a, e: a == e.address and e.tx.height >= height and e.tx.id == tx.id,

View file

@ -1,6 +1,6 @@
import asyncio
import logging
from typing import Type, MutableSequence, MutableMapping
from typing import Type, MutableSequence, MutableMapping, Optional
from torba.client.baseledger import BaseLedger, LedgerRegistry
from torba.client.wallet import Wallet, WalletStorage
@ -41,16 +41,6 @@ class BaseWalletManager:
self.wallets.append(wallet)
return wallet
async def get_detailed_accounts(self, **kwargs):
ledgers = {}
for i, account in enumerate(self.accounts):
details = await account.get_details(**kwargs)
details['is_default'] = i == 0
ledger_id = account.ledger.get_id()
ledgers.setdefault(ledger_id, [])
ledgers[ledger_id].append(details)
return ledgers
@property
def default_wallet(self):
for wallet in self.wallets:
@ -78,3 +68,14 @@ class BaseWalletManager:
l.stop() for l in self.ledgers.values()
))
self.running = False
def get_wallet_or_default(self, wallet_id: Optional[str]) -> Wallet:
if wallet_id is None:
return self.default_wallet
return self.get_wallet_or_error(wallet_id)
def get_wallet_or_error(self, wallet_id: str) -> Wallet:
for wallet in self.wallets:
if wallet.id == wallet_id:
return wallet
raise ValueError(f"Couldn't find wallet: {wallet_id}.")

View file

@ -1,6 +1,6 @@
import logging
import typing
from typing import List, Iterable, Optional
from typing import List, Iterable, Optional, Tuple
from binascii import hexlify
from torba.client.basescript import BaseInputScript, BaseOutputScript
@ -12,7 +12,7 @@ from torba.client.util import ReadOnlyList
from torba.client.errors import InsufficientFundsError
if typing.TYPE_CHECKING:
from torba.client import baseledger
from torba.client import baseledger, wallet as basewallet
log = logging.getLogger()
@ -431,21 +431,32 @@ class BaseTransaction:
self.locktime = stream.read_uint32()
@classmethod
def ensure_all_have_same_ledger(cls, funding_accounts: Iterable[BaseAccount],
change_account: BaseAccount = None) -> 'baseledger.BaseLedger':
ledger = None
def ensure_all_have_same_ledger_and_wallet(
cls, funding_accounts: Iterable[BaseAccount],
change_account: BaseAccount = None) -> Tuple['baseledger.BaseLedger', 'basewallet.Wallet']:
ledger = wallet = None
for account in funding_accounts:
if ledger is None:
ledger = account.ledger
wallet = account.wallet
if ledger != account.ledger:
raise ValueError(
'All funding accounts used to create a transaction must be on the same ledger.'
)
if change_account is not None and change_account.ledger != ledger:
raise ValueError('Change account must use same ledger as funding accounts.')
if wallet != account.wallet:
raise ValueError(
'All funding accounts used to create a transaction must be from the same wallet.'
)
if change_account is not None:
if change_account.ledger != ledger:
raise ValueError('Change account must use same ledger as funding accounts.')
if change_account.wallet != wallet:
raise ValueError('Change account must use same wallet as funding accounts.')
if ledger is None:
raise ValueError('No ledger found.')
return ledger
if wallet is None:
raise ValueError('No wallet found.')
return ledger, wallet
@classmethod
async def create(cls, inputs: Iterable[BaseInput], outputs: Iterable[BaseOutput],
@ -458,7 +469,7 @@ class BaseTransaction:
.add_inputs(inputs) \
.add_outputs(outputs)
ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account)
ledger, _ = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account)
# value of the outputs plus associated fees
cost = (
@ -524,15 +535,15 @@ class BaseTransaction:
return hash_type
async def sign(self, funding_accounts: Iterable[BaseAccount]):
ledger = self.ensure_all_have_same_ledger(funding_accounts)
ledger, wallet = self.ensure_all_have_same_ledger_and_wallet(funding_accounts)
for i, txi in enumerate(self._inputs):
assert txi.script is not None
assert txi.txo_ref.txo is not None
txo_script = txi.txo_ref.txo.script
if txo_script.is_pay_pubkey_hash:
address = ledger.hash160_to_address(txo_script.values['pubkey_hash'])
private_key = await ledger.get_private_key_for_address(address)
assert private_key is not None
private_key = await ledger.get_private_key_for_address(wallet, address)
assert private_key is not None, 'Cannot find private key for signing output.'
tx = self._serialize_for_signature(i)
txi.script.values['signature'] = \
private_key.sign(tx) + bytes((self.signature_hash_type(1),))

View file

@ -3,7 +3,7 @@ import stat
import json
import zlib
import typing
from typing import Sequence, MutableSequence
from typing import List, Sequence, MutableSequence, Optional
from hashlib import sha256
from operator import attrgetter
from torba.client.hash import better_aes_encrypt, better_aes_decrypt
@ -25,12 +25,51 @@ class Wallet:
self.accounts = accounts or []
self.storage = storage or WalletStorage()
def add_account(self, account):
@property
def id(self):
if self.storage.path:
return os.path.basename(self.storage.path)
return self.name
def add_account(self, account: 'baseaccount.BaseAccount'):
self.accounts.append(account)
def generate_account(self, ledger: 'baseledger.BaseLedger') -> 'baseaccount.BaseAccount':
return ledger.account_class.generate(ledger, self)
@property
def default_account(self) -> Optional['baseaccount.BaseAccount']:
for account in self.accounts:
return account
return None
def get_account_or_default(self, account_id: str) -> Optional['baseaccount.BaseAccount']:
if account_id is None:
return self.default_account
return self.get_account_or_error(account_id)
def get_account_or_error(self, account_id: str) -> 'baseaccount.BaseAccount':
for account in self.accounts:
if account.id == account_id:
return account
raise ValueError(f"Couldn't find account: {account_id}.")
def get_accounts_or_all(self, account_ids: List[str]) -> Sequence['baseaccount.BaseAccount']:
return [
self.get_account_or_error(account_id)
for account_id in account_ids
] if account_ids else self.accounts
async def get_detailed_accounts(self, **kwargs):
ledgers = {}
for i, account in enumerate(self.accounts):
details = await account.get_details(**kwargs)
details['is_default'] = i == 0
ledger_id = account.ledger.get_id()
ledgers.setdefault(ledger_id, [])
ledgers[ledger_id].append(details)
return ledgers
@classmethod
def from_storage(cls, storage: 'WalletStorage', manager: 'basemanager.BaseWalletManager') -> 'Wallet':
json_dict = storage.read()
@ -54,11 +93,6 @@ class Wallet:
def save(self):
self.storage.write(self.to_dict())
@property
def default_account(self):
for account in self.accounts:
return account
@property
def hash(self) -> bytes:
h = sha256()

View file

@ -151,7 +151,9 @@ class WalletNode:
async def start(self, spv_node: 'SPVNode', seed=None, connect=True):
self.data_path = tempfile.mkdtemp()
wallet_file_name = os.path.join(self.data_path, 'my_wallet.json')
wallets_dir = os.path.join(self.data_path, 'wallets')
os.mkdir(wallets_dir)
wallet_file_name = os.path.join(wallets_dir, 'my_wallet.json')
with open(wallet_file_name, 'w') as wallet_file:
wallet_file.write('{"version": 1, "accounts": []}\n')
self.manager = self.manager_class.from_config({