twisted -> asyncio

This commit is contained in:
Lex Berezhny 2018-10-14 22:16:51 -04:00
parent 0ce4b9a7de
commit 2c5fd4aade
21 changed files with 480 additions and 721 deletions

View file

@ -1,5 +1,5 @@
import asyncio
from orchstr8.testcase import IntegrationTestCase, d2f
from orchstr8.testcase import IntegrationTestCase
from torba.constants import COIN
@ -9,14 +9,14 @@ class BasicTransactionTests(IntegrationTestCase):
async def test_sending_and_receiving(self):
account1, account2 = self.account, self.wallet.generate_account(self.ledger)
await d2f(self.ledger.update_account(account2))
await self.ledger.update_account(account2)
self.assertEqual(await self.get_balance(account1), 0)
self.assertEqual(await self.get_balance(account2), 0)
sendtxids = []
for i in range(5):
address1 = await d2f(account1.receiving.get_or_create_usable_address())
address1 = await account1.receiving.get_or_create_usable_address()
sendtxid = await self.blockchain.send_to_address(address1, 1.1)
sendtxids.append(sendtxid)
await self.on_transaction_id(sendtxid) # mempool
@ -28,13 +28,13 @@ class BasicTransactionTests(IntegrationTestCase):
self.assertEqual(round(await self.get_balance(account1)/COIN, 1), 5.5)
self.assertEqual(round(await self.get_balance(account2)/COIN, 1), 0)
address2 = await d2f(account2.receiving.get_or_create_usable_address())
address2 = await account2.receiving.get_or_create_usable_address()
hash2 = self.ledger.address_to_hash160(address2)
tx = await d2f(self.ledger.transaction_class.create(
tx = await self.ledger.transaction_class.create(
[],
[self.ledger.transaction_class.output_class.pay_pubkey_hash(2*COIN, hash2)],
[account1], account1
))
)
await self.broadcast(tx)
await self.on_transaction(tx) # mempool
await self.blockchain.generate(1)
@ -43,18 +43,18 @@ class BasicTransactionTests(IntegrationTestCase):
self.assertEqual(round(await self.get_balance(account1)/COIN, 1), 3.5)
self.assertEqual(round(await self.get_balance(account2)/COIN, 1), 2.0)
utxos = await d2f(self.account.get_utxos())
tx = await d2f(self.ledger.transaction_class.create(
utxos = await self.account.get_utxos()
tx = await self.ledger.transaction_class.create(
[self.ledger.transaction_class.input_class.spend(utxos[0])],
[],
[account1], account1
))
)
await self.broadcast(tx)
await self.on_transaction(tx) # mempool
await self.blockchain.generate(1)
await self.on_transaction(tx) # confirmed
txs = await d2f(account1.get_transactions())
txs = await account1.get_transactions()
tx = txs[1]
self.assertEqual(round(tx.inputs[0].txo_ref.txo.amount/COIN, 1), 1.1)
self.assertEqual(round(tx.inputs[1].txo_ref.txo.amount/COIN, 1), 1.1)

View file

@ -1,25 +1,26 @@
from binascii import hexlify
from twisted.trial import unittest
from twisted.internet import defer
from orchstr8.testcase import AsyncioTestCase
from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class
from torba.baseaccount import HierarchicalDeterministic, SingleKey
from torba.wallet import Wallet
class TestHierarchicalDeterministicAccount(unittest.TestCase):
class TestHierarchicalDeterministicAccount(AsyncioTestCase):
@defer.inlineCallbacks
def setUp(self):
async def asyncSetUp(self):
self.ledger = ledger_class({
'db': ledger_class.database_class(':memory:'),
'headers': ledger_class.headers_class(':memory:'),
})
yield self.ledger.db.open()
await self.ledger.db.open()
self.account = self.ledger.account_class.generate(self.ledger, Wallet(), "torba")
@defer.inlineCallbacks
def test_generate_account(self):
async def asyncTearDown(self):
await self.ledger.db.close()
async def test_generate_account(self):
account = self.account
self.assertEqual(account.ledger, self.ledger)
@ -27,81 +28,77 @@ class TestHierarchicalDeterministicAccount(unittest.TestCase):
self.assertEqual(account.public_key.ledger, self.ledger)
self.assertEqual(account.private_key.public_key, account.public_key)
addresses = yield account.receiving.get_addresses()
addresses = await account.receiving.get_addresses()
self.assertEqual(len(addresses), 0)
addresses = yield account.change.get_addresses()
addresses = await account.change.get_addresses()
self.assertEqual(len(addresses), 0)
yield account.ensure_address_gap()
await account.ensure_address_gap()
addresses = yield account.receiving.get_addresses()
addresses = await account.receiving.get_addresses()
self.assertEqual(len(addresses), 20)
addresses = yield account.change.get_addresses()
addresses = await account.change.get_addresses()
self.assertEqual(len(addresses), 6)
addresses = yield account.get_addresses()
addresses = await account.get_addresses()
self.assertEqual(len(addresses), 26)
@defer.inlineCallbacks
def test_generate_keys_over_batch_threshold_saves_it_properly(self):
yield self.account.receiving.generate_keys(0, 200)
records = yield self.account.receiving.get_address_records()
async def test_generate_keys_over_batch_threshold_saves_it_properly(self):
await self.account.receiving.generate_keys(0, 200)
records = await self.account.receiving.get_address_records()
self.assertEqual(201, len(records))
@defer.inlineCallbacks
def test_ensure_address_gap(self):
async def test_ensure_address_gap(self):
account = self.account
self.assertIsInstance(account.receiving, HierarchicalDeterministic)
yield account.receiving.generate_keys(4, 7)
yield account.receiving.generate_keys(0, 3)
yield account.receiving.generate_keys(8, 11)
records = yield account.receiving.get_address_records()
await account.receiving.generate_keys(4, 7)
await account.receiving.generate_keys(0, 3)
await account.receiving.generate_keys(8, 11)
records = await account.receiving.get_address_records()
self.assertEqual(
[r['position'] for r in records],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
)
# we have 12, but default gap is 20
new_keys = yield account.receiving.ensure_address_gap()
new_keys = await account.receiving.ensure_address_gap()
self.assertEqual(len(new_keys), 8)
records = yield account.receiving.get_address_records()
records = await account.receiving.get_address_records()
self.assertEqual(
[r['position'] for r in records],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
)
# case #1: no new addresses needed
empty = yield account.receiving.ensure_address_gap()
empty = await account.receiving.ensure_address_gap()
self.assertEqual(len(empty), 0)
# case #2: only one new addressed needed
records = yield account.receiving.get_address_records()
yield self.ledger.db.set_address_history(records[0]['address'], 'a:1:')
new_keys = yield account.receiving.ensure_address_gap()
records = await account.receiving.get_address_records()
await self.ledger.db.set_address_history(records[0]['address'], 'a:1:')
new_keys = await account.receiving.ensure_address_gap()
self.assertEqual(len(new_keys), 1)
# case #3: 20 addresses needed
yield self.ledger.db.set_address_history(new_keys[0], 'a:1:')
new_keys = yield account.receiving.ensure_address_gap()
await self.ledger.db.set_address_history(new_keys[0], 'a:1:')
new_keys = await account.receiving.ensure_address_gap()
self.assertEqual(len(new_keys), 20)
@defer.inlineCallbacks
def test_get_or_create_usable_address(self):
async def test_get_or_create_usable_address(self):
account = self.account
keys = yield account.receiving.get_addresses()
keys = await account.receiving.get_addresses()
self.assertEqual(len(keys), 0)
address = yield account.receiving.get_or_create_usable_address()
address = await account.receiving.get_or_create_usable_address()
self.assertIsNotNone(address)
keys = yield account.receiving.get_addresses()
keys = await account.receiving.get_addresses()
self.assertEqual(len(keys), 20)
@defer.inlineCallbacks
def test_generate_account_from_seed(self):
async def test_generate_account_from_seed(self):
account = self.ledger.account_class.from_dict(
self.ledger, Wallet(), {
"seed": "carbon smart garage balance margin twelve chest sword "
@ -123,17 +120,17 @@ class TestHierarchicalDeterministicAccount(unittest.TestCase):
'xpub661MyMwAqRbcFwwe67Bfjd53h5WXmKm6tqfBJZZH3pQLoy8Nb6mKUMJFc7UbpV'
'NzmwFPN2evn3YHnig1pkKVYcvCV8owTd2yAcEkJfCX53g'
)
address = yield account.receiving.ensure_address_gap()
address = await account.receiving.ensure_address_gap()
self.assertEqual(address[0], '1CDLuMfwmPqJiNk5C2Bvew6tpgjAGgUk8J')
private_key = yield self.ledger.get_private_key_for_address('1CDLuMfwmPqJiNk5C2Bvew6tpgjAGgUk8J')
private_key = await self.ledger.get_private_key_for_address('1CDLuMfwmPqJiNk5C2Bvew6tpgjAGgUk8J')
self.assertEqual(
private_key.extended_key_string(),
'xprv9xV7rhbg6M4yWrdTeLorz3Q1GrQb4aQzzGWboP3du7W7UUztzNTUrEYTnDfz7o'
'ptBygDxXYRppyiuenJpoBTgYP2C26E1Ah5FEALM24CsWi'
)
invalid_key = yield self.ledger.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX')
invalid_key = await self.ledger.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX')
self.assertIsNone(invalid_key)
self.assertEqual(
@ -141,8 +138,7 @@ class TestHierarchicalDeterministicAccount(unittest.TestCase):
b'1c01ae1e4c7d89e39f6d3aa7792c097a30ca7d40be249b6de52c81ec8cf9aab48b01'
)
@defer.inlineCallbacks
def test_load_and_save_account(self):
async def test_load_and_save_account(self):
account_data = {
'name': 'My Account',
'seed':
@ -164,11 +160,11 @@ class TestHierarchicalDeterministicAccount(unittest.TestCase):
account = self.ledger.account_class.from_dict(self.ledger, Wallet(), account_data)
yield account.ensure_address_gap()
await account.ensure_address_gap()
addresses = yield account.receiving.get_addresses()
addresses = await account.receiving.get_addresses()
self.assertEqual(len(addresses), 5)
addresses = yield account.change.get_addresses()
addresses = await account.change.get_addresses()
self.assertEqual(len(addresses), 5)
self.maxDiff = None
@ -176,20 +172,21 @@ class TestHierarchicalDeterministicAccount(unittest.TestCase):
self.assertDictEqual(account_data, account.to_dict())
class TestSingleKeyAccount(unittest.TestCase):
class TestSingleKeyAccount(AsyncioTestCase):
@defer.inlineCallbacks
def setUp(self):
async def asyncSetUp(self):
self.ledger = ledger_class({
'db': ledger_class.database_class(':memory:'),
'headers': ledger_class.headers_class(':memory:'),
})
yield self.ledger.db.open()
await self.ledger.db.open()
self.account = self.ledger.account_class.generate(
self.ledger, Wallet(), "torba", {'name': 'single-address'})
@defer.inlineCallbacks
def test_generate_account(self):
async def asyncTearDown(self):
await self.ledger.db.close()
async def test_generate_account(self):
account = self.account
self.assertEqual(account.ledger, self.ledger)
@ -197,37 +194,36 @@ class TestSingleKeyAccount(unittest.TestCase):
self.assertEqual(account.public_key.ledger, self.ledger)
self.assertEqual(account.private_key.public_key, account.public_key)
addresses = yield account.receiving.get_addresses()
addresses = await account.receiving.get_addresses()
self.assertEqual(len(addresses), 0)
addresses = yield account.change.get_addresses()
addresses = await account.change.get_addresses()
self.assertEqual(len(addresses), 0)
yield account.ensure_address_gap()
await account.ensure_address_gap()
addresses = yield account.receiving.get_addresses()
addresses = await account.receiving.get_addresses()
self.assertEqual(len(addresses), 1)
self.assertEqual(addresses[0], account.public_key.address)
addresses = yield account.change.get_addresses()
addresses = await account.change.get_addresses()
self.assertEqual(len(addresses), 1)
self.assertEqual(addresses[0], account.public_key.address)
addresses = yield account.get_addresses()
addresses = await account.get_addresses()
self.assertEqual(len(addresses), 1)
self.assertEqual(addresses[0], account.public_key.address)
@defer.inlineCallbacks
def test_ensure_address_gap(self):
async def test_ensure_address_gap(self):
account = self.account
self.assertIsInstance(account.receiving, SingleKey)
addresses = yield account.receiving.get_addresses()
addresses = await account.receiving.get_addresses()
self.assertEqual(addresses, [])
# we have 12, but default gap is 20
new_keys = yield account.receiving.ensure_address_gap()
new_keys = await account.receiving.ensure_address_gap()
self.assertEqual(len(new_keys), 1)
self.assertEqual(new_keys[0], account.public_key.address)
records = yield account.receiving.get_address_records()
records = await account.receiving.get_address_records()
self.assertEqual(records, [{
'position': 0, 'chain': 0,
'account': account.public_key.address,
@ -236,37 +232,35 @@ class TestSingleKeyAccount(unittest.TestCase):
}])
# case #1: no new addresses needed
empty = yield account.receiving.ensure_address_gap()
empty = await account.receiving.ensure_address_gap()
self.assertEqual(len(empty), 0)
# case #2: after use, still no new address needed
records = yield account.receiving.get_address_records()
yield self.ledger.db.set_address_history(records[0]['address'], 'a:1:')
empty = yield account.receiving.ensure_address_gap()
records = await account.receiving.get_address_records()
await self.ledger.db.set_address_history(records[0]['address'], 'a:1:')
empty = await account.receiving.ensure_address_gap()
self.assertEqual(len(empty), 0)
@defer.inlineCallbacks
def test_get_or_create_usable_address(self):
async def test_get_or_create_usable_address(self):
account = self.account
addresses = yield account.receiving.get_addresses()
addresses = await account.receiving.get_addresses()
self.assertEqual(len(addresses), 0)
address1 = yield account.receiving.get_or_create_usable_address()
address1 = await account.receiving.get_or_create_usable_address()
self.assertIsNotNone(address1)
yield self.ledger.db.set_address_history(address1, 'a:1:b:2:c:3:')
records = yield account.receiving.get_address_records()
await self.ledger.db.set_address_history(address1, 'a:1:b:2:c:3:')
records = await account.receiving.get_address_records()
self.assertEqual(records[0]['used_times'], 3)
address2 = yield account.receiving.get_or_create_usable_address()
address2 = await account.receiving.get_or_create_usable_address()
self.assertEqual(address1, address2)
keys = yield account.receiving.get_addresses()
keys = await account.receiving.get_addresses()
self.assertEqual(len(keys), 1)
@defer.inlineCallbacks
def test_generate_account_from_seed(self):
async def test_generate_account_from_seed(self):
account = self.ledger.account_class.from_dict(
self.ledger, Wallet(), {
"seed":
@ -285,17 +279,17 @@ class TestSingleKeyAccount(unittest.TestCase):
'xpub661MyMwAqRbcFwwe67Bfjd53h5WXmKm6tqfBJZZH3pQLoy8Nb6mKUMJFc7'
'UbpVNzmwFPN2evn3YHnig1pkKVYcvCV8owTd2yAcEkJfCX53g',
)
address = yield account.receiving.ensure_address_gap()
address = await account.receiving.ensure_address_gap()
self.assertEqual(address[0], account.public_key.address)
private_key = yield self.ledger.get_private_key_for_address(address[0])
private_key = await self.ledger.get_private_key_for_address(address[0])
self.assertEqual(
private_key.extended_key_string(),
'xprv9s21ZrQH143K3TsAz5efNV8K93g3Ms3FXcjaWB9fVUsMwAoE3ZT4vYymkp'
'5BxKKfnpz8J6sHDFriX1SnpvjNkzcks8XBnxjGLS83BTyfpna',
)
invalid_key = yield self.ledger.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX')
invalid_key = await self.ledger.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX')
self.assertIsNone(invalid_key)
self.assertEqual(
@ -303,8 +297,7 @@ class TestSingleKeyAccount(unittest.TestCase):
b'1c92caa0ef99bfd5e2ceb73b66da8cd726a9370be8c368d448a322f3c5b23aaab901'
)
@defer.inlineCallbacks
def test_load_and_save_account(self):
async def test_load_and_save_account(self):
account_data = {
'name': 'My Account',
'seed':
@ -322,11 +315,11 @@ class TestSingleKeyAccount(unittest.TestCase):
account = self.ledger.account_class.from_dict(self.ledger, Wallet(), account_data)
yield account.ensure_address_gap()
await account.ensure_address_gap()
addresses = yield account.receiving.get_addresses()
addresses = await account.receiving.get_addresses()
self.assertEqual(len(addresses), 1)
addresses = yield account.change.get_addresses()
addresses = await account.change.get_addresses()
self.assertEqual(len(addresses), 1)
self.maxDiff = None
@ -334,7 +327,7 @@ class TestSingleKeyAccount(unittest.TestCase):
self.assertDictEqual(account_data, account.to_dict())
class AccountEncryptionTests(unittest.TestCase):
class AccountEncryptionTests(AsyncioTestCase):
password = "password"
init_vector = b'0000000000000000'
unencrypted_account = {
@ -368,7 +361,7 @@ class AccountEncryptionTests(unittest.TestCase):
'address_generator': {'name': 'single-address'}
}
def setUp(self):
async def asyncSetUp(self):
self.ledger = ledger_class({
'db': ledger_class.database_class(':memory:'),
'headers': ledger_class.headers_class(':memory:'),

View file

@ -1,4 +1,4 @@
from twisted.trial import unittest
import unittest
from torba.bcd_data_stream import BCDataStream

View file

@ -1,5 +1,5 @@
import unittest
from binascii import unhexlify, hexlify
from twisted.trial import unittest
from .key_fixtures import expected_ids, expected_privkeys, expected_hardened_privkeys
from torba.bip32 import PubKey, PrivateKey, from_extended_key_string

View file

@ -1,4 +1,4 @@
from twisted.trial import unittest
import unittest
from types import GeneratorType
from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class

View file

@ -1,11 +1,12 @@
from twisted.trial import unittest
from twisted.internet import defer
import unittest
from torba.wallet import Wallet
from torba.constants import COIN
from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class
from torba.basedatabase import query, constraints_to_sql
from orchstr8.testcase import AsyncioTestCase
from .test_transaction import get_output, NULL_HASH
@ -100,53 +101,52 @@ class TestQueryBuilder(unittest.TestCase):
)
class TestQueries(unittest.TestCase):
class TestQueries(AsyncioTestCase):
def setUp(self):
async def asyncSetUp(self):
self.ledger = ledger_class({
'db': ledger_class.database_class(':memory:'),
'headers': ledger_class.headers_class(':memory:'),
})
return self.ledger.db.open()
await self.ledger.db.open()
@defer.inlineCallbacks
def create_account(self):
async def asyncTearDown(self):
await self.ledger.db.close()
async def create_account(self):
account = self.ledger.account_class.generate(self.ledger, Wallet())
yield account.ensure_address_gap()
await account.ensure_address_gap()
return account
@defer.inlineCallbacks
def create_tx_from_nothing(self, my_account, height):
to_address = yield my_account.receiving.get_or_create_usable_address()
async def create_tx_from_nothing(self, my_account, height):
to_address = await my_account.receiving.get_or_create_usable_address()
to_hash = ledger_class.address_to_hash160(to_address)
tx = ledger_class.transaction_class(height=height, is_verified=True) \
.add_inputs([self.txi(self.txo(1, NULL_HASH))]) \
.add_outputs([self.txo(1, to_hash)])
yield self.ledger.db.save_transaction_io('insert', tx, to_address, to_hash, '')
await self.ledger.db.save_transaction_io('insert', tx, to_address, to_hash, '')
return tx
@defer.inlineCallbacks
def create_tx_from_txo(self, txo, to_account, height):
async def create_tx_from_txo(self, txo, to_account, height):
from_hash = txo.script.values['pubkey_hash']
from_address = self.ledger.hash160_to_address(from_hash)
to_address = yield to_account.receiving.get_or_create_usable_address()
to_address = await to_account.receiving.get_or_create_usable_address()
to_hash = ledger_class.address_to_hash160(to_address)
tx = ledger_class.transaction_class(height=height, is_verified=True) \
.add_inputs([self.txi(txo)]) \
.add_outputs([self.txo(1, to_hash)])
yield self.ledger.db.save_transaction_io('insert', tx, from_address, from_hash, '')
yield self.ledger.db.save_transaction_io('', tx, to_address, to_hash, '')
await self.ledger.db.save_transaction_io('insert', tx, from_address, from_hash, '')
await self.ledger.db.save_transaction_io('', tx, to_address, to_hash, '')
return tx
@defer.inlineCallbacks
def create_tx_to_nowhere(self, txo, height):
async def create_tx_to_nowhere(self, txo, height):
from_hash = txo.script.values['pubkey_hash']
from_address = self.ledger.hash160_to_address(from_hash)
to_hash = NULL_HASH
tx = ledger_class.transaction_class(height=height, is_verified=True) \
.add_inputs([self.txi(txo)]) \
.add_outputs([self.txo(1, to_hash)])
yield self.ledger.db.save_transaction_io('insert', tx, from_address, from_hash, '')
await self.ledger.db.save_transaction_io('insert', tx, from_address, from_hash, '')
return tx
def txo(self, amount, address):
@ -155,39 +155,38 @@ class TestQueries(unittest.TestCase):
def txi(self, txo):
return ledger_class.transaction_class.input_class.spend(txo)
@defer.inlineCallbacks
def test_get_transactions(self):
account1 = yield self.create_account()
account2 = yield self.create_account()
tx1 = yield self.create_tx_from_nothing(account1, 1)
tx2 = yield self.create_tx_from_txo(tx1.outputs[0], account2, 2)
tx3 = yield self.create_tx_to_nowhere(tx2.outputs[0], 3)
async def test_get_transactions(self):
account1 = await self.create_account()
account2 = await self.create_account()
tx1 = await self.create_tx_from_nothing(account1, 1)
tx2 = await self.create_tx_from_txo(tx1.outputs[0], account2, 2)
tx3 = await self.create_tx_to_nowhere(tx2.outputs[0], 3)
txs = yield self.ledger.db.get_transactions()
txs = await self.ledger.db.get_transactions()
self.assertEqual([tx3.id, tx2.id, tx1.id], [tx.id for tx in txs])
self.assertEqual([3, 2, 1], [tx.height for tx in txs])
txs = yield self.ledger.db.get_transactions(account=account1)
txs = await self.ledger.db.get_transactions(account=account1)
self.assertEqual([tx2.id, tx1.id], [tx.id for tx in txs])
self.assertEqual(txs[0].inputs[0].is_my_account, True)
self.assertEqual(txs[0].outputs[0].is_my_account, False)
self.assertEqual(txs[1].inputs[0].is_my_account, False)
self.assertEqual(txs[1].outputs[0].is_my_account, True)
txs = yield self.ledger.db.get_transactions(account=account2)
txs = await self.ledger.db.get_transactions(account=account2)
self.assertEqual([tx3.id, tx2.id], [tx.id for tx in txs])
self.assertEqual(txs[0].inputs[0].is_my_account, True)
self.assertEqual(txs[0].outputs[0].is_my_account, False)
self.assertEqual(txs[1].inputs[0].is_my_account, False)
self.assertEqual(txs[1].outputs[0].is_my_account, True)
tx = yield self.ledger.db.get_transaction(txid=tx2.id)
tx = await self.ledger.db.get_transaction(txid=tx2.id)
self.assertEqual(tx.id, tx2.id)
self.assertEqual(tx.inputs[0].is_my_account, False)
self.assertEqual(tx.outputs[0].is_my_account, False)
tx = yield self.ledger.db.get_transaction(txid=tx2.id, account=account1)
tx = await self.ledger.db.get_transaction(txid=tx2.id, account=account1)
self.assertEqual(tx.inputs[0].is_my_account, True)
self.assertEqual(tx.outputs[0].is_my_account, False)
tx = yield self.ledger.db.get_transaction(txid=tx2.id, account=account2)
tx = await self.ledger.db.get_transaction(txid=tx2.id, account=account2)
self.assertEqual(tx.inputs[0].is_my_account, False)
self.assertEqual(tx.outputs[0].is_my_account, True)

View file

@ -1,11 +1,6 @@
from unittest import TestCase
from unittest import TestCase, mock
from torba.hash import aes_decrypt, aes_encrypt
try:
from unittest import mock
except ImportError:
import mock
class TestAESEncryptDecrypt(TestCase):
message = 'The Times 03/Jan/2009 Chancellor on brink of second bailout for banks'

View file

@ -1,8 +1,7 @@
import os
from urllib.request import Request, urlopen
from twisted.trial import unittest
from twisted.internet import defer
from orchstr8.testcase import AsyncioTestCase
from torba.coin.bitcoinsegwit import MainHeaders
@ -11,7 +10,7 @@ def block_bytes(blocks):
return blocks * MainHeaders.header_size
class BitcoinHeadersTestCase(unittest.TestCase):
class BitcoinHeadersTestCase(AsyncioTestCase):
# Download headers instead of storing them in git.
HEADER_URL = 'http://headers.electrum.org/blockchain_headers'
@ -39,7 +38,7 @@ class BitcoinHeadersTestCase(unittest.TestCase):
headers.seek(after, os.SEEK_SET)
return headers.read(upto)
def get_headers(self, upto: int = -1):
async def get_headers(self, upto: int = -1):
h = MainHeaders(':memory:')
h.io.write(self.get_bytes(upto))
return h
@ -47,8 +46,8 @@ class BitcoinHeadersTestCase(unittest.TestCase):
class BasicHeadersTests(BitcoinHeadersTestCase):
def test_serialization(self):
h = self.get_headers()
async def test_serialization(self):
h = await self.get_headers()
self.assertEqual(h[0], {
'bits': 486604799,
'block_height': 0,
@ -94,18 +93,16 @@ class BasicHeadersTests(BitcoinHeadersTestCase):
h.get_raw_header(self.RETARGET_BLOCK)
)
@defer.inlineCallbacks
def test_connect_from_genesis_to_3000_past_first_chunk_at_2016(self):
async def test_connect_from_genesis_to_3000_past_first_chunk_at_2016(self):
headers = MainHeaders(':memory:')
self.assertEqual(headers.height, -1)
yield headers.connect(0, self.get_bytes(block_bytes(3001)))
await headers.connect(0, self.get_bytes(block_bytes(3001)))
self.assertEqual(headers.height, 3000)
@defer.inlineCallbacks
def test_connect_9_blocks_passing_a_retarget_at_32256(self):
async def test_connect_9_blocks_passing_a_retarget_at_32256(self):
retarget = block_bytes(self.RETARGET_BLOCK-5)
headers = self.get_headers(upto=retarget)
headers = await self.get_headers(upto=retarget)
remainder = self.get_bytes(after=retarget)
self.assertEqual(headers.height, 32250)
yield headers.connect(len(headers), remainder)
await headers.connect(len(headers), remainder)
self.assertEqual(headers.height, 32259)

View file

@ -1,6 +1,5 @@
import os
from binascii import hexlify
from twisted.internet import defer
from torba.coin.bitcoinsegwit import MainNetLedger
from torba.wallet import Wallet
@ -18,32 +17,30 @@ class MockNetwork:
self.get_history_called = []
self.get_transaction_called = []
def get_history(self, address):
async def get_history(self, address):
self.get_history_called.append(address)
self.address = address
return defer.succeed(self.history)
return self.history
def get_merkle(self, txid, height):
return defer.succeed({'merkle': ['abcd01'], 'pos': 1})
async def get_merkle(self, txid, height):
return {'merkle': ['abcd01'], 'pos': 1}
def get_transaction(self, tx_hash):
async def get_transaction(self, tx_hash):
self.get_transaction_called.append(tx_hash)
return defer.succeed(self.transaction[tx_hash])
return self.transaction[tx_hash]
class LedgerTestCase(BitcoinHeadersTestCase):
def setUp(self):
super().setUp()
async def asyncSetUp(self):
self.ledger = MainNetLedger({
'db': MainNetLedger.database_class(':memory:'),
'headers': MainNetLedger.headers_class(':memory:')
})
return self.ledger.db.open()
await self.ledger.db.open()
def tearDown(self):
super().tearDown()
return self.ledger.db.close()
async def asyncTearDown(self):
await self.ledger.db.close()
def make_header(self, **kwargs):
header = {
@ -69,11 +66,10 @@ class LedgerTestCase(BitcoinHeadersTestCase):
class TestSynchronization(LedgerTestCase):
@defer.inlineCallbacks
def test_update_history(self):
async def test_update_history(self):
account = self.ledger.account_class.generate(self.ledger, Wallet(), "torba")
address = yield account.receiving.get_or_create_usable_address()
address_details = yield self.ledger.db.get_address(address=address)
address = await account.receiving.get_or_create_usable_address()
address_details = await self.ledger.db.get_address(address=address)
self.assertEqual(address_details['history'], None)
self.add_header(block_height=0, merkle_root=b'abcd04')
@ -89,16 +85,16 @@ class TestSynchronization(LedgerTestCase):
'abcd02': hexlify(get_transaction(get_output(2)).raw),
'abcd03': hexlify(get_transaction(get_output(3)).raw),
})
yield self.ledger.update_history(address)
await self.ledger.update_history(address)
self.assertEqual(self.ledger.network.get_history_called, [address])
self.assertEqual(self.ledger.network.get_transaction_called, ['abcd01', 'abcd02', 'abcd03'])
address_details = yield self.ledger.db.get_address(address=address)
address_details = await self.ledger.db.get_address(address=address)
self.assertEqual(address_details['history'], 'abcd01:0:abcd02:1:abcd03:2:')
self.ledger.network.get_history_called = []
self.ledger.network.get_transaction_called = []
yield self.ledger.update_history(address)
await self.ledger.update_history(address)
self.assertEqual(self.ledger.network.get_history_called, [address])
self.assertEqual(self.ledger.network.get_transaction_called, [])
@ -106,10 +102,10 @@ class TestSynchronization(LedgerTestCase):
self.ledger.network.transaction['abcd04'] = hexlify(get_transaction(get_output(4)).raw)
self.ledger.network.get_history_called = []
self.ledger.network.get_transaction_called = []
yield self.ledger.update_history(address)
await self.ledger.update_history(address)
self.assertEqual(self.ledger.network.get_history_called, [address])
self.assertEqual(self.ledger.network.get_transaction_called, ['abcd04'])
address_details = yield self.ledger.db.get_address(address=address)
address_details = await self.ledger.db.get_address(address=address)
self.assertEqual(address_details['history'], 'abcd01:0:abcd02:1:abcd03:2:abcd04:3:')
@ -117,14 +113,13 @@ class MocHeaderNetwork:
def __init__(self, responses):
self.responses = responses
def get_headers(self, height, blocks):
async def get_headers(self, height, blocks):
return self.responses[height]
class BlockchainReorganizationTests(LedgerTestCase):
@defer.inlineCallbacks
def test_1_block_reorganization(self):
async def test_1_block_reorganization(self):
self.ledger.network = MocHeaderNetwork({
20: {'height': 20, 'count': 5, 'hex': hexlify(
self.get_bytes(after=block_bytes(20), upto=block_bytes(5))
@ -132,15 +127,14 @@ class BlockchainReorganizationTests(LedgerTestCase):
25: {'height': 25, 'count': 0, 'hex': b''}
})
headers = self.ledger.headers
yield headers.connect(0, self.get_bytes(upto=block_bytes(20)))
await headers.connect(0, self.get_bytes(upto=block_bytes(20)))
self.add_header(block_height=len(headers))
self.assertEqual(headers.height, 20)
yield self.ledger.receive_header([{
await self.ledger.receive_header([{
'height': 21, 'hex': hexlify(self.make_header(block_height=21))
}])
@defer.inlineCallbacks
def test_3_block_reorganization(self):
async def test_3_block_reorganization(self):
self.ledger.network = MocHeaderNetwork({
20: {'height': 20, 'count': 5, 'hex': hexlify(
self.get_bytes(after=block_bytes(20), upto=block_bytes(5))
@ -150,11 +144,11 @@ class BlockchainReorganizationTests(LedgerTestCase):
25: {'height': 25, 'count': 0, 'hex': b''}
})
headers = self.ledger.headers
yield headers.connect(0, self.get_bytes(upto=block_bytes(20)))
await headers.connect(0, self.get_bytes(upto=block_bytes(20)))
self.add_header(block_height=len(headers))
self.add_header(block_height=len(headers))
self.add_header(block_height=len(headers))
self.assertEqual(headers.height, 22)
yield self.ledger.receive_header(({
await self.ledger.receive_header(({
'height': 23, 'hex': hexlify(self.make_header(block_height=23))
},))

View file

@ -1,5 +1,5 @@
import unittest
from binascii import hexlify, unhexlify
from twisted.trial import unittest
from torba.bcd_data_stream import BCDataStream
from torba.basescript import Template, ParseError, tokenize, push_data

View file

@ -1,7 +1,8 @@
import unittest
from binascii import hexlify, unhexlify
from itertools import cycle
from twisted.trial import unittest
from twisted.internet import defer
from orchstr8.testcase import AsyncioTestCase
from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class
from torba.wallet import Wallet
@ -29,9 +30,9 @@ def get_transaction(txo=None):
.add_outputs([txo or ledger_class.transaction_class.output_class.pay_pubkey_hash(CENT, NULL_HASH)])
class TestSizeAndFeeEstimation(unittest.TestCase):
class TestSizeAndFeeEstimation(AsyncioTestCase):
def setUp(self):
async def asyncSetUp(self):
self.ledger = ledger_class({
'db': ledger_class.database_class(':memory:'),
'headers': ledger_class.headers_class(':memory:'),
@ -181,17 +182,19 @@ class TestTransactionSerialization(unittest.TestCase):
self.assertEqual(tx.raw, raw)
class TestTransactionSigning(unittest.TestCase):
class TestTransactionSigning(AsyncioTestCase):
def setUp(self):
async def asyncSetUp(self):
self.ledger = ledger_class({
'db': ledger_class.database_class(':memory:'),
'headers': ledger_class.headers_class(':memory:'),
})
return self.ledger.db.open()
await self.ledger.db.open()
@defer.inlineCallbacks
def test_sign(self):
async def asyncTearDown(self):
await self.ledger.db.close()
async def test_sign(self):
account = self.ledger.account_class.from_dict(
self.ledger, Wallet(), {
"seed": "carbon smart garage balance margin twelve chest sword "
@ -200,8 +203,8 @@ class TestTransactionSigning(unittest.TestCase):
}
)
yield account.ensure_address_gap()
address1, address2 = yield account.receiving.get_addresses(limit=2)
await account.ensure_address_gap()
address1, address2 = await account.receiving.get_addresses(limit=2)
pubkey_hash1 = self.ledger.address_to_hash160(address1)
pubkey_hash2 = self.ledger.address_to_hash160(address2)
@ -211,9 +214,8 @@ class TestTransactionSigning(unittest.TestCase):
.add_inputs([tx_class.input_class.spend(get_output(2*COIN, pubkey_hash1))]) \
.add_outputs([tx_class.output_class.pay_pubkey_hash(int(1.9*COIN), pubkey_hash2)]) \
yield tx.sign([account])
await tx.sign([account])
print(hexlify(tx.inputs[0].script.values['signature']))
self.assertEqual(
hexlify(tx.inputs[0].script.values['signature']),
b'304402205a1df8cd5d2d2fa5934b756883d6c07e4f83e1350c740992d47a12422'
@ -221,15 +223,14 @@ class TestTransactionSigning(unittest.TestCase):
)
class TransactionIOBalancing(unittest.TestCase):
class TransactionIOBalancing(AsyncioTestCase):
@defer.inlineCallbacks
def setUp(self):
async def asyncSetUp(self):
self.ledger = ledger_class({
'db': ledger_class.database_class(':memory:'),
'headers': ledger_class.headers_class(':memory:'),
})
yield self.ledger.db.open()
await self.ledger.db.open()
self.account = self.ledger.account_class.from_dict(
self.ledger, Wallet(), {
"seed": "carbon smart garage balance margin twelve chest sword "
@ -237,10 +238,13 @@ class TransactionIOBalancing(unittest.TestCase):
}
)
addresses = yield self.account.ensure_address_gap()
addresses = await self.account.ensure_address_gap()
self.pubkey_hash = [self.ledger.address_to_hash160(a) for a in addresses]
self.hash_cycler = cycle(self.pubkey_hash)
async def asyncTearDown(self):
await self.ledger.db.close()
def txo(self, amount, address=None):
return get_output(int(amount*COIN), address or next(self.hash_cycler))
@ -250,8 +254,7 @@ class TransactionIOBalancing(unittest.TestCase):
def tx(self, inputs, outputs):
return ledger_class.transaction_class.create(inputs, outputs, [self.account], self.account)
@defer.inlineCallbacks
def create_utxos(self, amounts):
async def create_utxos(self, amounts):
utxos = [self.txo(amount) for amount in amounts]
self.funding_tx = ledger_class.transaction_class(is_verified=True) \
@ -260,7 +263,7 @@ class TransactionIOBalancing(unittest.TestCase):
save_tx = 'insert'
for utxo in utxos:
yield self.ledger.db.save_transaction_io(
await self.ledger.db.save_transaction_io(
save_tx, self.funding_tx,
self.ledger.hash160_to_address(utxo.script.values['pubkey_hash']),
utxo.script.values['pubkey_hash'], ''
@ -277,17 +280,16 @@ class TransactionIOBalancing(unittest.TestCase):
def outputs(tx):
return [round(o.amount/COIN, 2) for o in tx.outputs]
@defer.inlineCallbacks
def test_basic_use_cases(self):
async def test_basic_use_cases(self):
self.ledger.fee_per_byte = int(.01*CENT)
# available UTXOs for filling missing inputs
utxos = yield self.create_utxos([
utxos = await self.create_utxos([
1, 1, 3, 5, 10
])
# pay 3 coins (3.02 w/ fees)
tx = yield self.tx(
tx = await self.tx(
[], # inputs
[self.txo(3)] # outputs
)
@ -296,10 +298,10 @@ class TransactionIOBalancing(unittest.TestCase):
# a change of 1.98 is added to reach balance
self.assertEqual(self.outputs(tx), [3, 1.98])
yield self.ledger.release_outputs(utxos)
await self.ledger.release_outputs(utxos)
# pay 2.98 coins (3.00 w/ fees)
tx = yield self.tx(
tx = await self.tx(
[], # inputs
[self.txo(2.98)] # outputs
)
@ -307,10 +309,10 @@ class TransactionIOBalancing(unittest.TestCase):
self.assertEqual(self.inputs(tx), [3])
self.assertEqual(self.outputs(tx), [2.98])
yield self.ledger.release_outputs(utxos)
await self.ledger.release_outputs(utxos)
# supplied input and output, but input is not enough to cover output
tx = yield self.tx(
tx = await self.tx(
[self.txi(self.txo(10))], # inputs
[self.txo(11)] # outputs
)
@ -319,10 +321,10 @@ class TransactionIOBalancing(unittest.TestCase):
# change is now needed to consume extra input
self.assertEqual([11, 1.96], self.outputs(tx))
yield self.ledger.release_outputs(utxos)
await self.ledger.release_outputs(utxos)
# liquidating a UTXO
tx = yield self.tx(
tx = await self.tx(
[self.txi(self.txo(10))], # inputs
[] # outputs
)
@ -330,10 +332,10 @@ class TransactionIOBalancing(unittest.TestCase):
# missing change added to consume the amount
self.assertEqual([9.98], self.outputs(tx))
yield self.ledger.release_outputs(utxos)
await self.ledger.release_outputs(utxos)
# liquidating at a loss, requires adding extra inputs
tx = yield self.tx(
tx = await self.tx(
[self.txi(self.txo(0.01))], # inputs
[] # outputs
)

View file

@ -1,4 +1,4 @@
from twisted.trial import unittest
import unittest
from torba.util import ArithUint256

View file

@ -1,5 +1,5 @@
import unittest
import tempfile
from twisted.trial import unittest
from torba.coin.bitcoinsegwit import MainNetLedger as BTCLedger
from torba.coin.bitcoincash import MainNetLedger as BCHLedger

View file

@ -1,7 +1,6 @@
import random
import typing
from typing import List, Dict, Tuple, Type, Optional, Any
from twisted.internet import defer
from torba.mnemonic import Mnemonic
from torba.bip32 import PrivateKey, PubKey, from_extended_key_string
@ -44,12 +43,8 @@ class AddressManager:
def to_dict_instance(self) -> Optional[dict]:
raise NotImplementedError
@property
def db(self):
return self.account.ledger.db
def _query_addresses(self, **constraints):
return self.db.get_addresses(
return self.account.ledger.db.get_addresses(
account=self.account,
chain=self.chain_number,
**constraints
@ -58,26 +53,24 @@ class AddressManager:
def get_private_key(self, index: int) -> PrivateKey:
raise NotImplementedError
def get_max_gap(self) -> defer.Deferred:
async def get_max_gap(self):
raise NotImplementedError
def ensure_address_gap(self) -> defer.Deferred:
async def ensure_address_gap(self):
raise NotImplementedError
def get_address_records(self, only_usable: bool = False, **constraints) -> defer.Deferred:
def get_address_records(self, only_usable: bool = False, **constraints):
raise NotImplementedError
@defer.inlineCallbacks
def get_addresses(self, only_usable: bool = False, **constraints) -> defer.Deferred:
records = yield self.get_address_records(only_usable=only_usable, **constraints)
async def get_addresses(self, only_usable: bool = False, **constraints):
records = await self.get_address_records(only_usable=only_usable, **constraints)
return [r['address'] for r in records]
@defer.inlineCallbacks
def get_or_create_usable_address(self) -> defer.Deferred:
addresses = yield self.get_addresses(only_usable=True, limit=10)
async def get_or_create_usable_address(self):
addresses = await self.get_addresses(only_usable=True, limit=10)
if addresses:
return random.choice(addresses)
addresses = yield self.ensure_address_gap()
addresses = await self.ensure_address_gap()
return addresses[0]
@ -106,22 +99,20 @@ class HierarchicalDeterministic(AddressManager):
def get_private_key(self, index: int) -> PrivateKey:
return self.account.private_key.child(self.chain_number).child(index)
@defer.inlineCallbacks
def generate_keys(self, start: int, end: int) -> defer.Deferred:
async def generate_keys(self, start: int, end: int):
keys_batch, final_keys = [], []
for index in range(start, end+1):
keys_batch.append((index, self.public_key.child(index)))
if index % 180 == 0 or index == end:
yield self.db.add_keys(
await self.account.ledger.db.add_keys(
self.account, self.chain_number, keys_batch
)
final_keys.extend(keys_batch)
keys_batch.clear()
return [key[1].address for key in final_keys]
@defer.inlineCallbacks
def get_max_gap(self) -> defer.Deferred:
addresses = yield self._query_addresses(order_by="position ASC")
async def get_max_gap(self):
addresses = await self._query_addresses(order_by="position ASC")
max_gap = 0
current_gap = 0
for address in addresses:
@ -132,9 +123,8 @@ class HierarchicalDeterministic(AddressManager):
current_gap = 0
return max_gap
@defer.inlineCallbacks
def ensure_address_gap(self) -> defer.Deferred:
addresses = yield self._query_addresses(limit=self.gap, order_by="position DESC")
async def ensure_address_gap(self):
addresses = await self._query_addresses(limit=self.gap, order_by="position DESC")
existing_gap = 0
for address in addresses:
@ -148,7 +138,7 @@ class HierarchicalDeterministic(AddressManager):
start = addresses[0]['position']+1 if addresses else 0
end = start + (self.gap - existing_gap)
new_keys = yield self.generate_keys(start, end-1)
new_keys = await self.generate_keys(start, end-1)
return new_keys
def get_address_records(self, only_usable: bool = False, **constraints):
@ -176,20 +166,19 @@ class SingleKey(AddressManager):
def get_private_key(self, index: int) -> PrivateKey:
return self.account.private_key
def get_max_gap(self) -> defer.Deferred:
return defer.succeed(0)
async def get_max_gap(self):
return 0
@defer.inlineCallbacks
def ensure_address_gap(self) -> defer.Deferred:
exists = yield self.get_address_records()
async def ensure_address_gap(self):
exists = await self.get_address_records()
if not exists:
yield self.db.add_keys(
await self.account.ledger.db.add_keys(
self.account, self.chain_number, [(0, self.public_key)]
)
return [self.public_key.address]
return []
def get_address_records(self, only_usable: bool = False, **constraints) -> defer.Deferred:
def get_address_records(self, only_usable: bool = False, **constraints):
return self._query_addresses(**constraints)
@ -289,9 +278,8 @@ class BaseAccount:
'address_generator': self.address_generator.to_dict(self.receiving, self.change)
}
@defer.inlineCallbacks
def get_details(self, show_seed=False, **kwargs):
satoshis = yield self.get_balance(**kwargs)
async def get_details(self, show_seed=False, **kwargs):
satoshis = await self.get_balance(**kwargs)
details = {
'id': self.id,
'name': self.name,
@ -325,23 +313,21 @@ class BaseAccount:
self.password = None
self.encrypted = True
@defer.inlineCallbacks
def ensure_address_gap(self):
async def ensure_address_gap(self):
addresses = []
for address_manager in self.address_managers:
new_addresses = yield address_manager.ensure_address_gap()
new_addresses = await address_manager.ensure_address_gap()
addresses.extend(new_addresses)
return addresses
@defer.inlineCallbacks
def get_addresses(self, **constraints) -> defer.Deferred:
rows = yield self.ledger.db.select_addresses('address', account=self, **constraints)
async def get_addresses(self, **constraints):
rows = await self.ledger.db.select_addresses('address', account=self, **constraints)
return [r[0] for r in rows]
def get_address_records(self, **constraints) -> defer.Deferred:
def get_address_records(self, **constraints):
return self.ledger.db.get_addresses(account=self, **constraints)
def get_address_count(self, **constraints) -> defer.Deferred:
def get_address_count(self, **constraints):
return self.ledger.db.get_address_count(account=self, **constraints)
def get_private_key(self, chain: int, index: int) -> PrivateKey:
@ -355,10 +341,9 @@ class BaseAccount:
constraints.update({'height__lte': height, 'height__gt': 0})
return self.ledger.db.get_balance(account=self, **constraints)
@defer.inlineCallbacks
def get_max_gap(self):
change_gap = yield self.change.get_max_gap()
receiving_gap = yield self.receiving.get_max_gap()
async def get_max_gap(self):
change_gap = await self.change.get_max_gap()
receiving_gap = await self.receiving.get_max_gap()
return {
'max_change_gap': change_gap,
'max_receiving_gap': receiving_gap,
@ -376,24 +361,23 @@ class BaseAccount:
def get_transaction_count(self, **constraints):
return self.ledger.db.get_transaction_count(account=self, **constraints)
@defer.inlineCallbacks
def fund(self, to_account, amount=None, everything=False,
async def fund(self, to_account, amount=None, everything=False,
outputs=1, broadcast=False, **constraints):
assert self.ledger == to_account.ledger, 'Can only transfer between accounts of the same ledger.'
tx_class = self.ledger.transaction_class
if everything:
utxos = yield self.get_utxos(**constraints)
yield self.ledger.reserve_outputs(utxos)
tx = yield tx_class.create(
utxos = await self.get_utxos(**constraints)
await self.ledger.reserve_outputs(utxos)
tx = await tx_class.create(
inputs=[tx_class.input_class.spend(txo) for txo in utxos],
outputs=[],
funding_accounts=[self],
change_account=to_account
)
elif amount > 0:
to_address = yield to_account.change.get_or_create_usable_address()
to_address = await to_account.change.get_or_create_usable_address()
to_hash160 = to_account.ledger.address_to_hash160(to_address)
tx = yield tx_class.create(
tx = await tx_class.create(
inputs=[],
outputs=[
tx_class.output_class.pay_pubkey_hash(amount//outputs, to_hash160)
@ -406,9 +390,9 @@ class BaseAccount:
raise ValueError('An amount is required.')
if broadcast:
yield self.ledger.broadcast(tx)
await self.ledger.broadcast(tx)
else:
yield self.ledger.release_outputs(
await self.ledger.release_outputs(
[txi.txo_ref.txo for txi in tx.inputs]
)

View file

@ -1,9 +1,8 @@
import logging
from typing import Tuple, List, Sequence
from typing import Tuple, List
import sqlite3
from twisted.internet import defer
from twisted.enterprise import adbapi
import aiosqlite
from torba.hash import TXRefImmutable
from torba.basetransaction import BaseTransaction
@ -107,25 +106,21 @@ def row_dict_or_default(rows, fields, default=None):
class SQLiteMixin:
CREATE_TABLES_QUERY: Sequence[str] = ()
CREATE_TABLES_QUERY: str
def __init__(self, path):
self._db_path = path
self.db: adbapi.ConnectionPool = None
self.db: aiosqlite.Connection = None
self.ledger = None
def open(self):
async def open(self):
log.info("connecting to database: %s", self._db_path)
self.db = adbapi.ConnectionPool(
'sqlite3', self._db_path, cp_min=1, cp_max=1, check_same_thread=False
)
return self.db.runInteraction(
lambda t: t.executescript(self.CREATE_TABLES_QUERY)
)
self.db = aiosqlite.connect(self._db_path)
await self.db.__aenter__()
await self.db.executescript(self.CREATE_TABLES_QUERY)
def close(self):
self.db.close()
return defer.succeed(True)
async def close(self):
await self.db.close()
@staticmethod
def _insert_sql(table: str, data: dict) -> Tuple[str, List]:
@ -247,78 +242,75 @@ class BaseDatabase(SQLiteMixin):
'script': sqlite3.Binary(txo.script.source)
}
def save_transaction_io(self, save_tx, tx: BaseTransaction, address, txhash, history):
async def save_transaction_io(self, save_tx, tx: BaseTransaction, address, txhash, history):
def _steps(t):
if save_tx == 'insert':
self.execute(t, *self._insert_sql('tx', {
if save_tx == 'insert':
await self.db.execute(*self._insert_sql('tx', {
'txid': tx.id,
'raw': sqlite3.Binary(tx.raw),
'height': tx.height,
'position': tx.position,
'is_verified': tx.is_verified
}))
elif save_tx == 'update':
await self.db.execute(*self._update_sql("tx", {
'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified
}, 'txid = ?', (tx.id,)))
existing_txos = [r[0] for r in await self.db.execute_fetchall(*query(
"SELECT position FROM txo", txid=tx.id
))]
for txo in tx.outputs:
if txo.position in existing_txos:
continue
if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == txhash:
await self.db.execute(*self._insert_sql("txo", self.txo_to_row(tx, address, txo)))
elif txo.script.is_pay_script_hash:
# TODO: implement script hash payments
print('Database.save_transaction_io: pay script hash is not implemented!')
# lookup the address associated with each TXI (via its TXO)
txoid_to_address = {r[0]: r[1] for r in await self.db.execute_fetchall(*query(
"SELECT txoid, address FROM txo", txoid__in=[txi.txo_ref.id for txi in tx.inputs]
))}
# list of TXIs that have already been added
existing_txis = [r[0] for r in await self.db.execute_fetchall(*query(
"SELECT txoid FROM txi", txid=tx.id
))]
for txi in tx.inputs:
txoid = txi.txo_ref.id
new_txi = txoid not in existing_txis
address_matches = txoid_to_address.get(txoid) == address
if new_txi and address_matches:
await self.db.execute(*self._insert_sql("txi", {
'txid': tx.id,
'raw': sqlite3.Binary(tx.raw),
'height': tx.height,
'position': tx.position,
'is_verified': tx.is_verified
'txoid': txoid,
'address': address,
}))
elif save_tx == 'update':
self.execute(t, *self._update_sql("tx", {
'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified
}, 'txid = ?', (tx.id,)))
existing_txos = [r[0] for r in self.execute(t, *query(
"SELECT position FROM txo", txid=tx.id
)).fetchall()]
await self._set_address_history(address, history)
for txo in tx.outputs:
if txo.position in existing_txos:
continue
if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == txhash:
self.execute(t, *self._insert_sql("txo", self.txo_to_row(tx, address, txo)))
elif txo.script.is_pay_script_hash:
# TODO: implement script hash payments
print('Database.save_transaction_io: pay script hash is not implemented!')
# lookup the address associated with each TXI (via its TXO)
txoid_to_address = {r[0]: r[1] for r in self.execute(t, *query(
"SELECT txoid, address FROM txo", txoid__in=[txi.txo_ref.id for txi in tx.inputs]
)).fetchall()}
# list of TXIs that have already been added
existing_txis = [r[0] for r in self.execute(t, *query(
"SELECT txoid FROM txi", txid=tx.id
)).fetchall()]
for txi in tx.inputs:
txoid = txi.txo_ref.id
new_txi = txoid not in existing_txis
address_matches = txoid_to_address.get(txoid) == address
if new_txi and address_matches:
self.execute(t, *self._insert_sql("txi", {
'txid': tx.id,
'txoid': txoid,
'address': address,
}))
self._set_address_history(t, address, history)
return self.db.runInteraction(_steps)
def reserve_outputs(self, txos, is_reserved=True):
async def reserve_outputs(self, txos, is_reserved=True):
txoids = [txo.id for txo in txos]
return self.run_operation(
await self.db.execute(
"UPDATE txo SET is_reserved = ? WHERE txoid IN ({})".format(
', '.join(['?']*len(txoids))
), [is_reserved]+txoids
)
def release_outputs(self, txos):
return self.reserve_outputs(txos, is_reserved=False)
async def release_outputs(self, txos):
await self.reserve_outputs(txos, is_reserved=False)
def rewind_blockchain(self, above_height): # pylint: disable=no-self-use
async def rewind_blockchain(self, above_height): # pylint: disable=no-self-use
# TODO:
# 1. delete transactions above_height
# 2. update address histories removing deleted TXs
return defer.succeed(True)
return True
def select_transactions(self, cols, account=None, **constraints):
async def select_transactions(self, cols, account=None, **constraints):
if 'txid' not in constraints and account is not None:
constraints['$account'] = account.public_key.address
constraints['txid__in'] = """
@ -328,13 +320,14 @@ class BaseDatabase(SQLiteMixin):
SELECT txi.txid FROM txi
JOIN pubkey_address USING (address) WHERE pubkey_address.account = :$account
"""
return self.run_query(*query("SELECT {} FROM tx".format(cols), **constraints))
return await self.db.execute_fetchall(
*query("SELECT {} FROM tx".format(cols), **constraints)
)
@defer.inlineCallbacks
def get_transactions(self, my_account=None, **constraints):
async def get_transactions(self, my_account=None, **constraints):
my_account = my_account or constraints.get('account', None)
tx_rows = yield self.select_transactions(
tx_rows = await self.select_transactions(
'txid, raw, height, position, is_verified',
order_by=["height DESC", "position DESC"],
**constraints
@ -352,7 +345,7 @@ class BaseDatabase(SQLiteMixin):
annotated_txos = {
txo.id: txo for txo in
(yield self.get_txos(
(await self.get_txos(
my_account=my_account,
txid__in=txids
))
@ -360,7 +353,7 @@ class BaseDatabase(SQLiteMixin):
referenced_txos = {
txo.id: txo for txo in
(yield self.get_txos(
(await self.get_txos(
my_account=my_account,
txoid__in=query("SELECT txoid FROM txi", **{'txid__in': txids})[0]
))
@ -380,33 +373,30 @@ class BaseDatabase(SQLiteMixin):
return txs
@defer.inlineCallbacks
def get_transaction_count(self, **constraints):
async def get_transaction_count(self, **constraints):
constraints.pop('offset', None)
constraints.pop('limit', None)
constraints.pop('order_by', None)
count = yield self.select_transactions('count(*)', **constraints)
count = await self.select_transactions('count(*)', **constraints)
return count[0][0]
@defer.inlineCallbacks
def get_transaction(self, **constraints):
txs = yield self.get_transactions(limit=1, **constraints)
async def get_transaction(self, **constraints):
txs = await self.get_transactions(limit=1, **constraints)
if txs:
return txs[0]
def select_txos(self, cols, **constraints):
return self.run_query(*query(
async def select_txos(self, cols, **constraints):
return await self.db.execute_fetchall(*query(
"SELECT {} FROM txo"
" JOIN pubkey_address USING (address)"
" JOIN tx USING (txid)".format(cols), **constraints
))
@defer.inlineCallbacks
def get_txos(self, my_account=None, **constraints):
async def get_txos(self, my_account=None, **constraints):
my_account = my_account or constraints.get('account', None)
if isinstance(my_account, BaseAccount):
my_account = my_account.public_key.address
rows = yield self.select_txos(
rows = await self.select_txos(
"amount, script, txid, tx.height, txo.position, chain, account", **constraints
)
output_class = self.ledger.transaction_class.output_class
@ -421,12 +411,11 @@ class BaseDatabase(SQLiteMixin):
) for row in rows
]
@defer.inlineCallbacks
def get_txo_count(self, **constraints):
async def get_txo_count(self, **constraints):
constraints.pop('offset', None)
constraints.pop('limit', None)
constraints.pop('order_by', None)
count = yield self.select_txos('count(*)', **constraints)
count = await self.select_txos('count(*)', **constraints)
return count[0][0]
@staticmethod
@ -442,37 +431,33 @@ class BaseDatabase(SQLiteMixin):
self.constrain_utxo(constraints)
return self.get_txo_count(**constraints)
@defer.inlineCallbacks
def get_balance(self, **constraints):
async def get_balance(self, **constraints):
self.constrain_utxo(constraints)
balance = yield self.select_txos('SUM(amount)', **constraints)
balance = await self.select_txos('SUM(amount)', **constraints)
return balance[0][0] or 0
def select_addresses(self, cols, **constraints):
return self.run_query(*query(
"SELECT {} FROM pubkey_address".format(cols), **constraints
))
async def select_addresses(self, cols, **constraints):
return await self.db.execute_fetchall(*query(
"SELECT {} FROM pubkey_address".format(cols), **constraints
))
@defer.inlineCallbacks
def get_addresses(self, cols=('address', 'account', 'chain', 'position', 'used_times'), **constraints):
addresses = yield self.select_addresses(', '.join(cols), **constraints)
async def get_addresses(self, cols=('address', 'account', 'chain', 'position', 'used_times'), **constraints):
addresses = await self.select_addresses(', '.join(cols), **constraints)
return rows_to_dict(addresses, cols)
@defer.inlineCallbacks
def get_address_count(self, **constraints):
count = yield self.select_addresses('count(*)', **constraints)
async def get_address_count(self, **constraints):
count = await self.select_addresses('count(*)', **constraints)
return count[0][0]
@defer.inlineCallbacks
def get_address(self, **constraints):
addresses = yield self.get_addresses(
async def get_address(self, **constraints):
addresses = await self.get_addresses(
cols=('address', 'account', 'chain', 'position', 'pubkey', 'history', 'used_times'),
limit=1, **constraints
)
if addresses:
return addresses[0]
def add_keys(self, account, chain, keys):
async def add_keys(self, account, chain, keys):
sql = (
"insert into pubkey_address "
"(address, account, chain, position, pubkey) "
@ -484,14 +469,13 @@ class BaseDatabase(SQLiteMixin):
pubkey.address, account.public_key.address, chain, position,
sqlite3.Binary(pubkey.pubkey_bytes)
))
return self.run_operation(sql, values)
await self.db.execute(sql, values)
@classmethod
def _set_address_history(cls, t, address, history):
cls.execute(
t, "UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
async def _set_address_history(self, address, history):
await self.db.execute(
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
(history, history.count(':')//2, address)
)
def set_address_history(self, address, history):
return self.db.runInteraction(lambda t: self._set_address_history(t, address, history))
async def set_address_history(self, address, history):
await self._set_address_history(address, history)

View file

@ -1,11 +1,10 @@
import os
import asyncio
import logging
from io import BytesIO
from typing import Optional, Iterator, Tuple
from binascii import hexlify
from twisted.internet import threads, defer
from torba.util import ArithUint256
from torba.hash import double_sha256
@ -36,16 +35,14 @@ class BaseHeaders:
self.io = BytesIO()
self.path = path
self._size: Optional[int] = None
self._header_connect_lock = defer.DeferredLock()
self._header_connect_lock = asyncio.Lock()
def open(self):
async def open(self):
if self.path != ':memory:':
self.io = open(self.path, 'a+b')
return defer.succeed(True)
def close(self):
async def close(self):
self.io.close()
return defer.succeed(True)
@staticmethod
def serialize(header: dict) -> bytes:
@ -95,16 +92,15 @@ class BaseHeaders:
return b'0' * 64
return hexlify(double_sha256(header)[::-1])
@defer.inlineCallbacks
def connect(self, start: int, headers: bytes):
async def connect(self, start: int, headers: bytes) -> int:
added = 0
bail = False
yield self._header_connect_lock.acquire()
try:
loop = asyncio.get_running_loop()
async with self._header_connect_lock:
for height, chunk in self._iterate_chunks(start, headers):
try:
# validate_chunk() is CPU bound and reads previous chunks from file system
yield threads.deferToThread(self.validate_chunk, height, chunk)
await loop.run_in_executor(None, self.validate_chunk, height, chunk)
except InvalidHeader as e:
bail = True
chunk = chunk[:(height-e.height+1)*self.header_size]
@ -115,14 +111,12 @@ class BaseHeaders:
self.io.truncate()
# .seek()/.write()/.truncate() might also .flush() when needed
# the goal here is mainly to ensure we're definitely flush()'ing
yield threads.deferToThread(self.io.flush)
await loop.run_in_executor(None, self.io.flush)
self._size = None
added += written
if bail:
break
finally:
self._header_connect_lock.release()
defer.returnValue(added)
return added
def validate_chunk(self, height, chunk):
previous_hash, previous_header, previous_previous_header = None, None, None

View file

@ -1,4 +1,5 @@
import os
import asyncio
import logging
from binascii import hexlify, unhexlify
from io import StringIO
@ -7,8 +8,6 @@ from typing import Dict, Type, Iterable
from operator import itemgetter
from collections import namedtuple
from twisted.internet import defer
from torba import baseaccount
from torba import basenetwork
from torba import basetransaction
@ -104,8 +103,8 @@ class BaseLedger(metaclass=LedgerRegistry):
)
self._transaction_processing_locks = {}
self._utxo_reservation_lock = defer.DeferredLock()
self._header_processing_lock = defer.DeferredLock()
self._utxo_reservation_lock = asyncio.Lock()
self._header_processing_lock = asyncio.Lock()
@classmethod
def get_id(cls):
@ -135,41 +134,32 @@ class BaseLedger(metaclass=LedgerRegistry):
def add_account(self, account: baseaccount.BaseAccount):
self.accounts.append(account)
@defer.inlineCallbacks
def get_private_key_for_address(self, address):
match = yield self.db.get_address(address=address)
async def get_private_key_for_address(self, address):
match = await self.db.get_address(address=address)
if match:
for account in self.accounts:
if match['account'] == account.public_key.address:
return account.get_private_key(match['chain'], match['position'])
@defer.inlineCallbacks
def get_effective_amount_estimators(self, funding_accounts: Iterable[baseaccount.BaseAccount]):
async def get_effective_amount_estimators(self, funding_accounts: Iterable[baseaccount.BaseAccount]):
estimators = []
for account in funding_accounts:
utxos = yield account.get_utxos()
utxos = await account.get_utxos()
for utxo in utxos:
estimators.append(utxo.get_estimator(self))
return estimators
@defer.inlineCallbacks
def get_spendable_utxos(self, amount: int, funding_accounts):
yield self._utxo_reservation_lock.acquire()
try:
txos = yield self.get_effective_amount_estimators(funding_accounts)
async def get_spendable_utxos(self, amount: int, funding_accounts):
async with self._utxo_reservation_lock:
txos = await self.get_effective_amount_estimators(funding_accounts)
selector = CoinSelector(
txos, amount,
self.transaction_class.output_class.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(self)
)
spendables = selector.select()
if spendables:
yield self.reserve_outputs(s.txo for s in spendables)
except Exception:
log.exception('Failed to get spendable utxos:')
raise
finally:
self._utxo_reservation_lock.release()
return spendables
await self.reserve_outputs(s.txo for s in spendables)
return spendables
def reserve_outputs(self, txos):
return self.db.reserve_outputs(txos)
@ -177,16 +167,14 @@ class BaseLedger(metaclass=LedgerRegistry):
def release_outputs(self, txos):
return self.db.release_outputs(txos)
@defer.inlineCallbacks
def get_local_status(self, address):
address_details = yield self.db.get_address(address=address)
async def get_local_status(self, address):
address_details = await self.db.get_address(address=address)
history = address_details['history'] or ''
h = sha256(history.encode())
return hexlify(h)
@defer.inlineCallbacks
def get_local_history(self, address):
address_details = yield self.db.get_address(address=address)
async def get_local_history(self, address):
address_details = await self.db.get_address(address=address)
history = address_details['history'] or ''
parts = history.split(':')[:-1]
return list(zip(parts[0::2], map(int, parts[1::2])))
@ -203,43 +191,40 @@ class BaseLedger(metaclass=LedgerRegistry):
working_branch = double_sha256(combined)
return hexlify(working_branch[::-1])
def validate_transaction_and_set_position(self, tx, height, merkle):
async def validate_transaction_and_set_position(self, tx, height):
if not height <= len(self.headers):
return False
merkle = await self.network.get_merkle(tx.id, height)
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
header = self.headers[height]
tx.position = merkle['pos']
tx.is_verified = merkle_root == header['merkle_root']
@defer.inlineCallbacks
def start(self):
async def start(self):
if not os.path.exists(self.path):
os.mkdir(self.path)
yield defer.gatherResults([
await asyncio.gather(
self.db.open(),
self.headers.open()
])
)
first_connection = self.network.on_connected.first
self.network.start()
yield first_connection
yield self.join_network()
asyncio.ensure_future(self.network.start())
await first_connection
await self.join_network()
self.network.on_connected.listen(self.join_network)
@defer.inlineCallbacks
def join_network(self, *args):
async def join_network(self, *args):
log.info("Subscribing and updating accounts.")
yield self.update_headers()
yield self.network.subscribe_headers()
yield self.update_accounts()
await self.update_headers()
await self.network.subscribe_headers()
await self.update_accounts()
@defer.inlineCallbacks
def stop(self):
yield self.network.stop()
yield self.db.close()
yield self.headers.close()
async def stop(self):
await self.network.stop()
await self.db.close()
await self.headers.close()
@defer.inlineCallbacks
def update_headers(self, height=None, headers=None, subscription_update=False):
async def update_headers(self, height=None, headers=None, subscription_update=False):
rewound = 0
while True:
@ -251,14 +236,14 @@ class BaseLedger(metaclass=LedgerRegistry):
subscription_update = False
if not headers:
header_response = yield self.network.get_headers(height, 2001)
header_response = await self.network.get_headers(height, 2001)
headers = header_response['hex']
if not headers:
# Nothing to do, network thinks we're already at the latest height.
return
added = yield self.headers.connect(height, unhexlify(headers))
added = await self.headers.connect(height, unhexlify(headers))
if added > 0:
height += added
self._on_header_controller.add(
@ -268,7 +253,7 @@ class BaseLedger(metaclass=LedgerRegistry):
# we started rewinding blocks and apparently found
# a new chain
rewound = 0
yield self.db.rewind_blockchain(height)
await self.db.rewind_blockchain(height)
if subscription_update:
# subscription updates are for latest header already
@ -310,66 +295,37 @@ class BaseLedger(metaclass=LedgerRegistry):
# robust sync, turn off subscription update shortcut
subscription_update = False
@defer.inlineCallbacks
def receive_header(self, response):
yield self._header_processing_lock.acquire()
try:
async def receive_header(self, response):
async with self._header_processing_lock:
header = response[0]
yield self.update_headers(
await self.update_headers(
height=header['height'], headers=header['hex'], subscription_update=True
)
finally:
self._header_processing_lock.release()
def update_accounts(self):
return defer.DeferredList([
async def update_accounts(self):
return await asyncio.gather(*(
self.update_account(a) for a in self.accounts
])
))
@defer.inlineCallbacks
def update_account(self, account): # type: (baseaccount.BaseAccount) -> defer.Defferred
async def update_account(self, account: baseaccount.BaseAccount):
# Before subscribing, download history for any addresses that don't have any,
# this avoids situation where we're getting status updates to addresses we know
# need to update anyways. Continue to get history and create more addresses until
# all missing addresses are created and history for them is fully restored.
yield account.ensure_address_gap()
addresses = yield account.get_addresses(used_times=0)
await account.ensure_address_gap()
addresses = await account.get_addresses(used_times=0)
while addresses:
yield defer.DeferredList([
self.update_history(a) for a in addresses
])
addresses = yield account.ensure_address_gap()
await asyncio.gather(*(self.update_history(a) for a in addresses))
addresses = await account.ensure_address_gap()
# By this point all of the addresses should be restored and we
# can now subscribe all of them to receive updates.
all_addresses = yield account.get_addresses()
yield defer.DeferredList(
list(map(self.subscribe_history, all_addresses))
)
all_addresses = await account.get_addresses()
await asyncio.gather(*(self.subscribe_history(a) for a in all_addresses))
@defer.inlineCallbacks
def _prefetch_history(self, remote_history, local_history):
proofs, network_txs, deferreds = {}, {}, []
for i, (hex_id, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)):
if i < len(local_history) and local_history[i] == (hex_id, remote_height):
continue
if remote_height > 0:
deferreds.append(
self.network.get_merkle(hex_id, remote_height).addBoth(
lambda result, txid: proofs.__setitem__(txid, result), hex_id)
)
deferreds.append(
self.network.get_transaction(hex_id).addBoth(
lambda result, txid: network_txs.__setitem__(txid, result), hex_id)
)
yield defer.DeferredList(deferreds)
return proofs, network_txs
@defer.inlineCallbacks
def update_history(self, address):
remote_history = yield self.network.get_history(address)
local_history = yield self.get_local_history(address)
proofs, network_txs = yield self._prefetch_history(remote_history, local_history)
async def update_history(self, address):
remote_history = await self.network.get_history(address)
local_history = await self.get_local_history(address)
synced_history = StringIO()
for i, (hex_id, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)):
@ -379,30 +335,29 @@ class BaseLedger(metaclass=LedgerRegistry):
if i < len(local_history) and local_history[i] == (hex_id, remote_height):
continue
lock = self._transaction_processing_locks.setdefault(hex_id, defer.DeferredLock())
lock = self._transaction_processing_locks.setdefault(hex_id, asyncio.Lock())
yield lock.acquire()
await lock.acquire()
try:
# see if we have a local copy of transaction, otherwise fetch it from server
tx = yield self.db.get_transaction(txid=hex_id)
tx = await self.db.get_transaction(txid=hex_id)
save_tx = None
if tx is None:
_raw = network_txs[hex_id]
_raw = await self.network.get_transaction(hex_id)
tx = self.transaction_class(unhexlify(_raw))
save_tx = 'insert'
tx.height = remote_height
if remote_height > 0 and (not tx.is_verified or tx.position == -1):
self.validate_transaction_and_set_position(tx, remote_height, proofs[hex_id])
await self.validate_transaction_and_set_position(tx, remote_height)
if save_tx is None:
save_tx = 'update'
yield self.db.save_transaction_io(
save_tx, tx, address, self.address_to_hash160(address),
synced_history.getvalue()
await self.db.save_transaction_io(
save_tx, tx, address, self.address_to_hash160(address), synced_history.getvalue()
)
log.debug(
@ -412,28 +367,22 @@ class BaseLedger(metaclass=LedgerRegistry):
self._on_transaction_controller.add(TransactionEvent(address, tx))
except Exception:
log.exception('Failed to synchronize transaction:')
raise
finally:
lock.release()
if not lock.locked and hex_id in self._transaction_processing_locks:
if not lock.locked() and hex_id in self._transaction_processing_locks:
del self._transaction_processing_locks[hex_id]
@defer.inlineCallbacks
def subscribe_history(self, address):
remote_status = yield self.network.subscribe_address(address)
local_status = yield self.get_local_status(address)
async def subscribe_history(self, address):
remote_status = await self.network.subscribe_address(address)
local_status = await self.get_local_status(address)
if local_status != remote_status:
yield self.update_history(address)
await self.update_history(address)
@defer.inlineCallbacks
def receive_status(self, response):
async def receive_status(self, response):
address, remote_status = response
local_status = yield self.get_local_status(address)
local_status = await self.get_local_status(address)
if local_status != remote_status:
yield self.update_history(address)
await self.update_history(address)
def broadcast(self, tx):
return self.network.broadcast(hexlify(tx.raw).decode())

View file

@ -1,6 +1,6 @@
import asyncio
import logging
from typing import Type, MutableSequence, MutableMapping
from twisted.internet import defer
from torba.baseledger import BaseLedger, LedgerRegistry
from torba.wallet import Wallet, WalletStorage
@ -41,11 +41,10 @@ class BaseWalletManager:
self.wallets.append(wallet)
return wallet
@defer.inlineCallbacks
def get_detailed_accounts(self, confirmations=6, show_seed=False):
async def get_detailed_accounts(self, confirmations=6, show_seed=False):
ledgers = {}
for i, account in enumerate(self.accounts):
details = yield account.get_details(confirmations=confirmations, show_seed=True)
details = await account.get_details(confirmations=confirmations, show_seed=True)
details['is_default_account'] = i == 0
ledger_id = account.ledger.get_id()
ledgers.setdefault(ledger_id, [])
@ -68,16 +67,14 @@ class BaseWalletManager:
for account in wallet.accounts:
yield account
@defer.inlineCallbacks
def start(self):
async def start(self):
self.running = True
yield defer.DeferredList([
await asyncio.gather(*(
l.start() for l in self.ledgers.values()
])
))
@defer.inlineCallbacks
def stop(self):
yield defer.DeferredList([
async def stop(self):
await asyncio.gather(*(
l.stop() for l in self.ledgers.values()
])
))
self.running = False

View file

@ -1,145 +1,36 @@
import json
import socket
import logging
from itertools import cycle
from twisted.internet import defer, reactor, protocol
from twisted.application.internet import ClientService, CancelledError
from twisted.internet.endpoints import clientFromString
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python import failure
from aiorpcx import ClientSession as BaseClientSession
from torba import __version__
from torba.stream import StreamController
from torba.constants import TIMEOUT
log = logging.getLogger(__name__)
class StratumClientProtocol(LineOnlyReceiver):
delimiter = b'\n'
MAX_LENGTH = 2000000
class ClientSession(BaseClientSession):
def __init__(self):
self.request_id = 0
self.lookup_table = {}
self.session = {}
self.network = None
self.on_disconnected_controller = StreamController()
self.on_disconnected = self.on_disconnected_controller.stream
def _get_id(self):
self.request_id += 1
return self.request_id
@property
def _ip(self):
return self.transport.getPeer().host
def get_session(self):
return self.session
def connectionMade(self):
try:
self.transport.setTcpNoDelay(True)
self.transport.setTcpKeepAlive(True)
if hasattr(socket, "TCP_KEEPIDLE"):
self.transport.socket.setsockopt(
socket.SOL_TCP, socket.TCP_KEEPIDLE, 120
# Seconds before sending keepalive probes
)
else:
log.debug("TCP_KEEPIDLE not available")
if hasattr(socket, "TCP_KEEPINTVL"):
self.transport.socket.setsockopt(
socket.SOL_TCP, socket.TCP_KEEPINTVL, 1
# Interval in seconds between keepalive probes
)
else:
log.debug("TCP_KEEPINTVL not available")
if hasattr(socket, "TCP_KEEPCNT"):
self.transport.socket.setsockopt(
socket.SOL_TCP, socket.TCP_KEEPCNT, 5
# Failed keepalive probles before declaring other end dead
)
else:
log.debug("TCP_KEEPCNT not available")
except Exception as err: # pylint: disable=broad-except
# Supported only by the socket transport,
# but there's really no better place in code to trigger this.
log.warning("Error setting up socket: %s", err)
def connectionLost(self, reason=None):
self.connected = 0
self.on_disconnected_controller.add(True)
for deferred in self.lookup_table.values():
if not deferred.called:
deferred.errback(TimeoutError("Connection dropped."))
def lineReceived(self, line):
log.debug('received: %s', line)
try:
message = json.loads(line)
except (ValueError, TypeError):
raise ValueError("Cannot decode message '{}'".format(line.strip()))
if message.get('id'):
try:
d = self.lookup_table.pop(message['id'])
if message.get('error'):
d.errback(RuntimeError(message['error']))
else:
d.callback(message.get('result'))
except KeyError:
raise LookupError(
"Lookup for deferred object for message ID '{}' failed.".format(message['id']))
elif message.get('method') in self.network.subscription_controllers:
controller = self.network.subscription_controllers[message['method']]
controller.add(message.get('params'))
else:
log.warning("Cannot handle message '%s'", line)
def rpc(self, method, *args):
message_id = self._get_id()
message = json.dumps({
'id': message_id,
'method': method,
'params': args
})
log.debug('sent: %s', message)
self.sendLine(message.encode('latin-1'))
d = self.lookup_table[message_id] = defer.Deferred()
d.addTimeout(
TIMEOUT, reactor, onTimeoutCancel=lambda *_: failure.Failure(TimeoutError(
"Timeout: Stratum request for '%s' took more than %s seconds" % (method, TIMEOUT)))
)
return d
class StratumClientFactory(protocol.ClientFactory):
protocol = StratumClientProtocol
def __init__(self, network):
def __init__(self, *args, network, **kwargs):
self.network = network
self.client = None
super().__init__(*args, **kwargs)
self._on_disconnect_controller = StreamController()
self.on_disconnected = self._on_disconnect_controller.stream
def buildProtocol(self, addr):
client = self.protocol()
client.factory = self
client.network = self.network
self.client = client
return client
async def handle_request(self, request):
controller = self.network.subscription_controllers[request.method]
controller.add(request.args)
def connection_lost(self, exc):
super().connection_lost(exc)
self._on_disconnect_controller.add(True)
class BaseNetwork:
def __init__(self, ledger):
self.config = ledger.config
self.client = None
self.service = None
self.client: ClientSession = None
self.running = False
self._on_connected_controller = StreamController()
@ -156,48 +47,35 @@ class BaseNetwork:
'blockchain.address.subscribe': self._on_status_controller,
}
@defer.inlineCallbacks
def start(self):
async def start(self):
self.running = True
for server in cycle(self.config['default_servers']):
connection_string = 'tcp:{}:{}'.format(*server)
endpoint = clientFromString(reactor, connection_string)
log.debug("Attempting connection to SPV wallet server: %s", connection_string)
self.service = ClientService(endpoint, StratumClientFactory(self))
self.service.startService()
self.client = ClientSession(*server, network=self)
try:
self.client = yield self.service.whenConnected(failAfterFailures=2)
yield self.ensure_server_version()
log.info("Successfully connected to SPV wallet server: %s", connection_string)
await self.client.create_connection()
await self.ensure_server_version()
log.info("Successfully connected to SPV wallet server: %s", )
self._on_connected_controller.add(True)
yield self.client.on_disconnected.first
except CancelledError:
return
await self.client.on_disconnected.first
except Exception: # pylint: disable=broad-except
log.exception("Connecting to %s raised an exception:", connection_string)
finally:
self.client = None
if self.service is not None:
self.service.stopService()
if not self.running:
return
def stop(self):
async def stop(self):
self.running = False
if self.service is not None:
self.service.stopService()
if self.is_connected:
return self.client.on_disconnected.first
else:
return defer.succeed(True)
await self.client.close()
await self.client.on_disconnected.first
@property
def is_connected(self):
return self.client is not None and self.client.connected
return self.client is not None and not self.client.is_closing()
def rpc(self, list_or_method, *args):
if self.is_connected:
return self.client.rpc(list_or_method, *args)
return self.client.send_request(list_or_method, args)
else:
raise ConnectionError("Attempting to send rpc request when connection is not available.")

View file

@ -3,8 +3,6 @@ import typing
from typing import List, Iterable, Optional
from binascii import hexlify
from twisted.internet import defer
from torba.basescript import BaseInputScript, BaseOutputScript
from torba.baseaccount import BaseAccount
from torba.constants import COIN, NULL_HASH32
@ -426,8 +424,7 @@ class BaseTransaction:
return ledger
@classmethod
@defer.inlineCallbacks
def create(cls, inputs: Iterable[BaseInput], outputs: Iterable[BaseOutput],
async def create(cls, inputs: Iterable[BaseInput], outputs: Iterable[BaseOutput],
funding_accounts: Iterable[BaseAccount], change_account: BaseAccount):
""" Find optimal set of inputs when only outputs are provided; add change
outputs if only inputs are provided or if inputs are greater than outputs. """
@ -450,7 +447,7 @@ class BaseTransaction:
if payment < cost:
deficit = cost - payment
spendables = yield ledger.get_spendable_utxos(deficit, funding_accounts)
spendables = await ledger.get_spendable_utxos(deficit, funding_accounts)
if not spendables:
raise ValueError('Not enough funds to cover this transaction.')
payment += sum(s.effective_amount for s in spendables)
@ -463,28 +460,27 @@ class BaseTransaction:
)
change = payment - cost
if change > cost_of_change:
change_address = yield change_account.change.get_or_create_usable_address()
change_address = await change_account.change.get_or_create_usable_address()
change_hash160 = change_account.ledger.address_to_hash160(change_address)
change_amount = change - cost_of_change
change_output = cls.output_class.pay_pubkey_hash(change_amount, change_hash160)
change_output.is_change = True
tx.add_outputs([cls.output_class.pay_pubkey_hash(change_amount, change_hash160)])
yield tx.sign(funding_accounts)
await tx.sign(funding_accounts)
except Exception as e:
log.exception('Failed to synchronize transaction:')
yield ledger.release_outputs(tx.outputs)
await ledger.release_outputs(tx.outputs)
raise e
defer.returnValue(tx)
return tx
@staticmethod
def signature_hash_type(hash_type):
return hash_type
@defer.inlineCallbacks
def sign(self, funding_accounts: Iterable[BaseAccount]) -> defer.Deferred:
async def sign(self, funding_accounts: Iterable[BaseAccount]):
ledger = self.ensure_all_have_same_ledger(funding_accounts)
for i, txi in enumerate(self._inputs):
assert txi.script is not None
@ -492,7 +488,7 @@ class BaseTransaction:
txo_script = txi.txo_ref.txo.script
if txo_script.is_pay_pubkey_hash:
address = ledger.hash160_to_address(txo_script.values['pubkey_hash'])
private_key = yield ledger.get_private_key_for_address(address)
private_key = await ledger.get_private_key_for_address(address)
tx = self._serialize_for_signature(i)
txi.script.values['signature'] = \
private_key.sign(tx) + bytes((self.signature_hash_type(1),))

View file

@ -1,6 +1,4 @@
import asyncio
from twisted.internet.defer import Deferred
from twisted.python.failure import Failure
class BroadcastSubscription:
@ -31,11 +29,13 @@ class BroadcastSubscription:
def _add(self, data):
if self.can_fire and self._on_data is not None:
self._on_data(data)
maybe_coroutine = self._on_data(data)
if asyncio.iscoroutine(maybe_coroutine):
asyncio.ensure_future(maybe_coroutine)
def _add_error(self, error, traceback):
def _add_error(self, exception):
if self.can_fire and self._on_error is not None:
self._on_error(error, traceback)
self._on_error(exception)
def _close(self):
if self.can_fire and self._on_done is not None:
@ -66,9 +66,9 @@ class StreamController:
for subscription in self._iterate_subscriptions:
subscription._add(event)
def add_error(self, error, traceback):
def add_error(self, exception):
for subscription in self._iterate_subscriptions:
subscription._add_error(error, traceback)
subscription._add_error(exception)
def close(self):
for subscription in self._iterate_subscriptions:
@ -108,38 +108,35 @@ class Stream:
def listen(self, on_data, on_error=None, on_done=None):
return self._controller._listen(on_data, on_error, on_done)
def deferred_where(self, condition):
deferred = Deferred()
def where(self, condition) -> asyncio.Future:
future = asyncio.get_event_loop().create_future()
def where_test(value):
if condition(value):
self._cancel_and_callback(subscription, deferred, value)
self._cancel_and_callback(subscription, future, value)
subscription = self.listen(
where_test,
lambda error, traceback: self._cancel_and_error(subscription, deferred, error, traceback)
lambda exception: self._cancel_and_error(subscription, future, exception)
)
return deferred
def where(self, condition):
return self.deferred_where(condition).asFuture(asyncio.get_event_loop())
return future
@property
def first(self):
deferred = Deferred()
future = asyncio.get_event_loop().create_future()
subscription = self.listen(
lambda value: self._cancel_and_callback(subscription, deferred, value),
lambda error, traceback: self._cancel_and_error(subscription, deferred, error, traceback)
lambda value: self._cancel_and_callback(subscription, future, value),
lambda exception: self._cancel_and_error(subscription, future, exception)
)
return deferred
return future
@staticmethod
def _cancel_and_callback(subscription, deferred, value):
def _cancel_and_callback(subscription: BroadcastSubscription, future: asyncio.Future, value):
subscription.cancel()
deferred.callback(value)
future.set_result(value)
@staticmethod
def _cancel_and_error(subscription, deferred, error, traceback):
def _cancel_and_error(subscription: BroadcastSubscription, future: asyncio.Future, exception):
subscription.cancel()
deferred.errback(Failure(error, exc_tb=traceback))
future.set_exception(exception)