dropped everything in curio.py except TaskGroup so far, TaskGroup is next

This commit is contained in:
Lex Berezhny 2018-12-06 16:48:17 -05:00
parent b93f9d4c94
commit 5ed478ed11
3 changed files with 20 additions and 185 deletions

View file

@ -40,25 +40,9 @@ import logging
import asyncio import asyncio
from collections import deque from collections import deque
from contextlib import suppress from contextlib import suppress
from functools import partial
from .util import normalize_corofunc, check_task
__all__ = ( __all__ = 'TaskGroup',
'spawn_sync', 'TaskGroup',
'TaskTimeout', 'TimeoutCancellationError', 'UncaughtTimeoutError',
'timeout_after', 'ignore_after'
)
def spawn_sync(coro, *args, loop=None, report_crash=True):
coro = normalize_corofunc(coro, args)
loop = loop or asyncio.get_event_loop()
task = loop.create_task(coro)
if report_crash:
task.add_done_callback(partial(check_task, logging))
return task
class TaskGroup: class TaskGroup:
@ -120,7 +104,7 @@ class TaskGroup:
'''Create a new task thats part of the group. Returns a Task '''Create a new task thats part of the group. Returns a Task
instance. instance.
''' '''
task = spawn_sync(coro, *args, report_crash=False) task = asyncio.create_task(coro)
self._add_task(task) self._add_task(task)
return task return task
@ -214,151 +198,3 @@ class TaskGroup:
await self.cancel_remaining() await self.cancel_remaining()
else: else:
await self.join() await self.join()
class TaskTimeout(asyncio.CancelledError):
def __init__(self, secs):
self.secs = secs
def __str__(self):
return f'task timed out after {self.args[0]}s'
class TimeoutCancellationError(asyncio.CancelledError):
pass
class UncaughtTimeoutError(Exception):
pass
def _set_new_deadline(task, deadline):
def timeout_task():
# Unfortunately task.cancel is all we can do with asyncio
task.cancel()
task._timed_out = deadline
task._deadline_handle = task._loop.call_at(deadline, timeout_task)
def _set_task_deadline(task, deadline):
deadlines = getattr(task, '_deadlines', [])
if deadlines:
if deadline < min(deadlines):
task._deadline_handle.cancel()
_set_new_deadline(task, deadline)
else:
_set_new_deadline(task, deadline)
deadlines.append(deadline)
task._deadlines = deadlines
task._timed_out = None
def _unset_task_deadline(task):
deadlines = task._deadlines
timed_out_deadline = task._timed_out
uncaught = timed_out_deadline not in deadlines
task._deadline_handle.cancel()
deadlines.pop()
if deadlines:
_set_new_deadline(task, min(deadlines))
return timed_out_deadline, uncaught
class TimeoutAfter(object):
def __init__(self, deadline, *, ignore=False, absolute=False):
self._deadline = deadline
self._ignore = ignore
self._absolute = absolute
self.expired = False
async def __aenter__(self):
task = asyncio.current_task()
loop_time = task._loop.time()
if self._absolute:
self._secs = self._deadline - loop_time
else:
self._secs = self._deadline
self._deadline += loop_time
_set_task_deadline(task, self._deadline)
self.expired = False
self._task = task
return self
async def __aexit__(self, exc_type, exc_value, traceback):
timed_out_deadline, uncaught = _unset_task_deadline(self._task)
if exc_type not in (asyncio.CancelledError, TaskTimeout,
TimeoutCancellationError):
return False
if timed_out_deadline == self._deadline:
self.expired = True
if self._ignore:
return True
raise TaskTimeout(self._secs) from None
if timed_out_deadline is None:
assert exc_type is asyncio.CancelledError
return False
if uncaught:
raise UncaughtTimeoutError('uncaught timeout received')
if exc_type is TimeoutCancellationError:
return False
raise TimeoutCancellationError(timed_out_deadline) from None
async def _timeout_after_func(seconds, absolute, coro, args):
coro = normalize_corofunc(coro, args)
async with TimeoutAfter(seconds, absolute=absolute):
return await coro
def timeout_after(seconds, coro=None, *args):
'''Execute the specified coroutine and return its result. However,
issue a cancellation request to the calling task after seconds
have elapsed. When this happens, a TaskTimeout exception is
raised. If coro is None, the result of this function serves
as an asynchronous context manager that applies a timeout to a
block of statements.
timeout_after() may be composed with other timeout_after()
operations (i.e., nested timeouts). If an outer timeout expires
first, then TimeoutCancellationError is raised instead of
TaskTimeout. If an inner timeout expires and fails to properly
TaskTimeout, a UncaughtTimeoutError is raised in the outer
timeout.
'''
if coro:
return _timeout_after_func(seconds, False, coro, args)
return TimeoutAfter(seconds)
async def _ignore_after_func(seconds, absolute, coro, args, timeout_result):
coro = normalize_corofunc(coro, args)
async with TimeoutAfter(seconds, absolute=absolute, ignore=True):
return await coro
return timeout_result
def ignore_after(seconds, coro=None, *args, timeout_result=None):
'''Execute the specified coroutine and return its result. Issue a
cancellation request after seconds have elapsed. When a timeout
occurs, no exception is raised. Instead, timeout_result is
returned.
If coro is None, the result is an asynchronous context manager
that applies a timeout to a block of statements. For the context
manager case, the resulting context manager object has an expired
attribute set to True if time expired.
Note: ignore_after() may also be composed with other timeout
operations. TimeoutCancellationError and UncaughtTimeoutError
exceptions might be raised according to the same rules as for
timeout_after().
'''
if coro:
return _ignore_after_func(seconds, False, coro, args, timeout_result)
return TimeoutAfter(seconds, ignore=True)

View file

@ -36,7 +36,7 @@ from contextlib import suppress
from .jsonrpc import Request, JSONRPCConnection, JSONRPCv2, JSONRPC, Batch, Notification from .jsonrpc import Request, JSONRPCConnection, JSONRPCv2, JSONRPC, Batch, Notification
from .jsonrpc import RPCError, ProtocolError from .jsonrpc import RPCError, ProtocolError
from .curio import TaskGroup, TaskTimeout, spawn_sync, ignore_after, timeout_after from .curio import TaskGroup
from .framing import BadMagicError, BadChecksumError, OversizedPayloadError, BitcoinFramer, NewlineFramer from .framing import BadMagicError, BadChecksumError, OversizedPayloadError, BitcoinFramer, NewlineFramer
from .util import Concurrency from .util import Concurrency
@ -151,17 +151,15 @@ class SessionBase(asyncio.Protocol):
task_group = self._task_group task_group = self._task_group
async with task_group: async with task_group:
await self.spawn(self._receive_messages) await self.spawn(self._receive_messages())
await self.spawn(collect_tasks) await self.spawn(collect_tasks())
async def _limited_wait(self, secs): async def _limited_wait(self, secs):
# Wait at most secs seconds to send, otherwise abort the connection
try: try:
async with timeout_after(secs): await asyncio.wait_for(self._can_send.wait(), secs)
await self._can_send.wait() except asyncio.TimeoutError:
except TaskTimeout:
self.abort() self.abort()
raise raise asyncio.CancelledError(f'task timed out after {secs}s')
async def _send_message(self, message): async def _send_message(self, message):
if not self._can_send.is_set(): if not self._can_send.is_set():
@ -221,7 +219,7 @@ class SessionBase(asyncio.Protocol):
self._proxy_address = peer_address self._proxy_address = peer_address
else: else:
self._address = peer_address self._address = peer_address
self._pm_task = spawn_sync(self._process_messages(), loop=self.loop) self._pm_task = self.loop.create_task(self._process_messages())
def connection_lost(self, exc): def connection_lost(self, exc):
'''Called by asyncio when the connection closes. '''Called by asyncio when the connection closes.
@ -281,8 +279,7 @@ class SessionBase(asyncio.Protocol):
self._close() self._close()
if self._pm_task: if self._pm_task:
with suppress(CancelledError): with suppress(CancelledError):
async with ignore_after(force_after): await asyncio.wait([self._pm_task], timeout=force_after)
await self._pm_task
self.abort() self.abort()
await self._pm_task await self._pm_task

View file

@ -16,7 +16,7 @@ from collections import defaultdict, Counter
from torba.rpc import ( from torba.rpc import (
Connector, RPCSession, SOCKSProxy, Notification, handler_invocation, Connector, RPCSession, SOCKSProxy, Notification, handler_invocation,
SOCKSError, RPCError, TaskTimeout, TaskGroup, ignore_after, timeout_after SOCKSError, RPCError, TaskGroup
) )
from torba.server.peer import Peer from torba.server.peer import Peer
from torba.server.util import class_logger, protocol_tuple from torba.server.util import class_logger, protocol_tuple
@ -194,8 +194,8 @@ class PeerManager:
pause = STALE_SECS - WAKEUP_SECS * 2 pause = STALE_SECS - WAKEUP_SECS * 2
else: else:
pause = WAKEUP_SECS * 2 ** peer.try_count pause = WAKEUP_SECS * 2 ** peer.try_count
async with ignore_after(pause): pending, done = await asyncio.wait([peer.retry_event.wait()], timeout=pause)
await peer.retry_event.wait() if done:
peer.retry_event.clear() peer.retry_event.clear()
async def _should_drop_peer(self, peer): async def _should_drop_peer(self, peer):
@ -224,10 +224,12 @@ class PeerManager:
peer_text = f'[{peer}:{port} {kind}]' peer_text = f'[{peer}:{port} {kind}]'
try: try:
async with timeout_after(120 if peer.is_tor else 30): async with Connector(PeerSession, peer.host, port,
async with Connector(PeerSession, peer.host, port, **kwargs) as session:
**kwargs) as session: await asyncio.wait_for(
await self._verify_peer(session, peer) self._verify_peer(session, peer),
120 if peer.is_tor else 30
)
is_good = True is_good = True
break break
except BadPeerError as e: except BadPeerError as e:
@ -237,7 +239,7 @@ class PeerManager:
except RPCError as e: except RPCError as e:
self.logger.error(f'{peer_text} RPC error: {e.message} ' self.logger.error(f'{peer_text} RPC error: {e.message} '
f'({e.code})') f'({e.code})')
except (OSError, SOCKSError, ConnectionError, TaskTimeout) as e: except (OSError, SOCKSError, ConnectionError, asyncio.TimeoutError) as e:
self.logger.info(f'{peer_text} {e}') self.logger.info(f'{peer_text} {e}')
if is_good: if is_good: