Merge pull request #2823 from lbryio/multiple-db-readers

Use multiple processes for querying the db for api calls
This commit is contained in:
Jack Robison 2020-03-19 19:19:53 -04:00 committed by GitHub
commit c13aab3ffc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 172 additions and 82 deletions

View file

@ -1492,7 +1492,7 @@ class Daemon(metaclass=JSONRPCServerType):
wallet = self.wallet_manager.get_wallet_or_default(wallet_id)
account = wallet.get_account_or_default(account_id)
balance = await account.get_detailed_balance(
confirmations=confirmations, reserved_subtotals=True
confirmations=confirmations, reserved_subtotals=True, read_only=True
)
return dict_values_to_lbc(balance)
@ -1817,7 +1817,7 @@ class Daemon(metaclass=JSONRPCServerType):
"""
wallet = self.wallet_manager.get_wallet_or_default(wallet_id)
account = wallet.get_account_or_default(account_id)
match = await self.ledger.db.get_address(address=address, accounts=[account])
match = await self.ledger.db.get_address(read_only=True, address=address, accounts=[account])
if match is not None:
return True
return False
@ -1853,7 +1853,7 @@ class Daemon(metaclass=JSONRPCServerType):
return paginate_rows(
self.ledger.get_addresses,
self.ledger.get_address_count,
page, page_size, **constraints
page, page_size, read_only=True, **constraints
)
@requires(WALLET_COMPONENT)
@ -4089,7 +4089,7 @@ class Daemon(metaclass=JSONRPCServerType):
self.ledger.get_transaction_history, wallet=wallet, accounts=wallet.accounts)
transaction_count = partial(
self.ledger.get_transaction_history_count, wallet=wallet, accounts=wallet.accounts)
return paginate_rows(transactions, transaction_count, page, page_size)
return paginate_rows(transactions, transaction_count, page, page_size, read_only=True)
@requires(WALLET_COMPONENT)
def jsonrpc_transaction_show(self, txid):
@ -4153,8 +4153,8 @@ class Daemon(metaclass=JSONRPCServerType):
claims = account.get_txos
claim_count = account.get_txo_count
else:
claims = partial(self.ledger.get_txos, wallet=wallet, accounts=wallet.accounts)
claim_count = partial(self.ledger.get_txo_count, wallet=wallet, accounts=wallet.accounts)
claims = partial(self.ledger.get_txos, wallet=wallet, accounts=wallet.accounts, read_only=True)
claim_count = partial(self.ledger.get_txo_count, wallet=wallet, accounts=wallet.accounts, read_only=True)
constraints = {'resolve': resolve, 'unspent': unspent, 'include_is_received': include_is_received}
if is_received is True:
constraints['is_received'] = True
@ -4305,7 +4305,7 @@ class Daemon(metaclass=JSONRPCServerType):
search_bottom_out_limit = 4
peers = []
peer_q = asyncio.Queue(loop=self.component_manager.loop)
await self.dht_node._value_producer(blob_hash, peer_q)
await self.dht_node._peers_for_value_producer(blob_hash, peer_q)
while not peer_q.empty():
peers.extend(peer_q.get_nowait())
results = [

View file

@ -71,6 +71,7 @@ class AddressManager:
def _query_addresses(self, **constraints):
return self.account.ledger.db.get_addresses(
read_only=constraints.pop("read_only", False),
accounts=[self.account],
chain=self.chain_number,
**constraints
@ -434,8 +435,8 @@ class Account:
addresses.extend(new_addresses)
return addresses
async def get_addresses(self, **constraints) -> List[str]:
rows = await self.ledger.db.select_addresses('address', accounts=[self], **constraints)
async def get_addresses(self, read_only=False, **constraints) -> List[str]:
rows = await self.ledger.db.select_addresses('address', read_only=read_only, accounts=[self], **constraints)
return [r[0] for r in rows]
def get_address_records(self, **constraints):
@ -451,13 +452,13 @@ class Account:
def get_public_key(self, chain: int, index: int) -> PubKey:
return self.address_managers[chain].get_public_key(index)
def get_balance(self, confirmations: int = 0, include_claims=False, **constraints):
def get_balance(self, confirmations=0, include_claims=False, read_only=False, **constraints):
if not include_claims:
constraints.update({'txo_type__in': (0, TXO_TYPES['purchase'])})
if confirmations > 0:
height = self.ledger.headers.height - (confirmations-1)
constraints.update({'height__lte': height, 'height__gt': 0})
return self.ledger.db.get_balance(accounts=[self], **constraints)
return self.ledger.db.get_balance(accounts=[self], read_only=read_only, **constraints)
async def get_max_gap(self):
change_gap = await self.change.get_max_gap()
@ -561,9 +562,10 @@ class Account:
if gap_changed:
self.wallet.save()
async def get_detailed_balance(self, confirmations=0, reserved_subtotals=False):
async def get_detailed_balance(self, confirmations=0, reserved_subtotals=False, read_only=False):
tips_balance, supports_balance, claims_balance = 0, 0, 0
get_total_balance = partial(self.get_balance, confirmations=confirmations, include_claims=True)
get_total_balance = partial(self.get_balance, read_only=read_only, confirmations=confirmations,
include_claims=True)
total = await get_total_balance()
if reserved_subtotals:
claims_balance = await get_total_balance(txo_type__in=CLAIM_TYPES)
@ -591,11 +593,15 @@ class Account:
} if reserved_subtotals else None
}
def get_transaction_history(self, **constraints):
return self.ledger.get_transaction_history(wallet=self.wallet, accounts=[self], **constraints)
def get_transaction_history(self, read_only=False, **constraints):
return self.ledger.get_transaction_history(
read_only=read_only, wallet=self.wallet, accounts=[self], **constraints
)
def get_transaction_history_count(self, **constraints):
return self.ledger.get_transaction_history_count(wallet=self.wallet, accounts=[self], **constraints)
def get_transaction_history_count(self, read_only=False, **constraints):
return self.ledger.get_transaction_history_count(
read_only=read_only, wallet=self.wallet, accounts=[self], **constraints
)
def get_claims(self, **constraints):
return self.ledger.get_claims(wallet=self.wallet, accounts=[self], **constraints)

View file

@ -1,9 +1,13 @@
import os
import logging
import asyncio
import sqlite3
import platform
from binascii import hexlify
from dataclasses import dataclass
from contextvars import ContextVar
from concurrent.futures.thread import ThreadPoolExecutor
from concurrent.futures.process import ProcessPoolExecutor
from typing import Tuple, List, Union, Callable, Any, Awaitable, Iterable, Dict, Optional
from .bip32 import PubKey
@ -15,31 +19,83 @@ log = logging.getLogger(__name__)
sqlite3.enable_callback_tracebacks(True)
@dataclass
class ReaderProcessState:
cursor: sqlite3.Cursor
reader_context: Optional[ContextVar[ReaderProcessState]] = ContextVar('reader_context')
def initializer(path):
db = sqlite3.connect(path)
db.executescript("pragma journal_mode=WAL;")
reader = ReaderProcessState(db.cursor())
reader_context.set(reader)
def run_read_only_fetchall(sql, params):
cursor = reader_context.get().cursor
try:
return cursor.execute(sql, params).fetchall()
except (Exception, OSError) as e:
log.exception('Error running transaction:', exc_info=e)
raise
def run_read_only_fetchone(sql, params):
cursor = reader_context.get().cursor
try:
return cursor.execute(sql, params).fetchone()
except (Exception, OSError) as e:
log.exception('Error running transaction:', exc_info=e)
raise
if platform.system() == 'Windows' or 'ANDROID_ARGUMENT' in os.environ:
ReaderExecutorClass = ThreadPoolExecutor
else:
ReaderExecutorClass = ProcessPoolExecutor
class AIOSQLite:
reader_executor: ReaderExecutorClass
def __init__(self):
# has to be single threaded as there is no mapping of thread:connection
self.executor = ThreadPoolExecutor(max_workers=1)
self.connection: sqlite3.Connection = None
self.writer_executor = ThreadPoolExecutor(max_workers=1)
self.writer_connection: Optional[sqlite3.Connection] = None
self._closing = False
self.query_count = 0
self.write_lock = asyncio.Lock()
self.writers = 0
self.read_ready = asyncio.Event()
@classmethod
async def connect(cls, path: Union[bytes, str], *args, **kwargs):
sqlite3.enable_callback_tracebacks(True)
def _connect():
return sqlite3.connect(path, *args, **kwargs)
db = cls()
db.connection = await asyncio.get_event_loop().run_in_executor(db.executor, _connect)
def _connect_writer():
db.writer_connection = sqlite3.connect(path, *args, **kwargs)
readers = max(os.cpu_count() - 2, 2)
db.reader_executor = ReaderExecutorClass(
max_workers=readers, initializer=initializer, initargs=(path, )
)
await asyncio.get_event_loop().run_in_executor(db.writer_executor, _connect_writer)
db.read_ready.set()
return db
async def close(self):
if self._closing:
return
self._closing = True
await asyncio.get_event_loop().run_in_executor(self.executor, self.connection.close)
self.executor.shutdown(wait=True)
self.connection = None
await asyncio.get_event_loop().run_in_executor(self.writer_executor, self.writer_connection.close)
self.writer_executor.shutdown(wait=True)
self.reader_executor.shutdown(wait=True)
self.read_ready.clear()
self.writer_connection = None
def executemany(self, sql: str, params: Iterable):
params = params if params is not None else []
@ -49,52 +105,74 @@ class AIOSQLite:
def executescript(self, script: str) -> Awaitable:
return self.run(lambda conn: conn.executescript(script))
def execute_fetchall(self, sql: str, parameters: Iterable = None) -> Awaitable[Iterable[sqlite3.Row]]:
async def _execute_fetch(self, sql: str, parameters: Iterable = None,
read_only=False, fetch_all: bool = False) -> Iterable[sqlite3.Row]:
read_only_fn = run_read_only_fetchall if fetch_all else run_read_only_fetchone
parameters = parameters if parameters is not None else []
return self.run(lambda conn: conn.execute(sql, parameters).fetchall())
if read_only:
while self.writers:
await self.read_ready.wait()
return await asyncio.get_event_loop().run_in_executor(
self.reader_executor, read_only_fn, sql, parameters
)
if fetch_all:
return await self.run(lambda conn: conn.execute(sql, parameters).fetchall())
return await self.run(lambda conn: conn.execute(sql, parameters).fetchone())
def execute_fetchone(self, sql: str, parameters: Iterable = None) -> Awaitable[Iterable[sqlite3.Row]]:
parameters = parameters if parameters is not None else []
return self.run(lambda conn: conn.execute(sql, parameters).fetchone())
async def execute_fetchall(self, sql: str, parameters: Iterable = None,
read_only=False) -> Iterable[sqlite3.Row]:
return await self._execute_fetch(sql, parameters, read_only, fetch_all=True)
async def execute_fetchone(self, sql: str, parameters: Iterable = None,
read_only=False) -> Iterable[sqlite3.Row]:
return await self._execute_fetch(sql, parameters, read_only, fetch_all=False)
def execute(self, sql: str, parameters: Iterable = None) -> Awaitable[sqlite3.Cursor]:
parameters = parameters if parameters is not None else []
return self.run(lambda conn: conn.execute(sql, parameters))
def run(self, fun, *args, **kwargs) -> Awaitable:
return asyncio.get_event_loop().run_in_executor(
self.executor, lambda: self.__run_transaction(fun, *args, **kwargs)
)
async def run(self, fun, *args, **kwargs):
self.writers += 1
self.read_ready.clear()
async with self.write_lock:
try:
return await asyncio.get_event_loop().run_in_executor(
self.writer_executor, lambda: self.__run_transaction(fun, *args, **kwargs)
)
finally:
self.writers -= 1
if not self.writers:
self.read_ready.set()
def __run_transaction(self, fun: Callable[[sqlite3.Connection, Any, Any], Any], *args, **kwargs):
self.connection.execute('begin')
self.writer_connection.execute('begin')
try:
self.query_count += 1
result = fun(self.connection, *args, **kwargs) # type: ignore
self.connection.commit()
result = fun(self.writer_connection, *args, **kwargs) # type: ignore
self.writer_connection.commit()
return result
except (Exception, OSError) as e:
log.exception('Error running transaction:', exc_info=e)
self.connection.rollback()
self.writer_connection.rollback()
log.warning("rolled back")
raise
def run_with_foreign_keys_disabled(self, fun, *args, **kwargs) -> Awaitable:
return asyncio.get_event_loop().run_in_executor(
self.executor, self.__run_transaction_with_foreign_keys_disabled, fun, args, kwargs
self.writer_executor, self.__run_transaction_with_foreign_keys_disabled, fun, args, kwargs
)
def __run_transaction_with_foreign_keys_disabled(self,
fun: Callable[[sqlite3.Connection, Any, Any], Any],
args, kwargs):
foreign_keys_enabled, = self.connection.execute("pragma foreign_keys").fetchone()
foreign_keys_enabled, = self.writer_connection.execute("pragma foreign_keys").fetchone()
if not foreign_keys_enabled:
raise sqlite3.IntegrityError("foreign keys are disabled, use `AIOSQLite.run` instead")
try:
self.connection.execute('pragma foreign_keys=off').fetchone()
self.writer_connection.execute('pragma foreign_keys=off').fetchone()
return self.__run_transaction(fun, *args, **kwargs)
finally:
self.connection.execute('pragma foreign_keys=on').fetchone()
self.writer_connection.execute('pragma foreign_keys=on').fetchone()
def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''):
@ -487,7 +565,7 @@ class Database(SQLiteMixin):
# 2. update address histories removing deleted TXs
return True
async def select_transactions(self, cols, accounts=None, **constraints):
async def select_transactions(self, cols, accounts=None, read_only=False, **constraints):
if not {'txid', 'txid__in'}.intersection(constraints):
assert accounts, "'accounts' argument required when no 'txid' constraint is present"
where, values = constraints_to_sql({
@ -500,7 +578,7 @@ class Database(SQLiteMixin):
"""
constraints.update(values)
return await self.db.execute_fetchall(
*query(f"SELECT {cols} FROM tx", **constraints)
*query(f"SELECT {cols} FROM tx", **constraints), read_only=read_only
)
TXO_NOT_MINE = Output(None, None, is_my_account=False)
@ -577,7 +655,7 @@ class Database(SQLiteMixin):
if txs:
return txs[0]
async def select_txos(self, cols, wallet=None, include_is_received=False, **constraints):
async def select_txos(self, cols, wallet=None, include_is_received=False, read_only=False, **constraints):
if include_is_received:
assert wallet is not None, 'cannot use is_recieved filter without wallet argument'
account_in_wallet, values = constraints_to_sql({
@ -593,14 +671,15 @@ class Database(SQLiteMixin):
sql = f"SELECT {cols} FROM txo JOIN tx USING (txid)"
if 'accounts' in constraints:
sql += " JOIN account_address USING (address)"
return await self.db.execute_fetchall(*query(sql, **constraints))
return await self.db.execute_fetchall(*query(sql, **constraints), read_only=read_only)
@staticmethod
def constrain_unspent(constraints):
constraints['is_reserved'] = False
constraints['txoid__not_in'] = "SELECT txoid FROM txi"
async def get_txos(self, wallet=None, no_tx=False, unspent=False, include_is_received=False, **constraints):
async def get_txos(self, wallet=None, no_tx=False, unspent=False, include_is_received=False,
read_only=False, **constraints):
include_is_received = include_is_received or 'is_received' in constraints
if unspent:
self.constrain_unspent(constraints)
@ -616,7 +695,7 @@ class Database(SQLiteMixin):
where account_address.address=txo.address
), exists(select 1 from txi where txi.txoid=txo.txoid)
""",
wallet=wallet, include_is_received=include_is_received, **constraints
wallet=wallet, include_is_received=include_is_received, read_only=read_only, **constraints
)
txos = []
txs = {}
@ -665,7 +744,8 @@ class Database(SQLiteMixin):
txo.claim_id: txo for txo in
(await self.get_channels(
wallet=wallet,
claim_id__in=channel_ids
claim_id__in=channel_ids,
read_only=read_only
))
}
for txo in txos:
@ -685,32 +765,32 @@ class Database(SQLiteMixin):
count = await self.select_txos('count(*)', **constraints)
return count[0][0]
def get_utxos(self, **constraints):
return self.get_txos(unspent=True, **constraints)
def get_utxos(self, read_only=False, **constraints):
return self.get_txos(unspent=True, read_only=read_only, **constraints)
def get_utxo_count(self, **constraints):
return self.get_txo_count(unspent=True, **constraints)
async def get_balance(self, wallet=None, accounts=None, **constraints):
async def get_balance(self, wallet=None, accounts=None, read_only=False, **constraints):
assert wallet or accounts, \
"'wallet' or 'accounts' constraints required to calculate balance"
constraints['accounts'] = accounts or wallet.accounts
self.constrain_unspent(constraints)
balance = await self.select_txos('SUM(amount)', **constraints)
balance = await self.select_txos('SUM(amount)', read_only=read_only, **constraints)
return balance[0][0] or 0
async def select_addresses(self, cols, **constraints):
async def select_addresses(self, cols, read_only=False, **constraints):
return await self.db.execute_fetchall(*query(
f"SELECT {cols} FROM pubkey_address JOIN account_address USING (address)",
**constraints
))
), read_only=read_only)
async def get_addresses(self, cols=None, **constraints):
async def get_addresses(self, cols=None, read_only=False, **constraints):
cols = cols or (
'address', 'account', 'chain', 'history', 'used_times',
'pubkey', 'chain_code', 'n', 'depth'
)
addresses = rows_to_dict(await self.select_addresses(', '.join(cols), **constraints), cols)
addresses = rows_to_dict(await self.select_addresses(', '.join(cols), read_only=read_only, **constraints), cols)
if 'pubkey' in cols:
for address in addresses:
address['pubkey'] = PubKey(
@ -719,12 +799,12 @@ class Database(SQLiteMixin):
)
return addresses
async def get_address_count(self, cols=None, **constraints):
count = await self.select_addresses('count(*)', **constraints)
async def get_address_count(self, cols=None, read_only=False, **constraints):
count = await self.select_addresses('count(*)', read_only=read_only, **constraints)
return count[0][0]
async def get_address(self, **constraints):
addresses = await self.get_addresses(limit=1, **constraints)
async def get_address(self, read_only=False, **constraints):
addresses = await self.get_addresses(read_only=read_only, limit=1, **constraints)
if addresses:
return addresses[0]
@ -788,9 +868,9 @@ class Database(SQLiteMixin):
else:
constraints['txo_type__in'] = CLAIM_TYPES
async def get_claims(self, **constraints) -> List[Output]:
async def get_claims(self, read_only=False, **constraints) -> List[Output]:
self.constrain_claims(constraints)
return await self.get_utxos(**constraints)
return await self.get_utxos(read_only=read_only, **constraints)
def get_claim_count(self, **constraints):
self.constrain_claims(constraints)
@ -800,9 +880,9 @@ class Database(SQLiteMixin):
def constrain_streams(constraints):
constraints['txo_type'] = TXO_TYPES['stream']
def get_streams(self, **constraints):
def get_streams(self, read_only=False, **constraints):
self.constrain_streams(constraints)
return self.get_claims(**constraints)
return self.get_claims(read_only=read_only, **constraints)
def get_stream_count(self, **constraints):
self.constrain_streams(constraints)
@ -852,7 +932,7 @@ class Database(SQLiteMixin):
" )", (account.public_key.address, )
)
def get_supports_summary(self, account_id):
def get_supports_summary(self, account_id, read_only=False):
return self.db.execute_fetchall(f"""
select txo.amount, exists(select * from txi where txi.txoid=txo.txoid) as spent,
(txo.txid in
@ -861,4 +941,4 @@ class Database(SQLiteMixin):
(txo.address in (select address from account_address where account=?)) as to_me,
tx.height
from txo join tx using (txid) where txo_type={TXO_TYPES['support']}
""", (account_id, account_id))
""", (account_id, account_id), read_only=read_only)

View file

@ -7,9 +7,9 @@ from io import StringIO
from datetime import datetime
from functools import partial
from operator import itemgetter
from collections import namedtuple, defaultdict
from collections import defaultdict
from binascii import hexlify, unhexlify
from typing import Dict, Tuple, Type, Iterable, List, Optional, DefaultDict
from typing import Dict, Tuple, Type, Iterable, List, Optional, DefaultDict, NamedTuple
import pylru
from lbry.schema.result import Outputs, INVALID, NOT_FOUND
@ -53,16 +53,19 @@ class LedgerRegistry(type):
return mcs.ledgers[ledger_id]
class TransactionEvent(namedtuple('TransactionEvent', ('address', 'tx'))):
pass
class TransactionEvent(NamedTuple):
address: str
tx: Transaction
class AddressesGeneratedEvent(namedtuple('AddressesGeneratedEvent', ('address_manager', 'addresses'))):
pass
class AddressesGeneratedEvent(NamedTuple):
address_manager: AddressManager
addresses: List[str]
class BlockHeightEvent(namedtuple('BlockHeightEvent', ('height', 'change'))):
pass
class BlockHeightEvent(NamedTuple):
height: int
change: int
class TransactionCacheItem:
@ -822,8 +825,8 @@ class Ledger(metaclass=LedgerRegistry):
def get_support_count(self, **constraints):
return self.db.get_support_count(**constraints)
async def get_transaction_history(self, **constraints):
txs: List[Transaction] = await self.db.get_transactions(**constraints)
async def get_transaction_history(self, read_only=False, **constraints):
txs: List[Transaction] = await self.db.get_transactions(read_only=read_only, **constraints)
headers = self.headers
history = []
for tx in txs: # pylint: disable=too-many-nested-blocks
@ -932,8 +935,8 @@ class Ledger(metaclass=LedgerRegistry):
history.append(item)
return history
def get_transaction_history_count(self, **constraints):
return self.db.get_transaction_count(**constraints)
def get_transaction_history_count(self, read_only=False, **constraints):
return self.db.get_transaction_count(read_only=read_only, **constraints)
async def get_detailed_balance(self, accounts, confirmations=0):
result = {

View file

@ -266,9 +266,10 @@ class TestQueries(AsyncioTestCase):
# SQLite is usually compiled with 999 variables limit: https://www.sqlite.org/limits.html
# This can be removed when there is a better way. See: https://github.com/lbryio/lbry-sdk/issues/2281
fetchall = self.ledger.db.db.execute_fetchall
def check_parameters_length(sql, parameters):
def check_parameters_length(sql, parameters, read_only=False):
self.assertLess(len(parameters or []), 999)
return fetchall(sql, parameters)
return fetchall(sql, parameters, read_only)
self.ledger.db.db.execute_fetchall = check_parameters_length
account = await self.create_account()