query interpolation
This commit is contained in:
parent
2a9e6473eb
commit
71551058ce
2 changed files with 35 additions and 2 deletions
|
@ -1,11 +1,10 @@
|
|||
import unittest
|
||||
import sqlite3
|
||||
from functools import wraps
|
||||
|
||||
from torba.client.wallet import Wallet
|
||||
from torba.client.constants import COIN
|
||||
from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class
|
||||
from torba.client.basedatabase import query, constraints_to_sql, AIOSQLite
|
||||
from torba.client.basedatabase import query, interpolate, constraints_to_sql, AIOSQLite
|
||||
from torba.client.hash import sha256
|
||||
|
||||
from torba.testcase import AsyncioTestCase
|
||||
|
@ -167,6 +166,26 @@ class TestQueryBuilder(unittest.TestCase):
|
|||
("select * from foo LIMIT 20 OFFSET 10", {})
|
||||
)
|
||||
|
||||
def test_query_interpolation(self):
|
||||
self.maxDiff = None
|
||||
# tests that interpolation replaces longer keys first
|
||||
self.assertEqual(
|
||||
interpolate(*query(
|
||||
"select * from foo",
|
||||
a__not='b', b__in='select * from blah where c=:$c',
|
||||
d__any={'one__like': 'o', 'two': 2},
|
||||
a0=3, a00=1, a00a=2, a00aa=4, # <-- breaks without correct interpolation key order
|
||||
ahash=memoryview(sha256(b'hello world')),
|
||||
limit=10, order_by='b', **{'$c': 3})
|
||||
),
|
||||
"select * from foo WHERE a != 'b' AND "
|
||||
"b IN (select * from blah where c=3) AND "
|
||||
"(one LIKE 'o' OR two = 2) AND "
|
||||
"a0 = 3 AND a00 = 1 AND a00a = 2 AND a00aa = 4 "
|
||||
"AND ahash = e9cdefe2acf78890ee80537ae3ef84c4faab7ddad7522ea5083e4d93b9274db9 "
|
||||
"ORDER BY b LIMIT 10",
|
||||
)
|
||||
|
||||
|
||||
class TestQueries(AsyncioTestCase):
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
import asyncio
|
||||
from binascii import hexlify
|
||||
from asyncio import wrap_future
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
|
||||
|
@ -191,6 +192,19 @@ def query(select, **constraints):
|
|||
return ' '.join(sql), values
|
||||
|
||||
|
||||
def interpolate(sql, values):
|
||||
for k in sorted(values.keys(), reverse=True):
|
||||
value = values[k]
|
||||
if isinstance(value, memoryview):
|
||||
value = hexlify(bytes(value)[::-1]).decode()
|
||||
elif isinstance(value, str):
|
||||
value = f"'{value}'"
|
||||
else:
|
||||
value = str(value)
|
||||
sql = sql.replace(f":{k}", value)
|
||||
return sql
|
||||
|
||||
|
||||
def rows_to_dict(rows, fields):
|
||||
if rows:
|
||||
return [dict(zip(fields, r)) for r in rows]
|
||||
|
|
Loading…
Reference in a new issue