forked from LBRYCommunity/lbry-sdk
dropped everything in curio.py except TaskGroup so far, TaskGroup is next
This commit is contained in:
parent
b93f9d4c94
commit
5ed478ed11
3 changed files with 20 additions and 185 deletions
|
@ -40,25 +40,9 @@ import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
from .util import normalize_corofunc, check_task
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = (
|
__all__ = 'TaskGroup',
|
||||||
'spawn_sync', 'TaskGroup',
|
|
||||||
'TaskTimeout', 'TimeoutCancellationError', 'UncaughtTimeoutError',
|
|
||||||
'timeout_after', 'ignore_after'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def spawn_sync(coro, *args, loop=None, report_crash=True):
|
|
||||||
coro = normalize_corofunc(coro, args)
|
|
||||||
loop = loop or asyncio.get_event_loop()
|
|
||||||
task = loop.create_task(coro)
|
|
||||||
if report_crash:
|
|
||||||
task.add_done_callback(partial(check_task, logging))
|
|
||||||
return task
|
|
||||||
|
|
||||||
|
|
||||||
class TaskGroup:
|
class TaskGroup:
|
||||||
|
@ -120,7 +104,7 @@ class TaskGroup:
|
||||||
'''Create a new task that’s part of the group. Returns a Task
|
'''Create a new task that’s part of the group. Returns a Task
|
||||||
instance.
|
instance.
|
||||||
'''
|
'''
|
||||||
task = spawn_sync(coro, *args, report_crash=False)
|
task = asyncio.create_task(coro)
|
||||||
self._add_task(task)
|
self._add_task(task)
|
||||||
return task
|
return task
|
||||||
|
|
||||||
|
@ -214,151 +198,3 @@ class TaskGroup:
|
||||||
await self.cancel_remaining()
|
await self.cancel_remaining()
|
||||||
else:
|
else:
|
||||||
await self.join()
|
await self.join()
|
||||||
|
|
||||||
|
|
||||||
class TaskTimeout(asyncio.CancelledError):
|
|
||||||
|
|
||||||
def __init__(self, secs):
|
|
||||||
self.secs = secs
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return f'task timed out after {self.args[0]}s'
|
|
||||||
|
|
||||||
|
|
||||||
class TimeoutCancellationError(asyncio.CancelledError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class UncaughtTimeoutError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def _set_new_deadline(task, deadline):
|
|
||||||
def timeout_task():
|
|
||||||
# Unfortunately task.cancel is all we can do with asyncio
|
|
||||||
task.cancel()
|
|
||||||
task._timed_out = deadline
|
|
||||||
task._deadline_handle = task._loop.call_at(deadline, timeout_task)
|
|
||||||
|
|
||||||
|
|
||||||
def _set_task_deadline(task, deadline):
|
|
||||||
deadlines = getattr(task, '_deadlines', [])
|
|
||||||
if deadlines:
|
|
||||||
if deadline < min(deadlines):
|
|
||||||
task._deadline_handle.cancel()
|
|
||||||
_set_new_deadline(task, deadline)
|
|
||||||
else:
|
|
||||||
_set_new_deadline(task, deadline)
|
|
||||||
deadlines.append(deadline)
|
|
||||||
task._deadlines = deadlines
|
|
||||||
task._timed_out = None
|
|
||||||
|
|
||||||
|
|
||||||
def _unset_task_deadline(task):
|
|
||||||
deadlines = task._deadlines
|
|
||||||
timed_out_deadline = task._timed_out
|
|
||||||
uncaught = timed_out_deadline not in deadlines
|
|
||||||
task._deadline_handle.cancel()
|
|
||||||
deadlines.pop()
|
|
||||||
if deadlines:
|
|
||||||
_set_new_deadline(task, min(deadlines))
|
|
||||||
return timed_out_deadline, uncaught
|
|
||||||
|
|
||||||
|
|
||||||
class TimeoutAfter(object):
|
|
||||||
|
|
||||||
def __init__(self, deadline, *, ignore=False, absolute=False):
|
|
||||||
self._deadline = deadline
|
|
||||||
self._ignore = ignore
|
|
||||||
self._absolute = absolute
|
|
||||||
self.expired = False
|
|
||||||
|
|
||||||
async def __aenter__(self):
|
|
||||||
task = asyncio.current_task()
|
|
||||||
loop_time = task._loop.time()
|
|
||||||
if self._absolute:
|
|
||||||
self._secs = self._deadline - loop_time
|
|
||||||
else:
|
|
||||||
self._secs = self._deadline
|
|
||||||
self._deadline += loop_time
|
|
||||||
_set_task_deadline(task, self._deadline)
|
|
||||||
self.expired = False
|
|
||||||
self._task = task
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
||||||
timed_out_deadline, uncaught = _unset_task_deadline(self._task)
|
|
||||||
if exc_type not in (asyncio.CancelledError, TaskTimeout,
|
|
||||||
TimeoutCancellationError):
|
|
||||||
return False
|
|
||||||
if timed_out_deadline == self._deadline:
|
|
||||||
self.expired = True
|
|
||||||
if self._ignore:
|
|
||||||
return True
|
|
||||||
raise TaskTimeout(self._secs) from None
|
|
||||||
if timed_out_deadline is None:
|
|
||||||
assert exc_type is asyncio.CancelledError
|
|
||||||
return False
|
|
||||||
if uncaught:
|
|
||||||
raise UncaughtTimeoutError('uncaught timeout received')
|
|
||||||
if exc_type is TimeoutCancellationError:
|
|
||||||
return False
|
|
||||||
raise TimeoutCancellationError(timed_out_deadline) from None
|
|
||||||
|
|
||||||
|
|
||||||
async def _timeout_after_func(seconds, absolute, coro, args):
|
|
||||||
coro = normalize_corofunc(coro, args)
|
|
||||||
async with TimeoutAfter(seconds, absolute=absolute):
|
|
||||||
return await coro
|
|
||||||
|
|
||||||
|
|
||||||
def timeout_after(seconds, coro=None, *args):
|
|
||||||
'''Execute the specified coroutine and return its result. However,
|
|
||||||
issue a cancellation request to the calling task after seconds
|
|
||||||
have elapsed. When this happens, a TaskTimeout exception is
|
|
||||||
raised. If coro is None, the result of this function serves
|
|
||||||
as an asynchronous context manager that applies a timeout to a
|
|
||||||
block of statements.
|
|
||||||
|
|
||||||
timeout_after() may be composed with other timeout_after()
|
|
||||||
operations (i.e., nested timeouts). If an outer timeout expires
|
|
||||||
first, then TimeoutCancellationError is raised instead of
|
|
||||||
TaskTimeout. If an inner timeout expires and fails to properly
|
|
||||||
TaskTimeout, a UncaughtTimeoutError is raised in the outer
|
|
||||||
timeout.
|
|
||||||
|
|
||||||
'''
|
|
||||||
if coro:
|
|
||||||
return _timeout_after_func(seconds, False, coro, args)
|
|
||||||
|
|
||||||
return TimeoutAfter(seconds)
|
|
||||||
|
|
||||||
|
|
||||||
async def _ignore_after_func(seconds, absolute, coro, args, timeout_result):
|
|
||||||
coro = normalize_corofunc(coro, args)
|
|
||||||
async with TimeoutAfter(seconds, absolute=absolute, ignore=True):
|
|
||||||
return await coro
|
|
||||||
|
|
||||||
return timeout_result
|
|
||||||
|
|
||||||
|
|
||||||
def ignore_after(seconds, coro=None, *args, timeout_result=None):
|
|
||||||
'''Execute the specified coroutine and return its result. Issue a
|
|
||||||
cancellation request after seconds have elapsed. When a timeout
|
|
||||||
occurs, no exception is raised. Instead, timeout_result is
|
|
||||||
returned.
|
|
||||||
|
|
||||||
If coro is None, the result is an asynchronous context manager
|
|
||||||
that applies a timeout to a block of statements. For the context
|
|
||||||
manager case, the resulting context manager object has an expired
|
|
||||||
attribute set to True if time expired.
|
|
||||||
|
|
||||||
Note: ignore_after() may also be composed with other timeout
|
|
||||||
operations. TimeoutCancellationError and UncaughtTimeoutError
|
|
||||||
exceptions might be raised according to the same rules as for
|
|
||||||
timeout_after().
|
|
||||||
'''
|
|
||||||
if coro:
|
|
||||||
return _ignore_after_func(seconds, False, coro, args, timeout_result)
|
|
||||||
|
|
||||||
return TimeoutAfter(seconds, ignore=True)
|
|
||||||
|
|
|
@ -36,7 +36,7 @@ from contextlib import suppress
|
||||||
|
|
||||||
from .jsonrpc import Request, JSONRPCConnection, JSONRPCv2, JSONRPC, Batch, Notification
|
from .jsonrpc import Request, JSONRPCConnection, JSONRPCv2, JSONRPC, Batch, Notification
|
||||||
from .jsonrpc import RPCError, ProtocolError
|
from .jsonrpc import RPCError, ProtocolError
|
||||||
from .curio import TaskGroup, TaskTimeout, spawn_sync, ignore_after, timeout_after
|
from .curio import TaskGroup
|
||||||
from .framing import BadMagicError, BadChecksumError, OversizedPayloadError, BitcoinFramer, NewlineFramer
|
from .framing import BadMagicError, BadChecksumError, OversizedPayloadError, BitcoinFramer, NewlineFramer
|
||||||
from .util import Concurrency
|
from .util import Concurrency
|
||||||
|
|
||||||
|
@ -151,17 +151,15 @@ class SessionBase(asyncio.Protocol):
|
||||||
|
|
||||||
task_group = self._task_group
|
task_group = self._task_group
|
||||||
async with task_group:
|
async with task_group:
|
||||||
await self.spawn(self._receive_messages)
|
await self.spawn(self._receive_messages())
|
||||||
await self.spawn(collect_tasks)
|
await self.spawn(collect_tasks())
|
||||||
|
|
||||||
async def _limited_wait(self, secs):
|
async def _limited_wait(self, secs):
|
||||||
# Wait at most secs seconds to send, otherwise abort the connection
|
|
||||||
try:
|
try:
|
||||||
async with timeout_after(secs):
|
await asyncio.wait_for(self._can_send.wait(), secs)
|
||||||
await self._can_send.wait()
|
except asyncio.TimeoutError:
|
||||||
except TaskTimeout:
|
|
||||||
self.abort()
|
self.abort()
|
||||||
raise
|
raise asyncio.CancelledError(f'task timed out after {secs}s')
|
||||||
|
|
||||||
async def _send_message(self, message):
|
async def _send_message(self, message):
|
||||||
if not self._can_send.is_set():
|
if not self._can_send.is_set():
|
||||||
|
@ -221,7 +219,7 @@ class SessionBase(asyncio.Protocol):
|
||||||
self._proxy_address = peer_address
|
self._proxy_address = peer_address
|
||||||
else:
|
else:
|
||||||
self._address = peer_address
|
self._address = peer_address
|
||||||
self._pm_task = spawn_sync(self._process_messages(), loop=self.loop)
|
self._pm_task = self.loop.create_task(self._process_messages())
|
||||||
|
|
||||||
def connection_lost(self, exc):
|
def connection_lost(self, exc):
|
||||||
'''Called by asyncio when the connection closes.
|
'''Called by asyncio when the connection closes.
|
||||||
|
@ -281,8 +279,7 @@ class SessionBase(asyncio.Protocol):
|
||||||
self._close()
|
self._close()
|
||||||
if self._pm_task:
|
if self._pm_task:
|
||||||
with suppress(CancelledError):
|
with suppress(CancelledError):
|
||||||
async with ignore_after(force_after):
|
await asyncio.wait([self._pm_task], timeout=force_after)
|
||||||
await self._pm_task
|
|
||||||
self.abort()
|
self.abort()
|
||||||
await self._pm_task
|
await self._pm_task
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ from collections import defaultdict, Counter
|
||||||
|
|
||||||
from torba.rpc import (
|
from torba.rpc import (
|
||||||
Connector, RPCSession, SOCKSProxy, Notification, handler_invocation,
|
Connector, RPCSession, SOCKSProxy, Notification, handler_invocation,
|
||||||
SOCKSError, RPCError, TaskTimeout, TaskGroup, ignore_after, timeout_after
|
SOCKSError, RPCError, TaskGroup
|
||||||
)
|
)
|
||||||
from torba.server.peer import Peer
|
from torba.server.peer import Peer
|
||||||
from torba.server.util import class_logger, protocol_tuple
|
from torba.server.util import class_logger, protocol_tuple
|
||||||
|
@ -194,8 +194,8 @@ class PeerManager:
|
||||||
pause = STALE_SECS - WAKEUP_SECS * 2
|
pause = STALE_SECS - WAKEUP_SECS * 2
|
||||||
else:
|
else:
|
||||||
pause = WAKEUP_SECS * 2 ** peer.try_count
|
pause = WAKEUP_SECS * 2 ** peer.try_count
|
||||||
async with ignore_after(pause):
|
pending, done = await asyncio.wait([peer.retry_event.wait()], timeout=pause)
|
||||||
await peer.retry_event.wait()
|
if done:
|
||||||
peer.retry_event.clear()
|
peer.retry_event.clear()
|
||||||
|
|
||||||
async def _should_drop_peer(self, peer):
|
async def _should_drop_peer(self, peer):
|
||||||
|
@ -224,10 +224,12 @@ class PeerManager:
|
||||||
|
|
||||||
peer_text = f'[{peer}:{port} {kind}]'
|
peer_text = f'[{peer}:{port} {kind}]'
|
||||||
try:
|
try:
|
||||||
async with timeout_after(120 if peer.is_tor else 30):
|
async with Connector(PeerSession, peer.host, port,
|
||||||
async with Connector(PeerSession, peer.host, port,
|
**kwargs) as session:
|
||||||
**kwargs) as session:
|
await asyncio.wait_for(
|
||||||
await self._verify_peer(session, peer)
|
self._verify_peer(session, peer),
|
||||||
|
120 if peer.is_tor else 30
|
||||||
|
)
|
||||||
is_good = True
|
is_good = True
|
||||||
break
|
break
|
||||||
except BadPeerError as e:
|
except BadPeerError as e:
|
||||||
|
@ -237,7 +239,7 @@ class PeerManager:
|
||||||
except RPCError as e:
|
except RPCError as e:
|
||||||
self.logger.error(f'{peer_text} RPC error: {e.message} '
|
self.logger.error(f'{peer_text} RPC error: {e.message} '
|
||||||
f'({e.code})')
|
f'({e.code})')
|
||||||
except (OSError, SOCKSError, ConnectionError, TaskTimeout) as e:
|
except (OSError, SOCKSError, ConnectionError, asyncio.TimeoutError) as e:
|
||||||
self.logger.info(f'{peer_text} {e}')
|
self.logger.info(f'{peer_text} {e}')
|
||||||
|
|
||||||
if is_good:
|
if is_good:
|
||||||
|
|
Loading…
Reference in a new issue