diff --git a/tests/client_tests/unit/test_database.py b/tests/client_tests/unit/test_database.py index e9d873d77..29aaa684b 100644 --- a/tests/client_tests/unit/test_database.py +++ b/tests/client_tests/unit/test_database.py @@ -1,15 +1,56 @@ import unittest +import sqlite3 from torba.client.wallet import Wallet from torba.client.constants import COIN from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class -from torba.client.basedatabase import query, constraints_to_sql +from torba.client.basedatabase import query, constraints_to_sql, AIOSQLite from torba.testcase import AsyncioTestCase from client_tests.unit.test_transaction import get_output, NULL_HASH +class TestAIOSQLite(AsyncioTestCase): + async def asyncSetUp(self): + self.db = await AIOSQLite.connect(':memory:') + await self.db.executescript(""" + pragma foreign_keys=on; + create table parent (id integer primary key, name); + create table child (id integer primary key, parent_id references parent); + """) + await self.db.execute("insert into parent values (1, 'test')") + await self.db.execute("insert into child values (2, 1)") + + @staticmethod + def delete_item(transaction): + transaction.execute('delete from parent where id=1') + + async def test_foreign_keys_integrity_error(self): + self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent")) + + with self.assertRaises(sqlite3.IntegrityError): + await self.db.run(self.delete_item) + self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent")) + + await self.db.executescript("pragma foreign_keys=off;") + + await self.db.run(self.delete_item) + self.assertListEqual([], await self.db.execute_fetchall("select * from parent")) + + async def test_run_without_foreign_keys(self): + self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent")) + await self.db.run_with_foreign_keys_disabled(self.delete_item) + self.assertListEqual([], await self.db.execute_fetchall("select * from parent")) + + async def test_integrity_error_when_foreign_keys_disabled_and_skipped(self): + await self.db.executescript("pragma foreign_keys=off;") + self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent")) + with self.assertRaises(sqlite3.IntegrityError): + await self.db.run_with_foreign_keys_disabled(self.delete_item) + self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent")) + + class TestQueryBuilder(unittest.TestCase): def test_dot(self): diff --git a/torba/client/basedatabase.py b/torba/client/basedatabase.py index ccbffb7e5..77a2bb647 100644 --- a/torba/client/basedatabase.py +++ b/torba/client/basedatabase.py @@ -67,21 +67,20 @@ class AIOSQLite: self.connection.rollback() raise - def run_nofk(self, fun, *args, **kwargs) -> Awaitable: - return wrap_future(self.executor.submit(self.__run_transaction_no_fk, fun, *args, **kwargs)) + def run_with_foreign_keys_disabled(self, fun, *args, **kwargs) -> Awaitable: + return wrap_future( + self.executor.submit(self.__run_transaction_with_foreign_keys_disabled, fun, *args, **kwargs) + ) - def __run_transaction_no_fk(self, fun: Callable[[sqlite3.Connection, Any, Any], Any], *args, **kwargs): + def __run_transaction_with_foreign_keys_disabled(self, fun: Callable[[sqlite3.Connection, Any, Any], Any], *args, + **kwargs): + foreign_keys_enabled, = self.connection.execute("pragma foreign_keys").fetchone() + if not foreign_keys_enabled: + raise sqlite3.IntegrityError("foreign keys are disabled, use `AIOSQLite.run` instead") try: self.connection.execute('pragma foreign_keys=off') self.connection.commit() - try: - self.connection.execute('begin') - result = fun(self.connection, *args, **kwargs) # type: ignore - self.connection.commit() - return result - except (Exception, OSError): # as e: - self.connection.rollback() - raise + return self.__run_transaction(fun, *args, **kwargs) finally: self.connection.execute('pragma foreign_keys=on') self.connection.commit()