niko
This commit is contained in:
parent
de12e25450
commit
593422f70c
8 changed files with 300 additions and 175 deletions
|
@ -18,6 +18,9 @@ class MockNetwork:
|
||||||
self.get_transaction_called = []
|
self.get_transaction_called = []
|
||||||
self.is_connected = False
|
self.is_connected = False
|
||||||
|
|
||||||
|
def retriable_call(self, function, *args, **kwargs):
|
||||||
|
return function(*args, **kwargs)
|
||||||
|
|
||||||
async def get_history(self, address):
|
async def get_history(self, address):
|
||||||
self.get_history_called.append(address)
|
self.get_history_called.append(address)
|
||||||
self.address = address
|
self.address = address
|
||||||
|
@ -121,8 +124,9 @@ class TestSynchronization(LedgerTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MocHeaderNetwork:
|
class MocHeaderNetwork(MockNetwork):
|
||||||
def __init__(self, responses):
|
def __init__(self, responses):
|
||||||
|
super().__init__(None, None)
|
||||||
self.responses = responses
|
self.responses = responses
|
||||||
|
|
||||||
async def get_headers(self, height, blocks):
|
async def get_headers(self, height, blocks):
|
||||||
|
|
|
@ -70,8 +70,10 @@ class BaseHeaders:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def __getitem__(self, height) -> dict:
|
def __getitem__(self, height) -> dict:
|
||||||
assert not isinstance(height, slice), \
|
if isinstance(height, slice):
|
||||||
"Slicing of header chain has not been implemented yet."
|
raise NotImplementedError("Slicing of header chain has not been implemented yet.")
|
||||||
|
if not 0 <= height <= self.height:
|
||||||
|
raise IndexError(f"{height} is out of bounds, current height: {self.height}")
|
||||||
return self.deserialize(height, self.get_raw_header(height))
|
return self.deserialize(height, self.get_raw_header(height))
|
||||||
|
|
||||||
def get_raw_header(self, height) -> bytes:
|
def get_raw_header(self, height) -> bytes:
|
||||||
|
|
|
@ -287,7 +287,7 @@ class BaseLedger(metaclass=LedgerRegistry):
|
||||||
subscription_update = False
|
subscription_update = False
|
||||||
|
|
||||||
if not headers:
|
if not headers:
|
||||||
header_response = await self.network.get_headers(height, 2001)
|
header_response = await self.network.retriable_call(self.network.get_headers, height, 2001)
|
||||||
headers = header_response['hex']
|
headers = header_response['hex']
|
||||||
|
|
||||||
if not headers:
|
if not headers:
|
||||||
|
@ -394,7 +394,7 @@ class BaseLedger(metaclass=LedgerRegistry):
|
||||||
if local_status == remote_status:
|
if local_status == remote_status:
|
||||||
return
|
return
|
||||||
|
|
||||||
remote_history = await self.network.get_history(address)
|
remote_history = await self.network.retriable_call(self.network.get_history, address)
|
||||||
|
|
||||||
cache_tasks = []
|
cache_tasks = []
|
||||||
synced_history = StringIO()
|
synced_history = StringIO()
|
||||||
|
@ -447,6 +447,17 @@ class BaseLedger(metaclass=LedgerRegistry):
|
||||||
if address_manager is not None:
|
if address_manager is not None:
|
||||||
await address_manager.ensure_address_gap()
|
await address_manager.ensure_address_gap()
|
||||||
|
|
||||||
|
local_status, local_history = await self.get_local_status_and_history(address)
|
||||||
|
if local_status != remote_status:
|
||||||
|
log.debug(
|
||||||
|
"Wallet is out of sync after syncing. Remote: %s with %d items, local: %s with %d items",
|
||||||
|
remote_status, len(remote_history), local_status, len(local_history)
|
||||||
|
)
|
||||||
|
log.debug("local: %s", local_history)
|
||||||
|
log.debug("remote: %s", remote_history)
|
||||||
|
else:
|
||||||
|
log.debug("Sync completed for: %s", address)
|
||||||
|
|
||||||
async def cache_transaction(self, txid, remote_height):
|
async def cache_transaction(self, txid, remote_height):
|
||||||
cache_item = self._tx_cache.get(txid)
|
cache_item = self._tx_cache.get(txid)
|
||||||
if cache_item is None:
|
if cache_item is None:
|
||||||
|
@ -466,7 +477,7 @@ class BaseLedger(metaclass=LedgerRegistry):
|
||||||
|
|
||||||
if tx is None:
|
if tx is None:
|
||||||
# fetch from network
|
# fetch from network
|
||||||
_raw = await self.network.get_transaction(txid)
|
_raw = await self.network.retriable_call(self.network.get_transaction, txid)
|
||||||
if _raw:
|
if _raw:
|
||||||
tx = self.transaction_class(unhexlify(_raw))
|
tx = self.transaction_class(unhexlify(_raw))
|
||||||
await self.maybe_verify_transaction(tx, remote_height)
|
await self.maybe_verify_transaction(tx, remote_height)
|
||||||
|
@ -486,8 +497,8 @@ class BaseLedger(metaclass=LedgerRegistry):
|
||||||
|
|
||||||
async def maybe_verify_transaction(self, tx, remote_height):
|
async def maybe_verify_transaction(self, tx, remote_height):
|
||||||
tx.height = remote_height
|
tx.height = remote_height
|
||||||
if 0 < remote_height <= len(self.headers):
|
if 0 < remote_height < len(self.headers):
|
||||||
merkle = await self.network.get_merkle(tx.id, remote_height)
|
merkle = await self.network.retriable_call(self.network.get_merkle, tx.id, remote_height)
|
||||||
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
|
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
|
||||||
header = self.headers[remote_height]
|
header = self.headers[remote_height]
|
||||||
tx.position = merkle['pos']
|
tx.position = merkle['pos']
|
||||||
|
@ -501,6 +512,7 @@ class BaseLedger(metaclass=LedgerRegistry):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def broadcast(self, tx):
|
def broadcast(self, tx):
|
||||||
|
# broadcast cant be a retriable call yet
|
||||||
return self.network.broadcast(hexlify(tx.raw).decode())
|
return self.network.broadcast(hexlify(tx.raw).decode())
|
||||||
|
|
||||||
async def wait(self, tx: basetransaction.BaseTransaction, height=-1, timeout=None):
|
async def wait(self, tx: basetransaction.BaseTransaction, height=-1, timeout=None):
|
||||||
|
@ -522,4 +534,4 @@ class BaseLedger(metaclass=LedgerRegistry):
|
||||||
)) for address_record in records
|
)) for address_record in records
|
||||||
], timeout=timeout)
|
], timeout=timeout)
|
||||||
if pending:
|
if pending:
|
||||||
raise TimeoutError('Timed out waiting for transaction.')
|
raise asyncio.TimeoutError('Timed out waiting for transaction.')
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
from asyncio import CancelledError
|
from operator import itemgetter
|
||||||
from time import time
|
from typing import Dict, Optional, Tuple
|
||||||
from typing import List
|
from time import perf_counter
|
||||||
import socket
|
|
||||||
|
|
||||||
from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError
|
from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError
|
||||||
|
|
||||||
|
@ -14,71 +13,146 @@ log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ClientSession(BaseClientSession):
|
class ClientSession(BaseClientSession):
|
||||||
|
def __init__(self, *args, network, server, timeout=30, on_connect_callback=None, **kwargs):
|
||||||
def __init__(self, *args, network, server, timeout=30, **kwargs):
|
|
||||||
self.network = network
|
self.network = network
|
||||||
self.server = server
|
self.server = server
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self._on_disconnect_controller = StreamController()
|
self._on_disconnect_controller = StreamController()
|
||||||
self.on_disconnected = self._on_disconnect_controller.stream
|
self.on_disconnected = self._on_disconnect_controller.stream
|
||||||
self.bw_limit = self.framer.max_size = self.max_errors = 1 << 32
|
self.framer.max_size = self.max_errors = 1 << 32
|
||||||
|
self.bw_limit = -1
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.max_seconds_idle = timeout * 2
|
self.max_seconds_idle = timeout * 2
|
||||||
self.ping_task = None
|
self.response_time: Optional[float] = None
|
||||||
|
self.connection_latency: Optional[float] = None
|
||||||
|
self._response_samples = 0
|
||||||
|
self.pending_amount = 0
|
||||||
|
self._on_connect_cb = on_connect_callback or (lambda: None)
|
||||||
|
self.trigger_urgent_reconnect = asyncio.Event()
|
||||||
|
self._semaphore = asyncio.Semaphore(int(self.timeout))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def available(self):
|
||||||
|
return not self.is_closing() and self._can_send.is_set() and self.response_time is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def server_address_and_port(self) -> Optional[Tuple[str, int]]:
|
||||||
|
if not self.transport:
|
||||||
|
return None
|
||||||
|
return self.transport.get_extra_info('peername')
|
||||||
|
|
||||||
|
async def send_timed_server_version_request(self, args=(), timeout=None):
|
||||||
|
timeout = timeout or self.timeout
|
||||||
|
log.debug("send version request to %s:%i", *self.server)
|
||||||
|
start = perf_counter()
|
||||||
|
result = await asyncio.wait_for(
|
||||||
|
super().send_request('server.version', args), timeout=timeout
|
||||||
|
)
|
||||||
|
current_response_time = perf_counter() - start
|
||||||
|
response_sum = (self.response_time or 0) * self._response_samples + current_response_time
|
||||||
|
self.response_time = response_sum / (self._response_samples + 1)
|
||||||
|
self._response_samples += 1
|
||||||
|
return result
|
||||||
|
|
||||||
async def send_request(self, method, args=()):
|
async def send_request(self, method, args=()):
|
||||||
try:
|
self.pending_amount += 1
|
||||||
return await asyncio.wait_for(super().send_request(method, args), timeout=self.timeout)
|
async with self._semaphore:
|
||||||
except RPCError as e:
|
return await self._send_request(method, args)
|
||||||
log.warning("Wallet server returned an error. Code: %s Message: %s", *e.args)
|
|
||||||
raise e
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
self.abort()
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def ping_forever(self):
|
async def _send_request(self, method, args=()):
|
||||||
|
log.debug("send %s to %s:%i", method, *self.server)
|
||||||
|
try:
|
||||||
|
if method == 'server.version':
|
||||||
|
reply = await self.send_timed_server_version_request(args, self.timeout)
|
||||||
|
else:
|
||||||
|
reply = await asyncio.wait_for(
|
||||||
|
super().send_request(method, args), timeout=self.timeout
|
||||||
|
)
|
||||||
|
log.debug("got reply for %s from %s:%i", method, *self.server)
|
||||||
|
return reply
|
||||||
|
except RPCError as e:
|
||||||
|
log.warning("Wallet server (%s:%i) returned an error. Code: %s Message: %s",
|
||||||
|
*self.server, *e.args)
|
||||||
|
raise e
|
||||||
|
except ConnectionError:
|
||||||
|
log.warning("connection to %s:%i lost", *self.server)
|
||||||
|
self.synchronous_close()
|
||||||
|
raise asyncio.CancelledError(f"connection to {self.server[0]}:{self.server[1]} lost")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
log.info("timeout sending %s to %s:%i", method, *self.server)
|
||||||
|
raise
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
log.info("cancelled sending %s to %s:%i", method, *self.server)
|
||||||
|
self.synchronous_close()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self.pending_amount -= 1
|
||||||
|
|
||||||
|
async def ensure_session(self):
|
||||||
|
# Handles reconnecting and maintaining a session alive
|
||||||
# TODO: change to 'ping' on newer protocol (above 1.2)
|
# TODO: change to 'ping' on newer protocol (above 1.2)
|
||||||
while not self.is_closing():
|
retry_delay = default_delay = 1.0
|
||||||
if (time() - self.last_send) > self.max_seconds_idle:
|
while True:
|
||||||
try:
|
try:
|
||||||
await self.send_request('server.banner')
|
if self.is_closing():
|
||||||
except:
|
await self.create_connection(self.timeout)
|
||||||
self.abort()
|
await self.ensure_server_version()
|
||||||
raise
|
self._on_connect_cb()
|
||||||
await asyncio.sleep(self.max_seconds_idle//3)
|
if (perf_counter() - self.last_send) > self.max_seconds_idle or self.response_time is None:
|
||||||
|
await self.ensure_server_version()
|
||||||
|
retry_delay = default_delay
|
||||||
|
except (asyncio.TimeoutError, OSError):
|
||||||
|
await self.close()
|
||||||
|
retry_delay = min(60, retry_delay * 2)
|
||||||
|
log.debug("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server)
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self.trigger_urgent_reconnect.wait(), timeout=retry_delay)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
self.trigger_urgent_reconnect.clear()
|
||||||
|
|
||||||
|
async def ensure_server_version(self, required='1.2', timeout=3):
|
||||||
|
return await asyncio.wait_for(
|
||||||
|
self.send_request('server.version', [__version__, required]), timeout=timeout
|
||||||
|
)
|
||||||
|
|
||||||
async def create_connection(self, timeout=6):
|
async def create_connection(self, timeout=6):
|
||||||
connector = Connector(lambda: self, *self.server)
|
connector = Connector(lambda: self, *self.server)
|
||||||
|
start = perf_counter()
|
||||||
await asyncio.wait_for(connector.create_connection(), timeout=timeout)
|
await asyncio.wait_for(connector.create_connection(), timeout=timeout)
|
||||||
self.ping_task = asyncio.create_task(self.ping_forever())
|
self.connection_latency = perf_counter() - start
|
||||||
|
|
||||||
async def handle_request(self, request):
|
async def handle_request(self, request):
|
||||||
controller = self.network.subscription_controllers[request.method]
|
controller = self.network.subscription_controllers[request.method]
|
||||||
controller.add(request.args)
|
controller.add(request.args)
|
||||||
|
|
||||||
def connection_lost(self, exc):
|
def connection_lost(self, exc):
|
||||||
|
log.debug("Connection lost: %s:%d", *self.server)
|
||||||
super().connection_lost(exc)
|
super().connection_lost(exc)
|
||||||
|
self.response_time = None
|
||||||
|
self.connection_latency = None
|
||||||
|
self._response_samples = 0
|
||||||
|
self.pending_amount = 0
|
||||||
self._on_disconnect_controller.add(True)
|
self._on_disconnect_controller.add(True)
|
||||||
if self.ping_task:
|
|
||||||
self.ping_task.cancel()
|
|
||||||
|
|
||||||
|
|
||||||
class BaseNetwork:
|
class BaseNetwork:
|
||||||
|
|
||||||
def __init__(self, ledger):
|
def __init__(self, ledger):
|
||||||
self.config = ledger.config
|
self.config = ledger.config
|
||||||
self.client: ClientSession = None
|
self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6))
|
||||||
self.session_pool: SessionPool = None
|
self.client: Optional[ClientSession] = None
|
||||||
self.running = False
|
self.running = False
|
||||||
self.remote_height: int = 0
|
self.remote_height: int = 0
|
||||||
|
|
||||||
self._on_connected_controller = StreamController()
|
self._on_connected_controller = StreamController()
|
||||||
self.on_connected = self._on_connected_controller.stream
|
self.on_connected = self._on_connected_controller.stream
|
||||||
|
|
||||||
self._on_header_controller = StreamController()
|
self._on_header_controller = StreamController(merge_repeated_events=True)
|
||||||
self.on_header = self._on_header_controller.stream
|
self.on_header = self._on_header_controller.stream
|
||||||
|
|
||||||
self._on_status_controller = StreamController()
|
self._on_status_controller = StreamController(merge_repeated_events=True)
|
||||||
self.on_status = self._on_status_controller.stream
|
self.on_status = self._on_status_controller.stream
|
||||||
|
|
||||||
self.subscription_controllers = {
|
self.subscription_controllers = {
|
||||||
|
@ -86,79 +160,72 @@ class BaseNetwork:
|
||||||
'blockchain.address.subscribe': self._on_status_controller,
|
'blockchain.address.subscribe': self._on_status_controller,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def switch_to_fastest(self):
|
||||||
|
try:
|
||||||
|
client = await asyncio.wait_for(self.session_pool.wait_for_fastest_session(), 30)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
if self.client:
|
||||||
|
await self.client.close()
|
||||||
|
self.client = None
|
||||||
|
for session in self.session_pool.sessions:
|
||||||
|
session.synchronous_close()
|
||||||
|
log.warning("not connected to any wallet servers")
|
||||||
|
return
|
||||||
|
current_client = self.client
|
||||||
|
self.client = client
|
||||||
|
log.info("Switching to SPV wallet server: %s:%d", *self.client.server)
|
||||||
|
self._on_connected_controller.add(True)
|
||||||
|
try:
|
||||||
|
self._update_remote_height((await self.subscribe_headers(),))
|
||||||
|
log.info("Subscribed to headers: %s:%d", *self.client.server)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
if self.client:
|
||||||
|
await self.client.close()
|
||||||
|
self.client = current_client
|
||||||
|
return
|
||||||
|
self.session_pool.new_connection_event.clear()
|
||||||
|
return await self.session_pool.new_connection_event.wait()
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
self.running = True
|
self.running = True
|
||||||
connect_timeout = self.config.get('connect_timeout', 6)
|
|
||||||
self.session_pool = SessionPool(network=self, timeout=connect_timeout)
|
|
||||||
self.session_pool.start(self.config['default_servers'])
|
self.session_pool.start(self.config['default_servers'])
|
||||||
self.on_header.listen(self._update_remote_height)
|
self.on_header.listen(self._update_remote_height)
|
||||||
while True:
|
while self.running:
|
||||||
try:
|
await self.switch_to_fastest()
|
||||||
self.client = await self.pick_fastest_session()
|
|
||||||
if self.is_connected:
|
|
||||||
await self.ensure_server_version()
|
|
||||||
self._update_remote_height((await self.subscribe_headers(),))
|
|
||||||
log.info("Successfully connected to SPV wallet server: %s:%d", *self.client.server)
|
|
||||||
self._on_connected_controller.add(True)
|
|
||||||
await self.client.on_disconnected.first
|
|
||||||
except CancelledError:
|
|
||||||
self.running = False
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
log.warning("Timed out while trying to find a server!")
|
|
||||||
except Exception: # pylint: disable=broad-except
|
|
||||||
log.exception("Exception while trying to find a server!")
|
|
||||||
if not self.running:
|
|
||||||
return
|
|
||||||
elif self.client:
|
|
||||||
await self.client.close()
|
|
||||||
self.client.connection.cancel_pending_requests()
|
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
self.running = False
|
self.running = False
|
||||||
if self.session_pool:
|
self.session_pool.stop()
|
||||||
self.session_pool.stop()
|
|
||||||
if self.is_connected:
|
|
||||||
disconnected = self.client.on_disconnected.first
|
|
||||||
await self.client.close()
|
|
||||||
await disconnected
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_connected(self):
|
def is_connected(self):
|
||||||
return self.client is not None and not self.client.is_closing()
|
return self.client and not self.client.is_closing()
|
||||||
|
|
||||||
def rpc(self, list_or_method, args):
|
def rpc(self, list_or_method, args, session=None):
|
||||||
if self.is_connected:
|
session = session or self.session_pool.fastest_session
|
||||||
return self.client.send_request(list_or_method, args)
|
if session:
|
||||||
|
return session.send_request(list_or_method, args)
|
||||||
else:
|
else:
|
||||||
|
self.session_pool.trigger_nodelay_connect()
|
||||||
raise ConnectionError("Attempting to send rpc request when connection is not available.")
|
raise ConnectionError("Attempting to send rpc request when connection is not available.")
|
||||||
|
|
||||||
async def pick_fastest_session(self):
|
async def retriable_call(self, function, *args, **kwargs):
|
||||||
sessions = await self.session_pool.get_online_sessions()
|
while self.running:
|
||||||
done, pending = await asyncio.wait([
|
if not self.is_connected:
|
||||||
self.probe_session(session)
|
log.warning("Wallet server unavailable, waiting for it to come back and retry.")
|
||||||
for session in sessions if not session.is_closing()
|
await self.on_connected.first
|
||||||
], return_when='FIRST_COMPLETED')
|
await self.session_pool.wait_for_fastest_session()
|
||||||
for task in pending:
|
try:
|
||||||
task.cancel()
|
return await function(*args, **kwargs)
|
||||||
for session in done:
|
except asyncio.TimeoutError:
|
||||||
return await session
|
log.warning("Wallet server call timed out, retrying.")
|
||||||
|
except ConnectionError:
|
||||||
async def probe_session(self, session: ClientSession):
|
pass
|
||||||
await session.send_request('server.banner')
|
raise asyncio.CancelledError() # if we got here, we are shutting down
|
||||||
return session
|
|
||||||
|
|
||||||
def _update_remote_height(self, header_args):
|
def _update_remote_height(self, header_args):
|
||||||
self.remote_height = header_args[0]["height"]
|
self.remote_height = header_args[0]["height"]
|
||||||
|
|
||||||
def ensure_server_version(self, required='1.2'):
|
|
||||||
return self.rpc('server.version', [__version__, required])
|
|
||||||
|
|
||||||
def broadcast(self, raw_transaction):
|
|
||||||
return self.rpc('blockchain.transaction.broadcast', [raw_transaction])
|
|
||||||
|
|
||||||
def get_history(self, address):
|
|
||||||
return self.rpc('blockchain.address.get_history', [address])
|
|
||||||
|
|
||||||
def get_transaction(self, tx_hash):
|
def get_transaction(self, tx_hash):
|
||||||
return self.rpc('blockchain.transaction.get', [tx_hash])
|
return self.rpc('blockchain.transaction.get', [tx_hash])
|
||||||
|
|
||||||
|
@ -171,84 +238,111 @@ class BaseNetwork:
|
||||||
def get_headers(self, height, count=10000):
|
def get_headers(self, height, count=10000):
|
||||||
return self.rpc('blockchain.block.headers', [height, count])
|
return self.rpc('blockchain.block.headers', [height, count])
|
||||||
|
|
||||||
def subscribe_headers(self):
|
# --- Subscribes, history and broadcasts are always aimed towards the master client directly
|
||||||
return self.rpc('blockchain.headers.subscribe', [True])
|
def get_history(self, address):
|
||||||
|
return self.rpc('blockchain.address.get_history', [address], session=self.client)
|
||||||
|
|
||||||
def subscribe_address(self, address):
|
def broadcast(self, raw_transaction):
|
||||||
return self.rpc('blockchain.address.subscribe', [address])
|
return self.rpc('blockchain.transaction.broadcast', [raw_transaction], session=self.client)
|
||||||
|
|
||||||
|
def subscribe_headers(self):
|
||||||
|
return self.rpc('blockchain.headers.subscribe', [True], session=self.client)
|
||||||
|
|
||||||
|
async def subscribe_address(self, address):
|
||||||
|
try:
|
||||||
|
return await self.rpc('blockchain.address.subscribe', [address], session=self.client)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# abort and cancel, we cant lose a subscription, it will happen again on reconnect
|
||||||
|
self.client.abort()
|
||||||
|
raise asyncio.CancelledError()
|
||||||
|
|
||||||
|
|
||||||
class SessionPool:
|
class SessionPool:
|
||||||
|
|
||||||
def __init__(self, network: BaseNetwork, timeout: float):
|
def __init__(self, network: BaseNetwork, timeout: float):
|
||||||
self.network = network
|
self.network = network
|
||||||
self.sessions: List[ClientSession] = []
|
self.sessions: Dict[ClientSession, Optional[asyncio.Task]] = dict()
|
||||||
self._dead_servers: List[ClientSession] = []
|
|
||||||
self.maintain_connections_task = None
|
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
# triggered when the master server is out, to speed up reconnect
|
self.new_connection_event = asyncio.Event()
|
||||||
self._lost_master = asyncio.Event()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def online(self):
|
def online(self):
|
||||||
for session in self.sessions:
|
return any(not session.is_closing() for session in self.sessions)
|
||||||
if not session.is_closing():
|
|
||||||
return True
|
@property
|
||||||
return False
|
def available_sessions(self):
|
||||||
|
return [session for session in self.sessions if session.available]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fastest_session(self):
|
||||||
|
if not self.available_sessions:
|
||||||
|
return None
|
||||||
|
return min(
|
||||||
|
[((session.response_time + session.connection_latency) * (session.pending_amount + 1), session)
|
||||||
|
for session in self.available_sessions],
|
||||||
|
key=itemgetter(0)
|
||||||
|
)[1]
|
||||||
|
|
||||||
|
def _get_session_connect_callback(self, session: ClientSession):
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
def callback():
|
||||||
|
duplicate_connections = [
|
||||||
|
s for s in self.sessions
|
||||||
|
if s is not session and s.server_address_and_port == session.server_address_and_port
|
||||||
|
]
|
||||||
|
already_connected = None if not duplicate_connections else duplicate_connections[0]
|
||||||
|
if already_connected:
|
||||||
|
self.sessions.pop(session).cancel()
|
||||||
|
session.synchronous_close()
|
||||||
|
log.debug("wallet server %s resolves to the same server as %s, rechecking in an hour",
|
||||||
|
session.server[0], already_connected.server[0])
|
||||||
|
loop.call_later(3600, self._connect_session, session.server)
|
||||||
|
return
|
||||||
|
self.new_connection_event.set()
|
||||||
|
log.info("connected to %s:%i", *session.server)
|
||||||
|
|
||||||
|
return callback
|
||||||
|
|
||||||
|
def _connect_session(self, server: Tuple[str, int]):
|
||||||
|
session = None
|
||||||
|
for s in self.sessions:
|
||||||
|
if s.server == server:
|
||||||
|
session = s
|
||||||
|
break
|
||||||
|
if not session:
|
||||||
|
session = ClientSession(
|
||||||
|
network=self.network, server=server
|
||||||
|
)
|
||||||
|
session._on_connect_cb = self._get_session_connect_callback(session)
|
||||||
|
task = self.sessions.get(session, None)
|
||||||
|
if not task or task.done():
|
||||||
|
task = asyncio.create_task(session.ensure_session())
|
||||||
|
task.add_done_callback(lambda _: self.ensure_connections())
|
||||||
|
self.sessions[session] = task
|
||||||
|
|
||||||
def start(self, default_servers):
|
def start(self, default_servers):
|
||||||
self.sessions = [
|
for server in default_servers:
|
||||||
ClientSession(network=self.network, server=server)
|
self._connect_session(server)
|
||||||
for server in default_servers
|
|
||||||
]
|
|
||||||
self.maintain_connections_task = asyncio.create_task(self.ensure_connections())
|
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
if self.maintain_connections_task:
|
for task in self.sessions.values():
|
||||||
self.maintain_connections_task.cancel()
|
task.cancel()
|
||||||
|
self.sessions.clear()
|
||||||
|
|
||||||
|
def ensure_connections(self):
|
||||||
for session in self.sessions:
|
for session in self.sessions:
|
||||||
if not session.is_closing():
|
self._connect_session(session.server)
|
||||||
session.abort()
|
|
||||||
self.sessions, self._dead_servers, self.maintain_connections_task = [], [], None
|
|
||||||
|
|
||||||
async def ensure_connections(self):
|
def trigger_nodelay_connect(self):
|
||||||
while True:
|
# used when other parts of the system sees we might have internet back
|
||||||
await asyncio.gather(*[
|
# bypasses the retry interval
|
||||||
self.ensure_connection(session)
|
for session in self.sessions:
|
||||||
for session in self.sessions
|
session.trigger_urgent_reconnect.set()
|
||||||
], return_exceptions=True)
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(self._lost_master.wait(), timeout=3)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
pass
|
|
||||||
self._lost_master.clear()
|
|
||||||
if not self.sessions:
|
|
||||||
self.sessions.extend(self._dead_servers)
|
|
||||||
self._dead_servers = []
|
|
||||||
|
|
||||||
async def ensure_connection(self, session):
|
async def wait_for_fastest_session(self):
|
||||||
self._dead_servers.append(session)
|
while not self.fastest_session:
|
||||||
self.sessions.remove(session)
|
self.trigger_nodelay_connect()
|
||||||
try:
|
self.new_connection_event.clear()
|
||||||
if session.is_closing():
|
await self.new_connection_event.wait()
|
||||||
await session.create_connection(self.timeout)
|
return self.fastest_session
|
||||||
await asyncio.wait_for(session.send_request('server.banner'), timeout=self.timeout)
|
|
||||||
self.sessions.append(session)
|
|
||||||
self._dead_servers.remove(session)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
log.warning("Timeout connecting to %s:%d", *session.server)
|
|
||||||
except asyncio.CancelledError: # pylint: disable=try-except-raise
|
|
||||||
raise
|
|
||||||
except socket.gaierror:
|
|
||||||
log.warning("Could not resolve IP for %s", session.server[0])
|
|
||||||
except Exception as err: # pylint: disable=broad-except
|
|
||||||
if 'Connect call failed' in str(err):
|
|
||||||
log.warning("Could not connect to %s:%d", *session.server)
|
|
||||||
else:
|
|
||||||
log.exception("Connecting to %s:%d raised an exception:", *session.server)
|
|
||||||
|
|
||||||
async def get_online_sessions(self):
|
|
||||||
while not self.online:
|
|
||||||
self._lost_master.set()
|
|
||||||
await asyncio.sleep(0.5)
|
|
||||||
return self.sessions
|
|
||||||
|
|
|
@ -33,6 +33,7 @@ __all__ = ('JSONRPC', 'JSONRPCv1', 'JSONRPCv2', 'JSONRPCLoose',
|
||||||
import itertools
|
import itertools
|
||||||
import json
|
import json
|
||||||
import typing
|
import typing
|
||||||
|
import asyncio
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
|
|
||||||
|
@ -745,9 +746,8 @@ class JSONRPCConnection(object):
|
||||||
self._protocol = item
|
self._protocol = item
|
||||||
return self.receive_message(message)
|
return self.receive_message(message)
|
||||||
|
|
||||||
def cancel_pending_requests(self):
|
def raise_pending_requests(self, exception):
|
||||||
"""Cancel all pending requests."""
|
exception = exception or asyncio.TimeoutError()
|
||||||
exception = CancelledError()
|
|
||||||
for request, event in self._requests.values():
|
for request, event in self._requests.values():
|
||||||
event.result = exception
|
event.result = exception
|
||||||
event.set()
|
event.set()
|
||||||
|
|
|
@ -103,7 +103,7 @@ class SessionBase(asyncio.Protocol):
|
||||||
# Force-close a connection if a send doesn't succeed in this time
|
# Force-close a connection if a send doesn't succeed in this time
|
||||||
self.max_send_delay = 60
|
self.max_send_delay = 60
|
||||||
# Statistics. The RPC object also keeps its own statistics.
|
# Statistics. The RPC object also keeps its own statistics.
|
||||||
self.start_time = time.time()
|
self.start_time = time.perf_counter()
|
||||||
self.errors = 0
|
self.errors = 0
|
||||||
self.send_count = 0
|
self.send_count = 0
|
||||||
self.send_size = 0
|
self.send_size = 0
|
||||||
|
@ -123,7 +123,7 @@ class SessionBase(asyncio.Protocol):
|
||||||
# A non-positive value means not to limit concurrency
|
# A non-positive value means not to limit concurrency
|
||||||
if self.bw_limit <= 0:
|
if self.bw_limit <= 0:
|
||||||
return
|
return
|
||||||
now = time.time()
|
now = time.perf_counter()
|
||||||
# Reduce the recorded usage in proportion to the elapsed time
|
# Reduce the recorded usage in proportion to the elapsed time
|
||||||
refund = (now - self.bw_time) * (self.bw_limit / 3600)
|
refund = (now - self.bw_time) * (self.bw_limit / 3600)
|
||||||
self.bw_charge = max(0, self.bw_charge - int(refund))
|
self.bw_charge = max(0, self.bw_charge - int(refund))
|
||||||
|
@ -146,7 +146,7 @@ class SessionBase(asyncio.Protocol):
|
||||||
await asyncio.wait_for(self._can_send.wait(), secs)
|
await asyncio.wait_for(self._can_send.wait(), secs)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
self.abort()
|
self.abort()
|
||||||
raise asyncio.CancelledError(f'task timed out after {secs}s')
|
raise asyncio.TimeoutError(f'task timed out after {secs}s')
|
||||||
|
|
||||||
async def _send_message(self, message):
|
async def _send_message(self, message):
|
||||||
if not self._can_send.is_set():
|
if not self._can_send.is_set():
|
||||||
|
@ -156,7 +156,7 @@ class SessionBase(asyncio.Protocol):
|
||||||
self.send_size += len(framed_message)
|
self.send_size += len(framed_message)
|
||||||
self._using_bandwidth(len(framed_message))
|
self._using_bandwidth(len(framed_message))
|
||||||
self.send_count += 1
|
self.send_count += 1
|
||||||
self.last_send = time.time()
|
self.last_send = time.perf_counter()
|
||||||
if self.verbosity >= 4:
|
if self.verbosity >= 4:
|
||||||
self.logger.debug(f'Sending framed message {framed_message}')
|
self.logger.debug(f'Sending framed message {framed_message}')
|
||||||
self.transport.write(framed_message)
|
self.transport.write(framed_message)
|
||||||
|
@ -215,7 +215,8 @@ class SessionBase(asyncio.Protocol):
|
||||||
self._address = None
|
self._address = None
|
||||||
self.transport = None
|
self.transport = None
|
||||||
self._task_group.cancel()
|
self._task_group.cancel()
|
||||||
self._pm_task.cancel()
|
if self._pm_task:
|
||||||
|
self._pm_task.cancel()
|
||||||
# Release waiting tasks
|
# Release waiting tasks
|
||||||
self._can_send.set()
|
self._can_send.set()
|
||||||
|
|
||||||
|
@ -253,6 +254,7 @@ class SessionBase(asyncio.Protocol):
|
||||||
if self.transport:
|
if self.transport:
|
||||||
self.transport.abort()
|
self.transport.abort()
|
||||||
|
|
||||||
|
# TODO: replace with synchronous_close
|
||||||
async def close(self, *, force_after=30):
|
async def close(self, *, force_after=30):
|
||||||
"""Close the connection and return when closed."""
|
"""Close the connection and return when closed."""
|
||||||
self._close()
|
self._close()
|
||||||
|
@ -262,6 +264,11 @@ class SessionBase(asyncio.Protocol):
|
||||||
self.abort()
|
self.abort()
|
||||||
await self._pm_task
|
await self._pm_task
|
||||||
|
|
||||||
|
def synchronous_close(self):
|
||||||
|
self._close()
|
||||||
|
if self._pm_task and not self._pm_task.done():
|
||||||
|
self._pm_task.cancel()
|
||||||
|
|
||||||
|
|
||||||
class MessageSession(SessionBase):
|
class MessageSession(SessionBase):
|
||||||
"""Session class for protocols where messages are not tied to responses,
|
"""Session class for protocols where messages are not tied to responses,
|
||||||
|
@ -296,7 +303,7 @@ class MessageSession(SessionBase):
|
||||||
)
|
)
|
||||||
self._bump_errors()
|
self._bump_errors()
|
||||||
else:
|
else:
|
||||||
self.last_recv = time.time()
|
self.last_recv = time.perf_counter()
|
||||||
self.recv_count += 1
|
self.recv_count += 1
|
||||||
if self.recv_count % 10 == 0:
|
if self.recv_count % 10 == 0:
|
||||||
await self._update_concurrency()
|
await self._update_concurrency()
|
||||||
|
@ -416,7 +423,7 @@ class RPCSession(SessionBase):
|
||||||
self.logger.warning(f'{e!r}')
|
self.logger.warning(f'{e!r}')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.last_recv = time.time()
|
self.last_recv = time.perf_counter()
|
||||||
self.recv_count += 1
|
self.recv_count += 1
|
||||||
if self.recv_count % 10 == 0:
|
if self.recv_count % 10 == 0:
|
||||||
await self._update_concurrency()
|
await self._update_concurrency()
|
||||||
|
@ -456,7 +463,7 @@ class RPCSession(SessionBase):
|
||||||
|
|
||||||
def connection_lost(self, exc):
|
def connection_lost(self, exc):
|
||||||
# Cancel pending requests and message processing
|
# Cancel pending requests and message processing
|
||||||
self.connection.cancel_pending_requests()
|
self.connection.raise_pending_requests(exc)
|
||||||
super().connection_lost(exc)
|
super().connection_lost(exc)
|
||||||
|
|
||||||
# External API
|
# External API
|
||||||
|
@ -473,6 +480,8 @@ class RPCSession(SessionBase):
|
||||||
|
|
||||||
async def send_request(self, method, args=()):
|
async def send_request(self, method, args=()):
|
||||||
"""Send an RPC request over the network."""
|
"""Send an RPC request over the network."""
|
||||||
|
if self.is_closing():
|
||||||
|
raise asyncio.TimeoutError("Trying to send request on a recently dropped connection.")
|
||||||
message, event = self.connection.send_request(Request(method, args))
|
message, event = self.connection.send_request(Request(method, args))
|
||||||
await self._send_message(message)
|
await self._send_message(message)
|
||||||
await event.wait()
|
await event.wait()
|
||||||
|
|
|
@ -258,7 +258,7 @@ class SessionManager:
|
||||||
session_timeout = self.env.session_timeout
|
session_timeout = self.env.session_timeout
|
||||||
while True:
|
while True:
|
||||||
await sleep(session_timeout // 10)
|
await sleep(session_timeout // 10)
|
||||||
stale_cutoff = time.time() - session_timeout
|
stale_cutoff = time.perf_counter() - session_timeout
|
||||||
stale_sessions = [session for session in self.sessions
|
stale_sessions = [session for session in self.sessions
|
||||||
if session.last_recv < stale_cutoff]
|
if session.last_recv < stale_cutoff]
|
||||||
if stale_sessions:
|
if stale_sessions:
|
||||||
|
|
|
@ -45,10 +45,12 @@ class BroadcastSubscription:
|
||||||
|
|
||||||
class StreamController:
|
class StreamController:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, merge_repeated_events=False):
|
||||||
self.stream = Stream(self)
|
self.stream = Stream(self)
|
||||||
self._first_subscription = None
|
self._first_subscription = None
|
||||||
self._last_subscription = None
|
self._last_subscription = None
|
||||||
|
self._last_event = None
|
||||||
|
self._merge_repeated = merge_repeated_events
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def has_listener(self):
|
def has_listener(self):
|
||||||
|
@ -76,8 +78,10 @@ class StreamController:
|
||||||
return f
|
return f
|
||||||
|
|
||||||
def add(self, event):
|
def add(self, event):
|
||||||
|
skip = self._merge_repeated and event == self._last_event
|
||||||
|
self._last_event = event
|
||||||
return self._notify_and_ensure_future(
|
return self._notify_and_ensure_future(
|
||||||
lambda subscription: subscription._add(event)
|
lambda subscription: None if skip else subscription._add(event)
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_error(self, exception):
|
def add_error(self, exception):
|
||||||
|
@ -141,8 +145,8 @@ class Stream:
|
||||||
def first(self):
|
def first(self):
|
||||||
future = asyncio.get_event_loop().create_future()
|
future = asyncio.get_event_loop().create_future()
|
||||||
subscription = self.listen(
|
subscription = self.listen(
|
||||||
lambda value: self._cancel_and_callback(subscription, future, value),
|
lambda value: not future.done() and self._cancel_and_callback(subscription, future, value),
|
||||||
lambda exception: self._cancel_and_error(subscription, future, exception)
|
lambda exception: not future.done() and self._cancel_and_error(subscription, future, exception)
|
||||||
)
|
)
|
||||||
return future
|
return future
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue