diff --git a/.gitignore b/.gitignore index c7e09d7..54e4315 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ -config/conf.json +config/conf.yml +docker-compose.yml + diff --git a/.travis.yml b/.travis.yml index ac2c250..fe81b23 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,28 @@ sudo: required language: python dist: xenial -python: 3.7 +python: 3.8 + +# for docker-compose +services: + - docker + +# to avoid "well it works on my computer" moments +env: + - DOCKER_COMPOSE_VERSION=1.25.4 + +before_install: + # ensure docker-compose version is as specified above + - sudo rm /usr/local/bin/docker-compose + - curl -L https://github.com/docker/compose/releases/download/${DOCKER_COMPOSE_VERSION}/docker-compose-`uname -s`-`uname -m` > docker-compose + - chmod +x docker-compose + - sudo mv docker-compose /usr/local/bin + # refresh docker images + - sudo apt-get update + +before_script: + - docker-compose up -d + jobs: include: diff --git a/config/conf.yml b/config/conf.yml new file mode 100644 index 0000000..6d90446 --- /dev/null +++ b/config/conf.yml @@ -0,0 +1,30 @@ +--- +# for running local-tests without using MySQL for now +testing: + database: sqlite + file: comments.db + pragmas: + journal_mode: wal + cache_size: 64000 + foreign_keys: 0 + ignore_check_constraints: 1 + synchronous: 0 + +# actual database should be running MySQL +production: + database: mysql + name: lbry + user: lbry + password: lbry + host: localhost + port: 3306 + +mode: production +logging: + format: "%(asctime)s | %(levelname)s | %(name)s | %(module)s.%(funcName)s:%(lineno)d + | %(message)s" + aiohttp_format: "%(asctime)s | %(levelname)s | %(name)s | %(message)s" + datefmt: "%Y-%m-%d %H:%M:%S" +host: localhost +port: 5921 +lbrynet: http://localhost:5279 \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..3faf5f2 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,24 @@ +version: "3.7" +services: + ########### + ## MySQL ## + ########### + mysql: + image: mysql/mysql-server:5.7.27 + restart: "no" + ports: + - "3306:3306" + environment: + - MYSQL_ALLOW_EMPTY_PASSWORD=true + - MYSQL_DATABASE=lbry + - MYSQL_USER=lbry + - MYSQL_PASSWORD=lbry + - MYSQL_LOG_CONSOLE=true + ############# + ## Adminer ## + ############# + adminer: + image: adminer + restart: always + ports: + - 8080:8080 diff --git a/setup.py b/setup.py index 923e82a..f7a397f 100644 --- a/setup.py +++ b/setup.py @@ -14,15 +14,17 @@ setup( data_files=[('config', ['config/conf.json',])], include_package_data=True, install_requires=[ + 'pymysql', + 'pyyaml', 'Faker>=1.0.7', - 'asyncio>=3.4.3', - 'aiohttp==3.5.4', - 'aiojobs==0.2.2', + 'asyncio', + 'aiohttp', + 'aiojobs', 'ecdsa>=0.13.3', 'cryptography==2.5', - 'aiosqlite==0.10.0', 'PyNaCl>=1.3.0', 'requests', 'cython', + 'peewee' ] ) diff --git a/src/database/models.py b/src/database/models.py new file mode 100644 index 0000000..9c40586 --- /dev/null +++ b/src/database/models.py @@ -0,0 +1,217 @@ +import json +import time + +import logging +import math +import typing + +from peewee import * +import nacl.hash + +from src.server.validation import is_valid_base_comment +from src.misc import clean + + +class Channel(Model): + claim_id = CharField(column_name='ClaimId', primary_key=True, max_length=40) + name = CharField(column_name='Name', max_length=256) + + class Meta: + table_name = 'CHANNEL' + + +class Comment(Model): + comment = CharField(column_name='Body', max_length=2000) + channel = ForeignKeyField( + backref='comments', + column_name='ChannelId', + field='claim_id', + model=Channel, + null=True + ) + comment_id = CharField(column_name='CommentId', primary_key=True, max_length=64) + is_hidden = BooleanField(column_name='IsHidden', constraints=[SQL("DEFAULT 0")]) + claim_id = CharField(max_length=40, column_name='LbryClaimId') + parent = ForeignKeyField( + column_name='ParentId', + field='comment_id', + model='self', + null=True, + backref='replies' + ) + signature = CharField(max_length=128, column_name='Signature', null=True, unique=True) + signing_ts = TextField(column_name='SigningTs', null=True) + timestamp = IntegerField(column_name='Timestamp') + + class Meta: + table_name = 'COMMENT' + indexes = ( + (('channel', 'comment_id'), False), + (('claim_id', 'comment_id'), False), + ) + + +FIELDS = { + 'comment': Comment.comment, + 'comment_id': Comment.comment_id, + 'claim_id': Comment.claim_id, + 'timestamp': Comment.timestamp, + 'signature': Comment.signature, + 'signing_ts': Comment.signing_ts, + 'is_hidden': Comment.is_hidden, + 'parent_id': Comment.parent.alias('parent_id'), + 'channel_id': Channel.claim_id.alias('channel_id'), + 'channel_name': Channel.name.alias('channel_name'), + 'channel_url': ('lbry://' + Channel.name + '#' + Channel.claim_id).alias('channel_url') +} + + +def comment_list(claim_id: str = None, parent_id: str = None, + top_level: bool = False, exclude_mode: str = None, + page: int = 1, page_size: int = 50, expressions=None, + select_fields: list = None, exclude_fields: list = None) -> dict: + fields = FIELDS.keys() + if exclude_fields: + fields -= set(exclude_fields) + if select_fields: + fields &= set(select_fields) + attributes = [FIELDS[field] for field in fields] + query = Comment.select(*attributes) + + # todo: allow this process to be more automated, so it can just be an expression + if claim_id: + query = query.where(Comment.claim_id == claim_id) + if top_level: + query = query.where(Comment.parent.is_null()) + + if parent_id: + query = query.where(Comment.ParentId == parent_id) + + if exclude_mode: + show_hidden = exclude_mode.lower() == 'hidden' + query = query.where((Comment.is_hidden == show_hidden)) + + if expressions: + query = query.where(expressions) + + total = query.count() + query = (query + .join(Channel, JOIN.LEFT_OUTER) + .order_by(Comment.timestamp.desc()) + .paginate(page, page_size)) + items = [clean(item) for item in query.dicts()] + # has_hidden_comments is deprecated + data = { + 'page': page, + 'page_size': page_size, + 'total_pages': math.ceil(total / page_size), + 'total_items': total, + 'items': items, + 'has_hidden_comments': exclude_mode is not None and exclude_mode == 'hidden', + } + return data + + +def get_comment(comment_id: str) -> dict: + try: + comment = comment_list(expressions=(Comment.comment_id == comment_id), page_size=1).get('items').pop() + except IndexError: + raise ValueError(f'Comment does not exist with id {comment_id}') + else: + return comment + + +def create_comment_id(comment: str, channel_id: str, timestamp: int): + # We convert the timestamp from seconds into minutes + # to prevent spammers from commenting the same BS everywhere. + nearest_minute = str(math.floor(timestamp / 60)) + + # don't use claim_id for the comment_id anymore so comments + # are not unique to just one claim + prehash = b':'.join([ + comment.encode(), + channel_id.encode(), + nearest_minute.encode() + ]) + return nacl.hash.sha256(prehash).decode() + + +def create_comment(comment: str = None, claim_id: str = None, + parent_id: str = None, channel_id: str = None, + channel_name: str = None, signature: str = None, + signing_ts: str = None) -> dict: + if not is_valid_base_comment( + comment=comment, + claim_id=claim_id, + parent_id=parent_id, + channel_id=channel_id, + channel_name=channel_name, + signature=signature, + signing_ts=signing_ts + ): + raise ValueError('Invalid Parameters given for comment') + + channel, _ = Channel.get_or_create(name=channel_name, claim_id=channel_id) + if parent_id and not claim_id: + parent: Comment = Comment.get_by_id(parent_id) + claim_id = parent.claim_id + + timestamp = int(time.time()) + comment_id = create_comment_id(comment, channel_id, timestamp) + new_comment = Comment.create( + claim_id=claim_id, + comment_id=comment_id, + comment=comment, + parent=parent_id, + channel=channel, + signature=signature, + signing_ts=signing_ts, + timestamp=timestamp + ) + return get_comment(new_comment.comment_id) + + +def delete_comment(comment_id: str) -> bool: + try: + comment: Comment = Comment.get_by_id(comment_id) + except DoesNotExist as e: + raise ValueError from e + else: + return 0 < comment.delete_instance(True, delete_nullable=True) + + +def edit_comment(comment_id: str, new_comment: str, new_sig: str, new_ts: str) -> bool: + try: + comment: Comment = Comment.get_by_id(comment_id) + except DoesNotExist as e: + raise ValueError from e + else: + comment.comment = new_comment + comment.signature = new_sig + comment.signing_ts = new_ts + + # todo: add a 'last-modified' timestamp + comment.timestamp = int(time.time()) + return comment.save() > 0 + + +def set_hidden_flag(comment_ids: typing.List[str], hidden=True) -> bool: + # sets `is_hidden` flag for all `comment_ids` to the `hidden` param + update = (Comment + .update(is_hidden=hidden) + .where(Comment.comment_id.in_(comment_ids))) + return update.execute() > 0 + + +if __name__ == '__main__': + logger = logging.getLogger('peewee') + logger.addHandler(logging.StreamHandler()) + logger.setLevel(logging.DEBUG) + + comments = comment_list( + page_size=20, + expressions=((Comment.timestamp < 1583272089) & + (Comment.claim_id ** '420%')) + ) + + print(json.dumps(comments, indent=4)) diff --git a/src/database/queries.py b/src/database/queries.py deleted file mode 100644 index 03815aa..0000000 --- a/src/database/queries.py +++ /dev/null @@ -1,290 +0,0 @@ -import atexit -import logging -import math -import sqlite3 -import time -import typing - -import nacl.hash - -from src.database.schema import CREATE_TABLES_QUERY - -logger = logging.getLogger(__name__) - -SELECT_COMMENTS_ON_CLAIMS = """ - SELECT comment, comment_id, claim_id, timestamp, is_hidden, parent_id, - channel_name, channel_id, channel_url, signature, signing_ts - FROM COMMENTS_ON_CLAIMS -""" - - -def clean(thing: dict) -> dict: - if 'is_hidden' in thing: - thing.update({'is_hidden': bool(thing['is_hidden'])}) - return {k: v for k, v in thing.items() if v is not None} - - -def obtain_connection(filepath: str = None, row_factory: bool = True): - connection = sqlite3.connect(filepath) - if row_factory: - connection.row_factory = sqlite3.Row - return connection - - -def get_claim_comments(conn: sqlite3.Connection, claim_id: str, parent_id: str = None, - page: int = 1, page_size: int = 50, top_level=False): - with conn: - if top_level: - results = [clean(dict(row)) for row in conn.execute( - SELECT_COMMENTS_ON_CLAIMS + " WHERE claim_id = ? AND parent_id IS NULL LIMIT ? OFFSET ?", - (claim_id, page_size, page_size * (page - 1)) - )] - count = conn.execute( - "SELECT COUNT(*) FROM COMMENTS_ON_CLAIMS WHERE claim_id = ? AND parent_id IS NULL", - (claim_id,) - ) - elif parent_id is None: - results = [clean(dict(row)) for row in conn.execute( - SELECT_COMMENTS_ON_CLAIMS + "WHERE claim_id = ? LIMIT ? OFFSET ? ", - (claim_id, page_size, page_size * (page - 1)) - )] - count = conn.execute( - "SELECT COUNT(*) FROM COMMENTS_ON_CLAIMS WHERE claim_id = ?", - (claim_id,) - ) - else: - results = [clean(dict(row)) for row in conn.execute( - SELECT_COMMENTS_ON_CLAIMS + "WHERE claim_id = ? AND parent_id = ? LIMIT ? OFFSET ? ", - (claim_id, parent_id, page_size, page_size * (page - 1)) - )] - count = conn.execute( - "SELECT COUNT(*) FROM COMMENTS_ON_CLAIMS WHERE claim_id = ? AND parent_id = ?", - (claim_id, parent_id) - ) - count = tuple(count.fetchone())[0] - return { - 'items': results, - 'page': page, - 'page_size': page_size, - 'total_pages': math.ceil(count / page_size), - 'total_items': count, - 'has_hidden_comments': claim_has_hidden_comments(conn, claim_id) - } - - -def get_claim_hidden_comments(conn: sqlite3.Connection, claim_id: str, hidden=True, page=1, page_size=50): - with conn: - results = conn.execute( - SELECT_COMMENTS_ON_CLAIMS + "WHERE claim_id = ? AND is_hidden IS ? LIMIT ? OFFSET ?", - (claim_id, hidden, page_size, page_size * (page - 1)) - ) - count = conn.execute( - "SELECT COUNT(*) FROM COMMENTS_ON_CLAIMS WHERE claim_id = ? AND is_hidden IS ?", (claim_id, hidden) - ) - results = [clean(dict(row)) for row in results.fetchall()] - count = tuple(count.fetchone())[0] - - return { - 'items': results, - 'page': page, - 'page_size': page_size, - 'total_pages': math.ceil(count / page_size), - 'total_items': count, - 'has_hidden_comments': claim_has_hidden_comments(conn, claim_id) - } - - -def claim_has_hidden_comments(conn, claim_id): - with conn: - result = conn.execute( - "SELECT COUNT(DISTINCT is_hidden) FROM COMMENTS_ON_CLAIMS WHERE claim_id = ? AND is_hidden IS 1", - (claim_id,) - ) - return bool(tuple(result.fetchone())[0]) - - -def insert_comment(conn: sqlite3.Connection, claim_id: str, comment: str, - channel_id: str = None, signature: str = None, signing_ts: str = None, **extra) -> str: - timestamp = int(time.time()) - prehash = b':'.join((claim_id.encode(), comment.encode(), str(timestamp).encode(),)) - comment_id = nacl.hash.sha256(prehash).decode() - with conn: - curs = conn.execute( - """ - INSERT INTO COMMENT(CommentId, LbryClaimId, ChannelId, Body, ParentId, - Timestamp, Signature, SigningTs, IsHidden) - VALUES (:comment_id, :claim_id, :channel_id, :comment, NULL, - :timestamp, :signature, :signing_ts, 0) """, - { - 'comment_id': comment_id, - 'claim_id': claim_id, - 'channel_id': channel_id, - 'comment': comment, - 'timestamp': timestamp, - 'signature': signature, - 'signing_ts': signing_ts - } - ) - logging.info('attempted to insert comment with comment_id [%s] | %d rows affected', comment_id, curs.rowcount) - return comment_id - - -def insert_reply(conn: sqlite3.Connection, comment: str, parent_id: str, - channel_id: str = None, signature: str = None, - signing_ts: str = None, **extra) -> str: - timestamp = int(time.time()) - prehash = b':'.join((parent_id.encode(), comment.encode(), str(timestamp).encode(),)) - comment_id = nacl.hash.sha256(prehash).decode() - with conn: - curs = conn.execute( - """ - INSERT INTO COMMENT - (CommentId, LbryClaimId, ChannelId, Body, ParentId, Signature, Timestamp, SigningTs, IsHidden) - SELECT :comment_id, LbryClaimId, :channel_id, :comment, :parent_id, :signature, :timestamp, :signing_ts, 0 - FROM COMMENT WHERE CommentId = :parent_id - """, { - 'comment_id': comment_id, - 'parent_id': parent_id, - 'timestamp': timestamp, - 'comment': comment, - 'channel_id': channel_id, - 'signature': signature, - 'signing_ts': signing_ts - } - ) - logging.info('attempted to insert reply with comment_id [%s] | %d rows affected', comment_id, curs.rowcount) - return comment_id - - -def get_comment_or_none(conn: sqlite3.Connection, comment_id: str) -> dict: - with conn: - curry = conn.execute(SELECT_COMMENTS_ON_CLAIMS + "WHERE comment_id = ?", (comment_id,)) - thing = curry.fetchone() - return clean(dict(thing)) if thing else None - - -def get_comment_ids(conn: sqlite3.Connection, claim_id: str, parent_id: str = None, page=1, page_size=50): - """ Just return a list of the comment IDs that are associated with the given claim_id. - If get_all is specified then it returns all the IDs, otherwise only the IDs at that level. - if parent_id is left null then it only returns the top level comments. - - For pagination the parameters are: - get_all XOR (page_size + page) - """ - with conn: - if parent_id is None: - curs = conn.execute(""" - SELECT comment_id FROM COMMENTS_ON_CLAIMS - WHERE claim_id = ? AND parent_id IS NULL LIMIT ? OFFSET ? - """, (claim_id, page_size, page_size * abs(page - 1),) - ) - else: - curs = conn.execute(""" - SELECT comment_id FROM COMMENTS_ON_CLAIMS - WHERE claim_id = ? AND parent_id = ? LIMIT ? OFFSET ? - """, (claim_id, parent_id, page_size, page_size * abs(page - 1),) - ) - return [tuple(row)[0] for row in curs.fetchall()] - - -def get_comments_by_id(conn, comment_ids: typing.Union[list, tuple]) -> typing.Union[list, None]: - """ Returns a list containing the comment data associated with each ID within the list""" - # format the input, under the assumption that the - placeholders = ', '.join('?' for _ in comment_ids) - with conn: - return [clean(dict(row)) for row in conn.execute( - SELECT_COMMENTS_ON_CLAIMS + f'WHERE comment_id IN ({placeholders})', - tuple(comment_ids) - )] - - -def delete_comment_by_id(conn: sqlite3.Connection, comment_id: str) -> bool: - with conn: - curs = conn.execute("DELETE FROM COMMENT WHERE CommentId = ?", (comment_id,)) - return bool(curs.rowcount) - - -def insert_channel(conn: sqlite3.Connection, channel_name: str, channel_id: str): - with conn: - curs = conn.execute('INSERT INTO CHANNEL(ClaimId, Name) VALUES (?, ?)', (channel_id, channel_name)) - return bool(curs.rowcount) - - -def get_channel_id_from_comment_id(conn: sqlite3.Connection, comment_id: str): - with conn: - channel = conn.execute( - "SELECT channel_id, channel_name FROM COMMENTS_ON_CLAIMS WHERE comment_id = ?", (comment_id,) - ).fetchone() - return dict(channel) if channel else {} - - -def get_claim_ids_from_comment_ids(conn: sqlite3.Connection, comment_ids: list): - with conn: - cids = conn.execute( - f""" SELECT CommentId as comment_id, LbryClaimId AS claim_id FROM COMMENT - WHERE CommentId IN ({', '.join('?' for _ in comment_ids)}) """, - tuple(comment_ids) - ) - return {row['comment_id']: row['claim_id'] for row in cids.fetchall()} - - -def hide_comments_by_id(conn: sqlite3.Connection, comment_ids: list) -> bool: - with conn: - curs = conn.cursor() - curs.executemany( - "UPDATE COMMENT SET IsHidden = 1 WHERE CommentId = ?", - [[c] for c in comment_ids] - ) - return bool(curs.rowcount) - - -def edit_comment_by_id(conn: sqlite3.Connection, comment_id: str, comment: str, - signature: str, signing_ts: str) -> bool: - with conn: - curs = conn.execute( - """ - UPDATE COMMENT - SET Body = :comment, Signature = :signature, SigningTs = :signing_ts - WHERE CommentId = :comment_id - """, - { - 'comment': comment, - 'signature': signature, - 'signing_ts': signing_ts, - 'comment_id': comment_id - }) - logger.info("updated comment with `comment_id`: %s", comment_id) - return bool(curs.rowcount) - - -class DatabaseWriter(object): - _writer = None - - def __init__(self, db_file): - if not DatabaseWriter._writer: - self.conn = obtain_connection(db_file) - DatabaseWriter._writer = self - atexit.register(self.cleanup) - logging.info('Database writer has been created at %s', repr(self)) - else: - logging.warning('Someone attempted to insantiate DatabaseWriter') - raise TypeError('Database Writer already exists!') - - def cleanup(self): - logging.info('Cleaning up database writer') - self.conn.close() - DatabaseWriter._writer = None - - @property - def connection(self): - return self.conn - - -def setup_database(db_path): - with sqlite3.connect(db_path) as conn: - conn.executescript(CREATE_TABLES_QUERY) - - -def backup_database(conn: sqlite3.Connection, back_fp): - with sqlite3.connect(back_fp) as back: - conn.backup(back) diff --git a/src/database/writes.py b/src/database/writes.py index b9d4b98..61caa78 100644 --- a/src/database/writes.py +++ b/src/database/writes.py @@ -7,7 +7,7 @@ from src.server.validation import is_valid_base_comment from src.server.validation import is_valid_credential_input from src.server.validation import validate_signature_from_claim from src.server.validation import body_is_valid -from src.server.misc import get_claim_from_id +from src.misc import get_claim_from_id from src.server.external import send_notifications from src.server.external import send_notification import src.database.queries as db @@ -84,13 +84,17 @@ async def hide_comments(app, pieces: list) -> list: # TODO: Amortize this process claims = {} comments_to_hide = [] + # go through a list of dict objects for p in pieces: + # maps the comment_id from the piece to a claim_id claim_id = comment_cids[p['comment_id']] + # resolve the claim from its id if claim_id not in claims: claim = await get_claim_from_id(app, claim_id) if claim: claims[claim_id] = claim + # get the claim's signing channel, then use it to validate the hidden comment channel = claims[claim_id].get('signing_channel') if validate_signature_from_claim(channel, p['signature'], p['signing_ts'], p['comment_id']): comments_to_hide.append(p) diff --git a/src/definitions.py b/src/definitions.py new file mode 100644 index 0000000..51654e1 --- /dev/null +++ b/src/definitions.py @@ -0,0 +1,7 @@ +import os + +SRC_DIR = os.path.dirname(os.path.abspath(__file__)) +ROOT_DIR = os.path.dirname(SRC_DIR) +CONFIG_FILE = os.path.join(ROOT_DIR, 'config', 'conf.yml') +LOGGING_DIR = os.path.join(ROOT_DIR, 'logs') +DATABASE_DIR = os.path.join(ROOT_DIR, 'database') diff --git a/src/main.py b/src/main.py index b31bcea..817dac4 100644 --- a/src/main.py +++ b/src/main.py @@ -1,13 +1,20 @@ import argparse +import json +import yaml import logging import logging.config +import os import sys from src.server.app import run_app -from src.settings import config +from src.definitions import LOGGING_DIR, CONFIG_FILE, DATABASE_DIR -def config_logging_from_settings(conf): +def setup_logging_from_config(conf: dict): + # set the logging directory here from the settings file + if not os.path.exists(LOGGING_DIR): + os.mkdir(LOGGING_DIR) + _config = { "version": 1, "disable_existing_loggers": False, @@ -32,7 +39,7 @@ def config_logging_from_settings(conf): "level": "DEBUG", "formatter": "standard", "class": "logging.handlers.RotatingFileHandler", - "filename": conf['path']['debug_log'], + "filename": os.path.join(LOGGING_DIR, 'debug.log'), "maxBytes": 10485760, "backupCount": 5 }, @@ -40,7 +47,7 @@ def config_logging_from_settings(conf): "level": "ERROR", "formatter": "standard", "class": "logging.handlers.RotatingFileHandler", - "filename": conf['path']['error_log'], + "filename": os.path.join(LOGGING_DIR, 'error.log'), "maxBytes": 10485760, "backupCount": 5 }, @@ -48,7 +55,7 @@ def config_logging_from_settings(conf): "level": "NOTSET", "formatter": "aiohttp", "class": "logging.handlers.RotatingFileHandler", - "filename": conf['path']['server_log'], + "filename": os.path.join(LOGGING_DIR, 'server.log'), "maxBytes": 10485760, "backupCount": 5 } @@ -70,15 +77,42 @@ def config_logging_from_settings(conf): logging.config.dictConfig(_config) +def get_config(filepath): + with open(filepath, 'r') as cfile: + config = yaml.load(cfile, Loader=yaml.Loader) + return config + + +def setup_db_from_config(config: dict): + mode = config['mode'] + if config[mode]['database'] == 'sqlite': + if not os.path.exists(DATABASE_DIR): + os.mkdir(DATABASE_DIR) + + config[mode]['db_file'] = os.path.join( + DATABASE_DIR, config[mode]['name'] + ) + + def main(argv=None): argv = argv or sys.argv[1:] parser = argparse.ArgumentParser(description='LBRY Comment Server') parser.add_argument('--port', type=int) + parser.add_argument('--config', type=str) + parser.add_argument('--mode', type=str) args = parser.parse_args(argv) - config_logging_from_settings(config) + + config = get_config(CONFIG_FILE) if not args.config else args.config + setup_logging_from_config(config) + + if args.mode: + config['mode'] = args.mode + + setup_db_from_config(config) + if args.port: config['port'] = args.port - config_logging_from_settings(config) + run_app(config) diff --git a/src/server/misc.py b/src/misc.py similarity index 86% rename from src/server/misc.py rename to src/misc.py index 4f620e6..e9d5be8 100644 --- a/src/server/misc.py +++ b/src/misc.py @@ -20,3 +20,7 @@ def clean_input_params(kwargs: dict): kwargs[k] = v.strip() if k in ID_LIST: kwargs[k] = v.lower() + + +def clean(thing: dict) -> dict: + return {k: v for k, v in thing.items() if v is not None} \ No newline at end of file diff --git a/src/server/app.py b/src/server/app.py index a25787b..389083e 100644 --- a/src/server/app.py +++ b/src/server/app.py @@ -1,7 +1,6 @@ # cython: language_level=3 import asyncio import logging -import pathlib import signal import time @@ -9,81 +8,67 @@ import aiojobs import aiojobs.aiohttp from aiohttp import web -from src.database.queries import obtain_connection, DatabaseWriter -from src.database.queries import setup_database, backup_database +from peewee import * from src.server.handles import api_endpoint, get_api_endpoint +from src.database.models import Comment, Channel +MODELS = [Comment, Channel] logger = logging.getLogger(__name__) -async def setup_db_schema(app): - if not pathlib.Path(app['db_path']).exists(): - logger.info(f'Setting up schema in {app["db_path"]}') - setup_database(app['db_path']) - else: - logger.info(f'Database already exists in {app["db_path"]}, skipping setup') +def setup_database(app): + config = app['config'] + mode = config['mode'] + # switch between Database objects + if config[mode]['database'] == 'mysql': + app['db'] = MySQLDatabase( + database=config[mode]['name'], + user=config[mode]['user'], + host=config[mode]['host'], + password=config[mode]['password'], + port=config[mode]['port'], + ) + elif config[mode]['database'] == 'sqlite': + app['db'] = SqliteDatabase( + config[mode]['file'], + pragmas=config[mode]['pragmas'] + ) -async def database_backup_routine(app): - try: - while True: - await asyncio.sleep(app['config']['backup_int']) - with app['reader'] as conn: - logger.debug('backing up database') - backup_database(conn, app['backup']) - except asyncio.CancelledError: - pass + # bind the Model list to the database + app['db'].bind(MODELS, bind_refs=False, bind_backrefs=False) async def start_background_tasks(app): - # Reading the DB - app['reader'] = obtain_connection(app['db_path'], True) - app['waitful_backup'] = asyncio.create_task(database_backup_routine(app)) - - # Scheduler to prevent multiple threads from writing to DB simulataneously - app['comment_scheduler'] = await aiojobs.create_scheduler(limit=1, pending_limit=0) - app['db_writer'] = DatabaseWriter(app['db_path']) - app['writer'] = app['db_writer'].connection + app['db'].connect() + app['db'].create_tables(MODELS) # for requesting to external and internal APIs app['webhooks'] = await aiojobs.create_scheduler(pending_limit=0) async def close_database_connections(app): - logger.info('Ending background backup loop') - app['waitful_backup'].cancel() - await app['waitful_backup'] - app['reader'].close() - app['writer'].close() - app['db_writer'].cleanup() + app['db'].close() async def close_schedulers(app): - logger.info('Closing comment_scheduler') - await app['comment_scheduler'].close() - logger.info('Closing scheduler for webhook requests') await app['webhooks'].close() class CommentDaemon: - def __init__(self, config, db_file=None, backup=None, **kwargs): + def __init__(self, config, **kwargs): app = web.Application() + app['config'] = config # configure the config - app['config'] = config - self.config = app['config'] + self.config = config + self.host = config['host'] + self.port = config['port'] - # configure the db file - if db_file: - app['db_path'] = db_file - app['backup'] = backup - else: - app['db_path'] = config['path']['database'] - app['backup'] = backup or (app['db_path'] + '.backup') + setup_database(app) # configure the order of tasks to run during app lifetime - app.on_startup.append(setup_db_schema) app.on_startup.append(start_background_tasks) app.on_shutdown.append(close_schedulers) app.on_cleanup.append(close_database_connections) @@ -105,20 +90,19 @@ class CommentDaemon: await self.app_runner.setup() self.app_site = web.TCPSite( runner=self.app_runner, - host=host or self.config['host'], - port=port or self.config['port'], + host=host or self.host, + port=port or self.port, ) await self.app_site.start() - logger.info(f'Comment Server is running on {self.config["host"]}:{self.config["port"]}') + logger.info(f'Comment Server is running on {self.host}:{self.port}') async def stop(self): await self.app_runner.shutdown() await self.app_runner.cleanup() -def run_app(config, db_file=None): - comment_app = CommentDaemon(config=config, db_file=db_file, close_timeout=5.0) - +def run_app(config): + comment_app = CommentDaemon(config=config) loop = asyncio.get_event_loop() def __exit(): diff --git a/src/server/handles.py b/src/server/handles.py index 6f629a0..1c56663 100644 --- a/src/server/handles.py +++ b/src/server/handles.py @@ -1,16 +1,23 @@ import asyncio import logging import time +import typing from aiohttp import web from aiojobs.aiohttp import atomic +from peewee import DoesNotExist -import src.database.queries as db -from src.database.writes import abandon_comment, create_comment -from src.database.writes import hide_comments -from src.database.writes import edit_comment -from src.server.misc import clean_input_params +from src.server.validation import validate_signature_from_claim +from src.misc import clean_input_params, get_claim_from_id from src.server.errors import make_error, report_error +from src.database.models import Comment, Channel +from src.database.models import get_comment +from src.database.models import comment_list +from src.database.models import create_comment +from src.database.models import edit_comment +from src.database.models import delete_comment +from src.database.models import set_hidden_flag + logger = logging.getLogger(__name__) @@ -20,51 +27,194 @@ def ping(*args): return 'pong' -def handle_get_channel_from_comment_id(app, kwargs: dict): - return db.get_channel_id_from_comment_id(app['reader'], **kwargs) +def handle_get_channel_from_comment_id(app: web.Application, comment_id: str) -> dict: + comment = get_comment(comment_id) + return { + 'channel_id': comment['channel_id'], + 'channel_name': comment['channel_name'] + } -def handle_get_comment_ids(app, kwargs): - return db.get_comment_ids(app['reader'], **kwargs) +def handle_get_comment_ids( + app: web.Application, + claim_id: str, + parent_id: str = None, + page: int = 1, + page_size: int = 50, + flattened=False +) -> dict: + results = comment_list( + claim_id=claim_id, + parent_id=parent_id, + top_level=(parent_id is None), + page=page, + page_size=page_size, + select_fields=['comment_id', 'parent_id'] + ) + if flattened: + results.update({ + 'items': [item['comment_id'] for item in results['items']], + 'replies': [(item['comment_id'], item.get('parent_id')) + for item in results['items']] + }) + return results -def handle_get_claim_comments(app, kwargs): - return db.get_claim_comments(app['reader'], **kwargs) +def handle_get_comments_by_id( + app: web.Application, + comment_ids: typing.Union[list, tuple] +) -> dict: + expression = Comment.comment_id.in_(comment_ids) + return comment_list(expressions=expression, page_size=len(comment_ids)) -def handle_get_comments_by_id(app, kwargs): - return db.get_comments_by_id(app['reader'], **kwargs) +def handle_get_claim_comments( + app: web.Application, + claim_id: str, + parent_id: str = None, + page: int = 1, + page_size: int = 50, + top_level: bool = False +) -> dict: + return comment_list( + claim_id=claim_id, + parent_id=parent_id, + page=page, + page_size=page_size, + top_level=top_level + ) -def handle_get_claim_hidden_comments(app, kwargs): - return db.get_claim_hidden_comments(app['reader'], **kwargs) +def handle_get_claim_hidden_comments( + app: web.Application, + claim_id: str, + hidden: bool, + page: int = 1, + page_size: int = 50, +) -> dict: + exclude = 'hidden' if hidden else 'visible' + return comment_list( + claim_id=claim_id, + exclude_mode=exclude, + page=page, + page_size=page_size + ) -async def handle_abandon_comment(app, params): - return {'abandoned': await abandon_comment(app, **params)} +async def handle_abandon_comment( + app: web.Application, + comment_id: str, + signature: str, + signing_ts: str, + **kwargs, +) -> dict: + comment = get_comment(comment_id) + try: + channel = await get_claim_from_id(app, comment['channel_id']) + except DoesNotExist: + raise ValueError('Could not find a channel associated with the given comment') + else: + if not validate_signature_from_claim(channel, signature, signing_ts, comment_id): + raise ValueError('Abandon signature could not be validated') + + with app['db'].atomic(): + return { + 'abandoned': delete_comment(comment_id) + } -async def handle_hide_comments(app, params): - return {'hidden': await hide_comments(app, **params)} +async def handle_hide_comments(app: web.Application, pieces: list, hide: bool = True) -> dict: + # let's get all the distinct claim_ids from the list of comment_ids + pieces_by_id = {p['comment_id']: p for p in pieces} + comment_ids = list(pieces_by_id.keys()) + comments = (Comment + .select(Comment.comment_id, Comment.claim_id) + .where(Comment.comment_id.in_(comment_ids)) + .tuples()) + + # resolve the claims and map them to their corresponding comment_ids + claims = {} + for comment_id, claim_id in comments: + try: + # try and resolve the claim, if fails then we mark it as null + # and remove the associated comment from the pieces + if claim_id not in claims: + claims[claim_id] = await get_claim_from_id(app, claim_id) + + # try to get a public key to validate + if claims[claim_id] is None or 'signing_channel' not in claims[claim_id]: + raise ValueError(f'could not get signing channel from claim_id: {claim_id}') + + # try to validate signature + else: + channel = claims[claim_id]['signing_channel'] + piece = pieces_by_id[comment_id] + is_valid_signature = validate_signature_from_claim( + claim=channel, + signature=piece['signature'], + signing_ts=piece['signing_ts'], + data=piece['comment_id'] + ) + if not is_valid_signature: + raise ValueError(f'could not validate signature on comment_id: {comment_id}') + + except ValueError: + # remove the piece from being hidden + pieces_by_id.pop(comment_id) + + # remaining items in pieces_by_id have been able to successfully validate + with app['db'].atomic(): + set_hidden_flag(list(pieces_by_id.keys()), hidden=hide) + + query = Comment.select().where(Comment.comment_id.in_(comment_ids)).objects() + result = { + 'hidden': [c.comment_id for c in query if c.is_hidden], + 'visible': [c.comment_id for c in query if not c.is_hidden], + } + return result -async def handle_edit_comment(app, params): - if await edit_comment(app, **params): - return db.get_comment_or_none(app['reader'], params['comment_id']) +async def handle_edit_comment(app, comment: str = None, comment_id: str = None, + signature: str = None, signing_ts: str = None, **params) -> dict: + current = get_comment(comment_id) + channel_claim = await get_claim_from_id(app, current['channel_id']) + if not validate_signature_from_claim(channel_claim, signature, signing_ts, comment): + raise ValueError('Signature could not be validated') + + with app['db'].atomic(): + if not edit_comment(comment_id, comment, signature, signing_ts): + raise ValueError('Comment could not be edited') + return get_comment(comment_id) + + +# TODO: retrieve stake amounts for each channel & store in db +def handle_create_comment(app, comment: str = None, claim_id: str = None, + parent_id: str = None, channel_id: str = None, channel_name: str = None, + signature: str = None, signing_ts: str = None) -> dict: + with app['db'].atomic(): + return create_comment( + comment=comment, + claim_id=claim_id, + parent_id=parent_id, + channel_id=channel_id, + channel_name=channel_name, + signature=signature, + signing_ts=signing_ts + ) METHODS = { 'ping': ping, - 'get_claim_comments': handle_get_claim_comments, - 'get_claim_hidden_comments': handle_get_claim_hidden_comments, + 'get_claim_comments': handle_get_claim_comments, # this gets used + 'get_claim_hidden_comments': handle_get_claim_hidden_comments, # this gets used 'get_comment_ids': handle_get_comment_ids, - 'get_comments_by_id': handle_get_comments_by_id, - 'get_channel_from_comment_id': handle_get_channel_from_comment_id, - 'create_comment': create_comment, + 'get_comments_by_id': handle_get_comments_by_id, # this gets used + 'get_channel_from_comment_id': handle_get_channel_from_comment_id, # this gets used + 'create_comment': handle_create_comment, # this gets used 'delete_comment': handle_abandon_comment, - 'abandon_comment': handle_abandon_comment, - 'hide_comments': handle_hide_comments, - 'edit_comment': handle_edit_comment + 'abandon_comment': handle_abandon_comment, # this gets used + 'hide_comments': handle_hide_comments, # this gets used + 'edit_comment': handle_edit_comment # this gets used } @@ -78,17 +228,19 @@ async def process_json(app, body: dict) -> dict: start = time.time() try: if asyncio.iscoroutinefunction(METHODS[method]): - result = await METHODS[method](app, params) + result = await METHODS[method](app, **params) else: - result = METHODS[method](app, params) - response['result'] = result + result = METHODS[method](app, **params) + except Exception as err: - logger.exception(f'Got {type(err).__name__}:') + logger.exception(f'Got {type(err).__name__}:\n{err}') if type(err) in (ValueError, TypeError): # param error, not too important response['error'] = make_error('INVALID_PARAMS', err) else: response['error'] = make_error('INTERNAL', err) - await app['webhooks'].spawn(report_error(app, err, body)) + await app['webhooks'].spawn(report_error(app, err, body)) + else: + response['result'] = result finally: end = time.time() diff --git a/src/server/validation.py b/src/server/validation.py index ef1c841..3b251ba 100644 --- a/src/server/validation.py +++ b/src/server/validation.py @@ -51,11 +51,31 @@ def claim_id_is_valid(claim_id: str) -> bool: # default to None so params can be treated as kwargs; param count becomes more manageable -def is_valid_base_comment(comment: str = None, claim_id: str = None, parent_id: str = None, **kwargs) -> bool: - return comment and body_is_valid(comment) and \ - ((claim_id and claim_id_is_valid(claim_id)) or # parentid is used in place of claimid in replies - (parent_id and comment_id_is_valid(parent_id))) \ - and is_valid_credential_input(**kwargs) +def is_valid_base_comment( + comment: str = None, + claim_id: str = None, + parent_id: str = None, + strict: bool = False, + **kwargs, +) -> bool: + try: + assert comment and body_is_valid(comment) + # strict mode assumes that the parent_id might not exist + if strict: + assert claim_id and claim_id_is_valid(claim_id) + assert parent_id is None or comment_id_is_valid(parent_id) + # non-strict removes reference restrictions + else: + assert claim_id or parent_id + if claim_id: + assert claim_id_is_valid(claim_id) + else: + assert comment_id_is_valid(parent_id) + + except AssertionError: + return False + else: + return is_valid_credential_input(**kwargs) def is_valid_credential_input(channel_id: str = None, channel_name: str = None, diff --git a/src/settings.py b/src/settings.py deleted file mode 100644 index 2d720c2..0000000 --- a/src/settings.py +++ /dev/null @@ -1,17 +0,0 @@ -# cython: language_level=3 -import json -import pathlib - -root_dir = pathlib.Path(__file__).parent.parent -config_path = root_dir / 'config' / 'conf.json' - - -def get_config(filepath): - with open(filepath, 'r') as cfile: - conf = json.load(cfile) - for key, path in conf['path'].items(): - conf['path'][key] = str(root_dir / path) - return conf - - -config = get_config(config_path) diff --git a/test/test_database.py b/test/test_database.py index c43bb0a..c698ad1 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -1,18 +1,13 @@ -import sqlite3 - from random import randint import faker from faker.providers import internet from faker.providers import lorem from faker.providers import misc -from src.database.queries import get_comments_by_id -from src.database.queries import get_comment_ids -from src.database.queries import get_claim_comments -from src.database.queries import get_claim_hidden_comments -from src.database.writes import create_comment_or_error -from src.database.queries import hide_comments_by_id -from src.database.queries import delete_comment_by_id +from src.database.models import create_comment +from src.database.models import delete_comment +from src.database.models import comment_list, get_comment +from src.database.models import set_hidden_flag from test.testcase import DatabaseTestCase fake = faker.Faker() @@ -27,26 +22,25 @@ class TestDatabaseOperations(DatabaseTestCase): self.claimId = '529357c3422c6046d3fec76be2358004ba22e340' def test01NamedComments(self): - comment = create_comment_or_error( - conn=self.conn, + comment = create_comment( claim_id=self.claimId, comment='This is a named comment', channel_name='@username', channel_id='529357c3422c6046d3fec76be2358004ba22abcd', - signature=fake.uuid4(), + signature='22'*64, signing_ts='aaa' ) self.assertIsNotNone(comment) self.assertNotIn('parent_in', comment) + previous_id = comment['comment_id'] - reply = create_comment_or_error( - conn=self.conn, + reply = create_comment( claim_id=self.claimId, comment='This is a named response', channel_name='@another_username', channel_id='529357c3422c6046d3fec76be2358004ba224bcd', parent_id=previous_id, - signature=fake.uuid4(), + signature='11'*64, signing_ts='aaa' ) self.assertIsNotNone(reply) @@ -54,34 +48,32 @@ class TestDatabaseOperations(DatabaseTestCase): def test02AnonymousComments(self): self.assertRaises( - sqlite3.IntegrityError, - create_comment_or_error, - conn=self.conn, + ValueError, + create_comment, claim_id=self.claimId, comment='This is an ANONYMOUS comment' ) def test03SignedComments(self): - comment = create_comment_or_error( - conn=self.conn, + comment = create_comment( claim_id=self.claimId, comment='I like big butts and i cannot lie', channel_name='@sirmixalot', channel_id='529357c3422c6046d3fec76be2358005ba22abcd', - signature=fake.uuid4(), + signature='24'*64, signing_ts='asdasd' ) self.assertIsNotNone(comment) self.assertIn('signing_ts', comment) + previous_id = comment['comment_id'] - reply = create_comment_or_error( - conn=self.conn, + reply = create_comment( claim_id=self.claimId, comment='This is a LBRY verified response', channel_name='@LBRY', channel_id='529357c3422c6046d3fec76be2358001ba224bcd', parent_id=previous_id, - signature=fake.uuid4(), + signature='12'*64, signing_ts='sfdfdfds' ) self.assertIsNotNone(reply) @@ -90,75 +82,109 @@ class TestDatabaseOperations(DatabaseTestCase): def test04UsernameVariations(self): self.assertRaises( - AssertionError, - callable=create_comment_or_error, - conn=self.conn, + ValueError, + create_comment, claim_id=self.claimId, channel_name='$#(@#$@#$', channel_id='529357c3422c6046d3fec76be2358001ba224b23', - comment='this is an invalid username' + comment='this is an invalid username', + signature='1' * 128, + signing_ts='123' ) - valid_username = create_comment_or_error( - conn=self.conn, + + valid_username = create_comment( claim_id=self.claimId, channel_name='@' + 'a' * 255, channel_id='529357c3422c6046d3fec76be2358001ba224b23', - comment='this is a valid username' + comment='this is a valid username', + signature='1'*128, + signing_ts='123' ) self.assertIsNotNone(valid_username) - self.assertRaises(AssertionError, - callable=create_comment_or_error, - conn=self.conn, - claim_id=self.claimId, - channel_name='@' + 'a' * 256, - channel_id='529357c3422c6046d3fec76be2358001ba224b23', - comment='this username is too long' - ) self.assertRaises( - AssertionError, - callable=create_comment_or_error, - conn=self.conn, + ValueError, + create_comment, + claim_id=self.claimId, + channel_name='@' + 'a' * 256, + channel_id='529357c3422c6046d3fec76be2358001ba224b23', + comment='this username is too long', + signature='2' * 128, + signing_ts='123' + ) + + self.assertRaises( + ValueError, + create_comment, claim_id=self.claimId, channel_name='', channel_id='529357c3422c6046d3fec76be2358001ba224b23', - comment='this username should not default to ANONYMOUS' + comment='this username should not default to ANONYMOUS', + signature='3' * 128, + signing_ts='123' ) + self.assertRaises( - AssertionError, - callable=create_comment_or_error, - conn=self.conn, + ValueError, + create_comment, claim_id=self.claimId, channel_name='@', channel_id='529357c3422c6046d3fec76be2358001ba224b23', - comment='this username is too short' + comment='this username is too short', + signature='3' * 128, + signing_ts='123' ) def test05HideComments(self): - comm = create_comment_or_error(self.conn, 'Comment #1', self.claimId, '1'*40, '@Doge123', 'a'*128, '123') - comment = get_comments_by_id(self.conn, [comm['comment_id']]).pop() + comm = create_comment( + comment='Comment #1', + claim_id=self.claimId, + channel_id='1'*40, + channel_name='@Doge123', + signature='a'*128, + signing_ts='123' + ) + comment = get_comment(comm['comment_id']) self.assertFalse(comment['is_hidden']) - success = hide_comments_by_id(self.conn, [comm['comment_id']]) + + success = set_hidden_flag([comm['comment_id']]) self.assertTrue(success) - comment = get_comments_by_id(self.conn, [comm['comment_id']]).pop() + + comment = get_comment(comm['comment_id']) self.assertTrue(comment['is_hidden']) - success = hide_comments_by_id(self.conn, [comm['comment_id']]) + + success = set_hidden_flag([comm['comment_id']]) self.assertTrue(success) - comment = get_comments_by_id(self.conn, [comm['comment_id']]).pop() + + comment = get_comment(comm['comment_id']) self.assertTrue(comment['is_hidden']) def test06DeleteComments(self): - comm = create_comment_or_error(self.conn, 'Comment #1', self.claimId, '1'*40, '@Doge123', 'a'*128, '123') - comments = get_claim_comments(self.conn, self.claimId) - match = list(filter(lambda x: comm['comment_id'] == x['comment_id'], comments['items'])) - self.assertTrue(match) - deleted = delete_comment_by_id(self.conn, comm['comment_id']) + # make sure that the comment was created + comm = create_comment( + comment='Comment #1', + claim_id=self.claimId, + channel_id='1'*40, + channel_name='@Doge123', + signature='a'*128, + signing_ts='123' + ) + comments = comment_list(self.claimId) + match = [x for x in comments['items'] if x['comment_id'] == comm['comment_id']] + self.assertTrue(len(match) > 0) + + deleted = delete_comment(comm['comment_id']) self.assertTrue(deleted) - comments = get_claim_comments(self.conn, self.claimId) - match = list(filter(lambda x: comm['comment_id'] == x['comment_id'], comments['items'])) + + # make sure that we can't find the comment here + comments = comment_list(self.claimId) + match = [x for x in comments['items'] if x['comment_id'] == comm['comment_id']] self.assertFalse(match) - deleted = delete_comment_by_id(self.conn, comm['comment_id']) - self.assertFalse(deleted) + self.assertRaises( + ValueError, + delete_comment, + comment_id=comm['comment_id'], + ) class ListDatabaseTest(DatabaseTestCase): @@ -169,61 +195,75 @@ class ListDatabaseTest(DatabaseTestCase): def testLists(self): for claim_id in self.claim_ids: with self.subTest(claim_id=claim_id): - comments = get_claim_comments(self.conn, claim_id) + comments = comment_list(claim_id) self.assertIsNotNone(comments) self.assertGreater(comments['page_size'], 0) self.assertIn('has_hidden_comments', comments) self.assertFalse(comments['has_hidden_comments']) - top_comments = get_claim_comments(self.conn, claim_id, top_level=True, page=1, page_size=50) + top_comments = comment_list(claim_id, top_level=True, page=1, page_size=50) self.assertIsNotNone(top_comments) self.assertEqual(top_comments['page_size'], 50) self.assertEqual(top_comments['page'], 1) self.assertGreaterEqual(top_comments['total_pages'], 0) self.assertGreaterEqual(top_comments['total_items'], 0) - comment_ids = get_comment_ids(self.conn, claim_id, page_size=50, page=1) + comment_ids = comment_list(claim_id, page_size=50, page=1) with self.subTest(comment_ids=comment_ids): self.assertIsNotNone(comment_ids) self.assertLessEqual(len(comment_ids), 50) - matching_comments = get_comments_by_id(self.conn, comment_ids) + matching_comments = (comment_ids) self.assertIsNotNone(matching_comments) self.assertEqual(len(matching_comments), len(comment_ids)) def testHiddenCommentLists(self): claim_id = 'a'*40 - comm1 = create_comment_or_error(self.conn, 'Comment #1', claim_id, '1'*40, '@Doge123', 'a'*128, '123') - comm2 = create_comment_or_error(self.conn, 'Comment #2', claim_id, '1'*40, '@Doge123', 'b'*128, '123') - comm3 = create_comment_or_error(self.conn, 'Comment #3', claim_id, '1'*40, '@Doge123', 'c'*128, '123') + comm1 = create_comment( + 'Comment #1', + claim_id, + channel_id='1'*40, + channel_name='@Doge123', + signature='a'*128, + signing_ts='123' + ) + comm2 = create_comment( + 'Comment #2', claim_id, + channel_id='1'*40, + channel_name='@Doge123', + signature='b'*128, + signing_ts='123' + ) + comm3 = create_comment( + 'Comment #3', claim_id, + channel_id='1'*40, + channel_name='@Doge123', + signature='c'*128, + signing_ts='123' + ) comments = [comm1, comm2, comm3] - comment_list = get_claim_comments(self.conn, claim_id) - self.assertIn('items', comment_list) - self.assertIn('has_hidden_comments', comment_list) - self.assertEqual(len(comments), comment_list['total_items']) - self.assertIn('has_hidden_comments', comment_list) - self.assertFalse(comment_list['has_hidden_comments']) - hide_comments_by_id(self.conn, [comm2['comment_id']]) + listed_comments = comment_list(claim_id) + self.assertEqual(len(comments), listed_comments['total_items']) + self.assertFalse(listed_comments['has_hidden_comments']) - default_comments = get_claim_hidden_comments(self.conn, claim_id) - self.assertIn('has_hidden_comments', default_comments) + set_hidden_flag([comm2['comment_id']]) + hidden = comment_list(claim_id, exclude_mode='hidden') - hidden_comments = get_claim_hidden_comments(self.conn, claim_id, hidden=True) - self.assertIn('has_hidden_comments', hidden_comments) - self.assertEqual(default_comments, hidden_comments) + self.assertTrue(hidden['has_hidden_comments']) + self.assertGreater(len(hidden['items']), 0) - hidden_comment = hidden_comments['items'][0] + visible = comment_list(claim_id, exclude_mode='visible') + self.assertFalse(visible['has_hidden_comments']) + self.assertNotEqual(listed_comments['items'], visible['items']) + + # make sure the hidden comment is the one we marked as hidden + hidden_comment = hidden['items'][0] self.assertEqual(hidden_comment['comment_id'], comm2['comment_id']) - visible_comments = get_claim_hidden_comments(self.conn, claim_id, hidden=False) - self.assertIn('has_hidden_comments', visible_comments) - self.assertNotIn(hidden_comment, visible_comments['items']) - - hidden_ids = [c['comment_id'] for c in hidden_comments['items']] - visible_ids = [c['comment_id'] for c in visible_comments['items']] + hidden_ids = [c['comment_id'] for c in hidden['items']] + visible_ids = [c['comment_id'] for c in visible['items']] composite_ids = hidden_ids + visible_ids + listed_comments = comment_list(claim_id) + all_ids = [c['comment_id'] for c in listed_comments['items']] composite_ids.sort() - - comment_list = get_claim_comments(self.conn, claim_id) - all_ids = [c['comment_id'] for c in comment_list['items']] all_ids.sort() self.assertEqual(composite_ids, all_ids) diff --git a/test/test_server.py b/test/test_server.py index 5cd9ea6..56bccfb 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -9,13 +9,18 @@ from faker.providers import internet from faker.providers import lorem from faker.providers import misc -from src.settings import config +from src.main import get_config, CONFIG_FILE from src.server import app from src.server.validation import is_valid_base_comment from test.testcase import AsyncioTestCase +config = get_config(CONFIG_FILE) +config['mode'] = 'testing' +config['testing']['file'] = ':memory:' + + if 'slack_webhook' in config: config.pop('slack_webhook') @@ -71,10 +76,10 @@ def create_test_comments(values: iter, **default): class ServerTest(AsyncioTestCase): - db_file = 'test.db' - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + config['mode'] = 'testing' + config['testing']['file'] = ':memory:' self.host = 'localhost' self.port = 5931 @@ -85,11 +90,10 @@ class ServerTest(AsyncioTestCase): @classmethod def tearDownClass(cls) -> None: print('exit reached') - os.remove(cls.db_file) async def asyncSetUp(self): await super().asyncSetUp() - self.server = app.CommentDaemon(config, db_file=self.db_file) + self.server = app.CommentDaemon(config) await self.server.start(host=self.host, port=self.port) self.addCleanup(self.server.stop) @@ -135,14 +139,16 @@ class ServerTest(AsyncioTestCase): test_all = create_test_comments(replace.keys(), **{ k: None for k in replace.keys() }) + test_all.reverse() for test in test_all: - with self.subTest(test=test): + nulls = 'null fields: ' + ', '.join(k for k, v in test.items() if not v) + with self.subTest(test=nulls): message = await self.post_comment(**test) self.assertTrue('result' in message or 'error' in message) if 'error' in message: - self.assertFalse(is_valid_base_comment(**test)) + self.assertFalse(is_valid_base_comment(**test, strict=True)) else: - self.assertTrue(is_valid_base_comment(**test)) + self.assertTrue(is_valid_base_comment(**test, strict=True)) async def test04CreateAllReplies(self): claim_id = '1d8a5cc39ca02e55782d619e67131c0a20843be8' @@ -220,7 +226,8 @@ class ListCommentsTest(AsyncioTestCase): super().__init__(*args, **kwargs) self.host = 'localhost' self.port = 5931 - self.db_file = 'list_test.db' + config['mode'] = 'testing' + config['testing']['file'] = ':memory:' self.claim_id = '1d8a5cc39ca02e55782d619e67131c0a20843be8' self.comment_ids = None @@ -231,10 +238,6 @@ class ListCommentsTest(AsyncioTestCase): async def post_comment(self, **params): return await jsonrpc_post(self.url, 'create_comment', **params) - def tearDown(self) -> None: - print('exit reached') - os.remove(self.db_file) - async def create_lots_of_comments(self, n=23): self.comment_list = [{key: self.replace[key]() for key in self.replace.keys()} for _ in range(23)] for comment in self.comment_list: @@ -244,7 +247,7 @@ class ListCommentsTest(AsyncioTestCase): async def asyncSetUp(self): await super().asyncSetUp() - self.server = app.CommentDaemon(config, db_file=self.db_file) + self.server = app.CommentDaemon(config) await self.server.start(self.host, self.port) self.addCleanup(self.server.stop) diff --git a/test/testcase.py b/test/testcase.py index ddd8b44..415041d 100644 --- a/test/testcase.py +++ b/test/testcase.py @@ -1,12 +1,38 @@ import os import pathlib import unittest -from asyncio.runners import _cancel_all_tasks # type: ignore from unittest.case import _Outcome import asyncio +from asyncio.runners import _cancel_all_tasks # type: ignore +from peewee import * -from src.database.queries import obtain_connection, setup_database +from src.database.models import Channel, Comment + + +test_db = SqliteDatabase(':memory:') + + +MODELS = [Channel, Comment] + + +class DatabaseTestCase(unittest.TestCase): + def __init__(self, methodName='DatabaseTest'): + super().__init__(methodName) + + def setUp(self) -> None: + super().setUp() + test_db.bind(MODELS, bind_refs=False, bind_backrefs=False) + + test_db.connect() + test_db.create_tables(MODELS) + + def tearDown(self) -> None: + # drop tables for next test + test_db.drop_tables(MODELS) + + # close connection + test_db.close() class AsyncioTestCase(unittest.TestCase): @@ -117,21 +143,3 @@ class AsyncioTestCase(unittest.TestCase): self.loop.run_until_complete(maybe_coroutine) -class DatabaseTestCase(unittest.TestCase): - db_file = 'test.db' - - def __init__(self, methodName='DatabaseTest'): - super().__init__(methodName) - if pathlib.Path(self.db_file).exists(): - os.remove(self.db_file) - - def setUp(self) -> None: - super().setUp() - setup_database(self.db_file) - self.conn = obtain_connection(self.db_file) - self.addCleanup(self.conn.close) - self.addCleanup(os.remove, self.db_file) - - def tearDown(self) -> None: - self.conn.close() -