diff --git a/config/conf.json b/config/conf.json index 7375879..6857d62 100644 --- a/config/conf.json +++ b/config/conf.json @@ -1,15 +1,15 @@ { - "path": { - "schema": "schema/comments_ddl.sql", - "main": "database/comments.db", - "backup": "database/comments.backup.db", - "default": "database/default.db", - "test": "tests/test.db" + "PATH": { + "SCHEMA": "schema/comments_ddl.sql", + "MAIN": "database/comments.db", + "BACKUP": "database/comments.backup.db", + "DEFAULT": "database/default.db", + "TEST": "tests/test.db" }, - "anonymous": { - "channel_id": "9cb713f01bf247a0e03170b5ed00d5161340c486", - "channel_name": "@Anonymous" + "ANONYMOUS": { + "CHANNEL_ID": "9cb713f01bf247a0e03170b5ed00d5161340c486", + "CHANNEL_NAME": "@Anonymous" }, - "host": "localhost", - "port": "2903" + "HOST": "localhost", + "PORT": 2903 } \ No newline at end of file diff --git a/lbry_comment_server/__init__.py b/lbry_comment_server/__init__.py index 7df6fbc..216173b 100644 --- a/lbry_comment_server/__init__.py +++ b/lbry_comment_server/__init__.py @@ -1,9 +1,6 @@ -from lbry_comment_server.settings import config from lbry_comment_server.database import obtain_connection, validate_input, get_claim_comments from lbry_comment_server.database import get_comments_by_id, get_comment_ids, create_comment from lbry_comment_server.handles import api_endpoint -SCHEMA = config['path']['SCHEMA'] -DATABASE = config['path']['dev'] -BACKUP = config['path']['BACKUP'] -ANONYMOUS = config['ANONYMOUS'] +from lbry_comment_server.settings import config + diff --git a/lbry_comment_server/database.py b/lbry_comment_server/database.py index 687f357..57b7a6e 100644 --- a/lbry_comment_server/database.py +++ b/lbry_comment_server/database.py @@ -4,11 +4,11 @@ import typing import re import nacl.hash import time -from lbry_comment_server import ANONYMOUS, DATABASE +from lbry_comment_server.settings import config def obtain_connection(filepath: str = None, row_factory: bool = True): - filepath = filepath if filepath else DATABASE + filepath = filepath if filepath else config['PATH']['DATABASE'] connection = sqlite3.connect(filepath) if row_factory: connection.row_factory = sqlite3.Row @@ -64,8 +64,6 @@ def _insert_channel(conn: sqlite3.Connection, channel_name: str, channel_id: str ) - - def _insert_comment(conn: sqlite3.Connection, claim_id: str = None, comment: str = None, channel_id: str = None, signature: str = None, parent_id: str = None) -> str: timestamp = time.time_ns() @@ -99,7 +97,7 @@ def create_comment(conn: sqlite3.Connection, comment: str, claim_id: str, **kwar except AssertionError: return None else: - channel_id = ANONYMOUS['channel_id'] + channel_id = config['ANONYMOUS']['CHANNEL_ID'] comment_id = _insert_comment( conn=conn, comment=comment, claim_id=claim_id, channel_id=channel_id, **kwargs ) @@ -169,7 +167,6 @@ async def _insert_comment_async(db_file: str, claim_id: str = None, comment: str return comment_id - async def create_comment_async(db_file: str, comment: str, claim_id: str, **kwargs): channel_id = kwargs.pop('channel_id', '') channel_name = kwargs.pop('channel_name', '') @@ -185,7 +182,7 @@ async def create_comment_async(db_file: str, comment: str, claim_id: str, **kwar except AssertionError: return None else: - channel_id = ANONYMOUS['channel_id'] + channel_id = config['ANONYMOUS']['CHANNEL_ID'] comment_id = await _insert_comment_async( db_file=db_file, comment=comment, claim_id=claim_id, channel_id=channel_id, **kwargs ) diff --git a/lbry_comment_server/main.py b/lbry_comment_server/main.py index 0cc9d50..1ab621c 100644 --- a/lbry_comment_server/main.py +++ b/lbry_comment_server/main.py @@ -6,8 +6,12 @@ from aiohttp import web from settings import config from lbry_comment_server import DATABASE from lbry_comment_server.database import obtain_connection +from lbry_comment_server import api_endpoint +def add_routes(app: web.Application): + app.add_routes([web.post('/api', api_endpoint)]) + class CommentServer: def __init__(self, port=2903): diff --git a/lbry_comment_server/settings.py b/lbry_comment_server/settings.py index b515a8c..9da8a2a 100644 --- a/lbry_comment_server/settings.py +++ b/lbry_comment_server/settings.py @@ -8,8 +8,8 @@ config_path = root_dir / 'config' / 'conf.json' def get_config(filepath): with open(filepath, 'r') as cfile: conf = json.load(cfile) - for key, path in conf['path'].items(): - conf['path'][key] = str(root_dir / path) + for key, path in conf['PATH'].items(): + conf['PATH'][key] = str(root_dir / path) return conf diff --git a/tests/database_test.py b/tests/database_test.py index 86580ca..39a4379 100644 --- a/tests/database_test.py +++ b/tests/database_test.py @@ -3,45 +3,17 @@ import unittest from faker.providers import internet from faker.providers import lorem from faker.providers import misc - -import schema.db_helpers as schema -from lbry_comment_server.settings import config import lbry_comment_server.database as db import faker from random import randint + fake = faker.Faker() fake.add_provider(internet) fake.add_provider(lorem) fake.add_provider(misc) -class DatabaseTestCase(unittest.TestCase): - def setUp(self) -> None: - super().setUp() - schema.setup_database(config['path']['test']) - self.conn = db.obtain_connection(config['path']['test']) - - def tearDown(self) -> None: - curs = self.conn.execute('SELECT * FROM COMMENT') - results = {'COMMENT': [dict(r) for r in curs.fetchall()]} - curs = self.conn.execute('SELECT * FROM CHANNEL') - results['CHANNEL'] = [dict(r) for r in curs.fetchall()] - curs = self.conn.execute('SELECT * FROM COMMENTS_ON_CLAIMS') - results['COMMENTS_ON_CLAIMS'] = [dict(r) for r in curs.fetchall()] - curs = self.conn.execute('SELECT * FROM COMMENT_REPLIES') - results['COMMENT_REPLIES'] = [dict(r) for r in curs.fetchall()] - # print(json.dumps(results, indent=4)) - with self.conn: - self.conn.executescript(""" - DROP TABLE IF EXISTS COMMENT; - DROP TABLE IF EXISTS CHANNEL; - DROP VIEW IF EXISTS COMMENTS_ON_CLAIMS; - DROP VIEW IF EXISTS COMMENT_REPLIES; - """) - self.conn.close() - - class TestCommentCreation(DatabaseTestCase): def setUp(self) -> None: super().setUp() @@ -259,6 +231,10 @@ class ListDatabaseTest(DatabaseTestCase): self.assertEqual(len(matching_comments), len(comment_ids)) +class AsyncDatabaseTestCase(unittest.TestCase): + async def asyncSetup + + def generate_replies(top_comments): return [{ 'claim_id': comment['claim_id'], diff --git a/tests/testcase.py b/tests/testcase.py new file mode 100644 index 0000000..16c913f --- /dev/null +++ b/tests/testcase.py @@ -0,0 +1,143 @@ +import asyncio +from asyncio.runners import _cancel_all_tasks # type: ignore +import unittest +from unittest.case import _Outcome +import lbry_comment_server.database as db + +from lbry_comment_server import config +import schema.db_helpers as schema + +class AsyncioTestCase(unittest.TestCase): + # Implementation inspired by discussion: + # https://bugs.python.org/issue32972 + + maxDiff = None + + async def asyncSetUp(self): # pylint: disable=C0103 + pass + + async def asyncTearDown(self): # pylint: disable=C0103 + pass + + def run(self, result=None): # pylint: disable=R0915 + orig_result = result + if result is None: + result = self.defaultTestResult() + startTestRun = getattr(result, 'startTestRun', None) # pylint: disable=C0103 + if startTestRun is not None: + startTestRun() + + result.startTest(self) + + testMethod = getattr(self, self._testMethodName) # pylint: disable=C0103 + if (getattr(self.__class__, "__unittest_skip__", False) or + getattr(testMethod, "__unittest_skip__", False)): + # If the class or method was skipped. + try: + skip_why = (getattr(self.__class__, '__unittest_skip_why__', '') + or getattr(testMethod, '__unittest_skip_why__', '')) + self._addSkip(result, self, skip_why) + finally: + result.stopTest(self) + return + expecting_failure_method = getattr(testMethod, + "__unittest_expecting_failure__", False) + expecting_failure_class = getattr(self, + "__unittest_expecting_failure__", False) + expecting_failure = expecting_failure_class or expecting_failure_method + outcome = _Outcome(result) + + self.loop = asyncio.new_event_loop() # pylint: disable=W0201 + asyncio.set_event_loop(self.loop) + self.loop.set_debug(True) + + try: + self._outcome = outcome + + with outcome.testPartExecutor(self): + self.setUp() + self.loop.run_until_complete(self.asyncSetUp()) + if outcome.success: + outcome.expecting_failure = expecting_failure + with outcome.testPartExecutor(self, isTest=True): + maybe_coroutine = testMethod() + if asyncio.iscoroutine(maybe_coroutine): + self.loop.run_until_complete(maybe_coroutine) + outcome.expecting_failure = False + with outcome.testPartExecutor(self): + self.loop.run_until_complete(self.asyncTearDown()) + self.tearDown() + + self.doAsyncCleanups() + + try: + _cancel_all_tasks(self.loop) + self.loop.run_until_complete(self.loop.shutdown_asyncgens()) + finally: + asyncio.set_event_loop(None) + self.loop.close() + + for test, reason in outcome.skipped: + self._addSkip(result, test, reason) + self._feedErrorsToResult(result, outcome.errors) + if outcome.success: + if expecting_failure: + if outcome.expectedFailure: + self._addExpectedFailure(result, outcome.expectedFailure) + else: + self._addUnexpectedSuccess(result) + else: + result.addSuccess(self) + return result + finally: + result.stopTest(self) + if orig_result is None: + stopTestRun = getattr(result, 'stopTestRun', None) # pylint: disable=C0103 + if stopTestRun is not None: + stopTestRun() # pylint: disable=E1102 + + # explicitly break reference cycles: + # outcome.errors -> frame -> outcome -> outcome.errors + # outcome.expectedFailure -> frame -> outcome -> outcome.expectedFailure + outcome.errors.clear() + outcome.expectedFailure = None + + # clear the outcome, no more needed + self._outcome = None + + def doAsyncCleanups(self): # pylint: disable=C0103 + outcome = self._outcome or _Outcome() + while self._cleanups: + function, args, kwargs = self._cleanups.pop() + with outcome.testPartExecutor(self): + maybe_coroutine = function(*args, **kwargs) + if asyncio.iscoroutine(maybe_coroutine): + self.loop.run_until_complete(maybe_coroutine) + + + +class DatabaseTestCase(unittest.TestCase): + def setUp(self) -> None: + super().setUp() + schema.setup_database(config['PATH']['TEST']) + self.conn = db.obtain_connection(config['PATH']['TEST']) + + def tearDown(self) -> None: + curs = self.conn.execute('SELECT * FROM COMMENT') + results = {'COMMENT': [dict(r) for r in curs.fetchall()]} + curs = self.conn.execute('SELECT * FROM CHANNEL') + results['CHANNEL'] = [dict(r) for r in curs.fetchall()] + curs = self.conn.execute('SELECT * FROM COMMENTS_ON_CLAIMS') + results['COMMENTS_ON_CLAIMS'] = [dict(r) for r in curs.fetchall()] + curs = self.conn.execute('SELECT * FROM COMMENT_REPLIES') + results['COMMENT_REPLIES'] = [dict(r) for r in curs.fetchall()] + # print(json.dumps(results, indent=4)) + with self.conn: + self.conn.executescript(""" + DROP TABLE IF EXISTS COMMENT; + DROP TABLE IF EXISTS CHANNEL; + DROP VIEW IF EXISTS COMMENTS_ON_CLAIMS; + DROP VIEW IF EXISTS COMMENT_REPLIES; + """) + self.conn.close() +