Compare commits

...

7 commits

Author SHA1 Message Date
Victor Shyba 550458b2f1 WAL_CHECKPOINT on close 2020-11-18 13:38:57 -03:00
Jack Robison 0ed067cbe5
test 2020-11-18 10:51:36 -05:00
Jack Robison bf336e9905
fix test 2020-11-17 17:29:06 -05:00
Jack Robison 69ded96516
pylint 2020-11-17 17:02:58 -05:00
Jack Robison 0caec7e629
fix 2020-11-17 17:02:58 -05:00
Jack Robison 73c40cef60
tests 2020-11-17 17:02:58 -05:00
Jack Robison f992e86675
one spv connection per loaded wallet 2020-11-17 17:02:58 -05:00
11 changed files with 227 additions and 99 deletions

View file

@ -1303,8 +1303,9 @@ class Daemon(metaclass=JSONRPCServerType):
'name': SingleKey.name if single_key else HierarchicalDeterministic.name
}
)
if self.ledger.network.is_connected:
await self.ledger.subscribe_account(account)
await self.ledger._subscribe_accounts([account])
elif wallet.accounts:
await self.ledger._subscribe_accounts(wallet.accounts)
wallet.save()
if not skip_on_startup:
with self.conf.update_config() as c:
@ -1331,9 +1332,9 @@ class Daemon(metaclass=JSONRPCServerType):
if not os.path.exists(wallet_path):
raise Exception(f"Wallet at path '{wallet_path}' was not found.")
wallet = self.wallet_manager.import_wallet(wallet_path)
if self.ledger.network.is_connected:
for account in wallet.accounts:
await self.ledger.subscribe_account(account)
if not self.ledger.network.is_connected(wallet.id):
await self.ledger._subscribe_accounts(wallet.accounts)
return wallet
@requires("wallet")
@ -1619,7 +1620,7 @@ class Daemon(metaclass=JSONRPCServerType):
}
)
wallet.save()
if self.ledger.network.is_connected:
if self.ledger.network.is_connected(wallet.id):
await self.ledger.subscribe_account(account)
return account
@ -1647,7 +1648,7 @@ class Daemon(metaclass=JSONRPCServerType):
}
)
wallet.save()
if self.ledger.network.is_connected:
if self.ledger.network.is_connected(wallet.id):
await self.ledger.subscribe_account(account)
return account
@ -1863,7 +1864,7 @@ class Daemon(metaclass=JSONRPCServerType):
wallet_changed = False
if data is not None:
added_accounts = wallet.merge(self.wallet_manager, password, data)
if added_accounts and self.ledger.network.is_connected:
if added_accounts and self.ledger.network.is_connected(wallet.id):
if blocking:
await asyncio.wait([
a.ledger.subscribe_account(a) for a in added_accounts
@ -2957,7 +2958,7 @@ class Daemon(metaclass=JSONRPCServerType):
'public_key': data['holding_public_key'],
'address_generator': {'name': 'single-address'}
})
if self.ledger.network.is_connected:
if self.ledger.network.is_connected(wallet.id):
await self.ledger.subscribe_account(account)
await self.ledger._update_tasks.done.wait()
# Case 3: the holding address has changed and we can't create or find an account for it

View file

@ -564,7 +564,7 @@ class Account:
self.change.gap = new_change_gap
gap_changed = True
if gap_changed:
self.wallet.save()
await asyncio.get_event_loop().run_in_executor(None, self.wallet.save)
async def get_detailed_balance(self, confirmations=0, reserved_subtotals=False, read_only=False):
tips_balance, supports_balance, claims_balance = 0, 0, 0

View file

@ -121,7 +121,12 @@ class AIOSQLite:
if self._closing:
return
self._closing = True
await asyncio.get_event_loop().run_in_executor(self.writer_executor, self.writer_connection.close)
def __checkpoint_and_close(conn: sqlite3.Connection):
conn.execute("PRAGMA WAL_CHECKPOINT(FULL);")
conn.close()
await asyncio.get_event_loop().run_in_executor(
self.writer_executor, __checkpoint_and_close, self.writer_connection)
self.writer_executor.shutdown(wait=True)
self.reader_executor.shutdown(wait=True)
self.read_ready.clear()

