completely dropped curio.py

This commit is contained in:
Lex Berezhny 2018-12-06 20:03:22 -05:00
parent 5ed478ed11
commit a6de3a9642
10 changed files with 82 additions and 337 deletions

View file

@ -9,6 +9,7 @@ from typing import Dict, Type, Iterable, List, Optional
from operator import itemgetter
from collections import namedtuple
from torba.tasks import TaskGroup
from torba.client import baseaccount, basenetwork, basetransaction
from torba.client.basedatabase import BaseDatabase
from torba.client.baseheader import BaseHeaders
@ -73,32 +74,6 @@ class TransactionCacheItem:
self.has_tx.set()
class SynchronizationMonitor:
def __init__(self, loop=None):
self.done = asyncio.Event()
self.tasks = []
self.loop = loop or asyncio.get_event_loop()
def add(self, coro):
len(self.tasks) < 1 and self.done.clear()
self.loop.create_task(self._monitor(coro))
def cancel(self):
for task in self.tasks:
task.cancel()
async def _monitor(self, coro):
task = self.loop.create_task(coro)
self.tasks.append(task)
log.debug('sync tasks: %s', len(self.tasks))
try:
await task
finally:
self.tasks.remove(task)
len(self.tasks) < 1 and self.done.set()
class BaseLedger(metaclass=LedgerRegistry):
name: str
@ -160,7 +135,7 @@ class BaseLedger(metaclass=LedgerRegistry):
)
self._tx_cache = {}
self.sync = SynchronizationMonitor()
self._update_tasks = TaskGroup()
self._utxo_reservation_lock = asyncio.Lock()
self._header_processing_lock = asyncio.Lock()
self._address_update_locks: Dict[str, asyncio.Lock] = {}
@ -265,11 +240,11 @@ class BaseLedger(metaclass=LedgerRegistry):
await self.update_headers()
await self.network.subscribe_headers()
await self.subscribe_accounts()
await self.sync.done.wait()
await self._update_tasks.done.wait()
async def stop(self):
self.sync.cancel()
await self.sync.done.wait()
self._update_tasks.cancel()
await self._update_tasks.done.wait()
await self.network.stop()
await self.db.close()
await self.headers.close()
@ -377,11 +352,11 @@ class BaseLedger(metaclass=LedgerRegistry):
async def subscribe_address(self, address_manager: baseaccount.AddressManager, address: str):
remote_status = await self.network.subscribe_address(address)
self.sync.add(self.update_history(address, remote_status, address_manager))
self._update_tasks.add(self.update_history(address, remote_status, address_manager))
def process_status_update(self, update):
address, remote_status = update
self.sync.add(self.update_history(address, remote_status))
self._update_tasks.add(self.update_history(address, remote_status))
async def update_history(self, address, remote_status,
address_manager: baseaccount.AddressManager = None):

View file

