query interpolation

This commit is contained in:
Lex Berezhny 2019-07-17 21:46:38 -04:00
parent 2a9e6473eb
commit 71551058ce
2 changed files with 35 additions and 2 deletions

View file

@ -1,11 +1,10 @@
import unittest import unittest
import sqlite3 import sqlite3
from functools import wraps
from torba.client.wallet import Wallet from torba.client.wallet import Wallet
from torba.client.constants import COIN from torba.client.constants import COIN
from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class
from torba.client.basedatabase import query, constraints_to_sql, AIOSQLite from torba.client.basedatabase import query, interpolate, constraints_to_sql, AIOSQLite
from torba.client.hash import sha256 from torba.client.hash import sha256
from torba.testcase import AsyncioTestCase from torba.testcase import AsyncioTestCase
@ -167,6 +166,26 @@ class TestQueryBuilder(unittest.TestCase):
("select * from foo LIMIT 20 OFFSET 10", {}) ("select * from foo LIMIT 20 OFFSET 10", {})
) )
def test_query_interpolation(self):
self.maxDiff = None
# tests that interpolation replaces longer keys first
self.assertEqual(
interpolate(*query(
"select * from foo",
a__not='b', b__in='select * from blah where c=:$c',
d__any={'one__like': 'o', 'two': 2},
a0=3, a00=1, a00a=2, a00aa=4, # <-- breaks without correct interpolation key order
ahash=memoryview(sha256(b'hello world')),
limit=10, order_by='b', **{'$c': 3})
),
"select * from foo WHERE a != 'b' AND "
"b IN (select * from blah where c=3) AND "
"(one LIKE 'o' OR two = 2) AND "
"a0 = 3 AND a00 = 1 AND a00a = 2 AND a00aa = 4 "
"AND ahash = e9cdefe2acf78890ee80537ae3ef84c4faab7ddad7522ea5083e4d93b9274db9 "
"ORDER BY b LIMIT 10",
)
class TestQueries(AsyncioTestCase): class TestQueries(AsyncioTestCase):

View file

@ -1,5 +1,6 @@
import logging import logging
import asyncio import asyncio
from binascii import hexlify
from asyncio import wrap_future from asyncio import wrap_future
from concurrent.futures.thread import ThreadPoolExecutor from concurrent.futures.thread import ThreadPoolExecutor
@ -191,6 +192,19 @@ def query(select, **constraints):
return ' '.join(sql), values return ' '.join(sql), values
def interpolate(sql, values):
for k in sorted(values.keys(), reverse=True):
value = values[k]
if isinstance(value, memoryview):
value = hexlify(bytes(value)[::-1]).decode()
elif isinstance(value, str):
value = f"'{value}'"
else:
value = str(value)
sql = sql.replace(f":{k}", value)
return sql
def rows_to_dict(rows, fields): def rows_to_dict(rows, fields):
if rows: if rows:
return [dict(zip(fields, r)) for r in rows] return [dict(zip(fields, r)) for r in rows]