lbry-sdk/lbry/db/utils.py

174 lines
5.6 KiB
Python
Raw Normal View History

2020-05-01 09:29:44 -04:00
from itertools import islice
from typing import List, Union
from sqlalchemy import text, and_
from sqlalchemy.sql.expression import Select, FunctionElement
from sqlalchemy.types import Numeric
from sqlalchemy.ext.compiler import compiles
2020-05-01 09:29:44 -04:00
try:
2020-06-05 00:35:22 -04:00
from sqlalchemy.dialects.postgresql import insert as pg_insert # pylint: disable=unused-import
2020-05-01 09:29:44 -04:00
except ImportError:
pg_insert = None
from .tables import AccountAddress
2020-06-22 11:02:59 -04:00
class greatest(FunctionElement): # pylint: disable=invalid-name
type = Numeric()
name = 'greatest'
@compiles(greatest)
def default_greatest(element, compiler, **kw):
return "greatest(%s)" % compiler.process(element.clauses, **kw)
@compiles(greatest, 'sqlite')
def sqlite_greatest(element, compiler, **kw):
return "max(%s)" % compiler.process(element.clauses, **kw)
2020-06-22 20:30:05 -04:00
class least(FunctionElement): # pylint: disable=invalid-name
type = Numeric()
name = 'least'
@compiles(least)
def default_least(element, compiler, **kw):
return "least(%s)" % compiler.process(element.clauses, **kw)
@compiles(least, 'sqlite')
def sqlite_least(element, compiler, **kw):
return "min(%s)" % compiler.process(element.clauses, **kw)
2020-05-01 09:29:44 -04:00
def chunk(rows, step):
it, total = iter(rows), len(rows)
for _ in range(0, total, step):
2020-06-19 14:28:34 -04:00
yield list(islice(it, step))
2020-05-01 09:29:44 -04:00
total -= step
def constrain_single_or_list(constraints, column, value, convert=lambda x: x):
if value is not None:
if isinstance(value, list):
value = [convert(v) for v in value]
if len(value) == 1:
constraints[column] = value[0]
elif len(value) > 1:
constraints[f"{column}__in"] = value
else:
constraints[column] = convert(value)
return constraints
def in_account_ids(account_ids: Union[List[str], str]):
if isinstance(account_ids, list):
if len(account_ids) > 1:
return AccountAddress.c.account.in_(account_ids)
account_ids = account_ids[0]
return AccountAddress.c.account == account_ids
def query(table, s: Select, **constraints) -> Select:
limit = constraints.pop('limit', None)
if limit is not None:
s = s.limit(limit)
offset = constraints.pop('offset', None)
if offset is not None:
s = s.offset(offset)
order_by = constraints.pop('order_by', None)
if order_by:
if isinstance(order_by, str):
s = s.order_by(text(order_by))
elif isinstance(order_by, list):
s = s.order_by(text(', '.join(order_by)))
else:
raise ValueError("order_by must be string or list")
group_by = constraints.pop('group_by', None)
if group_by is not None:
s = s.group_by(text(group_by))
account_ids = constraints.pop('account_ids', [])
if account_ids:
s = s.where(in_account_ids(account_ids))
if constraints:
s = s.where(
constraints_to_clause(table, constraints)
)
return s
def constraints_to_clause(tables, constraints):
clause = []
for key, constraint in constraints.items():
if key.endswith('__not'):
col, op = key[:-len('__not')], '__ne__'
elif key.endswith('__is_null'):
col = key[:-len('__is_null')]
op = '__eq__'
constraint = None
elif key.endswith('__is_not_null'):
col = key[:-len('__is_not_null')]
op = '__ne__'
constraint = None
elif key.endswith('__lt'):
col, op = key[:-len('__lt')], '__lt__'
elif key.endswith('__lte'):
col, op = key[:-len('__lte')], '__le__'
elif key.endswith('__gt'):
col, op = key[:-len('__gt')], '__gt__'
elif key.endswith('__gte'):
col, op = key[:-len('__gte')], '__ge__'
elif key.endswith('__like'):
col, op = key[:-len('__like')], 'like'
elif key.endswith('__not_like'):
col, op = key[:-len('__not_like')], 'notlike'
elif key.endswith('__in') or key.endswith('__not_in'):
if key.endswith('__in'):
col, op, one_val_op = key[:-len('__in')], 'in_', '__eq__'
else:
col, op, one_val_op = key[:-len('__not_in')], 'notin_', '__ne__'
if isinstance(constraint, Select):
pass
elif constraint:
if isinstance(constraint, (list, set, tuple)):
if len(constraint) == 1:
op = one_val_op
constraint = next(iter(constraint))
elif isinstance(constraint, str):
constraint = text(constraint)
else:
raise ValueError(f"{col} requires a list, set or string as constraint value.")
else:
continue
else:
col, op = key, '__eq__'
attr = None
2020-05-01 23:25:07 -04:00
if '.' in col:
table_name, col = col.split('.')
_table = None
for table in tables:
if table.name == table_name.lower():
_table = table
break
if _table is not None:
attr = getattr(_table.c, col)
else:
raise ValueError(f"Table '{table_name}' not available: {', '.join([t.name for t in tables])}.")
else:
for table in tables:
attr = getattr(table.c, col, None)
if attr is not None:
break
2020-05-01 09:29:44 -04:00
if attr is None:
raise ValueError(f"Attribute '{col}' not found on tables: {', '.join([t.name for t in tables])}.")
2020-05-01 23:25:07 -04:00
clause.append(getattr(attr, op)(constraint))
2020-05-01 09:29:44 -04:00
return and_(*clause)