Merge pull request #2452 from lbryio/fix-sqlite-misuse
Fix SQLITE_MISUSE
This commit is contained in:
commit
b2f0374426
2 changed files with 34 additions and 19 deletions
|
@ -1,5 +1,9 @@
|
|||
import asyncio
|
||||
import sqlite3
|
||||
from lbry.testcase import CommandTestCase
|
||||
from torba.client.basedatabase import SQLiteMixin
|
||||
from lbry.wallet.dewies import dewies_to_lbc
|
||||
from lbry.wallet.account import Account
|
||||
|
||||
|
||||
def extract(d, keys):
|
||||
|
@ -7,6 +11,13 @@ def extract(d, keys):
|
|||
|
||||
|
||||
class AccountManagement(CommandTestCase):
|
||||
async def test_sqlite_binding_error(self):
|
||||
tasks = [
|
||||
self.loop.create_task(self.daemon.jsonrpc_account_create('second account' + str(x))) for x in range(100)
|
||||
]
|
||||
await asyncio.wait(tasks)
|
||||
for result in tasks:
|
||||
self.assertFalse(isinstance(result.result(), Exception))
|
||||
|
||||
async def test_account_list_set_create_remove_add(self):
|
||||
# check initial account
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import logging
|
||||
import asyncio
|
||||
from asyncio import wrap_future
|
||||
from binascii import hexlify
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
|
||||
|
@ -23,8 +22,10 @@ class AIOSQLite:
|
|||
|
||||
@classmethod
|
||||
async def connect(cls, path: Union[bytes, str], *args, **kwargs):
|
||||
def _connect():
|
||||
return sqlite3.connect(path, *args, **kwargs)
|
||||
db = cls()
|
||||
db.connection = await wrap_future(db.executor.submit(sqlite3.connect, path, *args, **kwargs))
|
||||
db.connection = await asyncio.get_event_loop().run_in_executor(db.executor, _connect)
|
||||
return db
|
||||
|
||||
async def close(self):
|
||||
|
@ -38,25 +39,25 @@ class AIOSQLite:
|
|||
return asyncio.get_event_loop_policy().get_event_loop().call_later(0.01, __close, conn)
|
||||
|
||||
def executemany(self, sql: str, params: Iterable):
|
||||
def __executemany_in_a_transaction(conn: sqlite3.Connection, *args, **kwargs):
|
||||
return conn.executemany(*args, **kwargs)
|
||||
return self.run(__executemany_in_a_transaction, sql, params)
|
||||
params = params if params is not None else []
|
||||
# this fetchall is needed to prevent SQLITE_MISUSE
|
||||
return self.run(lambda conn: conn.executemany(sql, params).fetchall())
|
||||
|
||||
def executescript(self, script: str) -> Awaitable:
|
||||
return wrap_future(self.executor.submit(self.connection.executescript, script))
|
||||
return self.run(lambda conn: conn.executescript(script))
|
||||
|
||||
def execute_fetchall(self, sql: str, parameters: Iterable = None) -> Awaitable[Iterable[sqlite3.Row]]:
|
||||
parameters = parameters if parameters is not None else []
|
||||
def __fetchall(conn: sqlite3.Connection, *args, **kwargs):
|
||||
return conn.execute(*args, **kwargs).fetchall()
|
||||
return wrap_future(self.executor.submit(__fetchall, self.connection, sql, parameters))
|
||||
return self.run(lambda conn: conn.execute(sql, parameters).fetchall())
|
||||
|
||||
def execute(self, sql: str, parameters: Iterable = None) -> Awaitable[sqlite3.Cursor]:
|
||||
parameters = parameters if parameters is not None else []
|
||||
return self.run(lambda conn, sql, parameters: conn.execute(sql, parameters), sql, parameters)
|
||||
return self.run(lambda conn: conn.execute(sql, parameters))
|
||||
|
||||
def run(self, fun, *args, **kwargs) -> Awaitable:
|
||||
return wrap_future(self.executor.submit(self.__run_transaction, fun, *args, **kwargs))
|
||||
return asyncio.get_event_loop().run_in_executor(
|
||||
self.executor, lambda: self.__run_transaction(fun, *args, **kwargs)
|
||||
)
|
||||
|
||||
def __run_transaction(self, fun: Callable[[sqlite3.Connection, Any, Any], Any], *args, **kwargs):
|
||||
self.connection.execute('begin')
|
||||
|
@ -64,19 +65,20 @@ class AIOSQLite:
|
|||
result = fun(self.connection, *args, **kwargs) # type: ignore
|
||||
self.connection.commit()
|
||||
return result
|
||||
except (Exception, OSError): # as e:
|
||||
#log.exception('Error running transaction:', exc_info=e)
|
||||
except (Exception, OSError) as e:
|
||||
log.exception('Error running transaction:', exc_info=e)
|
||||
self.connection.rollback()
|
||||
log.warning("rolled back")
|
||||
raise
|
||||
|
||||
def run_with_foreign_keys_disabled(self, fun, *args, **kwargs) -> Awaitable:
|
||||
return wrap_future(
|
||||
self.executor.submit(self.__run_transaction_with_foreign_keys_disabled, fun, *args, **kwargs)
|
||||
return asyncio.get_event_loop().run_in_executor(
|
||||
self.executor, self.__run_transaction_with_foreign_keys_disabled, fun, args, kwargs
|
||||
)
|
||||
|
||||
def __run_transaction_with_foreign_keys_disabled(self,
|
||||
fun: Callable[[sqlite3.Connection, Any, Any], Any],
|
||||
*args, **kwargs):
|
||||
args, kwargs):
|
||||
foreign_keys_enabled, = self.connection.execute("pragma foreign_keys").fetchone()
|
||||
if not foreign_keys_enabled:
|
||||
raise sqlite3.IntegrityError("foreign keys are disabled, use `AIOSQLite.run` instead")
|
||||
|
@ -584,9 +586,11 @@ class BaseDatabase(SQLiteMixin):
|
|||
async def add_keys(self, account, chain, keys):
|
||||
await self.db.executemany(
|
||||
"insert into pubkey_address (address, account, chain, position, pubkey) values (?, ?, ?, ?, ?)",
|
||||
((pubkey.address, account.public_key.address, chain,
|
||||
position, sqlite3.Binary(pubkey.pubkey_bytes))
|
||||
for position, pubkey in keys)
|
||||
(
|
||||
(pubkey.address, account.public_key.address, chain, position,
|
||||
sqlite3.Binary(pubkey.pubkey_bytes))
|
||||
for position, pubkey in keys
|
||||
)
|
||||
)
|
||||
|
||||
async def _set_address_history(self, address, history):
|
||||
|
|
Loading…
Reference in a new issue