lbry-sdk/lbry/file/source_manager.py

139 lines
4.8 KiB
Python
Raw Normal View History

import os
import asyncio
import logging
import typing
from typing import Optional
from lbry.file.source import ManagedDownloadSource
if typing.TYPE_CHECKING:
from lbry.conf import Config
from lbry.extras.daemon.analytics import AnalyticsManager
from lbry.extras.daemon.storage import SQLiteStorage
log = logging.getLogger(__name__)
2020-01-29 02:37:52 +01:00
COMPARISON_OPERATORS = {
'eq': lambda a, b: a == b,
'ne': lambda a, b: a != b,
'g': lambda a, b: a > b,
'l': lambda a, b: a < b,
'ge': lambda a, b: a >= b,
'le': lambda a, b: a <= b,
}
class SourceManager:
filter_fields = {
'rowid',
'status',
'file_name',
'added_on',
'download_path',
'claim_name',
'claim_height',
'claim_id',
'outpoint',
'txid',
'nout',
'channel_claim_id',
'channel_name',
'completed'
}
set_filter_fields = {
"claim_ids": "claim_id",
"channel_claim_ids": "channel_claim_id",
"outpoints": "outpoint"
}
source_class = ManagedDownloadSource
def __init__(self, loop: asyncio.AbstractEventLoop, config: 'Config', storage: 'SQLiteStorage',
analytics_manager: Optional['AnalyticsManager'] = None):
self.loop = loop
self.config = config
self.storage = storage
self.analytics_manager = analytics_manager
self._sources: typing.Dict[str, ManagedDownloadSource] = {}
self.started = asyncio.Event()
def add(self, source: ManagedDownloadSource):
self._sources[source.identifier] = source
def remove(self, source: ManagedDownloadSource):
if source.identifier not in self._sources:
return
self._sources.pop(source.identifier)
source.stop_tasks()
async def initialize_from_database(self):
raise NotImplementedError()
async def start(self):
await self.initialize_from_database()
self.started.set()
def stop(self):
while self._sources:
_, source = self._sources.popitem()
source.stop_tasks()
self.started.clear()
2020-01-29 02:37:52 +01:00
async def create(self, file_path: str, key: Optional[bytes] = None,
iv_generator: Optional[typing.Generator[bytes, None, None]] = None) -> ManagedDownloadSource:
raise NotImplementedError()
async def delete(self, source: ManagedDownloadSource, delete_file: Optional[bool] = False):
self.remove(source)
if delete_file and source.output_file_exists:
os.remove(source.full_path)
def get_filtered(self, sort_by: Optional[str] = None, reverse: Optional[bool] = False,
comparison: Optional[str] = None, **search_by) -> typing.List[ManagedDownloadSource]:
"""
Get a list of filtered and sorted ManagedStream objects
:param sort_by: field to sort by
:param reverse: reverse sorting
:param comparison: comparison operator used for filtering
:param search_by: fields and values to filter by
"""
if sort_by and sort_by not in self.filter_fields:
raise ValueError(f"'{sort_by}' is not a valid field to sort by")
2020-01-29 02:37:52 +01:00
if comparison and comparison not in COMPARISON_OPERATORS:
raise ValueError(f"'{comparison}' is not a valid comparison")
if 'full_status' in search_by:
del search_by['full_status']
2020-04-26 10:45:18 +02:00
2020-01-29 02:37:52 +01:00
for search in search_by:
if search not in self.filter_fields:
raise ValueError(f"'{search}' is not a valid search operation")
2020-04-26 10:45:18 +02:00
compare_sets = {}
if isinstance(search_by.get('claim_id'), list):
compare_sets['claim_ids'] = search_by.pop('claim_id')
if isinstance(search_by.get('outpoint'), list):
compare_sets['outpoints'] = search_by.pop('outpoint')
if isinstance(search_by.get('channel_claim_id'), list):
compare_sets['channel_claim_ids'] = search_by.pop('channel_claim_id')
if search_by or compare_sets:
comparison = comparison or 'eq'
2020-04-26 10:45:18 +02:00
streams = []
for stream in self._sources.values():
if compare_sets and not all(
getattr(stream, self.set_filter_fields[set_search]) in val
for set_search, val in compare_sets.items()):
continue
if search_by and not all(
COMPARISON_OPERATORS[comparison](getattr(stream, search), val)
for search, val in search_by.items()):
2020-04-26 10:45:18 +02:00
continue
streams.append(stream)
else:
2020-04-26 10:45:18 +02:00
streams = list(self._sources.values())
if sort_by:
streams.sort(key=lambda s: getattr(s, sort_by) or "")
if reverse:
2020-04-26 10:45:18 +02:00
streams.reverse()
return streams