Updates database routines & fixes up error handling code
This commit is contained in:
parent
9e82282f5e
commit
7e38a87a0b
3 changed files with 132 additions and 132 deletions
|
@ -50,21 +50,20 @@ async def close_comment_scheduler(app):
|
||||||
await app['comment_scheduler'].close()
|
await app['comment_scheduler'].close()
|
||||||
|
|
||||||
|
|
||||||
async def create_database_backup(app):
|
async def database_backup_routine(app):
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
await asyncio.sleep(app['config']['BACKUP_INT'])
|
await asyncio.sleep(app['config']['BACKUP_INT'])
|
||||||
with obtain_connection(app['db_path']) as conn:
|
with obtain_connection(app['db_path']) as conn:
|
||||||
logger.debug('backing up database')
|
logger.debug('backing up database')
|
||||||
schema.db_helpers.backup_database(conn, app['backup'])
|
schema.db_helpers.backup_database(conn, app['backup'])
|
||||||
|
except asyncio.CancelledError:
|
||||||
except asyncio.CancelledError as e:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
async def start_background_tasks(app: web.Application):
|
async def start_background_tasks(app: web.Application):
|
||||||
app['reader'] = obtain_connection(app['db_path'], True)
|
app['reader'] = obtain_connection(app['db_path'], True)
|
||||||
app['waitful_backup'] = app.loop.create_task(create_database_backup(app))
|
app['waitful_backup'] = app.loop.create_task(database_backup_routine(app))
|
||||||
app['comment_scheduler'] = await create_comment_scheduler()
|
app['comment_scheduler'] = await create_comment_scheduler()
|
||||||
app['writer'] = DatabaseWriter(app['db_path'])
|
app['writer'] = DatabaseWriter(app['db_path'])
|
||||||
|
|
||||||
|
|
|
@ -23,9 +23,11 @@ def obtain_connection(filepath: str = None, row_factory: bool = True):
|
||||||
|
|
||||||
def get_claim_comments(conn: sqlite3.Connection, claim_id: str, parent_id: str = None,
|
def get_claim_comments(conn: sqlite3.Connection, claim_id: str, parent_id: str = None,
|
||||||
page: int = 1, page_size: int = 50, top_level=False):
|
page: int = 1, page_size: int = 50, top_level=False):
|
||||||
|
with conn:
|
||||||
if top_level:
|
if top_level:
|
||||||
results = [clean(dict(row)) for row in conn.execute(
|
results = [clean(dict(row)) for row in conn.execute(
|
||||||
""" SELECT comment, comment_id, channel_name, channel_id, channel_url, timestamp, signature, parent_id
|
""" SELECT comment, comment_id, channel_name, channel_id,
|
||||||
|
channel_url, timestamp, signature, parent_id
|
||||||
FROM COMMENTS_ON_CLAIMS
|
FROM COMMENTS_ON_CLAIMS
|
||||||
WHERE claim_id LIKE ? AND parent_id IS NULL
|
WHERE claim_id LIKE ? AND parent_id IS NULL
|
||||||
LIMIT ? OFFSET ? """,
|
LIMIT ? OFFSET ? """,
|
||||||
|
@ -40,7 +42,8 @@ def get_claim_comments(conn: sqlite3.Connection, claim_id: str, parent_id: str =
|
||||||
)
|
)
|
||||||
elif parent_id is None:
|
elif parent_id is None:
|
||||||
results = [clean(dict(row)) for row in conn.execute(
|
results = [clean(dict(row)) for row in conn.execute(
|
||||||
""" SELECT comment, comment_id, channel_name, channel_id, channel_url, timestamp, signature, parent_id
|
""" SELECT comment, comment_id, channel_name, channel_id,
|
||||||
|
channel_url, timestamp, signature, parent_id
|
||||||
FROM COMMENTS_ON_CLAIMS
|
FROM COMMENTS_ON_CLAIMS
|
||||||
WHERE claim_id LIKE ?
|
WHERE claim_id LIKE ?
|
||||||
LIMIT ? OFFSET ? """,
|
LIMIT ? OFFSET ? """,
|
||||||
|
@ -55,7 +58,8 @@ def get_claim_comments(conn: sqlite3.Connection, claim_id: str, parent_id: str =
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
results = [clean(dict(row)) for row in conn.execute(
|
results = [clean(dict(row)) for row in conn.execute(
|
||||||
""" SELECT comment, comment_id, channel_name, channel_id, channel_url, timestamp, signature, parent_id
|
""" SELECT comment, comment_id, channel_name, channel_id,
|
||||||
|
channel_url, timestamp, signature, parent_id
|
||||||
FROM COMMENTS_ON_CLAIMS
|
FROM COMMENTS_ON_CLAIMS
|
||||||
WHERE claim_id LIKE ? AND parent_id = ?
|
WHERE claim_id LIKE ? AND parent_id = ?
|
||||||
LIMIT ? OFFSET ? """,
|
LIMIT ? OFFSET ? """,
|
||||||
|
@ -78,20 +82,20 @@ def get_claim_comments(conn: sqlite3.Connection, claim_id: str, parent_id: str =
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def validate_input(**kwargs):
|
def validate_channel(channel_id: str, channel_name: str):
|
||||||
assert 0 < len(kwargs['comment']) <= 2000
|
assert channel_id and channel_name
|
||||||
assert re.fullmatch(
|
assert type(channel_id) is str and type(channel_name) is str
|
||||||
'[a-z0-9]{40}:([a-z0-9]{40})?',
|
|
||||||
kwargs['claim_id'] + ':' + kwargs.get('channel_id', '')
|
|
||||||
)
|
|
||||||
if 'channel_name' in kwargs or 'channel_id' in kwargs:
|
|
||||||
assert 'channel_id' in kwargs and 'channel_name' in kwargs
|
|
||||||
assert re.fullmatch(
|
assert re.fullmatch(
|
||||||
'^@(?:(?![\x00-\x08\x0b\x0c\x0e-\x1f\x23-\x26'
|
'^@(?:(?![\x00-\x08\x0b\x0c\x0e-\x1f\x23-\x26'
|
||||||
'\x2f\x3a\x3d\x3f-\x40\uFFFE-\U0000FFFF]).){1,255}$',
|
'\x2f\x3a\x3d\x3f-\x40\uFFFE-\U0000FFFF]).){1,255}$',
|
||||||
kwargs.get('channel_name', '')
|
channel_name
|
||||||
)
|
)
|
||||||
assert re.fullmatch('[a-z0-9]{40}', kwargs.get('channel_id', ''))
|
assert re.fullmatch('[a-z0-9]{40}', channel_id)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_input(comment: str, claim_id: str, **kwargs):
|
||||||
|
assert 0 < len(comment) <= 2000
|
||||||
|
assert re.fullmatch('[a-z0-9]{40}', claim_id)
|
||||||
|
|
||||||
|
|
||||||
def _insert_channel(conn: sqlite3.Connection, channel_name: str, channel_id: str):
|
def _insert_channel(conn: sqlite3.Connection, channel_name: str, channel_id: str):
|
||||||
|
@ -102,17 +106,25 @@ def _insert_channel(conn: sqlite3.Connection, channel_name: str, channel_id: str
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def insert_channel_or_error(conn: sqlite3.Connection, channel_name: str, channel_id):
|
||||||
|
try:
|
||||||
|
validate_channel(channel_id, channel_name)
|
||||||
|
_insert_channel(conn, channel_name, channel_id)
|
||||||
|
except AssertionError as ae:
|
||||||
|
logger.exception('Invalid channel values given: %s', ae)
|
||||||
|
raise ValueError('Received invalid values for channel_id or channel_name')
|
||||||
|
|
||||||
|
|
||||||
def _insert_comment(conn: sqlite3.Connection, claim_id: str = None, comment: str = None,
|
def _insert_comment(conn: sqlite3.Connection, claim_id: str = None, comment: str = None,
|
||||||
channel_id: str = None, signature: str = None, parent_id: str = None) -> str:
|
channel_id: str = None, signature: str = None, parent_id: str = None) -> str:
|
||||||
timestamp = int(time.time())
|
timestamp = int(time.time())
|
||||||
comment_prehash = ':'.join((claim_id, comment, str(timestamp),))
|
prehash = ':'.join((claim_id, comment, str(timestamp),))
|
||||||
comment_prehash = bytes(comment_prehash.encode('utf-8'))
|
prehash = bytes(prehash.encode('utf-8'))
|
||||||
comment_id = nacl.hash.sha256(comment_prehash).decode('utf-8')
|
comment_id = nacl.hash.sha256(prehash).decode('utf-8')
|
||||||
with conn:
|
with conn:
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"""
|
"""
|
||||||
INSERT INTO COMMENT(CommentId, LbryClaimId, ChannelId, Body,
|
INSERT INTO COMMENT(CommentId, LbryClaimId, ChannelId, Body, ParentId, Signature, Timestamp)
|
||||||
ParentId, Signature, Timestamp)
|
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||||
""",
|
""",
|
||||||
(comment_id, claim_id, channel_id, comment, parent_id, signature, timestamp)
|
(comment_id, claim_id, channel_id, comment, parent_id, signature, timestamp)
|
||||||
|
@ -121,31 +133,8 @@ def _insert_comment(conn: sqlite3.Connection, claim_id: str = None, comment: str
|
||||||
return comment_id
|
return comment_id
|
||||||
|
|
||||||
|
|
||||||
def create_comment(conn: sqlite3.Connection, comment: str, claim_id: str, **kwargs) -> typing.Union[dict, None]:
|
def get_comment_or_none(conn: sqlite3.Connection, comment_id: str) -> dict:
|
||||||
channel_id = kwargs.pop('channel_id', '')
|
with conn:
|
||||||
channel_name = kwargs.pop('channel_name', '')
|
|
||||||
if channel_id or channel_name:
|
|
||||||
try:
|
|
||||||
validate_input(
|
|
||||||
comment=comment,
|
|
||||||
claim_id=claim_id,
|
|
||||||
channel_id=channel_id,
|
|
||||||
channel_name=channel_name,
|
|
||||||
)
|
|
||||||
_insert_channel(conn, channel_name, channel_id)
|
|
||||||
except AssertionError:
|
|
||||||
logger.exception('Received invalid input')
|
|
||||||
raise TypeError('Invalid params given to input validation')
|
|
||||||
else:
|
|
||||||
channel_id = None
|
|
||||||
try:
|
|
||||||
comment_id = _insert_comment(
|
|
||||||
conn=conn, comment=comment, claim_id=claim_id, channel_id=channel_id, **kwargs
|
|
||||||
)
|
|
||||||
except sqlite3.IntegrityError as ie:
|
|
||||||
logger.exception(ie)
|
|
||||||
return None
|
|
||||||
|
|
||||||
curry = conn.execute(
|
curry = conn.execute(
|
||||||
"""
|
"""
|
||||||
SELECT comment, comment_id, channel_name, channel_id, channel_url, timestamp, signature, parent_id
|
SELECT comment, comment_id, channel_name, channel_id, channel_url, timestamp, signature, parent_id
|
||||||
|
@ -157,6 +146,22 @@ def create_comment(conn: sqlite3.Connection, comment: str, claim_id: str, **kwar
|
||||||
return clean(dict(thing)) if thing else None
|
return clean(dict(thing)) if thing else None
|
||||||
|
|
||||||
|
|
||||||
|
def create_comment(conn: sqlite3.Connection, comment: str, claim_id: str,
|
||||||
|
channel_id: str = None, channel_name: str = None,
|
||||||
|
signature: str = None, parent_id: str = None):
|
||||||
|
if channel_id or channel_name or signature:
|
||||||
|
# do nothing with signature for now
|
||||||
|
insert_channel_or_error(conn, channel_name=channel_name, channel_id=channel_id)
|
||||||
|
try:
|
||||||
|
comment_id = _insert_comment(
|
||||||
|
conn=conn, comment=comment, claim_id=claim_id, channel_id=channel_id,
|
||||||
|
signature=signature, parent_id=parent_id
|
||||||
|
)
|
||||||
|
return get_comment_or_none(conn, comment_id)
|
||||||
|
except sqlite3.IntegrityError as ie:
|
||||||
|
logger.exception(ie)
|
||||||
|
|
||||||
|
|
||||||
def get_comment_ids(conn: sqlite3.Connection, claim_id: str, parent_id: str = None, page=1, page_size=50):
|
def get_comment_ids(conn: sqlite3.Connection, claim_id: str, parent_id: str = None, page=1, page_size=50):
|
||||||
""" Just return a list of the comment IDs that are associated with the given claim_id.
|
""" Just return a list of the comment IDs that are associated with the given claim_id.
|
||||||
If get_all is specified then it returns all the IDs, otherwise only the IDs at that level.
|
If get_all is specified then it returns all the IDs, otherwise only the IDs at that level.
|
||||||
|
@ -165,6 +170,7 @@ def get_comment_ids(conn: sqlite3.Connection, claim_id: str, parent_id: str = No
|
||||||
For pagination the parameters are:
|
For pagination the parameters are:
|
||||||
get_all XOR (page_size + page)
|
get_all XOR (page_size + page)
|
||||||
"""
|
"""
|
||||||
|
with conn:
|
||||||
if parent_id is None:
|
if parent_id is None:
|
||||||
curs = conn.execute("""
|
curs = conn.execute("""
|
||||||
SELECT comment_id FROM COMMENTS_ON_CLAIMS
|
SELECT comment_id FROM COMMENTS_ON_CLAIMS
|
||||||
|
@ -184,12 +190,8 @@ def get_comments_by_id(conn, comment_ids: list) -> typing.Union[list, None]:
|
||||||
""" Returns a list containing the comment data associated with each ID within the list"""
|
""" Returns a list containing the comment data associated with each ID within the list"""
|
||||||
# format the input, under the assumption that the
|
# format the input, under the assumption that the
|
||||||
placeholders = ', '.join('?' for _ in comment_ids)
|
placeholders = ', '.join('?' for _ in comment_ids)
|
||||||
|
with conn:
|
||||||
return [clean(dict(row)) for row in conn.execute(
|
return [clean(dict(row)) for row in conn.execute(
|
||||||
f'SELECT * FROM COMMENTS_ON_CLAIMS WHERE comment_id IN ({placeholders})',
|
f'SELECT * FROM COMMENTS_ON_CLAIMS WHERE comment_id IN ({placeholders})',
|
||||||
tuple(comment_ids)
|
tuple(comment_ids)
|
||||||
)]
|
)]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
pass
|
|
||||||
# __generate_database_schema(connection, 'comments_ddl.sql')
|
|
||||||
|
|
|
@ -2,17 +2,17 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import aiojobs
|
import aiojobs
|
||||||
from asyncio import coroutine
|
import asyncio
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from aiojobs.aiohttp import atomic
|
from aiojobs.aiohttp import atomic
|
||||||
|
from asyncio import coroutine
|
||||||
|
|
||||||
from src.writes import DatabaseWriter
|
from src.database import create_comment
|
||||||
from src.database import get_claim_comments
|
from src.database import get_claim_comments
|
||||||
from src.database import get_comments_by_id, get_comment_ids
|
from src.database import get_comments_by_id, get_comment_ids
|
||||||
from src.database import obtain_connection
|
from src.database import obtain_connection
|
||||||
from src.database import create_comment
|
from src.writes import DatabaseWriter
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -79,6 +79,9 @@ async def process_json(app, body: dict) -> dict:
|
||||||
except TypeError as te:
|
except TypeError as te:
|
||||||
logger.exception('Got TypeError: %s', te)
|
logger.exception('Got TypeError: %s', te)
|
||||||
response['error'] = ERRORS['INVALID_PARAMS']
|
response['error'] = ERRORS['INVALID_PARAMS']
|
||||||
|
except ValueError as ve:
|
||||||
|
logger.exception('Got ValueError: %s', ve)
|
||||||
|
response['error'] = ERRORS['INVALID_PARAMS']
|
||||||
else:
|
else:
|
||||||
response['error'] = ERRORS['UNKNOWN']
|
response['error'] = ERRORS['UNKNOWN']
|
||||||
return response
|
return response
|
||||||
|
@ -87,8 +90,8 @@ async def process_json(app, body: dict) -> dict:
|
||||||
@atomic
|
@atomic
|
||||||
async def api_endpoint(request: web.Request):
|
async def api_endpoint(request: web.Request):
|
||||||
try:
|
try:
|
||||||
body = await request.json()
|
|
||||||
logger.info('Received POST request from %s', request.remote)
|
logger.info('Received POST request from %s', request.remote)
|
||||||
|
body = await request.json()
|
||||||
if type(body) is list or type(body) is dict:
|
if type(body) is list or type(body) is dict:
|
||||||
if type(body) is list:
|
if type(body) is list:
|
||||||
return web.json_response(
|
return web.json_response(
|
||||||
|
@ -104,7 +107,3 @@ async def api_endpoint(request: web.Request):
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
'error': {'message': jde.msg, 'code': -1}
|
'error': {'message': jde.msg, 'code': -1}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue