diff --git a/lbry/extras/daemon/daemon.py b/lbry/extras/daemon/daemon.py index 9e0e3d658..51507a526 100644 --- a/lbry/extras/daemon/daemon.py +++ b/lbry/extras/daemon/daemon.py @@ -1492,7 +1492,7 @@ class Daemon(metaclass=JSONRPCServerType): wallet = self.wallet_manager.get_wallet_or_default(wallet_id) account = wallet.get_account_or_default(account_id) balance = await account.get_detailed_balance( - confirmations=confirmations, reserved_subtotals=True + confirmations=confirmations, reserved_subtotals=True, read_only=True ) return dict_values_to_lbc(balance) @@ -1817,7 +1817,7 @@ class Daemon(metaclass=JSONRPCServerType): """ wallet = self.wallet_manager.get_wallet_or_default(wallet_id) account = wallet.get_account_or_default(account_id) - match = await self.ledger.db.get_address(address=address, accounts=[account]) + match = await self.ledger.db.get_address(read_only=True, address=address, accounts=[account]) if match is not None: return True return False @@ -1853,7 +1853,7 @@ class Daemon(metaclass=JSONRPCServerType): return paginate_rows( self.ledger.get_addresses, self.ledger.get_address_count, - page, page_size, **constraints + page, page_size, read_only=True, **constraints ) @requires(WALLET_COMPONENT) @@ -4089,7 +4089,7 @@ class Daemon(metaclass=JSONRPCServerType): self.ledger.get_transaction_history, wallet=wallet, accounts=wallet.accounts) transaction_count = partial( self.ledger.get_transaction_history_count, wallet=wallet, accounts=wallet.accounts) - return paginate_rows(transactions, transaction_count, page, page_size) + return paginate_rows(transactions, transaction_count, page, page_size, read_only=True) @requires(WALLET_COMPONENT) def jsonrpc_transaction_show(self, txid): @@ -4153,8 +4153,8 @@ class Daemon(metaclass=JSONRPCServerType): claims = account.get_txos claim_count = account.get_txo_count else: - claims = partial(self.ledger.get_txos, wallet=wallet, accounts=wallet.accounts) - claim_count = partial(self.ledger.get_txo_count, wallet=wallet, accounts=wallet.accounts) + claims = partial(self.ledger.get_txos, wallet=wallet, accounts=wallet.accounts, read_only=True) + claim_count = partial(self.ledger.get_txo_count, wallet=wallet, accounts=wallet.accounts, read_only=True) constraints = {'resolve': resolve, 'unspent': unspent, 'include_is_received': include_is_received} if is_received is True: constraints['is_received'] = True @@ -4305,7 +4305,7 @@ class Daemon(metaclass=JSONRPCServerType): search_bottom_out_limit = 4 peers = [] peer_q = asyncio.Queue(loop=self.component_manager.loop) - await self.dht_node._value_producer(blob_hash, peer_q) + await self.dht_node._peers_for_value_producer(blob_hash, peer_q) while not peer_q.empty(): peers.extend(peer_q.get_nowait()) results = [ diff --git a/lbry/wallet/account.py b/lbry/wallet/account.py index 76ab7bb5a..b7d4e29c5 100644 --- a/lbry/wallet/account.py +++ b/lbry/wallet/account.py @@ -71,6 +71,7 @@ class AddressManager: def _query_addresses(self, **constraints): return self.account.ledger.db.get_addresses( + read_only=constraints.pop("read_only", False), accounts=[self.account], chain=self.chain_number, **constraints @@ -434,8 +435,8 @@ class Account: addresses.extend(new_addresses) return addresses - async def get_addresses(self, **constraints) -> List[str]: - rows = await self.ledger.db.select_addresses('address', accounts=[self], **constraints) + async def get_addresses(self, read_only=False, **constraints) -> List[str]: + rows = await self.ledger.db.select_addresses('address', read_only=read_only, accounts=[self], **constraints) return [r[0] for r in rows] def get_address_records(self, **constraints): @@ -451,13 +452,13 @@ class Account: def get_public_key(self, chain: int, index: int) -> PubKey: return self.address_managers[chain].get_public_key(index) - def get_balance(self, confirmations: int = 0, include_claims=False, **constraints): + def get_balance(self, confirmations=0, include_claims=False, read_only=False, **constraints): if not include_claims: constraints.update({'txo_type__in': (0, TXO_TYPES['purchase'])}) if confirmations > 0: height = self.ledger.headers.height - (confirmations-1) constraints.update({'height__lte': height, 'height__gt': 0}) - return self.ledger.db.get_balance(accounts=[self], **constraints) + return self.ledger.db.get_balance(accounts=[self], read_only=read_only, **constraints) async def get_max_gap(self): change_gap = await self.change.get_max_gap() @@ -561,9 +562,10 @@ class Account: if gap_changed: self.wallet.save() - async def get_detailed_balance(self, confirmations=0, reserved_subtotals=False): + async def get_detailed_balance(self, confirmations=0, reserved_subtotals=False, read_only=False): tips_balance, supports_balance, claims_balance = 0, 0, 0 - get_total_balance = partial(self.get_balance, confirmations=confirmations, include_claims=True) + get_total_balance = partial(self.get_balance, read_only=read_only, confirmations=confirmations, + include_claims=True) total = await get_total_balance() if reserved_subtotals: claims_balance = await get_total_balance(txo_type__in=CLAIM_TYPES) @@ -591,11 +593,15 @@ class Account: } if reserved_subtotals else None } - def get_transaction_history(self, **constraints): - return self.ledger.get_transaction_history(wallet=self.wallet, accounts=[self], **constraints) + def get_transaction_history(self, read_only=False, **constraints): + return self.ledger.get_transaction_history( + read_only=read_only, wallet=self.wallet, accounts=[self], **constraints + ) - def get_transaction_history_count(self, **constraints): - return self.ledger.get_transaction_history_count(wallet=self.wallet, accounts=[self], **constraints) + def get_transaction_history_count(self, read_only=False, **constraints): + return self.ledger.get_transaction_history_count( + read_only=read_only, wallet=self.wallet, accounts=[self], **constraints + ) def get_claims(self, **constraints): return self.ledger.get_claims(wallet=self.wallet, accounts=[self], **constraints) diff --git a/lbry/wallet/database.py b/lbry/wallet/database.py index 95e0e82c2..4c733fa14 100644 --- a/lbry/wallet/database.py +++ b/lbry/wallet/database.py @@ -1,9 +1,13 @@ +import os import logging import asyncio import sqlite3 - +import platform from binascii import hexlify +from dataclasses import dataclass +from contextvars import ContextVar from concurrent.futures.thread import ThreadPoolExecutor +from concurrent.futures.process import ProcessPoolExecutor from typing import Tuple, List, Union, Callable, Any, Awaitable, Iterable, Dict, Optional from .bip32 import PubKey @@ -15,31 +19,83 @@ log = logging.getLogger(__name__) sqlite3.enable_callback_tracebacks(True) +@dataclass +class ReaderProcessState: + cursor: sqlite3.Cursor + + +reader_context: Optional[ContextVar[ReaderProcessState]] = ContextVar('reader_context') + + +def initializer(path): + db = sqlite3.connect(path) + db.executescript("pragma journal_mode=WAL;") + reader = ReaderProcessState(db.cursor()) + reader_context.set(reader) + + +def run_read_only_fetchall(sql, params): + cursor = reader_context.get().cursor + try: + return cursor.execute(sql, params).fetchall() + except (Exception, OSError) as e: + log.exception('Error running transaction:', exc_info=e) + raise + + +def run_read_only_fetchone(sql, params): + cursor = reader_context.get().cursor + try: + return cursor.execute(sql, params).fetchone() + except (Exception, OSError) as e: + log.exception('Error running transaction:', exc_info=e) + raise + + +if platform.system() == 'Windows' or 'ANDROID_ARGUMENT' in os.environ: + ReaderExecutorClass = ThreadPoolExecutor +else: + ReaderExecutorClass = ProcessPoolExecutor + + class AIOSQLite: + reader_executor: ReaderExecutorClass def __init__(self): # has to be single threaded as there is no mapping of thread:connection - self.executor = ThreadPoolExecutor(max_workers=1) - self.connection: sqlite3.Connection = None + self.writer_executor = ThreadPoolExecutor(max_workers=1) + self.writer_connection: Optional[sqlite3.Connection] = None self._closing = False self.query_count = 0 + self.write_lock = asyncio.Lock() + self.writers = 0 + self.read_ready = asyncio.Event() @classmethod async def connect(cls, path: Union[bytes, str], *args, **kwargs): sqlite3.enable_callback_tracebacks(True) - def _connect(): - return sqlite3.connect(path, *args, **kwargs) db = cls() - db.connection = await asyncio.get_event_loop().run_in_executor(db.executor, _connect) + + def _connect_writer(): + db.writer_connection = sqlite3.connect(path, *args, **kwargs) + + readers = max(os.cpu_count() - 2, 2) + db.reader_executor = ReaderExecutorClass( + max_workers=readers, initializer=initializer, initargs=(path, ) + ) + await asyncio.get_event_loop().run_in_executor(db.writer_executor, _connect_writer) + db.read_ready.set() return db async def close(self): if self._closing: return self._closing = True - await asyncio.get_event_loop().run_in_executor(self.executor, self.connection.close) - self.executor.shutdown(wait=True) - self.connection = None + await asyncio.get_event_loop().run_in_executor(self.writer_executor, self.writer_connection.close) + self.writer_executor.shutdown(wait=True) + self.reader_executor.shutdown(wait=True) + self.read_ready.clear() + self.writer_connection = None def executemany(self, sql: str, params: Iterable): params = params if params is not None else [] @@ -49,52 +105,74 @@ class AIOSQLite: def executescript(self, script: str) -> Awaitable: return self.run(lambda conn: conn.executescript(script)) - def execute_fetchall(self, sql: str, parameters: Iterable = None) -> Awaitable[Iterable[sqlite3.Row]]: + async def _execute_fetch(self, sql: str, parameters: Iterable = None, + read_only=False, fetch_all: bool = False) -> Iterable[sqlite3.Row]: + read_only_fn = run_read_only_fetchall if fetch_all else run_read_only_fetchone parameters = parameters if parameters is not None else [] - return self.run(lambda conn: conn.execute(sql, parameters).fetchall()) + if read_only: + while self.writers: + await self.read_ready.wait() + return await asyncio.get_event_loop().run_in_executor( + self.reader_executor, read_only_fn, sql, parameters + ) + if fetch_all: + return await self.run(lambda conn: conn.execute(sql, parameters).fetchall()) + return await self.run(lambda conn: conn.execute(sql, parameters).fetchone()) - def execute_fetchone(self, sql: str, parameters: Iterable = None) -> Awaitable[Iterable[sqlite3.Row]]: - parameters = parameters if parameters is not None else [] - return self.run(lambda conn: conn.execute(sql, parameters).fetchone()) + async def execute_fetchall(self, sql: str, parameters: Iterable = None, + read_only=False) -> Iterable[sqlite3.Row]: + return await self._execute_fetch(sql, parameters, read_only, fetch_all=True) + + async def execute_fetchone(self, sql: str, parameters: Iterable = None, + read_only=False) -> Iterable[sqlite3.Row]: + return await self._execute_fetch(sql, parameters, read_only, fetch_all=False) def execute(self, sql: str, parameters: Iterable = None) -> Awaitable[sqlite3.Cursor]: parameters = parameters if parameters is not None else [] return self.run(lambda conn: conn.execute(sql, parameters)) - def run(self, fun, *args, **kwargs) -> Awaitable: - return asyncio.get_event_loop().run_in_executor( - self.executor, lambda: self.__run_transaction(fun, *args, **kwargs) - ) + async def run(self, fun, *args, **kwargs): + self.writers += 1 + self.read_ready.clear() + async with self.write_lock: + try: + return await asyncio.get_event_loop().run_in_executor( + self.writer_executor, lambda: self.__run_transaction(fun, *args, **kwargs) + ) + finally: + self.writers -= 1 + if not self.writers: + self.read_ready.set() def __run_transaction(self, fun: Callable[[sqlite3.Connection, Any, Any], Any], *args, **kwargs): - self.connection.execute('begin') + self.writer_connection.execute('begin') try: self.query_count += 1 - result = fun(self.connection, *args, **kwargs) # type: ignore - self.connection.commit() + result = fun(self.writer_connection, *args, **kwargs) # type: ignore + self.writer_connection.commit() return result except (Exception, OSError) as e: log.exception('Error running transaction:', exc_info=e) - self.connection.rollback() + self.writer_connection.rollback() log.warning("rolled back") raise def run_with_foreign_keys_disabled(self, fun, *args, **kwargs) -> Awaitable: return asyncio.get_event_loop().run_in_executor( - self.executor, self.__run_transaction_with_foreign_keys_disabled, fun, args, kwargs + self.writer_executor, self.__run_transaction_with_foreign_keys_disabled, fun, args, kwargs ) def __run_transaction_with_foreign_keys_disabled(self, fun: Callable[[sqlite3.Connection, Any, Any], Any], args, kwargs): - foreign_keys_enabled, = self.connection.execute("pragma foreign_keys").fetchone() + foreign_keys_enabled, = self.writer_connection.execute("pragma foreign_keys").fetchone() if not foreign_keys_enabled: raise sqlite3.IntegrityError("foreign keys are disabled, use `AIOSQLite.run` instead") try: - self.connection.execute('pragma foreign_keys=off').fetchone() + self.writer_connection.execute('pragma foreign_keys=off').fetchone() return self.__run_transaction(fun, *args, **kwargs) finally: - self.connection.execute('pragma foreign_keys=on').fetchone() + self.writer_connection.execute('pragma foreign_keys=on').fetchone() def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''): @@ -487,7 +565,7 @@ class Database(SQLiteMixin): # 2. update address histories removing deleted TXs return True - async def select_transactions(self, cols, accounts=None, **constraints): + async def select_transactions(self, cols, accounts=None, read_only=False, **constraints): if not {'txid', 'txid__in'}.intersection(constraints): assert accounts, "'accounts' argument required when no 'txid' constraint is present" where, values = constraints_to_sql({ @@ -500,7 +578,7 @@ class Database(SQLiteMixin): """ constraints.update(values) return await self.db.execute_fetchall( - *query(f"SELECT {cols} FROM tx", **constraints) + *query(f"SELECT {cols} FROM tx", **constraints), read_only=read_only ) TXO_NOT_MINE = Output(None, None, is_my_account=False) @@ -577,7 +655,7 @@ class Database(SQLiteMixin): if txs: return txs[0] - async def select_txos(self, cols, wallet=None, include_is_received=False, **constraints): + async def select_txos(self, cols, wallet=None, include_is_received=False, read_only=False, **constraints): if include_is_received: assert wallet is not None, 'cannot use is_recieved filter without wallet argument' account_in_wallet, values = constraints_to_sql({ @@ -593,14 +671,15 @@ class Database(SQLiteMixin): sql = f"SELECT {cols} FROM txo JOIN tx USING (txid)" if 'accounts' in constraints: sql += " JOIN account_address USING (address)" - return await self.db.execute_fetchall(*query(sql, **constraints)) + return await self.db.execute_fetchall(*query(sql, **constraints), read_only=read_only) @staticmethod def constrain_unspent(constraints): constraints['is_reserved'] = False constraints['txoid__not_in'] = "SELECT txoid FROM txi" - async def get_txos(self, wallet=None, no_tx=False, unspent=False, include_is_received=False, **constraints): + async def get_txos(self, wallet=None, no_tx=False, unspent=False, include_is_received=False, + read_only=False, **constraints): include_is_received = include_is_received or 'is_received' in constraints if unspent: self.constrain_unspent(constraints) @@ -616,7 +695,7 @@ class Database(SQLiteMixin): where account_address.address=txo.address ), exists(select 1 from txi where txi.txoid=txo.txoid) """, - wallet=wallet, include_is_received=include_is_received, **constraints + wallet=wallet, include_is_received=include_is_received, read_only=read_only, **constraints ) txos = [] txs = {} @@ -665,7 +744,8 @@ class Database(SQLiteMixin): txo.claim_id: txo for txo in (await self.get_channels( wallet=wallet, - claim_id__in=channel_ids + claim_id__in=channel_ids, + read_only=read_only )) } for txo in txos: @@ -685,32 +765,32 @@ class Database(SQLiteMixin): count = await self.select_txos('count(*)', **constraints) return count[0][0] - def get_utxos(self, **constraints): - return self.get_txos(unspent=True, **constraints) + def get_utxos(self, read_only=False, **constraints): + return self.get_txos(unspent=True, read_only=read_only, **constraints) def get_utxo_count(self, **constraints): return self.get_txo_count(unspent=True, **constraints) - async def get_balance(self, wallet=None, accounts=None, **constraints): + async def get_balance(self, wallet=None, accounts=None, read_only=False, **constraints): assert wallet or accounts, \ "'wallet' or 'accounts' constraints required to calculate balance" constraints['accounts'] = accounts or wallet.accounts self.constrain_unspent(constraints) - balance = await self.select_txos('SUM(amount)', **constraints) + balance = await self.select_txos('SUM(amount)', read_only=read_only, **constraints) return balance[0][0] or 0 - async def select_addresses(self, cols, **constraints): + async def select_addresses(self, cols, read_only=False, **constraints): return await self.db.execute_fetchall(*query( f"SELECT {cols} FROM pubkey_address JOIN account_address USING (address)", **constraints - )) + ), read_only=read_only) - async def get_addresses(self, cols=None, **constraints): + async def get_addresses(self, cols=None, read_only=False, **constraints): cols = cols or ( 'address', 'account', 'chain', 'history', 'used_times', 'pubkey', 'chain_code', 'n', 'depth' ) - addresses = rows_to_dict(await self.select_addresses(', '.join(cols), **constraints), cols) + addresses = rows_to_dict(await self.select_addresses(', '.join(cols), read_only=read_only, **constraints), cols) if 'pubkey' in cols: for address in addresses: address['pubkey'] = PubKey( @@ -719,12 +799,12 @@ class Database(SQLiteMixin): ) return addresses - async def get_address_count(self, cols=None, **constraints): - count = await self.select_addresses('count(*)', **constraints) + async def get_address_count(self, cols=None, read_only=False, **constraints): + count = await self.select_addresses('count(*)', read_only=read_only, **constraints) return count[0][0] - async def get_address(self, **constraints): - addresses = await self.get_addresses(limit=1, **constraints) + async def get_address(self, read_only=False, **constraints): + addresses = await self.get_addresses(read_only=read_only, limit=1, **constraints) if addresses: return addresses[0] @@ -788,9 +868,9 @@ class Database(SQLiteMixin): else: constraints['txo_type__in'] = CLAIM_TYPES - async def get_claims(self, **constraints) -> List[Output]: + async def get_claims(self, read_only=False, **constraints) -> List[Output]: self.constrain_claims(constraints) - return await self.get_utxos(**constraints) + return await self.get_utxos(read_only=read_only, **constraints) def get_claim_count(self, **constraints): self.constrain_claims(constraints) @@ -800,9 +880,9 @@ class Database(SQLiteMixin): def constrain_streams(constraints): constraints['txo_type'] = TXO_TYPES['stream'] - def get_streams(self, **constraints): + def get_streams(self, read_only=False, **constraints): self.constrain_streams(constraints) - return self.get_claims(**constraints) + return self.get_claims(read_only=read_only, **constraints) def get_stream_count(self, **constraints): self.constrain_streams(constraints) @@ -852,7 +932,7 @@ class Database(SQLiteMixin): " )", (account.public_key.address, ) ) - def get_supports_summary(self, account_id): + def get_supports_summary(self, account_id, read_only=False): return self.db.execute_fetchall(f""" select txo.amount, exists(select * from txi where txi.txoid=txo.txoid) as spent, (txo.txid in @@ -861,4 +941,4 @@ class Database(SQLiteMixin): (txo.address in (select address from account_address where account=?)) as to_me, tx.height from txo join tx using (txid) where txo_type={TXO_TYPES['support']} - """, (account_id, account_id)) + """, (account_id, account_id), read_only=read_only) diff --git a/lbry/wallet/ledger.py b/lbry/wallet/ledger.py index feef913b5..64ac744c0 100644 --- a/lbry/wallet/ledger.py +++ b/lbry/wallet/ledger.py @@ -7,9 +7,9 @@ from io import StringIO from datetime import datetime from functools import partial from operator import itemgetter -from collections import namedtuple, defaultdict +from collections import defaultdict from binascii import hexlify, unhexlify -from typing import Dict, Tuple, Type, Iterable, List, Optional, DefaultDict +from typing import Dict, Tuple, Type, Iterable, List, Optional, DefaultDict, NamedTuple import pylru from lbry.schema.result import Outputs, INVALID, NOT_FOUND @@ -53,16 +53,19 @@ class LedgerRegistry(type): return mcs.ledgers[ledger_id] -class TransactionEvent(namedtuple('TransactionEvent', ('address', 'tx'))): - pass +class TransactionEvent(NamedTuple): + address: str + tx: Transaction -class AddressesGeneratedEvent(namedtuple('AddressesGeneratedEvent', ('address_manager', 'addresses'))): - pass +class AddressesGeneratedEvent(NamedTuple): + address_manager: AddressManager + addresses: List[str] -class BlockHeightEvent(namedtuple('BlockHeightEvent', ('height', 'change'))): - pass +class BlockHeightEvent(NamedTuple): + height: int + change: int class TransactionCacheItem: @@ -822,8 +825,8 @@ class Ledger(metaclass=LedgerRegistry): def get_support_count(self, **constraints): return self.db.get_support_count(**constraints) - async def get_transaction_history(self, **constraints): - txs: List[Transaction] = await self.db.get_transactions(**constraints) + async def get_transaction_history(self, read_only=False, **constraints): + txs: List[Transaction] = await self.db.get_transactions(read_only=read_only, **constraints) headers = self.headers history = [] for tx in txs: # pylint: disable=too-many-nested-blocks @@ -932,8 +935,8 @@ class Ledger(metaclass=LedgerRegistry): history.append(item) return history - def get_transaction_history_count(self, **constraints): - return self.db.get_transaction_count(**constraints) + def get_transaction_history_count(self, read_only=False, **constraints): + return self.db.get_transaction_count(read_only=read_only, **constraints) async def get_detailed_balance(self, accounts, confirmations=0): result = { diff --git a/tests/unit/wallet/test_database.py b/tests/unit/wallet/test_database.py index 29ef58b46..50a40ad91 100644 --- a/tests/unit/wallet/test_database.py +++ b/tests/unit/wallet/test_database.py @@ -266,9 +266,10 @@ class TestQueries(AsyncioTestCase): # SQLite is usually compiled with 999 variables limit: https://www.sqlite.org/limits.html # This can be removed when there is a better way. See: https://github.com/lbryio/lbry-sdk/issues/2281 fetchall = self.ledger.db.db.execute_fetchall - def check_parameters_length(sql, parameters): + + def check_parameters_length(sql, parameters, read_only=False): self.assertLess(len(parameters or []), 999) - return fetchall(sql, parameters) + return fetchall(sql, parameters, read_only) self.ledger.db.db.execute_fetchall = check_parameters_length account = await self.create_account()