fixing unit tests
This commit is contained in:
parent
9a467f8840
commit
1dfc18683d
4 changed files with 47 additions and 24 deletions
|
@ -41,13 +41,13 @@ class TestKeyChain(unittest.TestCase):
|
|||
|
||||
# case #2: only one new addressed needed
|
||||
keys = yield account.receiving.get_addresses(None, True)
|
||||
yield self.ledger.db.set_address_history(keys[19]['address'], b'a:1:')
|
||||
yield self.ledger.db.set_address_history(keys[19]['address'], 'a:1:')
|
||||
new_keys = yield account.receiving.ensure_address_gap()
|
||||
self.assertEqual(len(new_keys), 1)
|
||||
|
||||
# case #3: 20 addresses needed
|
||||
keys = yield account.receiving.get_addresses(None, True)
|
||||
yield self.ledger.db.set_address_history(keys[0]['address'], b'a:1:')
|
||||
yield self.ledger.db.set_address_history(keys[0]['address'], 'a:1:')
|
||||
new_keys = yield account.receiving.ensure_address_gap()
|
||||
self.assertEqual(len(new_keys), 20)
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ from twisted.internet import defer
|
|||
|
||||
from torba.coin.bitcoinsegwit import MainNetLedger
|
||||
|
||||
from .test_transaction import get_transaction
|
||||
from .test_transaction import get_transaction, get_output
|
||||
|
||||
if six.PY3:
|
||||
buffer = memoryview
|
||||
|
@ -25,15 +25,30 @@ class MockNetwork:
|
|||
self.address = address
|
||||
return defer.succeed(self.history)
|
||||
|
||||
def get_merkle(self, txid, height):
|
||||
return {'merkle': ['abcd01'], 'pos': 1}
|
||||
|
||||
def get_transaction(self, tx_hash):
|
||||
self.get_transaction_called.append(tx_hash)
|
||||
return defer.succeed(self.transaction[tx_hash])
|
||||
return defer.succeed(self.transaction[tx_hash.decode()])
|
||||
|
||||
|
||||
class MockHeaders:
|
||||
def __init__(self, ledger):
|
||||
self.ledger = ledger
|
||||
self.height = 1
|
||||
|
||||
def __len__(self):
|
||||
return self.height
|
||||
|
||||
def __getitem__(self, height):
|
||||
return {'merkle_root': 'abcd04'}
|
||||
|
||||
|
||||
class TestSynchronization(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.ledger = MainNetLedger(db=MainNetLedger.database_class(':memory:'))
|
||||
self.ledger = MainNetLedger(db=MainNetLedger.database_class(':memory:'), headers_class=MockHeaders)
|
||||
return self.ledger.db.start()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -43,21 +58,22 @@ class TestSynchronization(unittest.TestCase):
|
|||
address_details = yield self.ledger.db.get_address(address)
|
||||
self.assertEqual(address_details['history'], None)
|
||||
|
||||
self.ledger.headers.height = 3
|
||||
self.ledger.network = MockNetwork([
|
||||
{'tx_hash': b'abc', 'height': 1},
|
||||
{'tx_hash': b'def', 'height': 2},
|
||||
{'tx_hash': b'ghi', 'height': 3},
|
||||
{'tx_hash': b'abcd01', 'height': 1},
|
||||
{'tx_hash': b'abcd02', 'height': 2},
|
||||
{'tx_hash': b'abcd03', 'height': 3},
|
||||
], {
|
||||
b'abc': hexlify(get_transaction().raw),
|
||||
b'def': hexlify(get_transaction().raw),
|
||||
b'ghi': hexlify(get_transaction().raw),
|
||||
'abcd01': hexlify(get_transaction(get_output(1)).raw),
|
||||
'abcd02': hexlify(get_transaction(get_output(2)).raw),
|
||||
'abcd03': hexlify(get_transaction(get_output(3)).raw),
|
||||
})
|
||||
yield self.ledger.update_history(address)
|
||||
self.assertEqual(self.ledger.network.get_history_called, [address])
|
||||
self.assertEqual(self.ledger.network.get_transaction_called, [b'abc', b'def', b'ghi'])
|
||||
self.assertEqual(self.ledger.network.get_transaction_called, [b'abcd01', b'abcd02', b'abcd03'])
|
||||
|
||||
address_details = yield self.ledger.db.get_address(address)
|
||||
self.assertEqual(address_details['history'], buffer(b'abc:1:def:2:ghi:3:'))
|
||||
self.assertEqual(address_details['history'], 'abcd01:1:abcd02:2:abcd03:3:')
|
||||
|
||||
self.ledger.network.get_history_called = []
|
||||
self.ledger.network.get_transaction_called = []
|
||||
|
@ -65,12 +81,12 @@ class TestSynchronization(unittest.TestCase):
|
|||
self.assertEqual(self.ledger.network.get_history_called, [address])
|
||||
self.assertEqual(self.ledger.network.get_transaction_called, [])
|
||||
|
||||
self.ledger.network.history.append({'tx_hash': b'jkl', 'height': 4})
|
||||
self.ledger.network.transaction[b'jkl'] = hexlify(get_transaction().raw)
|
||||
self.ledger.network.history.append({'tx_hash': b'abcd04', 'height': 4})
|
||||
self.ledger.network.transaction['abcd04'] = hexlify(get_transaction(get_output(4)).raw)
|
||||
self.ledger.network.get_history_called = []
|
||||
self.ledger.network.get_transaction_called = []
|
||||
yield self.ledger.update_history(address)
|
||||
self.assertEqual(self.ledger.network.get_history_called, [address])
|
||||
self.assertEqual(self.ledger.network.get_transaction_called, [b'jkl'])
|
||||
self.assertEqual(self.ledger.network.get_transaction_called, [b'abcd04'])
|
||||
address_details = yield self.ledger.db.get_address(address)
|
||||
self.assertEqual(address_details['history'], buffer(b'abc:1:def:2:ghi:3:jkl:4:'))
|
||||
self.assertEqual(address_details['history'], 'abcd01:1:abcd02:2:abcd03:3:abcd04:4:')
|
||||
|
|
|
@ -208,10 +208,7 @@ class BaseDatabase(SQLiteMixin):
|
|||
'txoid': txoid[0],
|
||||
}))
|
||||
|
||||
t.execute(
|
||||
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
|
||||
(history, history.count(':')//2, sqlite3.Binary(address))
|
||||
)
|
||||
self._set_address_history(t, address, history)
|
||||
|
||||
return self.db.runInteraction(_steps)
|
||||
|
||||
|
@ -318,6 +315,16 @@ class BaseDatabase(SQLiteMixin):
|
|||
sql.append('LIMIT {}'.format(limit))
|
||||
return ' '.join(sql), params
|
||||
|
||||
@staticmethod
|
||||
def _set_address_history(t, address, history):
|
||||
t.execute(
|
||||
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
|
||||
(history, history.count(':')//2, sqlite3.Binary(address))
|
||||
)
|
||||
|
||||
def set_address_history(self, address, history):
|
||||
return self.db.runInteraction(lambda t: self._set_address_history(t, address, history))
|
||||
|
||||
def get_unused_addresses(self, account, chain):
|
||||
# type: (torba.baseaccount.BaseAccount, int) -> defer.Deferred[List[str]]
|
||||
return self.query_one_value_list(*self._used_address_sql(
|
||||
|
|
|
@ -61,7 +61,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
|
|||
|
||||
default_fee_per_byte = 10
|
||||
|
||||
def __init__(self, config=None, db=None, network=None):
|
||||
def __init__(self, config=None, db=None, network=None, headers_class=None):
|
||||
self.config = config or {}
|
||||
self.db = db or self.database_class(
|
||||
os.path.join(self.path, "blockchain.db")
|
||||
|
@ -70,7 +70,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
|
|||
self.network.on_header.listen(self.process_header)
|
||||
self.network.on_status.listen(self.process_status)
|
||||
self.accounts = set()
|
||||
self.headers = self.headers_class(self)
|
||||
self.headers = (headers_class or self.headers_class)(self)
|
||||
self.fee_per_byte = self.config.get('fee_per_byte', self.default_fee_per_byte)
|
||||
|
||||
self._on_transaction_controller = StreamController()
|
||||
|
@ -257,7 +257,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
|
|||
|
||||
synced_history.append((hex_id, remote_height))
|
||||
|
||||
if i < len(local_history) and local_history[i] == (hex_id, remote_height):
|
||||
if i < len(local_history) and local_history[i] == (hex_id.decode(), remote_height):
|
||||
continue
|
||||
|
||||
lock = self._transaction_processing_locks.setdefault(hex_id, defer.DeferredLock())
|
||||
|
|
Loading…
Reference in a new issue