diff --git a/lbry/db/database.py b/lbry/db/database.py index 97a89c6c4..c27f056bf 100644 --- a/lbry/db/database.py +++ b/lbry/db/database.py @@ -1,6 +1,6 @@ import os import asyncio -from typing import List, Optional, Tuple, Iterable +from typing import List, Optional, Tuple, Iterable, TYPE_CHECKING from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor from functools import partial @@ -8,12 +8,15 @@ from sqlalchemy import create_engine, text from lbry.crypto.bip32 import PubKey from lbry.schema.result import Censor -from lbry.blockchain.ledger import Ledger from lbry.blockchain.transaction import Transaction, Output from .constants import TXO_TYPES from . import queries as q +if TYPE_CHECKING: + from lbry.blockchain.ledger import Ledger + + def clean_wallet_account_ids(constraints): wallet = constraints.pop('wallet', None) account = constraints.pop('account', None) @@ -28,12 +31,12 @@ def clean_wallet_account_ids(constraints): constraints['account_ids'] = [account.id for account in accounts] -def add_channel_keys_to_txo_results(accounts: List, txos: Iterable[Output]): +async def add_channel_keys_to_txo_results(accounts: List, txos: Iterable[Output]): sub_channels = set() for txo in txos: if txo.claim.is_channel: for account in accounts: - private_key = account.get_channel_private_key( + private_key = await account.get_channel_private_key( txo.claim.channel.public_key_bytes ) if private_key: @@ -42,17 +45,21 @@ def add_channel_keys_to_txo_results(accounts: List, txos: Iterable[Output]): if txo.channel is not None: sub_channels.add(txo.channel) if sub_channels: - add_channel_keys_to_txo_results(accounts, sub_channels) + await add_channel_keys_to_txo_results(accounts, sub_channels) class Database: - def __init__(self, ledger: Ledger, url: str, multiprocess=False): + def __init__(self, ledger: 'Ledger', url: str, multiprocess=False): self.url = url self.ledger = ledger self.multiprocess = multiprocess self.executor: Optional[Executor] = None + @classmethod + def from_memory(cls, ledger): + return cls(ledger, 'sqlite:///:memory:') + def sync_create(self, name): engine = create_engine(self.url) db = engine.connect() @@ -145,8 +152,8 @@ class Database: async def get_balance(self, **constraints): return await self.run_in_executor(q.get_balance, **constraints) - async def get_supports_summary(self, **constraints): - return await self.run_in_executor(self.get_supports_summary, **constraints) + async def get_report(self, accounts): + return await self.run_in_executor(q.get_report, accounts=accounts) async def get_addresses(self, **constraints) -> Tuple[List[dict], Optional[int]]: addresses, count = await self.run_in_executor(q.get_addresses, **constraints) @@ -193,10 +200,10 @@ class Database: return await self.run_in_executor(q.get_txo_plot, **constraints) async def get_txos(self, **constraints) -> Tuple[List[Output], Optional[int]]: - txos = await self.run_in_executor(q.get_txos, **constraints) + txos, count = await self.run_in_executor(q.get_txos, **constraints) if 'wallet' in constraints: - add_channel_keys_to_txo_results(constraints['wallet'], txos) - return txos + await add_channel_keys_to_txo_results(constraints['wallet'].accounts, txos) + return txos, count async def get_utxos(self, **constraints) -> Tuple[List[Output], Optional[int]]: return await self.get_txos(is_spent=False, **constraints) @@ -207,7 +214,7 @@ class Database: async def get_claims(self, **constraints) -> Tuple[List[Output], Optional[int]]: txos, count = await self.run_in_executor(q.get_claims, **constraints) if 'wallet' in constraints: - add_channel_keys_to_txo_results(constraints['wallet'].accounts, txos) + await add_channel_keys_to_txo_results(constraints['wallet'].accounts, txos) return txos, count async def get_streams(self, **constraints) -> Tuple[List[Output], Optional[int]]: diff --git a/lbry/db/queries.py b/lbry/db/queries.py index 81b2beda5..da44289eb 100644 --- a/lbry/db/queries.py +++ b/lbry/db/queries.py @@ -737,6 +737,14 @@ def get_txo_sum(**constraints): return result[0]['total'] or 0 +def get_balance(**constraints): + return get_txo_sum(is_spent=False, **constraints) + + +def get_report(account_ids): + return + + def get_txo_plot(start_day=None, days_back=0, end_day=None, days_after=None, **constraints): _clean_txo_constraints_for_aggregation(constraints) if start_day is None: @@ -771,13 +779,6 @@ def get_purchases(**constraints) -> Tuple[List[Output], Optional[int]]: return [tx.outputs[0] for tx in txs], count -def get_balance(**constraints): - balance = select_txos( - [func.sum(TXO.c.amount).label('total')], is_spent=False, **constraints - ) - return balance[0]['total'] or 0 - - def select_addresses(cols, **constraints): return ctx().fetchall(query( [AccountAddress, PubkeyAddress],