twisted -> asyncio

This commit is contained in:
Lex Berezhny 2018-10-14 22:16:51 -04:00
parent 0ce4b9a7de
commit 2c5fd4aade
21 changed files with 480 additions and 721 deletions

View file

@ -1,5 +1,5 @@
import asyncio import asyncio
from orchstr8.testcase import IntegrationTestCase, d2f from orchstr8.testcase import IntegrationTestCase
from torba.constants import COIN from torba.constants import COIN
@ -9,14 +9,14 @@ class BasicTransactionTests(IntegrationTestCase):
async def test_sending_and_receiving(self): async def test_sending_and_receiving(self):
account1, account2 = self.account, self.wallet.generate_account(self.ledger) account1, account2 = self.account, self.wallet.generate_account(self.ledger)
await d2f(self.ledger.update_account(account2)) await self.ledger.update_account(account2)
self.assertEqual(await self.get_balance(account1), 0) self.assertEqual(await self.get_balance(account1), 0)
self.assertEqual(await self.get_balance(account2), 0) self.assertEqual(await self.get_balance(account2), 0)
sendtxids = [] sendtxids = []
for i in range(5): for i in range(5):
address1 = await d2f(account1.receiving.get_or_create_usable_address()) address1 = await account1.receiving.get_or_create_usable_address()
sendtxid = await self.blockchain.send_to_address(address1, 1.1) sendtxid = await self.blockchain.send_to_address(address1, 1.1)
sendtxids.append(sendtxid) sendtxids.append(sendtxid)
await self.on_transaction_id(sendtxid) # mempool await self.on_transaction_id(sendtxid) # mempool
@ -28,13 +28,13 @@ class BasicTransactionTests(IntegrationTestCase):
self.assertEqual(round(await self.get_balance(account1)/COIN, 1), 5.5) self.assertEqual(round(await self.get_balance(account1)/COIN, 1), 5.5)
self.assertEqual(round(await self.get_balance(account2)/COIN, 1), 0) self.assertEqual(round(await self.get_balance(account2)/COIN, 1), 0)
address2 = await d2f(account2.receiving.get_or_create_usable_address()) address2 = await account2.receiving.get_or_create_usable_address()
hash2 = self.ledger.address_to_hash160(address2) hash2 = self.ledger.address_to_hash160(address2)
tx = await d2f(self.ledger.transaction_class.create( tx = await self.ledger.transaction_class.create(
[], [],
[self.ledger.transaction_class.output_class.pay_pubkey_hash(2*COIN, hash2)], [self.ledger.transaction_class.output_class.pay_pubkey_hash(2*COIN, hash2)],
[account1], account1 [account1], account1
)) )
await self.broadcast(tx) await self.broadcast(tx)
await self.on_transaction(tx) # mempool await self.on_transaction(tx) # mempool
await self.blockchain.generate(1) await self.blockchain.generate(1)
@ -43,18 +43,18 @@ class BasicTransactionTests(IntegrationTestCase):
self.assertEqual(round(await self.get_balance(account1)/COIN, 1), 3.5) self.assertEqual(round(await self.get_balance(account1)/COIN, 1), 3.5)
self.assertEqual(round(await self.get_balance(account2)/COIN, 1), 2.0) self.assertEqual(round(await self.get_balance(account2)/COIN, 1), 2.0)
utxos = await d2f(self.account.get_utxos()) utxos = await self.account.get_utxos()
tx = await d2f(self.ledger.transaction_class.create( tx = await self.ledger.transaction_class.create(
[self.ledger.transaction_class.input_class.spend(utxos[0])], [self.ledger.transaction_class.input_class.spend(utxos[0])],
[], [],
[account1], account1 [account1], account1
)) )
await self.broadcast(tx) await self.broadcast(tx)
await self.on_transaction(tx) # mempool await self.on_transaction(tx) # mempool
await self.blockchain.generate(1) await self.blockchain.generate(1)
await self.on_transaction(tx) # confirmed await self.on_transaction(tx) # confirmed
txs = await d2f(account1.get_transactions()) txs = await account1.get_transactions()
tx = txs[1] tx = txs[1]
self.assertEqual(round(tx.inputs[0].txo_ref.txo.amount/COIN, 1), 1.1) self.assertEqual(round(tx.inputs[0].txo_ref.txo.amount/COIN, 1), 1.1)
self.assertEqual(round(tx.inputs[1].txo_ref.txo.amount/COIN, 1), 1.1) self.assertEqual(round(tx.inputs[1].txo_ref.txo.amount/COIN, 1), 1.1)

View file

