forked from LBRYCommunity/lbry-sdk
Merge pull request #2418 from lbryio/no_chdir
sync and connection issues
This commit is contained in:
commit
072f1f112e
13 changed files with 244 additions and 135 deletions
|
@ -199,6 +199,8 @@ class LbryWalletManager(BaseWalletManager):
|
||||||
if not tx:
|
if not tx:
|
||||||
try:
|
try:
|
||||||
raw = await self.ledger.network.get_transaction(txid)
|
raw = await self.ledger.network.get_transaction(txid)
|
||||||
|
if not raw:
|
||||||
|
return {'success': False, 'code': 404, 'message': 'transaction not found'}
|
||||||
height = await self.ledger.network.get_transaction_height(txid)
|
height = await self.ledger.network.get_transaction_height(txid)
|
||||||
except CodeMessageError as e:
|
except CodeMessageError as e:
|
||||||
return {'success': False, 'code': e.code, 'message': e.message}
|
return {'success': False, 'code': e.code, 'message': e.message}
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import os
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from typing import Union, Tuple, Set, List
|
from typing import Union, Tuple, Set, List
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
@ -705,7 +706,8 @@ class LBRYDB(DB):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.sql = SQLDB(self, 'claims.db')
|
path = os.path.join(self.env.db_dir, 'claims.db')
|
||||||
|
self.sql = SQLDB(self, path)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
super().close()
|
super().close()
|
||||||
|
|
|
@ -64,9 +64,10 @@ class LBRYSessionManager(SessionManager):
|
||||||
|
|
||||||
async def start_other(self):
|
async def start_other(self):
|
||||||
self.running = True
|
self.running = True
|
||||||
|
path = os.path.join(self.env.db_dir, 'claims.db')
|
||||||
args = dict(
|
args = dict(
|
||||||
initializer=reader.initializer,
|
initializer=reader.initializer,
|
||||||
initargs=(self.logger, 'claims.db', self.env.coin.NET, self.env.database_query_timeout,
|
initargs=(self.logger, path, self.env.coin.NET, self.env.database_query_timeout,
|
||||||
self.env.track_metrics)
|
self.env.track_metrics)
|
||||||
)
|
)
|
||||||
if self.env.max_query_workers is not None and self.env.max_query_workers == 0:
|
if self.env.max_query_workers is not None and self.env.max_query_workers == 0:
|
||||||
|
|
|
@ -20,7 +20,9 @@ class TestSessionBloat(IntegrationTestCase):
|
||||||
await self.ledger.stop()
|
await self.ledger.stop()
|
||||||
self.conductor.spv_node.session_timeout = 1
|
self.conductor.spv_node.session_timeout = 1
|
||||||
await self.conductor.start_spv()
|
await self.conductor.start_spv()
|
||||||
session = ClientSession(network=None, server=self.ledger.network.client.server, timeout=0.2)
|
session = ClientSession(
|
||||||
|
network=None, server=(self.conductor.spv_node.hostname, self.conductor.spv_node.port), timeout=0.2
|
||||||
|
)
|
||||||
await session.create_connection()
|
await session.create_connection()
|
||||||
await session.send_request('server.banner', ())
|
await session.send_request('server.banner', ())
|
||||||
self.assertEqual(len(self.conductor.spv_node.server.session_mgr.sessions), 1)
|
self.assertEqual(len(self.conductor.spv_node.server.session_mgr.sessions), 1)
|
||||||
|
|
|
@ -3,6 +3,7 @@ import asyncio
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from torba.client.basenetwork import BaseNetwork
|
from torba.client.basenetwork import BaseNetwork
|
||||||
|
from torba.orchstr8.node import SPVNode
|
||||||
from torba.rpc import RPCSession
|
from torba.rpc import RPCSession
|
||||||
from torba.testcase import IntegrationTestCase, AsyncioTestCase
|
from torba.testcase import IntegrationTestCase, AsyncioTestCase
|
||||||
|
|
||||||
|
@ -20,6 +21,26 @@ class ReconnectTests(IntegrationTestCase):
|
||||||
|
|
||||||
VERBOSITY = logging.WARN
|
VERBOSITY = logging.WARN
|
||||||
|
|
||||||
|
async def test_multiple_servers(self):
|
||||||
|
# we have a secondary node that connects later, so
|
||||||
|
node2 = SPVNode(self.conductor.spv_module, node_number=2)
|
||||||
|
self.ledger.network.config['default_servers'].append((node2.hostname, node2.port))
|
||||||
|
await asyncio.wait_for(self.ledger.stop(), timeout=1)
|
||||||
|
await asyncio.wait_for(self.ledger.start(), timeout=1)
|
||||||
|
self.ledger.network.session_pool.new_connection_event.clear()
|
||||||
|
await node2.start(self.blockchain)
|
||||||
|
# this is only to speed up the test as retrying would take 4+ seconds
|
||||||
|
for session in self.ledger.network.session_pool.sessions:
|
||||||
|
session.trigger_urgent_reconnect.set()
|
||||||
|
await asyncio.wait_for(self.ledger.network.session_pool.new_connection_event.wait(), timeout=1)
|
||||||
|
self.assertEqual(2, len(list(self.ledger.network.session_pool.available_sessions)))
|
||||||
|
self.assertTrue(self.ledger.network.is_connected)
|
||||||
|
switch_event = self.ledger.network.on_connected.first
|
||||||
|
await node2.stop(True)
|
||||||
|
# secondary down, but primary is ok, do not switch! (switches trigger new on_connected events)
|
||||||
|
with self.assertRaises(asyncio.TimeoutError):
|
||||||
|
await asyncio.wait_for(switch_event, timeout=1)
|
||||||
|
|
||||||
async def test_connection_drop_still_receives_events_after_reconnected(self):
|
async def test_connection_drop_still_receives_events_after_reconnected(self):
|
||||||
address1 = await self.account.receiving.get_or_create_usable_address()
|
address1 = await self.account.receiving.get_or_create_usable_address()
|
||||||
# disconnect and send a new tx, should reconnect and get it
|
# disconnect and send a new tx, should reconnect and get it
|
||||||
|
@ -35,10 +56,11 @@ class ReconnectTests(IntegrationTestCase):
|
||||||
# is it real? are we rich!? let me see this tx...
|
# is it real? are we rich!? let me see this tx...
|
||||||
d = self.ledger.network.get_transaction(sendtxid)
|
d = self.ledger.network.get_transaction(sendtxid)
|
||||||
# what's that smoke on my ethernet cable? oh no!
|
# what's that smoke on my ethernet cable? oh no!
|
||||||
|
master_client = self.ledger.network.client
|
||||||
self.ledger.network.client.connection_lost(Exception())
|
self.ledger.network.client.connection_lost(Exception())
|
||||||
with self.assertRaises(asyncio.TimeoutError):
|
with self.assertRaises(asyncio.TimeoutError):
|
||||||
await d
|
await d
|
||||||
self.assertIsNone(self.ledger.network.client.response_time) # response time unknown as it failed
|
self.assertIsNone(master_client.response_time) # response time unknown as it failed
|
||||||
# rich but offline? no way, no water, let's retry
|
# rich but offline? no way, no water, let's retry
|
||||||
with self.assertRaisesRegex(ConnectionError, 'connection is not available'):
|
with self.assertRaisesRegex(ConnectionError, 'connection is not available'):
|
||||||
await self.ledger.network.get_transaction(sendtxid)
|
await self.ledger.network.get_transaction(sendtxid)
|
||||||
|
@ -104,4 +126,4 @@ class ServerPickingTestCase(AsyncioTestCase):
|
||||||
self.assertTrue(all([not session.is_closing() for session in network.session_pool.available_sessions]))
|
self.assertTrue(all([not session.is_closing() for session in network.session_pool.available_sessions]))
|
||||||
# ensure we are connected to all of them after a while
|
# ensure we are connected to all of them after a while
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
self.assertEqual(len(network.session_pool.available_sessions), 3)
|
self.assertEqual(len(list(network.session_pool.available_sessions)), 3)
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import random
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
from random import shuffle
|
||||||
|
|
||||||
from torba.testcase import IntegrationTestCase
|
from torba.testcase import IntegrationTestCase
|
||||||
from torba.client.util import satoshis_to_coins, coins_to_satoshis
|
from torba.client.util import satoshis_to_coins, coins_to_satoshis
|
||||||
|
|
||||||
|
@ -129,3 +132,48 @@ class BasicTransactionTests(IntegrationTestCase):
|
||||||
self.assertEqual(tx.outputs[0].get_address(self.ledger), address2)
|
self.assertEqual(tx.outputs[0].get_address(self.ledger), address2)
|
||||||
self.assertEqual(tx.outputs[0].is_change, False)
|
self.assertEqual(tx.outputs[0].is_change, False)
|
||||||
self.assertEqual(tx.outputs[1].is_change, True)
|
self.assertEqual(tx.outputs[1].is_change, True)
|
||||||
|
|
||||||
|
async def test_history_edge_cases(self):
|
||||||
|
await self.assertBalance(self.account, '0.0')
|
||||||
|
address = await self.account.receiving.get_or_create_usable_address()
|
||||||
|
# evil trick: mempool is unsorted on real life, but same order between python instances. reproduce it
|
||||||
|
original_summary = self.conductor.spv_node.server.mempool.transaction_summaries
|
||||||
|
|
||||||
|
async def random_summary(*args, **kwargs):
|
||||||
|
summary = await original_summary(*args, **kwargs)
|
||||||
|
if summary and len(summary) > 2:
|
||||||
|
ordered = summary.copy()
|
||||||
|
while summary == ordered:
|
||||||
|
random.shuffle(summary)
|
||||||
|
return summary
|
||||||
|
self.conductor.spv_node.server.mempool.transaction_summaries = random_summary
|
||||||
|
# 10 unconfirmed txs, all from blockchain wallet
|
||||||
|
sends = list(self.blockchain.send_to_address(address, 10) for _ in range(10))
|
||||||
|
# use batching to reduce issues with send_to_address on cli
|
||||||
|
for batch in range(0, len(sends), 10):
|
||||||
|
txids = await asyncio.gather(*sends[batch:batch + 10])
|
||||||
|
await asyncio.wait([self.on_transaction_id(txid) for txid in txids])
|
||||||
|
remote_status = await self.ledger.network.subscribe_address(address)
|
||||||
|
self.assertTrue(await self.ledger.update_history(address, remote_status))
|
||||||
|
# 20 unconfirmed txs, 10 from blockchain, 10 from local to local
|
||||||
|
utxos = await self.account.get_utxos()
|
||||||
|
txs = []
|
||||||
|
for utxo in utxos:
|
||||||
|
tx = await self.ledger.transaction_class.create(
|
||||||
|
[self.ledger.transaction_class.input_class.spend(utxo)],
|
||||||
|
[],
|
||||||
|
[self.account], self.account
|
||||||
|
)
|
||||||
|
await self.broadcast(tx)
|
||||||
|
txs.append(tx)
|
||||||
|
await asyncio.wait([self.on_transaction_address(tx, address) for tx in txs], timeout=1)
|
||||||
|
remote_status = await self.ledger.network.subscribe_address(address)
|
||||||
|
self.assertTrue(await self.ledger.update_history(address, remote_status))
|
||||||
|
# server history grows unordered
|
||||||
|
txid = await self.blockchain.send_to_address(address, 1)
|
||||||
|
await self.on_transaction_id(txid)
|
||||||
|
self.assertTrue(await self.ledger.update_history(address, remote_status))
|
||||||
|
self.assertEqual(21, len((await self.ledger.get_local_status_and_history(address))[1]))
|
||||||
|
self.assertEqual(0, len(self.ledger._known_addresses_out_of_sync))
|
||||||
|
# should be another test, but it would be too much to setup just for that and it affects sync
|
||||||
|
self.assertIsNone(await self.ledger.network.retriable_call(self.ledger.network.get_transaction, '1'*64))
|
||||||
|
|
|
@ -29,7 +29,7 @@ class MockNetwork:
|
||||||
async def get_merkle(self, txid, height):
|
async def get_merkle(self, txid, height):
|
||||||
return {'merkle': ['abcd01'], 'pos': 1}
|
return {'merkle': ['abcd01'], 'pos': 1}
|
||||||
|
|
||||||
async def get_transaction(self, tx_hash):
|
async def get_transaction(self, tx_hash, _=None):
|
||||||
self.get_transaction_called.append(tx_hash)
|
self.get_transaction_called.append(tx_hash)
|
||||||
return self.transaction[tx_hash]
|
return self.transaction[tx_hash]
|
||||||
|
|
||||||
|
|
|
@ -220,23 +220,26 @@ class SQLiteMixin:
|
||||||
|
|
||||||
async def open(self):
|
async def open(self):
|
||||||
log.info("connecting to database: %s", self._db_path)
|
log.info("connecting to database: %s", self._db_path)
|
||||||
self.db = await AIOSQLite.connect(self._db_path)
|
self.db = await AIOSQLite.connect(self._db_path, isolation_level=None)
|
||||||
await self.db.executescript(self.CREATE_TABLES_QUERY)
|
await self.db.executescript(self.CREATE_TABLES_QUERY)
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
await self.db.close()
|
await self.db.close()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _insert_sql(table: str, data: dict, ignore_duplicate: bool = False) -> Tuple[str, List]:
|
def _insert_sql(table: str, data: dict, ignore_duplicate: bool = False,
|
||||||
|
replace: bool = False) -> Tuple[str, List]:
|
||||||
columns, values = [], []
|
columns, values = [], []
|
||||||
for column, value in data.items():
|
for column, value in data.items():
|
||||||
columns.append(column)
|
columns.append(column)
|
||||||
values.append(value)
|
values.append(value)
|
||||||
or_ignore = ""
|
policy = ""
|
||||||
if ignore_duplicate:
|
if ignore_duplicate:
|
||||||
or_ignore = " OR IGNORE"
|
policy = " OR IGNORE"
|
||||||
|
if replace:
|
||||||
|
policy = " OR REPLACE"
|
||||||
sql = "INSERT{} INTO {} ({}) VALUES ({})".format(
|
sql = "INSERT{} INTO {} ({}) VALUES ({})".format(
|
||||||
or_ignore, table, ', '.join(columns), ', '.join(['?'] * len(values))
|
policy, table, ', '.join(columns), ', '.join(['?'] * len(values))
|
||||||
)
|
)
|
||||||
return sql, values
|
return sql, values
|
||||||
|
|
||||||
|
@ -348,9 +351,14 @@ class BaseDatabase(SQLiteMixin):
|
||||||
'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified
|
'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified
|
||||||
}, 'txid = ?', (tx.id,)))
|
}, 'txid = ?', (tx.id,)))
|
||||||
|
|
||||||
def save_transaction_io(self, tx: BaseTransaction, address, txhash, history):
|
def _transaction_io(self, conn: sqlite3.Connection, tx: BaseTransaction, address, txhash, history):
|
||||||
|
conn.execute(*self._insert_sql('tx', {
|
||||||
def _transaction(conn: sqlite3.Connection, tx: BaseTransaction, address, txhash, history):
|
'txid': tx.id,
|
||||||
|
'raw': sqlite3.Binary(tx.raw),
|
||||||
|
'height': tx.height,
|
||||||
|
'position': tx.position,
|
||||||
|
'is_verified': tx.is_verified
|
||||||
|
}, replace=True))
|
||||||
|
|
||||||
for txo in tx.outputs:
|
for txo in tx.outputs:
|
||||||
if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == txhash:
|
if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == txhash:
|
||||||
|
@ -373,10 +381,17 @@ class BaseDatabase(SQLiteMixin):
|
||||||
|
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
|
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
|
||||||
(history, history.count(':')//2, address)
|
(history, history.count(':') // 2, address)
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.db.run(_transaction, tx, address, txhash, history)
|
def save_transaction_io(self, tx: BaseTransaction, address, txhash, history):
|
||||||
|
return self.db.run(self._transaction_io, tx, address, txhash, history)
|
||||||
|
|
||||||
|
def save_transaction_io_batch(self, txs: Iterable[BaseTransaction], address, txhash, history):
|
||||||
|
def __many(conn):
|
||||||
|
for tx in txs:
|
||||||
|
self._transaction_io(conn, tx, address, txhash, history)
|
||||||
|
return self.db.run(__many)
|
||||||
|
|
||||||
async def reserve_outputs(self, txos, is_reserved=True):
|
async def reserve_outputs(self, txos, is_reserved=True):
|
||||||
txoids = ((is_reserved, txo.id) for txo in txos)
|
txoids = ((is_reserved, txo.id) for txo in txos)
|
||||||
|
|
|
@ -10,6 +10,7 @@ from operator import itemgetter
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
import pylru
|
import pylru
|
||||||
|
from torba.client.basetransaction import BaseTransaction
|
||||||
from torba.tasks import TaskGroup
|
from torba.tasks import TaskGroup
|
||||||
from torba.client import baseaccount, basenetwork, basetransaction
|
from torba.client import baseaccount, basenetwork, basetransaction
|
||||||
from torba.client.basedatabase import BaseDatabase
|
from torba.client.basedatabase import BaseDatabase
|
||||||
|
@ -142,6 +143,7 @@ class BaseLedger(metaclass=LedgerRegistry):
|
||||||
self._address_update_locks: Dict[str, asyncio.Lock] = {}
|
self._address_update_locks: Dict[str, asyncio.Lock] = {}
|
||||||
|
|
||||||
self.coin_selection_strategy = None
|
self.coin_selection_strategy = None
|
||||||
|
self._known_addresses_out_of_sync = set()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_id(cls):
|
def get_id(cls):
|
||||||
|
@ -250,7 +252,8 @@ class BaseLedger(metaclass=LedgerRegistry):
|
||||||
self.constraint_account_or_all(constraints)
|
self.constraint_account_or_all(constraints)
|
||||||
return self.db.get_transaction_count(**constraints)
|
return self.db.get_transaction_count(**constraints)
|
||||||
|
|
||||||
async def get_local_status_and_history(self, address):
|
async def get_local_status_and_history(self, address, history=None):
|
||||||
|
if not history:
|
||||||
address_details = await self.db.get_address(address=address)
|
address_details = await self.db.get_address(address=address)
|
||||||
history = address_details['history'] or ''
|
history = address_details['history'] or ''
|
||||||
parts = history.split(':')[:-1]
|
parts = history.split(':')[:-1]
|
||||||
|
@ -284,7 +287,7 @@ class BaseLedger(metaclass=LedgerRegistry):
|
||||||
await self.join_network()
|
await self.join_network()
|
||||||
self.network.on_connected.listen(self.join_network)
|
self.network.on_connected.listen(self.join_network)
|
||||||
|
|
||||||
async def join_network(self, *args):
|
async def join_network(self, *_):
|
||||||
log.info("Subscribing and updating accounts.")
|
log.info("Subscribing and updating accounts.")
|
||||||
async with self._header_processing_lock:
|
async with self._header_processing_lock:
|
||||||
await self.update_headers()
|
await self.update_headers()
|
||||||
|
@ -411,24 +414,31 @@ class BaseLedger(metaclass=LedgerRegistry):
|
||||||
address_manager: baseaccount.AddressManager = None):
|
address_manager: baseaccount.AddressManager = None):
|
||||||
|
|
||||||
async with self._address_update_locks.setdefault(address, asyncio.Lock()):
|
async with self._address_update_locks.setdefault(address, asyncio.Lock()):
|
||||||
|
self._known_addresses_out_of_sync.discard(address)
|
||||||
|
|
||||||
local_status, local_history = await self.get_local_status_and_history(address)
|
local_status, local_history = await self.get_local_status_and_history(address)
|
||||||
|
|
||||||
if local_status == remote_status:
|
if local_status == remote_status:
|
||||||
return
|
return True
|
||||||
|
|
||||||
remote_history = await self.network.retriable_call(self.network.get_history, address)
|
remote_history = await self.network.retriable_call(self.network.get_history, address)
|
||||||
|
remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history))
|
||||||
|
we_need = set(remote_history) - set(local_history)
|
||||||
|
if not we_need:
|
||||||
|
return True
|
||||||
|
|
||||||
cache_tasks = []
|
cache_tasks: List[asyncio.Future[BaseTransaction]] = []
|
||||||
synced_history = StringIO()
|
synced_history = StringIO()
|
||||||
for i, (txid, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)):
|
for i, (txid, remote_height) in enumerate(remote_history):
|
||||||
if i < len(local_history) and local_history[i] == (txid, remote_height):
|
if i < len(local_history) and local_history[i] == (txid, remote_height) and not cache_tasks:
|
||||||
synced_history.write(f'{txid}:{remote_height}:')
|
synced_history.write(f'{txid}:{remote_height}:')
|
||||||
else:
|
else:
|
||||||
|
check_local = (txid, remote_height) not in we_need
|
||||||
cache_tasks.append(asyncio.ensure_future(
|
cache_tasks.append(asyncio.ensure_future(
|
||||||
self.cache_transaction(txid, remote_height)
|
self.cache_transaction(txid, remote_height, check_local=check_local)
|
||||||
))
|
))
|
||||||
|
|
||||||
|
synced_txs = []
|
||||||
for task in cache_tasks:
|
for task in cache_tasks:
|
||||||
tx = await task
|
tx = await task
|
||||||
|
|
||||||
|
@ -457,12 +467,15 @@ class BaseLedger(metaclass=LedgerRegistry):
|
||||||
txi.txo_ref = referenced_txo.ref
|
txi.txo_ref = referenced_txo.ref
|
||||||
|
|
||||||
synced_history.write(f'{tx.id}:{tx.height}:')
|
synced_history.write(f'{tx.id}:{tx.height}:')
|
||||||
|
synced_txs.append(tx)
|
||||||
|
|
||||||
await self.db.save_transaction_io(
|
await self.db.save_transaction_io_batch(
|
||||||
tx, address, self.address_to_hash160(address), synced_history.getvalue()
|
synced_txs, address, self.address_to_hash160(address), synced_history.getvalue()
|
||||||
)
|
)
|
||||||
|
await asyncio.wait([
|
||||||
await self._on_transaction_controller.add(TransactionEvent(address, tx))
|
self._on_transaction_controller.add(TransactionEvent(address, tx))
|
||||||
|
for tx in synced_txs
|
||||||
|
])
|
||||||
|
|
||||||
if address_manager is None:
|
if address_manager is None:
|
||||||
address_manager = await self.get_address_manager_for_address(address)
|
address_manager = await self.get_address_manager_for_address(address)
|
||||||
|
@ -470,18 +483,23 @@ class BaseLedger(metaclass=LedgerRegistry):
|
||||||
if address_manager is not None:
|
if address_manager is not None:
|
||||||
await address_manager.ensure_address_gap()
|
await address_manager.ensure_address_gap()
|
||||||
|
|
||||||
local_status, local_history = await self.get_local_status_and_history(address)
|
local_status, local_history = \
|
||||||
|
await self.get_local_status_and_history(address, synced_history.getvalue())
|
||||||
if local_status != remote_status:
|
if local_status != remote_status:
|
||||||
log.debug(
|
if local_history == remote_history:
|
||||||
|
return True
|
||||||
|
log.warning(
|
||||||
"Wallet is out of sync after syncing. Remote: %s with %d items, local: %s with %d items",
|
"Wallet is out of sync after syncing. Remote: %s with %d items, local: %s with %d items",
|
||||||
remote_status, len(remote_history), local_status, len(local_history)
|
remote_status, len(remote_history), local_status, len(local_history)
|
||||||
)
|
)
|
||||||
log.debug("local: %s", local_history)
|
log.warning("local: %s", local_history)
|
||||||
log.debug("remote: %s", remote_history)
|
log.warning("remote: %s", remote_history)
|
||||||
|
self._known_addresses_out_of_sync.add(address)
|
||||||
|
return False
|
||||||
else:
|
else:
|
||||||
log.debug("Sync completed for: %s", address)
|
return True
|
||||||
|
|
||||||
async def cache_transaction(self, txid, remote_height):
|
async def cache_transaction(self, txid, remote_height, check_local=True):
|
||||||
cache_item = self._tx_cache.get(txid)
|
cache_item = self._tx_cache.get(txid)
|
||||||
if cache_item is None:
|
if cache_item is None:
|
||||||
cache_item = self._tx_cache[txid] = TransactionCacheItem()
|
cache_item = self._tx_cache[txid] = TransactionCacheItem()
|
||||||
|
@ -494,28 +512,21 @@ class BaseLedger(metaclass=LedgerRegistry):
|
||||||
|
|
||||||
tx = cache_item.tx
|
tx = cache_item.tx
|
||||||
|
|
||||||
if tx is None:
|
if tx is None and check_local:
|
||||||
# check local db
|
# check local db
|
||||||
tx = cache_item.tx = await self.db.get_transaction(txid=txid)
|
tx = cache_item.tx = await self.db.get_transaction(txid=txid)
|
||||||
|
|
||||||
if tx is None:
|
if tx is None:
|
||||||
# fetch from network
|
# fetch from network
|
||||||
_raw = await self.network.retriable_call(self.network.get_transaction, txid)
|
_raw = await self.network.retriable_call(self.network.get_transaction, txid, remote_height)
|
||||||
if _raw:
|
if _raw:
|
||||||
tx = self.transaction_class(unhexlify(_raw))
|
tx = self.transaction_class(unhexlify(_raw))
|
||||||
await self.maybe_verify_transaction(tx, remote_height)
|
|
||||||
await self.db.insert_transaction(tx)
|
|
||||||
cache_item.tx = tx # make sure it's saved before caching it
|
cache_item.tx = tx # make sure it's saved before caching it
|
||||||
return tx
|
|
||||||
|
|
||||||
if tx is None:
|
if tx is None:
|
||||||
raise ValueError(f'Transaction {txid} was not in database and not on network.')
|
raise ValueError(f'Transaction {txid} was not in database and not on network.')
|
||||||
|
|
||||||
if remote_height > 0 and not tx.is_verified:
|
|
||||||
# tx from cache / db is not up-to-date
|
|
||||||
await self.maybe_verify_transaction(tx, remote_height)
|
await self.maybe_verify_transaction(tx, remote_height)
|
||||||
await self.db.update_transaction(tx)
|
|
||||||
|
|
||||||
return tx
|
return tx
|
||||||
|
|
||||||
async def maybe_verify_transaction(self, tx, remote_height):
|
async def maybe_verify_transaction(self, tx, remote_height):
|
||||||
|
|
|
@ -4,7 +4,7 @@ from operator import itemgetter
|
||||||
from typing import Dict, Optional, Tuple
|
from typing import Dict, Optional, Tuple
|
||||||
from time import perf_counter
|
from time import perf_counter
|
||||||
|
|
||||||
from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError
|
from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError, ProtocolError
|
||||||
|
|
||||||
from torba import __version__
|
from torba import __version__
|
||||||
from torba.stream import StreamController
|
from torba.stream import StreamController
|
||||||
|
@ -30,11 +30,11 @@ class ClientSession(BaseClientSession):
|
||||||
self._on_connect_cb = on_connect_callback or (lambda: None)
|
self._on_connect_cb = on_connect_callback or (lambda: None)
|
||||||
self.trigger_urgent_reconnect = asyncio.Event()
|
self.trigger_urgent_reconnect = asyncio.Event()
|
||||||
# one request per second of timeout, conservative default
|
# one request per second of timeout, conservative default
|
||||||
self._semaphore = asyncio.Semaphore(self.timeout)
|
self._semaphore = asyncio.Semaphore(self.timeout * 2)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def available(self):
|
def available(self):
|
||||||
return not self.is_closing() and self._can_send.is_set() and self.response_time is not None
|
return not self.is_closing() and self.response_time is not None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def server_address_and_port(self) -> Optional[Tuple[str, int]]:
|
def server_address_and_port(self) -> Optional[Tuple[str, int]]:
|
||||||
|
@ -71,7 +71,10 @@ class ClientSession(BaseClientSession):
|
||||||
)
|
)
|
||||||
log.debug("got reply for %s from %s:%i", method, *self.server)
|
log.debug("got reply for %s from %s:%i", method, *self.server)
|
||||||
return reply
|
return reply
|
||||||
except RPCError as e:
|
except (RPCError, ProtocolError) as e:
|
||||||
|
if str(e).find('.*no such .*transaction.*'):
|
||||||
|
# shouldnt the server return none instead?
|
||||||
|
return None
|
||||||
log.warning("Wallet server (%s:%i) returned an error. Code: %s Message: %s",
|
log.warning("Wallet server (%s:%i) returned an error. Code: %s Message: %s",
|
||||||
*self.server, *e.args)
|
*self.server, *e.args)
|
||||||
raise e
|
raise e
|
||||||
|
@ -144,6 +147,7 @@ class BaseNetwork:
|
||||||
self.config = ledger.config
|
self.config = ledger.config
|
||||||
self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6))
|
self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6))
|
||||||
self.client: Optional[ClientSession] = None
|
self.client: Optional[ClientSession] = None
|
||||||
|
self._switch_task: Optional[asyncio.Task] = None
|
||||||
self.running = False
|
self.running = False
|
||||||
self.remote_height: int = 0
|
self.remote_height: int = 0
|
||||||
|
|
||||||
|
@ -161,51 +165,41 @@ class BaseNetwork:
|
||||||
'blockchain.address.subscribe': self._on_status_controller,
|
'blockchain.address.subscribe': self._on_status_controller,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def switch_to_fastest(self):
|
async def switch_forever(self):
|
||||||
try:
|
while self.running:
|
||||||
client = await asyncio.wait_for(self.session_pool.wait_for_fastest_session(), 30)
|
if self.is_connected:
|
||||||
except asyncio.TimeoutError:
|
await self.client.on_disconnected.first
|
||||||
if self.client:
|
|
||||||
await self.client.close()
|
|
||||||
self.client = None
|
self.client = None
|
||||||
for session in self.session_pool.sessions:
|
continue
|
||||||
session.synchronous_close()
|
self.client = await self.session_pool.wait_for_fastest_session()
|
||||||
log.warning("not connected to any wallet servers")
|
|
||||||
return
|
|
||||||
current_client = self.client
|
|
||||||
self.client = client
|
|
||||||
log.info("Switching to SPV wallet server: %s:%d", *self.client.server)
|
log.info("Switching to SPV wallet server: %s:%d", *self.client.server)
|
||||||
self._on_connected_controller.add(True)
|
self._on_connected_controller.add(True)
|
||||||
try:
|
try:
|
||||||
self._update_remote_height((await self.subscribe_headers(),))
|
self._update_remote_height((await self.subscribe_headers(),))
|
||||||
log.info("Subscribed to headers: %s:%d", *self.client.server)
|
log.info("Subscribed to headers: %s:%d", *self.client.server)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
if self.client:
|
log.info("Switching to %s:%d timed out, closing and retrying.")
|
||||||
await self.client.close()
|
self.client.synchronous_close()
|
||||||
self.client = current_client
|
self.client = None
|
||||||
return
|
|
||||||
self.session_pool.new_connection_event.clear()
|
|
||||||
return await self.session_pool.new_connection_event.wait()
|
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
self.running = True
|
self.running = True
|
||||||
|
self._switch_task = asyncio.ensure_future(self.switch_forever())
|
||||||
self.session_pool.start(self.config['default_servers'])
|
self.session_pool.start(self.config['default_servers'])
|
||||||
self.on_header.listen(self._update_remote_height)
|
self.on_header.listen(self._update_remote_height)
|
||||||
while self.running:
|
|
||||||
await self.switch_to_fastest()
|
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
|
if self.running:
|
||||||
self.running = False
|
self.running = False
|
||||||
|
self._switch_task.cancel()
|
||||||
self.session_pool.stop()
|
self.session_pool.stop()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_connected(self):
|
def is_connected(self):
|
||||||
return self.client and not self.client.is_closing()
|
return self.client and not self.client.is_closing()
|
||||||
|
|
||||||
def rpc(self, list_or_method, args, session=None):
|
def rpc(self, list_or_method, args, restricted=True):
|
||||||
# fixme: use fastest unloaded session, but for now it causes issues with wallet sync
|
session = self.client if restricted else self.session_pool.fastest_session
|
||||||
# session = session or self.session_pool.fastest_session
|
|
||||||
session = self.client
|
|
||||||
if session and not session.is_closing():
|
if session and not session.is_closing():
|
||||||
return session.send_request(list_or_method, args)
|
return session.send_request(list_or_method, args)
|
||||||
else:
|
else:
|
||||||
|
@ -229,31 +223,35 @@ class BaseNetwork:
|
||||||
def _update_remote_height(self, header_args):
|
def _update_remote_height(self, header_args):
|
||||||
self.remote_height = header_args[0]["height"]
|
self.remote_height = header_args[0]["height"]
|
||||||
|
|
||||||
def get_transaction(self, tx_hash):
|
def get_transaction(self, tx_hash, known_height=None):
|
||||||
return self.rpc('blockchain.transaction.get', [tx_hash])
|
# use any server if its old, otherwise restrict to who gave us the history
|
||||||
|
restricted = not known_height or 0 > known_height > self.remote_height - 10
|
||||||
|
return self.rpc('blockchain.transaction.get', [tx_hash], restricted)
|
||||||
|
|
||||||
def get_transaction_height(self, tx_hash):
|
def get_transaction_height(self, tx_hash, known_height=None):
|
||||||
return self.rpc('blockchain.transaction.get_height', [tx_hash])
|
restricted = not known_height or 0 > known_height > self.remote_height - 10
|
||||||
|
return self.rpc('blockchain.transaction.get_height', [tx_hash], restricted)
|
||||||
|
|
||||||
def get_merkle(self, tx_hash, height):
|
def get_merkle(self, tx_hash, height):
|
||||||
return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height])
|
restricted = 0 > height > self.remote_height - 10
|
||||||
|
return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height], restricted)
|
||||||
|
|
||||||
def get_headers(self, height, count=10000):
|
def get_headers(self, height, count=10000):
|
||||||
return self.rpc('blockchain.block.headers', [height, count])
|
return self.rpc('blockchain.block.headers', [height, count])
|
||||||
|
|
||||||
# --- Subscribes, history and broadcasts are always aimed towards the master client directly
|
# --- Subscribes, history and broadcasts are always aimed towards the master client directly
|
||||||
def get_history(self, address):
|
def get_history(self, address):
|
||||||
return self.rpc('blockchain.address.get_history', [address], session=self.client)
|
return self.rpc('blockchain.address.get_history', [address], True)
|
||||||
|
|
||||||
def broadcast(self, raw_transaction):
|
def broadcast(self, raw_transaction):
|
||||||
return self.rpc('blockchain.transaction.broadcast', [raw_transaction], session=self.client)
|
return self.rpc('blockchain.transaction.broadcast', [raw_transaction], True)
|
||||||
|
|
||||||
def subscribe_headers(self):
|
def subscribe_headers(self):
|
||||||
return self.rpc('blockchain.headers.subscribe', [True], session=self.client)
|
return self.rpc('blockchain.headers.subscribe', [True], True)
|
||||||
|
|
||||||
async def subscribe_address(self, address):
|
async def subscribe_address(self, address):
|
||||||
try:
|
try:
|
||||||
return await self.rpc('blockchain.address.subscribe', [address], session=self.client)
|
return await self.rpc('blockchain.address.subscribe', [address], True)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
# abort and cancel, we cant lose a subscription, it will happen again on reconnect
|
# abort and cancel, we cant lose a subscription, it will happen again on reconnect
|
||||||
self.client.abort()
|
self.client.abort()
|
||||||
|
@ -274,11 +272,11 @@ class SessionPool:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def available_sessions(self):
|
def available_sessions(self):
|
||||||
return [session for session in self.sessions if session.available]
|
return (session for session in self.sessions if session.available)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def fastest_session(self):
|
def fastest_session(self):
|
||||||
if not self.available_sessions:
|
if not self.online:
|
||||||
return None
|
return None
|
||||||
return min(
|
return min(
|
||||||
[((session.response_time + session.connection_latency) * (session.pending_amount + 1), session)
|
[((session.response_time + session.connection_latency) * (session.pending_amount + 1), session)
|
||||||
|
@ -329,8 +327,9 @@ class SessionPool:
|
||||||
self._connect_session(server)
|
self._connect_session(server)
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
for task in self.sessions.values():
|
for session, task in self.sessions.items():
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
session.synchronous_close()
|
||||||
self.sessions.clear()
|
self.sessions.clear()
|
||||||
|
|
||||||
def ensure_connections(self):
|
def ensure_connections(self):
|
||||||
|
|
|
@ -190,14 +190,15 @@ class WalletNode:
|
||||||
|
|
||||||
class SPVNode:
|
class SPVNode:
|
||||||
|
|
||||||
def __init__(self, coin_class):
|
def __init__(self, coin_class, node_number=1):
|
||||||
self.coin_class = coin_class
|
self.coin_class = coin_class
|
||||||
self.controller = None
|
self.controller = None
|
||||||
self.data_path = None
|
self.data_path = None
|
||||||
self.server = None
|
self.server = None
|
||||||
self.hostname = 'localhost'
|
self.hostname = 'localhost'
|
||||||
self.port = 50001 + 1 # avoid conflict with default daemon
|
self.port = 50001 + node_number # avoid conflict with default daemon
|
||||||
self.session_timeout = 600
|
self.session_timeout = 600
|
||||||
|
self.rpc_port = '0' # disabled by default
|
||||||
|
|
||||||
async def start(self, blockchain_node: 'BlockchainNode'):
|
async def start(self, blockchain_node: 'BlockchainNode'):
|
||||||
self.data_path = tempfile.mkdtemp()
|
self.data_path = tempfile.mkdtemp()
|
||||||
|
@ -210,6 +211,7 @@ class SPVNode:
|
||||||
'SESSION_TIMEOUT': str(self.session_timeout),
|
'SESSION_TIMEOUT': str(self.session_timeout),
|
||||||
'MAX_QUERY_WORKERS': '0',
|
'MAX_QUERY_WORKERS': '0',
|
||||||
'INDIVIDUAL_TAG_INDEXES': '',
|
'INDIVIDUAL_TAG_INDEXES': '',
|
||||||
|
'RPC_PORT': self.rpc_port
|
||||||
}
|
}
|
||||||
# TODO: don't use os.environ
|
# TODO: don't use os.environ
|
||||||
os.environ.update(conf)
|
os.environ.update(conf)
|
||||||
|
|
|
@ -17,6 +17,7 @@ import time
|
||||||
from asyncio import sleep
|
from asyncio import sleep
|
||||||
from bisect import bisect_right
|
from bisect import bisect_right
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
from functools import partial
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from struct import pack, unpack
|
from struct import pack, unpack
|
||||||
|
|
||||||
|
@ -72,9 +73,8 @@ class DB:
|
||||||
self.header_len = self.dynamic_header_len
|
self.header_len = self.dynamic_header_len
|
||||||
|
|
||||||
self.logger.info(f'switching current directory to {env.db_dir}')
|
self.logger.info(f'switching current directory to {env.db_dir}')
|
||||||
os.chdir(env.db_dir)
|
|
||||||
|
|
||||||
self.db_class = db_class(self.env.db_engine)
|
self.db_class = db_class(env.db_dir, self.env.db_engine)
|
||||||
self.history = History()
|
self.history = History()
|
||||||
self.utxo_db = None
|
self.utxo_db = None
|
||||||
self.tx_counts = None
|
self.tx_counts = None
|
||||||
|
@ -86,12 +86,13 @@ class DB:
|
||||||
self.merkle = Merkle()
|
self.merkle = Merkle()
|
||||||
self.header_mc = MerkleCache(self.merkle, self.fs_block_hashes)
|
self.header_mc = MerkleCache(self.merkle, self.fs_block_hashes)
|
||||||
|
|
||||||
self.headers_file = util.LogicalFile('meta/headers', 2, 16000000)
|
path = partial(os.path.join, self.env.db_dir)
|
||||||
self.tx_counts_file = util.LogicalFile('meta/txcounts', 2, 2000000)
|
self.headers_file = util.LogicalFile(path('meta/headers'), 2, 16000000)
|
||||||
self.hashes_file = util.LogicalFile('meta/hashes', 4, 16000000)
|
self.tx_counts_file = util.LogicalFile(path('meta/txcounts'), 2, 2000000)
|
||||||
|
self.hashes_file = util.LogicalFile(path('meta/hashes'), 4, 16000000)
|
||||||
if not self.coin.STATIC_BLOCK_HEADERS:
|
if not self.coin.STATIC_BLOCK_HEADERS:
|
||||||
self.headers_offsets_file = util.LogicalFile(
|
self.headers_offsets_file = util.LogicalFile(
|
||||||
'meta/headers_offsets', 2, 16000000)
|
path('meta/headers_offsets'), 2, 16000000)
|
||||||
|
|
||||||
async def _read_tx_counts(self):
|
async def _read_tx_counts(self):
|
||||||
if self.tx_counts is not None:
|
if self.tx_counts is not None:
|
||||||
|
@ -115,8 +116,9 @@ class DB:
|
||||||
if self.utxo_db.is_new:
|
if self.utxo_db.is_new:
|
||||||
self.logger.info('created new database')
|
self.logger.info('created new database')
|
||||||
self.logger.info('creating metadata directory')
|
self.logger.info('creating metadata directory')
|
||||||
os.mkdir('meta')
|
os.mkdir(os.path.join(self.env.db_dir, 'meta'))
|
||||||
with util.open_file('COIN', create=True) as f:
|
coin_path = os.path.join(self.env.db_dir, 'meta', 'COIN')
|
||||||
|
with util.open_file(coin_path, create=True) as f:
|
||||||
f.write(f'ElectrumX databases and metadata for '
|
f.write(f'ElectrumX databases and metadata for '
|
||||||
f'{self.coin.NAME} {self.coin.NET}'.encode())
|
f'{self.coin.NAME} {self.coin.NET}'.encode())
|
||||||
if not self.coin.STATIC_BLOCK_HEADERS:
|
if not self.coin.STATIC_BLOCK_HEADERS:
|
||||||
|
@ -474,7 +476,7 @@ class DB:
|
||||||
return 'meta/block'
|
return 'meta/block'
|
||||||
|
|
||||||
def raw_block_path(self, height):
|
def raw_block_path(self, height):
|
||||||
return f'{self.raw_block_prefix()}{height:d}'
|
return os.path.join(self.env.db_dir, f'{self.raw_block_prefix()}{height:d}')
|
||||||
|
|
||||||
def read_raw_block(self, height):
|
def read_raw_block(self, height):
|
||||||
"""Returns a raw block read from disk. Raises FileNotFoundError
|
"""Returns a raw block read from disk. Raises FileNotFoundError
|
||||||
|
|
|
@ -13,20 +13,21 @@ from functools import partial
|
||||||
from torba.server import util
|
from torba.server import util
|
||||||
|
|
||||||
|
|
||||||
def db_class(name):
|
def db_class(db_dir, name):
|
||||||
"""Returns a DB engine class."""
|
"""Returns a DB engine class."""
|
||||||
for db_class in util.subclasses(Storage):
|
for db_class in util.subclasses(Storage):
|
||||||
if db_class.__name__.lower() == name.lower():
|
if db_class.__name__.lower() == name.lower():
|
||||||
db_class.import_module()
|
db_class.import_module()
|
||||||
return db_class
|
return partial(db_class, db_dir)
|
||||||
raise RuntimeError('unrecognised DB engine "{}"'.format(name))
|
raise RuntimeError('unrecognised DB engine "{}"'.format(name))
|
||||||
|
|
||||||
|
|
||||||
class Storage:
|
class Storage:
|
||||||
"""Abstract base class of the DB backend abstraction."""
|
"""Abstract base class of the DB backend abstraction."""
|
||||||
|
|
||||||
def __init__(self, name, for_sync):
|
def __init__(self, db_dir, name, for_sync):
|
||||||
self.is_new = not os.path.exists(name)
|
self.db_dir = db_dir
|
||||||
|
self.is_new = not os.path.exists(os.path.join(db_dir, name))
|
||||||
self.for_sync = for_sync or self.is_new
|
self.for_sync = for_sync or self.is_new
|
||||||
self.open(name, create=self.is_new)
|
self.open(name, create=self.is_new)
|
||||||
|
|
||||||
|
@ -78,8 +79,9 @@ class LevelDB(Storage):
|
||||||
|
|
||||||
def open(self, name, create):
|
def open(self, name, create):
|
||||||
mof = 512 if self.for_sync else 128
|
mof = 512 if self.for_sync else 128
|
||||||
|
path = os.path.join(self.db_dir, name)
|
||||||
# Use snappy compression (the default)
|
# Use snappy compression (the default)
|
||||||
self.db = self.module.DB(name, create_if_missing=create,
|
self.db = self.module.DB(path, create_if_missing=create,
|
||||||
max_open_files=mof)
|
max_open_files=mof)
|
||||||
self.close = self.db.close
|
self.close = self.db.close
|
||||||
self.get = self.db.get
|
self.get = self.db.get
|
||||||
|
@ -99,12 +101,13 @@ class RocksDB(Storage):
|
||||||
|
|
||||||
def open(self, name, create):
|
def open(self, name, create):
|
||||||
mof = 512 if self.for_sync else 128
|
mof = 512 if self.for_sync else 128
|
||||||
|
path = os.path.join(self.db_dir, name)
|
||||||
# Use snappy compression (the default)
|
# Use snappy compression (the default)
|
||||||
options = self.module.Options(create_if_missing=create,
|
options = self.module.Options(create_if_missing=create,
|
||||||
use_fsync=True,
|
use_fsync=True,
|
||||||
target_file_size_base=33554432,
|
target_file_size_base=33554432,
|
||||||
max_open_files=mof)
|
max_open_files=mof)
|
||||||
self.db = self.module.DB(name, options)
|
self.db = self.module.DB(path, options)
|
||||||
self.get = self.db.get
|
self.get = self.db.get
|
||||||
self.put = self.db.put
|
self.put = self.db.put
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue