Merge pull request #2418 from lbryio/no_chdir

sync and connection issues
This commit is contained in:
shyba 2019-09-03 11:53:30 -03:00 committed by GitHub
commit 072f1f112e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 244 additions and 135 deletions

View file

@ -199,6 +199,8 @@ class LbryWalletManager(BaseWalletManager):
if not tx: if not tx:
try: try:
raw = await self.ledger.network.get_transaction(txid) raw = await self.ledger.network.get_transaction(txid)
if not raw:
return {'success': False, 'code': 404, 'message': 'transaction not found'}
height = await self.ledger.network.get_transaction_height(txid) height = await self.ledger.network.get_transaction_height(txid)
except CodeMessageError as e: except CodeMessageError as e:
return {'success': False, 'code': e.code, 'message': e.message} return {'success': False, 'code': e.code, 'message': e.message}

View file

@ -1,3 +1,4 @@
import os
import sqlite3 import sqlite3
from typing import Union, Tuple, Set, List from typing import Union, Tuple, Set, List
from itertools import chain from itertools import chain
@ -705,7 +706,8 @@ class LBRYDB(DB):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.sql = SQLDB(self, 'claims.db') path = os.path.join(self.env.db_dir, 'claims.db')
self.sql = SQLDB(self, path)
def close(self): def close(self):
super().close() super().close()

View file

@ -64,9 +64,10 @@ class LBRYSessionManager(SessionManager):
async def start_other(self): async def start_other(self):
self.running = True self.running = True
path = os.path.join(self.env.db_dir, 'claims.db')
args = dict( args = dict(
initializer=reader.initializer, initializer=reader.initializer,
initargs=(self.logger, 'claims.db', self.env.coin.NET, self.env.database_query_timeout, initargs=(self.logger, path, self.env.coin.NET, self.env.database_query_timeout,
self.env.track_metrics) self.env.track_metrics)
) )
if self.env.max_query_workers is not None and self.env.max_query_workers == 0: if self.env.max_query_workers is not None and self.env.max_query_workers == 0:

View file

@ -20,7 +20,9 @@ class TestSessionBloat(IntegrationTestCase):
await self.ledger.stop() await self.ledger.stop()
self.conductor.spv_node.session_timeout = 1 self.conductor.spv_node.session_timeout = 1
await self.conductor.start_spv() await self.conductor.start_spv()
session = ClientSession(network=None, server=self.ledger.network.client.server, timeout=0.2) session = ClientSession(
network=None, server=(self.conductor.spv_node.hostname, self.conductor.spv_node.port), timeout=0.2
)
await session.create_connection() await session.create_connection()
await session.send_request('server.banner', ()) await session.send_request('server.banner', ())
self.assertEqual(len(self.conductor.spv_node.server.session_mgr.sessions), 1) self.assertEqual(len(self.conductor.spv_node.server.session_mgr.sessions), 1)

View file

