diff --git a/src/database.py b/src/database.py index a16e9af..d1bc994 100644 --- a/src/database.py +++ b/src/database.py @@ -101,12 +101,17 @@ def create_comment(conn: sqlite3.Connection, comment: str, claim_id: str, **kwar _insert_channel(conn, channel_name, channel_id) except AssertionError: logger.exception('Received invalid input') - return None + raise TypeError('Invalid params given to input validation') else: channel_id = config['ANONYMOUS']['CHANNEL_ID'] - comment_id = _insert_comment( - conn=conn, comment=comment, claim_id=claim_id, channel_id=channel_id, **kwargs - ) + try: + comment_id = _insert_comment( + conn=conn, comment=comment, claim_id=claim_id, channel_id=channel_id, **kwargs + ) + except sqlite3.IntegrityError as ie: + logger.exception(ie) + return None # silently return none + curry = conn.execute( 'SELECT * FROM COMMENTS_ON_CLAIMS WHERE comment_id = ?', (comment_id,) ) @@ -186,7 +191,7 @@ async def create_comment_async(db_file: str, comment: str, claim_id: str, **kwar ) await _insert_channel_async(db_file, channel_name, channel_id) except AssertionError: - return None + raise TypeError('Invalid parameters given to input validation') else: channel_id = config['ANONYMOUS']['CHANNEL_ID'] comment_id = await _insert_comment_async( diff --git a/src/main.py b/src/main.py index 9ef240a..b9c3da3 100644 --- a/src/main.py +++ b/src/main.py @@ -3,6 +3,7 @@ import logging import aiojobs.aiohttp import asyncio from aiohttp import web +import re import schema.db_helpers from src.database import obtain_connection @@ -10,6 +11,7 @@ from src.handles import api_endpoint from src.settings import config from src.writes import create_comment_scheduler, DatabaseWriter + logger = logging.getLogger(__name__) # logger.setLevel(logging.DEBUG) @@ -46,16 +48,25 @@ async def create_database_backup(app): while True: await asyncio.sleep(app['config']['BACKUP_INT']) with obtain_connection(app['db_path']) as conn: - logger.debug('%s backing up database') + logger.debug('backing up database') schema.db_helpers.backup_database(conn, app['backup']) except asyncio.CancelledError as e: pass + async def start_background_tasks(app: web.Application): app['waitful_backup'] = app.loop.create_task(create_database_backup(app)) app['comment_scheduler'] = await create_comment_scheduler() - app['writer'] = DatabaseWriter(config['PATH']['DEFAULT']) + app['writer'] = DatabaseWriter(app['db_path']) + + +def insert_to_config(app, conf=None, db_file=None): + db_file = db_file if db_file else 'DEFAULT' + app['config'] = conf if conf else config + app['db_path'] = conf['PATH'][db_file] + app['backup'] = re.sub(r'\.db$', '.backup.db', app['db_path']) + assert app['db_path'] != app['backup'] async def cleanup_background_tasks(app): @@ -66,11 +77,9 @@ async def cleanup_background_tasks(app): app['writer'].close() -def create_app(**kwargs): +def create_app(conf, db_path='DEFAULT', **kwargs): app = web.Application() - app['config'] = config - app['db_path'] = config['PATH']['DEFAULT'] - app['backup'] = config['PATH']['BACKUP'] + insert_to_config(app, conf, db_path) app.on_startup.append(setup_db_schema) app.on_startup.append(start_background_tasks) app['reader'] = obtain_connection(app['db_path'], True) @@ -102,11 +111,15 @@ async def run_app(app): await stop_app(runner) -if __name__ == '__main__': - appl = create_app(close_timeout=5.0) +def __run_app(): + appl = create_app(conf=config, db_path='TEST', close_timeout=5.0) try: asyncio.run(web.run_app(appl, access_log=logger, host=config['HOST'], port=config['PORT'])) except asyncio.CancelledError: pass except ValueError: pass + + +if __name__ == '__main__': + __run_app()