dropped everything in curio.py except TaskGroup so far, TaskGroup is next

This commit is contained in:
Lex Berezhny 2018-12-06 16:48:17 -05:00
parent b93f9d4c94
commit 5ed478ed11
3 changed files with 20 additions and 185 deletions

View file

@ -40,25 +40,9 @@ import logging
import asyncio
from collections import deque
from contextlib import suppress
from functools import partial
from .util import normalize_corofunc, check_task
__all__ = (
'spawn_sync', 'TaskGroup',
'TaskTimeout', 'TimeoutCancellationError', 'UncaughtTimeoutError',
'timeout_after', 'ignore_after'
)
def spawn_sync(coro, *args, loop=None, report_crash=True):
coro = normalize_corofunc(coro, args)
loop = loop or asyncio.get_event_loop()
task = loop.create_task(coro)
if report_crash:
task.add_done_callback(partial(check_task, logging))
return task
__all__ = 'TaskGroup',
class TaskGroup:
@ -120,7 +104,7 @@ class TaskGroup:
'''Create a new task thats part of the group. Returns a Task
instance.
'''
task = spawn_sync(coro, *args, report_crash=False)
task = asyncio.create_task(coro)
self._add_task(task)
return task
@ -214,151 +198,3 @@ class TaskGroup:
await self.cancel_remaining()
else:
await self.join()
class TaskTimeout(asyncio.CancelledError):
def __init__(self, secs):
self.secs = secs
def __str__(self):
return f'task timed out after {self.args[0]}s'
class TimeoutCancellationError(asyncio.CancelledError):
pass
class UncaughtTimeoutError(Exception):
pass
def _set_new_deadline(task, deadline):
def timeout_task():
# Unfortunately task.cancel is all we can do with asyncio
task.cancel()
task._timed_out = deadline
task._deadline_handle = task._loop.call_at(deadline, timeout_task)
def _set_task_deadline(task, deadline):
deadlines = getattr(task, '_deadlines', [])
if deadlines:
if deadline < min(deadlines):
task._deadline_handle.cancel()
_set_new_deadline(task, deadline)
else:
_set_new_deadline(task, deadline)
deadlines.append(deadline)
task._deadlines = deadlines
task._timed_out = None
def _unset_task_deadline(task):
deadlines = task._deadlines
timed_out_deadline = task._timed_out
uncaught = timed_out_deadline not in deadlines
task._deadline_handle.cancel()
deadlines.pop()
if deadlines:
_set_new_deadline(task, min(deadlines))
return timed_out_deadline, uncaught
class TimeoutAfter(object):
def __init__(self, deadline, *, ignore=False, absolute=False):
self._deadline = deadline
self._ignore = ignore
self._absolute = absolute
self.expired = False
async def __aenter__(self):
task = asyncio.current_task()
loop_time = task._loop.time()
if self._absolute:
self._secs = self._deadline - loop_time
else:
self._secs = self._deadline
self._deadline += loop_time
_set_task_deadline(task, self._deadline)
self.expired = False
self._task = task
return self
async def __aexit__(self, exc_type, exc_value, traceback):
timed_out_deadline, uncaught = _unset_task_deadline(self._task)
if exc_type not in (asyncio.CancelledError, TaskTimeout,
TimeoutCancellationError):
return False
if timed_out_deadline == self._deadline:
self.expired = True
if self._ignore:
return True
raise TaskTimeout(self._secs) from None
if timed_out_deadline is None:
assert exc_type is asyncio.CancelledError
return False
if uncaught:
raise UncaughtTimeoutError('uncaught timeout received')
if exc_type is TimeoutCancellationError:
return False
raise TimeoutCancellationError(timed_out_deadline) from None
async def _timeout_after_func(seconds, absolute, coro, args):
coro = normalize_corofunc(coro, args)
async with TimeoutAfter(seconds, absolute=absolute):
return await coro
def timeout_after(seconds, coro=None, *args):
'''Execute the specified coroutine and return its result. However,
issue a cancellation request to the calling task after seconds
have elapsed. When this happens, a TaskTimeout exception is
raised. If coro is None, the result of this function serves
as an asynchronous context manager that applies a timeout to a
block of statements.
timeout_after() may be composed with other timeout_after()
operations (i.e., nested timeouts). If an outer timeout expires
first, then TimeoutCancellationError is raised instead of
TaskTimeout. If an inner timeout expires and fails to properly
TaskTimeout, a UncaughtTimeoutError is raised in the outer
timeout.
'''
if coro:
return _timeout_after_func(seconds, False, coro, args)
return TimeoutAfter(seconds)
async def _ignore_after_func(seconds, absolute, coro, args, timeout_result):
coro = normalize_corofunc(coro, args)
async with TimeoutAfter(seconds, absolute=absolute, ignore=True):
return await coro
return timeout_result
def ignore_after(seconds, coro=None, *args, timeout_result=None):
'''Execute the specified coroutine and return its result. Issue a
cancellation request after seconds have elapsed. When a timeout
occurs, no exception is raised. Instead, timeout_result is
returned.
If coro is None, the result is an asynchronous context manager
that applies a timeout to a block of statements. For the context
manager case, the resulting context manager object has an expired
attribute set to True if time expired.
Note: ignore_after() may also be composed with other timeout
operations. TimeoutCancellationError and UncaughtTimeoutError
exceptions might be raised according to the same rules as for
timeout_after().
'''
if coro:
return _ignore_after_func(seconds, False, coro, args, timeout_result)
return TimeoutAfter(seconds, ignore=True)

