diff --git a/tests/integration/test_transactions.py b/tests/integration/test_transactions.py index 697397ba1..8b5c7a263 100644 --- a/tests/integration/test_transactions.py +++ b/tests/integration/test_transactions.py @@ -1,5 +1,5 @@ import asyncio -from orchstr8.testcase import IntegrationTestCase, d2f +from orchstr8.testcase import IntegrationTestCase from torba.constants import COIN @@ -9,14 +9,14 @@ class BasicTransactionTests(IntegrationTestCase): async def test_sending_and_receiving(self): account1, account2 = self.account, self.wallet.generate_account(self.ledger) - await d2f(self.ledger.update_account(account2)) + await self.ledger.update_account(account2) self.assertEqual(await self.get_balance(account1), 0) self.assertEqual(await self.get_balance(account2), 0) sendtxids = [] for i in range(5): - address1 = await d2f(account1.receiving.get_or_create_usable_address()) + address1 = await account1.receiving.get_or_create_usable_address() sendtxid = await self.blockchain.send_to_address(address1, 1.1) sendtxids.append(sendtxid) await self.on_transaction_id(sendtxid) # mempool @@ -28,13 +28,13 @@ class BasicTransactionTests(IntegrationTestCase): self.assertEqual(round(await self.get_balance(account1)/COIN, 1), 5.5) self.assertEqual(round(await self.get_balance(account2)/COIN, 1), 0) - address2 = await d2f(account2.receiving.get_or_create_usable_address()) + address2 = await account2.receiving.get_or_create_usable_address() hash2 = self.ledger.address_to_hash160(address2) - tx = await d2f(self.ledger.transaction_class.create( + tx = await self.ledger.transaction_class.create( [], [self.ledger.transaction_class.output_class.pay_pubkey_hash(2*COIN, hash2)], [account1], account1 - )) + ) await self.broadcast(tx) await self.on_transaction(tx) # mempool await self.blockchain.generate(1) @@ -43,18 +43,18 @@ class BasicTransactionTests(IntegrationTestCase): self.assertEqual(round(await self.get_balance(account1)/COIN, 1), 3.5) self.assertEqual(round(await self.get_balance(account2)/COIN, 1), 2.0) - utxos = await d2f(self.account.get_utxos()) - tx = await d2f(self.ledger.transaction_class.create( + utxos = await self.account.get_utxos() + tx = await self.ledger.transaction_class.create( [self.ledger.transaction_class.input_class.spend(utxos[0])], [], [account1], account1 - )) + ) await self.broadcast(tx) await self.on_transaction(tx) # mempool await self.blockchain.generate(1) await self.on_transaction(tx) # confirmed - txs = await d2f(account1.get_transactions()) + txs = await account1.get_transactions() tx = txs[1] self.assertEqual(round(tx.inputs[0].txo_ref.txo.amount/COIN, 1), 1.1) self.assertEqual(round(tx.inputs[1].txo_ref.txo.amount/COIN, 1), 1.1) diff --git a/tests/unit/test_account.py b/tests/unit/test_account.py index 2e840c20b..347c0f843 100644 --- a/tests/unit/test_account.py +++ b/tests/unit/test_account.py @@ -1,25 +1,26 @@ from binascii import hexlify -from twisted.trial import unittest -from twisted.internet import defer + +from orchstr8.testcase import AsyncioTestCase from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class from torba.baseaccount import HierarchicalDeterministic, SingleKey from torba.wallet import Wallet -class TestHierarchicalDeterministicAccount(unittest.TestCase): +class TestHierarchicalDeterministicAccount(AsyncioTestCase): - @defer.inlineCallbacks - def setUp(self): + async def asyncSetUp(self): self.ledger = ledger_class({ 'db': ledger_class.database_class(':memory:'), 'headers': ledger_class.headers_class(':memory:'), }) - yield self.ledger.db.open() + await self.ledger.db.open() self.account = self.ledger.account_class.generate(self.ledger, Wallet(), "torba") - @defer.inlineCallbacks - def test_generate_account(self): + async def asyncTearDown(self): + await self.ledger.db.close() + + async def test_generate_account(self): account = self.account self.assertEqual(account.ledger, self.ledger) @@ -27,81 +28,77 @@ class TestHierarchicalDeterministicAccount(unittest.TestCase): self.assertEqual(account.public_key.ledger, self.ledger) self.assertEqual(account.private_key.public_key, account.public_key) - addresses = yield account.receiving.get_addresses() + addresses = await account.receiving.get_addresses() self.assertEqual(len(addresses), 0) - addresses = yield account.change.get_addresses() + addresses = await account.change.get_addresses() self.assertEqual(len(addresses), 0) - yield account.ensure_address_gap() + await account.ensure_address_gap() - addresses = yield account.receiving.get_addresses() + addresses = await account.receiving.get_addresses() self.assertEqual(len(addresses), 20) - addresses = yield account.change.get_addresses() + addresses = await account.change.get_addresses() self.assertEqual(len(addresses), 6) - addresses = yield account.get_addresses() + addresses = await account.get_addresses() self.assertEqual(len(addresses), 26) - @defer.inlineCallbacks - def test_generate_keys_over_batch_threshold_saves_it_properly(self): - yield self.account.receiving.generate_keys(0, 200) - records = yield self.account.receiving.get_address_records() + async def test_generate_keys_over_batch_threshold_saves_it_properly(self): + await self.account.receiving.generate_keys(0, 200) + records = await self.account.receiving.get_address_records() self.assertEqual(201, len(records)) - @defer.inlineCallbacks - def test_ensure_address_gap(self): + async def test_ensure_address_gap(self): account = self.account self.assertIsInstance(account.receiving, HierarchicalDeterministic) - yield account.receiving.generate_keys(4, 7) - yield account.receiving.generate_keys(0, 3) - yield account.receiving.generate_keys(8, 11) - records = yield account.receiving.get_address_records() + await account.receiving.generate_keys(4, 7) + await account.receiving.generate_keys(0, 3) + await account.receiving.generate_keys(8, 11) + records = await account.receiving.get_address_records() self.assertEqual( [r['position'] for r in records], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] ) # we have 12, but default gap is 20 - new_keys = yield account.receiving.ensure_address_gap() + new_keys = await account.receiving.ensure_address_gap() self.assertEqual(len(new_keys), 8) - records = yield account.receiving.get_address_records() + records = await account.receiving.get_address_records() self.assertEqual( [r['position'] for r in records], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] ) # case #1: no new addresses needed - empty = yield account.receiving.ensure_address_gap() + empty = await account.receiving.ensure_address_gap() self.assertEqual(len(empty), 0) # case #2: only one new addressed needed - records = yield account.receiving.get_address_records() - yield self.ledger.db.set_address_history(records[0]['address'], 'a:1:') - new_keys = yield account.receiving.ensure_address_gap() + records = await account.receiving.get_address_records() + await self.ledger.db.set_address_history(records[0]['address'], 'a:1:') + new_keys = await account.receiving.ensure_address_gap() self.assertEqual(len(new_keys), 1) # case #3: 20 addresses needed - yield self.ledger.db.set_address_history(new_keys[0], 'a:1:') - new_keys = yield account.receiving.ensure_address_gap() + await self.ledger.db.set_address_history(new_keys[0], 'a:1:') + new_keys = await account.receiving.ensure_address_gap() self.assertEqual(len(new_keys), 20) - @defer.inlineCallbacks - def test_get_or_create_usable_address(self): + async def test_get_or_create_usable_address(self): account = self.account - keys = yield account.receiving.get_addresses() + keys = await account.receiving.get_addresses() self.assertEqual(len(keys), 0) - address = yield account.receiving.get_or_create_usable_address() + address = await account.receiving.get_or_create_usable_address() self.assertIsNotNone(address) - keys = yield account.receiving.get_addresses() + keys = await account.receiving.get_addresses() self.assertEqual(len(keys), 20) - @defer.inlineCallbacks - def test_generate_account_from_seed(self): + async def test_generate_account_from_seed(self): account = self.ledger.account_class.from_dict( self.ledger, Wallet(), { "seed": "carbon smart garage balance margin twelve chest sword " @@ -123,17 +120,17 @@ class TestHierarchicalDeterministicAccount(unittest.TestCase): 'xpub661MyMwAqRbcFwwe67Bfjd53h5WXmKm6tqfBJZZH3pQLoy8Nb6mKUMJFc7UbpV' 'NzmwFPN2evn3YHnig1pkKVYcvCV8owTd2yAcEkJfCX53g' ) - address = yield account.receiving.ensure_address_gap() + address = await account.receiving.ensure_address_gap() self.assertEqual(address[0], '1CDLuMfwmPqJiNk5C2Bvew6tpgjAGgUk8J') - private_key = yield self.ledger.get_private_key_for_address('1CDLuMfwmPqJiNk5C2Bvew6tpgjAGgUk8J') + private_key = await self.ledger.get_private_key_for_address('1CDLuMfwmPqJiNk5C2Bvew6tpgjAGgUk8J') self.assertEqual( private_key.extended_key_string(), 'xprv9xV7rhbg6M4yWrdTeLorz3Q1GrQb4aQzzGWboP3du7W7UUztzNTUrEYTnDfz7o' 'ptBygDxXYRppyiuenJpoBTgYP2C26E1Ah5FEALM24CsWi' ) - invalid_key = yield self.ledger.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX') + invalid_key = await self.ledger.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX') self.assertIsNone(invalid_key) self.assertEqual( @@ -141,8 +138,7 @@ class TestHierarchicalDeterministicAccount(unittest.TestCase): b'1c01ae1e4c7d89e39f6d3aa7792c097a30ca7d40be249b6de52c81ec8cf9aab48b01' ) - @defer.inlineCallbacks - def test_load_and_save_account(self): + async def test_load_and_save_account(self): account_data = { 'name': 'My Account', 'seed': @@ -164,11 +160,11 @@ class TestHierarchicalDeterministicAccount(unittest.TestCase): account = self.ledger.account_class.from_dict(self.ledger, Wallet(), account_data) - yield account.ensure_address_gap() + await account.ensure_address_gap() - addresses = yield account.receiving.get_addresses() + addresses = await account.receiving.get_addresses() self.assertEqual(len(addresses), 5) - addresses = yield account.change.get_addresses() + addresses = await account.change.get_addresses() self.assertEqual(len(addresses), 5) self.maxDiff = None @@ -176,20 +172,21 @@ class TestHierarchicalDeterministicAccount(unittest.TestCase): self.assertDictEqual(account_data, account.to_dict()) -class TestSingleKeyAccount(unittest.TestCase): +class TestSingleKeyAccount(AsyncioTestCase): - @defer.inlineCallbacks - def setUp(self): + async def asyncSetUp(self): self.ledger = ledger_class({ 'db': ledger_class.database_class(':memory:'), 'headers': ledger_class.headers_class(':memory:'), }) - yield self.ledger.db.open() + await self.ledger.db.open() self.account = self.ledger.account_class.generate( self.ledger, Wallet(), "torba", {'name': 'single-address'}) - @defer.inlineCallbacks - def test_generate_account(self): + async def asyncTearDown(self): + await self.ledger.db.close() + + async def test_generate_account(self): account = self.account self.assertEqual(account.ledger, self.ledger) @@ -197,37 +194,36 @@ class TestSingleKeyAccount(unittest.TestCase): self.assertEqual(account.public_key.ledger, self.ledger) self.assertEqual(account.private_key.public_key, account.public_key) - addresses = yield account.receiving.get_addresses() + addresses = await account.receiving.get_addresses() self.assertEqual(len(addresses), 0) - addresses = yield account.change.get_addresses() + addresses = await account.change.get_addresses() self.assertEqual(len(addresses), 0) - yield account.ensure_address_gap() + await account.ensure_address_gap() - addresses = yield account.receiving.get_addresses() + addresses = await account.receiving.get_addresses() self.assertEqual(len(addresses), 1) self.assertEqual(addresses[0], account.public_key.address) - addresses = yield account.change.get_addresses() + addresses = await account.change.get_addresses() self.assertEqual(len(addresses), 1) self.assertEqual(addresses[0], account.public_key.address) - addresses = yield account.get_addresses() + addresses = await account.get_addresses() self.assertEqual(len(addresses), 1) self.assertEqual(addresses[0], account.public_key.address) - @defer.inlineCallbacks - def test_ensure_address_gap(self): + async def test_ensure_address_gap(self): account = self.account self.assertIsInstance(account.receiving, SingleKey) - addresses = yield account.receiving.get_addresses() + addresses = await account.receiving.get_addresses() self.assertEqual(addresses, []) # we have 12, but default gap is 20 - new_keys = yield account.receiving.ensure_address_gap() + new_keys = await account.receiving.ensure_address_gap() self.assertEqual(len(new_keys), 1) self.assertEqual(new_keys[0], account.public_key.address) - records = yield account.receiving.get_address_records() + records = await account.receiving.get_address_records() self.assertEqual(records, [{ 'position': 0, 'chain': 0, 'account': account.public_key.address, @@ -236,37 +232,35 @@ class TestSingleKeyAccount(unittest.TestCase): }]) # case #1: no new addresses needed - empty = yield account.receiving.ensure_address_gap() + empty = await account.receiving.ensure_address_gap() self.assertEqual(len(empty), 0) # case #2: after use, still no new address needed - records = yield account.receiving.get_address_records() - yield self.ledger.db.set_address_history(records[0]['address'], 'a:1:') - empty = yield account.receiving.ensure_address_gap() + records = await account.receiving.get_address_records() + await self.ledger.db.set_address_history(records[0]['address'], 'a:1:') + empty = await account.receiving.ensure_address_gap() self.assertEqual(len(empty), 0) - @defer.inlineCallbacks - def test_get_or_create_usable_address(self): + async def test_get_or_create_usable_address(self): account = self.account - addresses = yield account.receiving.get_addresses() + addresses = await account.receiving.get_addresses() self.assertEqual(len(addresses), 0) - address1 = yield account.receiving.get_or_create_usable_address() + address1 = await account.receiving.get_or_create_usable_address() self.assertIsNotNone(address1) - yield self.ledger.db.set_address_history(address1, 'a:1:b:2:c:3:') - records = yield account.receiving.get_address_records() + await self.ledger.db.set_address_history(address1, 'a:1:b:2:c:3:') + records = await account.receiving.get_address_records() self.assertEqual(records[0]['used_times'], 3) - address2 = yield account.receiving.get_or_create_usable_address() + address2 = await account.receiving.get_or_create_usable_address() self.assertEqual(address1, address2) - keys = yield account.receiving.get_addresses() + keys = await account.receiving.get_addresses() self.assertEqual(len(keys), 1) - @defer.inlineCallbacks - def test_generate_account_from_seed(self): + async def test_generate_account_from_seed(self): account = self.ledger.account_class.from_dict( self.ledger, Wallet(), { "seed": @@ -285,17 +279,17 @@ class TestSingleKeyAccount(unittest.TestCase): 'xpub661MyMwAqRbcFwwe67Bfjd53h5WXmKm6tqfBJZZH3pQLoy8Nb6mKUMJFc7' 'UbpVNzmwFPN2evn3YHnig1pkKVYcvCV8owTd2yAcEkJfCX53g', ) - address = yield account.receiving.ensure_address_gap() + address = await account.receiving.ensure_address_gap() self.assertEqual(address[0], account.public_key.address) - private_key = yield self.ledger.get_private_key_for_address(address[0]) + private_key = await self.ledger.get_private_key_for_address(address[0]) self.assertEqual( private_key.extended_key_string(), 'xprv9s21ZrQH143K3TsAz5efNV8K93g3Ms3FXcjaWB9fVUsMwAoE3ZT4vYymkp' '5BxKKfnpz8J6sHDFriX1SnpvjNkzcks8XBnxjGLS83BTyfpna', ) - invalid_key = yield self.ledger.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX') + invalid_key = await self.ledger.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX') self.assertIsNone(invalid_key) self.assertEqual( @@ -303,8 +297,7 @@ class TestSingleKeyAccount(unittest.TestCase): b'1c92caa0ef99bfd5e2ceb73b66da8cd726a9370be8c368d448a322f3c5b23aaab901' ) - @defer.inlineCallbacks - def test_load_and_save_account(self): + async def test_load_and_save_account(self): account_data = { 'name': 'My Account', 'seed': @@ -322,11 +315,11 @@ class TestSingleKeyAccount(unittest.TestCase): account = self.ledger.account_class.from_dict(self.ledger, Wallet(), account_data) - yield account.ensure_address_gap() + await account.ensure_address_gap() - addresses = yield account.receiving.get_addresses() + addresses = await account.receiving.get_addresses() self.assertEqual(len(addresses), 1) - addresses = yield account.change.get_addresses() + addresses = await account.change.get_addresses() self.assertEqual(len(addresses), 1) self.maxDiff = None @@ -334,7 +327,7 @@ class TestSingleKeyAccount(unittest.TestCase): self.assertDictEqual(account_data, account.to_dict()) -class AccountEncryptionTests(unittest.TestCase): +class AccountEncryptionTests(AsyncioTestCase): password = "password" init_vector = b'0000000000000000' unencrypted_account = { @@ -368,7 +361,7 @@ class AccountEncryptionTests(unittest.TestCase): 'address_generator': {'name': 'single-address'} } - def setUp(self): + async def asyncSetUp(self): self.ledger = ledger_class({ 'db': ledger_class.database_class(':memory:'), 'headers': ledger_class.headers_class(':memory:'), diff --git a/tests/unit/test_bcd_data_stream.py b/tests/unit/test_bcd_data_stream.py index 577c0ebf2..d15a92f4f 100644 --- a/tests/unit/test_bcd_data_stream.py +++ b/tests/unit/test_bcd_data_stream.py @@ -1,4 +1,4 @@ -from twisted.trial import unittest +import unittest from torba.bcd_data_stream import BCDataStream diff --git a/tests/unit/test_bip32.py b/tests/unit/test_bip32.py index 12d69089e..b96c5a894 100644 --- a/tests/unit/test_bip32.py +++ b/tests/unit/test_bip32.py @@ -1,5 +1,5 @@ +import unittest from binascii import unhexlify, hexlify -from twisted.trial import unittest from .key_fixtures import expected_ids, expected_privkeys, expected_hardened_privkeys from torba.bip32 import PubKey, PrivateKey, from_extended_key_string diff --git a/tests/unit/test_coinselection.py b/tests/unit/test_coinselection.py index b3b577ce8..610a20715 100644 --- a/tests/unit/test_coinselection.py +++ b/tests/unit/test_coinselection.py @@ -1,4 +1,4 @@ -from twisted.trial import unittest +import unittest from types import GeneratorType from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 9c3409d8b..860223bea 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -1,11 +1,12 @@ -from twisted.trial import unittest -from twisted.internet import defer +import unittest from torba.wallet import Wallet from torba.constants import COIN from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class from torba.basedatabase import query, constraints_to_sql +from orchstr8.testcase import AsyncioTestCase + from .test_transaction import get_output, NULL_HASH @@ -100,53 +101,52 @@ class TestQueryBuilder(unittest.TestCase): ) -class TestQueries(unittest.TestCase): +class TestQueries(AsyncioTestCase): - def setUp(self): + async def asyncSetUp(self): self.ledger = ledger_class({ 'db': ledger_class.database_class(':memory:'), 'headers': ledger_class.headers_class(':memory:'), }) - return self.ledger.db.open() + await self.ledger.db.open() - @defer.inlineCallbacks - def create_account(self): + async def asyncTearDown(self): + await self.ledger.db.close() + + async def create_account(self): account = self.ledger.account_class.generate(self.ledger, Wallet()) - yield account.ensure_address_gap() + await account.ensure_address_gap() return account - @defer.inlineCallbacks - def create_tx_from_nothing(self, my_account, height): - to_address = yield my_account.receiving.get_or_create_usable_address() + async def create_tx_from_nothing(self, my_account, height): + to_address = await my_account.receiving.get_or_create_usable_address() to_hash = ledger_class.address_to_hash160(to_address) tx = ledger_class.transaction_class(height=height, is_verified=True) \ .add_inputs([self.txi(self.txo(1, NULL_HASH))]) \ .add_outputs([self.txo(1, to_hash)]) - yield self.ledger.db.save_transaction_io('insert', tx, to_address, to_hash, '') + await self.ledger.db.save_transaction_io('insert', tx, to_address, to_hash, '') return tx - @defer.inlineCallbacks - def create_tx_from_txo(self, txo, to_account, height): + async def create_tx_from_txo(self, txo, to_account, height): from_hash = txo.script.values['pubkey_hash'] from_address = self.ledger.hash160_to_address(from_hash) - to_address = yield to_account.receiving.get_or_create_usable_address() + to_address = await to_account.receiving.get_or_create_usable_address() to_hash = ledger_class.address_to_hash160(to_address) tx = ledger_class.transaction_class(height=height, is_verified=True) \ .add_inputs([self.txi(txo)]) \ .add_outputs([self.txo(1, to_hash)]) - yield self.ledger.db.save_transaction_io('insert', tx, from_address, from_hash, '') - yield self.ledger.db.save_transaction_io('', tx, to_address, to_hash, '') + await self.ledger.db.save_transaction_io('insert', tx, from_address, from_hash, '') + await self.ledger.db.save_transaction_io('', tx, to_address, to_hash, '') return tx - @defer.inlineCallbacks - def create_tx_to_nowhere(self, txo, height): + async def create_tx_to_nowhere(self, txo, height): from_hash = txo.script.values['pubkey_hash'] from_address = self.ledger.hash160_to_address(from_hash) to_hash = NULL_HASH tx = ledger_class.transaction_class(height=height, is_verified=True) \ .add_inputs([self.txi(txo)]) \ .add_outputs([self.txo(1, to_hash)]) - yield self.ledger.db.save_transaction_io('insert', tx, from_address, from_hash, '') + await self.ledger.db.save_transaction_io('insert', tx, from_address, from_hash, '') return tx def txo(self, amount, address): @@ -155,39 +155,38 @@ class TestQueries(unittest.TestCase): def txi(self, txo): return ledger_class.transaction_class.input_class.spend(txo) - @defer.inlineCallbacks - def test_get_transactions(self): - account1 = yield self.create_account() - account2 = yield self.create_account() - tx1 = yield self.create_tx_from_nothing(account1, 1) - tx2 = yield self.create_tx_from_txo(tx1.outputs[0], account2, 2) - tx3 = yield self.create_tx_to_nowhere(tx2.outputs[0], 3) + async def test_get_transactions(self): + account1 = await self.create_account() + account2 = await self.create_account() + tx1 = await self.create_tx_from_nothing(account1, 1) + tx2 = await self.create_tx_from_txo(tx1.outputs[0], account2, 2) + tx3 = await self.create_tx_to_nowhere(tx2.outputs[0], 3) - txs = yield self.ledger.db.get_transactions() + txs = await self.ledger.db.get_transactions() self.assertEqual([tx3.id, tx2.id, tx1.id], [tx.id for tx in txs]) self.assertEqual([3, 2, 1], [tx.height for tx in txs]) - txs = yield self.ledger.db.get_transactions(account=account1) + txs = await self.ledger.db.get_transactions(account=account1) self.assertEqual([tx2.id, tx1.id], [tx.id for tx in txs]) self.assertEqual(txs[0].inputs[0].is_my_account, True) self.assertEqual(txs[0].outputs[0].is_my_account, False) self.assertEqual(txs[1].inputs[0].is_my_account, False) self.assertEqual(txs[1].outputs[0].is_my_account, True) - txs = yield self.ledger.db.get_transactions(account=account2) + txs = await self.ledger.db.get_transactions(account=account2) self.assertEqual([tx3.id, tx2.id], [tx.id for tx in txs]) self.assertEqual(txs[0].inputs[0].is_my_account, True) self.assertEqual(txs[0].outputs[0].is_my_account, False) self.assertEqual(txs[1].inputs[0].is_my_account, False) self.assertEqual(txs[1].outputs[0].is_my_account, True) - tx = yield self.ledger.db.get_transaction(txid=tx2.id) + tx = await self.ledger.db.get_transaction(txid=tx2.id) self.assertEqual(tx.id, tx2.id) self.assertEqual(tx.inputs[0].is_my_account, False) self.assertEqual(tx.outputs[0].is_my_account, False) - tx = yield self.ledger.db.get_transaction(txid=tx2.id, account=account1) + tx = await self.ledger.db.get_transaction(txid=tx2.id, account=account1) self.assertEqual(tx.inputs[0].is_my_account, True) self.assertEqual(tx.outputs[0].is_my_account, False) - tx = yield self.ledger.db.get_transaction(txid=tx2.id, account=account2) + tx = await self.ledger.db.get_transaction(txid=tx2.id, account=account2) self.assertEqual(tx.inputs[0].is_my_account, False) self.assertEqual(tx.outputs[0].is_my_account, True) diff --git a/tests/unit/test_hash.py b/tests/unit/test_hash.py index 4f4d31526..4a7f2331d 100644 --- a/tests/unit/test_hash.py +++ b/tests/unit/test_hash.py @@ -1,11 +1,6 @@ -from unittest import TestCase +from unittest import TestCase, mock from torba.hash import aes_decrypt, aes_encrypt -try: - from unittest import mock -except ImportError: - import mock - class TestAESEncryptDecrypt(TestCase): message = 'The Times 03/Jan/2009 Chancellor on brink of second bailout for banks' diff --git a/tests/unit/test_headers.py b/tests/unit/test_headers.py index 92654bde5..46489e733 100644 --- a/tests/unit/test_headers.py +++ b/tests/unit/test_headers.py @@ -1,8 +1,7 @@ import os from urllib.request import Request, urlopen -from twisted.trial import unittest -from twisted.internet import defer +from orchstr8.testcase import AsyncioTestCase from torba.coin.bitcoinsegwit import MainHeaders @@ -11,7 +10,7 @@ def block_bytes(blocks): return blocks * MainHeaders.header_size -class BitcoinHeadersTestCase(unittest.TestCase): +class BitcoinHeadersTestCase(AsyncioTestCase): # Download headers instead of storing them in git. HEADER_URL = 'http://headers.electrum.org/blockchain_headers' @@ -39,7 +38,7 @@ class BitcoinHeadersTestCase(unittest.TestCase): headers.seek(after, os.SEEK_SET) return headers.read(upto) - def get_headers(self, upto: int = -1): + async def get_headers(self, upto: int = -1): h = MainHeaders(':memory:') h.io.write(self.get_bytes(upto)) return h @@ -47,8 +46,8 @@ class BitcoinHeadersTestCase(unittest.TestCase): class BasicHeadersTests(BitcoinHeadersTestCase): - def test_serialization(self): - h = self.get_headers() + async def test_serialization(self): + h = await self.get_headers() self.assertEqual(h[0], { 'bits': 486604799, 'block_height': 0, @@ -94,18 +93,16 @@ class BasicHeadersTests(BitcoinHeadersTestCase): h.get_raw_header(self.RETARGET_BLOCK) ) - @defer.inlineCallbacks - def test_connect_from_genesis_to_3000_past_first_chunk_at_2016(self): + async def test_connect_from_genesis_to_3000_past_first_chunk_at_2016(self): headers = MainHeaders(':memory:') self.assertEqual(headers.height, -1) - yield headers.connect(0, self.get_bytes(block_bytes(3001))) + await headers.connect(0, self.get_bytes(block_bytes(3001))) self.assertEqual(headers.height, 3000) - @defer.inlineCallbacks - def test_connect_9_blocks_passing_a_retarget_at_32256(self): + async def test_connect_9_blocks_passing_a_retarget_at_32256(self): retarget = block_bytes(self.RETARGET_BLOCK-5) - headers = self.get_headers(upto=retarget) + headers = await self.get_headers(upto=retarget) remainder = self.get_bytes(after=retarget) self.assertEqual(headers.height, 32250) - yield headers.connect(len(headers), remainder) + await headers.connect(len(headers), remainder) self.assertEqual(headers.height, 32259) diff --git a/tests/unit/test_ledger.py b/tests/unit/test_ledger.py index 771ab1426..3f6d3ba7c 100644 --- a/tests/unit/test_ledger.py +++ b/tests/unit/test_ledger.py @@ -1,6 +1,5 @@ import os from binascii import hexlify -from twisted.internet import defer from torba.coin.bitcoinsegwit import MainNetLedger from torba.wallet import Wallet @@ -18,32 +17,30 @@ class MockNetwork: self.get_history_called = [] self.get_transaction_called = [] - def get_history(self, address): + async def get_history(self, address): self.get_history_called.append(address) self.address = address - return defer.succeed(self.history) + return self.history - def get_merkle(self, txid, height): - return defer.succeed({'merkle': ['abcd01'], 'pos': 1}) + async def get_merkle(self, txid, height): + return {'merkle': ['abcd01'], 'pos': 1} - def get_transaction(self, tx_hash): + async def get_transaction(self, tx_hash): self.get_transaction_called.append(tx_hash) - return defer.succeed(self.transaction[tx_hash]) + return self.transaction[tx_hash] class LedgerTestCase(BitcoinHeadersTestCase): - def setUp(self): - super().setUp() + async def asyncSetUp(self): self.ledger = MainNetLedger({ 'db': MainNetLedger.database_class(':memory:'), 'headers': MainNetLedger.headers_class(':memory:') }) - return self.ledger.db.open() + await self.ledger.db.open() - def tearDown(self): - super().tearDown() - return self.ledger.db.close() + async def asyncTearDown(self): + await self.ledger.db.close() def make_header(self, **kwargs): header = { @@ -69,11 +66,10 @@ class LedgerTestCase(BitcoinHeadersTestCase): class TestSynchronization(LedgerTestCase): - @defer.inlineCallbacks - def test_update_history(self): + async def test_update_history(self): account = self.ledger.account_class.generate(self.ledger, Wallet(), "torba") - address = yield account.receiving.get_or_create_usable_address() - address_details = yield self.ledger.db.get_address(address=address) + address = await account.receiving.get_or_create_usable_address() + address_details = await self.ledger.db.get_address(address=address) self.assertEqual(address_details['history'], None) self.add_header(block_height=0, merkle_root=b'abcd04') @@ -89,16 +85,16 @@ class TestSynchronization(LedgerTestCase): 'abcd02': hexlify(get_transaction(get_output(2)).raw), 'abcd03': hexlify(get_transaction(get_output(3)).raw), }) - yield self.ledger.update_history(address) + await self.ledger.update_history(address) self.assertEqual(self.ledger.network.get_history_called, [address]) self.assertEqual(self.ledger.network.get_transaction_called, ['abcd01', 'abcd02', 'abcd03']) - address_details = yield self.ledger.db.get_address(address=address) + address_details = await self.ledger.db.get_address(address=address) self.assertEqual(address_details['history'], 'abcd01:0:abcd02:1:abcd03:2:') self.ledger.network.get_history_called = [] self.ledger.network.get_transaction_called = [] - yield self.ledger.update_history(address) + await self.ledger.update_history(address) self.assertEqual(self.ledger.network.get_history_called, [address]) self.assertEqual(self.ledger.network.get_transaction_called, []) @@ -106,10 +102,10 @@ class TestSynchronization(LedgerTestCase): self.ledger.network.transaction['abcd04'] = hexlify(get_transaction(get_output(4)).raw) self.ledger.network.get_history_called = [] self.ledger.network.get_transaction_called = [] - yield self.ledger.update_history(address) + await self.ledger.update_history(address) self.assertEqual(self.ledger.network.get_history_called, [address]) self.assertEqual(self.ledger.network.get_transaction_called, ['abcd04']) - address_details = yield self.ledger.db.get_address(address=address) + address_details = await self.ledger.db.get_address(address=address) self.assertEqual(address_details['history'], 'abcd01:0:abcd02:1:abcd03:2:abcd04:3:') @@ -117,14 +113,13 @@ class MocHeaderNetwork: def __init__(self, responses): self.responses = responses - def get_headers(self, height, blocks): + async def get_headers(self, height, blocks): return self.responses[height] class BlockchainReorganizationTests(LedgerTestCase): - @defer.inlineCallbacks - def test_1_block_reorganization(self): + async def test_1_block_reorganization(self): self.ledger.network = MocHeaderNetwork({ 20: {'height': 20, 'count': 5, 'hex': hexlify( self.get_bytes(after=block_bytes(20), upto=block_bytes(5)) @@ -132,15 +127,14 @@ class BlockchainReorganizationTests(LedgerTestCase): 25: {'height': 25, 'count': 0, 'hex': b''} }) headers = self.ledger.headers - yield headers.connect(0, self.get_bytes(upto=block_bytes(20))) + await headers.connect(0, self.get_bytes(upto=block_bytes(20))) self.add_header(block_height=len(headers)) self.assertEqual(headers.height, 20) - yield self.ledger.receive_header([{ + await self.ledger.receive_header([{ 'height': 21, 'hex': hexlify(self.make_header(block_height=21)) }]) - @defer.inlineCallbacks - def test_3_block_reorganization(self): + async def test_3_block_reorganization(self): self.ledger.network = MocHeaderNetwork({ 20: {'height': 20, 'count': 5, 'hex': hexlify( self.get_bytes(after=block_bytes(20), upto=block_bytes(5)) @@ -150,11 +144,11 @@ class BlockchainReorganizationTests(LedgerTestCase): 25: {'height': 25, 'count': 0, 'hex': b''} }) headers = self.ledger.headers - yield headers.connect(0, self.get_bytes(upto=block_bytes(20))) + await headers.connect(0, self.get_bytes(upto=block_bytes(20))) self.add_header(block_height=len(headers)) self.add_header(block_height=len(headers)) self.add_header(block_height=len(headers)) self.assertEqual(headers.height, 22) - yield self.ledger.receive_header(({ + await self.ledger.receive_header(({ 'height': 23, 'hex': hexlify(self.make_header(block_height=23)) },)) diff --git a/tests/unit/test_script.py b/tests/unit/test_script.py index 1bb832123..ffcbc5b12 100644 --- a/tests/unit/test_script.py +++ b/tests/unit/test_script.py @@ -1,5 +1,5 @@ +import unittest from binascii import hexlify, unhexlify -from twisted.trial import unittest from torba.bcd_data_stream import BCDataStream from torba.basescript import Template, ParseError, tokenize, push_data diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index 26c5f6d8d..b082b6035 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -1,7 +1,8 @@ +import unittest from binascii import hexlify, unhexlify from itertools import cycle -from twisted.trial import unittest -from twisted.internet import defer + +from orchstr8.testcase import AsyncioTestCase from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class from torba.wallet import Wallet @@ -29,9 +30,9 @@ def get_transaction(txo=None): .add_outputs([txo or ledger_class.transaction_class.output_class.pay_pubkey_hash(CENT, NULL_HASH)]) -class TestSizeAndFeeEstimation(unittest.TestCase): +class TestSizeAndFeeEstimation(AsyncioTestCase): - def setUp(self): + async def asyncSetUp(self): self.ledger = ledger_class({ 'db': ledger_class.database_class(':memory:'), 'headers': ledger_class.headers_class(':memory:'), @@ -181,17 +182,19 @@ class TestTransactionSerialization(unittest.TestCase): self.assertEqual(tx.raw, raw) -class TestTransactionSigning(unittest.TestCase): +class TestTransactionSigning(AsyncioTestCase): - def setUp(self): + async def asyncSetUp(self): self.ledger = ledger_class({ 'db': ledger_class.database_class(':memory:'), 'headers': ledger_class.headers_class(':memory:'), }) - return self.ledger.db.open() + await self.ledger.db.open() - @defer.inlineCallbacks - def test_sign(self): + async def asyncTearDown(self): + await self.ledger.db.close() + + async def test_sign(self): account = self.ledger.account_class.from_dict( self.ledger, Wallet(), { "seed": "carbon smart garage balance margin twelve chest sword " @@ -200,8 +203,8 @@ class TestTransactionSigning(unittest.TestCase): } ) - yield account.ensure_address_gap() - address1, address2 = yield account.receiving.get_addresses(limit=2) + await account.ensure_address_gap() + address1, address2 = await account.receiving.get_addresses(limit=2) pubkey_hash1 = self.ledger.address_to_hash160(address1) pubkey_hash2 = self.ledger.address_to_hash160(address2) @@ -211,9 +214,8 @@ class TestTransactionSigning(unittest.TestCase): .add_inputs([tx_class.input_class.spend(get_output(2*COIN, pubkey_hash1))]) \ .add_outputs([tx_class.output_class.pay_pubkey_hash(int(1.9*COIN), pubkey_hash2)]) \ - yield tx.sign([account]) + await tx.sign([account]) - print(hexlify(tx.inputs[0].script.values['signature'])) self.assertEqual( hexlify(tx.inputs[0].script.values['signature']), b'304402205a1df8cd5d2d2fa5934b756883d6c07e4f83e1350c740992d47a12422' @@ -221,15 +223,14 @@ class TestTransactionSigning(unittest.TestCase): ) -class TransactionIOBalancing(unittest.TestCase): +class TransactionIOBalancing(AsyncioTestCase): - @defer.inlineCallbacks - def setUp(self): + async def asyncSetUp(self): self.ledger = ledger_class({ 'db': ledger_class.database_class(':memory:'), 'headers': ledger_class.headers_class(':memory:'), }) - yield self.ledger.db.open() + await self.ledger.db.open() self.account = self.ledger.account_class.from_dict( self.ledger, Wallet(), { "seed": "carbon smart garage balance margin twelve chest sword " @@ -237,10 +238,13 @@ class TransactionIOBalancing(unittest.TestCase): } ) - addresses = yield self.account.ensure_address_gap() + addresses = await self.account.ensure_address_gap() self.pubkey_hash = [self.ledger.address_to_hash160(a) for a in addresses] self.hash_cycler = cycle(self.pubkey_hash) + async def asyncTearDown(self): + await self.ledger.db.close() + def txo(self, amount, address=None): return get_output(int(amount*COIN), address or next(self.hash_cycler)) @@ -250,8 +254,7 @@ class TransactionIOBalancing(unittest.TestCase): def tx(self, inputs, outputs): return ledger_class.transaction_class.create(inputs, outputs, [self.account], self.account) - @defer.inlineCallbacks - def create_utxos(self, amounts): + async def create_utxos(self, amounts): utxos = [self.txo(amount) for amount in amounts] self.funding_tx = ledger_class.transaction_class(is_verified=True) \ @@ -260,7 +263,7 @@ class TransactionIOBalancing(unittest.TestCase): save_tx = 'insert' for utxo in utxos: - yield self.ledger.db.save_transaction_io( + await self.ledger.db.save_transaction_io( save_tx, self.funding_tx, self.ledger.hash160_to_address(utxo.script.values['pubkey_hash']), utxo.script.values['pubkey_hash'], '' @@ -277,17 +280,16 @@ class TransactionIOBalancing(unittest.TestCase): def outputs(tx): return [round(o.amount/COIN, 2) for o in tx.outputs] - @defer.inlineCallbacks - def test_basic_use_cases(self): + async def test_basic_use_cases(self): self.ledger.fee_per_byte = int(.01*CENT) # available UTXOs for filling missing inputs - utxos = yield self.create_utxos([ + utxos = await self.create_utxos([ 1, 1, 3, 5, 10 ]) # pay 3 coins (3.02 w/ fees) - tx = yield self.tx( + tx = await self.tx( [], # inputs [self.txo(3)] # outputs ) @@ -296,10 +298,10 @@ class TransactionIOBalancing(unittest.TestCase): # a change of 1.98 is added to reach balance self.assertEqual(self.outputs(tx), [3, 1.98]) - yield self.ledger.release_outputs(utxos) + await self.ledger.release_outputs(utxos) # pay 2.98 coins (3.00 w/ fees) - tx = yield self.tx( + tx = await self.tx( [], # inputs [self.txo(2.98)] # outputs ) @@ -307,10 +309,10 @@ class TransactionIOBalancing(unittest.TestCase): self.assertEqual(self.inputs(tx), [3]) self.assertEqual(self.outputs(tx), [2.98]) - yield self.ledger.release_outputs(utxos) + await self.ledger.release_outputs(utxos) # supplied input and output, but input is not enough to cover output - tx = yield self.tx( + tx = await self.tx( [self.txi(self.txo(10))], # inputs [self.txo(11)] # outputs ) @@ -319,10 +321,10 @@ class TransactionIOBalancing(unittest.TestCase): # change is now needed to consume extra input self.assertEqual([11, 1.96], self.outputs(tx)) - yield self.ledger.release_outputs(utxos) + await self.ledger.release_outputs(utxos) # liquidating a UTXO - tx = yield self.tx( + tx = await self.tx( [self.txi(self.txo(10))], # inputs [] # outputs ) @@ -330,10 +332,10 @@ class TransactionIOBalancing(unittest.TestCase): # missing change added to consume the amount self.assertEqual([9.98], self.outputs(tx)) - yield self.ledger.release_outputs(utxos) + await self.ledger.release_outputs(utxos) # liquidating at a loss, requires adding extra inputs - tx = yield self.tx( + tx = await self.tx( [self.txi(self.txo(0.01))], # inputs [] # outputs ) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index d46c9d5d9..197591ce5 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,4 +1,4 @@ -from twisted.trial import unittest +import unittest from torba.util import ArithUint256 diff --git a/tests/unit/test_wallet.py b/tests/unit/test_wallet.py index 9f825834f..bebcf1585 100644 --- a/tests/unit/test_wallet.py +++ b/tests/unit/test_wallet.py @@ -1,5 +1,5 @@ +import unittest import tempfile -from twisted.trial import unittest from torba.coin.bitcoinsegwit import MainNetLedger as BTCLedger from torba.coin.bitcoincash import MainNetLedger as BCHLedger diff --git a/torba/baseaccount.py b/torba/baseaccount.py index fe35d9b37..7369f9394 100644 --- a/torba/baseaccount.py +++ b/torba/baseaccount.py @@ -1,7 +1,6 @@ import random import typing from typing import List, Dict, Tuple, Type, Optional, Any -from twisted.internet import defer from torba.mnemonic import Mnemonic from torba.bip32 import PrivateKey, PubKey, from_extended_key_string @@ -44,12 +43,8 @@ class AddressManager: def to_dict_instance(self) -> Optional[dict]: raise NotImplementedError - @property - def db(self): - return self.account.ledger.db - def _query_addresses(self, **constraints): - return self.db.get_addresses( + return self.account.ledger.db.get_addresses( account=self.account, chain=self.chain_number, **constraints @@ -58,26 +53,24 @@ class AddressManager: def get_private_key(self, index: int) -> PrivateKey: raise NotImplementedError - def get_max_gap(self) -> defer.Deferred: + async def get_max_gap(self): raise NotImplementedError - def ensure_address_gap(self) -> defer.Deferred: + async def ensure_address_gap(self): raise NotImplementedError - def get_address_records(self, only_usable: bool = False, **constraints) -> defer.Deferred: + def get_address_records(self, only_usable: bool = False, **constraints): raise NotImplementedError - @defer.inlineCallbacks - def get_addresses(self, only_usable: bool = False, **constraints) -> defer.Deferred: - records = yield self.get_address_records(only_usable=only_usable, **constraints) + async def get_addresses(self, only_usable: bool = False, **constraints): + records = await self.get_address_records(only_usable=only_usable, **constraints) return [r['address'] for r in records] - @defer.inlineCallbacks - def get_or_create_usable_address(self) -> defer.Deferred: - addresses = yield self.get_addresses(only_usable=True, limit=10) + async def get_or_create_usable_address(self): + addresses = await self.get_addresses(only_usable=True, limit=10) if addresses: return random.choice(addresses) - addresses = yield self.ensure_address_gap() + addresses = await self.ensure_address_gap() return addresses[0] @@ -106,22 +99,20 @@ class HierarchicalDeterministic(AddressManager): def get_private_key(self, index: int) -> PrivateKey: return self.account.private_key.child(self.chain_number).child(index) - @defer.inlineCallbacks - def generate_keys(self, start: int, end: int) -> defer.Deferred: + async def generate_keys(self, start: int, end: int): keys_batch, final_keys = [], [] for index in range(start, end+1): keys_batch.append((index, self.public_key.child(index))) if index % 180 == 0 or index == end: - yield self.db.add_keys( + await self.account.ledger.db.add_keys( self.account, self.chain_number, keys_batch ) final_keys.extend(keys_batch) keys_batch.clear() return [key[1].address for key in final_keys] - @defer.inlineCallbacks - def get_max_gap(self) -> defer.Deferred: - addresses = yield self._query_addresses(order_by="position ASC") + async def get_max_gap(self): + addresses = await self._query_addresses(order_by="position ASC") max_gap = 0 current_gap = 0 for address in addresses: @@ -132,9 +123,8 @@ class HierarchicalDeterministic(AddressManager): current_gap = 0 return max_gap - @defer.inlineCallbacks - def ensure_address_gap(self) -> defer.Deferred: - addresses = yield self._query_addresses(limit=self.gap, order_by="position DESC") + async def ensure_address_gap(self): + addresses = await self._query_addresses(limit=self.gap, order_by="position DESC") existing_gap = 0 for address in addresses: @@ -148,7 +138,7 @@ class HierarchicalDeterministic(AddressManager): start = addresses[0]['position']+1 if addresses else 0 end = start + (self.gap - existing_gap) - new_keys = yield self.generate_keys(start, end-1) + new_keys = await self.generate_keys(start, end-1) return new_keys def get_address_records(self, only_usable: bool = False, **constraints): @@ -176,20 +166,19 @@ class SingleKey(AddressManager): def get_private_key(self, index: int) -> PrivateKey: return self.account.private_key - def get_max_gap(self) -> defer.Deferred: - return defer.succeed(0) + async def get_max_gap(self): + return 0 - @defer.inlineCallbacks - def ensure_address_gap(self) -> defer.Deferred: - exists = yield self.get_address_records() + async def ensure_address_gap(self): + exists = await self.get_address_records() if not exists: - yield self.db.add_keys( + await self.account.ledger.db.add_keys( self.account, self.chain_number, [(0, self.public_key)] ) return [self.public_key.address] return [] - def get_address_records(self, only_usable: bool = False, **constraints) -> defer.Deferred: + def get_address_records(self, only_usable: bool = False, **constraints): return self._query_addresses(**constraints) @@ -289,9 +278,8 @@ class BaseAccount: 'address_generator': self.address_generator.to_dict(self.receiving, self.change) } - @defer.inlineCallbacks - def get_details(self, show_seed=False, **kwargs): - satoshis = yield self.get_balance(**kwargs) + async def get_details(self, show_seed=False, **kwargs): + satoshis = await self.get_balance(**kwargs) details = { 'id': self.id, 'name': self.name, @@ -325,23 +313,21 @@ class BaseAccount: self.password = None self.encrypted = True - @defer.inlineCallbacks - def ensure_address_gap(self): + async def ensure_address_gap(self): addresses = [] for address_manager in self.address_managers: - new_addresses = yield address_manager.ensure_address_gap() + new_addresses = await address_manager.ensure_address_gap() addresses.extend(new_addresses) return addresses - @defer.inlineCallbacks - def get_addresses(self, **constraints) -> defer.Deferred: - rows = yield self.ledger.db.select_addresses('address', account=self, **constraints) + async def get_addresses(self, **constraints): + rows = await self.ledger.db.select_addresses('address', account=self, **constraints) return [r[0] for r in rows] - def get_address_records(self, **constraints) -> defer.Deferred: + def get_address_records(self, **constraints): return self.ledger.db.get_addresses(account=self, **constraints) - def get_address_count(self, **constraints) -> defer.Deferred: + def get_address_count(self, **constraints): return self.ledger.db.get_address_count(account=self, **constraints) def get_private_key(self, chain: int, index: int) -> PrivateKey: @@ -355,10 +341,9 @@ class BaseAccount: constraints.update({'height__lte': height, 'height__gt': 0}) return self.ledger.db.get_balance(account=self, **constraints) - @defer.inlineCallbacks - def get_max_gap(self): - change_gap = yield self.change.get_max_gap() - receiving_gap = yield self.receiving.get_max_gap() + async def get_max_gap(self): + change_gap = await self.change.get_max_gap() + receiving_gap = await self.receiving.get_max_gap() return { 'max_change_gap': change_gap, 'max_receiving_gap': receiving_gap, @@ -376,24 +361,23 @@ class BaseAccount: def get_transaction_count(self, **constraints): return self.ledger.db.get_transaction_count(account=self, **constraints) - @defer.inlineCallbacks - def fund(self, to_account, amount=None, everything=False, + async def fund(self, to_account, amount=None, everything=False, outputs=1, broadcast=False, **constraints): assert self.ledger == to_account.ledger, 'Can only transfer between accounts of the same ledger.' tx_class = self.ledger.transaction_class if everything: - utxos = yield self.get_utxos(**constraints) - yield self.ledger.reserve_outputs(utxos) - tx = yield tx_class.create( + utxos = await self.get_utxos(**constraints) + await self.ledger.reserve_outputs(utxos) + tx = await tx_class.create( inputs=[tx_class.input_class.spend(txo) for txo in utxos], outputs=[], funding_accounts=[self], change_account=to_account ) elif amount > 0: - to_address = yield to_account.change.get_or_create_usable_address() + to_address = await to_account.change.get_or_create_usable_address() to_hash160 = to_account.ledger.address_to_hash160(to_address) - tx = yield tx_class.create( + tx = await tx_class.create( inputs=[], outputs=[ tx_class.output_class.pay_pubkey_hash(amount//outputs, to_hash160) @@ -406,9 +390,9 @@ class BaseAccount: raise ValueError('An amount is required.') if broadcast: - yield self.ledger.broadcast(tx) + await self.ledger.broadcast(tx) else: - yield self.ledger.release_outputs( + await self.ledger.release_outputs( [txi.txo_ref.txo for txi in tx.inputs] ) diff --git a/torba/basedatabase.py b/torba/basedatabase.py index 186c98d56..e6b74ab57 100644 --- a/torba/basedatabase.py +++ b/torba/basedatabase.py @@ -1,9 +1,8 @@ import logging -from typing import Tuple, List, Sequence +from typing import Tuple, List import sqlite3 -from twisted.internet import defer -from twisted.enterprise import adbapi +import aiosqlite from torba.hash import TXRefImmutable from torba.basetransaction import BaseTransaction @@ -107,25 +106,21 @@ def row_dict_or_default(rows, fields, default=None): class SQLiteMixin: - CREATE_TABLES_QUERY: Sequence[str] = () + CREATE_TABLES_QUERY: str def __init__(self, path): self._db_path = path - self.db: adbapi.ConnectionPool = None + self.db: aiosqlite.Connection = None self.ledger = None - def open(self): + async def open(self): log.info("connecting to database: %s", self._db_path) - self.db = adbapi.ConnectionPool( - 'sqlite3', self._db_path, cp_min=1, cp_max=1, check_same_thread=False - ) - return self.db.runInteraction( - lambda t: t.executescript(self.CREATE_TABLES_QUERY) - ) + self.db = aiosqlite.connect(self._db_path) + await self.db.__aenter__() + await self.db.executescript(self.CREATE_TABLES_QUERY) - def close(self): - self.db.close() - return defer.succeed(True) + async def close(self): + await self.db.close() @staticmethod def _insert_sql(table: str, data: dict) -> Tuple[str, List]: @@ -247,78 +242,75 @@ class BaseDatabase(SQLiteMixin): 'script': sqlite3.Binary(txo.script.source) } - def save_transaction_io(self, save_tx, tx: BaseTransaction, address, txhash, history): + async def save_transaction_io(self, save_tx, tx: BaseTransaction, address, txhash, history): - def _steps(t): - if save_tx == 'insert': - self.execute(t, *self._insert_sql('tx', { + if save_tx == 'insert': + await self.db.execute(*self._insert_sql('tx', { + 'txid': tx.id, + 'raw': sqlite3.Binary(tx.raw), + 'height': tx.height, + 'position': tx.position, + 'is_verified': tx.is_verified + })) + elif save_tx == 'update': + await self.db.execute(*self._update_sql("tx", { + 'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified + }, 'txid = ?', (tx.id,))) + + existing_txos = [r[0] for r in await self.db.execute_fetchall(*query( + "SELECT position FROM txo", txid=tx.id + ))] + + for txo in tx.outputs: + if txo.position in existing_txos: + continue + if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == txhash: + await self.db.execute(*self._insert_sql("txo", self.txo_to_row(tx, address, txo))) + elif txo.script.is_pay_script_hash: + # TODO: implement script hash payments + print('Database.save_transaction_io: pay script hash is not implemented!') + + # lookup the address associated with each TXI (via its TXO) + txoid_to_address = {r[0]: r[1] for r in await self.db.execute_fetchall(*query( + "SELECT txoid, address FROM txo", txoid__in=[txi.txo_ref.id for txi in tx.inputs] + ))} + + # list of TXIs that have already been added + existing_txis = [r[0] for r in await self.db.execute_fetchall(*query( + "SELECT txoid FROM txi", txid=tx.id + ))] + + for txi in tx.inputs: + txoid = txi.txo_ref.id + new_txi = txoid not in existing_txis + address_matches = txoid_to_address.get(txoid) == address + if new_txi and address_matches: + await self.db.execute(*self._insert_sql("txi", { 'txid': tx.id, - 'raw': sqlite3.Binary(tx.raw), - 'height': tx.height, - 'position': tx.position, - 'is_verified': tx.is_verified + 'txoid': txoid, + 'address': address, })) - elif save_tx == 'update': - self.execute(t, *self._update_sql("tx", { - 'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified - }, 'txid = ?', (tx.id,))) - existing_txos = [r[0] for r in self.execute(t, *query( - "SELECT position FROM txo", txid=tx.id - )).fetchall()] + await self._set_address_history(address, history) - for txo in tx.outputs: - if txo.position in existing_txos: - continue - if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == txhash: - self.execute(t, *self._insert_sql("txo", self.txo_to_row(tx, address, txo))) - elif txo.script.is_pay_script_hash: - # TODO: implement script hash payments - print('Database.save_transaction_io: pay script hash is not implemented!') - - # lookup the address associated with each TXI (via its TXO) - txoid_to_address = {r[0]: r[1] for r in self.execute(t, *query( - "SELECT txoid, address FROM txo", txoid__in=[txi.txo_ref.id for txi in tx.inputs] - )).fetchall()} - - # list of TXIs that have already been added - existing_txis = [r[0] for r in self.execute(t, *query( - "SELECT txoid FROM txi", txid=tx.id - )).fetchall()] - - for txi in tx.inputs: - txoid = txi.txo_ref.id - new_txi = txoid not in existing_txis - address_matches = txoid_to_address.get(txoid) == address - if new_txi and address_matches: - self.execute(t, *self._insert_sql("txi", { - 'txid': tx.id, - 'txoid': txoid, - 'address': address, - })) - - self._set_address_history(t, address, history) - - return self.db.runInteraction(_steps) - - def reserve_outputs(self, txos, is_reserved=True): + async def reserve_outputs(self, txos, is_reserved=True): txoids = [txo.id for txo in txos] - return self.run_operation( + await self.db.execute( "UPDATE txo SET is_reserved = ? WHERE txoid IN ({})".format( ', '.join(['?']*len(txoids)) ), [is_reserved]+txoids ) - def release_outputs(self, txos): - return self.reserve_outputs(txos, is_reserved=False) + async def release_outputs(self, txos): + await self.reserve_outputs(txos, is_reserved=False) - def rewind_blockchain(self, above_height): # pylint: disable=no-self-use + async def rewind_blockchain(self, above_height): # pylint: disable=no-self-use # TODO: # 1. delete transactions above_height # 2. update address histories removing deleted TXs - return defer.succeed(True) + return True - def select_transactions(self, cols, account=None, **constraints): + async def select_transactions(self, cols, account=None, **constraints): if 'txid' not in constraints and account is not None: constraints['$account'] = account.public_key.address constraints['txid__in'] = """ @@ -328,13 +320,14 @@ class BaseDatabase(SQLiteMixin): SELECT txi.txid FROM txi JOIN pubkey_address USING (address) WHERE pubkey_address.account = :$account """ - return self.run_query(*query("SELECT {} FROM tx".format(cols), **constraints)) + return await self.db.execute_fetchall( + *query("SELECT {} FROM tx".format(cols), **constraints) + ) - @defer.inlineCallbacks - def get_transactions(self, my_account=None, **constraints): + async def get_transactions(self, my_account=None, **constraints): my_account = my_account or constraints.get('account', None) - tx_rows = yield self.select_transactions( + tx_rows = await self.select_transactions( 'txid, raw, height, position, is_verified', order_by=["height DESC", "position DESC"], **constraints @@ -352,7 +345,7 @@ class BaseDatabase(SQLiteMixin): annotated_txos = { txo.id: txo for txo in - (yield self.get_txos( + (await self.get_txos( my_account=my_account, txid__in=txids )) @@ -360,7 +353,7 @@ class BaseDatabase(SQLiteMixin): referenced_txos = { txo.id: txo for txo in - (yield self.get_txos( + (await self.get_txos( my_account=my_account, txoid__in=query("SELECT txoid FROM txi", **{'txid__in': txids})[0] )) @@ -380,33 +373,30 @@ class BaseDatabase(SQLiteMixin): return txs - @defer.inlineCallbacks - def get_transaction_count(self, **constraints): + async def get_transaction_count(self, **constraints): constraints.pop('offset', None) constraints.pop('limit', None) constraints.pop('order_by', None) - count = yield self.select_transactions('count(*)', **constraints) + count = await self.select_transactions('count(*)', **constraints) return count[0][0] - @defer.inlineCallbacks - def get_transaction(self, **constraints): - txs = yield self.get_transactions(limit=1, **constraints) + async def get_transaction(self, **constraints): + txs = await self.get_transactions(limit=1, **constraints) if txs: return txs[0] - def select_txos(self, cols, **constraints): - return self.run_query(*query( + async def select_txos(self, cols, **constraints): + return await self.db.execute_fetchall(*query( "SELECT {} FROM txo" " JOIN pubkey_address USING (address)" " JOIN tx USING (txid)".format(cols), **constraints )) - @defer.inlineCallbacks - def get_txos(self, my_account=None, **constraints): + async def get_txos(self, my_account=None, **constraints): my_account = my_account or constraints.get('account', None) if isinstance(my_account, BaseAccount): my_account = my_account.public_key.address - rows = yield self.select_txos( + rows = await self.select_txos( "amount, script, txid, tx.height, txo.position, chain, account", **constraints ) output_class = self.ledger.transaction_class.output_class @@ -421,12 +411,11 @@ class BaseDatabase(SQLiteMixin): ) for row in rows ] - @defer.inlineCallbacks - def get_txo_count(self, **constraints): + async def get_txo_count(self, **constraints): constraints.pop('offset', None) constraints.pop('limit', None) constraints.pop('order_by', None) - count = yield self.select_txos('count(*)', **constraints) + count = await self.select_txos('count(*)', **constraints) return count[0][0] @staticmethod @@ -442,37 +431,33 @@ class BaseDatabase(SQLiteMixin): self.constrain_utxo(constraints) return self.get_txo_count(**constraints) - @defer.inlineCallbacks - def get_balance(self, **constraints): + async def get_balance(self, **constraints): self.constrain_utxo(constraints) - balance = yield self.select_txos('SUM(amount)', **constraints) + balance = await self.select_txos('SUM(amount)', **constraints) return balance[0][0] or 0 - def select_addresses(self, cols, **constraints): - return self.run_query(*query( - "SELECT {} FROM pubkey_address".format(cols), **constraints - )) + async def select_addresses(self, cols, **constraints): + return await self.db.execute_fetchall(*query( + "SELECT {} FROM pubkey_address".format(cols), **constraints + )) - @defer.inlineCallbacks - def get_addresses(self, cols=('address', 'account', 'chain', 'position', 'used_times'), **constraints): - addresses = yield self.select_addresses(', '.join(cols), **constraints) + async def get_addresses(self, cols=('address', 'account', 'chain', 'position', 'used_times'), **constraints): + addresses = await self.select_addresses(', '.join(cols), **constraints) return rows_to_dict(addresses, cols) - @defer.inlineCallbacks - def get_address_count(self, **constraints): - count = yield self.select_addresses('count(*)', **constraints) + async def get_address_count(self, **constraints): + count = await self.select_addresses('count(*)', **constraints) return count[0][0] - @defer.inlineCallbacks - def get_address(self, **constraints): - addresses = yield self.get_addresses( + async def get_address(self, **constraints): + addresses = await self.get_addresses( cols=('address', 'account', 'chain', 'position', 'pubkey', 'history', 'used_times'), limit=1, **constraints ) if addresses: return addresses[0] - def add_keys(self, account, chain, keys): + async def add_keys(self, account, chain, keys): sql = ( "insert into pubkey_address " "(address, account, chain, position, pubkey) " @@ -484,14 +469,13 @@ class BaseDatabase(SQLiteMixin): pubkey.address, account.public_key.address, chain, position, sqlite3.Binary(pubkey.pubkey_bytes) )) - return self.run_operation(sql, values) + await self.db.execute(sql, values) - @classmethod - def _set_address_history(cls, t, address, history): - cls.execute( - t, "UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?", + async def _set_address_history(self, address, history): + await self.db.execute( + "UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?", (history, history.count(':')//2, address) ) - def set_address_history(self, address, history): - return self.db.runInteraction(lambda t: self._set_address_history(t, address, history)) + async def set_address_history(self, address, history): + await self._set_address_history(address, history) diff --git a/torba/baseheader.py b/torba/baseheader.py index 9fca794f0..5526c3fb6 100644 --- a/torba/baseheader.py +++ b/torba/baseheader.py @@ -1,11 +1,10 @@ import os +import asyncio import logging from io import BytesIO from typing import Optional, Iterator, Tuple from binascii import hexlify -from twisted.internet import threads, defer - from torba.util import ArithUint256 from torba.hash import double_sha256 @@ -36,16 +35,14 @@ class BaseHeaders: self.io = BytesIO() self.path = path self._size: Optional[int] = None - self._header_connect_lock = defer.DeferredLock() + self._header_connect_lock = asyncio.Lock() - def open(self): + async def open(self): if self.path != ':memory:': self.io = open(self.path, 'a+b') - return defer.succeed(True) - def close(self): + async def close(self): self.io.close() - return defer.succeed(True) @staticmethod def serialize(header: dict) -> bytes: @@ -95,16 +92,15 @@ class BaseHeaders: return b'0' * 64 return hexlify(double_sha256(header)[::-1]) - @defer.inlineCallbacks - def connect(self, start: int, headers: bytes): + async def connect(self, start: int, headers: bytes) -> int: added = 0 bail = False - yield self._header_connect_lock.acquire() - try: + loop = asyncio.get_running_loop() + async with self._header_connect_lock: for height, chunk in self._iterate_chunks(start, headers): try: # validate_chunk() is CPU bound and reads previous chunks from file system - yield threads.deferToThread(self.validate_chunk, height, chunk) + await loop.run_in_executor(None, self.validate_chunk, height, chunk) except InvalidHeader as e: bail = True chunk = chunk[:(height-e.height+1)*self.header_size] @@ -115,14 +111,12 @@ class BaseHeaders: self.io.truncate() # .seek()/.write()/.truncate() might also .flush() when needed # the goal here is mainly to ensure we're definitely flush()'ing - yield threads.deferToThread(self.io.flush) + await loop.run_in_executor(None, self.io.flush) self._size = None added += written if bail: break - finally: - self._header_connect_lock.release() - defer.returnValue(added) + return added def validate_chunk(self, height, chunk): previous_hash, previous_header, previous_previous_header = None, None, None diff --git a/torba/baseledger.py b/torba/baseledger.py index ab99bd3ca..b168e1ec6 100644 --- a/torba/baseledger.py +++ b/torba/baseledger.py @@ -1,4 +1,5 @@ import os +import asyncio import logging from binascii import hexlify, unhexlify from io import StringIO @@ -7,8 +8,6 @@ from typing import Dict, Type, Iterable from operator import itemgetter from collections import namedtuple -from twisted.internet import defer - from torba import baseaccount from torba import basenetwork from torba import basetransaction @@ -104,8 +103,8 @@ class BaseLedger(metaclass=LedgerRegistry): ) self._transaction_processing_locks = {} - self._utxo_reservation_lock = defer.DeferredLock() - self._header_processing_lock = defer.DeferredLock() + self._utxo_reservation_lock = asyncio.Lock() + self._header_processing_lock = asyncio.Lock() @classmethod def get_id(cls): @@ -135,41 +134,32 @@ class BaseLedger(metaclass=LedgerRegistry): def add_account(self, account: baseaccount.BaseAccount): self.accounts.append(account) - @defer.inlineCallbacks - def get_private_key_for_address(self, address): - match = yield self.db.get_address(address=address) + async def get_private_key_for_address(self, address): + match = await self.db.get_address(address=address) if match: for account in self.accounts: if match['account'] == account.public_key.address: return account.get_private_key(match['chain'], match['position']) - @defer.inlineCallbacks - def get_effective_amount_estimators(self, funding_accounts: Iterable[baseaccount.BaseAccount]): + async def get_effective_amount_estimators(self, funding_accounts: Iterable[baseaccount.BaseAccount]): estimators = [] for account in funding_accounts: - utxos = yield account.get_utxos() + utxos = await account.get_utxos() for utxo in utxos: estimators.append(utxo.get_estimator(self)) return estimators - @defer.inlineCallbacks - def get_spendable_utxos(self, amount: int, funding_accounts): - yield self._utxo_reservation_lock.acquire() - try: - txos = yield self.get_effective_amount_estimators(funding_accounts) + async def get_spendable_utxos(self, amount: int, funding_accounts): + async with self._utxo_reservation_lock: + txos = await self.get_effective_amount_estimators(funding_accounts) selector = CoinSelector( txos, amount, self.transaction_class.output_class.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(self) ) spendables = selector.select() if spendables: - yield self.reserve_outputs(s.txo for s in spendables) - except Exception: - log.exception('Failed to get spendable utxos:') - raise - finally: - self._utxo_reservation_lock.release() - return spendables + await self.reserve_outputs(s.txo for s in spendables) + return spendables def reserve_outputs(self, txos): return self.db.reserve_outputs(txos) @@ -177,16 +167,14 @@ class BaseLedger(metaclass=LedgerRegistry): def release_outputs(self, txos): return self.db.release_outputs(txos) - @defer.inlineCallbacks - def get_local_status(self, address): - address_details = yield self.db.get_address(address=address) + async def get_local_status(self, address): + address_details = await self.db.get_address(address=address) history = address_details['history'] or '' h = sha256(history.encode()) return hexlify(h) - @defer.inlineCallbacks - def get_local_history(self, address): - address_details = yield self.db.get_address(address=address) + async def get_local_history(self, address): + address_details = await self.db.get_address(address=address) history = address_details['history'] or '' parts = history.split(':')[:-1] return list(zip(parts[0::2], map(int, parts[1::2]))) @@ -203,43 +191,40 @@ class BaseLedger(metaclass=LedgerRegistry): working_branch = double_sha256(combined) return hexlify(working_branch[::-1]) - def validate_transaction_and_set_position(self, tx, height, merkle): + async def validate_transaction_and_set_position(self, tx, height): if not height <= len(self.headers): return False + merkle = await self.network.get_merkle(tx.id, height) merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) header = self.headers[height] tx.position = merkle['pos'] tx.is_verified = merkle_root == header['merkle_root'] - @defer.inlineCallbacks - def start(self): + async def start(self): if not os.path.exists(self.path): os.mkdir(self.path) - yield defer.gatherResults([ + await asyncio.gather( self.db.open(), self.headers.open() - ]) + ) first_connection = self.network.on_connected.first - self.network.start() - yield first_connection - yield self.join_network() + asyncio.ensure_future(self.network.start()) + await first_connection + await self.join_network() self.network.on_connected.listen(self.join_network) - @defer.inlineCallbacks - def join_network(self, *args): + async def join_network(self, *args): log.info("Subscribing and updating accounts.") - yield self.update_headers() - yield self.network.subscribe_headers() - yield self.update_accounts() + await self.update_headers() + await self.network.subscribe_headers() + await self.update_accounts() - @defer.inlineCallbacks - def stop(self): - yield self.network.stop() - yield self.db.close() - yield self.headers.close() + async def stop(self): + await self.network.stop() + await self.db.close() + await self.headers.close() - @defer.inlineCallbacks - def update_headers(self, height=None, headers=None, subscription_update=False): + async def update_headers(self, height=None, headers=None, subscription_update=False): rewound = 0 while True: @@ -251,14 +236,14 @@ class BaseLedger(metaclass=LedgerRegistry): subscription_update = False if not headers: - header_response = yield self.network.get_headers(height, 2001) + header_response = await self.network.get_headers(height, 2001) headers = header_response['hex'] if not headers: # Nothing to do, network thinks we're already at the latest height. return - added = yield self.headers.connect(height, unhexlify(headers)) + added = await self.headers.connect(height, unhexlify(headers)) if added > 0: height += added self._on_header_controller.add( @@ -268,7 +253,7 @@ class BaseLedger(metaclass=LedgerRegistry): # we started rewinding blocks and apparently found # a new chain rewound = 0 - yield self.db.rewind_blockchain(height) + await self.db.rewind_blockchain(height) if subscription_update: # subscription updates are for latest header already @@ -310,66 +295,37 @@ class BaseLedger(metaclass=LedgerRegistry): # robust sync, turn off subscription update shortcut subscription_update = False - @defer.inlineCallbacks - def receive_header(self, response): - yield self._header_processing_lock.acquire() - try: + async def receive_header(self, response): + async with self._header_processing_lock: header = response[0] - yield self.update_headers( + await self.update_headers( height=header['height'], headers=header['hex'], subscription_update=True ) - finally: - self._header_processing_lock.release() - def update_accounts(self): - return defer.DeferredList([ + async def update_accounts(self): + return await asyncio.gather(*( self.update_account(a) for a in self.accounts - ]) + )) - @defer.inlineCallbacks - def update_account(self, account): # type: (baseaccount.BaseAccount) -> defer.Defferred + async def update_account(self, account: baseaccount.BaseAccount): # Before subscribing, download history for any addresses that don't have any, # this avoids situation where we're getting status updates to addresses we know # need to update anyways. Continue to get history and create more addresses until # all missing addresses are created and history for them is fully restored. - yield account.ensure_address_gap() - addresses = yield account.get_addresses(used_times=0) + await account.ensure_address_gap() + addresses = await account.get_addresses(used_times=0) while addresses: - yield defer.DeferredList([ - self.update_history(a) for a in addresses - ]) - addresses = yield account.ensure_address_gap() + await asyncio.gather(*(self.update_history(a) for a in addresses)) + addresses = await account.ensure_address_gap() # By this point all of the addresses should be restored and we # can now subscribe all of them to receive updates. - all_addresses = yield account.get_addresses() - yield defer.DeferredList( - list(map(self.subscribe_history, all_addresses)) - ) + all_addresses = await account.get_addresses() + await asyncio.gather(*(self.subscribe_history(a) for a in all_addresses)) - @defer.inlineCallbacks - def _prefetch_history(self, remote_history, local_history): - proofs, network_txs, deferreds = {}, {}, [] - for i, (hex_id, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)): - if i < len(local_history) and local_history[i] == (hex_id, remote_height): - continue - if remote_height > 0: - deferreds.append( - self.network.get_merkle(hex_id, remote_height).addBoth( - lambda result, txid: proofs.__setitem__(txid, result), hex_id) - ) - deferreds.append( - self.network.get_transaction(hex_id).addBoth( - lambda result, txid: network_txs.__setitem__(txid, result), hex_id) - ) - yield defer.DeferredList(deferreds) - return proofs, network_txs - - @defer.inlineCallbacks - def update_history(self, address): - remote_history = yield self.network.get_history(address) - local_history = yield self.get_local_history(address) - proofs, network_txs = yield self._prefetch_history(remote_history, local_history) + async def update_history(self, address): + remote_history = await self.network.get_history(address) + local_history = await self.get_local_history(address) synced_history = StringIO() for i, (hex_id, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)): @@ -379,30 +335,29 @@ class BaseLedger(metaclass=LedgerRegistry): if i < len(local_history) and local_history[i] == (hex_id, remote_height): continue - lock = self._transaction_processing_locks.setdefault(hex_id, defer.DeferredLock()) + lock = self._transaction_processing_locks.setdefault(hex_id, asyncio.Lock()) - yield lock.acquire() + await lock.acquire() try: # see if we have a local copy of transaction, otherwise fetch it from server - tx = yield self.db.get_transaction(txid=hex_id) + tx = await self.db.get_transaction(txid=hex_id) save_tx = None if tx is None: - _raw = network_txs[hex_id] + _raw = await self.network.get_transaction(hex_id) tx = self.transaction_class(unhexlify(_raw)) save_tx = 'insert' tx.height = remote_height if remote_height > 0 and (not tx.is_verified or tx.position == -1): - self.validate_transaction_and_set_position(tx, remote_height, proofs[hex_id]) + await self.validate_transaction_and_set_position(tx, remote_height) if save_tx is None: save_tx = 'update' - yield self.db.save_transaction_io( - save_tx, tx, address, self.address_to_hash160(address), - synced_history.getvalue() + await self.db.save_transaction_io( + save_tx, tx, address, self.address_to_hash160(address), synced_history.getvalue() ) log.debug( @@ -412,28 +367,22 @@ class BaseLedger(metaclass=LedgerRegistry): self._on_transaction_controller.add(TransactionEvent(address, tx)) - except Exception: - log.exception('Failed to synchronize transaction:') - raise - finally: lock.release() - if not lock.locked and hex_id in self._transaction_processing_locks: + if not lock.locked() and hex_id in self._transaction_processing_locks: del self._transaction_processing_locks[hex_id] - @defer.inlineCallbacks - def subscribe_history(self, address): - remote_status = yield self.network.subscribe_address(address) - local_status = yield self.get_local_status(address) + async def subscribe_history(self, address): + remote_status = await self.network.subscribe_address(address) + local_status = await self.get_local_status(address) if local_status != remote_status: - yield self.update_history(address) + await self.update_history(address) - @defer.inlineCallbacks - def receive_status(self, response): + async def receive_status(self, response): address, remote_status = response - local_status = yield self.get_local_status(address) + local_status = await self.get_local_status(address) if local_status != remote_status: - yield self.update_history(address) + await self.update_history(address) def broadcast(self, tx): return self.network.broadcast(hexlify(tx.raw).decode()) diff --git a/torba/basemanager.py b/torba/basemanager.py index 40e2acafa..706a6b296 100644 --- a/torba/basemanager.py +++ b/torba/basemanager.py @@ -1,6 +1,6 @@ +import asyncio import logging from typing import Type, MutableSequence, MutableMapping -from twisted.internet import defer from torba.baseledger import BaseLedger, LedgerRegistry from torba.wallet import Wallet, WalletStorage @@ -41,11 +41,10 @@ class BaseWalletManager: self.wallets.append(wallet) return wallet - @defer.inlineCallbacks - def get_detailed_accounts(self, confirmations=6, show_seed=False): + async def get_detailed_accounts(self, confirmations=6, show_seed=False): ledgers = {} for i, account in enumerate(self.accounts): - details = yield account.get_details(confirmations=confirmations, show_seed=True) + details = await account.get_details(confirmations=confirmations, show_seed=True) details['is_default_account'] = i == 0 ledger_id = account.ledger.get_id() ledgers.setdefault(ledger_id, []) @@ -68,16 +67,14 @@ class BaseWalletManager: for account in wallet.accounts: yield account - @defer.inlineCallbacks - def start(self): + async def start(self): self.running = True - yield defer.DeferredList([ + await asyncio.gather(*( l.start() for l in self.ledgers.values() - ]) + )) - @defer.inlineCallbacks - def stop(self): - yield defer.DeferredList([ + async def stop(self): + await asyncio.gather(*( l.stop() for l in self.ledgers.values() - ]) + )) self.running = False diff --git a/torba/basenetwork.py b/torba/basenetwork.py index 452386120..689cab088 100644 --- a/torba/basenetwork.py +++ b/torba/basenetwork.py @@ -1,145 +1,36 @@ -import json -import socket import logging from itertools import cycle -from twisted.internet import defer, reactor, protocol -from twisted.application.internet import ClientService, CancelledError -from twisted.internet.endpoints import clientFromString -from twisted.protocols.basic import LineOnlyReceiver -from twisted.python import failure + +from aiorpcx import ClientSession as BaseClientSession from torba import __version__ from torba.stream import StreamController -from torba.constants import TIMEOUT log = logging.getLogger(__name__) -class StratumClientProtocol(LineOnlyReceiver): - delimiter = b'\n' - MAX_LENGTH = 2000000 +class ClientSession(BaseClientSession): - def __init__(self): - self.request_id = 0 - self.lookup_table = {} - self.session = {} - self.network = None - - self.on_disconnected_controller = StreamController() - self.on_disconnected = self.on_disconnected_controller.stream - - def _get_id(self): - self.request_id += 1 - return self.request_id - - @property - def _ip(self): - return self.transport.getPeer().host - - def get_session(self): - return self.session - - def connectionMade(self): - try: - self.transport.setTcpNoDelay(True) - self.transport.setTcpKeepAlive(True) - if hasattr(socket, "TCP_KEEPIDLE"): - self.transport.socket.setsockopt( - socket.SOL_TCP, socket.TCP_KEEPIDLE, 120 - # Seconds before sending keepalive probes - ) - else: - log.debug("TCP_KEEPIDLE not available") - if hasattr(socket, "TCP_KEEPINTVL"): - self.transport.socket.setsockopt( - socket.SOL_TCP, socket.TCP_KEEPINTVL, 1 - # Interval in seconds between keepalive probes - ) - else: - log.debug("TCP_KEEPINTVL not available") - if hasattr(socket, "TCP_KEEPCNT"): - self.transport.socket.setsockopt( - socket.SOL_TCP, socket.TCP_KEEPCNT, 5 - # Failed keepalive probles before declaring other end dead - ) - else: - log.debug("TCP_KEEPCNT not available") - - except Exception as err: # pylint: disable=broad-except - # Supported only by the socket transport, - # but there's really no better place in code to trigger this. - log.warning("Error setting up socket: %s", err) - - def connectionLost(self, reason=None): - self.connected = 0 - self.on_disconnected_controller.add(True) - for deferred in self.lookup_table.values(): - if not deferred.called: - deferred.errback(TimeoutError("Connection dropped.")) - - def lineReceived(self, line): - log.debug('received: %s', line) - - try: - message = json.loads(line) - except (ValueError, TypeError): - raise ValueError("Cannot decode message '{}'".format(line.strip())) - - if message.get('id'): - try: - d = self.lookup_table.pop(message['id']) - if message.get('error'): - d.errback(RuntimeError(message['error'])) - else: - d.callback(message.get('result')) - except KeyError: - raise LookupError( - "Lookup for deferred object for message ID '{}' failed.".format(message['id'])) - elif message.get('method') in self.network.subscription_controllers: - controller = self.network.subscription_controllers[message['method']] - controller.add(message.get('params')) - else: - log.warning("Cannot handle message '%s'", line) - - def rpc(self, method, *args): - message_id = self._get_id() - message = json.dumps({ - 'id': message_id, - 'method': method, - 'params': args - }) - log.debug('sent: %s', message) - self.sendLine(message.encode('latin-1')) - d = self.lookup_table[message_id] = defer.Deferred() - d.addTimeout( - TIMEOUT, reactor, onTimeoutCancel=lambda *_: failure.Failure(TimeoutError( - "Timeout: Stratum request for '%s' took more than %s seconds" % (method, TIMEOUT))) - ) - return d - - -class StratumClientFactory(protocol.ClientFactory): - - protocol = StratumClientProtocol - - def __init__(self, network): + def __init__(self, *args, network, **kwargs): self.network = network - self.client = None + super().__init__(*args, **kwargs) + self._on_disconnect_controller = StreamController() + self.on_disconnected = self._on_disconnect_controller.stream - def buildProtocol(self, addr): - client = self.protocol() - client.factory = self - client.network = self.network - self.client = client - return client + async def handle_request(self, request): + controller = self.network.subscription_controllers[request.method] + controller.add(request.args) + + def connection_lost(self, exc): + super().connection_lost(exc) + self._on_disconnect_controller.add(True) class BaseNetwork: def __init__(self, ledger): self.config = ledger.config - self.client = None - self.service = None + self.client: ClientSession = None self.running = False self._on_connected_controller = StreamController() @@ -156,48 +47,35 @@ class BaseNetwork: 'blockchain.address.subscribe': self._on_status_controller, } - @defer.inlineCallbacks - def start(self): + async def start(self): self.running = True for server in cycle(self.config['default_servers']): connection_string = 'tcp:{}:{}'.format(*server) - endpoint = clientFromString(reactor, connection_string) - log.debug("Attempting connection to SPV wallet server: %s", connection_string) - self.service = ClientService(endpoint, StratumClientFactory(self)) - self.service.startService() + self.client = ClientSession(*server, network=self) try: - self.client = yield self.service.whenConnected(failAfterFailures=2) - yield self.ensure_server_version() - log.info("Successfully connected to SPV wallet server: %s", connection_string) + await self.client.create_connection() + await self.ensure_server_version() + log.info("Successfully connected to SPV wallet server: %s", ) self._on_connected_controller.add(True) - yield self.client.on_disconnected.first - except CancelledError: - return + await self.client.on_disconnected.first except Exception: # pylint: disable=broad-except log.exception("Connecting to %s raised an exception:", connection_string) - finally: - self.client = None - if self.service is not None: - self.service.stopService() if not self.running: return - def stop(self): + async def stop(self): self.running = False - if self.service is not None: - self.service.stopService() if self.is_connected: - return self.client.on_disconnected.first - else: - return defer.succeed(True) + await self.client.close() + await self.client.on_disconnected.first @property def is_connected(self): - return self.client is not None and self.client.connected + return self.client is not None and not self.client.is_closing() def rpc(self, list_or_method, *args): if self.is_connected: - return self.client.rpc(list_or_method, *args) + return self.client.send_request(list_or_method, args) else: raise ConnectionError("Attempting to send rpc request when connection is not available.") diff --git a/torba/basetransaction.py b/torba/basetransaction.py index a5fa2fce1..dca0fb2f8 100644 --- a/torba/basetransaction.py +++ b/torba/basetransaction.py @@ -3,8 +3,6 @@ import typing from typing import List, Iterable, Optional from binascii import hexlify -from twisted.internet import defer - from torba.basescript import BaseInputScript, BaseOutputScript from torba.baseaccount import BaseAccount from torba.constants import COIN, NULL_HASH32 @@ -426,8 +424,7 @@ class BaseTransaction: return ledger @classmethod - @defer.inlineCallbacks - def create(cls, inputs: Iterable[BaseInput], outputs: Iterable[BaseOutput], + async def create(cls, inputs: Iterable[BaseInput], outputs: Iterable[BaseOutput], funding_accounts: Iterable[BaseAccount], change_account: BaseAccount): """ Find optimal set of inputs when only outputs are provided; add change outputs if only inputs are provided or if inputs are greater than outputs. """ @@ -450,7 +447,7 @@ class BaseTransaction: if payment < cost: deficit = cost - payment - spendables = yield ledger.get_spendable_utxos(deficit, funding_accounts) + spendables = await ledger.get_spendable_utxos(deficit, funding_accounts) if not spendables: raise ValueError('Not enough funds to cover this transaction.') payment += sum(s.effective_amount for s in spendables) @@ -463,28 +460,27 @@ class BaseTransaction: ) change = payment - cost if change > cost_of_change: - change_address = yield change_account.change.get_or_create_usable_address() + change_address = await change_account.change.get_or_create_usable_address() change_hash160 = change_account.ledger.address_to_hash160(change_address) change_amount = change - cost_of_change change_output = cls.output_class.pay_pubkey_hash(change_amount, change_hash160) change_output.is_change = True tx.add_outputs([cls.output_class.pay_pubkey_hash(change_amount, change_hash160)]) - yield tx.sign(funding_accounts) + await tx.sign(funding_accounts) except Exception as e: log.exception('Failed to synchronize transaction:') - yield ledger.release_outputs(tx.outputs) + await ledger.release_outputs(tx.outputs) raise e - defer.returnValue(tx) + return tx @staticmethod def signature_hash_type(hash_type): return hash_type - @defer.inlineCallbacks - def sign(self, funding_accounts: Iterable[BaseAccount]) -> defer.Deferred: + async def sign(self, funding_accounts: Iterable[BaseAccount]): ledger = self.ensure_all_have_same_ledger(funding_accounts) for i, txi in enumerate(self._inputs): assert txi.script is not None @@ -492,7 +488,7 @@ class BaseTransaction: txo_script = txi.txo_ref.txo.script if txo_script.is_pay_pubkey_hash: address = ledger.hash160_to_address(txo_script.values['pubkey_hash']) - private_key = yield ledger.get_private_key_for_address(address) + private_key = await ledger.get_private_key_for_address(address) tx = self._serialize_for_signature(i) txi.script.values['signature'] = \ private_key.sign(tx) + bytes((self.signature_hash_type(1),)) diff --git a/torba/stream.py b/torba/stream.py index 79fbba796..64c1f5e10 100644 --- a/torba/stream.py +++ b/torba/stream.py @@ -1,6 +1,4 @@ import asyncio -from twisted.internet.defer import Deferred -from twisted.python.failure import Failure class BroadcastSubscription: @@ -31,11 +29,13 @@ class BroadcastSubscription: def _add(self, data): if self.can_fire and self._on_data is not None: - self._on_data(data) + maybe_coroutine = self._on_data(data) + if asyncio.iscoroutine(maybe_coroutine): + asyncio.ensure_future(maybe_coroutine) - def _add_error(self, error, traceback): + def _add_error(self, exception): if self.can_fire and self._on_error is not None: - self._on_error(error, traceback) + self._on_error(exception) def _close(self): if self.can_fire and self._on_done is not None: @@ -66,9 +66,9 @@ class StreamController: for subscription in self._iterate_subscriptions: subscription._add(event) - def add_error(self, error, traceback): + def add_error(self, exception): for subscription in self._iterate_subscriptions: - subscription._add_error(error, traceback) + subscription._add_error(exception) def close(self): for subscription in self._iterate_subscriptions: @@ -108,38 +108,35 @@ class Stream: def listen(self, on_data, on_error=None, on_done=None): return self._controller._listen(on_data, on_error, on_done) - def deferred_where(self, condition): - deferred = Deferred() + def where(self, condition) -> asyncio.Future: + future = asyncio.get_event_loop().create_future() def where_test(value): if condition(value): - self._cancel_and_callback(subscription, deferred, value) + self._cancel_and_callback(subscription, future, value) subscription = self.listen( where_test, - lambda error, traceback: self._cancel_and_error(subscription, deferred, error, traceback) + lambda exception: self._cancel_and_error(subscription, future, exception) ) - return deferred - - def where(self, condition): - return self.deferred_where(condition).asFuture(asyncio.get_event_loop()) + return future @property def first(self): - deferred = Deferred() + future = asyncio.get_event_loop().create_future() subscription = self.listen( - lambda value: self._cancel_and_callback(subscription, deferred, value), - lambda error, traceback: self._cancel_and_error(subscription, deferred, error, traceback) + lambda value: self._cancel_and_callback(subscription, future, value), + lambda exception: self._cancel_and_error(subscription, future, exception) ) - return deferred + return future @staticmethod - def _cancel_and_callback(subscription, deferred, value): + def _cancel_and_callback(subscription: BroadcastSubscription, future: asyncio.Future, value): subscription.cancel() - deferred.callback(value) + future.set_result(value) @staticmethod - def _cancel_and_error(subscription, deferred, error, traceback): + def _cancel_and_error(subscription: BroadcastSubscription, future: asyncio.Future, exception): subscription.cancel() - deferred.errback(Failure(error, exc_tb=traceback)) + future.set_exception(exception)