Updates database routines & fixes up error handling code
This commit is contained in:
parent
9e82282f5e
commit
7e38a87a0b
3 changed files with 132 additions and 132 deletions
|
@ -50,21 +50,20 @@ async def close_comment_scheduler(app):
|
|||
await app['comment_scheduler'].close()
|
||||
|
||||
|
||||
async def create_database_backup(app):
|
||||
async def database_backup_routine(app):
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(app['config']['BACKUP_INT'])
|
||||
with obtain_connection(app['db_path']) as conn:
|
||||
logger.debug('backing up database')
|
||||
schema.db_helpers.backup_database(conn, app['backup'])
|
||||
|
||||
except asyncio.CancelledError as e:
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
async def start_background_tasks(app: web.Application):
|
||||
app['reader'] = obtain_connection(app['db_path'], True)
|
||||
app['waitful_backup'] = app.loop.create_task(create_database_backup(app))
|
||||
app['waitful_backup'] = app.loop.create_task(database_backup_routine(app))
|
||||
app['comment_scheduler'] = await create_comment_scheduler()
|
||||
app['writer'] = DatabaseWriter(app['db_path'])
|
||||
|
||||
|
|
|
@ -23,9 +23,11 @@ def obtain_connection(filepath: str = None, row_factory: bool = True):
|
|||
|
||||
def get_claim_comments(conn: sqlite3.Connection, claim_id: str, parent_id: str = None,
|
||||
page: int = 1, page_size: int = 50, top_level=False):
|
||||
with conn:
|
||||
if top_level:
|
||||
results = [clean(dict(row)) for row in conn.execute(
|
||||
""" SELECT comment, comment_id, channel_name, channel_id, channel_url, timestamp, signature, parent_id
|
||||
""" SELECT comment, comment_id, channel_name, channel_id,
|
||||
channel_url, timestamp, signature, parent_id
|
||||
FROM COMMENTS_ON_CLAIMS
|
||||
WHERE claim_id LIKE ? AND parent_id IS NULL
|
||||
LIMIT ? OFFSET ? """,
|
||||
|
@ -40,7 +42,8 @@ def get_claim_comments(conn: sqlite3.Connection, claim_id: str, parent_id: str =
|
|||
)
|
||||
elif parent_id is None:
|
||||
results = [clean(dict(row)) for row in conn.execute(
|
||||
""" SELECT comment, comment_id, channel_name, channel_id, channel_url, timestamp, signature, parent_id
|
||||
""" SELECT comment, comment_id, channel_name, channel_id,
|
||||
channel_url, timestamp, signature, parent_id
|
||||
FROM COMMENTS_ON_CLAIMS
|
||||
WHERE claim_id LIKE ?
|
||||
LIMIT ? OFFSET ? """,
|
||||
|
@ -55,7 +58,8 @@ def get_claim_comments(conn: sqlite3.Connection, claim_id: str, parent_id: str =
|
|||
)
|
||||
else:
|
||||
results = [clean(dict(row)) for row in conn.execute(
|
||||
""" SELECT comment, comment_id, channel_name, channel_id, channel_url, timestamp, signature, parent_id
|
||||
""" SELECT comment, comment_id, channel_name, channel_id,
|
||||
channel_url, timestamp, signature, parent_id
|
||||
FROM COMMENTS_ON_CLAIMS
|
||||
WHERE claim_id LIKE ? AND parent_id = ?
|
||||
LIMIT ? OFFSET ? """,
|
||||
|
@ -78,20 +82,20 @@ def get_claim_comments(conn: sqlite3.Connection, claim_id: str, parent_id: str =
|
|||
}
|
||||
|
||||
|
||||
def validate_input(**kwargs):
|
||||
assert 0 < len(kwargs['comment']) <= 2000
|
||||
assert re.fullmatch(
|
||||
'[a-z0-9]{40}:([a-z0-9]{40})?',
|
||||
kwargs['claim_id'] + ':' + kwargs.get('channel_id', '')
|
||||
)
|
||||
if 'channel_name' in kwargs or 'channel_id' in kwargs:
|
||||
assert 'channel_id' in kwargs and 'channel_name' in kwargs
|
||||
def validate_channel(channel_id: str, channel_name: str):
|
||||
assert channel_id and channel_name
|
||||
assert type(channel_id) is str and type(channel_name) is str
|
||||
assert re.fullmatch(
|
||||
'^@(?:(?![\x00-\x08\x0b\x0c\x0e-\x1f\x23-\x26'
|
||||
'\x2f\x3a\x3d\x3f-\x40\uFFFE-\U0000FFFF]).){1,255}$',
|
||||
kwargs.get('channel_name', '')
|
||||
channel_name
|
||||
)
|
||||
assert re.fullmatch('[a-z0-9]{40}', kwargs.get('channel_id', ''))
|
||||
assert re.fullmatch('[a-z0-9]{40}', channel_id)
|
||||
|
||||
|
||||
def validate_input(comment: str, claim_id: str, **kwargs):
|
||||
assert 0 < len(comment) <= 2000
|
||||
assert re.fullmatch('[a-z0-9]{40}', claim_id)
|
||||
|
||||
|
||||
def _insert_channel(conn: sqlite3.Connection, channel_name: str, channel_id: str):
|
||||
|
@ -102,17 +106,25 @@ def _insert_channel(conn: sqlite3.Connection, channel_name: str, channel_id: str
|
|||
)
|
||||
|
||||
|
||||
def insert_channel_or_error(conn: sqlite3.Connection, channel_name: str, channel_id):
|
||||
try:
|
||||
validate_channel(channel_id, channel_name)
|
||||
_insert_channel(conn, channel_name, channel_id)
|
||||
except AssertionError as ae:
|
||||
logger.exception('Invalid channel values given: %s', ae)
|
||||
raise ValueError('Received invalid values for channel_id or channel_name')
|
||||
|
||||
|
||||
def _insert_comment(conn: sqlite3.Connection, claim_id: str = None, comment: str = None,
|
||||
channel_id: str = None, signature: str = None, parent_id: str = None) -> str:
|
||||
timestamp = int(time.time())
|
||||
comment_prehash = ':'.join((claim_id, comment, str(timestamp),))
|
||||
comment_prehash = bytes(comment_prehash.encode('utf-8'))
|
||||
comment_id = nacl.hash.sha256(comment_prehash).decode('utf-8')
|
||||
prehash = ':'.join((claim_id, comment, str(timestamp),))
|
||||
prehash = bytes(prehash.encode('utf-8'))
|
||||
comment_id = nacl.hash.sha256(prehash).decode('utf-8')
|
||||
with conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO COMMENT(CommentId, LbryClaimId, ChannelId, Body,
|
||||
ParentId, Signature, Timestamp)
|
||||
INSERT INTO COMMENT(CommentId, LbryClaimId, ChannelId, Body, ParentId, Signature, Timestamp)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(comment_id, claim_id, channel_id, comment, parent_id, signature, timestamp)
|
||||
|
@ -121,31 +133,8 @@ def _insert_comment(conn: sqlite3.Connection, claim_id: str = None, comment: str
|
|||
return comment_id
|
||||
|
||||
|
||||
def create_comment(conn: sqlite3.Connection, comment: str, claim_id: str, **kwargs) -> typing.Union[dict, None]:
|
||||
channel_id = kwargs.pop('channel_id', '')
|
||||
channel_name = kwargs.pop('channel_name', '')
|
||||
if channel_id or channel_name:
|
||||
try:
|
||||
validate_input(
|
||||
comment=comment,
|
||||
claim_id=claim_id,
|
||||
channel_id=channel_id,
|
||||
channel_name=channel_name,
|
||||
)
|
||||
_insert_channel(conn, channel_name, channel_id)
|
||||
except AssertionError:
|
||||
logger.exception('Received invalid input')
|
||||
raise TypeError('Invalid params given to input validation')
|
||||
else:
|
||||
channel_id = None
|
||||
try:
|
||||
comment_id = _insert_comment(
|
||||
conn=conn, comment=comment, claim_id=claim_id, channel_id=channel_id, **kwargs
|
||||
)
|
||||
except sqlite3.IntegrityError as ie:
|
||||
logger.exception(ie)
|
||||
return None
|
||||
|
||||
def get_comment_or_none(conn: sqlite3.Connection, comment_id: str) -> dict:
|
||||
with conn:
|
||||
curry = conn.execute(
|
||||
"""
|
||||
SELECT comment, comment_id, channel_name, channel_id, channel_url, timestamp, signature, parent_id
|
||||
|
@ -157,6 +146,22 @@ def create_comment(conn: sqlite3.Connection, comment: str, claim_id: str, **kwar
|
|||
return clean(dict(thing)) if thing else None
|
||||
|
||||
|
||||
def create_comment(conn: sqlite3.Connection, comment: str, claim_id: str,
|
||||
channel_id: str = None, channel_name: str = None,
|
||||
signature: str = None, parent_id: str = None):
|
||||
if channel_id or channel_name or signature:
|
||||
# do nothing with signature for now
|
||||
insert_channel_or_error(conn, channel_name=channel_name, channel_id=channel_id)
|
||||
try:
|
||||
comment_id = _insert_comment(
|
||||
conn=conn, comment=comment, claim_id=claim_id, channel_id=channel_id,
|
||||
signature=signature, parent_id=parent_id
|
||||
)
|
||||
return get_comment_or_none(conn, comment_id)
|
||||
except sqlite3.IntegrityError as ie:
|
||||
logger.exception(ie)
|
||||
|
||||
|
||||
def get_comment_ids(conn: sqlite3.Connection, claim_id: str, parent_id: str = None, page=1, page_size=50):
|
||||
""" Just return a list of the comment IDs that are associated with the given claim_id.
|
||||
If get_all is specified then it returns all the IDs, otherwise only the IDs at that level.
|
||||
|
@ -165,6 +170,7 @@ def get_comment_ids(conn: sqlite3.Connection, claim_id: str, parent_id: str = No
|
|||
For pagination the parameters are:
|
||||
get_all XOR (page_size + page)
|
||||
"""
|
||||
with conn:
|
||||
if parent_id is None:
|
||||
curs = conn.execute("""
|
||||
SELECT comment_id FROM COMMENTS_ON_CLAIMS
|
||||
|
@ -184,12 +190,8 @@ def get_comments_by_id(conn, comment_ids: list) -> typing.Union[list, None]:
|
|||
""" Returns a list containing the comment data associated with each ID within the list"""
|
||||
# format the input, under the assumption that the
|
||||
placeholders = ', '.join('?' for _ in comment_ids)
|
||||
with conn:
|
||||
return [clean(dict(row)) for row in conn.execute(
|
||||
f'SELECT * FROM COMMENTS_ON_CLAIMS WHERE comment_id IN ({placeholders})',
|
||||
tuple(comment_ids)
|
||||
)]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
||||
# __generate_database_schema(connection, 'comments_ddl.sql')
|
||||
|
|
|
@ -2,17 +2,17 @@
|
|||
import json
|
||||
import logging
|
||||
|
||||
import asyncio
|
||||
import aiojobs
|
||||
from asyncio import coroutine
|
||||
import asyncio
|
||||
from aiohttp import web
|
||||
from aiojobs.aiohttp import atomic
|
||||
from asyncio import coroutine
|
||||
|
||||
from src.writes import DatabaseWriter
|
||||
from src.database import create_comment
|
||||
from src.database import get_claim_comments
|
||||
from src.database import get_comments_by_id, get_comment_ids
|
||||
from src.database import obtain_connection
|
||||
from src.database import create_comment
|
||||
from src.writes import DatabaseWriter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -79,6 +79,9 @@ async def process_json(app, body: dict) -> dict:
|
|||
except TypeError as te:
|
||||
logger.exception('Got TypeError: %s', te)
|
||||
response['error'] = ERRORS['INVALID_PARAMS']
|
||||
except ValueError as ve:
|
||||
logger.exception('Got ValueError: %s', ve)
|
||||
response['error'] = ERRORS['INVALID_PARAMS']
|
||||
else:
|
||||
response['error'] = ERRORS['UNKNOWN']
|
||||
return response
|
||||
|
@ -87,8 +90,8 @@ async def process_json(app, body: dict) -> dict:
|
|||
@atomic
|
||||
async def api_endpoint(request: web.Request):
|
||||
try:
|
||||
body = await request.json()
|
||||
logger.info('Received POST request from %s', request.remote)
|
||||
body = await request.json()
|
||||
if type(body) is list or type(body) is dict:
|
||||
if type(body) is list:
|
||||
return web.json_response(
|
||||
|
@ -104,7 +107,3 @@ async def api_endpoint(request: web.Request):
|
|||
return web.json_response({
|
||||
'error': {'message': jde.msg, 'code': -1}
|
||||
})
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue