sql __in uses proper escaping instead of putting values directly in SQL

This commit is contained in:
Lex Berezhny 2019-04-06 18:08:33 -04:00
parent 397ebe8428
commit cfb051396a
4 changed files with 56 additions and 29 deletions

View file

@ -76,11 +76,17 @@ class TestQueryBuilder(unittest.TestCase):
def test_in(self): def test_in(self):
self.assertEqual( self.assertEqual(
constraints_to_sql({'txo.age__in': [18, 38]}), constraints_to_sql({'txo.age__in': [18, 38]}),
('txo.age IN (18, 38)', {}) ('txo.age IN (:txo_age__in0, :txo_age__in1)', {
'txo_age__in0': 18,
'txo_age__in1': 38
})
) )
self.assertEqual( self.assertEqual(
constraints_to_sql({'txo.age__in': ['abc123', 'def456']}), constraints_to_sql({'txo.name__in': ('abc123', 'def456')}),
("txo.age IN ('abc123', 'def456')", {}) ('txo.name IN (:txo_name__in0, :txo_name__in1)', {
'txo_name__in0': 'abc123',
'txo_name__in1': 'def456'
})
) )
self.assertEqual( self.assertEqual(
constraints_to_sql({'txo.age__in': 'SELECT age from ages_table'}), constraints_to_sql({'txo.age__in': 'SELECT age from ages_table'}),
@ -90,7 +96,17 @@ class TestQueryBuilder(unittest.TestCase):
def test_not_in(self): def test_not_in(self):
self.assertEqual( self.assertEqual(
constraints_to_sql({'txo.age__not_in': [18, 38]}), constraints_to_sql({'txo.age__not_in': [18, 38]}),
('txo.age NOT IN (18, 38)', {}) ('txo.age NOT IN (:txo_age__not_in0, :txo_age__not_in1)', {
'txo_age__not_in0': 18,
'txo_age__not_in1': 38
})
)
self.assertEqual(
constraints_to_sql({'txo.name__not_in': ('abc123', 'def456')}),
('txo.name NOT IN (:txo_name__not_in0, :txo_name__not_in1)', {
'txo_name__not_in0': 'abc123',
'txo_name__not_in1': 'def456'
})
) )
self.assertEqual( self.assertEqual(
constraints_to_sql({'txo.age__not_in': 'SELECT age from ages_table'}), constraints_to_sql({'txo.age__not_in': 'SELECT age from ages_table'}),

View file

@ -93,6 +93,14 @@ def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''):
continue continue
elif key.endswith('__not'): elif key.endswith('__not'):
col, op = col[:-len('__not')], '!=' col, op = col[:-len('__not')], '!='
elif key.endswith('__is_null'):
col = col[:-len('__is_null')]
sql.append(f'{col} IS NULL')
continue
elif key.endswith('__is_not_null'):
col = col[:-len('__is_not_null')]
sql.append(f'{col} IS NOT NULL')
continue
elif key.endswith('__lt'): elif key.endswith('__lt'):
col, op = col[:-len('__lt')], '<' col, op = col[:-len('__lt')], '<'
elif key.endswith('__lte'): elif key.endswith('__lte'):
@ -108,23 +116,24 @@ def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''):
col, op = col[:-len('__in')], 'IN' col, op = col[:-len('__in')], 'IN'
else: else:
col, op = col[:-len('__not_in')], 'NOT IN' col, op = col[:-len('__not_in')], 'NOT IN'
if isinstance(constraint, (list, set)): if constraint:
items = ', '.join( if isinstance(constraint, (list, set, tuple)):
"'{}'".format(item) if isinstance(item, str) else str(item) keys = []
for item in constraint for i, val in enumerate(constraint):
) keys.append(f':{key}{i}')
elif isinstance(constraint, str): values[f'{key}{i}'] = val
items = constraint sql.append(f'{col} {op} ({", ".join(keys)})')
else: elif isinstance(constraint, str):
raise ValueError("{} requires a list, set or string as constraint value.".format(col)) sql.append(f'{col} {op} ({constraint})')
sql.append('{} {} ({})'.format(col, op, items)) else:
raise ValueError(f"{col} requires a list, set or string as constraint value.")
continue continue
elif key.endswith('__any'): elif key.endswith('__any'):
where, subvalues = constraints_to_sql(constraint, ' OR ', key+'_') where, subvalues = constraints_to_sql(constraint, ' OR ', key+'_')
sql.append('({})'.format(where)) sql.append(f'({where})')
values.update(subvalues) values.update(subvalues)
continue continue
sql.append('{} {} :{}'.format(col, op, prepend_key+key)) sql.append(f'{col} {op} :{prepend_key}{key}')
values[prepend_key+key] = constraint values[prepend_key+key] = constraint
return joiner.join(sql) if sql else '', values return joiner.join(sql) if sql else '', values
@ -382,12 +391,14 @@ class BaseDatabase(SQLiteMixin):
if not tx_rows: if not tx_rows:
return [] return []
txids, txs = [], [] txids, txs, txi_txoids = [], [], []
for row in tx_rows: for row in tx_rows:
txids.append(row[0]) txids.append(row[0])
txs.append(self.ledger.transaction_class( txs.append(self.ledger.transaction_class(
raw=row[1], height=row[2], position=row[3], is_verified=bool(row[4]) raw=row[1], height=row[2], position=row[3], is_verified=bool(row[4])
)) ))
for txi in txs[-1].inputs:
txi_txoids.append(txi.txo_ref.id)
annotated_txos = { annotated_txos = {
txo.id: txo for txo in txo.id: txo for txo in
@ -401,7 +412,7 @@ class BaseDatabase(SQLiteMixin):
txo.id: txo for txo in txo.id: txo for txo in
(await self.get_txos( (await self.get_txos(
my_account=my_account, my_account=my_account,
txoid__in=query("SELECT txoid FROM txi", **{'txid__in': txids})[0] txoid__in=txi_txoids
)) ))
} }

View file

@ -114,32 +114,32 @@ class BaseNetwork:
def is_connected(self): def is_connected(self):
return self.client is not None and not self.client.is_closing() return self.client is not None and not self.client.is_closing()
def rpc(self, list_or_method, *args): def rpc(self, list_or_method, args):
if self.is_connected: if self.is_connected:
return self.client.send_request(list_or_method, args) return self.client.send_request(list_or_method, args)
else: else:
raise ConnectionError("Attempting to send rpc request when connection is not available.") raise ConnectionError("Attempting to send rpc request when connection is not available.")
def ensure_server_version(self, required='1.2'): def ensure_server_version(self, required='1.2'):
return self.rpc('server.version', __version__, required) return self.rpc('server.version', [__version__, required])
def broadcast(self, raw_transaction): def broadcast(self, raw_transaction):
return self.rpc('blockchain.transaction.broadcast', raw_transaction) return self.rpc('blockchain.transaction.broadcast', [raw_transaction])
def get_history(self, address): def get_history(self, address):
return self.rpc('blockchain.address.get_history', address) return self.rpc('blockchain.address.get_history', [address])
def get_transaction(self, tx_hash): def get_transaction(self, tx_hash):
return self.rpc('blockchain.transaction.get', tx_hash) return self.rpc('blockchain.transaction.get', [tx_hash])
def get_merkle(self, tx_hash, height): def get_merkle(self, tx_hash, height):
return self.rpc('blockchain.transaction.get_merkle', tx_hash, height) return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height])
def get_headers(self, height, count=10000): def get_headers(self, height, count=10000):
return self.rpc('blockchain.block.headers', height, count) return self.rpc('blockchain.block.headers', [height, count])
def subscribe_headers(self): def subscribe_headers(self):
return self.rpc('blockchain.headers.subscribe', True) return self.rpc('blockchain.headers.subscribe', [True])
def subscribe_address(self, address): def subscribe_address(self, address):
return self.rpc('blockchain.address.subscribe', address) return self.rpc('blockchain.address.subscribe', [address])

View file

@ -2,8 +2,8 @@ import logging
import traceback import traceback
import argparse import argparse
import importlib import importlib
from .env import Env from torba.server.env import Env
from .server import Server from torba.server.server import Server
def get_argument_parser(): def get_argument_parser():