wallet commands and wallet_id argument to other commands

This commit is contained in:
Lex Berezhny 2019-09-20 00:05:37 -04:00
parent 9fbc83fe9d
commit 84587ac232
16 changed files with 608 additions and 406 deletions

File diff suppressed because it is too large Load diff

View file

@ -139,34 +139,34 @@ class Account(BaseAccount):
return details return details
def get_transaction_history(self, **constraints): def get_transaction_history(self, **constraints):
return self.ledger.get_transaction_history(account=self, **constraints) return self.ledger.get_transaction_history(wallet=self.wallet, accounts=[self], **constraints)
def get_transaction_history_count(self, **constraints): def get_transaction_history_count(self, **constraints):
return self.ledger.get_transaction_history_count(account=self, **constraints) return self.ledger.get_transaction_history_count(wallet=self.wallet, accounts=[self], **constraints)
def get_claims(self, **constraints): def get_claims(self, **constraints):
return self.ledger.get_claims(account=self, **constraints) return self.ledger.get_claims(wallet=self.wallet, accounts=[self], **constraints)
def get_claim_count(self, **constraints): def get_claim_count(self, **constraints):
return self.ledger.get_claim_count(account=self, **constraints) return self.ledger.get_claim_count(wallet=self.wallet, accounts=[self], **constraints)
def get_streams(self, **constraints): def get_streams(self, **constraints):
return self.ledger.get_streams(account=self, **constraints) return self.ledger.get_streams(wallet=self.wallet, accounts=[self], **constraints)
def get_stream_count(self, **constraints): def get_stream_count(self, **constraints):
return self.ledger.get_stream_count(account=self, **constraints) return self.ledger.get_stream_count(wallet=self.wallet, accounts=[self], **constraints)
def get_channels(self, **constraints): def get_channels(self, **constraints):
return self.ledger.get_channels(account=self, **constraints) return self.ledger.get_channels(wallet=self.wallet, accounts=[self], **constraints)
def get_channel_count(self, **constraints): def get_channel_count(self, **constraints):
return self.ledger.get_channel_count(account=self, **constraints) return self.ledger.get_channel_count(wallet=self.wallet, accounts=[self], **constraints)
def get_supports(self, **constraints): def get_supports(self, **constraints):
return self.ledger.get_supports(account=self, **constraints) return self.ledger.get_supports(wallet=self.wallet, accounts=[self], **constraints)
def get_support_count(self, **constraints): def get_support_count(self, **constraints):
return self.ledger.get_support_count(account=self, **constraints) return self.ledger.get_support_count(wallet=self.wallet, accounts=[self], **constraints)
def get_support_summary(self): def get_support_summary(self):
return self.ledger.db.get_supports_summary(account_id=self.id) return self.ledger.db.get_supports_summary(account_id=self.id)

View file

@ -22,18 +22,18 @@ class WalletDatabase(BaseDatabase):
claim_id text, claim_id text,
claim_name text claim_name text
); );
create index if not exists txo_address_idx on txo (address);
create index if not exists txo_claim_id_idx on txo (claim_id); create index if not exists txo_claim_id_idx on txo (claim_id);
create index if not exists txo_txo_type_idx on txo (txo_type); create index if not exists txo_txo_type_idx on txo (txo_type);
""" """
CREATE_TABLES_QUERY = ( CREATE_TABLES_QUERY = (
BaseDatabase.CREATE_TX_TABLE + BaseDatabase.PRAGMAS +
BaseDatabase.CREATE_ACCOUNT_TABLE +
BaseDatabase.CREATE_PUBKEY_ADDRESS_TABLE + BaseDatabase.CREATE_PUBKEY_ADDRESS_TABLE +
BaseDatabase.CREATE_PUBKEY_ADDRESS_INDEX + BaseDatabase.CREATE_TX_TABLE +
CREATE_TXO_TABLE + CREATE_TXO_TABLE +
BaseDatabase.CREATE_TXO_INDEX + BaseDatabase.CREATE_TXI_TABLE
BaseDatabase.CREATE_TXI_TABLE +
BaseDatabase.CREATE_TXI_INDEX
) )
def txo_to_row(self, tx, address, txo): def txo_to_row(self, tx, address, txo):
@ -50,18 +50,16 @@ class WalletDatabase(BaseDatabase):
row['claim_name'] = txo.claim_name row['claim_name'] = txo.claim_name
return row return row
async def get_txos(self, **constraints) -> List[Output]: async def get_txos(self, wallet=None, no_tx=False, **constraints) -> List[Output]:
my_accounts = constraints.get('my_accounts', constraints.get('accounts', [])) txos = await super().get_txos(wallet=wallet, no_tx=no_tx, **constraints)
txos = await super().get_txos(**constraints)
channel_ids = set() channel_ids = set()
for txo in txos: for txo in txos:
if txo.is_claim and txo.can_decode_claim: if txo.is_claim and txo.can_decode_claim:
if txo.claim.is_signed: if txo.claim.is_signed:
channel_ids.add(txo.claim.signing_channel_id) channel_ids.add(txo.claim.signing_channel_id)
if txo.claim.is_channel and my_accounts: if txo.claim.is_channel and wallet:
for account in my_accounts: for account in wallet.accounts:
private_key = account.get_channel_private_key( private_key = account.get_channel_private_key(
txo.claim.channel.public_key_bytes txo.claim.channel.public_key_bytes
) )
@ -73,7 +71,7 @@ class WalletDatabase(BaseDatabase):
channels = { channels = {
txo.claim_id: txo for txo in txo.claim_id: txo for txo in
(await self.get_claims( (await self.get_claims(
my_accounts=my_accounts, wallet=wallet,
claim_id__in=channel_ids claim_id__in=channel_ids
)) ))
} }

View file

@ -121,39 +121,30 @@ class MainNetLedger(BaseLedger):
return super().get_utxo_count(**constraints) return super().get_utxo_count(**constraints)
def get_claims(self, **constraints): def get_claims(self, **constraints):
self.constraint_account_or_all(constraints)
return self.db.get_claims(**constraints) return self.db.get_claims(**constraints)
def get_claim_count(self, **constraints): def get_claim_count(self, **constraints):
self.constraint_account_or_all(constraints)
return self.db.get_claim_count(**constraints) return self.db.get_claim_count(**constraints)
def get_streams(self, **constraints): def get_streams(self, **constraints):
self.constraint_account_or_all(constraints)
return self.db.get_streams(**constraints) return self.db.get_streams(**constraints)
def get_stream_count(self, **constraints): def get_stream_count(self, **constraints):
self.constraint_account_or_all(constraints)
return self.db.get_stream_count(**constraints) return self.db.get_stream_count(**constraints)
def get_channels(self, **constraints): def get_channels(self, **constraints):
self.constraint_account_or_all(constraints)
return self.db.get_channels(**constraints) return self.db.get_channels(**constraints)
def get_channel_count(self, **constraints): def get_channel_count(self, **constraints):
self.constraint_account_or_all(constraints)
return self.db.get_channel_count(**constraints) return self.db.get_channel_count(**constraints)
def get_supports(self, **constraints): def get_supports(self, **constraints):
self.constraint_account_or_all(constraints)
return self.db.get_supports(**constraints) return self.db.get_supports(**constraints)
def get_support_count(self, **constraints): def get_support_count(self, **constraints):
self.constraint_account_or_all(constraints)
return self.db.get_support_count(**constraints) return self.db.get_support_count(**constraints)
async def get_transaction_history(self, **constraints): async def get_transaction_history(self, **constraints):
self.constraint_account_or_all(constraints)
txs = await self.db.get_transactions(**constraints) txs = await self.db.get_transactions(**constraints)
headers = self.headers headers = self.headers
history = [] history = []
@ -249,7 +240,6 @@ class MainNetLedger(BaseLedger):
return history return history
def get_transaction_history_count(self, **constraints): def get_transaction_history_count(self, **constraints):
self.constraint_account_or_all(constraints)
return self.db.get_transaction_count(**constraints) return self.db.get_transaction_count(**constraints)

View file

@ -203,7 +203,7 @@ class Transaction(BaseTransaction):
@classmethod @classmethod
def pay(cls, amount: int, address: bytes, funding_accounts: List[Account], change_account: Account): def pay(cls, amount: int, address: bytes, funding_accounts: List[Account], change_account: Account):
ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account) ledger, wallet = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account)
output = Output.pay_pubkey_hash(amount, ledger.address_to_hash160(address)) output = Output.pay_pubkey_hash(amount, ledger.address_to_hash160(address))
return cls.create([], [output], funding_accounts, change_account) return cls.create([], [output], funding_accounts, change_account)
@ -211,7 +211,7 @@ class Transaction(BaseTransaction):
def claim_create( def claim_create(
cls, name: str, claim: Claim, amount: int, holding_address: str, cls, name: str, claim: Claim, amount: int, holding_address: str,
funding_accounts: List[Account], change_account: Account, signing_channel: Output = None): funding_accounts: List[Account], change_account: Account, signing_channel: Output = None):
ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account) ledger, wallet = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account)
claim_output = Output.pay_claim_name_pubkey_hash( claim_output = Output.pay_claim_name_pubkey_hash(
amount, name, claim, ledger.address_to_hash160(holding_address) amount, name, claim, ledger.address_to_hash160(holding_address)
) )
@ -223,7 +223,7 @@ class Transaction(BaseTransaction):
def claim_update( def claim_update(
cls, previous_claim: Output, claim: Claim, amount: int, holding_address: str, cls, previous_claim: Output, claim: Claim, amount: int, holding_address: str,
funding_accounts: List[Account], change_account: Account, signing_channel: Output = None): funding_accounts: List[Account], change_account: Account, signing_channel: Output = None):
ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account) ledger, wallet = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account)
updated_claim = Output.pay_update_claim_pubkey_hash( updated_claim = Output.pay_update_claim_pubkey_hash(
amount, previous_claim.claim_name, previous_claim.claim_id, amount, previous_claim.claim_name, previous_claim.claim_id,
claim, ledger.address_to_hash160(holding_address) claim, ledger.address_to_hash160(holding_address)
@ -239,7 +239,7 @@ class Transaction(BaseTransaction):
@classmethod @classmethod
def support(cls, claim_name: str, claim_id: str, amount: int, holding_address: str, def support(cls, claim_name: str, claim_id: str, amount: int, holding_address: str,
funding_accounts: List[Account], change_account: Account): funding_accounts: List[Account], change_account: Account):
ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account) ledger, wallet = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account)
support_output = Output.pay_support_pubkey_hash( support_output = Output.pay_support_pubkey_hash(
amount, claim_name, claim_id, ledger.address_to_hash160(holding_address) amount, claim_name, claim_id, ledger.address_to_hash160(holding_address)
) )
@ -248,7 +248,7 @@ class Transaction(BaseTransaction):
@classmethod @classmethod
def purchase(cls, claim: Output, amount: int, merchant_address: bytes, def purchase(cls, claim: Output, amount: int, merchant_address: bytes,
funding_accounts: List[Account], change_account: Account): funding_accounts: List[Account], change_account: Account):
ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account) ledger, wallet = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account)
claim_output = Output.purchase_claim_pubkey_hash( claim_output = Output.purchase_claim_pubkey_hash(
amount, claim.claim_id, ledger.address_to_hash160(merchant_address) amount, claim.claim_id, ledger.address_to_hash160(merchant_address)
) )

View file

@ -466,45 +466,47 @@ class ChannelCommands(CommandTestCase):
'title': 'foo', 'email': 'new@email.com'} 'title': 'foo', 'email': 'new@email.com'}
) )
# send channel to someone else # move channel to another account
new_account = await self.out(self.daemon.jsonrpc_account_create('second account')) new_account = await self.out(self.daemon.jsonrpc_account_create('second account'))
account2_id, account2 = new_account['id'], self.daemon.get_account_or_error(new_account['id']) account2_id, account2 = new_account['id'], self.wallet.get_account_or_error(new_account['id'])
# before sending # before moving
self.assertEqual(len(await self.daemon.jsonrpc_channel_list()), 3) self.assertEqual(len(await self.daemon.jsonrpc_channel_list()), 3)
self.assertEqual(len(await self.daemon.jsonrpc_channel_list(account_id=account2_id)), 0) self.assertEqual(len(await self.daemon.jsonrpc_channel_list(account_id=account2_id)), 0)
other_address = await account2.receiving.get_or_create_usable_address() other_address = await account2.receiving.get_or_create_usable_address()
tx = await self.out(self.channel_update(claim_id, claim_address=other_address)) tx = await self.out(self.channel_update(claim_id, claim_address=other_address))
# after sending # after moving
self.assertEqual(len(await self.daemon.jsonrpc_channel_list()), 3) self.assertEqual(len(await self.daemon.jsonrpc_channel_list()), 3)
self.assertEqual(len(await self.daemon.jsonrpc_channel_list(account_id=self.account.id)), 2) self.assertEqual(len(await self.daemon.jsonrpc_channel_list(account_id=self.account.id)), 2)
self.assertEqual(len(await self.daemon.jsonrpc_channel_list(account_id=account2_id)), 1) self.assertEqual(len(await self.daemon.jsonrpc_channel_list(account_id=account2_id)), 1)
# shoud not have private key async def test_channel_export_import_before_sending_channel(self):
txo = (await account2.get_channels())[0] # export
self.assertIsNone(txo.private_key)
# send the private key too
private_key = self.account.get_channel_private_key(unhexlify(channel['public_key']))
account2.add_channel_private_key(private_key)
# now should have private key
txo = (await account2.get_channels())[0]
self.assertIsNotNone(txo.private_key)
async def test_channel_export_import_into_new_account(self):
tx = await self.channel_create('@foo', '1.0') tx = await self.channel_create('@foo', '1.0')
claim_id = self.get_claim_id(tx) claim_id = self.get_claim_id(tx)
channel_private_key = (await self.account.get_channels())[0].private_key channel_private_key = (await self.account.get_channels())[0].private_key
exported_data = await self.out(self.daemon.jsonrpc_channel_export(claim_id)) exported_data = await self.out(self.daemon.jsonrpc_channel_export(claim_id))
# import
daemon2 = await self.add_daemon() daemon2 = await self.add_daemon()
self.assertEqual(0, len(await daemon2.jsonrpc_channel_list()))
await daemon2.jsonrpc_channel_import(exported_data) await daemon2.jsonrpc_channel_import(exported_data)
channels = await daemon2.jsonrpc_channel_list() channels = await daemon2.jsonrpc_channel_list()
self.assertEqual(1, len(channels)) self.assertEqual(1, len(channels))
self.assertEqual(channel_private_key.to_string(), channels[0].private_key.to_string()) self.assertEqual(channel_private_key.to_string(), channels[0].private_key.to_string())
# second wallet can't update until channel is sent to it
with self.assertRaisesRegex(AssertionError, 'Cannot find private key for signing output.'):
await daemon2.jsonrpc_channel_update(claim_id, bid='0.5')
# now send the channel as well
await self.channel_update(claim_id, claim_address=await daemon2.jsonrpc_address_unused())
# second wallet should be able to update now
await daemon2.jsonrpc_channel_update(claim_id, bid='0.5')
class StreamCommands(ClaimTestCase): class StreamCommands(ClaimTestCase):
@ -565,7 +567,7 @@ class StreamCommands(ClaimTestCase):
async def test_publishing_checks_all_accounts_for_channel(self): async def test_publishing_checks_all_accounts_for_channel(self):
account1_id, account1 = self.account.id, self.account account1_id, account1 = self.account.id, self.account
new_account = await self.out(self.daemon.jsonrpc_account_create('second account')) new_account = await self.out(self.daemon.jsonrpc_account_create('second account'))
account2_id, account2 = new_account['id'], self.daemon.get_account_or_error(new_account['id']) account2_id, account2 = new_account['id'], self.wallet.get_account_or_error(new_account['id'])
await self.out(self.channel_create('@spam', '1.0')) await self.out(self.channel_create('@spam', '1.0'))
self.assertEqual('8.989893', (await self.daemon.jsonrpc_account_balance())['available']) self.assertEqual('8.989893', (await self.daemon.jsonrpc_account_balance())['available'])
@ -782,7 +784,7 @@ class StreamCommands(ClaimTestCase):
# send claim to someone else # send claim to someone else
new_account = await self.out(self.daemon.jsonrpc_account_create('second account')) new_account = await self.out(self.daemon.jsonrpc_account_create('second account'))
account2_id, account2 = new_account['id'], self.daemon.get_account_or_error(new_account['id']) account2_id, account2 = new_account['id'], self.wallet.get_account_or_error(new_account['id'])
# before sending # before sending
self.assertEqual(len(await self.daemon.jsonrpc_claim_list()), 4) self.assertEqual(len(await self.daemon.jsonrpc_claim_list()), 4)
@ -1079,13 +1081,12 @@ class StreamCommands(ClaimTestCase):
class SupportCommands(CommandTestCase): class SupportCommands(CommandTestCase):
async def test_regular_supports_and_tip_supports(self): async def test_regular_supports_and_tip_supports(self):
# account2 will be used to send tips and supports to account1 wallet2 = await self.daemon.jsonrpc_wallet_add('wallet2', create_wallet=True, create_account=True)
account2_id = (await self.out(self.daemon.jsonrpc_account_create('second account')))['id'] account2 = wallet2.accounts[0]
account2 = self.daemon.get_account_or_error(account2_id)
# send account2 5 LBC out of the 10 LBC in account1 # send account2 5 LBC out of the 10 LBC in account1
result = await self.out(self.daemon.jsonrpc_account_send( result = await self.out(self.daemon.jsonrpc_account_send(
'5.0', await self.daemon.jsonrpc_address_unused(account2_id) '5.0', await self.daemon.jsonrpc_address_unused(wallet_id='wallet2')
)) ))
await self.on_transaction_dict(result) await self.on_transaction_dict(result)
@ -1103,7 +1104,7 @@ class SupportCommands(CommandTestCase):
# send a tip to the claim using account2 # send a tip to the claim using account2
tip = await self.out( tip = await self.out(
self.daemon.jsonrpc_support_create( self.daemon.jsonrpc_support_create(
claim_id, '1.0', True, account2_id, funding_account_ids=[account2_id]) claim_id, '1.0', True, account2.id, 'wallet2', funding_account_ids=[account2.id])
) )
await self.confirm_tx(tip['txid']) await self.confirm_tx(tip['txid'])
@ -1122,7 +1123,7 @@ class SupportCommands(CommandTestCase):
# verify that the outgoing tip is marked correctly as is_tip=True in account2 # verify that the outgoing tip is marked correctly as is_tip=True in account2
txs2 = await self.out( txs2 = await self.out(
self.daemon.jsonrpc_transaction_list(account2_id) self.daemon.jsonrpc_transaction_list(wallet_id='wallet2', account_id=account2.id)
) )
self.assertEqual(len(txs2[0]['support_info']), 1) self.assertEqual(len(txs2[0]['support_info']), 1)
self.assertEqual(txs2[0]['support_info'][0]['balance_delta'], '-1.0') self.assertEqual(txs2[0]['support_info'][0]['balance_delta'], '-1.0')
@ -1134,7 +1135,7 @@ class SupportCommands(CommandTestCase):
# send a support to the claim using account2 # send a support to the claim using account2
support = await self.out( support = await self.out(
self.daemon.jsonrpc_support_create( self.daemon.jsonrpc_support_create(
claim_id, '2.0', False, account2_id, funding_account_ids=[account2_id]) claim_id, '2.0', False, account2.id, 'wallet2', funding_account_ids=[account2.id])
) )
await self.confirm_tx(support['txid']) await self.confirm_tx(support['txid'])
@ -1143,7 +1144,7 @@ class SupportCommands(CommandTestCase):
await self.assertBalance(account2, '1.999717') await self.assertBalance(account2, '1.999717')
# verify that the outgoing support is marked correctly as is_tip=False in account2 # verify that the outgoing support is marked correctly as is_tip=False in account2
txs2 = await self.out(self.daemon.jsonrpc_transaction_list(account2_id)) txs2 = await self.out(self.daemon.jsonrpc_transaction_list(wallet_id='wallet2'))
self.assertEqual(len(txs2[0]['support_info']), 1) self.assertEqual(len(txs2[0]['support_info']), 1)
self.assertEqual(txs2[0]['support_info'][0]['balance_delta'], '-2.0') self.assertEqual(txs2[0]['support_info'][0]['balance_delta'], '-2.0')
self.assertEqual(txs2[0]['support_info'][0]['claim_id'], claim_id) self.assertEqual(txs2[0]['support_info'][0]['claim_id'], claim_id)

View file

@ -61,13 +61,17 @@ class TestAccount(AsyncioTestCase):
address = await account.receiving.ensure_address_gap() address = await account.receiving.ensure_address_gap()
self.assertEqual(address[0], 'bCqJrLHdoiRqEZ1whFZ3WHNb33bP34SuGx') self.assertEqual(address[0], 'bCqJrLHdoiRqEZ1whFZ3WHNb33bP34SuGx')
private_key = await self.ledger.get_private_key_for_address('bCqJrLHdoiRqEZ1whFZ3WHNb33bP34SuGx') private_key = await self.ledger.get_private_key_for_address(
account.wallet, 'bCqJrLHdoiRqEZ1whFZ3WHNb33bP34SuGx'
)
self.assertEqual( self.assertEqual(
private_key.extended_key_string(), private_key.extended_key_string(),
'xprv9vwXVierUTT4hmoe3dtTeBfbNv1ph2mm8RWXARU6HsZjBaAoFaS2FRQu4fptR' 'xprv9vwXVierUTT4hmoe3dtTeBfbNv1ph2mm8RWXARU6HsZjBaAoFaS2FRQu4fptR'
'AyJWhJW42dmsEaC1nKnVKKTMhq3TVEHsNj1ca3ciZMKktT' 'AyJWhJW42dmsEaC1nKnVKKTMhq3TVEHsNj1ca3ciZMKktT'
) )
private_key = await self.ledger.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX') private_key = await self.ledger.get_private_key_for_address(
account.wallet, 'BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX'
)
self.assertIsNone(private_key) self.assertIsNone(private_key)
def test_load_and_save_account(self): def test_load_and_save_account(self):

View file

@ -60,7 +60,7 @@ class TestHierarchicalDeterministicAccount(AsyncioTestCase):
await account.receiving._generate_keys(8, 11) await account.receiving._generate_keys(8, 11)
records = await account.receiving.get_address_records() records = await account.receiving.get_address_records()
self.assertEqual( self.assertEqual(
[r['position'] for r in records], [r['pubkey'].n for r in records],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
) )
@ -69,7 +69,7 @@ class TestHierarchicalDeterministicAccount(AsyncioTestCase):
self.assertEqual(len(new_keys), 8) self.assertEqual(len(new_keys), 8)
records = await account.receiving.get_address_records() records = await account.receiving.get_address_records()
self.assertEqual( self.assertEqual(
[r['position'] for r in records], [r['pubkey'].n for r in records],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
) )
@ -125,14 +125,18 @@ class TestHierarchicalDeterministicAccount(AsyncioTestCase):
address = await account.receiving.ensure_address_gap() address = await account.receiving.ensure_address_gap()
self.assertEqual(address[0], '1CDLuMfwmPqJiNk5C2Bvew6tpgjAGgUk8J') self.assertEqual(address[0], '1CDLuMfwmPqJiNk5C2Bvew6tpgjAGgUk8J')
private_key = await self.ledger.get_private_key_for_address('1CDLuMfwmPqJiNk5C2Bvew6tpgjAGgUk8J') private_key = await self.ledger.get_private_key_for_address(
account.wallet, '1CDLuMfwmPqJiNk5C2Bvew6tpgjAGgUk8J'
)
self.assertEqual( self.assertEqual(
private_key.extended_key_string(), private_key.extended_key_string(),
'xprv9xV7rhbg6M4yWrdTeLorz3Q1GrQb4aQzzGWboP3du7W7UUztzNTUrEYTnDfz7o' 'xprv9xV7rhbg6M4yWrdTeLorz3Q1GrQb4aQzzGWboP3du7W7UUztzNTUrEYTnDfz7o'
'ptBygDxXYRppyiuenJpoBTgYP2C26E1Ah5FEALM24CsWi' 'ptBygDxXYRppyiuenJpoBTgYP2C26E1Ah5FEALM24CsWi'
) )
invalid_key = await self.ledger.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX') invalid_key = await self.ledger.get_private_key_for_address(
account.wallet, 'BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX'
)
self.assertIsNone(invalid_key) self.assertIsNone(invalid_key)
self.assertEqual( self.assertEqual(
@ -275,12 +279,18 @@ class TestSingleKeyAccount(AsyncioTestCase):
self.assertEqual(len(new_keys), 1) self.assertEqual(len(new_keys), 1)
self.assertEqual(new_keys[0], account.public_key.address) self.assertEqual(new_keys[0], account.public_key.address)
records = await account.receiving.get_address_records() records = await account.receiving.get_address_records()
pubkey = records[0].pop('pubkey')
self.assertEqual(records, [{ self.assertEqual(records, [{
'position': 0, 'chain': 0, 'chain': 0,
'account': account.public_key.address, 'account': account.public_key.address,
'address': account.public_key.address, 'address': account.public_key.address,
'history': None,
'used_times': 0 'used_times': 0
}]) }])
self.assertEqual(
pubkey.extended_key_string(),
account.public_key.extended_key_string()
)
# case #1: no new addresses needed # case #1: no new addresses needed
empty = await account.receiving.ensure_address_gap() empty = await account.receiving.ensure_address_gap()
@ -333,14 +343,18 @@ class TestSingleKeyAccount(AsyncioTestCase):
address = await account.receiving.ensure_address_gap() address = await account.receiving.ensure_address_gap()
self.assertEqual(address[0], account.public_key.address) self.assertEqual(address[0], account.public_key.address)
private_key = await self.ledger.get_private_key_for_address(address[0]) private_key = await self.ledger.get_private_key_for_address(
account.wallet, address[0]
)
self.assertEqual( self.assertEqual(
private_key.extended_key_string(), private_key.extended_key_string(),
'xprv9s21ZrQH143K3TsAz5efNV8K93g3Ms3FXcjaWB9fVUsMwAoE3ZT4vYymkp' 'xprv9s21ZrQH143K3TsAz5efNV8K93g3Ms3FXcjaWB9fVUsMwAoE3ZT4vYymkp'
'5BxKKfnpz8J6sHDFriX1SnpvjNkzcks8XBnxjGLS83BTyfpna', '5BxKKfnpz8J6sHDFriX1SnpvjNkzcks8XBnxjGLS83BTyfpna',
) )
invalid_key = await self.ledger.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX') invalid_key = await self.ledger.get_private_key_for_address(
account.wallet, 'BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX'
)
self.assertIsNone(invalid_key) self.assertIsNone(invalid_key)
self.assertEqual( self.assertEqual(

View file

@ -199,13 +199,14 @@ class TestQueries(AsyncioTestCase):
'db': ledger_class.database_class(':memory:'), 'db': ledger_class.database_class(':memory:'),
'headers': ledger_class.headers_class(':memory:'), 'headers': ledger_class.headers_class(':memory:'),
}) })
self.wallet = Wallet()
await self.ledger.db.open() await self.ledger.db.open()
async def asyncTearDown(self): async def asyncTearDown(self):
await self.ledger.db.close() await self.ledger.db.close()
async def create_account(self): async def create_account(self, wallet=None):
account = self.ledger.account_class.generate(self.ledger, Wallet()) account = self.ledger.account_class.generate(self.ledger, wallet or self.wallet)
await account.ensure_address_gap() await account.ensure_address_gap()
return account return account
@ -264,7 +265,8 @@ class TestQueries(AsyncioTestCase):
tx = await self.create_tx_from_txo(tx.outputs[0], account, height=height) tx = await self.create_tx_from_txo(tx.outputs[0], account, height=height)
variable_limit = self.ledger.db.MAX_QUERY_VARIABLES variable_limit = self.ledger.db.MAX_QUERY_VARIABLES
for limit in range(variable_limit-2, variable_limit+2): for limit in range(variable_limit-2, variable_limit+2):
txs = await self.ledger.get_transactions(limit=limit, order_by='height asc') txs = await self.ledger.get_transactions(
accounts=self.wallet.accounts, limit=limit, order_by='height asc')
self.assertEqual(len(txs), limit) self.assertEqual(len(txs), limit)
inputs, outputs, last_tx = set(), set(), txs[0] inputs, outputs, last_tx = set(), set(), txs[0]
for tx in txs[1:]: for tx in txs[1:]:
@ -274,19 +276,23 @@ class TestQueries(AsyncioTestCase):
last_tx = tx last_tx = tx
async def test_queries(self): async def test_queries(self):
self.assertEqual(0, await self.ledger.db.get_address_count()) wallet1 = Wallet()
account1 = await self.create_account() account1 = await self.create_account(wallet1)
self.assertEqual(26, await self.ledger.db.get_address_count()) self.assertEqual(26, await self.ledger.db.get_address_count(accounts=[account1]))
account2 = await self.create_account() wallet2 = Wallet()
self.assertEqual(52, await self.ledger.db.get_address_count()) account2 = await self.create_account(wallet2)
account3 = await self.create_account(wallet2)
self.assertEqual(26, await self.ledger.db.get_address_count(accounts=[account2]))
self.assertEqual(0, await self.ledger.db.get_transaction_count(accounts=[account1, account2])) self.assertEqual(0, await self.ledger.db.get_transaction_count(accounts=[account1, account2, account3]))
self.assertEqual(0, await self.ledger.db.get_utxo_count()) self.assertEqual(0, await self.ledger.db.get_utxo_count())
self.assertEqual([], await self.ledger.db.get_utxos()) self.assertEqual([], await self.ledger.db.get_utxos())
self.assertEqual(0, await self.ledger.db.get_txo_count()) self.assertEqual(0, await self.ledger.db.get_txo_count())
self.assertEqual(0, await self.ledger.db.get_balance()) self.assertEqual(0, await self.ledger.db.get_balance(wallet=wallet1))
self.assertEqual(0, await self.ledger.db.get_balance(wallet=wallet2))
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account1])) self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account1]))
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account2])) self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account2]))
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account3]))
tx1 = await self.create_tx_from_nothing(account1, 1) tx1 = await self.create_tx_from_nothing(account1, 1)
self.assertEqual(1, await self.ledger.db.get_transaction_count(accounts=[account1])) self.assertEqual(1, await self.ledger.db.get_transaction_count(accounts=[account1]))
@ -294,20 +300,28 @@ class TestQueries(AsyncioTestCase):
self.assertEqual(1, await self.ledger.db.get_utxo_count(accounts=[account1])) self.assertEqual(1, await self.ledger.db.get_utxo_count(accounts=[account1]))
self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account1])) self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account1]))
self.assertEqual(0, await self.ledger.db.get_txo_count(accounts=[account2])) self.assertEqual(0, await self.ledger.db.get_txo_count(accounts=[account2]))
self.assertEqual(10**8, await self.ledger.db.get_balance()) self.assertEqual(10**8, await self.ledger.db.get_balance(wallet=wallet1))
self.assertEqual(0, await self.ledger.db.get_balance(wallet=wallet2))
self.assertEqual(10**8, await self.ledger.db.get_balance(accounts=[account1])) self.assertEqual(10**8, await self.ledger.db.get_balance(accounts=[account1]))
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account2])) self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account2]))
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account3]))
tx2 = await self.create_tx_from_txo(tx1.outputs[0], account2, 2) tx2 = await self.create_tx_from_txo(tx1.outputs[0], account2, 2)
tx2b = await self.create_tx_from_nothing(account3, 2)
self.assertEqual(2, await self.ledger.db.get_transaction_count(accounts=[account1])) self.assertEqual(2, await self.ledger.db.get_transaction_count(accounts=[account1]))
self.assertEqual(1, await self.ledger.db.get_transaction_count(accounts=[account2])) self.assertEqual(1, await self.ledger.db.get_transaction_count(accounts=[account2]))
self.assertEqual(1, await self.ledger.db.get_transaction_count(accounts=[account3]))
self.assertEqual(0, await self.ledger.db.get_utxo_count(accounts=[account1])) self.assertEqual(0, await self.ledger.db.get_utxo_count(accounts=[account1]))
self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account1])) self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account1]))
self.assertEqual(1, await self.ledger.db.get_utxo_count(accounts=[account2])) self.assertEqual(1, await self.ledger.db.get_utxo_count(accounts=[account2]))
self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account2])) self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account2]))
self.assertEqual(10**8, await self.ledger.db.get_balance()) self.assertEqual(1, await self.ledger.db.get_utxo_count(accounts=[account3]))
self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account3]))
self.assertEqual(0, await self.ledger.db.get_balance(wallet=wallet1))
self.assertEqual(10**8+10**8, await self.ledger.db.get_balance(wallet=wallet2))
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account1])) self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account1]))
self.assertEqual(10**8, await self.ledger.db.get_balance(accounts=[account2])) self.assertEqual(10**8, await self.ledger.db.get_balance(accounts=[account2]))
self.assertEqual(10**8, await self.ledger.db.get_balance(accounts=[account3]))
tx3 = await self.create_tx_to_nowhere(tx2.outputs[0], 3) tx3 = await self.create_tx_to_nowhere(tx2.outputs[0], 3)
self.assertEqual(2, await self.ledger.db.get_transaction_count(accounts=[account1])) self.assertEqual(2, await self.ledger.db.get_transaction_count(accounts=[account1]))
@ -316,22 +330,24 @@ class TestQueries(AsyncioTestCase):
self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account1])) self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account1]))
self.assertEqual(0, await self.ledger.db.get_utxo_count(accounts=[account2])) self.assertEqual(0, await self.ledger.db.get_utxo_count(accounts=[account2]))
self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account2])) self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account2]))
self.assertEqual(0, await self.ledger.db.get_balance()) self.assertEqual(0, await self.ledger.db.get_balance(wallet=wallet1))
self.assertEqual(10**8, await self.ledger.db.get_balance(wallet=wallet2))
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account1])) self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account1]))
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account2])) self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account2]))
self.assertEqual(10**8, await self.ledger.db.get_balance(accounts=[account3]))
txs = await self.ledger.db.get_transactions(accounts=[account1, account2]) txs = await self.ledger.db.get_transactions(accounts=[account1, account2])
self.assertEqual([tx3.id, tx2.id, tx1.id], [tx.id for tx in txs]) self.assertEqual([tx3.id, tx2.id, tx1.id], [tx.id for tx in txs])
self.assertEqual([3, 2, 1], [tx.height for tx in txs]) self.assertEqual([3, 2, 1], [tx.height for tx in txs])
txs = await self.ledger.db.get_transactions(accounts=[account1]) txs = await self.ledger.db.get_transactions(wallet=wallet1, accounts=wallet1.accounts)
self.assertEqual([tx2.id, tx1.id], [tx.id for tx in txs]) self.assertEqual([tx2.id, tx1.id], [tx.id for tx in txs])
self.assertEqual(txs[0].inputs[0].is_my_account, True) self.assertEqual(txs[0].inputs[0].is_my_account, True)
self.assertEqual(txs[0].outputs[0].is_my_account, False) self.assertEqual(txs[0].outputs[0].is_my_account, False)
self.assertEqual(txs[1].inputs[0].is_my_account, False) self.assertEqual(txs[1].inputs[0].is_my_account, False)
self.assertEqual(txs[1].outputs[0].is_my_account, True) self.assertEqual(txs[1].outputs[0].is_my_account, True)
txs = await self.ledger.db.get_transactions(accounts=[account2]) txs = await self.ledger.db.get_transactions(wallet=wallet2, accounts=[account2])
self.assertEqual([tx3.id, tx2.id], [tx.id for tx in txs]) self.assertEqual([tx3.id, tx2.id], [tx.id for tx in txs])
self.assertEqual(txs[0].inputs[0].is_my_account, True) self.assertEqual(txs[0].inputs[0].is_my_account, True)
self.assertEqual(txs[0].outputs[0].is_my_account, False) self.assertEqual(txs[0].outputs[0].is_my_account, False)
@ -343,18 +359,18 @@ class TestQueries(AsyncioTestCase):
self.assertEqual(tx.id, tx2.id) self.assertEqual(tx.id, tx2.id)
self.assertEqual(tx.inputs[0].is_my_account, False) self.assertEqual(tx.inputs[0].is_my_account, False)
self.assertEqual(tx.outputs[0].is_my_account, False) self.assertEqual(tx.outputs[0].is_my_account, False)
tx = await self.ledger.db.get_transaction(txid=tx2.id, accounts=[account1]) tx = await self.ledger.db.get_transaction(wallet=wallet1, txid=tx2.id)
self.assertEqual(tx.inputs[0].is_my_account, True) self.assertEqual(tx.inputs[0].is_my_account, True)
self.assertEqual(tx.outputs[0].is_my_account, False) self.assertEqual(tx.outputs[0].is_my_account, False)
tx = await self.ledger.db.get_transaction(txid=tx2.id, accounts=[account2]) tx = await self.ledger.db.get_transaction(wallet=wallet2, txid=tx2.id)
self.assertEqual(tx.inputs[0].is_my_account, False) self.assertEqual(tx.inputs[0].is_my_account, False)
self.assertEqual(tx.outputs[0].is_my_account, True) self.assertEqual(tx.outputs[0].is_my_account, True)
# height 0 sorted to the top with the rest in descending order # height 0 sorted to the top with the rest in descending order
tx4 = await self.create_tx_from_nothing(account1, 0) tx4 = await self.create_tx_from_nothing(account1, 0)
txos = await self.ledger.db.get_txos() txos = await self.ledger.db.get_txos()
self.assertEqual([0, 2, 1], [txo.tx_ref.height for txo in txos]) self.assertEqual([0, 2, 2, 1], [txo.tx_ref.height for txo in txos])
self.assertEqual([tx4.id, tx2.id, tx1.id], [txo.tx_ref.id for txo in txos]) self.assertEqual([tx4.id, tx2.id, tx2b.id, tx1.id], [txo.tx_ref.id for txo in txos])
txs = await self.ledger.db.get_transactions(accounts=[account1, account2]) txs = await self.ledger.db.get_transactions(accounts=[account1, account2])
self.assertEqual([0, 3, 2, 1], [tx.height for tx in txs]) self.assertEqual([0, 3, 2, 1], [tx.height for tx in txs])
self.assertEqual([tx4.id, tx3.id, tx2.id, tx1.id], [tx.id for tx in txs]) self.assertEqual([tx4.id, tx3.id, tx2.id, tx1.id], [tx.id for tx in txs])
@ -382,13 +398,13 @@ class TestUpgrade(AsyncioTestCase):
def add_address(self, address): def add_address(self, address):
with sqlite3.connect(self.path) as conn: with sqlite3.connect(self.path) as conn:
conn.execute(""" conn.execute("""
INSERT INTO pubkey_address (address, account, chain, position, pubkey) INSERT INTO account_address (address, account, chain, n, pubkey, chain_code, depth)
VALUES (?, 'account1', 0, 0, 'pubkey blob') VALUES (?, 'account1', 0, 0, 'pubkey', 'chain_code', 0)
""", (address,)) """, (address,))
def get_addresses(self): def get_addresses(self):
with sqlite3.connect(self.path) as conn: with sqlite3.connect(self.path) as conn:
sql = "SELECT address FROM pubkey_address ORDER BY address;" sql = "SELECT address FROM account_address ORDER BY address;"
return [col[0] for col in conn.execute(sql).fetchall()] return [col[0] for col in conn.execute(sql).fetchall()]
async def test_reset_on_version_change(self): async def test_reset_on_version_change(self):
@ -401,7 +417,7 @@ class TestUpgrade(AsyncioTestCase):
self.ledger.db.SCHEMA_VERSION = None self.ledger.db.SCHEMA_VERSION = None
self.assertEqual(self.get_tables(), []) self.assertEqual(self.get_tables(), [])
await self.ledger.db.open() await self.ledger.db.open()
self.assertEqual(self.get_tables(), ['pubkey_address', 'tx', 'txi', 'txo']) self.assertEqual(self.get_tables(), ['account_address', 'pubkey_address', 'tx', 'txi', 'txo'])
self.assertEqual(self.get_addresses(), []) self.assertEqual(self.get_addresses(), [])
self.add_address('address1') self.add_address('address1')
await self.ledger.db.close() await self.ledger.db.close()
@ -410,17 +426,17 @@ class TestUpgrade(AsyncioTestCase):
self.ledger.db.SCHEMA_VERSION = '1.0' self.ledger.db.SCHEMA_VERSION = '1.0'
await self.ledger.db.open() await self.ledger.db.open()
self.assertEqual(self.get_version(), '1.0') self.assertEqual(self.get_version(), '1.0')
self.assertEqual(self.get_tables(), ['pubkey_address', 'tx', 'txi', 'txo', 'version']) self.assertEqual(self.get_tables(), ['account_address', 'pubkey_address', 'tx', 'txi', 'txo', 'version'])
self.assertEqual(self.get_addresses(), []) # address1 deleted during version upgrade self.assertEqual(self.get_addresses(), []) # address1 deleted during version upgrade
self.add_address('address2') self.add_address('address2')
await self.ledger.db.close() await self.ledger.db.close()
# nothing changes # nothing changes
self.assertEqual(self.get_version(), '1.0') self.assertEqual(self.get_version(), '1.0')
self.assertEqual(self.get_tables(), ['pubkey_address', 'tx', 'txi', 'txo', 'version']) self.assertEqual(self.get_tables(), ['account_address', 'pubkey_address', 'tx', 'txi', 'txo', 'version'])
await self.ledger.db.open() await self.ledger.db.open()
self.assertEqual(self.get_version(), '1.0') self.assertEqual(self.get_version(), '1.0')
self.assertEqual(self.get_tables(), ['pubkey_address', 'tx', 'txi', 'txo', 'version']) self.assertEqual(self.get_tables(), ['account_address', 'pubkey_address', 'tx', 'txi', 'txo', 'version'])
self.assertEqual(self.get_addresses(), ['address2']) self.assertEqual(self.get_addresses(), ['address2'])
await self.ledger.db.close() await self.ledger.db.close()
@ -431,7 +447,7 @@ class TestUpgrade(AsyncioTestCase):
""" """
await self.ledger.db.open() await self.ledger.db.open()
self.assertEqual(self.get_version(), '1.1') self.assertEqual(self.get_version(), '1.1')
self.assertEqual(self.get_tables(), ['foo', 'pubkey_address', 'tx', 'txi', 'txo', 'version']) self.assertEqual(self.get_tables(), ['account_address', 'foo', 'pubkey_address', 'tx', 'txi', 'txo', 'version'])
self.assertEqual(self.get_addresses(), []) # all tables got reset self.assertEqual(self.get_addresses(), []) # all tables got reset
await self.ledger.db.close() await self.ledger.db.close()

View file

@ -115,7 +115,7 @@ class HierarchicalDeterministic(AddressManager):
return self.account.public_key.child(self.chain_number).child(index) return self.account.public_key.child(self.chain_number).child(index)
async def get_max_gap(self) -> int: async def get_max_gap(self) -> int:
addresses = await self._query_addresses(order_by="position ASC") addresses = await self._query_addresses(order_by="n asc")
max_gap = 0 max_gap = 0
current_gap = 0 current_gap = 0
for address in addresses: for address in addresses:
@ -128,7 +128,7 @@ class HierarchicalDeterministic(AddressManager):
async def ensure_address_gap(self) -> List[str]: async def ensure_address_gap(self) -> List[str]:
async with self.address_generator_lock: async with self.address_generator_lock:
addresses = await self._query_addresses(limit=self.gap, order_by="position DESC") addresses = await self._query_addresses(limit=self.gap, order_by="n desc")
existing_gap = 0 existing_gap = 0
for address in addresses: for address in addresses:
@ -140,7 +140,7 @@ class HierarchicalDeterministic(AddressManager):
if existing_gap == self.gap: if existing_gap == self.gap:
return [] return []
start = addresses[0]['position']+1 if addresses else 0 start = addresses[0]['pubkey'].n+1 if addresses else 0
end = start + (self.gap - existing_gap) end = start + (self.gap - existing_gap)
new_keys = await self._generate_keys(start, end-1) new_keys = await self._generate_keys(start, end-1)
await self.account.ledger.announce_addresses(self, new_keys) await self.account.ledger.announce_addresses(self, new_keys)
@ -149,15 +149,15 @@ class HierarchicalDeterministic(AddressManager):
async def _generate_keys(self, start: int, end: int) -> List[str]: async def _generate_keys(self, start: int, end: int) -> List[str]:
if not self.address_generator_lock.locked(): if not self.address_generator_lock.locked():
raise RuntimeError('Should not be called outside of address_generator_lock.') raise RuntimeError('Should not be called outside of address_generator_lock.')
keys = [(index, self.public_key.child(index)) for index in range(start, end+1)] keys = [self.public_key.child(index) for index in range(start, end+1)]
await self.account.ledger.db.add_keys(self.account, self.chain_number, keys) await self.account.ledger.db.add_keys(self.account, self.chain_number, keys)
return [key[1].address for key in keys] return [key.address for key in keys]
def get_address_records(self, only_usable: bool = False, **constraints): def get_address_records(self, only_usable: bool = False, **constraints):
if only_usable: if only_usable:
constraints['used_times__lt'] = self.maximum_uses_per_address constraints['used_times__lt'] = self.maximum_uses_per_address
if 'order_by' not in constraints: if 'order_by' not in constraints:
constraints['order_by'] = "used_times ASC, position ASC" constraints['order_by'] = "used_times asc, n asc"
return self._query_addresses(**constraints) return self._query_addresses(**constraints)
@ -190,9 +190,7 @@ class SingleKey(AddressManager):
async with self.address_generator_lock: async with self.address_generator_lock:
exists = await self.get_address_records() exists = await self.get_address_records()
if not exists: if not exists:
await self.account.ledger.db.add_keys( await self.account.ledger.db.add_keys(self.account, self.chain_number, [self.public_key])
self.account, self.chain_number, [(0, self.public_key)]
)
new_keys = [self.public_key.address] new_keys = [self.public_key.address]
await self.account.ledger.announce_addresses(self, new_keys) await self.account.ledger.announce_addresses(self, new_keys)
return new_keys return new_keys
@ -417,16 +415,16 @@ class BaseAccount:
} }
def get_utxos(self, **constraints): def get_utxos(self, **constraints):
return self.ledger.get_utxos(account=self, **constraints) return self.ledger.get_utxos(wallet=self.wallet, accounts=[self], **constraints)
def get_utxo_count(self, **constraints): def get_utxo_count(self, **constraints):
return self.ledger.get_utxo_count(account=self, **constraints) return self.ledger.get_utxo_count(wallet=self.wallet, accounts=[self], **constraints)
def get_transactions(self, **constraints): def get_transactions(self, **constraints):
return self.ledger.get_transactions(account=self, **constraints) return self.ledger.get_transactions(wallet=self.wallet, accounts=[self], **constraints)
def get_transaction_count(self, **constraints): def get_transaction_count(self, **constraints):
return self.ledger.get_transaction_count(account=self, **constraints) return self.ledger.get_transaction_count(wallet=self.wallet, accounts=[self], **constraints)
async def fund(self, to_account, amount=None, everything=False, async def fund(self, to_account, amount=None, everything=False,
outputs=1, broadcast=False, **constraints): outputs=1, broadcast=False, **constraints):

