import os import asyncio import tempfile import multiprocessing as mp from typing import List, Optional, Iterable, Iterator, TypeVar, Generic, TYPE_CHECKING, Dict from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor from functools import partial from sqlalchemy import create_engine, text from lbry.event import EventController from lbry.crypto.bip32 import PubKey from lbry.blockchain.transaction import Transaction, Output from .constants import TXO_TYPES, CLAIM_TYPE_CODES from .query_context import initialize, ProgressPublisher from . import queries as q from . import sync if TYPE_CHECKING: from lbry.blockchain.ledger import Ledger def clean_wallet_account_ids(constraints): wallet = constraints.pop('wallet', None) account = constraints.pop('account', None) accounts = constraints.pop('accounts', []) if account and not accounts: accounts = [account] if wallet: constraints['wallet_account_ids'] = [account.id for account in wallet.accounts] if not accounts: accounts = wallet.accounts if accounts: constraints['account_ids'] = [account.id for account in accounts] async def add_channel_keys_to_txo_results(accounts: List, txos: Iterable[Output]): sub_channels = set() for txo in txos: if txo.claim.is_channel: for account in accounts: private_key = await account.get_channel_private_key( txo.claim.channel.public_key_bytes ) if private_key: txo.private_key = private_key break if txo.channel is not None: sub_channels.add(txo.channel) if sub_channels: await add_channel_keys_to_txo_results(accounts, sub_channels) ResultType = TypeVar('ResultType') class Result(Generic[ResultType]): __slots__ = 'rows', 'total', 'censor' def __init__(self, rows: List[ResultType], total, censor=None): self.rows = rows self.total = total self.censor = censor def __getitem__(self, item: int) -> ResultType: return self.rows[item] def __iter__(self) -> Iterator[ResultType]: return iter(self.rows) def __len__(self): return len(self.rows) def __repr__(self): return repr(self.rows) class Database: def __init__(self, ledger: 'Ledger', processes=-1): self.url = ledger.conf.db_url_or_default self.ledger = ledger self.processes = self._normalize_processes(processes) self.executor: Optional[Executor] = None self.message_queue = mp.Queue() self.stop_event = mp.Event() self._on_progress_controller = EventController() self.on_progress = self._on_progress_controller.stream self.progress_publisher = ProgressPublisher( self.message_queue, self._on_progress_controller ) @staticmethod def _normalize_processes(processes): if processes == 0: return os.cpu_count() elif processes > 0: return processes return 1 @classmethod def temp_sqlite_regtest(cls, lbrycrd_dir=None): from lbry import Config, RegTestLedger # pylint: disable=import-outside-toplevel directory = tempfile.mkdtemp() conf = Config.with_same_dir(directory) if lbrycrd_dir is not None: conf.lbrycrd_dir = lbrycrd_dir ledger = RegTestLedger(conf) return cls(ledger) @classmethod def temp_sqlite(cls): from lbry import Config, Ledger # pylint: disable=import-outside-toplevel conf = Config.with_same_dir(tempfile.mkdtemp()) return cls(Ledger(conf)) @classmethod def in_memory(cls): from lbry import Config, Ledger # pylint: disable=import-outside-toplevel conf = Config.with_same_dir('/dev/null') conf.db_url = 'sqlite:///:memory:' return cls(Ledger(conf)) def sync_create(self, name): engine = create_engine(self.url) db = engine.connect() db.execute(text("COMMIT")) db.execute(text(f"CREATE DATABASE {name}")) async def create(self, name): return await asyncio.get_event_loop().run_in_executor(None, self.sync_create, name) def sync_drop(self, name): engine = create_engine(self.url) db = engine.connect() db.execute(text("COMMIT")) db.execute(text(f"DROP DATABASE IF EXISTS {name}")) async def drop(self, name): return await asyncio.get_event_loop().run_in_executor(None, self.sync_drop, name) async def open(self): assert self.executor is None, "Database already open." self.progress_publisher.start() kwargs = { "initializer": initialize, "initargs": ( self.ledger, self.message_queue, self.stop_event ) } if self.processes > 1: self.executor = ProcessPoolExecutor(max_workers=self.processes, **kwargs) else: self.executor = ThreadPoolExecutor(max_workers=1, **kwargs) return await self.run_in_executor(q.check_version_and_create_tables) async def close(self): self.progress_publisher.stop() if self.executor is not None: self.executor.shutdown() self.executor = None async def run_in_executor(self, func, *args, **kwargs): if kwargs: clean_wallet_account_ids(kwargs) return await asyncio.get_event_loop().run_in_executor( self.executor, partial(func, *args, **kwargs) ) async def fetch_result(self, func, *args, **kwargs) -> Result: rows, total = await self.run_in_executor(func, *args, **kwargs) return Result(rows, total) async def execute(self, sql): return await self.run_in_executor(q.execute, sql) async def execute_fetchall(self, sql): return await self.run_in_executor(q.execute_fetchall, sql) async def process_inputs_outputs(self): return await self.run_in_executor(sync.process_inputs_outputs) async def process_all_things_after_sync(self): return await self.run_in_executor(sync.process_all_things_after_sync) async def needs_initial_sync(self) -> bool: return (await self.get_best_tx_height()) == -1 async def get_best_tx_height(self) -> int: return await self.run_in_executor(q.get_best_tx_height) async def get_best_block_height_for_file(self, file_number) -> int: return await self.run_in_executor(q.get_best_block_height_for_file, file_number) async def get_blocks_without_filters(self): return await self.run_in_executor(q.get_blocks_without_filters) async def get_transactions_without_filters(self): return await self.run_in_executor(q.get_transactions_without_filters) async def get_block_tx_addresses(self, block_hash=None, tx_hash=None): return await self.run_in_executor(q.get_block_tx_addresses, block_hash, tx_hash) async def get_block_address_filters(self): return await self.run_in_executor(q.get_block_address_filters) async def get_transaction_address_filters(self, block_hash): return await self.run_in_executor(q.get_transaction_address_filters, block_hash) async def insert_block(self, block): return await self.run_in_executor(q.insert_block, block) async def insert_transaction(self, block_hash, tx): return await self.run_in_executor(q.insert_transaction, block_hash, tx) async def update_address_used_times(self, addresses): return await self.run_in_executor(q.update_address_used_times, addresses) async def reserve_outputs(self, txos, is_reserved=True): txo_hashes = [txo.hash for txo in txos] if txo_hashes: return await self.run_in_executor( q.reserve_outputs, txo_hashes, is_reserved ) async def release_outputs(self, txos): return await self.reserve_outputs(txos, is_reserved=False) async def release_tx(self, tx): return await self.release_outputs([txi.txo_ref.txo for txi in tx.inputs]) async def release_all_outputs(self, account): return await self.run_in_executor(q.release_all_outputs, account.id) async def get_balance(self, **constraints): return await self.run_in_executor(q.get_balance, **constraints) async def get_report(self, accounts): return await self.run_in_executor(q.get_report, accounts=accounts) async def get_addresses(self, **constraints) -> Result[dict]: addresses = await self.fetch_result(q.get_addresses, **constraints) if addresses and 'pubkey' in addresses[0]: for address in addresses: address['pubkey'] = PubKey( self.ledger, bytes(address.pop('pubkey')), bytes(address.pop('chain_code')), address.pop('n'), address.pop('depth') ) return addresses async def get_all_addresses(self): return await self.run_in_executor(q.get_all_addresses) async def get_address(self, **constraints): for address in await self.get_addresses(limit=1, **constraints): return address async def add_keys(self, account, chain, pubkeys): return await self.run_in_executor(q.add_keys, account, chain, pubkeys) async def get_transactions(self, **constraints) -> Result[Transaction]: return await self.fetch_result(q.get_transactions, **constraints) async def get_transaction(self, **constraints) -> Optional[Transaction]: txs = await self.get_transactions(limit=1, **constraints) if txs: return txs[0] async def get_purchases(self, **constraints) -> Result[Output]: return await self.fetch_result(q.get_purchases, **constraints) async def search_claims(self, **constraints) -> Result[Output]: #assert set(constraints).issubset(SEARCH_PARAMS), \ # f"Search query contains invalid arguments: {set(constraints).difference(SEARCH_PARAMS)}" claims, total, censor = await self.run_in_executor(q.search_claims, **constraints) return Result(claims, total, censor) async def search_supports(self, **constraints) -> Result[Output]: return await self.fetch_result(q.search_supports, **constraints) async def resolve(self, *urls) -> Dict[str, Output]: return await self.run_in_executor(q.resolve, *urls) async def get_txo_sum(self, **constraints) -> int: return await self.run_in_executor(q.get_txo_sum, **constraints) async def get_txo_plot(self, **constraints) -> List[dict]: return await self.run_in_executor(q.get_txo_plot, **constraints) async def get_txos(self, **constraints) -> Result[Output]: txos = await self.fetch_result(q.get_txos, **constraints) if 'wallet' in constraints: await add_channel_keys_to_txo_results(constraints['wallet'].accounts, txos) return txos async def get_utxos(self, **constraints) -> Result[Output]: return await self.get_txos(is_spent=False, **constraints) async def get_supports(self, **constraints) -> Result[Output]: return await self.get_utxos(txo_type=TXO_TYPES['support'], **constraints) async def get_claims(self, **constraints) -> Result[Output]: if 'txo_type' not in constraints: constraints['txo_type__in'] = CLAIM_TYPE_CODES txos = await self.fetch_result(q.get_txos, **constraints) if 'wallet' in constraints: await add_channel_keys_to_txo_results(constraints['wallet'].accounts, txos) return txos async def get_streams(self, **constraints) -> Result[Output]: return await self.get_claims(txo_type=TXO_TYPES['stream'], **constraints) async def get_channels(self, **constraints) -> Result[Output]: return await self.get_claims(txo_type=TXO_TYPES['channel'], **constraints) async def get_collections(self, **constraints) -> Result[Output]: return await self.get_claims(txo_type=TXO_TYPES['collection'], **constraints)