View file

@ -122,7 +122,7 @@ class Ledger(metaclass=LedgerRegistry):
self.headers.checkpoints = self.checkpoints
self.network: Network = self.config.get('network') or Network(self)
self.network.on_header.listen(self.receive_header)
self.network.on_status.listen(self.process_status_update)
# self.network.on_status.listen(self.process_status_update)
self.network.on_connected.listen(self.join_network)
self.accounts = []
@ -324,12 +324,15 @@ class Ledger(metaclass=LedgerRegistry):
self.db.open(),
self.headers.open()
])
fully_synced = self.on_ready.first
asyncio.create_task(self.network.start())
await self.network.on_connected.first
async with self._header_processing_lock:
await self._update_tasks.add(self.initial_headers_sync())
await fully_synced
await self.db.release_all_outputs()
await asyncio.gather(*(a.maybe_migrate_certificates() for a in self.accounts))
await asyncio.gather(*(a.save_max_gap() for a in self.accounts))
@ -448,10 +451,16 @@ class Ledger(metaclass=LedgerRegistry):
)
async def subscribe_accounts(self):
if self.network.is_connected and self.accounts:
log.info("Subscribe to %i accounts", len(self.accounts))
return await self._subscribe_accounts(self.accounts)
async def _subscribe_accounts(self, accounts):
wallet_ids = {account.wallet.id for account in accounts}
await self.network.connect_wallets(*wallet_ids)
# if self.network.is_connected and self.accounts:
if accounts:
log.info("Subscribe to %i accounts", len(accounts))
await asyncio.wait([
self.subscribe_account(a) for a in self.accounts
self.subscribe_account(a) for a in accounts
])
async def subscribe_account(self, account: Account):
@ -460,8 +469,10 @@ class Ledger(metaclass=LedgerRegistry):
await account.ensure_address_gap()
async def unsubscribe_account(self, account: Account):
session = self.network.get_wallet_session(account.wallet)
for address in await account.get_addresses():
await self.network.unsubscribe_address(address)
await self.network.unsubscribe_address(address, session=session)
await self.network.close_wallet_session(account.wallet)
async def announce_addresses(self, address_manager: AddressManager, addresses: List[str]):
await self.subscribe_addresses(address_manager, addresses)
@ -470,28 +481,29 @@ class Ledger(metaclass=LedgerRegistry):
)
async def subscribe_addresses(self, address_manager: AddressManager, addresses: List[str], batch_size: int = 1000):
if self.network.is_connected and addresses:
if self.network.is_connected(address_manager.account.wallet.id) and addresses:
session = self.network.get_wallet_session(address_manager.account.wallet)
addresses_remaining = list(addresses)
while addresses_remaining:
batch = addresses_remaining[:batch_size]
results = await self.network.subscribe_address(*batch)
results = await self.network.subscribe_address(session, *batch)
for address, remote_status in zip(batch, results):
self._update_tasks.add(self.update_history(address, remote_status, address_manager))
self._update_tasks.add(self.update_history(session, address, remote_status, address_manager))
addresses_remaining = addresses_remaining[batch_size:]
if self.network.client and self.network.client.server_address_and_port:
if session and session.server_address_and_port:
log.info("subscribed to %i/%i addresses on %s:%i", len(addresses) - len(addresses_remaining),
len(addresses), *self.network.client.server_address_and_port)
if self.network.client and self.network.client.server_address_and_port:
len(addresses), *session.server_address_and_port)
if session and session.server_address_and_port:
log.info(
"finished subscribing to %i addresses on %s:%i", len(addresses),
*self.network.client.server_address_and_port
*session.server_address_and_port
)
def process_status_update(self, update):
def process_status_update(self, client, update):
address, remote_status = update
self._update_tasks.add(self.update_history(address, remote_status))
self._update_tasks.add(self.update_history(client, address, remote_status))
async def update_history(self, address, remote_status, address_manager: AddressManager = None,
async def update_history(self, client, address, remote_status, address_manager: AddressManager = None,
reattempt_update: bool = True):
async with self._address_update_locks[address]:
self._known_addresses_out_of_sync.discard(address)
@ -500,7 +512,9 @@ class Ledger(metaclass=LedgerRegistry):
if local_status == remote_status:
return True
remote_history = await self.network.retriable_call(self.network.get_history, address)
# remote_history = await self.network.retriable_call(self.network.get_history, address, client)
remote_history = await self.network.get_history(address, session=client)
remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history))
we_need = set(remote_history) - set(local_history)
if not we_need:
@ -563,7 +577,7 @@ class Ledger(metaclass=LedgerRegistry):
"request %i transactions, %i/%i for %s are already synced", len(to_request), len(synced_txs),
len(remote_history), address
)
requested_txes = await self._request_transaction_batch(to_request, len(remote_history), address)
requested_txes = await self._request_transaction_batch(to_request, len(remote_history), address, client)
for tx in requested_txes:
pending_synced_history[tx_indexes[tx.id]] = f"{tx.id}:{tx.height}:"
synced_txs.append(tx)
@ -603,7 +617,7 @@ class Ledger(metaclass=LedgerRegistry):
if self._tx_cache.get(txid) is not cache_item:
log.warning("tx cache corrupted while syncing %s, reattempt sync=%s", address, reattempt_update)
if reattempt_update:
return await self.update_history(address, remote_status, address_manager, False)
return await self.update_history(client, address, remote_status, address_manager, False)
return False
local_status, local_history = \
@ -741,7 +755,7 @@ class Ledger(metaclass=LedgerRegistry):
await _single_batch(batch)
return transactions
async def _request_transaction_batch(self, to_request, remote_history_size, address):
async def _request_transaction_batch(self, to_request, remote_history_size, address, session):
header_cache = {}
batches = [[]]
remote_heights = {}
@ -766,7 +780,8 @@ class Ledger(metaclass=LedgerRegistry):
async def _single_batch(batch):
this_batch_synced = []
batch_result = await self.network.retriable_call(self.network.get_transaction_batch, batch)
# batch_result = await self.network.retriable_call(self.network.get_transaction_batch, batch, session)
batch_result = await self.network.get_transaction_batch(batch, session)
for txid, (raw, merkle) in batch_result.items():
remote_height = remote_heights[txid]
merkle_height = merkle['block_height']

View file

@ -6,7 +6,6 @@ from operator import itemgetter
from contextlib import asynccontextmanager
from functools import partial
from typing import Dict, Optional, Tuple
import aiohttp
from lbry import __version__
@ -14,6 +13,7 @@ from lbry.error import IncompatibleWalletServerError
from lbry.wallet.rpc import RPCSession as BaseClientSession, Connector, RPCError, ProtocolError
from lbry.wallet.stream import StreamController
log = logging.getLogger(__name__)
@ -33,11 +33,26 @@ class ClientSession(BaseClientSession):
self.pending_amount = 0
self._on_connect_cb = on_connect_callback or (lambda: None)
self.trigger_urgent_reconnect = asyncio.Event()
self._connected = asyncio.Event()
self._disconnected = asyncio.Event()
self._disconnected.set()
self._on_status_controller = StreamController(merge_repeated_events=True)
self.on_status = self._on_status_controller.stream
self.on_status.listen(partial(self.network.ledger.process_status_update, self))
self.subscription_controllers = {
'blockchain.headers.subscribe': self.network._on_header_controller,
'blockchain.address.subscribe': self._on_status_controller,
}
@property
def available(self):
return not self.is_closing() and self.response_time is not None
@property
def is_connected(self) -> bool:
return self._connected.is_set()
@property
def server_address_and_port(self) -> Optional[Tuple[str, int]]:
if not self.transport:
@ -98,7 +113,6 @@ class ClientSession(BaseClientSession):
async def ensure_session(self):
# Handles reconnecting and maintaining a session alive
# TODO: change to 'ping' on newer protocol (above 1.2)
retry_delay = default_delay = 1.0
while True:
try:
@ -144,7 +158,7 @@ class ClientSession(BaseClientSession):
self.connection_latency = perf_counter() - start
async def handle_request(self, request):
controller = self.network.subscription_controllers[request.method]
controller = self.subscription_controllers[request.method]
controller.add(request.args)
def connection_lost(self, exc):
@ -154,6 +168,13 @@ class ClientSession(BaseClientSession):
self.connection_latency = None
self._response_samples = 0
self._on_disconnect_controller.add(True)
self._connected.clear()
self._disconnected.set()
def connection_made(self, transport):
super(ClientSession, self).connection_made(transport)
self._disconnected.clear()
self._connected.set()
class Network:
@ -165,6 +186,7 @@ class Network:
self.ledger = ledger
self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6))
self.client: Optional[ClientSession] = None
self.clients: Dict[str, ClientSession] = {}
self.server_features = None
self._switch_task: Optional[asyncio.Task] = None
self.running = False
@ -177,27 +199,34 @@ class Network:
self._on_header_controller = StreamController(merge_repeated_events=True)
self.on_header = self._on_header_controller.stream
self._on_status_controller = StreamController(merge_repeated_events=True)
self.on_status = self._on_status_controller.stream
self.subscription_controllers = {
'blockchain.headers.subscribe': self._on_header_controller,
'blockchain.address.subscribe': self._on_status_controller,
}
self.aiohttp_session: Optional[aiohttp.ClientSession] = None
def get_wallet_session(self, wallet):
return self.clients[wallet.id]
async def close_wallet_session(self, wallet):
session = self.clients.pop(wallet.id)
ensure_connect_task = self.session_pool.wallet_session_tasks.pop(session)
if ensure_connect_task and not ensure_connect_task.done():
ensure_connect_task.cancel()
await session.close()
@property
def config(self):
return self.ledger.config
async def switch_forever(self):
while self.running:
if self.is_connected:
if self.is_connected():
await self.client.on_disconnected.first
self.server_features = None
self.client = None
continue
self.client = await self.session_pool.wait_for_fastest_session()
log.info("Switching to SPV wallet server: %s:%d", *self.client.server)
try:
@ -220,6 +249,11 @@ class Network:
self._switch_task.add_done_callback(lambda _: log.info("Wallet client switching task stopped."))
self.session_pool.start(self.config['default_servers'])
self.on_header.listen(self._update_remote_height)
await self.ledger.subscribe_accounts()
async def connect_wallets(self, *wallet_ids):
if wallet_ids:
await asyncio.wait([self.session_pool.connect_wallet(wallet_id) for wallet_id in wallet_ids])
async def stop(self):
if self.running:
@ -228,9 +262,19 @@ class Network:
self._switch_task.cancel()
self.session_pool.stop()
@property
def is_connected(self):
return self.client and not self.client.is_closing()
def is_connected(self, wallet_id: str = None):
if wallet_id is None:
if not self.client:
return False
return self.client.is_connected
if wallet_id not in self.clients:
return False
return self.clients[wallet_id].is_connected
def connect_wallet_client(self, wallet_id: str, server: Tuple[str, int]):
client = ClientSession(network=self, server=server)
self.clients[wallet_id] = client
return client
def rpc(self, list_or_method, args, restricted=True, session=None):
session = session or (self.client if restricted else self.session_pool.fastest_session)
@ -256,7 +300,8 @@ class Network:
raise asyncio.CancelledError() # if we got here, we are shutting down
@asynccontextmanager
async def single_call_context(self, function, *args, **kwargs):
async def fastest_connection_context(self):
if not self.is_connected:
log.warning("Wallet server unavailable, waiting for it to come back and retry.")
await self.on_connected.first
@ -264,6 +309,38 @@ class Network:
server = self.session_pool.fastest_session.server
session = ClientSession(network=self, server=server)
async def call_with_reconnect(function, *args, **kwargs):
nonlocal session
while self.running:
if not session.is_connected:
try:
await session.create_connection()
except asyncio.TimeoutError:
if not session.is_connected:
log.warning("Wallet server unavailable, waiting for it to come back and retry.")
await self.on_connected.first
await self.session_pool.wait_for_fastest_session()
server = self.session_pool.fastest_session.server
session = ClientSession(network=self, server=server)
try:
return await partial(function, *args, session_override=session, **kwargs)(*args, **kwargs)
except asyncio.TimeoutError:
log.warning("'%s' failed, retrying", function.__name__)
try:
yield call_with_reconnect
finally:
await session.close()
@asynccontextmanager
async def single_call_context(self, function, *args, **kwargs):
if not self.is_connected():
log.warning("Wallet server unavailable, waiting for it to come back and retry.")
await self.on_connected.first
await self.session_pool.wait_for_fastest_session()
server = self.session_pool.fastest_session.server
session = ClientSession(network=self, server=server)
async def call_with_reconnect(*a, **kw):
while self.running:
if not session.available:
@ -280,71 +357,71 @@ class Network:
def _update_remote_height(self, header_args):
self.remote_height = header_args[0]["height"]
def get_transaction(self, tx_hash, known_height=None):
def get_transaction(self, tx_hash, known_height=None, session=None):
# use any server if its old, otherwise restrict to who gave us the history
restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10
return self.rpc('blockchain.transaction.get', [tx_hash], restricted)
return self.rpc('blockchain.transaction.get', [tx_hash], restricted, session=session)
def get_transaction_batch(self, txids, restricted=True, session=None):
# use any server if its old, otherwise restrict to who gave us the history
return self.rpc('blockchain.transaction.get_batch', txids, restricted, session)
return self.rpc('blockchain.transaction.get_batch', txids, restricted, session=session)
def get_transaction_and_merkle(self, tx_hash, known_height=None):
def get_transaction_and_merkle(self, tx_hash, known_height=None, session=None):
# use any server if its old, otherwise restrict to who gave us the history
restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10
return self.rpc('blockchain.transaction.info', [tx_hash], restricted)
return self.rpc('blockchain.transaction.info', [tx_hash], restricted, session=session)
def get_transaction_height(self, tx_hash, known_height=None):
def get_transaction_height(self, tx_hash, known_height=None, session=None):
restricted = not known_height or 0 > known_height > self.remote_height - 10
return self.rpc('blockchain.transaction.get_height', [tx_hash], restricted)
return self.rpc('blockchain.transaction.get_height', [tx_hash], restricted, session=session)
def get_merkle(self, tx_hash, height):
def get_merkle(self, tx_hash, height, session=None):
restricted = 0 > height > self.remote_height - 10
return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height], restricted)
return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height], restricted, session=session)
def get_headers(self, height, count=10000, b64=False):
restricted = height >= self.remote_height - 100
return self.rpc('blockchain.block.headers', [height, count, 0, b64], restricted)
# --- Subscribes, history and broadcasts are always aimed towards the master client directly
def get_history(self, address):
return self.rpc('blockchain.address.get_history', [address], True)
def get_history(self, address, session=None):
return self.rpc('blockchain.address.get_history', [address], True, session=session)
def broadcast(self, raw_transaction):
return self.rpc('blockchain.transaction.broadcast', [raw_transaction], True)
def broadcast(self, raw_transaction, session=None):
return self.rpc('blockchain.transaction.broadcast', [raw_transaction], True, session=session)
def subscribe_headers(self):
return self.rpc('blockchain.headers.subscribe', [True], True)
async def subscribe_address(self, address, *addresses):
async def subscribe_address(self, session, address, *addresses):
addresses = list((address, ) + addresses)
server_addr_and_port = self.client.server_address_and_port # on disconnect client will be None
server_addr_and_port = session.server_address_and_port # on disconnect client will be None
try:
return await self.rpc('blockchain.address.subscribe', addresses, True)
return await self.rpc('blockchain.address.subscribe', addresses, True, session=session)
except asyncio.TimeoutError:
log.warning(
"timed out subscribing to addresses from %s:%i",
*server_addr_and_port
)
# abort and cancel, we can't lose a subscription, it will happen again on reconnect
if self.client:
self.client.abort()
if session:
session.abort()
raise asyncio.CancelledError()
def unsubscribe_address(self, address):
return self.rpc('blockchain.address.unsubscribe', [address], True)
def unsubscribe_address(self, address, session=None):
return self.rpc('blockchain.address.unsubscribe', [address], True, session=session)
def get_server_features(self):
return self.rpc('server.features', (), restricted=True)
def get_server_features(self, session=None):
return self.rpc('server.features', (), restricted=True, session=session)
def get_claims_by_ids(self, claim_ids):
return self.rpc('blockchain.claimtrie.getclaimsbyids', claim_ids)
def resolve(self, urls, session_override=None):
return self.rpc('blockchain.claimtrie.resolve', urls, False, session_override)
return self.rpc('blockchain.claimtrie.resolve', urls, False, session=session_override)
def claim_search(self, session_override=None, **kwargs):
return self.rpc('blockchain.claimtrie.search', kwargs, False, session_override)
return self.rpc('blockchain.claimtrie.search', kwargs, False, session=session_override)
async def new_resolve(self, server, urls):
message = {"method": "resolve", "params": {"urls": urls, "protobuf": True}}
@ -371,6 +448,7 @@ class SessionPool:
def __init__(self, network: Network, timeout: float):
self.network = network
self.sessions: Dict[ClientSession, Optional[asyncio.Task]] = dict()
self.wallet_session_tasks: Dict[ClientSession, Optional[asyncio.Task]] = dict()
self.timeout = timeout
self.new_connection_event = asyncio.Event()
@ -430,6 +508,23 @@ class SessionPool:
task.add_done_callback(lambda _: self.ensure_connections())
self.sessions[session] = task
async def connect_wallet(self, wallet_id: str):
fastest = await self.wait_for_fastest_session()
connected = asyncio.Event()
if not self.network.is_connected(wallet_id):
session = self.network.connect_wallet_client(wallet_id, fastest.server)
session._on_connect_cb = connected.set
# session._on_connect_cb = self._get_session_connect_callback(session)
else:
connected.set()
session = self.network.clients[wallet_id]
task = self.wallet_session_tasks.get(session, None)
if not task or task.done():
task = asyncio.create_task(session.ensure_session())
# task.add_done_callback(lambda _: self.ensure_connections())
self.wallet_session_tasks[session] = task
await connected.wait()
def start(self, default_servers):
for server in default_servers:
self._connect_session(server)

View file

@ -1227,10 +1227,7 @@ class LBRYElectrumX(SessionBase):
return await self.address_status(hashX)
async def hashX_unsubscribe(self, hashX, alias):
try:
del self.hashX_subs[hashX]
except ValueError:
pass
self.hashX_subs.pop(hashX, None)
def address_to_hashX(self, address):
try:

View file

@ -75,7 +75,7 @@ class ReconnectTests(IntegrationTestCase):
session.trigger_urgent_reconnect.set()
await asyncio.wait_for(self.ledger.network.session_pool.new_connection_event.wait(), timeout=1)
self.assertEqual(2, len(list(self.ledger.network.session_pool.available_sessions)))
self.assertTrue(self.ledger.network.is_connected)
self.assertTrue(self.ledger.network.is_connected())
switch_event = self.ledger.network.on_connected.first
await node2.stop(True)
# secondary down, but primary is ok, do not switch! (switches trigger new on_connected events)
@ -99,7 +99,7 @@ class ReconnectTests(IntegrationTestCase):
address1 = await self.account.receiving.get_or_create_usable_address()
# disconnect and send a new tx, should reconnect and get it
self.ledger.network.client.connection_lost(Exception())
self.assertFalse(self.ledger.network.is_connected)
self.assertFalse(self.ledger.network.is_connected())
sendtxid = await self.blockchain.send_to_address(address1, 1.1337)
await asyncio.wait_for(self.on_transaction_id(sendtxid), 1.0) # mempool
await self.blockchain.generate(1)
@ -122,7 +122,7 @@ class ReconnectTests(IntegrationTestCase):
sendtxid = await self.blockchain.send_to_address(address1, 42)
await self.blockchain.generate(1)
# (this is just so the test doesn't hang forever if it doesn't reconnect)
if not self.ledger.network.is_connected:
if not self.ledger.network.is_connected():
await asyncio.wait_for(self.ledger.network.on_connected.first, timeout=1.0)
# omg, the burned cable still works! torba is fire proof!
await self.ledger.network.get_transaction(sendtxid)
@ -130,11 +130,11 @@ class ReconnectTests(IntegrationTestCase):
async def test_timeout_then_reconnect(self):
# tests that it connects back after some failed attempts
await self.conductor.spv_node.stop()
self.assertFalse(self.ledger.network.is_connected)
self.assertFalse(self.ledger.network.is_connected())
await asyncio.sleep(0.2) # let it retry and fail once
await self.conductor.spv_node.start(self.conductor.blockchain_node)
await self.ledger.network.on_connected.first
self.assertTrue(self.ledger.network.is_connected)
self.assertTrue(self.ledger.network.is_connected())
async def test_online_but_still_unavailable(self):
# Edge case. See issue #2445 for context
@ -179,12 +179,18 @@ class ServerPickingTestCase(AsyncioTestCase):
],
'connect_timeout': 3
})
ledger.accounts = []
async def subscribe_accounts():
pass
ledger.subscribe_accounts = subscribe_accounts
network = Network(ledger)
self.addCleanup(network.stop)
asyncio.ensure_future(network.start())
await asyncio.wait_for(network.on_connected.first, timeout=1)
self.assertTrue(network.is_connected)
self.assertTrue(network.is_connected())
self.assertTupleEqual(network.client.server, ('127.0.0.1', 1337))
self.assertTrue(all([not session.is_closing() for session in network.session_pool.available_sessions]))
# ensure we are connected to all of them after a while

View file

@ -152,8 +152,11 @@ class BasicTransactionTests(IntegrationTestCase):
for batch in range(0, len(sends), 10):
txids = await asyncio.gather(*sends[batch:batch + 10])
await asyncio.wait([self.on_transaction_id(txid) for txid in txids])
remote_status = await self.ledger.network.subscribe_address(address)
self.assertTrue(await self.ledger.update_history(address, remote_status))
client = self.ledger.network.get_wallet_session(self.account.wallet)
remote_status = await self.ledger.network.subscribe_address(client, address)
self.assertTrue(await self.ledger.update_history(client, address, remote_status))
# 20 unconfirmed txs, 10 from blockchain, 10 from local to local
utxos = await self.account.get_utxos()
txs = []
@ -166,12 +169,12 @@ class BasicTransactionTests(IntegrationTestCase):
await self.broadcast(tx)
txs.append(tx)
await asyncio.wait([self.on_transaction_address(tx, address) for tx in txs], timeout=1)
remote_status = await self.ledger.network.subscribe_address(address)
self.assertTrue(await self.ledger.update_history(address, remote_status))
remote_status = await self.ledger.network.subscribe_address(client, address)
self.assertTrue(await self.ledger.update_history(client, address, remote_status))
# server history grows unordered
txid = await self.blockchain.send_to_address(address, 1)
await self.on_transaction_id(txid)
self.assertTrue(await self.ledger.update_history(address, remote_status))
self.assertTrue(await self.ledger.update_history(client, address, remote_status))
self.assertEqual(21, len((await self.ledger.get_local_status_and_history(address))[1]))
self.assertEqual(0, len(self.ledger._known_addresses_out_of_sync))

