Adds signing + comment deletion framework

This commit is contained in:
Oleg Silkin 2019-07-19 00:33:34 -04:00
parent d28eafab29
commit 3c8cc2520b
4 changed files with 175 additions and 67 deletions

View file

@ -82,17 +82,8 @@ def get_claim_comments(conn: sqlite3.Connection, claim_id: str, parent_id: str =
} }
def insert_channel(conn: sqlite3.Connection, channel_name: str, channel_id: str): def insert_comment(conn: sqlite3.Connection, claim_id: str, comment: str, parent_id: str = None,
with conn: channel_id: str = None, signature: str = None, signing_ts: str = None) -> str:
conn.execute(
'INSERT INTO CHANNEL(ClaimId, Name) VALUES (?, ?)',
(channel_id, channel_name)
)
def insert_comment(conn: sqlite3.Connection, claim_id: str = None, comment: str = None,
channel_id: str = None, signature: str = None, signing_ts: str = None,
parent_id: str = None) -> str:
timestamp = int(time.time()) timestamp = int(time.time())
prehash = b':'.join((claim_id.encode(), comment.encode(), str(timestamp).encode(),)) prehash = b':'.join((claim_id.encode(), comment.encode(), str(timestamp).encode(),))
comment_id = nacl.hash.sha256(prehash).decode() comment_id = nacl.hash.sha256(prehash).decode()
@ -156,12 +147,42 @@ def get_comments_by_id(conn, comment_ids: list) -> typing.Union[list, None]:
)] )]
def delete_comment_by_id(conn: sqlite3.Connection, comment_id: str): def delete_anonymous_comment_by_id(conn: sqlite3.Connection, comment_id: str):
with conn: with conn:
if conn.execute('SELECT 1 FROM COMMENT WHERE CommentId = ? LIMIT 1', (comment_id,)).fetchone(): curs = conn.execute(
conn.execute("DELETE FROM COMMENT WHERE CommentId = ?", (comment_id,)) "DELETE FROM COMMENT WHERE ChannelId IS NULL AND CommentId = ?",
return True (comment_id,)
return False )
return curs.rowcount
def delete_channel_comment_by_id(conn: sqlite3.Connection, comment_id: str, channel_id: str):
with conn:
curs = conn.execute(
"DELETE FROM COMMENT WHERE ChannelId = ? AND CommentId = ?",
(channel_id, comment_id)
)
return curs.rowcount
def insert_channel(conn: sqlite3.Connection, channel_name: str, channel_id: str):
with conn:
conn.execute(
'INSERT INTO CHANNEL(ClaimId, Name) VALUES (?, ?)',
(channel_id, channel_name)
)
def get_channel_from_comment_id(conn: sqlite3.Connection, comment_id: str):
with conn:
channel = conn.execute("""
SELECT CHN.ClaimId AS channel_id, CHN.Name AS channel_name
FROM CHANNEL AS CHN, COMMENT AS CMT
WHERE CHN.ClaimId = CMT.ChannelId AND CMT.CommentId = ?
LIMIT 1
""", (comment_id,)
).fetchone()
return dict(channel) if channel else dict()
class DatabaseWriter(object): class DatabaseWriter(object):
@ -184,4 +205,4 @@ class DatabaseWriter(object):
@property @property
def connection(self): def connection(self):
return self.conn return self.conn

View file

@ -2,33 +2,33 @@
import json import json
import logging import logging
import aiojobs
import asyncio import asyncio
from aiohttp import web from aiohttp import web
from aiojobs.aiohttp import atomic from aiojobs.aiohttp import atomic
from asyncio import coroutine from asyncio import coroutine
from src.database import DatabaseWriter from misc import clean_input_params, ERRORS
from src.database import get_claim_comments from src.database import get_claim_comments
from src.database import get_comments_by_id, get_comment_ids from src.database import get_comments_by_id, get_comment_ids
from src.database import get_channel_from_comment_id
from src.database import obtain_connection from src.database import obtain_connection
from src.database import delete_channel_comment_by_id
from src.writes import create_comment_or_error from src.writes import create_comment_or_error
from src.misc import is_authentic_delete_signal
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ERRORS = {
'INVALID_PARAMS': {'code': -32602, 'message': 'Invalid parameters'},
'INTERNAL': {'code': -32603, 'message': 'An internal error'},
'UNKNOWN': {'code': -1, 'message': 'An unknown or very miscellaneous error'},
}
ID_LIST = {'claim_id', 'parent_id', 'comment_id', 'channel_id'} # noinspection PyUnusedLocal
def ping(*args):
def ping(*args, **kwargs):
return 'pong' return 'pong'
def handle_get_channel_from_comment_id(app, kwargs: dict):
with obtain_connection(app['db_path']) as conn:
return get_channel_from_comment_id(conn, **kwargs)
def handle_get_comment_ids(app, kwargs): def handle_get_comment_ids(app, kwargs):
with obtain_connection(app['db_path']) as conn: with obtain_connection(app['db_path']) as conn:
return get_comment_ids(conn, **kwargs) return get_comment_ids(conn, **kwargs)
@ -44,37 +44,40 @@ def handle_get_comments_by_id(app, kwargs):
return get_comments_by_id(conn, **kwargs) return get_comments_by_id(conn, **kwargs)
async def create_comment_scheduler(): async def write_comment(app, comment):
return await aiojobs.create_scheduler(limit=1, pending_limit=0) return await coroutine(create_comment_or_error)(app['writer'], **comment)
async def write_comment(comment): async def handle_create_comment(app, params):
with DatabaseWriter._writer.connection as conn: job = await app['comment_scheduler'].spawn(write_comment(app, params))
return await coroutine(create_comment_or_error)(conn, **comment)
async def handle_create_comment(scheduler, comment):
job = await scheduler.spawn(write_comment(comment))
return await job.wait() return await job.wait()
async def delete_comment_if_authorized(app, comment_id, channel_name, channel_id, signature):
authorized = await is_authentic_delete_signal(app, comment_id, channel_name, channel_id, signature)
if not authorized:
return {'deleted': False}
delete_query = delete_channel_comment_by_id(app['writer'], comment_id, channel_id)
job = await app['comment_scheduler'].spawn(delete_query)
return {'deleted': await job.wait()}
async def handle_delete_comment(app, params):
return await delete_comment_if_authorized(app, **params)
METHODS = { METHODS = {
'ping': ping, 'ping': ping,
'get_claim_comments': handle_get_claim_comments, 'get_claim_comments': handle_get_claim_comments,
'get_comment_ids': handle_get_comment_ids, 'get_comment_ids': handle_get_comment_ids,
'get_comments_by_id': handle_get_comments_by_id, 'get_comments_by_id': handle_get_comments_by_id,
'create_comment': handle_create_comment 'get_channel_from_comment_id': handle_get_channel_from_comment_id,
'create_comment': handle_create_comment,
'delete_comment': handle_delete_comment,
} }
def clean_input_params(kwargs: dict):
for k, v in kwargs.items():
if type(v) is str:
kwargs[k] = v.strip()
if k in ID_LIST:
kwargs[k] = v.lower()
async def process_json(app, body: dict) -> dict: async def process_json(app, body: dict) -> dict:
response = {'jsonrpc': '2.0', 'id': body['id']} response = {'jsonrpc': '2.0', 'id': body['id']}
if body['method'] in METHODS: if body['method'] in METHODS:
@ -83,7 +86,7 @@ async def process_json(app, body: dict) -> dict:
clean_input_params(params) clean_input_params(params)
try: try:
if asyncio.iscoroutinefunction(METHODS[method]): if asyncio.iscoroutinefunction(METHODS[method]):
result = await METHODS[method](app['comment_scheduler'], params) result = await METHODS[method](app, params)
else: else:
result = METHODS[method](app, params) result = METHODS[method](app, params)
response['result'] = result response['result'] = result
@ -93,8 +96,11 @@ async def process_json(app, body: dict) -> dict:
except ValueError as ve: except ValueError as ve:
logger.exception('Got ValueError: %s', ve) logger.exception('Got ValueError: %s', ve)
response['error'] = ERRORS['INVALID_PARAMS'] response['error'] = ERRORS['INVALID_PARAMS']
except Exception as e:
logger.exception('Got unknown exception: %s', e)
response['error'] = ERRORS['INTERNAL']
else: else:
response['error'] = ERRORS['UNKNOWN'] response['error'] = ERRORS['METHOD_NOT_FOUND']
return response return response
@ -105,16 +111,21 @@ async def api_endpoint(request: web.Request):
body = await request.json() body = await request.json()
if type(body) is list or type(body) is dict: if type(body) is list or type(body) is dict:
if type(body) is list: if type(body) is list:
# for batching
return web.json_response( return web.json_response(
[await process_json(request.app, part) for part in body] [await process_json(request.app, part) for part in body]
) )
else: else:
return web.json_response(await process_json(request.app, body)) return web.json_response(await process_json(request.app, body))
else: else:
return web.json_response({'error': ERRORS['UNKNOWN']}) logger.warning('Got invalid request from %s: %s', request.remote, body)
return web.json_response({'error': ERRORS['INVALID_REQUEST']})
except json.decoder.JSONDecodeError as jde: except json.decoder.JSONDecodeError as jde:
logger.exception('Received malformed JSON from %s: %s', request.remote, jde.msg) logger.exception('Received malformed JSON from %s: %s', request.remote, jde.msg)
logger.debug('Request headers: %s', request.headers) logger.debug('Request headers: %s', request.headers)
return web.json_response({ return web.json_response({
'error': {'message': jde.msg, 'code': -1} 'error': ERRORS['PARSE_ERROR']
}) })
except Exception as e:
logger.exception('Exception raised by request from %s: %s', request.remote, e)
return web.json_response({'error': ERRORS['INVALID_REQUEST']})

View file

@ -1,7 +1,37 @@
import binascii
import logging
import re import re
from json import JSONDecodeError
from nacl.hash import sha256
import aiohttp
import ecdsa
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.serialization import load_der_public_key
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.asymmetric.utils import Prehashed
from cryptography.exceptions import InvalidSignature
logger = logging.getLogger(__name__)
ID_LIST = {'claim_id', 'parent_id', 'comment_id', 'channel_id'}
ERRORS = {
'INVALID_PARAMS': {'code': -32602, 'message': 'Invalid Method Parameter(s).'},
'INTERNAL': {'code': -32603, 'message': 'Internal Server Error.'},
'METHOD_NOT_FOUND': {'code': -32601, 'message': 'The method does not exist / is not available.'},
'INVALID_REQUEST': {'code': -32600, 'message': 'The JSON sent is not a valid Request object.'},
'PARSE_ERROR': {
'code': -32700,
'message': 'Invalid JSON was received by the server.\n'
'An error occurred on the server while parsing the JSON text.'
}
}
def validate_channel(channel_id: str, channel_name: str): def channel_matches_pattern(channel_id: str, channel_name: str):
assert channel_id and channel_name assert channel_id and channel_name
assert type(channel_id) is str and type(channel_name) is str assert type(channel_id) is str and type(channel_name) is str
assert re.fullmatch( assert re.fullmatch(
@ -12,11 +42,61 @@ def validate_channel(channel_id: str, channel_name: str):
assert re.fullmatch('[a-z0-9]{40}', channel_id) assert re.fullmatch('[a-z0-9]{40}', channel_id)
async def resolve_channel_claim(app: dict, channel_id: str, channel_name: str):
lbry_url = f'lbry://{channel_name}#{channel_id}'
resolve_body = {
'method': 'resolve',
'params': {
'urls': [lbry_url, ]
}
}
async with aiohttp.request('POST', app['config']['LBRYNET'], json=resolve_body) as req:
try:
resp = await req.json()
return resp.get(lbry_url)
except JSONDecodeError as jde:
logger.exception(jde.msg)
raise Exception(jde)
def is_signature_valid(encoded_signature, signature_digest, public_key_bytes):
try:
public_key = load_der_public_key(public_key_bytes, default_backend())
public_key.verify(encoded_signature, signature_digest, ec.ECDSA(Prehashed(hashes.SHA256())))
return True
except (ValueError, InvalidSignature) as err:
logger.debug('Signature Valiadation Failed: %s', err)
return False
def get_encoded_signature(signature):
signature = signature.encode() if type(signature) is str else signature
r = int(signature[:int(len(signature) / 2)], 16)
s = int(signature[int(len(signature) / 2):], 16)
return ecdsa.util.sigencode_der(r, s, len(signature) * 4)
def validate_base_comment(comment: str, claim_id: str, **kwargs): def validate_base_comment(comment: str, claim_id: str, **kwargs):
assert 0 < len(comment) <= 2000 assert 0 < len(comment) <= 2000
assert re.fullmatch('[a-z0-9]{40}', claim_id) assert re.fullmatch('[a-z0-9]{40}', claim_id)
async def validate_signature(*args, **kwargs): async def is_authentic_delete_signal(app, comment_id: str, channel_name: str, channel_id: str, signature: str):
pass claim = await resolve_channel_claim(app, channel_id, channel_name)
if claim:
public_key = claim['value']['public_key']
claim_hash = binascii.unhexlify(claim['claim_id'].encode())[::-1]
return is_signature_valid(
encoded_signature=get_encoded_signature(signature),
signature_digest=sha256(b''.join([comment_id.encode(), claim_hash])),
public_key_bytes=binascii.unhexlify(public_key.encode())
)
return False
def clean_input_params(kwargs: dict):
for k, v in kwargs.items():
if type(v) is str:
kwargs[k] = v.strip()
if k in ID_LIST:
kwargs[k] = v.lower()

View file

@ -4,30 +4,26 @@ import sqlite3
from src.database import get_comment_or_none from src.database import get_comment_or_none
from src.database import insert_comment from src.database import insert_comment
from src.database import insert_channel from src.database import insert_channel
from src.misc import validate_channel from src.misc import channel_matches_pattern
from src.misc import validate_signature
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def create_comment_or_error(conn: sqlite3.Connection, comment: str, claim_id: str, channel_id: str = None, def create_comment_or_error(conn, comment, claim_id, channel_id=None, channel_name=None,
channel_name: str = None, signature: str = None, signing_ts: str = None, parent_id: str = None): signature=None, signing_ts=None, parent_id=None) -> dict:
if channel_id or channel_name or signature or signing_ts: if channel_id or channel_name or signature or signing_ts:
validate_signature(signature, signing_ts, comment, channel_name, channel_id) # validate_signature(signature, signing_ts, comment, channel_name, channel_id)
insert_channel_or_error(conn, channel_name, channel_id) insert_channel_or_error(conn, channel_name, channel_id)
try: comment_id = insert_comment(
comment_id = insert_comment( conn=conn, comment=comment, claim_id=claim_id, channel_id=channel_id,
conn=conn, comment=comment, claim_id=claim_id, channel_id=channel_id, signature=signature, parent_id=parent_id, signing_ts=signing_ts
signature=signature, parent_id=parent_id, signing_ts=signing_ts )
) return get_comment_or_none(conn, comment_id)
return get_comment_or_none(conn, comment_id)
except sqlite3.IntegrityError as ie:
logger.exception(ie)
def insert_channel_or_error(conn: sqlite3.Connection, channel_name: str, channel_id: str): def insert_channel_or_error(conn: sqlite3.Connection, channel_name: str, channel_id: str):
try: try:
validate_channel(channel_id, channel_name) channel_matches_pattern(channel_id, channel_name)
insert_channel(conn, channel_name, channel_id) insert_channel(conn, channel_name, channel_id)
except AssertionError as ae: except AssertionError as ae:
logger.exception('Invalid channel values given: %s', ae) logger.exception('Invalid channel values given: %s', ae)