sql __in uses proper escaping instead of putting values directly in SQL
This commit is contained in:
parent
397ebe8428
commit
cfb051396a
4 changed files with 56 additions and 29 deletions
|
@ -76,11 +76,17 @@ class TestQueryBuilder(unittest.TestCase):
|
||||||
def test_in(self):
|
def test_in(self):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
constraints_to_sql({'txo.age__in': [18, 38]}),
|
constraints_to_sql({'txo.age__in': [18, 38]}),
|
||||||
('txo.age IN (18, 38)', {})
|
('txo.age IN (:txo_age__in0, :txo_age__in1)', {
|
||||||
|
'txo_age__in0': 18,
|
||||||
|
'txo_age__in1': 38
|
||||||
|
})
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
constraints_to_sql({'txo.age__in': ['abc123', 'def456']}),
|
constraints_to_sql({'txo.name__in': ('abc123', 'def456')}),
|
||||||
("txo.age IN ('abc123', 'def456')", {})
|
('txo.name IN (:txo_name__in0, :txo_name__in1)', {
|
||||||
|
'txo_name__in0': 'abc123',
|
||||||
|
'txo_name__in1': 'def456'
|
||||||
|
})
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
constraints_to_sql({'txo.age__in': 'SELECT age from ages_table'}),
|
constraints_to_sql({'txo.age__in': 'SELECT age from ages_table'}),
|
||||||
|
@ -90,7 +96,17 @@ class TestQueryBuilder(unittest.TestCase):
|
||||||
def test_not_in(self):
|
def test_not_in(self):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
constraints_to_sql({'txo.age__not_in': [18, 38]}),
|
constraints_to_sql({'txo.age__not_in': [18, 38]}),
|
||||||
('txo.age NOT IN (18, 38)', {})
|
('txo.age NOT IN (:txo_age__not_in0, :txo_age__not_in1)', {
|
||||||
|
'txo_age__not_in0': 18,
|
||||||
|
'txo_age__not_in1': 38
|
||||||
|
})
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
constraints_to_sql({'txo.name__not_in': ('abc123', 'def456')}),
|
||||||
|
('txo.name NOT IN (:txo_name__not_in0, :txo_name__not_in1)', {
|
||||||
|
'txo_name__not_in0': 'abc123',
|
||||||
|
'txo_name__not_in1': 'def456'
|
||||||
|
})
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
constraints_to_sql({'txo.age__not_in': 'SELECT age from ages_table'}),
|
constraints_to_sql({'txo.age__not_in': 'SELECT age from ages_table'}),
|
||||||
|
|
|
@ -93,6 +93,14 @@ def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''):
|
||||||
continue
|
continue
|
||||||
elif key.endswith('__not'):
|
elif key.endswith('__not'):
|
||||||
col, op = col[:-len('__not')], '!='
|
col, op = col[:-len('__not')], '!='
|
||||||
|
elif key.endswith('__is_null'):
|
||||||
|
col = col[:-len('__is_null')]
|
||||||
|
sql.append(f'{col} IS NULL')
|
||||||
|
continue
|
||||||
|
elif key.endswith('__is_not_null'):
|
||||||
|
col = col[:-len('__is_not_null')]
|
||||||
|
sql.append(f'{col} IS NOT NULL')
|
||||||
|
continue
|
||||||
elif key.endswith('__lt'):
|
elif key.endswith('__lt'):
|
||||||
col, op = col[:-len('__lt')], '<'
|
col, op = col[:-len('__lt')], '<'
|
||||||
elif key.endswith('__lte'):
|
elif key.endswith('__lte'):
|
||||||
|
@ -108,23 +116,24 @@ def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''):
|
||||||
col, op = col[:-len('__in')], 'IN'
|
col, op = col[:-len('__in')], 'IN'
|
||||||
else:
|
else:
|
||||||
col, op = col[:-len('__not_in')], 'NOT IN'
|
col, op = col[:-len('__not_in')], 'NOT IN'
|
||||||
if isinstance(constraint, (list, set)):
|
if constraint:
|
||||||
items = ', '.join(
|
if isinstance(constraint, (list, set, tuple)):
|
||||||
"'{}'".format(item) if isinstance(item, str) else str(item)
|
keys = []
|
||||||
for item in constraint
|
for i, val in enumerate(constraint):
|
||||||
)
|
keys.append(f':{key}{i}')
|
||||||
elif isinstance(constraint, str):
|
values[f'{key}{i}'] = val
|
||||||
items = constraint
|
sql.append(f'{col} {op} ({", ".join(keys)})')
|
||||||
else:
|
elif isinstance(constraint, str):
|
||||||
raise ValueError("{} requires a list, set or string as constraint value.".format(col))
|
sql.append(f'{col} {op} ({constraint})')
|
||||||
sql.append('{} {} ({})'.format(col, op, items))
|
else:
|
||||||
|
raise ValueError(f"{col} requires a list, set or string as constraint value.")
|
||||||
continue
|
continue
|
||||||
elif key.endswith('__any'):
|
elif key.endswith('__any'):
|
||||||
where, subvalues = constraints_to_sql(constraint, ' OR ', key+'_')
|
where, subvalues = constraints_to_sql(constraint, ' OR ', key+'_')
|
||||||
sql.append('({})'.format(where))
|
sql.append(f'({where})')
|
||||||
values.update(subvalues)
|
values.update(subvalues)
|
||||||
continue
|
continue
|
||||||
sql.append('{} {} :{}'.format(col, op, prepend_key+key))
|
sql.append(f'{col} {op} :{prepend_key}{key}')
|
||||||
values[prepend_key+key] = constraint
|
values[prepend_key+key] = constraint
|
||||||
return joiner.join(sql) if sql else '', values
|
return joiner.join(sql) if sql else '', values
|
||||||
|
|
||||||
|
@ -382,12 +391,14 @@ class BaseDatabase(SQLiteMixin):
|
||||||
if not tx_rows:
|
if not tx_rows:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
txids, txs = [], []
|
txids, txs, txi_txoids = [], [], []
|
||||||
for row in tx_rows:
|
for row in tx_rows:
|
||||||
txids.append(row[0])
|
txids.append(row[0])
|
||||||
txs.append(self.ledger.transaction_class(
|
txs.append(self.ledger.transaction_class(
|
||||||
raw=row[1], height=row[2], position=row[3], is_verified=bool(row[4])
|
raw=row[1], height=row[2], position=row[3], is_verified=bool(row[4])
|
||||||
))
|
))
|
||||||
|
for txi in txs[-1].inputs:
|
||||||
|
txi_txoids.append(txi.txo_ref.id)
|
||||||
|
|
||||||
annotated_txos = {
|
annotated_txos = {
|
||||||
txo.id: txo for txo in
|
txo.id: txo for txo in
|
||||||
|
@ -401,7 +412,7 @@ class BaseDatabase(SQLiteMixin):
|
||||||
txo.id: txo for txo in
|
txo.id: txo for txo in
|
||||||
(await self.get_txos(
|
(await self.get_txos(
|
||||||
my_account=my_account,
|
my_account=my_account,
|
||||||
txoid__in=query("SELECT txoid FROM txi", **{'txid__in': txids})[0]
|
txoid__in=txi_txoids
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -114,32 +114,32 @@ class BaseNetwork:
|
||||||
def is_connected(self):
|
def is_connected(self):
|
||||||
return self.client is not None and not self.client.is_closing()
|
return self.client is not None and not self.client.is_closing()
|
||||||
|
|
||||||
def rpc(self, list_or_method, *args):
|
def rpc(self, list_or_method, args):
|
||||||
if self.is_connected:
|
if self.is_connected:
|
||||||
return self.client.send_request(list_or_method, args)
|
return self.client.send_request(list_or_method, args)
|
||||||
else:
|
else:
|
||||||
raise ConnectionError("Attempting to send rpc request when connection is not available.")
|
raise ConnectionError("Attempting to send rpc request when connection is not available.")
|
||||||
|
|
||||||
def ensure_server_version(self, required='1.2'):
|
def ensure_server_version(self, required='1.2'):
|
||||||
return self.rpc('server.version', __version__, required)
|
return self.rpc('server.version', [__version__, required])
|
||||||
|
|
||||||
def broadcast(self, raw_transaction):
|
def broadcast(self, raw_transaction):
|
||||||
return self.rpc('blockchain.transaction.broadcast', raw_transaction)
|
return self.rpc('blockchain.transaction.broadcast', [raw_transaction])
|
||||||
|
|
||||||
def get_history(self, address):
|
def get_history(self, address):
|
||||||
return self.rpc('blockchain.address.get_history', address)
|
return self.rpc('blockchain.address.get_history', [address])
|
||||||
|
|
||||||
def get_transaction(self, tx_hash):
|
def get_transaction(self, tx_hash):
|
||||||
return self.rpc('blockchain.transaction.get', tx_hash)
|
return self.rpc('blockchain.transaction.get', [tx_hash])
|
||||||
|
|
||||||
def get_merkle(self, tx_hash, height):
|
def get_merkle(self, tx_hash, height):
|
||||||
return self.rpc('blockchain.transaction.get_merkle', tx_hash, height)
|
return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height])
|
||||||
|
|
||||||
def get_headers(self, height, count=10000):
|
def get_headers(self, height, count=10000):
|
||||||
return self.rpc('blockchain.block.headers', height, count)
|
return self.rpc('blockchain.block.headers', [height, count])
|
||||||
|
|
||||||
def subscribe_headers(self):
|
def subscribe_headers(self):
|
||||||
return self.rpc('blockchain.headers.subscribe', True)
|
return self.rpc('blockchain.headers.subscribe', [True])
|
||||||
|
|
||||||
def subscribe_address(self, address):
|
def subscribe_address(self, address):
|
||||||
return self.rpc('blockchain.address.subscribe', address)
|
return self.rpc('blockchain.address.subscribe', [address])
|
||||||
|
|
|
@ -2,8 +2,8 @@ import logging
|
||||||
import traceback
|
import traceback
|
||||||
import argparse
|
import argparse
|
||||||
import importlib
|
import importlib
|
||||||
from .env import Env
|
from torba.server.env import Env
|
||||||
from .server import Server
|
from torba.server.server import Server
|
||||||
|
|
||||||
|
|
||||||
def get_argument_parser():
|
def get_argument_parser():
|
||||||
|
|
Loading…
Reference in a new issue