restore multi server, improve sync concurrency

This commit is contained in:
Victor Shyba 2019-08-30 19:53:51 -03:00
parent 10b7ccaa92
commit d2cd0ece5f
4 changed files with 93 additions and 69 deletions

View file

@ -33,7 +33,7 @@ class ReconnectTests(IntegrationTestCase):
for session in self.ledger.network.session_pool.sessions: for session in self.ledger.network.session_pool.sessions:
session.trigger_urgent_reconnect.set() session.trigger_urgent_reconnect.set()
await asyncio.wait_for(self.ledger.network.session_pool.new_connection_event.wait(), timeout=1) await asyncio.wait_for(self.ledger.network.session_pool.new_connection_event.wait(), timeout=1)
self.assertEqual(2, len(self.ledger.network.session_pool.available_sessions)) self.assertEqual(2, len(list(self.ledger.network.session_pool.available_sessions)))
self.assertTrue(self.ledger.network.is_connected) self.assertTrue(self.ledger.network.is_connected)
switch_event = self.ledger.network.on_connected.first switch_event = self.ledger.network.on_connected.first
await node2.stop(True) await node2.stop(True)
@ -126,4 +126,4 @@ class ServerPickingTestCase(AsyncioTestCase):
self.assertTrue(all([not session.is_closing() for session in network.session_pool.available_sessions])) self.assertTrue(all([not session.is_closing() for session in network.session_pool.available_sessions]))
# ensure we are connected to all of them after a while # ensure we are connected to all of them after a while
await asyncio.sleep(1) await asyncio.sleep(1)
self.assertEqual(len(network.session_pool.available_sessions), 3) self.assertEqual(len(list(network.session_pool.available_sessions)), 3)

View file

@ -227,16 +227,19 @@ class SQLiteMixin:
await self.db.close() await self.db.close()
@staticmethod @staticmethod
def _insert_sql(table: str, data: dict, ignore_duplicate: bool = False) -> Tuple[str, List]: def _insert_sql(table: str, data: dict, ignore_duplicate: bool = False,
replace: bool = False) -> Tuple[str, List]:
columns, values = [], [] columns, values = [], []
for column, value in data.items(): for column, value in data.items():
columns.append(column) columns.append(column)
values.append(value) values.append(value)
or_ignore = "" policy = ""
if ignore_duplicate: if ignore_duplicate:
or_ignore = " OR IGNORE" policy = " OR IGNORE"
if replace:
policy = " OR REPLACE"
sql = "INSERT{} INTO {} ({}) VALUES ({})".format( sql = "INSERT{} INTO {} ({}) VALUES ({})".format(
or_ignore, table, ', '.join(columns), ', '.join(['?'] * len(values)) policy, table, ', '.join(columns), ', '.join(['?'] * len(values))
) )
return sql, values return sql, values
@ -348,35 +351,47 @@ class BaseDatabase(SQLiteMixin):
'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified 'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified
}, 'txid = ?', (tx.id,))) }, 'txid = ?', (tx.id,)))
def _transaction_io(self, conn: sqlite3.Connection, tx: BaseTransaction, address, txhash, history):
conn.execute(*self._insert_sql('tx', {
'txid': tx.id,
'raw': sqlite3.Binary(tx.raw),
'height': tx.height,
'position': tx.position,
'is_verified': tx.is_verified
}, replace=True))
for txo in tx.outputs:
if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == txhash:
conn.execute(*self._insert_sql(
"txo", self.txo_to_row(tx, address, txo), ignore_duplicate=True
))
elif txo.script.is_pay_script_hash:
# TODO: implement script hash payments
log.warning('Database.save_transaction_io: pay script hash is not implemented!')
for txi in tx.inputs:
if txi.txo_ref.txo is not None:
txo = txi.txo_ref.txo
if txo.get_address(self.ledger) == address:
conn.execute(*self._insert_sql("txi", {
'txid': tx.id,
'txoid': txo.id,
'address': address,
}, ignore_duplicate=True))
conn.execute(
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
(history, history.count(':') // 2, address)
)
def save_transaction_io(self, tx: BaseTransaction, address, txhash, history): def save_transaction_io(self, tx: BaseTransaction, address, txhash, history):
return self.db.run(self._transaction_io, tx, address, txhash, history)
def _transaction(conn: sqlite3.Connection, tx: BaseTransaction, address, txhash, history): def save_transaction_io_batch(self, txs: Iterable[BaseTransaction], address, txhash, history):
def __many(conn):
for txo in tx.outputs: for tx in txs:
if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == txhash: self._transaction_io(conn, tx, address, txhash, history)
conn.execute(*self._insert_sql( return self.db.run(__many)
"txo", self.txo_to_row(tx, address, txo), ignore_duplicate=True
))
elif txo.script.is_pay_script_hash:
# TODO: implement script hash payments
log.warning('Database.save_transaction_io: pay script hash is not implemented!')
for txi in tx.inputs:
if txi.txo_ref.txo is not None:
txo = txi.txo_ref.txo
if txo.get_address(self.ledger) == address:
conn.execute(*self._insert_sql("txi", {
'txid': tx.id,
'txoid': txo.id,
'address': address,
}, ignore_duplicate=True))
conn.execute(
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
(history, history.count(':')//2, address)
)
return self.db.run(_transaction, tx, address, txhash, history)
async def reserve_outputs(self, txos, is_reserved=True): async def reserve_outputs(self, txos, is_reserved=True):
txoids = ((is_reserved, txo.id) for txo in txos) txoids = ((is_reserved, txo.id) for txo in txos)

View file

@ -10,6 +10,7 @@ from operator import itemgetter
from collections import namedtuple from collections import namedtuple
import pylru import pylru
from torba.client.basetransaction import BaseTransaction
from torba.tasks import TaskGroup from torba.tasks import TaskGroup
from torba.client import baseaccount, basenetwork, basetransaction from torba.client import baseaccount, basenetwork, basetransaction
from torba.client.basedatabase import BaseDatabase from torba.client.basedatabase import BaseDatabase
@ -251,9 +252,10 @@ class BaseLedger(metaclass=LedgerRegistry):
self.constraint_account_or_all(constraints) self.constraint_account_or_all(constraints)
return self.db.get_transaction_count(**constraints) return self.db.get_transaction_count(**constraints)
async def get_local_status_and_history(self, address): async def get_local_status_and_history(self, address, history=None):
address_details = await self.db.get_address(address=address) if not history:
history = address_details['history'] or '' address_details = await self.db.get_address(address=address)
history = address_details['history'] or ''
parts = history.split(':')[:-1] parts = history.split(':')[:-1]
return ( return (
hexlify(sha256(history.encode())).decode() if history else None, hexlify(sha256(history.encode())).decode() if history else None,
@ -420,17 +422,23 @@ class BaseLedger(metaclass=LedgerRegistry):
return True return True
remote_history = await self.network.retriable_call(self.network.get_history, address) remote_history = await self.network.retriable_call(self.network.get_history, address)
remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history))
we_need = set(remote_history) - set(local_history)
if not we_need:
return True
cache_tasks = [] cache_tasks: List[asyncio.Future[BaseTransaction]] = []
synced_history = StringIO() synced_history = StringIO()
for i, (txid, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)): for i, (txid, remote_height) in enumerate(remote_history):
if i < len(local_history) and local_history[i] == (txid, remote_height) and not cache_tasks: if i < len(local_history) and local_history[i] == (txid, remote_height) and not cache_tasks:
synced_history.write(f'{txid}:{remote_height}:') synced_history.write(f'{txid}:{remote_height}:')
else: else:
check_local = (txid, remote_height) not in we_need
cache_tasks.append(asyncio.ensure_future( cache_tasks.append(asyncio.ensure_future(
self.cache_transaction(txid, remote_height) self.cache_transaction(txid, remote_height, check_local=check_local)
)) ))
synced_txs = []
for task in cache_tasks: for task in cache_tasks:
tx = await task tx = await task
@ -459,11 +467,13 @@ class BaseLedger(metaclass=LedgerRegistry):
txi.txo_ref = referenced_txo.ref txi.txo_ref = referenced_txo.ref
synced_history.write(f'{tx.id}:{tx.height}:') synced_history.write(f'{tx.id}:{tx.height}:')
synced_txs.append(tx)
await self.db.save_transaction_io( await self.db.save_transaction_io_batch(
tx, address, self.address_to_hash160(address), synced_history.getvalue() synced_txs, address, self.address_to_hash160(address), synced_history.getvalue()
) )
for tx in synced_txs:
await self._on_transaction_controller.add(TransactionEvent(address, tx)) await self._on_transaction_controller.add(TransactionEvent(address, tx))
if address_manager is None: if address_manager is None:
@ -472,9 +482,10 @@ class BaseLedger(metaclass=LedgerRegistry):
if address_manager is not None: if address_manager is not None:
await address_manager.ensure_address_gap() await address_manager.ensure_address_gap()
local_status, local_history = await self.get_local_status_and_history(address) local_status, local_history = \
await self.get_local_status_and_history(address, synced_history.getvalue())
if local_status != remote_status: if local_status != remote_status:
if local_history == list(map(itemgetter('tx_hash', 'height'), remote_history)): if local_history == remote_history:
return True return True
log.warning( log.warning(
"Wallet is out of sync after syncing. Remote: %s with %d items, local: %s with %d items", "Wallet is out of sync after syncing. Remote: %s with %d items, local: %s with %d items",
@ -487,7 +498,7 @@ class BaseLedger(metaclass=LedgerRegistry):
else: else:
return True return True
async def cache_transaction(self, txid, remote_height): async def cache_transaction(self, txid, remote_height, check_local=True):
cache_item = self._tx_cache.get(txid) cache_item = self._tx_cache.get(txid)
if cache_item is None: if cache_item is None:
cache_item = self._tx_cache[txid] = TransactionCacheItem() cache_item = self._tx_cache[txid] = TransactionCacheItem()
@ -500,7 +511,7 @@ class BaseLedger(metaclass=LedgerRegistry):
tx = cache_item.tx tx = cache_item.tx
if tx is None: if tx is None and check_local:
# check local db # check local db
tx = cache_item.tx = await self.db.get_transaction(txid=txid) tx = cache_item.tx = await self.db.get_transaction(txid=txid)
@ -509,19 +520,12 @@ class BaseLedger(metaclass=LedgerRegistry):
_raw = await self.network.retriable_call(self.network.get_transaction, txid) _raw = await self.network.retriable_call(self.network.get_transaction, txid)
if _raw: if _raw:
tx = self.transaction_class(unhexlify(_raw)) tx = self.transaction_class(unhexlify(_raw))
await self.maybe_verify_transaction(tx, remote_height)
await self.db.insert_transaction(tx)
cache_item.tx = tx # make sure it's saved before caching it cache_item.tx = tx # make sure it's saved before caching it
return tx
if tx is None: if tx is None:
raise ValueError(f'Transaction {txid} was not in database and not on network.') raise ValueError(f'Transaction {txid} was not in database and not on network.')
if remote_height > 0 and not tx.is_verified: await self.maybe_verify_transaction(tx, remote_height)
# tx from cache / db is not up-to-date
await self.maybe_verify_transaction(tx, remote_height)
await self.db.update_transaction(tx)
return tx return tx
async def maybe_verify_transaction(self, tx, remote_height): async def maybe_verify_transaction(self, tx, remote_height):

View file

@ -30,11 +30,11 @@ class ClientSession(BaseClientSession):
self._on_connect_cb = on_connect_callback or (lambda: None) self._on_connect_cb = on_connect_callback or (lambda: None)
self.trigger_urgent_reconnect = asyncio.Event() self.trigger_urgent_reconnect = asyncio.Event()
# one request per second of timeout, conservative default # one request per second of timeout, conservative default
self._semaphore = asyncio.Semaphore(self.timeout) self._semaphore = asyncio.Semaphore(self.timeout * 2)
@property @property
def available(self): def available(self):
return not self.is_closing() and self._can_send.is_set() and self.response_time is not None return not self.is_closing() and self.response_time is not None
@property @property
def server_address_and_port(self) -> Optional[Tuple[str, int]]: def server_address_and_port(self) -> Optional[Tuple[str, int]]:
@ -195,10 +195,8 @@ class BaseNetwork:
def is_connected(self): def is_connected(self):
return self.client and not self.client.is_closing() return self.client and not self.client.is_closing()
def rpc(self, list_or_method, args, session=None): def rpc(self, list_or_method, args, restricted=False):
# fixme: use fastest unloaded session, but for now it causes issues with wallet sync session = self.client if restricted else self.session_pool.fastest_session
# session = session or self.session_pool.fastest_session
session = self.client
if session and not session.is_closing(): if session and not session.is_closing():
return session.send_request(list_or_method, args) return session.send_request(list_or_method, args)
else: else:
@ -225,28 +223,35 @@ class BaseNetwork:
def get_transaction(self, tx_hash): def get_transaction(self, tx_hash):
return self.rpc('blockchain.transaction.get', [tx_hash]) return self.rpc('blockchain.transaction.get', [tx_hash])
def get_transaction_height(self, tx_hash): def get_transaction_height(self, tx_hash, known_height=None):
return self.rpc('blockchain.transaction.get_height', [tx_hash]) restricted = True # by default, check master for consistency
if known_height:
if 0 < known_height < self.remote_height - 10:
restricted = False # we can get from any server, its old
return self.rpc('blockchain.transaction.get_height', [tx_hash], restricted)
def get_merkle(self, tx_hash, height): def get_merkle(self, tx_hash, height):
return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height]) restricted = True # by default, check master for consistency
if 0 < height < self.remote_height - 10:
restricted = False # we can get from any server, its old
return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height], restricted)
def get_headers(self, height, count=10000): def get_headers(self, height, count=10000):
return self.rpc('blockchain.block.headers', [height, count]) return self.rpc('blockchain.block.headers', [height, count])
# --- Subscribes, history and broadcasts are always aimed towards the master client directly # --- Subscribes, history and broadcasts are always aimed towards the master client directly
def get_history(self, address): def get_history(self, address):
return self.rpc('blockchain.address.get_history', [address], session=self.client) return self.rpc('blockchain.address.get_history', [address], True)
def broadcast(self, raw_transaction): def broadcast(self, raw_transaction):
return self.rpc('blockchain.transaction.broadcast', [raw_transaction], session=self.client) return self.rpc('blockchain.transaction.broadcast', [raw_transaction], True)
def subscribe_headers(self): def subscribe_headers(self):
return self.rpc('blockchain.headers.subscribe', [True], session=self.client) return self.rpc('blockchain.headers.subscribe', [True], True)
async def subscribe_address(self, address): async def subscribe_address(self, address):
try: try:
return await self.rpc('blockchain.address.subscribe', [address], session=self.client) return await self.rpc('blockchain.address.subscribe', [address], True)
except asyncio.TimeoutError: except asyncio.TimeoutError:
# abort and cancel, we cant lose a subscription, it will happen again on reconnect # abort and cancel, we cant lose a subscription, it will happen again on reconnect
self.client.abort() self.client.abort()
@ -267,11 +272,11 @@ class SessionPool:
@property @property
def available_sessions(self): def available_sessions(self):
return [session for session in self.sessions if session.available] return (session for session in self.sessions if session.available)
@property @property
def fastest_session(self): def fastest_session(self):
if not self.available_sessions: if not self.online:
return None return None
return min( return min(
[((session.response_time + session.connection_latency) * (session.pending_amount + 1), session) [((session.response_time + session.connection_latency) * (session.pending_amount + 1), session)