From d2cd0ece5ffb7aa06028b4c2365f12519fc6a73e Mon Sep 17 00:00:00 2001 From: Victor Shyba Date: Fri, 30 Aug 2019 19:53:51 -0300 Subject: [PATCH] restore multi server, improve sync concurrency --- .../client_tests/integration/test_network.py | 4 +- torba/torba/client/basedatabase.py | 77 +++++++++++-------- torba/torba/client/baseledger.py | 46 ++++++----- torba/torba/client/basenetwork.py | 35 +++++---- 4 files changed, 93 insertions(+), 69 deletions(-) diff --git a/torba/tests/client_tests/integration/test_network.py b/torba/tests/client_tests/integration/test_network.py index 75a28a28a..ee7ed4882 100644 --- a/torba/tests/client_tests/integration/test_network.py +++ b/torba/tests/client_tests/integration/test_network.py @@ -33,7 +33,7 @@ class ReconnectTests(IntegrationTestCase): for session in self.ledger.network.session_pool.sessions: session.trigger_urgent_reconnect.set() await asyncio.wait_for(self.ledger.network.session_pool.new_connection_event.wait(), timeout=1) - self.assertEqual(2, len(self.ledger.network.session_pool.available_sessions)) + self.assertEqual(2, len(list(self.ledger.network.session_pool.available_sessions))) self.assertTrue(self.ledger.network.is_connected) switch_event = self.ledger.network.on_connected.first await node2.stop(True) @@ -126,4 +126,4 @@ class ServerPickingTestCase(AsyncioTestCase): self.assertTrue(all([not session.is_closing() for session in network.session_pool.available_sessions])) # ensure we are connected to all of them after a while await asyncio.sleep(1) - self.assertEqual(len(network.session_pool.available_sessions), 3) + self.assertEqual(len(list(network.session_pool.available_sessions)), 3) diff --git a/torba/torba/client/basedatabase.py b/torba/torba/client/basedatabase.py index d0f37ccaf..a9b802dbf 100644 --- a/torba/torba/client/basedatabase.py +++ b/torba/torba/client/basedatabase.py @@ -227,16 +227,19 @@ class SQLiteMixin: await self.db.close() @staticmethod - def _insert_sql(table: str, data: dict, ignore_duplicate: bool = False) -> Tuple[str, List]: + def _insert_sql(table: str, data: dict, ignore_duplicate: bool = False, + replace: bool = False) -> Tuple[str, List]: columns, values = [], [] for column, value in data.items(): columns.append(column) values.append(value) - or_ignore = "" + policy = "" if ignore_duplicate: - or_ignore = " OR IGNORE" + policy = " OR IGNORE" + if replace: + policy = " OR REPLACE" sql = "INSERT{} INTO {} ({}) VALUES ({})".format( - or_ignore, table, ', '.join(columns), ', '.join(['?'] * len(values)) + policy, table, ', '.join(columns), ', '.join(['?'] * len(values)) ) return sql, values @@ -348,35 +351,47 @@ class BaseDatabase(SQLiteMixin): 'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified }, 'txid = ?', (tx.id,))) + def _transaction_io(self, conn: sqlite3.Connection, tx: BaseTransaction, address, txhash, history): + conn.execute(*self._insert_sql('tx', { + 'txid': tx.id, + 'raw': sqlite3.Binary(tx.raw), + 'height': tx.height, + 'position': tx.position, + 'is_verified': tx.is_verified + }, replace=True)) + + for txo in tx.outputs: + if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == txhash: + conn.execute(*self._insert_sql( + "txo", self.txo_to_row(tx, address, txo), ignore_duplicate=True + )) + elif txo.script.is_pay_script_hash: + # TODO: implement script hash payments + log.warning('Database.save_transaction_io: pay script hash is not implemented!') + + for txi in tx.inputs: + if txi.txo_ref.txo is not None: + txo = txi.txo_ref.txo + if txo.get_address(self.ledger) == address: + conn.execute(*self._insert_sql("txi", { + 'txid': tx.id, + 'txoid': txo.id, + 'address': address, + }, ignore_duplicate=True)) + + conn.execute( + "UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?", + (history, history.count(':') // 2, address) + ) + def save_transaction_io(self, tx: BaseTransaction, address, txhash, history): + return self.db.run(self._transaction_io, tx, address, txhash, history) - def _transaction(conn: sqlite3.Connection, tx: BaseTransaction, address, txhash, history): - - for txo in tx.outputs: - if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == txhash: - conn.execute(*self._insert_sql( - "txo", self.txo_to_row(tx, address, txo), ignore_duplicate=True - )) - elif txo.script.is_pay_script_hash: - # TODO: implement script hash payments - log.warning('Database.save_transaction_io: pay script hash is not implemented!') - - for txi in tx.inputs: - if txi.txo_ref.txo is not None: - txo = txi.txo_ref.txo - if txo.get_address(self.ledger) == address: - conn.execute(*self._insert_sql("txi", { - 'txid': tx.id, - 'txoid': txo.id, - 'address': address, - }, ignore_duplicate=True)) - - conn.execute( - "UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?", - (history, history.count(':')//2, address) - ) - - return self.db.run(_transaction, tx, address, txhash, history) + def save_transaction_io_batch(self, txs: Iterable[BaseTransaction], address, txhash, history): + def __many(conn): + for tx in txs: + self._transaction_io(conn, tx, address, txhash, history) + return self.db.run(__many) async def reserve_outputs(self, txos, is_reserved=True): txoids = ((is_reserved, txo.id) for txo in txos) diff --git a/torba/torba/client/baseledger.py b/torba/torba/client/baseledger.py index 5c47cb79f..c31612aee 100644 --- a/torba/torba/client/baseledger.py +++ b/torba/torba/client/baseledger.py @@ -10,6 +10,7 @@ from operator import itemgetter from collections import namedtuple import pylru +from torba.client.basetransaction import BaseTransaction from torba.tasks import TaskGroup from torba.client import baseaccount, basenetwork, basetransaction from torba.client.basedatabase import BaseDatabase @@ -251,9 +252,10 @@ class BaseLedger(metaclass=LedgerRegistry): self.constraint_account_or_all(constraints) return self.db.get_transaction_count(**constraints) - async def get_local_status_and_history(self, address): - address_details = await self.db.get_address(address=address) - history = address_details['history'] or '' + async def get_local_status_and_history(self, address, history=None): + if not history: + address_details = await self.db.get_address(address=address) + history = address_details['history'] or '' parts = history.split(':')[:-1] return ( hexlify(sha256(history.encode())).decode() if history else None, @@ -420,17 +422,23 @@ class BaseLedger(metaclass=LedgerRegistry): return True remote_history = await self.network.retriable_call(self.network.get_history, address) + remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history)) + we_need = set(remote_history) - set(local_history) + if not we_need: + return True - cache_tasks = [] + cache_tasks: List[asyncio.Future[BaseTransaction]] = [] synced_history = StringIO() - for i, (txid, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)): + for i, (txid, remote_height) in enumerate(remote_history): if i < len(local_history) and local_history[i] == (txid, remote_height) and not cache_tasks: synced_history.write(f'{txid}:{remote_height}:') else: + check_local = (txid, remote_height) not in we_need cache_tasks.append(asyncio.ensure_future( - self.cache_transaction(txid, remote_height) + self.cache_transaction(txid, remote_height, check_local=check_local) )) + synced_txs = [] for task in cache_tasks: tx = await task @@ -459,11 +467,13 @@ class BaseLedger(metaclass=LedgerRegistry): txi.txo_ref = referenced_txo.ref synced_history.write(f'{tx.id}:{tx.height}:') + synced_txs.append(tx) - await self.db.save_transaction_io( - tx, address, self.address_to_hash160(address), synced_history.getvalue() - ) + await self.db.save_transaction_io_batch( + synced_txs, address, self.address_to_hash160(address), synced_history.getvalue() + ) + for tx in synced_txs: await self._on_transaction_controller.add(TransactionEvent(address, tx)) if address_manager is None: @@ -472,9 +482,10 @@ class BaseLedger(metaclass=LedgerRegistry): if address_manager is not None: await address_manager.ensure_address_gap() - local_status, local_history = await self.get_local_status_and_history(address) + local_status, local_history = \ + await self.get_local_status_and_history(address, synced_history.getvalue()) if local_status != remote_status: - if local_history == list(map(itemgetter('tx_hash', 'height'), remote_history)): + if local_history == remote_history: return True log.warning( "Wallet is out of sync after syncing. Remote: %s with %d items, local: %s with %d items", @@ -487,7 +498,7 @@ class BaseLedger(metaclass=LedgerRegistry): else: return True - async def cache_transaction(self, txid, remote_height): + async def cache_transaction(self, txid, remote_height, check_local=True): cache_item = self._tx_cache.get(txid) if cache_item is None: cache_item = self._tx_cache[txid] = TransactionCacheItem() @@ -500,7 +511,7 @@ class BaseLedger(metaclass=LedgerRegistry): tx = cache_item.tx - if tx is None: + if tx is None and check_local: # check local db tx = cache_item.tx = await self.db.get_transaction(txid=txid) @@ -509,19 +520,12 @@ class BaseLedger(metaclass=LedgerRegistry): _raw = await self.network.retriable_call(self.network.get_transaction, txid) if _raw: tx = self.transaction_class(unhexlify(_raw)) - await self.maybe_verify_transaction(tx, remote_height) - await self.db.insert_transaction(tx) cache_item.tx = tx # make sure it's saved before caching it - return tx if tx is None: raise ValueError(f'Transaction {txid} was not in database and not on network.') - if remote_height > 0 and not tx.is_verified: - # tx from cache / db is not up-to-date - await self.maybe_verify_transaction(tx, remote_height) - await self.db.update_transaction(tx) - + await self.maybe_verify_transaction(tx, remote_height) return tx async def maybe_verify_transaction(self, tx, remote_height): diff --git a/torba/torba/client/basenetwork.py b/torba/torba/client/basenetwork.py index 19dbd58b7..121c3f3d9 100644 --- a/torba/torba/client/basenetwork.py +++ b/torba/torba/client/basenetwork.py @@ -30,11 +30,11 @@ class ClientSession(BaseClientSession): self._on_connect_cb = on_connect_callback or (lambda: None) self.trigger_urgent_reconnect = asyncio.Event() # one request per second of timeout, conservative default - self._semaphore = asyncio.Semaphore(self.timeout) + self._semaphore = asyncio.Semaphore(self.timeout * 2) @property def available(self): - return not self.is_closing() and self._can_send.is_set() and self.response_time is not None + return not self.is_closing() and self.response_time is not None @property def server_address_and_port(self) -> Optional[Tuple[str, int]]: @@ -195,10 +195,8 @@ class BaseNetwork: def is_connected(self): return self.client and not self.client.is_closing() - def rpc(self, list_or_method, args, session=None): - # fixme: use fastest unloaded session, but for now it causes issues with wallet sync - # session = session or self.session_pool.fastest_session - session = self.client + def rpc(self, list_or_method, args, restricted=False): + session = self.client if restricted else self.session_pool.fastest_session if session and not session.is_closing(): return session.send_request(list_or_method, args) else: @@ -225,28 +223,35 @@ class BaseNetwork: def get_transaction(self, tx_hash): return self.rpc('blockchain.transaction.get', [tx_hash]) - def get_transaction_height(self, tx_hash): - return self.rpc('blockchain.transaction.get_height', [tx_hash]) + def get_transaction_height(self, tx_hash, known_height=None): + restricted = True # by default, check master for consistency + if known_height: + if 0 < known_height < self.remote_height - 10: + restricted = False # we can get from any server, its old + return self.rpc('blockchain.transaction.get_height', [tx_hash], restricted) def get_merkle(self, tx_hash, height): - return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height]) + restricted = True # by default, check master for consistency + if 0 < height < self.remote_height - 10: + restricted = False # we can get from any server, its old + return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height], restricted) def get_headers(self, height, count=10000): return self.rpc('blockchain.block.headers', [height, count]) # --- Subscribes, history and broadcasts are always aimed towards the master client directly def get_history(self, address): - return self.rpc('blockchain.address.get_history', [address], session=self.client) + return self.rpc('blockchain.address.get_history', [address], True) def broadcast(self, raw_transaction): - return self.rpc('blockchain.transaction.broadcast', [raw_transaction], session=self.client) + return self.rpc('blockchain.transaction.broadcast', [raw_transaction], True) def subscribe_headers(self): - return self.rpc('blockchain.headers.subscribe', [True], session=self.client) + return self.rpc('blockchain.headers.subscribe', [True], True) async def subscribe_address(self, address): try: - return await self.rpc('blockchain.address.subscribe', [address], session=self.client) + return await self.rpc('blockchain.address.subscribe', [address], True) except asyncio.TimeoutError: # abort and cancel, we cant lose a subscription, it will happen again on reconnect self.client.abort() @@ -267,11 +272,11 @@ class SessionPool: @property def available_sessions(self): - return [session for session in self.sessions if session.available] + return (session for session in self.sessions if session.available) @property def fastest_session(self): - if not self.available_sessions: + if not self.online: return None return min( [((session.response_time + session.connection_latency) * (session.pending_amount + 1), session)