Merge pull request #2095 from lbryio/add-content-fee-and-saved-file-to-db

add saved_file and content_fee columns to file table
This commit is contained in:
Jack Robison 2019-05-08 20:37:56 -04:00 committed by GitHub
commit 8620a0e424
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 279 additions and 93 deletions

View file

@ -24,9 +24,12 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
self.blob: typing.Optional['AbstractBlob'] = None self.blob: typing.Optional['AbstractBlob'] = None
self._blob_bytes_received = 0 self._blob_bytes_received = 0
self._response_fut: asyncio.Future = None self._response_fut: typing.Optional[asyncio.Future] = None
self.buf = b'' self.buf = b''
# this is here to handle the race when the downloader is closed right as response_fut gets a result
self.closed = asyncio.Event(loop=self.loop)
def data_received(self, data: bytes): def data_received(self, data: bytes):
log.debug("%s:%d -- got %s bytes -- %s bytes on buffer -- %s blob bytes received", log.debug("%s:%d -- got %s bytes -- %s bytes on buffer -- %s blob bytes received",
self.peer_address, self.peer_port, len(data), len(self.buf), self._blob_bytes_received) self.peer_address, self.peer_port, len(data), len(self.buf), self._blob_bytes_received)
@ -90,6 +93,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
:return: download success (bool), keep connection (bool) :return: download success (bool), keep connection (bool)
""" """
request = BlobRequest.make_request_for_blob_hash(self.blob.blob_hash) request = BlobRequest.make_request_for_blob_hash(self.blob.blob_hash)
blob_hash = self.blob.blob_hash
try: try:
msg = request.serialize() msg = request.serialize()
log.debug("send request to %s:%i -> %s", self.peer_address, self.peer_port, msg.decode()) log.debug("send request to %s:%i -> %s", self.peer_address, self.peer_port, msg.decode())
@ -98,6 +102,10 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
availability_response = response.get_availability_response() availability_response = response.get_availability_response()
price_response = response.get_price_response() price_response = response.get_price_response()
blob_response = response.get_blob_response() blob_response = response.get_blob_response()
if self.closed.is_set():
msg = f"cancelled blob request for {blob_hash} immediately after we got a response"
log.warning(msg)
raise asyncio.CancelledError(msg)
if (not blob_response or blob_response.error) and\ if (not blob_response or blob_response.error) and\
(not availability_response or not availability_response.available_blobs): (not availability_response or not availability_response.available_blobs):
log.warning("%s not in availability response from %s:%i", self.blob.blob_hash, self.peer_address, log.warning("%s not in availability response from %s:%i", self.blob.blob_hash, self.peer_address,
@ -136,6 +144,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
return self._blob_bytes_received, self.close() return self._blob_bytes_received, self.close()
def close(self): def close(self):
self.closed.set()
if self._response_fut and not self._response_fut.done(): if self._response_fut and not self._response_fut.done():
self._response_fut.cancel() self._response_fut.cancel()
if self.writer and not self.writer.closed(): if self.writer and not self.writer.closed():
@ -149,6 +158,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
self.buf = b'' self.buf = b''
async def download_blob(self, blob: 'AbstractBlob') -> typing.Tuple[int, typing.Optional[asyncio.Transport]]: async def download_blob(self, blob: 'AbstractBlob') -> typing.Tuple[int, typing.Optional[asyncio.Transport]]:
self.closed.clear()
blob_hash = blob.blob_hash blob_hash = blob.blob_hash
if blob.get_is_verified() or not blob.is_writeable(): if blob.get_is_verified() or not blob.is_writeable():
return 0, self.transport return 0, self.transport

View file

@ -63,7 +63,7 @@ class DatabaseComponent(Component):
@staticmethod @staticmethod
def get_current_db_revision(): def get_current_db_revision():
return 10 return 11
@property @property
def revision_filename(self): def revision_filename(self):

View file

@ -24,6 +24,8 @@ def migrate_db(conf, start, end):
from .migrate8to9 import do_migration from .migrate8to9 import do_migration
elif current == 9: elif current == 9:
from .migrate9to10 import do_migration from .migrate9to10 import do_migration
elif current == 10:
from .migrate10to11 import do_migration
else: else:
raise Exception("DB migration of version {} to {} is not available".format(current, raise Exception("DB migration of version {} to {} is not available".format(current,
current+1)) current+1))

View file

@ -0,0 +1,54 @@
import sqlite3
import os
import binascii
def do_migration(conf):
db_path = os.path.join(conf.data_dir, "lbrynet.sqlite")
connection = sqlite3.connect(db_path)
cursor = connection.cursor()
current_columns = []
for col_info in cursor.execute("pragma table_info('file');").fetchall():
current_columns.append(col_info[1])
if 'content_fee' in current_columns or 'saved_file' in current_columns:
connection.close()
print("already migrated")
return
cursor.execute(
"pragma foreign_keys=off;"
)
cursor.execute("""
create table if not exists new_file (
stream_hash text primary key not null references stream,
file_name text,
download_directory text,
blob_data_rate real not null,
status text not null,
saved_file integer not null,
content_fee text
);
""")
for (stream_hash, file_name, download_dir, data_rate, status) in cursor.execute("select * from file").fetchall():
saved_file = 0
if download_dir != '{stream}' and file_name != '{stream}':
try:
if os.path.isfile(os.path.join(binascii.unhexlify(download_dir).decode(),
binascii.unhexlify(file_name).decode())):
saved_file = 1
else:
download_dir, file_name = None, None
except Exception:
download_dir, file_name = None, None
else:
download_dir, file_name = None, None
cursor.execute(
"insert into new_file values (?, ?, ?, ?, ?, ?, NULL)",
(stream_hash, file_name, download_dir, data_rate, status, saved_file)
)
cursor.execute("drop table file")
cursor.execute("alter table new_file rename to file")
connection.commit()
connection.close()

View file

@ -8,6 +8,7 @@ import time
from torba.client.basedatabase import SQLiteMixin from torba.client.basedatabase import SQLiteMixin
from lbrynet.conf import Config from lbrynet.conf import Config
from lbrynet.wallet.dewies import dewies_to_lbc, lbc_to_dewies from lbrynet.wallet.dewies import dewies_to_lbc, lbc_to_dewies
from lbrynet.wallet.transaction import Transaction
from lbrynet.schema.claim import Claim from lbrynet.schema.claim import Claim
from lbrynet.dht.constants import data_expiration from lbrynet.dht.constants import data_expiration
from lbrynet.blob.blob_info import BlobInfo from lbrynet.blob.blob_info import BlobInfo
@ -114,8 +115,8 @@ def _batched_select(transaction, query, parameters, batch_size=900):
def get_all_lbry_files(transaction: sqlite3.Connection) -> typing.List[typing.Dict]: def get_all_lbry_files(transaction: sqlite3.Connection) -> typing.List[typing.Dict]:
files = [] files = []
signed_claims = {} signed_claims = {}
for (rowid, stream_hash, file_name, download_dir, data_rate, status, _, sd_hash, stream_key, for (rowid, stream_hash, file_name, download_dir, data_rate, status, saved_file, raw_content_fee, _,
stream_name, suggested_file_name, *claim_args) in _batched_select( sd_hash, stream_key, stream_name, suggested_file_name, *claim_args) in _batched_select(
transaction, "select file.rowid, file.*, stream.*, c.* " transaction, "select file.rowid, file.*, stream.*, c.* "
"from file inner join stream on file.stream_hash=stream.stream_hash " "from file inner join stream on file.stream_hash=stream.stream_hash "
"inner join content_claim cc on file.stream_hash=cc.stream_hash " "inner join content_claim cc on file.stream_hash=cc.stream_hash "
@ -141,7 +142,11 @@ def get_all_lbry_files(transaction: sqlite3.Connection) -> typing.List[typing.Di
"key": stream_key, "key": stream_key,
"stream_name": stream_name, # hex "stream_name": stream_name, # hex
"suggested_file_name": suggested_file_name, # hex "suggested_file_name": suggested_file_name, # hex
"claim": claim "claim": claim,
"saved_file": bool(saved_file),
"content_fee": None if not raw_content_fee else Transaction(
binascii.unhexlify(raw_content_fee)
)
} }
) )
for claim_name, claim_id in _batched_select( for claim_name, claim_id in _batched_select(
@ -188,16 +193,20 @@ def delete_stream(transaction: sqlite3.Connection, descriptor: 'StreamDescriptor
def store_file(transaction: sqlite3.Connection, stream_hash: str, file_name: typing.Optional[str], def store_file(transaction: sqlite3.Connection, stream_hash: str, file_name: typing.Optional[str],
download_directory: typing.Optional[str], data_payment_rate: float, status: str) -> int: download_directory: typing.Optional[str], data_payment_rate: float, status: str,
content_fee: typing.Optional[Transaction]) -> int:
if not file_name and not download_directory: if not file_name and not download_directory:
encoded_file_name, encoded_download_dir = "{stream}", "{stream}" encoded_file_name, encoded_download_dir = None, None
else: else:
encoded_file_name = binascii.hexlify(file_name.encode()).decode() encoded_file_name = binascii.hexlify(file_name.encode()).decode()
encoded_download_dir = binascii.hexlify(download_directory.encode()).decode() encoded_download_dir = binascii.hexlify(download_directory.encode()).decode()
transaction.execute( transaction.execute(
"insert or replace into file values (?, ?, ?, ?, ?)", "insert or replace into file values (?, ?, ?, ?, ?, ?, ?)",
(stream_hash, encoded_file_name, encoded_download_dir, data_payment_rate, status) (stream_hash, encoded_file_name, encoded_download_dir, data_payment_rate, status,
1 if (file_name and download_directory and os.path.isfile(os.path.join(download_directory, file_name))) else 0,
None if not content_fee else content_fee.raw.decode())
) )
return transaction.execute("select rowid from file where stream_hash=?", (stream_hash, )).fetchone()[0] return transaction.execute("select rowid from file where stream_hash=?", (stream_hash, )).fetchone()[0]
@ -246,10 +255,12 @@ class SQLiteStorage(SQLiteMixin):
create table if not exists file ( create table if not exists file (
stream_hash text primary key not null references stream, stream_hash text primary key not null references stream,
file_name text not null, file_name text,
download_directory text not null, download_directory text,
blob_data_rate real not null, blob_data_rate real not null,
status text not null status text not null,
saved_file integer not null,
content_fee text
); );
create table if not exists content_claim ( create table if not exists content_claim (
@ -430,7 +441,7 @@ class SQLiteStorage(SQLiteMixin):
def set_files_as_streaming(self, stream_hashes: typing.List[str]): def set_files_as_streaming(self, stream_hashes: typing.List[str]):
def _set_streaming(transaction: sqlite3.Connection): def _set_streaming(transaction: sqlite3.Connection):
transaction.executemany( transaction.executemany(
"update file set file_name='{stream}', download_directory='{stream}' where stream_hash=?", "update file set file_name=null, download_directory=null where stream_hash=?",
[(stream_hash, ) for stream_hash in stream_hashes] [(stream_hash, ) for stream_hash in stream_hashes]
) )
@ -509,16 +520,42 @@ class SQLiteStorage(SQLiteMixin):
# # # # # # # # # file stuff # # # # # # # # # # # # # # # # # # file stuff # # # # # # # # #
def save_downloaded_file(self, stream_hash, file_name, download_directory, def save_downloaded_file(self, stream_hash: str, file_name: typing.Optional[str],
data_payment_rate) -> typing.Awaitable[int]: download_directory: typing.Optional[str], data_payment_rate: float,
content_fee: typing.Optional[Transaction] = None) -> typing.Awaitable[int]:
return self.save_published_file( return self.save_published_file(
stream_hash, file_name, download_directory, data_payment_rate, status="running" stream_hash, file_name, download_directory, data_payment_rate, status="running",
content_fee=content_fee
) )
def save_published_file(self, stream_hash: str, file_name: typing.Optional[str], def save_published_file(self, stream_hash: str, file_name: typing.Optional[str],
download_directory: typing.Optional[str], data_payment_rate: float, download_directory: typing.Optional[str], data_payment_rate: float,
status="finished") -> typing.Awaitable[int]: status: str = "finished",
return self.db.run(store_file, stream_hash, file_name, download_directory, data_payment_rate, status) content_fee: typing.Optional[Transaction] = None) -> typing.Awaitable[int]:
return self.db.run(store_file, stream_hash, file_name, download_directory, data_payment_rate, status,
content_fee)
async def update_manually_removed_files_since_last_run(self):
"""
Update files that have been removed from the downloads directory since the last run
"""
def update_manually_removed_files(transaction: sqlite3.Connection):
removed = []
for (stream_hash, download_directory, file_name) in transaction.execute(
"select stream_hash, download_directory, file_name from file where saved_file=1"
).fetchall():
if download_directory and file_name and os.path.isfile(
os.path.join(binascii.unhexlify(download_directory.encode()).decode(),
binascii.unhexlify(file_name.encode()).decode())):
continue
else:
removed.append((stream_hash,))
if removed:
transaction.executemany(
"update file set file_name=null, download_directory=null, saved_file=0 where stream_hash=?",
removed
)
return await self.db.run(update_manually_removed_files)
def get_all_lbry_files(self) -> typing.Awaitable[typing.List[typing.Dict]]: def get_all_lbry_files(self) -> typing.Awaitable[typing.List[typing.Dict]]:
return self.db.run(get_all_lbry_files) return self.db.run(get_all_lbry_files)
@ -530,7 +567,7 @@ class SQLiteStorage(SQLiteMixin):
async def change_file_download_dir_and_file_name(self, stream_hash: str, download_dir: typing.Optional[str], async def change_file_download_dir_and_file_name(self, stream_hash: str, download_dir: typing.Optional[str],
file_name: typing.Optional[str]): file_name: typing.Optional[str]):
if not file_name or not download_dir: if not file_name or not download_dir:
encoded_file_name, encoded_download_dir = "{stream}", "{stream}" encoded_file_name, encoded_download_dir = None, None
else: else:
encoded_file_name = binascii.hexlify(file_name.encode()).decode() encoded_file_name = binascii.hexlify(file_name.encode()).decode()
encoded_download_dir = binascii.hexlify(download_dir.encode()).decode() encoded_download_dir = binascii.hexlify(download_dir.encode()).decode()
@ -538,18 +575,34 @@ class SQLiteStorage(SQLiteMixin):
encoded_download_dir, encoded_file_name, stream_hash, encoded_download_dir, encoded_file_name, stream_hash,
)) ))
async def recover_streams(self, descriptors_and_sds: typing.List[typing.Tuple['StreamDescriptor', 'BlobFile']], async def save_content_fee(self, stream_hash: str, content_fee: Transaction):
return await self.db.execute("update file set content_fee=? where stream_hash=?", (
binascii.hexlify(content_fee.raw), stream_hash,
))
async def set_saved_file(self, stream_hash: str):
return await self.db.execute("update file set saved_file=1 where stream_hash=?", (
stream_hash,
))
async def clear_saved_file(self, stream_hash: str):
return await self.db.execute("update file set saved_file=0 where stream_hash=?", (
stream_hash,
))
async def recover_streams(self, descriptors_and_sds: typing.List[typing.Tuple['StreamDescriptor', 'BlobFile',
typing.Optional[Transaction]]],
download_directory: str): download_directory: str):
def _recover(transaction: sqlite3.Connection): def _recover(transaction: sqlite3.Connection):
stream_hashes = [d.stream_hash for d, s in descriptors_and_sds] stream_hashes = [x[0].stream_hash for x in descriptors_and_sds]
for descriptor, sd_blob in descriptors_and_sds: for descriptor, sd_blob, content_fee in descriptors_and_sds:
content_claim = transaction.execute( content_claim = transaction.execute(
"select * from content_claim where stream_hash=?", (descriptor.stream_hash, ) "select * from content_claim where stream_hash=?", (descriptor.stream_hash, )
).fetchone() ).fetchone()
delete_stream(transaction, descriptor) # this will also delete the content claim delete_stream(transaction, descriptor) # this will also delete the content claim
store_stream(transaction, sd_blob, descriptor) store_stream(transaction, sd_blob, descriptor)
store_file(transaction, descriptor.stream_hash, os.path.basename(descriptor.suggested_file_name), store_file(transaction, descriptor.stream_hash, os.path.basename(descriptor.suggested_file_name),
download_directory, 0.0, 'stopped') download_directory, 0.0, 'stopped', content_fee=content_fee)
if content_claim: if content_claim:
transaction.execute("insert or ignore into content_claim values (?, ?)", content_claim) transaction.execute("insert or ignore into content_claim values (?, ?)", content_claim)
transaction.executemany( transaction.executemany(

View file

@ -215,8 +215,6 @@ class Stream(BaseClaim):
if 'sd_hash' in kwargs: if 'sd_hash' in kwargs:
self.source.sd_hash = kwargs.pop('sd_hash') self.source.sd_hash = kwargs.pop('sd_hash')
if 'file_size' in kwargs:
self.source.size = kwargs.pop('file_size')
if 'file_name' in kwargs: if 'file_name' in kwargs:
self.source.name = kwargs.pop('file_name') self.source.name = kwargs.pop('file_name')
if 'file_hash' in kwargs: if 'file_hash' in kwargs:
@ -230,6 +228,9 @@ class Stream(BaseClaim):
elif self.source.media_type: elif self.source.media_type:
stream_type = guess_stream_type(self.source.media_type) stream_type = guess_stream_type(self.source.media_type)
if 'file_size' in kwargs:
self.source.size = kwargs.pop('file_size')
if stream_type in ('image', 'video', 'audio'): if stream_type in ('image', 'video', 'audio'):
media = getattr(self, stream_type) media = getattr(self, stream_type)
media_args = {'file_metadata': None} media_args = {'file_metadata': None}

View file

@ -198,21 +198,31 @@ class ManagedStream:
return guess_media_type(os.path.basename(self.descriptor.suggested_file_name))[0] return guess_media_type(os.path.basename(self.descriptor.suggested_file_name))[0]
def as_dict(self) -> typing.Dict: def as_dict(self) -> typing.Dict:
if not self.written_bytes and self.output_file_exists: full_path = self.full_path
written_bytes = os.stat(self.full_path).st_size file_name = self.file_name
download_directory = self.download_directory
if self.full_path and self.output_file_exists:
if self.written_bytes:
written_bytes = self.written_bytes
else:
written_bytes = os.stat(self.full_path).st_size
else: else:
written_bytes = self.written_bytes full_path = None
file_name = None
download_directory = None
written_bytes = None
return { return {
'completed': self.output_file_exists and self.status in ('stopped', 'finished'), 'completed': (self.output_file_exists and self.status in ('stopped', 'finished')) or all(
'file_name': self.file_name, self.blob_manager.is_blob_verified(b.blob_hash) for b in self.descriptor.blobs[:-1]),
'download_directory': self.download_directory, 'file_name': file_name,
'download_directory': download_directory,
'points_paid': 0.0, 'points_paid': 0.0,
'stopped': not self.running, 'stopped': not self.running,
'stream_hash': self.stream_hash, 'stream_hash': self.stream_hash,
'stream_name': self.descriptor.stream_name, 'stream_name': self.descriptor.stream_name,
'suggested_file_name': self.descriptor.suggested_file_name, 'suggested_file_name': self.descriptor.suggested_file_name,
'sd_hash': self.descriptor.sd_hash, 'sd_hash': self.descriptor.sd_hash,
'download_path': self.full_path, 'download_path': full_path,
'mime_type': self.mime_type, 'mime_type': self.mime_type,
'key': self.descriptor.key, 'key': self.descriptor.key,
'total_bytes_lower_bound': self.descriptor.lower_bound_decrypted_length(), 'total_bytes_lower_bound': self.descriptor.lower_bound_decrypted_length(),
@ -231,7 +241,7 @@ class ManagedStream:
'channel_claim_id': self.channel_claim_id, 'channel_claim_id': self.channel_claim_id,
'channel_name': self.channel_name, 'channel_name': self.channel_name,
'claim_name': self.claim_name, 'claim_name': self.claim_name,
'content_fee': self.content_fee # TODO: this isn't in the database 'content_fee': self.content_fee
} }
@classmethod @classmethod
@ -328,6 +338,11 @@ class ManagedStream:
if not self.streaming_responses: if not self.streaming_responses:
self.streaming.clear() self.streaming.clear()
@staticmethod
def _write_decrypted_blob(handle: typing.IO, data: bytes):
handle.write(data)
handle.flush()
async def _save_file(self, output_path: str): async def _save_file(self, output_path: str):
log.info("save file for lbry://%s#%s (sd hash %s...) -> %s", self.claim_name, self.claim_id, self.sd_hash[:6], log.info("save file for lbry://%s#%s (sd hash %s...) -> %s", self.claim_name, self.claim_id, self.sd_hash[:6],
output_path) output_path)
@ -338,8 +353,7 @@ class ManagedStream:
with open(output_path, 'wb') as file_write_handle: with open(output_path, 'wb') as file_write_handle:
async for blob_info, decrypted in self._aiter_read_stream(connection_id=1): async for blob_info, decrypted in self._aiter_read_stream(connection_id=1):
log.info("write blob %i/%i", blob_info.blob_num + 1, len(self.descriptor.blobs) - 1) log.info("write blob %i/%i", blob_info.blob_num + 1, len(self.descriptor.blobs) - 1)
file_write_handle.write(decrypted) await self.loop.run_in_executor(None, self._write_decrypted_blob, file_write_handle, decrypted)
file_write_handle.flush()
self.written_bytes += len(decrypted) self.written_bytes += len(decrypted)
if not self.started_writing.is_set(): if not self.started_writing.is_set():
self.started_writing.set() self.started_writing.set()
@ -351,6 +365,7 @@ class ManagedStream:
self.finished_writing.set() self.finished_writing.set()
log.info("finished saving file for lbry://%s#%s (sd hash %s...) -> %s", self.claim_name, self.claim_id, log.info("finished saving file for lbry://%s#%s (sd hash %s...) -> %s", self.claim_name, self.claim_id,
self.sd_hash[:6], self.full_path) self.sd_hash[:6], self.full_path)
await self.blob_manager.storage.set_saved_file(self.stream_hash)
except Exception as err: except Exception as err:
if os.path.isfile(output_path): if os.path.isfile(output_path):
log.warning("removing incomplete download %s for %s", output_path, self.sd_hash) log.warning("removing incomplete download %s for %s", output_path, self.sd_hash)
@ -376,7 +391,7 @@ class ManagedStream:
os.mkdir(self.download_directory) os.mkdir(self.download_directory)
self._file_name = await get_next_available_file_name( self._file_name = await get_next_available_file_name(
self.loop, self.download_directory, self.loop, self.download_directory,
file_name or self._file_name or self.descriptor.suggested_file_name file_name or self.descriptor.suggested_file_name
) )
await self.blob_manager.storage.change_file_download_dir_and_file_name( await self.blob_manager.storage.change_file_download_dir_and_file_name(
self.stream_hash, self.download_directory, self.file_name self.stream_hash, self.download_directory, self.file_name
@ -461,8 +476,18 @@ class ManagedStream:
get_range = get_range.split('=')[1] get_range = get_range.split('=')[1]
start, end = get_range.split('-') start, end = get_range.split('-')
size = 0 size = 0
for blob in self.descriptor.blobs[:-1]: for blob in self.descriptor.blobs[:-1]:
size += blob.length - 1 size += blob.length - 1
if self.stream_claim_info and self.stream_claim_info.claim.stream.source.size:
size_from_claim = int(self.stream_claim_info.claim.stream.source.size)
if not size_from_claim <= size <= size_from_claim + 16:
raise ValueError("claim contains implausible stream size")
log.debug("using stream size from claim")
size = size_from_claim
elif self.stream_claim_info:
log.debug("estimating stream size")
start = int(start) start = int(start)
end = int(end) if end else size - 1 end = int(end) if end else size - 1
skip_blobs = start // 2097150 skip_blobs = start // 2097150

View file

@ -21,6 +21,7 @@ if typing.TYPE_CHECKING:
from lbrynet.extras.daemon.analytics import AnalyticsManager from lbrynet.extras.daemon.analytics import AnalyticsManager
from lbrynet.extras.daemon.storage import SQLiteStorage, StoredStreamClaim from lbrynet.extras.daemon.storage import SQLiteStorage, StoredStreamClaim
from lbrynet.wallet import LbryWalletManager from lbrynet.wallet import LbryWalletManager
from lbrynet.wallet.transaction import Transaction
from lbrynet.extras.daemon.exchange_rate_manager import ExchangeRateManager from lbrynet.extras.daemon.exchange_rate_manager import ExchangeRateManager
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -55,7 +56,9 @@ comparison_operators = {
def path_or_none(p) -> typing.Optional[str]: def path_or_none(p) -> typing.Optional[str]:
return None if p == '{stream}' else binascii.unhexlify(p).decode() if not p:
return
return binascii.unhexlify(p).decode()
class StreamManager: class StreamManager:
@ -70,8 +73,8 @@ class StreamManager:
self.node = node self.node = node
self.analytics_manager = analytics_manager self.analytics_manager = analytics_manager
self.streams: typing.Dict[str, ManagedStream] = {} self.streams: typing.Dict[str, ManagedStream] = {}
self.resume_downloading_task: asyncio.Task = None self.resume_saving_task: typing.Optional[asyncio.Task] = None
self.re_reflect_task: asyncio.Task = None self.re_reflect_task: typing.Optional[asyncio.Task] = None
self.update_stream_finished_futs: typing.List[asyncio.Future] = [] self.update_stream_finished_futs: typing.List[asyncio.Future] = []
self.running_reflector_uploads: typing.List[asyncio.Task] = [] self.running_reflector_uploads: typing.List[asyncio.Task] = []
self.started = asyncio.Event(loop=self.loop) self.started = asyncio.Event(loop=self.loop)
@ -84,7 +87,8 @@ class StreamManager:
to_restore = [] to_restore = []
async def recover_stream(sd_hash: str, stream_hash: str, stream_name: str, async def recover_stream(sd_hash: str, stream_hash: str, stream_name: str,
suggested_file_name: str, key: str) -> typing.Optional[StreamDescriptor]: suggested_file_name: str, key: str,
content_fee: typing.Optional['Transaction']) -> typing.Optional[StreamDescriptor]:
sd_blob = self.blob_manager.get_blob(sd_hash) sd_blob = self.blob_manager.get_blob(sd_hash)
blobs = await self.storage.get_blobs_for_stream(stream_hash) blobs = await self.storage.get_blobs_for_stream(stream_hash)
descriptor = await StreamDescriptor.recover( descriptor = await StreamDescriptor.recover(
@ -92,12 +96,13 @@ class StreamManager:
) )
if not descriptor: if not descriptor:
return return
to_restore.append((descriptor, sd_blob)) to_restore.append((descriptor, sd_blob, content_fee))
await asyncio.gather(*[ await asyncio.gather(*[
recover_stream( recover_stream(
file_info['sd_hash'], file_info['stream_hash'], binascii.unhexlify(file_info['stream_name']).decode(), file_info['sd_hash'], file_info['stream_hash'], binascii.unhexlify(file_info['stream_name']).decode(),
binascii.unhexlify(file_info['suggested_file_name']).decode(), file_info['key'] binascii.unhexlify(file_info['suggested_file_name']).decode(), file_info['key'],
file_info['content_fee']
) for file_info in file_infos ) for file_info in file_infos
]) ])
@ -109,7 +114,7 @@ class StreamManager:
async def add_stream(self, rowid: int, sd_hash: str, file_name: typing.Optional[str], async def add_stream(self, rowid: int, sd_hash: str, file_name: typing.Optional[str],
download_directory: typing.Optional[str], status: str, download_directory: typing.Optional[str], status: str,
claim: typing.Optional['StoredStreamClaim']): claim: typing.Optional['StoredStreamClaim'], content_fee: typing.Optional['Transaction']):
try: try:
descriptor = await self.blob_manager.get_stream_descriptor(sd_hash) descriptor = await self.blob_manager.get_stream_descriptor(sd_hash)
except InvalidStreamDescriptorError as err: except InvalidStreamDescriptorError as err:
@ -117,16 +122,18 @@ class StreamManager:
return return
stream = ManagedStream( stream = ManagedStream(
self.loop, self.config, self.blob_manager, descriptor.sd_hash, download_directory, file_name, status, self.loop, self.config, self.blob_manager, descriptor.sd_hash, download_directory, file_name, status,
claim, rowid=rowid, descriptor=descriptor, analytics_manager=self.analytics_manager claim, content_fee=content_fee, rowid=rowid, descriptor=descriptor,
analytics_manager=self.analytics_manager
) )
self.streams[sd_hash] = stream self.streams[sd_hash] = stream
self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream)
async def load_streams_from_database(self): async def load_and_resume_streams_from_database(self):
to_recover = [] to_recover = []
to_start = [] to_start = []
# this will set streams marked as finished and are missing blobs as being stopped await self.storage.update_manually_removed_files_since_last_run()
# await self.storage.sync_files_to_blobs()
for file_info in await self.storage.get_all_lbry_files(): for file_info in await self.storage.get_all_lbry_files():
# if the sd blob is not verified, try to reconstruct it from the database # if the sd blob is not verified, try to reconstruct it from the database
# this could either be because the blob files were deleted manually or save_blobs was not true when # this could either be because the blob files were deleted manually or save_blobs was not true when
@ -138,29 +145,33 @@ class StreamManager:
await self.recover_streams(to_recover) await self.recover_streams(to_recover)
log.info("Initializing %i files", len(to_start)) log.info("Initializing %i files", len(to_start))
if to_start: to_resume_saving = []
await asyncio.gather(*[ add_stream_tasks = []
self.add_stream( for file_info in to_start:
file_info['rowid'], file_info['sd_hash'], path_or_none(file_info['file_name']), file_name = path_or_none(file_info['file_name'])
path_or_none(file_info['download_directory']), file_info['status'], download_directory = path_or_none(file_info['download_directory'])
file_info['claim'] if file_name and download_directory and not file_info['saved_file'] and file_info['status'] == 'running':
) for file_info in to_start to_resume_saving.append((file_name, download_directory, file_info['sd_hash']))
]) add_stream_tasks.append(self.loop.create_task(self.add_stream(
file_info['rowid'], file_info['sd_hash'], file_name,
download_directory, file_info['status'],
file_info['claim'], file_info['content_fee']
)))
if add_stream_tasks:
await asyncio.gather(*add_stream_tasks, loop=self.loop)
log.info("Started stream manager with %i files", len(self.streams)) log.info("Started stream manager with %i files", len(self.streams))
async def resume(self):
if not self.node: if not self.node:
log.warning("no DHT node given, resuming downloads trusting that we can contact reflector") log.warning("no DHT node given, resuming downloads trusting that we can contact reflector")
t = [ if to_resume_saving:
self.loop.create_task( self.resume_saving_task = self.loop.create_task(self.resume(to_resume_saving))
stream.start(node=self.node, save_now=(stream.full_path is not None))
if not stream.full_path else async def resume(self, to_resume_saving):
stream.save_file(node=self.node) log.info("Resuming saving %i files", len(to_resume_saving))
) for stream in self.streams.values() if stream.running await asyncio.gather(
] *(self.streams[sd_hash].save_file(file_name, download_directory, node=self.node)
if t: for (file_name, download_directory, sd_hash) in to_resume_saving),
log.info("resuming %i downloads", len(t)) loop=self.loop
await asyncio.gather(*t, loop=self.loop) )
async def reflect_streams(self): async def reflect_streams(self):
while True: while True:
@ -182,14 +193,13 @@ class StreamManager:
await asyncio.sleep(300, loop=self.loop) await asyncio.sleep(300, loop=self.loop)
async def start(self): async def start(self):
await self.load_streams_from_database() await self.load_and_resume_streams_from_database()
self.resume_downloading_task = self.loop.create_task(self.resume())
self.re_reflect_task = self.loop.create_task(self.reflect_streams()) self.re_reflect_task = self.loop.create_task(self.reflect_streams())
self.started.set() self.started.set()
def stop(self): def stop(self):
if self.resume_downloading_task and not self.resume_downloading_task.done(): if self.resume_saving_task and not self.resume_saving_task.done():
self.resume_downloading_task.cancel() self.resume_saving_task.cancel()
if self.re_reflect_task and not self.re_reflect_task.done(): if self.re_reflect_task and not self.re_reflect_task.done():
self.re_reflect_task.cancel() self.re_reflect_task.cancel()
while self.streams: while self.streams:
@ -387,6 +397,7 @@ class StreamManager:
lbc_to_dewies(str(fee_amount)), fee_address.encode('latin1') lbc_to_dewies(str(fee_amount)), fee_address.encode('latin1')
) )
log.info("paid fee of %s for %s", fee_amount, uri) log.info("paid fee of %s for %s", fee_amount, uri)
await self.storage.save_content_fee(stream.stream_hash, stream.content_fee)
self.streams[stream.sd_hash] = stream self.streams[stream.sd_hash] = stream
self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream) self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream)

View file

@ -172,27 +172,43 @@ class CommandTestCase(IntegrationTestCase):
return json.loads(jsonrpc_dumps_pretty(value, ledger=self.ledger))['result'] return json.loads(jsonrpc_dumps_pretty(value, ledger=self.ledger))['result']
async def stream_create(self, name='hovercraft', bid='1.0', data=b'hi!', confirm=True, **kwargs): async def stream_create(self, name='hovercraft', bid='1.0', data=b'hi!', confirm=True, **kwargs):
with tempfile.NamedTemporaryFile() as file: file = tempfile.NamedTemporaryFile()
file.write(data)
file.flush() def cleanup():
claim = await self.out( try:
self.daemon.jsonrpc_stream_create(name, bid, file_path=file.name, **kwargs) file.close()
) except FileNotFoundError:
self.assertEqual(claim['outputs'][0]['name'], name) pass
if confirm:
await self.on_transaction_dict(claim) self.addCleanup(cleanup)
await self.generate(1) file.write(data)
await self.on_transaction_dict(claim) file.flush()
return claim claim = await self.out(
self.daemon.jsonrpc_stream_create(name, bid, file_path=file.name, **kwargs)
)
self.assertEqual(claim['outputs'][0]['name'], name)
if confirm:
await self.on_transaction_dict(claim)
await self.generate(1)
await self.on_transaction_dict(claim)
return claim
async def stream_update(self, claim_id, data=None, confirm=True, **kwargs): async def stream_update(self, claim_id, data=None, confirm=True, **kwargs):
if data: if data:
with tempfile.NamedTemporaryFile() as file: file = tempfile.NamedTemporaryFile()
file.write(data) file.write(data)
file.flush() file.flush()
claim = await self.out(
self.daemon.jsonrpc_stream_update(claim_id, file_path=file.name, **kwargs) def cleanup():
) try:
file.close()
except FileNotFoundError:
pass
self.addCleanup(cleanup)
claim = await self.out(
self.daemon.jsonrpc_stream_update(claim_id, file_path=file.name, **kwargs)
)
else: else:
claim = await self.out(self.daemon.jsonrpc_stream_update(claim_id, **kwargs)) claim = await self.out(self.daemon.jsonrpc_stream_update(claim_id, **kwargs))
self.assertIsNotNone(claim['outputs'][0]['name']) self.assertIsNotNone(claim['outputs'][0]['name'])

View file

@ -239,6 +239,7 @@ class FileCommands(CommandTestCase):
await self.daemon.jsonrpc_file_delete(claim_name='icanpay') await self.daemon.jsonrpc_file_delete(claim_name='icanpay')
await self.assertBalance(self.account, '9.925679') await self.assertBalance(self.account, '9.925679')
response = await self.daemon.jsonrpc_get('lbry://icanpay') response = await self.daemon.jsonrpc_get('lbry://icanpay')
raw_content_fee = response.content_fee.raw
await self.ledger.wait(response.content_fee) await self.ledger.wait(response.content_fee)
await self.assertBalance(self.account, '8.925555') await self.assertBalance(self.account, '8.925555')
self.assertEqual(len(self.daemon.jsonrpc_file_list()), 1) self.assertEqual(len(self.daemon.jsonrpc_file_list()), 1)
@ -252,3 +253,10 @@ class FileCommands(CommandTestCase):
self.assertEqual( self.assertEqual(
await self.blockchain.get_balance(), starting_balance + block_reward_and_claim_fee await self.blockchain.get_balance(), starting_balance + block_reward_and_claim_fee
) )
# restart the daemon and make sure the fee is still there
self.daemon.stream_manager.stop()
await self.daemon.stream_manager.start()
self.assertEqual(len(self.daemon.jsonrpc_file_list()), 1)
self.assertEqual(self.daemon.jsonrpc_file_list()[0].content_fee.raw, raw_content_fee)

View file

@ -24,11 +24,11 @@ class RangeRequests(CommandTestCase):
await self.daemon.stream_manager.start() await self.daemon.stream_manager.start()
return return
async def _setup_stream(self, data: bytes, save_blobs: bool = True, save_files: bool = False): async def _setup_stream(self, data: bytes, save_blobs: bool = True, save_files: bool = False, file_size=0):
self.daemon.conf.save_blobs = save_blobs self.daemon.conf.save_blobs = save_blobs
self.daemon.conf.save_files = save_files self.daemon.conf.save_files = save_files
self.data = data self.data = data
await self.stream_create('foo', '0.01', data=self.data) await self.stream_create('foo', '0.01', data=self.data, file_size=file_size)
if save_blobs: if save_blobs:
self.assertTrue(len(os.listdir(self.daemon.blob_manager.blob_dir)) > 1) self.assertTrue(len(os.listdir(self.daemon.blob_manager.blob_dir)) > 1)
await self.daemon.jsonrpc_file_list()[0].fully_reflected.wait() await self.daemon.jsonrpc_file_list()[0].fully_reflected.wait()
@ -70,9 +70,10 @@ class RangeRequests(CommandTestCase):
self.assertEqual('bytes 0-14/15', content_range) self.assertEqual('bytes 0-14/15', content_range)
async def test_range_requests_0_padded_bytes(self, size: int = (MAX_BLOB_SIZE - 1) * 4, async def test_range_requests_0_padded_bytes(self, size: int = (MAX_BLOB_SIZE - 1) * 4,
expected_range: str = 'bytes 0-8388603/8388604', padding=b''): expected_range: str = 'bytes 0-8388603/8388604', padding=b'',
file_size=0):
self.data = get_random_bytes(size) self.data = get_random_bytes(size)
await self._setup_stream(self.data) await self._setup_stream(self.data, file_size=file_size)
streamed, content_range, content_length = await self._test_range_requests() streamed, content_range, content_length = await self._test_range_requests()
self.assertEqual(len(self.data + padding), content_length) self.assertEqual(len(self.data + padding), content_length)
self.assertEqual(streamed, self.data + padding) self.assertEqual(streamed, self.data + padding)
@ -93,6 +94,11 @@ class RangeRequests(CommandTestCase):
((MAX_BLOB_SIZE - 1) * 4) - 14, padding=b'\x00' * 14 ((MAX_BLOB_SIZE - 1) * 4) - 14, padding=b'\x00' * 14
) )
async def test_range_requests_no_padding_size_from_claim(self):
size = ((MAX_BLOB_SIZE - 1) * 4) - 14
await self.test_range_requests_0_padded_bytes(size, padding=b'', file_size=size,
expected_range=f"bytes 0-{size-1}/{size}")
async def test_range_requests_15_padded_bytes(self): async def test_range_requests_15_padded_bytes(self):
await self.test_range_requests_0_padded_bytes( await self.test_range_requests_0_padded_bytes(
((MAX_BLOB_SIZE - 1) * 4) - 15, padding=b'\x00' * 15 ((MAX_BLOB_SIZE - 1) * 4) - 15, padding=b'\x00' * 15