View file

@ -8,16 +8,14 @@ from lbry.wallet.dewies import dict_values_to_lbc
class WalletCommands(CommandTestCase):
async def test_wallet_create_and_add_subscribe(self):
session = next(iter(self.conductor.spv_node.server.session_mgr.sessions))
self.assertEqual(len(session.hashX_subs), 27)
self.assertSetEqual({0, 27}, {len(session.hashX_subs) for session in self.conductor.spv_node.server.session_mgr.sessions})
wallet = await self.daemon.jsonrpc_wallet_create('foo', create_account=True, single_key=True)
self.assertEqual(len(session.hashX_subs), 28)
self.assertSetEqual({0, 27, 1}, {len(session.hashX_subs) for session in self.conductor.spv_node.server.session_mgr.sessions})
await self.daemon.jsonrpc_wallet_remove(wallet.id)
self.assertEqual(len(session.hashX_subs), 27)
self.assertSetEqual({0, 27}, {len(session.hashX_subs) for session in self.conductor.spv_node.server.session_mgr.sessions})
await self.daemon.jsonrpc_wallet_add(wallet.id)
self.assertEqual(len(session.hashX_subs), 28)
self.assertSetEqual({0, 27, 1}, {len(session.hashX_subs) for session in self.conductor.spv_node.server.session_mgr.sessions})
async def test_wallet_syncing_status(self):
address = await self.daemon.jsonrpc_address_unused()

View file

@ -2,26 +2,33 @@ import asyncio
import lbry
import lbry.wallet
import unittest
from lbry.error import ServerPaymentFeeAboveMaxAllowedError
from lbry.wallet.network import ClientSession
from lbry.testcase import IntegrationTestCase, CommandTestCase
from lbry.testcase import CommandTestCase
from lbry.wallet.orchstr8.node import SPVNode
class TestSessions(IntegrationTestCase):
class MockNetwork:
def __init__(self, ledger):
self.ledger = ledger
self._on_header_controller = None
class TestSessions(CommandTestCase):
"""
Tests that server cleans up stale connections after session timeout and client times out too.
"""
LEDGER = lbry.wallet
async def test_session_bloat_from_socket_timeout(self):
await self.conductor.stop_spv()
await self.ledger.stop()
self.conductor.spv_node.session_timeout = 1
await self.conductor.start_spv()
session = ClientSession(
network=None, server=(self.conductor.spv_node.hostname, self.conductor.spv_node.port), timeout=0.2
network=MockNetwork(self.ledger), server=(self.conductor.spv_node.hostname, self.conductor.spv_node.port),
timeout=0.2
)
await session.create_connection()
await session.send_request('server.banner', ())
@ -46,6 +53,7 @@ class TestSessions(IntegrationTestCase):
class TestUsagePayment(CommandTestCase):
@unittest.skip('freezes ci')
async def test_single_server_payment(self):
wallet_pay_service = self.daemon.component_manager.get_component('wallet_server_payments')
wallet_pay_service.payment_period = 1

