diff --git a/torba/tests/client_tests/unit/test_database.py b/torba/tests/client_tests/unit/test_database.py index e5217d12c..569831c2a 100644 --- a/torba/tests/client_tests/unit/test_database.py +++ b/torba/tests/client_tests/unit/test_database.py @@ -255,7 +255,7 @@ class TestQueries(AsyncioTestCase): self.ledger.db.db.execute_fetchall = check_parameters_length account = await self.create_account() tx = await self.create_tx_from_nothing(account, 0) - for height in range(1200): + for height in range(1, 1200): tx = await self.create_tx_from_txo(tx.outputs[0], account, height=height) variable_limit = self.ledger.db.MAX_QUERY_VARIABLES for limit in range(variable_limit-2, variable_limit+2): diff --git a/torba/tests/client_tests/unit/test_ledger.py b/torba/tests/client_tests/unit/test_ledger.py index 0e077b441..beb26a5d1 100644 --- a/torba/tests/client_tests/unit/test_ledger.py +++ b/torba/tests/client_tests/unit/test_ledger.py @@ -18,6 +18,9 @@ class MockNetwork: self.get_transaction_called = [] self.is_connected = False + def retriable_call(self, function, *args, **kwargs): + return function(*args, **kwargs) + async def get_history(self, address): self.get_history_called.append(address) self.address = address @@ -121,8 +124,9 @@ class TestSynchronization(LedgerTestCase): ) -class MocHeaderNetwork: +class MocHeaderNetwork(MockNetwork): def __init__(self, responses): + super().__init__(None, None) self.responses = responses async def get_headers(self, height, blocks): diff --git a/torba/torba/client/baseledger.py b/torba/torba/client/baseledger.py index 8f16455d9..c11c8648b 100644 --- a/torba/torba/client/baseledger.py +++ b/torba/torba/client/baseledger.py @@ -310,7 +310,7 @@ class BaseLedger(metaclass=LedgerRegistry): subscription_update = False if not headers: - header_response = await self.network.get_headers(height, 2001) + header_response = await self.network.retriable_call(self.network.get_headers, height, 2001) headers = header_response['hex'] if not headers: @@ -395,13 +395,9 @@ class BaseLedger(metaclass=LedgerRegistry): async def subscribe_addresses(self, address_manager: baseaccount.AddressManager, addresses: List[str]): if self.network.is_connected and addresses: - await asyncio.wait([ - self.subscribe_address(address_manager, address) for address in addresses - ]) - - async def subscribe_address(self, address_manager: baseaccount.AddressManager, address: str): - remote_status = await self.network.subscribe_address(address) - self._update_tasks.add(self.update_history(address, remote_status, address_manager)) + async for address, remote_status in self.network.subscribe_address(*addresses): + # subscribe isnt a retriable call as it happens right after a connection is made + self._update_tasks.add(self.update_history(address, remote_status, address_manager)) def process_status_update(self, update): address, remote_status = update @@ -417,7 +413,7 @@ class BaseLedger(metaclass=LedgerRegistry): if local_status == remote_status: return - remote_history = await self.network.get_history(address) + remote_history = await self.network.retriable_call(self.network.get_history, address) cache_tasks = [] synced_history = StringIO() @@ -489,7 +485,7 @@ class BaseLedger(metaclass=LedgerRegistry): if tx is None: # fetch from network - _raw = await self.network.get_transaction(txid) + _raw = await self.network.retriable_call(self.network.get_transaction, txid) if _raw: tx = self.transaction_class(unhexlify(_raw)) await self.maybe_verify_transaction(tx, remote_height) @@ -510,7 +506,7 @@ class BaseLedger(metaclass=LedgerRegistry): async def maybe_verify_transaction(self, tx, remote_height): tx.height = remote_height if 0 < remote_height <= len(self.headers): - merkle = await self.network.get_merkle(tx.id, remote_height) + merkle = await self.network.retriable_call(self.network.get_merkle, tx.id, remote_height) merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) header = self.headers[remote_height] tx.position = merkle['pos'] @@ -524,6 +520,7 @@ class BaseLedger(metaclass=LedgerRegistry): return None def broadcast(self, tx): + # broadcast cant be a retriable call yet return self.network.broadcast(hexlify(tx.raw).decode()) async def wait(self, tx: basetransaction.BaseTransaction, height=-1, timeout=None): @@ -545,4 +542,4 @@ class BaseLedger(metaclass=LedgerRegistry): )) for address_record in records ], timeout=timeout) if pending: - raise TimeoutError('Timed out waiting for transaction.') + raise asyncio.TimeoutError('Timed out waiting for transaction.') diff --git a/torba/torba/client/basenetwork.py b/torba/torba/client/basenetwork.py index 39ed1fbdf..572d242a7 100644 --- a/torba/torba/client/basenetwork.py +++ b/torba/torba/client/basenetwork.py @@ -1,8 +1,8 @@ import logging import asyncio from operator import itemgetter -from typing import Dict, Optional -from time import time, perf_counter +from typing import Dict, Optional, Tuple +from time import perf_counter from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError @@ -13,17 +13,20 @@ log = logging.getLogger(__name__) class ClientSession(BaseClientSession): - def __init__(self, *args, network, server, timeout=30, on_connect_callback=None, **kwargs): self.network = network self.server = server super().__init__(*args, **kwargs) self._on_disconnect_controller = StreamController() self.on_disconnected = self._on_disconnect_controller.stream - self.bw_limit = self.framer.max_size = self.max_errors = 1 << 32 + self.framer.max_size = self.max_errors = 1 << 32 + self.bw_limit = -1 self.timeout = timeout self.max_seconds_idle = timeout * 2 self.response_time: Optional[float] = None + self.connection_latency: Optional[float] = None + self._response_samples = 0 + self.pending_amount = 0 self._on_connect_cb = on_connect_callback or (lambda: None) self.trigger_urgent_reconnect = asyncio.Event() @@ -31,20 +34,38 @@ class ClientSession(BaseClientSession): def available(self): return not self.is_closing() and self._can_send.is_set() and self.response_time is not None + @property + def server_address_and_port(self) -> Optional[Tuple[str, int]]: + if not self.transport: + return None + return self.transport.get_extra_info('peername') + + async def send_timed_server_version_request(self, args=()): + log.debug("send version request to %s:%i", *self.server) + start = perf_counter() + result = await asyncio.wait_for( + super().send_request('server.version', args), timeout=self.timeout + ) + current_response_time = perf_counter() - start + response_sum = (self.response_time or 0) * self._response_samples + current_response_time + self.response_time = response_sum / (self._response_samples + 1) + self._response_samples += 1 + return result + async def send_request(self, method, args=()): + self.pending_amount += 1 try: - start = perf_counter() - result = await asyncio.wait_for( + if method == 'server.version': + return await self.send_timed_server_version_request(args) + return await asyncio.wait_for( super().send_request(method, args), timeout=self.timeout ) - self.response_time = perf_counter() - start - return result except RPCError as e: - log.warning("Wallet server returned an error. Code: %s Message: %s", *e.args) + log.warning("Wallet server (%s:%i) returned an error. Code: %s Message: %s", + *self.server, *e.args) raise e - except TimeoutError: - self.response_time = None - raise + finally: + self.pending_amount -= 1 async def ensure_session(self): # Handles reconnecting and maintaining a session alive @@ -56,8 +77,8 @@ class ClientSession(BaseClientSession): await self.create_connection(self.timeout) await self.ensure_server_version() self._on_connect_cb() - if (time() - self.last_send) > self.max_seconds_idle or self.response_time is None: - await self.send_request('server.banner') + if (perf_counter() - self.last_send) > self.max_seconds_idle or self.response_time is None: + await self.ensure_server_version() retry_delay = default_delay except (asyncio.TimeoutError, OSError): await self.close() @@ -67,6 +88,9 @@ class ClientSession(BaseClientSession): await asyncio.wait_for(self.trigger_urgent_reconnect.wait(), timeout=retry_delay) except asyncio.TimeoutError: pass + except asyncio.CancelledError: + self.synchronous_close() + raise finally: self.trigger_urgent_reconnect.clear() @@ -75,7 +99,9 @@ class ClientSession(BaseClientSession): async def create_connection(self, timeout=6): connector = Connector(lambda: self, *self.server) + start = perf_counter() await asyncio.wait_for(connector.create_connection(), timeout=timeout) + self.connection_latency = perf_counter() - start async def handle_request(self, request): controller = self.network.subscription_controllers[request.method] @@ -85,6 +111,9 @@ class ClientSession(BaseClientSession): log.debug("Connection lost: %s:%d", *self.server) super().connection_lost(exc) self.response_time = None + self.connection_latency = None + self._response_samples = 0 + self.pending_amount = 0 self._on_disconnect_controller.add(True) @@ -133,33 +162,34 @@ class BaseNetwork: async def stop(self): self.running = False - if self.session_pool: - self.session_pool.stop() - if self.is_connected: - disconnected = self.client.on_disconnected.first - await self.client.close() - await disconnected + self.session_pool.stop() @property def is_connected(self): return self.client and not self.client.is_closing() - def rpc(self, list_or_method, args): - fastest = self.session_pool.fastest_session - if fastest is not None and self.client != fastest: - self.switch_event.set() - if self.is_connected: - return self.client.send_request(list_or_method, args) + def rpc(self, list_or_method, args, session=None): + session = session or self.session_pool.fastest_session + if session: + return session.send_request(list_or_method, args) else: self.session_pool.trigger_nodelay_connect() raise ConnectionError("Attempting to send rpc request when connection is not available.") + async def retriable_call(self, function, *args, **kwargs): + while self.running: + try: + return await function(*args, **kwargs) + except asyncio.TimeoutError: + log.warning("Wallet server call timed out, retrying.") + except ConnectionError: + if not self.is_connected and self.running: + log.warning("Wallet server unavailable, waiting for it to come back and retry.") + await self.on_connected.first + def _update_remote_height(self, header_args): self.remote_height = header_args[0]["height"] - def broadcast(self, raw_transaction): - return self.rpc('blockchain.transaction.broadcast', [raw_transaction]) - def get_history(self, address): return self.rpc('blockchain.address.get_history', [address]) @@ -175,11 +205,19 @@ class BaseNetwork: def get_headers(self, height, count=10000): return self.rpc('blockchain.block.headers', [height, count]) - def subscribe_headers(self): - return self.rpc('blockchain.headers.subscribe', [True]) + # --- Subscribes and broadcasts are always aimed towards the master client directly + def broadcast(self, raw_transaction): + return self.rpc('blockchain.transaction.broadcast', [raw_transaction], session=self.client) - def subscribe_address(self, address): - return self.rpc('blockchain.address.subscribe', [address]) + def subscribe_headers(self): + return self.rpc('blockchain.headers.subscribe', [True], session=self.client) + + async def subscribe_address(self, *addresses): + async with self.client.send_batch() as batch: + for address in addresses: + batch.add_request('blockchain.address.subscribe', [address]) + for address, status in zip(addresses, batch.results): + yield address, status class SessionPool: @@ -203,30 +241,61 @@ class SessionPool: if not self.available_sessions: return None return min( - [(session.response_time, session) for session in self.available_sessions], key=itemgetter(0) + [((session.response_time + session.connection_latency) * (session.pending_amount + 1), session) + for session in self.available_sessions], + key=itemgetter(0) )[1] + def _get_session_connect_callback(self, session: ClientSession): + loop = asyncio.get_event_loop() + + def callback(): + duplicate_connections = [ + s for s in self.sessions + if s is not session and s.server_address_and_port == session.server_address_and_port + ] + already_connected = None if not duplicate_connections else duplicate_connections[0] + if already_connected: + self.sessions.pop(session).cancel() + session.synchronous_close() + log.info("wallet server %s resolves to the same server as %s, rechecking in an hour", + session.server[0], already_connected.server[0]) + loop.call_later(3600, self._connect_session, session.server) + return + self.new_connection_event.set() + log.info("connected to %s:%i", *session.server) + + return callback + + def _connect_session(self, server: Tuple[str, int]): + session = None + for s in self.sessions: + if s.server == server: + session = s + break + if not session: + session = ClientSession( + network=self.network, server=server + ) + session._on_connect_cb = self._get_session_connect_callback(session) + task = self.sessions.get(session, None) + if not task or task.done(): + task = asyncio.create_task(session.ensure_session()) + task.add_done_callback(lambda _: self.ensure_connections()) + self.sessions[session] = task + def start(self, default_servers): - callback = self.new_connection_event.set - self.sessions = { - ClientSession( - network=self.network, server=server, on_connect_callback=callback - ): None for server in default_servers - } - self.ensure_connections() + for server in default_servers: + self._connect_session(server) def stop(self): - for session, task in self.sessions.items(): + for task in self.sessions.values(): task.cancel() - session.abort() self.sessions.clear() def ensure_connections(self): - for session, task in list(self.sessions.items()): - if not task or task.done(): - task = asyncio.create_task(session.ensure_session()) - task.add_done_callback(lambda _: self.ensure_connections()) - self.sessions[session] = task + for session in self.sessions: + self._connect_session(session.server) def trigger_nodelay_connect(self): # used when other parts of the system sees we might have internet back diff --git a/torba/torba/rpc/jsonrpc.py b/torba/torba/rpc/jsonrpc.py index 4e5cca8ca..2e8bfa2a7 100644 --- a/torba/torba/rpc/jsonrpc.py +++ b/torba/torba/rpc/jsonrpc.py @@ -746,10 +746,8 @@ class JSONRPCConnection(object): self._protocol = item return self.receive_message(message) - def time_out_pending_requests(self): - """Times out all pending requests.""" - # this used to be CancelledError, but thats confusing as in are we closing the whole sdk or failing? - exception = asyncio.TimeoutError() + def raise_pending_requests(self, exception): + exception = exception or asyncio.TimeoutError() for request, event in self._requests.values(): event.result = exception event.set() diff --git a/torba/torba/rpc/session.py b/torba/torba/rpc/session.py index e16b6bbb4..dd9909cfd 100644 --- a/torba/torba/rpc/session.py +++ b/torba/torba/rpc/session.py @@ -103,7 +103,7 @@ class SessionBase(asyncio.Protocol): # Force-close a connection if a send doesn't succeed in this time self.max_send_delay = 60 # Statistics. The RPC object also keeps its own statistics. - self.start_time = time.time() + self.start_time = time.perf_counter() self.errors = 0 self.send_count = 0 self.send_size = 0 @@ -123,7 +123,7 @@ class SessionBase(asyncio.Protocol): # A non-positive value means not to limit concurrency if self.bw_limit <= 0: return - now = time.time() + now = time.perf_counter() # Reduce the recorded usage in proportion to the elapsed time refund = (now - self.bw_time) * (self.bw_limit / 3600) self.bw_charge = max(0, self.bw_charge - int(refund)) @@ -146,7 +146,7 @@ class SessionBase(asyncio.Protocol): await asyncio.wait_for(self._can_send.wait(), secs) except asyncio.TimeoutError: self.abort() - raise asyncio.CancelledError(f'task timed out after {secs}s') + raise asyncio.TimeoutError(f'task timed out after {secs}s') async def _send_message(self, message): if not self._can_send.is_set(): @@ -156,7 +156,7 @@ class SessionBase(asyncio.Protocol): self.send_size += len(framed_message) self._using_bandwidth(len(framed_message)) self.send_count += 1 - self.last_send = time.time() + self.last_send = time.perf_counter() if self.verbosity >= 4: self.logger.debug(f'Sending framed message {framed_message}') self.transport.write(framed_message) @@ -215,7 +215,8 @@ class SessionBase(asyncio.Protocol): self._address = None self.transport = None self._task_group.cancel() - self._pm_task.cancel() + if self._pm_task: + self._pm_task.cancel() # Release waiting tasks self._can_send.set() @@ -253,6 +254,7 @@ class SessionBase(asyncio.Protocol): if self.transport: self.transport.abort() + # TODO: replace with synchronous_close async def close(self, *, force_after=30): """Close the connection and return when closed.""" self._close() @@ -262,6 +264,11 @@ class SessionBase(asyncio.Protocol): self.abort() await self._pm_task + def synchronous_close(self): + self._close() + if self._pm_task and not self._pm_task.done(): + self._pm_task.cancel() + class MessageSession(SessionBase): """Session class for protocols where messages are not tied to responses, @@ -296,7 +303,7 @@ class MessageSession(SessionBase): ) self._bump_errors() else: - self.last_recv = time.time() + self.last_recv = time.perf_counter() self.recv_count += 1 if self.recv_count % 10 == 0: await self._update_concurrency() @@ -416,7 +423,7 @@ class RPCSession(SessionBase): self.logger.warning(f'{e!r}') continue - self.last_recv = time.time() + self.last_recv = time.perf_counter() self.recv_count += 1 if self.recv_count % 10 == 0: await self._update_concurrency() @@ -456,7 +463,7 @@ class RPCSession(SessionBase): def connection_lost(self, exc): # Cancel pending requests and message processing - self.connection.time_out_pending_requests() + self.connection.raise_pending_requests(exc) super().connection_lost(exc) # External API diff --git a/torba/torba/server/session.py b/torba/torba/server/session.py index 9c5981d22..db0e6fa85 100644 --- a/torba/torba/server/session.py +++ b/torba/torba/server/session.py @@ -258,7 +258,7 @@ class SessionManager: session_timeout = self.env.session_timeout while True: await sleep(session_timeout // 10) - stale_cutoff = time.time() - session_timeout + stale_cutoff = time.perf_counter() - session_timeout stale_sessions = [session for session in self.sessions if session.last_recv < stale_cutoff] if stale_sessions: