import os import time import unittest from unittest import mock import asyncio from lbrynet.blob_exchange.serialization import BlobResponse from lbrynet.blob_exchange.server import BlobServerProtocol from lbrynet.conf import Config from lbrynet.stream.descriptor import StreamDescriptor from lbrynet.stream.downloader import StreamDownloader from lbrynet.dht.node import Node from lbrynet.dht.peer import KademliaPeer from lbrynet.blob.blob_file import MAX_BLOB_SIZE from tests.unit.blob_exchange.test_transfer_blob import BlobExchangeTestBase class TestStreamDownloader(BlobExchangeTestBase): async def setup_stream(self, blob_count: int = 10): self.stream_bytes = b'' for _ in range(blob_count): self.stream_bytes += os.urandom((MAX_BLOB_SIZE - 1)) # create the stream file_path = os.path.join(self.server_dir, "test_file") with open(file_path, 'wb') as f: f.write(self.stream_bytes) descriptor = await StreamDescriptor.create_stream(self.loop, self.server_blob_manager.blob_dir, file_path) self.sd_hash = descriptor.calculate_sd_hash() conf = Config(data_dir=self.server_dir, wallet_dir=self.server_dir, download_dir=self.server_dir, reflector_servers=[]) self.downloader = StreamDownloader(self.loop, conf, self.client_blob_manager, self.sd_hash) async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None): await self.setup_stream(blob_count) mock_node = mock.Mock(spec=Node) def _mock_accumulate_peers(q1, q2): async def _task(): pass q2.put_nowait([self.server_from_client]) return q2, self.loop.create_task(_task()) mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers self.downloader.download(mock_node) await self.downloader.stream_finished_event.wait() self.assertTrue(self.downloader.stream_handle.closed) self.assertTrue(os.path.isfile(self.downloader.output_path)) self.downloader.stop() self.assertIs(self.downloader.stream_handle, None) self.assertTrue(os.path.isfile(self.downloader.output_path)) with open(self.downloader.output_path, 'rb') as f: self.assertEqual(f.read(), self.stream_bytes) await asyncio.sleep(0.01) async def test_transfer_stream(self): await self._test_transfer_stream(10) @unittest.SkipTest async def test_transfer_hundred_blob_stream(self): await self._test_transfer_stream(100) async def test_transfer_stream_bad_first_peer_good_second(self): await self.setup_stream(2) mock_node = mock.Mock(spec=Node) q = asyncio.Queue() bad_peer = KademliaPeer(self.loop, "127.0.0.1", b'2' * 48, tcp_port=3334) def _mock_accumulate_peers(q1, q2): async def _task(): pass q2.put_nowait([bad_peer]) self.loop.call_later(1, q2.put_nowait, [self.server_from_client]) return q2, self.loop.create_task(_task()) mock_node.accumulate_peers = _mock_accumulate_peers self.downloader.download(mock_node) await self.downloader.stream_finished_event.wait() self.assertTrue(os.path.isfile(self.downloader.output_path)) with open(self.downloader.output_path, 'rb') as f: self.assertEqual(f.read(), self.stream_bytes) # self.assertIs(self.server_from_client.tcp_last_down, None) # self.assertIsNot(bad_peer.tcp_last_down, None) async def test_client_chunked_response(self): self.server.stop_server() class ChunkedServerProtocol(BlobServerProtocol): def send_response(self, responses): to_send = [] while responses: to_send.append(responses.pop()) for byte in BlobResponse(to_send).serialize(): self.transport.write(bytes([byte])) self.server.server_protocol_class = ChunkedServerProtocol self.server.start_server(33333, '127.0.0.1') self.assertEqual(0, len(self.client_blob_manager.completed_blob_hashes)) await asyncio.wait_for(self._test_transfer_stream(10), timeout=2) self.assertEqual(11, len(self.client_blob_manager.completed_blob_hashes))