Compare commits

...
Sign in to create a new pull request.

1 commit

Author SHA1 Message Date
Victor Shyba
593422f70c niko 2019-08-23 21:50:11 -03:00
8 changed files with 300 additions and 175 deletions

View file

@ -18,6 +18,9 @@ class MockNetwork:
self.get_transaction_called = [] self.get_transaction_called = []
self.is_connected = False self.is_connected = False
def retriable_call(self, function, *args, **kwargs):
return function(*args, **kwargs)
async def get_history(self, address): async def get_history(self, address):
self.get_history_called.append(address) self.get_history_called.append(address)
self.address = address self.address = address
@ -121,8 +124,9 @@ class TestSynchronization(LedgerTestCase):
) )
class MocHeaderNetwork: class MocHeaderNetwork(MockNetwork):
def __init__(self, responses): def __init__(self, responses):
super().__init__(None, None)
self.responses = responses self.responses = responses
async def get_headers(self, height, blocks): async def get_headers(self, height, blocks):

View file

@ -70,8 +70,10 @@ class BaseHeaders:
return True return True
def __getitem__(self, height) -> dict: def __getitem__(self, height) -> dict:
assert not isinstance(height, slice), \ if isinstance(height, slice):
"Slicing of header chain has not been implemented yet." raise NotImplementedError("Slicing of header chain has not been implemented yet.")
if not 0 <= height <= self.height:
raise IndexError(f"{height} is out of bounds, current height: {self.height}")
return self.deserialize(height, self.get_raw_header(height)) return self.deserialize(height, self.get_raw_header(height))
def get_raw_header(self, height) -> bytes: def get_raw_header(self, height) -> bytes:

View file

@ -287,7 +287,7 @@ class BaseLedger(metaclass=LedgerRegistry):
subscription_update = False subscription_update = False
if not headers: if not headers:
header_response = await self.network.get_headers(height, 2001) header_response = await self.network.retriable_call(self.network.get_headers, height, 2001)
headers = header_response['hex'] headers = header_response['hex']
if not headers: if not headers:
@ -394,7 +394,7 @@ class BaseLedger(metaclass=LedgerRegistry):
if local_status == remote_status: if local_status == remote_status:
return return
remote_history = await self.network.get_history(address) remote_history = await self.network.retriable_call(self.network.get_history, address)
cache_tasks = [] cache_tasks = []
synced_history = StringIO() synced_history = StringIO()
@ -447,6 +447,17 @@ class BaseLedger(metaclass=LedgerRegistry):
if address_manager is not None: if address_manager is not None:
await address_manager.ensure_address_gap() await address_manager.ensure_address_gap()
local_status, local_history = await self.get_local_status_and_history(address)
if local_status != remote_status:
log.debug(
"Wallet is out of sync after syncing. Remote: %s with %d items, local: %s with %d items",
remote_status, len(remote_history), local_status, len(local_history)
)
log.debug("local: %s", local_history)
log.debug("remote: %s", remote_history)
else:
log.debug("Sync completed for: %s", address)
async def cache_transaction(self, txid, remote_height): async def cache_transaction(self, txid, remote_height):
cache_item = self._tx_cache.get(txid) cache_item = self._tx_cache.get(txid)
if cache_item is None: if cache_item is None:
@ -466,7 +477,7 @@ class BaseLedger(metaclass=LedgerRegistry):
if tx is None: if tx is None:
# fetch from network # fetch from network
_raw = await self.network.get_transaction(txid) _raw = await self.network.retriable_call(self.network.get_transaction, txid)
if _raw: if _raw:
tx = self.transaction_class(unhexlify(_raw)) tx = self.transaction_class(unhexlify(_raw))
await self.maybe_verify_transaction(tx, remote_height) await self.maybe_verify_transaction(tx, remote_height)
@ -486,8 +497,8 @@ class BaseLedger(metaclass=LedgerRegistry):
async def maybe_verify_transaction(self, tx, remote_height): async def maybe_verify_transaction(self, tx, remote_height):
tx.height = remote_height tx.height = remote_height
if 0 < remote_height <= len(self.headers): if 0 < remote_height < len(self.headers):
merkle = await self.network.get_merkle(tx.id, remote_height) merkle = await self.network.retriable_call(self.network.get_merkle, tx.id, remote_height)
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
header = self.headers[remote_height] header = self.headers[remote_height]
tx.position = merkle['pos'] tx.position = merkle['pos']
@ -501,6 +512,7 @@ class BaseLedger(metaclass=LedgerRegistry):
return None return None
def broadcast(self, tx): def broadcast(self, tx):
# broadcast cant be a retriable call yet
return self.network.broadcast(hexlify(tx.raw).decode()) return self.network.broadcast(hexlify(tx.raw).decode())
async def wait(self, tx: basetransaction.BaseTransaction, height=-1, timeout=None): async def wait(self, tx: basetransaction.BaseTransaction, height=-1, timeout=None):
@ -522,4 +534,4 @@ class BaseLedger(metaclass=LedgerRegistry):
)) for address_record in records )) for address_record in records
], timeout=timeout) ], timeout=timeout)
if pending: if pending:
raise TimeoutError('Timed out waiting for transaction.') raise asyncio.TimeoutError('Timed out waiting for transaction.')

View file

@ -1,9 +1,8 @@
import logging import logging
import asyncio import asyncio
from asyncio import CancelledError from operator import itemgetter
from time import time from typing import Dict, Optional, Tuple
from typing import List from time import perf_counter
import socket
from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError
@ -14,71 +13,146 @@ log = logging.getLogger(__name__)
class ClientSession(BaseClientSession): class ClientSession(BaseClientSession):
def __init__(self, *args, network, server, timeout=30, on_connect_callback=None, **kwargs):
def __init__(self, *args, network, server, timeout=30, **kwargs):
self.network = network self.network = network
self.server = server self.server = server
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._on_disconnect_controller = StreamController() self._on_disconnect_controller = StreamController()
self.on_disconnected = self._on_disconnect_controller.stream self.on_disconnected = self._on_disconnect_controller.stream
self.bw_limit = self.framer.max_size = self.max_errors = 1 << 32 self.framer.max_size = self.max_errors = 1 << 32
self.bw_limit = -1
self.timeout = timeout self.timeout = timeout
self.max_seconds_idle = timeout * 2 self.max_seconds_idle = timeout * 2
self.ping_task = None self.response_time: Optional[float] = None
self.connection_latency: Optional[float] = None
self._response_samples = 0
self.pending_amount = 0
self._on_connect_cb = on_connect_callback or (lambda: None)
self.trigger_urgent_reconnect = asyncio.Event()
self._semaphore = asyncio.Semaphore(int(self.timeout))
@property
def available(self):
return not self.is_closing() and self._can_send.is_set() and self.response_time is not None
@property
def server_address_and_port(self) -> Optional[Tuple[str, int]]:
if not self.transport:
return None
return self.transport.get_extra_info('peername')
async def send_timed_server_version_request(self, args=(), timeout=None):
timeout = timeout or self.timeout
log.debug("send version request to %s:%i", *self.server)
start = perf_counter()
result = await asyncio.wait_for(
super().send_request('server.version', args), timeout=timeout
)
current_response_time = perf_counter() - start
response_sum = (self.response_time or 0) * self._response_samples + current_response_time
self.response_time = response_sum / (self._response_samples + 1)
self._response_samples += 1
return result
async def send_request(self, method, args=()): async def send_request(self, method, args=()):
try: self.pending_amount += 1
return await asyncio.wait_for(super().send_request(method, args), timeout=self.timeout) async with self._semaphore:
except RPCError as e: return await self._send_request(method, args)
log.warning("Wallet server returned an error. Code: %s Message: %s", *e.args)
raise e
except asyncio.TimeoutError:
self.abort()
raise
async def ping_forever(self): async def _send_request(self, method, args=()):
log.debug("send %s to %s:%i", method, *self.server)
try:
if method == 'server.version':
reply = await self.send_timed_server_version_request(args, self.timeout)
else:
reply = await asyncio.wait_for(
super().send_request(method, args), timeout=self.timeout
)
log.debug("got reply for %s from %s:%i", method, *self.server)
return reply
except RPCError as e:
log.warning("Wallet server (%s:%i) returned an error. Code: %s Message: %s",
*self.server, *e.args)
raise e
except ConnectionError:
log.warning("connection to %s:%i lost", *self.server)
self.synchronous_close()
raise asyncio.CancelledError(f"connection to {self.server[0]}:{self.server[1]} lost")
except asyncio.TimeoutError:
log.info("timeout sending %s to %s:%i", method, *self.server)
raise
except asyncio.CancelledError:
log.info("cancelled sending %s to %s:%i", method, *self.server)
self.synchronous_close()
raise
finally:
self.pending_amount -= 1
async def ensure_session(self):
# Handles reconnecting and maintaining a session alive
# TODO: change to 'ping' on newer protocol (above 1.2) # TODO: change to 'ping' on newer protocol (above 1.2)
while not self.is_closing(): retry_delay = default_delay = 1.0
if (time() - self.last_send) > self.max_seconds_idle: while True:
try: try:
await self.send_request('server.banner') if self.is_closing():
except: await self.create_connection(self.timeout)
self.abort() await self.ensure_server_version()
raise self._on_connect_cb()
await asyncio.sleep(self.max_seconds_idle//3) if (perf_counter() - self.last_send) > self.max_seconds_idle or self.response_time is None:
await self.ensure_server_version()
retry_delay = default_delay
except (asyncio.TimeoutError, OSError):
await self.close()
retry_delay = min(60, retry_delay * 2)
log.debug("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server)
try:
await asyncio.wait_for(self.trigger_urgent_reconnect.wait(), timeout=retry_delay)
except asyncio.TimeoutError:
pass
finally:
self.trigger_urgent_reconnect.clear()
async def ensure_server_version(self, required='1.2', timeout=3):
return await asyncio.wait_for(
self.send_request('server.version', [__version__, required]), timeout=timeout
)
async def create_connection(self, timeout=6): async def create_connection(self, timeout=6):
connector = Connector(lambda: self, *self.server) connector = Connector(lambda: self, *self.server)
start = perf_counter()
await asyncio.wait_for(connector.create_connection(), timeout=timeout) await asyncio.wait_for(connector.create_connection(), timeout=timeout)
self.ping_task = asyncio.create_task(self.ping_forever()) self.connection_latency = perf_counter() - start
async def handle_request(self, request): async def handle_request(self, request):
controller = self.network.subscription_controllers[request.method] controller = self.network.subscription_controllers[request.method]
controller.add(request.args) controller.add(request.args)
def connection_lost(self, exc): def connection_lost(self, exc):
log.debug("Connection lost: %s:%d", *self.server)
super().connection_lost(exc) super().connection_lost(exc)
self.response_time = None
self.connection_latency = None
self._response_samples = 0
self.pending_amount = 0
self._on_disconnect_controller.add(True) self._on_disconnect_controller.add(True)
if self.ping_task:
self.ping_task.cancel()
class BaseNetwork: class BaseNetwork:
def __init__(self, ledger): def __init__(self, ledger):
self.config = ledger.config self.config = ledger.config
self.client: ClientSession = None self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6))
self.session_pool: SessionPool = None self.client: Optional[ClientSession] = None
self.running = False self.running = False
self.remote_height: int = 0 self.remote_height: int = 0
self._on_connected_controller = StreamController() self._on_connected_controller = StreamController()
self.on_connected = self._on_connected_controller.stream self.on_connected = self._on_connected_controller.stream
self._on_header_controller = StreamController() self._on_header_controller = StreamController(merge_repeated_events=True)
self.on_header = self._on_header_controller.stream self.on_header = self._on_header_controller.stream
self._on_status_controller = StreamController() self._on_status_controller = StreamController(merge_repeated_events=True)
self.on_status = self._on_status_controller.stream self.on_status = self._on_status_controller.stream
self.subscription_controllers = { self.subscription_controllers = {
@ -86,79 +160,72 @@ class BaseNetwork:
'blockchain.address.subscribe': self._on_status_controller, 'blockchain.address.subscribe': self._on_status_controller,
} }
async def switch_to_fastest(self):
try:
client = await asyncio.wait_for(self.session_pool.wait_for_fastest_session(), 30)
except asyncio.TimeoutError:
if self.client:
await self.client.close()
self.client = None
for session in self.session_pool.sessions:
session.synchronous_close()
log.warning("not connected to any wallet servers")
return
current_client = self.client
self.client = client
log.info("Switching to SPV wallet server: %s:%d", *self.client.server)
self._on_connected_controller.add(True)
try:
self._update_remote_height((await self.subscribe_headers(),))
log.info("Subscribed to headers: %s:%d", *self.client.server)
except asyncio.TimeoutError:
if self.client:
await self.client.close()
self.client = current_client
return
self.session_pool.new_connection_event.clear()
return await self.session_pool.new_connection_event.wait()
async def start(self): async def start(self):
self.running = True self.running = True
connect_timeout = self.config.get('connect_timeout', 6)
self.session_pool = SessionPool(network=self, timeout=connect_timeout)
self.session_pool.start(self.config['default_servers']) self.session_pool.start(self.config['default_servers'])
self.on_header.listen(self._update_remote_height) self.on_header.listen(self._update_remote_height)
while True: while self.running:
try: await self.switch_to_fastest()
self.client = await self.pick_fastest_session()
if self.is_connected:
await self.ensure_server_version()
self._update_remote_height((await self.subscribe_headers(),))
log.info("Successfully connected to SPV wallet server: %s:%d", *self.client.server)
self._on_connected_controller.add(True)
await self.client.on_disconnected.first
except CancelledError:
self.running = False
except asyncio.TimeoutError:
log.warning("Timed out while trying to find a server!")
except Exception: # pylint: disable=broad-except
log.exception("Exception while trying to find a server!")
if not self.running:
return
elif self.client:
await self.client.close()
self.client.connection.cancel_pending_requests()
async def stop(self): async def stop(self):
self.running = False self.running = False
if self.session_pool: self.session_pool.stop()
self.session_pool.stop()
if self.is_connected:
disconnected = self.client.on_disconnected.first
await self.client.close()
await disconnected
@property @property
def is_connected(self): def is_connected(self):
return self.client is not None and not self.client.is_closing() return self.client and not self.client.is_closing()
def rpc(self, list_or_method, args): def rpc(self, list_or_method, args, session=None):
if self.is_connected: session = session or self.session_pool.fastest_session
return self.client.send_request(list_or_method, args) if session:
return session.send_request(list_or_method, args)
else: else:
self.session_pool.trigger_nodelay_connect()
raise ConnectionError("Attempting to send rpc request when connection is not available.") raise ConnectionError("Attempting to send rpc request when connection is not available.")
async def pick_fastest_session(self): async def retriable_call(self, function, *args, **kwargs):
sessions = await self.session_pool.get_online_sessions() while self.running:
done, pending = await asyncio.wait([ if not self.is_connected:
self.probe_session(session) log.warning("Wallet server unavailable, waiting for it to come back and retry.")
for session in sessions if not session.is_closing() await self.on_connected.first
], return_when='FIRST_COMPLETED') await self.session_pool.wait_for_fastest_session()
for task in pending: try:
task.cancel() return await function(*args, **kwargs)
for session in done: except asyncio.TimeoutError:
return await session log.warning("Wallet server call timed out, retrying.")
except ConnectionError:
async def probe_session(self, session: ClientSession): pass
await session.send_request('server.banner') raise asyncio.CancelledError() # if we got here, we are shutting down
return session
def _update_remote_height(self, header_args): def _update_remote_height(self, header_args):
self.remote_height = header_args[0]["height"] self.remote_height = header_args[0]["height"]
def ensure_server_version(self, required='1.2'):
return self.rpc('server.version', [__version__, required])
def broadcast(self, raw_transaction):
return self.rpc('blockchain.transaction.broadcast', [raw_transaction])
def get_history(self, address):
return self.rpc('blockchain.address.get_history', [address])
def get_transaction(self, tx_hash): def get_transaction(self, tx_hash):
return self.rpc('blockchain.transaction.get', [tx_hash]) return self.rpc('blockchain.transaction.get', [tx_hash])
@ -171,84 +238,111 @@ class BaseNetwork:
def get_headers(self, height, count=10000): def get_headers(self, height, count=10000):
return self.rpc('blockchain.block.headers', [height, count]) return self.rpc('blockchain.block.headers', [height, count])
def subscribe_headers(self): # --- Subscribes, history and broadcasts are always aimed towards the master client directly
return self.rpc('blockchain.headers.subscribe', [True]) def get_history(self, address):
return self.rpc('blockchain.address.get_history', [address], session=self.client)
def subscribe_address(self, address): def broadcast(self, raw_transaction):
return self.rpc('blockchain.address.subscribe', [address]) return self.rpc('blockchain.transaction.broadcast', [raw_transaction], session=self.client)
def subscribe_headers(self):
return self.rpc('blockchain.headers.subscribe', [True], session=self.client)
async def subscribe_address(self, address):
try:
return await self.rpc('blockchain.address.subscribe', [address], session=self.client)
except asyncio.TimeoutError:
# abort and cancel, we cant lose a subscription, it will happen again on reconnect
self.client.abort()
raise asyncio.CancelledError()
class SessionPool: class SessionPool:
def __init__(self, network: BaseNetwork, timeout: float): def __init__(self, network: BaseNetwork, timeout: float):
self.network = network self.network = network
self.sessions: List[ClientSession] = [] self.sessions: Dict[ClientSession, Optional[asyncio.Task]] = dict()
self._dead_servers: List[ClientSession] = []
self.maintain_connections_task = None
self.timeout = timeout self.timeout = timeout
# triggered when the master server is out, to speed up reconnect self.new_connection_event = asyncio.Event()
self._lost_master = asyncio.Event()
@property @property
def online(self): def online(self):
for session in self.sessions: return any(not session.is_closing() for session in self.sessions)
if not session.is_closing():
return True @property
return False def available_sessions(self):
return [session for session in self.sessions if session.available]
@property
def fastest_session(self):
if not self.available_sessions:
return None
return min(
[((session.response_time + session.connection_latency) * (session.pending_amount + 1), session)
for session in self.available_sessions],
key=itemgetter(0)
)[1]
def _get_session_connect_callback(self, session: ClientSession):
loop = asyncio.get_event_loop()
def callback():
duplicate_connections = [
s for s in self.sessions
if s is not session and s.server_address_and_port == session.server_address_and_port
]
already_connected = None if not duplicate_connections else duplicate_connections[0]
if already_connected:
self.sessions.pop(session).cancel()
session.synchronous_close()
log.debug("wallet server %s resolves to the same server as %s, rechecking in an hour",
session.server[0], already_connected.server[0])
loop.call_later(3600, self._connect_session, session.server)
return
self.new_connection_event.set()
log.info("connected to %s:%i", *session.server)
return callback
def _connect_session(self, server: Tuple[str, int]):
session = None
for s in self.sessions:
if s.server == server:
session = s
break
if not session:
session = ClientSession(
network=self.network, server=server
)
session._on_connect_cb = self._get_session_connect_callback(session)
task = self.sessions.get(session, None)
if not task or task.done():
task = asyncio.create_task(session.ensure_session())
task.add_done_callback(lambda _: self.ensure_connections())
self.sessions[session] = task
def start(self, default_servers): def start(self, default_servers):
self.sessions = [ for server in default_servers:
ClientSession(network=self.network, server=server) self._connect_session(server)
for server in default_servers
]
self.maintain_connections_task = asyncio.create_task(self.ensure_connections())
def stop(self): def stop(self):
if self.maintain_connections_task: for task in self.sessions.values():
self.maintain_connections_task.cancel() task.cancel()
self.sessions.clear()
def ensure_connections(self):
for session in self.sessions: for session in self.sessions:
if not session.is_closing(): self._connect_session(session.server)
session.abort()
self.sessions, self._dead_servers, self.maintain_connections_task = [], [], None
async def ensure_connections(self): def trigger_nodelay_connect(self):
while True: # used when other parts of the system sees we might have internet back
await asyncio.gather(*[ # bypasses the retry interval
self.ensure_connection(session) for session in self.sessions:
for session in self.sessions session.trigger_urgent_reconnect.set()
], return_exceptions=True)
try:
await asyncio.wait_for(self._lost_master.wait(), timeout=3)
except asyncio.TimeoutError:
pass
self._lost_master.clear()
if not self.sessions:
self.sessions.extend(self._dead_servers)
self._dead_servers = []
async def ensure_connection(self, session): async def wait_for_fastest_session(self):
self._dead_servers.append(session) while not self.fastest_session:
self.sessions.remove(session) self.trigger_nodelay_connect()
try: self.new_connection_event.clear()
if session.is_closing(): await self.new_connection_event.wait()
await session.create_connection(self.timeout) return self.fastest_session
await asyncio.wait_for(session.send_request('server.banner'), timeout=self.timeout)
self.sessions.append(session)
self._dead_servers.remove(session)
except asyncio.TimeoutError:
log.warning("Timeout connecting to %s:%d", *session.server)
except asyncio.CancelledError: # pylint: disable=try-except-raise
raise
except socket.gaierror:
log.warning("Could not resolve IP for %s", session.server[0])
except Exception as err: # pylint: disable=broad-except
if 'Connect call failed' in str(err):
log.warning("Could not connect to %s:%d", *session.server)
else:
log.exception("Connecting to %s:%d raised an exception:", *session.server)
async def get_online_sessions(self):
while not self.online:
self._lost_master.set()
await asyncio.sleep(0.5)
return self.sessions

View file

@ -33,6 +33,7 @@ __all__ = ('JSONRPC', 'JSONRPCv1', 'JSONRPCv2', 'JSONRPCLoose',
import itertools import itertools
import json import json
import typing import typing
import asyncio
from functools import partial from functools import partial
from numbers import Number from numbers import Number
@ -745,9 +746,8 @@ class JSONRPCConnection(object):
self._protocol = item self._protocol = item
return self.receive_message(message) return self.receive_message(message)
def cancel_pending_requests(self): def raise_pending_requests(self, exception):
"""Cancel all pending requests.""" exception = exception or asyncio.TimeoutError()
exception = CancelledError()
for request, event in self._requests.values(): for request, event in self._requests.values():
event.result = exception event.result = exception
event.set() event.set()

View file

@ -103,7 +103,7 @@ class SessionBase(asyncio.Protocol):
# Force-close a connection if a send doesn't succeed in this time # Force-close a connection if a send doesn't succeed in this time
self.max_send_delay = 60 self.max_send_delay = 60
# Statistics. The RPC object also keeps its own statistics. # Statistics. The RPC object also keeps its own statistics.
self.start_time = time.time() self.start_time = time.perf_counter()
self.errors = 0 self.errors = 0
self.send_count = 0 self.send_count = 0
self.send_size = 0 self.send_size = 0
@ -123,7 +123,7 @@ class SessionBase(asyncio.Protocol):
# A non-positive value means not to limit concurrency # A non-positive value means not to limit concurrency
if self.bw_limit <= 0: if self.bw_limit <= 0:
return return
now = time.time() now = time.perf_counter()
# Reduce the recorded usage in proportion to the elapsed time # Reduce the recorded usage in proportion to the elapsed time
refund = (now - self.bw_time) * (self.bw_limit / 3600) refund = (now - self.bw_time) * (self.bw_limit / 3600)
self.bw_charge = max(0, self.bw_charge - int(refund)) self.bw_charge = max(0, self.bw_charge - int(refund))
@ -146,7 +146,7 @@ class SessionBase(asyncio.Protocol):
await asyncio.wait_for(self._can_send.wait(), secs) await asyncio.wait_for(self._can_send.wait(), secs)
except asyncio.TimeoutError: except asyncio.TimeoutError:
self.abort() self.abort()
raise asyncio.CancelledError(f'task timed out after {secs}s') raise asyncio.TimeoutError(f'task timed out after {secs}s')
async def _send_message(self, message): async def _send_message(self, message):
if not self._can_send.is_set(): if not self._can_send.is_set():
@ -156,7 +156,7 @@ class SessionBase(asyncio.Protocol):
self.send_size += len(framed_message) self.send_size += len(framed_message)
self._using_bandwidth(len(framed_message)) self._using_bandwidth(len(framed_message))
self.send_count += 1 self.send_count += 1
self.last_send = time.time() self.last_send = time.perf_counter()
if self.verbosity >= 4: if self.verbosity >= 4:
self.logger.debug(f'Sending framed message {framed_message}') self.logger.debug(f'Sending framed message {framed_message}')
self.transport.write(framed_message) self.transport.write(framed_message)
@ -215,7 +215,8 @@ class SessionBase(asyncio.Protocol):
self._address = None self._address = None
self.transport = None self.transport = None
self._task_group.cancel() self._task_group.cancel()
self._pm_task.cancel() if self._pm_task:
self._pm_task.cancel()
# Release waiting tasks # Release waiting tasks
self._can_send.set() self._can_send.set()
@ -253,6 +254,7 @@ class SessionBase(asyncio.Protocol):
if self.transport: if self.transport:
self.transport.abort() self.transport.abort()
# TODO: replace with synchronous_close
async def close(self, *, force_after=30): async def close(self, *, force_after=30):
"""Close the connection and return when closed.""" """Close the connection and return when closed."""
self._close() self._close()
@ -262,6 +264,11 @@ class SessionBase(asyncio.Protocol):
self.abort() self.abort()
await self._pm_task await self._pm_task
def synchronous_close(self):
self._close()
if self._pm_task and not self._pm_task.done():
self._pm_task.cancel()
class MessageSession(SessionBase): class MessageSession(SessionBase):
"""Session class for protocols where messages are not tied to responses, """Session class for protocols where messages are not tied to responses,
@ -296,7 +303,7 @@ class MessageSession(SessionBase):
) )
self._bump_errors() self._bump_errors()
else: else:
self.last_recv = time.time() self.last_recv = time.perf_counter()
self.recv_count += 1 self.recv_count += 1
if self.recv_count % 10 == 0: if self.recv_count % 10 == 0:
await self._update_concurrency() await self._update_concurrency()
@ -416,7 +423,7 @@ class RPCSession(SessionBase):
self.logger.warning(f'{e!r}') self.logger.warning(f'{e!r}')
continue continue
self.last_recv = time.time() self.last_recv = time.perf_counter()
self.recv_count += 1 self.recv_count += 1
if self.recv_count % 10 == 0: if self.recv_count % 10 == 0:
await self._update_concurrency() await self._update_concurrency()
@ -456,7 +463,7 @@ class RPCSession(SessionBase):
def connection_lost(self, exc): def connection_lost(self, exc):
# Cancel pending requests and message processing # Cancel pending requests and message processing
self.connection.cancel_pending_requests() self.connection.raise_pending_requests(exc)
super().connection_lost(exc) super().connection_lost(exc)
# External API # External API
@ -473,6 +480,8 @@ class RPCSession(SessionBase):
async def send_request(self, method, args=()): async def send_request(self, method, args=()):
"""Send an RPC request over the network.""" """Send an RPC request over the network."""
if self.is_closing():
raise asyncio.TimeoutError("Trying to send request on a recently dropped connection.")
message, event = self.connection.send_request(Request(method, args)) message, event = self.connection.send_request(Request(method, args))
await self._send_message(message) await self._send_message(message)
await event.wait() await event.wait()

View file

@ -258,7 +258,7 @@ class SessionManager:
session_timeout = self.env.session_timeout session_timeout = self.env.session_timeout
while True: while True:
await sleep(session_timeout // 10) await sleep(session_timeout // 10)
stale_cutoff = time.time() - session_timeout stale_cutoff = time.perf_counter() - session_timeout
stale_sessions = [session for session in self.sessions stale_sessions = [session for session in self.sessions
if session.last_recv < stale_cutoff] if session.last_recv < stale_cutoff]
if stale_sessions: if stale_sessions:

View file

@ -45,10 +45,12 @@ class BroadcastSubscription:
class StreamController: class StreamController:
def __init__(self): def __init__(self, merge_repeated_events=False):
self.stream = Stream(self) self.stream = Stream(self)
self._first_subscription = None self._first_subscription = None
self._last_subscription = None self._last_subscription = None
self._last_event = None
self._merge_repeated = merge_repeated_events
@property @property
def has_listener(self): def has_listener(self):
@ -76,8 +78,10 @@ class StreamController:
return f return f
def add(self, event): def add(self, event):
skip = self._merge_repeated and event == self._last_event
self._last_event = event
return self._notify_and_ensure_future( return self._notify_and_ensure_future(
lambda subscription: subscription._add(event) lambda subscription: None if skip else subscription._add(event)
) )
def add_error(self, exception): def add_error(self, exception):
@ -141,8 +145,8 @@ class Stream:
def first(self): def first(self):
future = asyncio.get_event_loop().create_future() future = asyncio.get_event_loop().create_future()
subscription = self.listen( subscription = self.listen(
lambda value: self._cancel_and_callback(subscription, future, value), lambda value: not future.done() and self._cancel_and_callback(subscription, future, value),
lambda exception: self._cancel_and_error(subscription, future, exception) lambda exception: not future.done() and self._cancel_and_error(subscription, future, exception)
) )
return future return future