diff --git a/lbrynet/stream/assembler.py b/lbrynet/stream/assembler.py index 356a8af76..6b325d248 100644 --- a/lbrynet/stream/assembler.py +++ b/lbrynet/stream/assembler.py @@ -88,7 +88,10 @@ class StreamAssembler: await self.blob_manager.blob_completed(self.sd_blob) with open(self.output_path, 'wb') as stream_handle: self.stream_handle = stream_handle - for blob_info in self.descriptor.blobs[:-1]: + for i, blob_info in enumerate(self.descriptor.blobs[:-1]): + if blob_info.blob_num != i: + log.error("sd blob %s is invalid, cannot assemble stream", self.descriptor.sd_hash) + return while not stream_handle.closed: try: blob = await self.get_blob(blob_info.blob_hash, blob_info.length) diff --git a/lbrynet/stream/descriptor.py b/lbrynet/stream/descriptor.py index 453852d28..103f42e18 100644 --- a/lbrynet/stream/descriptor.py +++ b/lbrynet/stream/descriptor.py @@ -102,6 +102,8 @@ class StreamDescriptor: raise InvalidStreamDescriptorError("Contains zero-length data blob") if 'blob_hash' in decoded['blobs'][-1]: raise InvalidStreamDescriptorError("Stream terminator blob should not have a hash") + if any([i != blob_info['blob_num'] for i, blob_info in enumerate(decoded['blobs'])]): + raise InvalidStreamDescriptorError("Stream contains out of order or skipped blobs") descriptor = cls( loop, blob_dir, binascii.unhexlify(decoded['stream_name']).decode(), diff --git a/tests/unit/stream/test_stream_descriptor.py b/tests/unit/stream/test_stream_descriptor.py new file mode 100644 index 000000000..f634e9972 --- /dev/null +++ b/tests/unit/stream/test_stream_descriptor.py @@ -0,0 +1,77 @@ +import os +import asyncio +import tempfile +import shutil +import json + +from torba.testcase import AsyncioTestCase +from lbrynet.conf import Config +from lbrynet.error import InvalidStreamDescriptorError +from lbrynet.extras.daemon.storage import SQLiteStorage +from lbrynet.blob.blob_manager import BlobFileManager +from lbrynet.stream.descriptor import StreamDescriptor + + +class TestStreamDescriptor(AsyncioTestCase): + async def asyncSetUp(self): + self.loop = asyncio.get_event_loop() + self.key = b'deadbeef' * 4 + self.cleartext = os.urandom(20000000) + self.tmp_dir = tempfile.mkdtemp() + self.addCleanup(lambda: shutil.rmtree(self.tmp_dir)) + self.storage = SQLiteStorage(Config(), ":memory:") + await self.storage.open() + self.blob_manager = BlobFileManager(self.loop, self.tmp_dir, self.storage) + + self.file_path = os.path.join(self.tmp_dir, "test_file") + with open(self.file_path, 'wb') as f: + f.write(self.cleartext) + + self.descriptor = await StreamDescriptor.create_stream(self.loop, self.tmp_dir, self.file_path, key=self.key) + self.sd_hash = self.descriptor.calculate_sd_hash() + self.sd_dict = json.loads(self.descriptor.as_json()) + + def _write_sd(self): + with open(os.path.join(self.tmp_dir, self.sd_hash), 'wb') as f: + f.write(json.dumps(self.sd_dict, sort_keys=True).encode()) + + async def _test_invalid_sd(self): + self._write_sd() + with self.assertRaises(InvalidStreamDescriptorError): + await self.blob_manager.get_stream_descriptor(self.sd_hash) + + async def test_load_sd_blob(self): + self._write_sd() + descriptor = await self.blob_manager.get_stream_descriptor(self.sd_hash) + self.assertEqual(descriptor.calculate_sd_hash(), self.sd_hash) + + async def test_missing_terminator(self): + self.sd_dict['blobs'].pop() + await self._test_invalid_sd() + + async def test_terminator_not_at_end(self): + terminator = self.sd_dict['blobs'].pop() + self.sd_dict['blobs'] = [terminator] + self.sd_dict['blobs'] + await self._test_invalid_sd() + + async def test_terminator_has_blob_hash(self): + self.sd_dict['blobs'][-1]['blob_hash'] = '1' * 96 + await self._test_invalid_sd() + + async def test_blob_order(self): + terminator = self.sd_dict['blobs'].pop() + self.sd_dict['blobs'].reverse() + self.sd_dict['blobs'].append(terminator) + await self._test_invalid_sd() + + async def test_skip_blobs(self): + self.sd_dict['blobs'][-2]['blob_num'] = self.sd_dict['blobs'][-2]['blob_num'] + 1 + await self._test_invalid_sd() + + async def test_invalid_stream_hash(self): + self.sd_dict['blobs'][-2]['blob_hash'] = '1' * 96 + await self._test_invalid_sd() + + async def test_zero_length_blob(self): + self.sd_dict['blobs'][-2]['length'] = 0 + await self._test_invalid_sd()