sql __in uses proper escaping instead of putting values directly in SQL
This commit is contained in:
parent
397ebe8428
commit
cfb051396a
4 changed files with 56 additions and 29 deletions
|
@ -76,11 +76,17 @@ class TestQueryBuilder(unittest.TestCase):
|
|||
def test_in(self):
|
||||
self.assertEqual(
|
||||
constraints_to_sql({'txo.age__in': [18, 38]}),
|
||||
('txo.age IN (18, 38)', {})
|
||||
('txo.age IN (:txo_age__in0, :txo_age__in1)', {
|
||||
'txo_age__in0': 18,
|
||||
'txo_age__in1': 38
|
||||
})
|
||||
)
|
||||
self.assertEqual(
|
||||
constraints_to_sql({'txo.age__in': ['abc123', 'def456']}),
|
||||
("txo.age IN ('abc123', 'def456')", {})
|
||||
constraints_to_sql({'txo.name__in': ('abc123', 'def456')}),
|
||||
('txo.name IN (:txo_name__in0, :txo_name__in1)', {
|
||||
'txo_name__in0': 'abc123',
|
||||
'txo_name__in1': 'def456'
|
||||
})
|
||||
)
|
||||
self.assertEqual(
|
||||
constraints_to_sql({'txo.age__in': 'SELECT age from ages_table'}),
|
||||
|
@ -90,7 +96,17 @@ class TestQueryBuilder(unittest.TestCase):
|
|||
def test_not_in(self):
|
||||
self.assertEqual(
|
||||
constraints_to_sql({'txo.age__not_in': [18, 38]}),
|
||||
('txo.age NOT IN (18, 38)', {})
|
||||
('txo.age NOT IN (:txo_age__not_in0, :txo_age__not_in1)', {
|
||||
'txo_age__not_in0': 18,
|
||||
'txo_age__not_in1': 38
|
||||
})
|
||||
)
|
||||
self.assertEqual(
|
||||
constraints_to_sql({'txo.name__not_in': ('abc123', 'def456')}),
|
||||
('txo.name NOT IN (:txo_name__not_in0, :txo_name__not_in1)', {
|
||||
'txo_name__not_in0': 'abc123',
|
||||
'txo_name__not_in1': 'def456'
|
||||
})
|
||||
)
|
||||
self.assertEqual(
|
||||
constraints_to_sql({'txo.age__not_in': 'SELECT age from ages_table'}),
|
||||
|
|
|
@ -93,6 +93,14 @@ def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''):
|
|||
continue
|
||||
elif key.endswith('__not'):
|
||||
col, op = col[:-len('__not')], '!='
|
||||
elif key.endswith('__is_null'):
|
||||
col = col[:-len('__is_null')]
|
||||
sql.append(f'{col} IS NULL')
|
||||
continue
|
||||
elif key.endswith('__is_not_null'):
|
||||
col = col[:-len('__is_not_null')]
|
||||
sql.append(f'{col} IS NOT NULL')
|
||||
continue
|
||||
elif key.endswith('__lt'):
|
||||
col, op = col[:-len('__lt')], '<'
|
||||
elif key.endswith('__lte'):
|
||||
|
@ -108,23 +116,24 @@ def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''):
|
|||
col, op = col[:-len('__in')], 'IN'
|
||||
else:
|
||||
col, op = col[:-len('__not_in')], 'NOT IN'
|
||||
if isinstance(constraint, (list, set)):
|
||||
items = ', '.join(
|
||||
"'{}'".format(item) if isinstance(item, str) else str(item)
|
||||
for item in constraint
|
||||
)
|
||||
elif isinstance(constraint, str):
|
||||
items = constraint
|
||||
else:
|
||||
raise ValueError("{} requires a list, set or string as constraint value.".format(col))
|
||||
sql.append('{} {} ({})'.format(col, op, items))
|
||||
if constraint:
|
||||
if isinstance(constraint, (list, set, tuple)):
|
||||
keys = []
|
||||
for i, val in enumerate(constraint):
|
||||
keys.append(f':{key}{i}')
|
||||
values[f'{key}{i}'] = val
|
||||
sql.append(f'{col} {op} ({", ".join(keys)})')
|
||||
elif isinstance(constraint, str):
|
||||
sql.append(f'{col} {op} ({constraint})')
|
||||
else:
|
||||
raise ValueError(f"{col} requires a list, set or string as constraint value.")
|
||||
continue
|
||||
elif key.endswith('__any'):
|
||||
where, subvalues = constraints_to_sql(constraint, ' OR ', key+'_')
|
||||
sql.append('({})'.format(where))
|
||||
sql.append(f'({where})')
|
||||
values.update(subvalues)
|
||||
continue
|
||||
sql.append('{} {} :{}'.format(col, op, prepend_key+key))
|
||||
sql.append(f'{col} {op} :{prepend_key}{key}')
|
||||
values[prepend_key+key] = constraint
|
||||
return joiner.join(sql) if sql else '', values
|
||||
|
||||
|
@ -382,12 +391,14 @@ class BaseDatabase(SQLiteMixin):
|
|||
if not tx_rows:
|
||||
return []
|
||||
|
||||
txids, txs = [], []
|
||||
txids, txs, txi_txoids = [], [], []
|
||||
for row in tx_rows:
|
||||
txids.append(row[0])
|
||||
txs.append(self.ledger.transaction_class(
|
||||
raw=row[1], height=row[2], position=row[3], is_verified=bool(row[4])
|
||||
))
|
||||
for txi in txs[-1].inputs:
|
||||
txi_txoids.append(txi.txo_ref.id)
|
||||
|
||||
annotated_txos = {
|
||||
txo.id: txo for txo in
|
||||
|
@ -401,7 +412,7 @@ class BaseDatabase(SQLiteMixin):
|
|||
txo.id: txo for txo in
|
||||
(await self.get_txos(
|
||||
my_account=my_account,
|
||||
txoid__in=query("SELECT txoid FROM txi", **{'txid__in': txids})[0]
|
||||
txoid__in=txi_txoids
|
||||
))
|
||||
}
|
||||
|
||||
|
|
|
@ -114,32 +114,32 @@ class BaseNetwork:
|
|||
def is_connected(self):
|
||||
return self.client is not None and not self.client.is_closing()
|
||||
|
||||
def rpc(self, list_or_method, *args):
|
||||
def rpc(self, list_or_method, args):
|
||||
if self.is_connected:
|
||||
return self.client.send_request(list_or_method, args)
|
||||
else:
|
||||
raise ConnectionError("Attempting to send rpc request when connection is not available.")
|
||||
|
||||
def ensure_server_version(self, required='1.2'):
|
||||
return self.rpc('server.version', __version__, required)
|
||||
return self.rpc('server.version', [__version__, required])
|
||||
|
||||
def broadcast(self, raw_transaction):
|
||||
return self.rpc('blockchain.transaction.broadcast', raw_transaction)
|
||||
return self.rpc('blockchain.transaction.broadcast', [raw_transaction])
|
||||
|
||||
def get_history(self, address):
|
||||
return self.rpc('blockchain.address.get_history', address)
|
||||
return self.rpc('blockchain.address.get_history', [address])
|
||||
|
||||
def get_transaction(self, tx_hash):
|
||||
return self.rpc('blockchain.transaction.get', tx_hash)
|
||||
return self.rpc('blockchain.transaction.get', [tx_hash])
|
||||
|
||||
def get_merkle(self, tx_hash, height):
|
||||
return self.rpc('blockchain.transaction.get_merkle', tx_hash, height)
|
||||
return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height])
|
||||
|
||||
def get_headers(self, height, count=10000):
|
||||
return self.rpc('blockchain.block.headers', height, count)
|
||||
return self.rpc('blockchain.block.headers', [height, count])
|
||||
|
||||
def subscribe_headers(self):
|
||||
return self.rpc('blockchain.headers.subscribe', True)
|
||||
return self.rpc('blockchain.headers.subscribe', [True])
|
||||
|
||||
def subscribe_address(self, address):
|
||||
return self.rpc('blockchain.address.subscribe', address)
|
||||
return self.rpc('blockchain.address.subscribe', [address])
|
||||
|
|
|
@ -2,8 +2,8 @@ import logging
|
|||
import traceback
|
||||
import argparse
|
||||
import importlib
|
||||
from .env import Env
|
||||
from .server import Server
|
||||
from torba.server.env import Env
|
||||
from torba.server.server import Server
|
||||
|
||||
|
||||
def get_argument_parser():
|
||||
|
|
Loading…
Reference in a new issue