diff --git a/torba/rpc/curio.py b/torba/rpc/curio.py index cccac1d3b..b07f44b29 100644 --- a/torba/rpc/curio.py +++ b/torba/rpc/curio.py @@ -40,25 +40,9 @@ import logging import asyncio from collections import deque from contextlib import suppress -from functools import partial - -from .util import normalize_corofunc, check_task -__all__ = ( - 'spawn_sync', 'TaskGroup', - 'TaskTimeout', 'TimeoutCancellationError', 'UncaughtTimeoutError', - 'timeout_after', 'ignore_after' -) - - -def spawn_sync(coro, *args, loop=None, report_crash=True): - coro = normalize_corofunc(coro, args) - loop = loop or asyncio.get_event_loop() - task = loop.create_task(coro) - if report_crash: - task.add_done_callback(partial(check_task, logging)) - return task +__all__ = 'TaskGroup', class TaskGroup: @@ -120,7 +104,7 @@ class TaskGroup: '''Create a new task that’s part of the group. Returns a Task instance. ''' - task = spawn_sync(coro, *args, report_crash=False) + task = asyncio.create_task(coro) self._add_task(task) return task @@ -214,151 +198,3 @@ class TaskGroup: await self.cancel_remaining() else: await self.join() - - -class TaskTimeout(asyncio.CancelledError): - - def __init__(self, secs): - self.secs = secs - - def __str__(self): - return f'task timed out after {self.args[0]}s' - - -class TimeoutCancellationError(asyncio.CancelledError): - pass - - -class UncaughtTimeoutError(Exception): - pass - - -def _set_new_deadline(task, deadline): - def timeout_task(): - # Unfortunately task.cancel is all we can do with asyncio - task.cancel() - task._timed_out = deadline - task._deadline_handle = task._loop.call_at(deadline, timeout_task) - - -def _set_task_deadline(task, deadline): - deadlines = getattr(task, '_deadlines', []) - if deadlines: - if deadline < min(deadlines): - task._deadline_handle.cancel() - _set_new_deadline(task, deadline) - else: - _set_new_deadline(task, deadline) - deadlines.append(deadline) - task._deadlines = deadlines - task._timed_out = None - - -def _unset_task_deadline(task): - deadlines = task._deadlines - timed_out_deadline = task._timed_out - uncaught = timed_out_deadline not in deadlines - task._deadline_handle.cancel() - deadlines.pop() - if deadlines: - _set_new_deadline(task, min(deadlines)) - return timed_out_deadline, uncaught - - -class TimeoutAfter(object): - - def __init__(self, deadline, *, ignore=False, absolute=False): - self._deadline = deadline - self._ignore = ignore - self._absolute = absolute - self.expired = False - - async def __aenter__(self): - task = asyncio.current_task() - loop_time = task._loop.time() - if self._absolute: - self._secs = self._deadline - loop_time - else: - self._secs = self._deadline - self._deadline += loop_time - _set_task_deadline(task, self._deadline) - self.expired = False - self._task = task - return self - - async def __aexit__(self, exc_type, exc_value, traceback): - timed_out_deadline, uncaught = _unset_task_deadline(self._task) - if exc_type not in (asyncio.CancelledError, TaskTimeout, - TimeoutCancellationError): - return False - if timed_out_deadline == self._deadline: - self.expired = True - if self._ignore: - return True - raise TaskTimeout(self._secs) from None - if timed_out_deadline is None: - assert exc_type is asyncio.CancelledError - return False - if uncaught: - raise UncaughtTimeoutError('uncaught timeout received') - if exc_type is TimeoutCancellationError: - return False - raise TimeoutCancellationError(timed_out_deadline) from None - - -async def _timeout_after_func(seconds, absolute, coro, args): - coro = normalize_corofunc(coro, args) - async with TimeoutAfter(seconds, absolute=absolute): - return await coro - - -def timeout_after(seconds, coro=None, *args): - '''Execute the specified coroutine and return its result. However, - issue a cancellation request to the calling task after seconds - have elapsed. When this happens, a TaskTimeout exception is - raised. If coro is None, the result of this function serves - as an asynchronous context manager that applies a timeout to a - block of statements. - - timeout_after() may be composed with other timeout_after() - operations (i.e., nested timeouts). If an outer timeout expires - first, then TimeoutCancellationError is raised instead of - TaskTimeout. If an inner timeout expires and fails to properly - TaskTimeout, a UncaughtTimeoutError is raised in the outer - timeout. - - ''' - if coro: - return _timeout_after_func(seconds, False, coro, args) - - return TimeoutAfter(seconds) - - -async def _ignore_after_func(seconds, absolute, coro, args, timeout_result): - coro = normalize_corofunc(coro, args) - async with TimeoutAfter(seconds, absolute=absolute, ignore=True): - return await coro - - return timeout_result - - -def ignore_after(seconds, coro=None, *args, timeout_result=None): - '''Execute the specified coroutine and return its result. Issue a - cancellation request after seconds have elapsed. When a timeout - occurs, no exception is raised. Instead, timeout_result is - returned. - - If coro is None, the result is an asynchronous context manager - that applies a timeout to a block of statements. For the context - manager case, the resulting context manager object has an expired - attribute set to True if time expired. - - Note: ignore_after() may also be composed with other timeout - operations. TimeoutCancellationError and UncaughtTimeoutError - exceptions might be raised according to the same rules as for - timeout_after(). - ''' - if coro: - return _ignore_after_func(seconds, False, coro, args, timeout_result) - - return TimeoutAfter(seconds, ignore=True) diff --git a/torba/rpc/session.py b/torba/rpc/session.py index 5be1a7dfc..3c465abd6 100644 --- a/torba/rpc/session.py +++ b/torba/rpc/session.py @@ -36,7 +36,7 @@ from contextlib import suppress from .jsonrpc import Request, JSONRPCConnection, JSONRPCv2, JSONRPC, Batch, Notification from .jsonrpc import RPCError, ProtocolError -from .curio import TaskGroup, TaskTimeout, spawn_sync, ignore_after, timeout_after +from .curio import TaskGroup from .framing import BadMagicError, BadChecksumError, OversizedPayloadError, BitcoinFramer, NewlineFramer from .util import Concurrency @@ -151,17 +151,15 @@ class SessionBase(asyncio.Protocol): task_group = self._task_group async with task_group: - await self.spawn(self._receive_messages) - await self.spawn(collect_tasks) + await self.spawn(self._receive_messages()) + await self.spawn(collect_tasks()) async def _limited_wait(self, secs): - # Wait at most secs seconds to send, otherwise abort the connection try: - async with timeout_after(secs): - await self._can_send.wait() - except TaskTimeout: + await asyncio.wait_for(self._can_send.wait(), secs) + except asyncio.TimeoutError: self.abort() - raise + raise asyncio.CancelledError(f'task timed out after {secs}s') async def _send_message(self, message): if not self._can_send.is_set(): @@ -221,7 +219,7 @@ class SessionBase(asyncio.Protocol): self._proxy_address = peer_address else: self._address = peer_address - self._pm_task = spawn_sync(self._process_messages(), loop=self.loop) + self._pm_task = self.loop.create_task(self._process_messages()) def connection_lost(self, exc): '''Called by asyncio when the connection closes. @@ -281,8 +279,7 @@ class SessionBase(asyncio.Protocol): self._close() if self._pm_task: with suppress(CancelledError): - async with ignore_after(force_after): - await self._pm_task + await asyncio.wait([self._pm_task], timeout=force_after) self.abort() await self._pm_task diff --git a/torba/server/peers.py b/torba/server/peers.py index af319c29a..f764ba548 100644 --- a/torba/server/peers.py +++ b/torba/server/peers.py @@ -16,7 +16,7 @@ from collections import defaultdict, Counter from torba.rpc import ( Connector, RPCSession, SOCKSProxy, Notification, handler_invocation, - SOCKSError, RPCError, TaskTimeout, TaskGroup, ignore_after, timeout_after + SOCKSError, RPCError, TaskGroup ) from torba.server.peer import Peer from torba.server.util import class_logger, protocol_tuple @@ -194,8 +194,8 @@ class PeerManager: pause = STALE_SECS - WAKEUP_SECS * 2 else: pause = WAKEUP_SECS * 2 ** peer.try_count - async with ignore_after(pause): - await peer.retry_event.wait() + pending, done = await asyncio.wait([peer.retry_event.wait()], timeout=pause) + if done: peer.retry_event.clear() async def _should_drop_peer(self, peer): @@ -224,10 +224,12 @@ class PeerManager: peer_text = f'[{peer}:{port} {kind}]' try: - async with timeout_after(120 if peer.is_tor else 30): - async with Connector(PeerSession, peer.host, port, - **kwargs) as session: - await self._verify_peer(session, peer) + async with Connector(PeerSession, peer.host, port, + **kwargs) as session: + await asyncio.wait_for( + self._verify_peer(session, peer), + 120 if peer.is_tor else 30 + ) is_good = True break except BadPeerError as e: @@ -237,7 +239,7 @@ class PeerManager: except RPCError as e: self.logger.error(f'{peer_text} RPC error: {e.message} ' f'({e.code})') - except (OSError, SOCKSError, ConnectionError, TaskTimeout) as e: + except (OSError, SOCKSError, ConnectionError, asyncio.TimeoutError) as e: self.logger.info(f'{peer_text} {e}') if is_good: