From 3c8cc2520b4ab9dcea8a5e9a1f0d5f23166555c2 Mon Sep 17 00:00:00 2001 From: Oleg Silkin Date: Fri, 19 Jul 2019 00:33:34 -0400 Subject: [PATCH] Adds signing + comment deletion framework --- src/database.py | 55 +++++++++++++++++++++---------- src/handles.py | 77 ++++++++++++++++++++++++------------------- src/misc.py | 86 +++++++++++++++++++++++++++++++++++++++++++++++-- src/writes.py | 24 ++++++-------- 4 files changed, 175 insertions(+), 67 deletions(-) diff --git a/src/database.py b/src/database.py index 6b5f4e8..50edca3 100644 --- a/src/database.py +++ b/src/database.py @@ -82,17 +82,8 @@ def get_claim_comments(conn: sqlite3.Connection, claim_id: str, parent_id: str = } -def insert_channel(conn: sqlite3.Connection, channel_name: str, channel_id: str): - with conn: - conn.execute( - 'INSERT INTO CHANNEL(ClaimId, Name) VALUES (?, ?)', - (channel_id, channel_name) - ) - - -def insert_comment(conn: sqlite3.Connection, claim_id: str = None, comment: str = None, - channel_id: str = None, signature: str = None, signing_ts: str = None, - parent_id: str = None) -> str: +def insert_comment(conn: sqlite3.Connection, claim_id: str, comment: str, parent_id: str = None, + channel_id: str = None, signature: str = None, signing_ts: str = None) -> str: timestamp = int(time.time()) prehash = b':'.join((claim_id.encode(), comment.encode(), str(timestamp).encode(),)) comment_id = nacl.hash.sha256(prehash).decode() @@ -156,12 +147,42 @@ def get_comments_by_id(conn, comment_ids: list) -> typing.Union[list, None]: )] -def delete_comment_by_id(conn: sqlite3.Connection, comment_id: str): +def delete_anonymous_comment_by_id(conn: sqlite3.Connection, comment_id: str): with conn: - if conn.execute('SELECT 1 FROM COMMENT WHERE CommentId = ? LIMIT 1', (comment_id,)).fetchone(): - conn.execute("DELETE FROM COMMENT WHERE CommentId = ?", (comment_id,)) - return True - return False + curs = conn.execute( + "DELETE FROM COMMENT WHERE ChannelId IS NULL AND CommentId = ?", + (comment_id,) + ) + return curs.rowcount + + +def delete_channel_comment_by_id(conn: sqlite3.Connection, comment_id: str, channel_id: str): + with conn: + curs = conn.execute( + "DELETE FROM COMMENT WHERE ChannelId = ? AND CommentId = ?", + (channel_id, comment_id) + ) + return curs.rowcount + + +def insert_channel(conn: sqlite3.Connection, channel_name: str, channel_id: str): + with conn: + conn.execute( + 'INSERT INTO CHANNEL(ClaimId, Name) VALUES (?, ?)', + (channel_id, channel_name) + ) + + +def get_channel_from_comment_id(conn: sqlite3.Connection, comment_id: str): + with conn: + channel = conn.execute(""" + SELECT CHN.ClaimId AS channel_id, CHN.Name AS channel_name + FROM CHANNEL AS CHN, COMMENT AS CMT + WHERE CHN.ClaimId = CMT.ChannelId AND CMT.CommentId = ? + LIMIT 1 + """, (comment_id,) + ).fetchone() + return dict(channel) if channel else dict() class DatabaseWriter(object): @@ -184,4 +205,4 @@ class DatabaseWriter(object): @property def connection(self): - return self.conn \ No newline at end of file + return self.conn diff --git a/src/handles.py b/src/handles.py index 71f2e53..5c202f0 100644 --- a/src/handles.py +++ b/src/handles.py @@ -2,33 +2,33 @@ import json import logging -import aiojobs import asyncio from aiohttp import web from aiojobs.aiohttp import atomic from asyncio import coroutine -from src.database import DatabaseWriter +from misc import clean_input_params, ERRORS from src.database import get_claim_comments from src.database import get_comments_by_id, get_comment_ids +from src.database import get_channel_from_comment_id from src.database import obtain_connection +from src.database import delete_channel_comment_by_id from src.writes import create_comment_or_error +from src.misc import is_authentic_delete_signal logger = logging.getLogger(__name__) -ERRORS = { - 'INVALID_PARAMS': {'code': -32602, 'message': 'Invalid parameters'}, - 'INTERNAL': {'code': -32603, 'message': 'An internal error'}, - 'UNKNOWN': {'code': -1, 'message': 'An unknown or very miscellaneous error'}, -} -ID_LIST = {'claim_id', 'parent_id', 'comment_id', 'channel_id'} - - -def ping(*args, **kwargs): +# noinspection PyUnusedLocal +def ping(*args): return 'pong' +def handle_get_channel_from_comment_id(app, kwargs: dict): + with obtain_connection(app['db_path']) as conn: + return get_channel_from_comment_id(conn, **kwargs) + + def handle_get_comment_ids(app, kwargs): with obtain_connection(app['db_path']) as conn: return get_comment_ids(conn, **kwargs) @@ -44,37 +44,40 @@ def handle_get_comments_by_id(app, kwargs): return get_comments_by_id(conn, **kwargs) -async def create_comment_scheduler(): - return await aiojobs.create_scheduler(limit=1, pending_limit=0) +async def write_comment(app, comment): + return await coroutine(create_comment_or_error)(app['writer'], **comment) -async def write_comment(comment): - with DatabaseWriter._writer.connection as conn: - return await coroutine(create_comment_or_error)(conn, **comment) - - -async def handle_create_comment(scheduler, comment): - job = await scheduler.spawn(write_comment(comment)) +async def handle_create_comment(app, params): + job = await app['comment_scheduler'].spawn(write_comment(app, params)) return await job.wait() +async def delete_comment_if_authorized(app, comment_id, channel_name, channel_id, signature): + authorized = await is_authentic_delete_signal(app, comment_id, channel_name, channel_id, signature) + if not authorized: + return {'deleted': False} + + delete_query = delete_channel_comment_by_id(app['writer'], comment_id, channel_id) + job = await app['comment_scheduler'].spawn(delete_query) + return {'deleted': await job.wait()} + + +async def handle_delete_comment(app, params): + return await delete_comment_if_authorized(app, **params) + + METHODS = { 'ping': ping, 'get_claim_comments': handle_get_claim_comments, 'get_comment_ids': handle_get_comment_ids, 'get_comments_by_id': handle_get_comments_by_id, - 'create_comment': handle_create_comment + 'get_channel_from_comment_id': handle_get_channel_from_comment_id, + 'create_comment': handle_create_comment, + 'delete_comment': handle_delete_comment, } -def clean_input_params(kwargs: dict): - for k, v in kwargs.items(): - if type(v) is str: - kwargs[k] = v.strip() - if k in ID_LIST: - kwargs[k] = v.lower() - - async def process_json(app, body: dict) -> dict: response = {'jsonrpc': '2.0', 'id': body['id']} if body['method'] in METHODS: @@ -83,7 +86,7 @@ async def process_json(app, body: dict) -> dict: clean_input_params(params) try: if asyncio.iscoroutinefunction(METHODS[method]): - result = await METHODS[method](app['comment_scheduler'], params) + result = await METHODS[method](app, params) else: result = METHODS[method](app, params) response['result'] = result @@ -93,8 +96,11 @@ async def process_json(app, body: dict) -> dict: except ValueError as ve: logger.exception('Got ValueError: %s', ve) response['error'] = ERRORS['INVALID_PARAMS'] + except Exception as e: + logger.exception('Got unknown exception: %s', e) + response['error'] = ERRORS['INTERNAL'] else: - response['error'] = ERRORS['UNKNOWN'] + response['error'] = ERRORS['METHOD_NOT_FOUND'] return response @@ -105,16 +111,21 @@ async def api_endpoint(request: web.Request): body = await request.json() if type(body) is list or type(body) is dict: if type(body) is list: + # for batching return web.json_response( [await process_json(request.app, part) for part in body] ) else: return web.json_response(await process_json(request.app, body)) else: - return web.json_response({'error': ERRORS['UNKNOWN']}) + logger.warning('Got invalid request from %s: %s', request.remote, body) + return web.json_response({'error': ERRORS['INVALID_REQUEST']}) except json.decoder.JSONDecodeError as jde: logger.exception('Received malformed JSON from %s: %s', request.remote, jde.msg) logger.debug('Request headers: %s', request.headers) return web.json_response({ - 'error': {'message': jde.msg, 'code': -1} + 'error': ERRORS['PARSE_ERROR'] }) + except Exception as e: + logger.exception('Exception raised by request from %s: %s', request.remote, e) + return web.json_response({'error': ERRORS['INVALID_REQUEST']}) diff --git a/src/misc.py b/src/misc.py index c333587..32ce4ea 100644 --- a/src/misc.py +++ b/src/misc.py @@ -1,7 +1,37 @@ +import binascii +import logging import re +from json import JSONDecodeError + +from nacl.hash import sha256 +import aiohttp + +import ecdsa +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.serialization import load_der_public_key +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.asymmetric.utils import Prehashed +from cryptography.exceptions import InvalidSignature + +logger = logging.getLogger(__name__) + +ID_LIST = {'claim_id', 'parent_id', 'comment_id', 'channel_id'} + +ERRORS = { + 'INVALID_PARAMS': {'code': -32602, 'message': 'Invalid Method Parameter(s).'}, + 'INTERNAL': {'code': -32603, 'message': 'Internal Server Error.'}, + 'METHOD_NOT_FOUND': {'code': -32601, 'message': 'The method does not exist / is not available.'}, + 'INVALID_REQUEST': {'code': -32600, 'message': 'The JSON sent is not a valid Request object.'}, + 'PARSE_ERROR': { + 'code': -32700, + 'message': 'Invalid JSON was received by the server.\n' + 'An error occurred on the server while parsing the JSON text.' + } +} -def validate_channel(channel_id: str, channel_name: str): +def channel_matches_pattern(channel_id: str, channel_name: str): assert channel_id and channel_name assert type(channel_id) is str and type(channel_name) is str assert re.fullmatch( @@ -12,11 +42,61 @@ def validate_channel(channel_id: str, channel_name: str): assert re.fullmatch('[a-z0-9]{40}', channel_id) +async def resolve_channel_claim(app: dict, channel_id: str, channel_name: str): + lbry_url = f'lbry://{channel_name}#{channel_id}' + resolve_body = { + 'method': 'resolve', + 'params': { + 'urls': [lbry_url, ] + } + } + async with aiohttp.request('POST', app['config']['LBRYNET'], json=resolve_body) as req: + try: + resp = await req.json() + return resp.get(lbry_url) + except JSONDecodeError as jde: + logger.exception(jde.msg) + raise Exception(jde) + + +def is_signature_valid(encoded_signature, signature_digest, public_key_bytes): + try: + public_key = load_der_public_key(public_key_bytes, default_backend()) + public_key.verify(encoded_signature, signature_digest, ec.ECDSA(Prehashed(hashes.SHA256()))) + return True + except (ValueError, InvalidSignature) as err: + logger.debug('Signature Valiadation Failed: %s', err) + return False + + +def get_encoded_signature(signature): + signature = signature.encode() if type(signature) is str else signature + r = int(signature[:int(len(signature) / 2)], 16) + s = int(signature[int(len(signature) / 2):], 16) + return ecdsa.util.sigencode_der(r, s, len(signature) * 4) + def validate_base_comment(comment: str, claim_id: str, **kwargs): assert 0 < len(comment) <= 2000 assert re.fullmatch('[a-z0-9]{40}', claim_id) -async def validate_signature(*args, **kwargs): - pass +async def is_authentic_delete_signal(app, comment_id: str, channel_name: str, channel_id: str, signature: str): + claim = await resolve_channel_claim(app, channel_id, channel_name) + if claim: + public_key = claim['value']['public_key'] + claim_hash = binascii.unhexlify(claim['claim_id'].encode())[::-1] + return is_signature_valid( + encoded_signature=get_encoded_signature(signature), + signature_digest=sha256(b''.join([comment_id.encode(), claim_hash])), + public_key_bytes=binascii.unhexlify(public_key.encode()) + ) + return False + + +def clean_input_params(kwargs: dict): + for k, v in kwargs.items(): + if type(v) is str: + kwargs[k] = v.strip() + if k in ID_LIST: + kwargs[k] = v.lower() diff --git a/src/writes.py b/src/writes.py index 83cc0e3..9375e98 100644 --- a/src/writes.py +++ b/src/writes.py @@ -4,30 +4,26 @@ import sqlite3 from src.database import get_comment_or_none from src.database import insert_comment from src.database import insert_channel -from src.misc import validate_channel -from src.misc import validate_signature +from src.misc import channel_matches_pattern logger = logging.getLogger(__name__) -def create_comment_or_error(conn: sqlite3.Connection, comment: str, claim_id: str, channel_id: str = None, - channel_name: str = None, signature: str = None, signing_ts: str = None, parent_id: str = None): +def create_comment_or_error(conn, comment, claim_id, channel_id=None, channel_name=None, + signature=None, signing_ts=None, parent_id=None) -> dict: if channel_id or channel_name or signature or signing_ts: - validate_signature(signature, signing_ts, comment, channel_name, channel_id) + # validate_signature(signature, signing_ts, comment, channel_name, channel_id) insert_channel_or_error(conn, channel_name, channel_id) - try: - comment_id = insert_comment( - conn=conn, comment=comment, claim_id=claim_id, channel_id=channel_id, - signature=signature, parent_id=parent_id, signing_ts=signing_ts - ) - return get_comment_or_none(conn, comment_id) - except sqlite3.IntegrityError as ie: - logger.exception(ie) + comment_id = insert_comment( + conn=conn, comment=comment, claim_id=claim_id, channel_id=channel_id, + signature=signature, parent_id=parent_id, signing_ts=signing_ts + ) + return get_comment_or_none(conn, comment_id) def insert_channel_or_error(conn: sqlite3.Connection, channel_name: str, channel_id: str): try: - validate_channel(channel_id, channel_name) + channel_matches_pattern(channel_id, channel_name) insert_channel(conn, channel_name, channel_id) except AssertionError as ae: logger.exception('Invalid channel values given: %s', ae)