twisted -> asyncio
This commit is contained in:
parent
0ce4b9a7de
commit
2c5fd4aade
21 changed files with 480 additions and 721 deletions
|
@ -1,5 +1,5 @@
|
|||
import asyncio
|
||||
from orchstr8.testcase import IntegrationTestCase, d2f
|
||||
from orchstr8.testcase import IntegrationTestCase
|
||||
from torba.constants import COIN
|
||||
|
||||
|
||||
|
@ -9,14 +9,14 @@ class BasicTransactionTests(IntegrationTestCase):
|
|||
|
||||
async def test_sending_and_receiving(self):
|
||||
account1, account2 = self.account, self.wallet.generate_account(self.ledger)
|
||||
await d2f(self.ledger.update_account(account2))
|
||||
await self.ledger.update_account(account2)
|
||||
|
||||
self.assertEqual(await self.get_balance(account1), 0)
|
||||
self.assertEqual(await self.get_balance(account2), 0)
|
||||
|
||||
sendtxids = []
|
||||
for i in range(5):
|
||||
address1 = await d2f(account1.receiving.get_or_create_usable_address())
|
||||
address1 = await account1.receiving.get_or_create_usable_address()
|
||||
sendtxid = await self.blockchain.send_to_address(address1, 1.1)
|
||||
sendtxids.append(sendtxid)
|
||||
await self.on_transaction_id(sendtxid) # mempool
|
||||
|
@ -28,13 +28,13 @@ class BasicTransactionTests(IntegrationTestCase):
|
|||
self.assertEqual(round(await self.get_balance(account1)/COIN, 1), 5.5)
|
||||
self.assertEqual(round(await self.get_balance(account2)/COIN, 1), 0)
|
||||
|
||||
address2 = await d2f(account2.receiving.get_or_create_usable_address())
|
||||
address2 = await account2.receiving.get_or_create_usable_address()
|
||||
hash2 = self.ledger.address_to_hash160(address2)
|
||||
tx = await d2f(self.ledger.transaction_class.create(
|
||||
tx = await self.ledger.transaction_class.create(
|
||||
[],
|
||||
[self.ledger.transaction_class.output_class.pay_pubkey_hash(2*COIN, hash2)],
|
||||
[account1], account1
|
||||
))
|
||||
)
|
||||
await self.broadcast(tx)
|
||||
await self.on_transaction(tx) # mempool
|
||||
await self.blockchain.generate(1)
|
||||
|
@ -43,18 +43,18 @@ class BasicTransactionTests(IntegrationTestCase):
|
|||
self.assertEqual(round(await self.get_balance(account1)/COIN, 1), 3.5)
|
||||
self.assertEqual(round(await self.get_balance(account2)/COIN, 1), 2.0)
|
||||
|
||||
utxos = await d2f(self.account.get_utxos())
|
||||
tx = await d2f(self.ledger.transaction_class.create(
|
||||
utxos = await self.account.get_utxos()
|
||||
tx = await self.ledger.transaction_class.create(
|
||||
[self.ledger.transaction_class.input_class.spend(utxos[0])],
|
||||
[],
|
||||
[account1], account1
|
||||
))
|
||||
)
|
||||
await self.broadcast(tx)
|
||||
await self.on_transaction(tx) # mempool
|
||||
await self.blockchain.generate(1)
|
||||
await self.on_transaction(tx) # confirmed
|
||||
|
||||
txs = await d2f(account1.get_transactions())
|
||||
txs = await account1.get_transactions()
|
||||
tx = txs[1]
|
||||
self.assertEqual(round(tx.inputs[0].txo_ref.txo.amount/COIN, 1), 1.1)
|
||||
self.assertEqual(round(tx.inputs[1].txo_ref.txo.amount/COIN, 1), 1.1)
|
||||
|
|
|
@ -1,25 +1,26 @@
|
|||
from binascii import hexlify
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet import defer
|
||||
|
||||
from orchstr8.testcase import AsyncioTestCase
|
||||
|
||||
from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class
|
||||
from torba.baseaccount import HierarchicalDeterministic, SingleKey
|
||||
from torba.wallet import Wallet
|
||||
|
||||
|
||||
class TestHierarchicalDeterministicAccount(unittest.TestCase):
|
||||
class TestHierarchicalDeterministicAccount(AsyncioTestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
async def asyncSetUp(self):
|
||||
self.ledger = ledger_class({
|
||||
'db': ledger_class.database_class(':memory:'),
|
||||
'headers': ledger_class.headers_class(':memory:'),
|
||||
})
|
||||
yield self.ledger.db.open()
|
||||
await self.ledger.db.open()
|
||||
self.account = self.ledger.account_class.generate(self.ledger, Wallet(), "torba")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_generate_account(self):
|
||||
async def asyncTearDown(self):
|
||||
await self.ledger.db.close()
|
||||
|
||||
async def test_generate_account(self):
|
||||
account = self.account
|
||||
|
||||
self.assertEqual(account.ledger, self.ledger)
|
||||
|
@ -27,81 +28,77 @@ class TestHierarchicalDeterministicAccount(unittest.TestCase):
|
|||
self.assertEqual(account.public_key.ledger, self.ledger)
|
||||
self.assertEqual(account.private_key.public_key, account.public_key)
|
||||
|
||||
addresses = yield account.receiving.get_addresses()
|
||||
addresses = await account.receiving.get_addresses()
|
||||
self.assertEqual(len(addresses), 0)
|
||||
addresses = yield account.change.get_addresses()
|
||||
addresses = await account.change.get_addresses()
|
||||
self.assertEqual(len(addresses), 0)
|
||||
|
||||
yield account.ensure_address_gap()
|
||||
await account.ensure_address_gap()
|
||||
|
||||
addresses = yield account.receiving.get_addresses()
|
||||
addresses = await account.receiving.get_addresses()
|
||||
self.assertEqual(len(addresses), 20)
|
||||
addresses = yield account.change.get_addresses()
|
||||
addresses = await account.change.get_addresses()
|
||||
self.assertEqual(len(addresses), 6)
|
||||
|
||||
addresses = yield account.get_addresses()
|
||||
addresses = await account.get_addresses()
|
||||
self.assertEqual(len(addresses), 26)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_generate_keys_over_batch_threshold_saves_it_properly(self):
|
||||
yield self.account.receiving.generate_keys(0, 200)
|
||||
records = yield self.account.receiving.get_address_records()
|
||||
async def test_generate_keys_over_batch_threshold_saves_it_properly(self):
|
||||
await self.account.receiving.generate_keys(0, 200)
|
||||
records = await self.account.receiving.get_address_records()
|
||||
self.assertEqual(201, len(records))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_ensure_address_gap(self):
|
||||
async def test_ensure_address_gap(self):
|
||||
account = self.account
|
||||
|
||||
self.assertIsInstance(account.receiving, HierarchicalDeterministic)
|
||||
|
||||
yield account.receiving.generate_keys(4, 7)
|
||||
yield account.receiving.generate_keys(0, 3)
|
||||
yield account.receiving.generate_keys(8, 11)
|
||||
records = yield account.receiving.get_address_records()
|
||||
await account.receiving.generate_keys(4, 7)
|
||||
await account.receiving.generate_keys(0, 3)
|
||||
await account.receiving.generate_keys(8, 11)
|
||||
records = await account.receiving.get_address_records()
|
||||
self.assertEqual(
|
||||
[r['position'] for r in records],
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
|
||||
)
|
||||
|
||||
# we have 12, but default gap is 20
|
||||
new_keys = yield account.receiving.ensure_address_gap()
|
||||
new_keys = await account.receiving.ensure_address_gap()
|
||||
self.assertEqual(len(new_keys), 8)
|
||||
records = yield account.receiving.get_address_records()
|
||||
records = await account.receiving.get_address_records()
|
||||
self.assertEqual(
|
||||
[r['position'] for r in records],
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
|
||||
)
|
||||
|
||||
# case #1: no new addresses needed
|
||||
empty = yield account.receiving.ensure_address_gap()
|
||||
empty = await account.receiving.ensure_address_gap()
|
||||
self.assertEqual(len(empty), 0)
|
||||
|
||||
# case #2: only one new addressed needed
|
||||
records = yield account.receiving.get_address_records()
|
||||
yield self.ledger.db.set_address_history(records[0]['address'], 'a:1:')
|
||||
new_keys = yield account.receiving.ensure_address_gap()
|
||||
records = await account.receiving.get_address_records()
|
||||
await self.ledger.db.set_address_history(records[0]['address'], 'a:1:')
|
||||
new_keys = await account.receiving.ensure_address_gap()
|
||||
self.assertEqual(len(new_keys), 1)
|
||||
|
||||
# case #3: 20 addresses needed
|
||||
yield self.ledger.db.set_address_history(new_keys[0], 'a:1:')
|
||||
new_keys = yield account.receiving.ensure_address_gap()
|
||||
await self.ledger.db.set_address_history(new_keys[0], 'a:1:')
|
||||
new_keys = await account.receiving.ensure_address_gap()
|
||||
self.assertEqual(len(new_keys), 20)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_or_create_usable_address(self):
|
||||
async def test_get_or_create_usable_address(self):
|
||||
account = self.account
|
||||
|
||||
keys = yield account.receiving.get_addresses()
|
||||
keys = await account.receiving.get_addresses()
|
||||
self.assertEqual(len(keys), 0)
|
||||
|
||||
address = yield account.receiving.get_or_create_usable_address()
|
||||
address = await account.receiving.get_or_create_usable_address()
|
||||
self.assertIsNotNone(address)
|
||||
|
||||
keys = yield account.receiving.get_addresses()
|
||||
keys = await account.receiving.get_addresses()
|
||||
self.assertEqual(len(keys), 20)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_generate_account_from_seed(self):
|
||||
async def test_generate_account_from_seed(self):
|
||||
account = self.ledger.account_class.from_dict(
|
||||
self.ledger, Wallet(), {
|
||||
"seed": "carbon smart garage balance margin twelve chest sword "
|
||||
|
@ -123,17 +120,17 @@ class TestHierarchicalDeterministicAccount(unittest.TestCase):
|
|||
'xpub661MyMwAqRbcFwwe67Bfjd53h5WXmKm6tqfBJZZH3pQLoy8Nb6mKUMJFc7UbpV'
|
||||
'NzmwFPN2evn3YHnig1pkKVYcvCV8owTd2yAcEkJfCX53g'
|
||||
)
|
||||
address = yield account.receiving.ensure_address_gap()
|
||||
address = await account.receiving.ensure_address_gap()
|
||||
self.assertEqual(address[0], '1CDLuMfwmPqJiNk5C2Bvew6tpgjAGgUk8J')
|
||||
|
||||
private_key = yield self.ledger.get_private_key_for_address('1CDLuMfwmPqJiNk5C2Bvew6tpgjAGgUk8J')
|
||||
private_key = await self.ledger.get_private_key_for_address('1CDLuMfwmPqJiNk5C2Bvew6tpgjAGgUk8J')
|
||||
self.assertEqual(
|
||||
private_key.extended_key_string(),
|
||||
'xprv9xV7rhbg6M4yWrdTeLorz3Q1GrQb4aQzzGWboP3du7W7UUztzNTUrEYTnDfz7o'
|
||||
'ptBygDxXYRppyiuenJpoBTgYP2C26E1Ah5FEALM24CsWi'
|
||||
)
|
||||
|
||||
invalid_key = yield self.ledger.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX')
|
||||
invalid_key = await self.ledger.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX')
|
||||
self.assertIsNone(invalid_key)
|
||||
|
||||
self.assertEqual(
|
||||
|
@ -141,8 +138,7 @@ class TestHierarchicalDeterministicAccount(unittest.TestCase):
|
|||
b'1c01ae1e4c7d89e39f6d3aa7792c097a30ca7d40be249b6de52c81ec8cf9aab48b01'
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_load_and_save_account(self):
|
||||
async def test_load_and_save_account(self):
|
||||
account_data = {
|
||||
'name': 'My Account',
|
||||
'seed':
|
||||
|
@ -164,11 +160,11 @@ class TestHierarchicalDeterministicAccount(unittest.TestCase):
|
|||
|
||||
account = self.ledger.account_class.from_dict(self.ledger, Wallet(), account_data)
|
||||
|
||||
yield account.ensure_address_gap()
|
||||
await account.ensure_address_gap()
|
||||
|
||||
addresses = yield account.receiving.get_addresses()
|
||||
addresses = await account.receiving.get_addresses()
|
||||
self.assertEqual(len(addresses), 5)
|
||||
addresses = yield account.change.get_addresses()
|
||||
addresses = await account.change.get_addresses()
|
||||
self.assertEqual(len(addresses), 5)
|
||||
|
||||
self.maxDiff = None
|
||||
|
@ -176,20 +172,21 @@ class TestHierarchicalDeterministicAccount(unittest.TestCase):
|
|||
self.assertDictEqual(account_data, account.to_dict())
|
||||
|
||||
|
||||
class TestSingleKeyAccount(unittest.TestCase):
|
||||
class TestSingleKeyAccount(AsyncioTestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
async def asyncSetUp(self):
|
||||
self.ledger = ledger_class({
|
||||
'db': ledger_class.database_class(':memory:'),
|
||||
'headers': ledger_class.headers_class(':memory:'),
|
||||
})
|
||||
yield self.ledger.db.open()
|
||||
await self.ledger.db.open()
|
||||
self.account = self.ledger.account_class.generate(
|
||||
self.ledger, Wallet(), "torba", {'name': 'single-address'})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_generate_account(self):
|
||||
async def asyncTearDown(self):
|
||||
await self.ledger.db.close()
|
||||
|
||||
async def test_generate_account(self):
|
||||
account = self.account
|
||||
|
||||
self.assertEqual(account.ledger, self.ledger)
|
||||
|
@ -197,37 +194,36 @@ class TestSingleKeyAccount(unittest.TestCase):
|
|||
self.assertEqual(account.public_key.ledger, self.ledger)
|
||||
self.assertEqual(account.private_key.public_key, account.public_key)
|
||||
|
||||
addresses = yield account.receiving.get_addresses()
|
||||
addresses = await account.receiving.get_addresses()
|
||||
self.assertEqual(len(addresses), 0)
|
||||
addresses = yield account.change.get_addresses()
|
||||
addresses = await account.change.get_addresses()
|
||||
self.assertEqual(len(addresses), 0)
|
||||
|
||||
yield account.ensure_address_gap()
|
||||
await account.ensure_address_gap()
|
||||
|
||||
addresses = yield account.receiving.get_addresses()
|
||||
addresses = await account.receiving.get_addresses()
|
||||
self.assertEqual(len(addresses), 1)
|
||||
self.assertEqual(addresses[0], account.public_key.address)
|
||||
addresses = yield account.change.get_addresses()
|
||||
addresses = await account.change.get_addresses()
|
||||
self.assertEqual(len(addresses), 1)
|
||||
self.assertEqual(addresses[0], account.public_key.address)
|
||||
|
||||
addresses = yield account.get_addresses()
|
||||
addresses = await account.get_addresses()
|
||||
self.assertEqual(len(addresses), 1)
|
||||
self.assertEqual(addresses[0], account.public_key.address)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_ensure_address_gap(self):
|
||||
async def test_ensure_address_gap(self):
|
||||
account = self.account
|
||||
|
||||
self.assertIsInstance(account.receiving, SingleKey)
|
||||
addresses = yield account.receiving.get_addresses()
|
||||
addresses = await account.receiving.get_addresses()
|
||||
self.assertEqual(addresses, [])
|
||||
|
||||
# we have 12, but default gap is 20
|
||||
new_keys = yield account.receiving.ensure_address_gap()
|
||||
new_keys = await account.receiving.ensure_address_gap()
|
||||
self.assertEqual(len(new_keys), 1)
|
||||
self.assertEqual(new_keys[0], account.public_key.address)
|
||||
records = yield account.receiving.get_address_records()
|
||||
records = await account.receiving.get_address_records()
|
||||
self.assertEqual(records, [{
|
||||
'position': 0, 'chain': 0,
|
||||
'account': account.public_key.address,
|
||||
|
@ -236,37 +232,35 @@ class TestSingleKeyAccount(unittest.TestCase):
|
|||
}])
|
||||
|
||||
# case #1: no new addresses needed
|
||||
empty = yield account.receiving.ensure_address_gap()
|
||||
empty = await account.receiving.ensure_address_gap()
|
||||
self.assertEqual(len(empty), 0)
|
||||
|
||||
# case #2: after use, still no new address needed
|
||||
records = yield account.receiving.get_address_records()
|
||||
yield self.ledger.db.set_address_history(records[0]['address'], 'a:1:')
|
||||
empty = yield account.receiving.ensure_address_gap()
|
||||
records = await account.receiving.get_address_records()
|
||||
await self.ledger.db.set_address_history(records[0]['address'], 'a:1:')
|
||||
empty = await account.receiving.ensure_address_gap()
|
||||
self.assertEqual(len(empty), 0)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_or_create_usable_address(self):
|
||||
async def test_get_or_create_usable_address(self):
|
||||
account = self.account
|
||||
|
||||
addresses = yield account.receiving.get_addresses()
|
||||
addresses = await account.receiving.get_addresses()
|
||||
self.assertEqual(len(addresses), 0)
|
||||
|
||||
address1 = yield account.receiving.get_or_create_usable_address()
|
||||
address1 = await account.receiving.get_or_create_usable_address()
|
||||
self.assertIsNotNone(address1)
|
||||
|
||||
yield self.ledger.db.set_address_history(address1, 'a:1:b:2:c:3:')
|
||||
records = yield account.receiving.get_address_records()
|
||||
await self.ledger.db.set_address_history(address1, 'a:1:b:2:c:3:')
|
||||
records = await account.receiving.get_address_records()
|
||||
self.assertEqual(records[0]['used_times'], 3)
|
||||
|
||||
address2 = yield account.receiving.get_or_create_usable_address()
|
||||
address2 = await account.receiving.get_or_create_usable_address()
|
||||
self.assertEqual(address1, address2)
|
||||
|
||||
keys = yield account.receiving.get_addresses()
|
||||
keys = await account.receiving.get_addresses()
|
||||
self.assertEqual(len(keys), 1)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_generate_account_from_seed(self):
|
||||
async def test_generate_account_from_seed(self):
|
||||
account = self.ledger.account_class.from_dict(
|
||||
self.ledger, Wallet(), {
|
||||
"seed":
|
||||
|
@ -285,17 +279,17 @@ class TestSingleKeyAccount(unittest.TestCase):
|
|||
'xpub661MyMwAqRbcFwwe67Bfjd53h5WXmKm6tqfBJZZH3pQLoy8Nb6mKUMJFc7'
|
||||
'UbpVNzmwFPN2evn3YHnig1pkKVYcvCV8owTd2yAcEkJfCX53g',
|
||||
)
|
||||
address = yield account.receiving.ensure_address_gap()
|
||||
address = await account.receiving.ensure_address_gap()
|
||||
self.assertEqual(address[0], account.public_key.address)
|
||||
|
||||
private_key = yield self.ledger.get_private_key_for_address(address[0])
|
||||
private_key = await self.ledger.get_private_key_for_address(address[0])
|
||||
self.assertEqual(
|
||||
private_key.extended_key_string(),
|
||||
'xprv9s21ZrQH143K3TsAz5efNV8K93g3Ms3FXcjaWB9fVUsMwAoE3ZT4vYymkp'
|
||||
'5BxKKfnpz8J6sHDFriX1SnpvjNkzcks8XBnxjGLS83BTyfpna',
|
||||
)
|
||||
|
||||
invalid_key = yield self.ledger.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX')
|
||||
invalid_key = await self.ledger.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX')
|
||||
self.assertIsNone(invalid_key)
|
||||
|
||||
self.assertEqual(
|
||||
|
@ -303,8 +297,7 @@ class TestSingleKeyAccount(unittest.TestCase):
|
|||
b'1c92caa0ef99bfd5e2ceb73b66da8cd726a9370be8c368d448a322f3c5b23aaab901'
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_load_and_save_account(self):
|
||||
async def test_load_and_save_account(self):
|
||||
account_data = {
|
||||
'name': 'My Account',
|
||||
'seed':
|
||||
|
@ -322,11 +315,11 @@ class TestSingleKeyAccount(unittest.TestCase):
|
|||
|
||||
account = self.ledger.account_class.from_dict(self.ledger, Wallet(), account_data)
|
||||
|
||||
yield account.ensure_address_gap()
|
||||
await account.ensure_address_gap()
|
||||
|
||||
addresses = yield account.receiving.get_addresses()
|
||||
addresses = await account.receiving.get_addresses()
|
||||
self.assertEqual(len(addresses), 1)
|
||||
addresses = yield account.change.get_addresses()
|
||||
addresses = await account.change.get_addresses()
|
||||
self.assertEqual(len(addresses), 1)
|
||||
|
||||
self.maxDiff = None
|
||||
|
@ -334,7 +327,7 @@ class TestSingleKeyAccount(unittest.TestCase):
|
|||
self.assertDictEqual(account_data, account.to_dict())
|
||||
|
||||
|
||||
class AccountEncryptionTests(unittest.TestCase):
|
||||
class AccountEncryptionTests(AsyncioTestCase):
|
||||
password = "password"
|
||||
init_vector = b'0000000000000000'
|
||||
unencrypted_account = {
|
||||
|
@ -368,7 +361,7 @@ class AccountEncryptionTests(unittest.TestCase):
|
|||
'address_generator': {'name': 'single-address'}
|
||||
}
|
||||
|
||||
def setUp(self):
|
||||
async def asyncSetUp(self):
|
||||
self.ledger = ledger_class({
|
||||
'db': ledger_class.database_class(':memory:'),
|
||||
'headers': ledger_class.headers_class(':memory:'),
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from twisted.trial import unittest
|
||||
import unittest
|
||||
|
||||
from torba.bcd_data_stream import BCDataStream
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import unittest
|
||||
from binascii import unhexlify, hexlify
|
||||
from twisted.trial import unittest
|
||||
|
||||
from .key_fixtures import expected_ids, expected_privkeys, expected_hardened_privkeys
|
||||
from torba.bip32 import PubKey, PrivateKey, from_extended_key_string
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from twisted.trial import unittest
|
||||
import unittest
|
||||
from types import GeneratorType
|
||||
|
||||
from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
from twisted.trial import unittest
|
||||
from twisted.internet import defer
|
||||
import unittest
|
||||
|
||||
from torba.wallet import Wallet
|
||||
from torba.constants import COIN
|
||||
from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class
|
||||
from torba.basedatabase import query, constraints_to_sql
|
||||
|
||||
from orchstr8.testcase import AsyncioTestCase
|
||||
|
||||
from .test_transaction import get_output, NULL_HASH
|
||||
|
||||
|
||||
|
@ -100,53 +101,52 @@ class TestQueryBuilder(unittest.TestCase):
|
|||
)
|
||||
|
||||
|
||||
class TestQueries(unittest.TestCase):
|
||||
class TestQueries(AsyncioTestCase):
|
||||
|
||||
def setUp(self):
|
||||
async def asyncSetUp(self):
|
||||
self.ledger = ledger_class({
|
||||
'db': ledger_class.database_class(':memory:'),
|
||||
'headers': ledger_class.headers_class(':memory:'),
|
||||
})
|
||||
return self.ledger.db.open()
|
||||
await self.ledger.db.open()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_account(self):
|
||||
async def asyncTearDown(self):
|
||||
await self.ledger.db.close()
|
||||
|
||||
async def create_account(self):
|
||||
account = self.ledger.account_class.generate(self.ledger, Wallet())
|
||||
yield account.ensure_address_gap()
|
||||
await account.ensure_address_gap()
|
||||
return account
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_tx_from_nothing(self, my_account, height):
|
||||
to_address = yield my_account.receiving.get_or_create_usable_address()
|
||||
async def create_tx_from_nothing(self, my_account, height):
|
||||
to_address = await my_account.receiving.get_or_create_usable_address()
|
||||
to_hash = ledger_class.address_to_hash160(to_address)
|
||||
tx = ledger_class.transaction_class(height=height, is_verified=True) \
|
||||
.add_inputs([self.txi(self.txo(1, NULL_HASH))]) \
|
||||
.add_outputs([self.txo(1, to_hash)])
|
||||
yield self.ledger.db.save_transaction_io('insert', tx, to_address, to_hash, '')
|
||||
await self.ledger.db.save_transaction_io('insert', tx, to_address, to_hash, '')
|
||||
return tx
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_tx_from_txo(self, txo, to_account, height):
|
||||
async def create_tx_from_txo(self, txo, to_account, height):
|
||||
from_hash = txo.script.values['pubkey_hash']
|
||||
from_address = self.ledger.hash160_to_address(from_hash)
|
||||
to_address = yield to_account.receiving.get_or_create_usable_address()
|
||||
to_address = await to_account.receiving.get_or_create_usable_address()
|
||||
to_hash = ledger_class.address_to_hash160(to_address)
|
||||
tx = ledger_class.transaction_class(height=height, is_verified=True) \
|
||||
.add_inputs([self.txi(txo)]) \
|
||||
.add_outputs([self.txo(1, to_hash)])
|
||||
yield self.ledger.db.save_transaction_io('insert', tx, from_address, from_hash, '')
|
||||
yield self.ledger.db.save_transaction_io('', tx, to_address, to_hash, '')
|
||||
await self.ledger.db.save_transaction_io('insert', tx, from_address, from_hash, '')
|
||||
await self.ledger.db.save_transaction_io('', tx, to_address, to_hash, '')
|
||||
return tx
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_tx_to_nowhere(self, txo, height):
|
||||
async def create_tx_to_nowhere(self, txo, height):
|
||||
from_hash = txo.script.values['pubkey_hash']
|
||||
from_address = self.ledger.hash160_to_address(from_hash)
|
||||
to_hash = NULL_HASH
|
||||
tx = ledger_class.transaction_class(height=height, is_verified=True) \
|
||||
.add_inputs([self.txi(txo)]) \
|
||||
.add_outputs([self.txo(1, to_hash)])
|
||||
yield self.ledger.db.save_transaction_io('insert', tx, from_address, from_hash, '')
|
||||
await self.ledger.db.save_transaction_io('insert', tx, from_address, from_hash, '')
|
||||
return tx
|
||||
|
||||
def txo(self, amount, address):
|
||||
|
@ -155,39 +155,38 @@ class TestQueries(unittest.TestCase):
|
|||
def txi(self, txo):
|
||||
return ledger_class.transaction_class.input_class.spend(txo)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_transactions(self):
|
||||
account1 = yield self.create_account()
|
||||
account2 = yield self.create_account()
|
||||
tx1 = yield self.create_tx_from_nothing(account1, 1)
|
||||
tx2 = yield self.create_tx_from_txo(tx1.outputs[0], account2, 2)
|
||||
tx3 = yield self.create_tx_to_nowhere(tx2.outputs[0], 3)
|
||||
async def test_get_transactions(self):
|
||||
account1 = await self.create_account()
|
||||
account2 = await self.create_account()
|
||||
tx1 = await self.create_tx_from_nothing(account1, 1)
|
||||
tx2 = await self.create_tx_from_txo(tx1.outputs[0], account2, 2)
|
||||
tx3 = await self.create_tx_to_nowhere(tx2.outputs[0], 3)
|
||||
|
||||
txs = yield self.ledger.db.get_transactions()
|
||||
txs = await self.ledger.db.get_transactions()
|
||||
self.assertEqual([tx3.id, tx2.id, tx1.id], [tx.id for tx in txs])
|
||||
self.assertEqual([3, 2, 1], [tx.height for tx in txs])
|
||||
|
||||
txs = yield self.ledger.db.get_transactions(account=account1)
|
||||
txs = await self.ledger.db.get_transactions(account=account1)
|
||||
self.assertEqual([tx2.id, tx1.id], [tx.id for tx in txs])
|
||||
self.assertEqual(txs[0].inputs[0].is_my_account, True)
|
||||
self.assertEqual(txs[0].outputs[0].is_my_account, False)
|
||||
self.assertEqual(txs[1].inputs[0].is_my_account, False)
|
||||
self.assertEqual(txs[1].outputs[0].is_my_account, True)
|
||||
|
||||
txs = yield self.ledger.db.get_transactions(account=account2)
|
||||
txs = await self.ledger.db.get_transactions(account=account2)
|
||||
self.assertEqual([tx3.id, tx2.id], [tx.id for tx in txs])
|
||||
self.assertEqual(txs[0].inputs[0].is_my_account, True)
|
||||
self.assertEqual(txs[0].outputs[0].is_my_account, False)
|
||||
self.assertEqual(txs[1].inputs[0].is_my_account, False)
|
||||
self.assertEqual(txs[1].outputs[0].is_my_account, True)
|
||||
|
||||
tx = yield self.ledger.db.get_transaction(txid=tx2.id)
|
||||
tx = await self.ledger.db.get_transaction(txid=tx2.id)
|
||||
self.assertEqual(tx.id, tx2.id)
|
||||
self.assertEqual(tx.inputs[0].is_my_account, False)
|
||||
self.assertEqual(tx.outputs[0].is_my_account, False)
|
||||
tx = yield self.ledger.db.get_transaction(txid=tx2.id, account=account1)
|
||||
tx = await self.ledger.db.get_transaction(txid=tx2.id, account=account1)
|
||||
self.assertEqual(tx.inputs[0].is_my_account, True)
|
||||
self.assertEqual(tx.outputs[0].is_my_account, False)
|
||||
tx = yield self.ledger.db.get_transaction(txid=tx2.id, account=account2)
|
||||
tx = await self.ledger.db.get_transaction(txid=tx2.id, account=account2)
|
||||
self.assertEqual(tx.inputs[0].is_my_account, False)
|
||||
self.assertEqual(tx.outputs[0].is_my_account, True)
|
||||
|
|
|
@ -1,11 +1,6 @@
|
|||
from unittest import TestCase
|
||||
from unittest import TestCase, mock
|
||||
from torba.hash import aes_decrypt, aes_encrypt
|
||||
|
||||
try:
|
||||
from unittest import mock
|
||||
except ImportError:
|
||||
import mock
|
||||
|
||||
|
||||
class TestAESEncryptDecrypt(TestCase):
|
||||
message = 'The Times 03/Jan/2009 Chancellor on brink of second bailout for banks'
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
import os
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet import defer
|
||||
from orchstr8.testcase import AsyncioTestCase
|
||||
|
||||
from torba.coin.bitcoinsegwit import MainHeaders
|
||||
|
||||
|
@ -11,7 +10,7 @@ def block_bytes(blocks):
|
|||
return blocks * MainHeaders.header_size
|
||||
|
||||
|
||||
class BitcoinHeadersTestCase(unittest.TestCase):
|
||||
class BitcoinHeadersTestCase(AsyncioTestCase):
|
||||
|
||||
# Download headers instead of storing them in git.
|
||||
HEADER_URL = 'http://headers.electrum.org/blockchain_headers'
|
||||
|
@ -39,7 +38,7 @@ class BitcoinHeadersTestCase(unittest.TestCase):
|
|||
headers.seek(after, os.SEEK_SET)
|
||||
return headers.read(upto)
|
||||
|
||||
def get_headers(self, upto: int = -1):
|
||||
async def get_headers(self, upto: int = -1):
|
||||
h = MainHeaders(':memory:')
|
||||
h.io.write(self.get_bytes(upto))
|
||||
return h
|
||||
|
@ -47,8 +46,8 @@ class BitcoinHeadersTestCase(unittest.TestCase):
|
|||
|
||||
class BasicHeadersTests(BitcoinHeadersTestCase):
|
||||
|
||||
def test_serialization(self):
|
||||
h = self.get_headers()
|
||||
async def test_serialization(self):
|
||||
h = await self.get_headers()
|
||||
self.assertEqual(h[0], {
|
||||
'bits': 486604799,
|
||||
'block_height': 0,
|
||||
|
@ -94,18 +93,16 @@ class BasicHeadersTests(BitcoinHeadersTestCase):
|
|||
h.get_raw_header(self.RETARGET_BLOCK)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_connect_from_genesis_to_3000_past_first_chunk_at_2016(self):
|
||||
async def test_connect_from_genesis_to_3000_past_first_chunk_at_2016(self):
|
||||
headers = MainHeaders(':memory:')
|
||||
self.assertEqual(headers.height, -1)
|
||||
yield headers.connect(0, self.get_bytes(block_bytes(3001)))
|
||||
await headers.connect(0, self.get_bytes(block_bytes(3001)))
|
||||
self.assertEqual(headers.height, 3000)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_connect_9_blocks_passing_a_retarget_at_32256(self):
|
||||
async def test_connect_9_blocks_passing_a_retarget_at_32256(self):
|
||||
retarget = block_bytes(self.RETARGET_BLOCK-5)
|
||||
headers = self.get_headers(upto=retarget)
|
||||
headers = await self.get_headers(upto=retarget)
|
||||
remainder = self.get_bytes(after=retarget)
|
||||
self.assertEqual(headers.height, 32250)
|
||||
yield headers.connect(len(headers), remainder)
|
||||
await headers.connect(len(headers), remainder)
|
||||
self.assertEqual(headers.height, 32259)
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import os
|
||||
from binascii import hexlify
|
||||
from twisted.internet import defer
|
||||
|
||||
from torba.coin.bitcoinsegwit import MainNetLedger
|
||||
from torba.wallet import Wallet
|
||||
|
@ -18,32 +17,30 @@ class MockNetwork:
|
|||
self.get_history_called = []
|
||||
self.get_transaction_called = []
|
||||
|
||||
def get_history(self, address):
|
||||
async def get_history(self, address):
|
||||
self.get_history_called.append(address)
|
||||
self.address = address
|
||||
return defer.succeed(self.history)
|
||||
return self.history
|
||||
|
||||
def get_merkle(self, txid, height):
|
||||
return defer.succeed({'merkle': ['abcd01'], 'pos': 1})
|
||||
async def get_merkle(self, txid, height):
|
||||
return {'merkle': ['abcd01'], 'pos': 1}
|
||||
|
||||
def get_transaction(self, tx_hash):
|
||||
async def get_transaction(self, tx_hash):
|
||||
self.get_transaction_called.append(tx_hash)
|
||||
return defer.succeed(self.transaction[tx_hash])
|
||||
return self.transaction[tx_hash]
|
||||
|
||||
|
||||
class LedgerTestCase(BitcoinHeadersTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
async def asyncSetUp(self):
|
||||
self.ledger = MainNetLedger({
|
||||
'db': MainNetLedger.database_class(':memory:'),
|
||||
'headers': MainNetLedger.headers_class(':memory:')
|
||||
})
|
||||
return self.ledger.db.open()
|
||||
await self.ledger.db.open()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
return self.ledger.db.close()
|
||||
async def asyncTearDown(self):
|
||||
await self.ledger.db.close()
|
||||
|
||||
def make_header(self, **kwargs):
|
||||
header = {
|
||||
|
@ -69,11 +66,10 @@ class LedgerTestCase(BitcoinHeadersTestCase):
|
|||
|
||||
class TestSynchronization(LedgerTestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_update_history(self):
|
||||
async def test_update_history(self):
|
||||
account = self.ledger.account_class.generate(self.ledger, Wallet(), "torba")
|
||||
address = yield account.receiving.get_or_create_usable_address()
|
||||
address_details = yield self.ledger.db.get_address(address=address)
|
||||
address = await account.receiving.get_or_create_usable_address()
|
||||
address_details = await self.ledger.db.get_address(address=address)
|
||||
self.assertEqual(address_details['history'], None)
|
||||
|
||||
self.add_header(block_height=0, merkle_root=b'abcd04')
|
||||
|
@ -89,16 +85,16 @@ class TestSynchronization(LedgerTestCase):
|
|||
'abcd02': hexlify(get_transaction(get_output(2)).raw),
|
||||
'abcd03': hexlify(get_transaction(get_output(3)).raw),
|
||||
})
|
||||
yield self.ledger.update_history(address)
|
||||
await self.ledger.update_history(address)
|
||||
self.assertEqual(self.ledger.network.get_history_called, [address])
|
||||
self.assertEqual(self.ledger.network.get_transaction_called, ['abcd01', 'abcd02', 'abcd03'])
|
||||
|
||||
address_details = yield self.ledger.db.get_address(address=address)
|
||||
address_details = await self.ledger.db.get_address(address=address)
|
||||
self.assertEqual(address_details['history'], 'abcd01:0:abcd02:1:abcd03:2:')
|
||||
|
||||
self.ledger.network.get_history_called = []
|
||||
self.ledger.network.get_transaction_called = []
|
||||
yield self.ledger.update_history(address)
|
||||
await self.ledger.update_history(address)
|
||||
self.assertEqual(self.ledger.network.get_history_called, [address])
|
||||
self.assertEqual(self.ledger.network.get_transaction_called, [])
|
||||
|
||||
|
@ -106,10 +102,10 @@ class TestSynchronization(LedgerTestCase):
|
|||
self.ledger.network.transaction['abcd04'] = hexlify(get_transaction(get_output(4)).raw)
|
||||
self.ledger.network.get_history_called = []
|
||||
self.ledger.network.get_transaction_called = []
|
||||
yield self.ledger.update_history(address)
|
||||
await self.ledger.update_history(address)
|
||||
self.assertEqual(self.ledger.network.get_history_called, [address])
|
||||
self.assertEqual(self.ledger.network.get_transaction_called, ['abcd04'])
|
||||
address_details = yield self.ledger.db.get_address(address=address)
|
||||
address_details = await self.ledger.db.get_address(address=address)
|
||||
self.assertEqual(address_details['history'], 'abcd01:0:abcd02:1:abcd03:2:abcd04:3:')
|
||||
|
||||
|
||||
|
@ -117,14 +113,13 @@ class MocHeaderNetwork:
|
|||
def __init__(self, responses):
|
||||
self.responses = responses
|
||||
|
||||
def get_headers(self, height, blocks):
|
||||
async def get_headers(self, height, blocks):
|
||||
return self.responses[height]
|
||||
|
||||
|
||||
class BlockchainReorganizationTests(LedgerTestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_1_block_reorganization(self):
|
||||
async def test_1_block_reorganization(self):
|
||||
self.ledger.network = MocHeaderNetwork({
|
||||
20: {'height': 20, 'count': 5, 'hex': hexlify(
|
||||
self.get_bytes(after=block_bytes(20), upto=block_bytes(5))
|
||||
|
@ -132,15 +127,14 @@ class BlockchainReorganizationTests(LedgerTestCase):
|
|||
25: {'height': 25, 'count': 0, 'hex': b''}
|
||||
})
|
||||
headers = self.ledger.headers
|
||||
yield headers.connect(0, self.get_bytes(upto=block_bytes(20)))
|
||||
await headers.connect(0, self.get_bytes(upto=block_bytes(20)))
|
||||
self.add_header(block_height=len(headers))
|
||||
self.assertEqual(headers.height, 20)
|
||||
yield self.ledger.receive_header([{
|
||||
await self.ledger.receive_header([{
|
||||
'height': 21, 'hex': hexlify(self.make_header(block_height=21))
|
||||
}])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_3_block_reorganization(self):
|
||||
async def test_3_block_reorganization(self):
|
||||
self.ledger.network = MocHeaderNetwork({
|
||||
20: {'height': 20, 'count': 5, 'hex': hexlify(
|
||||
self.get_bytes(after=block_bytes(20), upto=block_bytes(5))
|
||||
|
@ -150,11 +144,11 @@ class BlockchainReorganizationTests(LedgerTestCase):
|
|||
25: {'height': 25, 'count': 0, 'hex': b''}
|
||||
})
|
||||
headers = self.ledger.headers
|
||||
yield headers.connect(0, self.get_bytes(upto=block_bytes(20)))
|
||||
await headers.connect(0, self.get_bytes(upto=block_bytes(20)))
|
||||
self.add_header(block_height=len(headers))
|
||||
self.add_header(block_height=len(headers))
|
||||
self.add_header(block_height=len(headers))
|
||||
self.assertEqual(headers.height, 22)
|
||||
yield self.ledger.receive_header(({
|
||||
await self.ledger.receive_header(({
|
||||
'height': 23, 'hex': hexlify(self.make_header(block_height=23))
|
||||
},))
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import unittest
|
||||
from binascii import hexlify, unhexlify
|
||||
from twisted.trial import unittest
|
||||
|
||||
from torba.bcd_data_stream import BCDataStream
|
||||
from torba.basescript import Template, ParseError, tokenize, push_data
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import unittest
|
||||
from binascii import hexlify, unhexlify
|
||||
from itertools import cycle
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet import defer
|
||||
|
||||
from orchstr8.testcase import AsyncioTestCase
|
||||
|
||||
from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class
|
||||
from torba.wallet import Wallet
|
||||
|
@ -29,9 +30,9 @@ def get_transaction(txo=None):
|
|||
.add_outputs([txo or ledger_class.transaction_class.output_class.pay_pubkey_hash(CENT, NULL_HASH)])
|
||||
|
||||
|
||||
class TestSizeAndFeeEstimation(unittest.TestCase):
|
||||
class TestSizeAndFeeEstimation(AsyncioTestCase):
|
||||
|
||||
def setUp(self):
|
||||
async def asyncSetUp(self):
|
||||
self.ledger = ledger_class({
|
||||
'db': ledger_class.database_class(':memory:'),
|
||||
'headers': ledger_class.headers_class(':memory:'),
|
||||
|
@ -181,17 +182,19 @@ class TestTransactionSerialization(unittest.TestCase):
|
|||
self.assertEqual(tx.raw, raw)
|
||||
|
||||
|
||||
class TestTransactionSigning(unittest.TestCase):
|
||||
class TestTransactionSigning(AsyncioTestCase):
|
||||
|
||||
def setUp(self):
|
||||
async def asyncSetUp(self):
|
||||
self.ledger = ledger_class({
|
||||
'db': ledger_class.database_class(':memory:'),
|
||||
'headers': ledger_class.headers_class(':memory:'),
|
||||
})
|
||||
return self.ledger.db.open()
|
||||
await self.ledger.db.open()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_sign(self):
|
||||
async def asyncTearDown(self):
|
||||
await self.ledger.db.close()
|
||||
|
||||
async def test_sign(self):
|
||||
account = self.ledger.account_class.from_dict(
|
||||
self.ledger, Wallet(), {
|
||||
"seed": "carbon smart garage balance margin twelve chest sword "
|
||||
|
@ -200,8 +203,8 @@ class TestTransactionSigning(unittest.TestCase):
|
|||
}
|
||||
)
|
||||
|
||||
yield account.ensure_address_gap()
|
||||
address1, address2 = yield account.receiving.get_addresses(limit=2)
|
||||
await account.ensure_address_gap()
|
||||
address1, address2 = await account.receiving.get_addresses(limit=2)
|
||||
pubkey_hash1 = self.ledger.address_to_hash160(address1)
|
||||
pubkey_hash2 = self.ledger.address_to_hash160(address2)
|
||||
|
||||
|
@ -211,9 +214,8 @@ class TestTransactionSigning(unittest.TestCase):
|
|||
.add_inputs([tx_class.input_class.spend(get_output(2*COIN, pubkey_hash1))]) \
|
||||
.add_outputs([tx_class.output_class.pay_pubkey_hash(int(1.9*COIN), pubkey_hash2)]) \
|
||||
|
||||
yield tx.sign([account])
|
||||
await tx.sign([account])
|
||||
|
||||
print(hexlify(tx.inputs[0].script.values['signature']))
|
||||
self.assertEqual(
|
||||
hexlify(tx.inputs[0].script.values['signature']),
|
||||
b'304402205a1df8cd5d2d2fa5934b756883d6c07e4f83e1350c740992d47a12422'
|
||||
|
@ -221,15 +223,14 @@ class TestTransactionSigning(unittest.TestCase):
|
|||
)
|
||||
|
||||
|
||||
class TransactionIOBalancing(unittest.TestCase):
|
||||
class TransactionIOBalancing(AsyncioTestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
async def asyncSetUp(self):
|
||||
self.ledger = ledger_class({
|
||||
'db': ledger_class.database_class(':memory:'),
|
||||
'headers': ledger_class.headers_class(':memory:'),
|
||||
})
|
||||
yield self.ledger.db.open()
|
||||
await self.ledger.db.open()
|
||||
self.account = self.ledger.account_class.from_dict(
|
||||
self.ledger, Wallet(), {
|
||||
"seed": "carbon smart garage balance margin twelve chest sword "
|
||||
|
@ -237,10 +238,13 @@ class TransactionIOBalancing(unittest.TestCase):
|
|||
}
|
||||
)
|
||||
|
||||
addresses = yield self.account.ensure_address_gap()
|
||||
addresses = await self.account.ensure_address_gap()
|
||||
self.pubkey_hash = [self.ledger.address_to_hash160(a) for a in addresses]
|
||||
self.hash_cycler = cycle(self.pubkey_hash)
|
||||
|
||||
async def asyncTearDown(self):
|
||||
await self.ledger.db.close()
|
||||
|
||||
def txo(self, amount, address=None):
|
||||
return get_output(int(amount*COIN), address or next(self.hash_cycler))
|
||||
|
||||
|
@ -250,8 +254,7 @@ class TransactionIOBalancing(unittest.TestCase):
|
|||
def tx(self, inputs, outputs):
|
||||
return ledger_class.transaction_class.create(inputs, outputs, [self.account], self.account)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_utxos(self, amounts):
|
||||
async def create_utxos(self, amounts):
|
||||
utxos = [self.txo(amount) for amount in amounts]
|
||||
|
||||
self.funding_tx = ledger_class.transaction_class(is_verified=True) \
|
||||
|
@ -260,7 +263,7 @@ class TransactionIOBalancing(unittest.TestCase):
|
|||
|
||||
save_tx = 'insert'
|
||||
for utxo in utxos:
|
||||
yield self.ledger.db.save_transaction_io(
|
||||
await self.ledger.db.save_transaction_io(
|
||||
save_tx, self.funding_tx,
|
||||
self.ledger.hash160_to_address(utxo.script.values['pubkey_hash']),
|
||||
utxo.script.values['pubkey_hash'], ''
|
||||
|
@ -277,17 +280,16 @@ class TransactionIOBalancing(unittest.TestCase):
|
|||
def outputs(tx):
|
||||
return [round(o.amount/COIN, 2) for o in tx.outputs]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_basic_use_cases(self):
|
||||
async def test_basic_use_cases(self):
|
||||
self.ledger.fee_per_byte = int(.01*CENT)
|
||||
|
||||
# available UTXOs for filling missing inputs
|
||||
utxos = yield self.create_utxos([
|
||||
utxos = await self.create_utxos([
|
||||
1, 1, 3, 5, 10
|
||||
])
|
||||
|
||||
# pay 3 coins (3.02 w/ fees)
|
||||
tx = yield self.tx(
|
||||
tx = await self.tx(
|
||||
[], # inputs
|
||||
[self.txo(3)] # outputs
|
||||
)
|
||||
|
@ -296,10 +298,10 @@ class TransactionIOBalancing(unittest.TestCase):
|
|||
# a change of 1.98 is added to reach balance
|
||||
self.assertEqual(self.outputs(tx), [3, 1.98])
|
||||
|
||||
yield self.ledger.release_outputs(utxos)
|
||||
await self.ledger.release_outputs(utxos)
|
||||
|
||||
# pay 2.98 coins (3.00 w/ fees)
|
||||
tx = yield self.tx(
|
||||
tx = await self.tx(
|
||||
[], # inputs
|
||||
[self.txo(2.98)] # outputs
|
||||
)
|
||||
|
@ -307,10 +309,10 @@ class TransactionIOBalancing(unittest.TestCase):
|
|||
self.assertEqual(self.inputs(tx), [3])
|
||||
self.assertEqual(self.outputs(tx), [2.98])
|
||||
|
||||
yield self.ledger.release_outputs(utxos)
|
||||
await self.ledger.release_outputs(utxos)
|
||||
|
||||
# supplied input and output, but input is not enough to cover output
|
||||
tx = yield self.tx(
|
||||
tx = await self.tx(
|
||||
[self.txi(self.txo(10))], # inputs
|
||||
[self.txo(11)] # outputs
|
||||
)
|
||||
|
@ -319,10 +321,10 @@ class TransactionIOBalancing(unittest.TestCase):
|
|||
# change is now needed to consume extra input
|
||||
self.assertEqual([11, 1.96], self.outputs(tx))
|
||||
|
||||
yield self.ledger.release_outputs(utxos)
|
||||
await self.ledger.release_outputs(utxos)
|
||||
|
||||
# liquidating a UTXO
|
||||
tx = yield self.tx(
|
||||
tx = await self.tx(
|
||||
[self.txi(self.txo(10))], # inputs
|
||||
[] # outputs
|
||||
)
|
||||
|
@ -330,10 +332,10 @@ class TransactionIOBalancing(unittest.TestCase):
|
|||
# missing change added to consume the amount
|
||||
self.assertEqual([9.98], self.outputs(tx))
|
||||
|
||||
yield self.ledger.release_outputs(utxos)
|
||||
await self.ledger.release_outputs(utxos)
|
||||
|
||||
# liquidating at a loss, requires adding extra inputs
|
||||
tx = yield self.tx(
|
||||
tx = await self.tx(
|
||||
[self.txi(self.txo(0.01))], # inputs
|
||||
[] # outputs
|
||||
)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from twisted.trial import unittest
|
||||
import unittest
|
||||
|
||||
from torba.util import ArithUint256
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import unittest
|
||||
import tempfile
|
||||
from twisted.trial import unittest
|
||||
|
||||
from torba.coin.bitcoinsegwit import MainNetLedger as BTCLedger
|
||||
from torba.coin.bitcoincash import MainNetLedger as BCHLedger
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import random
|
||||
import typing
|
||||
from typing import List, Dict, Tuple, Type, Optional, Any
|
||||
from twisted.internet import defer
|
||||
|
||||
from torba.mnemonic import Mnemonic
|
||||
from torba.bip32 import PrivateKey, PubKey, from_extended_key_string
|
||||
|
@ -44,12 +43,8 @@ class AddressManager:
|
|||
def to_dict_instance(self) -> Optional[dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def db(self):
|
||||
return self.account.ledger.db
|
||||
|
||||
def _query_addresses(self, **constraints):
|
||||
return self.db.get_addresses(
|
||||
return self.account.ledger.db.get_addresses(
|
||||
account=self.account,
|
||||
chain=self.chain_number,
|
||||
**constraints
|
||||
|
@ -58,26 +53,24 @@ class AddressManager:
|
|||
def get_private_key(self, index: int) -> PrivateKey:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_max_gap(self) -> defer.Deferred:
|
||||
async def get_max_gap(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def ensure_address_gap(self) -> defer.Deferred:
|
||||
async def ensure_address_gap(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_address_records(self, only_usable: bool = False, **constraints) -> defer.Deferred:
|
||||
def get_address_records(self, only_usable: bool = False, **constraints):
|
||||
raise NotImplementedError
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_addresses(self, only_usable: bool = False, **constraints) -> defer.Deferred:
|
||||
records = yield self.get_address_records(only_usable=only_usable, **constraints)
|
||||
async def get_addresses(self, only_usable: bool = False, **constraints):
|
||||
records = await self.get_address_records(only_usable=only_usable, **constraints)
|
||||
return [r['address'] for r in records]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_or_create_usable_address(self) -> defer.Deferred:
|
||||
addresses = yield self.get_addresses(only_usable=True, limit=10)
|
||||
async def get_or_create_usable_address(self):
|
||||
addresses = await self.get_addresses(only_usable=True, limit=10)
|
||||
if addresses:
|
||||
return random.choice(addresses)
|
||||
addresses = yield self.ensure_address_gap()
|
||||
addresses = await self.ensure_address_gap()
|
||||
return addresses[0]
|
||||
|
||||
|
||||
|
@ -106,22 +99,20 @@ class HierarchicalDeterministic(AddressManager):
|
|||
def get_private_key(self, index: int) -> PrivateKey:
|
||||
return self.account.private_key.child(self.chain_number).child(index)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def generate_keys(self, start: int, end: int) -> defer.Deferred:
|
||||
async def generate_keys(self, start: int, end: int):
|
||||
keys_batch, final_keys = [], []
|
||||
for index in range(start, end+1):
|
||||
keys_batch.append((index, self.public_key.child(index)))
|
||||
if index % 180 == 0 or index == end:
|
||||
yield self.db.add_keys(
|
||||
await self.account.ledger.db.add_keys(
|
||||
self.account, self.chain_number, keys_batch
|
||||
)
|
||||
final_keys.extend(keys_batch)
|
||||
keys_batch.clear()
|
||||
return [key[1].address for key in final_keys]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_max_gap(self) -> defer.Deferred:
|
||||
addresses = yield self._query_addresses(order_by="position ASC")
|
||||
async def get_max_gap(self):
|
||||
addresses = await self._query_addresses(order_by="position ASC")
|
||||
max_gap = 0
|
||||
current_gap = 0
|
||||
for address in addresses:
|
||||
|
@ -132,9 +123,8 @@ class HierarchicalDeterministic(AddressManager):
|
|||
current_gap = 0
|
||||
return max_gap
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def ensure_address_gap(self) -> defer.Deferred:
|
||||
addresses = yield self._query_addresses(limit=self.gap, order_by="position DESC")
|
||||
async def ensure_address_gap(self):
|
||||
addresses = await self._query_addresses(limit=self.gap, order_by="position DESC")
|
||||
|
||||
existing_gap = 0
|
||||
for address in addresses:
|
||||
|
@ -148,7 +138,7 @@ class HierarchicalDeterministic(AddressManager):
|
|||
|
||||
start = addresses[0]['position']+1 if addresses else 0
|
||||
end = start + (self.gap - existing_gap)
|
||||
new_keys = yield self.generate_keys(start, end-1)
|
||||
new_keys = await self.generate_keys(start, end-1)
|
||||
return new_keys
|
||||
|
||||
def get_address_records(self, only_usable: bool = False, **constraints):
|
||||
|
@ -176,20 +166,19 @@ class SingleKey(AddressManager):
|
|||
def get_private_key(self, index: int) -> PrivateKey:
|
||||
return self.account.private_key
|
||||
|
||||
def get_max_gap(self) -> defer.Deferred:
|
||||
return defer.succeed(0)
|
||||
async def get_max_gap(self):
|
||||
return 0
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def ensure_address_gap(self) -> defer.Deferred:
|
||||
exists = yield self.get_address_records()
|
||||
async def ensure_address_gap(self):
|
||||
exists = await self.get_address_records()
|
||||
if not exists:
|
||||
yield self.db.add_keys(
|
||||
await self.account.ledger.db.add_keys(
|
||||
self.account, self.chain_number, [(0, self.public_key)]
|
||||
)
|
||||
return [self.public_key.address]
|
||||
return []
|
||||
|
||||
def get_address_records(self, only_usable: bool = False, **constraints) -> defer.Deferred:
|
||||
def get_address_records(self, only_usable: bool = False, **constraints):
|
||||
return self._query_addresses(**constraints)
|
||||
|
||||
|
||||
|
@ -289,9 +278,8 @@ class BaseAccount:
|
|||
'address_generator': self.address_generator.to_dict(self.receiving, self.change)
|
||||
}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_details(self, show_seed=False, **kwargs):
|
||||
satoshis = yield self.get_balance(**kwargs)
|
||||
async def get_details(self, show_seed=False, **kwargs):
|
||||
satoshis = await self.get_balance(**kwargs)
|
||||
details = {
|
||||
'id': self.id,
|
||||
'name': self.name,
|
||||
|
@ -325,23 +313,21 @@ class BaseAccount:
|
|||
self.password = None
|
||||
self.encrypted = True
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def ensure_address_gap(self):
|
||||
async def ensure_address_gap(self):
|
||||
addresses = []
|
||||
for address_manager in self.address_managers:
|
||||
new_addresses = yield address_manager.ensure_address_gap()
|
||||
new_addresses = await address_manager.ensure_address_gap()
|
||||
addresses.extend(new_addresses)
|
||||
return addresses
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_addresses(self, **constraints) -> defer.Deferred:
|
||||
rows = yield self.ledger.db.select_addresses('address', account=self, **constraints)
|
||||
async def get_addresses(self, **constraints):
|
||||
rows = await self.ledger.db.select_addresses('address', account=self, **constraints)
|
||||
return [r[0] for r in rows]
|
||||
|
||||
def get_address_records(self, **constraints) -> defer.Deferred:
|
||||
def get_address_records(self, **constraints):
|
||||
return self.ledger.db.get_addresses(account=self, **constraints)
|
||||
|
||||
def get_address_count(self, **constraints) -> defer.Deferred:
|
||||
def get_address_count(self, **constraints):
|
||||
return self.ledger.db.get_address_count(account=self, **constraints)
|
||||
|
||||
def get_private_key(self, chain: int, index: int) -> PrivateKey:
|
||||
|
@ -355,10 +341,9 @@ class BaseAccount:
|
|||
constraints.update({'height__lte': height, 'height__gt': 0})
|
||||
return self.ledger.db.get_balance(account=self, **constraints)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_max_gap(self):
|
||||
change_gap = yield self.change.get_max_gap()
|
||||
receiving_gap = yield self.receiving.get_max_gap()
|
||||
async def get_max_gap(self):
|
||||
change_gap = await self.change.get_max_gap()
|
||||
receiving_gap = await self.receiving.get_max_gap()
|
||||
return {
|
||||
'max_change_gap': change_gap,
|
||||
'max_receiving_gap': receiving_gap,
|
||||
|
@ -376,24 +361,23 @@ class BaseAccount:
|
|||
def get_transaction_count(self, **constraints):
|
||||
return self.ledger.db.get_transaction_count(account=self, **constraints)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def fund(self, to_account, amount=None, everything=False,
|
||||
async def fund(self, to_account, amount=None, everything=False,
|
||||
outputs=1, broadcast=False, **constraints):
|
||||
assert self.ledger == to_account.ledger, 'Can only transfer between accounts of the same ledger.'
|
||||
tx_class = self.ledger.transaction_class
|
||||
if everything:
|
||||
utxos = yield self.get_utxos(**constraints)
|
||||
yield self.ledger.reserve_outputs(utxos)
|
||||
tx = yield tx_class.create(
|
||||
utxos = await self.get_utxos(**constraints)
|
||||
await self.ledger.reserve_outputs(utxos)
|
||||
tx = await tx_class.create(
|
||||
inputs=[tx_class.input_class.spend(txo) for txo in utxos],
|
||||
outputs=[],
|
||||
funding_accounts=[self],
|
||||
change_account=to_account
|
||||
)
|
||||
elif amount > 0:
|
||||
to_address = yield to_account.change.get_or_create_usable_address()
|
||||
to_address = await to_account.change.get_or_create_usable_address()
|
||||
to_hash160 = to_account.ledger.address_to_hash160(to_address)
|
||||
tx = yield tx_class.create(
|
||||
tx = await tx_class.create(
|
||||
inputs=[],
|
||||
outputs=[
|
||||
tx_class.output_class.pay_pubkey_hash(amount//outputs, to_hash160)
|
||||
|
@ -406,9 +390,9 @@ class BaseAccount:
|
|||
raise ValueError('An amount is required.')
|
||||
|
||||
if broadcast:
|
||||
yield self.ledger.broadcast(tx)
|
||||
await self.ledger.broadcast(tx)
|
||||
else:
|
||||
yield self.ledger.release_outputs(
|
||||
await self.ledger.release_outputs(
|
||||
[txi.txo_ref.txo for txi in tx.inputs]
|
||||
)
|
||||
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
import logging
|
||||
from typing import Tuple, List, Sequence
|
||||
from typing import Tuple, List
|
||||
|
||||
import sqlite3
|
||||
from twisted.internet import defer
|
||||
from twisted.enterprise import adbapi
|
||||
import aiosqlite
|
||||
|
||||
from torba.hash import TXRefImmutable
|
||||
from torba.basetransaction import BaseTransaction
|
||||
|
@ -107,25 +106,21 @@ def row_dict_or_default(rows, fields, default=None):
|
|||
|
||||
class SQLiteMixin:
|
||||
|
||||
CREATE_TABLES_QUERY: Sequence[str] = ()
|
||||
CREATE_TABLES_QUERY: str
|
||||
|
||||
def __init__(self, path):
|
||||
self._db_path = path
|
||||
self.db: adbapi.ConnectionPool = None
|
||||
self.db: aiosqlite.Connection = None
|
||||
self.ledger = None
|
||||
|
||||
def open(self):
|
||||
async def open(self):
|
||||
log.info("connecting to database: %s", self._db_path)
|
||||
self.db = adbapi.ConnectionPool(
|
||||
'sqlite3', self._db_path, cp_min=1, cp_max=1, check_same_thread=False
|
||||
)
|
||||
return self.db.runInteraction(
|
||||
lambda t: t.executescript(self.CREATE_TABLES_QUERY)
|
||||
)
|
||||
self.db = aiosqlite.connect(self._db_path)
|
||||
await self.db.__aenter__()
|
||||
await self.db.executescript(self.CREATE_TABLES_QUERY)
|
||||
|
||||
def close(self):
|
||||
self.db.close()
|
||||
return defer.succeed(True)
|
||||
async def close(self):
|
||||
await self.db.close()
|
||||
|
||||
@staticmethod
|
||||
def _insert_sql(table: str, data: dict) -> Tuple[str, List]:
|
||||
|
@ -247,78 +242,75 @@ class BaseDatabase(SQLiteMixin):
|
|||
'script': sqlite3.Binary(txo.script.source)
|
||||
}
|
||||
|
||||
def save_transaction_io(self, save_tx, tx: BaseTransaction, address, txhash, history):
|
||||
async def save_transaction_io(self, save_tx, tx: BaseTransaction, address, txhash, history):
|
||||
|
||||
def _steps(t):
|
||||
if save_tx == 'insert':
|
||||
self.execute(t, *self._insert_sql('tx', {
|
||||
if save_tx == 'insert':
|
||||
await self.db.execute(*self._insert_sql('tx', {
|
||||
'txid': tx.id,
|
||||
'raw': sqlite3.Binary(tx.raw),
|
||||
'height': tx.height,
|
||||
'position': tx.position,
|
||||
'is_verified': tx.is_verified
|
||||
}))
|
||||
elif save_tx == 'update':
|
||||
await self.db.execute(*self._update_sql("tx", {
|
||||
'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified
|
||||
}, 'txid = ?', (tx.id,)))
|
||||
|
||||
existing_txos = [r[0] for r in await self.db.execute_fetchall(*query(
|
||||
"SELECT position FROM txo", txid=tx.id
|
||||
))]
|
||||
|
||||
for txo in tx.outputs:
|
||||
if txo.position in existing_txos:
|
||||
continue
|
||||
if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == txhash:
|
||||
await self.db.execute(*self._insert_sql("txo", self.txo_to_row(tx, address, txo)))
|
||||
elif txo.script.is_pay_script_hash:
|
||||
# TODO: implement script hash payments
|
||||
print('Database.save_transaction_io: pay script hash is not implemented!')
|
||||
|
||||
# lookup the address associated with each TXI (via its TXO)
|
||||
txoid_to_address = {r[0]: r[1] for r in await self.db.execute_fetchall(*query(
|
||||
"SELECT txoid, address FROM txo", txoid__in=[txi.txo_ref.id for txi in tx.inputs]
|
||||
))}
|
||||
|
||||
# list of TXIs that have already been added
|
||||
existing_txis = [r[0] for r in await self.db.execute_fetchall(*query(
|
||||
"SELECT txoid FROM txi", txid=tx.id
|
||||
))]
|
||||
|
||||
for txi in tx.inputs:
|
||||
txoid = txi.txo_ref.id
|
||||
new_txi = txoid not in existing_txis
|
||||
address_matches = txoid_to_address.get(txoid) == address
|
||||
if new_txi and address_matches:
|
||||
await self.db.execute(*self._insert_sql("txi", {
|
||||
'txid': tx.id,
|
||||
'raw': sqlite3.Binary(tx.raw),
|
||||
'height': tx.height,
|
||||
'position': tx.position,
|
||||
'is_verified': tx.is_verified
|
||||
'txoid': txoid,
|
||||
'address': address,
|
||||
}))
|
||||
elif save_tx == 'update':
|
||||
self.execute(t, *self._update_sql("tx", {
|
||||
'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified
|
||||
}, 'txid = ?', (tx.id,)))
|
||||
|
||||
existing_txos = [r[0] for r in self.execute(t, *query(
|
||||
"SELECT position FROM txo", txid=tx.id
|
||||
)).fetchall()]
|
||||
await self._set_address_history(address, history)
|
||||
|
||||
for txo in tx.outputs:
|
||||
if txo.position in existing_txos:
|
||||
continue
|
||||
if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == txhash:
|
||||
self.execute(t, *self._insert_sql("txo", self.txo_to_row(tx, address, txo)))
|
||||
elif txo.script.is_pay_script_hash:
|
||||
# TODO: implement script hash payments
|
||||
print('Database.save_transaction_io: pay script hash is not implemented!')
|
||||
|
||||
# lookup the address associated with each TXI (via its TXO)
|
||||
txoid_to_address = {r[0]: r[1] for r in self.execute(t, *query(
|
||||
"SELECT txoid, address FROM txo", txoid__in=[txi.txo_ref.id for txi in tx.inputs]
|
||||
)).fetchall()}
|
||||
|
||||
# list of TXIs that have already been added
|
||||
existing_txis = [r[0] for r in self.execute(t, *query(
|
||||
"SELECT txoid FROM txi", txid=tx.id
|
||||
)).fetchall()]
|
||||
|
||||
for txi in tx.inputs:
|
||||
txoid = txi.txo_ref.id
|
||||
new_txi = txoid not in existing_txis
|
||||
address_matches = txoid_to_address.get(txoid) == address
|
||||
if new_txi and address_matches:
|
||||
self.execute(t, *self._insert_sql("txi", {
|
||||
'txid': tx.id,
|
||||
'txoid': txoid,
|
||||
'address': address,
|
||||
}))
|
||||
|
||||
self._set_address_history(t, address, history)
|
||||
|
||||
return self.db.runInteraction(_steps)
|
||||
|
||||
def reserve_outputs(self, txos, is_reserved=True):
|
||||
async def reserve_outputs(self, txos, is_reserved=True):
|
||||
txoids = [txo.id for txo in txos]
|
||||
return self.run_operation(
|
||||
await self.db.execute(
|
||||
"UPDATE txo SET is_reserved = ? WHERE txoid IN ({})".format(
|
||||
', '.join(['?']*len(txoids))
|
||||
), [is_reserved]+txoids
|
||||
)
|
||||
|
||||
def release_outputs(self, txos):
|
||||
return self.reserve_outputs(txos, is_reserved=False)
|
||||
async def release_outputs(self, txos):
|
||||
await self.reserve_outputs(txos, is_reserved=False)
|
||||
|
||||
def rewind_blockchain(self, above_height): # pylint: disable=no-self-use
|
||||
async def rewind_blockchain(self, above_height): # pylint: disable=no-self-use
|
||||
# TODO:
|
||||
# 1. delete transactions above_height
|
||||
# 2. update address histories removing deleted TXs
|
||||
return defer.succeed(True)
|
||||
return True
|
||||
|
||||
def select_transactions(self, cols, account=None, **constraints):
|
||||
async def select_transactions(self, cols, account=None, **constraints):
|
||||
if 'txid' not in constraints and account is not None:
|
||||
constraints['$account'] = account.public_key.address
|
||||
constraints['txid__in'] = """
|
||||
|
@ -328,13 +320,14 @@ class BaseDatabase(SQLiteMixin):
|
|||
SELECT txi.txid FROM txi
|
||||
JOIN pubkey_address USING (address) WHERE pubkey_address.account = :$account
|
||||
"""
|
||||
return self.run_query(*query("SELECT {} FROM tx".format(cols), **constraints))
|
||||
return await self.db.execute_fetchall(
|
||||
*query("SELECT {} FROM tx".format(cols), **constraints)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_transactions(self, my_account=None, **constraints):
|
||||
async def get_transactions(self, my_account=None, **constraints):
|
||||
my_account = my_account or constraints.get('account', None)
|
||||
|
||||
tx_rows = yield self.select_transactions(
|
||||
tx_rows = await self.select_transactions(
|
||||
'txid, raw, height, position, is_verified',
|
||||
order_by=["height DESC", "position DESC"],
|
||||
**constraints
|
||||
|
@ -352,7 +345,7 @@ class BaseDatabase(SQLiteMixin):
|
|||
|
||||
annotated_txos = {
|
||||
txo.id: txo for txo in
|
||||
(yield self.get_txos(
|
||||
(await self.get_txos(
|
||||
my_account=my_account,
|
||||
txid__in=txids
|
||||
))
|
||||
|
@ -360,7 +353,7 @@ class BaseDatabase(SQLiteMixin):
|
|||
|
||||
referenced_txos = {
|
||||
txo.id: txo for txo in
|
||||
(yield self.get_txos(
|
||||
(await self.get_txos(
|
||||
my_account=my_account,
|
||||
txoid__in=query("SELECT txoid FROM txi", **{'txid__in': txids})[0]
|
||||
))
|
||||
|
@ -380,33 +373,30 @@ class BaseDatabase(SQLiteMixin):
|
|||
|
||||
return txs
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_transaction_count(self, **constraints):
|
||||
async def get_transaction_count(self, **constraints):
|
||||
constraints.pop('offset', None)
|
||||
constraints.pop('limit', None)
|
||||
constraints.pop('order_by', None)
|
||||
count = yield self.select_transactions('count(*)', **constraints)
|
||||
count = await self.select_transactions('count(*)', **constraints)
|
||||
return count[0][0]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_transaction(self, **constraints):
|
||||
txs = yield self.get_transactions(limit=1, **constraints)
|
||||
async def get_transaction(self, **constraints):
|
||||
txs = await self.get_transactions(limit=1, **constraints)
|
||||
if txs:
|
||||
return txs[0]
|
||||
|
||||
def select_txos(self, cols, **constraints):
|
||||
return self.run_query(*query(
|
||||
async def select_txos(self, cols, **constraints):
|
||||
return await self.db.execute_fetchall(*query(
|
||||
"SELECT {} FROM txo"
|
||||
" JOIN pubkey_address USING (address)"
|
||||
" JOIN tx USING (txid)".format(cols), **constraints
|
||||
))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_txos(self, my_account=None, **constraints):
|
||||
async def get_txos(self, my_account=None, **constraints):
|
||||
my_account = my_account or constraints.get('account', None)
|
||||
if isinstance(my_account, BaseAccount):
|
||||
my_account = my_account.public_key.address
|
||||
rows = yield self.select_txos(
|
||||
rows = await self.select_txos(
|
||||
"amount, script, txid, tx.height, txo.position, chain, account", **constraints
|
||||
)
|
||||
output_class = self.ledger.transaction_class.output_class
|
||||
|
@ -421,12 +411,11 @@ class BaseDatabase(SQLiteMixin):
|
|||
) for row in rows
|
||||
]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_txo_count(self, **constraints):
|
||||
async def get_txo_count(self, **constraints):
|
||||
constraints.pop('offset', None)
|
||||
constraints.pop('limit', None)
|
||||
constraints.pop('order_by', None)
|
||||
count = yield self.select_txos('count(*)', **constraints)
|
||||
count = await self.select_txos('count(*)', **constraints)
|
||||
return count[0][0]
|
||||
|
||||
@staticmethod
|
||||
|
@ -442,37 +431,33 @@ class BaseDatabase(SQLiteMixin):
|
|||
self.constrain_utxo(constraints)
|
||||
return self.get_txo_count(**constraints)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_balance(self, **constraints):
|
||||
async def get_balance(self, **constraints):
|
||||
self.constrain_utxo(constraints)
|
||||
balance = yield self.select_txos('SUM(amount)', **constraints)
|
||||
balance = await self.select_txos('SUM(amount)', **constraints)
|
||||
return balance[0][0] or 0
|
||||
|
||||
def select_addresses(self, cols, **constraints):
|
||||
return self.run_query(*query(
|
||||
"SELECT {} FROM pubkey_address".format(cols), **constraints
|
||||
))
|
||||
async def select_addresses(self, cols, **constraints):
|
||||
return await self.db.execute_fetchall(*query(
|
||||
"SELECT {} FROM pubkey_address".format(cols), **constraints
|
||||
))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_addresses(self, cols=('address', 'account', 'chain', 'position', 'used_times'), **constraints):
|
||||
addresses = yield self.select_addresses(', '.join(cols), **constraints)
|
||||
async def get_addresses(self, cols=('address', 'account', 'chain', 'position', 'used_times'), **constraints):
|
||||
addresses = await self.select_addresses(', '.join(cols), **constraints)
|
||||
return rows_to_dict(addresses, cols)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_address_count(self, **constraints):
|
||||
count = yield self.select_addresses('count(*)', **constraints)
|
||||
async def get_address_count(self, **constraints):
|
||||
count = await self.select_addresses('count(*)', **constraints)
|
||||
return count[0][0]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_address(self, **constraints):
|
||||
addresses = yield self.get_addresses(
|
||||
async def get_address(self, **constraints):
|
||||
addresses = await self.get_addresses(
|
||||
cols=('address', 'account', 'chain', 'position', 'pubkey', 'history', 'used_times'),
|
||||
limit=1, **constraints
|
||||
)
|
||||
if addresses:
|
||||
return addresses[0]
|
||||
|
||||
def add_keys(self, account, chain, keys):
|
||||
async def add_keys(self, account, chain, keys):
|
||||
sql = (
|
||||
"insert into pubkey_address "
|
||||
"(address, account, chain, position, pubkey) "
|
||||
|
@ -484,14 +469,13 @@ class BaseDatabase(SQLiteMixin):
|
|||
pubkey.address, account.public_key.address, chain, position,
|
||||
sqlite3.Binary(pubkey.pubkey_bytes)
|
||||
))
|
||||
return self.run_operation(sql, values)
|
||||
await self.db.execute(sql, values)
|
||||
|
||||
@classmethod
|
||||
def _set_address_history(cls, t, address, history):
|
||||
cls.execute(
|
||||
t, "UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
|
||||
async def _set_address_history(self, address, history):
|
||||
await self.db.execute(
|
||||
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
|
||||
(history, history.count(':')//2, address)
|
||||
)
|
||||
|
||||
def set_address_history(self, address, history):
|
||||
return self.db.runInteraction(lambda t: self._set_address_history(t, address, history))
|
||||
async def set_address_history(self, address, history):
|
||||
await self._set_address_history(address, history)
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
from io import BytesIO
|
||||
from typing import Optional, Iterator, Tuple
|
||||
from binascii import hexlify
|
||||
|
||||
from twisted.internet import threads, defer
|
||||
|
||||
from torba.util import ArithUint256
|
||||
from torba.hash import double_sha256
|
||||
|
||||
|
@ -36,16 +35,14 @@ class BaseHeaders:
|
|||
self.io = BytesIO()
|
||||
self.path = path
|
||||
self._size: Optional[int] = None
|
||||
self._header_connect_lock = defer.DeferredLock()
|
||||
self._header_connect_lock = asyncio.Lock()
|
||||
|
||||
def open(self):
|
||||
async def open(self):
|
||||
if self.path != ':memory:':
|
||||
self.io = open(self.path, 'a+b')
|
||||
return defer.succeed(True)
|
||||
|
||||
def close(self):
|
||||
async def close(self):
|
||||
self.io.close()
|
||||
return defer.succeed(True)
|
||||
|
||||
@staticmethod
|
||||
def serialize(header: dict) -> bytes:
|
||||
|
@ -95,16 +92,15 @@ class BaseHeaders:
|
|||
return b'0' * 64
|
||||
return hexlify(double_sha256(header)[::-1])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def connect(self, start: int, headers: bytes):
|
||||
async def connect(self, start: int, headers: bytes) -> int:
|
||||
added = 0
|
||||
bail = False
|
||||
yield self._header_connect_lock.acquire()
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
async with self._header_connect_lock:
|
||||
for height, chunk in self._iterate_chunks(start, headers):
|
||||
try:
|
||||
# validate_chunk() is CPU bound and reads previous chunks from file system
|
||||
yield threads.deferToThread(self.validate_chunk, height, chunk)
|
||||
await loop.run_in_executor(None, self.validate_chunk, height, chunk)
|
||||
except InvalidHeader as e:
|
||||
bail = True
|
||||
chunk = chunk[:(height-e.height+1)*self.header_size]
|
||||
|
@ -115,14 +111,12 @@ class BaseHeaders:
|
|||
self.io.truncate()
|
||||
# .seek()/.write()/.truncate() might also .flush() when needed
|
||||
# the goal here is mainly to ensure we're definitely flush()'ing
|
||||
yield threads.deferToThread(self.io.flush)
|
||||
await loop.run_in_executor(None, self.io.flush)
|
||||
self._size = None
|
||||
added += written
|
||||
if bail:
|
||||
break
|
||||
finally:
|
||||
self._header_connect_lock.release()
|
||||
defer.returnValue(added)
|
||||
return added
|
||||
|
||||
def validate_chunk(self, height, chunk):
|
||||
previous_hash, previous_header, previous_previous_header = None, None, None
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
from binascii import hexlify, unhexlify
|
||||
from io import StringIO
|
||||
|
@ -7,8 +8,6 @@ from typing import Dict, Type, Iterable
|
|||
from operator import itemgetter
|
||||
from collections import namedtuple
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from torba import baseaccount
|
||||
from torba import basenetwork
|
||||
from torba import basetransaction
|
||||
|
@ -104,8 +103,8 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
)
|
||||
|
||||
self._transaction_processing_locks = {}
|
||||
self._utxo_reservation_lock = defer.DeferredLock()
|
||||
self._header_processing_lock = defer.DeferredLock()
|
||||
self._utxo_reservation_lock = asyncio.Lock()
|
||||
self._header_processing_lock = asyncio.Lock()
|
||||
|
||||
@classmethod
|
||||
def get_id(cls):
|
||||
|
@ -135,41 +134,32 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
def add_account(self, account: baseaccount.BaseAccount):
|
||||
self.accounts.append(account)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_private_key_for_address(self, address):
|
||||
match = yield self.db.get_address(address=address)
|
||||
async def get_private_key_for_address(self, address):
|
||||
match = await self.db.get_address(address=address)
|
||||
if match:
|
||||
for account in self.accounts:
|
||||
if match['account'] == account.public_key.address:
|
||||
return account.get_private_key(match['chain'], match['position'])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_effective_amount_estimators(self, funding_accounts: Iterable[baseaccount.BaseAccount]):
|
||||
async def get_effective_amount_estimators(self, funding_accounts: Iterable[baseaccount.BaseAccount]):
|
||||
estimators = []
|
||||
for account in funding_accounts:
|
||||
utxos = yield account.get_utxos()
|
||||
utxos = await account.get_utxos()
|
||||
for utxo in utxos:
|
||||
estimators.append(utxo.get_estimator(self))
|
||||
return estimators
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_spendable_utxos(self, amount: int, funding_accounts):
|
||||
yield self._utxo_reservation_lock.acquire()
|
||||
try:
|
||||
txos = yield self.get_effective_amount_estimators(funding_accounts)
|
||||
async def get_spendable_utxos(self, amount: int, funding_accounts):
|
||||
async with self._utxo_reservation_lock:
|
||||
txos = await self.get_effective_amount_estimators(funding_accounts)
|
||||
selector = CoinSelector(
|
||||
txos, amount,
|
||||
self.transaction_class.output_class.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(self)
|
||||
)
|
||||
spendables = selector.select()
|
||||
if spendables:
|
||||
yield self.reserve_outputs(s.txo for s in spendables)
|
||||
except Exception:
|
||||
log.exception('Failed to get spendable utxos:')
|
||||
raise
|
||||
finally:
|
||||
self._utxo_reservation_lock.release()
|
||||
return spendables
|
||||
await self.reserve_outputs(s.txo for s in spendables)
|
||||
return spendables
|
||||
|
||||
def reserve_outputs(self, txos):
|
||||
return self.db.reserve_outputs(txos)
|
||||
|
@ -177,16 +167,14 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
def release_outputs(self, txos):
|
||||
return self.db.release_outputs(txos)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_local_status(self, address):
|
||||
address_details = yield self.db.get_address(address=address)
|
||||
async def get_local_status(self, address):
|
||||
address_details = await self.db.get_address(address=address)
|
||||
history = address_details['history'] or ''
|
||||
h = sha256(history.encode())
|
||||
return hexlify(h)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_local_history(self, address):
|
||||
address_details = yield self.db.get_address(address=address)
|
||||
async def get_local_history(self, address):
|
||||
address_details = await self.db.get_address(address=address)
|
||||
history = address_details['history'] or ''
|
||||
parts = history.split(':')[:-1]
|
||||
return list(zip(parts[0::2], map(int, parts[1::2])))
|
||||
|
@ -203,43 +191,40 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
working_branch = double_sha256(combined)
|
||||
return hexlify(working_branch[::-1])
|
||||
|
||||
def validate_transaction_and_set_position(self, tx, height, merkle):
|
||||
async def validate_transaction_and_set_position(self, tx, height):
|
||||
if not height <= len(self.headers):
|
||||
return False
|
||||
merkle = await self.network.get_merkle(tx.id, height)
|
||||
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
|
||||
header = self.headers[height]
|
||||
tx.position = merkle['pos']
|
||||
tx.is_verified = merkle_root == header['merkle_root']
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def start(self):
|
||||
async def start(self):
|
||||
if not os.path.exists(self.path):
|
||||
os.mkdir(self.path)
|
||||
yield defer.gatherResults([
|
||||
await asyncio.gather(
|
||||
self.db.open(),
|
||||
self.headers.open()
|
||||
])
|
||||
)
|
||||
first_connection = self.network.on_connected.first
|
||||
self.network.start()
|
||||
yield first_connection
|
||||
yield self.join_network()
|
||||
asyncio.ensure_future(self.network.start())
|
||||
await first_connection
|
||||
await self.join_network()
|
||||
self.network.on_connected.listen(self.join_network)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def join_network(self, *args):
|
||||
async def join_network(self, *args):
|
||||
log.info("Subscribing and updating accounts.")
|
||||
yield self.update_headers()
|
||||
yield self.network.subscribe_headers()
|
||||
yield self.update_accounts()
|
||||
await self.update_headers()
|
||||
await self.network.subscribe_headers()
|
||||
await self.update_accounts()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def stop(self):
|
||||
yield self.network.stop()
|
||||
yield self.db.close()
|
||||
yield self.headers.close()
|
||||
async def stop(self):
|
||||
await self.network.stop()
|
||||
await self.db.close()
|
||||
await self.headers.close()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_headers(self, height=None, headers=None, subscription_update=False):
|
||||
async def update_headers(self, height=None, headers=None, subscription_update=False):
|
||||
rewound = 0
|
||||
while True:
|
||||
|
||||
|
@ -251,14 +236,14 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
subscription_update = False
|
||||
|
||||
if not headers:
|
||||
header_response = yield self.network.get_headers(height, 2001)
|
||||
header_response = await self.network.get_headers(height, 2001)
|
||||
headers = header_response['hex']
|
||||
|
||||
if not headers:
|
||||
# Nothing to do, network thinks we're already at the latest height.
|
||||
return
|
||||
|
||||
added = yield self.headers.connect(height, unhexlify(headers))
|
||||
added = await self.headers.connect(height, unhexlify(headers))
|
||||
if added > 0:
|
||||
height += added
|
||||
self._on_header_controller.add(
|
||||
|
@ -268,7 +253,7 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
# we started rewinding blocks and apparently found
|
||||
# a new chain
|
||||
rewound = 0
|
||||
yield self.db.rewind_blockchain(height)
|
||||
await self.db.rewind_blockchain(height)
|
||||
|
||||
if subscription_update:
|
||||
# subscription updates are for latest header already
|
||||
|
@ -310,66 +295,37 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
# robust sync, turn off subscription update shortcut
|
||||
subscription_update = False
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def receive_header(self, response):
|
||||
yield self._header_processing_lock.acquire()
|
||||
try:
|
||||
async def receive_header(self, response):
|
||||
async with self._header_processing_lock:
|
||||
header = response[0]
|
||||
yield self.update_headers(
|
||||
await self.update_headers(
|
||||
height=header['height'], headers=header['hex'], subscription_update=True
|
||||
)
|
||||
finally:
|
||||
self._header_processing_lock.release()
|
||||
|
||||
def update_accounts(self):
|
||||
return defer.DeferredList([
|
||||
async def update_accounts(self):
|
||||
return await asyncio.gather(*(
|
||||
self.update_account(a) for a in self.accounts
|
||||
])
|
||||
))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_account(self, account): # type: (baseaccount.BaseAccount) -> defer.Defferred
|
||||
async def update_account(self, account: baseaccount.BaseAccount):
|
||||
# Before subscribing, download history for any addresses that don't have any,
|
||||
# this avoids situation where we're getting status updates to addresses we know
|
||||
# need to update anyways. Continue to get history and create more addresses until
|
||||
# all missing addresses are created and history for them is fully restored.
|
||||
yield account.ensure_address_gap()
|
||||
addresses = yield account.get_addresses(used_times=0)
|
||||
await account.ensure_address_gap()
|
||||
addresses = await account.get_addresses(used_times=0)
|
||||
while addresses:
|
||||
yield defer.DeferredList([
|
||||
self.update_history(a) for a in addresses
|
||||
])
|
||||
addresses = yield account.ensure_address_gap()
|
||||
await asyncio.gather(*(self.update_history(a) for a in addresses))
|
||||
addresses = await account.ensure_address_gap()
|
||||
|
||||
# By this point all of the addresses should be restored and we
|
||||
# can now subscribe all of them to receive updates.
|
||||
all_addresses = yield account.get_addresses()
|
||||
yield defer.DeferredList(
|
||||
list(map(self.subscribe_history, all_addresses))
|
||||
)
|
||||
all_addresses = await account.get_addresses()
|
||||
await asyncio.gather(*(self.subscribe_history(a) for a in all_addresses))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _prefetch_history(self, remote_history, local_history):
|
||||
proofs, network_txs, deferreds = {}, {}, []
|
||||
for i, (hex_id, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)):
|
||||
if i < len(local_history) and local_history[i] == (hex_id, remote_height):
|
||||
continue
|
||||
if remote_height > 0:
|
||||
deferreds.append(
|
||||
self.network.get_merkle(hex_id, remote_height).addBoth(
|
||||
lambda result, txid: proofs.__setitem__(txid, result), hex_id)
|
||||
)
|
||||
deferreds.append(
|
||||
self.network.get_transaction(hex_id).addBoth(
|
||||
lambda result, txid: network_txs.__setitem__(txid, result), hex_id)
|
||||
)
|
||||
yield defer.DeferredList(deferreds)
|
||||
return proofs, network_txs
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_history(self, address):
|
||||
remote_history = yield self.network.get_history(address)
|
||||
local_history = yield self.get_local_history(address)
|
||||
proofs, network_txs = yield self._prefetch_history(remote_history, local_history)
|
||||
async def update_history(self, address):
|
||||
remote_history = await self.network.get_history(address)
|
||||
local_history = await self.get_local_history(address)
|
||||
|
||||
synced_history = StringIO()
|
||||
for i, (hex_id, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)):
|
||||
|
@ -379,30 +335,29 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
if i < len(local_history) and local_history[i] == (hex_id, remote_height):
|
||||
continue
|
||||
|
||||
lock = self._transaction_processing_locks.setdefault(hex_id, defer.DeferredLock())
|
||||
lock = self._transaction_processing_locks.setdefault(hex_id, asyncio.Lock())
|
||||
|
||||
yield lock.acquire()
|
||||
await lock.acquire()
|
||||
|
||||
try:
|
||||
|
||||
# see if we have a local copy of transaction, otherwise fetch it from server
|
||||
tx = yield self.db.get_transaction(txid=hex_id)
|
||||
tx = await self.db.get_transaction(txid=hex_id)
|
||||
save_tx = None
|
||||
if tx is None:
|
||||
_raw = network_txs[hex_id]
|
||||
_raw = await self.network.get_transaction(hex_id)
|
||||
tx = self.transaction_class(unhexlify(_raw))
|
||||
save_tx = 'insert'
|
||||
|
||||
tx.height = remote_height
|
||||
|
||||
if remote_height > 0 and (not tx.is_verified or tx.position == -1):
|
||||
self.validate_transaction_and_set_position(tx, remote_height, proofs[hex_id])
|
||||
await self.validate_transaction_and_set_position(tx, remote_height)
|
||||
if save_tx is None:
|
||||
save_tx = 'update'
|
||||
|
||||
yield self.db.save_transaction_io(
|
||||
save_tx, tx, address, self.address_to_hash160(address),
|
||||
synced_history.getvalue()
|
||||
await self.db.save_transaction_io(
|
||||
save_tx, tx, address, self.address_to_hash160(address), synced_history.getvalue()
|
||||
)
|
||||
|
||||
log.debug(
|
||||
|
@ -412,28 +367,22 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
|
||||
self._on_transaction_controller.add(TransactionEvent(address, tx))
|
||||
|
||||
except Exception:
|
||||
log.exception('Failed to synchronize transaction:')
|
||||
raise
|
||||
|
||||
finally:
|
||||
lock.release()
|
||||
if not lock.locked and hex_id in self._transaction_processing_locks:
|
||||
if not lock.locked() and hex_id in self._transaction_processing_locks:
|
||||
del self._transaction_processing_locks[hex_id]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def subscribe_history(self, address):
|
||||
remote_status = yield self.network.subscribe_address(address)
|
||||
local_status = yield self.get_local_status(address)
|
||||
async def subscribe_history(self, address):
|
||||
remote_status = await self.network.subscribe_address(address)
|
||||
local_status = await self.get_local_status(address)
|
||||
if local_status != remote_status:
|
||||
yield self.update_history(address)
|
||||
await self.update_history(address)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def receive_status(self, response):
|
||||
async def receive_status(self, response):
|
||||
address, remote_status = response
|
||||
local_status = yield self.get_local_status(address)
|
||||
local_status = await self.get_local_status(address)
|
||||
if local_status != remote_status:
|
||||
yield self.update_history(address)
|
||||
await self.update_history(address)
|
||||
|
||||
def broadcast(self, tx):
|
||||
return self.network.broadcast(hexlify(tx.raw).decode())
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from typing import Type, MutableSequence, MutableMapping
|
||||
from twisted.internet import defer
|
||||
|
||||
from torba.baseledger import BaseLedger, LedgerRegistry
|
||||
from torba.wallet import Wallet, WalletStorage
|
||||
|
@ -41,11 +41,10 @@ class BaseWalletManager:
|
|||
self.wallets.append(wallet)
|
||||
return wallet
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_detailed_accounts(self, confirmations=6, show_seed=False):
|
||||
async def get_detailed_accounts(self, confirmations=6, show_seed=False):
|
||||
ledgers = {}
|
||||
for i, account in enumerate(self.accounts):
|
||||
details = yield account.get_details(confirmations=confirmations, show_seed=True)
|
||||
details = await account.get_details(confirmations=confirmations, show_seed=True)
|
||||
details['is_default_account'] = i == 0
|
||||
ledger_id = account.ledger.get_id()
|
||||
ledgers.setdefault(ledger_id, [])
|
||||
|
@ -68,16 +67,14 @@ class BaseWalletManager:
|
|||
for account in wallet.accounts:
|
||||
yield account
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def start(self):
|
||||
async def start(self):
|
||||
self.running = True
|
||||
yield defer.DeferredList([
|
||||
await asyncio.gather(*(
|
||||
l.start() for l in self.ledgers.values()
|
||||
])
|
||||
))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def stop(self):
|
||||
yield defer.DeferredList([
|
||||
async def stop(self):
|
||||
await asyncio.gather(*(
|
||||
l.stop() for l in self.ledgers.values()
|
||||
])
|
||||
))
|
||||
self.running = False
|
||||
|
|
|
@ -1,145 +1,36 @@
|
|||
import json
|
||||
import socket
|
||||
import logging
|
||||
from itertools import cycle
|
||||
from twisted.internet import defer, reactor, protocol
|
||||
from twisted.application.internet import ClientService, CancelledError
|
||||
from twisted.internet.endpoints import clientFromString
|
||||
from twisted.protocols.basic import LineOnlyReceiver
|
||||
from twisted.python import failure
|
||||
|
||||
from aiorpcx import ClientSession as BaseClientSession
|
||||
|
||||
from torba import __version__
|
||||
from torba.stream import StreamController
|
||||
from torba.constants import TIMEOUT
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StratumClientProtocol(LineOnlyReceiver):
|
||||
delimiter = b'\n'
|
||||
MAX_LENGTH = 2000000
|
||||
class ClientSession(BaseClientSession):
|
||||
|
||||
def __init__(self):
|
||||
self.request_id = 0
|
||||
self.lookup_table = {}
|
||||
self.session = {}
|
||||
self.network = None
|
||||
|
||||
self.on_disconnected_controller = StreamController()
|
||||
self.on_disconnected = self.on_disconnected_controller.stream
|
||||
|
||||
def _get_id(self):
|
||||
self.request_id += 1
|
||||
return self.request_id
|
||||
|
||||
@property
|
||||
def _ip(self):
|
||||
return self.transport.getPeer().host
|
||||
|
||||
def get_session(self):
|
||||
return self.session
|
||||
|
||||
def connectionMade(self):
|
||||
try:
|
||||
self.transport.setTcpNoDelay(True)
|
||||
self.transport.setTcpKeepAlive(True)
|
||||
if hasattr(socket, "TCP_KEEPIDLE"):
|
||||
self.transport.socket.setsockopt(
|
||||
socket.SOL_TCP, socket.TCP_KEEPIDLE, 120
|
||||
# Seconds before sending keepalive probes
|
||||
)
|
||||
else:
|
||||
log.debug("TCP_KEEPIDLE not available")
|
||||
if hasattr(socket, "TCP_KEEPINTVL"):
|
||||
self.transport.socket.setsockopt(
|
||||
socket.SOL_TCP, socket.TCP_KEEPINTVL, 1
|
||||
# Interval in seconds between keepalive probes
|
||||
)
|
||||
else:
|
||||
log.debug("TCP_KEEPINTVL not available")
|
||||
if hasattr(socket, "TCP_KEEPCNT"):
|
||||
self.transport.socket.setsockopt(
|
||||
socket.SOL_TCP, socket.TCP_KEEPCNT, 5
|
||||
# Failed keepalive probles before declaring other end dead
|
||||
)
|
||||
else:
|
||||
log.debug("TCP_KEEPCNT not available")
|
||||
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
# Supported only by the socket transport,
|
||||
# but there's really no better place in code to trigger this.
|
||||
log.warning("Error setting up socket: %s", err)
|
||||
|
||||
def connectionLost(self, reason=None):
|
||||
self.connected = 0
|
||||
self.on_disconnected_controller.add(True)
|
||||
for deferred in self.lookup_table.values():
|
||||
if not deferred.called:
|
||||
deferred.errback(TimeoutError("Connection dropped."))
|
||||
|
||||
def lineReceived(self, line):
|
||||
log.debug('received: %s', line)
|
||||
|
||||
try:
|
||||
message = json.loads(line)
|
||||
except (ValueError, TypeError):
|
||||
raise ValueError("Cannot decode message '{}'".format(line.strip()))
|
||||
|
||||
if message.get('id'):
|
||||
try:
|
||||
d = self.lookup_table.pop(message['id'])
|
||||
if message.get('error'):
|
||||
d.errback(RuntimeError(message['error']))
|
||||
else:
|
||||
d.callback(message.get('result'))
|
||||
except KeyError:
|
||||
raise LookupError(
|
||||
"Lookup for deferred object for message ID '{}' failed.".format(message['id']))
|
||||
elif message.get('method') in self.network.subscription_controllers:
|
||||
controller = self.network.subscription_controllers[message['method']]
|
||||
controller.add(message.get('params'))
|
||||
else:
|
||||
log.warning("Cannot handle message '%s'", line)
|
||||
|
||||
def rpc(self, method, *args):
|
||||
message_id = self._get_id()
|
||||
message = json.dumps({
|
||||
'id': message_id,
|
||||
'method': method,
|
||||
'params': args
|
||||
})
|
||||
log.debug('sent: %s', message)
|
||||
self.sendLine(message.encode('latin-1'))
|
||||
d = self.lookup_table[message_id] = defer.Deferred()
|
||||
d.addTimeout(
|
||||
TIMEOUT, reactor, onTimeoutCancel=lambda *_: failure.Failure(TimeoutError(
|
||||
"Timeout: Stratum request for '%s' took more than %s seconds" % (method, TIMEOUT)))
|
||||
)
|
||||
return d
|
||||
|
||||
|
||||
class StratumClientFactory(protocol.ClientFactory):
|
||||
|
||||
protocol = StratumClientProtocol
|
||||
|
||||
def __init__(self, network):
|
||||
def __init__(self, *args, network, **kwargs):
|
||||
self.network = network
|
||||
self.client = None
|
||||
super().__init__(*args, **kwargs)
|
||||
self._on_disconnect_controller = StreamController()
|
||||
self.on_disconnected = self._on_disconnect_controller.stream
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
client = self.protocol()
|
||||
client.factory = self
|
||||
client.network = self.network
|
||||
self.client = client
|
||||
return client
|
||||
async def handle_request(self, request):
|
||||
controller = self.network.subscription_controllers[request.method]
|
||||
controller.add(request.args)
|
||||
|
||||
def connection_lost(self, exc):
|
||||
super().connection_lost(exc)
|
||||
self._on_disconnect_controller.add(True)
|
||||
|
||||
|
||||
class BaseNetwork:
|
||||
|
||||
def __init__(self, ledger):
|
||||
self.config = ledger.config
|
||||
self.client = None
|
||||
self.service = None
|
||||
self.client: ClientSession = None
|
||||
self.running = False
|
||||
|
||||
self._on_connected_controller = StreamController()
|
||||
|
@ -156,48 +47,35 @@ class BaseNetwork:
|
|||
'blockchain.address.subscribe': self._on_status_controller,
|
||||
}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def start(self):
|
||||
async def start(self):
|
||||
self.running = True
|
||||
for server in cycle(self.config['default_servers']):
|
||||
connection_string = 'tcp:{}:{}'.format(*server)
|
||||
endpoint = clientFromString(reactor, connection_string)
|
||||
log.debug("Attempting connection to SPV wallet server: %s", connection_string)
|
||||
self.service = ClientService(endpoint, StratumClientFactory(self))
|
||||
self.service.startService()
|
||||
self.client = ClientSession(*server, network=self)
|
||||
try:
|
||||
self.client = yield self.service.whenConnected(failAfterFailures=2)
|
||||
yield self.ensure_server_version()
|
||||
log.info("Successfully connected to SPV wallet server: %s", connection_string)
|
||||
await self.client.create_connection()
|
||||
await self.ensure_server_version()
|
||||
log.info("Successfully connected to SPV wallet server: %s", )
|
||||
self._on_connected_controller.add(True)
|
||||
yield self.client.on_disconnected.first
|
||||
except CancelledError:
|
||||
return
|
||||
await self.client.on_disconnected.first
|
||||
except Exception: # pylint: disable=broad-except
|
||||
log.exception("Connecting to %s raised an exception:", connection_string)
|
||||
finally:
|
||||
self.client = None
|
||||
if self.service is not None:
|
||||
self.service.stopService()
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
def stop(self):
|
||||
async def stop(self):
|
||||
self.running = False
|
||||
if self.service is not None:
|
||||
self.service.stopService()
|
||||
if self.is_connected:
|
||||
return self.client.on_disconnected.first
|
||||
else:
|
||||
return defer.succeed(True)
|
||||
await self.client.close()
|
||||
await self.client.on_disconnected.first
|
||||
|
||||
@property
|
||||
def is_connected(self):
|
||||
return self.client is not None and self.client.connected
|
||||
return self.client is not None and not self.client.is_closing()
|
||||
|
||||
def rpc(self, list_or_method, *args):
|
||||
if self.is_connected:
|
||||
return self.client.rpc(list_or_method, *args)
|
||||
return self.client.send_request(list_or_method, args)
|
||||
else:
|
||||
raise ConnectionError("Attempting to send rpc request when connection is not available.")
|
||||
|
||||
|
|
|
@ -3,8 +3,6 @@ import typing
|
|||
from typing import List, Iterable, Optional
|
||||
from binascii import hexlify
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from torba.basescript import BaseInputScript, BaseOutputScript
|
||||
from torba.baseaccount import BaseAccount
|
||||
from torba.constants import COIN, NULL_HASH32
|
||||
|
@ -426,8 +424,7 @@ class BaseTransaction:
|
|||
return ledger
|
||||
|
||||
@classmethod
|
||||
@defer.inlineCallbacks
|
||||
def create(cls, inputs: Iterable[BaseInput], outputs: Iterable[BaseOutput],
|
||||
async def create(cls, inputs: Iterable[BaseInput], outputs: Iterable[BaseOutput],
|
||||
funding_accounts: Iterable[BaseAccount], change_account: BaseAccount):
|
||||
""" Find optimal set of inputs when only outputs are provided; add change
|
||||
outputs if only inputs are provided or if inputs are greater than outputs. """
|
||||
|
@ -450,7 +447,7 @@ class BaseTransaction:
|
|||
|
||||
if payment < cost:
|
||||
deficit = cost - payment
|
||||
spendables = yield ledger.get_spendable_utxos(deficit, funding_accounts)
|
||||
spendables = await ledger.get_spendable_utxos(deficit, funding_accounts)
|
||||
if not spendables:
|
||||
raise ValueError('Not enough funds to cover this transaction.')
|
||||
payment += sum(s.effective_amount for s in spendables)
|
||||
|
@ -463,28 +460,27 @@ class BaseTransaction:
|
|||
)
|
||||
change = payment - cost
|
||||
if change > cost_of_change:
|
||||
change_address = yield change_account.change.get_or_create_usable_address()
|
||||
change_address = await change_account.change.get_or_create_usable_address()
|
||||
change_hash160 = change_account.ledger.address_to_hash160(change_address)
|
||||
change_amount = change - cost_of_change
|
||||
change_output = cls.output_class.pay_pubkey_hash(change_amount, change_hash160)
|
||||
change_output.is_change = True
|
||||
tx.add_outputs([cls.output_class.pay_pubkey_hash(change_amount, change_hash160)])
|
||||
|
||||
yield tx.sign(funding_accounts)
|
||||
await tx.sign(funding_accounts)
|
||||
|
||||
except Exception as e:
|
||||
log.exception('Failed to synchronize transaction:')
|
||||
yield ledger.release_outputs(tx.outputs)
|
||||
await ledger.release_outputs(tx.outputs)
|
||||
raise e
|
||||
|
||||
defer.returnValue(tx)
|
||||
return tx
|
||||
|
||||
@staticmethod
|
||||
def signature_hash_type(hash_type):
|
||||
return hash_type
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def sign(self, funding_accounts: Iterable[BaseAccount]) -> defer.Deferred:
|
||||
async def sign(self, funding_accounts: Iterable[BaseAccount]):
|
||||
ledger = self.ensure_all_have_same_ledger(funding_accounts)
|
||||
for i, txi in enumerate(self._inputs):
|
||||
assert txi.script is not None
|
||||
|
@ -492,7 +488,7 @@ class BaseTransaction:
|
|||
txo_script = txi.txo_ref.txo.script
|
||||
if txo_script.is_pay_pubkey_hash:
|
||||
address = ledger.hash160_to_address(txo_script.values['pubkey_hash'])
|
||||
private_key = yield ledger.get_private_key_for_address(address)
|
||||
private_key = await ledger.get_private_key_for_address(address)
|
||||
tx = self._serialize_for_signature(i)
|
||||
txi.script.values['signature'] = \
|
||||
private_key.sign(tx) + bytes((self.signature_hash_type(1),))
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
import asyncio
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
|
||||
class BroadcastSubscription:
|
||||
|
@ -31,11 +29,13 @@ class BroadcastSubscription:
|
|||
|
||||
def _add(self, data):
|
||||
if self.can_fire and self._on_data is not None:
|
||||
self._on_data(data)
|
||||
maybe_coroutine = self._on_data(data)
|
||||
if asyncio.iscoroutine(maybe_coroutine):
|
||||
asyncio.ensure_future(maybe_coroutine)
|
||||
|
||||
def _add_error(self, error, traceback):
|
||||
def _add_error(self, exception):
|
||||
if self.can_fire and self._on_error is not None:
|
||||
self._on_error(error, traceback)
|
||||
self._on_error(exception)
|
||||
|
||||
def _close(self):
|
||||
if self.can_fire and self._on_done is not None:
|
||||
|
@ -66,9 +66,9 @@ class StreamController:
|
|||
for subscription in self._iterate_subscriptions:
|
||||
subscription._add(event)
|
||||
|
||||
def add_error(self, error, traceback):
|
||||
def add_error(self, exception):
|
||||
for subscription in self._iterate_subscriptions:
|
||||
subscription._add_error(error, traceback)
|
||||
subscription._add_error(exception)
|
||||
|
||||
def close(self):
|
||||
for subscription in self._iterate_subscriptions:
|
||||
|
@ -108,38 +108,35 @@ class Stream:
|
|||
def listen(self, on_data, on_error=None, on_done=None):
|
||||
return self._controller._listen(on_data, on_error, on_done)
|
||||
|
||||
def deferred_where(self, condition):
|
||||
deferred = Deferred()
|
||||
def where(self, condition) -> asyncio.Future:
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
|
||||
def where_test(value):
|
||||
if condition(value):
|
||||
self._cancel_and_callback(subscription, deferred, value)
|
||||
self._cancel_and_callback(subscription, future, value)
|
||||
|
||||
subscription = self.listen(
|
||||
where_test,
|
||||
lambda error, traceback: self._cancel_and_error(subscription, deferred, error, traceback)
|
||||
lambda exception: self._cancel_and_error(subscription, future, exception)
|
||||
)
|
||||
|
||||
return deferred
|
||||
|
||||
def where(self, condition):
|
||||
return self.deferred_where(condition).asFuture(asyncio.get_event_loop())
|
||||
return future
|
||||
|
||||
@property
|
||||
def first(self):
|
||||
deferred = Deferred()
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
subscription = self.listen(
|
||||
lambda value: self._cancel_and_callback(subscription, deferred, value),
|
||||
lambda error, traceback: self._cancel_and_error(subscription, deferred, error, traceback)
|
||||
lambda value: self._cancel_and_callback(subscription, future, value),
|
||||
lambda exception: self._cancel_and_error(subscription, future, exception)
|
||||
)
|
||||
return deferred
|
||||
return future
|
||||
|
||||
@staticmethod
|
||||
def _cancel_and_callback(subscription, deferred, value):
|
||||
def _cancel_and_callback(subscription: BroadcastSubscription, future: asyncio.Future, value):
|
||||
subscription.cancel()
|
||||
deferred.callback(value)
|
||||
future.set_result(value)
|
||||
|
||||
@staticmethod
|
||||
def _cancel_and_error(subscription, deferred, error, traceback):
|
||||
def _cancel_and_error(subscription: BroadcastSubscription, future: asyncio.Future, exception):
|
||||
subscription.cancel()
|
||||
deferred.errback(Failure(error, exc_tb=traceback))
|
||||
future.set_exception(exception)
|
||||
|
|
Loading…
Reference in a new issue