Merge pull request #39 from lbryio/orm-rewrite

Replace database methods with peewee ORM
This commit is contained in:
Oleg Silkin 2020-04-03 17:40:42 -04:00 committed by GitHub
commit 80f77218f9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 783 additions and 538 deletions

4
.gitignore vendored
View file

@ -1 +1,3 @@
config/conf.json config/conf.yml
docker-compose.yml

View file

@ -1,7 +1,28 @@
sudo: required sudo: required
language: python language: python
dist: xenial dist: xenial
python: 3.7 python: 3.8
# for docker-compose
services:
- docker
# to avoid "well it works on my computer" moments
env:
- DOCKER_COMPOSE_VERSION=1.25.4
before_install:
# ensure docker-compose version is as specified above
- sudo rm /usr/local/bin/docker-compose
- curl -L https://github.com/docker/compose/releases/download/${DOCKER_COMPOSE_VERSION}/docker-compose-`uname -s`-`uname -m` > docker-compose
- chmod +x docker-compose
- sudo mv docker-compose /usr/local/bin
# refresh docker images
- sudo apt-get update
before_script:
- docker-compose up -d
jobs: jobs:
include: include:

30
config/conf.yml Normal file
View file

@ -0,0 +1,30 @@
---
# for running local-tests without using MySQL for now
testing:
database: sqlite
file: comments.db
pragmas:
journal_mode: wal
cache_size: 64000
foreign_keys: 0
ignore_check_constraints: 1
synchronous: 0
# actual database should be running MySQL
production:
database: mysql
name: lbry
user: lbry
password: lbry
host: localhost
port: 3306
mode: production
logging:
format: "%(asctime)s | %(levelname)s | %(name)s | %(module)s.%(funcName)s:%(lineno)d
| %(message)s"
aiohttp_format: "%(asctime)s | %(levelname)s | %(name)s | %(message)s"
datefmt: "%Y-%m-%d %H:%M:%S"
host: localhost
port: 5921
lbrynet: http://localhost:5279

24
docker-compose.yml Normal file
View file

@ -0,0 +1,24 @@
version: "3.7"
services:
###########
## MySQL ##
###########
mysql:
image: mysql/mysql-server:5.7.27
restart: "no"
ports:
- "3306:3306"
environment:
- MYSQL_ALLOW_EMPTY_PASSWORD=true
- MYSQL_DATABASE=lbry
- MYSQL_USER=lbry
- MYSQL_PASSWORD=lbry
- MYSQL_LOG_CONSOLE=true
#############
## Adminer ##
#############
adminer:
image: adminer
restart: always
ports:
- 8080:8080

View file

@ -14,15 +14,17 @@ setup(
data_files=[('config', ['config/conf.json',])], data_files=[('config', ['config/conf.json',])],
include_package_data=True, include_package_data=True,
install_requires=[ install_requires=[
'pymysql',
'pyyaml',
'Faker>=1.0.7', 'Faker>=1.0.7',
'asyncio>=3.4.3', 'asyncio',
'aiohttp==3.5.4', 'aiohttp',
'aiojobs==0.2.2', 'aiojobs',
'ecdsa>=0.13.3', 'ecdsa>=0.13.3',
'cryptography==2.5', 'cryptography==2.5',
'aiosqlite==0.10.0',
'PyNaCl>=1.3.0', 'PyNaCl>=1.3.0',
'requests', 'requests',
'cython', 'cython',
'peewee'
] ]
) )

217
src/database/models.py Normal file
View file

@ -0,0 +1,217 @@
import json
import time
import logging
import math
import typing
from peewee import *
import nacl.hash
from src.server.validation import is_valid_base_comment
from src.misc import clean
class Channel(Model):
claim_id = CharField(column_name='ClaimId', primary_key=True, max_length=40)
name = CharField(column_name='Name', max_length=256)
class Meta:
table_name = 'CHANNEL'
class Comment(Model):
comment = CharField(column_name='Body', max_length=2000)
channel = ForeignKeyField(
backref='comments',
column_name='ChannelId',
field='claim_id',
model=Channel,
null=True
)
comment_id = CharField(column_name='CommentId', primary_key=True, max_length=64)
is_hidden = BooleanField(column_name='IsHidden', constraints=[SQL("DEFAULT 0")])
claim_id = CharField(max_length=40, column_name='LbryClaimId')
parent = ForeignKeyField(
column_name='ParentId',
field='comment_id',
model='self',
null=True,
backref='replies'
)
signature = CharField(max_length=128, column_name='Signature', null=True, unique=True)
signing_ts = TextField(column_name='SigningTs', null=True)
timestamp = IntegerField(column_name='Timestamp')
class Meta:
table_name = 'COMMENT'
indexes = (
(('channel', 'comment_id'), False),
(('claim_id', 'comment_id'), False),
)
FIELDS = {
'comment': Comment.comment,
'comment_id': Comment.comment_id,
'claim_id': Comment.claim_id,
'timestamp': Comment.timestamp,
'signature': Comment.signature,
'signing_ts': Comment.signing_ts,
'is_hidden': Comment.is_hidden,
'parent_id': Comment.parent.alias('parent_id'),
'channel_id': Channel.claim_id.alias('channel_id'),
'channel_name': Channel.name.alias('channel_name'),
'channel_url': ('lbry://' + Channel.name + '#' + Channel.claim_id).alias('channel_url')
}
def comment_list(claim_id: str = None, parent_id: str = None,
top_level: bool = False, exclude_mode: str = None,
page: int = 1, page_size: int = 50, expressions=None,
select_fields: list = None, exclude_fields: list = None) -> dict:
fields = FIELDS.keys()
if exclude_fields:
fields -= set(exclude_fields)
if select_fields:
fields &= set(select_fields)
attributes = [FIELDS[field] for field in fields]
query = Comment.select(*attributes)
# todo: allow this process to be more automated, so it can just be an expression
if claim_id:
query = query.where(Comment.claim_id == claim_id)
if top_level:
query = query.where(Comment.parent.is_null())
if parent_id:
query = query.where(Comment.ParentId == parent_id)
if exclude_mode:
show_hidden = exclude_mode.lower() == 'hidden'
query = query.where((Comment.is_hidden == show_hidden))
if expressions:
query = query.where(expressions)
total = query.count()
query = (query
.join(Channel, JOIN.LEFT_OUTER)
.order_by(Comment.timestamp.desc())
.paginate(page, page_size))
items = [clean(item) for item in query.dicts()]
# has_hidden_comments is deprecated
data = {
'page': page,
'page_size': page_size,
'total_pages': math.ceil(total / page_size),
'total_items': total,
'items': items,
'has_hidden_comments': exclude_mode is not None and exclude_mode == 'hidden',
}
return data
def get_comment(comment_id: str) -> dict:
try:
comment = comment_list(expressions=(Comment.comment_id == comment_id), page_size=1).get('items').pop()
except IndexError:
raise ValueError(f'Comment does not exist with id {comment_id}')
else:
return comment
def create_comment_id(comment: str, channel_id: str, timestamp: int):
# We convert the timestamp from seconds into minutes
# to prevent spammers from commenting the same BS everywhere.
nearest_minute = str(math.floor(timestamp / 60))
# don't use claim_id for the comment_id anymore so comments
# are not unique to just one claim
prehash = b':'.join([
comment.encode(),
channel_id.encode(),
nearest_minute.encode()
])
return nacl.hash.sha256(prehash).decode()
def create_comment(comment: str = None, claim_id: str = None,
parent_id: str = None, channel_id: str = None,
channel_name: str = None, signature: str = None,
signing_ts: str = None) -> dict:
if not is_valid_base_comment(
comment=comment,
claim_id=claim_id,
parent_id=parent_id,
channel_id=channel_id,
channel_name=channel_name,
signature=signature,
signing_ts=signing_ts
):
raise ValueError('Invalid Parameters given for comment')
channel, _ = Channel.get_or_create(name=channel_name, claim_id=channel_id)
if parent_id and not claim_id:
parent: Comment = Comment.get_by_id(parent_id)
claim_id = parent.claim_id
timestamp = int(time.time())
comment_id = create_comment_id(comment, channel_id, timestamp)
new_comment = Comment.create(
claim_id=claim_id,
comment_id=comment_id,
comment=comment,
parent=parent_id,
channel=channel,
signature=signature,
signing_ts=signing_ts,
timestamp=timestamp
)
return get_comment(new_comment.comment_id)
def delete_comment(comment_id: str) -> bool:
try:
comment: Comment = Comment.get_by_id(comment_id)
except DoesNotExist as e:
raise ValueError from e
else:
return 0 < comment.delete_instance(True, delete_nullable=True)
def edit_comment(comment_id: str, new_comment: str, new_sig: str, new_ts: str) -> bool:
try:
comment: Comment = Comment.get_by_id(comment_id)
except DoesNotExist as e:
raise ValueError from e
else:
comment.comment = new_comment
comment.signature = new_sig
comment.signing_ts = new_ts
# todo: add a 'last-modified' timestamp
comment.timestamp = int(time.time())
return comment.save() > 0
def set_hidden_flag(comment_ids: typing.List[str], hidden=True) -> bool:
# sets `is_hidden` flag for all `comment_ids` to the `hidden` param
update = (Comment
.update(is_hidden=hidden)
.where(Comment.comment_id.in_(comment_ids)))
return update.execute() > 0
if __name__ == '__main__':
logger = logging.getLogger('peewee')
logger.addHandler(logging.StreamHandler())
logger.setLevel(logging.DEBUG)
comments = comment_list(
page_size=20,
expressions=((Comment.timestamp < 1583272089) &
(Comment.claim_id ** '420%'))
)
print(json.dumps(comments, indent=4))

View file

@ -1,290 +0,0 @@
import atexit
import logging
import math
import sqlite3
import time
import typing
import nacl.hash
from src.database.schema import CREATE_TABLES_QUERY
logger = logging.getLogger(__name__)
SELECT_COMMENTS_ON_CLAIMS = """
SELECT comment, comment_id, claim_id, timestamp, is_hidden, parent_id,
channel_name, channel_id, channel_url, signature, signing_ts
FROM COMMENTS_ON_CLAIMS
"""
def clean(thing: dict) -> dict:
if 'is_hidden' in thing:
thing.update({'is_hidden': bool(thing['is_hidden'])})
return {k: v for k, v in thing.items() if v is not None}
def obtain_connection(filepath: str = None, row_factory: bool = True):
connection = sqlite3.connect(filepath)
if row_factory:
connection.row_factory = sqlite3.Row
return connection
def get_claim_comments(conn: sqlite3.Connection, claim_id: str, parent_id: str = None,
page: int = 1, page_size: int = 50, top_level=False):
with conn:
if top_level:
results = [clean(dict(row)) for row in conn.execute(
SELECT_COMMENTS_ON_CLAIMS + " WHERE claim_id = ? AND parent_id IS NULL LIMIT ? OFFSET ?",
(claim_id, page_size, page_size * (page - 1))
)]
count = conn.execute(
"SELECT COUNT(*) FROM COMMENTS_ON_CLAIMS WHERE claim_id = ? AND parent_id IS NULL",
(claim_id,)
)
elif parent_id is None:
results = [clean(dict(row)) for row in conn.execute(
SELECT_COMMENTS_ON_CLAIMS + "WHERE claim_id = ? LIMIT ? OFFSET ? ",
(claim_id, page_size, page_size * (page - 1))
)]
count = conn.execute(
"SELECT COUNT(*) FROM COMMENTS_ON_CLAIMS WHERE claim_id = ?",
(claim_id,)
)
else:
results = [clean(dict(row)) for row in conn.execute(
SELECT_COMMENTS_ON_CLAIMS + "WHERE claim_id = ? AND parent_id = ? LIMIT ? OFFSET ? ",
(claim_id, parent_id, page_size, page_size * (page - 1))
)]
count = conn.execute(
"SELECT COUNT(*) FROM COMMENTS_ON_CLAIMS WHERE claim_id = ? AND parent_id = ?",
(claim_id, parent_id)
)
count = tuple(count.fetchone())[0]
return {
'items': results,
'page': page,
'page_size': page_size,
'total_pages': math.ceil(count / page_size),
'total_items': count,
'has_hidden_comments': claim_has_hidden_comments(conn, claim_id)
}
def get_claim_hidden_comments(conn: sqlite3.Connection, claim_id: str, hidden=True, page=1, page_size=50):
with conn:
results = conn.execute(
SELECT_COMMENTS_ON_CLAIMS + "WHERE claim_id = ? AND is_hidden IS ? LIMIT ? OFFSET ?",
(claim_id, hidden, page_size, page_size * (page - 1))
)
count = conn.execute(
"SELECT COUNT(*) FROM COMMENTS_ON_CLAIMS WHERE claim_id = ? AND is_hidden IS ?", (claim_id, hidden)
)
results = [clean(dict(row)) for row in results.fetchall()]
count = tuple(count.fetchone())[0]
return {
'items': results,
'page': page,
'page_size': page_size,
'total_pages': math.ceil(count / page_size),
'total_items': count,
'has_hidden_comments': claim_has_hidden_comments(conn, claim_id)
}
def claim_has_hidden_comments(conn, claim_id):
with conn:
result = conn.execute(
"SELECT COUNT(DISTINCT is_hidden) FROM COMMENTS_ON_CLAIMS WHERE claim_id = ? AND is_hidden IS 1",
(claim_id,)
)
return bool(tuple(result.fetchone())[0])
def insert_comment(conn: sqlite3.Connection, claim_id: str, comment: str,
channel_id: str = None, signature: str = None, signing_ts: str = None, **extra) -> str:
timestamp = int(time.time())
prehash = b':'.join((claim_id.encode(), comment.encode(), str(timestamp).encode(),))
comment_id = nacl.hash.sha256(prehash).decode()
with conn:
curs = conn.execute(
"""
INSERT INTO COMMENT(CommentId, LbryClaimId, ChannelId, Body, ParentId,
Timestamp, Signature, SigningTs, IsHidden)
VALUES (:comment_id, :claim_id, :channel_id, :comment, NULL,
:timestamp, :signature, :signing_ts, 0) """,
{
'comment_id': comment_id,
'claim_id': claim_id,
'channel_id': channel_id,
'comment': comment,
'timestamp': timestamp,
'signature': signature,
'signing_ts': signing_ts
}
)
logging.info('attempted to insert comment with comment_id [%s] | %d rows affected', comment_id, curs.rowcount)
return comment_id
def insert_reply(conn: sqlite3.Connection, comment: str, parent_id: str,
channel_id: str = None, signature: str = None,
signing_ts: str = None, **extra) -> str:
timestamp = int(time.time())
prehash = b':'.join((parent_id.encode(), comment.encode(), str(timestamp).encode(),))
comment_id = nacl.hash.sha256(prehash).decode()
with conn:
curs = conn.execute(
"""
INSERT INTO COMMENT
(CommentId, LbryClaimId, ChannelId, Body, ParentId, Signature, Timestamp, SigningTs, IsHidden)
SELECT :comment_id, LbryClaimId, :channel_id, :comment, :parent_id, :signature, :timestamp, :signing_ts, 0
FROM COMMENT WHERE CommentId = :parent_id
""", {
'comment_id': comment_id,
'parent_id': parent_id,
'timestamp': timestamp,
'comment': comment,
'channel_id': channel_id,
'signature': signature,
'signing_ts': signing_ts
}
)
logging.info('attempted to insert reply with comment_id [%s] | %d rows affected', comment_id, curs.rowcount)
return comment_id
def get_comment_or_none(conn: sqlite3.Connection, comment_id: str) -> dict:
with conn:
curry = conn.execute(SELECT_COMMENTS_ON_CLAIMS + "WHERE comment_id = ?", (comment_id,))
thing = curry.fetchone()
return clean(dict(thing)) if thing else None
def get_comment_ids(conn: sqlite3.Connection, claim_id: str, parent_id: str = None, page=1, page_size=50):
""" Just return a list of the comment IDs that are associated with the given claim_id.
If get_all is specified then it returns all the IDs, otherwise only the IDs at that level.
if parent_id is left null then it only returns the top level comments.
For pagination the parameters are:
get_all XOR (page_size + page)
"""
with conn:
if parent_id is None:
curs = conn.execute("""
SELECT comment_id FROM COMMENTS_ON_CLAIMS
WHERE claim_id = ? AND parent_id IS NULL LIMIT ? OFFSET ?
""", (claim_id, page_size, page_size * abs(page - 1),)
)
else:
curs = conn.execute("""
SELECT comment_id FROM COMMENTS_ON_CLAIMS
WHERE claim_id = ? AND parent_id = ? LIMIT ? OFFSET ?
""", (claim_id, parent_id, page_size, page_size * abs(page - 1),)
)
return [tuple(row)[0] for row in curs.fetchall()]
def get_comments_by_id(conn, comment_ids: typing.Union[list, tuple]) -> typing.Union[list, None]:
""" Returns a list containing the comment data associated with each ID within the list"""
# format the input, under the assumption that the
placeholders = ', '.join('?' for _ in comment_ids)
with conn:
return [clean(dict(row)) for row in conn.execute(
SELECT_COMMENTS_ON_CLAIMS + f'WHERE comment_id IN ({placeholders})',
tuple(comment_ids)
)]
def delete_comment_by_id(conn: sqlite3.Connection, comment_id: str) -> bool:
with conn:
curs = conn.execute("DELETE FROM COMMENT WHERE CommentId = ?", (comment_id,))
return bool(curs.rowcount)
def insert_channel(conn: sqlite3.Connection, channel_name: str, channel_id: str):
with conn:
curs = conn.execute('INSERT INTO CHANNEL(ClaimId, Name) VALUES (?, ?)', (channel_id, channel_name))
return bool(curs.rowcount)
def get_channel_id_from_comment_id(conn: sqlite3.Connection, comment_id: str):
with conn:
channel = conn.execute(
"SELECT channel_id, channel_name FROM COMMENTS_ON_CLAIMS WHERE comment_id = ?", (comment_id,)
).fetchone()
return dict(channel) if channel else {}
def get_claim_ids_from_comment_ids(conn: sqlite3.Connection, comment_ids: list):
with conn:
cids = conn.execute(
f""" SELECT CommentId as comment_id, LbryClaimId AS claim_id FROM COMMENT
WHERE CommentId IN ({', '.join('?' for _ in comment_ids)}) """,
tuple(comment_ids)
)
return {row['comment_id']: row['claim_id'] for row in cids.fetchall()}
def hide_comments_by_id(conn: sqlite3.Connection, comment_ids: list) -> bool:
with conn:
curs = conn.cursor()
curs.executemany(
"UPDATE COMMENT SET IsHidden = 1 WHERE CommentId = ?",
[[c] for c in comment_ids]
)
return bool(curs.rowcount)
def edit_comment_by_id(conn: sqlite3.Connection, comment_id: str, comment: str,
signature: str, signing_ts: str) -> bool:
with conn:
curs = conn.execute(
"""
UPDATE COMMENT
SET Body = :comment, Signature = :signature, SigningTs = :signing_ts
WHERE CommentId = :comment_id
""",
{
'comment': comment,
'signature': signature,
'signing_ts': signing_ts,
'comment_id': comment_id
})
logger.info("updated comment with `comment_id`: %s", comment_id)
return bool(curs.rowcount)
class DatabaseWriter(object):
_writer = None
def __init__(self, db_file):
if not DatabaseWriter._writer:
self.conn = obtain_connection(db_file)
DatabaseWriter._writer = self
atexit.register(self.cleanup)
logging.info('Database writer has been created at %s', repr(self))
else:
logging.warning('Someone attempted to insantiate DatabaseWriter')
raise TypeError('Database Writer already exists!')
def cleanup(self):
logging.info('Cleaning up database writer')
self.conn.close()
DatabaseWriter._writer = None
@property
def connection(self):
return self.conn
def setup_database(db_path):
with sqlite3.connect(db_path) as conn:
conn.executescript(CREATE_TABLES_QUERY)
def backup_database(conn: sqlite3.Connection, back_fp):
with sqlite3.connect(back_fp) as back:
conn.backup(back)

View file

@ -7,7 +7,7 @@ from src.server.validation import is_valid_base_comment
from src.server.validation import is_valid_credential_input from src.server.validation import is_valid_credential_input
from src.server.validation import validate_signature_from_claim from src.server.validation import validate_signature_from_claim
from src.server.validation import body_is_valid from src.server.validation import body_is_valid
from src.server.misc import get_claim_from_id from src.misc import get_claim_from_id
from src.server.external import send_notifications from src.server.external import send_notifications
from src.server.external import send_notification from src.server.external import send_notification
import src.database.queries as db import src.database.queries as db
@ -84,13 +84,17 @@ async def hide_comments(app, pieces: list) -> list:
# TODO: Amortize this process # TODO: Amortize this process
claims = {} claims = {}
comments_to_hide = [] comments_to_hide = []
# go through a list of dict objects
for p in pieces: for p in pieces:
# maps the comment_id from the piece to a claim_id
claim_id = comment_cids[p['comment_id']] claim_id = comment_cids[p['comment_id']]
# resolve the claim from its id
if claim_id not in claims: if claim_id not in claims:
claim = await get_claim_from_id(app, claim_id) claim = await get_claim_from_id(app, claim_id)
if claim: if claim:
claims[claim_id] = claim claims[claim_id] = claim
# get the claim's signing channel, then use it to validate the hidden comment
channel = claims[claim_id].get('signing_channel') channel = claims[claim_id].get('signing_channel')
if validate_signature_from_claim(channel, p['signature'], p['signing_ts'], p['comment_id']): if validate_signature_from_claim(channel, p['signature'], p['signing_ts'], p['comment_id']):
comments_to_hide.append(p) comments_to_hide.append(p)

7
src/definitions.py Normal file
View file

@ -0,0 +1,7 @@
import os
SRC_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = os.path.dirname(SRC_DIR)
CONFIG_FILE = os.path.join(ROOT_DIR, 'config', 'conf.yml')
LOGGING_DIR = os.path.join(ROOT_DIR, 'logs')
DATABASE_DIR = os.path.join(ROOT_DIR, 'database')

View file

@ -1,13 +1,20 @@
import argparse import argparse
import json
import yaml
import logging import logging
import logging.config import logging.config
import os
import sys import sys
from src.server.app import run_app from src.server.app import run_app
from src.settings import config from src.definitions import LOGGING_DIR, CONFIG_FILE, DATABASE_DIR
def config_logging_from_settings(conf): def setup_logging_from_config(conf: dict):
# set the logging directory here from the settings file
if not os.path.exists(LOGGING_DIR):
os.mkdir(LOGGING_DIR)
_config = { _config = {
"version": 1, "version": 1,
"disable_existing_loggers": False, "disable_existing_loggers": False,
@ -32,7 +39,7 @@ def config_logging_from_settings(conf):
"level": "DEBUG", "level": "DEBUG",
"formatter": "standard", "formatter": "standard",
"class": "logging.handlers.RotatingFileHandler", "class": "logging.handlers.RotatingFileHandler",
"filename": conf['path']['debug_log'], "filename": os.path.join(LOGGING_DIR, 'debug.log'),
"maxBytes": 10485760, "maxBytes": 10485760,
"backupCount": 5 "backupCount": 5
}, },
@ -40,7 +47,7 @@ def config_logging_from_settings(conf):
"level": "ERROR", "level": "ERROR",
"formatter": "standard", "formatter": "standard",
"class": "logging.handlers.RotatingFileHandler", "class": "logging.handlers.RotatingFileHandler",
"filename": conf['path']['error_log'], "filename": os.path.join(LOGGING_DIR, 'error.log'),
"maxBytes": 10485760, "maxBytes": 10485760,
"backupCount": 5 "backupCount": 5
}, },
@ -48,7 +55,7 @@ def config_logging_from_settings(conf):
"level": "NOTSET", "level": "NOTSET",
"formatter": "aiohttp", "formatter": "aiohttp",
"class": "logging.handlers.RotatingFileHandler", "class": "logging.handlers.RotatingFileHandler",
"filename": conf['path']['server_log'], "filename": os.path.join(LOGGING_DIR, 'server.log'),
"maxBytes": 10485760, "maxBytes": 10485760,
"backupCount": 5 "backupCount": 5
} }
@ -70,15 +77,42 @@ def config_logging_from_settings(conf):
logging.config.dictConfig(_config) logging.config.dictConfig(_config)
def get_config(filepath):
with open(filepath, 'r') as cfile:
config = yaml.load(cfile, Loader=yaml.Loader)
return config
def setup_db_from_config(config: dict):
mode = config['mode']
if config[mode]['database'] == 'sqlite':
if not os.path.exists(DATABASE_DIR):
os.mkdir(DATABASE_DIR)
config[mode]['db_file'] = os.path.join(
DATABASE_DIR, config[mode]['name']
)
def main(argv=None): def main(argv=None):
argv = argv or sys.argv[1:] argv = argv or sys.argv[1:]
parser = argparse.ArgumentParser(description='LBRY Comment Server') parser = argparse.ArgumentParser(description='LBRY Comment Server')
parser.add_argument('--port', type=int) parser.add_argument('--port', type=int)
parser.add_argument('--config', type=str)
parser.add_argument('--mode', type=str)
args = parser.parse_args(argv) args = parser.parse_args(argv)
config_logging_from_settings(config)
config = get_config(CONFIG_FILE) if not args.config else args.config
setup_logging_from_config(config)
if args.mode:
config['mode'] = args.mode
setup_db_from_config(config)
if args.port: if args.port:
config['port'] = args.port config['port'] = args.port
config_logging_from_settings(config)
run_app(config) run_app(config)

View file

@ -20,3 +20,7 @@ def clean_input_params(kwargs: dict):
kwargs[k] = v.strip() kwargs[k] = v.strip()
if k in ID_LIST: if k in ID_LIST:
kwargs[k] = v.lower() kwargs[k] = v.lower()
def clean(thing: dict) -> dict:
return {k: v for k, v in thing.items() if v is not None}

View file

@ -1,7 +1,6 @@
# cython: language_level=3 # cython: language_level=3
import asyncio import asyncio
import logging import logging
import pathlib
import signal import signal
import time import time
@ -9,81 +8,67 @@ import aiojobs
import aiojobs.aiohttp import aiojobs.aiohttp
from aiohttp import web from aiohttp import web
from src.database.queries import obtain_connection, DatabaseWriter from peewee import *
from src.database.queries import setup_database, backup_database
from src.server.handles import api_endpoint, get_api_endpoint from src.server.handles import api_endpoint, get_api_endpoint
from src.database.models import Comment, Channel
MODELS = [Comment, Channel]
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
async def setup_db_schema(app): def setup_database(app):
if not pathlib.Path(app['db_path']).exists(): config = app['config']
logger.info(f'Setting up schema in {app["db_path"]}') mode = config['mode']
setup_database(app['db_path'])
else:
logger.info(f'Database already exists in {app["db_path"]}, skipping setup')
# switch between Database objects
if config[mode]['database'] == 'mysql':
app['db'] = MySQLDatabase(
database=config[mode]['name'],
user=config[mode]['user'],
host=config[mode]['host'],
password=config[mode]['password'],
port=config[mode]['port'],
)
elif config[mode]['database'] == 'sqlite':
app['db'] = SqliteDatabase(
config[mode]['file'],
pragmas=config[mode]['pragmas']
)
async def database_backup_routine(app): # bind the Model list to the database
try: app['db'].bind(MODELS, bind_refs=False, bind_backrefs=False)
while True:
await asyncio.sleep(app['config']['backup_int'])
with app['reader'] as conn:
logger.debug('backing up database')
backup_database(conn, app['backup'])
except asyncio.CancelledError:
pass
async def start_background_tasks(app): async def start_background_tasks(app):
# Reading the DB app['db'].connect()
app['reader'] = obtain_connection(app['db_path'], True) app['db'].create_tables(MODELS)
app['waitful_backup'] = asyncio.create_task(database_backup_routine(app))
# Scheduler to prevent multiple threads from writing to DB simulataneously
app['comment_scheduler'] = await aiojobs.create_scheduler(limit=1, pending_limit=0)
app['db_writer'] = DatabaseWriter(app['db_path'])
app['writer'] = app['db_writer'].connection
# for requesting to external and internal APIs # for requesting to external and internal APIs
app['webhooks'] = await aiojobs.create_scheduler(pending_limit=0) app['webhooks'] = await aiojobs.create_scheduler(pending_limit=0)
async def close_database_connections(app): async def close_database_connections(app):
logger.info('Ending background backup loop') app['db'].close()
app['waitful_backup'].cancel()
await app['waitful_backup']
app['reader'].close()
app['writer'].close()
app['db_writer'].cleanup()
async def close_schedulers(app): async def close_schedulers(app):
logger.info('Closing comment_scheduler')
await app['comment_scheduler'].close()
logger.info('Closing scheduler for webhook requests') logger.info('Closing scheduler for webhook requests')
await app['webhooks'].close() await app['webhooks'].close()
class CommentDaemon: class CommentDaemon:
def __init__(self, config, db_file=None, backup=None, **kwargs): def __init__(self, config, **kwargs):
app = web.Application() app = web.Application()
app['config'] = config
# configure the config # configure the config
app['config'] = config self.config = config
self.config = app['config'] self.host = config['host']
self.port = config['port']
# configure the db file setup_database(app)
if db_file:
app['db_path'] = db_file
app['backup'] = backup
else:
app['db_path'] = config['path']['database']
app['backup'] = backup or (app['db_path'] + '.backup')
# configure the order of tasks to run during app lifetime # configure the order of tasks to run during app lifetime
app.on_startup.append(setup_db_schema)
app.on_startup.append(start_background_tasks) app.on_startup.append(start_background_tasks)
app.on_shutdown.append(close_schedulers) app.on_shutdown.append(close_schedulers)
app.on_cleanup.append(close_database_connections) app.on_cleanup.append(close_database_connections)
@ -105,20 +90,19 @@ class CommentDaemon:
await self.app_runner.setup() await self.app_runner.setup()
self.app_site = web.TCPSite( self.app_site = web.TCPSite(
runner=self.app_runner, runner=self.app_runner,
host=host or self.config['host'], host=host or self.host,
port=port or self.config['port'], port=port or self.port,
) )
await self.app_site.start() await self.app_site.start()
logger.info(f'Comment Server is running on {self.config["host"]}:{self.config["port"]}') logger.info(f'Comment Server is running on {self.host}:{self.port}')
async def stop(self): async def stop(self):
await self.app_runner.shutdown() await self.app_runner.shutdown()
await self.app_runner.cleanup() await self.app_runner.cleanup()
def run_app(config, db_file=None): def run_app(config):
comment_app = CommentDaemon(config=config, db_file=db_file, close_timeout=5.0) comment_app = CommentDaemon(config=config)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
def __exit(): def __exit():

View file

@ -1,16 +1,23 @@
import asyncio import asyncio
import logging import logging
import time import time
import typing
from aiohttp import web from aiohttp import web
from aiojobs.aiohttp import atomic from aiojobs.aiohttp import atomic
from peewee import DoesNotExist
import src.database.queries as db from src.server.validation import validate_signature_from_claim
from src.database.writes import abandon_comment, create_comment from src.misc import clean_input_params, get_claim_from_id
from src.database.writes import hide_comments
from src.database.writes import edit_comment
from src.server.misc import clean_input_params
from src.server.errors import make_error, report_error from src.server.errors import make_error, report_error
from src.database.models import Comment, Channel
from src.database.models import get_comment
from src.database.models import comment_list
from src.database.models import create_comment
from src.database.models import edit_comment
from src.database.models import delete_comment
from src.database.models import set_hidden_flag
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -20,51 +27,194 @@ def ping(*args):
return 'pong' return 'pong'
def handle_get_channel_from_comment_id(app, kwargs: dict): def handle_get_channel_from_comment_id(app: web.Application, comment_id: str) -> dict:
return db.get_channel_id_from_comment_id(app['reader'], **kwargs) comment = get_comment(comment_id)
return {
'channel_id': comment['channel_id'],
'channel_name': comment['channel_name']
}
def handle_get_comment_ids(app, kwargs): def handle_get_comment_ids(
return db.get_comment_ids(app['reader'], **kwargs) app: web.Application,
claim_id: str,
parent_id: str = None,
page: int = 1,
page_size: int = 50,
flattened=False
) -> dict:
results = comment_list(
claim_id=claim_id,
parent_id=parent_id,
top_level=(parent_id is None),
page=page,
page_size=page_size,
select_fields=['comment_id', 'parent_id']
)
if flattened:
results.update({
'items': [item['comment_id'] for item in results['items']],
'replies': [(item['comment_id'], item.get('parent_id'))
for item in results['items']]
})
return results
def handle_get_claim_comments(app, kwargs): def handle_get_comments_by_id(
return db.get_claim_comments(app['reader'], **kwargs) app: web.Application,
comment_ids: typing.Union[list, tuple]
) -> dict:
expression = Comment.comment_id.in_(comment_ids)
return comment_list(expressions=expression, page_size=len(comment_ids))
def handle_get_comments_by_id(app, kwargs): def handle_get_claim_comments(
return db.get_comments_by_id(app['reader'], **kwargs) app: web.Application,
claim_id: str,
parent_id: str = None,
page: int = 1,
page_size: int = 50,
top_level: bool = False
) -> dict:
return comment_list(
claim_id=claim_id,
parent_id=parent_id,
page=page,
page_size=page_size,
top_level=top_level
)
def handle_get_claim_hidden_comments(app, kwargs): def handle_get_claim_hidden_comments(
return db.get_claim_hidden_comments(app['reader'], **kwargs) app: web.Application,
claim_id: str,
hidden: bool,
page: int = 1,
page_size: int = 50,
) -> dict:
exclude = 'hidden' if hidden else 'visible'
return comment_list(
claim_id=claim_id,
exclude_mode=exclude,
page=page,
page_size=page_size
)
async def handle_abandon_comment(app, params): async def handle_abandon_comment(
return {'abandoned': await abandon_comment(app, **params)} app: web.Application,
comment_id: str,
signature: str,
signing_ts: str,
**kwargs,
) -> dict:
comment = get_comment(comment_id)
try:
channel = await get_claim_from_id(app, comment['channel_id'])
except DoesNotExist:
raise ValueError('Could not find a channel associated with the given comment')
else:
if not validate_signature_from_claim(channel, signature, signing_ts, comment_id):
raise ValueError('Abandon signature could not be validated')
with app['db'].atomic():
return {
'abandoned': delete_comment(comment_id)
}
async def handle_hide_comments(app, params): async def handle_hide_comments(app: web.Application, pieces: list, hide: bool = True) -> dict:
return {'hidden': await hide_comments(app, **params)} # let's get all the distinct claim_ids from the list of comment_ids
pieces_by_id = {p['comment_id']: p for p in pieces}
comment_ids = list(pieces_by_id.keys())
comments = (Comment
.select(Comment.comment_id, Comment.claim_id)
.where(Comment.comment_id.in_(comment_ids))
.tuples())
# resolve the claims and map them to their corresponding comment_ids
claims = {}
for comment_id, claim_id in comments:
try:
# try and resolve the claim, if fails then we mark it as null
# and remove the associated comment from the pieces
if claim_id not in claims:
claims[claim_id] = await get_claim_from_id(app, claim_id)
# try to get a public key to validate
if claims[claim_id] is None or 'signing_channel' not in claims[claim_id]:
raise ValueError(f'could not get signing channel from claim_id: {claim_id}')
# try to validate signature
else:
channel = claims[claim_id]['signing_channel']
piece = pieces_by_id[comment_id]
is_valid_signature = validate_signature_from_claim(
claim=channel,
signature=piece['signature'],
signing_ts=piece['signing_ts'],
data=piece['comment_id']
)
if not is_valid_signature:
raise ValueError(f'could not validate signature on comment_id: {comment_id}')
except ValueError:
# remove the piece from being hidden
pieces_by_id.pop(comment_id)
# remaining items in pieces_by_id have been able to successfully validate
with app['db'].atomic():
set_hidden_flag(list(pieces_by_id.keys()), hidden=hide)
query = Comment.select().where(Comment.comment_id.in_(comment_ids)).objects()
result = {
'hidden': [c.comment_id for c in query if c.is_hidden],
'visible': [c.comment_id for c in query if not c.is_hidden],
}
return result
async def handle_edit_comment(app, params): async def handle_edit_comment(app, comment: str = None, comment_id: str = None,
if await edit_comment(app, **params): signature: str = None, signing_ts: str = None, **params) -> dict:
return db.get_comment_or_none(app['reader'], params['comment_id']) current = get_comment(comment_id)
channel_claim = await get_claim_from_id(app, current['channel_id'])
if not validate_signature_from_claim(channel_claim, signature, signing_ts, comment):
raise ValueError('Signature could not be validated')
with app['db'].atomic():
if not edit_comment(comment_id, comment, signature, signing_ts):
raise ValueError('Comment could not be edited')
return get_comment(comment_id)
# TODO: retrieve stake amounts for each channel & store in db
def handle_create_comment(app, comment: str = None, claim_id: str = None,
parent_id: str = None, channel_id: str = None, channel_name: str = None,
signature: str = None, signing_ts: str = None) -> dict:
with app['db'].atomic():
return create_comment(
comment=comment,
claim_id=claim_id,
parent_id=parent_id,
channel_id=channel_id,
channel_name=channel_name,
signature=signature,
signing_ts=signing_ts
)
METHODS = { METHODS = {
'ping': ping, 'ping': ping,
'get_claim_comments': handle_get_claim_comments, 'get_claim_comments': handle_get_claim_comments, # this gets used
'get_claim_hidden_comments': handle_get_claim_hidden_comments, 'get_claim_hidden_comments': handle_get_claim_hidden_comments, # this gets used
'get_comment_ids': handle_get_comment_ids, 'get_comment_ids': handle_get_comment_ids,
'get_comments_by_id': handle_get_comments_by_id, 'get_comments_by_id': handle_get_comments_by_id, # this gets used
'get_channel_from_comment_id': handle_get_channel_from_comment_id, 'get_channel_from_comment_id': handle_get_channel_from_comment_id, # this gets used
'create_comment': create_comment, 'create_comment': handle_create_comment, # this gets used
'delete_comment': handle_abandon_comment, 'delete_comment': handle_abandon_comment,
'abandon_comment': handle_abandon_comment, 'abandon_comment': handle_abandon_comment, # this gets used
'hide_comments': handle_hide_comments, 'hide_comments': handle_hide_comments, # this gets used
'edit_comment': handle_edit_comment 'edit_comment': handle_edit_comment # this gets used
} }
@ -78,17 +228,19 @@ async def process_json(app, body: dict) -> dict:
start = time.time() start = time.time()
try: try:
if asyncio.iscoroutinefunction(METHODS[method]): if asyncio.iscoroutinefunction(METHODS[method]):
result = await METHODS[method](app, params) result = await METHODS[method](app, **params)
else: else:
result = METHODS[method](app, params) result = METHODS[method](app, **params)
response['result'] = result
except Exception as err: except Exception as err:
logger.exception(f'Got {type(err).__name__}:') logger.exception(f'Got {type(err).__name__}:\n{err}')
if type(err) in (ValueError, TypeError): # param error, not too important if type(err) in (ValueError, TypeError): # param error, not too important
response['error'] = make_error('INVALID_PARAMS', err) response['error'] = make_error('INVALID_PARAMS', err)
else: else:
response['error'] = make_error('INTERNAL', err) response['error'] = make_error('INTERNAL', err)
await app['webhooks'].spawn(report_error(app, err, body)) await app['webhooks'].spawn(report_error(app, err, body))
else:
response['result'] = result
finally: finally:
end = time.time() end = time.time()

View file

@ -51,11 +51,31 @@ def claim_id_is_valid(claim_id: str) -> bool:
# default to None so params can be treated as kwargs; param count becomes more manageable # default to None so params can be treated as kwargs; param count becomes more manageable
def is_valid_base_comment(comment: str = None, claim_id: str = None, parent_id: str = None, **kwargs) -> bool: def is_valid_base_comment(
return comment and body_is_valid(comment) and \ comment: str = None,
((claim_id and claim_id_is_valid(claim_id)) or # parentid is used in place of claimid in replies claim_id: str = None,
(parent_id and comment_id_is_valid(parent_id))) \ parent_id: str = None,
and is_valid_credential_input(**kwargs) strict: bool = False,
**kwargs,
) -> bool:
try:
assert comment and body_is_valid(comment)
# strict mode assumes that the parent_id might not exist
if strict:
assert claim_id and claim_id_is_valid(claim_id)
assert parent_id is None or comment_id_is_valid(parent_id)
# non-strict removes reference restrictions
else:
assert claim_id or parent_id
if claim_id:
assert claim_id_is_valid(claim_id)
else:
assert comment_id_is_valid(parent_id)
except AssertionError:
return False
else:
return is_valid_credential_input(**kwargs)
def is_valid_credential_input(channel_id: str = None, channel_name: str = None, def is_valid_credential_input(channel_id: str = None, channel_name: str = None,

View file

@ -1,17 +0,0 @@
# cython: language_level=3
import json
import pathlib
root_dir = pathlib.Path(__file__).parent.parent
config_path = root_dir / 'config' / 'conf.json'
def get_config(filepath):
with open(filepath, 'r') as cfile:
conf = json.load(cfile)
for key, path in conf['path'].items():
conf['path'][key] = str(root_dir / path)
return conf
config = get_config(config_path)

View file

@ -1,18 +1,13 @@
import sqlite3
from random import randint from random import randint
import faker import faker
from faker.providers import internet from faker.providers import internet
from faker.providers import lorem from faker.providers import lorem
from faker.providers import misc from faker.providers import misc
from src.database.queries import get_comments_by_id from src.database.models import create_comment
from src.database.queries import get_comment_ids from src.database.models import delete_comment
from src.database.queries import get_claim_comments from src.database.models import comment_list, get_comment
from src.database.queries import get_claim_hidden_comments from src.database.models import set_hidden_flag
from src.database.writes import create_comment_or_error
from src.database.queries import hide_comments_by_id
from src.database.queries import delete_comment_by_id
from test.testcase import DatabaseTestCase from test.testcase import DatabaseTestCase
fake = faker.Faker() fake = faker.Faker()
@ -27,26 +22,25 @@ class TestDatabaseOperations(DatabaseTestCase):
self.claimId = '529357c3422c6046d3fec76be2358004ba22e340' self.claimId = '529357c3422c6046d3fec76be2358004ba22e340'
def test01NamedComments(self): def test01NamedComments(self):
comment = create_comment_or_error( comment = create_comment(
conn=self.conn,
claim_id=self.claimId, claim_id=self.claimId,
comment='This is a named comment', comment='This is a named comment',
channel_name='@username', channel_name='@username',
channel_id='529357c3422c6046d3fec76be2358004ba22abcd', channel_id='529357c3422c6046d3fec76be2358004ba22abcd',
signature=fake.uuid4(), signature='22'*64,
signing_ts='aaa' signing_ts='aaa'
) )
self.assertIsNotNone(comment) self.assertIsNotNone(comment)
self.assertNotIn('parent_in', comment) self.assertNotIn('parent_in', comment)
previous_id = comment['comment_id'] previous_id = comment['comment_id']
reply = create_comment_or_error( reply = create_comment(
conn=self.conn,
claim_id=self.claimId, claim_id=self.claimId,
comment='This is a named response', comment='This is a named response',
channel_name='@another_username', channel_name='@another_username',
channel_id='529357c3422c6046d3fec76be2358004ba224bcd', channel_id='529357c3422c6046d3fec76be2358004ba224bcd',
parent_id=previous_id, parent_id=previous_id,
signature=fake.uuid4(), signature='11'*64,
signing_ts='aaa' signing_ts='aaa'
) )
self.assertIsNotNone(reply) self.assertIsNotNone(reply)
@ -54,34 +48,32 @@ class TestDatabaseOperations(DatabaseTestCase):
def test02AnonymousComments(self): def test02AnonymousComments(self):
self.assertRaises( self.assertRaises(
sqlite3.IntegrityError, ValueError,
create_comment_or_error, create_comment,
conn=self.conn,
claim_id=self.claimId, claim_id=self.claimId,
comment='This is an ANONYMOUS comment' comment='This is an ANONYMOUS comment'
) )
def test03SignedComments(self): def test03SignedComments(self):
comment = create_comment_or_error( comment = create_comment(
conn=self.conn,
claim_id=self.claimId, claim_id=self.claimId,
comment='I like big butts and i cannot lie', comment='I like big butts and i cannot lie',
channel_name='@sirmixalot', channel_name='@sirmixalot',
channel_id='529357c3422c6046d3fec76be2358005ba22abcd', channel_id='529357c3422c6046d3fec76be2358005ba22abcd',
signature=fake.uuid4(), signature='24'*64,
signing_ts='asdasd' signing_ts='asdasd'
) )
self.assertIsNotNone(comment) self.assertIsNotNone(comment)
self.assertIn('signing_ts', comment) self.assertIn('signing_ts', comment)
previous_id = comment['comment_id'] previous_id = comment['comment_id']
reply = create_comment_or_error( reply = create_comment(
conn=self.conn,
claim_id=self.claimId, claim_id=self.claimId,
comment='This is a LBRY verified response', comment='This is a LBRY verified response',
channel_name='@LBRY', channel_name='@LBRY',
channel_id='529357c3422c6046d3fec76be2358001ba224bcd', channel_id='529357c3422c6046d3fec76be2358001ba224bcd',
parent_id=previous_id, parent_id=previous_id,
signature=fake.uuid4(), signature='12'*64,
signing_ts='sfdfdfds' signing_ts='sfdfdfds'
) )
self.assertIsNotNone(reply) self.assertIsNotNone(reply)
@ -90,75 +82,109 @@ class TestDatabaseOperations(DatabaseTestCase):
def test04UsernameVariations(self): def test04UsernameVariations(self):
self.assertRaises( self.assertRaises(
AssertionError, ValueError,
callable=create_comment_or_error, create_comment,
conn=self.conn,
claim_id=self.claimId, claim_id=self.claimId,
channel_name='$#(@#$@#$', channel_name='$#(@#$@#$',
channel_id='529357c3422c6046d3fec76be2358001ba224b23', channel_id='529357c3422c6046d3fec76be2358001ba224b23',
comment='this is an invalid username' comment='this is an invalid username',
signature='1' * 128,
signing_ts='123'
) )
valid_username = create_comment_or_error(
conn=self.conn, valid_username = create_comment(
claim_id=self.claimId, claim_id=self.claimId,
channel_name='@' + 'a' * 255, channel_name='@' + 'a' * 255,
channel_id='529357c3422c6046d3fec76be2358001ba224b23', channel_id='529357c3422c6046d3fec76be2358001ba224b23',
comment='this is a valid username' comment='this is a valid username',
signature='1'*128,
signing_ts='123'
) )
self.assertIsNotNone(valid_username) self.assertIsNotNone(valid_username)
self.assertRaises(AssertionError,
callable=create_comment_or_error,
conn=self.conn,
claim_id=self.claimId,
channel_name='@' + 'a' * 256,
channel_id='529357c3422c6046d3fec76be2358001ba224b23',
comment='this username is too long'
)
self.assertRaises( self.assertRaises(
AssertionError, ValueError,
callable=create_comment_or_error, create_comment,
conn=self.conn, claim_id=self.claimId,
channel_name='@' + 'a' * 256,
channel_id='529357c3422c6046d3fec76be2358001ba224b23',
comment='this username is too long',
signature='2' * 128,
signing_ts='123'
)
self.assertRaises(
ValueError,
create_comment,
claim_id=self.claimId, claim_id=self.claimId,
channel_name='', channel_name='',
channel_id='529357c3422c6046d3fec76be2358001ba224b23', channel_id='529357c3422c6046d3fec76be2358001ba224b23',
comment='this username should not default to ANONYMOUS' comment='this username should not default to ANONYMOUS',
signature='3' * 128,
signing_ts='123'
) )
self.assertRaises( self.assertRaises(
AssertionError, ValueError,
callable=create_comment_or_error, create_comment,
conn=self.conn,
claim_id=self.claimId, claim_id=self.claimId,
channel_name='@', channel_name='@',
channel_id='529357c3422c6046d3fec76be2358001ba224b23', channel_id='529357c3422c6046d3fec76be2358001ba224b23',
comment='this username is too short' comment='this username is too short',
signature='3' * 128,
signing_ts='123'
) )
def test05HideComments(self): def test05HideComments(self):
comm = create_comment_or_error(self.conn, 'Comment #1', self.claimId, '1'*40, '@Doge123', 'a'*128, '123') comm = create_comment(
comment = get_comments_by_id(self.conn, [comm['comment_id']]).pop() comment='Comment #1',
claim_id=self.claimId,
channel_id='1'*40,
channel_name='@Doge123',
signature='a'*128,
signing_ts='123'
)
comment = get_comment(comm['comment_id'])
self.assertFalse(comment['is_hidden']) self.assertFalse(comment['is_hidden'])
success = hide_comments_by_id(self.conn, [comm['comment_id']])
success = set_hidden_flag([comm['comment_id']])
self.assertTrue(success) self.assertTrue(success)
comment = get_comments_by_id(self.conn, [comm['comment_id']]).pop()
comment = get_comment(comm['comment_id'])
self.assertTrue(comment['is_hidden']) self.assertTrue(comment['is_hidden'])
success = hide_comments_by_id(self.conn, [comm['comment_id']])
success = set_hidden_flag([comm['comment_id']])
self.assertTrue(success) self.assertTrue(success)
comment = get_comments_by_id(self.conn, [comm['comment_id']]).pop()
comment = get_comment(comm['comment_id'])
self.assertTrue(comment['is_hidden']) self.assertTrue(comment['is_hidden'])
def test06DeleteComments(self): def test06DeleteComments(self):
comm = create_comment_or_error(self.conn, 'Comment #1', self.claimId, '1'*40, '@Doge123', 'a'*128, '123') # make sure that the comment was created
comments = get_claim_comments(self.conn, self.claimId) comm = create_comment(
match = list(filter(lambda x: comm['comment_id'] == x['comment_id'], comments['items'])) comment='Comment #1',
self.assertTrue(match) claim_id=self.claimId,
deleted = delete_comment_by_id(self.conn, comm['comment_id']) channel_id='1'*40,
channel_name='@Doge123',
signature='a'*128,
signing_ts='123'
)
comments = comment_list(self.claimId)
match = [x for x in comments['items'] if x['comment_id'] == comm['comment_id']]
self.assertTrue(len(match) > 0)
deleted = delete_comment(comm['comment_id'])
self.assertTrue(deleted) self.assertTrue(deleted)
comments = get_claim_comments(self.conn, self.claimId)
match = list(filter(lambda x: comm['comment_id'] == x['comment_id'], comments['items'])) # make sure that we can't find the comment here
comments = comment_list(self.claimId)
match = [x for x in comments['items'] if x['comment_id'] == comm['comment_id']]
self.assertFalse(match) self.assertFalse(match)
deleted = delete_comment_by_id(self.conn, comm['comment_id']) self.assertRaises(
self.assertFalse(deleted) ValueError,
delete_comment,
comment_id=comm['comment_id'],
)
class ListDatabaseTest(DatabaseTestCase): class ListDatabaseTest(DatabaseTestCase):
@ -169,61 +195,75 @@ class ListDatabaseTest(DatabaseTestCase):
def testLists(self): def testLists(self):
for claim_id in self.claim_ids: for claim_id in self.claim_ids:
with self.subTest(claim_id=claim_id): with self.subTest(claim_id=claim_id):
comments = get_claim_comments(self.conn, claim_id) comments = comment_list(claim_id)
self.assertIsNotNone(comments) self.assertIsNotNone(comments)
self.assertGreater(comments['page_size'], 0) self.assertGreater(comments['page_size'], 0)
self.assertIn('has_hidden_comments', comments) self.assertIn('has_hidden_comments', comments)
self.assertFalse(comments['has_hidden_comments']) self.assertFalse(comments['has_hidden_comments'])
top_comments = get_claim_comments(self.conn, claim_id, top_level=True, page=1, page_size=50) top_comments = comment_list(claim_id, top_level=True, page=1, page_size=50)
self.assertIsNotNone(top_comments) self.assertIsNotNone(top_comments)
self.assertEqual(top_comments['page_size'], 50) self.assertEqual(top_comments['page_size'], 50)
self.assertEqual(top_comments['page'], 1) self.assertEqual(top_comments['page'], 1)
self.assertGreaterEqual(top_comments['total_pages'], 0) self.assertGreaterEqual(top_comments['total_pages'], 0)
self.assertGreaterEqual(top_comments['total_items'], 0) self.assertGreaterEqual(top_comments['total_items'], 0)
comment_ids = get_comment_ids(self.conn, claim_id, page_size=50, page=1) comment_ids = comment_list(claim_id, page_size=50, page=1)
with self.subTest(comment_ids=comment_ids): with self.subTest(comment_ids=comment_ids):
self.assertIsNotNone(comment_ids) self.assertIsNotNone(comment_ids)
self.assertLessEqual(len(comment_ids), 50) self.assertLessEqual(len(comment_ids), 50)
matching_comments = get_comments_by_id(self.conn, comment_ids) matching_comments = (comment_ids)
self.assertIsNotNone(matching_comments) self.assertIsNotNone(matching_comments)
self.assertEqual(len(matching_comments), len(comment_ids)) self.assertEqual(len(matching_comments), len(comment_ids))
def testHiddenCommentLists(self): def testHiddenCommentLists(self):
claim_id = 'a'*40 claim_id = 'a'*40
comm1 = create_comment_or_error(self.conn, 'Comment #1', claim_id, '1'*40, '@Doge123', 'a'*128, '123') comm1 = create_comment(
comm2 = create_comment_or_error(self.conn, 'Comment #2', claim_id, '1'*40, '@Doge123', 'b'*128, '123') 'Comment #1',
comm3 = create_comment_or_error(self.conn, 'Comment #3', claim_id, '1'*40, '@Doge123', 'c'*128, '123') claim_id,
channel_id='1'*40,
channel_name='@Doge123',
signature='a'*128,
signing_ts='123'
)
comm2 = create_comment(
'Comment #2', claim_id,
channel_id='1'*40,
channel_name='@Doge123',
signature='b'*128,
signing_ts='123'
)
comm3 = create_comment(
'Comment #3', claim_id,
channel_id='1'*40,
channel_name='@Doge123',
signature='c'*128,
signing_ts='123'
)
comments = [comm1, comm2, comm3] comments = [comm1, comm2, comm3]
comment_list = get_claim_comments(self.conn, claim_id) listed_comments = comment_list(claim_id)
self.assertIn('items', comment_list) self.assertEqual(len(comments), listed_comments['total_items'])
self.assertIn('has_hidden_comments', comment_list) self.assertFalse(listed_comments['has_hidden_comments'])
self.assertEqual(len(comments), comment_list['total_items'])
self.assertIn('has_hidden_comments', comment_list)
self.assertFalse(comment_list['has_hidden_comments'])
hide_comments_by_id(self.conn, [comm2['comment_id']])
default_comments = get_claim_hidden_comments(self.conn, claim_id) set_hidden_flag([comm2['comment_id']])
self.assertIn('has_hidden_comments', default_comments) hidden = comment_list(claim_id, exclude_mode='hidden')
hidden_comments = get_claim_hidden_comments(self.conn, claim_id, hidden=True) self.assertTrue(hidden['has_hidden_comments'])
self.assertIn('has_hidden_comments', hidden_comments) self.assertGreater(len(hidden['items']), 0)
self.assertEqual(default_comments, hidden_comments)
hidden_comment = hidden_comments['items'][0] visible = comment_list(claim_id, exclude_mode='visible')
self.assertFalse(visible['has_hidden_comments'])
self.assertNotEqual(listed_comments['items'], visible['items'])
# make sure the hidden comment is the one we marked as hidden
hidden_comment = hidden['items'][0]
self.assertEqual(hidden_comment['comment_id'], comm2['comment_id']) self.assertEqual(hidden_comment['comment_id'], comm2['comment_id'])
visible_comments = get_claim_hidden_comments(self.conn, claim_id, hidden=False) hidden_ids = [c['comment_id'] for c in hidden['items']]
self.assertIn('has_hidden_comments', visible_comments) visible_ids = [c['comment_id'] for c in visible['items']]
self.assertNotIn(hidden_comment, visible_comments['items'])
hidden_ids = [c['comment_id'] for c in hidden_comments['items']]
visible_ids = [c['comment_id'] for c in visible_comments['items']]
composite_ids = hidden_ids + visible_ids composite_ids = hidden_ids + visible_ids
listed_comments = comment_list(claim_id)
all_ids = [c['comment_id'] for c in listed_comments['items']]
composite_ids.sort() composite_ids.sort()
comment_list = get_claim_comments(self.conn, claim_id)
all_ids = [c['comment_id'] for c in comment_list['items']]
all_ids.sort() all_ids.sort()
self.assertEqual(composite_ids, all_ids) self.assertEqual(composite_ids, all_ids)

View file

@ -9,13 +9,18 @@ from faker.providers import internet
from faker.providers import lorem from faker.providers import lorem
from faker.providers import misc from faker.providers import misc
from src.settings import config from src.main import get_config, CONFIG_FILE
from src.server import app from src.server import app
from src.server.validation import is_valid_base_comment from src.server.validation import is_valid_base_comment
from test.testcase import AsyncioTestCase from test.testcase import AsyncioTestCase
config = get_config(CONFIG_FILE)
config['mode'] = 'testing'
config['testing']['file'] = ':memory:'
if 'slack_webhook' in config: if 'slack_webhook' in config:
config.pop('slack_webhook') config.pop('slack_webhook')
@ -71,10 +76,10 @@ def create_test_comments(values: iter, **default):
class ServerTest(AsyncioTestCase): class ServerTest(AsyncioTestCase):
db_file = 'test.db'
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
config['mode'] = 'testing'
config['testing']['file'] = ':memory:'
self.host = 'localhost' self.host = 'localhost'
self.port = 5931 self.port = 5931
@ -85,11 +90,10 @@ class ServerTest(AsyncioTestCase):
@classmethod @classmethod
def tearDownClass(cls) -> None: def tearDownClass(cls) -> None:
print('exit reached') print('exit reached')
os.remove(cls.db_file)
async def asyncSetUp(self): async def asyncSetUp(self):
await super().asyncSetUp() await super().asyncSetUp()
self.server = app.CommentDaemon(config, db_file=self.db_file) self.server = app.CommentDaemon(config)
await self.server.start(host=self.host, port=self.port) await self.server.start(host=self.host, port=self.port)
self.addCleanup(self.server.stop) self.addCleanup(self.server.stop)
@ -135,14 +139,16 @@ class ServerTest(AsyncioTestCase):
test_all = create_test_comments(replace.keys(), **{ test_all = create_test_comments(replace.keys(), **{
k: None for k in replace.keys() k: None for k in replace.keys()
}) })
test_all.reverse()
for test in test_all: for test in test_all:
with self.subTest(test=test): nulls = 'null fields: ' + ', '.join(k for k, v in test.items() if not v)
with self.subTest(test=nulls):
message = await self.post_comment(**test) message = await self.post_comment(**test)
self.assertTrue('result' in message or 'error' in message) self.assertTrue('result' in message or 'error' in message)
if 'error' in message: if 'error' in message:
self.assertFalse(is_valid_base_comment(**test)) self.assertFalse(is_valid_base_comment(**test, strict=True))
else: else:
self.assertTrue(is_valid_base_comment(**test)) self.assertTrue(is_valid_base_comment(**test, strict=True))
async def test04CreateAllReplies(self): async def test04CreateAllReplies(self):
claim_id = '1d8a5cc39ca02e55782d619e67131c0a20843be8' claim_id = '1d8a5cc39ca02e55782d619e67131c0a20843be8'
@ -220,7 +226,8 @@ class ListCommentsTest(AsyncioTestCase):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.host = 'localhost' self.host = 'localhost'
self.port = 5931 self.port = 5931
self.db_file = 'list_test.db' config['mode'] = 'testing'
config['testing']['file'] = ':memory:'
self.claim_id = '1d8a5cc39ca02e55782d619e67131c0a20843be8' self.claim_id = '1d8a5cc39ca02e55782d619e67131c0a20843be8'
self.comment_ids = None self.comment_ids = None
@ -231,10 +238,6 @@ class ListCommentsTest(AsyncioTestCase):
async def post_comment(self, **params): async def post_comment(self, **params):
return await jsonrpc_post(self.url, 'create_comment', **params) return await jsonrpc_post(self.url, 'create_comment', **params)
def tearDown(self) -> None:
print('exit reached')
os.remove(self.db_file)
async def create_lots_of_comments(self, n=23): async def create_lots_of_comments(self, n=23):
self.comment_list = [{key: self.replace[key]() for key in self.replace.keys()} for _ in range(23)] self.comment_list = [{key: self.replace[key]() for key in self.replace.keys()} for _ in range(23)]
for comment in self.comment_list: for comment in self.comment_list:
@ -244,7 +247,7 @@ class ListCommentsTest(AsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):
await super().asyncSetUp() await super().asyncSetUp()
self.server = app.CommentDaemon(config, db_file=self.db_file) self.server = app.CommentDaemon(config)
await self.server.start(self.host, self.port) await self.server.start(self.host, self.port)
self.addCleanup(self.server.stop) self.addCleanup(self.server.stop)

View file

@ -1,12 +1,38 @@
import os import os
import pathlib import pathlib
import unittest import unittest
from asyncio.runners import _cancel_all_tasks # type: ignore
from unittest.case import _Outcome from unittest.case import _Outcome
import asyncio import asyncio
from asyncio.runners import _cancel_all_tasks # type: ignore
from peewee import *
from src.database.queries import obtain_connection, setup_database from src.database.models import Channel, Comment
test_db = SqliteDatabase(':memory:')
MODELS = [Channel, Comment]
class DatabaseTestCase(unittest.TestCase):
def __init__(self, methodName='DatabaseTest'):
super().__init__(methodName)
def setUp(self) -> None:
super().setUp()
test_db.bind(MODELS, bind_refs=False, bind_backrefs=False)
test_db.connect()
test_db.create_tables(MODELS)
def tearDown(self) -> None:
# drop tables for next test
test_db.drop_tables(MODELS)
# close connection
test_db.close()
class AsyncioTestCase(unittest.TestCase): class AsyncioTestCase(unittest.TestCase):
@ -117,21 +143,3 @@ class AsyncioTestCase(unittest.TestCase):
self.loop.run_until_complete(maybe_coroutine) self.loop.run_until_complete(maybe_coroutine)
class DatabaseTestCase(unittest.TestCase):
db_file = 'test.db'
def __init__(self, methodName='DatabaseTest'):
super().__init__(methodName)
if pathlib.Path(self.db_file).exists():
os.remove(self.db_file)
def setUp(self) -> None:
super().setUp()
setup_database(self.db_file)
self.conn = obtain_connection(self.db_file)
self.addCleanup(self.conn.close)
self.addCleanup(os.remove, self.db_file)
def tearDown(self) -> None:
self.conn.close()