write lock

This commit is contained in:
Jack Robison 2020-02-25 14:15:27 -05:00
parent a26cfc639c
commit 61603ccfce
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2

View file

@ -60,6 +60,9 @@ class AIOSQLite:
self.writer_connection: Optional[sqlite3.Connection] = None
self._closing = False
self.query_count = 0
self.write_lock = asyncio.Lock()
self.writers = 0
self.read_ready = asyncio.Event()
@classmethod
async def connect(cls, path: Union[bytes, str], *args, **kwargs):
@ -74,6 +77,7 @@ class AIOSQLite:
max_workers=readers, initializer=initializer, initargs=(path, )
)
await asyncio.get_event_loop().run_in_executor(db.writer_executor, _connect_writer)
db.read_ready.set()
return db
async def close(self):
@ -83,6 +87,7 @@ class AIOSQLite:
await asyncio.get_event_loop().run_in_executor(self.writer_executor, self.writer_connection.close)
self.writer_executor.shutdown(wait=True)
self.reader_executor.shutdown(wait=True)
self.read_ready.clear()
self.writer_connection = None
def executemany(self, sql: str, params: Iterable):
@ -93,32 +98,44 @@ class AIOSQLite:
def executescript(self, script: str) -> Awaitable:
return self.run(lambda conn: conn.executescript(script))
async def _execute_fetch(self, sql: str, parameters: Iterable = None,
read_only: bool = False, fetch_all: bool = False) -> Iterable[sqlite3.Row]:
read_only_fn = run_read_only_fetchall if fetch_all else run_read_only_fetchone
parameters = parameters if parameters is not None else []
if read_only:
while self.writers:
await self.read_ready.wait()
return await asyncio.get_event_loop().run_in_executor(
self.reader_executor, read_only_fn, sql, parameters
)
if fetch_all:
return await self.run(lambda conn: conn.execute(sql, parameters).fetchall())
return await self.run(lambda conn: conn.execute(sql, parameters).fetchone())
async def execute_fetchall(self, sql: str, parameters: Iterable = None,
read_only: bool = False) -> Iterable[sqlite3.Row]:
parameters = parameters if parameters is not None else []
if read_only:
return await asyncio.get_event_loop().run_in_executor(
self.reader_executor, run_read_only_fetchall, sql, parameters
)
return await self.run(lambda conn: conn.execute(sql, parameters).fetchall())
return await self._execute_fetch(sql, parameters, read_only, fetch_all=True)
def execute_fetchone(self, sql: str, parameters: Iterable = None,
read_only: bool = False) -> Awaitable[Iterable[sqlite3.Row]]:
parameters = parameters if parameters is not None else []
if read_only:
return asyncio.get_event_loop().run_in_executor(
self.reader_executor, run_read_only_fetchone, sql, parameters
)
return self.run(lambda conn: conn.execute(sql, parameters).fetchone())
async def execute_fetchone(self, sql: str, parameters: Iterable = None,
read_only: bool = False) -> Iterable[sqlite3.Row]:
return await self._execute_fetch(sql, parameters, read_only, fetch_all=False)
def execute(self, sql: str, parameters: Iterable = None) -> Awaitable[sqlite3.Cursor]:
parameters = parameters if parameters is not None else []
return self.run(lambda conn: conn.execute(sql, parameters))
def run(self, fun, *args, **kwargs) -> Awaitable:
return asyncio.get_event_loop().run_in_executor(
self.writer_executor, lambda: self.__run_transaction(fun, *args, **kwargs)
)
async def run(self, fun, *args, **kwargs):
self.writers += 1
self.read_ready.clear()
async with self.write_lock:
try:
return await asyncio.get_event_loop().run_in_executor(
self.writer_executor, lambda: self.__run_transaction(fun, *args, **kwargs)
)
finally:
self.writers -= 1
if not self.writers:
self.read_ready.set()
def __run_transaction(self, fun: Callable[[sqlite3.Connection, Any, Any], Any], *args, **kwargs):
self.writer_connection.execute('begin')