diff --git a/scripts/aggregate_comments_by_claim.py b/scripts/aggregate_comments_by_claim.py index cf7baa2..8ffc7cd 100644 --- a/scripts/aggregate_comments_by_claim.py +++ b/scripts/aggregate_comments_by_claim.py @@ -56,6 +56,7 @@ async def main(): i += 1 return claims + if __name__ == '__main__': loop = asyncio.get_event_loop() claims = loop.run_until_complete(main()) diff --git a/scripts/valid_signatures.py b/scripts/valid_signatures.py index 99ef3ce..19fc878 100644 --- a/scripts/valid_signatures.py +++ b/scripts/valid_signatures.py @@ -1,12 +1,14 @@ import binascii +import logging import hashlib import json import sqlite3 import asyncio - import aiohttp -from src.server.misc import is_signature_valid, get_encoded_signature +from server.validation import is_signature_valid, get_encoded_signature + +logger = logging.getLogger(__name__) async def request_lbrynet(url, method, **params): diff --git a/src/database/comments_ddl.sql b/src/database/comments_ddl.sql index 4eebe13..2dfc19d 100644 --- a/src/database/comments_ddl.sql +++ b/src/database/comments_ddl.sql @@ -9,13 +9,13 @@ CREATE TABLE IF NOT EXISTS COMMENT ( CommentId TEXT NOT NULL, LbryClaimId TEXT NOT NULL, - ChannelId TEXT DEFAULT NULL, + ChannelId TEXT DEFAULT NULL, Body TEXT NOT NULL, - ParentId TEXT DEFAULT NULL, - Signature TEXT DEFAULT NULL, + ParentId TEXT DEFAULT NULL, + Signature TEXT DEFAULT NULL, Timestamp INTEGER NOT NULL, - SigningTs TEXT DEFAULT NULL, - IsHidden BOOLEAN NOT NULL DEFAULT FALSE, + SigningTs TEXT DEFAULT NULL, + IsHidden BOOLEAN NOT NULL DEFAULT FALSE, CONSTRAINT COMMENT_PRIMARY_KEY PRIMARY KEY (CommentId) ON CONFLICT IGNORE, CONSTRAINT COMMENT_SIGNATURE_SK UNIQUE (Signature) ON CONFLICT ABORT, CONSTRAINT COMMENT_CHANNEL_FK FOREIGN KEY (ChannelId) REFERENCES CHANNEL (ClaimId) @@ -44,21 +44,21 @@ CREATE TABLE IF NOT EXISTS CHANNEL -- CREATE INDEX IF NOT EXISTS CHANNEL_COMMENT_INDEX ON COMMENT (ChannelId, CommentId); -- VIEWS -CREATE VIEW IF NOT EXISTS COMMENTS_ON_CLAIMS AS SELECT - C.CommentId AS comment_id, - C.Body AS comment, - C.LbryClaimId AS claim_id, - C.Timestamp AS timestamp, - CHAN.Name AS channel_name, - CHAN.ClaimId AS channel_id, - ('lbry://' || CHAN.Name || '#' || CHAN.ClaimId) AS channel_url, - C.Signature AS signature, - C.SigningTs AS signing_ts, - C.ParentId AS parent_id, - C.IsHidden AS is_hidden - FROM COMMENT AS C - LEFT OUTER JOIN CHANNEL CHAN ON C.ChannelId = CHAN.ClaimId - ORDER BY C.Timestamp DESC; +CREATE VIEW IF NOT EXISTS COMMENTS_ON_CLAIMS AS +SELECT C.CommentId AS comment_id, + C.Body AS comment, + C.LbryClaimId AS claim_id, + C.Timestamp AS timestamp, + CHAN.Name AS channel_name, + CHAN.ClaimId AS channel_id, + ('lbry://' || CHAN.Name || '#' || CHAN.ClaimId) AS channel_url, + C.Signature AS signature, + C.SigningTs AS signing_ts, + C.ParentId AS parent_id, + C.IsHidden AS is_hidden +FROM COMMENT AS C + LEFT OUTER JOIN CHANNEL CHAN ON C.ChannelId = CHAN.ClaimId +ORDER BY C.Timestamp DESC; DROP VIEW IF EXISTS COMMENT_REPLIES; diff --git a/src/database/queries.py b/src/database/queries.py index 4d16f80..4264a57 100644 --- a/src/database/queries.py +++ b/src/database/queries.py @@ -1,17 +1,16 @@ import atexit import logging +import math import sqlite3 import time import typing -import math import nacl.hash from src.database.schema import CREATE_TABLES_QUERY logger = logging.getLogger(__name__) - SELECT_COMMENTS_ON_CLAIMS = """ SELECT comment, comment_id, channel_name, channel_id, channel_url, timestamp, signature, signing_ts, parent_id, is_hidden @@ -95,7 +94,7 @@ def get_claim_hidden_comments(conn: sqlite3.Connection, claim_id: str, hidden=Tr 'items': results, 'page': page, 'page_size': page_size, - 'total_pages': math.ceil(count/page_size), + 'total_pages': math.ceil(count / page_size), 'total_items': count, 'has_hidden_comments': claim_has_hidden_comments(conn, claim_id) } @@ -157,14 +156,14 @@ def get_comment_ids(conn: sqlite3.Connection, claim_id: str, parent_id: str = No curs = conn.execute(""" SELECT comment_id FROM COMMENTS_ON_CLAIMS WHERE claim_id = ? AND parent_id IS NULL LIMIT ? OFFSET ? - """, (claim_id, page_size, page_size*abs(page - 1),) - ) + """, (claim_id, page_size, page_size * abs(page - 1),) + ) else: curs = conn.execute(""" SELECT comment_id FROM COMMENTS_ON_CLAIMS WHERE claim_id = ? AND parent_id = ? LIMIT ? OFFSET ? """, (claim_id, parent_id, page_size, page_size * abs(page - 1),) - ) + ) return [tuple(row)[0] for row in curs.fetchall()] @@ -219,6 +218,25 @@ def hide_comments_by_id(conn: sqlite3.Connection, comment_ids: list) -> bool: return bool(curs.rowcount) +def edit_comment_by_id(conn: sqlite3.Connection, comment_id: str, comment: str, + signature: str, signing_ts: str) -> bool: + with conn: + curs = conn.execute( + """ + UPDATE COMMENT + SET Body = :comment, Signature = :signature, SigningTs = :signing_ts + WHERE CommentId = :comment_id + """, + { + 'comment': comment, + 'signature': signature, + 'signing_ts': signing_ts, + 'comment_id': comment_id + }) + logger.info("updated comment with `comment_id`: %s", comment_id) + return bool(curs.rowcount) + + class DatabaseWriter(object): _writer = None diff --git a/src/database/schema.py b/src/database/schema.py index d56ad1f..c75681b 100644 --- a/src/database/schema.py +++ b/src/database/schema.py @@ -67,11 +67,10 @@ ORDER BY OG.Timestamp; """ CREATE_TABLES_QUERY = ( - PRAGMAS + - CREATE_COMMENT_TABLE + - CREATE_COMMENT_INDEXES + - CREATE_CHANNEL_TABLE + - CREATE_COMMENTS_ON_CLAIMS_VIEW + - CREATE_COMMENT_REPLIES_VIEW + PRAGMAS + + CREATE_COMMENT_TABLE + + CREATE_COMMENT_INDEXES + + CREATE_CHANNEL_TABLE + + CREATE_COMMENTS_ON_CLAIMS_VIEW + + CREATE_COMMENT_REPLIES_VIEW ) - diff --git a/src/database/writes.py b/src/database/writes.py index 611874a..eb189a6 100644 --- a/src/database/writes.py +++ b/src/database/writes.py @@ -1,21 +1,17 @@ import logging import sqlite3 - from asyncio import coroutine -from src.database.queries import delete_comment_by_id, get_comments_by_id -from src.database.queries import get_claim_ids_from_comment_ids -from src.database.queries import get_comment_or_none -from src.database.queries import hide_comments_by_id -from src.database.queries import insert_channel -from src.database.queries import insert_comment -from src.server.misc import channel_matches_pattern_or_error, create_notification_batch -from src.server.misc import is_valid_base_comment -from src.server.misc import is_valid_credential_input -from src.server.misc import send_notification -from src.server.misc import send_notifications +from src.server.validation import is_valid_channel +from src.server.validation import is_valid_base_comment +from src.server.validation import is_valid_credential_input +from src.server.validation import validate_signature_from_claim +from src.server.validation import body_is_valid from src.server.misc import get_claim_from_id -from src.server.misc import validate_signature_from_claim +from src.server.external import send_notifications +from src.server.external import send_notification +import src.database.queries as db + logger = logging.getLogger(__name__) @@ -24,7 +20,7 @@ def create_comment_or_error(conn, comment, claim_id, channel_id=None, channel_na signature=None, signing_ts=None, parent_id=None) -> dict: if channel_id or channel_name or signature or signing_ts: insert_channel_or_error(conn, channel_name, channel_id) - comment_id = insert_comment( + comment_id = db.insert_comment( conn=conn, comment=comment, claim_id=claim_id, @@ -33,13 +29,13 @@ def create_comment_or_error(conn, comment, claim_id, channel_id=None, channel_na parent_id=parent_id, signing_ts=signing_ts ) - return get_comment_or_none(conn, comment_id) + return db.get_comment_or_none(conn, comment_id) def insert_channel_or_error(conn: sqlite3.Connection, channel_name: str, channel_id: str): try: - channel_matches_pattern_or_error(channel_id, channel_name) - insert_channel(conn, channel_name, channel_id) + is_valid_channel(channel_id, channel_name) + db.insert_channel(conn, channel_name, channel_id) except AssertionError: logger.exception('Invalid channel values given') raise ValueError('Received invalid values for channel_id or channel_name') @@ -48,35 +44,28 @@ def insert_channel_or_error(conn: sqlite3.Connection, channel_name: str, channel """ COROUTINE WRAPPERS """ -async def write_comment(app, params): # CREATE +async def _create_comment(app, params): # CREATE return await coroutine(create_comment_or_error)(app['writer'], **params) -async def hide_comments(app, comment_ids): # UPDATE - return await coroutine(hide_comments_by_id)(app['writer'], comment_ids) +async def _hide_comments(app, comment_ids): # UPDATE + return await coroutine(db.hide_comments_by_id)(app['writer'], comment_ids) -async def abandon_comment(app, comment_id): # DELETE - return await coroutine(delete_comment_by_id)(app['writer'], comment_id) +async def _edit_comment(**params): + return await coroutine(db.edit_comment_by_id)(**params) + + +async def _abandon_comment(app, comment_id): # DELETE + return await coroutine(db.delete_comment_by_id)(app['writer'], comment_id) """ Core Functions called by request handlers """ -async def abandon_comment_if_authorized(app, comment_id, channel_id, signature, signing_ts, **kwargs): - claim = await get_claim_from_id(app, channel_id) - if not validate_signature_from_claim(claim, signature, signing_ts, comment_id): - return False - - comment = get_comment_or_none(app['reader'], comment_id) - job = await app['comment_scheduler'].spawn(abandon_comment(app, comment_id)) - await app['webhooks'].spawn(send_notification(app, 'DELETE', comment)) - return await job.wait() - - async def create_comment(app, params): if is_valid_base_comment(**params) and is_valid_credential_input(**params): - job = await app['comment_scheduler'].spawn(write_comment(app, params)) + job = await app['comment_scheduler'].spawn(_create_comment(app, params)) comment = await job.wait() if comment: await app['webhooks'].spawn( @@ -87,8 +76,8 @@ async def create_comment(app, params): raise ValueError('base comment is invalid') -async def hide_comments_where_authorized(app, pieces: list) -> list: - comment_cids = get_claim_ids_from_comment_ids( +async def hide_comments(app, pieces: list) -> list: + comment_cids = db.get_claim_ids_from_comment_ids( conn=app['reader'], comment_ids=[p['comment_id'] for p in pieces] ) @@ -104,12 +93,51 @@ async def hide_comments_where_authorized(app, pieces: list) -> list: comments_to_hide.append(p) comment_ids = [c['comment_id'] for c in comments_to_hide] - job = await app['comment_scheduler'].spawn(hide_comments(app, comment_ids)) + job = await app['comment_scheduler'].spawn(_hide_comments(app, comment_ids)) await app['webhooks'].spawn( send_notifications( - app, 'UPDATE', get_comments_by_id(app['reader'], comment_ids) + app, 'UPDATE', db.get_comments_by_id(app['reader'], comment_ids) ) ) await job.wait() return comment_ids + + +async def edit_comment(app, comment_id: str, comment: str, channel_id: str, + channel_name: str, signature: str, signing_ts: str): + if not(is_valid_credential_input(channel_id, channel_name, signature, signing_ts) + and body_is_valid(comment)): + logging.error('Invalid argument values, check input and try again') + return + + cmnt = db.get_comment_or_none(app['reader'], comment_id) + if not(cmnt and 'channel_id' in cmnt and cmnt['channel_id'] == channel_id.lower()): + logging.error("comment doesnt exist") + return + + channel = await get_claim_from_id(app, channel_id) + if not validate_signature_from_claim(channel, signature, signing_ts, comment): + logging.error("signature could not be validated") + return + + job = await app['comment_scheduler'].spawn(_edit_comment( + conn=app['writer'], + comment_id=comment_id, + signature=signature, + signing_ts=signing_ts, + comment=comment + )) + + return await job.wait() + + +async def abandon_comment(app, comment_id, channel_id, signature, signing_ts, **kwargs): + channel = await get_claim_from_id(app, channel_id) + if not validate_signature_from_claim(channel, signature, signing_ts, comment_id): + return False + + comment = db.get_comment_or_none(app['reader'], comment_id) + job = await app['comment_scheduler'].spawn(_abandon_comment(app, comment_id)) + await app['webhooks'].spawn(send_notification(app, 'DELETE', comment)) + return await job.wait() diff --git a/src/main.py b/src/main.py index 859e224..b31bcea 100644 --- a/src/main.py +++ b/src/main.py @@ -1,10 +1,10 @@ -import logging.config -import logging import argparse +import logging +import logging.config import sys -from src.settings import config from src.server.app import run_app +from src.settings import config def config_logging_from_settings(conf): diff --git a/src/server/app.py b/src/server/app.py index 79b4cbd..a25787b 100644 --- a/src/server/app.py +++ b/src/server/app.py @@ -1,19 +1,16 @@ # cython: language_level=3 +import asyncio import logging import pathlib import signal import time -import queue - import aiojobs import aiojobs.aiohttp -import asyncio -import aiohttp from aiohttp import web -from src.database.queries import setup_database, backup_database from src.database.queries import obtain_connection, DatabaseWriter +from src.database.queries import setup_database, backup_database from src.server.handles import api_endpoint, get_api_endpoint logger = logging.getLogger(__name__) diff --git a/src/server/errors.py b/src/server/errors.py new file mode 100644 index 0000000..c0b8af2 --- /dev/null +++ b/src/server/errors.py @@ -0,0 +1,45 @@ +import logging +import aiohttp + + +logger = logging.getLogger(__name__) + + +ERRORS = { + 'INVALID_PARAMS': {'code': -32602, 'message': 'Invalid Method Parameter(s).'}, + 'INTERNAL': {'code': -32603, 'message': 'Internal Server Error. Please notify a LBRY Administrator.'}, + 'METHOD_NOT_FOUND': {'code': -32601, 'message': 'The method does not exist / is not available.'}, + 'INVALID_REQUEST': {'code': -32600, 'message': 'The JSON sent is not a valid Request object.'}, + 'PARSE_ERROR': { + 'code': -32700, + 'message': 'Invalid JSON was received by the server.\n' + 'An error occurred on the server while parsing the JSON text.' + } +} + + +def make_error(error, exc=None) -> dict: + body = ERRORS[error] if error in ERRORS else ERRORS['INTERNAL'] + try: + if exc: + exc_name = type(exc).__name__ + body.update({exc_name: str(exc)}) + + finally: + return body + + +async def report_error(app, exc, msg=''): + try: + if 'slack_webhook' in app['config']: + if msg: + msg = f'"{msg}"' + body = { + "text": f"Got `{type(exc).__name__}`: ```\n{str(exc)}```\n{msg}" + } + async with aiohttp.ClientSession() as sesh: + async with sesh.post(app['config']['slack_webhook'], json=body) as resp: + await resp.wait_for_close() + + except Exception: + logger.critical('Error while logging to slack webhook') \ No newline at end of file diff --git a/src/server/external.py b/src/server/external.py new file mode 100644 index 0000000..aa0e7ee --- /dev/null +++ b/src/server/external.py @@ -0,0 +1,59 @@ +import logging +from json import JSONDecodeError +from typing import List + +import aiohttp +from aiohttp import ClientConnectorError + + +logger = logging.getLogger(__name__) + + +async def send_notifications(app, action: str, comments: List[dict]): + events = create_notification_batch(action, comments) + async with aiohttp.ClientSession() as session: + for event in events: + event.update(auth_token=app['config']['notifications']['auth_token']) + try: + async with session.get(app['config']['notifications']['url'], params=event) as resp: + logger.debug(f'Completed Notification: {await resp.text()}, HTTP Status: {resp.status}') + except Exception: + logger.exception(f'Error requesting internal API, Status {resp.status}: {resp.text()}, ' + f'comment_id: {event["comment_id"]}') + + +async def send_notification(app, action: str, comment: dict): + await send_notifications(app, action, [comment]) + + +def create_notification_batch(action: str, comments: List[dict]) -> List[dict]: + action_type = action[0].capitalize() # to turn Create -> C, edit -> U, delete -> D + events = [] + for comment in comments: + event = { + 'action_type': action_type, + 'comment_id': comment['comment_id'], + 'claim_id': comment['claim_id'] + } + if comment.get('channel_id'): + event['channel_id'] = comment['channel_id'] + events.append(event) + return events + + +async def request_lbrynet(app, method, **params): + body = {'method': method, 'params': {**params}} + try: + async with aiohttp.request('POST', app['config']['lbrynet'], json=body) as req: + try: + resp = await req.json() + except JSONDecodeError as jde: + logger.exception(jde.msg) + raise Exception('JSON Decode Error In lbrynet request') + finally: + if 'result' in resp: + return resp['result'] + raise ValueError('LBRYNET Request Error', {'error': resp['error']}) + except (ConnectionRefusedError, ClientConnectorError): + logger.critical("Connection to the LBRYnet daemon failed, make sure it's running.") + raise Exception("Server cannot verify delete signature") \ No newline at end of file diff --git a/src/server/handles.py b/src/server/handles.py index 222c69a..63503f3 100644 --- a/src/server/handles.py +++ b/src/server/handles.py @@ -1,19 +1,16 @@ +import asyncio import logging import time -import asyncio from aiohttp import web from aiojobs.aiohttp import atomic -from src.server.misc import clean_input_params, report_error -from src.database.queries import get_claim_comments -from src.database.queries import get_comments_by_id, get_comment_ids -from src.database.queries import get_channel_id_from_comment_id -from src.database.queries import get_claim_hidden_comments -from src.server.misc import make_error -from src.database.writes import abandon_comment_if_authorized, create_comment -from src.database.writes import hide_comments_where_authorized - +import src.database.queries as db +from src.database.writes import abandon_comment, create_comment +from src.database.writes import hide_comments +from src.database.writes import edit_comment +from src.server.misc import clean_input_params +from src.server.errors import make_error, report_error logger = logging.getLogger(__name__) @@ -24,35 +21,36 @@ def ping(*args): def handle_get_channel_from_comment_id(app, kwargs: dict): - return get_channel_id_from_comment_id(app['reader'], **kwargs) + return db.get_channel_id_from_comment_id(app['reader'], **kwargs) def handle_get_comment_ids(app, kwargs): - return get_comment_ids(app['reader'], **kwargs) + return db.get_comment_ids(app['reader'], **kwargs) def handle_get_claim_comments(app, kwargs): - return get_claim_comments(app['reader'], **kwargs) + return db.get_claim_comments(app['reader'], **kwargs) def handle_get_comments_by_id(app, kwargs): - return get_comments_by_id(app['reader'], **kwargs) + return db.get_comments_by_id(app['reader'], **kwargs) def handle_get_claim_hidden_comments(app, kwargs): - return get_claim_hidden_comments(app['reader'], **kwargs) - - -async def handle_create_comment(app, params): - return await create_comment(app, params) + return db.get_claim_hidden_comments(app['reader'], **kwargs) async def handle_abandon_comment(app, params): - return {'abandoned': await abandon_comment_if_authorized(app, **params)} + return {'abandoned': await abandon_comment(app, **params)} async def handle_hide_comments(app, params): - return {'hidden': await hide_comments_where_authorized(app, **params)} + return {'hidden': await hide_comments(app, **params)} + + +async def handle_edit_comment(app, params): + if await edit_comment(app, **params): + return db.get_comment_or_none(app['reader'], params['comment_id']) METHODS = { @@ -62,10 +60,11 @@ METHODS = { 'get_comment_ids': handle_get_comment_ids, 'get_comments_by_id': handle_get_comments_by_id, 'get_channel_from_comment_id': handle_get_channel_from_comment_id, - 'create_comment': handle_create_comment, + 'create_comment': create_comment, 'delete_comment': handle_abandon_comment, 'abandon_comment': handle_abandon_comment, - 'hide_comments': handle_hide_comments + 'hide_comments': handle_hide_comments, + 'edit_comment': handle_edit_comment } diff --git a/src/server/misc.py b/src/server/misc.py index 76a3bd2..593e61a 100644 --- a/src/server/misc.py +++ b/src/server/misc.py @@ -1,183 +1,16 @@ -import binascii -import hashlib import logging -import re -from json import JSONDecodeError -from typing import List -import aiohttp -import ecdsa -from aiohttp import ClientConnectorError -from cryptography.exceptions import InvalidSignature -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import hashes -from cryptography.hazmat.primitives.asymmetric import ec -from cryptography.hazmat.primitives.asymmetric.utils import Prehashed -from cryptography.hazmat.primitives.serialization import load_der_public_key +from src.server.external import request_lbrynet logger = logging.getLogger(__name__) ID_LIST = {'claim_id', 'parent_id', 'comment_id', 'channel_id'} -ERRORS = { - 'INVALID_PARAMS': {'code': -32602, 'message': 'Invalid Method Parameter(s).'}, - 'INTERNAL': {'code': -32603, 'message': 'Internal Server Error. Please notify a LBRY Administrator.'}, - 'METHOD_NOT_FOUND': {'code': -32601, 'message': 'The method does not exist / is not available.'}, - 'INVALID_REQUEST': {'code': -32600, 'message': 'The JSON sent is not a valid Request object.'}, - 'PARSE_ERROR': { - 'code': -32700, - 'message': 'Invalid JSON was received by the server.\n' - 'An error occurred on the server while parsing the JSON text.' - } -} - - -def make_error(error, exc=None) -> dict: - body = ERRORS[error] if error in ERRORS else ERRORS['INTERNAL'] - try: - if exc: - exc_name = type(exc).__name__ - body.update({exc_name: str(exc)}) - - finally: - return body - - -async def report_error(app, exc, msg=''): - try: - if 'slack_webhook' in app['config']: - if msg: - msg = f'"{msg}"' - body = { - "text": f"Got `{type(exc).__name__}`: ```\n{str(exc)}```\n{msg}" - } - async with aiohttp.ClientSession() as sesh: - async with sesh.post(app['config']['slack_webhook'], json=body) as resp: - await resp.wait_for_close() - - except Exception: - logger.critical('Error while logging to slack webhook') - - -async def send_notifications(app, action: str, comments: List[dict]): - events = create_notification_batch(action, comments) - async with aiohttp.ClientSession() as session: - for event in events: - event.update(auth_token=app['config']['notifications']['auth_token']) - try: - async with session.get(app['config']['notifications']['url'], params=event) as resp: - logger.debug(f'Completed Notification: {await resp.text()}, HTTP Status: {resp.status}') - except Exception: - logger.exception(f'Error requesting internal API, Status {resp.status}: {resp.text()}, ' - f'comment_id: {event["comment_id"]}') - - -async def send_notification(app, action: str, comment: dict): - await send_notifications(app, action, [comment]) - - -def create_notification_batch(action: str, comments: List[dict]) -> List[dict]: - action_type = action[0].capitalize() # to turn Create -> C, edit -> U, delete -> D - events = [] - for comment in comments: - event = { - 'action_type': action_type, - 'comment_id': comment['comment_id'], - 'claim_id': comment['claim_id'] - } - if comment.get('channel_id'): - event['channel_id'] = comment['channel_id'] - events.append(event) - return events - - -async def request_lbrynet(app, method, **params): - body = {'method': method, 'params': {**params}} - try: - async with aiohttp.request('POST', app['config']['lbrynet'], json=body) as req: - try: - resp = await req.json() - except JSONDecodeError as jde: - logger.exception(jde.msg) - raise Exception('JSON Decode Error In lbrynet request') - finally: - if 'result' in resp: - return resp['result'] - raise ValueError('LBRYNET Request Error', {'error': resp['error']}) - except (ConnectionRefusedError, ClientConnectorError): - logger.critical("Connection to the LBRYnet daemon failed, make sure it's running.") - raise Exception("Server cannot verify delete signature") - async def get_claim_from_id(app, claim_id, **kwargs): return (await request_lbrynet(app, 'claim_search', claim_id=claim_id, **kwargs))['items'][0] -def get_encoded_signature(signature): - signature = signature.encode() if type(signature) is str else signature - r = int(signature[:int(len(signature) / 2)], 16) - s = int(signature[int(len(signature) / 2):], 16) - return ecdsa.util.sigencode_der(r, s, len(signature) * 4) - - -def channel_matches_pattern_or_error(channel_id, channel_name): - assert channel_id and channel_name - assert re.fullmatch( - '^@(?:(?![\x00-\x08\x0b\x0c\x0e-\x1f\x23-\x26' - '\x2f\x3a\x3d\x3f-\x40\uFFFE-\U0000FFFF]).){1,255}$', - channel_name - ) - assert re.fullmatch('([a-f0-9]|[A-F0-9]){40}', channel_id) - return True - - -def is_signature_valid(encoded_signature, signature_digest, public_key_bytes): - try: - public_key = load_der_public_key(public_key_bytes, default_backend()) - public_key.verify(encoded_signature, signature_digest, ec.ECDSA(Prehashed(hashes.SHA256()))) - return True - except (ValueError, InvalidSignature): - logger.exception('Signature validation failed') - return False - - -def is_valid_base_comment(comment, claim_id, parent_id=None, **kwargs): - try: - assert 0 < len(comment) <= 2000 - assert (parent_id is None) or (0 < len(parent_id) <= 2000) - assert re.fullmatch('[a-z0-9]{40}', claim_id) - except Exception: - return False - return True - - -def is_valid_credential_input(channel_id=None, channel_name=None, signature=None, signing_ts=None, **kwargs): - if channel_name or channel_name or signature or signing_ts: - try: - assert channel_matches_pattern_or_error(channel_id, channel_name) - if signature or signing_ts: - assert len(signature) == 128 - assert signing_ts.isalnum() - except Exception: - return False - return True - - -def validate_signature_from_claim(claim, signature, signing_ts, data: str): - try: - if claim: - public_key = claim['value']['public_key'] - claim_hash = binascii.unhexlify(claim['claim_id'].encode())[::-1] - injest = b''.join((signing_ts.encode(), claim_hash, data.encode())) - return is_signature_valid( - encoded_signature=get_encoded_signature(signature), - signature_digest=hashlib.sha256(injest).digest(), - public_key_bytes=binascii.unhexlify(public_key.encode()) - ) - except: - return False - - def clean_input_params(kwargs: dict): for k, v in kwargs.items(): if type(v) is str and k is not 'comment': diff --git a/src/server/validation.py b/src/server/validation.py new file mode 100644 index 0000000..02f19b7 --- /dev/null +++ b/src/server/validation.py @@ -0,0 +1,93 @@ +import logging +import binascii +import hashlib +import re + +import ecdsa +import typing +from cryptography.exceptions import InvalidSignature +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.asymmetric.utils import Prehashed +from cryptography.hazmat.primitives.serialization import load_der_public_key + +logger = logging.getLogger(__name__) + + +def is_valid_channel(channel_id: str, channel_name: str) -> bool: + return channel_id and claim_id_is_valid(channel_id) and \ + channel_name and channel_name_is_valid(channel_name) + + +def is_signature_valid(encoded_signature, signature_digest, public_key_bytes) -> bool: + try: + public_key = load_der_public_key(public_key_bytes, default_backend()) + public_key.verify(encoded_signature, signature_digest, ec.ECDSA(Prehashed(hashes.SHA256()))) + return True + except (ValueError, InvalidSignature): + logger.exception('Signature validation failed') + return False + + +def channel_name_is_valid(channel_name: str) -> bool: + return re.fullmatch( + '@(?:(?![\x00-\x08\x0b\x0c\x0e-\x1f\x23-\x26' + '\x2f\x3a\x3d\x3f-\x40\uFFFE-\U0000FFFF]).){1,255}', + channel_name + ) is not None + + +def body_is_valid(comment: str) -> bool: + return 0 < len(comment) <= 2000 + + +def comment_id_is_valid(comment_id: str) -> bool: + return re.fullmatch('([a-z0-9]{64}|[A-Z0-9]{64})', comment_id) is not None + + +def claim_id_is_valid(claim_id: str) -> bool: + return re.fullmatch('([a-z0-9]{40}|[A-Z0-9]{40})', claim_id) is not None + + +def is_valid_base_comment(comment: str, claim_id: str, parent_id: str = None, **kwargs) -> bool: + return comment is not None and body_is_valid(comment) and \ + claim_id is not None and claim_id_is_valid(claim_id) and \ + (parent_id is None or comment_id_is_valid(parent_id)) + + +def is_valid_credential_input(channel_id: str = None, channel_name: str = None, + signature: str = None, signing_ts: str = None, **kwargs) -> bool: + if channel_id or channel_name or signature or signing_ts: + try: + assert channel_id and channel_name and signature and signing_ts + assert is_valid_channel(channel_id, channel_name) + assert len(signature) == 128 + assert signing_ts.isalnum() + + except Exception: + return False + return True + + +def validate_signature_from_claim(claim: dict, signature: typing.Union[str, bytes], + signing_ts: str, data: str) -> bool: + try: + if claim: + public_key = claim['value']['public_key'] + claim_hash = binascii.unhexlify(claim['claim_id'].encode())[::-1] + injest = b''.join((signing_ts.encode(), claim_hash, data.encode())) + return is_signature_valid( + encoded_signature=get_encoded_signature(signature), + signature_digest=hashlib.sha256(injest).digest(), + public_key_bytes=binascii.unhexlify(public_key.encode()) + ) + except: + return False + + +def get_encoded_signature(signature: typing.Union[str, bytes]) -> bytes: + signature = signature.encode() if type(signature) is str else signature + r = int(signature[:int(len(signature) / 2)], 16) + s = int(signature[int(len(signature) / 2):], 16) + return ecdsa.util.sigencode_der(r, s, len(signature) * 4) \ No newline at end of file diff --git a/test/test_server.py b/test/test_server.py index 72c794e..3c659d9 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -1,6 +1,7 @@ import os +import random + import aiohttp -import re from itertools import * import faker @@ -10,6 +11,9 @@ from faker.providers import misc from src.settings import config from src.server import app +from src.server.validation import is_valid_channel +from src.server.validation import is_valid_base_comment + from test.testcase import AsyncioTestCase @@ -23,10 +27,22 @@ fake.add_provider(lorem) fake.add_provider(misc) -def fake_lbryusername(): +def fake_lbryusername() -> str: return '@' + fake.user_name() +def nothing(): + pass + + +def fake_signature() -> str: + return fake.sha256() + fake.sha256() + + +def fake_signing_ts() -> str: + return str(random.randint(1, 2**32 - 1)) + + async def jsonrpc_post(url, method, **params): json_body = { 'jsonrpc': '2.0', @@ -38,17 +54,14 @@ async def jsonrpc_post(url, method, **params): return await request.json() -def nothing(): - pass - - replace = { 'claim_id': fake.sha1, 'comment': fake.text, 'channel_id': fake.sha1, 'channel_name': fake_lbryusername, - 'signature': fake.uuid4, - 'parent_id': fake.sha256 + 'signature': fake_signature, + 'signing_ts': fake_signing_ts, + 'parent_id': fake.sha256, } @@ -84,37 +97,22 @@ class ServerTest(AsyncioTestCase): async def post_comment(self, **params): return await jsonrpc_post(self.url, 'create_comment', **params) - def is_valid_message(self, comment=None, claim_id=None, parent_id=None, + @staticmethod + def is_valid_message(comment=None, claim_id=None, parent_id=None, channel_name=None, channel_id=None, signature=None, signing_ts=None): try: - assert comment is not None and claim_id is not None - assert re.fullmatch('([a-f0-9]|[A-F0-9]){40}', claim_id) - assert 0 < len(comment) <= 2000 - if parent_id is not None: - assert re.fullmatch('([a-f0-9]){64}', parent_id) + assert is_valid_base_comment(comment, claim_id, parent_id) if channel_name or channel_id or signature or signing_ts: - assert channel_id is not None and channel_name is not None - assert re.fullmatch('([a-f0-9]|[A-F0-9]){40}', channel_id) - assert self.valid_channel_name(channel_name) - assert (signature is None and signing_ts is None) or \ - (signature is not None and signing_ts is not None) - if signature: - assert len(signature) == 128 - if parent_id: - assert parent_id.isalnum() + assert channel_id and channel_name and signature and signing_ts + assert is_valid_channel(channel_id, channel_name) + assert len(signature) == 128 + assert signing_ts.isalnum() + except Exception: return False return True - @staticmethod - def valid_channel_name(channel_name): - return re.fullmatch( - '^@(?:(?![\x00-\x08\x0b\x0c\x0e-\x1f\x23-\x26' - '\x2f\x3a\x3d\x3f-\x40\uFFFE-\U0000FFFF]).){1,255}$', - channel_name - ) - async def test01CreateCommentNoReply(self): anonymous_test = create_test_comments( ('claim_id', 'channel_id', 'channel_name', 'comment'), @@ -170,6 +168,8 @@ class ServerTest(AsyncioTestCase): channel_id=fake.sha1(), comment='Hello everybody and welcome back to my chan nel', claim_id=claim_id, + signing_ts='1234', + signature='_'*128 ) parent_id = parent_comment['result']['comment_id'] test_all = create_test_comments( @@ -200,7 +200,8 @@ class ListCommentsTest(AsyncioTestCase): 'comment': fake.text, 'channel_id': fake.sha1, 'channel_name': fake_lbryusername, - 'signature': nothing, + 'signature': fake_signature, + 'signing_ts': fake_signing_ts, 'parent_id': nothing }