diff --git a/torba/tests/client_tests/unit/test_ledger.py b/torba/tests/client_tests/unit/test_ledger.py index 0e077b441..beb26a5d1 100644 --- a/torba/tests/client_tests/unit/test_ledger.py +++ b/torba/tests/client_tests/unit/test_ledger.py @@ -18,6 +18,9 @@ class MockNetwork: self.get_transaction_called = [] self.is_connected = False + def retriable_call(self, function, *args, **kwargs): + return function(*args, **kwargs) + async def get_history(self, address): self.get_history_called.append(address) self.address = address @@ -121,8 +124,9 @@ class TestSynchronization(LedgerTestCase): ) -class MocHeaderNetwork: +class MocHeaderNetwork(MockNetwork): def __init__(self, responses): + super().__init__(None, None) self.responses = responses async def get_headers(self, height, blocks): diff --git a/torba/torba/client/baseheader.py b/torba/torba/client/baseheader.py index c48fbdb9f..1df30ba10 100644 --- a/torba/torba/client/baseheader.py +++ b/torba/torba/client/baseheader.py @@ -70,8 +70,10 @@ class BaseHeaders: return True def __getitem__(self, height) -> dict: - assert not isinstance(height, slice), \ - "Slicing of header chain has not been implemented yet." + if isinstance(height, slice): + raise NotImplementedError("Slicing of header chain has not been implemented yet.") + if not 0 <= height <= self.height: + raise IndexError(f"{height} is out of bounds, current height: {self.height}") return self.deserialize(height, self.get_raw_header(height)) def get_raw_header(self, height) -> bytes: diff --git a/torba/torba/client/baseledger.py b/torba/torba/client/baseledger.py index c1226e21f..a5520b9a3 100644 --- a/torba/torba/client/baseledger.py +++ b/torba/torba/client/baseledger.py @@ -287,7 +287,7 @@ class BaseLedger(metaclass=LedgerRegistry): subscription_update = False if not headers: - header_response = await self.network.get_headers(height, 2001) + header_response = await self.network.retriable_call(self.network.get_headers, height, 2001) headers = header_response['hex'] if not headers: @@ -394,7 +394,7 @@ class BaseLedger(metaclass=LedgerRegistry): if local_status == remote_status: return - remote_history = await self.network.get_history(address) + remote_history = await self.network.retriable_call(self.network.get_history, address) cache_tasks = [] synced_history = StringIO() @@ -447,6 +447,17 @@ class BaseLedger(metaclass=LedgerRegistry): if address_manager is not None: await address_manager.ensure_address_gap() + local_status, local_history = await self.get_local_status_and_history(address) + if local_status != remote_status: + log.debug( + "Wallet is out of sync after syncing. Remote: %s with %d items, local: %s with %d items", + remote_status, len(remote_history), local_status, len(local_history) + ) + log.debug("local: %s", local_history) + log.debug("remote: %s", remote_history) + else: + log.debug("Sync completed for: %s", address) + async def cache_transaction(self, txid, remote_height): cache_item = self._tx_cache.get(txid) if cache_item is None: @@ -466,7 +477,7 @@ class BaseLedger(metaclass=LedgerRegistry): if tx is None: # fetch from network - _raw = await self.network.get_transaction(txid) + _raw = await self.network.retriable_call(self.network.get_transaction, txid) if _raw: tx = self.transaction_class(unhexlify(_raw)) await self.maybe_verify_transaction(tx, remote_height) @@ -486,8 +497,8 @@ class BaseLedger(metaclass=LedgerRegistry): async def maybe_verify_transaction(self, tx, remote_height): tx.height = remote_height - if 0 < remote_height <= len(self.headers): - merkle = await self.network.get_merkle(tx.id, remote_height) + if 0 < remote_height < len(self.headers): + merkle = await self.network.retriable_call(self.network.get_merkle, tx.id, remote_height) merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) header = self.headers[remote_height] tx.position = merkle['pos'] @@ -501,6 +512,7 @@ class BaseLedger(metaclass=LedgerRegistry): return None def broadcast(self, tx): + # broadcast cant be a retriable call yet return self.network.broadcast(hexlify(tx.raw).decode()) async def wait(self, tx: basetransaction.BaseTransaction, height=-1, timeout=None): @@ -522,4 +534,4 @@ class BaseLedger(metaclass=LedgerRegistry): )) for address_record in records ], timeout=timeout) if pending: - raise TimeoutError('Timed out waiting for transaction.') + raise asyncio.TimeoutError('Timed out waiting for transaction.') diff --git a/torba/torba/client/basenetwork.py b/torba/torba/client/basenetwork.py index 25c858dc4..72b158962 100644 --- a/torba/torba/client/basenetwork.py +++ b/torba/torba/client/basenetwork.py @@ -1,9 +1,8 @@ import logging import asyncio -from asyncio import CancelledError -from time import time -from typing import List -import socket +from operator import itemgetter +from typing import Dict, Optional, Tuple +from time import perf_counter from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError @@ -14,71 +13,146 @@ log = logging.getLogger(__name__) class ClientSession(BaseClientSession): - - def __init__(self, *args, network, server, timeout=30, **kwargs): + def __init__(self, *args, network, server, timeout=30, on_connect_callback=None, **kwargs): self.network = network self.server = server super().__init__(*args, **kwargs) self._on_disconnect_controller = StreamController() self.on_disconnected = self._on_disconnect_controller.stream - self.bw_limit = self.framer.max_size = self.max_errors = 1 << 32 + self.framer.max_size = self.max_errors = 1 << 32 + self.bw_limit = -1 self.timeout = timeout self.max_seconds_idle = timeout * 2 - self.ping_task = None + self.response_time: Optional[float] = None + self.connection_latency: Optional[float] = None + self._response_samples = 0 + self.pending_amount = 0 + self._on_connect_cb = on_connect_callback or (lambda: None) + self.trigger_urgent_reconnect = asyncio.Event() + self._semaphore = asyncio.Semaphore(int(self.timeout)) + + @property + def available(self): + return not self.is_closing() and self._can_send.is_set() and self.response_time is not None + + @property + def server_address_and_port(self) -> Optional[Tuple[str, int]]: + if not self.transport: + return None + return self.transport.get_extra_info('peername') + + async def send_timed_server_version_request(self, args=(), timeout=None): + timeout = timeout or self.timeout + log.debug("send version request to %s:%i", *self.server) + start = perf_counter() + result = await asyncio.wait_for( + super().send_request('server.version', args), timeout=timeout + ) + current_response_time = perf_counter() - start + response_sum = (self.response_time or 0) * self._response_samples + current_response_time + self.response_time = response_sum / (self._response_samples + 1) + self._response_samples += 1 + return result async def send_request(self, method, args=()): - try: - return await asyncio.wait_for(super().send_request(method, args), timeout=self.timeout) - except RPCError as e: - log.warning("Wallet server returned an error. Code: %s Message: %s", *e.args) - raise e - except asyncio.TimeoutError: - self.abort() - raise + self.pending_amount += 1 + async with self._semaphore: + return await self._send_request(method, args) - async def ping_forever(self): + async def _send_request(self, method, args=()): + log.debug("send %s to %s:%i", method, *self.server) + try: + if method == 'server.version': + reply = await self.send_timed_server_version_request(args, self.timeout) + else: + reply = await asyncio.wait_for( + super().send_request(method, args), timeout=self.timeout + ) + log.debug("got reply for %s from %s:%i", method, *self.server) + return reply + except RPCError as e: + log.warning("Wallet server (%s:%i) returned an error. Code: %s Message: %s", + *self.server, *e.args) + raise e + except ConnectionError: + log.warning("connection to %s:%i lost", *self.server) + self.synchronous_close() + raise asyncio.CancelledError(f"connection to {self.server[0]}:{self.server[1]} lost") + except asyncio.TimeoutError: + log.info("timeout sending %s to %s:%i", method, *self.server) + raise + except asyncio.CancelledError: + log.info("cancelled sending %s to %s:%i", method, *self.server) + self.synchronous_close() + raise + finally: + self.pending_amount -= 1 + + async def ensure_session(self): + # Handles reconnecting and maintaining a session alive # TODO: change to 'ping' on newer protocol (above 1.2) - while not self.is_closing(): - if (time() - self.last_send) > self.max_seconds_idle: - try: - await self.send_request('server.banner') - except: - self.abort() - raise - await asyncio.sleep(self.max_seconds_idle//3) + retry_delay = default_delay = 1.0 + while True: + try: + if self.is_closing(): + await self.create_connection(self.timeout) + await self.ensure_server_version() + self._on_connect_cb() + if (perf_counter() - self.last_send) > self.max_seconds_idle or self.response_time is None: + await self.ensure_server_version() + retry_delay = default_delay + except (asyncio.TimeoutError, OSError): + await self.close() + retry_delay = min(60, retry_delay * 2) + log.debug("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server) + try: + await asyncio.wait_for(self.trigger_urgent_reconnect.wait(), timeout=retry_delay) + except asyncio.TimeoutError: + pass + finally: + self.trigger_urgent_reconnect.clear() + + async def ensure_server_version(self, required='1.2', timeout=3): + return await asyncio.wait_for( + self.send_request('server.version', [__version__, required]), timeout=timeout + ) async def create_connection(self, timeout=6): connector = Connector(lambda: self, *self.server) + start = perf_counter() await asyncio.wait_for(connector.create_connection(), timeout=timeout) - self.ping_task = asyncio.create_task(self.ping_forever()) + self.connection_latency = perf_counter() - start async def handle_request(self, request): controller = self.network.subscription_controllers[request.method] controller.add(request.args) def connection_lost(self, exc): + log.debug("Connection lost: %s:%d", *self.server) super().connection_lost(exc) + self.response_time = None + self.connection_latency = None + self._response_samples = 0 + self.pending_amount = 0 self._on_disconnect_controller.add(True) - if self.ping_task: - self.ping_task.cancel() class BaseNetwork: def __init__(self, ledger): self.config = ledger.config - self.client: ClientSession = None - self.session_pool: SessionPool = None + self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6)) + self.client: Optional[ClientSession] = None self.running = False self.remote_height: int = 0 self._on_connected_controller = StreamController() self.on_connected = self._on_connected_controller.stream - self._on_header_controller = StreamController() + self._on_header_controller = StreamController(merge_repeated_events=True) self.on_header = self._on_header_controller.stream - self._on_status_controller = StreamController() + self._on_status_controller = StreamController(merge_repeated_events=True) self.on_status = self._on_status_controller.stream self.subscription_controllers = { @@ -86,79 +160,72 @@ class BaseNetwork: 'blockchain.address.subscribe': self._on_status_controller, } + async def switch_to_fastest(self): + try: + client = await asyncio.wait_for(self.session_pool.wait_for_fastest_session(), 30) + except asyncio.TimeoutError: + if self.client: + await self.client.close() + self.client = None + for session in self.session_pool.sessions: + session.synchronous_close() + log.warning("not connected to any wallet servers") + return + current_client = self.client + self.client = client + log.info("Switching to SPV wallet server: %s:%d", *self.client.server) + self._on_connected_controller.add(True) + try: + self._update_remote_height((await self.subscribe_headers(),)) + log.info("Subscribed to headers: %s:%d", *self.client.server) + except asyncio.TimeoutError: + if self.client: + await self.client.close() + self.client = current_client + return + self.session_pool.new_connection_event.clear() + return await self.session_pool.new_connection_event.wait() + async def start(self): self.running = True - connect_timeout = self.config.get('connect_timeout', 6) - self.session_pool = SessionPool(network=self, timeout=connect_timeout) self.session_pool.start(self.config['default_servers']) self.on_header.listen(self._update_remote_height) - while True: - try: - self.client = await self.pick_fastest_session() - if self.is_connected: - await self.ensure_server_version() - self._update_remote_height((await self.subscribe_headers(),)) - log.info("Successfully connected to SPV wallet server: %s:%d", *self.client.server) - self._on_connected_controller.add(True) - await self.client.on_disconnected.first - except CancelledError: - self.running = False - except asyncio.TimeoutError: - log.warning("Timed out while trying to find a server!") - except Exception: # pylint: disable=broad-except - log.exception("Exception while trying to find a server!") - if not self.running: - return - elif self.client: - await self.client.close() - self.client.connection.cancel_pending_requests() + while self.running: + await self.switch_to_fastest() async def stop(self): self.running = False - if self.session_pool: - self.session_pool.stop() - if self.is_connected: - disconnected = self.client.on_disconnected.first - await self.client.close() - await disconnected + self.session_pool.stop() @property def is_connected(self): - return self.client is not None and not self.client.is_closing() + return self.client and not self.client.is_closing() - def rpc(self, list_or_method, args): - if self.is_connected: - return self.client.send_request(list_or_method, args) + def rpc(self, list_or_method, args, session=None): + session = session or self.session_pool.fastest_session + if session: + return session.send_request(list_or_method, args) else: + self.session_pool.trigger_nodelay_connect() raise ConnectionError("Attempting to send rpc request when connection is not available.") - async def pick_fastest_session(self): - sessions = await self.session_pool.get_online_sessions() - done, pending = await asyncio.wait([ - self.probe_session(session) - for session in sessions if not session.is_closing() - ], return_when='FIRST_COMPLETED') - for task in pending: - task.cancel() - for session in done: - return await session - - async def probe_session(self, session: ClientSession): - await session.send_request('server.banner') - return session + async def retriable_call(self, function, *args, **kwargs): + while self.running: + if not self.is_connected: + log.warning("Wallet server unavailable, waiting for it to come back and retry.") + await self.on_connected.first + await self.session_pool.wait_for_fastest_session() + try: + return await function(*args, **kwargs) + except asyncio.TimeoutError: + log.warning("Wallet server call timed out, retrying.") + except ConnectionError: + pass + raise asyncio.CancelledError() # if we got here, we are shutting down def _update_remote_height(self, header_args): self.remote_height = header_args[0]["height"] - def ensure_server_version(self, required='1.2'): - return self.rpc('server.version', [__version__, required]) - - def broadcast(self, raw_transaction): - return self.rpc('blockchain.transaction.broadcast', [raw_transaction]) - - def get_history(self, address): - return self.rpc('blockchain.address.get_history', [address]) - def get_transaction(self, tx_hash): return self.rpc('blockchain.transaction.get', [tx_hash]) @@ -171,84 +238,111 @@ class BaseNetwork: def get_headers(self, height, count=10000): return self.rpc('blockchain.block.headers', [height, count]) - def subscribe_headers(self): - return self.rpc('blockchain.headers.subscribe', [True]) + # --- Subscribes, history and broadcasts are always aimed towards the master client directly + def get_history(self, address): + return self.rpc('blockchain.address.get_history', [address], session=self.client) - def subscribe_address(self, address): - return self.rpc('blockchain.address.subscribe', [address]) + def broadcast(self, raw_transaction): + return self.rpc('blockchain.transaction.broadcast', [raw_transaction], session=self.client) + + def subscribe_headers(self): + return self.rpc('blockchain.headers.subscribe', [True], session=self.client) + + async def subscribe_address(self, address): + try: + return await self.rpc('blockchain.address.subscribe', [address], session=self.client) + except asyncio.TimeoutError: + # abort and cancel, we cant lose a subscription, it will happen again on reconnect + self.client.abort() + raise asyncio.CancelledError() class SessionPool: def __init__(self, network: BaseNetwork, timeout: float): self.network = network - self.sessions: List[ClientSession] = [] - self._dead_servers: List[ClientSession] = [] - self.maintain_connections_task = None + self.sessions: Dict[ClientSession, Optional[asyncio.Task]] = dict() self.timeout = timeout - # triggered when the master server is out, to speed up reconnect - self._lost_master = asyncio.Event() + self.new_connection_event = asyncio.Event() @property def online(self): - for session in self.sessions: - if not session.is_closing(): - return True - return False + return any(not session.is_closing() for session in self.sessions) + + @property + def available_sessions(self): + return [session for session in self.sessions if session.available] + + @property + def fastest_session(self): + if not self.available_sessions: + return None + return min( + [((session.response_time + session.connection_latency) * (session.pending_amount + 1), session) + for session in self.available_sessions], + key=itemgetter(0) + )[1] + + def _get_session_connect_callback(self, session: ClientSession): + loop = asyncio.get_event_loop() + + def callback(): + duplicate_connections = [ + s for s in self.sessions + if s is not session and s.server_address_and_port == session.server_address_and_port + ] + already_connected = None if not duplicate_connections else duplicate_connections[0] + if already_connected: + self.sessions.pop(session).cancel() + session.synchronous_close() + log.debug("wallet server %s resolves to the same server as %s, rechecking in an hour", + session.server[0], already_connected.server[0]) + loop.call_later(3600, self._connect_session, session.server) + return + self.new_connection_event.set() + log.info("connected to %s:%i", *session.server) + + return callback + + def _connect_session(self, server: Tuple[str, int]): + session = None + for s in self.sessions: + if s.server == server: + session = s + break + if not session: + session = ClientSession( + network=self.network, server=server + ) + session._on_connect_cb = self._get_session_connect_callback(session) + task = self.sessions.get(session, None) + if not task or task.done(): + task = asyncio.create_task(session.ensure_session()) + task.add_done_callback(lambda _: self.ensure_connections()) + self.sessions[session] = task def start(self, default_servers): - self.sessions = [ - ClientSession(network=self.network, server=server) - for server in default_servers - ] - self.maintain_connections_task = asyncio.create_task(self.ensure_connections()) + for server in default_servers: + self._connect_session(server) def stop(self): - if self.maintain_connections_task: - self.maintain_connections_task.cancel() + for task in self.sessions.values(): + task.cancel() + self.sessions.clear() + + def ensure_connections(self): for session in self.sessions: - if not session.is_closing(): - session.abort() - self.sessions, self._dead_servers, self.maintain_connections_task = [], [], None + self._connect_session(session.server) - async def ensure_connections(self): - while True: - await asyncio.gather(*[ - self.ensure_connection(session) - for session in self.sessions - ], return_exceptions=True) - try: - await asyncio.wait_for(self._lost_master.wait(), timeout=3) - except asyncio.TimeoutError: - pass - self._lost_master.clear() - if not self.sessions: - self.sessions.extend(self._dead_servers) - self._dead_servers = [] + def trigger_nodelay_connect(self): + # used when other parts of the system sees we might have internet back + # bypasses the retry interval + for session in self.sessions: + session.trigger_urgent_reconnect.set() - async def ensure_connection(self, session): - self._dead_servers.append(session) - self.sessions.remove(session) - try: - if session.is_closing(): - await session.create_connection(self.timeout) - await asyncio.wait_for(session.send_request('server.banner'), timeout=self.timeout) - self.sessions.append(session) - self._dead_servers.remove(session) - except asyncio.TimeoutError: - log.warning("Timeout connecting to %s:%d", *session.server) - except asyncio.CancelledError: # pylint: disable=try-except-raise - raise - except socket.gaierror: - log.warning("Could not resolve IP for %s", session.server[0]) - except Exception as err: # pylint: disable=broad-except - if 'Connect call failed' in str(err): - log.warning("Could not connect to %s:%d", *session.server) - else: - log.exception("Connecting to %s:%d raised an exception:", *session.server) - - async def get_online_sessions(self): - while not self.online: - self._lost_master.set() - await asyncio.sleep(0.5) - return self.sessions + async def wait_for_fastest_session(self): + while not self.fastest_session: + self.trigger_nodelay_connect() + self.new_connection_event.clear() + await self.new_connection_event.wait() + return self.fastest_session diff --git a/torba/torba/rpc/jsonrpc.py b/torba/torba/rpc/jsonrpc.py index 5e908cd02..2e8bfa2a7 100644 --- a/torba/torba/rpc/jsonrpc.py +++ b/torba/torba/rpc/jsonrpc.py @@ -33,6 +33,7 @@ __all__ = ('JSONRPC', 'JSONRPCv1', 'JSONRPCv2', 'JSONRPCLoose', import itertools import json import typing +import asyncio from functools import partial from numbers import Number @@ -745,9 +746,8 @@ class JSONRPCConnection(object): self._protocol = item return self.receive_message(message) - def cancel_pending_requests(self): - """Cancel all pending requests.""" - exception = CancelledError() + def raise_pending_requests(self, exception): + exception = exception or asyncio.TimeoutError() for request, event in self._requests.values(): event.result = exception event.set() diff --git a/torba/torba/rpc/session.py b/torba/torba/rpc/session.py index c8b8c6945..dd9909cfd 100644 --- a/torba/torba/rpc/session.py +++ b/torba/torba/rpc/session.py @@ -103,7 +103,7 @@ class SessionBase(asyncio.Protocol): # Force-close a connection if a send doesn't succeed in this time self.max_send_delay = 60 # Statistics. The RPC object also keeps its own statistics. - self.start_time = time.time() + self.start_time = time.perf_counter() self.errors = 0 self.send_count = 0 self.send_size = 0 @@ -123,7 +123,7 @@ class SessionBase(asyncio.Protocol): # A non-positive value means not to limit concurrency if self.bw_limit <= 0: return - now = time.time() + now = time.perf_counter() # Reduce the recorded usage in proportion to the elapsed time refund = (now - self.bw_time) * (self.bw_limit / 3600) self.bw_charge = max(0, self.bw_charge - int(refund)) @@ -146,7 +146,7 @@ class SessionBase(asyncio.Protocol): await asyncio.wait_for(self._can_send.wait(), secs) except asyncio.TimeoutError: self.abort() - raise asyncio.CancelledError(f'task timed out after {secs}s') + raise asyncio.TimeoutError(f'task timed out after {secs}s') async def _send_message(self, message): if not self._can_send.is_set(): @@ -156,7 +156,7 @@ class SessionBase(asyncio.Protocol): self.send_size += len(framed_message) self._using_bandwidth(len(framed_message)) self.send_count += 1 - self.last_send = time.time() + self.last_send = time.perf_counter() if self.verbosity >= 4: self.logger.debug(f'Sending framed message {framed_message}') self.transport.write(framed_message) @@ -215,7 +215,8 @@ class SessionBase(asyncio.Protocol): self._address = None self.transport = None self._task_group.cancel() - self._pm_task.cancel() + if self._pm_task: + self._pm_task.cancel() # Release waiting tasks self._can_send.set() @@ -253,6 +254,7 @@ class SessionBase(asyncio.Protocol): if self.transport: self.transport.abort() + # TODO: replace with synchronous_close async def close(self, *, force_after=30): """Close the connection and return when closed.""" self._close() @@ -262,6 +264,11 @@ class SessionBase(asyncio.Protocol): self.abort() await self._pm_task + def synchronous_close(self): + self._close() + if self._pm_task and not self._pm_task.done(): + self._pm_task.cancel() + class MessageSession(SessionBase): """Session class for protocols where messages are not tied to responses, @@ -296,7 +303,7 @@ class MessageSession(SessionBase): ) self._bump_errors() else: - self.last_recv = time.time() + self.last_recv = time.perf_counter() self.recv_count += 1 if self.recv_count % 10 == 0: await self._update_concurrency() @@ -416,7 +423,7 @@ class RPCSession(SessionBase): self.logger.warning(f'{e!r}') continue - self.last_recv = time.time() + self.last_recv = time.perf_counter() self.recv_count += 1 if self.recv_count % 10 == 0: await self._update_concurrency() @@ -456,7 +463,7 @@ class RPCSession(SessionBase): def connection_lost(self, exc): # Cancel pending requests and message processing - self.connection.cancel_pending_requests() + self.connection.raise_pending_requests(exc) super().connection_lost(exc) # External API @@ -473,6 +480,8 @@ class RPCSession(SessionBase): async def send_request(self, method, args=()): """Send an RPC request over the network.""" + if self.is_closing(): + raise asyncio.TimeoutError("Trying to send request on a recently dropped connection.") message, event = self.connection.send_request(Request(method, args)) await self._send_message(message) await event.wait() diff --git a/torba/torba/server/session.py b/torba/torba/server/session.py index 9c5981d22..db0e6fa85 100644 --- a/torba/torba/server/session.py +++ b/torba/torba/server/session.py @@ -258,7 +258,7 @@ class SessionManager: session_timeout = self.env.session_timeout while True: await sleep(session_timeout // 10) - stale_cutoff = time.time() - session_timeout + stale_cutoff = time.perf_counter() - session_timeout stale_sessions = [session for session in self.sessions if session.last_recv < stale_cutoff] if stale_sessions: diff --git a/torba/torba/stream.py b/torba/torba/stream.py index 40589ade0..04b008688 100644 --- a/torba/torba/stream.py +++ b/torba/torba/stream.py @@ -45,10 +45,12 @@ class BroadcastSubscription: class StreamController: - def __init__(self): + def __init__(self, merge_repeated_events=False): self.stream = Stream(self) self._first_subscription = None self._last_subscription = None + self._last_event = None + self._merge_repeated = merge_repeated_events @property def has_listener(self): @@ -76,8 +78,10 @@ class StreamController: return f def add(self, event): + skip = self._merge_repeated and event == self._last_event + self._last_event = event return self._notify_and_ensure_future( - lambda subscription: subscription._add(event) + lambda subscription: None if skip else subscription._add(event) ) def add_error(self, exception): @@ -141,8 +145,8 @@ class Stream: def first(self): future = asyncio.get_event_loop().create_future() subscription = self.listen( - lambda value: self._cancel_and_callback(subscription, future, value), - lambda exception: self._cancel_and_error(subscription, future, exception) + lambda value: not future.done() and self._cancel_and_callback(subscription, future, value), + lambda exception: not future.done() and self._cancel_and_error(subscription, future, exception) ) return future