sql __in uses proper escaping instead of putting values directly in SQL

This commit is contained in:
Lex Berezhny 2019-04-06 18:08:33 -04:00
parent 397ebe8428
commit cfb051396a
4 changed files with 56 additions and 29 deletions

View file

@ -76,11 +76,17 @@ class TestQueryBuilder(unittest.TestCase):
def test_in(self):
self.assertEqual(
constraints_to_sql({'txo.age__in': [18, 38]}),
('txo.age IN (18, 38)', {})
('txo.age IN (:txo_age__in0, :txo_age__in1)', {
'txo_age__in0': 18,
'txo_age__in1': 38
})
)
self.assertEqual(
constraints_to_sql({'txo.age__in': ['abc123', 'def456']}),
("txo.age IN ('abc123', 'def456')", {})
constraints_to_sql({'txo.name__in': ('abc123', 'def456')}),
('txo.name IN (:txo_name__in0, :txo_name__in1)', {
'txo_name__in0': 'abc123',
'txo_name__in1': 'def456'
})
)
self.assertEqual(
constraints_to_sql({'txo.age__in': 'SELECT age from ages_table'}),
@ -90,7 +96,17 @@ class TestQueryBuilder(unittest.TestCase):
def test_not_in(self):
self.assertEqual(
constraints_to_sql({'txo.age__not_in': [18, 38]}),
('txo.age NOT IN (18, 38)', {})
('txo.age NOT IN (:txo_age__not_in0, :txo_age__not_in1)', {
'txo_age__not_in0': 18,
'txo_age__not_in1': 38
})
)
self.assertEqual(
constraints_to_sql({'txo.name__not_in': ('abc123', 'def456')}),
('txo.name NOT IN (:txo_name__not_in0, :txo_name__not_in1)', {
'txo_name__not_in0': 'abc123',
'txo_name__not_in1': 'def456'
})
)
self.assertEqual(
constraints_to_sql({'txo.age__not_in': 'SELECT age from ages_table'}),

View file

@ -93,6 +93,14 @@ def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''):
continue
elif key.endswith('__not'):
col, op = col[:-len('__not')], '!='
elif key.endswith('__is_null'):
col = col[:-len('__is_null')]
sql.append(f'{col} IS NULL')
continue
elif key.endswith('__is_not_null'):
col = col[:-len('__is_not_null')]
sql.append(f'{col} IS NOT NULL')
continue
elif key.endswith('__lt'):
col, op = col[:-len('__lt')], '<'
elif key.endswith('__lte'):
@ -108,23 +116,24 @@ def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''):
col, op = col[:-len('__in')], 'IN'
else:
col, op = col[:-len('__not_in')], 'NOT IN'
if isinstance(constraint, (list, set)):
items = ', '.join(
"'{}'".format(item) if isinstance(item, str) else str(item)
for item in constraint
)
elif isinstance(constraint, str):
items = constraint
else:
raise ValueError("{} requires a list, set or string as constraint value.".format(col))
sql.append('{} {} ({})'.format(col, op, items))
if constraint:
if isinstance(constraint, (list, set, tuple)):
keys = []
for i, val in enumerate(constraint):
keys.append(f':{key}{i}')
values[f'{key}{i}'] = val
sql.append(f'{col} {op} ({", ".join(keys)})')
elif isinstance(constraint, str):
sql.append(f'{col} {op} ({constraint})')
else:
raise ValueError(f"{col} requires a list, set or string as constraint value.")
continue
elif key.endswith('__any'):
where, subvalues = constraints_to_sql(constraint, ' OR ', key+'_')
sql.append('({})'.format(where))
sql.append(f'({where})')
values.update(subvalues)
continue
sql.append('{} {} :{}'.format(col, op, prepend_key+key))
sql.append(f'{col} {op} :{prepend_key}{key}')
values[prepend_key+key] = constraint
return joiner.join(sql) if sql else '', values
@ -382,12 +391,14 @@ class BaseDatabase(SQLiteMixin):
if not tx_rows:
return []
txids, txs = [], []
txids, txs, txi_txoids = [], [], []
for row in tx_rows:
txids.append(row[0])
txs.append(self.ledger.transaction_class(
raw=row[1], height=row[2], position=row[3], is_verified=bool(row[4])
))
for txi in txs[-1].inputs:
txi_txoids.append(txi.txo_ref.id)
annotated_txos = {
txo.id: txo for txo in
@ -401,7 +412,7 @@ class BaseDatabase(SQLiteMixin):
txo.id: txo for txo in
(await self.get_txos(
my_account=my_account,
txoid__in=query("SELECT txoid FROM txi", **{'txid__in': txids})[0]
txoid__in=txi_txoids
))
}

View file

@ -114,32 +114,32 @@ class BaseNetwork:
def is_connected(self):
return self.client is not None and not self.client.is_closing()
def rpc(self, list_or_method, *args):
def rpc(self, list_or_method, args):
if self.is_connected:
return self.client.send_request(list_or_method, args)
else:
raise ConnectionError("Attempting to send rpc request when connection is not available.")
def ensure_server_version(self, required='1.2'):
return self.rpc('server.version', __version__, required)
return self.rpc('server.version', [__version__, required])
def broadcast(self, raw_transaction):
return self.rpc('blockchain.transaction.broadcast', raw_transaction)
return self.rpc('blockchain.transaction.broadcast', [raw_transaction])
def get_history(self, address):
return self.rpc('blockchain.address.get_history', address)
return self.rpc('blockchain.address.get_history', [address])
def get_transaction(self, tx_hash):
return self.rpc('blockchain.transaction.get', tx_hash)
return self.rpc('blockchain.transaction.get', [tx_hash])
def get_merkle(self, tx_hash, height):
return self.rpc('blockchain.transaction.get_merkle', tx_hash, height)
return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height])
def get_headers(self, height, count=10000):
return self.rpc('blockchain.block.headers', height, count)
return self.rpc('blockchain.block.headers', [height, count])
def subscribe_headers(self):
return self.rpc('blockchain.headers.subscribe', True)
return self.rpc('blockchain.headers.subscribe', [True])
def subscribe_address(self, address):
return self.rpc('blockchain.address.subscribe', address)
return self.rpc('blockchain.address.subscribe', [address])

View file

@ -2,8 +2,8 @@ import logging
import traceback
import argparse
import importlib
from .env import Env
from .server import Server
from torba.server.env import Env
from torba.server.server import Server
def get_argument_parser():