Merge pull request #2459 from lbryio/test-sqlite-error-handling

Test sqlite error handling
This commit is contained in:
Jack Robison 2019-09-18 10:16:31 -04:00 committed by GitHub
commit 79727f0e97
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 183 additions and 93 deletions

View file

@ -101,6 +101,9 @@ def _batched_select(transaction, query, parameters, batch_size=900):
def get_all_lbry_files(transaction: sqlite3.Connection) -> typing.List[typing.Dict]:
files = []
signed_claims = {}
stream_hashes = tuple(
stream_hash for (stream_hash,) in transaction.execute("select stream_hash from file").fetchall()
)
for (rowid, stream_hash, file_name, download_dir, data_rate, status, saved_file, raw_content_fee, _,
sd_hash, stream_key, stream_name, suggested_file_name, *claim_args) in _batched_select(
transaction, "select file.rowid, file.*, stream.*, c.* "
@ -108,9 +111,7 @@ def get_all_lbry_files(transaction: sqlite3.Connection) -> typing.List[typing.Di
"inner join content_claim cc on file.stream_hash=cc.stream_hash "
"inner join claim c on cc.claim_outpoint=c.claim_outpoint "
"where file.stream_hash in {} "
"order by c.rowid desc", [
stream_hash for (stream_hash,) in transaction.execute("select stream_hash from file")]):
"order by c.rowid desc", stream_hashes):
claim = StoredStreamClaim(stream_hash, *claim_args)
if claim.channel_claim_id:
if claim.channel_claim_id not in signed_claims:
@ -137,7 +138,7 @@ def get_all_lbry_files(transaction: sqlite3.Connection) -> typing.List[typing.Di
)
for claim_name, claim_id in _batched_select(
transaction, "select c.claim_name, c.claim_id from claim c where c.claim_id in {}",
list(signed_claims.keys())):
tuple(signed_claims.keys())):
for claim in signed_claims[claim_id]:
claim.channel_name = claim_name
return files
@ -147,35 +148,35 @@ def store_stream(transaction: sqlite3.Connection, sd_blob: 'BlobFile', descripto
# add all blobs, except the last one, which is empty
transaction.executemany(
"insert or ignore into blob values (?, ?, ?, ?, ?, ?, ?)",
[(blob.blob_hash, blob.length, 0, 0, "pending", 0, 0)
for blob in (descriptor.blobs[:-1] if len(descriptor.blobs) > 1 else descriptor.blobs) + [sd_blob]]
)
((blob.blob_hash, blob.length, 0, 0, "pending", 0, 0)
for blob in (descriptor.blobs[:-1] if len(descriptor.blobs) > 1 else descriptor.blobs) + [sd_blob])
).fetchall()
# associate the blobs to the stream
transaction.execute("insert or ignore into stream values (?, ?, ?, ?, ?)",
(descriptor.stream_hash, sd_blob.blob_hash, descriptor.key,
binascii.hexlify(descriptor.stream_name.encode()).decode(),
binascii.hexlify(descriptor.suggested_file_name.encode()).decode()))
binascii.hexlify(descriptor.suggested_file_name.encode()).decode())).fetchall()
# add the stream
transaction.executemany(
"insert or ignore into stream_blob values (?, ?, ?, ?)",
[(descriptor.stream_hash, blob.blob_hash, blob.blob_num, blob.iv)
for blob in descriptor.blobs]
)
((descriptor.stream_hash, blob.blob_hash, blob.blob_num, blob.iv)
for blob in descriptor.blobs)
).fetchall()
# ensure should_announce is set regardless if insert was ignored
transaction.execute(
"update blob set should_announce=1 where blob_hash in (?, ?)",
(sd_blob.blob_hash, descriptor.blobs[0].blob_hash,)
)
).fetchall()
def delete_stream(transaction: sqlite3.Connection, descriptor: 'StreamDescriptor'):
blob_hashes = [(blob.blob_hash, ) for blob in descriptor.blobs[:-1]]
blob_hashes.append((descriptor.sd_hash, ))
transaction.execute("delete from content_claim where stream_hash=? ", (descriptor.stream_hash,))
transaction.execute("delete from file where stream_hash=? ", (descriptor.stream_hash,))
transaction.execute("delete from stream_blob where stream_hash=?", (descriptor.stream_hash,))
transaction.execute("delete from stream where stream_hash=? ", (descriptor.stream_hash,))
transaction.executemany("delete from blob where blob_hash=?", blob_hashes)
transaction.execute("delete from content_claim where stream_hash=? ", (descriptor.stream_hash,)).fetchall()
transaction.execute("delete from file where stream_hash=? ", (descriptor.stream_hash,)).fetchall()
transaction.execute("delete from stream_blob where stream_hash=?", (descriptor.stream_hash,)).fetchall()
transaction.execute("delete from stream where stream_hash=? ", (descriptor.stream_hash,)).fetchall()
transaction.executemany("delete from blob where blob_hash=?", blob_hashes).fetchall()
def store_file(transaction: sqlite3.Connection, stream_hash: str, file_name: typing.Optional[str],
@ -191,7 +192,7 @@ def store_file(transaction: sqlite3.Connection, stream_hash: str, file_name: typ
(stream_hash, encoded_file_name, encoded_download_dir, data_payment_rate, status,
1 if (file_name and download_directory and os.path.isfile(os.path.join(download_directory, file_name))) else 0,
None if not content_fee else binascii.hexlify(content_fee.raw).decode())
)
).fetchall()
return transaction.execute("select rowid from file where stream_hash=?", (stream_hash, )).fetchone()[0]
@ -293,17 +294,17 @@ class SQLiteStorage(SQLiteMixin):
def _add_blobs(transaction: sqlite3.Connection):
transaction.executemany(
"insert or ignore into blob values (?, ?, ?, ?, ?, ?, ?)",
[
(
(blob_hash, length, 0, 0, "pending" if not finished else "finished", 0, 0)
for blob_hash, length in blob_hashes_and_lengths
]
)
)
).fetchall()
if finished:
transaction.executemany(
"update blob set status='finished' where blob.blob_hash=?", [
"update blob set status='finished' where blob.blob_hash=?", (
(blob_hash, ) for blob_hash, _ in blob_hashes_and_lengths
]
)
)
).fetchall()
return await self.db.run(_add_blobs)
def get_blob_status(self, blob_hash: str):
@ -317,9 +318,9 @@ class SQLiteStorage(SQLiteMixin):
return transaction.executemany(
"update blob set next_announce_time=?, last_announced_time=?, single_announce=0 "
"where blob_hash=?",
[(int(last_announced + (data_expiration / 2)), int(last_announced), blob_hash)
for blob_hash in blob_hashes]
)
((int(last_announced + (data_expiration / 2)), int(last_announced), blob_hash)
for blob_hash in blob_hashes)
).fetchall()
return self.db.run(_update_last_announced_blobs)
def should_single_announce_blobs(self, blob_hashes, immediate=False):
@ -330,11 +331,11 @@ class SQLiteStorage(SQLiteMixin):
transaction.execute(
"update blob set single_announce=1, next_announce_time=? "
"where blob_hash=? and status='finished'", (int(now), blob_hash)
)
).fetchall()
else:
transaction.execute(
"update blob set single_announce=1 where blob_hash=? and status='finished'", (blob_hash,)
)
).fetchall()
return self.db.run(set_single_announce)
def get_blobs_to_announce(self):
@ -347,22 +348,22 @@ class SQLiteStorage(SQLiteMixin):
"(should_announce=1 or single_announce=1) and next_announce_time<? and status='finished' "
"order by next_announce_time asc limit ?",
(timestamp, int(self.conf.concurrent_blob_announcers * 10))
)
).fetchall()
else:
r = transaction.execute(
"select blob_hash from blob where blob_hash is not null "
"and next_announce_time<? and status='finished' "
"order by next_announce_time asc limit ?",
(timestamp, int(self.conf.concurrent_blob_announcers * 10))
)
return [b[0] for b in r.fetchall()]
).fetchall()
return [b[0] for b in r]
return self.db.run(get_and_update)
def delete_blobs_from_db(self, blob_hashes):
def delete_blobs(transaction):
transaction.executemany(
"delete from blob where blob_hash=?;", [(blob_hash,) for blob_hash in blob_hashes]
)
"delete from blob where blob_hash=?;", ((blob_hash,) for blob_hash in blob_hashes)
).fetchall()
return self.db.run_with_foreign_keys_disabled(delete_blobs)
def get_all_blob_hashes(self):
@ -370,22 +371,18 @@ class SQLiteStorage(SQLiteMixin):
def sync_missing_blobs(self, blob_files: typing.Set[str]) -> typing.Awaitable[typing.Set[str]]:
def _sync_blobs(transaction: sqlite3.Connection) -> typing.Set[str]:
to_update = [
(blob_hash, )
for (blob_hash, ) in transaction.execute("select blob_hash from blob where status='finished'")
if blob_hash not in blob_files
]
finished_blob_hashes = tuple(
blob_hash for (blob_hash, ) in transaction.execute(
"select blob_hash from blob where status='finished'"
).fetchall()
)
finished_blobs_set = set(finished_blob_hashes)
to_update_set = finished_blobs_set.difference(blob_files)
transaction.executemany(
"update blob set status='pending' where blob_hash=?",
to_update
)
return {
blob_hash
for blob_hash, in _batched_select(
transaction, "select blob_hash from blob where status='finished' and blob_hash in {}",
list(blob_files)
)
}
((blob_hash, ) for blob_hash in to_update_set)
).fetchall()
return blob_files.intersection(finished_blobs_set)
return self.db.run(_sync_blobs)
# # # # # # # # # stream functions # # # # # # # # #
@ -484,7 +481,7 @@ class SQLiteStorage(SQLiteMixin):
transaction.executemany(
"update file set file_name=null, download_directory=null, saved_file=0 where stream_hash=?",
removed
)
).fetchall()
return await self.db.run(update_manually_removed_files)
def get_all_lbry_files(self) -> typing.Awaitable[typing.List[typing.Dict]]:
@ -492,7 +489,7 @@ class SQLiteStorage(SQLiteMixin):
def change_file_status(self, stream_hash: str, new_status: str):
log.debug("update file status %s -> %s", stream_hash, new_status)
return self.db.execute("update file set status=? where stream_hash=?", (new_status, stream_hash))
return self.db.execute_fetchall("update file set status=? where stream_hash=?", (new_status, stream_hash))
async def change_file_download_dir_and_file_name(self, stream_hash: str, download_dir: typing.Optional[str],
file_name: typing.Optional[str]):
@ -501,22 +498,22 @@ class SQLiteStorage(SQLiteMixin):
else:
encoded_file_name = binascii.hexlify(file_name.encode()).decode()
encoded_download_dir = binascii.hexlify(download_dir.encode()).decode()
return await self.db.execute("update file set download_directory=?, file_name=? where stream_hash=?", (
return await self.db.execute_fetchall("update file set download_directory=?, file_name=? where stream_hash=?", (
encoded_download_dir, encoded_file_name, stream_hash,
))
async def save_content_fee(self, stream_hash: str, content_fee: Transaction):
return await self.db.execute("update file set content_fee=? where stream_hash=?", (
return await self.db.execute_fetchall("update file set content_fee=? where stream_hash=?", (
binascii.hexlify(content_fee.raw), stream_hash,
))
async def set_saved_file(self, stream_hash: str):
return await self.db.execute("update file set saved_file=1 where stream_hash=?", (
return await self.db.execute_fetchall("update file set saved_file=1 where stream_hash=?", (
stream_hash,
))
async def clear_saved_file(self, stream_hash: str):
return await self.db.execute("update file set saved_file=0 where stream_hash=?", (
return await self.db.execute_fetchall("update file set saved_file=0 where stream_hash=?", (
stream_hash,
))
@ -537,13 +534,13 @@ class SQLiteStorage(SQLiteMixin):
transaction.execute("insert or ignore into content_claim values (?, ?)", content_claim)
transaction.executemany(
"update file set status='stopped' where stream_hash=?",
[(stream_hash, ) for stream_hash in stream_hashes]
)
((stream_hash, ) for stream_hash in stream_hashes)
).fetchall()
download_dir = binascii.hexlify(self.conf.download_dir.encode()).decode()
transaction.executemany(
f"update file set download_directory=? where stream_hash=?",
[(download_dir, stream_hash) for stream_hash in stream_hashes]
)
((download_dir, stream_hash) for stream_hash in stream_hashes)
).fetchall()
await self.db.run_with_foreign_keys_disabled(_recover)
def get_all_stream_hashes(self):
@ -555,14 +552,16 @@ class SQLiteStorage(SQLiteMixin):
# TODO: add 'address' to support items returned for a claim from lbrycrdd and lbryum-server
def _save_support(transaction):
bind = "({})".format(','.join(['?'] * len(claim_id_to_supports)))
transaction.execute(f"delete from support where claim_id in {bind}", list(claim_id_to_supports.keys()))
transaction.execute(
f"delete from support where claim_id in {bind}", tuple(claim_id_to_supports.keys())
).fetchall()
for claim_id, supports in claim_id_to_supports.items():
for support in supports:
transaction.execute(
"insert into support values (?, ?, ?, ?)",
("%s:%i" % (support['txid'], support['nout']), claim_id, lbc_to_dewies(support['amount']),
support.get('address', ""))
)
).fetchall()
return self.db.run(_save_support)
def get_supports(self, *claim_ids):
@ -581,7 +580,7 @@ class SQLiteStorage(SQLiteMixin):
for support_info in _batched_select(
transaction,
"select * from support where claim_id in {}",
tuple(claim_ids)
claim_ids
)
]
@ -612,7 +611,7 @@ class SQLiteStorage(SQLiteMixin):
transaction.execute(
"insert or replace into claim values (?, ?, ?, ?, ?, ?, ?, ?, ?)",
(outpoint, claim_id, name, amount, height, serialized, certificate_id, address, sequence)
)
).fetchall()
# if this response doesn't have support info don't overwrite the existing
# support info
if 'supports' in claim_info:
@ -699,7 +698,9 @@ class SQLiteStorage(SQLiteMixin):
)
# update the claim associated to the file
transaction.execute("insert or replace into content_claim values (?, ?)", (stream_hash, claim_outpoint))
transaction.execute(
"insert or replace into content_claim values (?, ?)", (stream_hash, claim_outpoint)
).fetchall()
async def save_content_claim(self, stream_hash, claim_outpoint):
await self.db.run(self._save_content_claim, claim_outpoint, stream_hash)
@ -722,11 +723,11 @@ class SQLiteStorage(SQLiteMixin):
def update_reflected_stream(self, sd_hash, reflector_address, success=True):
if success:
return self.db.execute(
return self.db.execute_fetchall(
"insert or replace into reflected_stream values (?, ?, ?)",
(sd_hash, reflector_address, self.time_getter())
)
return self.db.execute(
return self.db.execute_fetchall(
"delete from reflected_stream where sd_hash=? and reflector_address=?",
(sd_hash, reflector_address)
)

View file

@ -134,11 +134,11 @@ class WalletDatabase(BaseDatabase):
return self.get_utxo_count(**constraints)
async def release_all_outputs(self, account):
await self.db.execute(
await self.db.execute_fetchall(
"UPDATE txo SET is_reserved = 0 WHERE"
" is_reserved = 1 AND txo.address IN ("
" SELECT address from pubkey_address WHERE account = ?"
" )", [account.public_key.address]
" )", (account.public_key.address, )
)
def get_supports_summary(self, account_id):

View file

@ -11,14 +11,6 @@ def extract(d, keys):
class AccountManagement(CommandTestCase):
async def test_sqlite_binding_error(self):
tasks = [
self.loop.create_task(self.daemon.jsonrpc_account_create('second account' + str(x))) for x in range(100)
]
await asyncio.wait(tasks)
for result in tasks:
self.assertFalse(isinstance(result.result(), Exception))
async def test_account_list_set_create_remove_add(self):
# check initial account
response = await self.daemon.jsonrpc_account_list()

View file

@ -1,7 +1,10 @@
import sys
import os
import unittest
import sqlite3
import tempfile
import os
import asyncio
from concurrent.futures.thread import ThreadPoolExecutor
from torba.client.wallet import Wallet
from torba.client.constants import COIN
@ -431,3 +434,98 @@ class TestUpgrade(AsyncioTestCase):
self.assertEqual(self.get_tables(), ['foo', 'pubkey_address', 'tx', 'txi', 'txo', 'version'])
self.assertEqual(self.get_addresses(), []) # all tables got reset
await self.ledger.db.close()
class TestSQLiteRace(AsyncioTestCase):
max_misuse_attempts = 40000
def setup_db(self):
self.db = sqlite3.connect(":memory:", isolation_level=None)
self.db.executescript(
"create table test1 (id text primary key not null, val text);\n" +
"create table test2 (id text primary key not null, val text);\n" +
"\n".join(f"insert into test1 values ({v}, NULL);" for v in range(1000))
)
async def asyncSetUp(self):
self.executor = ThreadPoolExecutor(1)
await self.loop.run_in_executor(self.executor, self.setup_db)
async def asyncTearDown(self):
await self.loop.run_in_executor(self.executor, self.db.close)
self.executor.shutdown()
async def test_binding_param_0_error(self):
# test real param 0 binding errors
for supported_type in [str, int, bytes]:
await self.loop.run_in_executor(
self.executor, self.db.executemany, "insert into test2 values (?, NULL)",
[(supported_type(1), ), (supported_type(2), )]
)
await self.loop.run_in_executor(
self.executor, self.db.execute, "delete from test2 where id in (1, 2)"
)
for unsupported_type in [lambda x: (x, ), lambda x: [x], lambda x: {x}]:
try:
await self.loop.run_in_executor(
self.executor, self.db.executemany, "insert into test2 (id, val) values (?, NULL)",
[(unsupported_type(1), ), (unsupported_type(2), )]
)
self.assertTrue(False)
except sqlite3.InterfaceError as err:
self.assertEqual(str(err), "Error binding parameter 0 - probably unsupported type.")
async def test_unhandled_sqlite_misuse(self):
# test SQLITE_MISUSE being incorrectly raised as a param 0 binding error
attempts = 0
python_version = sys.version.split('\n')[0].rstrip(' ')
try:
while attempts < self.max_misuse_attempts:
f1 = asyncio.wrap_future(
self.loop.run_in_executor(
self.executor, self.db.executemany, "update test1 set val='derp' where id=?",
((str(i),) for i in range(2))
)
)
f2 = asyncio.wrap_future(
self.loop.run_in_executor(
self.executor, self.db.executemany, "update test2 set val='derp' where id=?",
((str(i),) for i in range(2))
)
)
attempts += 1
await asyncio.gather(f1, f2)
print(f"\nsqlite3 {sqlite3.version}/python {python_version} "
f"did not raise SQLITE_MISUSE within {attempts} attempts of the race condition")
self.assertTrue(False, 'this test failing means either the sqlite race conditions '
'have been fixed in cpython or the test max_attempts needs to be increased')
except sqlite3.InterfaceError as err:
self.assertEqual(str(err), "Error binding parameter 0 - probably unsupported type.")
print(f"\nsqlite3 {sqlite3.version}/python {python_version} raised SQLITE_MISUSE "
f"after {attempts} attempts of the race condition")
@unittest.SkipTest
async def test_fetchall_prevents_sqlite_misuse(self):
# test that calling fetchall sufficiently avoids the race
attempts = 0
def executemany_fetchall(query, params):
self.db.executemany(query, params).fetchall()
while attempts < self.max_misuse_attempts:
f1 = asyncio.wrap_future(
self.loop.run_in_executor(
self.executor, executemany_fetchall, "update test1 set val='derp' where id=?",
((str(i),) for i in range(2))
)
)
f2 = asyncio.wrap_future(
self.loop.run_in_executor(
self.executor, executemany_fetchall, "update test2 set val='derp' where id=?",
((str(i),) for i in range(2))
)
)
attempts += 1
await asyncio.gather(f1, f2)

View file

@ -3,7 +3,7 @@ import asyncio
from binascii import hexlify
from concurrent.futures.thread import ThreadPoolExecutor
from typing import Tuple, List, Union, Callable, Any, Awaitable, Iterable, Optional
from typing import Tuple, List, Union, Callable, Any, Awaitable, Iterable, Dict, Optional
import sqlite3
@ -19,6 +19,7 @@ class AIOSQLite:
# has to be single threaded as there is no mapping of thread:connection
self.executor = ThreadPoolExecutor(max_workers=1)
self.connection: sqlite3.Connection = None
self._closing = False
@classmethod
async def connect(cls, path: Union[bytes, str], *args, **kwargs):
@ -29,14 +30,12 @@ class AIOSQLite:
return db
async def close(self):
def __close(conn):
self.executor.submit(conn.close)
self.executor.shutdown(wait=True)
conn = self.connection
if not conn:
if self._closing:
return
self._closing = True
await asyncio.get_event_loop().run_in_executor(self.executor, self.connection.close)
self.executor.shutdown(wait=True)
self.connection = None
return asyncio.get_event_loop_policy().get_event_loop().call_later(0.01, __close, conn)
def executemany(self, sql: str, params: Iterable):
params = params if params is not None else []
@ -87,10 +86,10 @@ class AIOSQLite:
if not foreign_keys_enabled:
raise sqlite3.IntegrityError("foreign keys are disabled, use `AIOSQLite.run` instead")
try:
self.connection.execute('pragma foreign_keys=off')
self.connection.execute('pragma foreign_keys=off').fetchone()
return self.__run_transaction(fun, *args, **kwargs)
finally:
self.connection.execute('pragma foreign_keys=on')
self.connection.execute('pragma foreign_keys=on').fetchone()
def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''):
@ -160,7 +159,7 @@ def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''):
return joiner.join(sql) if sql else '', values
def query(select, **constraints):
def query(select, **constraints) -> Tuple[str, Dict[str, Any]]:
sql = [select]
limit = constraints.pop('limit', None)
offset = constraints.pop('offset', None)
@ -377,10 +376,10 @@ class BaseDatabase(SQLiteMixin):
}
async def insert_transaction(self, tx):
await self.db.execute(*self._insert_sql('tx', self.tx_to_row(tx)))
await self.db.execute_fetchall(*self._insert_sql('tx', self.tx_to_row(tx)))
async def update_transaction(self, tx):
await self.db.execute(*self._update_sql("tx", {
await self.db.execute_fetchall(*self._update_sql("tx", {
'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified
}, 'txid = ?', (tx.id,)))
@ -391,7 +390,7 @@ class BaseDatabase(SQLiteMixin):
if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == txhash:
conn.execute(*self._insert_sql(
"txo", self.txo_to_row(tx, address, txo), ignore_duplicate=True
))
)).fetchall()
elif txo.script.is_pay_script_hash:
# TODO: implement script hash payments
log.warning('Database.save_transaction_io: pay script hash is not implemented!')
@ -404,7 +403,7 @@ class BaseDatabase(SQLiteMixin):
'txid': tx.id,
'txoid': txo.id,
'address': address,
}, ignore_duplicate=True))
}, ignore_duplicate=True)).fetchall()
conn.execute(
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
@ -619,7 +618,7 @@ class BaseDatabase(SQLiteMixin):
)
async def _set_address_history(self, address, history):
await self.db.execute(
await self.db.execute_fetchall(
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
(history, history.count(':')//2, address)
)

View file

@ -13,6 +13,6 @@ changedir = {toxinidir}/tests
setenv =
integration: TORBA_LEDGER={envname}
commands =
unit: coverage run -p --source={envsitepackagesdir}/torba -m unittest discover -t . client_tests.unit
unit: coverage run -p --source={envsitepackagesdir}/torba -m unittest discover -vv -t . client_tests.unit
integration: orchstr8 download
integration: coverage run -p --source={envsitepackagesdir}/torba -m unittest discover -t . client_tests.integration
integration: coverage run -p --source={envsitepackagesdir}/torba -m unittest discover -vv -t . client_tests.integration