forked from LBRYCommunity/lbry-sdk
Merge pull request #2459 from lbryio/test-sqlite-error-handling
Test sqlite error handling
This commit is contained in:
commit
79727f0e97
6 changed files with 183 additions and 93 deletions
|
@ -101,6 +101,9 @@ def _batched_select(transaction, query, parameters, batch_size=900):
|
|||
def get_all_lbry_files(transaction: sqlite3.Connection) -> typing.List[typing.Dict]:
|
||||
files = []
|
||||
signed_claims = {}
|
||||
stream_hashes = tuple(
|
||||
stream_hash for (stream_hash,) in transaction.execute("select stream_hash from file").fetchall()
|
||||
)
|
||||
for (rowid, stream_hash, file_name, download_dir, data_rate, status, saved_file, raw_content_fee, _,
|
||||
sd_hash, stream_key, stream_name, suggested_file_name, *claim_args) in _batched_select(
|
||||
transaction, "select file.rowid, file.*, stream.*, c.* "
|
||||
|
@ -108,9 +111,7 @@ def get_all_lbry_files(transaction: sqlite3.Connection) -> typing.List[typing.Di
|
|||
"inner join content_claim cc on file.stream_hash=cc.stream_hash "
|
||||
"inner join claim c on cc.claim_outpoint=c.claim_outpoint "
|
||||
"where file.stream_hash in {} "
|
||||
"order by c.rowid desc", [
|
||||
stream_hash for (stream_hash,) in transaction.execute("select stream_hash from file")]):
|
||||
|
||||
"order by c.rowid desc", stream_hashes):
|
||||
claim = StoredStreamClaim(stream_hash, *claim_args)
|
||||
if claim.channel_claim_id:
|
||||
if claim.channel_claim_id not in signed_claims:
|
||||
|
@ -137,7 +138,7 @@ def get_all_lbry_files(transaction: sqlite3.Connection) -> typing.List[typing.Di
|
|||
)
|
||||
for claim_name, claim_id in _batched_select(
|
||||
transaction, "select c.claim_name, c.claim_id from claim c where c.claim_id in {}",
|
||||
list(signed_claims.keys())):
|
||||
tuple(signed_claims.keys())):
|
||||
for claim in signed_claims[claim_id]:
|
||||
claim.channel_name = claim_name
|
||||
return files
|
||||
|
@ -147,35 +148,35 @@ def store_stream(transaction: sqlite3.Connection, sd_blob: 'BlobFile', descripto
|
|||
# add all blobs, except the last one, which is empty
|
||||
transaction.executemany(
|
||||
"insert or ignore into blob values (?, ?, ?, ?, ?, ?, ?)",
|
||||
[(blob.blob_hash, blob.length, 0, 0, "pending", 0, 0)
|
||||
for blob in (descriptor.blobs[:-1] if len(descriptor.blobs) > 1 else descriptor.blobs) + [sd_blob]]
|
||||
)
|
||||
((blob.blob_hash, blob.length, 0, 0, "pending", 0, 0)
|
||||
for blob in (descriptor.blobs[:-1] if len(descriptor.blobs) > 1 else descriptor.blobs) + [sd_blob])
|
||||
).fetchall()
|
||||
# associate the blobs to the stream
|
||||
transaction.execute("insert or ignore into stream values (?, ?, ?, ?, ?)",
|
||||
(descriptor.stream_hash, sd_blob.blob_hash, descriptor.key,
|
||||
binascii.hexlify(descriptor.stream_name.encode()).decode(),
|
||||
binascii.hexlify(descriptor.suggested_file_name.encode()).decode()))
|
||||
binascii.hexlify(descriptor.suggested_file_name.encode()).decode())).fetchall()
|
||||
# add the stream
|
||||
transaction.executemany(
|
||||
"insert or ignore into stream_blob values (?, ?, ?, ?)",
|
||||
[(descriptor.stream_hash, blob.blob_hash, blob.blob_num, blob.iv)
|
||||
for blob in descriptor.blobs]
|
||||
)
|
||||
((descriptor.stream_hash, blob.blob_hash, blob.blob_num, blob.iv)
|
||||
for blob in descriptor.blobs)
|
||||
).fetchall()
|
||||
# ensure should_announce is set regardless if insert was ignored
|
||||
transaction.execute(
|
||||
"update blob set should_announce=1 where blob_hash in (?, ?)",
|
||||
(sd_blob.blob_hash, descriptor.blobs[0].blob_hash,)
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
|
||||
def delete_stream(transaction: sqlite3.Connection, descriptor: 'StreamDescriptor'):
|
||||
blob_hashes = [(blob.blob_hash, ) for blob in descriptor.blobs[:-1]]
|
||||
blob_hashes.append((descriptor.sd_hash, ))
|
||||
transaction.execute("delete from content_claim where stream_hash=? ", (descriptor.stream_hash,))
|
||||
transaction.execute("delete from file where stream_hash=? ", (descriptor.stream_hash,))
|
||||
transaction.execute("delete from stream_blob where stream_hash=?", (descriptor.stream_hash,))
|
||||
transaction.execute("delete from stream where stream_hash=? ", (descriptor.stream_hash,))
|
||||
transaction.executemany("delete from blob where blob_hash=?", blob_hashes)
|
||||
transaction.execute("delete from content_claim where stream_hash=? ", (descriptor.stream_hash,)).fetchall()
|
||||
transaction.execute("delete from file where stream_hash=? ", (descriptor.stream_hash,)).fetchall()
|
||||
transaction.execute("delete from stream_blob where stream_hash=?", (descriptor.stream_hash,)).fetchall()
|
||||
transaction.execute("delete from stream where stream_hash=? ", (descriptor.stream_hash,)).fetchall()
|
||||
transaction.executemany("delete from blob where blob_hash=?", blob_hashes).fetchall()
|
||||
|
||||
|
||||
def store_file(transaction: sqlite3.Connection, stream_hash: str, file_name: typing.Optional[str],
|
||||
|
@ -191,7 +192,7 @@ def store_file(transaction: sqlite3.Connection, stream_hash: str, file_name: typ
|
|||
(stream_hash, encoded_file_name, encoded_download_dir, data_payment_rate, status,
|
||||
1 if (file_name and download_directory and os.path.isfile(os.path.join(download_directory, file_name))) else 0,
|
||||
None if not content_fee else binascii.hexlify(content_fee.raw).decode())
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
return transaction.execute("select rowid from file where stream_hash=?", (stream_hash, )).fetchone()[0]
|
||||
|
||||
|
@ -293,17 +294,17 @@ class SQLiteStorage(SQLiteMixin):
|
|||
def _add_blobs(transaction: sqlite3.Connection):
|
||||
transaction.executemany(
|
||||
"insert or ignore into blob values (?, ?, ?, ?, ?, ?, ?)",
|
||||
[
|
||||
(
|
||||
(blob_hash, length, 0, 0, "pending" if not finished else "finished", 0, 0)
|
||||
for blob_hash, length in blob_hashes_and_lengths
|
||||
]
|
||||
)
|
||||
)
|
||||
).fetchall()
|
||||
if finished:
|
||||
transaction.executemany(
|
||||
"update blob set status='finished' where blob.blob_hash=?", [
|
||||
"update blob set status='finished' where blob.blob_hash=?", (
|
||||
(blob_hash, ) for blob_hash, _ in blob_hashes_and_lengths
|
||||
]
|
||||
)
|
||||
)
|
||||
).fetchall()
|
||||
return await self.db.run(_add_blobs)
|
||||
|
||||
def get_blob_status(self, blob_hash: str):
|
||||
|
@ -317,9 +318,9 @@ class SQLiteStorage(SQLiteMixin):
|
|||
return transaction.executemany(
|
||||
"update blob set next_announce_time=?, last_announced_time=?, single_announce=0 "
|
||||
"where blob_hash=?",
|
||||
[(int(last_announced + (data_expiration / 2)), int(last_announced), blob_hash)
|
||||
for blob_hash in blob_hashes]
|
||||
)
|
||||
((int(last_announced + (data_expiration / 2)), int(last_announced), blob_hash)
|
||||
for blob_hash in blob_hashes)
|
||||
).fetchall()
|
||||
return self.db.run(_update_last_announced_blobs)
|
||||
|
||||
def should_single_announce_blobs(self, blob_hashes, immediate=False):
|
||||
|
@ -330,11 +331,11 @@ class SQLiteStorage(SQLiteMixin):
|
|||
transaction.execute(
|
||||
"update blob set single_announce=1, next_announce_time=? "
|
||||
"where blob_hash=? and status='finished'", (int(now), blob_hash)
|
||||
)
|
||||
).fetchall()
|
||||
else:
|
||||
transaction.execute(
|
||||
"update blob set single_announce=1 where blob_hash=? and status='finished'", (blob_hash,)
|
||||
)
|
||||
).fetchall()
|
||||
return self.db.run(set_single_announce)
|
||||
|
||||
def get_blobs_to_announce(self):
|
||||
|
@ -347,22 +348,22 @@ class SQLiteStorage(SQLiteMixin):
|
|||
"(should_announce=1 or single_announce=1) and next_announce_time<? and status='finished' "
|
||||
"order by next_announce_time asc limit ?",
|
||||
(timestamp, int(self.conf.concurrent_blob_announcers * 10))
|
||||
)
|
||||
).fetchall()
|
||||
else:
|
||||
r = transaction.execute(
|
||||
"select blob_hash from blob where blob_hash is not null "
|
||||
"and next_announce_time<? and status='finished' "
|
||||
"order by next_announce_time asc limit ?",
|
||||
(timestamp, int(self.conf.concurrent_blob_announcers * 10))
|
||||
)
|
||||
return [b[0] for b in r.fetchall()]
|
||||
).fetchall()
|
||||
return [b[0] for b in r]
|
||||
return self.db.run(get_and_update)
|
||||
|
||||
def delete_blobs_from_db(self, blob_hashes):
|
||||
def delete_blobs(transaction):
|
||||
transaction.executemany(
|
||||
"delete from blob where blob_hash=?;", [(blob_hash,) for blob_hash in blob_hashes]
|
||||
)
|
||||
"delete from blob where blob_hash=?;", ((blob_hash,) for blob_hash in blob_hashes)
|
||||
).fetchall()
|
||||
return self.db.run_with_foreign_keys_disabled(delete_blobs)
|
||||
|
||||
def get_all_blob_hashes(self):
|
||||
|
@ -370,22 +371,18 @@ class SQLiteStorage(SQLiteMixin):
|
|||
|
||||
def sync_missing_blobs(self, blob_files: typing.Set[str]) -> typing.Awaitable[typing.Set[str]]:
|
||||
def _sync_blobs(transaction: sqlite3.Connection) -> typing.Set[str]:
|
||||
to_update = [
|
||||
(blob_hash, )
|
||||
for (blob_hash, ) in transaction.execute("select blob_hash from blob where status='finished'")
|
||||
if blob_hash not in blob_files
|
||||
]
|
||||
finished_blob_hashes = tuple(
|
||||
blob_hash for (blob_hash, ) in transaction.execute(
|
||||
"select blob_hash from blob where status='finished'"
|
||||
).fetchall()
|
||||
)
|
||||
finished_blobs_set = set(finished_blob_hashes)
|
||||
to_update_set = finished_blobs_set.difference(blob_files)
|
||||
transaction.executemany(
|
||||
"update blob set status='pending' where blob_hash=?",
|
||||
to_update
|
||||
)
|
||||
return {
|
||||
blob_hash
|
||||
for blob_hash, in _batched_select(
|
||||
transaction, "select blob_hash from blob where status='finished' and blob_hash in {}",
|
||||
list(blob_files)
|
||||
)
|
||||
}
|
||||
((blob_hash, ) for blob_hash in to_update_set)
|
||||
).fetchall()
|
||||
return blob_files.intersection(finished_blobs_set)
|
||||
return self.db.run(_sync_blobs)
|
||||
|
||||
# # # # # # # # # stream functions # # # # # # # # #
|
||||
|
@ -484,7 +481,7 @@ class SQLiteStorage(SQLiteMixin):
|
|||
transaction.executemany(
|
||||
"update file set file_name=null, download_directory=null, saved_file=0 where stream_hash=?",
|
||||
removed
|
||||
)
|
||||
).fetchall()
|
||||
return await self.db.run(update_manually_removed_files)
|
||||
|
||||
def get_all_lbry_files(self) -> typing.Awaitable[typing.List[typing.Dict]]:
|
||||
|
@ -492,7 +489,7 @@ class SQLiteStorage(SQLiteMixin):
|
|||
|
||||
def change_file_status(self, stream_hash: str, new_status: str):
|
||||
log.debug("update file status %s -> %s", stream_hash, new_status)
|
||||
return self.db.execute("update file set status=? where stream_hash=?", (new_status, stream_hash))
|
||||
return self.db.execute_fetchall("update file set status=? where stream_hash=?", (new_status, stream_hash))
|
||||
|
||||
async def change_file_download_dir_and_file_name(self, stream_hash: str, download_dir: typing.Optional[str],
|
||||
file_name: typing.Optional[str]):
|
||||
|
@ -501,22 +498,22 @@ class SQLiteStorage(SQLiteMixin):
|
|||
else:
|
||||
encoded_file_name = binascii.hexlify(file_name.encode()).decode()
|
||||
encoded_download_dir = binascii.hexlify(download_dir.encode()).decode()
|
||||
return await self.db.execute("update file set download_directory=?, file_name=? where stream_hash=?", (
|
||||
return await self.db.execute_fetchall("update file set download_directory=?, file_name=? where stream_hash=?", (
|
||||
encoded_download_dir, encoded_file_name, stream_hash,
|
||||
))
|
||||
|
||||
async def save_content_fee(self, stream_hash: str, content_fee: Transaction):
|
||||
return await self.db.execute("update file set content_fee=? where stream_hash=?", (
|
||||
return await self.db.execute_fetchall("update file set content_fee=? where stream_hash=?", (
|
||||
binascii.hexlify(content_fee.raw), stream_hash,
|
||||
))
|
||||
|
||||
async def set_saved_file(self, stream_hash: str):
|
||||
return await self.db.execute("update file set saved_file=1 where stream_hash=?", (
|
||||
return await self.db.execute_fetchall("update file set saved_file=1 where stream_hash=?", (
|
||||
stream_hash,
|
||||
))
|
||||
|
||||
async def clear_saved_file(self, stream_hash: str):
|
||||
return await self.db.execute("update file set saved_file=0 where stream_hash=?", (
|
||||
return await self.db.execute_fetchall("update file set saved_file=0 where stream_hash=?", (
|
||||
stream_hash,
|
||||
))
|
||||
|
||||
|
@ -537,13 +534,13 @@ class SQLiteStorage(SQLiteMixin):
|
|||
transaction.execute("insert or ignore into content_claim values (?, ?)", content_claim)
|
||||
transaction.executemany(
|
||||
"update file set status='stopped' where stream_hash=?",
|
||||
[(stream_hash, ) for stream_hash in stream_hashes]
|
||||
)
|
||||
((stream_hash, ) for stream_hash in stream_hashes)
|
||||
).fetchall()
|
||||
download_dir = binascii.hexlify(self.conf.download_dir.encode()).decode()
|
||||
transaction.executemany(
|
||||
f"update file set download_directory=? where stream_hash=?",
|
||||
[(download_dir, stream_hash) for stream_hash in stream_hashes]
|
||||
)
|
||||
((download_dir, stream_hash) for stream_hash in stream_hashes)
|
||||
).fetchall()
|
||||
await self.db.run_with_foreign_keys_disabled(_recover)
|
||||
|
||||
def get_all_stream_hashes(self):
|
||||
|
@ -555,14 +552,16 @@ class SQLiteStorage(SQLiteMixin):
|
|||
# TODO: add 'address' to support items returned for a claim from lbrycrdd and lbryum-server
|
||||
def _save_support(transaction):
|
||||
bind = "({})".format(','.join(['?'] * len(claim_id_to_supports)))
|
||||
transaction.execute(f"delete from support where claim_id in {bind}", list(claim_id_to_supports.keys()))
|
||||
transaction.execute(
|
||||
f"delete from support where claim_id in {bind}", tuple(claim_id_to_supports.keys())
|
||||
).fetchall()
|
||||
for claim_id, supports in claim_id_to_supports.items():
|
||||
for support in supports:
|
||||
transaction.execute(
|
||||
"insert into support values (?, ?, ?, ?)",
|
||||
("%s:%i" % (support['txid'], support['nout']), claim_id, lbc_to_dewies(support['amount']),
|
||||
support.get('address', ""))
|
||||
)
|
||||
).fetchall()
|
||||
return self.db.run(_save_support)
|
||||
|
||||
def get_supports(self, *claim_ids):
|
||||
|
@ -581,7 +580,7 @@ class SQLiteStorage(SQLiteMixin):
|
|||
for support_info in _batched_select(
|
||||
transaction,
|
||||
"select * from support where claim_id in {}",
|
||||
tuple(claim_ids)
|
||||
claim_ids
|
||||
)
|
||||
]
|
||||
|
||||
|
@ -612,7 +611,7 @@ class SQLiteStorage(SQLiteMixin):
|
|||
transaction.execute(
|
||||
"insert or replace into claim values (?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(outpoint, claim_id, name, amount, height, serialized, certificate_id, address, sequence)
|
||||
)
|
||||
).fetchall()
|
||||
# if this response doesn't have support info don't overwrite the existing
|
||||
# support info
|
||||
if 'supports' in claim_info:
|
||||
|
@ -699,7 +698,9 @@ class SQLiteStorage(SQLiteMixin):
|
|||
)
|
||||
|
||||
# update the claim associated to the file
|
||||
transaction.execute("insert or replace into content_claim values (?, ?)", (stream_hash, claim_outpoint))
|
||||
transaction.execute(
|
||||
"insert or replace into content_claim values (?, ?)", (stream_hash, claim_outpoint)
|
||||
).fetchall()
|
||||
|
||||
async def save_content_claim(self, stream_hash, claim_outpoint):
|
||||
await self.db.run(self._save_content_claim, claim_outpoint, stream_hash)
|
||||
|
@ -722,11 +723,11 @@ class SQLiteStorage(SQLiteMixin):
|
|||
|
||||
def update_reflected_stream(self, sd_hash, reflector_address, success=True):
|
||||
if success:
|
||||
return self.db.execute(
|
||||
return self.db.execute_fetchall(
|
||||
"insert or replace into reflected_stream values (?, ?, ?)",
|
||||
(sd_hash, reflector_address, self.time_getter())
|
||||
)
|
||||
return self.db.execute(
|
||||
return self.db.execute_fetchall(
|
||||
"delete from reflected_stream where sd_hash=? and reflector_address=?",
|
||||
(sd_hash, reflector_address)
|
||||
)
|
||||
|
|
|
@ -134,11 +134,11 @@ class WalletDatabase(BaseDatabase):
|
|||
return self.get_utxo_count(**constraints)
|
||||
|
||||
async def release_all_outputs(self, account):
|
||||
await self.db.execute(
|
||||
await self.db.execute_fetchall(
|
||||
"UPDATE txo SET is_reserved = 0 WHERE"
|
||||
" is_reserved = 1 AND txo.address IN ("
|
||||
" SELECT address from pubkey_address WHERE account = ?"
|
||||
" )", [account.public_key.address]
|
||||
" )", (account.public_key.address, )
|
||||
)
|
||||
|
||||
def get_supports_summary(self, account_id):
|
||||
|
|
|
@ -11,14 +11,6 @@ def extract(d, keys):
|
|||
|
||||
|
||||
class AccountManagement(CommandTestCase):
|
||||
async def test_sqlite_binding_error(self):
|
||||
tasks = [
|
||||
self.loop.create_task(self.daemon.jsonrpc_account_create('second account' + str(x))) for x in range(100)
|
||||
]
|
||||
await asyncio.wait(tasks)
|
||||
for result in tasks:
|
||||
self.assertFalse(isinstance(result.result(), Exception))
|
||||
|
||||
async def test_account_list_set_create_remove_add(self):
|
||||
# check initial account
|
||||
response = await self.daemon.jsonrpc_account_list()
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
import sys
|
||||
import os
|
||||
import unittest
|
||||
import sqlite3
|
||||
import tempfile
|
||||
import os
|
||||
import asyncio
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
|
||||
from torba.client.wallet import Wallet
|
||||
from torba.client.constants import COIN
|
||||
|
@ -431,3 +434,98 @@ class TestUpgrade(AsyncioTestCase):
|
|||
self.assertEqual(self.get_tables(), ['foo', 'pubkey_address', 'tx', 'txi', 'txo', 'version'])
|
||||
self.assertEqual(self.get_addresses(), []) # all tables got reset
|
||||
await self.ledger.db.close()
|
||||
|
||||
|
||||
class TestSQLiteRace(AsyncioTestCase):
|
||||
max_misuse_attempts = 40000
|
||||
|
||||
def setup_db(self):
|
||||
self.db = sqlite3.connect(":memory:", isolation_level=None)
|
||||
self.db.executescript(
|
||||
"create table test1 (id text primary key not null, val text);\n" +
|
||||
"create table test2 (id text primary key not null, val text);\n" +
|
||||
"\n".join(f"insert into test1 values ({v}, NULL);" for v in range(1000))
|
||||
)
|
||||
|
||||
async def asyncSetUp(self):
|
||||
self.executor = ThreadPoolExecutor(1)
|
||||
await self.loop.run_in_executor(self.executor, self.setup_db)
|
||||
|
||||
async def asyncTearDown(self):
|
||||
await self.loop.run_in_executor(self.executor, self.db.close)
|
||||
self.executor.shutdown()
|
||||
|
||||
async def test_binding_param_0_error(self):
|
||||
# test real param 0 binding errors
|
||||
|
||||
for supported_type in [str, int, bytes]:
|
||||
await self.loop.run_in_executor(
|
||||
self.executor, self.db.executemany, "insert into test2 values (?, NULL)",
|
||||
[(supported_type(1), ), (supported_type(2), )]
|
||||
)
|
||||
await self.loop.run_in_executor(
|
||||
self.executor, self.db.execute, "delete from test2 where id in (1, 2)"
|
||||
)
|
||||
for unsupported_type in [lambda x: (x, ), lambda x: [x], lambda x: {x}]:
|
||||
try:
|
||||
await self.loop.run_in_executor(
|
||||
self.executor, self.db.executemany, "insert into test2 (id, val) values (?, NULL)",
|
||||
[(unsupported_type(1), ), (unsupported_type(2), )]
|
||||
)
|
||||
self.assertTrue(False)
|
||||
except sqlite3.InterfaceError as err:
|
||||
self.assertEqual(str(err), "Error binding parameter 0 - probably unsupported type.")
|
||||
|
||||
async def test_unhandled_sqlite_misuse(self):
|
||||
# test SQLITE_MISUSE being incorrectly raised as a param 0 binding error
|
||||
attempts = 0
|
||||
python_version = sys.version.split('\n')[0].rstrip(' ')
|
||||
|
||||
try:
|
||||
while attempts < self.max_misuse_attempts:
|
||||
f1 = asyncio.wrap_future(
|
||||
self.loop.run_in_executor(
|
||||
self.executor, self.db.executemany, "update test1 set val='derp' where id=?",
|
||||
((str(i),) for i in range(2))
|
||||
)
|
||||
)
|
||||
f2 = asyncio.wrap_future(
|
||||
self.loop.run_in_executor(
|
||||
self.executor, self.db.executemany, "update test2 set val='derp' where id=?",
|
||||
((str(i),) for i in range(2))
|
||||
)
|
||||
)
|
||||
attempts += 1
|
||||
await asyncio.gather(f1, f2)
|
||||
print(f"\nsqlite3 {sqlite3.version}/python {python_version} "
|
||||
f"did not raise SQLITE_MISUSE within {attempts} attempts of the race condition")
|
||||
self.assertTrue(False, 'this test failing means either the sqlite race conditions '
|
||||
'have been fixed in cpython or the test max_attempts needs to be increased')
|
||||
except sqlite3.InterfaceError as err:
|
||||
self.assertEqual(str(err), "Error binding parameter 0 - probably unsupported type.")
|
||||
print(f"\nsqlite3 {sqlite3.version}/python {python_version} raised SQLITE_MISUSE "
|
||||
f"after {attempts} attempts of the race condition")
|
||||
|
||||
@unittest.SkipTest
|
||||
async def test_fetchall_prevents_sqlite_misuse(self):
|
||||
# test that calling fetchall sufficiently avoids the race
|
||||
attempts = 0
|
||||
|
||||
def executemany_fetchall(query, params):
|
||||
self.db.executemany(query, params).fetchall()
|
||||
|
||||
while attempts < self.max_misuse_attempts:
|
||||
f1 = asyncio.wrap_future(
|
||||
self.loop.run_in_executor(
|
||||
self.executor, executemany_fetchall, "update test1 set val='derp' where id=?",
|
||||
((str(i),) for i in range(2))
|
||||
)
|
||||
)
|
||||
f2 = asyncio.wrap_future(
|
||||
self.loop.run_in_executor(
|
||||
self.executor, executemany_fetchall, "update test2 set val='derp' where id=?",
|
||||
((str(i),) for i in range(2))
|
||||
)
|
||||
)
|
||||
attempts += 1
|
||||
await asyncio.gather(f1, f2)
|
|
@ -3,7 +3,7 @@ import asyncio
|
|||
from binascii import hexlify
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
|
||||
from typing import Tuple, List, Union, Callable, Any, Awaitable, Iterable, Optional
|
||||
from typing import Tuple, List, Union, Callable, Any, Awaitable, Iterable, Dict, Optional
|
||||
|
||||
import sqlite3
|
||||
|
||||
|
@ -19,6 +19,7 @@ class AIOSQLite:
|
|||
# has to be single threaded as there is no mapping of thread:connection
|
||||
self.executor = ThreadPoolExecutor(max_workers=1)
|
||||
self.connection: sqlite3.Connection = None
|
||||
self._closing = False
|
||||
|
||||
@classmethod
|
||||
async def connect(cls, path: Union[bytes, str], *args, **kwargs):
|
||||
|
@ -29,14 +30,12 @@ class AIOSQLite:
|
|||
return db
|
||||
|
||||
async def close(self):
|
||||
def __close(conn):
|
||||
self.executor.submit(conn.close)
|
||||
self.executor.shutdown(wait=True)
|
||||
conn = self.connection
|
||||
if not conn:
|
||||
if self._closing:
|
||||
return
|
||||
self._closing = True
|
||||
await asyncio.get_event_loop().run_in_executor(self.executor, self.connection.close)
|
||||
self.executor.shutdown(wait=True)
|
||||
self.connection = None
|
||||
return asyncio.get_event_loop_policy().get_event_loop().call_later(0.01, __close, conn)
|
||||
|
||||
def executemany(self, sql: str, params: Iterable):
|
||||
params = params if params is not None else []
|
||||
|
@ -87,10 +86,10 @@ class AIOSQLite:
|
|||
if not foreign_keys_enabled:
|
||||
raise sqlite3.IntegrityError("foreign keys are disabled, use `AIOSQLite.run` instead")
|
||||
try:
|
||||
self.connection.execute('pragma foreign_keys=off')
|
||||
self.connection.execute('pragma foreign_keys=off').fetchone()
|
||||
return self.__run_transaction(fun, *args, **kwargs)
|
||||
finally:
|
||||
self.connection.execute('pragma foreign_keys=on')
|
||||
self.connection.execute('pragma foreign_keys=on').fetchone()
|
||||
|
||||
|
||||
def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''):
|
||||
|
@ -160,7 +159,7 @@ def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''):
|
|||
return joiner.join(sql) if sql else '', values
|
||||
|
||||
|
||||
def query(select, **constraints):
|
||||
def query(select, **constraints) -> Tuple[str, Dict[str, Any]]:
|
||||
sql = [select]
|
||||
limit = constraints.pop('limit', None)
|
||||
offset = constraints.pop('offset', None)
|
||||
|
@ -377,10 +376,10 @@ class BaseDatabase(SQLiteMixin):
|
|||
}
|
||||
|
||||
async def insert_transaction(self, tx):
|
||||
await self.db.execute(*self._insert_sql('tx', self.tx_to_row(tx)))
|
||||
await self.db.execute_fetchall(*self._insert_sql('tx', self.tx_to_row(tx)))
|
||||
|
||||
async def update_transaction(self, tx):
|
||||
await self.db.execute(*self._update_sql("tx", {
|
||||
await self.db.execute_fetchall(*self._update_sql("tx", {
|
||||
'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified
|
||||
}, 'txid = ?', (tx.id,)))
|
||||
|
||||
|
@ -391,7 +390,7 @@ class BaseDatabase(SQLiteMixin):
|
|||
if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == txhash:
|
||||
conn.execute(*self._insert_sql(
|
||||
"txo", self.txo_to_row(tx, address, txo), ignore_duplicate=True
|
||||
))
|
||||
)).fetchall()
|
||||
elif txo.script.is_pay_script_hash:
|
||||
# TODO: implement script hash payments
|
||||
log.warning('Database.save_transaction_io: pay script hash is not implemented!')
|
||||
|
@ -404,7 +403,7 @@ class BaseDatabase(SQLiteMixin):
|
|||
'txid': tx.id,
|
||||
'txoid': txo.id,
|
||||
'address': address,
|
||||
}, ignore_duplicate=True))
|
||||
}, ignore_duplicate=True)).fetchall()
|
||||
|
||||
conn.execute(
|
||||
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
|
||||
|
@ -619,7 +618,7 @@ class BaseDatabase(SQLiteMixin):
|
|||
)
|
||||
|
||||
async def _set_address_history(self, address, history):
|
||||
await self.db.execute(
|
||||
await self.db.execute_fetchall(
|
||||
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
|
||||
(history, history.count(':')//2, address)
|
||||
)
|
||||
|
|
|
@ -13,6 +13,6 @@ changedir = {toxinidir}/tests
|
|||
setenv =
|
||||
integration: TORBA_LEDGER={envname}
|
||||
commands =
|
||||
unit: coverage run -p --source={envsitepackagesdir}/torba -m unittest discover -t . client_tests.unit
|
||||
unit: coverage run -p --source={envsitepackagesdir}/torba -m unittest discover -vv -t . client_tests.unit
|
||||
integration: orchstr8 download
|
||||
integration: coverage run -p --source={envsitepackagesdir}/torba -m unittest discover -t . client_tests.integration
|
||||
integration: coverage run -p --source={envsitepackagesdir}/torba -m unittest discover -vv -t . client_tests.integration
|
||||
|
|
Loading…
Reference in a new issue