Merge pull request #3 from lbryio/lbrynet-service-integration

Activates Comment Deletion functionality + adds & fixes a ton of stuff
This commit is contained in:
Oleg Silkin 2019-07-30 01:06:32 -04:00 committed by GitHub
commit d8c96f35d0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 209 additions and 184 deletions

View file

@ -1,10 +1,7 @@
{ {
"PATH": { "PATH": {
"SCHEMA": "src/schema/comments_ddl.sql", "SCHEMA": "src/schema/comments_ddl.sql",
"MAIN": "database/comments.db", "DATABASE": "database/comments.db",
"BACKUP": "database/comments.backup.db",
"DEFAULT": "database/default.db",
"TEST": "tests/test.db",
"ERROR_LOG": "logs/error.log", "ERROR_LOG": "logs/error.log",
"DEBUG_LOG": "logs/debug.log", "DEBUG_LOG": "logs/debug.log",
"SERVER_LOG": "logs/server.log" "SERVER_LOG": "logs/server.log"

View file

@ -0,0 +1,34 @@
import unittest
from multiprocessing.pool import Pool
import requests
class ConcurrentWriteTest(unittest.TestCase):
@staticmethod
def make_comment(num):
return {
'jsonrpc': '2.0',
'id': num,
'method': 'create_comment',
'params': {
'comment': f'Comment #{num}',
'claim_id': '6d266af6c25c80fa2ac6cc7662921ad2e90a07e7',
}
}
@staticmethod
def send_comment_to_server(params):
with requests.post(params[0], json=params[1]) as req:
return req.json()
def test01Concurrency(self):
urls = [f'http://localhost:{port}/api' for port in range(5921, 5925)]
comments = [self.make_comment(i) for i in range(1, 5)]
inputs = list(zip(urls, comments))
with Pool(4) as pool:
results = pool.map(self.send_comment_to_server, inputs)
results = list(filter(lambda x: 'comment_id' in x['result'], results))
self.assertIsNotNone(results)
self.assertEqual(len(results), len(inputs))

View file

@ -75,6 +75,7 @@ def main(argv=None):
parser = argparse.ArgumentParser(description='LBRY Comment Server') parser = argparse.ArgumentParser(description='LBRY Comment Server')
parser.add_argument('--port', type=int) parser.add_argument('--port', type=int)
args = parser.parse_args(argv) args = parser.parse_args(argv)
config_logging_from_settings(config)
if args.port: if args.port:
config['PORT'] = args.port config['PORT'] = args.port
config_logging_from_settings(config) config_logging_from_settings(config)

View file

@ -1,4 +1,3 @@
PRAGMA FOREIGN_KEYS = ON; PRAGMA FOREIGN_KEYS = ON;
-- tables -- tables
@ -6,7 +5,8 @@ PRAGMA FOREIGN_KEYS = ON;
-- DROP TABLE IF EXISTS CHANNEL; -- DROP TABLE IF EXISTS CHANNEL;
-- DROP TABLE IF EXISTS COMMENT; -- DROP TABLE IF EXISTS COMMENT;
CREATE TABLE IF NOT EXISTS COMMENT ( CREATE TABLE IF NOT EXISTS COMMENT
(
CommentId TEXT NOT NULL, CommentId TEXT NOT NULL,
LbryClaimId TEXT NOT NULL, LbryClaimId TEXT NOT NULL,
ChannelId TEXT DEFAULT NULL, ChannelId TEXT DEFAULT NULL,
@ -26,7 +26,8 @@ CREATE TABLE IF NOT EXISTS COMMENT (
-- ALTER TABLE COMMENT ADD COLUMN SigningTs TEXT DEFAULT NULL; -- ALTER TABLE COMMENT ADD COLUMN SigningTs TEXT DEFAULT NULL;
-- DROP TABLE IF EXISTS CHANNEL; -- DROP TABLE IF EXISTS CHANNEL;
CREATE TABLE IF NOT EXISTS CHANNEL( CREATE TABLE IF NOT EXISTS CHANNEL
(
ClaimId TEXT NOT NULL, ClaimId TEXT NOT NULL,
Name TEXT NOT NULL, Name TEXT NOT NULL,
CONSTRAINT CHANNEL_PK PRIMARY KEY (ClaimId) CONSTRAINT CHANNEL_PK PRIMARY KEY (ClaimId)
@ -42,8 +43,18 @@ CREATE INDEX IF NOT EXISTS CHANNEL_COMMENT_INDEX ON COMMENT(ChannelId, CommentId
-- VIEWS -- VIEWS
DROP VIEW IF EXISTS COMMENTS_ON_CLAIMS; DROP VIEW IF EXISTS COMMENTS_ON_CLAIMS;
CREATE VIEW IF NOT EXISTS COMMENTS_ON_CLAIMS (comment_id, claim_id, timestamp, channel_name, channel_id, channel_url, signature, signing_ts, parent_id, comment) AS CREATE VIEW IF NOT EXISTS COMMENTS_ON_CLAIMS (comment_id, claim_id, timestamp, channel_name, channel_id, channel_url,
SELECT C.CommentId, C.LbryClaimId, C.Timestamp, CHAN.Name, CHAN.ClaimId, 'lbry://' || CHAN.Name || '#' || CHAN.ClaimId, C.Signature, C.SigningTs, C.ParentId, C.Body signature, signing_ts, parent_id, comment) AS
SELECT C.CommentId,
C.LbryClaimId,
C.Timestamp,
CHAN.Name,
CHAN.ClaimId,
'lbry://' || CHAN.Name || '#' || CHAN.ClaimId,
C.Signature,
C.SigningTs,
C.ParentId,
C.Body
FROM COMMENT AS C FROM COMMENT AS C
LEFT OUTER JOIN CHANNEL CHAN on C.ChannelId = CHAN.ClaimId LEFT OUTER JOIN CHANNEL CHAN on C.ChannelId = CHAN.ClaimId
ORDER BY C.Timestamp DESC; ORDER BY C.Timestamp DESC;
@ -52,7 +63,8 @@ CREATE VIEW IF NOT EXISTS COMMENTS_ON_CLAIMS (comment_id, claim_id, timestamp, c
DROP VIEW IF EXISTS COMMENT_REPLIES; DROP VIEW IF EXISTS COMMENT_REPLIES;
CREATE VIEW IF NOT EXISTS COMMENT_REPLIES (Author, CommentBody, ParentAuthor, ParentCommentBody) AS CREATE VIEW IF NOT EXISTS COMMENT_REPLIES (Author, CommentBody, ParentAuthor, ParentCommentBody) AS
SELECT AUTHOR.Name, OG.Body, PCHAN.Name, PARENT.Body FROM COMMENT AS OG SELECT AUTHOR.Name, OG.Body, PCHAN.Name, PARENT.Body
FROM COMMENT AS OG
JOIN COMMENT AS PARENT JOIN COMMENT AS PARENT
ON OG.ParentId = PARENT.CommentId ON OG.ParentId = PARENT.CommentId
JOIN CHANNEL AS PCHAN ON PARENT.ChannelId = PCHAN.ClaimId JOIN CHANNEL AS PCHAN ON PARENT.ChannelId = PCHAN.ClaimId

View file

@ -1,7 +1,6 @@
# cython: language_level=3 # cython: language_level=3
import logging import logging
import pathlib import pathlib
import re
import signal import signal
import time import time
@ -25,11 +24,6 @@ async def setup_db_schema(app):
logger.info(f'Database already exists in {app["db_path"]}, skipping setup') logger.info(f'Database already exists in {app["db_path"]}, skipping setup')
async def close_comment_scheduler(app):
logger.info('Closing comment_scheduler')
await app['comment_scheduler'].close()
async def database_backup_routine(app): async def database_backup_routine(app):
try: try:
while True: while True:
@ -43,29 +37,41 @@ async def database_backup_routine(app):
async def start_background_tasks(app): async def start_background_tasks(app):
app['reader'] = obtain_connection(app['db_path'], True) app['reader'] = obtain_connection(app['db_path'], True)
app['waitful_backup'] = app.loop.create_task(database_backup_routine(app)) app['waitful_backup'] = asyncio.create_task(database_backup_routine(app))
app['comment_scheduler'] = await aiojobs.create_scheduler(limit=1, pending_limit=0) app['comment_scheduler'] = await aiojobs.create_scheduler(limit=1, pending_limit=0)
app['db_writer'] = DatabaseWriter(app['db_path']) app['db_writer'] = DatabaseWriter(app['db_path'])
app['writer'] = app['db_writer'].connection app['writer'] = app['db_writer'].connection
async def stop_background_tasks(app): async def close_database_connections(app):
logger.info('Ending background backup loop') logger.info('Ending background backup loop')
app['waitful_backup'].cancel() app['waitful_backup'].cancel()
await app['waitful_backup'] await app['waitful_backup']
app['reader'].close() app['reader'].close()
app['writer'].close() app['writer'].close()
app['db_writer'].cleanup()
async def close_comment_scheduler(app):
logger.info('Closing comment_scheduler')
await app['comment_scheduler'].close()
class CommentDaemon: class CommentDaemon:
def __init__(self, config, db_path=None, **kwargs): def __init__(self, config, db_file=None, backup=None, **kwargs):
self.config = config self.config = config
app = web.Application() app = web.Application()
self.insert_to_config(app, config, db_file=db_path) app['config'] = config
if db_file:
app['db_path'] = db_file
app['backup'] = backup
else:
app['db_path'] = config['PATH']['DATABASE']
app['backup'] = backup or (app['db_path'] + '.backup')
app.on_startup.append(setup_db_schema) app.on_startup.append(setup_db_schema)
app.on_startup.append(start_background_tasks) app.on_startup.append(start_background_tasks)
app.on_shutdown.append(stop_background_tasks)
app.on_shutdown.append(close_comment_scheduler) app.on_shutdown.append(close_comment_scheduler)
app.on_cleanup.append(close_database_connections)
aiojobs.aiohttp.setup(app, **kwargs) aiojobs.aiohttp.setup(app, **kwargs)
app.add_routes([ app.add_routes([
web.post('/api', api_endpoint), web.post('/api', api_endpoint),
@ -73,36 +79,28 @@ class CommentDaemon:
web.get('/api', get_api_endpoint) web.get('/api', get_api_endpoint)
]) ])
self.app = app self.app = app
self.app_runner = web.AppRunner(app) self.app_runner = None
self.app_site = None self.app_site = None
async def start(self): async def start(self, host=None, port=None):
self.app['START_TIME'] = time.time() self.app['START_TIME'] = time.time()
self.app_runner = web.AppRunner(self.app)
await self.app_runner.setup() await self.app_runner.setup()
self.app_site = web.TCPSite( self.app_site = web.TCPSite(
runner=self.app_runner, runner=self.app_runner,
host=self.config['HOST'], host=host or self.config['HOST'],
port=self.config['PORT'], port=port or self.config['PORT'],
) )
await self.app_site.start() await self.app_site.start()
logger.info(f'Comment Server is running on {self.config["HOST"]}:{self.config["PORT"]}') logger.info(f'Comment Server is running on {self.config["HOST"]}:{self.config["PORT"]}')
async def stop(self): async def stop(self):
await self.app.shutdown() await self.app_runner.shutdown()
await self.app.cleanup()
await self.app_runner.cleanup() await self.app_runner.cleanup()
@staticmethod
def insert_to_config(app, conf=None, db_file=None):
db_file = db_file if db_file else 'DEFAULT'
app['config'] = conf
app['db_path'] = conf['PATH'][db_file]
app['backup'] = re.sub(r'\.db$', '.backup.db', app['db_path'])
assert app['db_path'] != app['backup']
def run_app(config, db_file=None):
def run_app(config): comment_app = CommentDaemon(config=config, db_file=db_file, close_timeout=5.0)
comment_app = CommentDaemon(config=config, db_path='DEFAULT', close_timeout=5.0)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()

View file

@ -194,8 +194,8 @@ class DatabaseWriter(object):
def cleanup(self): def cleanup(self):
logging.info('Cleaning up database writer') logging.info('Cleaning up database writer')
DatabaseWriter._writer = None
self.conn.close() self.conn.close()
DatabaseWriter._writer = None
@property @property
def connection(self): def connection(self):

View file

@ -10,7 +10,6 @@ from src.server.misc import clean_input_params
from src.server.database import get_claim_comments from src.server.database import get_claim_comments
from src.server.database import get_comments_by_id, get_comment_ids from src.server.database import get_comments_by_id, get_comment_ids
from src.server.database import get_channel_id_from_comment_id from src.server.database import get_channel_id_from_comment_id
from src.server.database import obtain_connection
from src.server.misc import is_valid_base_comment from src.server.misc import is_valid_base_comment
from src.server.misc import is_valid_credential_input from src.server.misc import is_valid_credential_input
from src.server.misc import make_error from src.server.misc import make_error
@ -63,7 +62,7 @@ METHODS = {
'get_comments_by_id': handle_get_comments_by_id, 'get_comments_by_id': handle_get_comments_by_id,
'get_channel_from_comment_id': handle_get_channel_from_comment_id, 'get_channel_from_comment_id': handle_get_channel_from_comment_id,
'create_comment': handle_create_comment, 'create_comment': handle_create_comment,
# 'delete_comment': handle_delete_comment, 'delete_comment': handle_delete_comment,
# 'abandon_comment': handle_delete_comment, # 'abandon_comment': handle_delete_comment,
} }
@ -75,8 +74,8 @@ async def process_json(app, body: dict) -> dict:
params = body.get('params', {}) params = body.get('params', {})
clean_input_params(params) clean_input_params(params)
logger.debug(f'Received Method {method}, params: {params}') logger.debug(f'Received Method {method}, params: {params}')
try:
start = time.time() start = time.time()
try:
if asyncio.iscoroutinefunction(METHODS[method]): if asyncio.iscoroutinefunction(METHODS[method]):
result = await METHODS[method](app, params) result = await METHODS[method](app, params)
else: else:
@ -99,6 +98,7 @@ async def process_json(app, body: dict) -> dict:
@atomic @atomic
async def api_endpoint(request: web.Request): async def api_endpoint(request: web.Request):
try: try:
web.access_logger.info(f'Forwarded headers: {request.forwarded}')
body = await request.json() body = await request.json()
if type(body) is list or type(body) is dict: if type(body) is list or type(body) is dict:
if type(body) is list: if type(body) is list:
@ -109,8 +109,6 @@ async def api_endpoint(request: web.Request):
else: else:
return web.json_response(await process_json(request.app, body)) return web.json_response(await process_json(request.app, body))
except Exception as e: except Exception as e:
logger.exception(f'Exception raised by request from {request.remote}: {e}')
logger.debug(f'Request headers: {request.headers}')
return make_error('INVALID_REQUEST', e) return make_error('INVALID_REQUEST', e)

View file

@ -7,6 +7,7 @@ import hashlib
import aiohttp import aiohttp
import ecdsa import ecdsa
from aiohttp import ClientConnectorError
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.serialization import load_der_public_key from cryptography.hazmat.primitives.serialization import load_der_public_key
from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import hashes
@ -20,7 +21,7 @@ ID_LIST = {'claim_id', 'parent_id', 'comment_id', 'channel_id'}
ERRORS = { ERRORS = {
'INVALID_PARAMS': {'code': -32602, 'message': 'Invalid Method Parameter(s).'}, 'INVALID_PARAMS': {'code': -32602, 'message': 'Invalid Method Parameter(s).'},
'INTERNAL': {'code': -32603, 'message': 'Internal Server Error.'}, 'INTERNAL': {'code': -32603, 'message': 'Internal Server Error. Please notify a LBRY Administrator.'},
'METHOD_NOT_FOUND': {'code': -32601, 'message': 'The method does not exist / is not available.'}, 'METHOD_NOT_FOUND': {'code': -32601, 'message': 'The method does not exist / is not available.'},
'INVALID_REQUEST': {'code': -32600, 'message': 'The JSON sent is not a valid Request object.'}, 'INVALID_REQUEST': {'code': -32600, 'message': 'The JSON sent is not a valid Request object.'},
'PARSE_ERROR': { 'PARSE_ERROR': {
@ -35,21 +36,15 @@ def make_error(error, exc=None) -> dict:
body = ERRORS[error] if error in ERRORS else ERRORS['INTERNAL'] body = ERRORS[error] if error in ERRORS else ERRORS['INTERNAL']
try: try:
if exc: if exc:
body.update({ body.update({type(exc).__name__: str(exc)})
type(exc).__name__: str(exc)
})
finally: finally:
return body return body
async def resolve_channel_claim(app, channel_id, channel_name): async def resolve_channel_claim(app, channel_id, channel_name):
lbry_url = f'lbry://{channel_name}#{channel_id}' lbry_url = f'lbry://{channel_name}#{channel_id}'
resolve_body = { resolve_body = {'method': 'resolve', 'params': {'urls': [lbry_url]}}
'method': 'resolve', try:
'params': {
'urls': [lbry_url, ]
}
}
async with aiohttp.request('POST', app['config']['LBRYNET'], json=resolve_body) as req: async with aiohttp.request('POST', app['config']['LBRYNET'], json=resolve_body) as req:
try: try:
resp = await req.json() resp = await req.json()
@ -60,6 +55,9 @@ async def resolve_channel_claim(app, channel_id, channel_name):
if 'result' in resp: if 'result' in resp:
return resp['result'].get(lbry_url) return resp['result'].get(lbry_url)
raise ValueError('claim resolution yields error', {'error': resp['error']}) raise ValueError('claim resolution yields error', {'error': resp['error']})
except (ConnectionRefusedError, ClientConnectorError):
logger.critical("Connection to the LBRYnet daemon failed, make sure it's running.")
raise Exception("Server cannot verify delete signature")
def get_encoded_signature(signature): def get_encoded_signature(signature):
@ -85,8 +83,8 @@ def is_signature_valid(encoded_signature, signature_digest, public_key_bytes):
public_key = load_der_public_key(public_key_bytes, default_backend()) public_key = load_der_public_key(public_key_bytes, default_backend())
public_key.verify(encoded_signature, signature_digest, ec.ECDSA(Prehashed(hashes.SHA256()))) public_key.verify(encoded_signature, signature_digest, ec.ECDSA(Prehashed(hashes.SHA256())))
return True return True
except (ValueError, InvalidSignature) as err: except (ValueError, InvalidSignature):
logger.debug('Signature Valiadation Failed: %s', err) logger.exception('Signature validation failed')
return False return False
@ -112,12 +110,12 @@ def is_valid_credential_input(channel_id=None, channel_name=None, signature=None
return True return True
async def is_authentic_delete_signal(app, comment_id, channel_name, channel_id, signature): async def is_authentic_delete_signal(app, comment_id, channel_name, channel_id, signature, signing_ts):
claim = await resolve_channel_claim(app, channel_id, channel_name) claim = await resolve_channel_claim(app, channel_id, channel_name)
if claim: if claim:
public_key = claim['value']['public_key'] public_key = claim['value']['public_key']
claim_hash = binascii.unhexlify(claim['claim_id'].encode())[::-1] claim_hash = binascii.unhexlify(claim['claim_id'].encode())[::-1]
pieces_injest = b''.join((comment_id.encode(), claim_hash)) pieces_injest = b''.join((signing_ts.encode(), claim_hash, comment_id.encode()))
return is_signature_valid( return is_signature_valid(
encoded_signature=get_encoded_signature(signature), encoded_signature=get_encoded_signature(signature),
signature_digest=hashlib.sha256(pieces_injest).digest(), signature_digest=hashlib.sha256(pieces_injest).digest(),
@ -132,4 +130,3 @@ def clean_input_params(kwargs: dict):
kwargs[k] = v.strip() kwargs[k] = v.strip()
if k in ID_LIST: if k in ID_LIST:
kwargs[k] = v.lower() kwargs[k] = v.lower()

View file

@ -38,8 +38,8 @@ async def delete_comment(app, comment_id):
return await coroutine(delete_comment_by_id)(app['writer'], comment_id) return await coroutine(delete_comment_by_id)(app['writer'], comment_id)
async def delete_comment_if_authorized(app, comment_id, channel_name, channel_id, signature): async def delete_comment_if_authorized(app, comment_id, **kwargs):
authorized = await is_authentic_delete_signal(app, comment_id, channel_name, channel_id, signature) authorized = await is_authentic_delete_signal(app, comment_id, **kwargs)
if not authorized: if not authorized:
return {'deleted': False} return {'deleted': False}

View file

@ -2,7 +2,6 @@
import json import json
import pathlib import pathlib
root_dir = pathlib.Path(__file__).parent.parent root_dir = pathlib.Path(__file__).parent.parent
config_path = root_dir / 'config' / 'conf.json' config_path = root_dir / 'config' / 'conf.json'

View file

@ -1,6 +1,9 @@
import atexit
import os
import unittest import unittest
from multiprocessing.pool import Pool from multiprocessing.pool import Pool
import asyncio
import aiohttp
import requests import requests
import re import re
from itertools import * from itertools import *
@ -10,7 +13,10 @@ from faker.providers import internet
from faker.providers import lorem from faker.providers import lorem
from faker.providers import misc from faker.providers import misc
from settings import config from src.settings import config
from src.server import app
from tests.testcase import AsyncioTestCase
fake = faker.Faker() fake = faker.Faker()
fake.add_provider(internet) fake.add_provider(internet)
@ -22,14 +28,15 @@ def fake_lbryusername():
return '@' + fake.user_name() return '@' + fake.user_name()
def jsonrpc_post(url, method, **params): async def jsonrpc_post(url, method, **params):
json_body = { json_body = {
'jsonrpc': '2.0', 'jsonrpc': '2.0',
'id': None, 'id': None,
'method': method, 'method': method,
'params': params 'params': params
} }
return requests.post(url=url, json=json_body) async with aiohttp.request('POST', url, json=json_body) as request:
return await request.json()
def nothing(): def nothing():
@ -52,19 +59,26 @@ def create_test_comments(values: iter, **default):
for comb in vars_combo] for comb in vars_combo]
class ServerTest(unittest.TestCase): class ServerTest(AsyncioTestCase):
db_file = 'test.db'
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.url = 'http://' + config['HOST'] + ':5921/api' self.url = 'http://' + config['HOST'] + ':5921/api'
def post_comment(self, **params): @classmethod
json_body = { def tearDownClass(cls) -> None:
'jsonrpc': '2.0', print('exit reached')
'id': None, os.remove(cls.db_file)
'method': 'create_comment',
'params': params async def asyncSetUp(self):
} await super().asyncSetUp()
return requests.post(url=self.url, json=json_body) self.server = app.CommentDaemon(config, db_file=self.db_file)
await self.server.start()
self.addCleanup(self.server.stop)
async def post_comment(self, **params):
return await jsonrpc_post(self.url, 'create_comment', **params)
def is_valid_message(self, comment=None, claim_id=None, parent_id=None, def is_valid_message(self, comment=None, claim_id=None, parent_id=None,
channel_name=None, channel_id=None, signature=None, signing_ts=None): channel_name=None, channel_id=None, signature=None, signing_ts=None):
@ -89,9 +103,6 @@ class ServerTest(unittest.TestCase):
return False return False
return True return True
def setUp(self) -> None:
self.reply_id = 'ace7800f36e55c74c4aa6a698f97a7ee5f1ccb047b5a0730960df90e58c41dc2'
@staticmethod @staticmethod
def valid_channel_name(channel_name): def valid_channel_name(channel_name):
return re.fullmatch( return re.fullmatch(
@ -100,7 +111,7 @@ class ServerTest(unittest.TestCase):
channel_name channel_name
) )
def test01CreateCommentNoReply(self): async def test01CreateCommentNoReply(self):
anonymous_test = create_test_comments( anonymous_test = create_test_comments(
('claim_id', 'channel_id', 'channel_name', 'comment'), ('claim_id', 'channel_id', 'channel_name', 'comment'),
comment=None, comment=None,
@ -110,15 +121,14 @@ class ServerTest(unittest.TestCase):
) )
for test in anonymous_test: for test in anonymous_test:
with self.subTest(test=test): with self.subTest(test=test):
message = self.post_comment(**test) message = await self.post_comment(**test)
message = message.json()
self.assertTrue('result' in message or 'error' in message) self.assertTrue('result' in message or 'error' in message)
if 'error' in message: if 'error' in message:
self.assertFalse(self.is_valid_message(**test)) self.assertFalse(self.is_valid_message(**test))
else: else:
self.assertTrue(self.is_valid_message(**test)) self.assertTrue(self.is_valid_message(**test))
def test02CreateNamedCommentsNoReply(self): async def test02CreateNamedCommentsNoReply(self):
named_test = create_test_comments( named_test = create_test_comments(
('channel_name', 'channel_id', 'signature'), ('channel_name', 'channel_id', 'signature'),
claim_id='1234567890123456789012345678901234567890', claim_id='1234567890123456789012345678901234567890',
@ -129,37 +139,35 @@ class ServerTest(unittest.TestCase):
) )
for test in named_test: for test in named_test:
with self.subTest(test=test): with self.subTest(test=test):
message = self.post_comment(**test) message = await self.post_comment(**test)
message = message.json()
self.assertTrue('result' in message or 'error' in message) self.assertTrue('result' in message or 'error' in message)
if 'error' in message: if 'error' in message:
self.assertFalse(self.is_valid_message(**test)) self.assertFalse(self.is_valid_message(**test))
else: else:
self.assertTrue(self.is_valid_message(**test)) self.assertTrue(self.is_valid_message(**test))
def test03CreateAllTestComments(self): async def test03CreateAllTestComments(self):
test_all = create_test_comments(replace.keys(), **{ test_all = create_test_comments(replace.keys(), **{
k: None for k in replace.keys() k: None for k in replace.keys()
}) })
for test in test_all: for test in test_all:
with self.subTest(test=test): with self.subTest(test=test):
message = self.post_comment(**test) message = await self.post_comment(**test)
message = message.json()
self.assertTrue('result' in message or 'error' in message) self.assertTrue('result' in message or 'error' in message)
if 'error' in message: if 'error' in message:
self.assertFalse(self.is_valid_message(**test)) self.assertFalse(self.is_valid_message(**test))
else: else:
self.assertTrue(self.is_valid_message(**test)) self.assertTrue(self.is_valid_message(**test))
def test04CreateAllReplies(self): async def test04CreateAllReplies(self):
claim_id = '1d8a5cc39ca02e55782d619e67131c0a20843be8' claim_id = '1d8a5cc39ca02e55782d619e67131c0a20843be8'
parent_comment = self.post_comment( parent_comment = await self.post_comment(
channel_name='@KevinWalterRabie', channel_name='@KevinWalterRabie',
channel_id=fake.sha1(), channel_id=fake.sha1(),
comment='Hello everybody and welcome back to my chan nel', comment='Hello everybody and welcome back to my chan nel',
claim_id=claim_id, claim_id=claim_id,
) )
parent_id = parent_comment.json()['result']['comment_id'] parent_id = parent_comment['result']['comment_id']
test_all = create_test_comments( test_all = create_test_comments(
('comment', 'channel_name', 'channel_id', 'signature', 'parent_id'), ('comment', 'channel_name', 'channel_id', 'signature', 'parent_id'),
parent_id=parent_id, parent_id=parent_id,
@ -174,8 +182,7 @@ class ServerTest(unittest.TestCase):
if test['parent_id'] != parent_id: if test['parent_id'] != parent_id:
continue continue
else: else:
message = self.post_comment(**test) message = await self.post_comment(**test)
message = message.json()
self.assertTrue('result' in message or 'error' in message) self.assertTrue('result' in message or 'error' in message)
if 'error' in message: if 'error' in message:
self.assertFalse(self.is_valid_message(**test)) self.assertFalse(self.is_valid_message(**test))
@ -183,7 +190,7 @@ class ServerTest(unittest.TestCase):
self.assertTrue(self.is_valid_message(**test)) self.assertTrue(self.is_valid_message(**test))
class ListCommentsTest(unittest.TestCase): class ListCommentsTest(AsyncioTestCase):
replace = { replace = {
'claim_id': fake.sha1, 'claim_id': fake.sha1,
'comment': fake.text, 'comment': fake.text,
@ -192,30 +199,35 @@ class ListCommentsTest(unittest.TestCase):
'signature': nothing, 'signature': nothing,
'parent_id': nothing 'parent_id': nothing
} }
db_file = 'list_test.db'
url = 'http://localhost:5921/api'
comment_ids = None
claim_id = '1d8a5cc39ca02e55782d619e67131c0a20843be8'
@classmethod @classmethod
def post_comment(cls, **params): async def post_comment(cls, **params):
json_body = { return await jsonrpc_post(cls.url, 'create_comment', **params)
'jsonrpc': '2.0',
'id': None,
'method': 'create_comment',
'params': params
}
return requests.post(url=cls.url, json=json_body)
@classmethod @classmethod
def setUpClass(cls) -> None: def tearDownClass(cls) -> None:
cls.url = 'http://' + config['HOST'] + ':5921/api' print('exit reached')
cls.claim_id = '1d8a5cc39ca02e55782d619e67131c0a20843be8' os.remove(cls.db_file)
cls.comment_list = [{key: cls.replace[key]() for key in cls.replace.keys()} for _ in range(23)]
for comment in cls.comment_list:
comment['claim_id'] = cls.claim_id
cls.comment_ids = [cls.post_comment(**comm).json()['result']['comment_id']
for comm in cls.comment_list]
def testListComments(self): async def asyncSetUp(self):
response_one = jsonrpc_post(self.url, 'get_claim_comments', page_size=20, await super().asyncSetUp()
page=1, top_level=1, claim_id=self.claim_id).json() self.server = app.CommentDaemon(config, db_file=self.db_file)
await self.server.start()
self.addCleanup(self.server.stop)
if self.comment_ids is None:
self.comment_list = [{key: self.replace[key]() for key in self.replace.keys()} for _ in range(23)]
for comment in self.comment_list:
comment['claim_id'] = self.claim_id
self.comment_ids = [(await self.post_comment(**comm))['result']['comment_id']
for comm in self.comment_list]
async def testListComments(self):
response_one = await jsonrpc_post(self.url, 'get_claim_comments', page_size=20,
page=1, top_level=1, claim_id=self.claim_id)
self.assertIsNotNone(response_one) self.assertIsNotNone(response_one)
self.assertIn('result', response_one) self.assertIn('result', response_one)
response_one: dict = response_one['result'] response_one: dict = response_one['result']
@ -224,40 +236,11 @@ class ListCommentsTest(unittest.TestCase):
self.assertIn('items', response_one) self.assertIn('items', response_one)
self.assertGreaterEqual(response_one['total_pages'], response_one['page']) self.assertGreaterEqual(response_one['total_pages'], response_one['page'])
last_page = response_one['total_pages'] last_page = response_one['total_pages']
response = jsonrpc_post(self.url, 'get_claim_comments', page_size=20, response = await jsonrpc_post(self.url, 'get_claim_comments', page_size=20,
page=last_page, top_level=1, claim_id=self.claim_id).json() page=last_page, top_level=1, claim_id=self.claim_id)
self.assertIsNotNone(response) self.assertIsNotNone(response)
self.assertIn('result', response) self.assertIn('result', response)
response: dict = response['result'] response: dict = response['result']
self.assertIs(type(response['items']), list) self.assertIs(type(response['items']), list)
self.assertEqual(response['total_items'], response_one['total_items']) self.assertEqual(response['total_items'], response_one['total_items'])
self.assertEqual(response['total_pages'], response_one['total_pages']) self.assertEqual(response['total_pages'], response_one['total_pages'])
class ConcurrentWriteTest(unittest.TestCase):
@staticmethod
def make_comment(num):
return {
'jsonrpc': '2.0',
'id': num,
'method': 'create_comment',
'params': {
'comment': f'Comment #{num}',
'claim_id': '6d266af6c25c80fa2ac6cc7662921ad2e90a07e7',
}
}
@staticmethod
def send_comment_to_server(params):
with requests.post(params[0], json=params[1]) as req:
return req.json()
def test01Concurrency(self):
urls = [f'http://localhost:{port}/api' for port in range(5921, 5925)]
comments = [self.make_comment(i) for i in range(1, 5)]
inputs = list(zip(urls, comments))
with Pool(4) as pool:
results = pool.map(self.send_comment_to_server, inputs)
results = list(filter(lambda x: 'comment_id' in x['result'], results))
self.assertIsNotNone(results)
self.assertEqual(len(results), len(inputs))

View file

@ -1,3 +1,4 @@
import os
import pathlib import pathlib
import unittest import unittest
from asyncio.runners import _cancel_all_tasks # type: ignore from asyncio.runners import _cancel_all_tasks # type: ignore
@ -119,15 +120,20 @@ class AsyncioTestCase(unittest.TestCase):
class DatabaseTestCase(unittest.TestCase): class DatabaseTestCase(unittest.TestCase):
db_file = 'test.db'
def __init__(self, methodName='DatabaseTest'):
super().__init__(methodName)
if pathlib.Path(self.db_file).exists():
os.remove(self.db_file)
def setUp(self) -> None: def setUp(self) -> None:
super().setUp() super().setUp()
if pathlib.Path(config['PATH']['TEST']).exists(): setup_database(self.db_file, config['PATH']['SCHEMA'])
teardown_database(config['PATH']['TEST']) self.conn = obtain_connection(self.db_file)
setup_database(config['PATH']['TEST'], config['PATH']['SCHEMA']) self.addCleanup(self.conn.close)
self.conn = obtain_connection(config['PATH']['TEST']) self.addCleanup(os.remove, self.db_file)
def tearDown(self) -> None: def tearDown(self) -> None:
self.conn.close() self.conn.close()
teardown_database(config['PATH']['TEST'])