removing dependence on curio abstraction

This commit is contained in:
Lex Berezhny 2018-12-06 14:27:38 -05:00
parent f4ec20a2e2
commit b93f9d4c94
7 changed files with 41 additions and 90 deletions

View file

@ -38,10 +38,6 @@
import logging
import asyncio
from asyncio import (
CancelledError, get_event_loop, Queue, Event, Lock, Semaphore,
sleep, Task
)
from collections import deque
from contextlib import suppress
from functools import partial
@ -50,32 +46,22 @@ from .util import normalize_corofunc, check_task
__all__ = (
'Queue', 'Event', 'Lock', 'Semaphore', 'sleep', 'CancelledError',
'run_in_thread', 'spawn', 'spawn_sync', 'TaskGroup',
'spawn_sync', 'TaskGroup',
'TaskTimeout', 'TimeoutCancellationError', 'UncaughtTimeoutError',
'timeout_after', 'timeout_at', 'ignore_after', 'ignore_at',
'timeout_after', 'ignore_after'
)
async def run_in_thread(func, *args):
'''Run a function in a separate thread, and await its completion.'''
return await get_event_loop().run_in_executor(None, func, *args)
async def spawn(coro, *args, loop=None, report_crash=True):
return spawn_sync(coro, *args, loop=loop, report_crash=report_crash)
def spawn_sync(coro, *args, loop=None, report_crash=True):
coro = normalize_corofunc(coro, args)
loop = loop or get_event_loop()
loop = loop or asyncio.get_event_loop()
task = loop.create_task(coro)
if report_crash:
task.add_done_callback(partial(check_task, logging))
return task
class TaskGroup(object):
class TaskGroup:
'''A class representing a group of executing tasks. tasks is an
optional set of existing tasks to put into the group. New tasks
can later be added using the spawn() method below. wait specifies
@ -98,7 +84,7 @@ class TaskGroup(object):
self._done = deque()
self._pending = set()
self._wait = wait
self._done_event = Event()
self._done_event = asyncio.Event()
self._logger = logging.getLogger(self.__class__.__name__)
self._closed = False
self.completed = None
@ -134,7 +120,7 @@ class TaskGroup(object):
'''Create a new task thats part of the group. Returns a Task
instance.
'''
task = await spawn(coro, *args, report_crash=False)
task = spawn_sync(coro, *args, report_crash=False)
self._add_task(task)
return task
@ -205,7 +191,7 @@ class TaskGroup(object):
self._closed = True
for task in list(self._pending):
task.cancel()
with suppress(CancelledError):
with suppress(asyncio.CancelledError):
await task
def closed(self):
@ -230,7 +216,7 @@ class TaskGroup(object):
await self.join()
class TaskTimeout(CancelledError):
class TaskTimeout(asyncio.CancelledError):
def __init__(self, secs):
self.secs = secs
@ -239,7 +225,7 @@ class TaskTimeout(CancelledError):
return f'task timed out after {self.args[0]}s'
class TimeoutCancellationError(CancelledError):
class TimeoutCancellationError(asyncio.CancelledError):
pass
@ -302,7 +288,7 @@ class TimeoutAfter(object):
async def __aexit__(self, exc_type, exc_value, traceback):
timed_out_deadline, uncaught = _unset_task_deadline(self._task)
if exc_type not in (CancelledError, TaskTimeout,
if exc_type not in (asyncio.CancelledError, TaskTimeout,
TimeoutCancellationError):
return False
if timed_out_deadline == self._deadline:
@ -311,7 +297,7 @@ class TimeoutAfter(object):
return True
raise TaskTimeout(self._secs) from None
if timed_out_deadline is None:
assert exc_type is CancelledError
assert exc_type is asyncio.CancelledError
return False
if uncaught:
raise UncaughtTimeoutError('uncaught timeout received')
@ -348,28 +334,6 @@ def timeout_after(seconds, coro=None, *args):
return TimeoutAfter(seconds)
def timeout_at(clock, coro=None, *args):
'''Execute the specified coroutine and return its result. However,
issue a cancellation request to the calling task after seconds
have elapsed. When this happens, a TaskTimeout exception is
raised. If coro is None, the result of this function serves
as an asynchronous context manager that applies a timeout to a
block of statements.
timeout_after() may be composed with other timeout_after()
operations (i.e., nested timeouts). If an outer timeout expires
first, then TimeoutCancellationError is raised instead of
TaskTimeout. If an inner timeout expires and fails to properly
TaskTimeout, a UncaughtTimeoutError is raised in the outer
timeout.
'''
if coro:
return _timeout_after_func(clock, True, coro, args)
return TimeoutAfter(clock, absolute=True)
async def _ignore_after_func(seconds, absolute, coro, args, timeout_result):
coro = normalize_corofunc(coro, args)
async with TimeoutAfter(seconds, absolute=absolute, ignore=True):
@ -398,14 +362,3 @@ def ignore_after(seconds, coro=None, *args, timeout_result=None):
return _ignore_after_func(seconds, False, coro, args, timeout_result)
return TimeoutAfter(seconds, ignore=True)
def ignore_at(clock, coro=None, *args, timeout_result=None):
'''
Stop the enclosed task or block of code at an absolute
clock value. Same usage as ignore_after().
'''
if coro:
return _ignore_after_func(clock, True, coro, args, timeout_result)
return TimeoutAfter(clock, absolute=True, ignore=True)

View file

@ -9,13 +9,11 @@
'''Block prefetcher and chain processor.'''
import array
import asyncio
from struct import pack, unpack
import time
from functools import partial
from torba.rpc import TaskGroup, run_in_thread
from torba.rpc import TaskGroup
import torba
from torba.server.daemon import DaemonError
@ -187,7 +185,7 @@ class BlockProcessor:
# consistent and not being updated elsewhere.
async def run_in_thread_locked():
async with self.state_lock:
return await run_in_thread(func, *args)
return await asyncio.get_event_loop().run_in_executor(None, func, *args)
return await asyncio.shield(run_in_thread_locked())
async def check_and_advance_blocks(self, raw_blocks):

View file

@ -9,6 +9,7 @@
'''Interface to the blockchain database.'''
import asyncio
import array
import ast
import os
@ -20,7 +21,6 @@ from struct import pack, unpack
import attr
from torba.rpc import run_in_thread, sleep
from torba.server import util
from torba.server.hash import hash_to_hex_str, HASHX_LEN
from torba.server.merkle import Merkle, MerkleCache
@ -403,7 +403,7 @@ class DB:
return self.headers_file.read(offset, size), disk_count
return b'', 0
return await run_in_thread(read_headers)
return await asyncio.get_event_loop().run_in_executor(None, read_headers)
def fs_tx_hash(self, tx_num):
'''Return a par (tx_hash, tx_height) for the given tx number.
@ -443,12 +443,12 @@ class DB:
return [fs_tx_hash(tx_num) for tx_num in tx_nums]
while True:
history = await run_in_thread(read_history)
history = await asyncio.get_event_loop().run_in_executor(None, read_history)
if all(hash is not None for hash, height in history):
return history
self.logger.warning(f'limited_history: tx hash '
f'not found (reorg?), retrying...')
await sleep(0.25)
await asyncio.sleep(0.25)
# -- Undo information
@ -612,12 +612,12 @@ class DB:
return utxos
while True:
utxos = await run_in_thread(read_utxos)
utxos = await asyncio.get_event_loop().run_in_executor(None, read_utxos)
if all(utxo.tx_hash is not None for utxo in utxos):
return utxos
self.logger.warning(f'all_utxos: tx hash not '
f'found (reorg?), retrying...')
await sleep(0.25)
await asyncio.sleep(0.25)
async def lookup_utxos(self, prevouts):
'''For each prevout, lookup it up in the DB and return a (hashX,
@ -665,5 +665,5 @@ class DB:
return hashX, value
return [lookup_utxo(*hashX_pair) for hashX_pair in hashX_pairs]
hashX_pairs = await run_in_thread(lookup_hashXs)
return await run_in_thread(lookup_utxos, hashX_pairs)
hashX_pairs = await asyncio.get_event_loop().run_in_executor(None, lookup_hashXs)
return await asyncio.get_event_loop().run_in_executor(None, lookup_utxos, hashX_pairs)

View file

@ -7,6 +7,7 @@
'''Mempool handling.'''
import asyncio
import itertools
import time
from abc import ABC, abstractmethod
@ -15,7 +16,7 @@ from collections import defaultdict
import attr
from torba.rpc import TaskGroup, run_in_thread, sleep
from torba.rpc import TaskGroup
from torba.server.hash import hash_to_hex_str, hex_str_to_hash
from torba.server.util import class_logger, chunks
from torba.server.db import UTXO
@ -117,7 +118,7 @@ class MemPool:
while True:
self.logger.info(f'{len(self.txs):,d} txs '
f'touching {len(self.hashXs):,d} addresses')
await sleep(self.log_status_secs)
await asyncio.sleep(self.log_status_secs)
await synchronized_event.wait()
async def _refresh_histogram(self, synchronized_event):
@ -125,8 +126,8 @@ class MemPool:
await synchronized_event.wait()
async with self.lock:
# Threaded as can be expensive
await run_in_thread(self._update_histogram, 100_000)
await sleep(self.coin.MEMPOOL_HISTOGRAM_REFRESH_SECS)
await asyncio.get_event_loop().run_in_executor(None, self._update_histogram, 100_000)
await asyncio.sleep(self.coin.MEMPOOL_HISTOGRAM_REFRESH_SECS)
def _update_histogram(self, bin_size):
# Build a histogram by fee rate
@ -212,7 +213,7 @@ class MemPool:
synchronized_event.set()
synchronized_event.clear()
await self.api.on_mempool(touched, height)
await sleep(self.refresh_secs)
await asyncio.sleep(self.refresh_secs)
async def _process_mempool(self, all_hashes):
# Re-sync with the new set of hashes
@ -284,7 +285,7 @@ class MemPool:
return txs
# Thread this potentially slow operation so as not to block
tx_map = await run_in_thread(deserialize_txs)
tx_map = await asyncio.get_event_loop().run_in_executor(None, deserialize_txs)
# Determine all prevouts not in the mempool, and fetch the
# UTXO information from the database. Failed prevout lookups

View file

@ -26,9 +26,9 @@
'''Merkle trees, branches, proofs and roots.'''
import asyncio
from math import ceil, log
from torba.rpc import Event
from torba.server.hash import double_sha256
@ -169,7 +169,7 @@ class MerkleCache:
self.source_func = source_func
self.length = 0
self.depth_higher = 0
self.initialized = Event()
self.initialized = asyncio.Event()
def _segment_length(self):
return 1 << self.depth_higher

View file

@ -14,10 +14,10 @@ import ssl
import time
from collections import defaultdict, Counter
from torba.rpc import (Connector, RPCSession, SOCKSProxy,
Notification, handler_invocation,
SOCKSError, RPCError, TaskTimeout, TaskGroup, Event,
sleep, run_in_thread, ignore_after, timeout_after)
from torba.rpc import (
Connector, RPCSession, SOCKSProxy, Notification, handler_invocation,
SOCKSError, RPCError, TaskTimeout, TaskGroup, ignore_after, timeout_after
)
from torba.server.peer import Peer
from torba.server.util import class_logger, protocol_tuple
@ -149,7 +149,7 @@ class PeerManager:
self.logger.info(f'detected {proxy}')
return
self.logger.info('no proxy detected, will try later')
await sleep(900)
await asyncio.sleep(900)
async def _note_peers(self, peers, limit=2, check_ports=False,
source=None):
@ -177,7 +177,7 @@ class PeerManager:
use_peers = new_peers
for peer in use_peers:
self.logger.info(f'accepted new peer {peer} from {source}')
peer.retry_event = Event()
peer.retry_event = asyncio.Event()
self.peers.add(peer)
await self.group.spawn(self._monitor_peer(peer))
@ -385,7 +385,7 @@ class PeerManager:
self.logger.info(f'beginning peer discovery. Force use of '
f'proxy: {self.env.force_proxy}')
forever = Event()
forever = asyncio.Event()
async with self.group as group:
await group.spawn(forever.wait())
await group.spawn(self._detect_proxy())

View file

@ -22,8 +22,7 @@ from functools import partial
import torba
from torba.rpc import (
RPCSession, JSONRPCAutoDetect, JSONRPCConnection,
TaskGroup, handler_invocation, RPCError, Request, ignore_after, sleep,
Event
TaskGroup, handler_invocation, RPCError, Request
)
from torba.server import text
from torba.server import util
@ -131,7 +130,7 @@ class SessionManager:
self.mn_cache_height = 0
self.mn_cache = []
self.session_event = Event()
self.session_event = asyncio.Event()
# Set up the RPC request handlers
cmds = ('add_peer daemon_url disconnect getinfo groups log peers '
@ -207,7 +206,7 @@ class SessionManager:
log_interval = self.env.log_sessions
if log_interval:
while True:
await sleep(log_interval)
await asyncio.sleep(log_interval)
data = self._session_data(for_log=True)
for line in text.sessions_lines(data):
self.logger.info(line)
@ -249,7 +248,7 @@ class SessionManager:
async def _clear_stale_sessions(self):
'''Cut off sessions that haven't done anything for 10 minutes.'''
while True:
await sleep(60)
await asyncio.sleep(60)
stale_cutoff = time.time() - self.env.session_timeout
stale_sessions = [session for session in self.sessions
if session.last_recv < stale_cutoff]