diff --git a/lbry/wallet/account.py b/lbry/wallet/account.py index 76ab7bb5a..831d45e96 100644 --- a/lbry/wallet/account.py +++ b/lbry/wallet/account.py @@ -434,8 +434,8 @@ class Account: addresses.extend(new_addresses) return addresses - async def get_addresses(self, **constraints) -> List[str]: - rows = await self.ledger.db.select_addresses('address', accounts=[self], **constraints) + async def get_addresses(self, read_only: bool = False, **constraints) -> List[str]: + rows = await self.ledger.db.select_addresses('address', read_only=read_only, accounts=[self], **constraints) return [r[0] for r in rows] def get_address_records(self, **constraints): @@ -591,11 +591,15 @@ class Account: } if reserved_subtotals else None } - def get_transaction_history(self, **constraints): - return self.ledger.get_transaction_history(wallet=self.wallet, accounts=[self], **constraints) + def get_transaction_history(self, read_only: bool = False, **constraints): + return self.ledger.get_transaction_history( + read_only=read_only, wallet=self.wallet, accounts=[self], **constraints + ) - def get_transaction_history_count(self, **constraints): - return self.ledger.get_transaction_history_count(wallet=self.wallet, accounts=[self], **constraints) + def get_transaction_history_count(self, read_only: bool = False, **constraints): + return self.ledger.get_transaction_history_count( + read_only=read_only, wallet=self.wallet, accounts=[self], **constraints + ) def get_claims(self, **constraints): return self.ledger.get_claims(wallet=self.wallet, accounts=[self], **constraints) diff --git a/lbry/wallet/database.py b/lbry/wallet/database.py index 10c42fd0a..49edaac79 100644 --- a/lbry/wallet/database.py +++ b/lbry/wallet/database.py @@ -1,9 +1,12 @@ +import os import logging import asyncio import sqlite3 - from binascii import hexlify +from dataclasses import dataclass +from contextvars import ContextVar from concurrent.futures.thread import ThreadPoolExecutor +from concurrent.futures.process import ProcessPoolExecutor from typing import Tuple, List, Union, Callable, Any, Awaitable, Iterable, Dict, Optional from .bip32 import PubKey @@ -15,7 +18,39 @@ log = logging.getLogger(__name__) sqlite3.enable_callback_tracebacks(True) +@dataclass +class ReaderProcessState: + cursor: sqlite3.Cursor + + +reader_context: Optional[ContextVar[ReaderProcessState]] = ContextVar('reader_context') + + +def initializer(path): + reader = ReaderProcessState(sqlite3.connect(path).cursor()) + reader_context.set(reader) + + +def run_read_only_fetchall(sql, params): + cursor = reader_context.get().cursor + try: + return cursor.execute(sql, params).fetchall() + except (Exception, OSError) as e: + log.exception('Error running transaction:', exc_info=e) + raise + + +def run_read_only_fetchone(sql, params): + cursor = reader_context.get().cursor + try: + return cursor.execute(sql, params).fetchone() + except (Exception, OSError) as e: + log.exception('Error running transaction:', exc_info=e) + raise + + class AIOSQLite: + reader_executor: ProcessPoolExecutor def __init__(self): # has to be single threaded as there is no mapping of thread:connection @@ -31,6 +66,11 @@ class AIOSQLite: def _connect_writer(): db.writer_connection = sqlite3.connect(path, *args, **kwargs) + + readers = max(os.cpu_count() - 2, 2) + db.reader_executor = ProcessPoolExecutor( + max_workers=readers, initializer=initializer, initargs=(path, ) + ) await asyncio.get_event_loop().run_in_executor(db.writer_executor, _connect_writer) return db @@ -40,6 +80,7 @@ class AIOSQLite: self._closing = True await asyncio.get_event_loop().run_in_executor(self.writer_executor, self.writer_connection.close) self.writer_executor.shutdown(wait=True) + self.reader_executor.shutdown(wait=True) self.writer_connection = None def executemany(self, sql: str, params: Iterable): @@ -50,12 +91,22 @@ class AIOSQLite: def executescript(self, script: str) -> Awaitable: return self.run(lambda conn: conn.executescript(script)) - def execute_fetchall(self, sql: str, parameters: Iterable = None) -> Awaitable[Iterable[sqlite3.Row]]: + async def execute_fetchall(self, sql: str, parameters: Iterable = None, + read_only: bool = False) -> Iterable[sqlite3.Row]: parameters = parameters if parameters is not None else [] - return self.run(lambda conn: conn.execute(sql, parameters).fetchall()) + if read_only: + return await asyncio.get_event_loop().run_in_executor( + self.reader_executor, run_read_only_fetchall, sql, parameters + ) + return await self.run(lambda conn: conn.execute(sql, parameters).fetchall()) - def execute_fetchone(self, sql: str, parameters: Iterable = None) -> Awaitable[Iterable[sqlite3.Row]]: + def execute_fetchone(self, sql: str, parameters: Iterable = None, + read_only: bool = False) -> Awaitable[Iterable[sqlite3.Row]]: parameters = parameters if parameters is not None else [] + if read_only: + return asyncio.get_event_loop().run_in_executor( + self.reader_executor, run_read_only_fetchone, sql, parameters + ) return self.run(lambda conn: conn.execute(sql, parameters).fetchone()) def execute(self, sql: str, parameters: Iterable = None) -> Awaitable[sqlite3.Cursor]: @@ -488,7 +539,7 @@ class Database(SQLiteMixin): # 2. update address histories removing deleted TXs return True - async def select_transactions(self, cols, accounts=None, **constraints): + async def select_transactions(self, cols, accounts=None, read_only: bool = False, **constraints): if not {'txid', 'txid__in'}.intersection(constraints): assert accounts, "'accounts' argument required when no 'txid' constraint is present" where, values = constraints_to_sql({ @@ -501,7 +552,7 @@ class Database(SQLiteMixin): """ constraints.update(values) return await self.db.execute_fetchall( - *query(f"SELECT {cols} FROM tx", **constraints) + *query(f"SELECT {cols} FROM tx", **constraints), read_only=read_only ) TXO_NOT_MINE = Output(None, None, is_my_account=False) @@ -578,7 +629,7 @@ class Database(SQLiteMixin): if txs: return txs[0] - async def select_txos(self, cols, wallet=None, include_is_received=False, **constraints): + async def select_txos(self, cols, wallet=None, include_is_received=False, read_only: bool = False, **constraints): if include_is_received: assert wallet is not None, 'cannot use is_recieved filter without wallet argument' account_in_wallet, values = constraints_to_sql({ @@ -594,7 +645,7 @@ class Database(SQLiteMixin): sql = f"SELECT {cols} FROM txo JOIN tx USING (txid)" if 'accounts' in constraints: sql += " JOIN account_address USING (address)" - return await self.db.execute_fetchall(*query(sql, **constraints)) + return await self.db.execute_fetchall(*query(sql, **constraints), read_only=read_only) @staticmethod def constrain_unspent(constraints): @@ -700,18 +751,18 @@ class Database(SQLiteMixin): balance = await self.select_txos('SUM(amount)', **constraints) return balance[0][0] or 0 - async def select_addresses(self, cols, **constraints): + async def select_addresses(self, cols, read_only: bool = False, **constraints): return await self.db.execute_fetchall(*query( f"SELECT {cols} FROM pubkey_address JOIN account_address USING (address)", **constraints - )) + ), read_only=read_only) - async def get_addresses(self, cols=None, **constraints): + async def get_addresses(self, cols=None, read_only: bool = False, **constraints): cols = cols or ( 'address', 'account', 'chain', 'history', 'used_times', 'pubkey', 'chain_code', 'n', 'depth' ) - addresses = rows_to_dict(await self.select_addresses(', '.join(cols), **constraints), cols) + addresses = rows_to_dict(await self.select_addresses(', '.join(cols), read_only=read_only, **constraints), cols) if 'pubkey' in cols: for address in addresses: address['pubkey'] = PubKey( @@ -720,8 +771,8 @@ class Database(SQLiteMixin): ) return addresses - async def get_address_count(self, cols=None, **constraints): - count = await self.select_addresses('count(*)', **constraints) + async def get_address_count(self, cols=None, read_only: bool = False, **constraints): + count = await self.select_addresses('count(*)', read_only=read_only, **constraints) return count[0][0] async def get_address(self, **constraints): @@ -853,7 +904,7 @@ class Database(SQLiteMixin): " )", (account.public_key.address, ) ) - def get_supports_summary(self, account_id): + def get_supports_summary(self, account_id, read_only: bool = False): return self.db.execute_fetchall(f""" select txo.amount, exists(select * from txi where txi.txoid=txo.txoid) as spent, (txo.txid in @@ -862,4 +913,4 @@ class Database(SQLiteMixin): (txo.address in (select address from account_address where account=?)) as to_me, tx.height from txo join tx using (txid) where txo_type={TXO_TYPES['support']} - """, (account_id, account_id)) + """, (account_id, account_id), read_only=read_only) diff --git a/lbry/wallet/ledger.py b/lbry/wallet/ledger.py index feef913b5..ba2256dec 100644 --- a/lbry/wallet/ledger.py +++ b/lbry/wallet/ledger.py @@ -7,9 +7,9 @@ from io import StringIO from datetime import datetime from functools import partial from operator import itemgetter -from collections import namedtuple, defaultdict +from collections import defaultdict from binascii import hexlify, unhexlify -from typing import Dict, Tuple, Type, Iterable, List, Optional, DefaultDict +from typing import Dict, Tuple, Type, Iterable, List, Optional, DefaultDict, NamedTuple import pylru from lbry.schema.result import Outputs, INVALID, NOT_FOUND @@ -53,16 +53,19 @@ class LedgerRegistry(type): return mcs.ledgers[ledger_id] -class TransactionEvent(namedtuple('TransactionEvent', ('address', 'tx'))): - pass +class TransactionEvent(NamedTuple): + address: str + tx: Transaction -class AddressesGeneratedEvent(namedtuple('AddressesGeneratedEvent', ('address_manager', 'addresses'))): - pass +class AddressesGeneratedEvent(NamedTuple): + address_manager: AddressManager + addresses: List[str] -class BlockHeightEvent(namedtuple('BlockHeightEvent', ('height', 'change'))): - pass +class BlockHeightEvent(NamedTuple): + height: int + change: int class TransactionCacheItem: @@ -822,8 +825,8 @@ class Ledger(metaclass=LedgerRegistry): def get_support_count(self, **constraints): return self.db.get_support_count(**constraints) - async def get_transaction_history(self, **constraints): - txs: List[Transaction] = await self.db.get_transactions(**constraints) + async def get_transaction_history(self, read_only: bool = False, **constraints): + txs: List[Transaction] = await self.db.get_transactions(read_only=read_only, **constraints) headers = self.headers history = [] for tx in txs: # pylint: disable=too-many-nested-blocks @@ -932,8 +935,8 @@ class Ledger(metaclass=LedgerRegistry): history.append(item) return history - def get_transaction_history_count(self, **constraints): - return self.db.get_transaction_count(**constraints) + def get_transaction_history_count(self, read_only: bool = False, **constraints): + return self.db.get_transaction_count(read_only=read_only, **constraints) async def get_detailed_balance(self, accounts, confirmations=0): result = {