@ -3,6 +3,7 @@ import asyncio
from unittest.mock import Mock from unittest.mock import Mock
from torba.client.basenetwork import BaseNetwork from torba.client.basenetwork import BaseNetwork
from torba.orchstr8.node import SPVNode
from torba.rpc import RPCSession from torba.rpc import RPCSession
from torba.testcase import IntegrationTestCase, AsyncioTestCase from torba.testcase import IntegrationTestCase, AsyncioTestCase
@ -20,6 +21,26 @@ class ReconnectTests(IntegrationTestCase):
VERBOSITY = logging.WARN VERBOSITY = logging.WARN
async def test_multiple_servers(self):
# we have a secondary node that connects later, so
node2 = SPVNode(self.conductor.spv_module, node_number=2)
self.ledger.network.config['default_servers'].append((node2.hostname, node2.port))
await asyncio.wait_for(self.ledger.stop(), timeout=1)
await asyncio.wait_for(self.ledger.start(), timeout=1)
self.ledger.network.session_pool.new_connection_event.clear()
await node2.start(self.blockchain)
# this is only to speed up the test as retrying would take 4+ seconds
for session in self.ledger.network.session_pool.sessions:
session.trigger_urgent_reconnect.set()
await asyncio.wait_for(self.ledger.network.session_pool.new_connection_event.wait(), timeout=1)
self.assertEqual(2, len(list(self.ledger.network.session_pool.available_sessions)))
self.assertTrue(self.ledger.network.is_connected)
switch_event = self.ledger.network.on_connected.first
await node2.stop(True)
# secondary down, but primary is ok, do not switch! (switches trigger new on_connected events)
with self.assertRaises(asyncio.TimeoutError):
await asyncio.wait_for(switch_event, timeout=1)
async def test_connection_drop_still_receives_events_after_reconnected(self): async def test_connection_drop_still_receives_events_after_reconnected(self):
address1 = await self.account.receiving.get_or_create_usable_address() address1 = await self.account.receiving.get_or_create_usable_address()
# disconnect and send a new tx, should reconnect and get it # disconnect and send a new tx, should reconnect and get it
@ -35,10 +56,11 @@ class ReconnectTests(IntegrationTestCase):
# is it real? are we rich!? let me see this tx... # is it real? are we rich!? let me see this tx...
d = self.ledger.network.get_transaction(sendtxid) d = self.ledger.network.get_transaction(sendtxid)
# what's that smoke on my ethernet cable? oh no! # what's that smoke on my ethernet cable? oh no!
master_client = self.ledger.network.client
self.ledger.network.client.connection_lost(Exception()) self.ledger.network.client.connection_lost(Exception())
with self.assertRaises(asyncio.TimeoutError): with self.assertRaises(asyncio.TimeoutError):
await d await d
self.assertIsNone(self.ledger.network.client.response_time) # response time unknown as it failed self.assertIsNone(master_client.response_time) # response time unknown as it failed
# rich but offline? no way, no water, let's retry # rich but offline? no way, no water, let's retry
with self.assertRaisesRegex(ConnectionError, 'connection is not available'): with self.assertRaisesRegex(ConnectionError, 'connection is not available'):
await self.ledger.network.get_transaction(sendtxid) await self.ledger.network.get_transaction(sendtxid)
@ -104,4 +126,4 @@ class ServerPickingTestCase(AsyncioTestCase):
self.assertTrue(all([not session.is_closing() for session in network.session_pool.available_sessions])) self.assertTrue(all([not session.is_closing() for session in network.session_pool.available_sessions]))
# ensure we are connected to all of them after a while # ensure we are connected to all of them after a while
await asyncio.sleep(1) await asyncio.sleep(1)
self.assertEqual(len(network.session_pool.available_sessions), 3) self.assertEqual(len(list(network.session_pool.available_sessions)), 3)

View file

@ -1,6 +1,9 @@
import logging import logging
import asyncio import asyncio
import random
from itertools import chain from itertools import chain
from random import shuffle
from torba.testcase import IntegrationTestCase from torba.testcase import IntegrationTestCase
from torba.client.util import satoshis_to_coins, coins_to_satoshis from torba.client.util import satoshis_to_coins, coins_to_satoshis
@ -129,3 +132,48 @@ class BasicTransactionTests(IntegrationTestCase):
self.assertEqual(tx.outputs[0].get_address(self.ledger), address2) self.assertEqual(tx.outputs[0].get_address(self.ledger), address2)
self.assertEqual(tx.outputs[0].is_change, False) self.assertEqual(tx.outputs[0].is_change, False)
self.assertEqual(tx.outputs[1].is_change, True) self.assertEqual(tx.outputs[1].is_change, True)
async def test_history_edge_cases(self):
await self.assertBalance(self.account, '0.0')
address = await self.account.receiving.get_or_create_usable_address()
# evil trick: mempool is unsorted on real life, but same order between python instances. reproduce it
original_summary = self.conductor.spv_node.server.mempool.transaction_summaries
async def random_summary(*args, **kwargs):
summary = await original_summary(*args, **kwargs)
if summary and len(summary) > 2:
ordered = summary.copy()
while summary == ordered:
random.shuffle(summary)
return summary
self.conductor.spv_node.server.mempool.transaction_summaries = random_summary
# 10 unconfirmed txs, all from blockchain wallet
sends = list(self.blockchain.send_to_address(address, 10) for _ in range(10))
# use batching to reduce issues with send_to_address on cli
for batch in range(0, len(sends), 10):
txids = await asyncio.gather(*sends[batch:batch + 10])
await asyncio.wait([self.on_transaction_id(txid) for txid in txids])
remote_status = await self.ledger.network.subscribe_address(address)
self.assertTrue(await self.ledger.update_history(address, remote_status))
# 20 unconfirmed txs, 10 from blockchain, 10 from local to local
utxos = await self.account.get_utxos()
txs = []
for utxo in utxos:
tx = await self.ledger.transaction_class.create(
[self.ledger.transaction_class.input_class.spend(utxo)],
[],
[self.account], self.account
)
await self.broadcast(tx)
txs.append(tx)
await asyncio.wait([self.on_transaction_address(tx, address) for tx in txs], timeout=1)
remote_status = await self.ledger.network.subscribe_address(address)
self.assertTrue(await self.ledger.update_history(address, remote_status))
# server history grows unordered
txid = await self.blockchain.send_to_address(address, 1)
await self.on_transaction_id(txid)
self.assertTrue(await self.ledger.update_history(address, remote_status))
self.assertEqual(21, len((await self.ledger.get_local_status_and_history(address))[1]))
self.assertEqual(0, len(self.ledger._known_addresses_out_of_sync))
# should be another test, but it would be too much to setup just for that and it affects sync
self.assertIsNone(await self.ledger.network.retriable_call(self.ledger.network.get_transaction, '1'*64))

