From f567aca532a0eaa722f3692892f8d08d66ef4627 Mon Sep 17 00:00:00 2001 From: Victor Shyba Date: Sun, 18 Aug 2019 15:40:38 -0300 Subject: [PATCH] retry and batch requests, fix some loose ends --- torba/tests/client_tests/unit/test_ledger.py | 6 ++- torba/torba/client/baseledger.py | 19 ++++----- torba/torba/client/basenetwork.py | 45 ++++++++++++-------- torba/torba/rpc/jsonrpc.py | 6 +-- torba/torba/rpc/session.py | 7 +-- 5 files changed, 47 insertions(+), 36 deletions(-) diff --git a/torba/tests/client_tests/unit/test_ledger.py b/torba/tests/client_tests/unit/test_ledger.py index 0e077b441..beb26a5d1 100644 --- a/torba/tests/client_tests/unit/test_ledger.py +++ b/torba/tests/client_tests/unit/test_ledger.py @@ -18,6 +18,9 @@ class MockNetwork: self.get_transaction_called = [] self.is_connected = False + def retriable_call(self, function, *args, **kwargs): + return function(*args, **kwargs) + async def get_history(self, address): self.get_history_called.append(address) self.address = address @@ -121,8 +124,9 @@ class TestSynchronization(LedgerTestCase): ) -class MocHeaderNetwork: +class MocHeaderNetwork(MockNetwork): def __init__(self, responses): + super().__init__(None, None) self.responses = responses async def get_headers(self, height, blocks): diff --git a/torba/torba/client/baseledger.py b/torba/torba/client/baseledger.py index 8f16455d9..52dba6c54 100644 --- a/torba/torba/client/baseledger.py +++ b/torba/torba/client/baseledger.py @@ -310,7 +310,7 @@ class BaseLedger(metaclass=LedgerRegistry): subscription_update = False if not headers: - header_response = await self.network.get_headers(height, 2001) + header_response = await self.network.retriable_call(self.network.get_headers, height, 2001) headers = header_response['hex'] if not headers: @@ -395,13 +395,9 @@ class BaseLedger(metaclass=LedgerRegistry): async def subscribe_addresses(self, address_manager: baseaccount.AddressManager, addresses: List[str]): if self.network.is_connected and addresses: - await asyncio.wait([ - self.subscribe_address(address_manager, address) for address in addresses - ]) - - async def subscribe_address(self, address_manager: baseaccount.AddressManager, address: str): - remote_status = await self.network.subscribe_address(address) - self._update_tasks.add(self.update_history(address, remote_status, address_manager)) + async for address, remote_status in self.network.subscribe_address(*addresses): + # subscribe isnt a retriable call as it happens right after a connection is made + self._update_tasks.add(self.update_history(address, remote_status, address_manager)) def process_status_update(self, update): address, remote_status = update @@ -417,7 +413,7 @@ class BaseLedger(metaclass=LedgerRegistry): if local_status == remote_status: return - remote_history = await self.network.get_history(address) + remote_history = await self.network.retriable_call(self.network.get_history, address) cache_tasks = [] synced_history = StringIO() @@ -489,7 +485,7 @@ class BaseLedger(metaclass=LedgerRegistry): if tx is None: # fetch from network - _raw = await self.network.get_transaction(txid) + _raw = await self.network.retriable_call(self.network.get_transaction, txid) if _raw: tx = self.transaction_class(unhexlify(_raw)) await self.maybe_verify_transaction(tx, remote_height) @@ -510,7 +506,7 @@ class BaseLedger(metaclass=LedgerRegistry): async def maybe_verify_transaction(self, tx, remote_height): tx.height = remote_height if 0 < remote_height <= len(self.headers): - merkle = await self.network.get_merkle(tx.id, remote_height) + merkle = await self.network.retriable_call(self.network.get_merkle, tx.id, remote_height) merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) header = self.headers[remote_height] tx.position = merkle['pos'] @@ -524,6 +520,7 @@ class BaseLedger(metaclass=LedgerRegistry): return None def broadcast(self, tx): + # broadcast cant be a retriable call yet return self.network.broadcast(hexlify(tx.raw).decode()) async def wait(self, tx: basetransaction.BaseTransaction, height=-1, timeout=None): diff --git a/torba/torba/client/basenetwork.py b/torba/torba/client/basenetwork.py index 39ed1fbdf..5ae9f1a47 100644 --- a/torba/torba/client/basenetwork.py +++ b/torba/torba/client/basenetwork.py @@ -135,31 +135,33 @@ class BaseNetwork: self.running = False if self.session_pool: self.session_pool.stop() - if self.is_connected: - disconnected = self.client.on_disconnected.first - await self.client.close() - await disconnected @property def is_connected(self): return self.client and not self.client.is_closing() - def rpc(self, list_or_method, args): - fastest = self.session_pool.fastest_session - if fastest is not None and self.client != fastest: - self.switch_event.set() - if self.is_connected: - return self.client.send_request(list_or_method, args) + def rpc(self, list_or_method, args, session=None): + session = session or self.session_pool.fastest_session + if session: + return session.send_request(list_or_method, args) else: self.session_pool.trigger_nodelay_connect() raise ConnectionError("Attempting to send rpc request when connection is not available.") + async def retriable_call(self, function, *args, **kwargs): + while True: + try: + return await function(*args, **kwargs) + except asyncio.TimeoutError: + log.warning("Wallet server call timed out, retrying.") + except ConnectionError: + if not self.is_connected: + log.warning("Wallet server unavailable, waiting for it to come back and retry.") + await self.on_connected.first + def _update_remote_height(self, header_args): self.remote_height = header_args[0]["height"] - def broadcast(self, raw_transaction): - return self.rpc('blockchain.transaction.broadcast', [raw_transaction]) - def get_history(self, address): return self.rpc('blockchain.address.get_history', [address]) @@ -175,11 +177,19 @@ class BaseNetwork: def get_headers(self, height, count=10000): return self.rpc('blockchain.block.headers', [height, count]) - def subscribe_headers(self): - return self.rpc('blockchain.headers.subscribe', [True]) + # --- Subscribes and broadcasts are always aimed towards the master client directly + def broadcast(self, raw_transaction): + return self.rpc('blockchain.transaction.broadcast', [raw_transaction], session=self.client) - def subscribe_address(self, address): - return self.rpc('blockchain.address.subscribe', [address]) + def subscribe_headers(self): + return self.rpc('blockchain.headers.subscribe', [True], session=self.client) + + async def subscribe_address(self, *addresses): + async with self.client.send_batch() as batch: + for address in addresses: + batch.add_request('blockchain.address.subscribe', [address]) + for address, status in zip(addresses, batch.results): + yield address, status class SessionPool: @@ -218,6 +228,7 @@ class SessionPool: def stop(self): for session, task in self.sessions.items(): task.cancel() + session.connection_lost(asyncio.CancelledError()) session.abort() self.sessions.clear() diff --git a/torba/torba/rpc/jsonrpc.py b/torba/torba/rpc/jsonrpc.py index 4e5cca8ca..2e8bfa2a7 100644 --- a/torba/torba/rpc/jsonrpc.py +++ b/torba/torba/rpc/jsonrpc.py @@ -746,10 +746,8 @@ class JSONRPCConnection(object): self._protocol = item return self.receive_message(message) - def time_out_pending_requests(self): - """Times out all pending requests.""" - # this used to be CancelledError, but thats confusing as in are we closing the whole sdk or failing? - exception = asyncio.TimeoutError() + def raise_pending_requests(self, exception): + exception = exception or asyncio.TimeoutError() for request, event in self._requests.values(): event.result = exception event.set() diff --git a/torba/torba/rpc/session.py b/torba/torba/rpc/session.py index e16b6bbb4..9ff3e7ed0 100644 --- a/torba/torba/rpc/session.py +++ b/torba/torba/rpc/session.py @@ -146,7 +146,7 @@ class SessionBase(asyncio.Protocol): await asyncio.wait_for(self._can_send.wait(), secs) except asyncio.TimeoutError: self.abort() - raise asyncio.CancelledError(f'task timed out after {secs}s') + raise asyncio.TimeoutError(f'task timed out after {secs}s') async def _send_message(self, message): if not self._can_send.is_set(): @@ -215,7 +215,8 @@ class SessionBase(asyncio.Protocol): self._address = None self.transport = None self._task_group.cancel() - self._pm_task.cancel() + if self._pm_task: + self._pm_task.cancel() # Release waiting tasks self._can_send.set() @@ -456,7 +457,7 @@ class RPCSession(SessionBase): def connection_lost(self, exc): # Cancel pending requests and message processing - self.connection.time_out_pending_requests() + self.connection.raise_pending_requests(exc) super().connection_lost(exc) # External API