View file

@ -16,12 +16,12 @@ class MockNetwork:
self.address = None
self.get_history_called = []
self.get_transaction_called = []
self.is_connected = False
self.is_connected = lambda _: False
def retriable_call(self, function, *args, **kwargs):
return function(*args, **kwargs)
async def get_history(self, address):
async def get_history(self, address, session=None):
self.get_history_called.append(address)
self.address = address
return self.history
@ -40,7 +40,7 @@ class MockNetwork:
merkle = await self.get_merkle(tx_hash, known_height)
return tx, merkle
async def get_transaction_batch(self, txids):
async def get_transaction_batch(self, txids, session=None):
return {
txid: await self.get_transaction_and_merkle(txid)
for txid in txids
@ -111,7 +111,7 @@ class TestSynchronization(LedgerTestCase):
txid2: hexlify(get_transaction(get_output(2)).raw),
txid3: hexlify(get_transaction(get_output(3)).raw),
})
await self.ledger.update_history(address, '')
await self.ledger.update_history(None, address, '')
self.assertListEqual(self.ledger.network.get_history_called, [address])
self.assertListEqual(self.ledger.network.get_transaction_called, [txid1, txid2, txid3])
@ -129,7 +129,7 @@ class TestSynchronization(LedgerTestCase):
self.assertFalse(self.ledger._tx_cache[txid1].tx.is_verified)
self.assertFalse(self.ledger._tx_cache[txid2].tx.is_verified)
self.assertFalse(self.ledger._tx_cache[txid3].tx.is_verified)
await self.ledger.update_history(address, '')
await self.ledger.update_history(None, address, '')
self.assertListEqual(self.ledger.network.get_history_called, [address])
self.assertListEqual(self.ledger.network.get_transaction_called, [])
@ -137,7 +137,7 @@ class TestSynchronization(LedgerTestCase):
self.ledger.network.transaction[txid4] = hexlify(get_transaction(get_output(4)).raw)
self.ledger.network.get_history_called = []
self.ledger.network.get_transaction_called = []
await self.ledger.update_history(address, '')
await self.ledger.update_history(None, address, '')
self.assertListEqual(self.ledger.network.get_history_called, [address])
self.assertListEqual(self.ledger.network.get_transaction_called, [txid4])
address_details = await self.ledger.db.get_address(address=address)