query interpolation
This commit is contained in:
parent
2a9e6473eb
commit
71551058ce
2 changed files with 35 additions and 2 deletions
|
@ -1,11 +1,10 @@
|
||||||
import unittest
|
import unittest
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from functools import wraps
|
|
||||||
|
|
||||||
from torba.client.wallet import Wallet
|
from torba.client.wallet import Wallet
|
||||||
from torba.client.constants import COIN
|
from torba.client.constants import COIN
|
||||||
from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class
|
from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class
|
||||||
from torba.client.basedatabase import query, constraints_to_sql, AIOSQLite
|
from torba.client.basedatabase import query, interpolate, constraints_to_sql, AIOSQLite
|
||||||
from torba.client.hash import sha256
|
from torba.client.hash import sha256
|
||||||
|
|
||||||
from torba.testcase import AsyncioTestCase
|
from torba.testcase import AsyncioTestCase
|
||||||
|
@ -167,6 +166,26 @@ class TestQueryBuilder(unittest.TestCase):
|
||||||
("select * from foo LIMIT 20 OFFSET 10", {})
|
("select * from foo LIMIT 20 OFFSET 10", {})
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_query_interpolation(self):
|
||||||
|
self.maxDiff = None
|
||||||
|
# tests that interpolation replaces longer keys first
|
||||||
|
self.assertEqual(
|
||||||
|
interpolate(*query(
|
||||||
|
"select * from foo",
|
||||||
|
a__not='b', b__in='select * from blah where c=:$c',
|
||||||
|
d__any={'one__like': 'o', 'two': 2},
|
||||||
|
a0=3, a00=1, a00a=2, a00aa=4, # <-- breaks without correct interpolation key order
|
||||||
|
ahash=memoryview(sha256(b'hello world')),
|
||||||
|
limit=10, order_by='b', **{'$c': 3})
|
||||||
|
),
|
||||||
|
"select * from foo WHERE a != 'b' AND "
|
||||||
|
"b IN (select * from blah where c=3) AND "
|
||||||
|
"(one LIKE 'o' OR two = 2) AND "
|
||||||
|
"a0 = 3 AND a00 = 1 AND a00a = 2 AND a00aa = 4 "
|
||||||
|
"AND ahash = e9cdefe2acf78890ee80537ae3ef84c4faab7ddad7522ea5083e4d93b9274db9 "
|
||||||
|
"ORDER BY b LIMIT 10",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestQueries(AsyncioTestCase):
|
class TestQueries(AsyncioTestCase):
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from binascii import hexlify
|
||||||
from asyncio import wrap_future
|
from asyncio import wrap_future
|
||||||
from concurrent.futures.thread import ThreadPoolExecutor
|
from concurrent.futures.thread import ThreadPoolExecutor
|
||||||
|
|
||||||
|
@ -191,6 +192,19 @@ def query(select, **constraints):
|
||||||
return ' '.join(sql), values
|
return ' '.join(sql), values
|
||||||
|
|
||||||
|
|
||||||
|
def interpolate(sql, values):
|
||||||
|
for k in sorted(values.keys(), reverse=True):
|
||||||
|
value = values[k]
|
||||||
|
if isinstance(value, memoryview):
|
||||||
|
value = hexlify(bytes(value)[::-1]).decode()
|
||||||
|
elif isinstance(value, str):
|
||||||
|
value = f"'{value}'"
|
||||||
|
else:
|
||||||
|
value = str(value)
|
||||||
|
sql = sql.replace(f":{k}", value)
|
||||||
|
return sql
|
||||||
|
|
||||||
|
|
||||||
def rows_to_dict(rows, fields):
|
def rows_to_dict(rows, fields):
|
||||||
if rows:
|
if rows:
|
||||||
return [dict(zip(fields, r)) for r in rows]
|
return [dict(zip(fields, r)) for r in rows]
|
||||||
|
|
Loading…
Add table
Reference in a new issue