@ -1,12 +1,10 @@
from .curio import *
from .framing import *
from .jsonrpc import *
from .socks import *
from .session import *
from .util import *
__all__ = (curio.__all__ +
framing.__all__ +
__all__ = (framing.__all__ +
jsonrpc.__all__ +
socks.__all__ +
session.__all__ +

View file

@ -1,200 +0,0 @@
# The code below is mostly my own but based on the interfaces of the
# curio library by David Beazley. I'm considering switching to using
# curio. In the mean-time this is an attempt to provide a similar
# clean, pure-async interface and move away from direct
# framework-specific dependencies. As asyncio differs in its design
# it is not possible to provide identical semantics.
#
# The curio library is distributed under the following licence:
#
# Copyright (C) 2015-2017
# David Beazley (Dabeaz LLC)
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of the David Beazley or Dabeaz LLC may be used to
# endorse or promote products derived from this software without
# specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import logging
import asyncio
from collections import deque
from contextlib import suppress
__all__ = 'TaskGroup',
class TaskGroup:
'''A class representing a group of executing tasks. tasks is an
optional set of existing tasks to put into the group. New tasks
can later be added using the spawn() method below. wait specifies
the policy used for waiting for tasks. See the join() method
below. Each TaskGroup is an independent entity. Task groups do not
form a hierarchy or any kind of relationship to other previously
created task groups or tasks. Moreover, Tasks created by the top
level spawn() function are not placed into any task group. To
create a task in a group, it should be created using
TaskGroup.spawn() or explicitly added using TaskGroup.add_task().
completed attribute: the first task that completed with a result
in the group. Takes into account the wait option used in the
TaskGroup constructor (but not in the join method)`.
'''
def __init__(self, tasks=(), *, wait=all):
if wait not in (any, all, object):
raise ValueError('invalid wait argument')
self._done = deque()
self._pending = set()
self._wait = wait
self._done_event = asyncio.Event()
self._logger = logging.getLogger(self.__class__.__name__)
self._closed = False
self.completed = None
for task in tasks:
self._add_task(task)
def _add_task(self, task):
'''Add an already existing task to the task group.'''
if hasattr(task, '_task_group'):
raise RuntimeError('task is already part of a group')
if self._closed:
raise RuntimeError('task group is closed')
task._task_group = self
if task.done():
self._done.append(task)
else:
self._pending.add(task)
task.add_done_callback(self._on_done)
def _on_done(self, task):
task._task_group = None
self._pending.remove(task)
self._done.append(task)
self._done_event.set()
if self.completed is None:
if not task.cancelled() and not task.exception():
if self._wait is object and task.result() is None:
pass
else:
self.completed = task
async def spawn(self, coro, *args):
'''Create a new task thats part of the group. Returns a Task
instance.
'''
task = asyncio.create_task(coro)
self._add_task(task)
return task
async def add_task(self, task):
'''Add an already existing task to the task group.'''
self._add_task(task)
async def next_done(self):
'''Returns the next completed task. Returns None if no more tasks
remain. A TaskGroup may also be used as an asynchronous iterator.
'''
if not self._done and self._pending:
self._done_event.clear()
await self._done_event.wait()
if self._done:
return self._done.popleft()
return None
async def next_result(self):
'''Returns the result of the next completed task. If the task failed
with an exception, that exception is raised. A RuntimeError
exception is raised if this is called when no remaining tasks
are available.'''
task = await self.next_done()
if not task:
raise RuntimeError('no tasks remain')
return task.result()
async def join(self):
'''Wait for tasks in the group to terminate according to the wait
policy for the group.
If the join() operation itself is cancelled, all remaining
tasks in the group are also cancelled.
If a TaskGroup is used as a context manager, the join() method
is called on context-exit.
Once join() returns, no more tasks may be added to the task
group. Tasks can be added while join() is running.
'''
def errored(task):
return not task.cancelled() and task.exception()
try:
if self._wait in (all, object):
while True:
task = await self.next_done()
if task is None:
return
if errored(task):
break
if self._wait is object:
if task.cancelled() or task.result() is not None:
return
else: # any
task = await self.next_done()
if task is None or not errored(task):
return
finally:
await self.cancel_remaining()
if errored(task):
raise task.exception()
async def cancel_remaining(self):
'''Cancel all remaining tasks.'''
self._closed = True
for task in list(self._pending):
task.cancel()
with suppress(asyncio.CancelledError):
await task
def closed(self):
return self._closed
def __aiter__(self):
return self
async def __anext__(self):
task = await self.next_done()
if task:
return task
raise StopAsyncIteration
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_value, traceback):
if exc_type:
await self.cancel_remaining()
else:
await self.join()

View file

@ -34,9 +34,10 @@ import logging
import time
from contextlib import suppress
from torba.tasks import TaskGroup
from .jsonrpc import Request, JSONRPCConnection, JSONRPCv2, JSONRPC, Batch, Notification
from .jsonrpc import RPCError, ProtocolError
from .curio import TaskGroup
from .framing import BadMagicError, BadChecksumError, OversizedPayloadError, BitcoinFramer, NewlineFramer
from .util import Concurrency
@ -98,7 +99,7 @@ class SessionBase(asyncio.Protocol):
self._can_send = Event()
self._can_send.set()
self._pm_task = None
self._task_group = TaskGroup()
self._task_group = TaskGroup(self.loop)
# Force-close a connection if a send doesn't succeed in this time
self.max_send_delay = 60
# Statistics. The RPC object also keeps its own statistics.
@ -140,20 +141,6 @@ class SessionBase(asyncio.Protocol):
'''Called when sending or receiving size bytes.'''
self.bw_charge += size
async def _process_messages(self):
'''Process incoming messages asynchronously and consume the
results.
'''
async def collect_tasks():
next_done = task_group.next_done
while True:
await next_done()
task_group = self._task_group
async with task_group:
await self.spawn(self._receive_messages())
await self.spawn(collect_tasks())
async def _limited_wait(self, secs):
try:
await asyncio.wait_for(self._can_send.wait(), secs)
@ -219,7 +206,7 @@ class SessionBase(asyncio.Protocol):
self._proxy_address = peer_address
else:
self._address = peer_address
self._pm_task = self.loop.create_task(self._process_messages())
self._pm_task = self.loop.create_task(self._receive_messages())
def connection_lost(self, exc):
'''Called by asyncio when the connection closes.
@ -227,6 +214,7 @@ class SessionBase(asyncio.Protocol):
Tear down things done in connection_made.'''
self._address = None
self.transport = None
self._task_group.cancel()
self._pm_task.cancel()
# Release waiting tasks
self._can_send.set()
@ -256,15 +244,6 @@ class SessionBase(asyncio.Protocol):
else:
return f'{ip_addr_str}:{port}'
async def spawn(self, coro, *args):
'''If the session is connected, spawn a task that is cancelled
on disconnect, and return it. Otherwise return None.'''
group = self._task_group
if not group.closed():
return await group.spawn(coro, *args)
else:
return None
def is_closing(self):
'''Return True if the connection is closing.'''
return not self.transport or self.transport.is_closing()
@ -321,7 +300,7 @@ class MessageSession(SessionBase):
self.recv_count += 1
if self.recv_count % 10 == 0:
await self._update_concurrency()
await self.spawn(self._throttled_message(message))
await self._task_group.add(self._throttled_message(message))
async def _throttled_message(self, message):
'''Process a single request, respecting the concurrency limit.'''
@ -453,7 +432,7 @@ class RPCSession(SessionBase):
self._bump_errors()
else:
for request in requests:
await self.spawn(self._throttled_request(request))
await self._task_group.add(self._throttled_request(request))
async def _throttled_request(self, request):
'''Process a single request, respecting the concurrency limit.'''

View file

@ -28,25 +28,8 @@ __all__ = ()
import asyncio
from collections import namedtuple
from functools import partial
import inspect
def normalize_corofunc(corofunc, args):
if asyncio.iscoroutine(corofunc):
if args != ():
raise ValueError('args cannot be passed with a coroutine')
return corofunc
return corofunc(*args)
def is_async_call(func):
'''inspect.iscoroutinefunction that looks through partials.'''
while isinstance(func, partial):
func = func.func
return inspect.iscoroutinefunction(func)
# other_params: None means cannot be called with keyword arguments only
# any means any name is good
SignatureInfo = namedtuple('SignatureInfo', 'min_args max_args '
@ -110,11 +93,3 @@ class Concurrency(object):
else:
for _ in range(-diff):
await self.semaphore.acquire()
def check_task(logger, task):
if not task.cancelled():
try:
task.result()
except Exception:
logger.error('task crashed: %r', task, exc_info=True)

View file

@ -13,8 +13,6 @@ import asyncio
from struct import pack, unpack
import time
from torba.rpc import TaskGroup
import torba
from torba.server.daemon import DaemonError
from torba.server.hash import hash_to_hex_str, HASHX_LEN
@ -651,9 +649,10 @@ class BlockProcessor:
self._caught_up_event = caught_up_event
await self._first_open_dbs()
try:
async with TaskGroup() as group:
await group.spawn(self.prefetcher.main_loop(self.height))
await group.spawn(self._process_prefetched_blocks())
await asyncio.wait([
self.prefetcher.main_loop(self.height),
self._process_prefetched_blocks()
])
finally:
# Shut down block processing
self.logger.info('flushing to DB for a clean shutdown...')

View file

@ -16,7 +16,6 @@ from collections import defaultdict
import attr
from torba.rpc import TaskGroup
from torba.server.hash import hash_to_hex_str, hex_str_to_hash
from torba.server.util import class_logger, chunks
from torba.server.db import UTXO
@ -235,14 +234,13 @@ class MemPool:
# Process new transactions
new_hashes = list(all_hashes.difference(txs))
if new_hashes:
group = TaskGroup()
fetches = []
for hashes in chunks(new_hashes, 200):
coro = self._fetch_and_accept(hashes, all_hashes, touched)
await group.spawn(coro)
fetches.append(self._fetch_and_accept(hashes, all_hashes, touched))
tx_map = {}
utxo_map = {}
async for task in group:
deferred, unspent = task.result()
for fetch in asyncio.as_completed(fetches):
deferred, unspent = await fetch
tx_map.update(deferred)
utxo_map.update(unspent)
@ -306,10 +304,11 @@ class MemPool:
async def keep_synchronized(self, synchronized_event):
'''Keep the mempool synchronized with the daemon.'''
async with TaskGroup() as group:
await group.spawn(self._refresh_hashes(synchronized_event))
await group.spawn(self._refresh_histogram(synchronized_event))
await group.spawn(self._logging(synchronized_event))
await asyncio.wait([
self._refresh_hashes(synchronized_event),
self._refresh_histogram(synchronized_event),
self._logging(synchronized_event)
])
async def balance_delta(self, hashX):
'''Return the unconfirmed amount in the mempool for hashX.

View file

@ -14,9 +14,10 @@ import ssl
import time
from collections import defaultdict, Counter
from torba.tasks import TaskGroup
from torba.rpc import (
Connector, RPCSession, SOCKSProxy, Notification, handler_invocation,
SOCKSError, RPCError, TaskGroup
SOCKSError, RPCError
)
from torba.server.peer import Peer
from torba.server.util import class_logger, protocol_tuple
@ -179,7 +180,7 @@ class PeerManager:
self.logger.info(f'accepted new peer {peer} from {source}')
peer.retry_event = asyncio.Event()
self.peers.add(peer)
await self.group.spawn(self._monitor_peer(peer))
await self.group.add(self._monitor_peer(peer))
async def _monitor_peer(self, peer):
# Stop monitoring if we were dropped (a duplicate peer)
@ -292,10 +293,11 @@ class PeerManager:
peer.features['server_version'] = server_version
ptuple = protocol_tuple(protocol_version)
async with TaskGroup() as g:
await g.spawn(self._send_headers_subscribe(session, peer, ptuple))
await g.spawn(self._send_server_features(session, peer))
await g.spawn(self._send_peers_subscribe(session, peer))
await asyncio.wait([
self._send_headers_subscribe(session, peer, ptuple),
self._send_server_features(session, peer),
self._send_peers_subscribe(session, peer)
])
async def _send_headers_subscribe(self, session, peer, ptuple):
message = 'blockchain.headers.subscribe'
@ -387,18 +389,9 @@ class PeerManager:
self.logger.info(f'beginning peer discovery. Force use of '
f'proxy: {self.env.force_proxy}')
forever = asyncio.Event()
async with self.group as group:
await group.spawn(forever.wait())
await group.spawn(self._detect_proxy())
await group.spawn(self._import_peers())
# Consume tasks as they complete, logging unexpected failures
async for task in group:
if not task.cancelled():
try:
task.result()
except Exception:
self.logger.exception('task failed unexpectedly')
self.group.add(self._detect_proxy())
self.group.add(self._import_peers())
def info(self):
'''The number of peers.'''

View file

@ -22,7 +22,7 @@ from functools import partial
import torba
from torba.rpc import (
RPCSession, JSONRPCAutoDetect, JSONRPCConnection,
TaskGroup, handler_invocation, RPCError, Request
handler_invocation, RPCError, Request
)
from torba.server import text
from torba.server import util
@ -257,9 +257,10 @@ class SessionManager:
for session in stale_sessions)
self.logger.info(f'closing stale connections {text}')
# Give the sockets some time to close gracefully
async with TaskGroup() as group:
for session in stale_sessions:
await group.spawn(session.close())
if stale_sessions:
await asyncio.wait([
session.close() for session in stale_sessions
])
# Consolidate small groups
bw_limit = self.env.bandwidth_limit
@ -499,17 +500,18 @@ class SessionManager:
server_listening_event.set()
# Peer discovery should start after the external servers
# because we connect to ourself
async with TaskGroup() as group:
await group.spawn(self.peer_mgr.discover_peers())
await group.spawn(self._clear_stale_sessions())
await group.spawn(self._log_sessions())
await group.spawn(self._manage_servers())
await asyncio.wait([
self.peer_mgr.discover_peers(),
self._clear_stale_sessions(),
self._log_sessions(),
self._manage_servers()
])
finally:
# Close servers and sessions
await self._close_servers(list(self.servers.keys()))
async with TaskGroup() as group:
for session in list(self.sessions):
await group.spawn(session.close(force_after=1))
if self.sessions:
await asyncio.wait([
session.close(force_after=1) for session in self.sessions
])
def session_count(self):
'''The number of connections that we've sent something to.'''
@ -562,9 +564,10 @@ class SessionManager:
for hashX in set(hc).intersection(touched):
del hc[hashX]
async with TaskGroup() as group:
for session in self.sessions:
await group.spawn(session.notify(touched, height_changed))
if self.sessions:
await asyncio.wait([
session.notify(touched, height_changed) for session in self.sessions
])
def add_session(self, session):
self.sessions.add(session)

24
torba/tasks.py Normal file
View file

@ -0,0 +1,24 @@
from asyncio import Event, get_event_loop
class TaskGroup:
def __init__(self, loop=None):
self._loop = loop or get_event_loop()
self._tasks = set()
self.done = Event()
def add(self, coro):
task = self._loop.create_task(coro)
self._tasks.add(task)
self.done.clear()
task.add_done_callback(self._remove)
return task
def _remove(self, task):
self._tasks.remove(task)
len(self._tasks) < 1 and self.done.set()
def cancel(self):
for task in self._tasks:
task.cancel()