View file

@ -36,7 +36,7 @@ from contextlib import suppress
from .jsonrpc import Request, JSONRPCConnection, JSONRPCv2, JSONRPC, Batch, Notification
from .jsonrpc import RPCError, ProtocolError
from .curio import TaskGroup, TaskTimeout, spawn_sync, ignore_after, timeout_after
from .curio import TaskGroup
from .framing import BadMagicError, BadChecksumError, OversizedPayloadError, BitcoinFramer, NewlineFramer
from .util import Concurrency
@ -151,17 +151,15 @@ class SessionBase(asyncio.Protocol):
task_group = self._task_group
async with task_group:
await self.spawn(self._receive_messages)
await self.spawn(collect_tasks)
await self.spawn(self._receive_messages())
await self.spawn(collect_tasks())
async def _limited_wait(self, secs):
# Wait at most secs seconds to send, otherwise abort the connection
try:
async with timeout_after(secs):
await self._can_send.wait()
except TaskTimeout:
await asyncio.wait_for(self._can_send.wait(), secs)
except asyncio.TimeoutError:
self.abort()
raise
raise asyncio.CancelledError(f'task timed out after {secs}s')
async def _send_message(self, message):
if not self._can_send.is_set():
@ -221,7 +219,7 @@ class SessionBase(asyncio.Protocol):
self._proxy_address = peer_address
else:
self._address = peer_address
self._pm_task = spawn_sync(self._process_messages(), loop=self.loop)
self._pm_task = self.loop.create_task(self._process_messages())
def connection_lost(self, exc):
'''Called by asyncio when the connection closes.
@ -281,8 +279,7 @@ class SessionBase(asyncio.Protocol):
self._close()
if self._pm_task:
with suppress(CancelledError):
async with ignore_after(force_after):
await self._pm_task
await asyncio.wait([self._pm_task], timeout=force_after)
self.abort()
await self._pm_task

View file

@ -16,7 +16,7 @@ from collections import defaultdict, Counter
from torba.rpc import (
Connector, RPCSession, SOCKSProxy, Notification, handler_invocation,
SOCKSError, RPCError, TaskTimeout, TaskGroup, ignore_after, timeout_after
SOCKSError, RPCError, TaskGroup
)
from torba.server.peer import Peer
from torba.server.util import class_logger, protocol_tuple
@ -194,8 +194,8 @@ class PeerManager:
pause = STALE_SECS - WAKEUP_SECS * 2
else:
pause = WAKEUP_SECS * 2 ** peer.try_count
async with ignore_after(pause):
await peer.retry_event.wait()
pending, done = await asyncio.wait([peer.retry_event.wait()], timeout=pause)
if done:
peer.retry_event.clear()
async def _should_drop_peer(self, peer):
@ -224,10 +224,12 @@ class PeerManager:
peer_text = f'[{peer}:{port} {kind}]'
try:
async with timeout_after(120 if peer.is_tor else 30):
async with Connector(PeerSession, peer.host, port,
**kwargs) as session:
await self._verify_peer(session, peer)
async with Connector(PeerSession, peer.host, port,
**kwargs) as session:
await asyncio.wait_for(
self._verify_peer(session, peer),
120 if peer.is_tor else 30
)
is_good = True
break
except BadPeerError as e:
@ -237,7 +239,7 @@ class PeerManager:
except RPCError as e:
self.logger.error(f'{peer_text} RPC error: {e.message} '
f'({e.code})')
except (OSError, SOCKSError, ConnectionError, TaskTimeout) as e:
except (OSError, SOCKSError, ConnectionError, asyncio.TimeoutError) as e:
self.logger.info(f'{peer_text} {e}')
if is_good: