diff --git a/lbrynet/extras/daemon/Downloader.py b/lbrynet/extras/daemon/Downloader.py index 6e3b7a04a..45900e2f4 100644 --- a/lbrynet/extras/daemon/Downloader.py +++ b/lbrynet/extras/daemon/Downloader.py @@ -6,8 +6,7 @@ from lbrynet import conf from lbrynet.schema.fee import Fee from lbrynet.p2p.Error import InsufficientFundsError, KeyFeeAboveMaxAllowed, InvalidStreamDescriptorError -from lbrynet.p2p.Error import DownloadDataTimeout, DownloadCanceledError, DownloadSDTimeout -from lbrynet.utils import safe_start_looping_call, safe_stop_looping_call +from lbrynet.p2p.Error import DownloadDataTimeout, DownloadCanceledError from lbrynet.p2p.StreamDescriptor import download_sd_blob from lbrynet.blob.EncryptedFileDownloader import ManagedEncryptedFileDownloaderFactory from torba.client.constants import COIN @@ -181,7 +180,6 @@ class GetStream: downloader - instance of ManagedEncryptedFileDownloader finished_deferred - deferred callbacked when download is finished """ - self.set_status(INITIALIZING_CODE, name) key_fee = yield self._initialize(stream_info) self.set_status(DOWNLOAD_METADATA_CODE, name) @@ -191,7 +189,11 @@ class GetStream: self.set_status(DOWNLOAD_RUNNING_CODE, name) log.info("Downloading lbry://%s (%s) --> %s", name, self.sd_hash[:6], self.download_path) self.data_downloading_deferred.addTimeout(self.timeout, self.reactor) - yield self.data_downloading_deferred + try: + yield self.data_downloading_deferred + self.wrote_data = True + except defer.TimeoutError: + raise DownloadDataTimeout("data download timed out") except (DownloadDataTimeout, InvalidStreamDescriptorError) as err: raise err diff --git a/tests/unit/lbrynet_daemon/test_Downloader.py b/tests/unit/lbrynet_daemon/test_Downloader.py index 625179518..36a5e97a5 100644 --- a/tests/unit/lbrynet_daemon/test_Downloader.py +++ b/tests/unit/lbrynet_daemon/test_Downloader.py @@ -9,6 +9,7 @@ from lbrynet.p2p.Error import DownloadDataTimeout, DownloadSDTimeout from lbrynet.p2p.StreamDescriptor import StreamDescriptorIdentifier from lbrynet.p2p.BlobManager import DiskBlobManager from lbrynet.p2p.RateLimiter import DummyRateLimiter +from lbrynet.p2p.client.DownloadManager import DownloadManager from lbrynet.extras.daemon import Downloader from lbrynet.extras.daemon import ExchangeRateManager from lbrynet.extras.daemon.storage import SQLiteStorage @@ -24,7 +25,7 @@ class MocDownloader: def __init__(self): self.finish_deferred = defer.Deferred(None) self.stop_called = False - + self.file_name = 'test' self.name = 'test' self.num_completed = 0 self.num_known = 1 @@ -60,8 +61,10 @@ def moc_download(self, sd_blob, name, txid, nout, key_fee, file_name): self.downloader.start() -def moc_pay_key_fee(self, key_fee, name): - self.pay_key_fee_called = True +def moc_pay_key_fee(d): + def _moc_pay_key_fee(self, key_fee, name): + d.callback(True) + return _moc_pay_key_fee class GetStreamTests(unittest.TestCase): @@ -83,10 +86,7 @@ class GetStreamTests(unittest.TestCase): sd_identifier, wallet, exchange_rate_manager, blob_manager, peer_finder, DummyRateLimiter(), prm, storage, max_key_fee, disable_max_key_fee, timeout=3, data_rate=data_rate ) - getstream.pay_key_fee_called = False - - self.clock = task.Clock() - getstream.checker.clock = self.clock + getstream.download_manager = mock.Mock(spec=DownloadManager) return getstream @defer.inlineCallbacks @@ -112,16 +112,18 @@ class GetStreamTests(unittest.TestCase): def download_sd_blob(self): raise DownloadSDTimeout(self) + called_pay_fee = defer.Deferred() + getstream = self.init_getstream_with_mocs() getstream._initialize = types.MethodType(moc_initialize, getstream) getstream._download_sd_blob = types.MethodType(download_sd_blob, getstream) getstream._download = types.MethodType(moc_download, getstream) - getstream.pay_key_fee = types.MethodType(moc_pay_key_fee, getstream) + getstream.pay_key_fee = types.MethodType(moc_pay_key_fee(called_pay_fee), getstream) name = 'test' stream_info = None with self.assertRaises(DownloadSDTimeout): yield getstream.start(stream_info, name, "deadbeef" * 12, 0) - self.assertFalse(getstream.pay_key_fee_called) + self.assertFalse(called_pay_fee.called) @defer.inlineCallbacks def test_timeout(self): @@ -129,20 +131,18 @@ class GetStreamTests(unittest.TestCase): test that timeout (set to 3 here) exception is raised when download times out while downloading first blob, and key fee is paid """ + called_pay_fee = defer.Deferred() + getstream = self.init_getstream_with_mocs() getstream._initialize = types.MethodType(moc_initialize, getstream) getstream._download_sd_blob = types.MethodType(moc_download_sd_blob, getstream) getstream._download = types.MethodType(moc_download, getstream) - getstream.pay_key_fee = types.MethodType(moc_pay_key_fee, getstream) + getstream.pay_key_fee = types.MethodType(moc_pay_key_fee(called_pay_fee), getstream) name = 'test' stream_info = None start = getstream.start(stream_info, name, "deadbeef" * 12, 0) - self.clock.advance(1) - self.clock.advance(1) - self.clock.advance(1) with self.assertRaises(DownloadDataTimeout): yield start - self.assertTrue(getstream.pay_key_fee_called) @defer.inlineCallbacks def test_finish_one_blob(self): @@ -150,20 +150,22 @@ class GetStreamTests(unittest.TestCase): test that if we have 1 completed blob, start() returns and key fee is paid """ + called_pay_fee = defer.Deferred() + getstream = self.init_getstream_with_mocs() getstream._initialize = types.MethodType(moc_initialize, getstream) getstream._download_sd_blob = types.MethodType(moc_download_sd_blob, getstream) getstream._download = types.MethodType(moc_download, getstream) - getstream.pay_key_fee = types.MethodType(moc_pay_key_fee, getstream) + getstream.pay_key_fee = types.MethodType(moc_pay_key_fee(called_pay_fee), getstream) name = 'test' stream_info = None - start = getstream.start(stream_info, name, "deadbeef" * 12, 0) - getstream.downloader.num_completed = 1 - self.clock.advance(1) + self.assertFalse(getstream.wrote_data) + getstream.data_downloading_deferred.callback(None) + yield getstream.start(stream_info, name, "deadbeef" * 12, 0) + self.assertTrue(getstream.wrote_data) - downloader, f_deferred = yield start - self.assertTrue(getstream.pay_key_fee_called) + # self.assertTrue(getstream.pay_key_fee_called) # @defer.inlineCallbacks # def test_finish_stopped_downloader(self):