forked from LBRYCommunity/lbry-sdk
completely dropped curio.py
This commit is contained in:
parent
5ed478ed11
commit
a6de3a9642
10 changed files with 82 additions and 337 deletions
|
@ -9,6 +9,7 @@ from typing import Dict, Type, Iterable, List, Optional
|
|||
from operator import itemgetter
|
||||
from collections import namedtuple
|
||||
|
||||
from torba.tasks import TaskGroup
|
||||
from torba.client import baseaccount, basenetwork, basetransaction
|
||||
from torba.client.basedatabase import BaseDatabase
|
||||
from torba.client.baseheader import BaseHeaders
|
||||
|
@ -73,32 +74,6 @@ class TransactionCacheItem:
|
|||
self.has_tx.set()
|
||||
|
||||
|
||||
class SynchronizationMonitor:
|
||||
|
||||
def __init__(self, loop=None):
|
||||
self.done = asyncio.Event()
|
||||
self.tasks = []
|
||||
self.loop = loop or asyncio.get_event_loop()
|
||||
|
||||
def add(self, coro):
|
||||
len(self.tasks) < 1 and self.done.clear()
|
||||
self.loop.create_task(self._monitor(coro))
|
||||
|
||||
def cancel(self):
|
||||
for task in self.tasks:
|
||||
task.cancel()
|
||||
|
||||
async def _monitor(self, coro):
|
||||
task = self.loop.create_task(coro)
|
||||
self.tasks.append(task)
|
||||
log.debug('sync tasks: %s', len(self.tasks))
|
||||
try:
|
||||
await task
|
||||
finally:
|
||||
self.tasks.remove(task)
|
||||
len(self.tasks) < 1 and self.done.set()
|
||||
|
||||
|
||||
class BaseLedger(metaclass=LedgerRegistry):
|
||||
|
||||
name: str
|
||||
|
@ -160,7 +135,7 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
)
|
||||
|
||||
self._tx_cache = {}
|
||||
self.sync = SynchronizationMonitor()
|
||||
self._update_tasks = TaskGroup()
|
||||
self._utxo_reservation_lock = asyncio.Lock()
|
||||
self._header_processing_lock = asyncio.Lock()
|
||||
self._address_update_locks: Dict[str, asyncio.Lock] = {}
|
||||
|
@ -265,11 +240,11 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
await self.update_headers()
|
||||
await self.network.subscribe_headers()
|
||||
await self.subscribe_accounts()
|
||||
await self.sync.done.wait()
|
||||
await self._update_tasks.done.wait()
|
||||
|
||||
async def stop(self):
|
||||
self.sync.cancel()
|
||||
await self.sync.done.wait()
|
||||
self._update_tasks.cancel()
|
||||
await self._update_tasks.done.wait()
|
||||
await self.network.stop()
|
||||
await self.db.close()
|
||||
await self.headers.close()
|
||||
|
@ -377,11 +352,11 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
|
||||
async def subscribe_address(self, address_manager: baseaccount.AddressManager, address: str):
|
||||
remote_status = await self.network.subscribe_address(address)
|
||||
self.sync.add(self.update_history(address, remote_status, address_manager))
|
||||
self._update_tasks.add(self.update_history(address, remote_status, address_manager))
|
||||
|
||||
def process_status_update(self, update):
|
||||
address, remote_status = update
|
||||
self.sync.add(self.update_history(address, remote_status))
|
||||
self._update_tasks.add(self.update_history(address, remote_status))
|
||||
|
||||
async def update_history(self, address, remote_status,
|
||||
address_manager: baseaccount.AddressManager = None):
|
||||
|
|
|
@ -1,12 +1,10 @@
|
|||
from .curio import *
|
||||
from .framing import *
|
||||
from .jsonrpc import *
|
||||
from .socks import *
|
||||
from .session import *
|
||||
from .util import *
|
||||
|
||||
__all__ = (curio.__all__ +
|
||||
framing.__all__ +
|
||||
__all__ = (framing.__all__ +
|
||||
jsonrpc.__all__ +
|
||||
socks.__all__ +
|
||||
session.__all__ +
|
||||
|
|
|
@ -1,200 +0,0 @@
|
|||
# The code below is mostly my own but based on the interfaces of the
|
||||
# curio library by David Beazley. I'm considering switching to using
|
||||
# curio. In the mean-time this is an attempt to provide a similar
|
||||
# clean, pure-async interface and move away from direct
|
||||
# framework-specific dependencies. As asyncio differs in its design
|
||||
# it is not possible to provide identical semantics.
|
||||
#
|
||||
# The curio library is distributed under the following licence:
|
||||
#
|
||||
# Copyright (C) 2015-2017
|
||||
# David Beazley (Dabeaz LLC)
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are
|
||||
# met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice,
|
||||
# this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
# * Neither the name of the David Beazley or Dabeaz LLC may be used to
|
||||
# endorse or promote products derived from this software without
|
||||
# specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from contextlib import suppress
|
||||
|
||||
|
||||
__all__ = 'TaskGroup',
|
||||
|
||||
|
||||
class TaskGroup:
|
||||
'''A class representing a group of executing tasks. tasks is an
|
||||
optional set of existing tasks to put into the group. New tasks
|
||||
can later be added using the spawn() method below. wait specifies
|
||||
the policy used for waiting for tasks. See the join() method
|
||||
below. Each TaskGroup is an independent entity. Task groups do not
|
||||
form a hierarchy or any kind of relationship to other previously
|
||||
created task groups or tasks. Moreover, Tasks created by the top
|
||||
level spawn() function are not placed into any task group. To
|
||||
create a task in a group, it should be created using
|
||||
TaskGroup.spawn() or explicitly added using TaskGroup.add_task().
|
||||
|
||||
completed attribute: the first task that completed with a result
|
||||
in the group. Takes into account the wait option used in the
|
||||
TaskGroup constructor (but not in the join method)`.
|
||||
'''
|
||||
|
||||
def __init__(self, tasks=(), *, wait=all):
|
||||
if wait not in (any, all, object):
|
||||
raise ValueError('invalid wait argument')
|
||||
self._done = deque()
|
||||
self._pending = set()
|
||||
self._wait = wait
|
||||
self._done_event = asyncio.Event()
|
||||
self._logger = logging.getLogger(self.__class__.__name__)
|
||||
self._closed = False
|
||||
self.completed = None
|
||||
for task in tasks:
|
||||
self._add_task(task)
|
||||
|
||||
def _add_task(self, task):
|
||||
'''Add an already existing task to the task group.'''
|
||||
if hasattr(task, '_task_group'):
|
||||
raise RuntimeError('task is already part of a group')
|
||||
if self._closed:
|
||||
raise RuntimeError('task group is closed')
|
||||
task._task_group = self
|
||||
if task.done():
|
||||
self._done.append(task)
|
||||
else:
|
||||
self._pending.add(task)
|
||||
task.add_done_callback(self._on_done)
|
||||
|
||||
def _on_done(self, task):
|
||||
task._task_group = None
|
||||
self._pending.remove(task)
|
||||
self._done.append(task)
|
||||
self._done_event.set()
|
||||
if self.completed is None:
|
||||
if not task.cancelled() and not task.exception():
|
||||
if self._wait is object and task.result() is None:
|
||||
pass
|
||||
else:
|
||||
self.completed = task
|
||||
|
||||
async def spawn(self, coro, *args):
|
||||
'''Create a new task that’s part of the group. Returns a Task
|
||||
instance.
|
||||
'''
|
||||
task = asyncio.create_task(coro)
|
||||
self._add_task(task)
|
||||
return task
|
||||
|
||||
async def add_task(self, task):
|
||||
'''Add an already existing task to the task group.'''
|
||||
self._add_task(task)
|
||||
|
||||
async def next_done(self):
|
||||
'''Returns the next completed task. Returns None if no more tasks
|
||||
remain. A TaskGroup may also be used as an asynchronous iterator.
|
||||
'''
|
||||
if not self._done and self._pending:
|
||||
self._done_event.clear()
|
||||
await self._done_event.wait()
|
||||
if self._done:
|
||||
return self._done.popleft()
|
||||
return None
|
||||
|
||||
async def next_result(self):
|
||||
'''Returns the result of the next completed task. If the task failed
|
||||
with an exception, that exception is raised. A RuntimeError
|
||||
exception is raised if this is called when no remaining tasks
|
||||
are available.'''
|
||||
task = await self.next_done()
|
||||
if not task:
|
||||
raise RuntimeError('no tasks remain')
|
||||
return task.result()
|
||||
|
||||
async def join(self):
|
||||
'''Wait for tasks in the group to terminate according to the wait
|
||||
policy for the group.
|
||||
|
||||
If the join() operation itself is cancelled, all remaining
|
||||
tasks in the group are also cancelled.
|
||||
|
||||
If a TaskGroup is used as a context manager, the join() method
|
||||
is called on context-exit.
|
||||
|
||||
Once join() returns, no more tasks may be added to the task
|
||||
group. Tasks can be added while join() is running.
|
||||
'''
|
||||
def errored(task):
|
||||
return not task.cancelled() and task.exception()
|
||||
|
||||
try:
|
||||
if self._wait in (all, object):
|
||||
while True:
|
||||
task = await self.next_done()
|
||||
if task is None:
|
||||
return
|
||||
if errored(task):
|
||||
break
|
||||
if self._wait is object:
|
||||
if task.cancelled() or task.result() is not None:
|
||||
return
|
||||
else: # any
|
||||
task = await self.next_done()
|
||||
if task is None or not errored(task):
|
||||
return
|
||||
finally:
|
||||
await self.cancel_remaining()
|
||||
|
||||
if errored(task):
|
||||
raise task.exception()
|
||||
|
||||
async def cancel_remaining(self):
|
||||
'''Cancel all remaining tasks.'''
|
||||
self._closed = True
|
||||
for task in list(self._pending):
|
||||
task.cancel()
|
||||
with suppress(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
def closed(self):
|
||||
return self._closed
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
task = await self.next_done()
|
||||
if task:
|
||||
return task
|
||||
raise StopAsyncIteration
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
if exc_type:
|
||||
await self.cancel_remaining()
|
||||
else:
|
||||
await self.join()
|
|
@ -34,9 +34,10 @@ import logging
|
|||
import time
|
||||
from contextlib import suppress
|
||||
|
||||
from torba.tasks import TaskGroup
|
||||
|
||||
from .jsonrpc import Request, JSONRPCConnection, JSONRPCv2, JSONRPC, Batch, Notification
|
||||
from .jsonrpc import RPCError, ProtocolError
|
||||
from .curio import TaskGroup
|
||||
from .framing import BadMagicError, BadChecksumError, OversizedPayloadError, BitcoinFramer, NewlineFramer
|
||||
from .util import Concurrency
|
||||
|
||||
|
@ -98,7 +99,7 @@ class SessionBase(asyncio.Protocol):
|
|||
self._can_send = Event()
|
||||
self._can_send.set()
|
||||
self._pm_task = None
|
||||
self._task_group = TaskGroup()
|
||||
self._task_group = TaskGroup(self.loop)
|
||||
# Force-close a connection if a send doesn't succeed in this time
|
||||
self.max_send_delay = 60
|
||||
# Statistics. The RPC object also keeps its own statistics.
|
||||
|
@ -140,20 +141,6 @@ class SessionBase(asyncio.Protocol):
|
|||
'''Called when sending or receiving size bytes.'''
|
||||
self.bw_charge += size
|
||||
|
||||
async def _process_messages(self):
|
||||
'''Process incoming messages asynchronously and consume the
|
||||
results.
|
||||
'''
|
||||
async def collect_tasks():
|
||||
next_done = task_group.next_done
|
||||
while True:
|
||||
await next_done()
|
||||
|
||||
task_group = self._task_group
|
||||
async with task_group:
|
||||
await self.spawn(self._receive_messages())
|
||||
await self.spawn(collect_tasks())
|
||||
|
||||
async def _limited_wait(self, secs):
|
||||
try:
|
||||
await asyncio.wait_for(self._can_send.wait(), secs)
|
||||
|
@ -219,7 +206,7 @@ class SessionBase(asyncio.Protocol):
|
|||
self._proxy_address = peer_address
|
||||
else:
|
||||
self._address = peer_address
|
||||
self._pm_task = self.loop.create_task(self._process_messages())
|
||||
self._pm_task = self.loop.create_task(self._receive_messages())
|
||||
|
||||
def connection_lost(self, exc):
|
||||
'''Called by asyncio when the connection closes.
|
||||
|
@ -227,6 +214,7 @@ class SessionBase(asyncio.Protocol):
|
|||
Tear down things done in connection_made.'''
|
||||
self._address = None
|
||||
self.transport = None
|
||||
self._task_group.cancel()
|
||||
self._pm_task.cancel()
|
||||
# Release waiting tasks
|
||||
self._can_send.set()
|
||||
|
@ -256,15 +244,6 @@ class SessionBase(asyncio.Protocol):
|
|||
else:
|
||||
return f'{ip_addr_str}:{port}'
|
||||
|
||||
async def spawn(self, coro, *args):
|
||||
'''If the session is connected, spawn a task that is cancelled
|
||||
on disconnect, and return it. Otherwise return None.'''
|
||||
group = self._task_group
|
||||
if not group.closed():
|
||||
return await group.spawn(coro, *args)
|
||||
else:
|
||||
return None
|
||||
|
||||
def is_closing(self):
|
||||
'''Return True if the connection is closing.'''
|
||||
return not self.transport or self.transport.is_closing()
|
||||
|
@ -321,7 +300,7 @@ class MessageSession(SessionBase):
|
|||
self.recv_count += 1
|
||||
if self.recv_count % 10 == 0:
|
||||
await self._update_concurrency()
|
||||
await self.spawn(self._throttled_message(message))
|
||||
await self._task_group.add(self._throttled_message(message))
|
||||
|
||||
async def _throttled_message(self, message):
|
||||
'''Process a single request, respecting the concurrency limit.'''
|
||||
|
@ -453,7 +432,7 @@ class RPCSession(SessionBase):
|
|||
self._bump_errors()
|
||||
else:
|
||||
for request in requests:
|
||||
await self.spawn(self._throttled_request(request))
|
||||
await self._task_group.add(self._throttled_request(request))
|
||||
|
||||
async def _throttled_request(self, request):
|
||||
'''Process a single request, respecting the concurrency limit.'''
|
||||
|
|
|
@ -28,25 +28,8 @@ __all__ = ()
|
|||
|
||||
import asyncio
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
import inspect
|
||||
|
||||
|
||||
def normalize_corofunc(corofunc, args):
|
||||
if asyncio.iscoroutine(corofunc):
|
||||
if args != ():
|
||||
raise ValueError('args cannot be passed with a coroutine')
|
||||
return corofunc
|
||||
return corofunc(*args)
|
||||
|
||||
|
||||
def is_async_call(func):
|
||||
'''inspect.iscoroutinefunction that looks through partials.'''
|
||||
while isinstance(func, partial):
|
||||
func = func.func
|
||||
return inspect.iscoroutinefunction(func)
|
||||
|
||||
|
||||
# other_params: None means cannot be called with keyword arguments only
|
||||
# any means any name is good
|
||||
SignatureInfo = namedtuple('SignatureInfo', 'min_args max_args '
|
||||
|
@ -110,11 +93,3 @@ class Concurrency(object):
|
|||
else:
|
||||
for _ in range(-diff):
|
||||
await self.semaphore.acquire()
|
||||
|
||||
|
||||
def check_task(logger, task):
|
||||
if not task.cancelled():
|
||||
try:
|
||||
task.result()
|
||||
except Exception:
|
||||
logger.error('task crashed: %r', task, exc_info=True)
|
||||
|
|
|
@ -13,8 +13,6 @@ import asyncio
|
|||
from struct import pack, unpack
|
||||
import time
|
||||
|
||||
from torba.rpc import TaskGroup
|
||||
|
||||
import torba
|
||||
from torba.server.daemon import DaemonError
|
||||
from torba.server.hash import hash_to_hex_str, HASHX_LEN
|
||||
|
@ -651,9 +649,10 @@ class BlockProcessor:
|
|||
self._caught_up_event = caught_up_event
|
||||
await self._first_open_dbs()
|
||||
try:
|
||||
async with TaskGroup() as group:
|
||||
await group.spawn(self.prefetcher.main_loop(self.height))
|
||||
await group.spawn(self._process_prefetched_blocks())
|
||||
await asyncio.wait([
|
||||
self.prefetcher.main_loop(self.height),
|
||||
self._process_prefetched_blocks()
|
||||
])
|
||||
finally:
|
||||
# Shut down block processing
|
||||
self.logger.info('flushing to DB for a clean shutdown...')
|
||||
|
|
|
@ -16,7 +16,6 @@ from collections import defaultdict
|
|||
|
||||
import attr
|
||||
|
||||
from torba.rpc import TaskGroup
|
||||
from torba.server.hash import hash_to_hex_str, hex_str_to_hash
|
||||
from torba.server.util import class_logger, chunks
|
||||
from torba.server.db import UTXO
|
||||
|
@ -235,14 +234,13 @@ class MemPool:
|
|||
# Process new transactions
|
||||
new_hashes = list(all_hashes.difference(txs))
|
||||
if new_hashes:
|
||||
group = TaskGroup()
|
||||
fetches = []
|
||||
for hashes in chunks(new_hashes, 200):
|
||||
coro = self._fetch_and_accept(hashes, all_hashes, touched)
|
||||
await group.spawn(coro)
|
||||
fetches.append(self._fetch_and_accept(hashes, all_hashes, touched))
|
||||
tx_map = {}
|
||||
utxo_map = {}
|
||||
async for task in group:
|
||||
deferred, unspent = task.result()
|
||||
for fetch in asyncio.as_completed(fetches):
|
||||
deferred, unspent = await fetch
|
||||
tx_map.update(deferred)
|
||||
utxo_map.update(unspent)
|
||||
|
||||
|
@ -306,10 +304,11 @@ class MemPool:
|
|||
|
||||
async def keep_synchronized(self, synchronized_event):
|
||||
'''Keep the mempool synchronized with the daemon.'''
|
||||
async with TaskGroup() as group:
|
||||
await group.spawn(self._refresh_hashes(synchronized_event))
|
||||
await group.spawn(self._refresh_histogram(synchronized_event))
|
||||
await group.spawn(self._logging(synchronized_event))
|
||||
await asyncio.wait([
|
||||
self._refresh_hashes(synchronized_event),
|
||||
self._refresh_histogram(synchronized_event),
|
||||
self._logging(synchronized_event)
|
||||
])
|
||||
|
||||
async def balance_delta(self, hashX):
|
||||
'''Return the unconfirmed amount in the mempool for hashX.
|
||||
|
|
|
@ -14,9 +14,10 @@ import ssl
|
|||
import time
|
||||
from collections import defaultdict, Counter
|
||||
|
||||
from torba.tasks import TaskGroup
|
||||
from torba.rpc import (
|
||||
Connector, RPCSession, SOCKSProxy, Notification, handler_invocation,
|
||||
SOCKSError, RPCError, TaskGroup
|
||||
SOCKSError, RPCError
|
||||
)
|
||||
from torba.server.peer import Peer
|
||||
from torba.server.util import class_logger, protocol_tuple
|
||||
|
@ -179,7 +180,7 @@ class PeerManager:
|
|||
self.logger.info(f'accepted new peer {peer} from {source}')
|
||||
peer.retry_event = asyncio.Event()
|
||||
self.peers.add(peer)
|
||||
await self.group.spawn(self._monitor_peer(peer))
|
||||
await self.group.add(self._monitor_peer(peer))
|
||||
|
||||
async def _monitor_peer(self, peer):
|
||||
# Stop monitoring if we were dropped (a duplicate peer)
|
||||
|
@ -292,10 +293,11 @@ class PeerManager:
|
|||
peer.features['server_version'] = server_version
|
||||
ptuple = protocol_tuple(protocol_version)
|
||||
|
||||
async with TaskGroup() as g:
|
||||
await g.spawn(self._send_headers_subscribe(session, peer, ptuple))
|
||||
await g.spawn(self._send_server_features(session, peer))
|
||||
await g.spawn(self._send_peers_subscribe(session, peer))
|
||||
await asyncio.wait([
|
||||
self._send_headers_subscribe(session, peer, ptuple),
|
||||
self._send_server_features(session, peer),
|
||||
self._send_peers_subscribe(session, peer)
|
||||
])
|
||||
|
||||
async def _send_headers_subscribe(self, session, peer, ptuple):
|
||||
message = 'blockchain.headers.subscribe'
|
||||
|
@ -387,18 +389,9 @@ class PeerManager:
|
|||
|
||||
self.logger.info(f'beginning peer discovery. Force use of '
|
||||
f'proxy: {self.env.force_proxy}')
|
||||
forever = asyncio.Event()
|
||||
async with self.group as group:
|
||||
await group.spawn(forever.wait())
|
||||
await group.spawn(self._detect_proxy())
|
||||
await group.spawn(self._import_peers())
|
||||
# Consume tasks as they complete, logging unexpected failures
|
||||
async for task in group:
|
||||
if not task.cancelled():
|
||||
try:
|
||||
task.result()
|
||||
except Exception:
|
||||
self.logger.exception('task failed unexpectedly')
|
||||
|
||||
self.group.add(self._detect_proxy())
|
||||
self.group.add(self._import_peers())
|
||||
|
||||
def info(self):
|
||||
'''The number of peers.'''
|
||||
|
|
|
@ -22,7 +22,7 @@ from functools import partial
|
|||
import torba
|
||||
from torba.rpc import (
|
||||
RPCSession, JSONRPCAutoDetect, JSONRPCConnection,
|
||||
TaskGroup, handler_invocation, RPCError, Request
|
||||
handler_invocation, RPCError, Request
|
||||
)
|
||||
from torba.server import text
|
||||
from torba.server import util
|
||||
|
@ -257,9 +257,10 @@ class SessionManager:
|
|||
for session in stale_sessions)
|
||||
self.logger.info(f'closing stale connections {text}')
|
||||
# Give the sockets some time to close gracefully
|
||||
async with TaskGroup() as group:
|
||||
for session in stale_sessions:
|
||||
await group.spawn(session.close())
|
||||
if stale_sessions:
|
||||
await asyncio.wait([
|
||||
session.close() for session in stale_sessions
|
||||
])
|
||||
|
||||
# Consolidate small groups
|
||||
bw_limit = self.env.bandwidth_limit
|
||||
|
@ -499,17 +500,18 @@ class SessionManager:
|
|||
server_listening_event.set()
|
||||
# Peer discovery should start after the external servers
|
||||
# because we connect to ourself
|
||||
async with TaskGroup() as group:
|
||||
await group.spawn(self.peer_mgr.discover_peers())
|
||||
await group.spawn(self._clear_stale_sessions())
|
||||
await group.spawn(self._log_sessions())
|
||||
await group.spawn(self._manage_servers())
|
||||
await asyncio.wait([
|
||||
self.peer_mgr.discover_peers(),
|
||||
self._clear_stale_sessions(),
|
||||
self._log_sessions(),
|
||||
self._manage_servers()
|
||||
])
|
||||
finally:
|
||||
# Close servers and sessions
|
||||
await self._close_servers(list(self.servers.keys()))
|
||||
async with TaskGroup() as group:
|
||||
for session in list(self.sessions):
|
||||
await group.spawn(session.close(force_after=1))
|
||||
if self.sessions:
|
||||
await asyncio.wait([
|
||||
session.close(force_after=1) for session in self.sessions
|
||||
])
|
||||
|
||||
def session_count(self):
|
||||
'''The number of connections that we've sent something to.'''
|
||||
|
@ -562,9 +564,10 @@ class SessionManager:
|
|||
for hashX in set(hc).intersection(touched):
|
||||
del hc[hashX]
|
||||
|
||||
async with TaskGroup() as group:
|
||||
for session in self.sessions:
|
||||
await group.spawn(session.notify(touched, height_changed))
|
||||
if self.sessions:
|
||||
await asyncio.wait([
|
||||
session.notify(touched, height_changed) for session in self.sessions
|
||||
])
|
||||
|
||||
def add_session(self, session):
|
||||
self.sessions.add(session)
|
||||
|
|
24
torba/tasks.py
Normal file
24
torba/tasks.py
Normal file
|
@ -0,0 +1,24 @@
|
|||
from asyncio import Event, get_event_loop
|
||||
|
||||
|
||||
class TaskGroup:
|
||||
|
||||
def __init__(self, loop=None):
|
||||
self._loop = loop or get_event_loop()
|
||||
self._tasks = set()
|
||||
self.done = Event()
|
||||
|
||||
def add(self, coro):
|
||||
task = self._loop.create_task(coro)
|
||||
self._tasks.add(task)
|
||||
self.done.clear()
|
||||
task.add_done_callback(self._remove)
|
||||
return task
|
||||
|
||||
def _remove(self, task):
|
||||
self._tasks.remove(task)
|
||||
len(self._tasks) < 1 and self.done.set()
|
||||
|
||||
def cancel(self):
|
||||
for task in self._tasks:
|
||||
task.cancel()
|
Loading…
Add table
Reference in a new issue