View file

@ -1,5 +1,6 @@
import logging import logging
import asyncio import asyncio
import json
from binascii import hexlify from binascii import hexlify
from concurrent.futures.thread import ThreadPoolExecutor from concurrent.futures.thread import ThreadPoolExecutor
@ -8,7 +9,7 @@ from typing import Tuple, List, Union, Callable, Any, Awaitable, Iterable, Dict,
import sqlite3 import sqlite3
from torba.client.basetransaction import BaseTransaction, TXRefImmutable from torba.client.basetransaction import BaseTransaction, TXRefImmutable
from torba.client.baseaccount import BaseAccount from torba.client.bip32 import PubKey
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -165,9 +166,8 @@ def query(select, **constraints) -> Tuple[str, Dict[str, Any]]:
offset = constraints.pop('offset', None) offset = constraints.pop('offset', None)
order_by = constraints.pop('order_by', None) order_by = constraints.pop('order_by', None)
constraints.pop('my_accounts', None) accounts = constraints.pop('accounts', [])
accounts = constraints.pop('accounts', None) if accounts:
if accounts is not None:
constraints['account__in'] = [a.public_key.address for a in accounts] constraints['account__in'] = [a.public_key.address for a in accounts]
where, values = constraints_to_sql(constraints) where, values = constraints_to_sql(constraints)
@ -285,26 +285,33 @@ class SQLiteMixin:
class BaseDatabase(SQLiteMixin): class BaseDatabase(SQLiteMixin):
SCHEMA_VERSION = "1.0" SCHEMA_VERSION = "1.1"
PRAGMAS = """ PRAGMAS = """
pragma journal_mode=WAL; pragma journal_mode=WAL;
""" """
CREATE_ACCOUNT_TABLE = """
create table if not exists account_address (
account text not null,
address text not null,
chain integer not null,
pubkey blob not null,
chain_code blob not null,
n integer not null,
depth integer not null,
primary key (account, address)
);
create index if not exists address_account_idx on account_address (address, account);
"""
CREATE_PUBKEY_ADDRESS_TABLE = """ CREATE_PUBKEY_ADDRESS_TABLE = """
create table if not exists pubkey_address ( create table if not exists pubkey_address (
address text primary key, address text primary key,
account text not null,
chain integer not null,
position integer not null,
pubkey blob not null,
history text, history text,
used_times integer not null default 0 used_times integer not null default 0
); );
""" """
CREATE_PUBKEY_ADDRESS_INDEX = """
create index if not exists pubkey_address_account_idx on pubkey_address (account);
"""
CREATE_TX_TABLE = """ CREATE_TX_TABLE = """
create table if not exists tx ( create table if not exists tx (
@ -326,8 +333,6 @@ class BaseDatabase(SQLiteMixin):
script blob not null, script blob not null,
is_reserved boolean not null default 0 is_reserved boolean not null default 0
); );
"""
CREATE_TXO_INDEX = """
create index if not exists txo_address_idx on txo (address); create index if not exists txo_address_idx on txo (address);
""" """
@ -337,21 +342,17 @@ class BaseDatabase(SQLiteMixin):
txoid text references txo, txoid text references txo,
address text references pubkey_address address text references pubkey_address
); );
"""
CREATE_TXI_INDEX = """
create index if not exists txi_address_idx on txi (address); create index if not exists txi_address_idx on txi (address);
create index if not exists txi_txoid_idx on txi (txoid); create index if not exists txi_txoid_idx on txi (txoid);
""" """
CREATE_TABLES_QUERY = ( CREATE_TABLES_QUERY = (
PRAGMAS + PRAGMAS +
CREATE_TX_TABLE + CREATE_ACCOUNT_TABLE +
CREATE_PUBKEY_ADDRESS_TABLE + CREATE_PUBKEY_ADDRESS_TABLE +
CREATE_PUBKEY_ADDRESS_INDEX + CREATE_TX_TABLE +
CREATE_TXO_TABLE + CREATE_TXO_TABLE +
CREATE_TXO_INDEX + CREATE_TXI_TABLE
CREATE_TXI_TABLE +
CREATE_TXI_INDEX
) )
@staticmethod @staticmethod
@ -434,25 +435,22 @@ class BaseDatabase(SQLiteMixin):
async def select_transactions(self, cols, accounts=None, **constraints): async def select_transactions(self, cols, accounts=None, **constraints):
if not set(constraints) & {'txid', 'txid__in'}: if not set(constraints) & {'txid', 'txid__in'}:
assert accounts is not None, "'accounts' argument required when no 'txid' constraint" assert accounts, "'accounts' argument required when no 'txid' constraint is present"
constraints.update({ constraints.update({
f'$account{i}': a.public_key.address for i, a in enumerate(accounts) f'$account{i}': a.public_key.address for i, a in enumerate(accounts)
}) })
account_values = ', '.join([f':$account{i}' for i in range(len(accounts))]) account_values = ', '.join([f':$account{i}' for i in range(len(accounts))])
where = f" WHERE account_address.account IN ({account_values})"
constraints['txid__in'] = f""" constraints['txid__in'] = f"""
SELECT txo.txid FROM txo SELECT txo.txid FROM txo JOIN account_address USING (address) {where}
INNER JOIN pubkey_address USING (address) WHERE pubkey_address.account IN ({account_values})
UNION UNION
SELECT txi.txid FROM txi SELECT txi.txid FROM txi JOIN account_address USING (address) {where}
INNER JOIN pubkey_address USING (address) WHERE pubkey_address.account IN ({account_values})
""" """
return await self.db.execute_fetchall( return await self.db.execute_fetchall(
*query("SELECT {} FROM tx".format(cols), **constraints) *query("SELECT {} FROM tx".format(cols), **constraints)
) )
async def get_transactions(self, **constraints): async def get_transactions(self, wallet=None, **constraints):
accounts = constraints.get('accounts', None)
tx_rows = await self.select_transactions( tx_rows = await self.select_transactions(
'txid, raw, height, position, is_verified', 'txid, raw, height, position, is_verified',
order_by=constraints.pop('order_by', ["height=0 DESC", "height DESC", "position DESC"]), order_by=constraints.pop('order_by', ["height=0 DESC", "height DESC", "position DESC"]),
@ -477,7 +475,7 @@ class BaseDatabase(SQLiteMixin):
annotated_txos.update({ annotated_txos.update({
txo.id: txo for txo in txo.id: txo for txo in
(await self.get_txos( (await self.get_txos(
my_accounts=accounts, wallet=wallet,
txid__in=txids[offset:offset+step], txid__in=txids[offset:offset+step],
)) ))
}) })
@ -487,7 +485,7 @@ class BaseDatabase(SQLiteMixin):
referenced_txos.update({ referenced_txos.update({
txo.id: txo for txo in txo.id: txo for txo in
(await self.get_txos( (await self.get_txos(
my_accounts=accounts, wallet=wallet,
txoid__in=txi_txoids[offset:offset+step], txoid__in=txi_txoids[offset:offset+step],
)) ))
}) })
@ -507,6 +505,7 @@ class BaseDatabase(SQLiteMixin):
return txs return txs
async def get_transaction_count(self, **constraints): async def get_transaction_count(self, **constraints):
constraints.pop('wallet', None)
constraints.pop('offset', None) constraints.pop('offset', None)
constraints.pop('limit', None) constraints.pop('limit', None)
constraints.pop('order_by', None) constraints.pop('order_by', None)
@ -519,23 +518,24 @@ class BaseDatabase(SQLiteMixin):
return txs[0] return txs[0]
async def select_txos(self, cols, **constraints): async def select_txos(self, cols, **constraints):
return await self.db.execute_fetchall(*query( sql = "SELECT {} FROM txo JOIN tx USING (txid)"
"SELECT {} FROM txo" if 'accounts' in constraints:
" JOIN pubkey_address USING (address)" sql += " JOIN account_address USING (address)"
" JOIN tx USING (txid)".format(cols), **constraints return await self.db.execute_fetchall(*query(sql.format(cols), **constraints))
))
async def get_txos(self, my_accounts=None, no_tx=False, **constraints): async def get_txos(self, wallet=None, no_tx=False, **constraints):
my_accounts = [ my_accounts = set(a.public_key.address for a in wallet.accounts) if wallet else set()
(a.public_key.address if isinstance(a, BaseAccount) else a)
for a in (my_accounts or constraints.get('accounts', []))
]
if 'order_by' not in constraints: if 'order_by' not in constraints:
constraints['order_by'] = [ constraints['order_by'] = [
"tx.height=0 DESC", "tx.height DESC", "tx.position DESC", "txo.position"] "tx.height=0 DESC", "tx.height DESC", "tx.position DESC", "txo.position"
]
rows = await self.select_txos( rows = await self.select_txos(
"tx.txid, raw, tx.height, tx.position, tx.is_verified, " """
"txo.position, chain, account, amount, script", tx.txid, raw, tx.height, tx.position, tx.is_verified, txo.position, amount, script, (
select json_group_object(account, chain) from account_address
where account_address.address=txo.address
)
""",
**constraints **constraints
) )
txos = [] txos = []
@ -544,8 +544,8 @@ class BaseDatabase(SQLiteMixin):
for row in rows: for row in rows:
if no_tx: if no_tx:
txo = output_class( txo = output_class(
amount=row[8], amount=row[6],
script=output_class.script_class(row[9]), script=output_class.script_class(row[7]),
tx_ref=TXRefImmutable.from_id(row[0], row[2]), tx_ref=TXRefImmutable.from_id(row[0], row[2]),
position=row[5] position=row[5]
) )
@ -555,12 +555,18 @@ class BaseDatabase(SQLiteMixin):
row[1], height=row[2], position=row[3], is_verified=row[4] row[1], height=row[2], position=row[3], is_verified=row[4]
) )
txo = txs[row[0]].outputs[row[5]] txo = txs[row[0]].outputs[row[5]]
txo.is_change = row[6] == 1 row_accounts = json.loads(row[8])
txo.is_my_account = row[7] in my_accounts account_match = set(row_accounts) & my_accounts
if account_match:
txo.is_my_account = True
txo.is_change = row_accounts[account_match.pop()] == 1
else:
txo.is_change = txo.is_my_account = False
txos.append(txo) txos.append(txo)
return txos return txos
async def get_txo_count(self, **constraints): async def get_txo_count(self, **constraints):
constraints.pop('wallet', None)
constraints.pop('offset', None) constraints.pop('offset', None)
constraints.pop('limit', None) constraints.pop('limit', None)
constraints.pop('order_by', None) constraints.pop('order_by', None)
@ -580,41 +586,56 @@ class BaseDatabase(SQLiteMixin):
self.constrain_utxo(constraints) self.constrain_utxo(constraints)
return self.get_txo_count(**constraints) return self.get_txo_count(**constraints)
async def get_balance(self, **constraints): async def get_balance(self, wallet=None, accounts=None, **constraints):
assert wallet or accounts, \
"'wallet' or 'accounts' constraints required to calculate balance"
constraints['accounts'] = accounts or wallet.accounts
self.constrain_utxo(constraints) self.constrain_utxo(constraints)
balance = await self.select_txos('SUM(amount)', **constraints) balance = await self.select_txos('SUM(amount)', **constraints)
return balance[0][0] or 0 return balance[0][0] or 0
async def select_addresses(self, cols, **constraints): async def select_addresses(self, cols, **constraints):
return await self.db.execute_fetchall(*query( return await self.db.execute_fetchall(*query(
"SELECT {} FROM pubkey_address".format(cols), **constraints "SELECT {} FROM pubkey_address JOIN account_address USING (address)".format(cols),
**constraints
)) ))
async def get_addresses(self, cols=('address', 'account', 'chain', 'position', 'used_times'), async def get_addresses(self, **constraints):
**constraints): cols = (
addresses = await self.select_addresses(', '.join(cols), **constraints) 'address', 'account', 'chain', 'history', 'used_times',
return rows_to_dict(addresses, cols) 'pubkey', 'chain_code', 'n', 'depth'
)
addresses = rows_to_dict(await self.select_addresses(', '.join(cols), **constraints), cols)
for address in addresses:
address['pubkey'] = PubKey(
self.ledger, address.pop('pubkey'), address.pop('chain_code'),
address.pop('n'), address.pop('depth')
)
return addresses
async def get_address_count(self, **constraints): async def get_address_count(self, **constraints):
count = await self.select_addresses('count(*)', **constraints) count = await self.select_addresses('count(*)', **constraints)
return count[0][0] return count[0][0]
async def get_address(self, **constraints): async def get_address(self, **constraints):
addresses = await self.get_addresses( addresses = await self.get_addresses(limit=1, **constraints)
cols=('address', 'account', 'chain', 'position', 'pubkey', 'history', 'used_times'),
limit=1, **constraints
)
if addresses: if addresses:
return addresses[0] return addresses[0]
async def add_keys(self, account, chain, keys): async def add_keys(self, account, chain, pubkeys):
await self.db.executemany( await self.db.executemany(
"insert into pubkey_address (address, account, chain, position, pubkey) values (?, ?, ?, ?, ?)", "insert or ignore into account_address "
( "(account, address, chain, pubkey, chain_code, n, depth) values "
(pubkey.address, account.public_key.address, chain, position, "(?, ?, ?, ?, ?, ?, ?)", ((
sqlite3.Binary(pubkey.pubkey_bytes)) account.id, k.address, chain,
for position, pubkey in keys sqlite3.Binary(k.pubkey_bytes),
sqlite3.Binary(k.chain_code),
k.n, k.depth
) for k in pubkeys)
) )
await self.db.executemany(
"insert or ignore into pubkey_address (address) values (?)",
((pubkey.address,) for pubkey in pubkeys)
) )
async def _set_address_history(self, address, history): async def _set_address_history(self, address, history):

