Adds signing + comment deletion framework
This commit is contained in:
parent
d28eafab29
commit
3c8cc2520b
4 changed files with 175 additions and 67 deletions
|
@ -82,17 +82,8 @@ def get_claim_comments(conn: sqlite3.Connection, claim_id: str, parent_id: str =
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def insert_channel(conn: sqlite3.Connection, channel_name: str, channel_id: str):
|
def insert_comment(conn: sqlite3.Connection, claim_id: str, comment: str, parent_id: str = None,
|
||||||
with conn:
|
channel_id: str = None, signature: str = None, signing_ts: str = None) -> str:
|
||||||
conn.execute(
|
|
||||||
'INSERT INTO CHANNEL(ClaimId, Name) VALUES (?, ?)',
|
|
||||||
(channel_id, channel_name)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def insert_comment(conn: sqlite3.Connection, claim_id: str = None, comment: str = None,
|
|
||||||
channel_id: str = None, signature: str = None, signing_ts: str = None,
|
|
||||||
parent_id: str = None) -> str:
|
|
||||||
timestamp = int(time.time())
|
timestamp = int(time.time())
|
||||||
prehash = b':'.join((claim_id.encode(), comment.encode(), str(timestamp).encode(),))
|
prehash = b':'.join((claim_id.encode(), comment.encode(), str(timestamp).encode(),))
|
||||||
comment_id = nacl.hash.sha256(prehash).decode()
|
comment_id = nacl.hash.sha256(prehash).decode()
|
||||||
|
@ -156,12 +147,42 @@ def get_comments_by_id(conn, comment_ids: list) -> typing.Union[list, None]:
|
||||||
)]
|
)]
|
||||||
|
|
||||||
|
|
||||||
def delete_comment_by_id(conn: sqlite3.Connection, comment_id: str):
|
def delete_anonymous_comment_by_id(conn: sqlite3.Connection, comment_id: str):
|
||||||
with conn:
|
with conn:
|
||||||
if conn.execute('SELECT 1 FROM COMMENT WHERE CommentId = ? LIMIT 1', (comment_id,)).fetchone():
|
curs = conn.execute(
|
||||||
conn.execute("DELETE FROM COMMENT WHERE CommentId = ?", (comment_id,))
|
"DELETE FROM COMMENT WHERE ChannelId IS NULL AND CommentId = ?",
|
||||||
return True
|
(comment_id,)
|
||||||
return False
|
)
|
||||||
|
return curs.rowcount
|
||||||
|
|
||||||
|
|
||||||
|
def delete_channel_comment_by_id(conn: sqlite3.Connection, comment_id: str, channel_id: str):
|
||||||
|
with conn:
|
||||||
|
curs = conn.execute(
|
||||||
|
"DELETE FROM COMMENT WHERE ChannelId = ? AND CommentId = ?",
|
||||||
|
(channel_id, comment_id)
|
||||||
|
)
|
||||||
|
return curs.rowcount
|
||||||
|
|
||||||
|
|
||||||
|
def insert_channel(conn: sqlite3.Connection, channel_name: str, channel_id: str):
|
||||||
|
with conn:
|
||||||
|
conn.execute(
|
||||||
|
'INSERT INTO CHANNEL(ClaimId, Name) VALUES (?, ?)',
|
||||||
|
(channel_id, channel_name)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_channel_from_comment_id(conn: sqlite3.Connection, comment_id: str):
|
||||||
|
with conn:
|
||||||
|
channel = conn.execute("""
|
||||||
|
SELECT CHN.ClaimId AS channel_id, CHN.Name AS channel_name
|
||||||
|
FROM CHANNEL AS CHN, COMMENT AS CMT
|
||||||
|
WHERE CHN.ClaimId = CMT.ChannelId AND CMT.CommentId = ?
|
||||||
|
LIMIT 1
|
||||||
|
""", (comment_id,)
|
||||||
|
).fetchone()
|
||||||
|
return dict(channel) if channel else dict()
|
||||||
|
|
||||||
|
|
||||||
class DatabaseWriter(object):
|
class DatabaseWriter(object):
|
||||||
|
@ -184,4 +205,4 @@ class DatabaseWriter(object):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def connection(self):
|
def connection(self):
|
||||||
return self.conn
|
return self.conn
|
||||||
|
|
|
@ -2,33 +2,33 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import aiojobs
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from aiojobs.aiohttp import atomic
|
from aiojobs.aiohttp import atomic
|
||||||
from asyncio import coroutine
|
from asyncio import coroutine
|
||||||
|
|
||||||
from src.database import DatabaseWriter
|
from misc import clean_input_params, ERRORS
|
||||||
from src.database import get_claim_comments
|
from src.database import get_claim_comments
|
||||||
from src.database import get_comments_by_id, get_comment_ids
|
from src.database import get_comments_by_id, get_comment_ids
|
||||||
|
from src.database import get_channel_from_comment_id
|
||||||
from src.database import obtain_connection
|
from src.database import obtain_connection
|
||||||
|
from src.database import delete_channel_comment_by_id
|
||||||
from src.writes import create_comment_or_error
|
from src.writes import create_comment_or_error
|
||||||
|
from src.misc import is_authentic_delete_signal
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ERRORS = {
|
|
||||||
'INVALID_PARAMS': {'code': -32602, 'message': 'Invalid parameters'},
|
|
||||||
'INTERNAL': {'code': -32603, 'message': 'An internal error'},
|
|
||||||
'UNKNOWN': {'code': -1, 'message': 'An unknown or very miscellaneous error'},
|
|
||||||
}
|
|
||||||
|
|
||||||
ID_LIST = {'claim_id', 'parent_id', 'comment_id', 'channel_id'}
|
# noinspection PyUnusedLocal
|
||||||
|
def ping(*args):
|
||||||
|
|
||||||
def ping(*args, **kwargs):
|
|
||||||
return 'pong'
|
return 'pong'
|
||||||
|
|
||||||
|
|
||||||
|
def handle_get_channel_from_comment_id(app, kwargs: dict):
|
||||||
|
with obtain_connection(app['db_path']) as conn:
|
||||||
|
return get_channel_from_comment_id(conn, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def handle_get_comment_ids(app, kwargs):
|
def handle_get_comment_ids(app, kwargs):
|
||||||
with obtain_connection(app['db_path']) as conn:
|
with obtain_connection(app['db_path']) as conn:
|
||||||
return get_comment_ids(conn, **kwargs)
|
return get_comment_ids(conn, **kwargs)
|
||||||
|
@ -44,37 +44,40 @@ def handle_get_comments_by_id(app, kwargs):
|
||||||
return get_comments_by_id(conn, **kwargs)
|
return get_comments_by_id(conn, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
async def create_comment_scheduler():
|
async def write_comment(app, comment):
|
||||||
return await aiojobs.create_scheduler(limit=1, pending_limit=0)
|
return await coroutine(create_comment_or_error)(app['writer'], **comment)
|
||||||
|
|
||||||
|
|
||||||
async def write_comment(comment):
|
async def handle_create_comment(app, params):
|
||||||
with DatabaseWriter._writer.connection as conn:
|
job = await app['comment_scheduler'].spawn(write_comment(app, params))
|
||||||
return await coroutine(create_comment_or_error)(conn, **comment)
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_create_comment(scheduler, comment):
|
|
||||||
job = await scheduler.spawn(write_comment(comment))
|
|
||||||
return await job.wait()
|
return await job.wait()
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_comment_if_authorized(app, comment_id, channel_name, channel_id, signature):
|
||||||
|
authorized = await is_authentic_delete_signal(app, comment_id, channel_name, channel_id, signature)
|
||||||
|
if not authorized:
|
||||||
|
return {'deleted': False}
|
||||||
|
|
||||||
|
delete_query = delete_channel_comment_by_id(app['writer'], comment_id, channel_id)
|
||||||
|
job = await app['comment_scheduler'].spawn(delete_query)
|
||||||
|
return {'deleted': await job.wait()}
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_delete_comment(app, params):
|
||||||
|
return await delete_comment_if_authorized(app, **params)
|
||||||
|
|
||||||
|
|
||||||
METHODS = {
|
METHODS = {
|
||||||
'ping': ping,
|
'ping': ping,
|
||||||
'get_claim_comments': handle_get_claim_comments,
|
'get_claim_comments': handle_get_claim_comments,
|
||||||
'get_comment_ids': handle_get_comment_ids,
|
'get_comment_ids': handle_get_comment_ids,
|
||||||
'get_comments_by_id': handle_get_comments_by_id,
|
'get_comments_by_id': handle_get_comments_by_id,
|
||||||
'create_comment': handle_create_comment
|
'get_channel_from_comment_id': handle_get_channel_from_comment_id,
|
||||||
|
'create_comment': handle_create_comment,
|
||||||
|
'delete_comment': handle_delete_comment,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def clean_input_params(kwargs: dict):
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
if type(v) is str:
|
|
||||||
kwargs[k] = v.strip()
|
|
||||||
if k in ID_LIST:
|
|
||||||
kwargs[k] = v.lower()
|
|
||||||
|
|
||||||
|
|
||||||
async def process_json(app, body: dict) -> dict:
|
async def process_json(app, body: dict) -> dict:
|
||||||
response = {'jsonrpc': '2.0', 'id': body['id']}
|
response = {'jsonrpc': '2.0', 'id': body['id']}
|
||||||
if body['method'] in METHODS:
|
if body['method'] in METHODS:
|
||||||
|
@ -83,7 +86,7 @@ async def process_json(app, body: dict) -> dict:
|
||||||
clean_input_params(params)
|
clean_input_params(params)
|
||||||
try:
|
try:
|
||||||
if asyncio.iscoroutinefunction(METHODS[method]):
|
if asyncio.iscoroutinefunction(METHODS[method]):
|
||||||
result = await METHODS[method](app['comment_scheduler'], params)
|
result = await METHODS[method](app, params)
|
||||||
else:
|
else:
|
||||||
result = METHODS[method](app, params)
|
result = METHODS[method](app, params)
|
||||||
response['result'] = result
|
response['result'] = result
|
||||||
|
@ -93,8 +96,11 @@ async def process_json(app, body: dict) -> dict:
|
||||||
except ValueError as ve:
|
except ValueError as ve:
|
||||||
logger.exception('Got ValueError: %s', ve)
|
logger.exception('Got ValueError: %s', ve)
|
||||||
response['error'] = ERRORS['INVALID_PARAMS']
|
response['error'] = ERRORS['INVALID_PARAMS']
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception('Got unknown exception: %s', e)
|
||||||
|
response['error'] = ERRORS['INTERNAL']
|
||||||
else:
|
else:
|
||||||
response['error'] = ERRORS['UNKNOWN']
|
response['error'] = ERRORS['METHOD_NOT_FOUND']
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@ -105,16 +111,21 @@ async def api_endpoint(request: web.Request):
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
if type(body) is list or type(body) is dict:
|
if type(body) is list or type(body) is dict:
|
||||||
if type(body) is list:
|
if type(body) is list:
|
||||||
|
# for batching
|
||||||
return web.json_response(
|
return web.json_response(
|
||||||
[await process_json(request.app, part) for part in body]
|
[await process_json(request.app, part) for part in body]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return web.json_response(await process_json(request.app, body))
|
return web.json_response(await process_json(request.app, body))
|
||||||
else:
|
else:
|
||||||
return web.json_response({'error': ERRORS['UNKNOWN']})
|
logger.warning('Got invalid request from %s: %s', request.remote, body)
|
||||||
|
return web.json_response({'error': ERRORS['INVALID_REQUEST']})
|
||||||
except json.decoder.JSONDecodeError as jde:
|
except json.decoder.JSONDecodeError as jde:
|
||||||
logger.exception('Received malformed JSON from %s: %s', request.remote, jde.msg)
|
logger.exception('Received malformed JSON from %s: %s', request.remote, jde.msg)
|
||||||
logger.debug('Request headers: %s', request.headers)
|
logger.debug('Request headers: %s', request.headers)
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
'error': {'message': jde.msg, 'code': -1}
|
'error': ERRORS['PARSE_ERROR']
|
||||||
})
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception('Exception raised by request from %s: %s', request.remote, e)
|
||||||
|
return web.json_response({'error': ERRORS['INVALID_REQUEST']})
|
||||||
|
|
86
src/misc.py
86
src/misc.py
|
@ -1,7 +1,37 @@
|
||||||
|
import binascii
|
||||||
|
import logging
|
||||||
import re
|
import re
|
||||||
|
from json import JSONDecodeError
|
||||||
|
|
||||||
|
from nacl.hash import sha256
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
import ecdsa
|
||||||
|
from cryptography.hazmat.backends import default_backend
|
||||||
|
from cryptography.hazmat.primitives.serialization import load_der_public_key
|
||||||
|
from cryptography.hazmat.primitives import hashes
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import ec
|
||||||
|
from cryptography.hazmat.primitives.asymmetric.utils import Prehashed
|
||||||
|
from cryptography.exceptions import InvalidSignature
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
ID_LIST = {'claim_id', 'parent_id', 'comment_id', 'channel_id'}
|
||||||
|
|
||||||
|
ERRORS = {
|
||||||
|
'INVALID_PARAMS': {'code': -32602, 'message': 'Invalid Method Parameter(s).'},
|
||||||
|
'INTERNAL': {'code': -32603, 'message': 'Internal Server Error.'},
|
||||||
|
'METHOD_NOT_FOUND': {'code': -32601, 'message': 'The method does not exist / is not available.'},
|
||||||
|
'INVALID_REQUEST': {'code': -32600, 'message': 'The JSON sent is not a valid Request object.'},
|
||||||
|
'PARSE_ERROR': {
|
||||||
|
'code': -32700,
|
||||||
|
'message': 'Invalid JSON was received by the server.\n'
|
||||||
|
'An error occurred on the server while parsing the JSON text.'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def validate_channel(channel_id: str, channel_name: str):
|
def channel_matches_pattern(channel_id: str, channel_name: str):
|
||||||
assert channel_id and channel_name
|
assert channel_id and channel_name
|
||||||
assert type(channel_id) is str and type(channel_name) is str
|
assert type(channel_id) is str and type(channel_name) is str
|
||||||
assert re.fullmatch(
|
assert re.fullmatch(
|
||||||
|
@ -12,11 +42,61 @@ def validate_channel(channel_id: str, channel_name: str):
|
||||||
assert re.fullmatch('[a-z0-9]{40}', channel_id)
|
assert re.fullmatch('[a-z0-9]{40}', channel_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def resolve_channel_claim(app: dict, channel_id: str, channel_name: str):
|
||||||
|
lbry_url = f'lbry://{channel_name}#{channel_id}'
|
||||||
|
resolve_body = {
|
||||||
|
'method': 'resolve',
|
||||||
|
'params': {
|
||||||
|
'urls': [lbry_url, ]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
async with aiohttp.request('POST', app['config']['LBRYNET'], json=resolve_body) as req:
|
||||||
|
try:
|
||||||
|
resp = await req.json()
|
||||||
|
return resp.get(lbry_url)
|
||||||
|
except JSONDecodeError as jde:
|
||||||
|
logger.exception(jde.msg)
|
||||||
|
raise Exception(jde)
|
||||||
|
|
||||||
|
|
||||||
|
def is_signature_valid(encoded_signature, signature_digest, public_key_bytes):
|
||||||
|
try:
|
||||||
|
public_key = load_der_public_key(public_key_bytes, default_backend())
|
||||||
|
public_key.verify(encoded_signature, signature_digest, ec.ECDSA(Prehashed(hashes.SHA256())))
|
||||||
|
return True
|
||||||
|
except (ValueError, InvalidSignature) as err:
|
||||||
|
logger.debug('Signature Valiadation Failed: %s', err)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_encoded_signature(signature):
|
||||||
|
signature = signature.encode() if type(signature) is str else signature
|
||||||
|
r = int(signature[:int(len(signature) / 2)], 16)
|
||||||
|
s = int(signature[int(len(signature) / 2):], 16)
|
||||||
|
return ecdsa.util.sigencode_der(r, s, len(signature) * 4)
|
||||||
|
|
||||||
def validate_base_comment(comment: str, claim_id: str, **kwargs):
|
def validate_base_comment(comment: str, claim_id: str, **kwargs):
|
||||||
assert 0 < len(comment) <= 2000
|
assert 0 < len(comment) <= 2000
|
||||||
assert re.fullmatch('[a-z0-9]{40}', claim_id)
|
assert re.fullmatch('[a-z0-9]{40}', claim_id)
|
||||||
|
|
||||||
|
|
||||||
async def validate_signature(*args, **kwargs):
|
async def is_authentic_delete_signal(app, comment_id: str, channel_name: str, channel_id: str, signature: str):
|
||||||
pass
|
claim = await resolve_channel_claim(app, channel_id, channel_name)
|
||||||
|
if claim:
|
||||||
|
public_key = claim['value']['public_key']
|
||||||
|
claim_hash = binascii.unhexlify(claim['claim_id'].encode())[::-1]
|
||||||
|
return is_signature_valid(
|
||||||
|
encoded_signature=get_encoded_signature(signature),
|
||||||
|
signature_digest=sha256(b''.join([comment_id.encode(), claim_hash])),
|
||||||
|
public_key_bytes=binascii.unhexlify(public_key.encode())
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def clean_input_params(kwargs: dict):
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if type(v) is str:
|
||||||
|
kwargs[k] = v.strip()
|
||||||
|
if k in ID_LIST:
|
||||||
|
kwargs[k] = v.lower()
|
||||||
|
|
||||||
|
|
|
@ -4,30 +4,26 @@ import sqlite3
|
||||||
from src.database import get_comment_or_none
|
from src.database import get_comment_or_none
|
||||||
from src.database import insert_comment
|
from src.database import insert_comment
|
||||||
from src.database import insert_channel
|
from src.database import insert_channel
|
||||||
from src.misc import validate_channel
|
from src.misc import channel_matches_pattern
|
||||||
from src.misc import validate_signature
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def create_comment_or_error(conn: sqlite3.Connection, comment: str, claim_id: str, channel_id: str = None,
|
def create_comment_or_error(conn, comment, claim_id, channel_id=None, channel_name=None,
|
||||||
channel_name: str = None, signature: str = None, signing_ts: str = None, parent_id: str = None):
|
signature=None, signing_ts=None, parent_id=None) -> dict:
|
||||||
if channel_id or channel_name or signature or signing_ts:
|
if channel_id or channel_name or signature or signing_ts:
|
||||||
validate_signature(signature, signing_ts, comment, channel_name, channel_id)
|
# validate_signature(signature, signing_ts, comment, channel_name, channel_id)
|
||||||
insert_channel_or_error(conn, channel_name, channel_id)
|
insert_channel_or_error(conn, channel_name, channel_id)
|
||||||
try:
|
comment_id = insert_comment(
|
||||||
comment_id = insert_comment(
|
conn=conn, comment=comment, claim_id=claim_id, channel_id=channel_id,
|
||||||
conn=conn, comment=comment, claim_id=claim_id, channel_id=channel_id,
|
signature=signature, parent_id=parent_id, signing_ts=signing_ts
|
||||||
signature=signature, parent_id=parent_id, signing_ts=signing_ts
|
)
|
||||||
)
|
return get_comment_or_none(conn, comment_id)
|
||||||
return get_comment_or_none(conn, comment_id)
|
|
||||||
except sqlite3.IntegrityError as ie:
|
|
||||||
logger.exception(ie)
|
|
||||||
|
|
||||||
|
|
||||||
def insert_channel_or_error(conn: sqlite3.Connection, channel_name: str, channel_id: str):
|
def insert_channel_or_error(conn: sqlite3.Connection, channel_name: str, channel_id: str):
|
||||||
try:
|
try:
|
||||||
validate_channel(channel_id, channel_name)
|
channel_matches_pattern(channel_id, channel_name)
|
||||||
insert_channel(conn, channel_name, channel_id)
|
insert_channel(conn, channel_name, channel_id)
|
||||||
except AssertionError as ae:
|
except AssertionError as ae:
|
||||||
logger.exception('Invalid channel values given: %s', ae)
|
logger.exception('Invalid channel values given: %s', ae)
|
||||||
|
|
Loading…
Reference in a new issue