diff --git a/torba/client/basedatabase.py b/torba/client/basedatabase.py index 28d0901c0..106e3e1d5 100644 --- a/torba/client/basedatabase.py +++ b/torba/client/basedatabase.py @@ -35,6 +35,11 @@ class AIOSQLite: self.connection = None return asyncio.get_event_loop_policy().get_event_loop().call_later(0.01, __close, conn) + def executemany(self, sql: str, params: Iterable): + def __executemany_in_a_transaction(conn: sqlite3.Connection, *args, **kwargs): + return conn.executemany(*args, **kwargs) + return self.run(__executemany_in_a_transaction, sql, params) + def executescript(self, script: str) -> Awaitable: return wrap_future(self.executor.submit(self.connection.executescript, script)) @@ -321,12 +326,8 @@ class BaseDatabase(SQLiteMixin): return self.db.run(_transaction, tx, address, txhash, history) async def reserve_outputs(self, txos, is_reserved=True): - txoids = [txo.id for txo in txos] - await self.db.execute( - "UPDATE txo SET is_reserved = ? WHERE txoid IN ({})".format( - ', '.join(['?']*len(txoids)) - ), [is_reserved]+txoids - ) + txoids = ((is_reserved, txo.id) for txo in txos) + await self.db.executemany("UPDATE txo SET is_reserved = ? WHERE txoid = ?", txoids) async def release_outputs(self, txos): await self.reserve_outputs(txos, is_reserved=False) @@ -488,18 +489,12 @@ class BaseDatabase(SQLiteMixin): return addresses[0] async def add_keys(self, account, chain, keys): - sql = ( - "insert into pubkey_address " - "(address, account, chain, position, pubkey) " - "values " - ) + ', '.join(['(?, ?, ?, ?, ?)'] * len(keys)) - values = [] - for position, pubkey in keys: - values.extend(( - pubkey.address, account.public_key.address, chain, position, - sqlite3.Binary(pubkey.pubkey_bytes) - )) - await self.db.execute(sql, values) + sql = "insert into pubkey_address (address, account, chain, position, pubkey) values (?, ?, ?, ?, ?)" + values = ( + (pubkey.address, account.public_key.address, chain, position, sqlite3.Binary(pubkey.pubkey_bytes)) + for position, pubkey in keys + ) + await self.db.executemany(sql, values) async def _set_address_history(self, address, history): await self.db.execute(