Updates database routines & fixes up error handling code

This commit is contained in:
Oleg Silkin 2019-06-04 08:14:12 -04:00
parent 9e82282f5e
commit 7e38a87a0b
3 changed files with 132 additions and 132 deletions

View file

@ -50,21 +50,20 @@ async def close_comment_scheduler(app):
await app['comment_scheduler'].close() await app['comment_scheduler'].close()
async def create_database_backup(app): async def database_backup_routine(app):
try: try:
while True: while True:
await asyncio.sleep(app['config']['BACKUP_INT']) await asyncio.sleep(app['config']['BACKUP_INT'])
with obtain_connection(app['db_path']) as conn: with obtain_connection(app['db_path']) as conn:
logger.debug('backing up database') logger.debug('backing up database')
schema.db_helpers.backup_database(conn, app['backup']) schema.db_helpers.backup_database(conn, app['backup'])
except asyncio.CancelledError:
except asyncio.CancelledError as e:
pass pass
async def start_background_tasks(app: web.Application): async def start_background_tasks(app: web.Application):
app['reader'] = obtain_connection(app['db_path'], True) app['reader'] = obtain_connection(app['db_path'], True)
app['waitful_backup'] = app.loop.create_task(create_database_backup(app)) app['waitful_backup'] = app.loop.create_task(database_backup_routine(app))
app['comment_scheduler'] = await create_comment_scheduler() app['comment_scheduler'] = await create_comment_scheduler()
app['writer'] = DatabaseWriter(app['db_path']) app['writer'] = DatabaseWriter(app['db_path'])

View file

