query interpolation

This commit is contained in:
Lex Berezhny 2019-07-17 21:46:38 -04:00
parent 2a9e6473eb
commit 71551058ce
2 changed files with 35 additions and 2 deletions

View file

@ -1,11 +1,10 @@
import unittest
import sqlite3
from functools import wraps
from torba.client.wallet import Wallet
from torba.client.constants import COIN
from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class
from torba.client.basedatabase import query, constraints_to_sql, AIOSQLite
from torba.client.basedatabase import query, interpolate, constraints_to_sql, AIOSQLite
from torba.client.hash import sha256
from torba.testcase import AsyncioTestCase
@ -167,6 +166,26 @@ class TestQueryBuilder(unittest.TestCase):
("select * from foo LIMIT 20 OFFSET 10", {})
)
def test_query_interpolation(self):
self.maxDiff = None
# tests that interpolation replaces longer keys first
self.assertEqual(
interpolate(*query(
"select * from foo",
a__not='b', b__in='select * from blah where c=:$c',
d__any={'one__like': 'o', 'two': 2},
a0=3, a00=1, a00a=2, a00aa=4, # <-- breaks without correct interpolation key order
ahash=memoryview(sha256(b'hello world')),
limit=10, order_by='b', **{'$c': 3})
),
"select * from foo WHERE a != 'b' AND "
"b IN (select * from blah where c=3) AND "
"(one LIKE 'o' OR two = 2) AND "
"a0 = 3 AND a00 = 1 AND a00a = 2 AND a00aa = 4 "
"AND ahash = e9cdefe2acf78890ee80537ae3ef84c4faab7ddad7522ea5083e4d93b9274db9 "
"ORDER BY b LIMIT 10",
)
class TestQueries(AsyncioTestCase):

View file

@ -1,5 +1,6 @@
import logging
import asyncio
from binascii import hexlify
from asyncio import wrap_future
from concurrent.futures.thread import ThreadPoolExecutor
@ -191,6 +192,19 @@ def query(select, **constraints):
return ' '.join(sql), values
def interpolate(sql, values):
for k in sorted(values.keys(), reverse=True):
value = values[k]
if isinstance(value, memoryview):
value = hexlify(bytes(value)[::-1]).decode()
elif isinstance(value, str):
value = f"'{value}'"
else:
value = str(value)
sql = sql.replace(f":{k}", value)
return sql
def rows_to_dict(rows, fields):
if rows:
return [dict(zip(fields, r)) for r in rows]