write lock
This commit is contained in:
parent
a26cfc639c
commit
61603ccfce
1 changed files with 35 additions and 18 deletions
|
@ -60,6 +60,9 @@ class AIOSQLite:
|
|||
self.writer_connection: Optional[sqlite3.Connection] = None
|
||||
self._closing = False
|
||||
self.query_count = 0
|
||||
self.write_lock = asyncio.Lock()
|
||||
self.writers = 0
|
||||
self.read_ready = asyncio.Event()
|
||||
|
||||
@classmethod
|
||||
async def connect(cls, path: Union[bytes, str], *args, **kwargs):
|
||||
|
@ -74,6 +77,7 @@ class AIOSQLite:
|
|||
max_workers=readers, initializer=initializer, initargs=(path, )
|
||||
)
|
||||
await asyncio.get_event_loop().run_in_executor(db.writer_executor, _connect_writer)
|
||||
db.read_ready.set()
|
||||
return db
|
||||
|
||||
async def close(self):
|
||||
|
@ -83,6 +87,7 @@ class AIOSQLite:
|
|||
await asyncio.get_event_loop().run_in_executor(self.writer_executor, self.writer_connection.close)
|
||||
self.writer_executor.shutdown(wait=True)
|
||||
self.reader_executor.shutdown(wait=True)
|
||||
self.read_ready.clear()
|
||||
self.writer_connection = None
|
||||
|
||||
def executemany(self, sql: str, params: Iterable):
|
||||
|
@ -93,32 +98,44 @@ class AIOSQLite:
|
|||
def executescript(self, script: str) -> Awaitable:
|
||||
return self.run(lambda conn: conn.executescript(script))
|
||||
|
||||
async def _execute_fetch(self, sql: str, parameters: Iterable = None,
|
||||
read_only: bool = False, fetch_all: bool = False) -> Iterable[sqlite3.Row]:
|
||||
read_only_fn = run_read_only_fetchall if fetch_all else run_read_only_fetchone
|
||||
parameters = parameters if parameters is not None else []
|
||||
if read_only:
|
||||
while self.writers:
|
||||
await self.read_ready.wait()
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
self.reader_executor, read_only_fn, sql, parameters
|
||||
)
|
||||
if fetch_all:
|
||||
return await self.run(lambda conn: conn.execute(sql, parameters).fetchall())
|
||||
return await self.run(lambda conn: conn.execute(sql, parameters).fetchone())
|
||||
|
||||
async def execute_fetchall(self, sql: str, parameters: Iterable = None,
|
||||
read_only: bool = False) -> Iterable[sqlite3.Row]:
|
||||
parameters = parameters if parameters is not None else []
|
||||
if read_only:
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
self.reader_executor, run_read_only_fetchall, sql, parameters
|
||||
)
|
||||
return await self.run(lambda conn: conn.execute(sql, parameters).fetchall())
|
||||
return await self._execute_fetch(sql, parameters, read_only, fetch_all=True)
|
||||
|
||||
def execute_fetchone(self, sql: str, parameters: Iterable = None,
|
||||
read_only: bool = False) -> Awaitable[Iterable[sqlite3.Row]]:
|
||||
parameters = parameters if parameters is not None else []
|
||||
if read_only:
|
||||
return asyncio.get_event_loop().run_in_executor(
|
||||
self.reader_executor, run_read_only_fetchone, sql, parameters
|
||||
)
|
||||
return self.run(lambda conn: conn.execute(sql, parameters).fetchone())
|
||||
async def execute_fetchone(self, sql: str, parameters: Iterable = None,
|
||||
read_only: bool = False) -> Iterable[sqlite3.Row]:
|
||||
return await self._execute_fetch(sql, parameters, read_only, fetch_all=False)
|
||||
|
||||
def execute(self, sql: str, parameters: Iterable = None) -> Awaitable[sqlite3.Cursor]:
|
||||
parameters = parameters if parameters is not None else []
|
||||
return self.run(lambda conn: conn.execute(sql, parameters))
|
||||
|
||||
def run(self, fun, *args, **kwargs) -> Awaitable:
|
||||
return asyncio.get_event_loop().run_in_executor(
|
||||
self.writer_executor, lambda: self.__run_transaction(fun, *args, **kwargs)
|
||||
)
|
||||
async def run(self, fun, *args, **kwargs):
|
||||
self.writers += 1
|
||||
self.read_ready.clear()
|
||||
async with self.write_lock:
|
||||
try:
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
self.writer_executor, lambda: self.__run_transaction(fun, *args, **kwargs)
|
||||
)
|
||||
finally:
|
||||
self.writers -= 1
|
||||
if not self.writers:
|
||||
self.read_ready.set()
|
||||
|
||||
def __run_transaction(self, fun: Callable[[sqlite3.Connection, Any, Any], Any], *args, **kwargs):
|
||||
self.writer_connection.execute('begin')
|
||||
|
|
Loading…
Reference in a new issue