wallet commands and wallet_id argument to other commands
This commit is contained in:
parent
9fbc83fe9d
commit
84587ac232
16 changed files with 608 additions and 406 deletions
File diff suppressed because it is too large
Load diff
|
@ -139,34 +139,34 @@ class Account(BaseAccount):
|
|||
return details
|
||||
|
||||
def get_transaction_history(self, **constraints):
|
||||
return self.ledger.get_transaction_history(account=self, **constraints)
|
||||
return self.ledger.get_transaction_history(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
def get_transaction_history_count(self, **constraints):
|
||||
return self.ledger.get_transaction_history_count(account=self, **constraints)
|
||||
return self.ledger.get_transaction_history_count(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
def get_claims(self, **constraints):
|
||||
return self.ledger.get_claims(account=self, **constraints)
|
||||
return self.ledger.get_claims(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
def get_claim_count(self, **constraints):
|
||||
return self.ledger.get_claim_count(account=self, **constraints)
|
||||
return self.ledger.get_claim_count(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
def get_streams(self, **constraints):
|
||||
return self.ledger.get_streams(account=self, **constraints)
|
||||
return self.ledger.get_streams(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
def get_stream_count(self, **constraints):
|
||||
return self.ledger.get_stream_count(account=self, **constraints)
|
||||
return self.ledger.get_stream_count(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
def get_channels(self, **constraints):
|
||||
return self.ledger.get_channels(account=self, **constraints)
|
||||
return self.ledger.get_channels(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
def get_channel_count(self, **constraints):
|
||||
return self.ledger.get_channel_count(account=self, **constraints)
|
||||
return self.ledger.get_channel_count(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
def get_supports(self, **constraints):
|
||||
return self.ledger.get_supports(account=self, **constraints)
|
||||
return self.ledger.get_supports(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
def get_support_count(self, **constraints):
|
||||
return self.ledger.get_support_count(account=self, **constraints)
|
||||
return self.ledger.get_support_count(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
def get_support_summary(self):
|
||||
return self.ledger.db.get_supports_summary(account_id=self.id)
|
||||
|
|
|
@ -22,18 +22,18 @@ class WalletDatabase(BaseDatabase):
|
|||
claim_id text,
|
||||
claim_name text
|
||||
);
|
||||
create index if not exists txo_address_idx on txo (address);
|
||||
create index if not exists txo_claim_id_idx on txo (claim_id);
|
||||
create index if not exists txo_txo_type_idx on txo (txo_type);
|
||||
"""
|
||||
|
||||
CREATE_TABLES_QUERY = (
|
||||
BaseDatabase.CREATE_TX_TABLE +
|
||||
BaseDatabase.PRAGMAS +
|
||||
BaseDatabase.CREATE_ACCOUNT_TABLE +
|
||||
BaseDatabase.CREATE_PUBKEY_ADDRESS_TABLE +
|
||||
BaseDatabase.CREATE_PUBKEY_ADDRESS_INDEX +
|
||||
BaseDatabase.CREATE_TX_TABLE +
|
||||
CREATE_TXO_TABLE +
|
||||
BaseDatabase.CREATE_TXO_INDEX +
|
||||
BaseDatabase.CREATE_TXI_TABLE +
|
||||
BaseDatabase.CREATE_TXI_INDEX
|
||||
BaseDatabase.CREATE_TXI_TABLE
|
||||
)
|
||||
|
||||
def txo_to_row(self, tx, address, txo):
|
||||
|
@ -50,18 +50,16 @@ class WalletDatabase(BaseDatabase):
|
|||
row['claim_name'] = txo.claim_name
|
||||
return row
|
||||
|
||||
async def get_txos(self, **constraints) -> List[Output]:
|
||||
my_accounts = constraints.get('my_accounts', constraints.get('accounts', []))
|
||||
|
||||
txos = await super().get_txos(**constraints)
|
||||
async def get_txos(self, wallet=None, no_tx=False, **constraints) -> List[Output]:
|
||||
txos = await super().get_txos(wallet=wallet, no_tx=no_tx, **constraints)
|
||||
|
||||
channel_ids = set()
|
||||
for txo in txos:
|
||||
if txo.is_claim and txo.can_decode_claim:
|
||||
if txo.claim.is_signed:
|
||||
channel_ids.add(txo.claim.signing_channel_id)
|
||||
if txo.claim.is_channel and my_accounts:
|
||||
for account in my_accounts:
|
||||
if txo.claim.is_channel and wallet:
|
||||
for account in wallet.accounts:
|
||||
private_key = account.get_channel_private_key(
|
||||
txo.claim.channel.public_key_bytes
|
||||
)
|
||||
|
@ -73,7 +71,7 @@ class WalletDatabase(BaseDatabase):
|
|||
channels = {
|
||||
txo.claim_id: txo for txo in
|
||||
(await self.get_claims(
|
||||
my_accounts=my_accounts,
|
||||
wallet=wallet,
|
||||
claim_id__in=channel_ids
|
||||
))
|
||||
}
|
||||
|
|
|
@ -121,39 +121,30 @@ class MainNetLedger(BaseLedger):
|
|||
return super().get_utxo_count(**constraints)
|
||||
|
||||
def get_claims(self, **constraints):
|
||||
self.constraint_account_or_all(constraints)
|
||||
return self.db.get_claims(**constraints)
|
||||
|
||||
def get_claim_count(self, **constraints):
|
||||
self.constraint_account_or_all(constraints)
|
||||
return self.db.get_claim_count(**constraints)
|
||||
|
||||
def get_streams(self, **constraints):
|
||||
self.constraint_account_or_all(constraints)
|
||||
return self.db.get_streams(**constraints)
|
||||
|
||||
def get_stream_count(self, **constraints):
|
||||
self.constraint_account_or_all(constraints)
|
||||
return self.db.get_stream_count(**constraints)
|
||||
|
||||
def get_channels(self, **constraints):
|
||||
self.constraint_account_or_all(constraints)
|
||||
return self.db.get_channels(**constraints)
|
||||
|
||||
def get_channel_count(self, **constraints):
|
||||
self.constraint_account_or_all(constraints)
|
||||
return self.db.get_channel_count(**constraints)
|
||||
|
||||
def get_supports(self, **constraints):
|
||||
self.constraint_account_or_all(constraints)
|
||||
return self.db.get_supports(**constraints)
|
||||
|
||||
def get_support_count(self, **constraints):
|
||||
self.constraint_account_or_all(constraints)
|
||||
return self.db.get_support_count(**constraints)
|
||||
|
||||
async def get_transaction_history(self, **constraints):
|
||||
self.constraint_account_or_all(constraints)
|
||||
txs = await self.db.get_transactions(**constraints)
|
||||
headers = self.headers
|
||||
history = []
|
||||
|
@ -249,7 +240,6 @@ class MainNetLedger(BaseLedger):
|
|||
return history
|
||||
|
||||
def get_transaction_history_count(self, **constraints):
|
||||
self.constraint_account_or_all(constraints)
|
||||
return self.db.get_transaction_count(**constraints)
|
||||
|
||||
|
||||
|
|
|
@ -203,7 +203,7 @@ class Transaction(BaseTransaction):
|
|||
|
||||
@classmethod
|
||||
def pay(cls, amount: int, address: bytes, funding_accounts: List[Account], change_account: Account):
|
||||
ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account)
|
||||
ledger, wallet = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account)
|
||||
output = Output.pay_pubkey_hash(amount, ledger.address_to_hash160(address))
|
||||
return cls.create([], [output], funding_accounts, change_account)
|
||||
|
||||
|
@ -211,7 +211,7 @@ class Transaction(BaseTransaction):
|
|||
def claim_create(
|
||||
cls, name: str, claim: Claim, amount: int, holding_address: str,
|
||||
funding_accounts: List[Account], change_account: Account, signing_channel: Output = None):
|
||||
ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account)
|
||||
ledger, wallet = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account)
|
||||
claim_output = Output.pay_claim_name_pubkey_hash(
|
||||
amount, name, claim, ledger.address_to_hash160(holding_address)
|
||||
)
|
||||
|
@ -223,7 +223,7 @@ class Transaction(BaseTransaction):
|
|||
def claim_update(
|
||||
cls, previous_claim: Output, claim: Claim, amount: int, holding_address: str,
|
||||
funding_accounts: List[Account], change_account: Account, signing_channel: Output = None):
|
||||
ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account)
|
||||
ledger, wallet = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account)
|
||||
updated_claim = Output.pay_update_claim_pubkey_hash(
|
||||
amount, previous_claim.claim_name, previous_claim.claim_id,
|
||||
claim, ledger.address_to_hash160(holding_address)
|
||||
|
@ -239,7 +239,7 @@ class Transaction(BaseTransaction):
|
|||
@classmethod
|
||||
def support(cls, claim_name: str, claim_id: str, amount: int, holding_address: str,
|
||||
funding_accounts: List[Account], change_account: Account):
|
||||
ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account)
|
||||
ledger, wallet = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account)
|
||||
support_output = Output.pay_support_pubkey_hash(
|
||||
amount, claim_name, claim_id, ledger.address_to_hash160(holding_address)
|
||||
)
|
||||
|
@ -248,7 +248,7 @@ class Transaction(BaseTransaction):
|
|||
@classmethod
|
||||
def purchase(cls, claim: Output, amount: int, merchant_address: bytes,
|
||||
funding_accounts: List[Account], change_account: Account):
|
||||
ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account)
|
||||
ledger, wallet = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account)
|
||||
claim_output = Output.purchase_claim_pubkey_hash(
|
||||
amount, claim.claim_id, ledger.address_to_hash160(merchant_address)
|
||||
)
|
||||
|
|
|
@ -466,45 +466,47 @@ class ChannelCommands(CommandTestCase):
|
|||
'title': 'foo', 'email': 'new@email.com'}
|
||||
)
|
||||
|
||||
# send channel to someone else
|
||||
# move channel to another account
|
||||
new_account = await self.out(self.daemon.jsonrpc_account_create('second account'))
|
||||
account2_id, account2 = new_account['id'], self.daemon.get_account_or_error(new_account['id'])
|
||||
account2_id, account2 = new_account['id'], self.wallet.get_account_or_error(new_account['id'])
|
||||
|
||||
# before sending
|
||||
# before moving
|
||||
self.assertEqual(len(await self.daemon.jsonrpc_channel_list()), 3)
|
||||
self.assertEqual(len(await self.daemon.jsonrpc_channel_list(account_id=account2_id)), 0)
|
||||
|
||||
other_address = await account2.receiving.get_or_create_usable_address()
|
||||
tx = await self.out(self.channel_update(claim_id, claim_address=other_address))
|
||||
|
||||
# after sending
|
||||
# after moving
|
||||
self.assertEqual(len(await self.daemon.jsonrpc_channel_list()), 3)
|
||||
self.assertEqual(len(await self.daemon.jsonrpc_channel_list(account_id=self.account.id)), 2)
|
||||
self.assertEqual(len(await self.daemon.jsonrpc_channel_list(account_id=account2_id)), 1)
|
||||
|
||||
# shoud not have private key
|
||||
txo = (await account2.get_channels())[0]
|
||||
self.assertIsNone(txo.private_key)
|
||||
|
||||
# send the private key too
|
||||
private_key = self.account.get_channel_private_key(unhexlify(channel['public_key']))
|
||||
account2.add_channel_private_key(private_key)
|
||||
|
||||
# now should have private key
|
||||
txo = (await account2.get_channels())[0]
|
||||
self.assertIsNotNone(txo.private_key)
|
||||
|
||||
async def test_channel_export_import_into_new_account(self):
|
||||
async def test_channel_export_import_before_sending_channel(self):
|
||||
# export
|
||||
tx = await self.channel_create('@foo', '1.0')
|
||||
claim_id = self.get_claim_id(tx)
|
||||
channel_private_key = (await self.account.get_channels())[0].private_key
|
||||
exported_data = await self.out(self.daemon.jsonrpc_channel_export(claim_id))
|
||||
|
||||
# import
|
||||
daemon2 = await self.add_daemon()
|
||||
self.assertEqual(0, len(await daemon2.jsonrpc_channel_list()))
|
||||
await daemon2.jsonrpc_channel_import(exported_data)
|
||||
channels = await daemon2.jsonrpc_channel_list()
|
||||
self.assertEqual(1, len(channels))
|
||||
self.assertEqual(channel_private_key.to_string(), channels[0].private_key.to_string())
|
||||
|
||||
# second wallet can't update until channel is sent to it
|
||||
with self.assertRaisesRegex(AssertionError, 'Cannot find private key for signing output.'):
|
||||
await daemon2.jsonrpc_channel_update(claim_id, bid='0.5')
|
||||
|
||||
# now send the channel as well
|
||||
await self.channel_update(claim_id, claim_address=await daemon2.jsonrpc_address_unused())
|
||||
|
||||
# second wallet should be able to update now
|
||||
await daemon2.jsonrpc_channel_update(claim_id, bid='0.5')
|
||||
|
||||
|
||||
class StreamCommands(ClaimTestCase):
|
||||
|
||||
|
@ -565,7 +567,7 @@ class StreamCommands(ClaimTestCase):
|
|||
async def test_publishing_checks_all_accounts_for_channel(self):
|
||||
account1_id, account1 = self.account.id, self.account
|
||||
new_account = await self.out(self.daemon.jsonrpc_account_create('second account'))
|
||||
account2_id, account2 = new_account['id'], self.daemon.get_account_or_error(new_account['id'])
|
||||
account2_id, account2 = new_account['id'], self.wallet.get_account_or_error(new_account['id'])
|
||||
|
||||
await self.out(self.channel_create('@spam', '1.0'))
|
||||
self.assertEqual('8.989893', (await self.daemon.jsonrpc_account_balance())['available'])
|
||||
|
@ -782,7 +784,7 @@ class StreamCommands(ClaimTestCase):
|
|||
|
||||
# send claim to someone else
|
||||
new_account = await self.out(self.daemon.jsonrpc_account_create('second account'))
|
||||
account2_id, account2 = new_account['id'], self.daemon.get_account_or_error(new_account['id'])
|
||||
account2_id, account2 = new_account['id'], self.wallet.get_account_or_error(new_account['id'])
|
||||
|
||||
# before sending
|
||||
self.assertEqual(len(await self.daemon.jsonrpc_claim_list()), 4)
|
||||
|
@ -1079,13 +1081,12 @@ class StreamCommands(ClaimTestCase):
|
|||
class SupportCommands(CommandTestCase):
|
||||
|
||||
async def test_regular_supports_and_tip_supports(self):
|
||||
# account2 will be used to send tips and supports to account1
|
||||
account2_id = (await self.out(self.daemon.jsonrpc_account_create('second account')))['id']
|
||||
account2 = self.daemon.get_account_or_error(account2_id)
|
||||
wallet2 = await self.daemon.jsonrpc_wallet_add('wallet2', create_wallet=True, create_account=True)
|
||||
account2 = wallet2.accounts[0]
|
||||
|
||||
# send account2 5 LBC out of the 10 LBC in account1
|
||||
result = await self.out(self.daemon.jsonrpc_account_send(
|
||||
'5.0', await self.daemon.jsonrpc_address_unused(account2_id)
|
||||
'5.0', await self.daemon.jsonrpc_address_unused(wallet_id='wallet2')
|
||||
))
|
||||
await self.on_transaction_dict(result)
|
||||
|
||||
|
@ -1103,7 +1104,7 @@ class SupportCommands(CommandTestCase):
|
|||
# send a tip to the claim using account2
|
||||
tip = await self.out(
|
||||
self.daemon.jsonrpc_support_create(
|
||||
claim_id, '1.0', True, account2_id, funding_account_ids=[account2_id])
|
||||
claim_id, '1.0', True, account2.id, 'wallet2', funding_account_ids=[account2.id])
|
||||
)
|
||||
await self.confirm_tx(tip['txid'])
|
||||
|
||||
|
@ -1122,7 +1123,7 @@ class SupportCommands(CommandTestCase):
|
|||
|
||||
# verify that the outgoing tip is marked correctly as is_tip=True in account2
|
||||
txs2 = await self.out(
|
||||
self.daemon.jsonrpc_transaction_list(account2_id)
|
||||
self.daemon.jsonrpc_transaction_list(wallet_id='wallet2', account_id=account2.id)
|
||||
)
|
||||
self.assertEqual(len(txs2[0]['support_info']), 1)
|
||||
self.assertEqual(txs2[0]['support_info'][0]['balance_delta'], '-1.0')
|
||||
|
@ -1134,7 +1135,7 @@ class SupportCommands(CommandTestCase):
|
|||
# send a support to the claim using account2
|
||||
support = await self.out(
|
||||
self.daemon.jsonrpc_support_create(
|
||||
claim_id, '2.0', False, account2_id, funding_account_ids=[account2_id])
|
||||
claim_id, '2.0', False, account2.id, 'wallet2', funding_account_ids=[account2.id])
|
||||
)
|
||||
await self.confirm_tx(support['txid'])
|
||||
|
||||
|
@ -1143,7 +1144,7 @@ class SupportCommands(CommandTestCase):
|
|||
await self.assertBalance(account2, '1.999717')
|
||||
|
||||
# verify that the outgoing support is marked correctly as is_tip=False in account2
|
||||
txs2 = await self.out(self.daemon.jsonrpc_transaction_list(account2_id))
|
||||
txs2 = await self.out(self.daemon.jsonrpc_transaction_list(wallet_id='wallet2'))
|
||||
self.assertEqual(len(txs2[0]['support_info']), 1)
|
||||
self.assertEqual(txs2[0]['support_info'][0]['balance_delta'], '-2.0')
|
||||
self.assertEqual(txs2[0]['support_info'][0]['claim_id'], claim_id)
|
||||
|
|
|
@ -61,13 +61,17 @@ class TestAccount(AsyncioTestCase):
|
|||
address = await account.receiving.ensure_address_gap()
|
||||
self.assertEqual(address[0], 'bCqJrLHdoiRqEZ1whFZ3WHNb33bP34SuGx')
|
||||
|
||||
private_key = await self.ledger.get_private_key_for_address('bCqJrLHdoiRqEZ1whFZ3WHNb33bP34SuGx')
|
||||
private_key = await self.ledger.get_private_key_for_address(
|
||||
account.wallet, 'bCqJrLHdoiRqEZ1whFZ3WHNb33bP34SuGx'
|
||||
)
|
||||
self.assertEqual(
|
||||
private_key.extended_key_string(),
|
||||
'xprv9vwXVierUTT4hmoe3dtTeBfbNv1ph2mm8RWXARU6HsZjBaAoFaS2FRQu4fptR'
|
||||
'AyJWhJW42dmsEaC1nKnVKKTMhq3TVEHsNj1ca3ciZMKktT'
|
||||
)
|
||||
private_key = await self.ledger.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX')
|
||||
private_key = await self.ledger.get_private_key_for_address(
|
||||
account.wallet, 'BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX'
|
||||
)
|
||||
self.assertIsNone(private_key)
|
||||
|
||||
def test_load_and_save_account(self):
|
||||
|
|
|
@ -60,7 +60,7 @@ class TestHierarchicalDeterministicAccount(AsyncioTestCase):
|
|||
await account.receiving._generate_keys(8, 11)
|
||||
records = await account.receiving.get_address_records()
|
||||
self.assertEqual(
|
||||
[r['position'] for r in records],
|
||||
[r['pubkey'].n for r in records],
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
|
||||
)
|
||||
|
||||
|
@ -69,7 +69,7 @@ class TestHierarchicalDeterministicAccount(AsyncioTestCase):
|
|||
self.assertEqual(len(new_keys), 8)
|
||||
records = await account.receiving.get_address_records()
|
||||
self.assertEqual(
|
||||
[r['position'] for r in records],
|
||||
[r['pubkey'].n for r in records],
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
|
||||
)
|
||||
|
||||
|
@ -125,14 +125,18 @@ class TestHierarchicalDeterministicAccount(AsyncioTestCase):
|
|||
address = await account.receiving.ensure_address_gap()
|
||||
self.assertEqual(address[0], '1CDLuMfwmPqJiNk5C2Bvew6tpgjAGgUk8J')
|
||||
|
||||
private_key = await self.ledger.get_private_key_for_address('1CDLuMfwmPqJiNk5C2Bvew6tpgjAGgUk8J')
|
||||
private_key = await self.ledger.get_private_key_for_address(
|
||||
account.wallet, '1CDLuMfwmPqJiNk5C2Bvew6tpgjAGgUk8J'
|
||||
)
|
||||
self.assertEqual(
|
||||
private_key.extended_key_string(),
|
||||
'xprv9xV7rhbg6M4yWrdTeLorz3Q1GrQb4aQzzGWboP3du7W7UUztzNTUrEYTnDfz7o'
|
||||
'ptBygDxXYRppyiuenJpoBTgYP2C26E1Ah5FEALM24CsWi'
|
||||
)
|
||||
|
||||
invalid_key = await self.ledger.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX')
|
||||
invalid_key = await self.ledger.get_private_key_for_address(
|
||||
account.wallet, 'BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX'
|
||||
)
|
||||
self.assertIsNone(invalid_key)
|
||||
|
||||
self.assertEqual(
|
||||
|
@ -275,12 +279,18 @@ class TestSingleKeyAccount(AsyncioTestCase):
|
|||
self.assertEqual(len(new_keys), 1)
|
||||
self.assertEqual(new_keys[0], account.public_key.address)
|
||||
records = await account.receiving.get_address_records()
|
||||
pubkey = records[0].pop('pubkey')
|
||||
self.assertEqual(records, [{
|
||||
'position': 0, 'chain': 0,
|
||||
'chain': 0,
|
||||
'account': account.public_key.address,
|
||||
'address': account.public_key.address,
|
||||
'history': None,
|
||||
'used_times': 0
|
||||
}])
|
||||
self.assertEqual(
|
||||
pubkey.extended_key_string(),
|
||||
account.public_key.extended_key_string()
|
||||
)
|
||||
|
||||
# case #1: no new addresses needed
|
||||
empty = await account.receiving.ensure_address_gap()
|
||||
|
@ -333,14 +343,18 @@ class TestSingleKeyAccount(AsyncioTestCase):
|
|||
address = await account.receiving.ensure_address_gap()
|
||||
self.assertEqual(address[0], account.public_key.address)
|
||||
|
||||
private_key = await self.ledger.get_private_key_for_address(address[0])
|
||||
private_key = await self.ledger.get_private_key_for_address(
|
||||
account.wallet, address[0]
|
||||
)
|
||||
self.assertEqual(
|
||||
private_key.extended_key_string(),
|
||||
'xprv9s21ZrQH143K3TsAz5efNV8K93g3Ms3FXcjaWB9fVUsMwAoE3ZT4vYymkp'
|
||||
'5BxKKfnpz8J6sHDFriX1SnpvjNkzcks8XBnxjGLS83BTyfpna',
|
||||
)
|
||||
|
||||
invalid_key = await self.ledger.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX')
|
||||
invalid_key = await self.ledger.get_private_key_for_address(
|
||||
account.wallet, 'BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX'
|
||||
)
|
||||
self.assertIsNone(invalid_key)
|
||||
|
||||
self.assertEqual(
|
||||
|
|
|
@ -199,13 +199,14 @@ class TestQueries(AsyncioTestCase):
|
|||
'db': ledger_class.database_class(':memory:'),
|
||||
'headers': ledger_class.headers_class(':memory:'),
|
||||
})
|
||||
self.wallet = Wallet()
|
||||
await self.ledger.db.open()
|
||||
|
||||
async def asyncTearDown(self):
|
||||
await self.ledger.db.close()
|
||||
|
||||
async def create_account(self):
|
||||
account = self.ledger.account_class.generate(self.ledger, Wallet())
|
||||
async def create_account(self, wallet=None):
|
||||
account = self.ledger.account_class.generate(self.ledger, wallet or self.wallet)
|
||||
await account.ensure_address_gap()
|
||||
return account
|
||||
|
||||
|
@ -264,7 +265,8 @@ class TestQueries(AsyncioTestCase):
|
|||
tx = await self.create_tx_from_txo(tx.outputs[0], account, height=height)
|
||||
variable_limit = self.ledger.db.MAX_QUERY_VARIABLES
|
||||
for limit in range(variable_limit-2, variable_limit+2):
|
||||
txs = await self.ledger.get_transactions(limit=limit, order_by='height asc')
|
||||
txs = await self.ledger.get_transactions(
|
||||
accounts=self.wallet.accounts, limit=limit, order_by='height asc')
|
||||
self.assertEqual(len(txs), limit)
|
||||
inputs, outputs, last_tx = set(), set(), txs[0]
|
||||
for tx in txs[1:]:
|
||||
|
@ -274,19 +276,23 @@ class TestQueries(AsyncioTestCase):
|
|||
last_tx = tx
|
||||
|
||||
async def test_queries(self):
|
||||
self.assertEqual(0, await self.ledger.db.get_address_count())
|
||||
account1 = await self.create_account()
|
||||
self.assertEqual(26, await self.ledger.db.get_address_count())
|
||||
account2 = await self.create_account()
|
||||
self.assertEqual(52, await self.ledger.db.get_address_count())
|
||||
wallet1 = Wallet()
|
||||
account1 = await self.create_account(wallet1)
|
||||
self.assertEqual(26, await self.ledger.db.get_address_count(accounts=[account1]))
|
||||
wallet2 = Wallet()
|
||||
account2 = await self.create_account(wallet2)
|
||||
account3 = await self.create_account(wallet2)
|
||||
self.assertEqual(26, await self.ledger.db.get_address_count(accounts=[account2]))
|
||||
|
||||
self.assertEqual(0, await self.ledger.db.get_transaction_count(accounts=[account1, account2]))
|
||||
self.assertEqual(0, await self.ledger.db.get_transaction_count(accounts=[account1, account2, account3]))
|
||||
self.assertEqual(0, await self.ledger.db.get_utxo_count())
|
||||
self.assertEqual([], await self.ledger.db.get_utxos())
|
||||
self.assertEqual(0, await self.ledger.db.get_txo_count())
|
||||
self.assertEqual(0, await self.ledger.db.get_balance())
|
||||
self.assertEqual(0, await self.ledger.db.get_balance(wallet=wallet1))
|
||||
self.assertEqual(0, await self.ledger.db.get_balance(wallet=wallet2))
|
||||
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account1]))
|
||||
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account2]))
|
||||
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account3]))
|
||||
|
||||
tx1 = await self.create_tx_from_nothing(account1, 1)
|
||||
self.assertEqual(1, await self.ledger.db.get_transaction_count(accounts=[account1]))
|
||||
|
@ -294,20 +300,28 @@ class TestQueries(AsyncioTestCase):
|
|||
self.assertEqual(1, await self.ledger.db.get_utxo_count(accounts=[account1]))
|
||||
self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account1]))
|
||||
self.assertEqual(0, await self.ledger.db.get_txo_count(accounts=[account2]))
|
||||
self.assertEqual(10**8, await self.ledger.db.get_balance())
|
||||
self.assertEqual(10**8, await self.ledger.db.get_balance(wallet=wallet1))
|
||||
self.assertEqual(0, await self.ledger.db.get_balance(wallet=wallet2))
|
||||
self.assertEqual(10**8, await self.ledger.db.get_balance(accounts=[account1]))
|
||||
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account2]))
|
||||
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account3]))
|
||||
|
||||
tx2 = await self.create_tx_from_txo(tx1.outputs[0], account2, 2)
|
||||
tx2b = await self.create_tx_from_nothing(account3, 2)
|
||||
self.assertEqual(2, await self.ledger.db.get_transaction_count(accounts=[account1]))
|
||||
self.assertEqual(1, await self.ledger.db.get_transaction_count(accounts=[account2]))
|
||||
self.assertEqual(1, await self.ledger.db.get_transaction_count(accounts=[account3]))
|
||||
self.assertEqual(0, await self.ledger.db.get_utxo_count(accounts=[account1]))
|
||||
self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account1]))
|
||||
self.assertEqual(1, await self.ledger.db.get_utxo_count(accounts=[account2]))
|
||||
self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account2]))
|
||||
self.assertEqual(10**8, await self.ledger.db.get_balance())
|
||||
self.assertEqual(1, await self.ledger.db.get_utxo_count(accounts=[account3]))
|
||||
self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account3]))
|
||||
self.assertEqual(0, await self.ledger.db.get_balance(wallet=wallet1))
|
||||
self.assertEqual(10**8+10**8, await self.ledger.db.get_balance(wallet=wallet2))
|
||||
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account1]))
|
||||
self.assertEqual(10**8, await self.ledger.db.get_balance(accounts=[account2]))
|
||||
self.assertEqual(10**8, await self.ledger.db.get_balance(accounts=[account3]))
|
||||
|
||||
tx3 = await self.create_tx_to_nowhere(tx2.outputs[0], 3)
|
||||
self.assertEqual(2, await self.ledger.db.get_transaction_count(accounts=[account1]))
|
||||
|
@ -316,22 +330,24 @@ class TestQueries(AsyncioTestCase):
|
|||
self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account1]))
|
||||
self.assertEqual(0, await self.ledger.db.get_utxo_count(accounts=[account2]))
|
||||
self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account2]))
|
||||
self.assertEqual(0, await self.ledger.db.get_balance())
|
||||
self.assertEqual(0, await self.ledger.db.get_balance(wallet=wallet1))
|
||||
self.assertEqual(10**8, await self.ledger.db.get_balance(wallet=wallet2))
|
||||
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account1]))
|
||||
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account2]))
|
||||
self.assertEqual(10**8, await self.ledger.db.get_balance(accounts=[account3]))
|
||||
|
||||
txs = await self.ledger.db.get_transactions(accounts=[account1, account2])
|
||||
self.assertEqual([tx3.id, tx2.id, tx1.id], [tx.id for tx in txs])
|
||||
self.assertEqual([3, 2, 1], [tx.height for tx in txs])
|
||||
|
||||
txs = await self.ledger.db.get_transactions(accounts=[account1])
|
||||
txs = await self.ledger.db.get_transactions(wallet=wallet1, accounts=wallet1.accounts)
|
||||
self.assertEqual([tx2.id, tx1.id], [tx.id for tx in txs])
|
||||
self.assertEqual(txs[0].inputs[0].is_my_account, True)
|
||||
self.assertEqual(txs[0].outputs[0].is_my_account, False)
|
||||
self.assertEqual(txs[1].inputs[0].is_my_account, False)
|
||||
self.assertEqual(txs[1].outputs[0].is_my_account, True)
|
||||
|
||||
txs = await self.ledger.db.get_transactions(accounts=[account2])
|
||||
txs = await self.ledger.db.get_transactions(wallet=wallet2, accounts=[account2])
|
||||
self.assertEqual([tx3.id, tx2.id], [tx.id for tx in txs])
|
||||
self.assertEqual(txs[0].inputs[0].is_my_account, True)
|
||||
self.assertEqual(txs[0].outputs[0].is_my_account, False)
|
||||
|
@ -343,18 +359,18 @@ class TestQueries(AsyncioTestCase):
|
|||
self.assertEqual(tx.id, tx2.id)
|
||||
self.assertEqual(tx.inputs[0].is_my_account, False)
|
||||
self.assertEqual(tx.outputs[0].is_my_account, False)
|
||||
tx = await self.ledger.db.get_transaction(txid=tx2.id, accounts=[account1])
|
||||
tx = await self.ledger.db.get_transaction(wallet=wallet1, txid=tx2.id)
|
||||
self.assertEqual(tx.inputs[0].is_my_account, True)
|
||||
self.assertEqual(tx.outputs[0].is_my_account, False)
|
||||
tx = await self.ledger.db.get_transaction(txid=tx2.id, accounts=[account2])
|
||||
tx = await self.ledger.db.get_transaction(wallet=wallet2, txid=tx2.id)
|
||||
self.assertEqual(tx.inputs[0].is_my_account, False)
|
||||
self.assertEqual(tx.outputs[0].is_my_account, True)
|
||||
|
||||
# height 0 sorted to the top with the rest in descending order
|
||||
tx4 = await self.create_tx_from_nothing(account1, 0)
|
||||
txos = await self.ledger.db.get_txos()
|
||||
self.assertEqual([0, 2, 1], [txo.tx_ref.height for txo in txos])
|
||||
self.assertEqual([tx4.id, tx2.id, tx1.id], [txo.tx_ref.id for txo in txos])
|
||||
self.assertEqual([0, 2, 2, 1], [txo.tx_ref.height for txo in txos])
|
||||
self.assertEqual([tx4.id, tx2.id, tx2b.id, tx1.id], [txo.tx_ref.id for txo in txos])
|
||||
txs = await self.ledger.db.get_transactions(accounts=[account1, account2])
|
||||
self.assertEqual([0, 3, 2, 1], [tx.height for tx in txs])
|
||||
self.assertEqual([tx4.id, tx3.id, tx2.id, tx1.id], [tx.id for tx in txs])
|
||||
|
@ -382,13 +398,13 @@ class TestUpgrade(AsyncioTestCase):
|
|||
def add_address(self, address):
|
||||
with sqlite3.connect(self.path) as conn:
|
||||
conn.execute("""
|
||||
INSERT INTO pubkey_address (address, account, chain, position, pubkey)
|
||||
VALUES (?, 'account1', 0, 0, 'pubkey blob')
|
||||
INSERT INTO account_address (address, account, chain, n, pubkey, chain_code, depth)
|
||||
VALUES (?, 'account1', 0, 0, 'pubkey', 'chain_code', 0)
|
||||
""", (address,))
|
||||
|
||||
def get_addresses(self):
|
||||
with sqlite3.connect(self.path) as conn:
|
||||
sql = "SELECT address FROM pubkey_address ORDER BY address;"
|
||||
sql = "SELECT address FROM account_address ORDER BY address;"
|
||||
return [col[0] for col in conn.execute(sql).fetchall()]
|
||||
|
||||
async def test_reset_on_version_change(self):
|
||||
|
@ -401,7 +417,7 @@ class TestUpgrade(AsyncioTestCase):
|
|||
self.ledger.db.SCHEMA_VERSION = None
|
||||
self.assertEqual(self.get_tables(), [])
|
||||
await self.ledger.db.open()
|
||||
self.assertEqual(self.get_tables(), ['pubkey_address', 'tx', 'txi', 'txo'])
|
||||
self.assertEqual(self.get_tables(), ['account_address', 'pubkey_address', 'tx', 'txi', 'txo'])
|
||||
self.assertEqual(self.get_addresses(), [])
|
||||
self.add_address('address1')
|
||||
await self.ledger.db.close()
|
||||
|
@ -410,17 +426,17 @@ class TestUpgrade(AsyncioTestCase):
|
|||
self.ledger.db.SCHEMA_VERSION = '1.0'
|
||||
await self.ledger.db.open()
|
||||
self.assertEqual(self.get_version(), '1.0')
|
||||
self.assertEqual(self.get_tables(), ['pubkey_address', 'tx', 'txi', 'txo', 'version'])
|
||||
self.assertEqual(self.get_tables(), ['account_address', 'pubkey_address', 'tx', 'txi', 'txo', 'version'])
|
||||
self.assertEqual(self.get_addresses(), []) # address1 deleted during version upgrade
|
||||
self.add_address('address2')
|
||||
await self.ledger.db.close()
|
||||
|
||||
# nothing changes
|
||||
self.assertEqual(self.get_version(), '1.0')
|
||||
self.assertEqual(self.get_tables(), ['pubkey_address', 'tx', 'txi', 'txo', 'version'])
|
||||
self.assertEqual(self.get_tables(), ['account_address', 'pubkey_address', 'tx', 'txi', 'txo', 'version'])
|
||||
await self.ledger.db.open()
|
||||
self.assertEqual(self.get_version(), '1.0')
|
||||
self.assertEqual(self.get_tables(), ['pubkey_address', 'tx', 'txi', 'txo', 'version'])
|
||||
self.assertEqual(self.get_tables(), ['account_address', 'pubkey_address', 'tx', 'txi', 'txo', 'version'])
|
||||
self.assertEqual(self.get_addresses(), ['address2'])
|
||||
await self.ledger.db.close()
|
||||
|
||||
|
@ -431,7 +447,7 @@ class TestUpgrade(AsyncioTestCase):
|
|||
"""
|
||||
await self.ledger.db.open()
|
||||
self.assertEqual(self.get_version(), '1.1')
|
||||
self.assertEqual(self.get_tables(), ['foo', 'pubkey_address', 'tx', 'txi', 'txo', 'version'])
|
||||
self.assertEqual(self.get_tables(), ['account_address', 'foo', 'pubkey_address', 'tx', 'txi', 'txo', 'version'])
|
||||
self.assertEqual(self.get_addresses(), []) # all tables got reset
|
||||
await self.ledger.db.close()
|
||||
|
||||
|
|
|
@ -115,7 +115,7 @@ class HierarchicalDeterministic(AddressManager):
|
|||
return self.account.public_key.child(self.chain_number).child(index)
|
||||
|
||||
async def get_max_gap(self) -> int:
|
||||
addresses = await self._query_addresses(order_by="position ASC")
|
||||
addresses = await self._query_addresses(order_by="n asc")
|
||||
max_gap = 0
|
||||
current_gap = 0
|
||||
for address in addresses:
|
||||
|
@ -128,7 +128,7 @@ class HierarchicalDeterministic(AddressManager):
|
|||
|
||||
async def ensure_address_gap(self) -> List[str]:
|
||||
async with self.address_generator_lock:
|
||||
addresses = await self._query_addresses(limit=self.gap, order_by="position DESC")
|
||||
addresses = await self._query_addresses(limit=self.gap, order_by="n desc")
|
||||
|
||||
existing_gap = 0
|
||||
for address in addresses:
|
||||
|
@ -140,7 +140,7 @@ class HierarchicalDeterministic(AddressManager):
|
|||
if existing_gap == self.gap:
|
||||
return []
|
||||
|
||||
start = addresses[0]['position']+1 if addresses else 0
|
||||
start = addresses[0]['pubkey'].n+1 if addresses else 0
|
||||
end = start + (self.gap - existing_gap)
|
||||
new_keys = await self._generate_keys(start, end-1)
|
||||
await self.account.ledger.announce_addresses(self, new_keys)
|
||||
|
@ -149,15 +149,15 @@ class HierarchicalDeterministic(AddressManager):
|
|||
async def _generate_keys(self, start: int, end: int) -> List[str]:
|
||||
if not self.address_generator_lock.locked():
|
||||
raise RuntimeError('Should not be called outside of address_generator_lock.')
|
||||
keys = [(index, self.public_key.child(index)) for index in range(start, end+1)]
|
||||
keys = [self.public_key.child(index) for index in range(start, end+1)]
|
||||
await self.account.ledger.db.add_keys(self.account, self.chain_number, keys)
|
||||
return [key[1].address for key in keys]
|
||||
return [key.address for key in keys]
|
||||
|
||||
def get_address_records(self, only_usable: bool = False, **constraints):
|
||||
if only_usable:
|
||||
constraints['used_times__lt'] = self.maximum_uses_per_address
|
||||
if 'order_by' not in constraints:
|
||||
constraints['order_by'] = "used_times ASC, position ASC"
|
||||
constraints['order_by'] = "used_times asc, n asc"
|
||||
return self._query_addresses(**constraints)
|
||||
|
||||
|
||||
|
@ -190,9 +190,7 @@ class SingleKey(AddressManager):
|
|||
async with self.address_generator_lock:
|
||||
exists = await self.get_address_records()
|
||||
if not exists:
|
||||
await self.account.ledger.db.add_keys(
|
||||
self.account, self.chain_number, [(0, self.public_key)]
|
||||
)
|
||||
await self.account.ledger.db.add_keys(self.account, self.chain_number, [self.public_key])
|
||||
new_keys = [self.public_key.address]
|
||||
await self.account.ledger.announce_addresses(self, new_keys)
|
||||
return new_keys
|
||||
|
@ -417,16 +415,16 @@ class BaseAccount:
|
|||
}
|
||||
|
||||
def get_utxos(self, **constraints):
|
||||
return self.ledger.get_utxos(account=self, **constraints)
|
||||
return self.ledger.get_utxos(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
def get_utxo_count(self, **constraints):
|
||||
return self.ledger.get_utxo_count(account=self, **constraints)
|
||||
return self.ledger.get_utxo_count(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
def get_transactions(self, **constraints):
|
||||
return self.ledger.get_transactions(account=self, **constraints)
|
||||
return self.ledger.get_transactions(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
def get_transaction_count(self, **constraints):
|
||||
return self.ledger.get_transaction_count(account=self, **constraints)
|
||||
return self.ledger.get_transaction_count(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
async def fund(self, to_account, amount=None, everything=False,
|
||||
outputs=1, broadcast=False, **constraints):
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
import asyncio
|
||||
import json
|
||||
from binascii import hexlify
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
|
||||
|
@ -8,7 +9,7 @@ from typing import Tuple, List, Union, Callable, Any, Awaitable, Iterable, Dict,
|
|||
import sqlite3
|
||||
|
||||
from torba.client.basetransaction import BaseTransaction, TXRefImmutable
|
||||
from torba.client.baseaccount import BaseAccount
|
||||
from torba.client.bip32 import PubKey
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -165,9 +166,8 @@ def query(select, **constraints) -> Tuple[str, Dict[str, Any]]:
|
|||
offset = constraints.pop('offset', None)
|
||||
order_by = constraints.pop('order_by', None)
|
||||
|
||||
constraints.pop('my_accounts', None)
|
||||
accounts = constraints.pop('accounts', None)
|
||||
if accounts is not None:
|
||||
accounts = constraints.pop('accounts', [])
|
||||
if accounts:
|
||||
constraints['account__in'] = [a.public_key.address for a in accounts]
|
||||
|
||||
where, values = constraints_to_sql(constraints)
|
||||
|
@ -285,26 +285,33 @@ class SQLiteMixin:
|
|||
|
||||
class BaseDatabase(SQLiteMixin):
|
||||
|
||||
SCHEMA_VERSION = "1.0"
|
||||
SCHEMA_VERSION = "1.1"
|
||||
|
||||
PRAGMAS = """
|
||||
pragma journal_mode=WAL;
|
||||
"""
|
||||
|
||||
CREATE_ACCOUNT_TABLE = """
|
||||
create table if not exists account_address (
|
||||
account text not null,
|
||||
address text not null,
|
||||
chain integer not null,
|
||||
pubkey blob not null,
|
||||
chain_code blob not null,
|
||||
n integer not null,
|
||||
depth integer not null,
|
||||
primary key (account, address)
|
||||
);
|
||||
create index if not exists address_account_idx on account_address (address, account);
|
||||
"""
|
||||
|
||||
CREATE_PUBKEY_ADDRESS_TABLE = """
|
||||
create table if not exists pubkey_address (
|
||||
address text primary key,
|
||||
account text not null,
|
||||
chain integer not null,
|
||||
position integer not null,
|
||||
pubkey blob not null,
|
||||
history text,
|
||||
used_times integer not null default 0
|
||||
);
|
||||
"""
|
||||
CREATE_PUBKEY_ADDRESS_INDEX = """
|
||||
create index if not exists pubkey_address_account_idx on pubkey_address (account);
|
||||
"""
|
||||
|
||||
CREATE_TX_TABLE = """
|
||||
create table if not exists tx (
|
||||
|
@ -326,8 +333,6 @@ class BaseDatabase(SQLiteMixin):
|
|||
script blob not null,
|
||||
is_reserved boolean not null default 0
|
||||
);
|
||||
"""
|
||||
CREATE_TXO_INDEX = """
|
||||
create index if not exists txo_address_idx on txo (address);
|
||||
"""
|
||||
|
||||
|
@ -337,21 +342,17 @@ class BaseDatabase(SQLiteMixin):
|
|||
txoid text references txo,
|
||||
address text references pubkey_address
|
||||
);
|
||||
"""
|
||||
CREATE_TXI_INDEX = """
|
||||
create index if not exists txi_address_idx on txi (address);
|
||||
create index if not exists txi_txoid_idx on txi (txoid);
|
||||
"""
|
||||
|
||||
CREATE_TABLES_QUERY = (
|
||||
PRAGMAS +
|
||||
CREATE_TX_TABLE +
|
||||
CREATE_ACCOUNT_TABLE +
|
||||
CREATE_PUBKEY_ADDRESS_TABLE +
|
||||
CREATE_PUBKEY_ADDRESS_INDEX +
|
||||
CREATE_TX_TABLE +
|
||||
CREATE_TXO_TABLE +
|
||||
CREATE_TXO_INDEX +
|
||||
CREATE_TXI_TABLE +
|
||||
CREATE_TXI_INDEX
|
||||
CREATE_TXI_TABLE
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
@ -434,25 +435,22 @@ class BaseDatabase(SQLiteMixin):
|
|||
|
||||
async def select_transactions(self, cols, accounts=None, **constraints):
|
||||
if not set(constraints) & {'txid', 'txid__in'}:
|
||||
assert accounts is not None, "'accounts' argument required when no 'txid' constraint"
|
||||
assert accounts, "'accounts' argument required when no 'txid' constraint is present"
|
||||
constraints.update({
|
||||
f'$account{i}': a.public_key.address for i, a in enumerate(accounts)
|
||||
})
|
||||
account_values = ', '.join([f':$account{i}' for i in range(len(accounts))])
|
||||
where = f" WHERE account_address.account IN ({account_values})"
|
||||
constraints['txid__in'] = f"""
|
||||
SELECT txo.txid FROM txo
|
||||
INNER JOIN pubkey_address USING (address) WHERE pubkey_address.account IN ({account_values})
|
||||
SELECT txo.txid FROM txo JOIN account_address USING (address) {where}
|
||||
UNION
|
||||
SELECT txi.txid FROM txi
|
||||
INNER JOIN pubkey_address USING (address) WHERE pubkey_address.account IN ({account_values})
|
||||
SELECT txi.txid FROM txi JOIN account_address USING (address) {where}
|
||||
"""
|
||||
return await self.db.execute_fetchall(
|
||||
*query("SELECT {} FROM tx".format(cols), **constraints)
|
||||
)
|
||||
|
||||
async def get_transactions(self, **constraints):
|
||||
accounts = constraints.get('accounts', None)
|
||||
|
||||
async def get_transactions(self, wallet=None, **constraints):
|
||||
tx_rows = await self.select_transactions(
|
||||
'txid, raw, height, position, is_verified',
|
||||
order_by=constraints.pop('order_by', ["height=0 DESC", "height DESC", "position DESC"]),
|
||||
|
@ -477,7 +475,7 @@ class BaseDatabase(SQLiteMixin):
|
|||
annotated_txos.update({
|
||||
txo.id: txo for txo in
|
||||
(await self.get_txos(
|
||||
my_accounts=accounts,
|
||||
wallet=wallet,
|
||||
txid__in=txids[offset:offset+step],
|
||||
))
|
||||
})
|
||||
|
@ -487,7 +485,7 @@ class BaseDatabase(SQLiteMixin):
|
|||
referenced_txos.update({
|
||||
txo.id: txo for txo in
|
||||
(await self.get_txos(
|
||||
my_accounts=accounts,
|
||||
wallet=wallet,
|
||||
txoid__in=txi_txoids[offset:offset+step],
|
||||
))
|
||||
})
|
||||
|
@ -507,6 +505,7 @@ class BaseDatabase(SQLiteMixin):
|
|||
return txs
|
||||
|
||||
async def get_transaction_count(self, **constraints):
|
||||
constraints.pop('wallet', None)
|
||||
constraints.pop('offset', None)
|
||||
constraints.pop('limit', None)
|
||||
constraints.pop('order_by', None)
|
||||
|
@ -519,23 +518,24 @@ class BaseDatabase(SQLiteMixin):
|
|||
return txs[0]
|
||||
|
||||
async def select_txos(self, cols, **constraints):
|
||||
return await self.db.execute_fetchall(*query(
|
||||
"SELECT {} FROM txo"
|
||||
" JOIN pubkey_address USING (address)"
|
||||
" JOIN tx USING (txid)".format(cols), **constraints
|
||||
))
|
||||
sql = "SELECT {} FROM txo JOIN tx USING (txid)"
|
||||
if 'accounts' in constraints:
|
||||
sql += " JOIN account_address USING (address)"
|
||||
return await self.db.execute_fetchall(*query(sql.format(cols), **constraints))
|
||||
|
||||
async def get_txos(self, my_accounts=None, no_tx=False, **constraints):
|
||||
my_accounts = [
|
||||
(a.public_key.address if isinstance(a, BaseAccount) else a)
|
||||
for a in (my_accounts or constraints.get('accounts', []))
|
||||
]
|
||||
async def get_txos(self, wallet=None, no_tx=False, **constraints):
|
||||
my_accounts = set(a.public_key.address for a in wallet.accounts) if wallet else set()
|
||||
if 'order_by' not in constraints:
|
||||
constraints['order_by'] = [
|
||||
"tx.height=0 DESC", "tx.height DESC", "tx.position DESC", "txo.position"]
|
||||
"tx.height=0 DESC", "tx.height DESC", "tx.position DESC", "txo.position"
|
||||
]
|
||||
rows = await self.select_txos(
|
||||
"tx.txid, raw, tx.height, tx.position, tx.is_verified, "
|
||||
"txo.position, chain, account, amount, script",
|
||||
"""
|
||||
tx.txid, raw, tx.height, tx.position, tx.is_verified, txo.position, amount, script, (
|
||||
select json_group_object(account, chain) from account_address
|
||||
where account_address.address=txo.address
|
||||
)
|
||||
""",
|
||||
**constraints
|
||||
)
|
||||
txos = []
|
||||
|
@ -544,8 +544,8 @@ class BaseDatabase(SQLiteMixin):
|
|||
for row in rows:
|
||||
if no_tx:
|
||||
txo = output_class(
|
||||
amount=row[8],
|
||||
script=output_class.script_class(row[9]),
|
||||
amount=row[6],
|
||||
script=output_class.script_class(row[7]),
|
||||
tx_ref=TXRefImmutable.from_id(row[0], row[2]),
|
||||
position=row[5]
|
||||
)
|
||||
|
@ -555,12 +555,18 @@ class BaseDatabase(SQLiteMixin):
|
|||
row[1], height=row[2], position=row[3], is_verified=row[4]
|
||||
)
|
||||
txo = txs[row[0]].outputs[row[5]]
|
||||
txo.is_change = row[6] == 1
|
||||
txo.is_my_account = row[7] in my_accounts
|
||||
row_accounts = json.loads(row[8])
|
||||
account_match = set(row_accounts) & my_accounts
|
||||
if account_match:
|
||||
txo.is_my_account = True
|
||||
txo.is_change = row_accounts[account_match.pop()] == 1
|
||||
else:
|
||||
txo.is_change = txo.is_my_account = False
|
||||
txos.append(txo)
|
||||
return txos
|
||||
|
||||
async def get_txo_count(self, **constraints):
|
||||
constraints.pop('wallet', None)
|
||||
constraints.pop('offset', None)
|
||||
constraints.pop('limit', None)
|
||||
constraints.pop('order_by', None)
|
||||
|
@ -580,41 +586,56 @@ class BaseDatabase(SQLiteMixin):
|
|||
self.constrain_utxo(constraints)
|
||||
return self.get_txo_count(**constraints)
|
||||
|
||||
async def get_balance(self, **constraints):
|
||||
async def get_balance(self, wallet=None, accounts=None, **constraints):
|
||||
assert wallet or accounts, \
|
||||
"'wallet' or 'accounts' constraints required to calculate balance"
|
||||
constraints['accounts'] = accounts or wallet.accounts
|
||||
self.constrain_utxo(constraints)
|
||||
balance = await self.select_txos('SUM(amount)', **constraints)
|
||||
return balance[0][0] or 0
|
||||
|
||||
async def select_addresses(self, cols, **constraints):
|
||||
return await self.db.execute_fetchall(*query(
|
||||
"SELECT {} FROM pubkey_address".format(cols), **constraints
|
||||
"SELECT {} FROM pubkey_address JOIN account_address USING (address)".format(cols),
|
||||
**constraints
|
||||
))
|
||||
|
||||
async def get_addresses(self, cols=('address', 'account', 'chain', 'position', 'used_times'),
|
||||
**constraints):
|
||||
addresses = await self.select_addresses(', '.join(cols), **constraints)
|
||||
return rows_to_dict(addresses, cols)
|
||||
async def get_addresses(self, **constraints):
|
||||
cols = (
|
||||
'address', 'account', 'chain', 'history', 'used_times',
|
||||
'pubkey', 'chain_code', 'n', 'depth'
|
||||
)
|
||||
addresses = rows_to_dict(await self.select_addresses(', '.join(cols), **constraints), cols)
|
||||
for address in addresses:
|
||||
address['pubkey'] = PubKey(
|
||||
self.ledger, address.pop('pubkey'), address.pop('chain_code'),
|
||||
address.pop('n'), address.pop('depth')
|
||||
)
|
||||
return addresses
|
||||
|
||||
async def get_address_count(self, **constraints):
|
||||
count = await self.select_addresses('count(*)', **constraints)
|
||||
return count[0][0]
|
||||
|
||||
async def get_address(self, **constraints):
|
||||
addresses = await self.get_addresses(
|
||||
cols=('address', 'account', 'chain', 'position', 'pubkey', 'history', 'used_times'),
|
||||
limit=1, **constraints
|
||||
)
|
||||
addresses = await self.get_addresses(limit=1, **constraints)
|
||||
if addresses:
|
||||
return addresses[0]
|
||||
|
||||
async def add_keys(self, account, chain, keys):
|
||||
async def add_keys(self, account, chain, pubkeys):
|
||||
await self.db.executemany(
|
||||
"insert into pubkey_address (address, account, chain, position, pubkey) values (?, ?, ?, ?, ?)",
|
||||
(
|
||||
(pubkey.address, account.public_key.address, chain, position,
|
||||
sqlite3.Binary(pubkey.pubkey_bytes))
|
||||
for position, pubkey in keys
|
||||
"insert or ignore into account_address "
|
||||
"(account, address, chain, pubkey, chain_code, n, depth) values "
|
||||
"(?, ?, ?, ?, ?, ?, ?)", ((
|
||||
account.id, k.address, chain,
|
||||
sqlite3.Binary(k.pubkey_bytes),
|
||||
sqlite3.Binary(k.chain_code),
|
||||
k.n, k.depth
|
||||
) for k in pubkeys)
|
||||
)
|
||||
await self.db.executemany(
|
||||
"insert or ignore into pubkey_address (address) values (?)",
|
||||
((pubkey.address,) for pubkey in pubkeys)
|
||||
)
|
||||
|
||||
async def _set_address_history(self, address, history):
|
||||
|
|
|
@ -179,29 +179,29 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
def add_account(self, account: baseaccount.BaseAccount):
|
||||
self.accounts.append(account)
|
||||
|
||||
async def _get_account_and_address_info_for_address(self, address):
|
||||
match = await self.db.get_address(address=address)
|
||||
async def _get_account_and_address_info_for_address(self, wallet, address):
|
||||
match = await self.db.get_address(accounts=wallet.accounts, address=address)
|
||||
if match:
|
||||
for account in self.accounts:
|
||||
for account in wallet.accounts:
|
||||
if match['account'] == account.public_key.address:
|
||||
return account, match
|
||||
|
||||
async def get_private_key_for_address(self, address) -> Optional[PrivateKey]:
|
||||
match = await self._get_account_and_address_info_for_address(address)
|
||||
async def get_private_key_for_address(self, wallet, address) -> Optional[PrivateKey]:
|
||||
match = await self._get_account_and_address_info_for_address(wallet, address)
|
||||
if match:
|
||||
account, address_info = match
|
||||
return account.get_private_key(address_info['chain'], address_info['position'])
|
||||
return account.get_private_key(address_info['chain'], address_info['pubkey'].n)
|
||||
return None
|
||||
|
||||
async def get_public_key_for_address(self, address) -> Optional[PubKey]:
|
||||
match = await self._get_account_and_address_info_for_address(address)
|
||||
async def get_public_key_for_address(self, wallet, address) -> Optional[PubKey]:
|
||||
match = await self._get_account_and_address_info_for_address(wallet, address)
|
||||
if match:
|
||||
account, address_info = match
|
||||
return account.get_public_key(address_info['chain'], address_info['position'])
|
||||
_, address_info = match
|
||||
return address_info['pubkey']
|
||||
return None
|
||||
|
||||
async def get_account_for_address(self, address):
|
||||
match = await self._get_account_and_address_info_for_address(address)
|
||||
async def get_account_for_address(self, wallet, address):
|
||||
match = await self._get_account_and_address_info_for_address(wallet, address)
|
||||
if match:
|
||||
return match[0]
|
||||
|
||||
|
@ -214,15 +214,9 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
return estimators
|
||||
|
||||
async def get_addresses(self, **constraints):
|
||||
self.constraint_account_or_all(constraints)
|
||||
addresses = await self.db.get_addresses(**constraints)
|
||||
for address in addresses:
|
||||
public_key = await self.get_public_key_for_address(address['address'])
|
||||
address['public_key'] = public_key.extended_key_string()
|
||||
return addresses
|
||||
return await self.db.get_addresses(**constraints)
|
||||
|
||||
def get_address_count(self, **constraints):
|
||||
self.constraint_account_or_all(constraints)
|
||||
return self.db.get_address_count(**constraints)
|
||||
|
||||
async def get_spendable_utxos(self, amount: int, funding_accounts):
|
||||
|
@ -244,29 +238,16 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
def release_tx(self, tx):
|
||||
return self.release_outputs([txi.txo_ref.txo for txi in tx.inputs])
|
||||
|
||||
def constraint_account_or_all(self, constraints):
|
||||
if 'accounts' in constraints:
|
||||
return
|
||||
account = constraints.pop('account', None)
|
||||
if account:
|
||||
constraints['accounts'] = [account]
|
||||
else:
|
||||
constraints['accounts'] = self.accounts
|
||||
|
||||
def get_utxos(self, **constraints):
|
||||
self.constraint_account_or_all(constraints)
|
||||
return self.db.get_utxos(**constraints)
|
||||
|
||||
def get_utxo_count(self, **constraints):
|
||||
self.constraint_account_or_all(constraints)
|
||||
return self.db.get_utxo_count(**constraints)
|
||||
|
||||
def get_transactions(self, **constraints):
|
||||
self.constraint_account_or_all(constraints)
|
||||
return self.db.get_transactions(**constraints)
|
||||
|
||||
def get_transaction_count(self, **constraints):
|
||||
self.constraint_account_or_all(constraints)
|
||||
return self.db.get_transaction_count(**constraints)
|
||||
|
||||
async def get_local_status_and_history(self, address, history=None):
|
||||
|
@ -577,7 +558,7 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
addresses.add(
|
||||
self.hash160_to_address(txo.script.values['pubkey_hash'])
|
||||
)
|
||||
records = await self.db.get_addresses(cols=('address',), address__in=addresses)
|
||||
records = await self.db.get_addresses(address__in=addresses)
|
||||
_, pending = await asyncio.wait([
|
||||
self.on_transaction.where(partial(
|
||||
lambda a, e: a == e.address and e.tx.height >= height and e.tx.id == tx.id,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from typing import Type, MutableSequence, MutableMapping
|
||||
from typing import Type, MutableSequence, MutableMapping, Optional
|
||||
|
||||
from torba.client.baseledger import BaseLedger, LedgerRegistry
|
||||
from torba.client.wallet import Wallet, WalletStorage
|
||||
|
@ -41,16 +41,6 @@ class BaseWalletManager:
|
|||
self.wallets.append(wallet)
|
||||
return wallet
|
||||
|
||||
async def get_detailed_accounts(self, **kwargs):
|
||||
ledgers = {}
|
||||
for i, account in enumerate(self.accounts):
|
||||
details = await account.get_details(**kwargs)
|
||||
details['is_default'] = i == 0
|
||||
ledger_id = account.ledger.get_id()
|
||||
ledgers.setdefault(ledger_id, [])
|
||||
ledgers[ledger_id].append(details)
|
||||
return ledgers
|
||||
|
||||
@property
|
||||
def default_wallet(self):
|
||||
for wallet in self.wallets:
|
||||
|
@ -78,3 +68,14 @@ class BaseWalletManager:
|
|||
l.stop() for l in self.ledgers.values()
|
||||
))
|
||||
self.running = False
|
||||
|
||||
def get_wallet_or_default(self, wallet_id: Optional[str]) -> Wallet:
|
||||
if wallet_id is None:
|
||||
return self.default_wallet
|
||||
return self.get_wallet_or_error(wallet_id)
|
||||
|
||||
def get_wallet_or_error(self, wallet_id: str) -> Wallet:
|
||||
for wallet in self.wallets:
|
||||
if wallet.id == wallet_id:
|
||||
return wallet
|
||||
raise ValueError(f"Couldn't find wallet: {wallet_id}.")
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import logging
|
||||
import typing
|
||||
from typing import List, Iterable, Optional
|
||||
from typing import List, Iterable, Optional, Tuple
|
||||
from binascii import hexlify
|
||||
|
||||
from torba.client.basescript import BaseInputScript, BaseOutputScript
|
||||
|
@ -12,7 +12,7 @@ from torba.client.util import ReadOnlyList
|
|||
from torba.client.errors import InsufficientFundsError
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from torba.client import baseledger
|
||||
from torba.client import baseledger, wallet as basewallet
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
|
@ -431,21 +431,32 @@ class BaseTransaction:
|
|||
self.locktime = stream.read_uint32()
|
||||
|
||||
@classmethod
|
||||
def ensure_all_have_same_ledger(cls, funding_accounts: Iterable[BaseAccount],
|
||||
change_account: BaseAccount = None) -> 'baseledger.BaseLedger':
|
||||
ledger = None
|
||||
def ensure_all_have_same_ledger_and_wallet(
|
||||
cls, funding_accounts: Iterable[BaseAccount],
|
||||
change_account: BaseAccount = None) -> Tuple['baseledger.BaseLedger', 'basewallet.Wallet']:
|
||||
ledger = wallet = None
|
||||
for account in funding_accounts:
|
||||
if ledger is None:
|
||||
ledger = account.ledger
|
||||
wallet = account.wallet
|
||||
if ledger != account.ledger:
|
||||
raise ValueError(
|
||||
'All funding accounts used to create a transaction must be on the same ledger.'
|
||||
)
|
||||
if change_account is not None and change_account.ledger != ledger:
|
||||
if wallet != account.wallet:
|
||||
raise ValueError(
|
||||
'All funding accounts used to create a transaction must be from the same wallet.'
|
||||
)
|
||||
if change_account is not None:
|
||||
if change_account.ledger != ledger:
|
||||
raise ValueError('Change account must use same ledger as funding accounts.')
|
||||
if change_account.wallet != wallet:
|
||||
raise ValueError('Change account must use same wallet as funding accounts.')
|
||||
if ledger is None:
|
||||
raise ValueError('No ledger found.')
|
||||
return ledger
|
||||
if wallet is None:
|
||||
raise ValueError('No wallet found.')
|
||||
return ledger, wallet
|
||||
|
||||
@classmethod
|
||||
async def create(cls, inputs: Iterable[BaseInput], outputs: Iterable[BaseOutput],
|
||||
|
@ -458,7 +469,7 @@ class BaseTransaction:
|
|||
.add_inputs(inputs) \
|
||||
.add_outputs(outputs)
|
||||
|
||||
ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account)
|
||||
ledger, _ = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account)
|
||||
|
||||
# value of the outputs plus associated fees
|
||||
cost = (
|
||||
|
@ -524,15 +535,15 @@ class BaseTransaction:
|
|||
return hash_type
|
||||
|
||||
async def sign(self, funding_accounts: Iterable[BaseAccount]):
|
||||
ledger = self.ensure_all_have_same_ledger(funding_accounts)
|
||||
ledger, wallet = self.ensure_all_have_same_ledger_and_wallet(funding_accounts)
|
||||
for i, txi in enumerate(self._inputs):
|
||||
assert txi.script is not None
|
||||
assert txi.txo_ref.txo is not None
|
||||
txo_script = txi.txo_ref.txo.script
|
||||
if txo_script.is_pay_pubkey_hash:
|
||||
address = ledger.hash160_to_address(txo_script.values['pubkey_hash'])
|
||||
private_key = await ledger.get_private_key_for_address(address)
|
||||
assert private_key is not None
|
||||
private_key = await ledger.get_private_key_for_address(wallet, address)
|
||||
assert private_key is not None, 'Cannot find private key for signing output.'
|
||||
tx = self._serialize_for_signature(i)
|
||||
txi.script.values['signature'] = \
|
||||
private_key.sign(tx) + bytes((self.signature_hash_type(1),))
|
||||
|
|
|
@ -3,7 +3,7 @@ import stat
|
|||
import json
|
||||
import zlib
|
||||
import typing
|
||||
from typing import Sequence, MutableSequence
|
||||
from typing import List, Sequence, MutableSequence, Optional
|
||||
from hashlib import sha256
|
||||
from operator import attrgetter
|
||||
from torba.client.hash import better_aes_encrypt, better_aes_decrypt
|
||||
|
@ -25,12 +25,51 @@ class Wallet:
|
|||
self.accounts = accounts or []
|
||||
self.storage = storage or WalletStorage()
|
||||
|
||||
def add_account(self, account):
|
||||
@property
|
||||
def id(self):
|
||||
if self.storage.path:
|
||||
return os.path.basename(self.storage.path)
|
||||
return self.name
|
||||
|
||||
def add_account(self, account: 'baseaccount.BaseAccount'):
|
||||
self.accounts.append(account)
|
||||
|
||||
def generate_account(self, ledger: 'baseledger.BaseLedger') -> 'baseaccount.BaseAccount':
|
||||
return ledger.account_class.generate(ledger, self)
|
||||
|
||||
@property
|
||||
def default_account(self) -> Optional['baseaccount.BaseAccount']:
|
||||
for account in self.accounts:
|
||||
return account
|
||||
return None
|
||||
|
||||
def get_account_or_default(self, account_id: str) -> Optional['baseaccount.BaseAccount']:
|
||||
if account_id is None:
|
||||
return self.default_account
|
||||
return self.get_account_or_error(account_id)
|
||||
|
||||
def get_account_or_error(self, account_id: str) -> 'baseaccount.BaseAccount':
|
||||
for account in self.accounts:
|
||||
if account.id == account_id:
|
||||
return account
|
||||
raise ValueError(f"Couldn't find account: {account_id}.")
|
||||
|
||||
def get_accounts_or_all(self, account_ids: List[str]) -> Sequence['baseaccount.BaseAccount']:
|
||||
return [
|
||||
self.get_account_or_error(account_id)
|
||||
for account_id in account_ids
|
||||
] if account_ids else self.accounts
|
||||
|
||||
async def get_detailed_accounts(self, **kwargs):
|
||||
ledgers = {}
|
||||
for i, account in enumerate(self.accounts):
|
||||
details = await account.get_details(**kwargs)
|
||||
details['is_default'] = i == 0
|
||||
ledger_id = account.ledger.get_id()
|
||||
ledgers.setdefault(ledger_id, [])
|
||||
ledgers[ledger_id].append(details)
|
||||
return ledgers
|
||||
|
||||
@classmethod
|
||||
def from_storage(cls, storage: 'WalletStorage', manager: 'basemanager.BaseWalletManager') -> 'Wallet':
|
||||
json_dict = storage.read()
|
||||
|
@ -54,11 +93,6 @@ class Wallet:
|
|||
def save(self):
|
||||
self.storage.write(self.to_dict())
|
||||
|
||||
@property
|
||||
def default_account(self):
|
||||
for account in self.accounts:
|
||||
return account
|
||||
|
||||
@property
|
||||
def hash(self) -> bytes:
|
||||
h = sha256()
|
||||
|
|
|
@ -151,7 +151,9 @@ class WalletNode:
|
|||
|
||||
async def start(self, spv_node: 'SPVNode', seed=None, connect=True):
|
||||
self.data_path = tempfile.mkdtemp()
|
||||
wallet_file_name = os.path.join(self.data_path, 'my_wallet.json')
|
||||
wallets_dir = os.path.join(self.data_path, 'wallets')
|
||||
os.mkdir(wallets_dir)
|
||||
wallet_file_name = os.path.join(wallets_dir, 'my_wallet.json')
|
||||
with open(wallet_file_name, 'w') as wallet_file:
|
||||
wallet_file.write('{"version": 1, "accounts": []}\n')
|
||||
self.manager = self.manager_class.from_config({
|
||||
|
|
Loading…
Reference in a new issue