@ -23,9 +23,11 @@ def obtain_connection(filepath: str = None, row_factory: bool = True):
def get_claim_comments(conn: sqlite3.Connection, claim_id: str, parent_id: str = None, def get_claim_comments(conn: sqlite3.Connection, claim_id: str, parent_id: str = None,
page: int = 1, page_size: int = 50, top_level=False): page: int = 1, page_size: int = 50, top_level=False):
with conn:
if top_level: if top_level:
results = [clean(dict(row)) for row in conn.execute( results = [clean(dict(row)) for row in conn.execute(
""" SELECT comment, comment_id, channel_name, channel_id, channel_url, timestamp, signature, parent_id """ SELECT comment, comment_id, channel_name, channel_id,
channel_url, timestamp, signature, parent_id
FROM COMMENTS_ON_CLAIMS FROM COMMENTS_ON_CLAIMS
WHERE claim_id LIKE ? AND parent_id IS NULL WHERE claim_id LIKE ? AND parent_id IS NULL
LIMIT ? OFFSET ? """, LIMIT ? OFFSET ? """,
@ -40,7 +42,8 @@ def get_claim_comments(conn: sqlite3.Connection, claim_id: str, parent_id: str =
) )
elif parent_id is None: elif parent_id is None:
results = [clean(dict(row)) for row in conn.execute( results = [clean(dict(row)) for row in conn.execute(
""" SELECT comment, comment_id, channel_name, channel_id, channel_url, timestamp, signature, parent_id """ SELECT comment, comment_id, channel_name, channel_id,
channel_url, timestamp, signature, parent_id
FROM COMMENTS_ON_CLAIMS FROM COMMENTS_ON_CLAIMS
WHERE claim_id LIKE ? WHERE claim_id LIKE ?
LIMIT ? OFFSET ? """, LIMIT ? OFFSET ? """,
@ -55,7 +58,8 @@ def get_claim_comments(conn: sqlite3.Connection, claim_id: str, parent_id: str =
) )
else: else:
results = [clean(dict(row)) for row in conn.execute( results = [clean(dict(row)) for row in conn.execute(
""" SELECT comment, comment_id, channel_name, channel_id, channel_url, timestamp, signature, parent_id """ SELECT comment, comment_id, channel_name, channel_id,
channel_url, timestamp, signature, parent_id
FROM COMMENTS_ON_CLAIMS FROM COMMENTS_ON_CLAIMS
WHERE claim_id LIKE ? AND parent_id = ? WHERE claim_id LIKE ? AND parent_id = ?
LIMIT ? OFFSET ? """, LIMIT ? OFFSET ? """,
@ -78,20 +82,20 @@ def get_claim_comments(conn: sqlite3.Connection, claim_id: str, parent_id: str =
} }
def validate_input(**kwargs): def validate_channel(channel_id: str, channel_name: str):
assert 0 < len(kwargs['comment']) <= 2000 assert channel_id and channel_name
assert re.fullmatch( assert type(channel_id) is str and type(channel_name) is str
'[a-z0-9]{40}:([a-z0-9]{40})?',
kwargs['claim_id'] + ':' + kwargs.get('channel_id', '')
)
if 'channel_name' in kwargs or 'channel_id' in kwargs:
assert 'channel_id' in kwargs and 'channel_name' in kwargs
assert re.fullmatch( assert re.fullmatch(
'^@(?:(?![\x00-\x08\x0b\x0c\x0e-\x1f\x23-\x26' '^@(?:(?![\x00-\x08\x0b\x0c\x0e-\x1f\x23-\x26'
'\x2f\x3a\x3d\x3f-\x40\uFFFE-\U0000FFFF]).){1,255}$', '\x2f\x3a\x3d\x3f-\x40\uFFFE-\U0000FFFF]).){1,255}$',
kwargs.get('channel_name', '') channel_name
) )
assert re.fullmatch('[a-z0-9]{40}', kwargs.get('channel_id', '')) assert re.fullmatch('[a-z0-9]{40}', channel_id)
def validate_input(comment: str, claim_id: str, **kwargs):
assert 0 < len(comment) <= 2000
assert re.fullmatch('[a-z0-9]{40}', claim_id)
def _insert_channel(conn: sqlite3.Connection, channel_name: str, channel_id: str): def _insert_channel(conn: sqlite3.Connection, channel_name: str, channel_id: str):
@ -102,17 +106,25 @@ def _insert_channel(conn: sqlite3.Connection, channel_name: str, channel_id: str
) )
def insert_channel_or_error(conn: sqlite3.Connection, channel_name: str, channel_id):
try:
validate_channel(channel_id, channel_name)
_insert_channel(conn, channel_name, channel_id)
except AssertionError as ae:
logger.exception('Invalid channel values given: %s', ae)
raise ValueError('Received invalid values for channel_id or channel_name')
def _insert_comment(conn: sqlite3.Connection, claim_id: str = None, comment: str = None, def _insert_comment(conn: sqlite3.Connection, claim_id: str = None, comment: str = None,
channel_id: str = None, signature: str = None, parent_id: str = None) -> str: channel_id: str = None, signature: str = None, parent_id: str = None) -> str:
timestamp = int(time.time()) timestamp = int(time.time())
comment_prehash = ':'.join((claim_id, comment, str(timestamp),)) prehash = ':'.join((claim_id, comment, str(timestamp),))
comment_prehash = bytes(comment_prehash.encode('utf-8')) prehash = bytes(prehash.encode('utf-8'))
comment_id = nacl.hash.sha256(comment_prehash).decode('utf-8') comment_id = nacl.hash.sha256(prehash).decode('utf-8')
with conn: with conn:
conn.execute( conn.execute(
""" """
INSERT INTO COMMENT(CommentId, LbryClaimId, ChannelId, Body, INSERT INTO COMMENT(CommentId, LbryClaimId, ChannelId, Body, ParentId, Signature, Timestamp)
ParentId, Signature, Timestamp)
VALUES (?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?)
""", """,
(comment_id, claim_id, channel_id, comment, parent_id, signature, timestamp) (comment_id, claim_id, channel_id, comment, parent_id, signature, timestamp)
@ -121,31 +133,8 @@ def _insert_comment(conn: sqlite3.Connection, claim_id: str = None, comment: str
return comment_id return comment_id
def create_comment(conn: sqlite3.Connection, comment: str, claim_id: str, **kwargs) -> typing.Union[dict, None]: def get_comment_or_none(conn: sqlite3.Connection, comment_id: str) -> dict:
channel_id = kwargs.pop('channel_id', '') with conn:
channel_name = kwargs.pop('channel_name', '')
if channel_id or channel_name:
try:
validate_input(
comment=comment,
claim_id=claim_id,
channel_id=channel_id,
channel_name=channel_name,
)
_insert_channel(conn, channel_name, channel_id)
except AssertionError:
logger.exception('Received invalid input')
raise TypeError('Invalid params given to input validation')
else:
channel_id = None
try:
comment_id = _insert_comment(
conn=conn, comment=comment, claim_id=claim_id, channel_id=channel_id, **kwargs
)
except sqlite3.IntegrityError as ie:
logger.exception(ie)
return None
curry = conn.execute( curry = conn.execute(
""" """
SELECT comment, comment_id, channel_name, channel_id, channel_url, timestamp, signature, parent_id SELECT comment, comment_id, channel_name, channel_id, channel_url, timestamp, signature, parent_id
@ -157,6 +146,22 @@ def create_comment(conn: sqlite3.Connection, comment: str, claim_id: str, **kwar
return clean(dict(thing)) if thing else None return clean(dict(thing)) if thing else None
def create_comment(conn: sqlite3.Connection, comment: str, claim_id: str,
channel_id: str = None, channel_name: str = None,
signature: str = None, parent_id: str = None):
if channel_id or channel_name or signature:
# do nothing with signature for now
insert_channel_or_error(conn, channel_name=channel_name, channel_id=channel_id)
try:
comment_id = _insert_comment(
conn=conn, comment=comment, claim_id=claim_id, channel_id=channel_id,
signature=signature, parent_id=parent_id
)
return get_comment_or_none(conn, comment_id)
except sqlite3.IntegrityError as ie:
logger.exception(ie)
def get_comment_ids(conn: sqlite3.Connection, claim_id: str, parent_id: str = None, page=1, page_size=50): def get_comment_ids(conn: sqlite3.Connection, claim_id: str, parent_id: str = None, page=1, page_size=50):
""" Just return a list of the comment IDs that are associated with the given claim_id. """ Just return a list of the comment IDs that are associated with the given claim_id.
If get_all is specified then it returns all the IDs, otherwise only the IDs at that level. If get_all is specified then it returns all the IDs, otherwise only the IDs at that level.
@ -165,6 +170,7 @@ def get_comment_ids(conn: sqlite3.Connection, claim_id: str, parent_id: str = No
For pagination the parameters are: For pagination the parameters are:
get_all XOR (page_size + page) get_all XOR (page_size + page)
""" """
with conn:
if parent_id is None: if parent_id is None:
curs = conn.execute(""" curs = conn.execute("""
SELECT comment_id FROM COMMENTS_ON_CLAIMS SELECT comment_id FROM COMMENTS_ON_CLAIMS
@ -184,12 +190,8 @@ def get_comments_by_id(conn, comment_ids: list) -> typing.Union[list, None]:
""" Returns a list containing the comment data associated with each ID within the list""" """ Returns a list containing the comment data associated with each ID within the list"""
# format the input, under the assumption that the # format the input, under the assumption that the
placeholders = ', '.join('?' for _ in comment_ids) placeholders = ', '.join('?' for _ in comment_ids)
with conn:
return [clean(dict(row)) for row in conn.execute( return [clean(dict(row)) for row in conn.execute(
f'SELECT * FROM COMMENTS_ON_CLAIMS WHERE comment_id IN ({placeholders})', f'SELECT * FROM COMMENTS_ON_CLAIMS WHERE comment_id IN ({placeholders})',
tuple(comment_ids) tuple(comment_ids)
)] )]
if __name__ == '__main__':
pass
# __generate_database_schema(connection, 'comments_ddl.sql')

View file

@ -2,17 +2,17 @@
import json import json
import logging import logging
import asyncio
import aiojobs import aiojobs
from asyncio import coroutine import asyncio
from aiohttp import web from aiohttp import web
from aiojobs.aiohttp import atomic from aiojobs.aiohttp import atomic
from asyncio import coroutine
from src.writes import DatabaseWriter from src.database import create_comment
from src.database import get_claim_comments from src.database import get_claim_comments
from src.database import get_comments_by_id, get_comment_ids from src.database import get_comments_by_id, get_comment_ids
from src.database import obtain_connection from src.database import obtain_connection
from src.database import create_comment from src.writes import DatabaseWriter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -79,6 +79,9 @@ async def process_json(app, body: dict) -> dict:
except TypeError as te: except TypeError as te:
logger.exception('Got TypeError: %s', te) logger.exception('Got TypeError: %s', te)
response['error'] = ERRORS['INVALID_PARAMS'] response['error'] = ERRORS['INVALID_PARAMS']
except ValueError as ve:
logger.exception('Got ValueError: %s', ve)
response['error'] = ERRORS['INVALID_PARAMS']
else: else:
response['error'] = ERRORS['UNKNOWN'] response['error'] = ERRORS['UNKNOWN']
return response return response
@ -87,8 +90,8 @@ async def process_json(app, body: dict) -> dict:
@atomic @atomic
async def api_endpoint(request: web.Request): async def api_endpoint(request: web.Request):
try: try:
body = await request.json()
logger.info('Received POST request from %s', request.remote) logger.info('Received POST request from %s', request.remote)
body = await request.json()
if type(body) is list or type(body) is dict: if type(body) is list or type(body) is dict:
if type(body) is list: if type(body) is list:
return web.json_response( return web.json_response(
@ -104,7 +107,3 @@ async def api_endpoint(request: web.Request):
return web.json_response({ return web.json_response({
'error': {'message': jde.msg, 'code': -1} 'error': {'message': jde.msg, 'code': -1}
}) })