diff --git a/torba/tests/client_tests/unit/test_database.py b/torba/tests/client_tests/unit/test_database.py index 456f873bd..89f62c70a 100644 --- a/torba/tests/client_tests/unit/test_database.py +++ b/torba/tests/client_tests/unit/test_database.py @@ -436,6 +436,8 @@ class TestUpgrade(AsyncioTestCase): class TestSQLiteRace(AsyncioTestCase): + max_misuse_attempts = 40000 + def setup_db(self): self.db = sqlite3.connect(":memory:", isolation_level=None) self.db.executescript( @@ -473,11 +475,12 @@ class TestSQLiteRace(AsyncioTestCase): except sqlite3.InterfaceError as err: self.assertEqual(str(err), "Error binding parameter 0 - probably unsupported type.") - async def test_unhandled_sqlite_misuse(self, max_attempts=100000): + async def test_unhandled_sqlite_misuse(self): # test SQLITE_MISUSE being incorrectly raised as a param 0 binding error attempts = 0 + try: - while attempts < max_attempts: + while attempts < self.max_misuse_attempts: f1 = asyncio.wrap_future( self.loop.run_in_executor( self.executor, self.db.executemany, "update test1 set val='derp' where id=?", @@ -492,6 +495,31 @@ class TestSQLiteRace(AsyncioTestCase): ) attempts += 1 await asyncio.gather(f1, f2) - self.assertTrue(False, f'failed to raise SQLITE_MISUSE within {max_attempts} tries') + self.assertTrue(False, f'failed to raise SQLITE_MISUSE within {self.max_misuse_attempts} tries\n' + f'this test failing means either the sqlite race conditions ' + f'have been fixed in cpython or the test max_attempts needs to be increased') except sqlite3.InterfaceError as err: self.assertEqual(str(err), "Error binding parameter 0 - probably unsupported type.") + + async def test_fetchall_prevents_sqlite_misuse(self): + # test that calling fetchall sufficiently avoids the race + attempts = 0 + + def executemany_fetchall(query, params): + self.db.executemany(query, params).fetchall() + + while attempts < self.max_misuse_attempts: + f1 = asyncio.wrap_future( + self.loop.run_in_executor( + self.executor, executemany_fetchall, "update test1 set val='derp' where id=?", + ((str(i),) for i in range(2)) + ) + ) + f2 = asyncio.wrap_future( + self.loop.run_in_executor( + self.executor, executemany_fetchall, "update test2 set val='derp' where id=?", + ((str(i),) for i in range(2)) + ) + ) + attempts += 1 + await asyncio.gather(f1, f2) \ No newline at end of file