View file

@ -179,29 +179,29 @@ class BaseLedger(metaclass=LedgerRegistry):
def add_account(self, account: baseaccount.BaseAccount): def add_account(self, account: baseaccount.BaseAccount):
self.accounts.append(account) self.accounts.append(account)
async def _get_account_and_address_info_for_address(self, address): async def _get_account_and_address_info_for_address(self, wallet, address):
match = await self.db.get_address(address=address) match = await self.db.get_address(accounts=wallet.accounts, address=address)
if match: if match:
for account in self.accounts: for account in wallet.accounts:
if match['account'] == account.public_key.address: if match['account'] == account.public_key.address:
return account, match return account, match
async def get_private_key_for_address(self, address) -> Optional[PrivateKey]: async def get_private_key_for_address(self, wallet, address) -> Optional[PrivateKey]:
match = await self._get_account_and_address_info_for_address(address) match = await self._get_account_and_address_info_for_address(wallet, address)
if match: if match:
account, address_info = match account, address_info = match
return account.get_private_key(address_info['chain'], address_info['position']) return account.get_private_key(address_info['chain'], address_info['pubkey'].n)
return None return None
async def get_public_key_for_address(self, address) -> Optional[PubKey]: async def get_public_key_for_address(self, wallet, address) -> Optional[PubKey]:
match = await self._get_account_and_address_info_for_address(address) match = await self._get_account_and_address_info_for_address(wallet, address)
if match: if match:
account, address_info = match _, address_info = match
return account.get_public_key(address_info['chain'], address_info['position']) return address_info['pubkey']
return None return None
async def get_account_for_address(self, address): async def get_account_for_address(self, wallet, address):
match = await self._get_account_and_address_info_for_address(address) match = await self._get_account_and_address_info_for_address(wallet, address)
if match: if match:
return match[0] return match[0]
@ -214,15 +214,9 @@ class BaseLedger(metaclass=LedgerRegistry):
return estimators return estimators
async def get_addresses(self, **constraints): async def get_addresses(self, **constraints):
self.constraint_account_or_all(constraints) return await self.db.get_addresses(**constraints)
addresses = await self.db.get_addresses(**constraints)
for address in addresses:
public_key = await self.get_public_key_for_address(address['address'])
address['public_key'] = public_key.extended_key_string()
return addresses
def get_address_count(self, **constraints): def get_address_count(self, **constraints):
self.constraint_account_or_all(constraints)
return self.db.get_address_count(**constraints) return self.db.get_address_count(**constraints)
async def get_spendable_utxos(self, amount: int, funding_accounts): async def get_spendable_utxos(self, amount: int, funding_accounts):
@ -244,29 +238,16 @@ class BaseLedger(metaclass=LedgerRegistry):
def release_tx(self, tx): def release_tx(self, tx):
return self.release_outputs([txi.txo_ref.txo for txi in tx.inputs]) return self.release_outputs([txi.txo_ref.txo for txi in tx.inputs])
def constraint_account_or_all(self, constraints):
if 'accounts' in constraints:
return
account = constraints.pop('account', None)
if account:
constraints['accounts'] = [account]
else:
constraints['accounts'] = self.accounts
def get_utxos(self, **constraints): def get_utxos(self, **constraints):
self.constraint_account_or_all(constraints)
return self.db.get_utxos(**constraints) return self.db.get_utxos(**constraints)
def get_utxo_count(self, **constraints): def get_utxo_count(self, **constraints):
self.constraint_account_or_all(constraints)
return self.db.get_utxo_count(**constraints) return self.db.get_utxo_count(**constraints)
def get_transactions(self, **constraints): def get_transactions(self, **constraints):
self.constraint_account_or_all(constraints)
return self.db.get_transactions(**constraints) return self.db.get_transactions(**constraints)
def get_transaction_count(self, **constraints): def get_transaction_count(self, **constraints):
self.constraint_account_or_all(constraints)
return self.db.get_transaction_count(**constraints) return self.db.get_transaction_count(**constraints)
async def get_local_status_and_history(self, address, history=None): async def get_local_status_and_history(self, address, history=None):
@ -577,7 +558,7 @@ class BaseLedger(metaclass=LedgerRegistry):
addresses.add( addresses.add(
self.hash160_to_address(txo.script.values['pubkey_hash']) self.hash160_to_address(txo.script.values['pubkey_hash'])
) )
records = await self.db.get_addresses(cols=('address',), address__in=addresses) records = await self.db.get_addresses(address__in=addresses)
_, pending = await asyncio.wait([ _, pending = await asyncio.wait([
self.on_transaction.where(partial( self.on_transaction.where(partial(
lambda a, e: a == e.address and e.tx.height >= height and e.tx.id == tx.id, lambda a, e: a == e.address and e.tx.height >= height and e.tx.id == tx.id,

View file

@ -1,6 +1,6 @@
import asyncio import asyncio
import logging import logging
from typing import Type, MutableSequence, MutableMapping from typing import Type, MutableSequence, MutableMapping, Optional
from torba.client.baseledger import BaseLedger, LedgerRegistry from torba.client.baseledger import BaseLedger, LedgerRegistry
from torba.client.wallet import Wallet, WalletStorage from torba.client.wallet import Wallet, WalletStorage
@ -41,16 +41,6 @@ class BaseWalletManager:
self.wallets.append(wallet) self.wallets.append(wallet)
return wallet return wallet
async def get_detailed_accounts(self, **kwargs):
ledgers = {}
for i, account in enumerate(self.accounts):
details = await account.get_details(**kwargs)
details['is_default'] = i == 0
ledger_id = account.ledger.get_id()
ledgers.setdefault(ledger_id, [])
ledgers[ledger_id].append(details)
return ledgers
@property @property
def default_wallet(self): def default_wallet(self):
for wallet in self.wallets: for wallet in self.wallets:
@ -78,3 +68,14 @@ class BaseWalletManager:
l.stop() for l in self.ledgers.values() l.stop() for l in self.ledgers.values()
)) ))
self.running = False self.running = False
def get_wallet_or_default(self, wallet_id: Optional[str]) -> Wallet:
if wallet_id is None:
return self.default_wallet
return self.get_wallet_or_error(wallet_id)
def get_wallet_or_error(self, wallet_id: str) -> Wallet:
for wallet in self.wallets:
if wallet.id == wallet_id:
return wallet
raise ValueError(f"Couldn't find wallet: {wallet_id}.")

View file

@ -1,6 +1,6 @@
import logging import logging
import typing import typing
from typing import List, Iterable, Optional from typing import List, Iterable, Optional, Tuple
from binascii import hexlify from binascii import hexlify
from torba.client.basescript import BaseInputScript, BaseOutputScript from torba.client.basescript import BaseInputScript, BaseOutputScript
@ -12,7 +12,7 @@ from torba.client.util import ReadOnlyList
from torba.client.errors import InsufficientFundsError from torba.client.errors import InsufficientFundsError
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from torba.client import baseledger from torba.client import baseledger, wallet as basewallet
log = logging.getLogger() log = logging.getLogger()
@ -431,21 +431,32 @@ class BaseTransaction:
self.locktime = stream.read_uint32() self.locktime = stream.read_uint32()
@classmethod @classmethod
def ensure_all_have_same_ledger(cls, funding_accounts: Iterable[BaseAccount], def ensure_all_have_same_ledger_and_wallet(
change_account: BaseAccount = None) -> 'baseledger.BaseLedger': cls, funding_accounts: Iterable[BaseAccount],
ledger = None change_account: BaseAccount = None) -> Tuple['baseledger.BaseLedger', 'basewallet.Wallet']:
ledger = wallet = None
for account in funding_accounts: for account in funding_accounts:
if ledger is None: if ledger is None:
ledger = account.ledger ledger = account.ledger
wallet = account.wallet
if ledger != account.ledger: if ledger != account.ledger:
raise ValueError( raise ValueError(
'All funding accounts used to create a transaction must be on the same ledger.' 'All funding accounts used to create a transaction must be on the same ledger.'
) )
if change_account is not None and change_account.ledger != ledger: if wallet != account.wallet:
raise ValueError(
'All funding accounts used to create a transaction must be from the same wallet.'
)
if change_account is not None:
if change_account.ledger != ledger:
raise ValueError('Change account must use same ledger as funding accounts.') raise ValueError('Change account must use same ledger as funding accounts.')
if change_account.wallet != wallet:
raise ValueError('Change account must use same wallet as funding accounts.')
if ledger is None: if ledger is None:
raise ValueError('No ledger found.') raise ValueError('No ledger found.')
return ledger if wallet is None:
raise ValueError('No wallet found.')
return ledger, wallet
@classmethod @classmethod
async def create(cls, inputs: Iterable[BaseInput], outputs: Iterable[BaseOutput], async def create(cls, inputs: Iterable[BaseInput], outputs: Iterable[BaseOutput],
@ -458,7 +469,7 @@ class BaseTransaction:
.add_inputs(inputs) \ .add_inputs(inputs) \
.add_outputs(outputs) .add_outputs(outputs)
ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account) ledger, _ = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account)
# value of the outputs plus associated fees # value of the outputs plus associated fees
cost = ( cost = (
@ -524,15 +535,15 @@ class BaseTransaction:
return hash_type return hash_type
async def sign(self, funding_accounts: Iterable[BaseAccount]): async def sign(self, funding_accounts: Iterable[BaseAccount]):
ledger = self.ensure_all_have_same_ledger(funding_accounts) ledger, wallet = self.ensure_all_have_same_ledger_and_wallet(funding_accounts)
for i, txi in enumerate(self._inputs): for i, txi in enumerate(self._inputs):
assert txi.script is not None assert txi.script is not None
assert txi.txo_ref.txo is not None assert txi.txo_ref.txo is not None
txo_script = txi.txo_ref.txo.script txo_script = txi.txo_ref.txo.script
if txo_script.is_pay_pubkey_hash: if txo_script.is_pay_pubkey_hash:
address = ledger.hash160_to_address(txo_script.values['pubkey_hash']) address = ledger.hash160_to_address(txo_script.values['pubkey_hash'])
private_key = await ledger.get_private_key_for_address(address) private_key = await ledger.get_private_key_for_address(wallet, address)
assert private_key is not None assert private_key is not None, 'Cannot find private key for signing output.'
tx = self._serialize_for_signature(i) tx = self._serialize_for_signature(i)
txi.script.values['signature'] = \ txi.script.values['signature'] = \
private_key.sign(tx) + bytes((self.signature_hash_type(1),)) private_key.sign(tx) + bytes((self.signature_hash_type(1),))

View file

@ -3,7 +3,7 @@ import stat
import json import json
import zlib import zlib
import typing import typing
from typing import Sequence, MutableSequence from typing import List, Sequence, MutableSequence, Optional
from hashlib import sha256 from hashlib import sha256
from operator import attrgetter from operator import attrgetter
from torba.client.hash import better_aes_encrypt, better_aes_decrypt from torba.client.hash import better_aes_encrypt, better_aes_decrypt
@ -25,12 +25,51 @@ class Wallet:
self.accounts = accounts or [] self.accounts = accounts or []
self.storage = storage or WalletStorage() self.storage = storage or WalletStorage()
def add_account(self, account): @property
def id(self):
if self.storage.path:
return os.path.basename(self.storage.path)
return self.name
def add_account(self, account: 'baseaccount.BaseAccount'):
self.accounts.append(account) self.accounts.append(account)
def generate_account(self, ledger: 'baseledger.BaseLedger') -> 'baseaccount.BaseAccount': def generate_account(self, ledger: 'baseledger.BaseLedger') -> 'baseaccount.BaseAccount':
return ledger.account_class.generate(ledger, self) return ledger.account_class.generate(ledger, self)
@property
def default_account(self) -> Optional['baseaccount.BaseAccount']:
for account in self.accounts:
return account
return None
def get_account_or_default(self, account_id: str) -> Optional['baseaccount.BaseAccount']:
if account_id is None:
return self.default_account
return self.get_account_or_error(account_id)
def get_account_or_error(self, account_id: str) -> 'baseaccount.BaseAccount':
for account in self.accounts:
if account.id == account_id:
return account
raise ValueError(f"Couldn't find account: {account_id}.")
def get_accounts_or_all(self, account_ids: List[str]) -> Sequence['baseaccount.BaseAccount']:
return [
self.get_account_or_error(account_id)
for account_id in account_ids
] if account_ids else self.accounts
async def get_detailed_accounts(self, **kwargs):
ledgers = {}
for i, account in enumerate(self.accounts):
details = await account.get_details(**kwargs)
details['is_default'] = i == 0
ledger_id = account.ledger.get_id()
ledgers.setdefault(ledger_id, [])
ledgers[ledger_id].append(details)
return ledgers
@classmethod @classmethod
def from_storage(cls, storage: 'WalletStorage', manager: 'basemanager.BaseWalletManager') -> 'Wallet': def from_storage(cls, storage: 'WalletStorage', manager: 'basemanager.BaseWalletManager') -> 'Wallet':
json_dict = storage.read() json_dict = storage.read()
@ -54,11 +93,6 @@ class Wallet:
def save(self): def save(self):
self.storage.write(self.to_dict()) self.storage.write(self.to_dict())
@property
def default_account(self):
for account in self.accounts:
return account
@property @property
def hash(self) -> bytes: def hash(self) -> bytes:
h = sha256() h = sha256()

View file

@ -151,7 +151,9 @@ class WalletNode:
async def start(self, spv_node: 'SPVNode', seed=None, connect=True): async def start(self, spv_node: 'SPVNode', seed=None, connect=True):
self.data_path = tempfile.mkdtemp() self.data_path = tempfile.mkdtemp()
wallet_file_name = os.path.join(self.data_path, 'my_wallet.json') wallets_dir = os.path.join(self.data_path, 'wallets')
os.mkdir(wallets_dir)
wallet_file_name = os.path.join(wallets_dir, 'my_wallet.json')
with open(wallet_file_name, 'w') as wallet_file: with open(wallet_file_name, 'w') as wallet_file:
wallet_file.write('{"version": 1, "accounts": []}\n') wallet_file.write('{"version": 1, "accounts": []}\n')
self.manager = self.manager_class.from_config({ self.manager = self.manager_class.from_config({