diff --git a/config/conf.json b/config/conf.json index 71e9ebd..7784ab2 100644 --- a/config/conf.json +++ b/config/conf.json @@ -1,10 +1,7 @@ { "PATH": { "SCHEMA": "src/schema/comments_ddl.sql", - "MAIN": "database/comments.db", - "BACKUP": "database/comments.backup.db", - "DEFAULT": "database/default.db", - "TEST": "tests/test.db", + "DATABASE": "database/comments.db", "ERROR_LOG": "logs/error.log", "DEBUG_LOG": "logs/debug.log", "SERVER_LOG": "logs/server.log" diff --git a/scripts/concurrency_test.py b/scripts/concurrency_test.py new file mode 100644 index 0000000..b7bf927 --- /dev/null +++ b/scripts/concurrency_test.py @@ -0,0 +1,34 @@ +import unittest +from multiprocessing.pool import Pool + +import requests + + +class ConcurrentWriteTest(unittest.TestCase): + + @staticmethod + def make_comment(num): + return { + 'jsonrpc': '2.0', + 'id': num, + 'method': 'create_comment', + 'params': { + 'comment': f'Comment #{num}', + 'claim_id': '6d266af6c25c80fa2ac6cc7662921ad2e90a07e7', + } + } + + @staticmethod + def send_comment_to_server(params): + with requests.post(params[0], json=params[1]) as req: + return req.json() + + def test01Concurrency(self): + urls = [f'http://localhost:{port}/api' for port in range(5921, 5925)] + comments = [self.make_comment(i) for i in range(1, 5)] + inputs = list(zip(urls, comments)) + with Pool(4) as pool: + results = pool.map(self.send_comment_to_server, inputs) + results = list(filter(lambda x: 'comment_id' in x['result'], results)) + self.assertIsNotNone(results) + self.assertEqual(len(results), len(inputs)) diff --git a/src/main.py b/src/main.py index 0a3a9cd..71f4cf7 100644 --- a/src/main.py +++ b/src/main.py @@ -17,7 +17,7 @@ def config_logging_from_settings(conf): "datefmt": conf['LOGGING']['DATEFMT'] }, "aiohttp": { - "format": conf['LOGGING']['AIOHTTP_FORMAT'], + "format": conf['LOGGING']['AIOHTTP_FORMAT'], "datefmt": conf['LOGGING']['DATEFMT'] } }, @@ -75,6 +75,7 @@ def main(argv=None): parser = argparse.ArgumentParser(description='LBRY Comment Server') parser.add_argument('--port', type=int) args = parser.parse_args(argv) + config_logging_from_settings(config) if args.port: config['PORT'] = args.port config_logging_from_settings(config) diff --git a/src/schema/comments_ddl.sql b/src/schema/comments_ddl.sql index c7c5d15..fa537ea 100644 --- a/src/schema/comments_ddl.sql +++ b/src/schema/comments_ddl.sql @@ -1,4 +1,3 @@ - PRAGMA FOREIGN_KEYS = ON; -- tables @@ -6,29 +5,31 @@ PRAGMA FOREIGN_KEYS = ON; -- DROP TABLE IF EXISTS CHANNEL; -- DROP TABLE IF EXISTS COMMENT; -CREATE TABLE IF NOT EXISTS COMMENT ( +CREATE TABLE IF NOT EXISTS COMMENT +( CommentId TEXT NOT NULL, LbryClaimId TEXT NOT NULL, - ChannelId TEXT DEFAULT NULL, + ChannelId TEXT DEFAULT NULL, Body TEXT NOT NULL, - ParentId TEXT DEFAULT NULL, - Signature TEXT DEFAULT NULL, + ParentId TEXT DEFAULT NULL, + Signature TEXT DEFAULT NULL, Timestamp INTEGER NOT NULL, - SigningTs TEXT DEFAULT NULL, + SigningTs TEXT DEFAULT NULL, CONSTRAINT COMMENT_PRIMARY_KEY PRIMARY KEY (CommentId) ON CONFLICT IGNORE, CONSTRAINT COMMENT_SIGNATURE_SK UNIQUE (Signature) ON CONFLICT ABORT, - CONSTRAINT COMMENT_CHANNEL_FK FOREIGN KEY (ChannelId) REFERENCES CHANNEL(ClaimId) + CONSTRAINT COMMENT_CHANNEL_FK FOREIGN KEY (ChannelId) REFERENCES CHANNEL (ClaimId) ON DELETE NO ACTION ON UPDATE NO ACTION, - CONSTRAINT COMMENT_PARENT_FK FOREIGN KEY (ParentId) REFERENCES COMMENT(CommentId) - ON UPDATE CASCADE ON DELETE NO ACTION -- setting null implies comment is top level + CONSTRAINT COMMENT_PARENT_FK FOREIGN KEY (ParentId) REFERENCES COMMENT (CommentId) + ON UPDATE CASCADE ON DELETE NO ACTION -- setting null implies comment is top level ); -- ALTER TABLE COMMENT ADD COLUMN SigningTs TEXT DEFAULT NULL; -- DROP TABLE IF EXISTS CHANNEL; -CREATE TABLE IF NOT EXISTS CHANNEL( - ClaimId TEXT NOT NULL, - Name TEXT NOT NULL, +CREATE TABLE IF NOT EXISTS CHANNEL +( + ClaimId TEXT NOT NULL, + Name TEXT NOT NULL, CONSTRAINT CHANNEL_PK PRIMARY KEY (ClaimId) ON CONFLICT IGNORE ); @@ -38,26 +39,37 @@ CREATE TABLE IF NOT EXISTS CHANNEL( -- DROP INDEX IF EXISTS COMMENT_CLAIM_INDEX; CREATE INDEX IF NOT EXISTS CLAIM_COMMENT_INDEX ON COMMENT (LbryClaimId, CommentId); -CREATE INDEX IF NOT EXISTS CHANNEL_COMMENT_INDEX ON COMMENT(ChannelId, CommentId); +CREATE INDEX IF NOT EXISTS CHANNEL_COMMENT_INDEX ON COMMENT (ChannelId, CommentId); -- VIEWS DROP VIEW IF EXISTS COMMENTS_ON_CLAIMS; -CREATE VIEW IF NOT EXISTS COMMENTS_ON_CLAIMS (comment_id, claim_id, timestamp, channel_name, channel_id, channel_url, signature, signing_ts, parent_id, comment) AS - SELECT C.CommentId, C.LbryClaimId, C.Timestamp, CHAN.Name, CHAN.ClaimId, 'lbry://' || CHAN.Name || '#' || CHAN.ClaimId, C.Signature, C.SigningTs, C.ParentId, C.Body - FROM COMMENT AS C - LEFT OUTER JOIN CHANNEL CHAN on C.ChannelId = CHAN.ClaimId - ORDER BY C.Timestamp DESC; +CREATE VIEW IF NOT EXISTS COMMENTS_ON_CLAIMS (comment_id, claim_id, timestamp, channel_name, channel_id, channel_url, + signature, signing_ts, parent_id, comment) AS +SELECT C.CommentId, + C.LbryClaimId, + C.Timestamp, + CHAN.Name, + CHAN.ClaimId, + 'lbry://' || CHAN.Name || '#' || CHAN.ClaimId, + C.Signature, + C.SigningTs, + C.ParentId, + C.Body +FROM COMMENT AS C + LEFT OUTER JOIN CHANNEL CHAN on C.ChannelId = CHAN.ClaimId +ORDER BY C.Timestamp DESC; DROP VIEW IF EXISTS COMMENT_REPLIES; CREATE VIEW IF NOT EXISTS COMMENT_REPLIES (Author, CommentBody, ParentAuthor, ParentCommentBody) AS - SELECT AUTHOR.Name, OG.Body, PCHAN.Name, PARENT.Body FROM COMMENT AS OG - JOIN COMMENT AS PARENT - ON OG.ParentId = PARENT.CommentId - JOIN CHANNEL AS PCHAN ON PARENT.ChannelId = PCHAN.ClaimId - JOIN CHANNEL AS AUTHOR ON OG.ChannelId = AUTHOR.ClaimId - ORDER BY OG.Timestamp; +SELECT AUTHOR.Name, OG.Body, PCHAN.Name, PARENT.Body +FROM COMMENT AS OG + JOIN COMMENT AS PARENT + ON OG.ParentId = PARENT.CommentId + JOIN CHANNEL AS PCHAN ON PARENT.ChannelId = PCHAN.ClaimId + JOIN CHANNEL AS AUTHOR ON OG.ChannelId = AUTHOR.ClaimId +ORDER BY OG.Timestamp; -- this is the default channel for anyone who wants to publish anonymously -- INSERT INTO CHANNEL diff --git a/src/server/app.py b/src/server/app.py index 7a54e86..84ed22e 100644 --- a/src/server/app.py +++ b/src/server/app.py @@ -1,7 +1,6 @@ # cython: language_level=3 import logging import pathlib -import re import signal import time @@ -25,11 +24,6 @@ async def setup_db_schema(app): logger.info(f'Database already exists in {app["db_path"]}, skipping setup') -async def close_comment_scheduler(app): - logger.info('Closing comment_scheduler') - await app['comment_scheduler'].close() - - async def database_backup_routine(app): try: while True: @@ -43,29 +37,41 @@ async def database_backup_routine(app): async def start_background_tasks(app): app['reader'] = obtain_connection(app['db_path'], True) - app['waitful_backup'] = app.loop.create_task(database_backup_routine(app)) + app['waitful_backup'] = asyncio.create_task(database_backup_routine(app)) app['comment_scheduler'] = await aiojobs.create_scheduler(limit=1, pending_limit=0) app['db_writer'] = DatabaseWriter(app['db_path']) app['writer'] = app['db_writer'].connection -async def stop_background_tasks(app): +async def close_database_connections(app): logger.info('Ending background backup loop') app['waitful_backup'].cancel() await app['waitful_backup'] app['reader'].close() app['writer'].close() + app['db_writer'].cleanup() + + +async def close_comment_scheduler(app): + logger.info('Closing comment_scheduler') + await app['comment_scheduler'].close() class CommentDaemon: - def __init__(self, config, db_path=None, **kwargs): + def __init__(self, config, db_file=None, backup=None, **kwargs): self.config = config app = web.Application() - self.insert_to_config(app, config, db_file=db_path) + app['config'] = config + if db_file: + app['db_path'] = db_file + app['backup'] = backup + else: + app['db_path'] = config['PATH']['DATABASE'] + app['backup'] = backup or (app['db_path'] + '.backup') app.on_startup.append(setup_db_schema) app.on_startup.append(start_background_tasks) - app.on_shutdown.append(stop_background_tasks) app.on_shutdown.append(close_comment_scheduler) + app.on_cleanup.append(close_database_connections) aiojobs.aiohttp.setup(app, **kwargs) app.add_routes([ web.post('/api', api_endpoint), @@ -73,38 +79,30 @@ class CommentDaemon: web.get('/api', get_api_endpoint) ]) self.app = app - self.app_runner = web.AppRunner(app) + self.app_runner = None self.app_site = None - async def start(self): + async def start(self, host=None, port=None): self.app['START_TIME'] = time.time() + self.app_runner = web.AppRunner(self.app) await self.app_runner.setup() self.app_site = web.TCPSite( runner=self.app_runner, - host=self.config['HOST'], - port=self.config['PORT'], + host=host or self.config['HOST'], + port=port or self.config['PORT'], ) await self.app_site.start() logger.info(f'Comment Server is running on {self.config["HOST"]}:{self.config["PORT"]}') async def stop(self): - await self.app.shutdown() - await self.app.cleanup() + await self.app_runner.shutdown() await self.app_runner.cleanup() - @staticmethod - def insert_to_config(app, conf=None, db_file=None): - db_file = db_file if db_file else 'DEFAULT' - app['config'] = conf - app['db_path'] = conf['PATH'][db_file] - app['backup'] = re.sub(r'\.db$', '.backup.db', app['db_path']) - assert app['db_path'] != app['backup'] +def run_app(config, db_file=None): + comment_app = CommentDaemon(config=config, db_file=db_file, close_timeout=5.0) -def run_app(config): - comment_app = CommentDaemon(config=config, db_path='DEFAULT', close_timeout=5.0) - - loop = asyncio.get_event_loop() + loop = asyncio.get_event_loop() def __exit(): raise web.GracefulExit() @@ -118,4 +116,4 @@ def run_app(config): except (web.GracefulExit, KeyboardInterrupt, asyncio.CancelledError, ValueError): logging.warning('Server going down, asyncio loop raised cancelled error:') finally: - loop.run_until_complete(comment_app.stop()) \ No newline at end of file + loop.run_until_complete(comment_app.stop()) diff --git a/src/server/database.py b/src/server/database.py index 9bbbd79..5b56ae1 100644 --- a/src/server/database.py +++ b/src/server/database.py @@ -31,14 +31,14 @@ def get_claim_comments(conn: sqlite3.Connection, claim_id: str, parent_id: str = FROM COMMENTS_ON_CLAIMS WHERE claim_id = ? AND parent_id IS NULL LIMIT ? OFFSET ? """, - (claim_id, page_size, page_size*(page - 1)) + (claim_id, page_size, page_size * (page - 1)) )] count = conn.execute( """ SELECT COUNT(*) FROM COMMENTS_ON_CLAIMS WHERE claim_id = ? AND parent_id IS NULL - """, (claim_id, ) + """, (claim_id,) ) elif parent_id is None: results = [clean(dict(row)) for row in conn.execute( @@ -47,7 +47,7 @@ def get_claim_comments(conn: sqlite3.Connection, claim_id: str, parent_id: str = FROM COMMENTS_ON_CLAIMS WHERE claim_id = ? LIMIT ? OFFSET ? """, - (claim_id, page_size, page_size*(page - 1)) + (claim_id, page_size, page_size * (page - 1)) )] count = conn.execute( """ @@ -63,7 +63,7 @@ def get_claim_comments(conn: sqlite3.Connection, claim_id: str, parent_id: str = FROM COMMENTS_ON_CLAIMS WHERE claim_id = ? AND parent_id = ? LIMIT ? OFFSET ? """, - (claim_id, parent_id, page_size, page_size*(page - 1)) + (claim_id, parent_id, page_size, page_size * (page - 1)) )] count = conn.execute( """ @@ -77,7 +77,7 @@ def get_claim_comments(conn: sqlite3.Connection, claim_id: str, parent_id: str = 'items': results, 'page': page, 'page_size': page_size, - 'total_pages': math.ceil(count/page_size), + 'total_pages': math.ceil(count / page_size), 'total_items': count } @@ -194,8 +194,8 @@ class DatabaseWriter(object): def cleanup(self): logging.info('Cleaning up database writer') - DatabaseWriter._writer = None self.conn.close() + DatabaseWriter._writer = None @property def connection(self): diff --git a/src/server/handles.py b/src/server/handles.py index a97b21e..b28ab02 100644 --- a/src/server/handles.py +++ b/src/server/handles.py @@ -10,7 +10,6 @@ from src.server.misc import clean_input_params from src.server.database import get_claim_comments from src.server.database import get_comments_by_id, get_comment_ids from src.server.database import get_channel_id_from_comment_id -from src.server.database import obtain_connection from src.server.misc import is_valid_base_comment from src.server.misc import is_valid_credential_input from src.server.misc import make_error @@ -63,7 +62,7 @@ METHODS = { 'get_comments_by_id': handle_get_comments_by_id, 'get_channel_from_comment_id': handle_get_channel_from_comment_id, 'create_comment': handle_create_comment, - # 'delete_comment': handle_delete_comment, + 'delete_comment': handle_delete_comment, # 'abandon_comment': handle_delete_comment, } @@ -75,8 +74,8 @@ async def process_json(app, body: dict) -> dict: params = body.get('params', {}) clean_input_params(params) logger.debug(f'Received Method {method}, params: {params}') + start = time.time() try: - start = time.time() if asyncio.iscoroutinefunction(METHODS[method]): result = await METHODS[method](app, params) else: @@ -99,6 +98,7 @@ async def process_json(app, body: dict) -> dict: @atomic async def api_endpoint(request: web.Request): try: + web.access_logger.info(f'Forwarded headers: {request.forwarded}') body = await request.json() if type(body) is list or type(body) is dict: if type(body) is list: @@ -109,8 +109,6 @@ async def api_endpoint(request: web.Request): else: return web.json_response(await process_json(request.app, body)) except Exception as e: - logger.exception(f'Exception raised by request from {request.remote}: {e}') - logger.debug(f'Request headers: {request.headers}') return make_error('INVALID_REQUEST', e) diff --git a/src/server/misc.py b/src/server/misc.py index 0fb6981..d0efd66 100644 --- a/src/server/misc.py +++ b/src/server/misc.py @@ -7,6 +7,7 @@ import hashlib import aiohttp import ecdsa +from aiohttp import ClientConnectorError from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.serialization import load_der_public_key from cryptography.hazmat.primitives import hashes @@ -20,7 +21,7 @@ ID_LIST = {'claim_id', 'parent_id', 'comment_id', 'channel_id'} ERRORS = { 'INVALID_PARAMS': {'code': -32602, 'message': 'Invalid Method Parameter(s).'}, - 'INTERNAL': {'code': -32603, 'message': 'Internal Server Error.'}, + 'INTERNAL': {'code': -32603, 'message': 'Internal Server Error. Please notify a LBRY Administrator.'}, 'METHOD_NOT_FOUND': {'code': -32601, 'message': 'The method does not exist / is not available.'}, 'INVALID_REQUEST': {'code': -32600, 'message': 'The JSON sent is not a valid Request object.'}, 'PARSE_ERROR': { @@ -35,31 +36,28 @@ def make_error(error, exc=None) -> dict: body = ERRORS[error] if error in ERRORS else ERRORS['INTERNAL'] try: if exc: - body.update({ - type(exc).__name__: str(exc) - }) + body.update({type(exc).__name__: str(exc)}) finally: return body async def resolve_channel_claim(app, channel_id, channel_name): lbry_url = f'lbry://{channel_name}#{channel_id}' - resolve_body = { - 'method': 'resolve', - 'params': { - 'urls': [lbry_url, ] - } - } - async with aiohttp.request('POST', app['config']['LBRYNET'], json=resolve_body) as req: - try: - resp = await req.json() - except JSONDecodeError as jde: - logger.exception(jde.msg) - raise Exception('JSON Decode Error in Claim Resolution') - finally: - if 'result' in resp: - return resp['result'].get(lbry_url) - raise ValueError('claim resolution yields error', {'error': resp['error']}) + resolve_body = {'method': 'resolve', 'params': {'urls': [lbry_url]}} + try: + async with aiohttp.request('POST', app['config']['LBRYNET'], json=resolve_body) as req: + try: + resp = await req.json() + except JSONDecodeError as jde: + logger.exception(jde.msg) + raise Exception('JSON Decode Error in Claim Resolution') + finally: + if 'result' in resp: + return resp['result'].get(lbry_url) + raise ValueError('claim resolution yields error', {'error': resp['error']}) + except (ConnectionRefusedError, ClientConnectorError): + logger.critical("Connection to the LBRYnet daemon failed, make sure it's running.") + raise Exception("Server cannot verify delete signature") def get_encoded_signature(signature): @@ -85,8 +83,8 @@ def is_signature_valid(encoded_signature, signature_digest, public_key_bytes): public_key = load_der_public_key(public_key_bytes, default_backend()) public_key.verify(encoded_signature, signature_digest, ec.ECDSA(Prehashed(hashes.SHA256()))) return True - except (ValueError, InvalidSignature) as err: - logger.debug('Signature Valiadation Failed: %s', err) + except (ValueError, InvalidSignature): + logger.exception('Signature validation failed') return False @@ -112,12 +110,12 @@ def is_valid_credential_input(channel_id=None, channel_name=None, signature=None return True -async def is_authentic_delete_signal(app, comment_id, channel_name, channel_id, signature): +async def is_authentic_delete_signal(app, comment_id, channel_name, channel_id, signature, signing_ts): claim = await resolve_channel_claim(app, channel_id, channel_name) if claim: public_key = claim['value']['public_key'] claim_hash = binascii.unhexlify(claim['claim_id'].encode())[::-1] - pieces_injest = b''.join((comment_id.encode(), claim_hash)) + pieces_injest = b''.join((signing_ts.encode(), claim_hash, comment_id.encode())) return is_signature_valid( encoded_signature=get_encoded_signature(signature), signature_digest=hashlib.sha256(pieces_injest).digest(), @@ -132,4 +130,3 @@ def clean_input_params(kwargs: dict): kwargs[k] = v.strip() if k in ID_LIST: kwargs[k] = v.lower() - diff --git a/src/server/writes.py b/src/server/writes.py index 2feb928..902548d 100644 --- a/src/server/writes.py +++ b/src/server/writes.py @@ -38,8 +38,8 @@ async def delete_comment(app, comment_id): return await coroutine(delete_comment_by_id)(app['writer'], comment_id) -async def delete_comment_if_authorized(app, comment_id, channel_name, channel_id, signature): - authorized = await is_authentic_delete_signal(app, comment_id, channel_name, channel_id, signature) +async def delete_comment_if_authorized(app, comment_id, **kwargs): + authorized = await is_authentic_delete_signal(app, comment_id, **kwargs) if not authorized: return {'deleted': False} diff --git a/src/settings.py b/src/settings.py index 110e2b8..4506d01 100644 --- a/src/settings.py +++ b/src/settings.py @@ -2,7 +2,6 @@ import json import pathlib - root_dir = pathlib.Path(__file__).parent.parent config_path = root_dir / 'config' / 'conf.json' diff --git a/tests/server_test.py b/tests/server_test.py index 1a5ab0e..38cb3d6 100644 --- a/tests/server_test.py +++ b/tests/server_test.py @@ -1,6 +1,9 @@ +import atexit +import os import unittest from multiprocessing.pool import Pool - +import asyncio +import aiohttp import requests import re from itertools import * @@ -10,7 +13,10 @@ from faker.providers import internet from faker.providers import lorem from faker.providers import misc -from settings import config +from src.settings import config +from src.server import app +from tests.testcase import AsyncioTestCase + fake = faker.Faker() fake.add_provider(internet) @@ -22,14 +28,15 @@ def fake_lbryusername(): return '@' + fake.user_name() -def jsonrpc_post(url, method, **params): +async def jsonrpc_post(url, method, **params): json_body = { 'jsonrpc': '2.0', 'id': None, 'method': method, 'params': params } - return requests.post(url=url, json=json_body) + async with aiohttp.request('POST', url, json=json_body) as request: + return await request.json() def nothing(): @@ -52,19 +59,26 @@ def create_test_comments(values: iter, **default): for comb in vars_combo] -class ServerTest(unittest.TestCase): +class ServerTest(AsyncioTestCase): + db_file = 'test.db' + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.url = 'http://' + config['HOST'] + ':5921/api' - def post_comment(self, **params): - json_body = { - 'jsonrpc': '2.0', - 'id': None, - 'method': 'create_comment', - 'params': params - } - return requests.post(url=self.url, json=json_body) + @classmethod + def tearDownClass(cls) -> None: + print('exit reached') + os.remove(cls.db_file) + + async def asyncSetUp(self): + await super().asyncSetUp() + self.server = app.CommentDaemon(config, db_file=self.db_file) + await self.server.start() + self.addCleanup(self.server.stop) + + async def post_comment(self, **params): + return await jsonrpc_post(self.url, 'create_comment', **params) def is_valid_message(self, comment=None, claim_id=None, parent_id=None, channel_name=None, channel_id=None, signature=None, signing_ts=None): @@ -89,9 +103,6 @@ class ServerTest(unittest.TestCase): return False return True - def setUp(self) -> None: - self.reply_id = 'ace7800f36e55c74c4aa6a698f97a7ee5f1ccb047b5a0730960df90e58c41dc2' - @staticmethod def valid_channel_name(channel_name): return re.fullmatch( @@ -100,7 +111,7 @@ class ServerTest(unittest.TestCase): channel_name ) - def test01CreateCommentNoReply(self): + async def test01CreateCommentNoReply(self): anonymous_test = create_test_comments( ('claim_id', 'channel_id', 'channel_name', 'comment'), comment=None, @@ -110,15 +121,14 @@ class ServerTest(unittest.TestCase): ) for test in anonymous_test: with self.subTest(test=test): - message = self.post_comment(**test) - message = message.json() + message = await self.post_comment(**test) self.assertTrue('result' in message or 'error' in message) if 'error' in message: self.assertFalse(self.is_valid_message(**test)) else: self.assertTrue(self.is_valid_message(**test)) - def test02CreateNamedCommentsNoReply(self): + async def test02CreateNamedCommentsNoReply(self): named_test = create_test_comments( ('channel_name', 'channel_id', 'signature'), claim_id='1234567890123456789012345678901234567890', @@ -129,37 +139,35 @@ class ServerTest(unittest.TestCase): ) for test in named_test: with self.subTest(test=test): - message = self.post_comment(**test) - message = message.json() + message = await self.post_comment(**test) self.assertTrue('result' in message or 'error' in message) if 'error' in message: self.assertFalse(self.is_valid_message(**test)) else: self.assertTrue(self.is_valid_message(**test)) - def test03CreateAllTestComments(self): + async def test03CreateAllTestComments(self): test_all = create_test_comments(replace.keys(), **{ k: None for k in replace.keys() }) for test in test_all: with self.subTest(test=test): - message = self.post_comment(**test) - message = message.json() + message = await self.post_comment(**test) self.assertTrue('result' in message or 'error' in message) if 'error' in message: self.assertFalse(self.is_valid_message(**test)) else: self.assertTrue(self.is_valid_message(**test)) - def test04CreateAllReplies(self): + async def test04CreateAllReplies(self): claim_id = '1d8a5cc39ca02e55782d619e67131c0a20843be8' - parent_comment = self.post_comment( + parent_comment = await self.post_comment( channel_name='@KevinWalterRabie', channel_id=fake.sha1(), comment='Hello everybody and welcome back to my chan nel', claim_id=claim_id, ) - parent_id = parent_comment.json()['result']['comment_id'] + parent_id = parent_comment['result']['comment_id'] test_all = create_test_comments( ('comment', 'channel_name', 'channel_id', 'signature', 'parent_id'), parent_id=parent_id, @@ -174,8 +182,7 @@ class ServerTest(unittest.TestCase): if test['parent_id'] != parent_id: continue else: - message = self.post_comment(**test) - message = message.json() + message = await self.post_comment(**test) self.assertTrue('result' in message or 'error' in message) if 'error' in message: self.assertFalse(self.is_valid_message(**test)) @@ -183,7 +190,7 @@ class ServerTest(unittest.TestCase): self.assertTrue(self.is_valid_message(**test)) -class ListCommentsTest(unittest.TestCase): +class ListCommentsTest(AsyncioTestCase): replace = { 'claim_id': fake.sha1, 'comment': fake.text, @@ -192,30 +199,35 @@ class ListCommentsTest(unittest.TestCase): 'signature': nothing, 'parent_id': nothing } + db_file = 'list_test.db' + url = 'http://localhost:5921/api' + comment_ids = None + claim_id = '1d8a5cc39ca02e55782d619e67131c0a20843be8' @classmethod - def post_comment(cls, **params): - json_body = { - 'jsonrpc': '2.0', - 'id': None, - 'method': 'create_comment', - 'params': params - } - return requests.post(url=cls.url, json=json_body) + async def post_comment(cls, **params): + return await jsonrpc_post(cls.url, 'create_comment', **params) @classmethod - def setUpClass(cls) -> None: - cls.url = 'http://' + config['HOST'] + ':5921/api' - cls.claim_id = '1d8a5cc39ca02e55782d619e67131c0a20843be8' - cls.comment_list = [{key: cls.replace[key]() for key in cls.replace.keys()} for _ in range(23)] - for comment in cls.comment_list: - comment['claim_id'] = cls.claim_id - cls.comment_ids = [cls.post_comment(**comm).json()['result']['comment_id'] - for comm in cls.comment_list] + def tearDownClass(cls) -> None: + print('exit reached') + os.remove(cls.db_file) - def testListComments(self): - response_one = jsonrpc_post(self.url, 'get_claim_comments', page_size=20, - page=1, top_level=1, claim_id=self.claim_id).json() + async def asyncSetUp(self): + await super().asyncSetUp() + self.server = app.CommentDaemon(config, db_file=self.db_file) + await self.server.start() + self.addCleanup(self.server.stop) + if self.comment_ids is None: + self.comment_list = [{key: self.replace[key]() for key in self.replace.keys()} for _ in range(23)] + for comment in self.comment_list: + comment['claim_id'] = self.claim_id + self.comment_ids = [(await self.post_comment(**comm))['result']['comment_id'] + for comm in self.comment_list] + + async def testListComments(self): + response_one = await jsonrpc_post(self.url, 'get_claim_comments', page_size=20, + page=1, top_level=1, claim_id=self.claim_id) self.assertIsNotNone(response_one) self.assertIn('result', response_one) response_one: dict = response_one['result'] @@ -224,40 +236,11 @@ class ListCommentsTest(unittest.TestCase): self.assertIn('items', response_one) self.assertGreaterEqual(response_one['total_pages'], response_one['page']) last_page = response_one['total_pages'] - response = jsonrpc_post(self.url, 'get_claim_comments', page_size=20, - page=last_page, top_level=1, claim_id=self.claim_id).json() + response = await jsonrpc_post(self.url, 'get_claim_comments', page_size=20, + page=last_page, top_level=1, claim_id=self.claim_id) self.assertIsNotNone(response) self.assertIn('result', response) response: dict = response['result'] self.assertIs(type(response['items']), list) self.assertEqual(response['total_items'], response_one['total_items']) self.assertEqual(response['total_pages'], response_one['total_pages']) - - -class ConcurrentWriteTest(unittest.TestCase): - @staticmethod - def make_comment(num): - return { - 'jsonrpc': '2.0', - 'id': num, - 'method': 'create_comment', - 'params': { - 'comment': f'Comment #{num}', - 'claim_id': '6d266af6c25c80fa2ac6cc7662921ad2e90a07e7', - } - } - - @staticmethod - def send_comment_to_server(params): - with requests.post(params[0], json=params[1]) as req: - return req.json() - - def test01Concurrency(self): - urls = [f'http://localhost:{port}/api' for port in range(5921, 5925)] - comments = [self.make_comment(i) for i in range(1, 5)] - inputs = list(zip(urls, comments)) - with Pool(4) as pool: - results = pool.map(self.send_comment_to_server, inputs) - results = list(filter(lambda x: 'comment_id' in x['result'], results)) - self.assertIsNotNone(results) - self.assertEqual(len(results), len(inputs)) diff --git a/tests/testcase.py b/tests/testcase.py index 6bf1ca6..4ab4324 100644 --- a/tests/testcase.py +++ b/tests/testcase.py @@ -1,3 +1,4 @@ +import os import pathlib import unittest from asyncio.runners import _cancel_all_tasks # type: ignore @@ -119,15 +120,20 @@ class AsyncioTestCase(unittest.TestCase): class DatabaseTestCase(unittest.TestCase): + db_file = 'test.db' + + def __init__(self, methodName='DatabaseTest'): + super().__init__(methodName) + if pathlib.Path(self.db_file).exists(): + os.remove(self.db_file) + def setUp(self) -> None: super().setUp() - if pathlib.Path(config['PATH']['TEST']).exists(): - teardown_database(config['PATH']['TEST']) - setup_database(config['PATH']['TEST'], config['PATH']['SCHEMA']) - self.conn = obtain_connection(config['PATH']['TEST']) + setup_database(self.db_file, config['PATH']['SCHEMA']) + self.conn = obtain_connection(self.db_file) + self.addCleanup(self.conn.close) + self.addCleanup(os.remove, self.db_file) def tearDown(self) -> None: self.conn.close() - teardown_database(config['PATH']['TEST']) -