@ -1,25 +1,26 @@
from binascii import hexlify from binascii import hexlify
from twisted.trial import unittest
from twisted.internet import defer from orchstr8.testcase import AsyncioTestCase
from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class
from torba.baseaccount import HierarchicalDeterministic, SingleKey from torba.baseaccount import HierarchicalDeterministic, SingleKey
from torba.wallet import Wallet from torba.wallet import Wallet
class TestHierarchicalDeterministicAccount(unittest.TestCase): class TestHierarchicalDeterministicAccount(AsyncioTestCase):
@defer.inlineCallbacks async def asyncSetUp(self):
def setUp(self):
self.ledger = ledger_class({ self.ledger = ledger_class({
'db': ledger_class.database_class(':memory:'), 'db': ledger_class.database_class(':memory:'),
'headers': ledger_class.headers_class(':memory:'), 'headers': ledger_class.headers_class(':memory:'),
}) })
yield self.ledger.db.open() await self.ledger.db.open()
self.account = self.ledger.account_class.generate(self.ledger, Wallet(), "torba") self.account = self.ledger.account_class.generate(self.ledger, Wallet(), "torba")
@defer.inlineCallbacks async def asyncTearDown(self):
def test_generate_account(self): await self.ledger.db.close()
async def test_generate_account(self):
account = self.account account = self.account
self.assertEqual(account.ledger, self.ledger) self.assertEqual(account.ledger, self.ledger)
@ -27,81 +28,77 @@ class TestHierarchicalDeterministicAccount(unittest.TestCase):
self.assertEqual(account.public_key.ledger, self.ledger) self.assertEqual(account.public_key.ledger, self.ledger)
self.assertEqual(account.private_key.public_key, account.public_key) self.assertEqual(account.private_key.public_key, account.public_key)
addresses = yield account.receiving.get_addresses() addresses = await account.receiving.get_addresses()
self.assertEqual(len(addresses), 0) self.assertEqual(len(addresses), 0)
addresses = yield account.change.get_addresses() addresses = await account.change.get_addresses()
self.assertEqual(len(addresses), 0) self.assertEqual(len(addresses), 0)
yield account.ensure_address_gap() await account.ensure_address_gap()
addresses = yield account.receiving.get_addresses() addresses = await account.receiving.get_addresses()
self.assertEqual(len(addresses), 20) self.assertEqual(len(addresses), 20)
addresses = yield account.change.get_addresses() addresses = await account.change.get_addresses()
self.assertEqual(len(addresses), 6) self.assertEqual(len(addresses), 6)
addresses = yield account.get_addresses() addresses = await account.get_addresses()
self.assertEqual(len(addresses), 26) self.assertEqual(len(addresses), 26)
@defer.inlineCallbacks async def test_generate_keys_over_batch_threshold_saves_it_properly(self):
def test_generate_keys_over_batch_threshold_saves_it_properly(self): await self.account.receiving.generate_keys(0, 200)
yield self.account.receiving.generate_keys(0, 200) records = await self.account.receiving.get_address_records()
records = yield self.account.receiving.get_address_records()
self.assertEqual(201, len(records)) self.assertEqual(201, len(records))
@defer.inlineCallbacks async def test_ensure_address_gap(self):
def test_ensure_address_gap(self):
account = self.account account = self.account
self.assertIsInstance(account.receiving, HierarchicalDeterministic) self.assertIsInstance(account.receiving, HierarchicalDeterministic)
yield account.receiving.generate_keys(4, 7) await account.receiving.generate_keys(4, 7)
yield account.receiving.generate_keys(0, 3) await account.receiving.generate_keys(0, 3)
yield account.receiving.generate_keys(8, 11) await account.receiving.generate_keys(8, 11)
records = yield account.receiving.get_address_records() records = await account.receiving.get_address_records()
self.assertEqual( self.assertEqual(
[r['position'] for r in records], [r['position'] for r in records],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
) )
# we have 12, but default gap is 20 # we have 12, but default gap is 20
new_keys = yield account.receiving.ensure_address_gap() new_keys = await account.receiving.ensure_address_gap()
self.assertEqual(len(new_keys), 8) self.assertEqual(len(new_keys), 8)
records = yield account.receiving.get_address_records() records = await account.receiving.get_address_records()
self.assertEqual( self.assertEqual(
[r['position'] for r in records], [r['position'] for r in records],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
) )
# case #1: no new addresses needed # case #1: no new addresses needed
empty = yield account.receiving.ensure_address_gap() empty = await account.receiving.ensure_address_gap()
self.assertEqual(len(empty), 0) self.assertEqual(len(empty), 0)
# case #2: only one new addressed needed # case #2: only one new addressed needed
records = yield account.receiving.get_address_records() records = await account.receiving.get_address_records()
yield self.ledger.db.set_address_history(records[0]['address'], 'a:1:') await self.ledger.db.set_address_history(records[0]['address'], 'a:1:')
new_keys = yield account.receiving.ensure_address_gap() new_keys = await account.receiving.ensure_address_gap()
self.assertEqual(len(new_keys), 1) self.assertEqual(len(new_keys), 1)
# case #3: 20 addresses needed # case #3: 20 addresses needed
yield self.ledger.db.set_address_history(new_keys[0], 'a:1:') await self.ledger.db.set_address_history(new_keys[0], 'a:1:')
new_keys = yield account.receiving.ensure_address_gap() new_keys = await account.receiving.ensure_address_gap()
self.assertEqual(len(new_keys), 20) self.assertEqual(len(new_keys), 20)
@defer.inlineCallbacks async def test_get_or_create_usable_address(self):
def test_get_or_create_usable_address(self):
account = self.account account = self.account
keys = yield account.receiving.get_addresses() keys = await account.receiving.get_addresses()
self.assertEqual(len(keys), 0) self.assertEqual(len(keys), 0)
address = yield account.receiving.get_or_create_usable_address() address = await account.receiving.get_or_create_usable_address()
self.assertIsNotNone(address) self.assertIsNotNone(address)
keys = yield account.receiving.get_addresses() keys = await account.receiving.get_addresses()
self.assertEqual(len(keys), 20) self.assertEqual(len(keys), 20)
@defer.inlineCallbacks async def test_generate_account_from_seed(self):
def test_generate_account_from_seed(self):
account = self.ledger.account_class.from_dict( account = self.ledger.account_class.from_dict(
self.ledger, Wallet(), { self.ledger, Wallet(), {
"seed": "carbon smart garage balance margin twelve chest sword " "seed": "carbon smart garage balance margin twelve chest sword "
@ -123,17 +120,17 @@ class TestHierarchicalDeterministicAccount(unittest.TestCase):
'xpub661MyMwAqRbcFwwe67Bfjd53h5WXmKm6tqfBJZZH3pQLoy8Nb6mKUMJFc7UbpV' 'xpub661MyMwAqRbcFwwe67Bfjd53h5WXmKm6tqfBJZZH3pQLoy8Nb6mKUMJFc7UbpV'
'NzmwFPN2evn3YHnig1pkKVYcvCV8owTd2yAcEkJfCX53g' 'NzmwFPN2evn3YHnig1pkKVYcvCV8owTd2yAcEkJfCX53g'
) )
address = yield account.receiving.ensure_address_gap() address = await account.receiving.ensure_address_gap()
self.assertEqual(address[0], '1CDLuMfwmPqJiNk5C2Bvew6tpgjAGgUk8J') self.assertEqual(address[0], '1CDLuMfwmPqJiNk5C2Bvew6tpgjAGgUk8J')
private_key = yield self.ledger.get_private_key_for_address('1CDLuMfwmPqJiNk5C2Bvew6tpgjAGgUk8J') private_key = await self.ledger.get_private_key_for_address('1CDLuMfwmPqJiNk5C2Bvew6tpgjAGgUk8J')
self.assertEqual( self.assertEqual(
private_key.extended_key_string(), private_key.extended_key_string(),
'xprv9xV7rhbg6M4yWrdTeLorz3Q1GrQb4aQzzGWboP3du7W7UUztzNTUrEYTnDfz7o' 'xprv9xV7rhbg6M4yWrdTeLorz3Q1GrQb4aQzzGWboP3du7W7UUztzNTUrEYTnDfz7o'
'ptBygDxXYRppyiuenJpoBTgYP2C26E1Ah5FEALM24CsWi' 'ptBygDxXYRppyiuenJpoBTgYP2C26E1Ah5FEALM24CsWi'
) )
invalid_key = yield self.ledger.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX') invalid_key = await self.ledger.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX')
self.assertIsNone(invalid_key) self.assertIsNone(invalid_key)
self.assertEqual( self.assertEqual(
@ -141,8 +138,7 @@ class TestHierarchicalDeterministicAccount(unittest.TestCase):
b'1c01ae1e4c7d89e39f6d3aa7792c097a30ca7d40be249b6de52c81ec8cf9aab48b01' b'1c01ae1e4c7d89e39f6d3aa7792c097a30ca7d40be249b6de52c81ec8cf9aab48b01'
) )
@defer.inlineCallbacks async def test_load_and_save_account(self):
def test_load_and_save_account(self):
account_data = { account_data = {
'name': 'My Account', 'name': 'My Account',
'seed': 'seed':
@ -164,11 +160,11 @@ class TestHierarchicalDeterministicAccount(unittest.TestCase):
account = self.ledger.account_class.from_dict(self.ledger, Wallet(), account_data) account = self.ledger.account_class.from_dict(self.ledger, Wallet(), account_data)
yield account.ensure_address_gap() await account.ensure_address_gap()
addresses = yield account.receiving.get_addresses() addresses = await account.receiving.get_addresses()
self.assertEqual(len(addresses), 5) self.assertEqual(len(addresses), 5)
addresses = yield account.change.get_addresses() addresses = await account.change.get_addresses()
self.assertEqual(len(addresses), 5) self.assertEqual(len(addresses), 5)
self.maxDiff = None self.maxDiff = None
@ -176,20 +172,21 @@ class TestHierarchicalDeterministicAccount(unittest.TestCase):
self.assertDictEqual(account_data, account.to_dict()) self.assertDictEqual(account_data, account.to_dict())
class TestSingleKeyAccount(unittest.TestCase): class TestSingleKeyAccount(AsyncioTestCase):
@defer.inlineCallbacks async def asyncSetUp(self):
def setUp(self):
self.ledger = ledger_class({ self.ledger = ledger_class({
'db': ledger_class.database_class(':memory:'), 'db': ledger_class.database_class(':memory:'),
'headers': ledger_class.headers_class(':memory:'), 'headers': ledger_class.headers_class(':memory:'),
}) })
yield self.ledger.db.open() await self.ledger.db.open()
self.account = self.ledger.account_class.generate( self.account = self.ledger.account_class.generate(
self.ledger, Wallet(), "torba", {'name': 'single-address'}) self.ledger, Wallet(), "torba", {'name': 'single-address'})
@defer.inlineCallbacks async def asyncTearDown(self):
def test_generate_account(self): await self.ledger.db.close()
async def test_generate_account(self):
account = self.account account = self.account
self.assertEqual(account.ledger, self.ledger) self.assertEqual(account.ledger, self.ledger)
@ -197,37 +194,36 @@ class TestSingleKeyAccount(unittest.TestCase):
self.assertEqual(account.public_key.ledger, self.ledger) self.assertEqual(account.public_key.ledger, self.ledger)
self.assertEqual(account.private_key.public_key, account.public_key) self.assertEqual(account.private_key.public_key, account.public_key)
addresses = yield account.receiving.get_addresses() addresses = await account.receiving.get_addresses()
self.assertEqual(len(addresses), 0) self.assertEqual(len(addresses), 0)
addresses = yield account.change.get_addresses() addresses = await account.change.get_addresses()
self.assertEqual(len(addresses), 0) self.assertEqual(len(addresses), 0)
yield account.ensure_address_gap() await account.ensure_address_gap()
addresses = yield account.receiving.get_addresses() addresses = await account.receiving.get_addresses()
self.assertEqual(len(addresses), 1) self.assertEqual(len(addresses), 1)
self.assertEqual(addresses[0], account.public_key.address) self.assertEqual(addresses[0], account.public_key.address)
addresses = yield account.change.get_addresses() addresses = await account.change.get_addresses()
self.assertEqual(len(addresses), 1) self.assertEqual(len(addresses), 1)
self.assertEqual(addresses[0], account.public_key.address) self.assertEqual(addresses[0], account.public_key.address)
addresses = yield account.get_addresses() addresses = await account.get_addresses()
self.assertEqual(len(addresses), 1) self.assertEqual(len(addresses), 1)
self.assertEqual(addresses[0], account.public_key.address) self.assertEqual(addresses[0], account.public_key.address)
@defer.inlineCallbacks async def test_ensure_address_gap(self):
def test_ensure_address_gap(self):
account = self.account account = self.account
self.assertIsInstance(account.receiving, SingleKey) self.assertIsInstance(account.receiving, SingleKey)
addresses = yield account.receiving.get_addresses() addresses = await account.receiving.get_addresses()
self.assertEqual(addresses, []) self.assertEqual(addresses, [])
# we have 12, but default gap is 20 # we have 12, but default gap is 20
new_keys = yield account.receiving.ensure_address_gap() new_keys = await account.receiving.ensure_address_gap()
self.assertEqual(len(new_keys), 1) self.assertEqual(len(new_keys), 1)
self.assertEqual(new_keys[0], account.public_key.address) self.assertEqual(new_keys[0], account.public_key.address)
records = yield account.receiving.get_address_records() records = await account.receiving.get_address_records()
self.assertEqual(records, [{ self.assertEqual(records, [{
'position': 0, 'chain': 0, 'position': 0, 'chain': 0,
'account': account.public_key.address, 'account': account.public_key.address,
@ -236,37 +232,35 @@ class TestSingleKeyAccount(unittest.TestCase):
}]) }])
# case #1: no new addresses needed # case #1: no new addresses needed
empty = yield account.receiving.ensure_address_gap() empty = await account.receiving.ensure_address_gap()
self.assertEqual(len(empty), 0) self.assertEqual(len(empty), 0)
# case #2: after use, still no new address needed # case #2: after use, still no new address needed
records = yield account.receiving.get_address_records() records = await account.receiving.get_address_records()
yield self.ledger.db.set_address_history(records[0]['address'], 'a:1:') await self.ledger.db.set_address_history(records[0]['address'], 'a:1:')
empty = yield account.receiving.ensure_address_gap() empty = await account.receiving.ensure_address_gap()
self.assertEqual(len(empty), 0) self.assertEqual(len(empty), 0)
@defer.inlineCallbacks async def test_get_or_create_usable_address(self):
def test_get_or_create_usable_address(self):
account = self.account account = self.account
addresses = yield account.receiving.get_addresses() addresses = await account.receiving.get_addresses()
self.assertEqual(len(addresses), 0) self.assertEqual(len(addresses), 0)
address1 = yield account.receiving.get_or_create_usable_address() address1 = await account.receiving.get_or_create_usable_address()
self.assertIsNotNone(address1) self.assertIsNotNone(address1)
yield self.ledger.db.set_address_history(address1, 'a:1:b:2:c:3:') await self.ledger.db.set_address_history(address1, 'a:1:b:2:c:3:')
records = yield account.receiving.get_address_records() records = await account.receiving.get_address_records()
self.assertEqual(records[0]['used_times'], 3) self.assertEqual(records[0]['used_times'], 3)
address2 = yield account.receiving.get_or_create_usable_address() address2 = await account.receiving.get_or_create_usable_address()
self.assertEqual(address1, address2) self.assertEqual(address1, address2)
keys = yield account.receiving.get_addresses() keys = await account.receiving.get_addresses()
self.assertEqual(len(keys), 1) self.assertEqual(len(keys), 1)
@defer.inlineCallbacks async def test_generate_account_from_seed(self):
def test_generate_account_from_seed(self):
account = self.ledger.account_class.from_dict( account = self.ledger.account_class.from_dict(
self.ledger, Wallet(), { self.ledger, Wallet(), {
"seed": "seed":
@ -285,17 +279,17 @@ class TestSingleKeyAccount(unittest.TestCase):
'xpub661MyMwAqRbcFwwe67Bfjd53h5WXmKm6tqfBJZZH3pQLoy8Nb6mKUMJFc7' 'xpub661MyMwAqRbcFwwe67Bfjd53h5WXmKm6tqfBJZZH3pQLoy8Nb6mKUMJFc7'
'UbpVNzmwFPN2evn3YHnig1pkKVYcvCV8owTd2yAcEkJfCX53g', 'UbpVNzmwFPN2evn3YHnig1pkKVYcvCV8owTd2yAcEkJfCX53g',
) )
address = yield account.receiving.ensure_address_gap() address = await account.receiving.ensure_address_gap()
self.assertEqual(address[0], account.public_key.address) self.assertEqual(address[0], account.public_key.address)
private_key = yield self.ledger.get_private_key_for_address(address[0]) private_key = await self.ledger.get_private_key_for_address(address[0])
self.assertEqual( self.assertEqual(
private_key.extended_key_string(), private_key.extended_key_string(),
'xprv9s21ZrQH143K3TsAz5efNV8K93g3Ms3FXcjaWB9fVUsMwAoE3ZT4vYymkp' 'xprv9s21ZrQH143K3TsAz5efNV8K93g3Ms3FXcjaWB9fVUsMwAoE3ZT4vYymkp'
'5BxKKfnpz8J6sHDFriX1SnpvjNkzcks8XBnxjGLS83BTyfpna', '5BxKKfnpz8J6sHDFriX1SnpvjNkzcks8XBnxjGLS83BTyfpna',
) )
invalid_key = yield self.ledger.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX') invalid_key = await self.ledger.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX')
self.assertIsNone(invalid_key) self.assertIsNone(invalid_key)
self.assertEqual( self.assertEqual(
@ -303,8 +297,7 @@ class TestSingleKeyAccount(unittest.TestCase):
b'1c92caa0ef99bfd5e2ceb73b66da8cd726a9370be8c368d448a322f3c5b23aaab901' b'1c92caa0ef99bfd5e2ceb73b66da8cd726a9370be8c368d448a322f3c5b23aaab901'
) )
@defer.inlineCallbacks async def test_load_and_save_account(self):
def test_load_and_save_account(self):
account_data = { account_data = {
'name': 'My Account', 'name': 'My Account',
'seed': 'seed':
@ -322,11 +315,11 @@ class TestSingleKeyAccount(unittest.TestCase):
account = self.ledger.account_class.from_dict(self.ledger, Wallet(), account_data) account = self.ledger.account_class.from_dict(self.ledger, Wallet(), account_data)
yield account.ensure_address_gap() await account.ensure_address_gap()
addresses = yield account.receiving.get_addresses() addresses = await account.receiving.get_addresses()
self.assertEqual(len(addresses), 1) self.assertEqual(len(addresses), 1)
addresses = yield account.change.get_addresses() addresses = await account.change.get_addresses()
self.assertEqual(len(addresses), 1) self.assertEqual(len(addresses), 1)
self.maxDiff = None self.maxDiff = None
@ -334,7 +327,7 @@ class TestSingleKeyAccount(unittest.TestCase):
self.assertDictEqual(account_data, account.to_dict()) self.assertDictEqual(account_data, account.to_dict())
class AccountEncryptionTests(unittest.TestCase): class AccountEncryptionTests(AsyncioTestCase):
password = "password" password = "password"
init_vector = b'0000000000000000' init_vector = b'0000000000000000'
unencrypted_account = { unencrypted_account = {
@ -368,7 +361,7 @@ class AccountEncryptionTests(unittest.TestCase):
'address_generator': {'name': 'single-address'} 'address_generator': {'name': 'single-address'}
} }
def setUp(self): async def asyncSetUp(self):
self.ledger = ledger_class({ self.ledger = ledger_class({
'db': ledger_class.database_class(':memory:'), 'db': ledger_class.database_class(':memory:'),
'headers': ledger_class.headers_class(':memory:'), 'headers': ledger_class.headers_class(':memory:'),

View file

@ -1,4 +1,4 @@
from twisted.trial import unittest import unittest
from torba.bcd_data_stream import BCDataStream from torba.bcd_data_stream import BCDataStream

View file

@ -1,5 +1,5 @@
import unittest
from binascii import unhexlify, hexlify from binascii import unhexlify, hexlify
from twisted.trial import unittest
from .key_fixtures import expected_ids, expected_privkeys, expected_hardened_privkeys from .key_fixtures import expected_ids, expected_privkeys, expected_hardened_privkeys
from torba.bip32 import PubKey, PrivateKey, from_extended_key_string from torba.bip32 import PubKey, PrivateKey, from_extended_key_string

View file

@ -1,4 +1,4 @@
from twisted.trial import unittest import unittest
from types import GeneratorType from types import GeneratorType
from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class

View file

@ -1,11 +1,12 @@
from twisted.trial import unittest import unittest
from twisted.internet import defer
from torba.wallet import Wallet from torba.wallet import Wallet
from torba.constants import COIN from torba.constants import COIN
from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class
from torba.basedatabase import query, constraints_to_sql from torba.basedatabase import query, constraints_to_sql
from orchstr8.testcase import AsyncioTestCase
from .test_transaction import get_output, NULL_HASH from .test_transaction import get_output, NULL_HASH
@ -100,53 +101,52 @@ class TestQueryBuilder(unittest.TestCase):
) )
class TestQueries(unittest.TestCase): class TestQueries(AsyncioTestCase):
def setUp(self): async def asyncSetUp(self):
self.ledger = ledger_class({ self.ledger = ledger_class({
'db': ledger_class.database_class(':memory:'), 'db': ledger_class.database_class(':memory:'),
'headers': ledger_class.headers_class(':memory:'), 'headers': ledger_class.headers_class(':memory:'),
}) })
return self.ledger.db.open() await self.ledger.db.open()
@defer.inlineCallbacks async def asyncTearDown(self):
def create_account(self): await self.ledger.db.close()
async def create_account(self):
account = self.ledger.account_class.generate(self.ledger, Wallet()) account = self.ledger.account_class.generate(self.ledger, Wallet())
yield account.ensure_address_gap() await account.ensure_address_gap()
return account return account
@defer.inlineCallbacks async def create_tx_from_nothing(self, my_account, height):
def create_tx_from_nothing(self, my_account, height): to_address = await my_account.receiving.get_or_create_usable_address()
to_address = yield my_account.receiving.get_or_create_usable_address()
to_hash = ledger_class.address_to_hash160(to_address) to_hash = ledger_class.address_to_hash160(to_address)
tx = ledger_class.transaction_class(height=height, is_verified=True) \ tx = ledger_class.transaction_class(height=height, is_verified=True) \
.add_inputs([self.txi(self.txo(1, NULL_HASH))]) \ .add_inputs([self.txi(self.txo(1, NULL_HASH))]) \
.add_outputs([self.txo(1, to_hash)]) .add_outputs([self.txo(1, to_hash)])
yield self.ledger.db.save_transaction_io('insert', tx, to_address, to_hash, '') await self.ledger.db.save_transaction_io('insert', tx, to_address, to_hash, '')
return tx return tx
@defer.inlineCallbacks async def create_tx_from_txo(self, txo, to_account, height):
def create_tx_from_txo(self, txo, to_account, height):
from_hash = txo.script.values['pubkey_hash'] from_hash = txo.script.values['pubkey_hash']
from_address = self.ledger.hash160_to_address(from_hash) from_address = self.ledger.hash160_to_address(from_hash)
to_address = yield to_account.receiving.get_or_create_usable_address() to_address = await to_account.receiving.get_or_create_usable_address()
to_hash = ledger_class.address_to_hash160(to_address) to_hash = ledger_class.address_to_hash160(to_address)
tx = ledger_class.transaction_class(height=height, is_verified=True) \ tx = ledger_class.transaction_class(height=height, is_verified=True) \
.add_inputs([self.txi(txo)]) \ .add_inputs([self.txi(txo)]) \
.add_outputs([self.txo(1, to_hash)]) .add_outputs([self.txo(1, to_hash)])
yield self.ledger.db.save_transaction_io('insert', tx, from_address, from_hash, '') await self.ledger.db.save_transaction_io('insert', tx, from_address, from_hash, '')
yield self.ledger.db.save_transaction_io('', tx, to_address, to_hash, '') await self.ledger.db.save_transaction_io('', tx, to_address, to_hash, '')
return tx return tx
@defer.inlineCallbacks async def create_tx_to_nowhere(self, txo, height):
def create_tx_to_nowhere(self, txo, height):
from_hash = txo.script.values['pubkey_hash'] from_hash = txo.script.values['pubkey_hash']
from_address = self.ledger.hash160_to_address(from_hash) from_address = self.ledger.hash160_to_address(from_hash)
to_hash = NULL_HASH to_hash = NULL_HASH
tx = ledger_class.transaction_class(height=height, is_verified=True) \ tx = ledger_class.transaction_class(height=height, is_verified=True) \
.add_inputs([self.txi(txo)]) \ .add_inputs([self.txi(txo)]) \
.add_outputs([self.txo(1, to_hash)]) .add_outputs([self.txo(1, to_hash)])
yield self.ledger.db.save_transaction_io('insert', tx, from_address, from_hash, '') await self.ledger.db.save_transaction_io('insert', tx, from_address, from_hash, '')
return tx return tx
def txo(self, amount, address): def txo(self, amount, address):
@ -155,39 +155,38 @@ class TestQueries(unittest.TestCase):
def txi(self, txo): def txi(self, txo):
return ledger_class.transaction_class.input_class.spend(txo) return ledger_class.transaction_class.input_class.spend(txo)
@defer.inlineCallbacks async def test_get_transactions(self):
def test_get_transactions(self): account1 = await self.create_account()
account1 = yield self.create_account() account2 = await self.create_account()
account2 = yield self.create_account() tx1 = await self.create_tx_from_nothing(account1, 1)
tx1 = yield self.create_tx_from_nothing(account1, 1) tx2 = await self.create_tx_from_txo(tx1.outputs[0], account2, 2)
tx2 = yield self.create_tx_from_txo(tx1.outputs[0], account2, 2) tx3 = await self.create_tx_to_nowhere(tx2.outputs[0], 3)
tx3 = yield self.create_tx_to_nowhere(tx2.outputs[0], 3)
txs = yield self.ledger.db.get_transactions() txs = await self.ledger.db.get_transactions()
self.assertEqual([tx3.id, tx2.id, tx1.id], [tx.id for tx in txs]) self.assertEqual([tx3.id, tx2.id, tx1.id], [tx.id for tx in txs])
self.assertEqual([3, 2, 1], [tx.height for tx in txs]) self.assertEqual([3, 2, 1], [tx.height for tx in txs])
txs = yield self.ledger.db.get_transactions(account=account1) txs = await self.ledger.db.get_transactions(account=account1)
self.assertEqual([tx2.id, tx1.id], [tx.id for tx in txs]) self.assertEqual([tx2.id, tx1.id], [tx.id for tx in txs])
self.assertEqual(txs[0].inputs[0].is_my_account, True) self.assertEqual(txs[0].inputs[0].is_my_account, True)
self.assertEqual(txs[0].outputs[0].is_my_account, False) self.assertEqual(txs[0].outputs[0].is_my_account, False)
self.assertEqual(txs[1].inputs[0].is_my_account, False) self.assertEqual(txs[1].inputs[0].is_my_account, False)
self.assertEqual(txs[1].outputs[0].is_my_account, True) self.assertEqual(txs[1].outputs[0].is_my_account, True)
txs = yield self.ledger.db.get_transactions(account=account2) txs = await self.ledger.db.get_transactions(account=account2)
self.assertEqual([tx3.id, tx2.id], [tx.id for tx in txs]) self.assertEqual([tx3.id, tx2.id], [tx.id for tx in txs])
self.assertEqual(txs[0].inputs[0].is_my_account, True) self.assertEqual(txs[0].inputs[0].is_my_account, True)
self.assertEqual(txs[0].outputs[0].is_my_account, False) self.assertEqual(txs[0].outputs[0].is_my_account, False)
self.assertEqual(txs[1].inputs[0].is_my_account, False) self.assertEqual(txs[1].inputs[0].is_my_account, False)
self.assertEqual(txs[1].outputs[0].is_my_account, True) self.assertEqual(txs[1].outputs[0].is_my_account, True)
tx = yield self.ledger.db.get_transaction(txid=tx2.id) tx = await self.ledger.db.get_transaction(txid=tx2.id)
self.assertEqual(tx.id, tx2.id) self.assertEqual(tx.id, tx2.id)
self.assertEqual(tx.inputs[0].is_my_account, False) self.assertEqual(tx.inputs[0].is_my_account, False)
self.assertEqual(tx.outputs[0].is_my_account, False) self.assertEqual(tx.outputs[0].is_my_account, False)
tx = yield self.ledger.db.get_transaction(txid=tx2.id, account=account1) tx = await self.ledger.db.get_transaction(txid=tx2.id, account=account1)
self.assertEqual(tx.inputs[0].is_my_account, True) self.assertEqual(tx.inputs[0].is_my_account, True)
self.assertEqual(tx.outputs[0].is_my_account, False) self.assertEqual(tx.outputs[0].is_my_account, False)
tx = yield self.ledger.db.get_transaction(txid=tx2.id, account=account2) tx = await self.ledger.db.get_transaction(txid=tx2.id, account=account2)
self.assertEqual(tx.inputs[0].is_my_account, False) self.assertEqual(tx.inputs[0].is_my_account, False)
self.assertEqual(tx.outputs[0].is_my_account, True) self.assertEqual(tx.outputs[0].is_my_account, True)

View file

@ -1,11 +1,6 @@
from unittest import TestCase from unittest import TestCase, mock
from torba.hash import aes_decrypt, aes_encrypt from torba.hash import aes_decrypt, aes_encrypt
try:
from unittest import mock
except ImportError:
import mock
class TestAESEncryptDecrypt(TestCase): class TestAESEncryptDecrypt(TestCase):
message = 'The Times 03/Jan/2009 Chancellor on brink of second bailout for banks' message = 'The Times 03/Jan/2009 Chancellor on brink of second bailout for banks'

View file

@ -1,8 +1,7 @@
import os import os
from urllib.request import Request, urlopen from urllib.request import Request, urlopen
from twisted.trial import unittest from orchstr8.testcase import AsyncioTestCase
from twisted.internet import defer
from torba.coin.bitcoinsegwit import MainHeaders from torba.coin.bitcoinsegwit import MainHeaders
@ -11,7 +10,7 @@ def block_bytes(blocks):
return blocks * MainHeaders.header_size return blocks * MainHeaders.header_size
class BitcoinHeadersTestCase(unittest.TestCase): class BitcoinHeadersTestCase(AsyncioTestCase):
# Download headers instead of storing them in git. # Download headers instead of storing them in git.
HEADER_URL = 'http://headers.electrum.org/blockchain_headers' HEADER_URL = 'http://headers.electrum.org/blockchain_headers'
@ -39,7 +38,7 @@ class BitcoinHeadersTestCase(unittest.TestCase):
headers.seek(after, os.SEEK_SET) headers.seek(after, os.SEEK_SET)
return headers.read(upto) return headers.read(upto)
def get_headers(self, upto: int = -1): async def get_headers(self, upto: int = -1):
h = MainHeaders(':memory:') h = MainHeaders(':memory:')
h.io.write(self.get_bytes(upto)) h.io.write(self.get_bytes(upto))
return h return h
@ -47,8 +46,8 @@ class BitcoinHeadersTestCase(unittest.TestCase):
class BasicHeadersTests(BitcoinHeadersTestCase): class BasicHeadersTests(BitcoinHeadersTestCase):
def test_serialization(self): async def test_serialization(self):
h = self.get_headers() h = await self.get_headers()
self.assertEqual(h[0], { self.assertEqual(h[0], {
'bits': 486604799, 'bits': 486604799,
'block_height': 0, 'block_height': 0,
@ -94,18 +93,16 @@ class BasicHeadersTests(BitcoinHeadersTestCase):
h.get_raw_header(self.RETARGET_BLOCK) h.get_raw_header(self.RETARGET_BLOCK)
) )
@defer.inlineCallbacks async def test_connect_from_genesis_to_3000_past_first_chunk_at_2016(self):
def test_connect_from_genesis_to_3000_past_first_chunk_at_2016(self):
headers = MainHeaders(':memory:') headers = MainHeaders(':memory:')
self.assertEqual(headers.height, -1) self.assertEqual(headers.height, -1)
yield headers.connect(0, self.get_bytes(block_bytes(3001))) await headers.connect(0, self.get_bytes(block_bytes(3001)))
self.assertEqual(headers.height, 3000) self.assertEqual(headers.height, 3000)
@defer.inlineCallbacks async def test_connect_9_blocks_passing_a_retarget_at_32256(self):
def test_connect_9_blocks_passing_a_retarget_at_32256(self):
retarget = block_bytes(self.RETARGET_BLOCK-5) retarget = block_bytes(self.RETARGET_BLOCK-5)
headers = self.get_headers(upto=retarget) headers = await self.get_headers(upto=retarget)
remainder = self.get_bytes(after=retarget) remainder = self.get_bytes(after=retarget)
self.assertEqual(headers.height, 32250) self.assertEqual(headers.height, 32250)
yield headers.connect(len(headers), remainder) await headers.connect(len(headers), remainder)
self.assertEqual(headers.height, 32259) self.assertEqual(headers.height, 32259)

View file

@ -1,6 +1,5 @@
import os import os
from binascii import hexlify from binascii import hexlify
from twisted.internet import defer
from torba.coin.bitcoinsegwit import MainNetLedger from torba.coin.bitcoinsegwit import MainNetLedger
from torba.wallet import Wallet from torba.wallet import Wallet
@ -18,32 +17,30 @@ class MockNetwork:
self.get_history_called = [] self.get_history_called = []
self.get_transaction_called = [] self.get_transaction_called = []
def get_history(self, address): async def get_history(self, address):
self.get_history_called.append(address) self.get_history_called.append(address)
self.address = address self.address = address
return defer.succeed(self.history) return self.history
def get_merkle(self, txid, height): async def get_merkle(self, txid, height):
return defer.succeed({'merkle': ['abcd01'], 'pos': 1}) return {'merkle': ['abcd01'], 'pos': 1}
def get_transaction(self, tx_hash): async def get_transaction(self, tx_hash):
self.get_transaction_called.append(tx_hash) self.get_transaction_called.append(tx_hash)
return defer.succeed(self.transaction[tx_hash]) return self.transaction[tx_hash]
class LedgerTestCase(BitcoinHeadersTestCase): class LedgerTestCase(BitcoinHeadersTestCase):
def setUp(self): async def asyncSetUp(self):
super().setUp()
self.ledger = MainNetLedger({ self.ledger = MainNetLedger({
'db': MainNetLedger.database_class(':memory:'), 'db': MainNetLedger.database_class(':memory:'),
'headers': MainNetLedger.headers_class(':memory:') 'headers': MainNetLedger.headers_class(':memory:')
}) })
return self.ledger.db.open() await self.ledger.db.open()
def tearDown(self): async def asyncTearDown(self):
super().tearDown() await self.ledger.db.close()
return self.ledger.db.close()
def make_header(self, **kwargs): def make_header(self, **kwargs):
header = { header = {
@ -69,11 +66,10 @@ class LedgerTestCase(BitcoinHeadersTestCase):
class TestSynchronization(LedgerTestCase): class TestSynchronization(LedgerTestCase):
@defer.inlineCallbacks async def test_update_history(self):
def test_update_history(self):
account = self.ledger.account_class.generate(self.ledger, Wallet(), "torba") account = self.ledger.account_class.generate(self.ledger, Wallet(), "torba")
address = yield account.receiving.get_or_create_usable_address() address = await account.receiving.get_or_create_usable_address()
address_details = yield self.ledger.db.get_address(address=address) address_details = await self.ledger.db.get_address(address=address)
self.assertEqual(address_details['history'], None) self.assertEqual(address_details['history'], None)
self.add_header(block_height=0, merkle_root=b'abcd04') self.add_header(block_height=0, merkle_root=b'abcd04')
@ -89,16 +85,16 @@ class TestSynchronization(LedgerTestCase):
'abcd02': hexlify(get_transaction(get_output(2)).raw), 'abcd02': hexlify(get_transaction(get_output(2)).raw),
'abcd03': hexlify(get_transaction(get_output(3)).raw), 'abcd03': hexlify(get_transaction(get_output(3)).raw),
}) })
yield self.ledger.update_history(address) await self.ledger.update_history(address)
self.assertEqual(self.ledger.network.get_history_called, [address]) self.assertEqual(self.ledger.network.get_history_called, [address])
self.assertEqual(self.ledger.network.get_transaction_called, ['abcd01', 'abcd02', 'abcd03']) self.assertEqual(self.ledger.network.get_transaction_called, ['abcd01', 'abcd02', 'abcd03'])
address_details = yield self.ledger.db.get_address(address=address) address_details = await self.ledger.db.get_address(address=address)
self.assertEqual(address_details['history'], 'abcd01:0:abcd02:1:abcd03:2:') self.assertEqual(address_details['history'], 'abcd01:0:abcd02:1:abcd03:2:')
self.ledger.network.get_history_called = [] self.ledger.network.get_history_called = []
self.ledger.network.get_transaction_called = [] self.ledger.network.get_transaction_called = []
yield self.ledger.update_history(address) await self.ledger.update_history(address)
self.assertEqual(self.ledger.network.get_history_called, [address]) self.assertEqual(self.ledger.network.get_history_called, [address])
self.assertEqual(self.ledger.network.get_transaction_called, []) self.assertEqual(self.ledger.network.get_transaction_called, [])
@ -106,10 +102,10 @@ class TestSynchronization(LedgerTestCase):
self.ledger.network.transaction['abcd04'] = hexlify(get_transaction(get_output(4)).raw) self.ledger.network.transaction['abcd04'] = hexlify(get_transaction(get_output(4)).raw)
self.ledger.network.get_history_called = [] self.ledger.network.get_history_called = []
self.ledger.network.get_transaction_called = [] self.ledger.network.get_transaction_called = []
yield self.ledger.update_history(address) await self.ledger.update_history(address)
self.assertEqual(self.ledger.network.get_history_called, [address]) self.assertEqual(self.ledger.network.get_history_called, [address])
self.assertEqual(self.ledger.network.get_transaction_called, ['abcd04']) self.assertEqual(self.ledger.network.get_transaction_called, ['abcd04'])
address_details = yield self.ledger.db.get_address(address=address) address_details = await self.ledger.db.get_address(address=address)
self.assertEqual(address_details['history'], 'abcd01:0:abcd02:1:abcd03:2:abcd04:3:') self.assertEqual(address_details['history'], 'abcd01:0:abcd02:1:abcd03:2:abcd04:3:')
@ -117,14 +113,13 @@ class MocHeaderNetwork:
def __init__(self, responses): def __init__(self, responses):
self.responses = responses self.responses = responses
def get_headers(self, height, blocks): async def get_headers(self, height, blocks):
return self.responses[height] return self.responses[height]
class BlockchainReorganizationTests(LedgerTestCase): class BlockchainReorganizationTests(LedgerTestCase):
@defer.inlineCallbacks async def test_1_block_reorganization(self):
def test_1_block_reorganization(self):
self.ledger.network = MocHeaderNetwork({ self.ledger.network = MocHeaderNetwork({
20: {'height': 20, 'count': 5, 'hex': hexlify( 20: {'height': 20, 'count': 5, 'hex': hexlify(
self.get_bytes(after=block_bytes(20), upto=block_bytes(5)) self.get_bytes(after=block_bytes(20), upto=block_bytes(5))
@ -132,15 +127,14 @@ class BlockchainReorganizationTests(LedgerTestCase):
25: {'height': 25, 'count': 0, 'hex': b''} 25: {'height': 25, 'count': 0, 'hex': b''}
}) })
headers = self.ledger.headers headers = self.ledger.headers
yield headers.connect(0, self.get_bytes(upto=block_bytes(20))) await headers.connect(0, self.get_bytes(upto=block_bytes(20)))
self.add_header(block_height=len(headers)) self.add_header(block_height=len(headers))
self.assertEqual(headers.height, 20) self.assertEqual(headers.height, 20)
yield self.ledger.receive_header([{ await self.ledger.receive_header([{
'height': 21, 'hex': hexlify(self.make_header(block_height=21)) 'height': 21, 'hex': hexlify(self.make_header(block_height=21))
}]) }])
@defer.inlineCallbacks async def test_3_block_reorganization(self):
def test_3_block_reorganization(self):
self.ledger.network = MocHeaderNetwork({ self.ledger.network = MocHeaderNetwork({
20: {'height': 20, 'count': 5, 'hex': hexlify( 20: {'height': 20, 'count': 5, 'hex': hexlify(
self.get_bytes(after=block_bytes(20), upto=block_bytes(5)) self.get_bytes(after=block_bytes(20), upto=block_bytes(5))
@ -150,11 +144,11 @@ class BlockchainReorganizationTests(LedgerTestCase):
25: {'height': 25, 'count': 0, 'hex': b''} 25: {'height': 25, 'count': 0, 'hex': b''}
}) })
headers = self.ledger.headers headers = self.ledger.headers
yield headers.connect(0, self.get_bytes(upto=block_bytes(20))) await headers.connect(0, self.get_bytes(upto=block_bytes(20)))
self.add_header(block_height=len(headers)) self.add_header(block_height=len(headers))
self.add_header(block_height=len(headers)) self.add_header(block_height=len(headers))
self.add_header(block_height=len(headers)) self.add_header(block_height=len(headers))
self.assertEqual(headers.height, 22) self.assertEqual(headers.height, 22)
yield self.ledger.receive_header(({ await self.ledger.receive_header(({
'height': 23, 'hex': hexlify(self.make_header(block_height=23)) 'height': 23, 'hex': hexlify(self.make_header(block_height=23))
},)) },))

View file

@ -1,5 +1,5 @@
import unittest
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from twisted.trial import unittest
from torba.bcd_data_stream import BCDataStream from torba.bcd_data_stream import BCDataStream
from torba.basescript import Template, ParseError, tokenize, push_data from torba.basescript import Template, ParseError, tokenize, push_data

View file

@ -1,7 +1,8 @@
import unittest
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from itertools import cycle from itertools import cycle
from twisted.trial import unittest
from twisted.internet import defer from orchstr8.testcase import AsyncioTestCase
from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class
from torba.wallet import Wallet from torba.wallet import Wallet
@ -29,9 +30,9 @@ def get_transaction(txo=None):
.add_outputs([txo or ledger_class.transaction_class.output_class.pay_pubkey_hash(CENT, NULL_HASH)]) .add_outputs([txo or ledger_class.transaction_class.output_class.pay_pubkey_hash(CENT, NULL_HASH)])
class TestSizeAndFeeEstimation(unittest.TestCase): class TestSizeAndFeeEstimation(AsyncioTestCase):
def setUp(self): async def asyncSetUp(self):
self.ledger = ledger_class({ self.ledger = ledger_class({
'db': ledger_class.database_class(':memory:'), 'db': ledger_class.database_class(':memory:'),
'headers': ledger_class.headers_class(':memory:'), 'headers': ledger_class.headers_class(':memory:'),
@ -181,17 +182,19 @@ class TestTransactionSerialization(unittest.TestCase):
self.assertEqual(tx.raw, raw) self.assertEqual(tx.raw, raw)
class TestTransactionSigning(unittest.TestCase): class TestTransactionSigning(AsyncioTestCase):
def setUp(self): async def asyncSetUp(self):
self.ledger = ledger_class({ self.ledger = ledger_class({
'db': ledger_class.database_class(':memory:'), 'db': ledger_class.database_class(':memory:'),
'headers': ledger_class.headers_class(':memory:'), 'headers': ledger_class.headers_class(':memory:'),
}) })
return self.ledger.db.open() await self.ledger.db.open()
@defer.inlineCallbacks async def asyncTearDown(self):
def test_sign(self): await self.ledger.db.close()
async def test_sign(self):
account = self.ledger.account_class.from_dict( account = self.ledger.account_class.from_dict(
self.ledger, Wallet(), { self.ledger, Wallet(), {
"seed": "carbon smart garage balance margin twelve chest sword " "seed": "carbon smart garage balance margin twelve chest sword "
@ -200,8 +203,8 @@ class TestTransactionSigning(unittest.TestCase):
} }
) )
yield account.ensure_address_gap() await account.ensure_address_gap()
address1, address2 = yield account.receiving.get_addresses(limit=2) address1, address2 = await account.receiving.get_addresses(limit=2)
pubkey_hash1 = self.ledger.address_to_hash160(address1) pubkey_hash1 = self.ledger.address_to_hash160(address1)
pubkey_hash2 = self.ledger.address_to_hash160(address2) pubkey_hash2 = self.ledger.address_to_hash160(address2)
@ -211,9 +214,8 @@ class TestTransactionSigning(unittest.TestCase):
.add_inputs([tx_class.input_class.spend(get_output(2*COIN, pubkey_hash1))]) \ .add_inputs([tx_class.input_class.spend(get_output(2*COIN, pubkey_hash1))]) \
.add_outputs([tx_class.output_class.pay_pubkey_hash(int(1.9*COIN), pubkey_hash2)]) \ .add_outputs([tx_class.output_class.pay_pubkey_hash(int(1.9*COIN), pubkey_hash2)]) \
yield tx.sign([account]) await tx.sign([account])
print(hexlify(tx.inputs[0].script.values['signature']))
self.assertEqual( self.assertEqual(
hexlify(tx.inputs[0].script.values['signature']), hexlify(tx.inputs[0].script.values['signature']),
b'304402205a1df8cd5d2d2fa5934b756883d6c07e4f83e1350c740992d47a12422' b'304402205a1df8cd5d2d2fa5934b756883d6c07e4f83e1350c740992d47a12422'
@ -221,15 +223,14 @@ class TestTransactionSigning(unittest.TestCase):
) )
class TransactionIOBalancing(unittest.TestCase): class TransactionIOBalancing(AsyncioTestCase):
@defer.inlineCallbacks async def asyncSetUp(self):
def setUp(self):
self.ledger = ledger_class({ self.ledger = ledger_class({
'db': ledger_class.database_class(':memory:'), 'db': ledger_class.database_class(':memory:'),
'headers': ledger_class.headers_class(':memory:'), 'headers': ledger_class.headers_class(':memory:'),
}) })
yield self.ledger.db.open() await self.ledger.db.open()
self.account = self.ledger.account_class.from_dict( self.account = self.ledger.account_class.from_dict(
self.ledger, Wallet(), { self.ledger, Wallet(), {
"seed": "carbon smart garage balance margin twelve chest sword " "seed": "carbon smart garage balance margin twelve chest sword "
@ -237,10 +238,13 @@ class TransactionIOBalancing(unittest.TestCase):
} }
) )
addresses = yield self.account.ensure_address_gap() addresses = await self.account.ensure_address_gap()
self.pubkey_hash = [self.ledger.address_to_hash160(a) for a in addresses] self.pubkey_hash = [self.ledger.address_to_hash160(a) for a in addresses]
self.hash_cycler = cycle(self.pubkey_hash) self.hash_cycler = cycle(self.pubkey_hash)
async def asyncTearDown(self):
await self.ledger.db.close()
def txo(self, amount, address=None): def txo(self, amount, address=None):
return get_output(int(amount*COIN), address or next(self.hash_cycler)) return get_output(int(amount*COIN), address or next(self.hash_cycler))
@ -250,8 +254,7 @@ class TransactionIOBalancing(unittest.TestCase):
def tx(self, inputs, outputs): def tx(self, inputs, outputs):
return ledger_class.transaction_class.create(inputs, outputs, [self.account], self.account) return ledger_class.transaction_class.create(inputs, outputs, [self.account], self.account)
@defer.inlineCallbacks async def create_utxos(self, amounts):
def create_utxos(self, amounts):
utxos = [self.txo(amount) for amount in amounts] utxos = [self.txo(amount) for amount in amounts]
self.funding_tx = ledger_class.transaction_class(is_verified=True) \ self.funding_tx = ledger_class.transaction_class(is_verified=True) \
@ -260,7 +263,7 @@ class TransactionIOBalancing(unittest.TestCase):
save_tx = 'insert' save_tx = 'insert'
for utxo in utxos: for utxo in utxos:
yield self.ledger.db.save_transaction_io( await self.ledger.db.save_transaction_io(
save_tx, self.funding_tx, save_tx, self.funding_tx,
self.ledger.hash160_to_address(utxo.script.values['pubkey_hash']), self.ledger.hash160_to_address(utxo.script.values['pubkey_hash']),
utxo.script.values['pubkey_hash'], '' utxo.script.values['pubkey_hash'], ''
@ -277,17 +280,16 @@ class TransactionIOBalancing(unittest.TestCase):
def outputs(tx): def outputs(tx):
return [round(o.amount/COIN, 2) for o in tx.outputs] return [round(o.amount/COIN, 2) for o in tx.outputs]
@defer.inlineCallbacks async def test_basic_use_cases(self):
def test_basic_use_cases(self):
self.ledger.fee_per_byte = int(.01*CENT) self.ledger.fee_per_byte = int(.01*CENT)
# available UTXOs for filling missing inputs # available UTXOs for filling missing inputs
utxos = yield self.create_utxos([ utxos = await self.create_utxos([
1, 1, 3, 5, 10 1, 1, 3, 5, 10
]) ])
# pay 3 coins (3.02 w/ fees) # pay 3 coins (3.02 w/ fees)
tx = yield self.tx( tx = await self.tx(
[], # inputs [], # inputs
[self.txo(3)] # outputs [self.txo(3)] # outputs
) )
@ -296,10 +298,10 @@ class TransactionIOBalancing(unittest.TestCase):
# a change of 1.98 is added to reach balance # a change of 1.98 is added to reach balance
self.assertEqual(self.outputs(tx), [3, 1.98]) self.assertEqual(self.outputs(tx), [3, 1.98])
yield self.ledger.release_outputs(utxos) await self.ledger.release_outputs(utxos)
# pay 2.98 coins (3.00 w/ fees) # pay 2.98 coins (3.00 w/ fees)
tx = yield self.tx( tx = await self.tx(
[], # inputs [], # inputs
[self.txo(2.98)] # outputs [self.txo(2.98)] # outputs
) )
@ -307,10 +309,10 @@ class TransactionIOBalancing(unittest.TestCase):
self.assertEqual(self.inputs(tx), [3]) self.assertEqual(self.inputs(tx), [3])
self.assertEqual(self.outputs(tx), [2.98]) self.assertEqual(self.outputs(tx), [2.98])
yield self.ledger.release_outputs(utxos) await self.ledger.release_outputs(utxos)
# supplied input and output, but input is not enough to cover output # supplied input and output, but input is not enough to cover output
tx = yield self.tx( tx = await self.tx(
[self.txi(self.txo(10))], # inputs [self.txi(self.txo(10))], # inputs
[self.txo(11)] # outputs [self.txo(11)] # outputs
) )
@ -319,10 +321,10 @@ class TransactionIOBalancing(unittest.TestCase):
# change is now needed to consume extra input # change is now needed to consume extra input
self.assertEqual([11, 1.96], self.outputs(tx)) self.assertEqual([11, 1.96], self.outputs(tx))
yield self.ledger.release_outputs(utxos) await self.ledger.release_outputs(utxos)
# liquidating a UTXO # liquidating a UTXO
tx = yield self.tx( tx = await self.tx(
[self.txi(self.txo(10))], # inputs [self.txi(self.txo(10))], # inputs
[] # outputs [] # outputs
) )
@ -330,10 +332,10 @@ class TransactionIOBalancing(unittest.TestCase):
# missing change added to consume the amount # missing change added to consume the amount
self.assertEqual([9.98], self.outputs(tx)) self.assertEqual([9.98], self.outputs(tx))
yield self.ledger.release_outputs(utxos) await self.ledger.release_outputs(utxos)
# liquidating at a loss, requires adding extra inputs # liquidating at a loss, requires adding extra inputs
tx = yield self.tx( tx = await self.tx(
[self.txi(self.txo(0.01))], # inputs [self.txi(self.txo(0.01))], # inputs
[] # outputs [] # outputs
) )

View file

@ -1,4 +1,4 @@
from twisted.trial import unittest import unittest
from torba.util import ArithUint256 from torba.util import ArithUint256

View file

@ -1,5 +1,5 @@
import unittest
import tempfile import tempfile
from twisted.trial import unittest
from torba.coin.bitcoinsegwit import MainNetLedger as BTCLedger from torba.coin.bitcoinsegwit import MainNetLedger as BTCLedger
from torba.coin.bitcoincash import MainNetLedger as BCHLedger from torba.coin.bitcoincash import MainNetLedger as BCHLedger

View file

@ -1,7 +1,6 @@
import random import random
import typing import typing
from typing import List, Dict, Tuple, Type, Optional, Any from typing import List, Dict, Tuple, Type, Optional, Any
from twisted.internet import defer
from torba.mnemonic import Mnemonic from torba.mnemonic import Mnemonic
from torba.bip32 import PrivateKey, PubKey, from_extended_key_string from torba.bip32 import PrivateKey, PubKey, from_extended_key_string
@ -44,12 +43,8 @@ class AddressManager:
def to_dict_instance(self) -> Optional[dict]: def to_dict_instance(self) -> Optional[dict]:
raise NotImplementedError raise NotImplementedError
@property
def db(self):
return self.account.ledger.db
def _query_addresses(self, **constraints): def _query_addresses(self, **constraints):
return self.db.get_addresses( return self.account.ledger.db.get_addresses(
account=self.account, account=self.account,
chain=self.chain_number, chain=self.chain_number,
**constraints **constraints
@ -58,26 +53,24 @@ class AddressManager:
def get_private_key(self, index: int) -> PrivateKey: def get_private_key(self, index: int) -> PrivateKey:
raise NotImplementedError raise NotImplementedError
def get_max_gap(self) -> defer.Deferred: async def get_max_gap(self):
raise NotImplementedError raise NotImplementedError
def ensure_address_gap(self) -> defer.Deferred: async def ensure_address_gap(self):
raise NotImplementedError raise NotImplementedError
def get_address_records(self, only_usable: bool = False, **constraints) -> defer.Deferred: def get_address_records(self, only_usable: bool = False, **constraints):
raise NotImplementedError raise NotImplementedError
@defer.inlineCallbacks async def get_addresses(self, only_usable: bool = False, **constraints):
def get_addresses(self, only_usable: bool = False, **constraints) -> defer.Deferred: records = await self.get_address_records(only_usable=only_usable, **constraints)
records = yield self.get_address_records(only_usable=only_usable, **constraints)
return [r['address'] for r in records] return [r['address'] for r in records]
@defer.inlineCallbacks async def get_or_create_usable_address(self):
def get_or_create_usable_address(self) -> defer.Deferred: addresses = await self.get_addresses(only_usable=True, limit=10)
addresses = yield self.get_addresses(only_usable=True, limit=10)
if addresses: if addresses:
return random.choice(addresses) return random.choice(addresses)
addresses = yield self.ensure_address_gap() addresses = await self.ensure_address_gap()
return addresses[0] return addresses[0]
@ -106,22 +99,20 @@ class HierarchicalDeterministic(AddressManager):
def get_private_key(self, index: int) -> PrivateKey: def get_private_key(self, index: int) -> PrivateKey:
return self.account.private_key.child(self.chain_number).child(index) return self.account.private_key.child(self.chain_number).child(index)
@defer.inlineCallbacks async def generate_keys(self, start: int, end: int):
def generate_keys(self, start: int, end: int) -> defer.Deferred:
keys_batch, final_keys = [], [] keys_batch, final_keys = [], []
for index in range(start, end+1): for index in range(start, end+1):
keys_batch.append((index, self.public_key.child(index))) keys_batch.append((index, self.public_key.child(index)))
if index % 180 == 0 or index == end: if index % 180 == 0 or index == end:
yield self.db.add_keys( await self.account.ledger.db.add_keys(
self.account, self.chain_number, keys_batch self.account, self.chain_number, keys_batch
) )
final_keys.extend(keys_batch) final_keys.extend(keys_batch)
keys_batch.clear() keys_batch.clear()
return [key[1].address for key in final_keys] return [key[1].address for key in final_keys]
@defer.inlineCallbacks async def get_max_gap(self):
def get_max_gap(self) -> defer.Deferred: addresses = await self._query_addresses(order_by="position ASC")
addresses = yield self._query_addresses(order_by="position ASC")
max_gap = 0 max_gap = 0
current_gap = 0 current_gap = 0
for address in addresses: for address in addresses:
@ -132,9 +123,8 @@ class HierarchicalDeterministic(AddressManager):
current_gap = 0 current_gap = 0
return max_gap return max_gap
@defer.inlineCallbacks async def ensure_address_gap(self):
def ensure_address_gap(self) -> defer.Deferred: addresses = await self._query_addresses(limit=self.gap, order_by="position DESC")
addresses = yield self._query_addresses(limit=self.gap, order_by="position DESC")
existing_gap = 0 existing_gap = 0
for address in addresses: for address in addresses:
@ -148,7 +138,7 @@ class HierarchicalDeterministic(AddressManager):
start = addresses[0]['position']+1 if addresses else 0 start = addresses[0]['position']+1 if addresses else 0
end = start + (self.gap - existing_gap) end = start + (self.gap - existing_gap)
new_keys = yield self.generate_keys(start, end-1) new_keys = await self.generate_keys(start, end-1)
return new_keys return new_keys
def get_address_records(self, only_usable: bool = False, **constraints): def get_address_records(self, only_usable: bool = False, **constraints):
@ -176,20 +166,19 @@ class SingleKey(AddressManager):
def get_private_key(self, index: int) -> PrivateKey: def get_private_key(self, index: int) -> PrivateKey:
return self.account.private_key return self.account.private_key
def get_max_gap(self) -> defer.Deferred: async def get_max_gap(self):
return defer.succeed(0) return 0
@defer.inlineCallbacks async def ensure_address_gap(self):
def ensure_address_gap(self) -> defer.Deferred: exists = await self.get_address_records()
exists = yield self.get_address_records()
if not exists: if not exists:
yield self.db.add_keys( await self.account.ledger.db.add_keys(
self.account, self.chain_number, [(0, self.public_key)] self.account, self.chain_number, [(0, self.public_key)]
) )
return [self.public_key.address] return [self.public_key.address]
return [] return []
def get_address_records(self, only_usable: bool = False, **constraints) -> defer.Deferred: def get_address_records(self, only_usable: bool = False, **constraints):
return self._query_addresses(**constraints) return self._query_addresses(**constraints)
@ -289,9 +278,8 @@ class BaseAccount:
'address_generator': self.address_generator.to_dict(self.receiving, self.change) 'address_generator': self.address_generator.to_dict(self.receiving, self.change)
} }
@defer.inlineCallbacks async def get_details(self, show_seed=False, **kwargs):
def get_details(self, show_seed=False, **kwargs): satoshis = await self.get_balance(**kwargs)
satoshis = yield self.get_balance(**kwargs)
details = { details = {
'id': self.id, 'id': self.id,
'name': self.name, 'name': self.name,
@ -325,23 +313,21 @@ class BaseAccount:
self.password = None self.password = None
self.encrypted = True self.encrypted = True
@defer.inlineCallbacks async def ensure_address_gap(self):
def ensure_address_gap(self):
addresses = [] addresses = []
for address_manager in self.address_managers: for address_manager in self.address_managers:
new_addresses = yield address_manager.ensure_address_gap() new_addresses = await address_manager.ensure_address_gap()
addresses.extend(new_addresses) addresses.extend(new_addresses)
return addresses return addresses
@defer.inlineCallbacks async def get_addresses(self, **constraints):
def get_addresses(self, **constraints) -> defer.Deferred: rows = await self.ledger.db.select_addresses('address', account=self, **constraints)
rows = yield self.ledger.db.select_addresses('address', account=self, **constraints)
return [r[0] for r in rows] return [r[0] for r in rows]
def get_address_records(self, **constraints) -> defer.Deferred: def get_address_records(self, **constraints):
return self.ledger.db.get_addresses(account=self, **constraints) return self.ledger.db.get_addresses(account=self, **constraints)
def get_address_count(self, **constraints) -> defer.Deferred: def get_address_count(self, **constraints):
return self.ledger.db.get_address_count(account=self, **constraints) return self.ledger.db.get_address_count(account=self, **constraints)
def get_private_key(self, chain: int, index: int) -> PrivateKey: def get_private_key(self, chain: int, index: int) -> PrivateKey:
@ -355,10 +341,9 @@ class BaseAccount:
constraints.update({'height__lte': height, 'height__gt': 0}) constraints.update({'height__lte': height, 'height__gt': 0})
return self.ledger.db.get_balance(account=self, **constraints) return self.ledger.db.get_balance(account=self, **constraints)
@defer.inlineCallbacks async def get_max_gap(self):
def get_max_gap(self): change_gap = await self.change.get_max_gap()
change_gap = yield self.change.get_max_gap() receiving_gap = await self.receiving.get_max_gap()
receiving_gap = yield self.receiving.get_max_gap()
return { return {
'max_change_gap': change_gap, 'max_change_gap': change_gap,
'max_receiving_gap': receiving_gap, 'max_receiving_gap': receiving_gap,
@ -376,24 +361,23 @@ class BaseAccount:
def get_transaction_count(self, **constraints): def get_transaction_count(self, **constraints):
return self.ledger.db.get_transaction_count(account=self, **constraints) return self.ledger.db.get_transaction_count(account=self, **constraints)
@defer.inlineCallbacks async def fund(self, to_account, amount=None, everything=False,
def fund(self, to_account, amount=None, everything=False,
outputs=1, broadcast=False, **constraints): outputs=1, broadcast=False, **constraints):
assert self.ledger == to_account.ledger, 'Can only transfer between accounts of the same ledger.' assert self.ledger == to_account.ledger, 'Can only transfer between accounts of the same ledger.'
tx_class = self.ledger.transaction_class tx_class = self.ledger.transaction_class
if everything: if everything:
utxos = yield self.get_utxos(**constraints) utxos = await self.get_utxos(**constraints)
yield self.ledger.reserve_outputs(utxos) await self.ledger.reserve_outputs(utxos)
tx = yield tx_class.create( tx = await tx_class.create(
inputs=[tx_class.input_class.spend(txo) for txo in utxos], inputs=[tx_class.input_class.spend(txo) for txo in utxos],
outputs=[], outputs=[],
funding_accounts=[self], funding_accounts=[self],
change_account=to_account change_account=to_account
) )
elif amount > 0: elif amount > 0:
to_address = yield to_account.change.get_or_create_usable_address() to_address = await to_account.change.get_or_create_usable_address()
to_hash160 = to_account.ledger.address_to_hash160(to_address) to_hash160 = to_account.ledger.address_to_hash160(to_address)
tx = yield tx_class.create( tx = await tx_class.create(
inputs=[], inputs=[],
outputs=[ outputs=[
tx_class.output_class.pay_pubkey_hash(amount//outputs, to_hash160) tx_class.output_class.pay_pubkey_hash(amount//outputs, to_hash160)
@ -406,9 +390,9 @@ class BaseAccount:
raise ValueError('An amount is required.') raise ValueError('An amount is required.')
if broadcast: if broadcast:
yield self.ledger.broadcast(tx) await self.ledger.broadcast(tx)
else: else:
yield self.ledger.release_outputs( await self.ledger.release_outputs(
[txi.txo_ref.txo for txi in tx.inputs] [txi.txo_ref.txo for txi in tx.inputs]
) )

View file

@ -1,9 +1,8 @@
import logging import logging
from typing import Tuple, List, Sequence from typing import Tuple, List
import sqlite3 import sqlite3
from twisted.internet import defer import aiosqlite
from twisted.enterprise import adbapi
from torba.hash import TXRefImmutable from torba.hash import TXRefImmutable
from torba.basetransaction import BaseTransaction from torba.basetransaction import BaseTransaction
@ -107,25 +106,21 @@ def row_dict_or_default(rows, fields, default=None):
class SQLiteMixin: class SQLiteMixin:
CREATE_TABLES_QUERY: Sequence[str] = () CREATE_TABLES_QUERY: str
def __init__(self, path): def __init__(self, path):
self._db_path = path self._db_path = path
self.db: adbapi.ConnectionPool = None self.db: aiosqlite.Connection = None
self.ledger = None self.ledger = None
def open(self): async def open(self):
log.info("connecting to database: %s", self._db_path) log.info("connecting to database: %s", self._db_path)
self.db = adbapi.ConnectionPool( self.db = aiosqlite.connect(self._db_path)
'sqlite3', self._db_path, cp_min=1, cp_max=1, check_same_thread=False await self.db.__aenter__()
) await self.db.executescript(self.CREATE_TABLES_QUERY)
return self.db.runInteraction(
lambda t: t.executescript(self.CREATE_TABLES_QUERY)
)
def close(self): async def close(self):
self.db.close() await self.db.close()
return defer.succeed(True)
@staticmethod @staticmethod
def _insert_sql(table: str, data: dict) -> Tuple[str, List]: def _insert_sql(table: str, data: dict) -> Tuple[str, List]:
@ -247,78 +242,75 @@ class BaseDatabase(SQLiteMixin):
'script': sqlite3.Binary(txo.script.source) 'script': sqlite3.Binary(txo.script.source)
} }
def save_transaction_io(self, save_tx, tx: BaseTransaction, address, txhash, history): async def save_transaction_io(self, save_tx, tx: BaseTransaction, address, txhash, history):
def _steps(t): if save_tx == 'insert':
if save_tx == 'insert': await self.db.execute(*self._insert_sql('tx', {
self.execute(t, *self._insert_sql('tx', { 'txid': tx.id,
'raw': sqlite3.Binary(tx.raw),
'height': tx.height,
'position': tx.position,
'is_verified': tx.is_verified
}))
elif save_tx == 'update':
await self.db.execute(*self._update_sql("tx", {
'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified
}, 'txid = ?', (tx.id,)))
existing_txos = [r[0] for r in await self.db.execute_fetchall(*query(
"SELECT position FROM txo", txid=tx.id
))]
for txo in tx.outputs:
if txo.position in existing_txos:
continue
if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == txhash:
await self.db.execute(*self._insert_sql("txo", self.txo_to_row(tx, address, txo)))
elif txo.script.is_pay_script_hash:
# TODO: implement script hash payments
print('Database.save_transaction_io: pay script hash is not implemented!')
# lookup the address associated with each TXI (via its TXO)
txoid_to_address = {r[0]: r[1] for r in await self.db.execute_fetchall(*query(
"SELECT txoid, address FROM txo", txoid__in=[txi.txo_ref.id for txi in tx.inputs]
))}
# list of TXIs that have already been added
existing_txis = [r[0] for r in await self.db.execute_fetchall(*query(
"SELECT txoid FROM txi", txid=tx.id
))]
for txi in tx.inputs:
txoid = txi.txo_ref.id
new_txi = txoid not in existing_txis
address_matches = txoid_to_address.get(txoid) == address
if new_txi and address_matches:
await self.db.execute(*self._insert_sql("txi", {
'txid': tx.id, 'txid': tx.id,
'raw': sqlite3.Binary(tx.raw), 'txoid': txoid,
'height': tx.height, 'address': address,
'position': tx.position,
'is_verified': tx.is_verified
})) }))
elif save_tx == 'update':
self.execute(t, *self._update_sql("tx", {
'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified
}, 'txid = ?', (tx.id,)))
existing_txos = [r[0] for r in self.execute(t, *query( await self._set_address_history(address, history)
"SELECT position FROM txo", txid=tx.id
)).fetchall()]
for txo in tx.outputs: async def reserve_outputs(self, txos, is_reserved=True):
if txo.position in existing_txos:
continue
if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == txhash:
self.execute(t, *self._insert_sql("txo", self.txo_to_row(tx, address, txo)))
elif txo.script.is_pay_script_hash:
# TODO: implement script hash payments
print('Database.save_transaction_io: pay script hash is not implemented!')
# lookup the address associated with each TXI (via its TXO)
txoid_to_address = {r[0]: r[1] for r in self.execute(t, *query(
"SELECT txoid, address FROM txo", txoid__in=[txi.txo_ref.id for txi in tx.inputs]
)).fetchall()}
# list of TXIs that have already been added
existing_txis = [r[0] for r in self.execute(t, *query(
"SELECT txoid FROM txi", txid=tx.id
)).fetchall()]
for txi in tx.inputs:
txoid = txi.txo_ref.id
new_txi = txoid not in existing_txis
address_matches = txoid_to_address.get(txoid) == address
if new_txi and address_matches:
self.execute(t, *self._insert_sql("txi", {
'txid': tx.id,
'txoid': txoid,
'address': address,
}))
self._set_address_history(t, address, history)
return self.db.runInteraction(_steps)
def reserve_outputs(self, txos, is_reserved=True):
txoids = [txo.id for txo in txos] txoids = [txo.id for txo in txos]
return self.run_operation( await self.db.execute(
"UPDATE txo SET is_reserved = ? WHERE txoid IN ({})".format( "UPDATE txo SET is_reserved = ? WHERE txoid IN ({})".format(
', '.join(['?']*len(txoids)) ', '.join(['?']*len(txoids))
), [is_reserved]+txoids ), [is_reserved]+txoids
) )
def release_outputs(self, txos): async def release_outputs(self, txos):
return self.reserve_outputs(txos, is_reserved=False) await self.reserve_outputs(txos, is_reserved=False)
def rewind_blockchain(self, above_height): # pylint: disable=no-self-use async def rewind_blockchain(self, above_height): # pylint: disable=no-self-use
# TODO: # TODO:
# 1. delete transactions above_height # 1. delete transactions above_height
# 2. update address histories removing deleted TXs # 2. update address histories removing deleted TXs
return defer.succeed(True) return True
def select_transactions(self, cols, account=None, **constraints): async def select_transactions(self, cols, account=None, **constraints):
if 'txid' not in constraints and account is not None: if 'txid' not in constraints and account is not None:
constraints['$account'] = account.public_key.address constraints['$account'] = account.public_key.address
constraints['txid__in'] = """ constraints['txid__in'] = """
@ -328,13 +320,14 @@ class BaseDatabase(SQLiteMixin):
SELECT txi.txid FROM txi SELECT txi.txid FROM txi
JOIN pubkey_address USING (address) WHERE pubkey_address.account = :$account JOIN pubkey_address USING (address) WHERE pubkey_address.account = :$account
""" """
return self.run_query(*query("SELECT {} FROM tx".format(cols), **constraints)) return await self.db.execute_fetchall(
*query("SELECT {} FROM tx".format(cols), **constraints)
)
@defer.inlineCallbacks async def get_transactions(self, my_account=None, **constraints):
def get_transactions(self, my_account=None, **constraints):
my_account = my_account or constraints.get('account', None) my_account = my_account or constraints.get('account', None)
tx_rows = yield self.select_transactions( tx_rows = await self.select_transactions(
'txid, raw, height, position, is_verified', 'txid, raw, height, position, is_verified',
order_by=["height DESC", "position DESC"], order_by=["height DESC", "position DESC"],
**constraints **constraints
@ -352,7 +345,7 @@ class BaseDatabase(SQLiteMixin):
annotated_txos = { annotated_txos = {
txo.id: txo for txo in txo.id: txo for txo in
(yield self.get_txos( (await self.get_txos(
my_account=my_account, my_account=my_account,
txid__in=txids txid__in=txids
)) ))
@ -360,7 +353,7 @@ class BaseDatabase(SQLiteMixin):
referenced_txos = { referenced_txos = {
txo.id: txo for txo in txo.id: txo for txo in
(yield self.get_txos( (await self.get_txos(
my_account=my_account, my_account=my_account,
txoid__in=query("SELECT txoid FROM txi", **{'txid__in': txids})[0] txoid__in=query("SELECT txoid FROM txi", **{'txid__in': txids})[0]
)) ))
@ -380,33 +373,30 @@ class BaseDatabase(SQLiteMixin):
return txs return txs
@defer.inlineCallbacks async def get_transaction_count(self, **constraints):
def get_transaction_count(self, **constraints):
constraints.pop('offset', None) constraints.pop('offset', None)
constraints.pop('limit', None) constraints.pop('limit', None)
constraints.pop('order_by', None) constraints.pop('order_by', None)
count = yield self.select_transactions('count(*)', **constraints) count = await self.select_transactions('count(*)', **constraints)
return count[0][0] return count[0][0]
@defer.inlineCallbacks async def get_transaction(self, **constraints):
def get_transaction(self, **constraints): txs = await self.get_transactions(limit=1, **constraints)
txs = yield self.get_transactions(limit=1, **constraints)
if txs: if txs:
return txs[0] return txs[0]
def select_txos(self, cols, **constraints): async def select_txos(self, cols, **constraints):
return self.run_query(*query( return await self.db.execute_fetchall(*query(
"SELECT {} FROM txo" "SELECT {} FROM txo"
" JOIN pubkey_address USING (address)" " JOIN pubkey_address USING (address)"
" JOIN tx USING (txid)".format(cols), **constraints " JOIN tx USING (txid)".format(cols), **constraints
)) ))
@defer.inlineCallbacks async def get_txos(self, my_account=None, **constraints):
def get_txos(self, my_account=None, **constraints):
my_account = my_account or constraints.get('account', None) my_account = my_account or constraints.get('account', None)
if isinstance(my_account, BaseAccount): if isinstance(my_account, BaseAccount):
my_account = my_account.public_key.address my_account = my_account.public_key.address
rows = yield self.select_txos( rows = await self.select_txos(
"amount, script, txid, tx.height, txo.position, chain, account", **constraints "amount, script, txid, tx.height, txo.position, chain, account", **constraints
) )
output_class = self.ledger.transaction_class.output_class output_class = self.ledger.transaction_class.output_class
@ -421,12 +411,11 @@ class BaseDatabase(SQLiteMixin):
) for row in rows ) for row in rows
] ]
@defer.inlineCallbacks async def get_txo_count(self, **constraints):
def get_txo_count(self, **constraints):
constraints.pop('offset', None) constraints.pop('offset', None)
constraints.pop('limit', None) constraints.pop('limit', None)
constraints.pop('order_by', None) constraints.pop('order_by', None)
count = yield self.select_txos('count(*)', **constraints) count = await self.select_txos('count(*)', **constraints)
return count[0][0] return count[0][0]
@staticmethod @staticmethod
@ -442,37 +431,33 @@ class BaseDatabase(SQLiteMixin):
self.constrain_utxo(constraints) self.constrain_utxo(constraints)
return self.get_txo_count(**constraints) return self.get_txo_count(**constraints)
@defer.inlineCallbacks async def get_balance(self, **constraints):
def get_balance(self, **constraints):
self.constrain_utxo(constraints) self.constrain_utxo(constraints)
balance = yield self.select_txos('SUM(amount)', **constraints) balance = await self.select_txos('SUM(amount)', **constraints)
return balance[0][0] or 0 return balance[0][0] or 0
def select_addresses(self, cols, **constraints): async def select_addresses(self, cols, **constraints):
return self.run_query(*query( return await self.db.execute_fetchall(*query(
"SELECT {} FROM pubkey_address".format(cols), **constraints "SELECT {} FROM pubkey_address".format(cols), **constraints
)) ))
@defer.inlineCallbacks async def get_addresses(self, cols=('address', 'account', 'chain', 'position', 'used_times'), **constraints):
def get_addresses(self, cols=('address', 'account', 'chain', 'position', 'used_times'), **constraints): addresses = await self.select_addresses(', '.join(cols), **constraints)
addresses = yield self.select_addresses(', '.join(cols), **constraints)
return rows_to_dict(addresses, cols) return rows_to_dict(addresses, cols)
@defer.inlineCallbacks async def get_address_count(self, **constraints):
def get_address_count(self, **constraints): count = await self.select_addresses('count(*)', **constraints)
count = yield self.select_addresses('count(*)', **constraints)
return count[0][0] return count[0][0]
@defer.inlineCallbacks async def get_address(self, **constraints):
def get_address(self, **constraints): addresses = await self.get_addresses(
addresses = yield self.get_addresses(
cols=('address', 'account', 'chain', 'position', 'pubkey', 'history', 'used_times'), cols=('address', 'account', 'chain', 'position', 'pubkey', 'history', 'used_times'),
limit=1, **constraints limit=1, **constraints
) )
if addresses: if addresses:
return addresses[0] return addresses[0]
def add_keys(self, account, chain, keys): async def add_keys(self, account, chain, keys):
sql = ( sql = (
"insert into pubkey_address " "insert into pubkey_address "
"(address, account, chain, position, pubkey) " "(address, account, chain, position, pubkey) "
@ -484,14 +469,13 @@ class BaseDatabase(SQLiteMixin):
pubkey.address, account.public_key.address, chain, position, pubkey.address, account.public_key.address, chain, position,
sqlite3.Binary(pubkey.pubkey_bytes) sqlite3.Binary(pubkey.pubkey_bytes)
)) ))
return self.run_operation(sql, values) await self.db.execute(sql, values)
@classmethod async def _set_address_history(self, address, history):
def _set_address_history(cls, t, address, history): await self.db.execute(
cls.execute( "UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
t, "UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
(history, history.count(':')//2, address) (history, history.count(':')//2, address)
) )
def set_address_history(self, address, history): async def set_address_history(self, address, history):
return self.db.runInteraction(lambda t: self._set_address_history(t, address, history)) await self._set_address_history(address, history)

View file

@ -1,11 +1,10 @@
import os import os
import asyncio
import logging import logging
from io import BytesIO from io import BytesIO
from typing import Optional, Iterator, Tuple from typing import Optional, Iterator, Tuple
from binascii import hexlify from binascii import hexlify
from twisted.internet import threads, defer
from torba.util import ArithUint256 from torba.util import ArithUint256
from torba.hash import double_sha256 from torba.hash import double_sha256
@ -36,16 +35,14 @@ class BaseHeaders:
self.io = BytesIO() self.io = BytesIO()
self.path = path self.path = path
self._size: Optional[int] = None self._size: Optional[int] = None
self._header_connect_lock = defer.DeferredLock() self._header_connect_lock = asyncio.Lock()
def open(self): async def open(self):
if self.path != ':memory:': if self.path != ':memory:':
self.io = open(self.path, 'a+b') self.io = open(self.path, 'a+b')
return defer.succeed(True)
def close(self): async def close(self):
self.io.close() self.io.close()
return defer.succeed(True)
@staticmethod @staticmethod
def serialize(header: dict) -> bytes: def serialize(header: dict) -> bytes:
@ -95,16 +92,15 @@ class BaseHeaders:
return b'0' * 64 return b'0' * 64
return hexlify(double_sha256(header)[::-1]) return hexlify(double_sha256(header)[::-1])
@defer.inlineCallbacks async def connect(self, start: int, headers: bytes) -> int:
def connect(self, start: int, headers: bytes):
added = 0 added = 0
bail = False bail = False
yield self._header_connect_lock.acquire() loop = asyncio.get_running_loop()
try: async with self._header_connect_lock:
for height, chunk in self._iterate_chunks(start, headers): for height, chunk in self._iterate_chunks(start, headers):
try: try:
# validate_chunk() is CPU bound and reads previous chunks from file system # validate_chunk() is CPU bound and reads previous chunks from file system
yield threads.deferToThread(self.validate_chunk, height, chunk) await loop.run_in_executor(None, self.validate_chunk, height, chunk)
except InvalidHeader as e: except InvalidHeader as e:
bail = True bail = True
chunk = chunk[:(height-e.height+1)*self.header_size] chunk = chunk[:(height-e.height+1)*self.header_size]
@ -115,14 +111,12 @@ class BaseHeaders:
self.io.truncate() self.io.truncate()
# .seek()/.write()/.truncate() might also .flush() when needed # .seek()/.write()/.truncate() might also .flush() when needed
# the goal here is mainly to ensure we're definitely flush()'ing # the goal here is mainly to ensure we're definitely flush()'ing
yield threads.deferToThread(self.io.flush) await loop.run_in_executor(None, self.io.flush)
self._size = None self._size = None
added += written added += written
if bail: if bail:
break break
finally: return added
self._header_connect_lock.release()
defer.returnValue(added)
def validate_chunk(self, height, chunk): def validate_chunk(self, height, chunk):
previous_hash, previous_header, previous_previous_header = None, None, None previous_hash, previous_header, previous_previous_header = None, None, None

View file

@ -1,4 +1,5 @@
import os import os
import asyncio
import logging import logging
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from io import StringIO from io import StringIO
@ -7,8 +8,6 @@ from typing import Dict, Type, Iterable
from operator import itemgetter from operator import itemgetter
from collections import namedtuple from collections import namedtuple
from twisted.internet import defer
from torba import baseaccount from torba import baseaccount
from torba import basenetwork from torba import basenetwork
from torba import basetransaction from torba import basetransaction
@ -104,8 +103,8 @@ class BaseLedger(metaclass=LedgerRegistry):
) )
self._transaction_processing_locks = {} self._transaction_processing_locks = {}
self._utxo_reservation_lock = defer.DeferredLock() self._utxo_reservation_lock = asyncio.Lock()
self._header_processing_lock = defer.DeferredLock() self._header_processing_lock = asyncio.Lock()
@classmethod @classmethod
def get_id(cls): def get_id(cls):
@ -135,41 +134,32 @@ class BaseLedger(metaclass=LedgerRegistry):
def add_account(self, account: baseaccount.BaseAccount): def add_account(self, account: baseaccount.BaseAccount):
self.accounts.append(account) self.accounts.append(account)
@defer.inlineCallbacks async def get_private_key_for_address(self, address):
def get_private_key_for_address(self, address): match = await self.db.get_address(address=address)
match = yield self.db.get_address(address=address)
if match: if match:
for account in self.accounts: for account in self.accounts:
if match['account'] == account.public_key.address: if match['account'] == account.public_key.address:
return account.get_private_key(match['chain'], match['position']) return account.get_private_key(match['chain'], match['position'])
@defer.inlineCallbacks async def get_effective_amount_estimators(self, funding_accounts: Iterable[baseaccount.BaseAccount]):
def get_effective_amount_estimators(self, funding_accounts: Iterable[baseaccount.BaseAccount]):
estimators = [] estimators = []
for account in funding_accounts: for account in funding_accounts:
utxos = yield account.get_utxos() utxos = await account.get_utxos()
for utxo in utxos: for utxo in utxos:
estimators.append(utxo.get_estimator(self)) estimators.append(utxo.get_estimator(self))
return estimators return estimators
@defer.inlineCallbacks async def get_spendable_utxos(self, amount: int, funding_accounts):
def get_spendable_utxos(self, amount: int, funding_accounts): async with self._utxo_reservation_lock:
yield self._utxo_reservation_lock.acquire() txos = await self.get_effective_amount_estimators(funding_accounts)
try:
txos = yield self.get_effective_amount_estimators(funding_accounts)
selector = CoinSelector( selector = CoinSelector(
txos, amount, txos, amount,
self.transaction_class.output_class.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(self) self.transaction_class.output_class.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(self)
) )
spendables = selector.select() spendables = selector.select()
if spendables: if spendables:
yield self.reserve_outputs(s.txo for s in spendables) await self.reserve_outputs(s.txo for s in spendables)
except Exception: return spendables
log.exception('Failed to get spendable utxos:')
raise
finally:
self._utxo_reservation_lock.release()
return spendables
def reserve_outputs(self, txos): def reserve_outputs(self, txos):
return self.db.reserve_outputs(txos) return self.db.reserve_outputs(txos)
@ -177,16 +167,14 @@ class BaseLedger(metaclass=LedgerRegistry):
def release_outputs(self, txos): def release_outputs(self, txos):
return self.db.release_outputs(txos) return self.db.release_outputs(txos)
@defer.inlineCallbacks async def get_local_status(self, address):
def get_local_status(self, address): address_details = await self.db.get_address(address=address)
address_details = yield self.db.get_address(address=address)
history = address_details['history'] or '' history = address_details['history'] or ''
h = sha256(history.encode()) h = sha256(history.encode())
return hexlify(h) return hexlify(h)
@defer.inlineCallbacks async def get_local_history(self, address):
def get_local_history(self, address): address_details = await self.db.get_address(address=address)
address_details = yield self.db.get_address(address=address)
history = address_details['history'] or '' history = address_details['history'] or ''
parts = history.split(':')[:-1] parts = history.split(':')[:-1]
return list(zip(parts[0::2], map(int, parts[1::2]))) return list(zip(parts[0::2], map(int, parts[1::2])))
@ -203,43 +191,40 @@ class BaseLedger(metaclass=LedgerRegistry):
working_branch = double_sha256(combined) working_branch = double_sha256(combined)
return hexlify(working_branch[::-1]) return hexlify(working_branch[::-1])
def validate_transaction_and_set_position(self, tx, height, merkle): async def validate_transaction_and_set_position(self, tx, height):
if not height <= len(self.headers): if not height <= len(self.headers):
return False return False
merkle = await self.network.get_merkle(tx.id, height)
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
header = self.headers[height] header = self.headers[height]
tx.position = merkle['pos'] tx.position = merkle['pos']
tx.is_verified = merkle_root == header['merkle_root'] tx.is_verified = merkle_root == header['merkle_root']
@defer.inlineCallbacks async def start(self):
def start(self):
if not os.path.exists(self.path): if not os.path.exists(self.path):
os.mkdir(self.path) os.mkdir(self.path)
yield defer.gatherResults([ await asyncio.gather(
self.db.open(), self.db.open(),
self.headers.open() self.headers.open()
]) )
first_connection = self.network.on_connected.first first_connection = self.network.on_connected.first
self.network.start() asyncio.ensure_future(self.network.start())
yield first_connection await first_connection
yield self.join_network() await self.join_network()
self.network.on_connected.listen(self.join_network) self.network.on_connected.listen(self.join_network)
@defer.inlineCallbacks async def join_network(self, *args):
def join_network(self, *args):
log.info("Subscribing and updating accounts.") log.info("Subscribing and updating accounts.")
yield self.update_headers() await self.update_headers()
yield self.network.subscribe_headers() await self.network.subscribe_headers()
yield self.update_accounts() await self.update_accounts()
@defer.inlineCallbacks async def stop(self):
def stop(self): await self.network.stop()
yield self.network.stop() await self.db.close()
yield self.db.close() await self.headers.close()
yield self.headers.close()
@defer.inlineCallbacks async def update_headers(self, height=None, headers=None, subscription_update=False):
def update_headers(self, height=None, headers=None, subscription_update=False):
rewound = 0 rewound = 0
while True: while True:
@ -251,14 +236,14 @@ class BaseLedger(metaclass=LedgerRegistry):
subscription_update = False subscription_update = False
if not headers: if not headers:
header_response = yield self.network.get_headers(height, 2001) header_response = await self.network.get_headers(height, 2001)
headers = header_response['hex'] headers = header_response['hex']
if not headers: if not headers:
# Nothing to do, network thinks we're already at the latest height. # Nothing to do, network thinks we're already at the latest height.
return return
added = yield self.headers.connect(height, unhexlify(headers)) added = await self.headers.connect(height, unhexlify(headers))
if added > 0: if added > 0:
height += added height += added
self._on_header_controller.add( self._on_header_controller.add(
@ -268,7 +253,7 @@ class BaseLedger(metaclass=LedgerRegistry):
# we started rewinding blocks and apparently found # we started rewinding blocks and apparently found
# a new chain # a new chain
rewound = 0 rewound = 0
yield self.db.rewind_blockchain(height) await self.db.rewind_blockchain(height)
if subscription_update: if subscription_update:
# subscription updates are for latest header already # subscription updates are for latest header already
@ -310,66 +295,37 @@ class BaseLedger(metaclass=LedgerRegistry):
# robust sync, turn off subscription update shortcut # robust sync, turn off subscription update shortcut
subscription_update = False subscription_update = False
@defer.inlineCallbacks async def receive_header(self, response):
def receive_header(self, response): async with self._header_processing_lock:
yield self._header_processing_lock.acquire()
try:
header = response[0] header = response[0]
yield self.update_headers( await self.update_headers(
height=header['height'], headers=header['hex'], subscription_update=True height=header['height'], headers=header['hex'], subscription_update=True
) )
finally:
self._header_processing_lock.release()
def update_accounts(self): async def update_accounts(self):
return defer.DeferredList([ return await asyncio.gather(*(
self.update_account(a) for a in self.accounts self.update_account(a) for a in self.accounts
]) ))
@defer.inlineCallbacks async def update_account(self, account: baseaccount.BaseAccount):
def update_account(self, account): # type: (baseaccount.BaseAccount) -> defer.Defferred
# Before subscribing, download history for any addresses that don't have any, # Before subscribing, download history for any addresses that don't have any,
# this avoids situation where we're getting status updates to addresses we know # this avoids situation where we're getting status updates to addresses we know
# need to update anyways. Continue to get history and create more addresses until # need to update anyways. Continue to get history and create more addresses until
# all missing addresses are created and history for them is fully restored. # all missing addresses are created and history for them is fully restored.
yield account.ensure_address_gap() await account.ensure_address_gap()
addresses = yield account.get_addresses(used_times=0) addresses = await account.get_addresses(used_times=0)
while addresses: while addresses:
yield defer.DeferredList([ await asyncio.gather(*(self.update_history(a) for a in addresses))
self.update_history(a) for a in addresses addresses = await account.ensure_address_gap()
])
addresses = yield account.ensure_address_gap()
# By this point all of the addresses should be restored and we # By this point all of the addresses should be restored and we
# can now subscribe all of them to receive updates. # can now subscribe all of them to receive updates.
all_addresses = yield account.get_addresses() all_addresses = await account.get_addresses()
yield defer.DeferredList( await asyncio.gather(*(self.subscribe_history(a) for a in all_addresses))
list(map(self.subscribe_history, all_addresses))
)
@defer.inlineCallbacks async def update_history(self, address):
def _prefetch_history(self, remote_history, local_history): remote_history = await self.network.get_history(address)
proofs, network_txs, deferreds = {}, {}, [] local_history = await self.get_local_history(address)
for i, (hex_id, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)):
if i < len(local_history) and local_history[i] == (hex_id, remote_height):
continue
if remote_height > 0:
deferreds.append(
self.network.get_merkle(hex_id, remote_height).addBoth(
lambda result, txid: proofs.__setitem__(txid, result), hex_id)
)
deferreds.append(
self.network.get_transaction(hex_id).addBoth(
lambda result, txid: network_txs.__setitem__(txid, result), hex_id)
)
yield defer.DeferredList(deferreds)
return proofs, network_txs
@defer.inlineCallbacks
def update_history(self, address):
remote_history = yield self.network.get_history(address)
local_history = yield self.get_local_history(address)
proofs, network_txs = yield self._prefetch_history(remote_history, local_history)
synced_history = StringIO() synced_history = StringIO()
for i, (hex_id, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)): for i, (hex_id, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)):
@ -379,30 +335,29 @@ class BaseLedger(metaclass=LedgerRegistry):
if i < len(local_history) and local_history[i] == (hex_id, remote_height): if i < len(local_history) and local_history[i] == (hex_id, remote_height):
continue continue
lock = self._transaction_processing_locks.setdefault(hex_id, defer.DeferredLock()) lock = self._transaction_processing_locks.setdefault(hex_id, asyncio.Lock())
yield lock.acquire() await lock.acquire()
try: try:
# see if we have a local copy of transaction, otherwise fetch it from server # see if we have a local copy of transaction, otherwise fetch it from server
tx = yield self.db.get_transaction(txid=hex_id) tx = await self.db.get_transaction(txid=hex_id)
save_tx = None save_tx = None
if tx is None: if tx is None:
_raw = network_txs[hex_id] _raw = await self.network.get_transaction(hex_id)
tx = self.transaction_class(unhexlify(_raw)) tx = self.transaction_class(unhexlify(_raw))
save_tx = 'insert' save_tx = 'insert'
tx.height = remote_height tx.height = remote_height
if remote_height > 0 and (not tx.is_verified or tx.position == -1): if remote_height > 0 and (not tx.is_verified or tx.position == -1):
self.validate_transaction_and_set_position(tx, remote_height, proofs[hex_id]) await self.validate_transaction_and_set_position(tx, remote_height)
if save_tx is None: if save_tx is None:
save_tx = 'update' save_tx = 'update'
yield self.db.save_transaction_io( await self.db.save_transaction_io(
save_tx, tx, address, self.address_to_hash160(address), save_tx, tx, address, self.address_to_hash160(address), synced_history.getvalue()
synced_history.getvalue()
) )
log.debug( log.debug(
@ -412,28 +367,22 @@ class BaseLedger(metaclass=LedgerRegistry):
self._on_transaction_controller.add(TransactionEvent(address, tx)) self._on_transaction_controller.add(TransactionEvent(address, tx))
except Exception:
log.exception('Failed to synchronize transaction:')
raise
finally: finally:
lock.release() lock.release()
if not lock.locked and hex_id in self._transaction_processing_locks: if not lock.locked() and hex_id in self._transaction_processing_locks:
del self._transaction_processing_locks[hex_id] del self._transaction_processing_locks[hex_id]
@defer.inlineCallbacks async def subscribe_history(self, address):
def subscribe_history(self, address): remote_status = await self.network.subscribe_address(address)
remote_status = yield self.network.subscribe_address(address) local_status = await self.get_local_status(address)
local_status = yield self.get_local_status(address)
if local_status != remote_status: if local_status != remote_status:
yield self.update_history(address) await self.update_history(address)
@defer.inlineCallbacks async def receive_status(self, response):
def receive_status(self, response):
address, remote_status = response address, remote_status = response
local_status = yield self.get_local_status(address) local_status = await self.get_local_status(address)
if local_status != remote_status: if local_status != remote_status:
yield self.update_history(address) await self.update_history(address)
def broadcast(self, tx): def broadcast(self, tx):
return self.network.broadcast(hexlify(tx.raw).decode()) return self.network.broadcast(hexlify(tx.raw).decode())

View file

@ -1,6 +1,6 @@
import asyncio
import logging import logging
from typing import Type, MutableSequence, MutableMapping from typing import Type, MutableSequence, MutableMapping
from twisted.internet import defer
from torba.baseledger import BaseLedger, LedgerRegistry from torba.baseledger import BaseLedger, LedgerRegistry
from torba.wallet import Wallet, WalletStorage from torba.wallet import Wallet, WalletStorage
@ -41,11 +41,10 @@ class BaseWalletManager:
self.wallets.append(wallet) self.wallets.append(wallet)
return wallet return wallet
@defer.inlineCallbacks async def get_detailed_accounts(self, confirmations=6, show_seed=False):
def get_detailed_accounts(self, confirmations=6, show_seed=False):
ledgers = {} ledgers = {}
for i, account in enumerate(self.accounts): for i, account in enumerate(self.accounts):
details = yield account.get_details(confirmations=confirmations, show_seed=True) details = await account.get_details(confirmations=confirmations, show_seed=True)
details['is_default_account'] = i == 0 details['is_default_account'] = i == 0
ledger_id = account.ledger.get_id() ledger_id = account.ledger.get_id()
ledgers.setdefault(ledger_id, []) ledgers.setdefault(ledger_id, [])
@ -68,16 +67,14 @@ class BaseWalletManager:
for account in wallet.accounts: for account in wallet.accounts:
yield account yield account
@defer.inlineCallbacks async def start(self):
def start(self):
self.running = True self.running = True
yield defer.DeferredList([ await asyncio.gather(*(
l.start() for l in self.ledgers.values() l.start() for l in self.ledgers.values()
]) ))
@defer.inlineCallbacks async def stop(self):
def stop(self): await asyncio.gather(*(
yield defer.DeferredList([
l.stop() for l in self.ledgers.values() l.stop() for l in self.ledgers.values()
]) ))
self.running = False self.running = False

View file

@ -1,145 +1,36 @@
import json
import socket
import logging import logging
from itertools import cycle from itertools import cycle
from twisted.internet import defer, reactor, protocol
from twisted.application.internet import ClientService, CancelledError from aiorpcx import ClientSession as BaseClientSession
from twisted.internet.endpoints import clientFromString
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python import failure
from torba import __version__ from torba import __version__
from torba.stream import StreamController from torba.stream import StreamController
from torba.constants import TIMEOUT
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class StratumClientProtocol(LineOnlyReceiver): class ClientSession(BaseClientSession):
delimiter = b'\n'
MAX_LENGTH = 2000000
def __init__(self): def __init__(self, *args, network, **kwargs):
self.request_id = 0
self.lookup_table = {}
self.session = {}
self.network = None
self.on_disconnected_controller = StreamController()
self.on_disconnected = self.on_disconnected_controller.stream
def _get_id(self):
self.request_id += 1
return self.request_id
@property
def _ip(self):
return self.transport.getPeer().host
def get_session(self):
return self.session
def connectionMade(self):
try:
self.transport.setTcpNoDelay(True)
self.transport.setTcpKeepAlive(True)
if hasattr(socket, "TCP_KEEPIDLE"):
self.transport.socket.setsockopt(
socket.SOL_TCP, socket.TCP_KEEPIDLE, 120
# Seconds before sending keepalive probes
)
else:
log.debug("TCP_KEEPIDLE not available")
if hasattr(socket, "TCP_KEEPINTVL"):
self.transport.socket.setsockopt(
socket.SOL_TCP, socket.TCP_KEEPINTVL, 1
# Interval in seconds between keepalive probes
)
else:
log.debug("TCP_KEEPINTVL not available")
if hasattr(socket, "TCP_KEEPCNT"):
self.transport.socket.setsockopt(
socket.SOL_TCP, socket.TCP_KEEPCNT, 5
# Failed keepalive probles before declaring other end dead
)
else:
log.debug("TCP_KEEPCNT not available")
except Exception as err: # pylint: disable=broad-except
# Supported only by the socket transport,
# but there's really no better place in code to trigger this.
log.warning("Error setting up socket: %s", err)
def connectionLost(self, reason=None):
self.connected = 0
self.on_disconnected_controller.add(True)
for deferred in self.lookup_table.values():
if not deferred.called:
deferred.errback(TimeoutError("Connection dropped."))
def lineReceived(self, line):
log.debug('received: %s', line)
try:
message = json.loads(line)
except (ValueError, TypeError):
raise ValueError("Cannot decode message '{}'".format(line.strip()))
if message.get('id'):
try:
d = self.lookup_table.pop(message['id'])
if message.get('error'):
d.errback(RuntimeError(message['error']))
else:
d.callback(message.get('result'))
except KeyError:
raise LookupError(
"Lookup for deferred object for message ID '{}' failed.".format(message['id']))
elif message.get('method') in self.network.subscription_controllers:
controller = self.network.subscription_controllers[message['method']]
controller.add(message.get('params'))
else:
log.warning("Cannot handle message '%s'", line)
def rpc(self, method, *args):
message_id = self._get_id()
message = json.dumps({
'id': message_id,
'method': method,
'params': args
})
log.debug('sent: %s', message)
self.sendLine(message.encode('latin-1'))
d = self.lookup_table[message_id] = defer.Deferred()
d.addTimeout(
TIMEOUT, reactor, onTimeoutCancel=lambda *_: failure.Failure(TimeoutError(
"Timeout: Stratum request for '%s' took more than %s seconds" % (method, TIMEOUT)))
)
return d
class StratumClientFactory(protocol.ClientFactory):
protocol = StratumClientProtocol
def __init__(self, network):
self.network = network self.network = network
self.client = None super().__init__(*args, **kwargs)
self._on_disconnect_controller = StreamController()
self.on_disconnected = self._on_disconnect_controller.stream
def buildProtocol(self, addr): async def handle_request(self, request):
client = self.protocol() controller = self.network.subscription_controllers[request.method]
client.factory = self controller.add(request.args)
client.network = self.network
self.client = client def connection_lost(self, exc):
return client super().connection_lost(exc)
self._on_disconnect_controller.add(True)
class BaseNetwork: class BaseNetwork:
def __init__(self, ledger): def __init__(self, ledger):
self.config = ledger.config self.config = ledger.config
self.client = None self.client: ClientSession = None
self.service = None
self.running = False self.running = False
self._on_connected_controller = StreamController() self._on_connected_controller = StreamController()
@ -156,48 +47,35 @@ class BaseNetwork:
'blockchain.address.subscribe': self._on_status_controller, 'blockchain.address.subscribe': self._on_status_controller,
} }
@defer.inlineCallbacks async def start(self):
def start(self):
self.running = True self.running = True
for server in cycle(self.config['default_servers']): for server in cycle(self.config['default_servers']):
connection_string = 'tcp:{}:{}'.format(*server) connection_string = 'tcp:{}:{}'.format(*server)
endpoint = clientFromString(reactor, connection_string) self.client = ClientSession(*server, network=self)
log.debug("Attempting connection to SPV wallet server: %s", connection_string)
self.service = ClientService(endpoint, StratumClientFactory(self))
self.service.startService()
try: try:
self.client = yield self.service.whenConnected(failAfterFailures=2) await self.client.create_connection()
yield self.ensure_server_version() await self.ensure_server_version()
log.info("Successfully connected to SPV wallet server: %s", connection_string) log.info("Successfully connected to SPV wallet server: %s", )
self._on_connected_controller.add(True) self._on_connected_controller.add(True)
yield self.client.on_disconnected.first await self.client.on_disconnected.first
except CancelledError:
return
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
log.exception("Connecting to %s raised an exception:", connection_string) log.exception("Connecting to %s raised an exception:", connection_string)
finally:
self.client = None
if self.service is not None:
self.service.stopService()
if not self.running: if not self.running:
return return
def stop(self): async def stop(self):
self.running = False self.running = False
if self.service is not None:
self.service.stopService()
if self.is_connected: if self.is_connected:
return self.client.on_disconnected.first await self.client.close()
else: await self.client.on_disconnected.first
return defer.succeed(True)
@property @property
def is_connected(self): def is_connected(self):
return self.client is not None and self.client.connected return self.client is not None and not self.client.is_closing()
def rpc(self, list_or_method, *args): def rpc(self, list_or_method, *args):
if self.is_connected: if self.is_connected:
return self.client.rpc(list_or_method, *args) return self.client.send_request(list_or_method, args)
else: else:
raise ConnectionError("Attempting to send rpc request when connection is not available.") raise ConnectionError("Attempting to send rpc request when connection is not available.")

View file

@ -3,8 +3,6 @@ import typing
from typing import List, Iterable, Optional from typing import List, Iterable, Optional
from binascii import hexlify from binascii import hexlify
from twisted.internet import defer
from torba.basescript import BaseInputScript, BaseOutputScript from torba.basescript import BaseInputScript, BaseOutputScript
from torba.baseaccount import BaseAccount from torba.baseaccount import BaseAccount
from torba.constants import COIN, NULL_HASH32 from torba.constants import COIN, NULL_HASH32
@ -426,8 +424,7 @@ class BaseTransaction:
return ledger return ledger
@classmethod @classmethod
@defer.inlineCallbacks async def create(cls, inputs: Iterable[BaseInput], outputs: Iterable[BaseOutput],
def create(cls, inputs: Iterable[BaseInput], outputs: Iterable[BaseOutput],
funding_accounts: Iterable[BaseAccount], change_account: BaseAccount): funding_accounts: Iterable[BaseAccount], change_account: BaseAccount):
""" Find optimal set of inputs when only outputs are provided; add change """ Find optimal set of inputs when only outputs are provided; add change
outputs if only inputs are provided or if inputs are greater than outputs. """ outputs if only inputs are provided or if inputs are greater than outputs. """
@ -450,7 +447,7 @@ class BaseTransaction:
if payment < cost: if payment < cost:
deficit = cost - payment deficit = cost - payment
spendables = yield ledger.get_spendable_utxos(deficit, funding_accounts) spendables = await ledger.get_spendable_utxos(deficit, funding_accounts)
if not spendables: if not spendables:
raise ValueError('Not enough funds to cover this transaction.') raise ValueError('Not enough funds to cover this transaction.')
payment += sum(s.effective_amount for s in spendables) payment += sum(s.effective_amount for s in spendables)
@ -463,28 +460,27 @@ class BaseTransaction:
) )
change = payment - cost change = payment - cost
if change > cost_of_change: if change > cost_of_change:
change_address = yield change_account.change.get_or_create_usable_address() change_address = await change_account.change.get_or_create_usable_address()
change_hash160 = change_account.ledger.address_to_hash160(change_address) change_hash160 = change_account.ledger.address_to_hash160(change_address)
change_amount = change - cost_of_change change_amount = change - cost_of_change
change_output = cls.output_class.pay_pubkey_hash(change_amount, change_hash160) change_output = cls.output_class.pay_pubkey_hash(change_amount, change_hash160)
change_output.is_change = True change_output.is_change = True
tx.add_outputs([cls.output_class.pay_pubkey_hash(change_amount, change_hash160)]) tx.add_outputs([cls.output_class.pay_pubkey_hash(change_amount, change_hash160)])
yield tx.sign(funding_accounts) await tx.sign(funding_accounts)
except Exception as e: except Exception as e:
log.exception('Failed to synchronize transaction:') log.exception('Failed to synchronize transaction:')
yield ledger.release_outputs(tx.outputs) await ledger.release_outputs(tx.outputs)
raise e raise e
defer.returnValue(tx) return tx
@staticmethod @staticmethod
def signature_hash_type(hash_type): def signature_hash_type(hash_type):
return hash_type return hash_type
@defer.inlineCallbacks async def sign(self, funding_accounts: Iterable[BaseAccount]):
def sign(self, funding_accounts: Iterable[BaseAccount]) -> defer.Deferred:
ledger = self.ensure_all_have_same_ledger(funding_accounts) ledger = self.ensure_all_have_same_ledger(funding_accounts)
for i, txi in enumerate(self._inputs): for i, txi in enumerate(self._inputs):
assert txi.script is not None assert txi.script is not None
@ -492,7 +488,7 @@ class BaseTransaction:
txo_script = txi.txo_ref.txo.script txo_script = txi.txo_ref.txo.script
if txo_script.is_pay_pubkey_hash: if txo_script.is_pay_pubkey_hash:
address = ledger.hash160_to_address(txo_script.values['pubkey_hash']) address = ledger.hash160_to_address(txo_script.values['pubkey_hash'])
private_key = yield ledger.get_private_key_for_address(address) private_key = await ledger.get_private_key_for_address(address)
tx = self._serialize_for_signature(i) tx = self._serialize_for_signature(i)
txi.script.values['signature'] = \ txi.script.values['signature'] = \
private_key.sign(tx) + bytes((self.signature_hash_type(1),)) private_key.sign(tx) + bytes((self.signature_hash_type(1),))

View file

@ -1,6 +1,4 @@
import asyncio import asyncio
from twisted.internet.defer import Deferred
from twisted.python.failure import Failure
class BroadcastSubscription: class BroadcastSubscription:
@ -31,11 +29,13 @@ class BroadcastSubscription:
def _add(self, data): def _add(self, data):
if self.can_fire and self._on_data is not None: if self.can_fire and self._on_data is not None:
self._on_data(data) maybe_coroutine = self._on_data(data)
if asyncio.iscoroutine(maybe_coroutine):
asyncio.ensure_future(maybe_coroutine)
def _add_error(self, error, traceback): def _add_error(self, exception):
if self.can_fire and self._on_error is not None: if self.can_fire and self._on_error is not None:
self._on_error(error, traceback) self._on_error(exception)
def _close(self): def _close(self):
if self.can_fire and self._on_done is not None: if self.can_fire and self._on_done is not None:
@ -66,9 +66,9 @@ class StreamController:
for subscription in self._iterate_subscriptions: for subscription in self._iterate_subscriptions:
subscription._add(event) subscription._add(event)
def add_error(self, error, traceback): def add_error(self, exception):
for subscription in self._iterate_subscriptions: for subscription in self._iterate_subscriptions:
subscription._add_error(error, traceback) subscription._add_error(exception)
def close(self): def close(self):
for subscription in self._iterate_subscriptions: for subscription in self._iterate_subscriptions:
@ -108,38 +108,35 @@ class Stream:
def listen(self, on_data, on_error=None, on_done=None): def listen(self, on_data, on_error=None, on_done=None):
return self._controller._listen(on_data, on_error, on_done) return self._controller._listen(on_data, on_error, on_done)
def deferred_where(self, condition): def where(self, condition) -> asyncio.Future:
deferred = Deferred() future = asyncio.get_event_loop().create_future()
def where_test(value): def where_test(value):
if condition(value): if condition(value):
self._cancel_and_callback(subscription, deferred, value) self._cancel_and_callback(subscription, future, value)
subscription = self.listen( subscription = self.listen(
where_test, where_test,
lambda error, traceback: self._cancel_and_error(subscription, deferred, error, traceback) lambda exception: self._cancel_and_error(subscription, future, exception)
) )
return deferred return future
def where(self, condition):
return self.deferred_where(condition).asFuture(asyncio.get_event_loop())
@property @property
def first(self): def first(self):
deferred = Deferred() future = asyncio.get_event_loop().create_future()
subscription = self.listen( subscription = self.listen(
lambda value: self._cancel_and_callback(subscription, deferred, value), lambda value: self._cancel_and_callback(subscription, future, value),
lambda error, traceback: self._cancel_and_error(subscription, deferred, error, traceback) lambda exception: self._cancel_and_error(subscription, future, exception)
) )
return deferred return future
@staticmethod @staticmethod
def _cancel_and_callback(subscription, deferred, value): def _cancel_and_callback(subscription: BroadcastSubscription, future: asyncio.Future, value):
subscription.cancel() subscription.cancel()
deferred.callback(value) future.set_result(value)
@staticmethod @staticmethod
def _cancel_and_error(subscription, deferred, error, traceback): def _cancel_and_error(subscription: BroadcastSubscription, future: asyncio.Future, exception):
subscription.cancel() subscription.cancel()
deferred.errback(Failure(error, exc_tb=traceback)) future.set_exception(exception)