multiple db reader processeses

This commit is contained in:
Jack Robison 2020-02-20 22:11:25 -05:00
parent 7a6b1930bf
commit d1b330028c
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
3 changed files with 92 additions and 34 deletions

View file

@ -434,8 +434,8 @@ class Account:
addresses.extend(new_addresses) addresses.extend(new_addresses)
return addresses return addresses
async def get_addresses(self, **constraints) -> List[str]: async def get_addresses(self, read_only: bool = False, **constraints) -> List[str]:
rows = await self.ledger.db.select_addresses('address', accounts=[self], **constraints) rows = await self.ledger.db.select_addresses('address', read_only=read_only, accounts=[self], **constraints)
return [r[0] for r in rows] return [r[0] for r in rows]
def get_address_records(self, **constraints): def get_address_records(self, **constraints):
@ -591,11 +591,15 @@ class Account:
} if reserved_subtotals else None } if reserved_subtotals else None
} }
def get_transaction_history(self, **constraints): def get_transaction_history(self, read_only: bool = False, **constraints):
return self.ledger.get_transaction_history(wallet=self.wallet, accounts=[self], **constraints) return self.ledger.get_transaction_history(
read_only=read_only, wallet=self.wallet, accounts=[self], **constraints
)
def get_transaction_history_count(self, **constraints): def get_transaction_history_count(self, read_only: bool = False, **constraints):
return self.ledger.get_transaction_history_count(wallet=self.wallet, accounts=[self], **constraints) return self.ledger.get_transaction_history_count(
read_only=read_only, wallet=self.wallet, accounts=[self], **constraints
)
def get_claims(self, **constraints): def get_claims(self, **constraints):
return self.ledger.get_claims(wallet=self.wallet, accounts=[self], **constraints) return self.ledger.get_claims(wallet=self.wallet, accounts=[self], **constraints)

View file

@ -1,9 +1,12 @@
import os
import logging import logging
import asyncio import asyncio
import sqlite3 import sqlite3
from binascii import hexlify from binascii import hexlify
from dataclasses import dataclass
from contextvars import ContextVar
from concurrent.futures.thread import ThreadPoolExecutor from concurrent.futures.thread import ThreadPoolExecutor
from concurrent.futures.process import ProcessPoolExecutor
from typing import Tuple, List, Union, Callable, Any, Awaitable, Iterable, Dict, Optional from typing import Tuple, List, Union, Callable, Any, Awaitable, Iterable, Dict, Optional
from .bip32 import PubKey from .bip32 import PubKey
@ -15,7 +18,39 @@ log = logging.getLogger(__name__)
sqlite3.enable_callback_tracebacks(True) sqlite3.enable_callback_tracebacks(True)
@dataclass
class ReaderProcessState:
cursor: sqlite3.Cursor
reader_context: Optional[ContextVar[ReaderProcessState]] = ContextVar('reader_context')
def initializer(path):
reader = ReaderProcessState(sqlite3.connect(path).cursor())
reader_context.set(reader)
def run_read_only_fetchall(sql, params):
cursor = reader_context.get().cursor
try:
return cursor.execute(sql, params).fetchall()
except (Exception, OSError) as e:
log.exception('Error running transaction:', exc_info=e)
raise
def run_read_only_fetchone(sql, params):
cursor = reader_context.get().cursor
try:
return cursor.execute(sql, params).fetchone()
except (Exception, OSError) as e:
log.exception('Error running transaction:', exc_info=e)
raise
class AIOSQLite: class AIOSQLite:
reader_executor: ProcessPoolExecutor
def __init__(self): def __init__(self):
# has to be single threaded as there is no mapping of thread:connection # has to be single threaded as there is no mapping of thread:connection
@ -31,6 +66,11 @@ class AIOSQLite:
def _connect_writer(): def _connect_writer():
db.writer_connection = sqlite3.connect(path, *args, **kwargs) db.writer_connection = sqlite3.connect(path, *args, **kwargs)
readers = max(os.cpu_count() - 2, 2)
db.reader_executor = ProcessPoolExecutor(
max_workers=readers, initializer=initializer, initargs=(path, )
)
await asyncio.get_event_loop().run_in_executor(db.writer_executor, _connect_writer) await asyncio.get_event_loop().run_in_executor(db.writer_executor, _connect_writer)
return db return db
@ -40,6 +80,7 @@ class AIOSQLite:
self._closing = True self._closing = True
await asyncio.get_event_loop().run_in_executor(self.writer_executor, self.writer_connection.close) await asyncio.get_event_loop().run_in_executor(self.writer_executor, self.writer_connection.close)
self.writer_executor.shutdown(wait=True) self.writer_executor.shutdown(wait=True)
self.reader_executor.shutdown(wait=True)
self.writer_connection = None self.writer_connection = None
def executemany(self, sql: str, params: Iterable): def executemany(self, sql: str, params: Iterable):
@ -50,12 +91,22 @@ class AIOSQLite:
def executescript(self, script: str) -> Awaitable: def executescript(self, script: str) -> Awaitable:
return self.run(lambda conn: conn.executescript(script)) return self.run(lambda conn: conn.executescript(script))
def execute_fetchall(self, sql: str, parameters: Iterable = None) -> Awaitable[Iterable[sqlite3.Row]]: async def execute_fetchall(self, sql: str, parameters: Iterable = None,
read_only: bool = False) -> Iterable[sqlite3.Row]:
parameters = parameters if parameters is not None else [] parameters = parameters if parameters is not None else []
return self.run(lambda conn: conn.execute(sql, parameters).fetchall()) if read_only:
return await asyncio.get_event_loop().run_in_executor(
self.reader_executor, run_read_only_fetchall, sql, parameters
)
return await self.run(lambda conn: conn.execute(sql, parameters).fetchall())
def execute_fetchone(self, sql: str, parameters: Iterable = None) -> Awaitable[Iterable[sqlite3.Row]]: def execute_fetchone(self, sql: str, parameters: Iterable = None,
read_only: bool = False) -> Awaitable[Iterable[sqlite3.Row]]:
parameters = parameters if parameters is not None else [] parameters = parameters if parameters is not None else []
if read_only:
return asyncio.get_event_loop().run_in_executor(
self.reader_executor, run_read_only_fetchone, sql, parameters
)
return self.run(lambda conn: conn.execute(sql, parameters).fetchone()) return self.run(lambda conn: conn.execute(sql, parameters).fetchone())
def execute(self, sql: str, parameters: Iterable = None) -> Awaitable[sqlite3.Cursor]: def execute(self, sql: str, parameters: Iterable = None) -> Awaitable[sqlite3.Cursor]:
@ -488,7 +539,7 @@ class Database(SQLiteMixin):
# 2. update address histories removing deleted TXs # 2. update address histories removing deleted TXs
return True return True
async def select_transactions(self, cols, accounts=None, **constraints): async def select_transactions(self, cols, accounts=None, read_only: bool = False, **constraints):
if not {'txid', 'txid__in'}.intersection(constraints): if not {'txid', 'txid__in'}.intersection(constraints):
assert accounts, "'accounts' argument required when no 'txid' constraint is present" assert accounts, "'accounts' argument required when no 'txid' constraint is present"
where, values = constraints_to_sql({ where, values = constraints_to_sql({
@ -501,7 +552,7 @@ class Database(SQLiteMixin):
""" """
constraints.update(values) constraints.update(values)
return await self.db.execute_fetchall( return await self.db.execute_fetchall(
*query(f"SELECT {cols} FROM tx", **constraints) *query(f"SELECT {cols} FROM tx", **constraints), read_only=read_only
) )
TXO_NOT_MINE = Output(None, None, is_my_account=False) TXO_NOT_MINE = Output(None, None, is_my_account=False)
@ -578,7 +629,7 @@ class Database(SQLiteMixin):
if txs: if txs:
return txs[0] return txs[0]
async def select_txos(self, cols, wallet=None, include_is_received=False, **constraints): async def select_txos(self, cols, wallet=None, include_is_received=False, read_only: bool = False, **constraints):
if include_is_received: if include_is_received:
assert wallet is not None, 'cannot use is_recieved filter without wallet argument' assert wallet is not None, 'cannot use is_recieved filter without wallet argument'
account_in_wallet, values = constraints_to_sql({ account_in_wallet, values = constraints_to_sql({
@ -594,7 +645,7 @@ class Database(SQLiteMixin):
sql = f"SELECT {cols} FROM txo JOIN tx USING (txid)" sql = f"SELECT {cols} FROM txo JOIN tx USING (txid)"
if 'accounts' in constraints: if 'accounts' in constraints:
sql += " JOIN account_address USING (address)" sql += " JOIN account_address USING (address)"
return await self.db.execute_fetchall(*query(sql, **constraints)) return await self.db.execute_fetchall(*query(sql, **constraints), read_only=read_only)
@staticmethod @staticmethod
def constrain_unspent(constraints): def constrain_unspent(constraints):
@ -700,18 +751,18 @@ class Database(SQLiteMixin):
balance = await self.select_txos('SUM(amount)', **constraints) balance = await self.select_txos('SUM(amount)', **constraints)
return balance[0][0] or 0 return balance[0][0] or 0
async def select_addresses(self, cols, **constraints): async def select_addresses(self, cols, read_only: bool = False, **constraints):
return await self.db.execute_fetchall(*query( return await self.db.execute_fetchall(*query(
f"SELECT {cols} FROM pubkey_address JOIN account_address USING (address)", f"SELECT {cols} FROM pubkey_address JOIN account_address USING (address)",
**constraints **constraints
)) ), read_only=read_only)
async def get_addresses(self, cols=None, **constraints): async def get_addresses(self, cols=None, read_only: bool = False, **constraints):
cols = cols or ( cols = cols or (
'address', 'account', 'chain', 'history', 'used_times', 'address', 'account', 'chain', 'history', 'used_times',
'pubkey', 'chain_code', 'n', 'depth' 'pubkey', 'chain_code', 'n', 'depth'
) )
addresses = rows_to_dict(await self.select_addresses(', '.join(cols), **constraints), cols) addresses = rows_to_dict(await self.select_addresses(', '.join(cols), read_only=read_only, **constraints), cols)
if 'pubkey' in cols: if 'pubkey' in cols:
for address in addresses: for address in addresses:
address['pubkey'] = PubKey( address['pubkey'] = PubKey(
@ -720,8 +771,8 @@ class Database(SQLiteMixin):
) )
return addresses return addresses
async def get_address_count(self, cols=None, **constraints): async def get_address_count(self, cols=None, read_only: bool = False, **constraints):
count = await self.select_addresses('count(*)', **constraints) count = await self.select_addresses('count(*)', read_only=read_only, **constraints)
return count[0][0] return count[0][0]
async def get_address(self, **constraints): async def get_address(self, **constraints):
@ -853,7 +904,7 @@ class Database(SQLiteMixin):
" )", (account.public_key.address, ) " )", (account.public_key.address, )
) )
def get_supports_summary(self, account_id): def get_supports_summary(self, account_id, read_only: bool = False):
return self.db.execute_fetchall(f""" return self.db.execute_fetchall(f"""
select txo.amount, exists(select * from txi where txi.txoid=txo.txoid) as spent, select txo.amount, exists(select * from txi where txi.txoid=txo.txoid) as spent,
(txo.txid in (txo.txid in
@ -862,4 +913,4 @@ class Database(SQLiteMixin):
(txo.address in (select address from account_address where account=?)) as to_me, (txo.address in (select address from account_address where account=?)) as to_me,
tx.height tx.height
from txo join tx using (txid) where txo_type={TXO_TYPES['support']} from txo join tx using (txid) where txo_type={TXO_TYPES['support']}
""", (account_id, account_id)) """, (account_id, account_id), read_only=read_only)

View file

@ -7,9 +7,9 @@ from io import StringIO
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from operator import itemgetter from operator import itemgetter
from collections import namedtuple, defaultdict from collections import defaultdict
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from typing import Dict, Tuple, Type, Iterable, List, Optional, DefaultDict from typing import Dict, Tuple, Type, Iterable, List, Optional, DefaultDict, NamedTuple
import pylru import pylru
from lbry.schema.result import Outputs, INVALID, NOT_FOUND from lbry.schema.result import Outputs, INVALID, NOT_FOUND
@ -53,16 +53,19 @@ class LedgerRegistry(type):
return mcs.ledgers[ledger_id] return mcs.ledgers[ledger_id]
class TransactionEvent(namedtuple('TransactionEvent', ('address', 'tx'))): class TransactionEvent(NamedTuple):
pass address: str
tx: Transaction
class AddressesGeneratedEvent(namedtuple('AddressesGeneratedEvent', ('address_manager', 'addresses'))): class AddressesGeneratedEvent(NamedTuple):
pass address_manager: AddressManager
addresses: List[str]
class BlockHeightEvent(namedtuple('BlockHeightEvent', ('height', 'change'))): class BlockHeightEvent(NamedTuple):
pass height: int
change: int
class TransactionCacheItem: class TransactionCacheItem:
@ -822,8 +825,8 @@ class Ledger(metaclass=LedgerRegistry):
def get_support_count(self, **constraints): def get_support_count(self, **constraints):
return self.db.get_support_count(**constraints) return self.db.get_support_count(**constraints)
async def get_transaction_history(self, **constraints): async def get_transaction_history(self, read_only: bool = False, **constraints):
txs: List[Transaction] = await self.db.get_transactions(**constraints) txs: List[Transaction] = await self.db.get_transactions(read_only=read_only, **constraints)
headers = self.headers headers = self.headers
history = [] history = []
for tx in txs: # pylint: disable=too-many-nested-blocks for tx in txs: # pylint: disable=too-many-nested-blocks
@ -932,8 +935,8 @@ class Ledger(metaclass=LedgerRegistry):
history.append(item) history.append(item)
return history return history
def get_transaction_history_count(self, **constraints): def get_transaction_history_count(self, read_only: bool = False, **constraints):
return self.db.get_transaction_count(**constraints) return self.db.get_transaction_count(read_only=read_only, **constraints)
async def get_detailed_balance(self, accounts, confirmations=0): async def get_detailed_balance(self, accounts, confirmations=0):
result = { result = {