twisted -> asyncio
This commit is contained in:
parent
0ce4b9a7de
commit
2c5fd4aade
21 changed files with 480 additions and 721 deletions
|
@ -1,5 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from orchstr8.testcase import IntegrationTestCase, d2f
|
from orchstr8.testcase import IntegrationTestCase
|
||||||
from torba.constants import COIN
|
from torba.constants import COIN
|
||||||
|
|
||||||
|
|
||||||
|
@ -9,14 +9,14 @@ class BasicTransactionTests(IntegrationTestCase):
|
||||||
|
|
||||||
async def test_sending_and_receiving(self):
|
async def test_sending_and_receiving(self):
|
||||||
account1, account2 = self.account, self.wallet.generate_account(self.ledger)
|
account1, account2 = self.account, self.wallet.generate_account(self.ledger)
|
||||||
await d2f(self.ledger.update_account(account2))
|
await self.ledger.update_account(account2)
|
||||||
|
|
||||||
self.assertEqual(await self.get_balance(account1), 0)
|
self.assertEqual(await self.get_balance(account1), 0)
|
||||||
self.assertEqual(await self.get_balance(account2), 0)
|
self.assertEqual(await self.get_balance(account2), 0)
|
||||||
|
|
||||||
sendtxids = []
|
sendtxids = []
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
address1 = await d2f(account1.receiving.get_or_create_usable_address())
|
address1 = await account1.receiving.get_or_create_usable_address()
|
||||||
sendtxid = await self.blockchain.send_to_address(address1, 1.1)
|
sendtxid = await self.blockchain.send_to_address(address1, 1.1)
|
||||||
sendtxids.append(sendtxid)
|
sendtxids.append(sendtxid)
|
||||||
await self.on_transaction_id(sendtxid) # mempool
|
await self.on_transaction_id(sendtxid) # mempool
|
||||||
|
@ -28,13 +28,13 @@ class BasicTransactionTests(IntegrationTestCase):
|
||||||
self.assertEqual(round(await self.get_balance(account1)/COIN, 1), 5.5)
|
self.assertEqual(round(await self.get_balance(account1)/COIN, 1), 5.5)
|
||||||
self.assertEqual(round(await self.get_balance(account2)/COIN, 1), 0)
|
self.assertEqual(round(await self.get_balance(account2)/COIN, 1), 0)
|
||||||
|
|
||||||
address2 = await d2f(account2.receiving.get_or_create_usable_address())
|
address2 = await account2.receiving.get_or_create_usable_address()
|
||||||
hash2 = self.ledger.address_to_hash160(address2)
|
hash2 = self.ledger.address_to_hash160(address2)
|
||||||
tx = await d2f(self.ledger.transaction_class.create(
|
tx = await self.ledger.transaction_class.create(
|
||||||
[],
|
[],
|
||||||
[self.ledger.transaction_class.output_class.pay_pubkey_hash(2*COIN, hash2)],
|
[self.ledger.transaction_class.output_class.pay_pubkey_hash(2*COIN, hash2)],
|
||||||
[account1], account1
|
[account1], account1
|
||||||
))
|
)
|
||||||
await self.broadcast(tx)
|
await self.broadcast(tx)
|
||||||
await self.on_transaction(tx) # mempool
|
await self.on_transaction(tx) # mempool
|
||||||
await self.blockchain.generate(1)
|
await self.blockchain.generate(1)
|
||||||
|
@ -43,18 +43,18 @@ class BasicTransactionTests(IntegrationTestCase):
|
||||||
self.assertEqual(round(await self.get_balance(account1)/COIN, 1), 3.5)
|
self.assertEqual(round(await self.get_balance(account1)/COIN, 1), 3.5)
|
||||||
self.assertEqual(round(await self.get_balance(account2)/COIN, 1), 2.0)
|
self.assertEqual(round(await self.get_balance(account2)/COIN, 1), 2.0)
|
||||||
|
|
||||||
utxos = await d2f(self.account.get_utxos())
|
utxos = await self.account.get_utxos()
|
||||||
tx = await d2f(self.ledger.transaction_class.create(
|
tx = await self.ledger.transaction_class.create(
|
||||||
[self.ledger.transaction_class.input_class.spend(utxos[0])],
|
[self.ledger.transaction_class.input_class.spend(utxos[0])],
|
||||||
[],
|
[],
|
||||||
[account1], account1
|
[account1], account1
|
||||||
))
|
)
|
||||||
await self.broadcast(tx)
|
await self.broadcast(tx)
|
||||||
await self.on_transaction(tx) # mempool
|
await self.on_transaction(tx) # mempool
|
||||||
await self.blockchain.generate(1)
|
await self.blockchain.generate(1)
|
||||||
await self.on_transaction(tx) # confirmed
|
await self.on_transaction(tx) # confirmed
|
||||||
|
|
||||||
txs = await d2f(account1.get_transactions())
|
txs = await account1.get_transactions()
|
||||||
tx = txs[1]
|
tx = txs[1]
|
||||||
self.assertEqual(round(tx.inputs[0].txo_ref.txo.amount/COIN, 1), 1.1)
|
self.assertEqual(round(tx.inputs[0].txo_ref.txo.amount/COIN, 1), 1.1)
|
||||||
self.assertEqual(round(tx.inputs[1].txo_ref.txo.amount/COIN, 1), 1.1)
|
self.assertEqual(round(tx.inputs[1].txo_ref.txo.amount/COIN, 1), 1.1)
|
||||||
|
|
|
@ -1,25 +1,26 @@
|
||||||
from binascii import hexlify
|
from binascii import hexlify
|
||||||
from twisted.trial import unittest
|
|
||||||
from twisted.internet import defer
|
from orchstr8.testcase import AsyncioTestCase
|
||||||
|
|
||||||
from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class
|
from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class
|
||||||
from torba.baseaccount import HierarchicalDeterministic, SingleKey
|
from torba.baseaccount import HierarchicalDeterministic, SingleKey
|
||||||
from torba.wallet import Wallet
|
from torba.wallet import Wallet
|
||||||
|
|
||||||
|
|
||||||
class TestHierarchicalDeterministicAccount(unittest.TestCase):
|
class TestHierarchicalDeterministicAccount(AsyncioTestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def asyncSetUp(self):
|
||||||
def setUp(self):
|
|
||||||
self.ledger = ledger_class({
|
self.ledger = ledger_class({
|
||||||
'db': ledger_class.database_class(':memory:'),
|
'db': ledger_class.database_class(':memory:'),
|
||||||
'headers': ledger_class.headers_class(':memory:'),
|
'headers': ledger_class.headers_class(':memory:'),
|
||||||
})
|
})
|
||||||
yield self.ledger.db.open()
|
await self.ledger.db.open()
|
||||||
self.account = self.ledger.account_class.generate(self.ledger, Wallet(), "torba")
|
self.account = self.ledger.account_class.generate(self.ledger, Wallet(), "torba")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def asyncTearDown(self):
|
||||||
def test_generate_account(self):
|
await self.ledger.db.close()
|
||||||
|
|
||||||
|
async def test_generate_account(self):
|
||||||
account = self.account
|
account = self.account
|
||||||
|
|
||||||
self.assertEqual(account.ledger, self.ledger)
|
self.assertEqual(account.ledger, self.ledger)
|
||||||
|
@ -27,81 +28,77 @@ class TestHierarchicalDeterministicAccount(unittest.TestCase):
|
||||||
self.assertEqual(account.public_key.ledger, self.ledger)
|
self.assertEqual(account.public_key.ledger, self.ledger)
|
||||||
self.assertEqual(account.private_key.public_key, account.public_key)
|
self.assertEqual(account.private_key.public_key, account.public_key)
|
||||||
|
|
||||||
addresses = yield account.receiving.get_addresses()
|
addresses = await account.receiving.get_addresses()
|
||||||
self.assertEqual(len(addresses), 0)
|
self.assertEqual(len(addresses), 0)
|
||||||
addresses = yield account.change.get_addresses()
|
addresses = await account.change.get_addresses()
|
||||||
self.assertEqual(len(addresses), 0)
|
self.assertEqual(len(addresses), 0)
|
||||||
|
|
||||||
yield account.ensure_address_gap()
|
await account.ensure_address_gap()
|
||||||
|
|
||||||
addresses = yield account.receiving.get_addresses()
|
addresses = await account.receiving.get_addresses()
|
||||||
self.assertEqual(len(addresses), 20)
|
self.assertEqual(len(addresses), 20)
|
||||||
addresses = yield account.change.get_addresses()
|
addresses = await account.change.get_addresses()
|
||||||
self.assertEqual(len(addresses), 6)
|
self.assertEqual(len(addresses), 6)
|
||||||
|
|
||||||
addresses = yield account.get_addresses()
|
addresses = await account.get_addresses()
|
||||||
self.assertEqual(len(addresses), 26)
|
self.assertEqual(len(addresses), 26)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def test_generate_keys_over_batch_threshold_saves_it_properly(self):
|
||||||
def test_generate_keys_over_batch_threshold_saves_it_properly(self):
|
await self.account.receiving.generate_keys(0, 200)
|
||||||
yield self.account.receiving.generate_keys(0, 200)
|
records = await self.account.receiving.get_address_records()
|
||||||
records = yield self.account.receiving.get_address_records()
|
|
||||||
self.assertEqual(201, len(records))
|
self.assertEqual(201, len(records))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def test_ensure_address_gap(self):
|
||||||
def test_ensure_address_gap(self):
|
|
||||||
account = self.account
|
account = self.account
|
||||||
|
|
||||||
self.assertIsInstance(account.receiving, HierarchicalDeterministic)
|
self.assertIsInstance(account.receiving, HierarchicalDeterministic)
|
||||||
|
|
||||||
yield account.receiving.generate_keys(4, 7)
|
await account.receiving.generate_keys(4, 7)
|
||||||
yield account.receiving.generate_keys(0, 3)
|
await account.receiving.generate_keys(0, 3)
|
||||||
yield account.receiving.generate_keys(8, 11)
|
await account.receiving.generate_keys(8, 11)
|
||||||
records = yield account.receiving.get_address_records()
|
records = await account.receiving.get_address_records()
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
[r['position'] for r in records],
|
[r['position'] for r in records],
|
||||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
|
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
|
||||||
)
|
)
|
||||||
|
|
||||||
# we have 12, but default gap is 20
|
# we have 12, but default gap is 20
|
||||||
new_keys = yield account.receiving.ensure_address_gap()
|
new_keys = await account.receiving.ensure_address_gap()
|
||||||
self.assertEqual(len(new_keys), 8)
|
self.assertEqual(len(new_keys), 8)
|
||||||
records = yield account.receiving.get_address_records()
|
records = await account.receiving.get_address_records()
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
[r['position'] for r in records],
|
[r['position'] for r in records],
|
||||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
|
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
|
||||||
)
|
)
|
||||||
|
|
||||||
# case #1: no new addresses needed
|
# case #1: no new addresses needed
|
||||||
empty = yield account.receiving.ensure_address_gap()
|
empty = await account.receiving.ensure_address_gap()
|
||||||
self.assertEqual(len(empty), 0)
|
self.assertEqual(len(empty), 0)
|
||||||
|
|
||||||
# case #2: only one new addressed needed
|
# case #2: only one new addressed needed
|
||||||
records = yield account.receiving.get_address_records()
|
records = await account.receiving.get_address_records()
|
||||||
yield self.ledger.db.set_address_history(records[0]['address'], 'a:1:')
|
await self.ledger.db.set_address_history(records[0]['address'], 'a:1:')
|
||||||
new_keys = yield account.receiving.ensure_address_gap()
|
new_keys = await account.receiving.ensure_address_gap()
|
||||||
self.assertEqual(len(new_keys), 1)
|
self.assertEqual(len(new_keys), 1)
|
||||||
|
|
||||||
# case #3: 20 addresses needed
|
# case #3: 20 addresses needed
|
||||||
yield self.ledger.db.set_address_history(new_keys[0], 'a:1:')
|
await self.ledger.db.set_address_history(new_keys[0], 'a:1:')
|
||||||
new_keys = yield account.receiving.ensure_address_gap()
|
new_keys = await account.receiving.ensure_address_gap()
|
||||||
self.assertEqual(len(new_keys), 20)
|
self.assertEqual(len(new_keys), 20)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def test_get_or_create_usable_address(self):
|
||||||
def test_get_or_create_usable_address(self):
|
|
||||||
account = self.account
|
account = self.account
|
||||||
|
|
||||||
keys = yield account.receiving.get_addresses()
|
keys = await account.receiving.get_addresses()
|
||||||
self.assertEqual(len(keys), 0)
|
self.assertEqual(len(keys), 0)
|
||||||
|
|
||||||
address = yield account.receiving.get_or_create_usable_address()
|
address = await account.receiving.get_or_create_usable_address()
|
||||||
self.assertIsNotNone(address)
|
self.assertIsNotNone(address)
|
||||||
|
|
||||||
keys = yield account.receiving.get_addresses()
|
keys = await account.receiving.get_addresses()
|
||||||
self.assertEqual(len(keys), 20)
|
self.assertEqual(len(keys), 20)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def test_generate_account_from_seed(self):
|
||||||
def test_generate_account_from_seed(self):
|
|
||||||
account = self.ledger.account_class.from_dict(
|
account = self.ledger.account_class.from_dict(
|
||||||
self.ledger, Wallet(), {
|
self.ledger, Wallet(), {
|
||||||
"seed": "carbon smart garage balance margin twelve chest sword "
|
"seed": "carbon smart garage balance margin twelve chest sword "
|
||||||
|
@ -123,17 +120,17 @@ class TestHierarchicalDeterministicAccount(unittest.TestCase):
|
||||||
'xpub661MyMwAqRbcFwwe67Bfjd53h5WXmKm6tqfBJZZH3pQLoy8Nb6mKUMJFc7UbpV'
|
'xpub661MyMwAqRbcFwwe67Bfjd53h5WXmKm6tqfBJZZH3pQLoy8Nb6mKUMJFc7UbpV'
|
||||||
'NzmwFPN2evn3YHnig1pkKVYcvCV8owTd2yAcEkJfCX53g'
|
'NzmwFPN2evn3YHnig1pkKVYcvCV8owTd2yAcEkJfCX53g'
|
||||||
)
|
)
|
||||||
address = yield account.receiving.ensure_address_gap()
|
address = await account.receiving.ensure_address_gap()
|
||||||
self.assertEqual(address[0], '1CDLuMfwmPqJiNk5C2Bvew6tpgjAGgUk8J')
|
self.assertEqual(address[0], '1CDLuMfwmPqJiNk5C2Bvew6tpgjAGgUk8J')
|
||||||
|
|
||||||
private_key = yield self.ledger.get_private_key_for_address('1CDLuMfwmPqJiNk5C2Bvew6tpgjAGgUk8J')
|
private_key = await self.ledger.get_private_key_for_address('1CDLuMfwmPqJiNk5C2Bvew6tpgjAGgUk8J')
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
private_key.extended_key_string(),
|
private_key.extended_key_string(),
|
||||||
'xprv9xV7rhbg6M4yWrdTeLorz3Q1GrQb4aQzzGWboP3du7W7UUztzNTUrEYTnDfz7o'
|
'xprv9xV7rhbg6M4yWrdTeLorz3Q1GrQb4aQzzGWboP3du7W7UUztzNTUrEYTnDfz7o'
|
||||||
'ptBygDxXYRppyiuenJpoBTgYP2C26E1Ah5FEALM24CsWi'
|
'ptBygDxXYRppyiuenJpoBTgYP2C26E1Ah5FEALM24CsWi'
|
||||||
)
|
)
|
||||||
|
|
||||||
invalid_key = yield self.ledger.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX')
|
invalid_key = await self.ledger.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX')
|
||||||
self.assertIsNone(invalid_key)
|
self.assertIsNone(invalid_key)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
@ -141,8 +138,7 @@ class TestHierarchicalDeterministicAccount(unittest.TestCase):
|
||||||
b'1c01ae1e4c7d89e39f6d3aa7792c097a30ca7d40be249b6de52c81ec8cf9aab48b01'
|
b'1c01ae1e4c7d89e39f6d3aa7792c097a30ca7d40be249b6de52c81ec8cf9aab48b01'
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def test_load_and_save_account(self):
|
||||||
def test_load_and_save_account(self):
|
|
||||||
account_data = {
|
account_data = {
|
||||||
'name': 'My Account',
|
'name': 'My Account',
|
||||||
'seed':
|
'seed':
|
||||||
|
@ -164,11 +160,11 @@ class TestHierarchicalDeterministicAccount(unittest.TestCase):
|
||||||
|
|
||||||
account = self.ledger.account_class.from_dict(self.ledger, Wallet(), account_data)
|
account = self.ledger.account_class.from_dict(self.ledger, Wallet(), account_data)
|
||||||
|
|
||||||
yield account.ensure_address_gap()
|
await account.ensure_address_gap()
|
||||||
|
|
||||||
addresses = yield account.receiving.get_addresses()
|
addresses = await account.receiving.get_addresses()
|
||||||
self.assertEqual(len(addresses), 5)
|
self.assertEqual(len(addresses), 5)
|
||||||
addresses = yield account.change.get_addresses()
|
addresses = await account.change.get_addresses()
|
||||||
self.assertEqual(len(addresses), 5)
|
self.assertEqual(len(addresses), 5)
|
||||||
|
|
||||||
self.maxDiff = None
|
self.maxDiff = None
|
||||||
|
@ -176,20 +172,21 @@ class TestHierarchicalDeterministicAccount(unittest.TestCase):
|
||||||
self.assertDictEqual(account_data, account.to_dict())
|
self.assertDictEqual(account_data, account.to_dict())
|
||||||
|
|
||||||
|
|
||||||
class TestSingleKeyAccount(unittest.TestCase):
|
class TestSingleKeyAccount(AsyncioTestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def asyncSetUp(self):
|
||||||
def setUp(self):
|
|
||||||
self.ledger = ledger_class({
|
self.ledger = ledger_class({
|
||||||
'db': ledger_class.database_class(':memory:'),
|
'db': ledger_class.database_class(':memory:'),
|
||||||
'headers': ledger_class.headers_class(':memory:'),
|
'headers': ledger_class.headers_class(':memory:'),
|
||||||
})
|
})
|
||||||
yield self.ledger.db.open()
|
await self.ledger.db.open()
|
||||||
self.account = self.ledger.account_class.generate(
|
self.account = self.ledger.account_class.generate(
|
||||||
self.ledger, Wallet(), "torba", {'name': 'single-address'})
|
self.ledger, Wallet(), "torba", {'name': 'single-address'})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def asyncTearDown(self):
|
||||||
def test_generate_account(self):
|
await self.ledger.db.close()
|
||||||
|
|
||||||
|
async def test_generate_account(self):
|
||||||
account = self.account
|
account = self.account
|
||||||
|
|
||||||
self.assertEqual(account.ledger, self.ledger)
|
self.assertEqual(account.ledger, self.ledger)
|
||||||
|
@ -197,37 +194,36 @@ class TestSingleKeyAccount(unittest.TestCase):
|
||||||
self.assertEqual(account.public_key.ledger, self.ledger)
|
self.assertEqual(account.public_key.ledger, self.ledger)
|
||||||
self.assertEqual(account.private_key.public_key, account.public_key)
|
self.assertEqual(account.private_key.public_key, account.public_key)
|
||||||
|
|
||||||
addresses = yield account.receiving.get_addresses()
|
addresses = await account.receiving.get_addresses()
|
||||||
self.assertEqual(len(addresses), 0)
|
self.assertEqual(len(addresses), 0)
|
||||||
addresses = yield account.change.get_addresses()
|
addresses = await account.change.get_addresses()
|
||||||
self.assertEqual(len(addresses), 0)
|
self.assertEqual(len(addresses), 0)
|
||||||
|
|
||||||
yield account.ensure_address_gap()
|
await account.ensure_address_gap()
|
||||||
|
|
||||||
addresses = yield account.receiving.get_addresses()
|
addresses = await account.receiving.get_addresses()
|
||||||
self.assertEqual(len(addresses), 1)
|
self.assertEqual(len(addresses), 1)
|
||||||
self.assertEqual(addresses[0], account.public_key.address)
|
self.assertEqual(addresses[0], account.public_key.address)
|
||||||
addresses = yield account.change.get_addresses()
|
addresses = await account.change.get_addresses()
|
||||||
self.assertEqual(len(addresses), 1)
|
self.assertEqual(len(addresses), 1)
|
||||||
self.assertEqual(addresses[0], account.public_key.address)
|
self.assertEqual(addresses[0], account.public_key.address)
|
||||||
|
|
||||||
addresses = yield account.get_addresses()
|
addresses = await account.get_addresses()
|
||||||
self.assertEqual(len(addresses), 1)
|
self.assertEqual(len(addresses), 1)
|
||||||
self.assertEqual(addresses[0], account.public_key.address)
|
self.assertEqual(addresses[0], account.public_key.address)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def test_ensure_address_gap(self):
|
||||||
def test_ensure_address_gap(self):
|
|
||||||
account = self.account
|
account = self.account
|
||||||
|
|
||||||
self.assertIsInstance(account.receiving, SingleKey)
|
self.assertIsInstance(account.receiving, SingleKey)
|
||||||
addresses = yield account.receiving.get_addresses()
|
addresses = await account.receiving.get_addresses()
|
||||||
self.assertEqual(addresses, [])
|
self.assertEqual(addresses, [])
|
||||||
|
|
||||||
# we have 12, but default gap is 20
|
# we have 12, but default gap is 20
|
||||||
new_keys = yield account.receiving.ensure_address_gap()
|
new_keys = await account.receiving.ensure_address_gap()
|
||||||
self.assertEqual(len(new_keys), 1)
|
self.assertEqual(len(new_keys), 1)
|
||||||
self.assertEqual(new_keys[0], account.public_key.address)
|
self.assertEqual(new_keys[0], account.public_key.address)
|
||||||
records = yield account.receiving.get_address_records()
|
records = await account.receiving.get_address_records()
|
||||||
self.assertEqual(records, [{
|
self.assertEqual(records, [{
|
||||||
'position': 0, 'chain': 0,
|
'position': 0, 'chain': 0,
|
||||||
'account': account.public_key.address,
|
'account': account.public_key.address,
|
||||||
|
@ -236,37 +232,35 @@ class TestSingleKeyAccount(unittest.TestCase):
|
||||||
}])
|
}])
|
||||||
|
|
||||||
# case #1: no new addresses needed
|
# case #1: no new addresses needed
|
||||||
empty = yield account.receiving.ensure_address_gap()
|
empty = await account.receiving.ensure_address_gap()
|
||||||
self.assertEqual(len(empty), 0)
|
self.assertEqual(len(empty), 0)
|
||||||
|
|
||||||
# case #2: after use, still no new address needed
|
# case #2: after use, still no new address needed
|
||||||
records = yield account.receiving.get_address_records()
|
records = await account.receiving.get_address_records()
|
||||||
yield self.ledger.db.set_address_history(records[0]['address'], 'a:1:')
|
await self.ledger.db.set_address_history(records[0]['address'], 'a:1:')
|
||||||
empty = yield account.receiving.ensure_address_gap()
|
empty = await account.receiving.ensure_address_gap()
|
||||||
self.assertEqual(len(empty), 0)
|
self.assertEqual(len(empty), 0)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def test_get_or_create_usable_address(self):
|
||||||
def test_get_or_create_usable_address(self):
|
|
||||||
account = self.account
|
account = self.account
|
||||||
|
|
||||||
addresses = yield account.receiving.get_addresses()
|
addresses = await account.receiving.get_addresses()
|
||||||
self.assertEqual(len(addresses), 0)
|
self.assertEqual(len(addresses), 0)
|
||||||
|
|
||||||
address1 = yield account.receiving.get_or_create_usable_address()
|
address1 = await account.receiving.get_or_create_usable_address()
|
||||||
self.assertIsNotNone(address1)
|
self.assertIsNotNone(address1)
|
||||||
|
|
||||||
yield self.ledger.db.set_address_history(address1, 'a:1:b:2:c:3:')
|
await self.ledger.db.set_address_history(address1, 'a:1:b:2:c:3:')
|
||||||
records = yield account.receiving.get_address_records()
|
records = await account.receiving.get_address_records()
|
||||||
self.assertEqual(records[0]['used_times'], 3)
|
self.assertEqual(records[0]['used_times'], 3)
|
||||||
|
|
||||||
address2 = yield account.receiving.get_or_create_usable_address()
|
address2 = await account.receiving.get_or_create_usable_address()
|
||||||
self.assertEqual(address1, address2)
|
self.assertEqual(address1, address2)
|
||||||
|
|
||||||
keys = yield account.receiving.get_addresses()
|
keys = await account.receiving.get_addresses()
|
||||||
self.assertEqual(len(keys), 1)
|
self.assertEqual(len(keys), 1)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def test_generate_account_from_seed(self):
|
||||||
def test_generate_account_from_seed(self):
|
|
||||||
account = self.ledger.account_class.from_dict(
|
account = self.ledger.account_class.from_dict(
|
||||||
self.ledger, Wallet(), {
|
self.ledger, Wallet(), {
|
||||||
"seed":
|
"seed":
|
||||||
|
@ -285,17 +279,17 @@ class TestSingleKeyAccount(unittest.TestCase):
|
||||||
'xpub661MyMwAqRbcFwwe67Bfjd53h5WXmKm6tqfBJZZH3pQLoy8Nb6mKUMJFc7'
|
'xpub661MyMwAqRbcFwwe67Bfjd53h5WXmKm6tqfBJZZH3pQLoy8Nb6mKUMJFc7'
|
||||||
'UbpVNzmwFPN2evn3YHnig1pkKVYcvCV8owTd2yAcEkJfCX53g',
|
'UbpVNzmwFPN2evn3YHnig1pkKVYcvCV8owTd2yAcEkJfCX53g',
|
||||||
)
|
)
|
||||||
address = yield account.receiving.ensure_address_gap()
|
address = await account.receiving.ensure_address_gap()
|
||||||
self.assertEqual(address[0], account.public_key.address)
|
self.assertEqual(address[0], account.public_key.address)
|
||||||
|
|
||||||
private_key = yield self.ledger.get_private_key_for_address(address[0])
|
private_key = await self.ledger.get_private_key_for_address(address[0])
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
private_key.extended_key_string(),
|
private_key.extended_key_string(),
|
||||||
'xprv9s21ZrQH143K3TsAz5efNV8K93g3Ms3FXcjaWB9fVUsMwAoE3ZT4vYymkp'
|
'xprv9s21ZrQH143K3TsAz5efNV8K93g3Ms3FXcjaWB9fVUsMwAoE3ZT4vYymkp'
|
||||||
'5BxKKfnpz8J6sHDFriX1SnpvjNkzcks8XBnxjGLS83BTyfpna',
|
'5BxKKfnpz8J6sHDFriX1SnpvjNkzcks8XBnxjGLS83BTyfpna',
|
||||||
)
|
)
|
||||||
|
|
||||||
invalid_key = yield self.ledger.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX')
|
invalid_key = await self.ledger.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX')
|
||||||
self.assertIsNone(invalid_key)
|
self.assertIsNone(invalid_key)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
@ -303,8 +297,7 @@ class TestSingleKeyAccount(unittest.TestCase):
|
||||||
b'1c92caa0ef99bfd5e2ceb73b66da8cd726a9370be8c368d448a322f3c5b23aaab901'
|
b'1c92caa0ef99bfd5e2ceb73b66da8cd726a9370be8c368d448a322f3c5b23aaab901'
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def test_load_and_save_account(self):
|
||||||
def test_load_and_save_account(self):
|
|
||||||
account_data = {
|
account_data = {
|
||||||
'name': 'My Account',
|
'name': 'My Account',
|
||||||
'seed':
|
'seed':
|
||||||
|
@ -322,11 +315,11 @@ class TestSingleKeyAccount(unittest.TestCase):
|
||||||
|
|
||||||
account = self.ledger.account_class.from_dict(self.ledger, Wallet(), account_data)
|
account = self.ledger.account_class.from_dict(self.ledger, Wallet(), account_data)
|
||||||
|
|
||||||
yield account.ensure_address_gap()
|
await account.ensure_address_gap()
|
||||||
|
|
||||||
addresses = yield account.receiving.get_addresses()
|
addresses = await account.receiving.get_addresses()
|
||||||
self.assertEqual(len(addresses), 1)
|
self.assertEqual(len(addresses), 1)
|
||||||
addresses = yield account.change.get_addresses()
|
addresses = await account.change.get_addresses()
|
||||||
self.assertEqual(len(addresses), 1)
|
self.assertEqual(len(addresses), 1)
|
||||||
|
|
||||||
self.maxDiff = None
|
self.maxDiff = None
|
||||||
|
@ -334,7 +327,7 @@ class TestSingleKeyAccount(unittest.TestCase):
|
||||||
self.assertDictEqual(account_data, account.to_dict())
|
self.assertDictEqual(account_data, account.to_dict())
|
||||||
|
|
||||||
|
|
||||||
class AccountEncryptionTests(unittest.TestCase):
|
class AccountEncryptionTests(AsyncioTestCase):
|
||||||
password = "password"
|
password = "password"
|
||||||
init_vector = b'0000000000000000'
|
init_vector = b'0000000000000000'
|
||||||
unencrypted_account = {
|
unencrypted_account = {
|
||||||
|
@ -368,7 +361,7 @@ class AccountEncryptionTests(unittest.TestCase):
|
||||||
'address_generator': {'name': 'single-address'}
|
'address_generator': {'name': 'single-address'}
|
||||||
}
|
}
|
||||||
|
|
||||||
def setUp(self):
|
async def asyncSetUp(self):
|
||||||
self.ledger = ledger_class({
|
self.ledger = ledger_class({
|
||||||
'db': ledger_class.database_class(':memory:'),
|
'db': ledger_class.database_class(':memory:'),
|
||||||
'headers': ledger_class.headers_class(':memory:'),
|
'headers': ledger_class.headers_class(':memory:'),
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from twisted.trial import unittest
|
import unittest
|
||||||
|
|
||||||
from torba.bcd_data_stream import BCDataStream
|
from torba.bcd_data_stream import BCDataStream
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
|
import unittest
|
||||||
from binascii import unhexlify, hexlify
|
from binascii import unhexlify, hexlify
|
||||||
from twisted.trial import unittest
|
|
||||||
|
|
||||||
from .key_fixtures import expected_ids, expected_privkeys, expected_hardened_privkeys
|
from .key_fixtures import expected_ids, expected_privkeys, expected_hardened_privkeys
|
||||||
from torba.bip32 import PubKey, PrivateKey, from_extended_key_string
|
from torba.bip32 import PubKey, PrivateKey, from_extended_key_string
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from twisted.trial import unittest
|
import unittest
|
||||||
from types import GeneratorType
|
from types import GeneratorType
|
||||||
|
|
||||||
from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class
|
from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
from twisted.trial import unittest
|
import unittest
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from torba.wallet import Wallet
|
from torba.wallet import Wallet
|
||||||
from torba.constants import COIN
|
from torba.constants import COIN
|
||||||
from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class
|
from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class
|
||||||
from torba.basedatabase import query, constraints_to_sql
|
from torba.basedatabase import query, constraints_to_sql
|
||||||
|
|
||||||
|
from orchstr8.testcase import AsyncioTestCase
|
||||||
|
|
||||||
from .test_transaction import get_output, NULL_HASH
|
from .test_transaction import get_output, NULL_HASH
|
||||||
|
|
||||||
|
|
||||||
|
@ -100,53 +101,52 @@ class TestQueryBuilder(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestQueries(unittest.TestCase):
|
class TestQueries(AsyncioTestCase):
|
||||||
|
|
||||||
def setUp(self):
|
async def asyncSetUp(self):
|
||||||
self.ledger = ledger_class({
|
self.ledger = ledger_class({
|
||||||
'db': ledger_class.database_class(':memory:'),
|
'db': ledger_class.database_class(':memory:'),
|
||||||
'headers': ledger_class.headers_class(':memory:'),
|
'headers': ledger_class.headers_class(':memory:'),
|
||||||
})
|
})
|
||||||
return self.ledger.db.open()
|
await self.ledger.db.open()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def asyncTearDown(self):
|
||||||
def create_account(self):
|
await self.ledger.db.close()
|
||||||
|
|
||||||
|
async def create_account(self):
|
||||||
account = self.ledger.account_class.generate(self.ledger, Wallet())
|
account = self.ledger.account_class.generate(self.ledger, Wallet())
|
||||||
yield account.ensure_address_gap()
|
await account.ensure_address_gap()
|
||||||
return account
|
return account
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def create_tx_from_nothing(self, my_account, height):
|
||||||
def create_tx_from_nothing(self, my_account, height):
|
to_address = await my_account.receiving.get_or_create_usable_address()
|
||||||
to_address = yield my_account.receiving.get_or_create_usable_address()
|
|
||||||
to_hash = ledger_class.address_to_hash160(to_address)
|
to_hash = ledger_class.address_to_hash160(to_address)
|
||||||
tx = ledger_class.transaction_class(height=height, is_verified=True) \
|
tx = ledger_class.transaction_class(height=height, is_verified=True) \
|
||||||
.add_inputs([self.txi(self.txo(1, NULL_HASH))]) \
|
.add_inputs([self.txi(self.txo(1, NULL_HASH))]) \
|
||||||
.add_outputs([self.txo(1, to_hash)])
|
.add_outputs([self.txo(1, to_hash)])
|
||||||
yield self.ledger.db.save_transaction_io('insert', tx, to_address, to_hash, '')
|
await self.ledger.db.save_transaction_io('insert', tx, to_address, to_hash, '')
|
||||||
return tx
|
return tx
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def create_tx_from_txo(self, txo, to_account, height):
|
||||||
def create_tx_from_txo(self, txo, to_account, height):
|
|
||||||
from_hash = txo.script.values['pubkey_hash']
|
from_hash = txo.script.values['pubkey_hash']
|
||||||
from_address = self.ledger.hash160_to_address(from_hash)
|
from_address = self.ledger.hash160_to_address(from_hash)
|
||||||
to_address = yield to_account.receiving.get_or_create_usable_address()
|
to_address = await to_account.receiving.get_or_create_usable_address()
|
||||||
to_hash = ledger_class.address_to_hash160(to_address)
|
to_hash = ledger_class.address_to_hash160(to_address)
|
||||||
tx = ledger_class.transaction_class(height=height, is_verified=True) \
|
tx = ledger_class.transaction_class(height=height, is_verified=True) \
|
||||||
.add_inputs([self.txi(txo)]) \
|
.add_inputs([self.txi(txo)]) \
|
||||||
.add_outputs([self.txo(1, to_hash)])
|
.add_outputs([self.txo(1, to_hash)])
|
||||||
yield self.ledger.db.save_transaction_io('insert', tx, from_address, from_hash, '')
|
await self.ledger.db.save_transaction_io('insert', tx, from_address, from_hash, '')
|
||||||
yield self.ledger.db.save_transaction_io('', tx, to_address, to_hash, '')
|
await self.ledger.db.save_transaction_io('', tx, to_address, to_hash, '')
|
||||||
return tx
|
return tx
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def create_tx_to_nowhere(self, txo, height):
|
||||||
def create_tx_to_nowhere(self, txo, height):
|
|
||||||
from_hash = txo.script.values['pubkey_hash']
|
from_hash = txo.script.values['pubkey_hash']
|
||||||
from_address = self.ledger.hash160_to_address(from_hash)
|
from_address = self.ledger.hash160_to_address(from_hash)
|
||||||
to_hash = NULL_HASH
|
to_hash = NULL_HASH
|
||||||
tx = ledger_class.transaction_class(height=height, is_verified=True) \
|
tx = ledger_class.transaction_class(height=height, is_verified=True) \
|
||||||
.add_inputs([self.txi(txo)]) \
|
.add_inputs([self.txi(txo)]) \
|
||||||
.add_outputs([self.txo(1, to_hash)])
|
.add_outputs([self.txo(1, to_hash)])
|
||||||
yield self.ledger.db.save_transaction_io('insert', tx, from_address, from_hash, '')
|
await self.ledger.db.save_transaction_io('insert', tx, from_address, from_hash, '')
|
||||||
return tx
|
return tx
|
||||||
|
|
||||||
def txo(self, amount, address):
|
def txo(self, amount, address):
|
||||||
|
@ -155,39 +155,38 @@ class TestQueries(unittest.TestCase):
|
||||||
def txi(self, txo):
|
def txi(self, txo):
|
||||||
return ledger_class.transaction_class.input_class.spend(txo)
|
return ledger_class.transaction_class.input_class.spend(txo)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def test_get_transactions(self):
|
||||||
def test_get_transactions(self):
|
account1 = await self.create_account()
|
||||||
account1 = yield self.create_account()
|
account2 = await self.create_account()
|
||||||
account2 = yield self.create_account()
|
tx1 = await self.create_tx_from_nothing(account1, 1)
|
||||||
tx1 = yield self.create_tx_from_nothing(account1, 1)
|
tx2 = await self.create_tx_from_txo(tx1.outputs[0], account2, 2)
|
||||||
tx2 = yield self.create_tx_from_txo(tx1.outputs[0], account2, 2)
|
tx3 = await self.create_tx_to_nowhere(tx2.outputs[0], 3)
|
||||||
tx3 = yield self.create_tx_to_nowhere(tx2.outputs[0], 3)
|
|
||||||
|
|
||||||
txs = yield self.ledger.db.get_transactions()
|
txs = await self.ledger.db.get_transactions()
|
||||||
self.assertEqual([tx3.id, tx2.id, tx1.id], [tx.id for tx in txs])
|
self.assertEqual([tx3.id, tx2.id, tx1.id], [tx.id for tx in txs])
|
||||||
self.assertEqual([3, 2, 1], [tx.height for tx in txs])
|
self.assertEqual([3, 2, 1], [tx.height for tx in txs])
|
||||||
|
|
||||||
txs = yield self.ledger.db.get_transactions(account=account1)
|
txs = await self.ledger.db.get_transactions(account=account1)
|
||||||
self.assertEqual([tx2.id, tx1.id], [tx.id for tx in txs])
|
self.assertEqual([tx2.id, tx1.id], [tx.id for tx in txs])
|
||||||
self.assertEqual(txs[0].inputs[0].is_my_account, True)
|
self.assertEqual(txs[0].inputs[0].is_my_account, True)
|
||||||
self.assertEqual(txs[0].outputs[0].is_my_account, False)
|
self.assertEqual(txs[0].outputs[0].is_my_account, False)
|
||||||
self.assertEqual(txs[1].inputs[0].is_my_account, False)
|
self.assertEqual(txs[1].inputs[0].is_my_account, False)
|
||||||
self.assertEqual(txs[1].outputs[0].is_my_account, True)
|
self.assertEqual(txs[1].outputs[0].is_my_account, True)
|
||||||
|
|
||||||
txs = yield self.ledger.db.get_transactions(account=account2)
|
txs = await self.ledger.db.get_transactions(account=account2)
|
||||||
self.assertEqual([tx3.id, tx2.id], [tx.id for tx in txs])
|
self.assertEqual([tx3.id, tx2.id], [tx.id for tx in txs])
|
||||||
self.assertEqual(txs[0].inputs[0].is_my_account, True)
|
self.assertEqual(txs[0].inputs[0].is_my_account, True)
|
||||||
self.assertEqual(txs[0].outputs[0].is_my_account, False)
|
self.assertEqual(txs[0].outputs[0].is_my_account, False)
|
||||||
self.assertEqual(txs[1].inputs[0].is_my_account, False)
|
self.assertEqual(txs[1].inputs[0].is_my_account, False)
|
||||||
self.assertEqual(txs[1].outputs[0].is_my_account, True)
|
self.assertEqual(txs[1].outputs[0].is_my_account, True)
|
||||||
|
|
||||||
tx = yield self.ledger.db.get_transaction(txid=tx2.id)
|
tx = await self.ledger.db.get_transaction(txid=tx2.id)
|
||||||
self.assertEqual(tx.id, tx2.id)
|
self.assertEqual(tx.id, tx2.id)
|
||||||
self.assertEqual(tx.inputs[0].is_my_account, False)
|
self.assertEqual(tx.inputs[0].is_my_account, False)
|
||||||
self.assertEqual(tx.outputs[0].is_my_account, False)
|
self.assertEqual(tx.outputs[0].is_my_account, False)
|
||||||
tx = yield self.ledger.db.get_transaction(txid=tx2.id, account=account1)
|
tx = await self.ledger.db.get_transaction(txid=tx2.id, account=account1)
|
||||||
self.assertEqual(tx.inputs[0].is_my_account, True)
|
self.assertEqual(tx.inputs[0].is_my_account, True)
|
||||||
self.assertEqual(tx.outputs[0].is_my_account, False)
|
self.assertEqual(tx.outputs[0].is_my_account, False)
|
||||||
tx = yield self.ledger.db.get_transaction(txid=tx2.id, account=account2)
|
tx = await self.ledger.db.get_transaction(txid=tx2.id, account=account2)
|
||||||
self.assertEqual(tx.inputs[0].is_my_account, False)
|
self.assertEqual(tx.inputs[0].is_my_account, False)
|
||||||
self.assertEqual(tx.outputs[0].is_my_account, True)
|
self.assertEqual(tx.outputs[0].is_my_account, True)
|
||||||
|
|
|
@ -1,11 +1,6 @@
|
||||||
from unittest import TestCase
|
from unittest import TestCase, mock
|
||||||
from torba.hash import aes_decrypt, aes_encrypt
|
from torba.hash import aes_decrypt, aes_encrypt
|
||||||
|
|
||||||
try:
|
|
||||||
from unittest import mock
|
|
||||||
except ImportError:
|
|
||||||
import mock
|
|
||||||
|
|
||||||
|
|
||||||
class TestAESEncryptDecrypt(TestCase):
|
class TestAESEncryptDecrypt(TestCase):
|
||||||
message = 'The Times 03/Jan/2009 Chancellor on brink of second bailout for banks'
|
message = 'The Times 03/Jan/2009 Chancellor on brink of second bailout for banks'
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
import os
|
import os
|
||||||
from urllib.request import Request, urlopen
|
from urllib.request import Request, urlopen
|
||||||
|
|
||||||
from twisted.trial import unittest
|
from orchstr8.testcase import AsyncioTestCase
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from torba.coin.bitcoinsegwit import MainHeaders
|
from torba.coin.bitcoinsegwit import MainHeaders
|
||||||
|
|
||||||
|
@ -11,7 +10,7 @@ def block_bytes(blocks):
|
||||||
return blocks * MainHeaders.header_size
|
return blocks * MainHeaders.header_size
|
||||||
|
|
||||||
|
|
||||||
class BitcoinHeadersTestCase(unittest.TestCase):
|
class BitcoinHeadersTestCase(AsyncioTestCase):
|
||||||
|
|
||||||
# Download headers instead of storing them in git.
|
# Download headers instead of storing them in git.
|
||||||
HEADER_URL = 'http://headers.electrum.org/blockchain_headers'
|
HEADER_URL = 'http://headers.electrum.org/blockchain_headers'
|
||||||
|
@ -39,7 +38,7 @@ class BitcoinHeadersTestCase(unittest.TestCase):
|
||||||
headers.seek(after, os.SEEK_SET)
|
headers.seek(after, os.SEEK_SET)
|
||||||
return headers.read(upto)
|
return headers.read(upto)
|
||||||
|
|
||||||
def get_headers(self, upto: int = -1):
|
async def get_headers(self, upto: int = -1):
|
||||||
h = MainHeaders(':memory:')
|
h = MainHeaders(':memory:')
|
||||||
h.io.write(self.get_bytes(upto))
|
h.io.write(self.get_bytes(upto))
|
||||||
return h
|
return h
|
||||||
|
@ -47,8 +46,8 @@ class BitcoinHeadersTestCase(unittest.TestCase):
|
||||||
|
|
||||||
class BasicHeadersTests(BitcoinHeadersTestCase):
|
class BasicHeadersTests(BitcoinHeadersTestCase):
|
||||||
|
|
||||||
def test_serialization(self):
|
async def test_serialization(self):
|
||||||
h = self.get_headers()
|
h = await self.get_headers()
|
||||||
self.assertEqual(h[0], {
|
self.assertEqual(h[0], {
|
||||||
'bits': 486604799,
|
'bits': 486604799,
|
||||||
'block_height': 0,
|
'block_height': 0,
|
||||||
|
@ -94,18 +93,16 @@ class BasicHeadersTests(BitcoinHeadersTestCase):
|
||||||
h.get_raw_header(self.RETARGET_BLOCK)
|
h.get_raw_header(self.RETARGET_BLOCK)
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def test_connect_from_genesis_to_3000_past_first_chunk_at_2016(self):
|
||||||
def test_connect_from_genesis_to_3000_past_first_chunk_at_2016(self):
|
|
||||||
headers = MainHeaders(':memory:')
|
headers = MainHeaders(':memory:')
|
||||||
self.assertEqual(headers.height, -1)
|
self.assertEqual(headers.height, -1)
|
||||||
yield headers.connect(0, self.get_bytes(block_bytes(3001)))
|
await headers.connect(0, self.get_bytes(block_bytes(3001)))
|
||||||
self.assertEqual(headers.height, 3000)
|
self.assertEqual(headers.height, 3000)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def test_connect_9_blocks_passing_a_retarget_at_32256(self):
|
||||||
def test_connect_9_blocks_passing_a_retarget_at_32256(self):
|
|
||||||
retarget = block_bytes(self.RETARGET_BLOCK-5)
|
retarget = block_bytes(self.RETARGET_BLOCK-5)
|
||||||
headers = self.get_headers(upto=retarget)
|
headers = await self.get_headers(upto=retarget)
|
||||||
remainder = self.get_bytes(after=retarget)
|
remainder = self.get_bytes(after=retarget)
|
||||||
self.assertEqual(headers.height, 32250)
|
self.assertEqual(headers.height, 32250)
|
||||||
yield headers.connect(len(headers), remainder)
|
await headers.connect(len(headers), remainder)
|
||||||
self.assertEqual(headers.height, 32259)
|
self.assertEqual(headers.height, 32259)
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
import os
|
import os
|
||||||
from binascii import hexlify
|
from binascii import hexlify
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from torba.coin.bitcoinsegwit import MainNetLedger
|
from torba.coin.bitcoinsegwit import MainNetLedger
|
||||||
from torba.wallet import Wallet
|
from torba.wallet import Wallet
|
||||||
|
@ -18,32 +17,30 @@ class MockNetwork:
|
||||||
self.get_history_called = []
|
self.get_history_called = []
|
||||||
self.get_transaction_called = []
|
self.get_transaction_called = []
|
||||||
|
|
||||||
def get_history(self, address):
|
async def get_history(self, address):
|
||||||
self.get_history_called.append(address)
|
self.get_history_called.append(address)
|
||||||
self.address = address
|
self.address = address
|
||||||
return defer.succeed(self.history)
|
return self.history
|
||||||
|
|
||||||
def get_merkle(self, txid, height):
|
async def get_merkle(self, txid, height):
|
||||||
return defer.succeed({'merkle': ['abcd01'], 'pos': 1})
|
return {'merkle': ['abcd01'], 'pos': 1}
|
||||||
|
|
||||||
def get_transaction(self, tx_hash):
|
async def get_transaction(self, tx_hash):
|
||||||
self.get_transaction_called.append(tx_hash)
|
self.get_transaction_called.append(tx_hash)
|
||||||
return defer.succeed(self.transaction[tx_hash])
|
return self.transaction[tx_hash]
|
||||||
|
|
||||||
|
|
||||||
class LedgerTestCase(BitcoinHeadersTestCase):
|
class LedgerTestCase(BitcoinHeadersTestCase):
|
||||||
|
|
||||||
def setUp(self):
|
async def asyncSetUp(self):
|
||||||
super().setUp()
|
|
||||||
self.ledger = MainNetLedger({
|
self.ledger = MainNetLedger({
|
||||||
'db': MainNetLedger.database_class(':memory:'),
|
'db': MainNetLedger.database_class(':memory:'),
|
||||||
'headers': MainNetLedger.headers_class(':memory:')
|
'headers': MainNetLedger.headers_class(':memory:')
|
||||||
})
|
})
|
||||||
return self.ledger.db.open()
|
await self.ledger.db.open()
|
||||||
|
|
||||||
def tearDown(self):
|
async def asyncTearDown(self):
|
||||||
super().tearDown()
|
await self.ledger.db.close()
|
||||||
return self.ledger.db.close()
|
|
||||||
|
|
||||||
def make_header(self, **kwargs):
|
def make_header(self, **kwargs):
|
||||||
header = {
|
header = {
|
||||||
|
@ -69,11 +66,10 @@ class LedgerTestCase(BitcoinHeadersTestCase):
|
||||||
|
|
||||||
class TestSynchronization(LedgerTestCase):
|
class TestSynchronization(LedgerTestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def test_update_history(self):
|
||||||
def test_update_history(self):
|
|
||||||
account = self.ledger.account_class.generate(self.ledger, Wallet(), "torba")
|
account = self.ledger.account_class.generate(self.ledger, Wallet(), "torba")
|
||||||
address = yield account.receiving.get_or_create_usable_address()
|
address = await account.receiving.get_or_create_usable_address()
|
||||||
address_details = yield self.ledger.db.get_address(address=address)
|
address_details = await self.ledger.db.get_address(address=address)
|
||||||
self.assertEqual(address_details['history'], None)
|
self.assertEqual(address_details['history'], None)
|
||||||
|
|
||||||
self.add_header(block_height=0, merkle_root=b'abcd04')
|
self.add_header(block_height=0, merkle_root=b'abcd04')
|
||||||
|
@ -89,16 +85,16 @@ class TestSynchronization(LedgerTestCase):
|
||||||
'abcd02': hexlify(get_transaction(get_output(2)).raw),
|
'abcd02': hexlify(get_transaction(get_output(2)).raw),
|
||||||
'abcd03': hexlify(get_transaction(get_output(3)).raw),
|
'abcd03': hexlify(get_transaction(get_output(3)).raw),
|
||||||
})
|
})
|
||||||
yield self.ledger.update_history(address)
|
await self.ledger.update_history(address)
|
||||||
self.assertEqual(self.ledger.network.get_history_called, [address])
|
self.assertEqual(self.ledger.network.get_history_called, [address])
|
||||||
self.assertEqual(self.ledger.network.get_transaction_called, ['abcd01', 'abcd02', 'abcd03'])
|
self.assertEqual(self.ledger.network.get_transaction_called, ['abcd01', 'abcd02', 'abcd03'])
|
||||||
|
|
||||||
address_details = yield self.ledger.db.get_address(address=address)
|
address_details = await self.ledger.db.get_address(address=address)
|
||||||
self.assertEqual(address_details['history'], 'abcd01:0:abcd02:1:abcd03:2:')
|
self.assertEqual(address_details['history'], 'abcd01:0:abcd02:1:abcd03:2:')
|
||||||
|
|
||||||
self.ledger.network.get_history_called = []
|
self.ledger.network.get_history_called = []
|
||||||
self.ledger.network.get_transaction_called = []
|
self.ledger.network.get_transaction_called = []
|
||||||
yield self.ledger.update_history(address)
|
await self.ledger.update_history(address)
|
||||||
self.assertEqual(self.ledger.network.get_history_called, [address])
|
self.assertEqual(self.ledger.network.get_history_called, [address])
|
||||||
self.assertEqual(self.ledger.network.get_transaction_called, [])
|
self.assertEqual(self.ledger.network.get_transaction_called, [])
|
||||||
|
|
||||||
|
@ -106,10 +102,10 @@ class TestSynchronization(LedgerTestCase):
|
||||||
self.ledger.network.transaction['abcd04'] = hexlify(get_transaction(get_output(4)).raw)
|
self.ledger.network.transaction['abcd04'] = hexlify(get_transaction(get_output(4)).raw)
|
||||||
self.ledger.network.get_history_called = []
|
self.ledger.network.get_history_called = []
|
||||||
self.ledger.network.get_transaction_called = []
|
self.ledger.network.get_transaction_called = []
|
||||||
yield self.ledger.update_history(address)
|
await self.ledger.update_history(address)
|
||||||
self.assertEqual(self.ledger.network.get_history_called, [address])
|
self.assertEqual(self.ledger.network.get_history_called, [address])
|
||||||
self.assertEqual(self.ledger.network.get_transaction_called, ['abcd04'])
|
self.assertEqual(self.ledger.network.get_transaction_called, ['abcd04'])
|
||||||
address_details = yield self.ledger.db.get_address(address=address)
|
address_details = await self.ledger.db.get_address(address=address)
|
||||||
self.assertEqual(address_details['history'], 'abcd01:0:abcd02:1:abcd03:2:abcd04:3:')
|
self.assertEqual(address_details['history'], 'abcd01:0:abcd02:1:abcd03:2:abcd04:3:')
|
||||||
|
|
||||||
|
|
||||||
|
@ -117,14 +113,13 @@ class MocHeaderNetwork:
|
||||||
def __init__(self, responses):
|
def __init__(self, responses):
|
||||||
self.responses = responses
|
self.responses = responses
|
||||||
|
|
||||||
def get_headers(self, height, blocks):
|
async def get_headers(self, height, blocks):
|
||||||
return self.responses[height]
|
return self.responses[height]
|
||||||
|
|
||||||
|
|
||||||
class BlockchainReorganizationTests(LedgerTestCase):
|
class BlockchainReorganizationTests(LedgerTestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def test_1_block_reorganization(self):
|
||||||
def test_1_block_reorganization(self):
|
|
||||||
self.ledger.network = MocHeaderNetwork({
|
self.ledger.network = MocHeaderNetwork({
|
||||||
20: {'height': 20, 'count': 5, 'hex': hexlify(
|
20: {'height': 20, 'count': 5, 'hex': hexlify(
|
||||||
self.get_bytes(after=block_bytes(20), upto=block_bytes(5))
|
self.get_bytes(after=block_bytes(20), upto=block_bytes(5))
|
||||||
|
@ -132,15 +127,14 @@ class BlockchainReorganizationTests(LedgerTestCase):
|
||||||
25: {'height': 25, 'count': 0, 'hex': b''}
|
25: {'height': 25, 'count': 0, 'hex': b''}
|
||||||
})
|
})
|
||||||
headers = self.ledger.headers
|
headers = self.ledger.headers
|
||||||
yield headers.connect(0, self.get_bytes(upto=block_bytes(20)))
|
await headers.connect(0, self.get_bytes(upto=block_bytes(20)))
|
||||||
self.add_header(block_height=len(headers))
|
self.add_header(block_height=len(headers))
|
||||||
self.assertEqual(headers.height, 20)
|
self.assertEqual(headers.height, 20)
|
||||||
yield self.ledger.receive_header([{
|
await self.ledger.receive_header([{
|
||||||
'height': 21, 'hex': hexlify(self.make_header(block_height=21))
|
'height': 21, 'hex': hexlify(self.make_header(block_height=21))
|
||||||
}])
|
}])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def test_3_block_reorganization(self):
|
||||||
def test_3_block_reorganization(self):
|
|
||||||
self.ledger.network = MocHeaderNetwork({
|
self.ledger.network = MocHeaderNetwork({
|
||||||
20: {'height': 20, 'count': 5, 'hex': hexlify(
|
20: {'height': 20, 'count': 5, 'hex': hexlify(
|
||||||
self.get_bytes(after=block_bytes(20), upto=block_bytes(5))
|
self.get_bytes(after=block_bytes(20), upto=block_bytes(5))
|
||||||
|
@ -150,11 +144,11 @@ class BlockchainReorganizationTests(LedgerTestCase):
|
||||||
25: {'height': 25, 'count': 0, 'hex': b''}
|
25: {'height': 25, 'count': 0, 'hex': b''}
|
||||||
})
|
})
|
||||||
headers = self.ledger.headers
|
headers = self.ledger.headers
|
||||||
yield headers.connect(0, self.get_bytes(upto=block_bytes(20)))
|
await headers.connect(0, self.get_bytes(upto=block_bytes(20)))
|
||||||
self.add_header(block_height=len(headers))
|
self.add_header(block_height=len(headers))
|
||||||
self.add_header(block_height=len(headers))
|
self.add_header(block_height=len(headers))
|
||||||
self.add_header(block_height=len(headers))
|
self.add_header(block_height=len(headers))
|
||||||
self.assertEqual(headers.height, 22)
|
self.assertEqual(headers.height, 22)
|
||||||
yield self.ledger.receive_header(({
|
await self.ledger.receive_header(({
|
||||||
'height': 23, 'hex': hexlify(self.make_header(block_height=23))
|
'height': 23, 'hex': hexlify(self.make_header(block_height=23))
|
||||||
},))
|
},))
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
|
import unittest
|
||||||
from binascii import hexlify, unhexlify
|
from binascii import hexlify, unhexlify
|
||||||
from twisted.trial import unittest
|
|
||||||
|
|
||||||
from torba.bcd_data_stream import BCDataStream
|
from torba.bcd_data_stream import BCDataStream
|
||||||
from torba.basescript import Template, ParseError, tokenize, push_data
|
from torba.basescript import Template, ParseError, tokenize, push_data
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
|
import unittest
|
||||||
from binascii import hexlify, unhexlify
|
from binascii import hexlify, unhexlify
|
||||||
from itertools import cycle
|
from itertools import cycle
|
||||||
from twisted.trial import unittest
|
|
||||||
from twisted.internet import defer
|
from orchstr8.testcase import AsyncioTestCase
|
||||||
|
|
||||||
from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class
|
from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class
|
||||||
from torba.wallet import Wallet
|
from torba.wallet import Wallet
|
||||||
|
@ -29,9 +30,9 @@ def get_transaction(txo=None):
|
||||||
.add_outputs([txo or ledger_class.transaction_class.output_class.pay_pubkey_hash(CENT, NULL_HASH)])
|
.add_outputs([txo or ledger_class.transaction_class.output_class.pay_pubkey_hash(CENT, NULL_HASH)])
|
||||||
|
|
||||||
|
|
||||||
class TestSizeAndFeeEstimation(unittest.TestCase):
|
class TestSizeAndFeeEstimation(AsyncioTestCase):
|
||||||
|
|
||||||
def setUp(self):
|
async def asyncSetUp(self):
|
||||||
self.ledger = ledger_class({
|
self.ledger = ledger_class({
|
||||||
'db': ledger_class.database_class(':memory:'),
|
'db': ledger_class.database_class(':memory:'),
|
||||||
'headers': ledger_class.headers_class(':memory:'),
|
'headers': ledger_class.headers_class(':memory:'),
|
||||||
|
@ -181,17 +182,19 @@ class TestTransactionSerialization(unittest.TestCase):
|
||||||
self.assertEqual(tx.raw, raw)
|
self.assertEqual(tx.raw, raw)
|
||||||
|
|
||||||
|
|
||||||
class TestTransactionSigning(unittest.TestCase):
|
class TestTransactionSigning(AsyncioTestCase):
|
||||||
|
|
||||||
def setUp(self):
|
async def asyncSetUp(self):
|
||||||
self.ledger = ledger_class({
|
self.ledger = ledger_class({
|
||||||
'db': ledger_class.database_class(':memory:'),
|
'db': ledger_class.database_class(':memory:'),
|
||||||
'headers': ledger_class.headers_class(':memory:'),
|
'headers': ledger_class.headers_class(':memory:'),
|
||||||
})
|
})
|
||||||
return self.ledger.db.open()
|
await self.ledger.db.open()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def asyncTearDown(self):
|
||||||
def test_sign(self):
|
await self.ledger.db.close()
|
||||||
|
|
||||||
|
async def test_sign(self):
|
||||||
account = self.ledger.account_class.from_dict(
|
account = self.ledger.account_class.from_dict(
|
||||||
self.ledger, Wallet(), {
|
self.ledger, Wallet(), {
|
||||||
"seed": "carbon smart garage balance margin twelve chest sword "
|
"seed": "carbon smart garage balance margin twelve chest sword "
|
||||||
|
@ -200,8 +203,8 @@ class TestTransactionSigning(unittest.TestCase):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
yield account.ensure_address_gap()
|
await account.ensure_address_gap()
|
||||||
address1, address2 = yield account.receiving.get_addresses(limit=2)
|
address1, address2 = await account.receiving.get_addresses(limit=2)
|
||||||
pubkey_hash1 = self.ledger.address_to_hash160(address1)
|
pubkey_hash1 = self.ledger.address_to_hash160(address1)
|
||||||
pubkey_hash2 = self.ledger.address_to_hash160(address2)
|
pubkey_hash2 = self.ledger.address_to_hash160(address2)
|
||||||
|
|
||||||
|
@ -211,9 +214,8 @@ class TestTransactionSigning(unittest.TestCase):
|
||||||
.add_inputs([tx_class.input_class.spend(get_output(2*COIN, pubkey_hash1))]) \
|
.add_inputs([tx_class.input_class.spend(get_output(2*COIN, pubkey_hash1))]) \
|
||||||
.add_outputs([tx_class.output_class.pay_pubkey_hash(int(1.9*COIN), pubkey_hash2)]) \
|
.add_outputs([tx_class.output_class.pay_pubkey_hash(int(1.9*COIN), pubkey_hash2)]) \
|
||||||
|
|
||||||
yield tx.sign([account])
|
await tx.sign([account])
|
||||||
|
|
||||||
print(hexlify(tx.inputs[0].script.values['signature']))
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
hexlify(tx.inputs[0].script.values['signature']),
|
hexlify(tx.inputs[0].script.values['signature']),
|
||||||
b'304402205a1df8cd5d2d2fa5934b756883d6c07e4f83e1350c740992d47a12422'
|
b'304402205a1df8cd5d2d2fa5934b756883d6c07e4f83e1350c740992d47a12422'
|
||||||
|
@ -221,15 +223,14 @@ class TestTransactionSigning(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TransactionIOBalancing(unittest.TestCase):
|
class TransactionIOBalancing(AsyncioTestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def asyncSetUp(self):
|
||||||
def setUp(self):
|
|
||||||
self.ledger = ledger_class({
|
self.ledger = ledger_class({
|
||||||
'db': ledger_class.database_class(':memory:'),
|
'db': ledger_class.database_class(':memory:'),
|
||||||
'headers': ledger_class.headers_class(':memory:'),
|
'headers': ledger_class.headers_class(':memory:'),
|
||||||
})
|
})
|
||||||
yield self.ledger.db.open()
|
await self.ledger.db.open()
|
||||||
self.account = self.ledger.account_class.from_dict(
|
self.account = self.ledger.account_class.from_dict(
|
||||||
self.ledger, Wallet(), {
|
self.ledger, Wallet(), {
|
||||||
"seed": "carbon smart garage balance margin twelve chest sword "
|
"seed": "carbon smart garage balance margin twelve chest sword "
|
||||||
|
@ -237,10 +238,13 @@ class TransactionIOBalancing(unittest.TestCase):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
addresses = yield self.account.ensure_address_gap()
|
addresses = await self.account.ensure_address_gap()
|
||||||
self.pubkey_hash = [self.ledger.address_to_hash160(a) for a in addresses]
|
self.pubkey_hash = [self.ledger.address_to_hash160(a) for a in addresses]
|
||||||
self.hash_cycler = cycle(self.pubkey_hash)
|
self.hash_cycler = cycle(self.pubkey_hash)
|
||||||
|
|
||||||
|
async def asyncTearDown(self):
|
||||||
|
await self.ledger.db.close()
|
||||||
|
|
||||||
def txo(self, amount, address=None):
|
def txo(self, amount, address=None):
|
||||||
return get_output(int(amount*COIN), address or next(self.hash_cycler))
|
return get_output(int(amount*COIN), address or next(self.hash_cycler))
|
||||||
|
|
||||||
|
@ -250,8 +254,7 @@ class TransactionIOBalancing(unittest.TestCase):
|
||||||
def tx(self, inputs, outputs):
|
def tx(self, inputs, outputs):
|
||||||
return ledger_class.transaction_class.create(inputs, outputs, [self.account], self.account)
|
return ledger_class.transaction_class.create(inputs, outputs, [self.account], self.account)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def create_utxos(self, amounts):
|
||||||
def create_utxos(self, amounts):
|
|
||||||
utxos = [self.txo(amount) for amount in amounts]
|
utxos = [self.txo(amount) for amount in amounts]
|
||||||
|
|
||||||
self.funding_tx = ledger_class.transaction_class(is_verified=True) \
|
self.funding_tx = ledger_class.transaction_class(is_verified=True) \
|
||||||
|
@ -260,7 +263,7 @@ class TransactionIOBalancing(unittest.TestCase):
|
||||||
|
|
||||||
save_tx = 'insert'
|
save_tx = 'insert'
|
||||||
for utxo in utxos:
|
for utxo in utxos:
|
||||||
yield self.ledger.db.save_transaction_io(
|
await self.ledger.db.save_transaction_io(
|
||||||
save_tx, self.funding_tx,
|
save_tx, self.funding_tx,
|
||||||
self.ledger.hash160_to_address(utxo.script.values['pubkey_hash']),
|
self.ledger.hash160_to_address(utxo.script.values['pubkey_hash']),
|
||||||
utxo.script.values['pubkey_hash'], ''
|
utxo.script.values['pubkey_hash'], ''
|
||||||
|
@ -277,17 +280,16 @@ class TransactionIOBalancing(unittest.TestCase):
|
||||||
def outputs(tx):
|
def outputs(tx):
|
||||||
return [round(o.amount/COIN, 2) for o in tx.outputs]
|
return [round(o.amount/COIN, 2) for o in tx.outputs]
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def test_basic_use_cases(self):
|
||||||
def test_basic_use_cases(self):
|
|
||||||
self.ledger.fee_per_byte = int(.01*CENT)
|
self.ledger.fee_per_byte = int(.01*CENT)
|
||||||
|
|
||||||
# available UTXOs for filling missing inputs
|
# available UTXOs for filling missing inputs
|
||||||
utxos = yield self.create_utxos([
|
utxos = await self.create_utxos([
|
||||||
1, 1, 3, 5, 10
|
1, 1, 3, 5, 10
|
||||||
])
|
])
|
||||||
|
|
||||||
# pay 3 coins (3.02 w/ fees)
|
# pay 3 coins (3.02 w/ fees)
|
||||||
tx = yield self.tx(
|
tx = await self.tx(
|
||||||
[], # inputs
|
[], # inputs
|
||||||
[self.txo(3)] # outputs
|
[self.txo(3)] # outputs
|
||||||
)
|
)
|
||||||
|
@ -296,10 +298,10 @@ class TransactionIOBalancing(unittest.TestCase):
|
||||||
# a change of 1.98 is added to reach balance
|
# a change of 1.98 is added to reach balance
|
||||||
self.assertEqual(self.outputs(tx), [3, 1.98])
|
self.assertEqual(self.outputs(tx), [3, 1.98])
|
||||||
|
|
||||||
yield self.ledger.release_outputs(utxos)
|
await self.ledger.release_outputs(utxos)
|
||||||
|
|
||||||
# pay 2.98 coins (3.00 w/ fees)
|
# pay 2.98 coins (3.00 w/ fees)
|
||||||
tx = yield self.tx(
|
tx = await self.tx(
|
||||||
[], # inputs
|
[], # inputs
|
||||||
[self.txo(2.98)] # outputs
|
[self.txo(2.98)] # outputs
|
||||||
)
|
)
|
||||||
|
@ -307,10 +309,10 @@ class TransactionIOBalancing(unittest.TestCase):
|
||||||
self.assertEqual(self.inputs(tx), [3])
|
self.assertEqual(self.inputs(tx), [3])
|
||||||
self.assertEqual(self.outputs(tx), [2.98])
|
self.assertEqual(self.outputs(tx), [2.98])
|
||||||
|
|
||||||
yield self.ledger.release_outputs(utxos)
|
await self.ledger.release_outputs(utxos)
|
||||||
|
|
||||||
# supplied input and output, but input is not enough to cover output
|
# supplied input and output, but input is not enough to cover output
|
||||||
tx = yield self.tx(
|
tx = await self.tx(
|
||||||
[self.txi(self.txo(10))], # inputs
|
[self.txi(self.txo(10))], # inputs
|
||||||
[self.txo(11)] # outputs
|
[self.txo(11)] # outputs
|
||||||
)
|
)
|
||||||
|
@ -319,10 +321,10 @@ class TransactionIOBalancing(unittest.TestCase):
|
||||||
# change is now needed to consume extra input
|
# change is now needed to consume extra input
|
||||||
self.assertEqual([11, 1.96], self.outputs(tx))
|
self.assertEqual([11, 1.96], self.outputs(tx))
|
||||||
|
|
||||||
yield self.ledger.release_outputs(utxos)
|
await self.ledger.release_outputs(utxos)
|
||||||
|
|
||||||
# liquidating a UTXO
|
# liquidating a UTXO
|
||||||
tx = yield self.tx(
|
tx = await self.tx(
|
||||||
[self.txi(self.txo(10))], # inputs
|
[self.txi(self.txo(10))], # inputs
|
||||||
[] # outputs
|
[] # outputs
|
||||||
)
|
)
|
||||||
|
@ -330,10 +332,10 @@ class TransactionIOBalancing(unittest.TestCase):
|
||||||
# missing change added to consume the amount
|
# missing change added to consume the amount
|
||||||
self.assertEqual([9.98], self.outputs(tx))
|
self.assertEqual([9.98], self.outputs(tx))
|
||||||
|
|
||||||
yield self.ledger.release_outputs(utxos)
|
await self.ledger.release_outputs(utxos)
|
||||||
|
|
||||||
# liquidating at a loss, requires adding extra inputs
|
# liquidating at a loss, requires adding extra inputs
|
||||||
tx = yield self.tx(
|
tx = await self.tx(
|
||||||
[self.txi(self.txo(0.01))], # inputs
|
[self.txi(self.txo(0.01))], # inputs
|
||||||
[] # outputs
|
[] # outputs
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from twisted.trial import unittest
|
import unittest
|
||||||
|
|
||||||
from torba.util import ArithUint256
|
from torba.util import ArithUint256
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
|
import unittest
|
||||||
import tempfile
|
import tempfile
|
||||||
from twisted.trial import unittest
|
|
||||||
|
|
||||||
from torba.coin.bitcoinsegwit import MainNetLedger as BTCLedger
|
from torba.coin.bitcoinsegwit import MainNetLedger as BTCLedger
|
||||||
from torba.coin.bitcoincash import MainNetLedger as BCHLedger
|
from torba.coin.bitcoincash import MainNetLedger as BCHLedger
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import random
|
import random
|
||||||
import typing
|
import typing
|
||||||
from typing import List, Dict, Tuple, Type, Optional, Any
|
from typing import List, Dict, Tuple, Type, Optional, Any
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from torba.mnemonic import Mnemonic
|
from torba.mnemonic import Mnemonic
|
||||||
from torba.bip32 import PrivateKey, PubKey, from_extended_key_string
|
from torba.bip32 import PrivateKey, PubKey, from_extended_key_string
|
||||||
|
@ -44,12 +43,8 @@ class AddressManager:
|
||||||
def to_dict_instance(self) -> Optional[dict]:
|
def to_dict_instance(self) -> Optional[dict]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@property
|
|
||||||
def db(self):
|
|
||||||
return self.account.ledger.db
|
|
||||||
|
|
||||||
def _query_addresses(self, **constraints):
|
def _query_addresses(self, **constraints):
|
||||||
return self.db.get_addresses(
|
return self.account.ledger.db.get_addresses(
|
||||||
account=self.account,
|
account=self.account,
|
||||||
chain=self.chain_number,
|
chain=self.chain_number,
|
||||||
**constraints
|
**constraints
|
||||||
|
@ -58,26 +53,24 @@ class AddressManager:
|
||||||
def get_private_key(self, index: int) -> PrivateKey:
|
def get_private_key(self, index: int) -> PrivateKey:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_max_gap(self) -> defer.Deferred:
|
async def get_max_gap(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def ensure_address_gap(self) -> defer.Deferred:
|
async def ensure_address_gap(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_address_records(self, only_usable: bool = False, **constraints) -> defer.Deferred:
|
def get_address_records(self, only_usable: bool = False, **constraints):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_addresses(self, only_usable: bool = False, **constraints):
|
||||||
def get_addresses(self, only_usable: bool = False, **constraints) -> defer.Deferred:
|
records = await self.get_address_records(only_usable=only_usable, **constraints)
|
||||||
records = yield self.get_address_records(only_usable=only_usable, **constraints)
|
|
||||||
return [r['address'] for r in records]
|
return [r['address'] for r in records]
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_or_create_usable_address(self):
|
||||||
def get_or_create_usable_address(self) -> defer.Deferred:
|
addresses = await self.get_addresses(only_usable=True, limit=10)
|
||||||
addresses = yield self.get_addresses(only_usable=True, limit=10)
|
|
||||||
if addresses:
|
if addresses:
|
||||||
return random.choice(addresses)
|
return random.choice(addresses)
|
||||||
addresses = yield self.ensure_address_gap()
|
addresses = await self.ensure_address_gap()
|
||||||
return addresses[0]
|
return addresses[0]
|
||||||
|
|
||||||
|
|
||||||
|
@ -106,22 +99,20 @@ class HierarchicalDeterministic(AddressManager):
|
||||||
def get_private_key(self, index: int) -> PrivateKey:
|
def get_private_key(self, index: int) -> PrivateKey:
|
||||||
return self.account.private_key.child(self.chain_number).child(index)
|
return self.account.private_key.child(self.chain_number).child(index)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def generate_keys(self, start: int, end: int):
|
||||||
def generate_keys(self, start: int, end: int) -> defer.Deferred:
|
|
||||||
keys_batch, final_keys = [], []
|
keys_batch, final_keys = [], []
|
||||||
for index in range(start, end+1):
|
for index in range(start, end+1):
|
||||||
keys_batch.append((index, self.public_key.child(index)))
|
keys_batch.append((index, self.public_key.child(index)))
|
||||||
if index % 180 == 0 or index == end:
|
if index % 180 == 0 or index == end:
|
||||||
yield self.db.add_keys(
|
await self.account.ledger.db.add_keys(
|
||||||
self.account, self.chain_number, keys_batch
|
self.account, self.chain_number, keys_batch
|
||||||
)
|
)
|
||||||
final_keys.extend(keys_batch)
|
final_keys.extend(keys_batch)
|
||||||
keys_batch.clear()
|
keys_batch.clear()
|
||||||
return [key[1].address for key in final_keys]
|
return [key[1].address for key in final_keys]
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_max_gap(self):
|
||||||
def get_max_gap(self) -> defer.Deferred:
|
addresses = await self._query_addresses(order_by="position ASC")
|
||||||
addresses = yield self._query_addresses(order_by="position ASC")
|
|
||||||
max_gap = 0
|
max_gap = 0
|
||||||
current_gap = 0
|
current_gap = 0
|
||||||
for address in addresses:
|
for address in addresses:
|
||||||
|
@ -132,9 +123,8 @@ class HierarchicalDeterministic(AddressManager):
|
||||||
current_gap = 0
|
current_gap = 0
|
||||||
return max_gap
|
return max_gap
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def ensure_address_gap(self):
|
||||||
def ensure_address_gap(self) -> defer.Deferred:
|
addresses = await self._query_addresses(limit=self.gap, order_by="position DESC")
|
||||||
addresses = yield self._query_addresses(limit=self.gap, order_by="position DESC")
|
|
||||||
|
|
||||||
existing_gap = 0
|
existing_gap = 0
|
||||||
for address in addresses:
|
for address in addresses:
|
||||||
|
@ -148,7 +138,7 @@ class HierarchicalDeterministic(AddressManager):
|
||||||
|
|
||||||
start = addresses[0]['position']+1 if addresses else 0
|
start = addresses[0]['position']+1 if addresses else 0
|
||||||
end = start + (self.gap - existing_gap)
|
end = start + (self.gap - existing_gap)
|
||||||
new_keys = yield self.generate_keys(start, end-1)
|
new_keys = await self.generate_keys(start, end-1)
|
||||||
return new_keys
|
return new_keys
|
||||||
|
|
||||||
def get_address_records(self, only_usable: bool = False, **constraints):
|
def get_address_records(self, only_usable: bool = False, **constraints):
|
||||||
|
@ -176,20 +166,19 @@ class SingleKey(AddressManager):
|
||||||
def get_private_key(self, index: int) -> PrivateKey:
|
def get_private_key(self, index: int) -> PrivateKey:
|
||||||
return self.account.private_key
|
return self.account.private_key
|
||||||
|
|
||||||
def get_max_gap(self) -> defer.Deferred:
|
async def get_max_gap(self):
|
||||||
return defer.succeed(0)
|
return 0
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def ensure_address_gap(self):
|
||||||
def ensure_address_gap(self) -> defer.Deferred:
|
exists = await self.get_address_records()
|
||||||
exists = yield self.get_address_records()
|
|
||||||
if not exists:
|
if not exists:
|
||||||
yield self.db.add_keys(
|
await self.account.ledger.db.add_keys(
|
||||||
self.account, self.chain_number, [(0, self.public_key)]
|
self.account, self.chain_number, [(0, self.public_key)]
|
||||||
)
|
)
|
||||||
return [self.public_key.address]
|
return [self.public_key.address]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def get_address_records(self, only_usable: bool = False, **constraints) -> defer.Deferred:
|
def get_address_records(self, only_usable: bool = False, **constraints):
|
||||||
return self._query_addresses(**constraints)
|
return self._query_addresses(**constraints)
|
||||||
|
|
||||||
|
|
||||||
|
@ -289,9 +278,8 @@ class BaseAccount:
|
||||||
'address_generator': self.address_generator.to_dict(self.receiving, self.change)
|
'address_generator': self.address_generator.to_dict(self.receiving, self.change)
|
||||||
}
|
}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_details(self, show_seed=False, **kwargs):
|
||||||
def get_details(self, show_seed=False, **kwargs):
|
satoshis = await self.get_balance(**kwargs)
|
||||||
satoshis = yield self.get_balance(**kwargs)
|
|
||||||
details = {
|
details = {
|
||||||
'id': self.id,
|
'id': self.id,
|
||||||
'name': self.name,
|
'name': self.name,
|
||||||
|
@ -325,23 +313,21 @@ class BaseAccount:
|
||||||
self.password = None
|
self.password = None
|
||||||
self.encrypted = True
|
self.encrypted = True
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def ensure_address_gap(self):
|
||||||
def ensure_address_gap(self):
|
|
||||||
addresses = []
|
addresses = []
|
||||||
for address_manager in self.address_managers:
|
for address_manager in self.address_managers:
|
||||||
new_addresses = yield address_manager.ensure_address_gap()
|
new_addresses = await address_manager.ensure_address_gap()
|
||||||
addresses.extend(new_addresses)
|
addresses.extend(new_addresses)
|
||||||
return addresses
|
return addresses
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_addresses(self, **constraints):
|
||||||
def get_addresses(self, **constraints) -> defer.Deferred:
|
rows = await self.ledger.db.select_addresses('address', account=self, **constraints)
|
||||||
rows = yield self.ledger.db.select_addresses('address', account=self, **constraints)
|
|
||||||
return [r[0] for r in rows]
|
return [r[0] for r in rows]
|
||||||
|
|
||||||
def get_address_records(self, **constraints) -> defer.Deferred:
|
def get_address_records(self, **constraints):
|
||||||
return self.ledger.db.get_addresses(account=self, **constraints)
|
return self.ledger.db.get_addresses(account=self, **constraints)
|
||||||
|
|
||||||
def get_address_count(self, **constraints) -> defer.Deferred:
|
def get_address_count(self, **constraints):
|
||||||
return self.ledger.db.get_address_count(account=self, **constraints)
|
return self.ledger.db.get_address_count(account=self, **constraints)
|
||||||
|
|
||||||
def get_private_key(self, chain: int, index: int) -> PrivateKey:
|
def get_private_key(self, chain: int, index: int) -> PrivateKey:
|
||||||
|
@ -355,10 +341,9 @@ class BaseAccount:
|
||||||
constraints.update({'height__lte': height, 'height__gt': 0})
|
constraints.update({'height__lte': height, 'height__gt': 0})
|
||||||
return self.ledger.db.get_balance(account=self, **constraints)
|
return self.ledger.db.get_balance(account=self, **constraints)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_max_gap(self):
|
||||||
def get_max_gap(self):
|
change_gap = await self.change.get_max_gap()
|
||||||
change_gap = yield self.change.get_max_gap()
|
receiving_gap = await self.receiving.get_max_gap()
|
||||||
receiving_gap = yield self.receiving.get_max_gap()
|
|
||||||
return {
|
return {
|
||||||
'max_change_gap': change_gap,
|
'max_change_gap': change_gap,
|
||||||
'max_receiving_gap': receiving_gap,
|
'max_receiving_gap': receiving_gap,
|
||||||
|
@ -376,24 +361,23 @@ class BaseAccount:
|
||||||
def get_transaction_count(self, **constraints):
|
def get_transaction_count(self, **constraints):
|
||||||
return self.ledger.db.get_transaction_count(account=self, **constraints)
|
return self.ledger.db.get_transaction_count(account=self, **constraints)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def fund(self, to_account, amount=None, everything=False,
|
||||||
def fund(self, to_account, amount=None, everything=False,
|
|
||||||
outputs=1, broadcast=False, **constraints):
|
outputs=1, broadcast=False, **constraints):
|
||||||
assert self.ledger == to_account.ledger, 'Can only transfer between accounts of the same ledger.'
|
assert self.ledger == to_account.ledger, 'Can only transfer between accounts of the same ledger.'
|
||||||
tx_class = self.ledger.transaction_class
|
tx_class = self.ledger.transaction_class
|
||||||
if everything:
|
if everything:
|
||||||
utxos = yield self.get_utxos(**constraints)
|
utxos = await self.get_utxos(**constraints)
|
||||||
yield self.ledger.reserve_outputs(utxos)
|
await self.ledger.reserve_outputs(utxos)
|
||||||
tx = yield tx_class.create(
|
tx = await tx_class.create(
|
||||||
inputs=[tx_class.input_class.spend(txo) for txo in utxos],
|
inputs=[tx_class.input_class.spend(txo) for txo in utxos],
|
||||||
outputs=[],
|
outputs=[],
|
||||||
funding_accounts=[self],
|
funding_accounts=[self],
|
||||||
change_account=to_account
|
change_account=to_account
|
||||||
)
|
)
|
||||||
elif amount > 0:
|
elif amount > 0:
|
||||||
to_address = yield to_account.change.get_or_create_usable_address()
|
to_address = await to_account.change.get_or_create_usable_address()
|
||||||
to_hash160 = to_account.ledger.address_to_hash160(to_address)
|
to_hash160 = to_account.ledger.address_to_hash160(to_address)
|
||||||
tx = yield tx_class.create(
|
tx = await tx_class.create(
|
||||||
inputs=[],
|
inputs=[],
|
||||||
outputs=[
|
outputs=[
|
||||||
tx_class.output_class.pay_pubkey_hash(amount//outputs, to_hash160)
|
tx_class.output_class.pay_pubkey_hash(amount//outputs, to_hash160)
|
||||||
|
@ -406,9 +390,9 @@ class BaseAccount:
|
||||||
raise ValueError('An amount is required.')
|
raise ValueError('An amount is required.')
|
||||||
|
|
||||||
if broadcast:
|
if broadcast:
|
||||||
yield self.ledger.broadcast(tx)
|
await self.ledger.broadcast(tx)
|
||||||
else:
|
else:
|
||||||
yield self.ledger.release_outputs(
|
await self.ledger.release_outputs(
|
||||||
[txi.txo_ref.txo for txi in tx.inputs]
|
[txi.txo_ref.txo for txi in tx.inputs]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Tuple, List, Sequence
|
from typing import Tuple, List
|
||||||
|
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from twisted.internet import defer
|
import aiosqlite
|
||||||
from twisted.enterprise import adbapi
|
|
||||||
|
|
||||||
from torba.hash import TXRefImmutable
|
from torba.hash import TXRefImmutable
|
||||||
from torba.basetransaction import BaseTransaction
|
from torba.basetransaction import BaseTransaction
|
||||||
|
@ -107,25 +106,21 @@ def row_dict_or_default(rows, fields, default=None):
|
||||||
|
|
||||||
class SQLiteMixin:
|
class SQLiteMixin:
|
||||||
|
|
||||||
CREATE_TABLES_QUERY: Sequence[str] = ()
|
CREATE_TABLES_QUERY: str
|
||||||
|
|
||||||
def __init__(self, path):
|
def __init__(self, path):
|
||||||
self._db_path = path
|
self._db_path = path
|
||||||
self.db: adbapi.ConnectionPool = None
|
self.db: aiosqlite.Connection = None
|
||||||
self.ledger = None
|
self.ledger = None
|
||||||
|
|
||||||
def open(self):
|
async def open(self):
|
||||||
log.info("connecting to database: %s", self._db_path)
|
log.info("connecting to database: %s", self._db_path)
|
||||||
self.db = adbapi.ConnectionPool(
|
self.db = aiosqlite.connect(self._db_path)
|
||||||
'sqlite3', self._db_path, cp_min=1, cp_max=1, check_same_thread=False
|
await self.db.__aenter__()
|
||||||
)
|
await self.db.executescript(self.CREATE_TABLES_QUERY)
|
||||||
return self.db.runInteraction(
|
|
||||||
lambda t: t.executescript(self.CREATE_TABLES_QUERY)
|
|
||||||
)
|
|
||||||
|
|
||||||
def close(self):
|
async def close(self):
|
||||||
self.db.close()
|
await self.db.close()
|
||||||
return defer.succeed(True)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _insert_sql(table: str, data: dict) -> Tuple[str, List]:
|
def _insert_sql(table: str, data: dict) -> Tuple[str, List]:
|
||||||
|
@ -247,78 +242,75 @@ class BaseDatabase(SQLiteMixin):
|
||||||
'script': sqlite3.Binary(txo.script.source)
|
'script': sqlite3.Binary(txo.script.source)
|
||||||
}
|
}
|
||||||
|
|
||||||
def save_transaction_io(self, save_tx, tx: BaseTransaction, address, txhash, history):
|
async def save_transaction_io(self, save_tx, tx: BaseTransaction, address, txhash, history):
|
||||||
|
|
||||||
def _steps(t):
|
if save_tx == 'insert':
|
||||||
if save_tx == 'insert':
|
await self.db.execute(*self._insert_sql('tx', {
|
||||||
self.execute(t, *self._insert_sql('tx', {
|
'txid': tx.id,
|
||||||
|
'raw': sqlite3.Binary(tx.raw),
|
||||||
|
'height': tx.height,
|
||||||
|
'position': tx.position,
|
||||||
|
'is_verified': tx.is_verified
|
||||||
|
}))
|
||||||
|
elif save_tx == 'update':
|
||||||
|
await self.db.execute(*self._update_sql("tx", {
|
||||||
|
'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified
|
||||||
|
}, 'txid = ?', (tx.id,)))
|
||||||
|
|
||||||
|
existing_txos = [r[0] for r in await self.db.execute_fetchall(*query(
|
||||||
|
"SELECT position FROM txo", txid=tx.id
|
||||||
|
))]
|
||||||
|
|
||||||
|
for txo in tx.outputs:
|
||||||
|
if txo.position in existing_txos:
|
||||||
|
continue
|
||||||
|
if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == txhash:
|
||||||
|
await self.db.execute(*self._insert_sql("txo", self.txo_to_row(tx, address, txo)))
|
||||||
|
elif txo.script.is_pay_script_hash:
|
||||||
|
# TODO: implement script hash payments
|
||||||
|
print('Database.save_transaction_io: pay script hash is not implemented!')
|
||||||
|
|
||||||
|
# lookup the address associated with each TXI (via its TXO)
|
||||||
|
txoid_to_address = {r[0]: r[1] for r in await self.db.execute_fetchall(*query(
|
||||||
|
"SELECT txoid, address FROM txo", txoid__in=[txi.txo_ref.id for txi in tx.inputs]
|
||||||
|
))}
|
||||||
|
|
||||||
|
# list of TXIs that have already been added
|
||||||
|
existing_txis = [r[0] for r in await self.db.execute_fetchall(*query(
|
||||||
|
"SELECT txoid FROM txi", txid=tx.id
|
||||||
|
))]
|
||||||
|
|
||||||
|
for txi in tx.inputs:
|
||||||
|
txoid = txi.txo_ref.id
|
||||||
|
new_txi = txoid not in existing_txis
|
||||||
|
address_matches = txoid_to_address.get(txoid) == address
|
||||||
|
if new_txi and address_matches:
|
||||||
|
await self.db.execute(*self._insert_sql("txi", {
|
||||||
'txid': tx.id,
|
'txid': tx.id,
|
||||||
'raw': sqlite3.Binary(tx.raw),
|
'txoid': txoid,
|
||||||
'height': tx.height,
|
'address': address,
|
||||||
'position': tx.position,
|
|
||||||
'is_verified': tx.is_verified
|
|
||||||
}))
|
}))
|
||||||
elif save_tx == 'update':
|
|
||||||
self.execute(t, *self._update_sql("tx", {
|
|
||||||
'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified
|
|
||||||
}, 'txid = ?', (tx.id,)))
|
|
||||||
|
|
||||||
existing_txos = [r[0] for r in self.execute(t, *query(
|
await self._set_address_history(address, history)
|
||||||
"SELECT position FROM txo", txid=tx.id
|
|
||||||
)).fetchall()]
|
|
||||||
|
|
||||||
for txo in tx.outputs:
|
async def reserve_outputs(self, txos, is_reserved=True):
|
||||||
if txo.position in existing_txos:
|
|
||||||
continue
|
|
||||||
if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == txhash:
|
|
||||||
self.execute(t, *self._insert_sql("txo", self.txo_to_row(tx, address, txo)))
|
|
||||||
elif txo.script.is_pay_script_hash:
|
|
||||||
# TODO: implement script hash payments
|
|
||||||
print('Database.save_transaction_io: pay script hash is not implemented!')
|
|
||||||
|
|
||||||
# lookup the address associated with each TXI (via its TXO)
|
|
||||||
txoid_to_address = {r[0]: r[1] for r in self.execute(t, *query(
|
|
||||||
"SELECT txoid, address FROM txo", txoid__in=[txi.txo_ref.id for txi in tx.inputs]
|
|
||||||
)).fetchall()}
|
|
||||||
|
|
||||||
# list of TXIs that have already been added
|
|
||||||
existing_txis = [r[0] for r in self.execute(t, *query(
|
|
||||||
"SELECT txoid FROM txi", txid=tx.id
|
|
||||||
)).fetchall()]
|
|
||||||
|
|
||||||
for txi in tx.inputs:
|
|
||||||
txoid = txi.txo_ref.id
|
|
||||||
new_txi = txoid not in existing_txis
|
|
||||||
address_matches = txoid_to_address.get(txoid) == address
|
|
||||||
if new_txi and address_matches:
|
|
||||||
self.execute(t, *self._insert_sql("txi", {
|
|
||||||
'txid': tx.id,
|
|
||||||
'txoid': txoid,
|
|
||||||
'address': address,
|
|
||||||
}))
|
|
||||||
|
|
||||||
self._set_address_history(t, address, history)
|
|
||||||
|
|
||||||
return self.db.runInteraction(_steps)
|
|
||||||
|
|
||||||
def reserve_outputs(self, txos, is_reserved=True):
|
|
||||||
txoids = [txo.id for txo in txos]
|
txoids = [txo.id for txo in txos]
|
||||||
return self.run_operation(
|
await self.db.execute(
|
||||||
"UPDATE txo SET is_reserved = ? WHERE txoid IN ({})".format(
|
"UPDATE txo SET is_reserved = ? WHERE txoid IN ({})".format(
|
||||||
', '.join(['?']*len(txoids))
|
', '.join(['?']*len(txoids))
|
||||||
), [is_reserved]+txoids
|
), [is_reserved]+txoids
|
||||||
)
|
)
|
||||||
|
|
||||||
def release_outputs(self, txos):
|
async def release_outputs(self, txos):
|
||||||
return self.reserve_outputs(txos, is_reserved=False)
|
await self.reserve_outputs(txos, is_reserved=False)
|
||||||
|
|
||||||
def rewind_blockchain(self, above_height): # pylint: disable=no-self-use
|
async def rewind_blockchain(self, above_height): # pylint: disable=no-self-use
|
||||||
# TODO:
|
# TODO:
|
||||||
# 1. delete transactions above_height
|
# 1. delete transactions above_height
|
||||||
# 2. update address histories removing deleted TXs
|
# 2. update address histories removing deleted TXs
|
||||||
return defer.succeed(True)
|
return True
|
||||||
|
|
||||||
def select_transactions(self, cols, account=None, **constraints):
|
async def select_transactions(self, cols, account=None, **constraints):
|
||||||
if 'txid' not in constraints and account is not None:
|
if 'txid' not in constraints and account is not None:
|
||||||
constraints['$account'] = account.public_key.address
|
constraints['$account'] = account.public_key.address
|
||||||
constraints['txid__in'] = """
|
constraints['txid__in'] = """
|
||||||
|
@ -328,13 +320,14 @@ class BaseDatabase(SQLiteMixin):
|
||||||
SELECT txi.txid FROM txi
|
SELECT txi.txid FROM txi
|
||||||
JOIN pubkey_address USING (address) WHERE pubkey_address.account = :$account
|
JOIN pubkey_address USING (address) WHERE pubkey_address.account = :$account
|
||||||
"""
|
"""
|
||||||
return self.run_query(*query("SELECT {} FROM tx".format(cols), **constraints))
|
return await self.db.execute_fetchall(
|
||||||
|
*query("SELECT {} FROM tx".format(cols), **constraints)
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_transactions(self, my_account=None, **constraints):
|
||||||
def get_transactions(self, my_account=None, **constraints):
|
|
||||||
my_account = my_account or constraints.get('account', None)
|
my_account = my_account or constraints.get('account', None)
|
||||||
|
|
||||||
tx_rows = yield self.select_transactions(
|
tx_rows = await self.select_transactions(
|
||||||
'txid, raw, height, position, is_verified',
|
'txid, raw, height, position, is_verified',
|
||||||
order_by=["height DESC", "position DESC"],
|
order_by=["height DESC", "position DESC"],
|
||||||
**constraints
|
**constraints
|
||||||
|
@ -352,7 +345,7 @@ class BaseDatabase(SQLiteMixin):
|
||||||
|
|
||||||
annotated_txos = {
|
annotated_txos = {
|
||||||
txo.id: txo for txo in
|
txo.id: txo for txo in
|
||||||
(yield self.get_txos(
|
(await self.get_txos(
|
||||||
my_account=my_account,
|
my_account=my_account,
|
||||||
txid__in=txids
|
txid__in=txids
|
||||||
))
|
))
|
||||||
|
@ -360,7 +353,7 @@ class BaseDatabase(SQLiteMixin):
|
||||||
|
|
||||||
referenced_txos = {
|
referenced_txos = {
|
||||||
txo.id: txo for txo in
|
txo.id: txo for txo in
|
||||||
(yield self.get_txos(
|
(await self.get_txos(
|
||||||
my_account=my_account,
|
my_account=my_account,
|
||||||
txoid__in=query("SELECT txoid FROM txi", **{'txid__in': txids})[0]
|
txoid__in=query("SELECT txoid FROM txi", **{'txid__in': txids})[0]
|
||||||
))
|
))
|
||||||
|
@ -380,33 +373,30 @@ class BaseDatabase(SQLiteMixin):
|
||||||
|
|
||||||
return txs
|
return txs
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_transaction_count(self, **constraints):
|
||||||
def get_transaction_count(self, **constraints):
|
|
||||||
constraints.pop('offset', None)
|
constraints.pop('offset', None)
|
||||||
constraints.pop('limit', None)
|
constraints.pop('limit', None)
|
||||||
constraints.pop('order_by', None)
|
constraints.pop('order_by', None)
|
||||||
count = yield self.select_transactions('count(*)', **constraints)
|
count = await self.select_transactions('count(*)', **constraints)
|
||||||
return count[0][0]
|
return count[0][0]
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_transaction(self, **constraints):
|
||||||
def get_transaction(self, **constraints):
|
txs = await self.get_transactions(limit=1, **constraints)
|
||||||
txs = yield self.get_transactions(limit=1, **constraints)
|
|
||||||
if txs:
|
if txs:
|
||||||
return txs[0]
|
return txs[0]
|
||||||
|
|
||||||
def select_txos(self, cols, **constraints):
|
async def select_txos(self, cols, **constraints):
|
||||||
return self.run_query(*query(
|
return await self.db.execute_fetchall(*query(
|
||||||
"SELECT {} FROM txo"
|
"SELECT {} FROM txo"
|
||||||
" JOIN pubkey_address USING (address)"
|
" JOIN pubkey_address USING (address)"
|
||||||
" JOIN tx USING (txid)".format(cols), **constraints
|
" JOIN tx USING (txid)".format(cols), **constraints
|
||||||
))
|
))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_txos(self, my_account=None, **constraints):
|
||||||
def get_txos(self, my_account=None, **constraints):
|
|
||||||
my_account = my_account or constraints.get('account', None)
|
my_account = my_account or constraints.get('account', None)
|
||||||
if isinstance(my_account, BaseAccount):
|
if isinstance(my_account, BaseAccount):
|
||||||
my_account = my_account.public_key.address
|
my_account = my_account.public_key.address
|
||||||
rows = yield self.select_txos(
|
rows = await self.select_txos(
|
||||||
"amount, script, txid, tx.height, txo.position, chain, account", **constraints
|
"amount, script, txid, tx.height, txo.position, chain, account", **constraints
|
||||||
)
|
)
|
||||||
output_class = self.ledger.transaction_class.output_class
|
output_class = self.ledger.transaction_class.output_class
|
||||||
|
@ -421,12 +411,11 @@ class BaseDatabase(SQLiteMixin):
|
||||||
) for row in rows
|
) for row in rows
|
||||||
]
|
]
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_txo_count(self, **constraints):
|
||||||
def get_txo_count(self, **constraints):
|
|
||||||
constraints.pop('offset', None)
|
constraints.pop('offset', None)
|
||||||
constraints.pop('limit', None)
|
constraints.pop('limit', None)
|
||||||
constraints.pop('order_by', None)
|
constraints.pop('order_by', None)
|
||||||
count = yield self.select_txos('count(*)', **constraints)
|
count = await self.select_txos('count(*)', **constraints)
|
||||||
return count[0][0]
|
return count[0][0]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -442,37 +431,33 @@ class BaseDatabase(SQLiteMixin):
|
||||||
self.constrain_utxo(constraints)
|
self.constrain_utxo(constraints)
|
||||||
return self.get_txo_count(**constraints)
|
return self.get_txo_count(**constraints)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_balance(self, **constraints):
|
||||||
def get_balance(self, **constraints):
|
|
||||||
self.constrain_utxo(constraints)
|
self.constrain_utxo(constraints)
|
||||||
balance = yield self.select_txos('SUM(amount)', **constraints)
|
balance = await self.select_txos('SUM(amount)', **constraints)
|
||||||
return balance[0][0] or 0
|
return balance[0][0] or 0
|
||||||
|
|
||||||
def select_addresses(self, cols, **constraints):
|
async def select_addresses(self, cols, **constraints):
|
||||||
return self.run_query(*query(
|
return await self.db.execute_fetchall(*query(
|
||||||
"SELECT {} FROM pubkey_address".format(cols), **constraints
|
"SELECT {} FROM pubkey_address".format(cols), **constraints
|
||||||
))
|
))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_addresses(self, cols=('address', 'account', 'chain', 'position', 'used_times'), **constraints):
|
||||||
def get_addresses(self, cols=('address', 'account', 'chain', 'position', 'used_times'), **constraints):
|
addresses = await self.select_addresses(', '.join(cols), **constraints)
|
||||||
addresses = yield self.select_addresses(', '.join(cols), **constraints)
|
|
||||||
return rows_to_dict(addresses, cols)
|
return rows_to_dict(addresses, cols)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_address_count(self, **constraints):
|
||||||
def get_address_count(self, **constraints):
|
count = await self.select_addresses('count(*)', **constraints)
|
||||||
count = yield self.select_addresses('count(*)', **constraints)
|
|
||||||
return count[0][0]
|
return count[0][0]
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_address(self, **constraints):
|
||||||
def get_address(self, **constraints):
|
addresses = await self.get_addresses(
|
||||||
addresses = yield self.get_addresses(
|
|
||||||
cols=('address', 'account', 'chain', 'position', 'pubkey', 'history', 'used_times'),
|
cols=('address', 'account', 'chain', 'position', 'pubkey', 'history', 'used_times'),
|
||||||
limit=1, **constraints
|
limit=1, **constraints
|
||||||
)
|
)
|
||||||
if addresses:
|
if addresses:
|
||||||
return addresses[0]
|
return addresses[0]
|
||||||
|
|
||||||
def add_keys(self, account, chain, keys):
|
async def add_keys(self, account, chain, keys):
|
||||||
sql = (
|
sql = (
|
||||||
"insert into pubkey_address "
|
"insert into pubkey_address "
|
||||||
"(address, account, chain, position, pubkey) "
|
"(address, account, chain, position, pubkey) "
|
||||||
|
@ -484,14 +469,13 @@ class BaseDatabase(SQLiteMixin):
|
||||||
pubkey.address, account.public_key.address, chain, position,
|
pubkey.address, account.public_key.address, chain, position,
|
||||||
sqlite3.Binary(pubkey.pubkey_bytes)
|
sqlite3.Binary(pubkey.pubkey_bytes)
|
||||||
))
|
))
|
||||||
return self.run_operation(sql, values)
|
await self.db.execute(sql, values)
|
||||||
|
|
||||||
@classmethod
|
async def _set_address_history(self, address, history):
|
||||||
def _set_address_history(cls, t, address, history):
|
await self.db.execute(
|
||||||
cls.execute(
|
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
|
||||||
t, "UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
|
|
||||||
(history, history.count(':')//2, address)
|
(history, history.count(':')//2, address)
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_address_history(self, address, history):
|
async def set_address_history(self, address, history):
|
||||||
return self.db.runInteraction(lambda t: self._set_address_history(t, address, history))
|
await self._set_address_history(address, history)
|
||||||
|
|
|
@ -1,11 +1,10 @@
|
||||||
import os
|
import os
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Optional, Iterator, Tuple
|
from typing import Optional, Iterator, Tuple
|
||||||
from binascii import hexlify
|
from binascii import hexlify
|
||||||
|
|
||||||
from twisted.internet import threads, defer
|
|
||||||
|
|
||||||
from torba.util import ArithUint256
|
from torba.util import ArithUint256
|
||||||
from torba.hash import double_sha256
|
from torba.hash import double_sha256
|
||||||
|
|
||||||
|
@ -36,16 +35,14 @@ class BaseHeaders:
|
||||||
self.io = BytesIO()
|
self.io = BytesIO()
|
||||||
self.path = path
|
self.path = path
|
||||||
self._size: Optional[int] = None
|
self._size: Optional[int] = None
|
||||||
self._header_connect_lock = defer.DeferredLock()
|
self._header_connect_lock = asyncio.Lock()
|
||||||
|
|
||||||
def open(self):
|
async def open(self):
|
||||||
if self.path != ':memory:':
|
if self.path != ':memory:':
|
||||||
self.io = open(self.path, 'a+b')
|
self.io = open(self.path, 'a+b')
|
||||||
return defer.succeed(True)
|
|
||||||
|
|
||||||
def close(self):
|
async def close(self):
|
||||||
self.io.close()
|
self.io.close()
|
||||||
return defer.succeed(True)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def serialize(header: dict) -> bytes:
|
def serialize(header: dict) -> bytes:
|
||||||
|
@ -95,16 +92,15 @@ class BaseHeaders:
|
||||||
return b'0' * 64
|
return b'0' * 64
|
||||||
return hexlify(double_sha256(header)[::-1])
|
return hexlify(double_sha256(header)[::-1])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def connect(self, start: int, headers: bytes) -> int:
|
||||||
def connect(self, start: int, headers: bytes):
|
|
||||||
added = 0
|
added = 0
|
||||||
bail = False
|
bail = False
|
||||||
yield self._header_connect_lock.acquire()
|
loop = asyncio.get_running_loop()
|
||||||
try:
|
async with self._header_connect_lock:
|
||||||
for height, chunk in self._iterate_chunks(start, headers):
|
for height, chunk in self._iterate_chunks(start, headers):
|
||||||
try:
|
try:
|
||||||
# validate_chunk() is CPU bound and reads previous chunks from file system
|
# validate_chunk() is CPU bound and reads previous chunks from file system
|
||||||
yield threads.deferToThread(self.validate_chunk, height, chunk)
|
await loop.run_in_executor(None, self.validate_chunk, height, chunk)
|
||||||
except InvalidHeader as e:
|
except InvalidHeader as e:
|
||||||
bail = True
|
bail = True
|
||||||
chunk = chunk[:(height-e.height+1)*self.header_size]
|
chunk = chunk[:(height-e.height+1)*self.header_size]
|
||||||
|
@ -115,14 +111,12 @@ class BaseHeaders:
|
||||||
self.io.truncate()
|
self.io.truncate()
|
||||||
# .seek()/.write()/.truncate() might also .flush() when needed
|
# .seek()/.write()/.truncate() might also .flush() when needed
|
||||||
# the goal here is mainly to ensure we're definitely flush()'ing
|
# the goal here is mainly to ensure we're definitely flush()'ing
|
||||||
yield threads.deferToThread(self.io.flush)
|
await loop.run_in_executor(None, self.io.flush)
|
||||||
self._size = None
|
self._size = None
|
||||||
added += written
|
added += written
|
||||||
if bail:
|
if bail:
|
||||||
break
|
break
|
||||||
finally:
|
return added
|
||||||
self._header_connect_lock.release()
|
|
||||||
defer.returnValue(added)
|
|
||||||
|
|
||||||
def validate_chunk(self, height, chunk):
|
def validate_chunk(self, height, chunk):
|
||||||
previous_hash, previous_header, previous_previous_header = None, None, None
|
previous_hash, previous_header, previous_previous_header = None, None, None
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from binascii import hexlify, unhexlify
|
from binascii import hexlify, unhexlify
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
@ -7,8 +8,6 @@ from typing import Dict, Type, Iterable
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from torba import baseaccount
|
from torba import baseaccount
|
||||||
from torba import basenetwork
|
from torba import basenetwork
|
||||||
from torba import basetransaction
|
from torba import basetransaction
|
||||||
|
@ -104,8 +103,8 @@ class BaseLedger(metaclass=LedgerRegistry):
|
||||||
)
|
)
|
||||||
|
|
||||||
self._transaction_processing_locks = {}
|
self._transaction_processing_locks = {}
|
||||||
self._utxo_reservation_lock = defer.DeferredLock()
|
self._utxo_reservation_lock = asyncio.Lock()
|
||||||
self._header_processing_lock = defer.DeferredLock()
|
self._header_processing_lock = asyncio.Lock()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_id(cls):
|
def get_id(cls):
|
||||||
|
@ -135,41 +134,32 @@ class BaseLedger(metaclass=LedgerRegistry):
|
||||||
def add_account(self, account: baseaccount.BaseAccount):
|
def add_account(self, account: baseaccount.BaseAccount):
|
||||||
self.accounts.append(account)
|
self.accounts.append(account)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_private_key_for_address(self, address):
|
||||||
def get_private_key_for_address(self, address):
|
match = await self.db.get_address(address=address)
|
||||||
match = yield self.db.get_address(address=address)
|
|
||||||
if match:
|
if match:
|
||||||
for account in self.accounts:
|
for account in self.accounts:
|
||||||
if match['account'] == account.public_key.address:
|
if match['account'] == account.public_key.address:
|
||||||
return account.get_private_key(match['chain'], match['position'])
|
return account.get_private_key(match['chain'], match['position'])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_effective_amount_estimators(self, funding_accounts: Iterable[baseaccount.BaseAccount]):
|
||||||
def get_effective_amount_estimators(self, funding_accounts: Iterable[baseaccount.BaseAccount]):
|
|
||||||
estimators = []
|
estimators = []
|
||||||
for account in funding_accounts:
|
for account in funding_accounts:
|
||||||
utxos = yield account.get_utxos()
|
utxos = await account.get_utxos()
|
||||||
for utxo in utxos:
|
for utxo in utxos:
|
||||||
estimators.append(utxo.get_estimator(self))
|
estimators.append(utxo.get_estimator(self))
|
||||||
return estimators
|
return estimators
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_spendable_utxos(self, amount: int, funding_accounts):
|
||||||
def get_spendable_utxos(self, amount: int, funding_accounts):
|
async with self._utxo_reservation_lock:
|
||||||
yield self._utxo_reservation_lock.acquire()
|
txos = await self.get_effective_amount_estimators(funding_accounts)
|
||||||
try:
|
|
||||||
txos = yield self.get_effective_amount_estimators(funding_accounts)
|
|
||||||
selector = CoinSelector(
|
selector = CoinSelector(
|
||||||
txos, amount,
|
txos, amount,
|
||||||
self.transaction_class.output_class.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(self)
|
self.transaction_class.output_class.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(self)
|
||||||
)
|
)
|
||||||
spendables = selector.select()
|
spendables = selector.select()
|
||||||
if spendables:
|
if spendables:
|
||||||
yield self.reserve_outputs(s.txo for s in spendables)
|
await self.reserve_outputs(s.txo for s in spendables)
|
||||||
except Exception:
|
return spendables
|
||||||
log.exception('Failed to get spendable utxos:')
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
self._utxo_reservation_lock.release()
|
|
||||||
return spendables
|
|
||||||
|
|
||||||
def reserve_outputs(self, txos):
|
def reserve_outputs(self, txos):
|
||||||
return self.db.reserve_outputs(txos)
|
return self.db.reserve_outputs(txos)
|
||||||
|
@ -177,16 +167,14 @@ class BaseLedger(metaclass=LedgerRegistry):
|
||||||
def release_outputs(self, txos):
|
def release_outputs(self, txos):
|
||||||
return self.db.release_outputs(txos)
|
return self.db.release_outputs(txos)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_local_status(self, address):
|
||||||
def get_local_status(self, address):
|
address_details = await self.db.get_address(address=address)
|
||||||
address_details = yield self.db.get_address(address=address)
|
|
||||||
history = address_details['history'] or ''
|
history = address_details['history'] or ''
|
||||||
h = sha256(history.encode())
|
h = sha256(history.encode())
|
||||||
return hexlify(h)
|
return hexlify(h)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_local_history(self, address):
|
||||||
def get_local_history(self, address):
|
address_details = await self.db.get_address(address=address)
|
||||||
address_details = yield self.db.get_address(address=address)
|
|
||||||
history = address_details['history'] or ''
|
history = address_details['history'] or ''
|
||||||
parts = history.split(':')[:-1]
|
parts = history.split(':')[:-1]
|
||||||
return list(zip(parts[0::2], map(int, parts[1::2])))
|
return list(zip(parts[0::2], map(int, parts[1::2])))
|
||||||
|
@ -203,43 +191,40 @@ class BaseLedger(metaclass=LedgerRegistry):
|
||||||
working_branch = double_sha256(combined)
|
working_branch = double_sha256(combined)
|
||||||
return hexlify(working_branch[::-1])
|
return hexlify(working_branch[::-1])
|
||||||
|
|
||||||
def validate_transaction_and_set_position(self, tx, height, merkle):
|
async def validate_transaction_and_set_position(self, tx, height):
|
||||||
if not height <= len(self.headers):
|
if not height <= len(self.headers):
|
||||||
return False
|
return False
|
||||||
|
merkle = await self.network.get_merkle(tx.id, height)
|
||||||
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
|
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
|
||||||
header = self.headers[height]
|
header = self.headers[height]
|
||||||
tx.position = merkle['pos']
|
tx.position = merkle['pos']
|
||||||
tx.is_verified = merkle_root == header['merkle_root']
|
tx.is_verified = merkle_root == header['merkle_root']
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def start(self):
|
||||||
def start(self):
|
|
||||||
if not os.path.exists(self.path):
|
if not os.path.exists(self.path):
|
||||||
os.mkdir(self.path)
|
os.mkdir(self.path)
|
||||||
yield defer.gatherResults([
|
await asyncio.gather(
|
||||||
self.db.open(),
|
self.db.open(),
|
||||||
self.headers.open()
|
self.headers.open()
|
||||||
])
|
)
|
||||||
first_connection = self.network.on_connected.first
|
first_connection = self.network.on_connected.first
|
||||||
self.network.start()
|
asyncio.ensure_future(self.network.start())
|
||||||
yield first_connection
|
await first_connection
|
||||||
yield self.join_network()
|
await self.join_network()
|
||||||
self.network.on_connected.listen(self.join_network)
|
self.network.on_connected.listen(self.join_network)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def join_network(self, *args):
|
||||||
def join_network(self, *args):
|
|
||||||
log.info("Subscribing and updating accounts.")
|
log.info("Subscribing and updating accounts.")
|
||||||
yield self.update_headers()
|
await self.update_headers()
|
||||||
yield self.network.subscribe_headers()
|
await self.network.subscribe_headers()
|
||||||
yield self.update_accounts()
|
await self.update_accounts()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def stop(self):
|
||||||
def stop(self):
|
await self.network.stop()
|
||||||
yield self.network.stop()
|
await self.db.close()
|
||||||
yield self.db.close()
|
await self.headers.close()
|
||||||
yield self.headers.close()
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def update_headers(self, height=None, headers=None, subscription_update=False):
|
||||||
def update_headers(self, height=None, headers=None, subscription_update=False):
|
|
||||||
rewound = 0
|
rewound = 0
|
||||||
while True:
|
while True:
|
||||||
|
|
||||||
|
@ -251,14 +236,14 @@ class BaseLedger(metaclass=LedgerRegistry):
|
||||||
subscription_update = False
|
subscription_update = False
|
||||||
|
|
||||||
if not headers:
|
if not headers:
|
||||||
header_response = yield self.network.get_headers(height, 2001)
|
header_response = await self.network.get_headers(height, 2001)
|
||||||
headers = header_response['hex']
|
headers = header_response['hex']
|
||||||
|
|
||||||
if not headers:
|
if not headers:
|
||||||
# Nothing to do, network thinks we're already at the latest height.
|
# Nothing to do, network thinks we're already at the latest height.
|
||||||
return
|
return
|
||||||
|
|
||||||
added = yield self.headers.connect(height, unhexlify(headers))
|
added = await self.headers.connect(height, unhexlify(headers))
|
||||||
if added > 0:
|
if added > 0:
|
||||||
height += added
|
height += added
|
||||||
self._on_header_controller.add(
|
self._on_header_controller.add(
|
||||||
|
@ -268,7 +253,7 @@ class BaseLedger(metaclass=LedgerRegistry):
|
||||||
# we started rewinding blocks and apparently found
|
# we started rewinding blocks and apparently found
|
||||||
# a new chain
|
# a new chain
|
||||||
rewound = 0
|
rewound = 0
|
||||||
yield self.db.rewind_blockchain(height)
|
await self.db.rewind_blockchain(height)
|
||||||
|
|
||||||
if subscription_update:
|
if subscription_update:
|
||||||
# subscription updates are for latest header already
|
# subscription updates are for latest header already
|
||||||
|
@ -310,66 +295,37 @@ class BaseLedger(metaclass=LedgerRegistry):
|
||||||
# robust sync, turn off subscription update shortcut
|
# robust sync, turn off subscription update shortcut
|
||||||
subscription_update = False
|
subscription_update = False
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def receive_header(self, response):
|
||||||
def receive_header(self, response):
|
async with self._header_processing_lock:
|
||||||
yield self._header_processing_lock.acquire()
|
|
||||||
try:
|
|
||||||
header = response[0]
|
header = response[0]
|
||||||
yield self.update_headers(
|
await self.update_headers(
|
||||||
height=header['height'], headers=header['hex'], subscription_update=True
|
height=header['height'], headers=header['hex'], subscription_update=True
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
self._header_processing_lock.release()
|
|
||||||
|
|
||||||
def update_accounts(self):
|
async def update_accounts(self):
|
||||||
return defer.DeferredList([
|
return await asyncio.gather(*(
|
||||||
self.update_account(a) for a in self.accounts
|
self.update_account(a) for a in self.accounts
|
||||||
])
|
))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def update_account(self, account: baseaccount.BaseAccount):
|
||||||
def update_account(self, account): # type: (baseaccount.BaseAccount) -> defer.Defferred
|
|
||||||
# Before subscribing, download history for any addresses that don't have any,
|
# Before subscribing, download history for any addresses that don't have any,
|
||||||
# this avoids situation where we're getting status updates to addresses we know
|
# this avoids situation where we're getting status updates to addresses we know
|
||||||
# need to update anyways. Continue to get history and create more addresses until
|
# need to update anyways. Continue to get history and create more addresses until
|
||||||
# all missing addresses are created and history for them is fully restored.
|
# all missing addresses are created and history for them is fully restored.
|
||||||
yield account.ensure_address_gap()
|
await account.ensure_address_gap()
|
||||||
addresses = yield account.get_addresses(used_times=0)
|
addresses = await account.get_addresses(used_times=0)
|
||||||
while addresses:
|
while addresses:
|
||||||
yield defer.DeferredList([
|
await asyncio.gather(*(self.update_history(a) for a in addresses))
|
||||||
self.update_history(a) for a in addresses
|
addresses = await account.ensure_address_gap()
|
||||||
])
|
|
||||||
addresses = yield account.ensure_address_gap()
|
|
||||||
|
|
||||||
# By this point all of the addresses should be restored and we
|
# By this point all of the addresses should be restored and we
|
||||||
# can now subscribe all of them to receive updates.
|
# can now subscribe all of them to receive updates.
|
||||||
all_addresses = yield account.get_addresses()
|
all_addresses = await account.get_addresses()
|
||||||
yield defer.DeferredList(
|
await asyncio.gather(*(self.subscribe_history(a) for a in all_addresses))
|
||||||
list(map(self.subscribe_history, all_addresses))
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def update_history(self, address):
|
||||||
def _prefetch_history(self, remote_history, local_history):
|
remote_history = await self.network.get_history(address)
|
||||||
proofs, network_txs, deferreds = {}, {}, []
|
local_history = await self.get_local_history(address)
|
||||||
for i, (hex_id, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)):
|
|
||||||
if i < len(local_history) and local_history[i] == (hex_id, remote_height):
|
|
||||||
continue
|
|
||||||
if remote_height > 0:
|
|
||||||
deferreds.append(
|
|
||||||
self.network.get_merkle(hex_id, remote_height).addBoth(
|
|
||||||
lambda result, txid: proofs.__setitem__(txid, result), hex_id)
|
|
||||||
)
|
|
||||||
deferreds.append(
|
|
||||||
self.network.get_transaction(hex_id).addBoth(
|
|
||||||
lambda result, txid: network_txs.__setitem__(txid, result), hex_id)
|
|
||||||
)
|
|
||||||
yield defer.DeferredList(deferreds)
|
|
||||||
return proofs, network_txs
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def update_history(self, address):
|
|
||||||
remote_history = yield self.network.get_history(address)
|
|
||||||
local_history = yield self.get_local_history(address)
|
|
||||||
proofs, network_txs = yield self._prefetch_history(remote_history, local_history)
|
|
||||||
|
|
||||||
synced_history = StringIO()
|
synced_history = StringIO()
|
||||||
for i, (hex_id, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)):
|
for i, (hex_id, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)):
|
||||||
|
@ -379,30 +335,29 @@ class BaseLedger(metaclass=LedgerRegistry):
|
||||||
if i < len(local_history) and local_history[i] == (hex_id, remote_height):
|
if i < len(local_history) and local_history[i] == (hex_id, remote_height):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
lock = self._transaction_processing_locks.setdefault(hex_id, defer.DeferredLock())
|
lock = self._transaction_processing_locks.setdefault(hex_id, asyncio.Lock())
|
||||||
|
|
||||||
yield lock.acquire()
|
await lock.acquire()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
# see if we have a local copy of transaction, otherwise fetch it from server
|
# see if we have a local copy of transaction, otherwise fetch it from server
|
||||||
tx = yield self.db.get_transaction(txid=hex_id)
|
tx = await self.db.get_transaction(txid=hex_id)
|
||||||
save_tx = None
|
save_tx = None
|
||||||
if tx is None:
|
if tx is None:
|
||||||
_raw = network_txs[hex_id]
|
_raw = await self.network.get_transaction(hex_id)
|
||||||
tx = self.transaction_class(unhexlify(_raw))
|
tx = self.transaction_class(unhexlify(_raw))
|
||||||
save_tx = 'insert'
|
save_tx = 'insert'
|
||||||
|
|
||||||
tx.height = remote_height
|
tx.height = remote_height
|
||||||
|
|
||||||
if remote_height > 0 and (not tx.is_verified or tx.position == -1):
|
if remote_height > 0 and (not tx.is_verified or tx.position == -1):
|
||||||
self.validate_transaction_and_set_position(tx, remote_height, proofs[hex_id])
|
await self.validate_transaction_and_set_position(tx, remote_height)
|
||||||
if save_tx is None:
|
if save_tx is None:
|
||||||
save_tx = 'update'
|
save_tx = 'update'
|
||||||
|
|
||||||
yield self.db.save_transaction_io(
|
await self.db.save_transaction_io(
|
||||||
save_tx, tx, address, self.address_to_hash160(address),
|
save_tx, tx, address, self.address_to_hash160(address), synced_history.getvalue()
|
||||||
synced_history.getvalue()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
log.debug(
|
log.debug(
|
||||||
|
@ -412,28 +367,22 @@ class BaseLedger(metaclass=LedgerRegistry):
|
||||||
|
|
||||||
self._on_transaction_controller.add(TransactionEvent(address, tx))
|
self._on_transaction_controller.add(TransactionEvent(address, tx))
|
||||||
|
|
||||||
except Exception:
|
|
||||||
log.exception('Failed to synchronize transaction:')
|
|
||||||
raise
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
lock.release()
|
lock.release()
|
||||||
if not lock.locked and hex_id in self._transaction_processing_locks:
|
if not lock.locked() and hex_id in self._transaction_processing_locks:
|
||||||
del self._transaction_processing_locks[hex_id]
|
del self._transaction_processing_locks[hex_id]
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def subscribe_history(self, address):
|
||||||
def subscribe_history(self, address):
|
remote_status = await self.network.subscribe_address(address)
|
||||||
remote_status = yield self.network.subscribe_address(address)
|
local_status = await self.get_local_status(address)
|
||||||
local_status = yield self.get_local_status(address)
|
|
||||||
if local_status != remote_status:
|
if local_status != remote_status:
|
||||||
yield self.update_history(address)
|
await self.update_history(address)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def receive_status(self, response):
|
||||||
def receive_status(self, response):
|
|
||||||
address, remote_status = response
|
address, remote_status = response
|
||||||
local_status = yield self.get_local_status(address)
|
local_status = await self.get_local_status(address)
|
||||||
if local_status != remote_status:
|
if local_status != remote_status:
|
||||||
yield self.update_history(address)
|
await self.update_history(address)
|
||||||
|
|
||||||
def broadcast(self, tx):
|
def broadcast(self, tx):
|
||||||
return self.network.broadcast(hexlify(tx.raw).decode())
|
return self.network.broadcast(hexlify(tx.raw).decode())
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import Type, MutableSequence, MutableMapping
|
from typing import Type, MutableSequence, MutableMapping
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from torba.baseledger import BaseLedger, LedgerRegistry
|
from torba.baseledger import BaseLedger, LedgerRegistry
|
||||||
from torba.wallet import Wallet, WalletStorage
|
from torba.wallet import Wallet, WalletStorage
|
||||||
|
@ -41,11 +41,10 @@ class BaseWalletManager:
|
||||||
self.wallets.append(wallet)
|
self.wallets.append(wallet)
|
||||||
return wallet
|
return wallet
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_detailed_accounts(self, confirmations=6, show_seed=False):
|
||||||
def get_detailed_accounts(self, confirmations=6, show_seed=False):
|
|
||||||
ledgers = {}
|
ledgers = {}
|
||||||
for i, account in enumerate(self.accounts):
|
for i, account in enumerate(self.accounts):
|
||||||
details = yield account.get_details(confirmations=confirmations, show_seed=True)
|
details = await account.get_details(confirmations=confirmations, show_seed=True)
|
||||||
details['is_default_account'] = i == 0
|
details['is_default_account'] = i == 0
|
||||||
ledger_id = account.ledger.get_id()
|
ledger_id = account.ledger.get_id()
|
||||||
ledgers.setdefault(ledger_id, [])
|
ledgers.setdefault(ledger_id, [])
|
||||||
|
@ -68,16 +67,14 @@ class BaseWalletManager:
|
||||||
for account in wallet.accounts:
|
for account in wallet.accounts:
|
||||||
yield account
|
yield account
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def start(self):
|
||||||
def start(self):
|
|
||||||
self.running = True
|
self.running = True
|
||||||
yield defer.DeferredList([
|
await asyncio.gather(*(
|
||||||
l.start() for l in self.ledgers.values()
|
l.start() for l in self.ledgers.values()
|
||||||
])
|
))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def stop(self):
|
||||||
def stop(self):
|
await asyncio.gather(*(
|
||||||
yield defer.DeferredList([
|
|
||||||
l.stop() for l in self.ledgers.values()
|
l.stop() for l in self.ledgers.values()
|
||||||
])
|
))
|
||||||
self.running = False
|
self.running = False
|
||||||
|
|
|
@ -1,145 +1,36 @@
|
||||||
import json
|
|
||||||
import socket
|
|
||||||
import logging
|
import logging
|
||||||
from itertools import cycle
|
from itertools import cycle
|
||||||
from twisted.internet import defer, reactor, protocol
|
|
||||||
from twisted.application.internet import ClientService, CancelledError
|
from aiorpcx import ClientSession as BaseClientSession
|
||||||
from twisted.internet.endpoints import clientFromString
|
|
||||||
from twisted.protocols.basic import LineOnlyReceiver
|
|
||||||
from twisted.python import failure
|
|
||||||
|
|
||||||
from torba import __version__
|
from torba import __version__
|
||||||
from torba.stream import StreamController
|
from torba.stream import StreamController
|
||||||
from torba.constants import TIMEOUT
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class StratumClientProtocol(LineOnlyReceiver):
|
class ClientSession(BaseClientSession):
|
||||||
delimiter = b'\n'
|
|
||||||
MAX_LENGTH = 2000000
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, *args, network, **kwargs):
|
||||||
self.request_id = 0
|
|
||||||
self.lookup_table = {}
|
|
||||||
self.session = {}
|
|
||||||
self.network = None
|
|
||||||
|
|
||||||
self.on_disconnected_controller = StreamController()
|
|
||||||
self.on_disconnected = self.on_disconnected_controller.stream
|
|
||||||
|
|
||||||
def _get_id(self):
|
|
||||||
self.request_id += 1
|
|
||||||
return self.request_id
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _ip(self):
|
|
||||||
return self.transport.getPeer().host
|
|
||||||
|
|
||||||
def get_session(self):
|
|
||||||
return self.session
|
|
||||||
|
|
||||||
def connectionMade(self):
|
|
||||||
try:
|
|
||||||
self.transport.setTcpNoDelay(True)
|
|
||||||
self.transport.setTcpKeepAlive(True)
|
|
||||||
if hasattr(socket, "TCP_KEEPIDLE"):
|
|
||||||
self.transport.socket.setsockopt(
|
|
||||||
socket.SOL_TCP, socket.TCP_KEEPIDLE, 120
|
|
||||||
# Seconds before sending keepalive probes
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
log.debug("TCP_KEEPIDLE not available")
|
|
||||||
if hasattr(socket, "TCP_KEEPINTVL"):
|
|
||||||
self.transport.socket.setsockopt(
|
|
||||||
socket.SOL_TCP, socket.TCP_KEEPINTVL, 1
|
|
||||||
# Interval in seconds between keepalive probes
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
log.debug("TCP_KEEPINTVL not available")
|
|
||||||
if hasattr(socket, "TCP_KEEPCNT"):
|
|
||||||
self.transport.socket.setsockopt(
|
|
||||||
socket.SOL_TCP, socket.TCP_KEEPCNT, 5
|
|
||||||
# Failed keepalive probles before declaring other end dead
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
log.debug("TCP_KEEPCNT not available")
|
|
||||||
|
|
||||||
except Exception as err: # pylint: disable=broad-except
|
|
||||||
# Supported only by the socket transport,
|
|
||||||
# but there's really no better place in code to trigger this.
|
|
||||||
log.warning("Error setting up socket: %s", err)
|
|
||||||
|
|
||||||
def connectionLost(self, reason=None):
|
|
||||||
self.connected = 0
|
|
||||||
self.on_disconnected_controller.add(True)
|
|
||||||
for deferred in self.lookup_table.values():
|
|
||||||
if not deferred.called:
|
|
||||||
deferred.errback(TimeoutError("Connection dropped."))
|
|
||||||
|
|
||||||
def lineReceived(self, line):
|
|
||||||
log.debug('received: %s', line)
|
|
||||||
|
|
||||||
try:
|
|
||||||
message = json.loads(line)
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
raise ValueError("Cannot decode message '{}'".format(line.strip()))
|
|
||||||
|
|
||||||
if message.get('id'):
|
|
||||||
try:
|
|
||||||
d = self.lookup_table.pop(message['id'])
|
|
||||||
if message.get('error'):
|
|
||||||
d.errback(RuntimeError(message['error']))
|
|
||||||
else:
|
|
||||||
d.callback(message.get('result'))
|
|
||||||
except KeyError:
|
|
||||||
raise LookupError(
|
|
||||||
"Lookup for deferred object for message ID '{}' failed.".format(message['id']))
|
|
||||||
elif message.get('method') in self.network.subscription_controllers:
|
|
||||||
controller = self.network.subscription_controllers[message['method']]
|
|
||||||
controller.add(message.get('params'))
|
|
||||||
else:
|
|
||||||
log.warning("Cannot handle message '%s'", line)
|
|
||||||
|
|
||||||
def rpc(self, method, *args):
|
|
||||||
message_id = self._get_id()
|
|
||||||
message = json.dumps({
|
|
||||||
'id': message_id,
|
|
||||||
'method': method,
|
|
||||||
'params': args
|
|
||||||
})
|
|
||||||
log.debug('sent: %s', message)
|
|
||||||
self.sendLine(message.encode('latin-1'))
|
|
||||||
d = self.lookup_table[message_id] = defer.Deferred()
|
|
||||||
d.addTimeout(
|
|
||||||
TIMEOUT, reactor, onTimeoutCancel=lambda *_: failure.Failure(TimeoutError(
|
|
||||||
"Timeout: Stratum request for '%s' took more than %s seconds" % (method, TIMEOUT)))
|
|
||||||
)
|
|
||||||
return d
|
|
||||||
|
|
||||||
|
|
||||||
class StratumClientFactory(protocol.ClientFactory):
|
|
||||||
|
|
||||||
protocol = StratumClientProtocol
|
|
||||||
|
|
||||||
def __init__(self, network):
|
|
||||||
self.network = network
|
self.network = network
|
||||||
self.client = None
|
super().__init__(*args, **kwargs)
|
||||||
|
self._on_disconnect_controller = StreamController()
|
||||||
|
self.on_disconnected = self._on_disconnect_controller.stream
|
||||||
|
|
||||||
def buildProtocol(self, addr):
|
async def handle_request(self, request):
|
||||||
client = self.protocol()
|
controller = self.network.subscription_controllers[request.method]
|
||||||
client.factory = self
|
controller.add(request.args)
|
||||||
client.network = self.network
|
|
||||||
self.client = client
|
def connection_lost(self, exc):
|
||||||
return client
|
super().connection_lost(exc)
|
||||||
|
self._on_disconnect_controller.add(True)
|
||||||
|
|
||||||
|
|
||||||
class BaseNetwork:
|
class BaseNetwork:
|
||||||
|
|
||||||
def __init__(self, ledger):
|
def __init__(self, ledger):
|
||||||
self.config = ledger.config
|
self.config = ledger.config
|
||||||
self.client = None
|
self.client: ClientSession = None
|
||||||
self.service = None
|
|
||||||
self.running = False
|
self.running = False
|
||||||
|
|
||||||
self._on_connected_controller = StreamController()
|
self._on_connected_controller = StreamController()
|
||||||
|
@ -156,48 +47,35 @@ class BaseNetwork:
|
||||||
'blockchain.address.subscribe': self._on_status_controller,
|
'blockchain.address.subscribe': self._on_status_controller,
|
||||||
}
|
}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def start(self):
|
||||||
def start(self):
|
|
||||||
self.running = True
|
self.running = True
|
||||||
for server in cycle(self.config['default_servers']):
|
for server in cycle(self.config['default_servers']):
|
||||||
connection_string = 'tcp:{}:{}'.format(*server)
|
connection_string = 'tcp:{}:{}'.format(*server)
|
||||||
endpoint = clientFromString(reactor, connection_string)
|
self.client = ClientSession(*server, network=self)
|
||||||
log.debug("Attempting connection to SPV wallet server: %s", connection_string)
|
|
||||||
self.service = ClientService(endpoint, StratumClientFactory(self))
|
|
||||||
self.service.startService()
|
|
||||||
try:
|
try:
|
||||||
self.client = yield self.service.whenConnected(failAfterFailures=2)
|
await self.client.create_connection()
|
||||||
yield self.ensure_server_version()
|
await self.ensure_server_version()
|
||||||
log.info("Successfully connected to SPV wallet server: %s", connection_string)
|
log.info("Successfully connected to SPV wallet server: %s", )
|
||||||
self._on_connected_controller.add(True)
|
self._on_connected_controller.add(True)
|
||||||
yield self.client.on_disconnected.first
|
await self.client.on_disconnected.first
|
||||||
except CancelledError:
|
|
||||||
return
|
|
||||||
except Exception: # pylint: disable=broad-except
|
except Exception: # pylint: disable=broad-except
|
||||||
log.exception("Connecting to %s raised an exception:", connection_string)
|
log.exception("Connecting to %s raised an exception:", connection_string)
|
||||||
finally:
|
|
||||||
self.client = None
|
|
||||||
if self.service is not None:
|
|
||||||
self.service.stopService()
|
|
||||||
if not self.running:
|
if not self.running:
|
||||||
return
|
return
|
||||||
|
|
||||||
def stop(self):
|
async def stop(self):
|
||||||
self.running = False
|
self.running = False
|
||||||
if self.service is not None:
|
|
||||||
self.service.stopService()
|
|
||||||
if self.is_connected:
|
if self.is_connected:
|
||||||
return self.client.on_disconnected.first
|
await self.client.close()
|
||||||
else:
|
await self.client.on_disconnected.first
|
||||||
return defer.succeed(True)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_connected(self):
|
def is_connected(self):
|
||||||
return self.client is not None and self.client.connected
|
return self.client is not None and not self.client.is_closing()
|
||||||
|
|
||||||
def rpc(self, list_or_method, *args):
|
def rpc(self, list_or_method, *args):
|
||||||
if self.is_connected:
|
if self.is_connected:
|
||||||
return self.client.rpc(list_or_method, *args)
|
return self.client.send_request(list_or_method, args)
|
||||||
else:
|
else:
|
||||||
raise ConnectionError("Attempting to send rpc request when connection is not available.")
|
raise ConnectionError("Attempting to send rpc request when connection is not available.")
|
||||||
|
|
||||||
|
|
|
@ -3,8 +3,6 @@ import typing
|
||||||
from typing import List, Iterable, Optional
|
from typing import List, Iterable, Optional
|
||||||
from binascii import hexlify
|
from binascii import hexlify
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from torba.basescript import BaseInputScript, BaseOutputScript
|
from torba.basescript import BaseInputScript, BaseOutputScript
|
||||||
from torba.baseaccount import BaseAccount
|
from torba.baseaccount import BaseAccount
|
||||||
from torba.constants import COIN, NULL_HASH32
|
from torba.constants import COIN, NULL_HASH32
|
||||||
|
@ -426,8 +424,7 @@ class BaseTransaction:
|
||||||
return ledger
|
return ledger
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@defer.inlineCallbacks
|
async def create(cls, inputs: Iterable[BaseInput], outputs: Iterable[BaseOutput],
|
||||||
def create(cls, inputs: Iterable[BaseInput], outputs: Iterable[BaseOutput],
|
|
||||||
funding_accounts: Iterable[BaseAccount], change_account: BaseAccount):
|
funding_accounts: Iterable[BaseAccount], change_account: BaseAccount):
|
||||||
""" Find optimal set of inputs when only outputs are provided; add change
|
""" Find optimal set of inputs when only outputs are provided; add change
|
||||||
outputs if only inputs are provided or if inputs are greater than outputs. """
|
outputs if only inputs are provided or if inputs are greater than outputs. """
|
||||||
|
@ -450,7 +447,7 @@ class BaseTransaction:
|
||||||
|
|
||||||
if payment < cost:
|
if payment < cost:
|
||||||
deficit = cost - payment
|
deficit = cost - payment
|
||||||
spendables = yield ledger.get_spendable_utxos(deficit, funding_accounts)
|
spendables = await ledger.get_spendable_utxos(deficit, funding_accounts)
|
||||||
if not spendables:
|
if not spendables:
|
||||||
raise ValueError('Not enough funds to cover this transaction.')
|
raise ValueError('Not enough funds to cover this transaction.')
|
||||||
payment += sum(s.effective_amount for s in spendables)
|
payment += sum(s.effective_amount for s in spendables)
|
||||||
|
@ -463,28 +460,27 @@ class BaseTransaction:
|
||||||
)
|
)
|
||||||
change = payment - cost
|
change = payment - cost
|
||||||
if change > cost_of_change:
|
if change > cost_of_change:
|
||||||
change_address = yield change_account.change.get_or_create_usable_address()
|
change_address = await change_account.change.get_or_create_usable_address()
|
||||||
change_hash160 = change_account.ledger.address_to_hash160(change_address)
|
change_hash160 = change_account.ledger.address_to_hash160(change_address)
|
||||||
change_amount = change - cost_of_change
|
change_amount = change - cost_of_change
|
||||||
change_output = cls.output_class.pay_pubkey_hash(change_amount, change_hash160)
|
change_output = cls.output_class.pay_pubkey_hash(change_amount, change_hash160)
|
||||||
change_output.is_change = True
|
change_output.is_change = True
|
||||||
tx.add_outputs([cls.output_class.pay_pubkey_hash(change_amount, change_hash160)])
|
tx.add_outputs([cls.output_class.pay_pubkey_hash(change_amount, change_hash160)])
|
||||||
|
|
||||||
yield tx.sign(funding_accounts)
|
await tx.sign(funding_accounts)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception('Failed to synchronize transaction:')
|
log.exception('Failed to synchronize transaction:')
|
||||||
yield ledger.release_outputs(tx.outputs)
|
await ledger.release_outputs(tx.outputs)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
defer.returnValue(tx)
|
return tx
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def signature_hash_type(hash_type):
|
def signature_hash_type(hash_type):
|
||||||
return hash_type
|
return hash_type
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def sign(self, funding_accounts: Iterable[BaseAccount]):
|
||||||
def sign(self, funding_accounts: Iterable[BaseAccount]) -> defer.Deferred:
|
|
||||||
ledger = self.ensure_all_have_same_ledger(funding_accounts)
|
ledger = self.ensure_all_have_same_ledger(funding_accounts)
|
||||||
for i, txi in enumerate(self._inputs):
|
for i, txi in enumerate(self._inputs):
|
||||||
assert txi.script is not None
|
assert txi.script is not None
|
||||||
|
@ -492,7 +488,7 @@ class BaseTransaction:
|
||||||
txo_script = txi.txo_ref.txo.script
|
txo_script = txi.txo_ref.txo.script
|
||||||
if txo_script.is_pay_pubkey_hash:
|
if txo_script.is_pay_pubkey_hash:
|
||||||
address = ledger.hash160_to_address(txo_script.values['pubkey_hash'])
|
address = ledger.hash160_to_address(txo_script.values['pubkey_hash'])
|
||||||
private_key = yield ledger.get_private_key_for_address(address)
|
private_key = await ledger.get_private_key_for_address(address)
|
||||||
tx = self._serialize_for_signature(i)
|
tx = self._serialize_for_signature(i)
|
||||||
txi.script.values['signature'] = \
|
txi.script.values['signature'] = \
|
||||||
private_key.sign(tx) + bytes((self.signature_hash_type(1),))
|
private_key.sign(tx) + bytes((self.signature_hash_type(1),))
|
||||||
|
|
|
@ -1,6 +1,4 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from twisted.internet.defer import Deferred
|
|
||||||
from twisted.python.failure import Failure
|
|
||||||
|
|
||||||
|
|
||||||
class BroadcastSubscription:
|
class BroadcastSubscription:
|
||||||
|
@ -31,11 +29,13 @@ class BroadcastSubscription:
|
||||||
|
|
||||||
def _add(self, data):
|
def _add(self, data):
|
||||||
if self.can_fire and self._on_data is not None:
|
if self.can_fire and self._on_data is not None:
|
||||||
self._on_data(data)
|
maybe_coroutine = self._on_data(data)
|
||||||
|
if asyncio.iscoroutine(maybe_coroutine):
|
||||||
|
asyncio.ensure_future(maybe_coroutine)
|
||||||
|
|
||||||
def _add_error(self, error, traceback):
|
def _add_error(self, exception):
|
||||||
if self.can_fire and self._on_error is not None:
|
if self.can_fire and self._on_error is not None:
|
||||||
self._on_error(error, traceback)
|
self._on_error(exception)
|
||||||
|
|
||||||
def _close(self):
|
def _close(self):
|
||||||
if self.can_fire and self._on_done is not None:
|
if self.can_fire and self._on_done is not None:
|
||||||
|
@ -66,9 +66,9 @@ class StreamController:
|
||||||
for subscription in self._iterate_subscriptions:
|
for subscription in self._iterate_subscriptions:
|
||||||
subscription._add(event)
|
subscription._add(event)
|
||||||
|
|
||||||
def add_error(self, error, traceback):
|
def add_error(self, exception):
|
||||||
for subscription in self._iterate_subscriptions:
|
for subscription in self._iterate_subscriptions:
|
||||||
subscription._add_error(error, traceback)
|
subscription._add_error(exception)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
for subscription in self._iterate_subscriptions:
|
for subscription in self._iterate_subscriptions:
|
||||||
|
@ -108,38 +108,35 @@ class Stream:
|
||||||
def listen(self, on_data, on_error=None, on_done=None):
|
def listen(self, on_data, on_error=None, on_done=None):
|
||||||
return self._controller._listen(on_data, on_error, on_done)
|
return self._controller._listen(on_data, on_error, on_done)
|
||||||
|
|
||||||
def deferred_where(self, condition):
|
def where(self, condition) -> asyncio.Future:
|
||||||
deferred = Deferred()
|
future = asyncio.get_event_loop().create_future()
|
||||||
|
|
||||||
def where_test(value):
|
def where_test(value):
|
||||||
if condition(value):
|
if condition(value):
|
||||||
self._cancel_and_callback(subscription, deferred, value)
|
self._cancel_and_callback(subscription, future, value)
|
||||||
|
|
||||||
subscription = self.listen(
|
subscription = self.listen(
|
||||||
where_test,
|
where_test,
|
||||||
lambda error, traceback: self._cancel_and_error(subscription, deferred, error, traceback)
|
lambda exception: self._cancel_and_error(subscription, future, exception)
|
||||||
)
|
)
|
||||||
|
|
||||||
return deferred
|
return future
|
||||||
|
|
||||||
def where(self, condition):
|
|
||||||
return self.deferred_where(condition).asFuture(asyncio.get_event_loop())
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def first(self):
|
def first(self):
|
||||||
deferred = Deferred()
|
future = asyncio.get_event_loop().create_future()
|
||||||
subscription = self.listen(
|
subscription = self.listen(
|
||||||
lambda value: self._cancel_and_callback(subscription, deferred, value),
|
lambda value: self._cancel_and_callback(subscription, future, value),
|
||||||
lambda error, traceback: self._cancel_and_error(subscription, deferred, error, traceback)
|
lambda exception: self._cancel_and_error(subscription, future, exception)
|
||||||
)
|
)
|
||||||
return deferred
|
return future
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _cancel_and_callback(subscription, deferred, value):
|
def _cancel_and_callback(subscription: BroadcastSubscription, future: asyncio.Future, value):
|
||||||
subscription.cancel()
|
subscription.cancel()
|
||||||
deferred.callback(value)
|
future.set_result(value)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _cancel_and_error(subscription, deferred, error, traceback):
|
def _cancel_and_error(subscription: BroadcastSubscription, future: asyncio.Future, exception):
|
||||||
subscription.cancel()
|
subscription.cancel()
|
||||||
deferred.errback(Failure(error, exc_tb=traceback))
|
future.set_exception(exception)
|
||||||
|
|
Loading…
Reference in a new issue