View file

@ -29,7 +29,7 @@ class MockNetwork:
async def get_merkle(self, txid, height): async def get_merkle(self, txid, height):
return {'merkle': ['abcd01'], 'pos': 1} return {'merkle': ['abcd01'], 'pos': 1}
async def get_transaction(self, tx_hash): async def get_transaction(self, tx_hash, _=None):
self.get_transaction_called.append(tx_hash) self.get_transaction_called.append(tx_hash)
return self.transaction[tx_hash] return self.transaction[tx_hash]

View file

@ -220,23 +220,26 @@ class SQLiteMixin:
async def open(self): async def open(self):
log.info("connecting to database: %s", self._db_path) log.info("connecting to database: %s", self._db_path)
self.db = await AIOSQLite.connect(self._db_path) self.db = await AIOSQLite.connect(self._db_path, isolation_level=None)
await self.db.executescript(self.CREATE_TABLES_QUERY) await self.db.executescript(self.CREATE_TABLES_QUERY)
async def close(self): async def close(self):
await self.db.close() await self.db.close()
@staticmethod @staticmethod
def _insert_sql(table: str, data: dict, ignore_duplicate: bool = False) -> Tuple[str, List]: def _insert_sql(table: str, data: dict, ignore_duplicate: bool = False,
replace: bool = False) -> Tuple[str, List]:
columns, values = [], [] columns, values = [], []
for column, value in data.items(): for column, value in data.items():
columns.append(column) columns.append(column)
values.append(value) values.append(value)
or_ignore = "" policy = ""
if ignore_duplicate: if ignore_duplicate:
or_ignore = " OR IGNORE" policy = " OR IGNORE"
if replace:
policy = " OR REPLACE"
sql = "INSERT{} INTO {} ({}) VALUES ({})".format( sql = "INSERT{} INTO {} ({}) VALUES ({})".format(
or_ignore, table, ', '.join(columns), ', '.join(['?'] * len(values)) policy, table, ', '.join(columns), ', '.join(['?'] * len(values))
) )
return sql, values return sql, values
@ -348,9 +351,14 @@ class BaseDatabase(SQLiteMixin):
'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified 'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified
}, 'txid = ?', (tx.id,))) }, 'txid = ?', (tx.id,)))
def save_transaction_io(self, tx: BaseTransaction, address, txhash, history): def _transaction_io(self, conn: sqlite3.Connection, tx: BaseTransaction, address, txhash, history):
conn.execute(*self._insert_sql('tx', {
def _transaction(conn: sqlite3.Connection, tx: BaseTransaction, address, txhash, history): 'txid': tx.id,
'raw': sqlite3.Binary(tx.raw),
'height': tx.height,
'position': tx.position,
'is_verified': tx.is_verified
}, replace=True))
for txo in tx.outputs: for txo in tx.outputs:
if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == txhash: if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == txhash:
@ -373,10 +381,17 @@ class BaseDatabase(SQLiteMixin):
conn.execute( conn.execute(
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?", "UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
(history, history.count(':')//2, address) (history, history.count(':') // 2, address)
) )
return self.db.run(_transaction, tx, address, txhash, history) def save_transaction_io(self, tx: BaseTransaction, address, txhash, history):
return self.db.run(self._transaction_io, tx, address, txhash, history)
def save_transaction_io_batch(self, txs: Iterable[BaseTransaction], address, txhash, history):
def __many(conn):
for tx in txs:
self._transaction_io(conn, tx, address, txhash, history)
return self.db.run(__many)
async def reserve_outputs(self, txos, is_reserved=True): async def reserve_outputs(self, txos, is_reserved=True):
txoids = ((is_reserved, txo.id) for txo in txos) txoids = ((is_reserved, txo.id) for txo in txos)

View file

@ -10,6 +10,7 @@ from operator import itemgetter
from collections import namedtuple from collections import namedtuple
import pylru import pylru
from torba.client.basetransaction import BaseTransaction
from torba.tasks import TaskGroup from torba.tasks import TaskGroup
from torba.client import baseaccount, basenetwork, basetransaction from torba.client import baseaccount, basenetwork, basetransaction
from torba.client.basedatabase import BaseDatabase from torba.client.basedatabase import BaseDatabase
@ -142,6 +143,7 @@ class BaseLedger(metaclass=LedgerRegistry):
self._address_update_locks: Dict[str, asyncio.Lock] = {} self._address_update_locks: Dict[str, asyncio.Lock] = {}
self.coin_selection_strategy = None self.coin_selection_strategy = None
self._known_addresses_out_of_sync = set()
@classmethod @classmethod
def get_id(cls): def get_id(cls):
@ -250,7 +252,8 @@ class BaseLedger(metaclass=LedgerRegistry):
self.constraint_account_or_all(constraints) self.constraint_account_or_all(constraints)
return self.db.get_transaction_count(**constraints) return self.db.get_transaction_count(**constraints)
async def get_local_status_and_history(self, address): async def get_local_status_and_history(self, address, history=None):
if not history:
address_details = await self.db.get_address(address=address) address_details = await self.db.get_address(address=address)
history = address_details['history'] or '' history = address_details['history'] or ''
parts = history.split(':')[:-1] parts = history.split(':')[:-1]
@ -284,7 +287,7 @@ class BaseLedger(metaclass=LedgerRegistry):
await self.join_network() await self.join_network()
self.network.on_connected.listen(self.join_network) self.network.on_connected.listen(self.join_network)
async def join_network(self, *args): async def join_network(self, *_):
log.info("Subscribing and updating accounts.") log.info("Subscribing and updating accounts.")
async with self._header_processing_lock: async with self._header_processing_lock:
await self.update_headers() await self.update_headers()
@ -411,24 +414,31 @@ class BaseLedger(metaclass=LedgerRegistry):
address_manager: baseaccount.AddressManager = None): address_manager: baseaccount.AddressManager = None):
async with self._address_update_locks.setdefault(address, asyncio.Lock()): async with self._address_update_locks.setdefault(address, asyncio.Lock()):
self._known_addresses_out_of_sync.discard(address)
local_status, local_history = await self.get_local_status_and_history(address) local_status, local_history = await self.get_local_status_and_history(address)
if local_status == remote_status: if local_status == remote_status:
return return True
remote_history = await self.network.retriable_call(self.network.get_history, address) remote_history = await self.network.retriable_call(self.network.get_history, address)
remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history))
we_need = set(remote_history) - set(local_history)
if not we_need:
return True
cache_tasks = [] cache_tasks: List[asyncio.Future[BaseTransaction]] = []
synced_history = StringIO() synced_history = StringIO()
for i, (txid, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)): for i, (txid, remote_height) in enumerate(remote_history):
if i < len(local_history) and local_history[i] == (txid, remote_height): if i < len(local_history) and local_history[i] == (txid, remote_height) and not cache_tasks:
synced_history.write(f'{txid}:{remote_height}:') synced_history.write(f'{txid}:{remote_height}:')
else: else:
check_local = (txid, remote_height) not in we_need
cache_tasks.append(asyncio.ensure_future( cache_tasks.append(asyncio.ensure_future(
self.cache_transaction(txid, remote_height) self.cache_transaction(txid, remote_height, check_local=check_local)
)) ))
synced_txs = []
for task in cache_tasks: for task in cache_tasks:
tx = await task tx = await task
@ -457,12 +467,15 @@ class BaseLedger(metaclass=LedgerRegistry):
txi.txo_ref = referenced_txo.ref txi.txo_ref = referenced_txo.ref
synced_history.write(f'{tx.id}:{tx.height}:') synced_history.write(f'{tx.id}:{tx.height}:')
synced_txs.append(tx)
await self.db.save_transaction_io( await self.db.save_transaction_io_batch(
tx, address, self.address_to_hash160(address), synced_history.getvalue() synced_txs, address, self.address_to_hash160(address), synced_history.getvalue()
) )
await asyncio.wait([
await self._on_transaction_controller.add(TransactionEvent(address, tx)) self._on_transaction_controller.add(TransactionEvent(address, tx))
for tx in synced_txs
])
if address_manager is None: if address_manager is None:
address_manager = await self.get_address_manager_for_address(address) address_manager = await self.get_address_manager_for_address(address)
@ -470,18 +483,23 @@ class BaseLedger(metaclass=LedgerRegistry):
if address_manager is not None: if address_manager is not None:
await address_manager.ensure_address_gap() await address_manager.ensure_address_gap()
local_status, local_history = await self.get_local_status_and_history(address) local_status, local_history = \
await self.get_local_status_and_history(address, synced_history.getvalue())
if local_status != remote_status: if local_status != remote_status:
log.debug( if local_history == remote_history:
return True
log.warning(
"Wallet is out of sync after syncing. Remote: %s with %d items, local: %s with %d items", "Wallet is out of sync after syncing. Remote: %s with %d items, local: %s with %d items",
remote_status, len(remote_history), local_status, len(local_history) remote_status, len(remote_history), local_status, len(local_history)
) )
log.debug("local: %s", local_history) log.warning("local: %s", local_history)
log.debug("remote: %s", remote_history) log.warning("remote: %s", remote_history)
self._known_addresses_out_of_sync.add(address)
return False
else: else:
log.debug("Sync completed for: %s", address) return True
async def cache_transaction(self, txid, remote_height): async def cache_transaction(self, txid, remote_height, check_local=True):
cache_item = self._tx_cache.get(txid) cache_item = self._tx_cache.get(txid)
if cache_item is None: if cache_item is None:
cache_item = self._tx_cache[txid] = TransactionCacheItem() cache_item = self._tx_cache[txid] = TransactionCacheItem()
@ -494,28 +512,21 @@ class BaseLedger(metaclass=LedgerRegistry):
tx = cache_item.tx tx = cache_item.tx
if tx is None: if tx is None and check_local:
# check local db # check local db
tx = cache_item.tx = await self.db.get_transaction(txid=txid) tx = cache_item.tx = await self.db.get_transaction(txid=txid)
if tx is None: if tx is None:
# fetch from network # fetch from network
_raw = await self.network.retriable_call(self.network.get_transaction, txid) _raw = await self.network.retriable_call(self.network.get_transaction, txid, remote_height)
if _raw: if _raw:
tx = self.transaction_class(unhexlify(_raw)) tx = self.transaction_class(unhexlify(_raw))
await self.maybe_verify_transaction(tx, remote_height)
await self.db.insert_transaction(tx)
cache_item.tx = tx # make sure it's saved before caching it cache_item.tx = tx # make sure it's saved before caching it
return tx
if tx is None: if tx is None:
raise ValueError(f'Transaction {txid} was not in database and not on network.') raise ValueError(f'Transaction {txid} was not in database and not on network.')
if remote_height > 0 and not tx.is_verified:
# tx from cache / db is not up-to-date
await self.maybe_verify_transaction(tx, remote_height) await self.maybe_verify_transaction(tx, remote_height)
await self.db.update_transaction(tx)
return tx return tx
async def maybe_verify_transaction(self, tx, remote_height): async def maybe_verify_transaction(self, tx, remote_height):

View file

@ -4,7 +4,7 @@ from operator import itemgetter
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
from time import perf_counter from time import perf_counter
from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError, ProtocolError
from torba import __version__ from torba import __version__
from torba.stream import StreamController from torba.stream import StreamController
@ -30,11 +30,11 @@ class ClientSession(BaseClientSession):
self._on_connect_cb = on_connect_callback or (lambda: None) self._on_connect_cb = on_connect_callback or (lambda: None)
self.trigger_urgent_reconnect = asyncio.Event() self.trigger_urgent_reconnect = asyncio.Event()
# one request per second of timeout, conservative default # one request per second of timeout, conservative default
self._semaphore = asyncio.Semaphore(self.timeout) self._semaphore = asyncio.Semaphore(self.timeout * 2)
@property @property
def available(self): def available(self):
return not self.is_closing() and self._can_send.is_set() and self.response_time is not None return not self.is_closing() and self.response_time is not None
@property @property
def server_address_and_port(self) -> Optional[Tuple[str, int]]: def server_address_and_port(self) -> Optional[Tuple[str, int]]:
@ -71,7 +71,10 @@ class ClientSession(BaseClientSession):
) )
log.debug("got reply for %s from %s:%i", method, *self.server) log.debug("got reply for %s from %s:%i", method, *self.server)
return reply return reply
except RPCError as e: except (RPCError, ProtocolError) as e:
if str(e).find('.*no such .*transaction.*'):
# shouldnt the server return none instead?
return None
log.warning("Wallet server (%s:%i) returned an error. Code: %s Message: %s", log.warning("Wallet server (%s:%i) returned an error. Code: %s Message: %s",
*self.server, *e.args) *self.server, *e.args)
raise e raise e
@ -144,6 +147,7 @@ class BaseNetwork:
self.config = ledger.config self.config = ledger.config
self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6)) self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6))
self.client: Optional[ClientSession] = None self.client: Optional[ClientSession] = None
self._switch_task: Optional[asyncio.Task] = None
self.running = False self.running = False
self.remote_height: int = 0 self.remote_height: int = 0
@ -161,51 +165,41 @@ class BaseNetwork:
'blockchain.address.subscribe': self._on_status_controller, 'blockchain.address.subscribe': self._on_status_controller,
} }
async def switch_to_fastest(self): async def switch_forever(self):
try: while self.running:
client = await asyncio.wait_for(self.session_pool.wait_for_fastest_session(), 30) if self.is_connected:
except asyncio.TimeoutError: await self.client.on_disconnected.first
if self.client:
await self.client.close()
self.client = None self.client = None
for session in self.session_pool.sessions: continue
session.synchronous_close() self.client = await self.session_pool.wait_for_fastest_session()
log.warning("not connected to any wallet servers")
return
current_client = self.client
self.client = client
log.info("Switching to SPV wallet server: %s:%d", *self.client.server) log.info("Switching to SPV wallet server: %s:%d", *self.client.server)
self._on_connected_controller.add(True) self._on_connected_controller.add(True)
try: try:
self._update_remote_height((await self.subscribe_headers(),)) self._update_remote_height((await self.subscribe_headers(),))
log.info("Subscribed to headers: %s:%d", *self.client.server) log.info("Subscribed to headers: %s:%d", *self.client.server)
except asyncio.TimeoutError: except asyncio.TimeoutError:
if self.client: log.info("Switching to %s:%d timed out, closing and retrying.")
await self.client.close() self.client.synchronous_close()
self.client = current_client self.client = None
return
self.session_pool.new_connection_event.clear()
return await self.session_pool.new_connection_event.wait()
async def start(self): async def start(self):
self.running = True self.running = True
self._switch_task = asyncio.ensure_future(self.switch_forever())
self.session_pool.start(self.config['default_servers']) self.session_pool.start(self.config['default_servers'])
self.on_header.listen(self._update_remote_height) self.on_header.listen(self._update_remote_height)
while self.running:
await self.switch_to_fastest()
async def stop(self): async def stop(self):
if self.running:
self.running = False self.running = False
self._switch_task.cancel()
self.session_pool.stop() self.session_pool.stop()
@property @property
def is_connected(self): def is_connected(self):
return self.client and not self.client.is_closing() return self.client and not self.client.is_closing()
def rpc(self, list_or_method, args, session=None): def rpc(self, list_or_method, args, restricted=True):
# fixme: use fastest unloaded session, but for now it causes issues with wallet sync session = self.client if restricted else self.session_pool.fastest_session
# session = session or self.session_pool.fastest_session
session = self.client
if session and not session.is_closing(): if session and not session.is_closing():
return session.send_request(list_or_method, args) return session.send_request(list_or_method, args)
else: else:
@ -229,31 +223,35 @@ class BaseNetwork:
def _update_remote_height(self, header_args): def _update_remote_height(self, header_args):
self.remote_height = header_args[0]["height"] self.remote_height = header_args[0]["height"]
def get_transaction(self, tx_hash): def get_transaction(self, tx_hash, known_height=None):
return self.rpc('blockchain.transaction.get', [tx_hash]) # use any server if its old, otherwise restrict to who gave us the history
restricted = not known_height or 0 > known_height > self.remote_height - 10
return self.rpc('blockchain.transaction.get', [tx_hash], restricted)
def get_transaction_height(self, tx_hash): def get_transaction_height(self, tx_hash, known_height=None):
return self.rpc('blockchain.transaction.get_height', [tx_hash]) restricted = not known_height or 0 > known_height > self.remote_height - 10
return self.rpc('blockchain.transaction.get_height', [tx_hash], restricted)
def get_merkle(self, tx_hash, height): def get_merkle(self, tx_hash, height):
return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height]) restricted = 0 > height > self.remote_height - 10
return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height], restricted)
def get_headers(self, height, count=10000): def get_headers(self, height, count=10000):
return self.rpc('blockchain.block.headers', [height, count]) return self.rpc('blockchain.block.headers', [height, count])
# --- Subscribes, history and broadcasts are always aimed towards the master client directly # --- Subscribes, history and broadcasts are always aimed towards the master client directly
def get_history(self, address): def get_history(self, address):
return self.rpc('blockchain.address.get_history', [address], session=self.client) return self.rpc('blockchain.address.get_history', [address], True)
def broadcast(self, raw_transaction): def broadcast(self, raw_transaction):
return self.rpc('blockchain.transaction.broadcast', [raw_transaction], session=self.client) return self.rpc('blockchain.transaction.broadcast', [raw_transaction], True)
def subscribe_headers(self): def subscribe_headers(self):
return self.rpc('blockchain.headers.subscribe', [True], session=self.client) return self.rpc('blockchain.headers.subscribe', [True], True)
async def subscribe_address(self, address): async def subscribe_address(self, address):
try: try:
return await self.rpc('blockchain.address.subscribe', [address], session=self.client) return await self.rpc('blockchain.address.subscribe', [address], True)
except asyncio.TimeoutError: except asyncio.TimeoutError:
# abort and cancel, we cant lose a subscription, it will happen again on reconnect # abort and cancel, we cant lose a subscription, it will happen again on reconnect
self.client.abort() self.client.abort()
@ -274,11 +272,11 @@ class SessionPool:
@property @property
def available_sessions(self): def available_sessions(self):
return [session for session in self.sessions if session.available] return (session for session in self.sessions if session.available)
@property @property
def fastest_session(self): def fastest_session(self):
if not self.available_sessions: if not self.online:
return None return None
return min( return min(
[((session.response_time + session.connection_latency) * (session.pending_amount + 1), session) [((session.response_time + session.connection_latency) * (session.pending_amount + 1), session)
@ -329,8 +327,9 @@ class SessionPool:
self._connect_session(server) self._connect_session(server)
def stop(self): def stop(self):
for task in self.sessions.values(): for session, task in self.sessions.items():
task.cancel() task.cancel()
session.synchronous_close()
self.sessions.clear() self.sessions.clear()
def ensure_connections(self): def ensure_connections(self):

View file

@ -190,14 +190,15 @@ class WalletNode:
class SPVNode: class SPVNode:
def __init__(self, coin_class): def __init__(self, coin_class, node_number=1):
self.coin_class = coin_class self.coin_class = coin_class
self.controller = None self.controller = None
self.data_path = None self.data_path = None
self.server = None self.server = None
self.hostname = 'localhost' self.hostname = 'localhost'
self.port = 50001 + 1 # avoid conflict with default daemon self.port = 50001 + node_number # avoid conflict with default daemon
self.session_timeout = 600 self.session_timeout = 600
self.rpc_port = '0' # disabled by default
async def start(self, blockchain_node: 'BlockchainNode'): async def start(self, blockchain_node: 'BlockchainNode'):
self.data_path = tempfile.mkdtemp() self.data_path = tempfile.mkdtemp()
@ -210,6 +211,7 @@ class SPVNode:
'SESSION_TIMEOUT': str(self.session_timeout), 'SESSION_TIMEOUT': str(self.session_timeout),
'MAX_QUERY_WORKERS': '0', 'MAX_QUERY_WORKERS': '0',
'INDIVIDUAL_TAG_INDEXES': '', 'INDIVIDUAL_TAG_INDEXES': '',
'RPC_PORT': self.rpc_port
} }
# TODO: don't use os.environ # TODO: don't use os.environ
os.environ.update(conf) os.environ.update(conf)

View file

@ -17,6 +17,7 @@ import time
from asyncio import sleep from asyncio import sleep
from bisect import bisect_right from bisect import bisect_right
from collections import namedtuple from collections import namedtuple
from functools import partial
from glob import glob from glob import glob
from struct import pack, unpack from struct import pack, unpack
@ -72,9 +73,8 @@ class DB:
self.header_len = self.dynamic_header_len self.header_len = self.dynamic_header_len
self.logger.info(f'switching current directory to {env.db_dir}') self.logger.info(f'switching current directory to {env.db_dir}')
os.chdir(env.db_dir)
self.db_class = db_class(self.env.db_engine) self.db_class = db_class(env.db_dir, self.env.db_engine)
self.history = History() self.history = History()
self.utxo_db = None self.utxo_db = None
self.tx_counts = None self.tx_counts = None
@ -86,12 +86,13 @@ class DB:
self.merkle = Merkle() self.merkle = Merkle()
self.header_mc = MerkleCache(self.merkle, self.fs_block_hashes) self.header_mc = MerkleCache(self.merkle, self.fs_block_hashes)
self.headers_file = util.LogicalFile('meta/headers', 2, 16000000) path = partial(os.path.join, self.env.db_dir)
self.tx_counts_file = util.LogicalFile('meta/txcounts', 2, 2000000) self.headers_file = util.LogicalFile(path('meta/headers'), 2, 16000000)
self.hashes_file = util.LogicalFile('meta/hashes', 4, 16000000) self.tx_counts_file = util.LogicalFile(path('meta/txcounts'), 2, 2000000)
self.hashes_file = util.LogicalFile(path('meta/hashes'), 4, 16000000)
if not self.coin.STATIC_BLOCK_HEADERS: if not self.coin.STATIC_BLOCK_HEADERS:
self.headers_offsets_file = util.LogicalFile( self.headers_offsets_file = util.LogicalFile(
'meta/headers_offsets', 2, 16000000) path('meta/headers_offsets'), 2, 16000000)
async def _read_tx_counts(self): async def _read_tx_counts(self):
if self.tx_counts is not None: if self.tx_counts is not None:
@ -115,8 +116,9 @@ class DB:
if self.utxo_db.is_new: if self.utxo_db.is_new:
self.logger.info('created new database') self.logger.info('created new database')
self.logger.info('creating metadata directory') self.logger.info('creating metadata directory')
os.mkdir('meta') os.mkdir(os.path.join(self.env.db_dir, 'meta'))
with util.open_file('COIN', create=True) as f: coin_path = os.path.join(self.env.db_dir, 'meta', 'COIN')
with util.open_file(coin_path, create=True) as f:
f.write(f'ElectrumX databases and metadata for ' f.write(f'ElectrumX databases and metadata for '
f'{self.coin.NAME} {self.coin.NET}'.encode()) f'{self.coin.NAME} {self.coin.NET}'.encode())
if not self.coin.STATIC_BLOCK_HEADERS: if not self.coin.STATIC_BLOCK_HEADERS:
@ -474,7 +476,7 @@ class DB:
return 'meta/block' return 'meta/block'
def raw_block_path(self, height): def raw_block_path(self, height):
return f'{self.raw_block_prefix()}{height:d}' return os.path.join(self.env.db_dir, f'{self.raw_block_prefix()}{height:d}')
def read_raw_block(self, height): def read_raw_block(self, height):
"""Returns a raw block read from disk. Raises FileNotFoundError """Returns a raw block read from disk. Raises FileNotFoundError

View file

@ -13,20 +13,21 @@ from functools import partial
from torba.server import util from torba.server import util
def db_class(name): def db_class(db_dir, name):
"""Returns a DB engine class.""" """Returns a DB engine class."""
for db_class in util.subclasses(Storage): for db_class in util.subclasses(Storage):
if db_class.__name__.lower() == name.lower(): if db_class.__name__.lower() == name.lower():
db_class.import_module() db_class.import_module()
return db_class return partial(db_class, db_dir)
raise RuntimeError('unrecognised DB engine "{}"'.format(name)) raise RuntimeError('unrecognised DB engine "{}"'.format(name))
class Storage: class Storage:
"""Abstract base class of the DB backend abstraction.""" """Abstract base class of the DB backend abstraction."""
def __init__(self, name, for_sync): def __init__(self, db_dir, name, for_sync):
self.is_new = not os.path.exists(name) self.db_dir = db_dir
self.is_new = not os.path.exists(os.path.join(db_dir, name))
self.for_sync = for_sync or self.is_new self.for_sync = for_sync or self.is_new
self.open(name, create=self.is_new) self.open(name, create=self.is_new)
@ -78,8 +79,9 @@ class LevelDB(Storage):
def open(self, name, create): def open(self, name, create):
mof = 512 if self.for_sync else 128 mof = 512 if self.for_sync else 128
path = os.path.join(self.db_dir, name)
# Use snappy compression (the default) # Use snappy compression (the default)
self.db = self.module.DB(name, create_if_missing=create, self.db = self.module.DB(path, create_if_missing=create,
max_open_files=mof) max_open_files=mof)
self.close = self.db.close self.close = self.db.close
self.get = self.db.get self.get = self.db.get
@ -99,12 +101,13 @@ class RocksDB(Storage):
def open(self, name, create): def open(self, name, create):
mof = 512 if self.for_sync else 128 mof = 512 if self.for_sync else 128
path = os.path.join(self.db_dir, name)
# Use snappy compression (the default) # Use snappy compression (the default)
options = self.module.Options(create_if_missing=create, options = self.module.Options(create_if_missing=create,
use_fsync=True, use_fsync=True,
target_file_size_base=33554432, target_file_size_base=33554432,
max_open_files=mof) max_open_files=mof)
self.db = self.module.DB(name, options) self.db = self.module.DB(path, options)
self.get = self.db.get self.get = self.db.get
self.put = self.db.put self.put = self.db.put