diff --git a/.gitignore b/.gitignore deleted file mode 100644 index 54e4315..0000000 --- a/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -config/conf.yml -docker-compose.yml - diff --git a/.travis.yml b/.travis.yml index fe9a4de..ac2c250 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,28 +1,7 @@ sudo: required language: python dist: xenial -python: 3.8 - -# for docker-compose -services: - - docker - -# to avoid "well it works on my computer" moments -env: - - DOCKER_COMPOSE_VERSION=1.25.4 - -before_install: - # ensure docker-compose version is as specified above - - sudo rm /usr/local/bin/docker-compose - - curl -L https://github.com/docker/compose/releases/download/${DOCKER_COMPOSE_VERSION}/docker-compose-`uname -s`-`uname -m` > docker-compose - - chmod +x docker-compose - - sudo mv docker-compose /usr/local/bin - # refresh docker images - - sudo apt-get update - -before_script: - - docker-compose up -d - +python: 3.7 jobs: include: @@ -30,5 +9,6 @@ jobs: name: "Unit Tests" install: - pip install -e . + - mkdir database script: - python -m unittest diff --git a/README.md b/README.md index cdeef1d..471f552 100644 --- a/README.md +++ b/README.md @@ -3,36 +3,34 @@ [](https://travis-ci.com/lbryio/comment-server) [](https://codeclimate.com/github/lbryio/comment-server/maintainability) +This is the code for the LBRY Comment Server. +Fork it, run it, set it on fire. Up to you. + ## Before Installing -Install the [`lbry-sdk`](https://github.com/lbryio/lbry-sdk) +Comment Deletion requires having the [`lbry-sdk`](https://github.com/lbryio/lbry-sdk) in order to validate & properly delete comments. - + + ## Installation #### Installing the server: ```bash -$ git clone https://github.com/lbryio/comment-server +$ git clone https://github.com/osilkin98/comment-server $ cd comment-server # create a virtual environment -$ virtualenv --python=python3.8 venv +$ virtualenv --python=python3 venv # Enter the virtual environment $ source venv/bin/activate -# Install required dependencies -(venv) $ pip install -e . - -# Run the server -(venv) $ python src/main.py \ - --port=5921 \ # use a different port besides the default - --config=conf.yml \ # provide a custom config file - & \ # detach and run the service in the background +# install the Server as a Executable Target +(venv) $ python setup.py develop ``` ### Installing the systemd Service Monitor @@ -72,11 +70,16 @@ To Test the database, simply run: There are basic tests to run against the server, though they require that there is a server instance running, though the database - chosen may have to be edited in `config/conf.yml`. + chosen may have to be edited in `config/conf.json`. Additionally there are HTTP requests that can be send with whatever software you choose to test the integrity of the comment server. +## Schema + + + + ## Contributing Contributions are welcome, verbosity is encouraged. Please be considerate diff --git a/config/conf.yml b/config/conf.yml deleted file mode 100644 index fc22322..0000000 --- a/config/conf.yml +++ /dev/null @@ -1,31 +0,0 @@ ---- -# for running local-tests without using MySQL for now -testing: - database: sqlite - file: comments.db - pragmas: - journal_mode: wal - cache_size: 64000 - foreign_keys: 0 - ignore_check_constraints: 1 - synchronous: 0 - -# actual database should be running MySQL -production: - charset: utf8mb4 - database: mysql - name: social - user: lbry - password: lbry - host: localhost - port: 3306 - -mode: production -logging: - format: "%(asctime)s | %(levelname)s | %(name)s | %(module)s.%(funcName)s:%(lineno)d - | %(message)s" - aiohttp_format: "%(asctime)s | %(levelname)s | %(name)s | %(message)s" - datefmt: "%Y-%m-%d %H:%M:%S" -host: localhost -port: 5921 -lbrynet: http://localhost:5279 \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml deleted file mode 100644 index e5f9099..0000000 --- a/docker-compose.yml +++ /dev/null @@ -1,25 +0,0 @@ -version: "3.7" -services: - ########### - ## MySQL ## - ########### - mysql: - image: mysql/mysql-server:5.7.27 - restart: "no" - command: --character_set_server=utf8mb4 --max_allowed_packet=1073741824 - ports: - - "3306:3306" - environment: - - MYSQL_ALLOW_EMPTY_PASSWORD=true - - MYSQL_DATABASE=social - - MYSQL_USER=lbry - - MYSQL_PASSWORD=lbry - - MYSQL_LOG_CONSOLE=true - ############# - ## Adminer ## - ############# - adminer: - image: adminer - restart: always - ports: - - 8080:8080 diff --git a/scripts/stress_test.py b/scripts/stress_test.py new file mode 100644 index 0000000..6022396 --- /dev/null +++ b/scripts/stress_test.py @@ -0,0 +1,84 @@ +import sqlite3 +import time + +import faker +from faker.providers import misc + +fake = faker.Faker() +fake.add_provider(misc) + + +if __name__ == '__main__': + song_time = """One, two, three! +My baby don't mess around +'Cause she loves me so +This I know fo sho! +But does she really wanna +But can't stand to see me walk out tha door +Don't try to fight the feeling +Because the thought alone is killin' me right now +Thank God for Mom and Dad +For sticking to together +Like we don't know how +Hey ya! Hey ya! +Hey ya! Hey ya! +Hey ya! Hey ya! +Hey ya! Hey ya! +You think you've got it +Oh, you think you've got it +But got it just don't get it when there's nothin' at all +We get together +Oh, we get together +But separate's always better when there's feelings involved +Know what they say -its +Nothing lasts forever! +Then what makes it, then what makes it +Then what makes it, then what makes it +Then what makes love the exception? +So why, oh, why, oh +Why, oh, why, oh, why, oh +Are we still in denial when we know we're not happy here +Hey ya! (y'all don't want to here me, ya just want to dance) Hey ya! +Don't want to meet your daddy (oh ohh), just want you in my caddy (oh ohh) +Hey ya! (oh, oh!) Hey ya! (oh, oh!) +Don't want to meet your momma, just want to make you cum-a (oh, oh!) +I'm (oh, oh) I'm (oh, oh) I'm just being honest! (oh, oh) +I'm just being honest! +Hey! alright now! alright now, fellas! +Yea? +Now, what cooler than being cool? +Ice cold! +I can't hear ya! I say what's, what's cooler than being cool? +Ice cold! +Alright alright alright alright alright alright alright alright alright alright alright alright alright alright alright alright! +Okay, now ladies! +Yea? +Now we gonna break this thang down for just a few seconds +Now don't have me break this thang down for nothin' +I want to see you on your badest behavior! +Lend me some sugar, I am your neighbor! +Ah! Here we go now, +Shake it, shake it, shake it, shake it, shake it +Shake it, shake it, shake it, shake it +Shake it like a Polaroid picture! Hey ya! +Shake it, shake it, shake it, shake it, shake it +Shake it, shake it, shake it, suga! +Shake it like a Polaroid picture! +Now all the Beyonce's, and Lucy Lu's, and baby dolls +Get on tha floor get on tha floor! +Shake it like a Polaroid picture! +Oh, you! oh, you! +Hey ya!(oh, oh) Hey ya!(oh, oh) +Hey ya!(oh, oh) Hey ya!(oh, oh) +Hey ya!(oh, oh) Hey ya!(oh, oh)""" + + song = song_time.split('\n') + claim_id = '2aa106927b733e2602ffb565efaccc78c2ed89df' + run_len = [(fake.sha256(), song_time, claim_id, str(int(time.time()))) for k in range(5000)] + + conn = sqlite3.connect('database/default_test.db') + with conn: + curs = conn.executemany(""" + INSERT INTO COMMENT(CommentId, Body, LbryClaimId, Timestamp) VALUES (?, ?, ?, ?) + """, run_len) + print(f'rows changed: {curs.rowcount}') diff --git a/scripts/valid_signatures.py b/scripts/valid_signatures.py index 4b7c641..19fc878 100644 --- a/scripts/valid_signatures.py +++ b/scripts/valid_signatures.py @@ -2,12 +2,11 @@ import binascii import logging import hashlib import json -# todo: remove sqlite3 as a dependency import sqlite3 import asyncio import aiohttp -from src.server.validation import is_signature_valid, get_encoded_signature +from server.validation import is_signature_valid, get_encoded_signature logger = logging.getLogger(__name__) diff --git a/setup.py b/setup.py index f7a397f..923e82a 100644 --- a/setup.py +++ b/setup.py @@ -14,17 +14,15 @@ setup( data_files=[('config', ['config/conf.json',])], include_package_data=True, install_requires=[ - 'pymysql', - 'pyyaml', 'Faker>=1.0.7', - 'asyncio', - 'aiohttp', - 'aiojobs', + 'asyncio>=3.4.3', + 'aiohttp==3.5.4', + 'aiojobs==0.2.2', 'ecdsa>=0.13.3', 'cryptography==2.5', + 'aiosqlite==0.10.0', 'PyNaCl>=1.3.0', 'requests', 'cython', - 'peewee' ] ) diff --git a/src/database/comments_ddl.sql b/src/database/comments_ddl.sql index f488406..2dfc19d 100644 --- a/src/database/comments_ddl.sql +++ b/src/database/comments_ddl.sql @@ -1,50 +1,76 @@ -USE `social`; -ALTER DATABASE `social` - DEFAULT CHARACTER SET utf8mb4 - DEFAULT COLLATE utf8mb4_unicode_ci; +PRAGMA FOREIGN_KEYS = ON; -DROP TABLE IF EXISTS `CHANNEL`; -CREATE TABLE `CHANNEL` ( - `claimid` VARCHAR(40) NOT NULL, - `name` CHAR(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci NOT NULL, - CONSTRAINT `channel_pk` PRIMARY KEY (`claimid`) - ) -CHARACTER SET utf8mb4 -COLLATE utf8mb4_unicode_ci; +-- Although I know this file is unnecessary, I like keeping it around. -DROP TABLE IF EXISTS `COMMENT`; -CREATE TABLE `COMMENT` ( - -- should be changed to CHAR(64) - `commentid` CHAR(64) NOT NULL, - -- should be changed to CHAR(40) - `lbryclaimid` CHAR(40) NOT NULL, - -- can be null, so idk if this should be char(40) - `channelid` CHAR(40) DEFAULT NULL, - `body` TEXT - CHARACTER SET utf8mb4 - COLLATE utf8mb4_unicode_ci - NOT NULL, - `parentid` CHAR(64) DEFAULT NULL, - `signature` CHAR(128) DEFAULT NULL, - -- 22 chars long is prolly enough - `signingts` VARCHAR(22) DEFAULT NULL, +-- I'm not gonna remove it. - `timestamp` INTEGER NOT NULL, - -- there's no way that the timestamp will ever reach 22 characters - `ishidden` BOOLEAN DEFAULT FALSE, - CONSTRAINT `COMMENT_PRIMARY_KEY` PRIMARY KEY (`commentid`) - -- setting null implies comment is top level - ) -CHARACTER SET utf8mb4 -COLLATE utf8mb4_unicode_ci; +-- tables +CREATE TABLE IF NOT EXISTS COMMENT +( + CommentId TEXT NOT NULL, + LbryClaimId TEXT NOT NULL, + ChannelId TEXT DEFAULT NULL, + Body TEXT NOT NULL, + ParentId TEXT DEFAULT NULL, + Signature TEXT DEFAULT NULL, + Timestamp INTEGER NOT NULL, + SigningTs TEXT DEFAULT NULL, + IsHidden BOOLEAN NOT NULL DEFAULT FALSE, + CONSTRAINT COMMENT_PRIMARY_KEY PRIMARY KEY (CommentId) ON CONFLICT IGNORE, + CONSTRAINT COMMENT_SIGNATURE_SK UNIQUE (Signature) ON CONFLICT ABORT, + CONSTRAINT COMMENT_CHANNEL_FK FOREIGN KEY (ChannelId) REFERENCES CHANNEL (ClaimId) + ON DELETE NO ACTION ON UPDATE NO ACTION, + CONSTRAINT COMMENT_PARENT_FK FOREIGN KEY (ParentId) REFERENCES COMMENT (CommentId) + ON UPDATE CASCADE ON DELETE NO ACTION -- setting null implies comment is top level +); + +-- ALTER TABLE COMMENT ADD COLUMN IsHidden BOOLEAN DEFAULT (FALSE); +-- ALTER TABLE COMMENT ADD COLUMN SigningTs TEXT DEFAULT NULL; + +-- DROP TABLE IF EXISTS CHANNEL; +CREATE TABLE IF NOT EXISTS CHANNEL +( + ClaimId TEXT NOT NULL, + Name TEXT NOT NULL, + CONSTRAINT CHANNEL_PK PRIMARY KEY (ClaimId) + ON CONFLICT IGNORE +); -ALTER TABLE COMMENT - ADD CONSTRAINT `comment_channel_fk` FOREIGN KEY (`channelid`) REFERENCES `CHANNEL` (`claimid`) - ON DELETE CASCADE ON UPDATE CASCADE, - ADD CONSTRAINT `comment_parent_fk` FOREIGN KEY (`parentid`) REFERENCES `COMMENT` (`commentid`) - ON UPDATE CASCADE ON DELETE CASCADE -; +-- indexes +-- DROP INDEX IF EXISTS COMMENT_CLAIM_INDEX; +-- CREATE INDEX IF NOT EXISTS CLAIM_COMMENT_INDEX ON COMMENT (LbryClaimId, CommentId); -CREATE INDEX `claim_comment_index` ON `COMMENT` (`lbryclaimid`, `commentid`); -CREATE INDEX `channel_comment_index` ON `COMMENT` (`channelid`, `commentid`); +-- CREATE INDEX IF NOT EXISTS CHANNEL_COMMENT_INDEX ON COMMENT (ChannelId, CommentId); + +-- VIEWS +CREATE VIEW IF NOT EXISTS COMMENTS_ON_CLAIMS AS +SELECT C.CommentId AS comment_id, + C.Body AS comment, + C.LbryClaimId AS claim_id, + C.Timestamp AS timestamp, + CHAN.Name AS channel_name, + CHAN.ClaimId AS channel_id, + ('lbry://' || CHAN.Name || '#' || CHAN.ClaimId) AS channel_url, + C.Signature AS signature, + C.SigningTs AS signing_ts, + C.ParentId AS parent_id, + C.IsHidden AS is_hidden +FROM COMMENT AS C + LEFT OUTER JOIN CHANNEL CHAN ON C.ChannelId = CHAN.ClaimId +ORDER BY C.Timestamp DESC; + + +DROP VIEW IF EXISTS COMMENT_REPLIES; +CREATE VIEW IF NOT EXISTS COMMENT_REPLIES (Author, CommentBody, ParentAuthor, ParentCommentBody) AS +SELECT AUTHOR.Name, OG.Body, PCHAN.Name, PARENT.Body +FROM COMMENT AS OG + JOIN COMMENT AS PARENT + ON OG.ParentId = PARENT.CommentId + JOIN CHANNEL AS PCHAN ON PARENT.ChannelId = PCHAN.ClaimId + JOIN CHANNEL AS AUTHOR ON OG.ChannelId = AUTHOR.ClaimId +ORDER BY OG.Timestamp; + +-- this is the default channel for anyone who wants to publish anonymously +-- INSERT INTO CHANNEL +-- VALUES ('9cb713f01bf247a0e03170b5ed00d5161340c486', '@Anonymous'); diff --git a/src/database/models.py b/src/database/models.py deleted file mode 100644 index b35f122..0000000 --- a/src/database/models.py +++ /dev/null @@ -1,217 +0,0 @@ -import json -import time - -import logging -import math -import typing - -from peewee import * -import nacl.hash - -from src.server.validation import is_valid_base_comment -from src.misc import clean - - -class Channel(Model): - claim_id = FixedCharField(column_name='claimid', primary_key=True, max_length=40) - name = CharField(column_name='name', max_length=255) - - class Meta: - table_name = 'CHANNEL' - - -class Comment(Model): - comment = TextField(column_name='body') - channel = ForeignKeyField( - backref='comments', - column_name='channelid', - field='claim_id', - model=Channel, - null=True - ) - comment_id = FixedCharField(column_name='commentid', primary_key=True, max_length=64) - is_hidden = BooleanField(column_name='ishidden', constraints=[SQL("DEFAULT 0")]) - claim_id = FixedCharField(max_length=40, column_name='lbryclaimid') - parent = ForeignKeyField( - column_name='ParentId', - field='comment_id', - model='self', - null=True, - backref='replies' - ) - signature = FixedCharField(max_length=128, column_name='signature', null=True, unique=True) - signing_ts = TextField(column_name='signingts', null=True) - timestamp = IntegerField(column_name='timestamp') - - class Meta: - table_name = 'COMMENT' - indexes = ( - (('channel', 'comment_id'), False), - (('claim_id', 'comment_id'), False), - ) - - -FIELDS = { - 'comment': Comment.comment, - 'comment_id': Comment.comment_id, - 'claim_id': Comment.claim_id, - 'timestamp': Comment.timestamp, - 'signature': Comment.signature, - 'signing_ts': Comment.signing_ts, - 'is_hidden': Comment.is_hidden, - 'parent_id': Comment.parent.alias('parent_id'), - 'channel_id': Channel.claim_id.alias('channel_id'), - 'channel_name': Channel.name.alias('channel_name'), - 'channel_url': ('lbry://' + Channel.name + '#' + Channel.claim_id).alias('channel_url') -} - - -def comment_list(claim_id: str = None, parent_id: str = None, - top_level: bool = False, exclude_mode: str = None, - page: int = 1, page_size: int = 50, expressions=None, - select_fields: list = None, exclude_fields: list = None) -> dict: - fields = FIELDS.keys() - if exclude_fields: - fields -= set(exclude_fields) - if select_fields: - fields &= set(select_fields) - attributes = [FIELDS[field] for field in fields] - query = Comment.select(*attributes) - - # todo: allow this process to be more automated, so it can just be an expression - if claim_id: - query = query.where(Comment.claim_id == claim_id) - if top_level: - query = query.where(Comment.parent.is_null()) - - if parent_id: - query = query.where(Comment.ParentId == parent_id) - - if exclude_mode: - show_hidden = exclude_mode.lower() == 'hidden' - query = query.where((Comment.is_hidden == show_hidden)) - - if expressions: - query = query.where(expressions) - - total = query.count() - query = (query - .join(Channel, JOIN.LEFT_OUTER) - .order_by(Comment.timestamp.desc()) - .paginate(page, page_size)) - items = [clean(item) for item in query.dicts()] - # has_hidden_comments is deprecated - data = { - 'page': page, - 'page_size': page_size, - 'total_pages': math.ceil(total / page_size), - 'total_items': total, - 'items': items, - 'has_hidden_comments': exclude_mode is not None and exclude_mode == 'hidden', - } - return data - - -def get_comment(comment_id: str) -> dict: - try: - comment = comment_list(expressions=(Comment.comment_id == comment_id), page_size=1).get('items').pop() - except IndexError: - raise ValueError(f'Comment does not exist with id {comment_id}') - else: - return comment - - -def create_comment_id(comment: str, channel_id: str, timestamp: int): - # We convert the timestamp from seconds into minutes - # to prevent spammers from commenting the same BS everywhere. - nearest_minute = str(math.floor(timestamp / 60)) - - # don't use claim_id for the comment_id anymore so comments - # are not unique to just one claim - prehash = b':'.join([ - comment.encode(), - channel_id.encode(), - nearest_minute.encode() - ]) - return nacl.hash.sha256(prehash).decode() - - -def create_comment(comment: str = None, claim_id: str = None, - parent_id: str = None, channel_id: str = None, - channel_name: str = None, signature: str = None, - signing_ts: str = None) -> dict: - if not is_valid_base_comment( - comment=comment, - claim_id=claim_id, - parent_id=parent_id, - channel_id=channel_id, - channel_name=channel_name, - signature=signature, - signing_ts=signing_ts - ): - raise ValueError('Invalid Parameters given for comment') - - channel, _ = Channel.get_or_create(name=channel_name, claim_id=channel_id) - if parent_id and not claim_id: - parent: Comment = Comment.get_by_id(parent_id) - claim_id = parent.claim_id - - timestamp = int(time.time()) - comment_id = create_comment_id(comment, channel_id, timestamp) - new_comment = Comment.create( - claim_id=claim_id, - comment_id=comment_id, - comment=comment, - parent=parent_id, - channel=channel, - signature=signature, - signing_ts=signing_ts, - timestamp=timestamp - ) - return get_comment(new_comment.comment_id) - - -def delete_comment(comment_id: str) -> bool: - try: - comment: Comment = Comment.get_by_id(comment_id) - except DoesNotExist as e: - raise ValueError from e - else: - return 0 < comment.delete_instance(True, delete_nullable=True) - - -def edit_comment(comment_id: str, new_comment: str, new_sig: str, new_ts: str) -> bool: - try: - comment: Comment = Comment.get_by_id(comment_id) - except DoesNotExist as e: - raise ValueError from e - else: - comment.comment = new_comment - comment.signature = new_sig - comment.signing_ts = new_ts - - # todo: add a 'last-modified' timestamp - comment.timestamp = int(time.time()) - return comment.save() > 0 - - -def set_hidden_flag(comment_ids: typing.List[str], hidden=True) -> bool: - # sets `is_hidden` flag for all `comment_ids` to the `hidden` param - update = (Comment - .update(is_hidden=hidden) - .where(Comment.comment_id.in_(comment_ids))) - return update.execute() > 0 - - -if __name__ == '__main__': - logger = logging.getLogger('peewee') - logger.addHandler(logging.StreamHandler()) - logger.setLevel(logging.DEBUG) - - comments = comment_list( - page_size=20, - expressions=((Comment.timestamp < 1583272089) & - (Comment.claim_id ** '420%')) - ) - - print(json.dumps(comments, indent=4)) diff --git a/src/database/queries.py b/src/database/queries.py new file mode 100644 index 0000000..16e6f0b --- /dev/null +++ b/src/database/queries.py @@ -0,0 +1,296 @@ +import atexit +import logging +import math +import sqlite3 +import time +import typing + +import nacl.hash + +from src.database.schema import CREATE_TABLES_QUERY + +logger = logging.getLogger(__name__) + +SELECT_COMMENTS_ON_CLAIMS = """ + SELECT comment, comment_id, channel_name, channel_id, channel_url, + timestamp, signature, signing_ts, parent_id, is_hidden + FROM COMMENTS_ON_CLAIMS +""" + +SELECT_COMMENTS_ON_CLAIMS_CLAIMID = """ + SELECT comment, comment_id, claim_id, channel_name, channel_id, channel_url, + timestamp, signature, signing_ts, parent_id, is_hidden + FROM COMMENTS_ON_CLAIMS +""" + + +def clean(thing: dict) -> dict: + if 'is_hidden' in thing: + thing.update({'is_hidden': bool(thing['is_hidden'])}) + return {k: v for k, v in thing.items() if v is not None} + + +def obtain_connection(filepath: str = None, row_factory: bool = True): + connection = sqlite3.connect(filepath) + if row_factory: + connection.row_factory = sqlite3.Row + return connection + + +def get_claim_comments(conn: sqlite3.Connection, claim_id: str, parent_id: str = None, + page: int = 1, page_size: int = 50, top_level=False): + with conn: + if top_level: + results = [clean(dict(row)) for row in conn.execute( + SELECT_COMMENTS_ON_CLAIMS + " WHERE claim_id = ? AND parent_id IS NULL LIMIT ? OFFSET ?", + (claim_id, page_size, page_size * (page - 1)) + )] + count = conn.execute( + "SELECT COUNT(*) FROM COMMENTS_ON_CLAIMS WHERE claim_id = ? AND parent_id IS NULL", + (claim_id,) + ) + elif parent_id is None: + results = [clean(dict(row)) for row in conn.execute( + SELECT_COMMENTS_ON_CLAIMS + "WHERE claim_id = ? LIMIT ? OFFSET ? ", + (claim_id, page_size, page_size * (page - 1)) + )] + count = conn.execute( + "SELECT COUNT(*) FROM COMMENTS_ON_CLAIMS WHERE claim_id = ?", + (claim_id,) + ) + else: + results = [clean(dict(row)) for row in conn.execute( + SELECT_COMMENTS_ON_CLAIMS + "WHERE claim_id = ? AND parent_id = ? LIMIT ? OFFSET ? ", + (claim_id, parent_id, page_size, page_size * (page - 1)) + )] + count = conn.execute( + "SELECT COUNT(*) FROM COMMENTS_ON_CLAIMS WHERE claim_id = ? AND parent_id = ?", + (claim_id, parent_id) + ) + count = tuple(count.fetchone())[0] + return { + 'items': results, + 'page': page, + 'page_size': page_size, + 'total_pages': math.ceil(count / page_size), + 'total_items': count, + 'has_hidden_comments': claim_has_hidden_comments(conn, claim_id) + } + + +def get_claim_hidden_comments(conn: sqlite3.Connection, claim_id: str, hidden=True, page=1, page_size=50): + with conn: + results = conn.execute( + SELECT_COMMENTS_ON_CLAIMS + "WHERE claim_id = ? AND is_hidden IS ? LIMIT ? OFFSET ?", + (claim_id, hidden, page_size, page_size * (page - 1)) + ) + count = conn.execute( + "SELECT COUNT(*) FROM COMMENTS_ON_CLAIMS WHERE claim_id = ? AND is_hidden IS ?", (claim_id, hidden) + ) + results = [clean(dict(row)) for row in results.fetchall()] + count = tuple(count.fetchone())[0] + + return { + 'items': results, + 'page': page, + 'page_size': page_size, + 'total_pages': math.ceil(count / page_size), + 'total_items': count, + 'has_hidden_comments': claim_has_hidden_comments(conn, claim_id) + } + + +def claim_has_hidden_comments(conn, claim_id): + with conn: + result = conn.execute( + "SELECT COUNT(DISTINCT is_hidden) FROM COMMENTS_ON_CLAIMS WHERE claim_id = ? AND is_hidden IS 1", + (claim_id,) + ) + return bool(tuple(result.fetchone())[0]) + + +def insert_comment(conn: sqlite3.Connection, claim_id: str, comment: str, + channel_id: str = None, signature: str = None, signing_ts: str = None, **extra) -> str: + timestamp = int(time.time()) + prehash = b':'.join((claim_id.encode(), comment.encode(), str(timestamp).encode(),)) + comment_id = nacl.hash.sha256(prehash).decode() + with conn: + curs = conn.execute( + """ + INSERT INTO COMMENT(CommentId, LbryClaimId, ChannelId, Body, ParentId, + Timestamp, Signature, SigningTs, IsHidden) + VALUES (:comment_id, :claim_id, :channel_id, :comment, NULL, + :timestamp, :signature, :signing_ts, 0) """, + { + 'comment_id': comment_id, + 'claim_id': claim_id, + 'channel_id': channel_id, + 'comment': comment, + 'timestamp': timestamp, + 'signature': signature, + 'signing_ts': signing_ts + } + ) + logging.info('attempted to insert comment with comment_id [%s] | %d rows affected', comment_id, curs.rowcount) + return comment_id + + +def insert_reply(conn: sqlite3.Connection, comment: str, parent_id: str, + channel_id: str = None, signature: str = None, + signing_ts: str = None, **extra) -> str: + timestamp = int(time.time()) + prehash = b':'.join((parent_id.encode(), comment.encode(), str(timestamp).encode(),)) + comment_id = nacl.hash.sha256(prehash).decode() + with conn: + curs = conn.execute( + """ + INSERT INTO COMMENT + (CommentId, LbryClaimId, ChannelId, Body, ParentId, Signature, Timestamp, SigningTs, IsHidden) + SELECT :comment_id, LbryClaimId, :channel_id, :comment, :parent_id, :signature, :timestamp, :signing_ts, 0 + FROM COMMENT WHERE CommentId = :parent_id + """, { + 'comment_id': comment_id, + 'parent_id': parent_id, + 'timestamp': timestamp, + 'comment': comment, + 'channel_id': channel_id, + 'signature': signature, + 'signing_ts': signing_ts + } + ) + logging.info('attempted to insert reply with comment_id [%s] | %d rows affected', comment_id, curs.rowcount) + return comment_id + + +def get_comment_or_none(conn: sqlite3.Connection, comment_id: str) -> dict: + with conn: + curry = conn.execute(SELECT_COMMENTS_ON_CLAIMS_CLAIMID + "WHERE comment_id = ?", (comment_id,)) + thing = curry.fetchone() + return clean(dict(thing)) if thing else None + + +def get_comment_ids(conn: sqlite3.Connection, claim_id: str, parent_id: str = None, page=1, page_size=50): + """ Just return a list of the comment IDs that are associated with the given claim_id. + If get_all is specified then it returns all the IDs, otherwise only the IDs at that level. + if parent_id is left null then it only returns the top level comments. + + For pagination the parameters are: + get_all XOR (page_size + page) + """ + with conn: + if parent_id is None: + curs = conn.execute(""" + SELECT comment_id FROM COMMENTS_ON_CLAIMS + WHERE claim_id = ? AND parent_id IS NULL LIMIT ? OFFSET ? + """, (claim_id, page_size, page_size * abs(page - 1),) + ) + else: + curs = conn.execute(""" + SELECT comment_id FROM COMMENTS_ON_CLAIMS + WHERE claim_id = ? AND parent_id = ? LIMIT ? OFFSET ? + """, (claim_id, parent_id, page_size, page_size * abs(page - 1),) + ) + return [tuple(row)[0] for row in curs.fetchall()] + + +def get_comments_by_id(conn, comment_ids: typing.Union[list, tuple]) -> typing.Union[list, None]: + """ Returns a list containing the comment data associated with each ID within the list""" + # format the input, under the assumption that the + placeholders = ', '.join('?' for _ in comment_ids) + with conn: + return [clean(dict(row)) for row in conn.execute( + SELECT_COMMENTS_ON_CLAIMS_CLAIMID + f'WHERE comment_id IN ({placeholders})', + tuple(comment_ids) + )] + + +def delete_comment_by_id(conn: sqlite3.Connection, comment_id: str) -> bool: + with conn: + curs = conn.execute("DELETE FROM COMMENT WHERE CommentId = ?", (comment_id,)) + return bool(curs.rowcount) + + +def insert_channel(conn: sqlite3.Connection, channel_name: str, channel_id: str): + with conn: + curs = conn.execute('INSERT INTO CHANNEL(ClaimId, Name) VALUES (?, ?)', (channel_id, channel_name)) + return bool(curs.rowcount) + + +def get_channel_id_from_comment_id(conn: sqlite3.Connection, comment_id: str): + with conn: + channel = conn.execute( + "SELECT channel_id, channel_name FROM COMMENTS_ON_CLAIMS WHERE comment_id = ?", (comment_id,) + ).fetchone() + return dict(channel) if channel else {} + + +def get_claim_ids_from_comment_ids(conn: sqlite3.Connection, comment_ids: list): + with conn: + cids = conn.execute( + f""" SELECT CommentId as comment_id, LbryClaimId AS claim_id FROM COMMENT + WHERE CommentId IN ({', '.join('?' for _ in comment_ids)}) """, + tuple(comment_ids) + ) + return {row['comment_id']: row['claim_id'] for row in cids.fetchall()} + + +def hide_comments_by_id(conn: sqlite3.Connection, comment_ids: list) -> bool: + with conn: + curs = conn.cursor() + curs.executemany( + "UPDATE COMMENT SET IsHidden = 1 WHERE CommentId = ?", + [[c] for c in comment_ids] + ) + return bool(curs.rowcount) + + +def edit_comment_by_id(conn: sqlite3.Connection, comment_id: str, comment: str, + signature: str, signing_ts: str) -> bool: + with conn: + curs = conn.execute( + """ + UPDATE COMMENT + SET Body = :comment, Signature = :signature, SigningTs = :signing_ts + WHERE CommentId = :comment_id + """, + { + 'comment': comment, + 'signature': signature, + 'signing_ts': signing_ts, + 'comment_id': comment_id + }) + logger.info("updated comment with `comment_id`: %s", comment_id) + return bool(curs.rowcount) + + +class DatabaseWriter(object): + _writer = None + + def __init__(self, db_file): + if not DatabaseWriter._writer: + self.conn = obtain_connection(db_file) + DatabaseWriter._writer = self + atexit.register(self.cleanup) + logging.info('Database writer has been created at %s', repr(self)) + else: + logging.warning('Someone attempted to insantiate DatabaseWriter') + raise TypeError('Database Writer already exists!') + + def cleanup(self): + logging.info('Cleaning up database writer') + self.conn.close() + DatabaseWriter._writer = None + + @property + def connection(self): + return self.conn + + +def setup_database(db_path): + with sqlite3.connect(db_path) as conn: + conn.executescript(CREATE_TABLES_QUERY) + + +def backup_database(conn: sqlite3.Connection, back_fp): + with sqlite3.connect(back_fp) as back: + conn.backup(back) diff --git a/src/database/schema.py b/src/database/schema.py new file mode 100644 index 0000000..c75681b --- /dev/null +++ b/src/database/schema.py @@ -0,0 +1,76 @@ +PRAGMAS = """ + PRAGMA FOREIGN_KEYS = ON; +""" + +CREATE_COMMENT_TABLE = """ + CREATE TABLE IF NOT EXISTS COMMENT ( + CommentId TEXT NOT NULL, + LbryClaimId TEXT NOT NULL, + ChannelId TEXT DEFAULT NULL, + Body TEXT NOT NULL, + ParentId TEXT DEFAULT NULL, + Signature TEXT DEFAULT NULL, + Timestamp INTEGER NOT NULL, + SigningTs TEXT DEFAULT NULL, + IsHidden BOOLEAN NOT NULL DEFAULT 0, + CONSTRAINT COMMENT_PRIMARY_KEY PRIMARY KEY (CommentId) ON CONFLICT IGNORE, + CONSTRAINT COMMENT_SIGNATURE_SK UNIQUE (Signature) ON CONFLICT ABORT, + CONSTRAINT COMMENT_CHANNEL_FK FOREIGN KEY (ChannelId) REFERENCES CHANNEL (ClaimId) + ON DELETE NO ACTION ON UPDATE NO ACTION, + CONSTRAINT COMMENT_PARENT_FK FOREIGN KEY (ParentId) REFERENCES COMMENT (CommentId) + ON UPDATE CASCADE ON DELETE NO ACTION -- setting null implies comment is top level + ); +""" + +CREATE_COMMENT_INDEXES = """ + CREATE INDEX IF NOT EXISTS CLAIM_COMMENT_INDEX ON COMMENT (LbryClaimId, CommentId); + CREATE INDEX IF NOT EXISTS CHANNEL_COMMENT_INDEX ON COMMENT (ChannelId, CommentId); +""" + +CREATE_CHANNEL_TABLE = """ + CREATE TABLE IF NOT EXISTS CHANNEL ( + ClaimId TEXT NOT NULL, + Name TEXT NOT NULL, + CONSTRAINT CHANNEL_PK PRIMARY KEY (ClaimId) + ON CONFLICT IGNORE + ); +""" + +CREATE_COMMENTS_ON_CLAIMS_VIEW = """ + CREATE VIEW IF NOT EXISTS COMMENTS_ON_CLAIMS AS SELECT + C.CommentId AS comment_id, + C.Body AS comment, + C.LbryClaimId AS claim_id, + C.Timestamp AS timestamp, + CHAN.Name AS channel_name, + CHAN.ClaimId AS channel_id, + ('lbry://' || CHAN.Name || '#' || CHAN.ClaimId) AS channel_url, + C.Signature AS signature, + C.SigningTs AS signing_ts, + C.ParentId AS parent_id, + C.IsHidden AS is_hidden + FROM COMMENT AS C + LEFT OUTER JOIN CHANNEL CHAN ON C.ChannelId = CHAN.ClaimId + ORDER BY C.Timestamp DESC; +""" + +# not being used right now but should be kept around when Tom finally asks for replies +CREATE_COMMENT_REPLIES_VIEW = """ +CREATE VIEW IF NOT EXISTS COMMENT_REPLIES (Author, CommentBody, ParentAuthor, ParentCommentBody) AS +SELECT AUTHOR.Name, OG.Body, PCHAN.Name, PARENT.Body +FROM COMMENT AS OG + JOIN COMMENT AS PARENT + ON OG.ParentId = PARENT.CommentId + JOIN CHANNEL AS PCHAN ON PARENT.ChannelId = PCHAN.ClaimId + JOIN CHANNEL AS AUTHOR ON OG.ChannelId = AUTHOR.ClaimId +ORDER BY OG.Timestamp; +""" + +CREATE_TABLES_QUERY = ( + PRAGMAS + + CREATE_COMMENT_TABLE + + CREATE_COMMENT_INDEXES + + CREATE_CHANNEL_TABLE + + CREATE_COMMENTS_ON_CLAIMS_VIEW + + CREATE_COMMENT_REPLIES_VIEW +) diff --git a/src/database/writes.py b/src/database/writes.py index b86e31b..e724521 100644 --- a/src/database/writes.py +++ b/src/database/writes.py @@ -1,4 +1,3 @@ -# TODO: scrap notification routines from these files & supply them in handles import logging import sqlite3 from asyncio import coroutine @@ -8,7 +7,7 @@ from src.server.validation import is_valid_base_comment from src.server.validation import is_valid_credential_input from src.server.validation import validate_signature_from_claim from src.server.validation import body_is_valid -from src.misc import get_claim_from_id +from src.server.misc import get_claim_from_id from src.server.external import send_notifications from src.server.external import send_notification import src.database.queries as db @@ -19,7 +18,8 @@ logger = logging.getLogger(__name__) def create_comment_or_error(conn, comment, claim_id=None, channel_id=None, channel_name=None, signature=None, signing_ts=None, parent_id=None) -> dict: - insert_channel_or_error(conn, channel_name, channel_id) + if channel_id and channel_name: + insert_channel_or_error(conn, channel_name, channel_id) fn = db.insert_comment if parent_id is None else db.insert_reply comment_id = fn( conn=conn, @@ -65,7 +65,7 @@ async def _abandon_comment(app, comment_id): # DELETE async def create_comment(app, params): - if is_valid_base_comment(**params): + if is_valid_base_comment(**params) and is_valid_credential_input(**params): job = await app['comment_scheduler'].spawn(_create_comment(app, params)) comment = await job.wait() if comment: @@ -85,17 +85,10 @@ async def hide_comments(app, pieces: list) -> list: # TODO: Amortize this process claims = {} comments_to_hide = [] - # go through a list of dict objects for p in pieces: - # maps the comment_id from the piece to a claim_id claim_id = comment_cids[p['comment_id']] - # resolve the claim from its id if claim_id not in claims: - claim = await get_claim_from_id(app, claim_id) - if claim: - claims[claim_id] = claim - - # get the claim's signing channel, then use it to validate the hidden comment + claims[claim_id] = await get_claim_from_id(app, claim_id, no_totals=True) channel = claims[claim_id].get('signing_channel') if validate_signature_from_claim(channel, p['signature'], p['signing_ts'], p['comment_id']): comments_to_hide.append(p) @@ -107,6 +100,7 @@ async def hide_comments(app, pieces: list) -> list: app, 'UPDATE', db.get_comments_by_id(app['reader'], comment_ids) ) ) + await job.wait() return comment_ids diff --git a/src/definitions.py b/src/definitions.py deleted file mode 100644 index 51654e1..0000000 --- a/src/definitions.py +++ /dev/null @@ -1,7 +0,0 @@ -import os - -SRC_DIR = os.path.dirname(os.path.abspath(__file__)) -ROOT_DIR = os.path.dirname(SRC_DIR) -CONFIG_FILE = os.path.join(ROOT_DIR, 'config', 'conf.yml') -LOGGING_DIR = os.path.join(ROOT_DIR, 'logs') -DATABASE_DIR = os.path.join(ROOT_DIR, 'database') diff --git a/src/main.py b/src/main.py index d4b5a53..b31bcea 100644 --- a/src/main.py +++ b/src/main.py @@ -1,19 +1,13 @@ import argparse -import yaml import logging import logging.config -import os import sys from src.server.app import run_app -from src.definitions import LOGGING_DIR, CONFIG_FILE, DATABASE_DIR +from src.settings import config -def setup_logging_from_config(conf: dict): - # set the logging directory here from the settings file - if not os.path.exists(LOGGING_DIR): - os.mkdir(LOGGING_DIR) - +def config_logging_from_settings(conf): _config = { "version": 1, "disable_existing_loggers": False, @@ -38,7 +32,7 @@ def setup_logging_from_config(conf: dict): "level": "DEBUG", "formatter": "standard", "class": "logging.handlers.RotatingFileHandler", - "filename": os.path.join(LOGGING_DIR, 'debug.log'), + "filename": conf['path']['debug_log'], "maxBytes": 10485760, "backupCount": 5 }, @@ -46,7 +40,7 @@ def setup_logging_from_config(conf: dict): "level": "ERROR", "formatter": "standard", "class": "logging.handlers.RotatingFileHandler", - "filename": os.path.join(LOGGING_DIR, 'error.log'), + "filename": conf['path']['error_log'], "maxBytes": 10485760, "backupCount": 5 }, @@ -54,7 +48,7 @@ def setup_logging_from_config(conf: dict): "level": "NOTSET", "formatter": "aiohttp", "class": "logging.handlers.RotatingFileHandler", - "filename": os.path.join(LOGGING_DIR, 'server.log'), + "filename": conf['path']['server_log'], "maxBytes": 10485760, "backupCount": 5 } @@ -76,42 +70,15 @@ def setup_logging_from_config(conf: dict): logging.config.dictConfig(_config) -def get_config(filepath): - with open(filepath, 'r') as cfile: - config = yaml.load(cfile, Loader=yaml.Loader) - return config - - -def setup_db_from_config(config: dict): - mode = config['mode'] - if config[mode]['database'] == 'sqlite': - if not os.path.exists(DATABASE_DIR): - os.mkdir(DATABASE_DIR) - - config[mode]['db_file'] = os.path.join( - DATABASE_DIR, config[mode]['name'] - ) - - def main(argv=None): argv = argv or sys.argv[1:] parser = argparse.ArgumentParser(description='LBRY Comment Server') parser.add_argument('--port', type=int) - parser.add_argument('--config', type=str) - parser.add_argument('--mode', type=str) args = parser.parse_args(argv) - - config = get_config(CONFIG_FILE) if not args.config else args.config - setup_logging_from_config(config) - - if args.mode: - config['mode'] = args.mode - - setup_db_from_config(config) - + config_logging_from_settings(config) if args.port: config['port'] = args.port - + config_logging_from_settings(config) run_app(config) diff --git a/src/server/app.py b/src/server/app.py index 37e89da..a25787b 100644 --- a/src/server/app.py +++ b/src/server/app.py @@ -1,6 +1,7 @@ # cython: language_level=3 import asyncio import logging +import pathlib import signal import time @@ -8,68 +9,81 @@ import aiojobs import aiojobs.aiohttp from aiohttp import web -from peewee import * +from src.database.queries import obtain_connection, DatabaseWriter +from src.database.queries import setup_database, backup_database from src.server.handles import api_endpoint, get_api_endpoint -from src.database.models import Comment, Channel -MODELS = [Comment, Channel] logger = logging.getLogger(__name__) -def setup_database(app): - config = app['config'] - mode = config['mode'] +async def setup_db_schema(app): + if not pathlib.Path(app['db_path']).exists(): + logger.info(f'Setting up schema in {app["db_path"]}') + setup_database(app['db_path']) + else: + logger.info(f'Database already exists in {app["db_path"]}, skipping setup') - # switch between Database objects - if config[mode]['database'] == 'mysql': - app['db'] = MySQLDatabase( - database=config[mode]['name'], - user=config[mode]['user'], - host=config[mode]['host'], - password=config[mode]['password'], - port=config[mode]['port'], - charset=config[mode]['charset'], - ) - elif config[mode]['database'] == 'sqlite': - app['db'] = SqliteDatabase( - config[mode]['file'], - pragmas=config[mode]['pragmas'] - ) - # bind the Model list to the database - app['db'].bind(MODELS, bind_refs=False, bind_backrefs=False) +async def database_backup_routine(app): + try: + while True: + await asyncio.sleep(app['config']['backup_int']) + with app['reader'] as conn: + logger.debug('backing up database') + backup_database(conn, app['backup']) + except asyncio.CancelledError: + pass async def start_background_tasks(app): - app['db'].connect() - app['db'].create_tables(MODELS) + # Reading the DB + app['reader'] = obtain_connection(app['db_path'], True) + app['waitful_backup'] = asyncio.create_task(database_backup_routine(app)) + + # Scheduler to prevent multiple threads from writing to DB simulataneously + app['comment_scheduler'] = await aiojobs.create_scheduler(limit=1, pending_limit=0) + app['db_writer'] = DatabaseWriter(app['db_path']) + app['writer'] = app['db_writer'].connection # for requesting to external and internal APIs app['webhooks'] = await aiojobs.create_scheduler(pending_limit=0) async def close_database_connections(app): - app['db'].close() + logger.info('Ending background backup loop') + app['waitful_backup'].cancel() + await app['waitful_backup'] + app['reader'].close() + app['writer'].close() + app['db_writer'].cleanup() async def close_schedulers(app): + logger.info('Closing comment_scheduler') + await app['comment_scheduler'].close() + logger.info('Closing scheduler for webhook requests') await app['webhooks'].close() class CommentDaemon: - def __init__(self, config, **kwargs): + def __init__(self, config, db_file=None, backup=None, **kwargs): app = web.Application() - app['config'] = config # configure the config - self.config = config - self.host = config['host'] - self.port = config['port'] + app['config'] = config + self.config = app['config'] - setup_database(app) + # configure the db file + if db_file: + app['db_path'] = db_file + app['backup'] = backup + else: + app['db_path'] = config['path']['database'] + app['backup'] = backup or (app['db_path'] + '.backup') # configure the order of tasks to run during app lifetime + app.on_startup.append(setup_db_schema) app.on_startup.append(start_background_tasks) app.on_shutdown.append(close_schedulers) app.on_cleanup.append(close_database_connections) @@ -91,19 +105,20 @@ class CommentDaemon: await self.app_runner.setup() self.app_site = web.TCPSite( runner=self.app_runner, - host=host or self.host, - port=port or self.port, + host=host or self.config['host'], + port=port or self.config['port'], ) await self.app_site.start() - logger.info(f'Comment Server is running on {self.host}:{self.port}') + logger.info(f'Comment Server is running on {self.config["host"]}:{self.config["port"]}') async def stop(self): await self.app_runner.shutdown() await self.app_runner.cleanup() -def run_app(config): - comment_app = CommentDaemon(config=config) +def run_app(config, db_file=None): + comment_app = CommentDaemon(config=config, db_file=db_file, close_timeout=5.0) + loop = asyncio.get_event_loop() def __exit(): diff --git a/src/server/errors.py b/src/server/errors.py index 644ce9e..c0b8af2 100644 --- a/src/server/errors.py +++ b/src/server/errors.py @@ -1,5 +1,3 @@ -import json - import logging import aiohttp @@ -31,17 +29,16 @@ def make_error(error, exc=None) -> dict: return body -async def report_error(app, exc, body: dict): +async def report_error(app, exc, msg=''): try: if 'slack_webhook' in app['config']: - body_dump = json.dumps(body, indent=4) - exec_name = type(exc).__name__ - exec_body = str(exc) - message = { - "text": f"Got `{exec_name}`: `\n{exec_body}`\n```{body_dump}```" + if msg: + msg = f'"{msg}"' + body = { + "text": f"Got `{type(exc).__name__}`: ```\n{str(exc)}```\n{msg}" } async with aiohttp.ClientSession() as sesh: - async with sesh.post(app['config']['slack_webhook'], json=message) as resp: + async with sesh.post(app['config']['slack_webhook'], json=body) as resp: await resp.wait_for_close() except Exception: diff --git a/src/server/external.py b/src/server/external.py index 25bd634..aa0e7ee 100644 --- a/src/server/external.py +++ b/src/server/external.py @@ -37,10 +37,6 @@ def create_notification_batch(action: str, comments: List[dict]) -> List[dict]: } if comment.get('channel_id'): event['channel_id'] = comment['channel_id'] - if comment.get('parent_id'): - event['parent_id'] = comment['parent_id'] - if comment.get('comment'): - event['comment'] = comment['comment'] events.append(event) return events diff --git a/src/server/handles.py b/src/server/handles.py index a7fe43b..63503f3 100644 --- a/src/server/handles.py +++ b/src/server/handles.py @@ -1,24 +1,16 @@ import asyncio import logging import time -import typing from aiohttp import web from aiojobs.aiohttp import atomic -from peewee import DoesNotExist -from src.server.external import send_notification -from src.server.validation import validate_signature_from_claim -from src.misc import clean_input_params, get_claim_from_id +import src.database.queries as db +from src.database.writes import abandon_comment, create_comment +from src.database.writes import hide_comments +from src.database.writes import edit_comment +from src.server.misc import clean_input_params from src.server.errors import make_error, report_error -from src.database.models import Comment, Channel -from src.database.models import get_comment -from src.database.models import comment_list -from src.database.models import create_comment -from src.database.models import edit_comment -from src.database.models import delete_comment -from src.database.models import set_hidden_flag - logger = logging.getLogger(__name__) @@ -28,198 +20,51 @@ def ping(*args): return 'pong' -def handle_get_channel_from_comment_id(app: web.Application, comment_id: str) -> dict: - comment = get_comment(comment_id) - return { - 'channel_id': comment['channel_id'], - 'channel_name': comment['channel_name'] - } +def handle_get_channel_from_comment_id(app, kwargs: dict): + return db.get_channel_id_from_comment_id(app['reader'], **kwargs) -def handle_get_comment_ids( - app: web.Application, - claim_id: str, - parent_id: str = None, - page: int = 1, - page_size: int = 50, - flattened=False -) -> dict: - results = comment_list( - claim_id=claim_id, - parent_id=parent_id, - top_level=(parent_id is None), - page=page, - page_size=page_size, - select_fields=['comment_id', 'parent_id'] - ) - if flattened: - results.update({ - 'items': [item['comment_id'] for item in results['items']], - 'replies': [(item['comment_id'], item.get('parent_id')) - for item in results['items']] - }) - return results +def handle_get_comment_ids(app, kwargs): + return db.get_comment_ids(app['reader'], **kwargs) -def handle_get_comments_by_id( - app: web.Application, - comment_ids: typing.Union[list, tuple] -) -> dict: - expression = Comment.comment_id.in_(comment_ids) - return comment_list(expressions=expression, page_size=len(comment_ids)) +def handle_get_claim_comments(app, kwargs): + return db.get_claim_comments(app['reader'], **kwargs) -def handle_get_claim_comments( - app: web.Application, - claim_id: str, - parent_id: str = None, - page: int = 1, - page_size: int = 50, - top_level: bool = False -) -> dict: - return comment_list( - claim_id=claim_id, - parent_id=parent_id, - page=page, - page_size=page_size, - top_level=top_level - ) +def handle_get_comments_by_id(app, kwargs): + return db.get_comments_by_id(app['reader'], **kwargs) -def handle_get_claim_hidden_comments( - app: web.Application, - claim_id: str, - hidden: bool, - page: int = 1, - page_size: int = 50, -) -> dict: - exclude = 'hidden' if hidden else 'visible' - return comment_list( - claim_id=claim_id, - exclude_mode=exclude, - page=page, - page_size=page_size - ) +def handle_get_claim_hidden_comments(app, kwargs): + return db.get_claim_hidden_comments(app['reader'], **kwargs) -async def handle_abandon_comment( - app: web.Application, - comment_id: str, - signature: str, - signing_ts: str, - **kwargs, -) -> dict: - comment = get_comment(comment_id) - try: - channel = await get_claim_from_id(app, comment['channel_id']) - except DoesNotExist: - raise ValueError('Could not find a channel associated with the given comment') - else: - if not validate_signature_from_claim(channel, signature, signing_ts, comment_id): - raise ValueError('Abandon signature could not be validated') - await app['webhooks'].spawn(send_notification(app, 'DELETE', comment)) - with app['db'].atomic(): - return { - 'abandoned': delete_comment(comment_id) - } +async def handle_abandon_comment(app, params): + return {'abandoned': await abandon_comment(app, **params)} -async def handle_hide_comments(app: web.Application, pieces: list, hide: bool = True) -> dict: - # let's get all the distinct claim_ids from the list of comment_ids - pieces_by_id = {p['comment_id']: p for p in pieces} - comment_ids = list(pieces_by_id.keys()) - comments = (Comment - .select(Comment.comment_id, Comment.claim_id) - .where(Comment.comment_id.in_(comment_ids)) - .tuples()) - - # resolve the claims and map them to their corresponding comment_ids - claims = {} - for comment_id, claim_id in comments: - try: - # try and resolve the claim, if fails then we mark it as null - # and remove the associated comment from the pieces - if claim_id not in claims: - claims[claim_id] = await get_claim_from_id(app, claim_id) - - # try to get a public key to validate - if claims[claim_id] is None or 'signing_channel' not in claims[claim_id]: - raise ValueError(f'could not get signing channel from claim_id: {claim_id}') - - # try to validate signature - else: - channel = claims[claim_id]['signing_channel'] - piece = pieces_by_id[comment_id] - is_valid_signature = validate_signature_from_claim( - claim=channel, - signature=piece['signature'], - signing_ts=piece['signing_ts'], - data=piece['comment_id'] - ) - if not is_valid_signature: - raise ValueError(f'could not validate signature on comment_id: {comment_id}') - - except ValueError: - # remove the piece from being hidden - pieces_by_id.pop(comment_id) - - # remaining items in pieces_by_id have been able to successfully validate - with app['db'].atomic(): - set_hidden_flag(list(pieces_by_id.keys()), hidden=hide) - - query = Comment.select().where(Comment.comment_id.in_(comment_ids)).objects() - result = { - 'hidden': [c.comment_id for c in query if c.is_hidden], - 'visible': [c.comment_id for c in query if not c.is_hidden], - } - return result +async def handle_hide_comments(app, params): + return {'hidden': await hide_comments(app, **params)} -async def handle_edit_comment(app, comment: str = None, comment_id: str = None, - signature: str = None, signing_ts: str = None, **params) -> dict: - current = get_comment(comment_id) - channel_claim = await get_claim_from_id(app, current['channel_id']) - if not validate_signature_from_claim(channel_claim, signature, signing_ts, comment): - raise ValueError('Signature could not be validated') - - with app['db'].atomic(): - if not edit_comment(comment_id, comment, signature, signing_ts): - raise ValueError('Comment could not be edited') - updated_comment = get_comment(comment_id) - await app['webhooks'].spawn(send_notification(app, 'UPDATE', updated_comment)) - return updated_comment - - -# TODO: retrieve stake amounts for each channel & store in db -async def handle_create_comment(app, comment: str = None, claim_id: str = None, - parent_id: str = None, channel_id: str = None, channel_name: str = None, - signature: str = None, signing_ts: str = None) -> dict: - with app['db'].atomic(): - comment = create_comment( - comment=comment, - claim_id=claim_id, - parent_id=parent_id, - channel_id=channel_id, - channel_name=channel_name, - signature=signature, - signing_ts=signing_ts - ) - await app['webhooks'].spawn(send_notification(app, 'CREATE', comment)) - return comment +async def handle_edit_comment(app, params): + if await edit_comment(app, **params): + return db.get_comment_or_none(app['reader'], params['comment_id']) METHODS = { 'ping': ping, - 'get_claim_comments': handle_get_claim_comments, # this gets used - 'get_claim_hidden_comments': handle_get_claim_hidden_comments, # this gets used + 'get_claim_comments': handle_get_claim_comments, + 'get_claim_hidden_comments': handle_get_claim_hidden_comments, 'get_comment_ids': handle_get_comment_ids, - 'get_comments_by_id': handle_get_comments_by_id, # this gets used - 'get_channel_from_comment_id': handle_get_channel_from_comment_id, # this gets used - 'create_comment': handle_create_comment, # this gets used + 'get_comments_by_id': handle_get_comments_by_id, + 'get_channel_from_comment_id': handle_get_channel_from_comment_id, + 'create_comment': create_comment, 'delete_comment': handle_abandon_comment, - 'abandon_comment': handle_abandon_comment, # this gets used - 'hide_comments': handle_hide_comments, # this gets used - 'edit_comment': handle_edit_comment # this gets used + 'abandon_comment': handle_abandon_comment, + 'hide_comments': handle_hide_comments, + 'edit_comment': handle_edit_comment } @@ -233,19 +78,17 @@ async def process_json(app, body: dict) -> dict: start = time.time() try: if asyncio.iscoroutinefunction(METHODS[method]): - result = await METHODS[method](app, **params) + result = await METHODS[method](app, params) else: - result = METHODS[method](app, **params) - + result = METHODS[method](app, params) + response['result'] = result except Exception as err: - logger.exception(f'Got {type(err).__name__}:\n{err}') + logger.exception(f'Got {type(err).__name__}:') if type(err) in (ValueError, TypeError): # param error, not too important response['error'] = make_error('INVALID_PARAMS', err) else: response['error'] = make_error('INTERNAL', err) - await app['webhooks'].spawn(report_error(app, err, body)) - else: - response['result'] = result + await app['webhooks'].spawn(report_error(app, err)) finally: end = time.time() diff --git a/src/misc.py b/src/server/misc.py similarity index 57% rename from src/misc.py rename to src/server/misc.py index 7c3d4bf..593e61a 100644 --- a/src/misc.py +++ b/src/server/misc.py @@ -8,19 +8,12 @@ ID_LIST = {'claim_id', 'parent_id', 'comment_id', 'channel_id'} async def get_claim_from_id(app, claim_id, **kwargs): - try: - return (await request_lbrynet(app, 'claim_search', claim_id=claim_id, **kwargs))['items'][0] - except IndexError: - return + return (await request_lbrynet(app, 'claim_search', claim_id=claim_id, **kwargs))['items'][0] def clean_input_params(kwargs: dict): for k, v in kwargs.items(): - if type(v) is str and k != 'comment': + if type(v) is str and k is not 'comment': kwargs[k] = v.strip() if k in ID_LIST: kwargs[k] = v.lower() - - -def clean(thing: dict) -> dict: - return {k: v for k, v in thing.items() if v is not None} \ No newline at end of file diff --git a/src/server/validation.py b/src/server/validation.py index 3b251ba..ad99f0c 100644 --- a/src/server/validation.py +++ b/src/server/validation.py @@ -50,48 +50,24 @@ def claim_id_is_valid(claim_id: str) -> bool: return re.fullmatch('([a-z0-9]{40}|[A-Z0-9]{40})', claim_id) is not None -# default to None so params can be treated as kwargs; param count becomes more manageable -def is_valid_base_comment( - comment: str = None, - claim_id: str = None, - parent_id: str = None, - strict: bool = False, - **kwargs, -) -> bool: - try: - assert comment and body_is_valid(comment) - # strict mode assumes that the parent_id might not exist - if strict: - assert claim_id and claim_id_is_valid(claim_id) - assert parent_id is None or comment_id_is_valid(parent_id) - # non-strict removes reference restrictions - else: - assert claim_id or parent_id - if claim_id: - assert claim_id_is_valid(claim_id) - else: - assert comment_id_is_valid(parent_id) - - except AssertionError: - return False - else: - return is_valid_credential_input(**kwargs) +def is_valid_base_comment(comment: str, claim_id: str, parent_id: str = None, **kwargs) -> bool: + return comment is not None and body_is_valid(comment) and \ + ((claim_id is not None and claim_id_is_valid(claim_id)) or + (parent_id is not None and comment_id_is_valid(parent_id))) def is_valid_credential_input(channel_id: str = None, channel_name: str = None, - signature: str = None, signing_ts: str = None) -> bool: - try: - assert None not in (channel_id, channel_name, signature, signing_ts) - assert is_valid_channel(channel_id, channel_name) - assert len(signature) == 128 - assert signing_ts.isalnum() + signature: str = None, signing_ts: str = None, **kwargs) -> bool: + if channel_id or channel_name or signature or signing_ts: + try: + assert channel_id and channel_name and signature and signing_ts + assert is_valid_channel(channel_id, channel_name) + assert len(signature) == 128 + assert signing_ts.isalnum() - return True - - except Exception as e: - logger.exception(f'Failed to validate channel: lbry://{channel_name}#{channel_id}, ' - f'signature: {signature} signing_ts: {signing_ts}') - return False + except Exception: + return False + return True def validate_signature_from_claim(claim: dict, signature: typing.Union[str, bytes], diff --git a/src/settings.py b/src/settings.py new file mode 100644 index 0000000..2d720c2 --- /dev/null +++ b/src/settings.py @@ -0,0 +1,17 @@ +# cython: language_level=3 +import json +import pathlib + +root_dir = pathlib.Path(__file__).parent.parent +config_path = root_dir / 'config' / 'conf.json' + + +def get_config(filepath): + with open(filepath, 'r') as cfile: + conf = json.load(cfile) + for key, path in conf['path'].items(): + conf['path'][key] = str(root_dir / path) + return conf + + +config = get_config(config_path) diff --git a/test/test_database.py b/test/test_database.py index c698ad1..347494d 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -1,13 +1,18 @@ +import unittest + from random import randint import faker from faker.providers import internet from faker.providers import lorem from faker.providers import misc -from src.database.models import create_comment -from src.database.models import delete_comment -from src.database.models import comment_list, get_comment -from src.database.models import set_hidden_flag +from src.database.queries import get_comments_by_id +from src.database.queries import get_comment_ids +from src.database.queries import get_claim_comments +from src.database.queries import get_claim_hidden_comments +from src.database.writes import create_comment_or_error +from src.database.queries import hide_comments_by_id +from src.database.queries import delete_comment_by_id from test.testcase import DatabaseTestCase fake = faker.Faker() @@ -22,58 +27,69 @@ class TestDatabaseOperations(DatabaseTestCase): self.claimId = '529357c3422c6046d3fec76be2358004ba22e340' def test01NamedComments(self): - comment = create_comment( + comment = create_comment_or_error( + conn=self.conn, claim_id=self.claimId, comment='This is a named comment', channel_name='@username', channel_id='529357c3422c6046d3fec76be2358004ba22abcd', - signature='22'*64, + signature=fake.uuid4(), signing_ts='aaa' ) self.assertIsNotNone(comment) self.assertNotIn('parent_in', comment) - previous_id = comment['comment_id'] - reply = create_comment( + reply = create_comment_or_error( + conn=self.conn, claim_id=self.claimId, comment='This is a named response', channel_name='@another_username', channel_id='529357c3422c6046d3fec76be2358004ba224bcd', parent_id=previous_id, - signature='11'*64, + signature=fake.uuid4(), signing_ts='aaa' ) self.assertIsNotNone(reply) self.assertEqual(reply['parent_id'], comment['comment_id']) def test02AnonymousComments(self): - self.assertRaises( - ValueError, - create_comment, + comment = create_comment_or_error( + conn=self.conn, claim_id=self.claimId, comment='This is an ANONYMOUS comment' ) + self.assertIsNotNone(comment) + previous_id = comment['comment_id'] + reply = create_comment_or_error( + conn=self.conn, + claim_id=self.claimId, + comment='This is an unnamed response', + parent_id=previous_id + ) + self.assertIsNotNone(reply) + self.assertEqual(reply['parent_id'], comment['comment_id']) def test03SignedComments(self): - comment = create_comment( + comment = create_comment_or_error( + conn=self.conn, claim_id=self.claimId, comment='I like big butts and i cannot lie', channel_name='@sirmixalot', channel_id='529357c3422c6046d3fec76be2358005ba22abcd', - signature='24'*64, + signature=fake.uuid4(), signing_ts='asdasd' ) self.assertIsNotNone(comment) self.assertIn('signing_ts', comment) - previous_id = comment['comment_id'] - reply = create_comment( + reply = create_comment_or_error( + conn=self.conn, claim_id=self.claimId, comment='This is a LBRY verified response', channel_name='@LBRY', channel_id='529357c3422c6046d3fec76be2358001ba224bcd', parent_id=previous_id, - signature='12'*64, + signature=fake.uuid4(), signing_ts='sfdfdfds' ) self.assertIsNotNone(reply) @@ -82,109 +98,129 @@ class TestDatabaseOperations(DatabaseTestCase): def test04UsernameVariations(self): self.assertRaises( - ValueError, - create_comment, + AssertionError, + callable=create_comment_or_error, + conn=self.conn, claim_id=self.claimId, channel_name='$#(@#$@#$', channel_id='529357c3422c6046d3fec76be2358001ba224b23', - comment='this is an invalid username', - signature='1' * 128, - signing_ts='123' + comment='this is an invalid username' ) - - valid_username = create_comment( + valid_username = create_comment_or_error( + conn=self.conn, claim_id=self.claimId, channel_name='@' + 'a' * 255, channel_id='529357c3422c6046d3fec76be2358001ba224b23', - comment='this is a valid username', - signature='1'*128, - signing_ts='123' + comment='this is a valid username' ) self.assertIsNotNone(valid_username) + self.assertRaises(AssertionError, + callable=create_comment_or_error, + conn=self.conn, + claim_id=self.claimId, + channel_name='@' + 'a' * 256, + channel_id='529357c3422c6046d3fec76be2358001ba224b23', + comment='this username is too long' + ) self.assertRaises( - ValueError, - create_comment, - claim_id=self.claimId, - channel_name='@' + 'a' * 256, - channel_id='529357c3422c6046d3fec76be2358001ba224b23', - comment='this username is too long', - signature='2' * 128, - signing_ts='123' - ) - - self.assertRaises( - ValueError, - create_comment, + AssertionError, + callable=create_comment_or_error, + conn=self.conn, claim_id=self.claimId, channel_name='', channel_id='529357c3422c6046d3fec76be2358001ba224b23', - comment='this username should not default to ANONYMOUS', - signature='3' * 128, - signing_ts='123' + comment='this username should not default to ANONYMOUS' ) - self.assertRaises( - ValueError, - create_comment, + AssertionError, + callable=create_comment_or_error, + conn=self.conn, claim_id=self.claimId, channel_name='@', channel_id='529357c3422c6046d3fec76be2358001ba224b23', - comment='this username is too short', - signature='3' * 128, - signing_ts='123' + comment='this username is too short' ) - def test05HideComments(self): - comm = create_comment( - comment='Comment #1', - claim_id=self.claimId, - channel_id='1'*40, - channel_name='@Doge123', - signature='a'*128, - signing_ts='123' - ) - comment = get_comment(comm['comment_id']) + def test05InsertRandomComments(self): + # TODO: Fix this test into something practical + self.skipTest('This is a bad test') + top_comments, claim_ids = generate_top_comments_random() + total = 0 + success = 0 + for _, comments in top_comments.items(): + for i, comment in enumerate(comments): + with self.subTest(comment=comment): + result = create_comment_or_error(self.conn, **comment) + if result: + success += 1 + comments[i] = result + del comment + total += len(comments) + self.assertLessEqual(success, total) + self.assertGreater(success, 0) + success = 0 + for reply in generate_replies_random(top_comments): + reply_id = create_comment_or_error(self.conn, **reply) + if reply_id: + success += 1 + self.assertGreater(success, 0) + self.assertLess(success, total) + del top_comments + del claim_ids + + def test06GenerateAndListComments(self): + # TODO: Make this test not suck + self.skipTest('this is a stupid test') + top_comments, claim_ids = generate_top_comments() + total, success = 0, 0 + for _, comments in top_comments.items(): + for i, comment in enumerate(comments): + result = create_comment_or_error(self.conn, **comment) + if result: + success += 1 + comments[i] = result + del comment + total += len(comments) + self.assertEqual(total, success) + self.assertGreater(total, 0) + for reply in generate_replies(top_comments): + create_comment_or_error(self.conn, **reply) + for claim_id in claim_ids: + comments_ids = get_comment_ids(self.conn, claim_id) + with self.subTest(comments_ids=comments_ids): + self.assertIs(type(comments_ids), list) + self.assertGreaterEqual(len(comments_ids), 0) + self.assertLessEqual(len(comments_ids), 50) + replies = get_comments_by_id(self.conn, comments_ids) + self.assertLessEqual(len(replies), 50) + self.assertEqual(len(replies), len(comments_ids)) + + def test07HideComments(self): + comm = create_comment_or_error(self.conn, 'Comment #1', self.claimId, '1'*40, '@Doge123', 'a'*128, '123') + comment = get_comments_by_id(self.conn, [comm['comment_id']]).pop() self.assertFalse(comment['is_hidden']) - - success = set_hidden_flag([comm['comment_id']]) + success = hide_comments_by_id(self.conn, [comm['comment_id']]) self.assertTrue(success) - - comment = get_comment(comm['comment_id']) + comment = get_comments_by_id(self.conn, [comm['comment_id']]).pop() + self.assertTrue(comment['is_hidden']) + success = hide_comments_by_id(self.conn, [comm['comment_id']]) + self.assertTrue(success) + comment = get_comments_by_id(self.conn, [comm['comment_id']]).pop() self.assertTrue(comment['is_hidden']) - success = set_hidden_flag([comm['comment_id']]) - self.assertTrue(success) - - comment = get_comment(comm['comment_id']) - self.assertTrue(comment['is_hidden']) - - def test06DeleteComments(self): - # make sure that the comment was created - comm = create_comment( - comment='Comment #1', - claim_id=self.claimId, - channel_id='1'*40, - channel_name='@Doge123', - signature='a'*128, - signing_ts='123' - ) - comments = comment_list(self.claimId) - match = [x for x in comments['items'] if x['comment_id'] == comm['comment_id']] - self.assertTrue(len(match) > 0) - - deleted = delete_comment(comm['comment_id']) + def test08DeleteComments(self): + comm = create_comment_or_error(self.conn, 'Comment #1', self.claimId, '1'*40, '@Doge123', 'a'*128, '123') + comments = get_claim_comments(self.conn, self.claimId) + match = list(filter(lambda x: comm['comment_id'] == x['comment_id'], comments['items'])) + self.assertTrue(match) + deleted = delete_comment_by_id(self.conn, comm['comment_id']) self.assertTrue(deleted) - - # make sure that we can't find the comment here - comments = comment_list(self.claimId) - match = [x for x in comments['items'] if x['comment_id'] == comm['comment_id']] + comments = get_claim_comments(self.conn, self.claimId) + match = list(filter(lambda x: comm['comment_id'] == x['comment_id'], comments['items'])) self.assertFalse(match) - self.assertRaises( - ValueError, - delete_comment, - comment_id=comm['comment_id'], - ) + deleted = delete_comment_by_id(self.conn, comm['comment_id']) + self.assertFalse(deleted) class ListDatabaseTest(DatabaseTestCase): @@ -195,75 +231,61 @@ class ListDatabaseTest(DatabaseTestCase): def testLists(self): for claim_id in self.claim_ids: with self.subTest(claim_id=claim_id): - comments = comment_list(claim_id) + comments = get_claim_comments(self.conn, claim_id) self.assertIsNotNone(comments) self.assertGreater(comments['page_size'], 0) self.assertIn('has_hidden_comments', comments) self.assertFalse(comments['has_hidden_comments']) - top_comments = comment_list(claim_id, top_level=True, page=1, page_size=50) + top_comments = get_claim_comments(self.conn, claim_id, top_level=True, page=1, page_size=50) self.assertIsNotNone(top_comments) self.assertEqual(top_comments['page_size'], 50) self.assertEqual(top_comments['page'], 1) self.assertGreaterEqual(top_comments['total_pages'], 0) self.assertGreaterEqual(top_comments['total_items'], 0) - comment_ids = comment_list(claim_id, page_size=50, page=1) + comment_ids = get_comment_ids(self.conn, claim_id, page_size=50, page=1) with self.subTest(comment_ids=comment_ids): self.assertIsNotNone(comment_ids) self.assertLessEqual(len(comment_ids), 50) - matching_comments = (comment_ids) + matching_comments = get_comments_by_id(self.conn, comment_ids) self.assertIsNotNone(matching_comments) self.assertEqual(len(matching_comments), len(comment_ids)) def testHiddenCommentLists(self): claim_id = 'a'*40 - comm1 = create_comment( - 'Comment #1', - claim_id, - channel_id='1'*40, - channel_name='@Doge123', - signature='a'*128, - signing_ts='123' - ) - comm2 = create_comment( - 'Comment #2', claim_id, - channel_id='1'*40, - channel_name='@Doge123', - signature='b'*128, - signing_ts='123' - ) - comm3 = create_comment( - 'Comment #3', claim_id, - channel_id='1'*40, - channel_name='@Doge123', - signature='c'*128, - signing_ts='123' - ) + comm1 = create_comment_or_error(self.conn, 'Comment #1', claim_id, '1'*40, '@Doge123', 'a'*128, '123') + comm2 = create_comment_or_error(self.conn, 'Comment #2', claim_id, '1'*40, '@Doge123', 'b'*128, '123') + comm3 = create_comment_or_error(self.conn, 'Comment #3', claim_id, '1'*40, '@Doge123', 'c'*128, '123') comments = [comm1, comm2, comm3] - listed_comments = comment_list(claim_id) - self.assertEqual(len(comments), listed_comments['total_items']) - self.assertFalse(listed_comments['has_hidden_comments']) + comment_list = get_claim_comments(self.conn, claim_id) + self.assertIn('items', comment_list) + self.assertIn('has_hidden_comments', comment_list) + self.assertEqual(len(comments), comment_list['total_items']) + self.assertIn('has_hidden_comments', comment_list) + self.assertFalse(comment_list['has_hidden_comments']) + hide_comments_by_id(self.conn, [comm2['comment_id']]) - set_hidden_flag([comm2['comment_id']]) - hidden = comment_list(claim_id, exclude_mode='hidden') + default_comments = get_claim_hidden_comments(self.conn, claim_id) + self.assertIn('has_hidden_comments', default_comments) - self.assertTrue(hidden['has_hidden_comments']) - self.assertGreater(len(hidden['items']), 0) + hidden_comments = get_claim_hidden_comments(self.conn, claim_id, hidden=True) + self.assertIn('has_hidden_comments', hidden_comments) + self.assertEqual(default_comments, hidden_comments) - visible = comment_list(claim_id, exclude_mode='visible') - self.assertFalse(visible['has_hidden_comments']) - self.assertNotEqual(listed_comments['items'], visible['items']) - - # make sure the hidden comment is the one we marked as hidden - hidden_comment = hidden['items'][0] + hidden_comment = hidden_comments['items'][0] self.assertEqual(hidden_comment['comment_id'], comm2['comment_id']) - hidden_ids = [c['comment_id'] for c in hidden['items']] - visible_ids = [c['comment_id'] for c in visible['items']] + visible_comments = get_claim_hidden_comments(self.conn, claim_id, hidden=False) + self.assertIn('has_hidden_comments', visible_comments) + self.assertNotIn(hidden_comment, visible_comments['items']) + + hidden_ids = [c['comment_id'] for c in hidden_comments['items']] + visible_ids = [c['comment_id'] for c in visible_comments['items']] composite_ids = hidden_ids + visible_ids - listed_comments = comment_list(claim_id) - all_ids = [c['comment_id'] for c in listed_comments['items']] composite_ids.sort() + + comment_list = get_claim_comments(self.conn, claim_id) + all_ids = [c['comment_id'] for c in comment_list['items']] all_ids.sort() self.assertEqual(composite_ids, all_ids) diff --git a/test/test_server.py b/test/test_server.py index 56bccfb..3c659d9 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -9,18 +9,14 @@ from faker.providers import internet from faker.providers import lorem from faker.providers import misc -from src.main import get_config, CONFIG_FILE +from src.settings import config from src.server import app +from src.server.validation import is_valid_channel from src.server.validation import is_valid_base_comment from test.testcase import AsyncioTestCase -config = get_config(CONFIG_FILE) -config['mode'] = 'testing' -config['testing']['file'] = ':memory:' - - if 'slack_webhook' in config: config.pop('slack_webhook') @@ -76,10 +72,10 @@ def create_test_comments(values: iter, **default): class ServerTest(AsyncioTestCase): + db_file = 'test.db' + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - config['mode'] = 'testing' - config['testing']['file'] = ':memory:' self.host = 'localhost' self.port = 5931 @@ -90,16 +86,33 @@ class ServerTest(AsyncioTestCase): @classmethod def tearDownClass(cls) -> None: print('exit reached') + os.remove(cls.db_file) async def asyncSetUp(self): await super().asyncSetUp() - self.server = app.CommentDaemon(config) + self.server = app.CommentDaemon(config, db_file=self.db_file) await self.server.start(host=self.host, port=self.port) self.addCleanup(self.server.stop) async def post_comment(self, **params): return await jsonrpc_post(self.url, 'create_comment', **params) + @staticmethod + def is_valid_message(comment=None, claim_id=None, parent_id=None, + channel_name=None, channel_id=None, signature=None, signing_ts=None): + try: + assert is_valid_base_comment(comment, claim_id, parent_id) + + if channel_name or channel_id or signature or signing_ts: + assert channel_id and channel_name and signature and signing_ts + assert is_valid_channel(channel_id, channel_name) + assert len(signature) == 128 + assert signing_ts.isalnum() + + except Exception: + return False + return True + async def test01CreateCommentNoReply(self): anonymous_test = create_test_comments( ('claim_id', 'channel_id', 'channel_name', 'comment'), @@ -109,13 +122,13 @@ class ServerTest(AsyncioTestCase): claim_id=None ) for test in anonymous_test: - with self.subTest(test='null fields: ' + ', '.join(k for k, v in test.items() if not v)): + with self.subTest(test=test): message = await self.post_comment(**test) self.assertTrue('result' in message or 'error' in message) if 'error' in message: - self.assertFalse(is_valid_base_comment(**test)) + self.assertFalse(self.is_valid_message(**test)) else: - self.assertTrue(is_valid_base_comment(**test)) + self.assertTrue(self.is_valid_message(**test)) async def test02CreateNamedCommentsNoReply(self): named_test = create_test_comments( @@ -131,24 +144,22 @@ class ServerTest(AsyncioTestCase): message = await self.post_comment(**test) self.assertTrue('result' in message or 'error' in message) if 'error' in message: - self.assertFalse(is_valid_base_comment(**test)) + self.assertFalse(self.is_valid_message(**test)) else: - self.assertTrue(is_valid_base_comment(**test)) + self.assertTrue(self.is_valid_message(**test)) async def test03CreateAllTestComments(self): test_all = create_test_comments(replace.keys(), **{ k: None for k in replace.keys() }) - test_all.reverse() for test in test_all: - nulls = 'null fields: ' + ', '.join(k for k, v in test.items() if not v) - with self.subTest(test=nulls): + with self.subTest(test=test): message = await self.post_comment(**test) self.assertTrue('result' in message or 'error' in message) if 'error' in message: - self.assertFalse(is_valid_base_comment(**test, strict=True)) + self.assertFalse(self.is_valid_message(**test)) else: - self.assertTrue(is_valid_base_comment(**test, strict=True)) + self.assertTrue(self.is_valid_message(**test)) async def test04CreateAllReplies(self): claim_id = '1d8a5cc39ca02e55782d619e67131c0a20843be8' @@ -178,37 +189,9 @@ class ServerTest(AsyncioTestCase): message = await self.post_comment(**test) self.assertTrue('result' in message or 'error' in message) if 'error' in message: - self.assertFalse(is_valid_base_comment(**test)) + self.assertFalse(self.is_valid_message(**test)) else: - self.assertTrue(is_valid_base_comment(**test)) - - async def testSlackWebhook(self): - claim_id = '1d8a5cc39ca02e55782d619e67131c0a20843be8' - channel_name = '@name' - channel_id = fake.sha1() - signature = '{}'*64 - signing_ts = '1234' - - base = await self.post_comment( - channel_name=channel_name, - channel_id=channel_id, - comment='duplicate', - claim_id=claim_id, - signing_ts=signing_ts, - signature=signature - ) - - comment_id = base['result']['comment_id'] - - with self.subTest(test=comment_id): - await self.post_comment( - channel_name=channel_name, - channel_id=channel_id, - comment='duplicate', - claim_id=claim_id, - signing_ts=signing_ts, - signature=signature - ) + self.assertTrue(self.is_valid_message(**test)) class ListCommentsTest(AsyncioTestCase): @@ -226,8 +209,7 @@ class ListCommentsTest(AsyncioTestCase): super().__init__(*args, **kwargs) self.host = 'localhost' self.port = 5931 - config['mode'] = 'testing' - config['testing']['file'] = ':memory:' + self.db_file = 'list_test.db' self.claim_id = '1d8a5cc39ca02e55782d619e67131c0a20843be8' self.comment_ids = None @@ -238,6 +220,10 @@ class ListCommentsTest(AsyncioTestCase): async def post_comment(self, **params): return await jsonrpc_post(self.url, 'create_comment', **params) + def tearDown(self) -> None: + print('exit reached') + os.remove(self.db_file) + async def create_lots_of_comments(self, n=23): self.comment_list = [{key: self.replace[key]() for key in self.replace.keys()} for _ in range(23)] for comment in self.comment_list: @@ -247,7 +233,7 @@ class ListCommentsTest(AsyncioTestCase): async def asyncSetUp(self): await super().asyncSetUp() - self.server = app.CommentDaemon(config) + self.server = app.CommentDaemon(config, db_file=self.db_file) await self.server.start(self.host, self.port) self.addCleanup(self.server.stop) diff --git a/test/testcase.py b/test/testcase.py index 415041d..ddd8b44 100644 --- a/test/testcase.py +++ b/test/testcase.py @@ -1,38 +1,12 @@ import os import pathlib import unittest +from asyncio.runners import _cancel_all_tasks # type: ignore from unittest.case import _Outcome import asyncio -from asyncio.runners import _cancel_all_tasks # type: ignore -from peewee import * -from src.database.models import Channel, Comment - - -test_db = SqliteDatabase(':memory:') - - -MODELS = [Channel, Comment] - - -class DatabaseTestCase(unittest.TestCase): - def __init__(self, methodName='DatabaseTest'): - super().__init__(methodName) - - def setUp(self) -> None: - super().setUp() - test_db.bind(MODELS, bind_refs=False, bind_backrefs=False) - - test_db.connect() - test_db.create_tables(MODELS) - - def tearDown(self) -> None: - # drop tables for next test - test_db.drop_tables(MODELS) - - # close connection - test_db.close() +from src.database.queries import obtain_connection, setup_database class AsyncioTestCase(unittest.TestCase): @@ -143,3 +117,21 @@ class AsyncioTestCase(unittest.TestCase): self.loop.run_until_complete(maybe_coroutine) +class DatabaseTestCase(unittest.TestCase): + db_file = 'test.db' + + def __init__(self, methodName='DatabaseTest'): + super().__init__(methodName) + if pathlib.Path(self.db_file).exists(): + os.remove(self.db_file) + + def setUp(self) -> None: + super().setUp() + setup_database(self.db_file) + self.conn = obtain_connection(self.db_file) + self.addCleanup(self.conn.close) + self.addCleanup(os.remove, self.db_file) + + def tearDown(self) -> None: + self.conn.close() +