forked from LBRYCommunity/lbry-sdk
dropped everything in curio.py except TaskGroup so far, TaskGroup is next
This commit is contained in:
parent
b93f9d4c94
commit
5ed478ed11
3 changed files with 20 additions and 185 deletions
|
@ -40,25 +40,9 @@ import logging
|
|||
import asyncio
|
||||
from collections import deque
|
||||
from contextlib import suppress
|
||||
from functools import partial
|
||||
|
||||
from .util import normalize_corofunc, check_task
|
||||
|
||||
|
||||
__all__ = (
|
||||
'spawn_sync', 'TaskGroup',
|
||||
'TaskTimeout', 'TimeoutCancellationError', 'UncaughtTimeoutError',
|
||||
'timeout_after', 'ignore_after'
|
||||
)
|
||||
|
||||
|
||||
def spawn_sync(coro, *args, loop=None, report_crash=True):
|
||||
coro = normalize_corofunc(coro, args)
|
||||
loop = loop or asyncio.get_event_loop()
|
||||
task = loop.create_task(coro)
|
||||
if report_crash:
|
||||
task.add_done_callback(partial(check_task, logging))
|
||||
return task
|
||||
__all__ = 'TaskGroup',
|
||||
|
||||
|
||||
class TaskGroup:
|
||||
|
@ -120,7 +104,7 @@ class TaskGroup:
|
|||
'''Create a new task that’s part of the group. Returns a Task
|
||||
instance.
|
||||
'''
|
||||
task = spawn_sync(coro, *args, report_crash=False)
|
||||
task = asyncio.create_task(coro)
|
||||
self._add_task(task)
|
||||
return task
|
||||
|
||||
|
@ -214,151 +198,3 @@ class TaskGroup:
|
|||
await self.cancel_remaining()
|
||||
else:
|
||||
await self.join()
|
||||
|
||||
|
||||
class TaskTimeout(asyncio.CancelledError):
|
||||
|
||||
def __init__(self, secs):
|
||||
self.secs = secs
|
||||
|
||||
def __str__(self):
|
||||
return f'task timed out after {self.args[0]}s'
|
||||
|
||||
|
||||
class TimeoutCancellationError(asyncio.CancelledError):
|
||||
pass
|
||||
|
||||
|
||||
class UncaughtTimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _set_new_deadline(task, deadline):
|
||||
def timeout_task():
|
||||
# Unfortunately task.cancel is all we can do with asyncio
|
||||
task.cancel()
|
||||
task._timed_out = deadline
|
||||
task._deadline_handle = task._loop.call_at(deadline, timeout_task)
|
||||
|
||||
|
||||
def _set_task_deadline(task, deadline):
|
||||
deadlines = getattr(task, '_deadlines', [])
|
||||
if deadlines:
|
||||
if deadline < min(deadlines):
|
||||
task._deadline_handle.cancel()
|
||||
_set_new_deadline(task, deadline)
|
||||
else:
|
||||
_set_new_deadline(task, deadline)
|
||||
deadlines.append(deadline)
|
||||
task._deadlines = deadlines
|
||||
task._timed_out = None
|
||||
|
||||
|
||||
def _unset_task_deadline(task):
|
||||
deadlines = task._deadlines
|
||||
timed_out_deadline = task._timed_out
|
||||
uncaught = timed_out_deadline not in deadlines
|
||||
task._deadline_handle.cancel()
|
||||
deadlines.pop()
|
||||
if deadlines:
|
||||
_set_new_deadline(task, min(deadlines))
|
||||
return timed_out_deadline, uncaught
|
||||
|
||||
|
||||
class TimeoutAfter(object):
|
||||
|
||||
def __init__(self, deadline, *, ignore=False, absolute=False):
|
||||
self._deadline = deadline
|
||||
self._ignore = ignore
|
||||
self._absolute = absolute
|
||||
self.expired = False
|
||||
|
||||
async def __aenter__(self):
|
||||
task = asyncio.current_task()
|
||||
loop_time = task._loop.time()
|
||||
if self._absolute:
|
||||
self._secs = self._deadline - loop_time
|
||||
else:
|
||||
self._secs = self._deadline
|
||||
self._deadline += loop_time
|
||||
_set_task_deadline(task, self._deadline)
|
||||
self.expired = False
|
||||
self._task = task
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
timed_out_deadline, uncaught = _unset_task_deadline(self._task)
|
||||
if exc_type not in (asyncio.CancelledError, TaskTimeout,
|
||||
TimeoutCancellationError):
|
||||
return False
|
||||
if timed_out_deadline == self._deadline:
|
||||
self.expired = True
|
||||
if self._ignore:
|
||||
return True
|
||||
raise TaskTimeout(self._secs) from None
|
||||
if timed_out_deadline is None:
|
||||
assert exc_type is asyncio.CancelledError
|
||||
return False
|
||||
if uncaught:
|
||||
raise UncaughtTimeoutError('uncaught timeout received')
|
||||
if exc_type is TimeoutCancellationError:
|
||||
return False
|
||||
raise TimeoutCancellationError(timed_out_deadline) from None
|
||||
|
||||
|
||||
async def _timeout_after_func(seconds, absolute, coro, args):
|
||||
coro = normalize_corofunc(coro, args)
|
||||
async with TimeoutAfter(seconds, absolute=absolute):
|
||||
return await coro
|
||||
|
||||
|
||||
def timeout_after(seconds, coro=None, *args):
|
||||
'''Execute the specified coroutine and return its result. However,
|
||||
issue a cancellation request to the calling task after seconds
|
||||
have elapsed. When this happens, a TaskTimeout exception is
|
||||
raised. If coro is None, the result of this function serves
|
||||
as an asynchronous context manager that applies a timeout to a
|
||||
block of statements.
|
||||
|
||||
timeout_after() may be composed with other timeout_after()
|
||||
operations (i.e., nested timeouts). If an outer timeout expires
|
||||
first, then TimeoutCancellationError is raised instead of
|
||||
TaskTimeout. If an inner timeout expires and fails to properly
|
||||
TaskTimeout, a UncaughtTimeoutError is raised in the outer
|
||||
timeout.
|
||||
|
||||
'''
|
||||
if coro:
|
||||
return _timeout_after_func(seconds, False, coro, args)
|
||||
|
||||
return TimeoutAfter(seconds)
|
||||
|
||||
|
||||
async def _ignore_after_func(seconds, absolute, coro, args, timeout_result):
|
||||
coro = normalize_corofunc(coro, args)
|
||||
async with TimeoutAfter(seconds, absolute=absolute, ignore=True):
|
||||
return await coro
|
||||
|
||||
return timeout_result
|
||||
|
||||
|
||||
def ignore_after(seconds, coro=None, *args, timeout_result=None):
|
||||
'''Execute the specified coroutine and return its result. Issue a
|
||||
cancellation request after seconds have elapsed. When a timeout
|
||||
occurs, no exception is raised. Instead, timeout_result is
|
||||
returned.
|
||||
|
||||
If coro is None, the result is an asynchronous context manager
|
||||
that applies a timeout to a block of statements. For the context
|
||||
manager case, the resulting context manager object has an expired
|
||||
attribute set to True if time expired.
|
||||
|
||||
Note: ignore_after() may also be composed with other timeout
|
||||
operations. TimeoutCancellationError and UncaughtTimeoutError
|
||||
exceptions might be raised according to the same rules as for
|
||||
timeout_after().
|
||||
'''
|
||||
if coro:
|
||||
return _ignore_after_func(seconds, False, coro, args, timeout_result)
|
||||
|
||||
return TimeoutAfter(seconds, ignore=True)
|
||||
|
|
|
@ -36,7 +36,7 @@ from contextlib import suppress
|
|||
|
||||
from .jsonrpc import Request, JSONRPCConnection, JSONRPCv2, JSONRPC, Batch, Notification
|
||||
from .jsonrpc import RPCError, ProtocolError
|
||||
from .curio import TaskGroup, TaskTimeout, spawn_sync, ignore_after, timeout_after
|
||||
from .curio import TaskGroup
|
||||
from .framing import BadMagicError, BadChecksumError, OversizedPayloadError, BitcoinFramer, NewlineFramer
|
||||
from .util import Concurrency
|
||||
|
||||
|
@ -151,17 +151,15 @@ class SessionBase(asyncio.Protocol):
|
|||
|
||||
task_group = self._task_group
|
||||
async with task_group:
|
||||
await self.spawn(self._receive_messages)
|
||||
await self.spawn(collect_tasks)
|
||||
await self.spawn(self._receive_messages())
|
||||
await self.spawn(collect_tasks())
|
||||
|
||||
async def _limited_wait(self, secs):
|
||||
# Wait at most secs seconds to send, otherwise abort the connection
|
||||
try:
|
||||
async with timeout_after(secs):
|
||||
await self._can_send.wait()
|
||||
except TaskTimeout:
|
||||
await asyncio.wait_for(self._can_send.wait(), secs)
|
||||
except asyncio.TimeoutError:
|
||||
self.abort()
|
||||
raise
|
||||
raise asyncio.CancelledError(f'task timed out after {secs}s')
|
||||
|
||||
async def _send_message(self, message):
|
||||
if not self._can_send.is_set():
|
||||
|
@ -221,7 +219,7 @@ class SessionBase(asyncio.Protocol):
|
|||
self._proxy_address = peer_address
|
||||
else:
|
||||
self._address = peer_address
|
||||
self._pm_task = spawn_sync(self._process_messages(), loop=self.loop)
|
||||
self._pm_task = self.loop.create_task(self._process_messages())
|
||||
|
||||
def connection_lost(self, exc):
|
||||
'''Called by asyncio when the connection closes.
|
||||
|
@ -281,8 +279,7 @@ class SessionBase(asyncio.Protocol):
|
|||
self._close()
|
||||
if self._pm_task:
|
||||
with suppress(CancelledError):
|
||||
async with ignore_after(force_after):
|
||||
await self._pm_task
|
||||
await asyncio.wait([self._pm_task], timeout=force_after)
|
||||
self.abort()
|
||||
await self._pm_task
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ from collections import defaultdict, Counter
|
|||
|
||||
from torba.rpc import (
|
||||
Connector, RPCSession, SOCKSProxy, Notification, handler_invocation,
|
||||
SOCKSError, RPCError, TaskTimeout, TaskGroup, ignore_after, timeout_after
|
||||
SOCKSError, RPCError, TaskGroup
|
||||
)
|
||||
from torba.server.peer import Peer
|
||||
from torba.server.util import class_logger, protocol_tuple
|
||||
|
@ -194,8 +194,8 @@ class PeerManager:
|
|||
pause = STALE_SECS - WAKEUP_SECS * 2
|
||||
else:
|
||||
pause = WAKEUP_SECS * 2 ** peer.try_count
|
||||
async with ignore_after(pause):
|
||||
await peer.retry_event.wait()
|
||||
pending, done = await asyncio.wait([peer.retry_event.wait()], timeout=pause)
|
||||
if done:
|
||||
peer.retry_event.clear()
|
||||
|
||||
async def _should_drop_peer(self, peer):
|
||||
|
@ -224,10 +224,12 @@ class PeerManager:
|
|||
|
||||
peer_text = f'[{peer}:{port} {kind}]'
|
||||
try:
|
||||
async with timeout_after(120 if peer.is_tor else 30):
|
||||
async with Connector(PeerSession, peer.host, port,
|
||||
**kwargs) as session:
|
||||
await self._verify_peer(session, peer)
|
||||
async with Connector(PeerSession, peer.host, port,
|
||||
**kwargs) as session:
|
||||
await asyncio.wait_for(
|
||||
self._verify_peer(session, peer),
|
||||
120 if peer.is_tor else 30
|
||||
)
|
||||
is_good = True
|
||||
break
|
||||
except BadPeerError as e:
|
||||
|
@ -237,7 +239,7 @@ class PeerManager:
|
|||
except RPCError as e:
|
||||
self.logger.error(f'{peer_text} RPC error: {e.message} '
|
||||
f'({e.code})')
|
||||
except (OSError, SOCKSError, ConnectionError, TaskTimeout) as e:
|
||||
except (OSError, SOCKSError, ConnectionError, asyncio.TimeoutError) as e:
|
||||
self.logger.info(f'{peer_text} {e}')
|
||||
|
||||
if is_good:
|
||||
|
|
Loading…
Reference in a new issue