diff --git a/lbrynet/extras/daemon/Components.py b/lbrynet/extras/daemon/Components.py index fb3a6485f..afd32b078 100644 --- a/lbrynet/extras/daemon/Components.py +++ b/lbrynet/extras/daemon/Components.py @@ -63,7 +63,7 @@ class DatabaseComponent(Component): @staticmethod def get_current_db_revision(): - return 9 + return 10 @property def revision_filename(self): diff --git a/lbrynet/extras/daemon/migrator/migrate9to10.py b/lbrynet/extras/daemon/migrator/migrate9to10.py new file mode 100644 index 000000000..ac08e77bc --- /dev/null +++ b/lbrynet/extras/daemon/migrator/migrate9to10.py @@ -0,0 +1,20 @@ +import sqlite3 +import os + + +def do_migration(conf): + db_path = os.path.join(conf.data_dir, "lbrynet.sqlite") + connection = sqlite3.connect(db_path) + cursor = connection.cursor() + + query = "select stream_hash, sd_hash from main.stream" + for stream_hash, sd_hash in cursor.execute(query): + head_blob_hash = cursor.execute( + "select blob_hash from stream_blob where position = 0 and stream_hash = ?", + (stream_hash,) + ).fetchone() + if not head_blob_hash: + continue + cursor.execute("update blob set should_announce=1 where blob_hash in (?, ?)", (sd_hash, head_blob_hash[0],)) + connection.commit() + connection.close() diff --git a/lbrynet/extras/daemon/storage.py b/lbrynet/extras/daemon/storage.py index 28e8271ba..28ce23acf 100644 --- a/lbrynet/extras/daemon/storage.py +++ b/lbrynet/extras/daemon/storage.py @@ -153,21 +153,12 @@ def get_all_lbry_files(transaction: sqlite3.Connection) -> typing.List[typing.Di def store_stream(transaction: sqlite3.Connection, sd_blob: 'BlobFile', descriptor: 'StreamDescriptor'): - # add the head blob and set it to be announced - transaction.execute( - "insert or ignore into blob values (?, ?, ?, ?, ?, ?, ?), (?, ?, ?, ?, ?, ?, ?)", - ( - sd_blob.blob_hash, sd_blob.length, 0, 1, "pending", 0, 0, - descriptor.blobs[0].blob_hash, descriptor.blobs[0].length, 0, 1, "pending", 0, 0 - ) + # add all blobs, except the last one, which is empty + transaction.executemany( + "insert or ignore into blob values (?, ?, ?, ?, ?, ?, ?)", + [(blob.blob_hash, blob.length, 0, 0, "pending", 0, 0) + for blob in (descriptor.blobs[:-1] if len(descriptor.blobs) > 1 else descriptor.blobs) + [sd_blob]] ) - # add the rest of the blobs with announcement off - if len(descriptor.blobs) > 2: - transaction.executemany( - "insert or ignore into blob values (?, ?, ?, ?, ?, ?, ?)", - [(blob.blob_hash, blob.length, 0, 0, "pending", 0, 0) - for blob in descriptor.blobs[1:-1]] - ) # associate the blobs to the stream transaction.execute("insert or ignore into stream values (?, ?, ?, ?, ?)", (descriptor.stream_hash, sd_blob.blob_hash, descriptor.key, @@ -179,6 +170,11 @@ def store_stream(transaction: sqlite3.Connection, sd_blob: 'BlobFile', descripto [(descriptor.stream_hash, blob.blob_hash, blob.blob_num, blob.iv) for blob in descriptor.blobs] ) + # ensure should_announce is set regardless if insert was ignored + transaction.execute( + "update blob set should_announce=1 where blob_hash in (?, ?)", + (sd_blob.blob_hash, descriptor.blobs[0].blob_hash,) + ) def delete_stream(transaction: sqlite3.Connection, descriptor: 'StreamDescriptor'): diff --git a/tests/integration/test_file_commands.py b/tests/integration/test_file_commands.py index 419afd8cf..78cbd2d07 100644 --- a/tests/integration/test_file_commands.py +++ b/tests/integration/test_file_commands.py @@ -28,6 +28,24 @@ class FileCommands(CommandTestCase): await self.daemon.jsonrpc_get('lbry://foo') self.assertEqual(len(self.daemon.jsonrpc_file_list()), 1) + async def test_announces(self): + # announces on publish + self.assertEqual(await self.daemon.storage.get_blobs_to_announce(), []) + await self.stream_create('foo', '0.01') + stream = self.daemon.jsonrpc_file_list()[0] + self.assertSetEqual( + set(await self.daemon.storage.get_blobs_to_announce()), + {stream.sd_hash, stream.descriptor.blobs[0].blob_hash} + ) + self.assertTrue(await self.daemon.jsonrpc_file_delete(delete_all=True)) + # announces on download + self.assertEqual(await self.daemon.storage.get_blobs_to_announce(), []) + stream = await self.daemon.jsonrpc_get('foo') + self.assertSetEqual( + set(await self.daemon.storage.get_blobs_to_announce()), + {stream.sd_hash, stream.descriptor.blobs[0].blob_hash} + ) + async def test_file_list_fields(self): await self.stream_create('foo', '0.01') file_list = self.sout(self.daemon.jsonrpc_file_list()) diff --git a/tests/unit/stream/test_managed_stream.py b/tests/unit/stream/test_managed_stream.py index 7b8d92ff7..372e5766d 100644 --- a/tests/unit/stream/test_managed_stream.py +++ b/tests/unit/stream/test_managed_stream.py @@ -151,26 +151,6 @@ class TestManagedStream(BlobExchangeTestBase): async def test_create_and_decrypt_multi_blob_stream(self): await self.test_create_and_decrypt_one_blob_stream(10) - # async def test_create_managed_stream_announces(self): - # # setup a blob manager - # storage = SQLiteStorage(Config(), ":memory:") - # await storage.open() - # tmp_dir = tempfile.mkdtemp() - # self.addCleanup(lambda: shutil.rmtree(tmp_dir)) - # blob_manager = BlobManager(self.loop, tmp_dir, storage) - # stream_manager = StreamManager(self.loop, Config(), blob_manager, None, storage, None) - # # create the stream - # download_dir = tempfile.mkdtemp() - # self.addCleanup(lambda: shutil.rmtree(download_dir)) - # file_path = os.path.join(download_dir, "test_file") - # with open(file_path, 'wb') as f: - # f.write(b'testtest') - # - # stream = await stream_manager.create_stream(file_path) - # self.assertEqual( - # [stream.sd_hash, stream.descriptor.blobs[0].blob_hash], - # await storage.get_blobs_to_announce()) - # async def test_create_truncate_and_handle_stream(self): # # The purpose of this test is just to make sure it can finish even if a blob is corrupt/truncated # await asyncio.wait_for(self.test_create_and_decrypt_one_blob_stream(corrupt=True), timeout=5) diff --git a/tests/unit/stream/test_reflector.py b/tests/unit/stream/test_reflector.py index b5cdf2960..2034b5af9 100644 --- a/tests/unit/stream/test_reflector.py +++ b/tests/unit/stream/test_reflector.py @@ -62,3 +62,8 @@ class TestStreamAssembler(AsyncioTestCase): sent = await self.stream.upload_to_reflector('127.0.0.1', 5566) self.assertListEqual(sent, []) + + async def test_announces(self): + to_announce = await self.storage.get_blobs_to_announce() + self.assertIn(self.stream.sd_hash, to_announce, "sd blob not set to announce") + self.assertIn(self.stream.descriptor.blobs[0].blob_hash, to_announce, "head blob not set to announce")