From 1dfc18683d05f5feca032b769fa0d29accdbd7e4 Mon Sep 17 00:00:00 2001 From: Lex Berezhny Date: Tue, 26 Jun 2018 18:31:42 -0400 Subject: [PATCH] fixing unit tests --- tests/unit/test_account.py | 4 ++-- tests/unit/test_ledger.py | 46 +++++++++++++++++++++++++------------- torba/basedatabase.py | 15 +++++++++---- torba/baseledger.py | 6 ++--- 4 files changed, 47 insertions(+), 24 deletions(-) diff --git a/tests/unit/test_account.py b/tests/unit/test_account.py index 95d3a68ab..7c492474f 100644 --- a/tests/unit/test_account.py +++ b/tests/unit/test_account.py @@ -41,13 +41,13 @@ class TestKeyChain(unittest.TestCase): # case #2: only one new addressed needed keys = yield account.receiving.get_addresses(None, True) - yield self.ledger.db.set_address_history(keys[19]['address'], b'a:1:') + yield self.ledger.db.set_address_history(keys[19]['address'], 'a:1:') new_keys = yield account.receiving.ensure_address_gap() self.assertEqual(len(new_keys), 1) # case #3: 20 addresses needed keys = yield account.receiving.get_addresses(None, True) - yield self.ledger.db.set_address_history(keys[0]['address'], b'a:1:') + yield self.ledger.db.set_address_history(keys[0]['address'], 'a:1:') new_keys = yield account.receiving.ensure_address_gap() self.assertEqual(len(new_keys), 20) diff --git a/tests/unit/test_ledger.py b/tests/unit/test_ledger.py index b36dd7eb5..4e26d6c2d 100644 --- a/tests/unit/test_ledger.py +++ b/tests/unit/test_ledger.py @@ -5,7 +5,7 @@ from twisted.internet import defer from torba.coin.bitcoinsegwit import MainNetLedger -from .test_transaction import get_transaction +from .test_transaction import get_transaction, get_output if six.PY3: buffer = memoryview @@ -25,15 +25,30 @@ class MockNetwork: self.address = address return defer.succeed(self.history) + def get_merkle(self, txid, height): + return {'merkle': ['abcd01'], 'pos': 1} + def get_transaction(self, tx_hash): self.get_transaction_called.append(tx_hash) - return defer.succeed(self.transaction[tx_hash]) + return defer.succeed(self.transaction[tx_hash.decode()]) + + +class MockHeaders: + def __init__(self, ledger): + self.ledger = ledger + self.height = 1 + + def __len__(self): + return self.height + + def __getitem__(self, height): + return {'merkle_root': 'abcd04'} class TestSynchronization(unittest.TestCase): def setUp(self): - self.ledger = MainNetLedger(db=MainNetLedger.database_class(':memory:')) + self.ledger = MainNetLedger(db=MainNetLedger.database_class(':memory:'), headers_class=MockHeaders) return self.ledger.db.start() @defer.inlineCallbacks @@ -43,21 +58,22 @@ class TestSynchronization(unittest.TestCase): address_details = yield self.ledger.db.get_address(address) self.assertEqual(address_details['history'], None) + self.ledger.headers.height = 3 self.ledger.network = MockNetwork([ - {'tx_hash': b'abc', 'height': 1}, - {'tx_hash': b'def', 'height': 2}, - {'tx_hash': b'ghi', 'height': 3}, + {'tx_hash': b'abcd01', 'height': 1}, + {'tx_hash': b'abcd02', 'height': 2}, + {'tx_hash': b'abcd03', 'height': 3}, ], { - b'abc': hexlify(get_transaction().raw), - b'def': hexlify(get_transaction().raw), - b'ghi': hexlify(get_transaction().raw), + 'abcd01': hexlify(get_transaction(get_output(1)).raw), + 'abcd02': hexlify(get_transaction(get_output(2)).raw), + 'abcd03': hexlify(get_transaction(get_output(3)).raw), }) yield self.ledger.update_history(address) self.assertEqual(self.ledger.network.get_history_called, [address]) - self.assertEqual(self.ledger.network.get_transaction_called, [b'abc', b'def', b'ghi']) + self.assertEqual(self.ledger.network.get_transaction_called, [b'abcd01', b'abcd02', b'abcd03']) address_details = yield self.ledger.db.get_address(address) - self.assertEqual(address_details['history'], buffer(b'abc:1:def:2:ghi:3:')) + self.assertEqual(address_details['history'], 'abcd01:1:abcd02:2:abcd03:3:') self.ledger.network.get_history_called = [] self.ledger.network.get_transaction_called = [] @@ -65,12 +81,12 @@ class TestSynchronization(unittest.TestCase): self.assertEqual(self.ledger.network.get_history_called, [address]) self.assertEqual(self.ledger.network.get_transaction_called, []) - self.ledger.network.history.append({'tx_hash': b'jkl', 'height': 4}) - self.ledger.network.transaction[b'jkl'] = hexlify(get_transaction().raw) + self.ledger.network.history.append({'tx_hash': b'abcd04', 'height': 4}) + self.ledger.network.transaction['abcd04'] = hexlify(get_transaction(get_output(4)).raw) self.ledger.network.get_history_called = [] self.ledger.network.get_transaction_called = [] yield self.ledger.update_history(address) self.assertEqual(self.ledger.network.get_history_called, [address]) - self.assertEqual(self.ledger.network.get_transaction_called, [b'jkl']) + self.assertEqual(self.ledger.network.get_transaction_called, [b'abcd04']) address_details = yield self.ledger.db.get_address(address) - self.assertEqual(address_details['history'], buffer(b'abc:1:def:2:ghi:3:jkl:4:')) + self.assertEqual(address_details['history'], 'abcd01:1:abcd02:2:abcd03:3:abcd04:4:') diff --git a/torba/basedatabase.py b/torba/basedatabase.py index 6923c6e5e..15092d209 100644 --- a/torba/basedatabase.py +++ b/torba/basedatabase.py @@ -208,10 +208,7 @@ class BaseDatabase(SQLiteMixin): 'txoid': txoid[0], })) - t.execute( - "UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?", - (history, history.count(':')//2, sqlite3.Binary(address)) - ) + self._set_address_history(t, address, history) return self.db.runInteraction(_steps) @@ -318,6 +315,16 @@ class BaseDatabase(SQLiteMixin): sql.append('LIMIT {}'.format(limit)) return ' '.join(sql), params + @staticmethod + def _set_address_history(t, address, history): + t.execute( + "UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?", + (history, history.count(':')//2, sqlite3.Binary(address)) + ) + + def set_address_history(self, address, history): + return self.db.runInteraction(lambda t: self._set_address_history(t, address, history)) + def get_unused_addresses(self, account, chain): # type: (torba.baseaccount.BaseAccount, int) -> defer.Deferred[List[str]] return self.query_one_value_list(*self._used_address_sql( diff --git a/torba/baseledger.py b/torba/baseledger.py index 745fa6fcb..2ff360b9e 100644 --- a/torba/baseledger.py +++ b/torba/baseledger.py @@ -61,7 +61,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)): default_fee_per_byte = 10 - def __init__(self, config=None, db=None, network=None): + def __init__(self, config=None, db=None, network=None, headers_class=None): self.config = config or {} self.db = db or self.database_class( os.path.join(self.path, "blockchain.db") @@ -70,7 +70,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)): self.network.on_header.listen(self.process_header) self.network.on_status.listen(self.process_status) self.accounts = set() - self.headers = self.headers_class(self) + self.headers = (headers_class or self.headers_class)(self) self.fee_per_byte = self.config.get('fee_per_byte', self.default_fee_per_byte) self._on_transaction_controller = StreamController() @@ -257,7 +257,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)): synced_history.append((hex_id, remote_height)) - if i < len(local_history) and local_history[i] == (hex_id, remote_height): + if i < len(local_history) and local_history[i] == (hex_id.decode(), remote_height): continue lock = self._transaction_processing_locks.setdefault(hex_id, defer.DeferredLock())