fixing unit tests

This commit is contained in:
Lex Berezhny 2018-06-26 18:31:42 -04:00
parent 9a467f8840
commit 1dfc18683d
4 changed files with 47 additions and 24 deletions

View file

@ -41,13 +41,13 @@ class TestKeyChain(unittest.TestCase):
# case #2: only one new addressed needed
keys = yield account.receiving.get_addresses(None, True)
yield self.ledger.db.set_address_history(keys[19]['address'], b'a:1:')
yield self.ledger.db.set_address_history(keys[19]['address'], 'a:1:')
new_keys = yield account.receiving.ensure_address_gap()
self.assertEqual(len(new_keys), 1)
# case #3: 20 addresses needed
keys = yield account.receiving.get_addresses(None, True)
yield self.ledger.db.set_address_history(keys[0]['address'], b'a:1:')
yield self.ledger.db.set_address_history(keys[0]['address'], 'a:1:')
new_keys = yield account.receiving.ensure_address_gap()
self.assertEqual(len(new_keys), 20)

View file

@ -5,7 +5,7 @@ from twisted.internet import defer
from torba.coin.bitcoinsegwit import MainNetLedger
from .test_transaction import get_transaction
from .test_transaction import get_transaction, get_output
if six.PY3:
buffer = memoryview
@ -25,15 +25,30 @@ class MockNetwork:
self.address = address
return defer.succeed(self.history)
def get_merkle(self, txid, height):
return {'merkle': ['abcd01'], 'pos': 1}
def get_transaction(self, tx_hash):
self.get_transaction_called.append(tx_hash)
return defer.succeed(self.transaction[tx_hash])
return defer.succeed(self.transaction[tx_hash.decode()])
class MockHeaders:
def __init__(self, ledger):
self.ledger = ledger
self.height = 1
def __len__(self):
return self.height
def __getitem__(self, height):
return {'merkle_root': 'abcd04'}
class TestSynchronization(unittest.TestCase):
def setUp(self):
self.ledger = MainNetLedger(db=MainNetLedger.database_class(':memory:'))
self.ledger = MainNetLedger(db=MainNetLedger.database_class(':memory:'), headers_class=MockHeaders)
return self.ledger.db.start()
@defer.inlineCallbacks
@ -43,21 +58,22 @@ class TestSynchronization(unittest.TestCase):
address_details = yield self.ledger.db.get_address(address)
self.assertEqual(address_details['history'], None)
self.ledger.headers.height = 3
self.ledger.network = MockNetwork([
{'tx_hash': b'abc', 'height': 1},
{'tx_hash': b'def', 'height': 2},
{'tx_hash': b'ghi', 'height': 3},
{'tx_hash': b'abcd01', 'height': 1},
{'tx_hash': b'abcd02', 'height': 2},
{'tx_hash': b'abcd03', 'height': 3},
], {
b'abc': hexlify(get_transaction().raw),
b'def': hexlify(get_transaction().raw),
b'ghi': hexlify(get_transaction().raw),
'abcd01': hexlify(get_transaction(get_output(1)).raw),
'abcd02': hexlify(get_transaction(get_output(2)).raw),
'abcd03': hexlify(get_transaction(get_output(3)).raw),
})
yield self.ledger.update_history(address)
self.assertEqual(self.ledger.network.get_history_called, [address])
self.assertEqual(self.ledger.network.get_transaction_called, [b'abc', b'def', b'ghi'])
self.assertEqual(self.ledger.network.get_transaction_called, [b'abcd01', b'abcd02', b'abcd03'])
address_details = yield self.ledger.db.get_address(address)
self.assertEqual(address_details['history'], buffer(b'abc:1:def:2:ghi:3:'))
self.assertEqual(address_details['history'], 'abcd01:1:abcd02:2:abcd03:3:')
self.ledger.network.get_history_called = []
self.ledger.network.get_transaction_called = []
@ -65,12 +81,12 @@ class TestSynchronization(unittest.TestCase):
self.assertEqual(self.ledger.network.get_history_called, [address])
self.assertEqual(self.ledger.network.get_transaction_called, [])
self.ledger.network.history.append({'tx_hash': b'jkl', 'height': 4})
self.ledger.network.transaction[b'jkl'] = hexlify(get_transaction().raw)
self.ledger.network.history.append({'tx_hash': b'abcd04', 'height': 4})
self.ledger.network.transaction['abcd04'] = hexlify(get_transaction(get_output(4)).raw)
self.ledger.network.get_history_called = []
self.ledger.network.get_transaction_called = []
yield self.ledger.update_history(address)
self.assertEqual(self.ledger.network.get_history_called, [address])
self.assertEqual(self.ledger.network.get_transaction_called, [b'jkl'])
self.assertEqual(self.ledger.network.get_transaction_called, [b'abcd04'])
address_details = yield self.ledger.db.get_address(address)
self.assertEqual(address_details['history'], buffer(b'abc:1:def:2:ghi:3:jkl:4:'))
self.assertEqual(address_details['history'], 'abcd01:1:abcd02:2:abcd03:3:abcd04:4:')

View file

@ -208,10 +208,7 @@ class BaseDatabase(SQLiteMixin):
'txoid': txoid[0],
}))
t.execute(
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
(history, history.count(':')//2, sqlite3.Binary(address))
)
self._set_address_history(t, address, history)
return self.db.runInteraction(_steps)
@ -318,6 +315,16 @@ class BaseDatabase(SQLiteMixin):
sql.append('LIMIT {}'.format(limit))
return ' '.join(sql), params
@staticmethod
def _set_address_history(t, address, history):
t.execute(
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
(history, history.count(':')//2, sqlite3.Binary(address))
)
def set_address_history(self, address, history):
return self.db.runInteraction(lambda t: self._set_address_history(t, address, history))
def get_unused_addresses(self, account, chain):
# type: (torba.baseaccount.BaseAccount, int) -> defer.Deferred[List[str]]
return self.query_one_value_list(*self._used_address_sql(

View file

@ -61,7 +61,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
default_fee_per_byte = 10
def __init__(self, config=None, db=None, network=None):
def __init__(self, config=None, db=None, network=None, headers_class=None):
self.config = config or {}
self.db = db or self.database_class(
os.path.join(self.path, "blockchain.db")
@ -70,7 +70,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
self.network.on_header.listen(self.process_header)
self.network.on_status.listen(self.process_status)
self.accounts = set()
self.headers = self.headers_class(self)
self.headers = (headers_class or self.headers_class)(self)
self.fee_per_byte = self.config.get('fee_per_byte', self.default_fee_per_byte)
self._on_transaction_controller = StreamController()
@ -257,7 +257,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
synced_history.append((hex_id, remote_height))
if i < len(local_history) and local_history[i] == (hex_id, remote_height):
if i < len(local_history) and local_history[i] == (hex_id.decode(), remote_height):
continue
lock = self._transaction_processing_locks.setdefault(hex_id, defer.DeferredLock())