fixes from review

This commit is contained in:
Victor Shyba 2019-08-12 13:32:20 -03:00
parent 9ee2f30df4
commit 4ead92cfbe
4 changed files with 20 additions and 14 deletions

View file

@ -45,8 +45,9 @@ class ReconnectTests(IntegrationTestCase):
# * goes to pick some water outside... * time passes by and another donation comes in # * goes to pick some water outside... * time passes by and another donation comes in
sendtxid = await self.blockchain.send_to_address(address1, 42) sendtxid = await self.blockchain.send_to_address(address1, 42)
await self.blockchain.generate(1) await self.blockchain.generate(1)
# (this is just so the test doesnt hang forever if it doesnt reconnect, also its not instant yet) # (this is just so the test doesnt hang forever if it doesnt reconnect)
await asyncio.wait_for(self.ledger.network.on_connected.first, timeout=1.0) if not self.ledger.network.is_connected:
await asyncio.wait_for(self.ledger.network.on_connected.first, timeout=1.0)
# omg, the burned cable still works! torba is fire proof! # omg, the burned cable still works! torba is fire proof!
await self.ledger.network.get_transaction(sendtxid) await self.ledger.network.get_transaction(sendtxid)

View file

@ -2,7 +2,7 @@ import logging
import asyncio import asyncio
from operator import itemgetter from operator import itemgetter
from typing import Dict, Optional from typing import Dict, Optional
from time import time from time import perf_counter as time
from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError
@ -14,7 +14,7 @@ log = logging.getLogger(__name__)
class ClientSession(BaseClientSession): class ClientSession(BaseClientSession):
def __init__(self, *args, network, server, timeout=30, **kwargs): def __init__(self, *args, network, server, timeout=30, on_connect_callback=None, **kwargs):
self.network = network self.network = network
self.server = server self.server = server
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -24,6 +24,7 @@ class ClientSession(BaseClientSession):
self.timeout = timeout self.timeout = timeout
self.max_seconds_idle = timeout * 2 self.max_seconds_idle = timeout * 2
self.latency = 1 << 32 self.latency = 1 << 32
self._on_connect_cb = on_connect_callback or (lambda: None)
@property @property
def available(self): def available(self):
@ -53,6 +54,7 @@ class ClientSession(BaseClientSession):
if self.is_closing(): if self.is_closing():
await self.create_connection(self.timeout) await self.create_connection(self.timeout)
await self.ensure_server_version() await self.ensure_server_version()
self._on_connect_cb()
if (time() - self.last_send) > self.max_seconds_idle or self.latency == 1 << 32: if (time() - self.last_send) > self.max_seconds_idle or self.latency == 1 << 32:
await self.send_request('server.banner') await self.send_request('server.banner')
retry_delay = default_delay retry_delay = default_delay
@ -177,8 +179,9 @@ class SessionPool:
def __init__(self, network: BaseNetwork, timeout: float): def __init__(self, network: BaseNetwork, timeout: float):
self.network = network self.network = network
self.sessions: Dict[ClientSession, asyncio.Task] = dict() self.sessions: Dict[ClientSession, Optional[asyncio.Task]] = dict()
self.timeout = timeout self.timeout = timeout
self.new_connection_event = asyncio.Event()
@property @property
def online(self): def online(self):
@ -195,8 +198,11 @@ class SessionPool:
return min([(session.latency, session) for session in self.available_sessions], key=itemgetter(0))[1] return min([(session.latency, session) for session in self.available_sessions], key=itemgetter(0))[1]
def start(self, default_servers): def start(self, default_servers):
callback = self.new_connection_event.set
self.sessions = { self.sessions = {
ClientSession(network=self.network, server=server): None for server in default_servers ClientSession(
network=self.network, server=server, on_connect_callback=callback
): None for server in default_servers
} }
self.ensure_connections() self.ensure_connections()
@ -214,9 +220,7 @@ class SessionPool:
self.sessions[session] = task self.sessions[session] = task
async def wait_for_fastest_session(self): async def wait_for_fastest_session(self):
while True: while not self.fastest_session:
fastest = self.fastest_session self.new_connection_event.clear()
if fastest: await self.new_connection_event.wait()
return fastest return self.fastest_session
else:
await asyncio.sleep(0.5)

View file

@ -33,6 +33,7 @@ __all__ = ('JSONRPC', 'JSONRPCv1', 'JSONRPCv2', 'JSONRPCLoose',
import itertools import itertools
import json import json
import typing import typing
import asyncio
from functools import partial from functools import partial
from numbers import Number from numbers import Number
@ -748,7 +749,7 @@ class JSONRPCConnection(object):
def time_out_pending_requests(self): def time_out_pending_requests(self):
"""Times out all pending requests.""" """Times out all pending requests."""
# this used to be CancelledError, but thats confusing as in are we closing the whole sdk or failing? # this used to be CancelledError, but thats confusing as in are we closing the whole sdk or failing?
exception = TimeoutError() exception = asyncio.TimeoutError()
for request, event in self._requests.values(): for request, event in self._requests.values():
event.result = exception event.result = exception
event.set() event.set()

View file

@ -474,7 +474,7 @@ class RPCSession(SessionBase):
async def send_request(self, method, args=()): async def send_request(self, method, args=()):
"""Send an RPC request over the network.""" """Send an RPC request over the network."""
if self.is_closing(): if self.is_closing():
raise asyncio.TimeoutError() raise asyncio.TimeoutError("Trying to send request on a recently dropped connection.")
message, event = self.connection.send_request(Request(method, args)) message, event = self.connection.send_request(Request(method, args))
await self._send_message(message) await self._send_message(message)
await event.wait() await event.wait()