Merge pull request #39 from lbryio/orm-rewrite

Replace database methods with peewee ORM
This commit is contained in:
Oleg Silkin 2020-04-03 17:40:42 -04:00 committed by GitHub
commit 80f77218f9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 783 additions and 538 deletions

4
.gitignore vendored
View file

@ -1 +1,3 @@
config/conf.json
config/conf.yml
docker-compose.yml

View file

@ -1,7 +1,28 @@
sudo: required
language: python
dist: xenial
python: 3.7
python: 3.8
# for docker-compose
services:
- docker
# to avoid "well it works on my computer" moments
env:
- DOCKER_COMPOSE_VERSION=1.25.4
before_install:
# ensure docker-compose version is as specified above
- sudo rm /usr/local/bin/docker-compose
- curl -L https://github.com/docker/compose/releases/download/${DOCKER_COMPOSE_VERSION}/docker-compose-`uname -s`-`uname -m` > docker-compose
- chmod +x docker-compose
- sudo mv docker-compose /usr/local/bin
# refresh docker images
- sudo apt-get update
before_script:
- docker-compose up -d
jobs:
include:

30
config/conf.yml Normal file
View file

@ -0,0 +1,30 @@
---
# for running local-tests without using MySQL for now
testing:
database: sqlite
file: comments.db
pragmas:
journal_mode: wal
cache_size: 64000
foreign_keys: 0
ignore_check_constraints: 1
synchronous: 0
# actual database should be running MySQL
production:
database: mysql
name: lbry
user: lbry
password: lbry
host: localhost
port: 3306
mode: production
logging:
format: "%(asctime)s | %(levelname)s | %(name)s | %(module)s.%(funcName)s:%(lineno)d
| %(message)s"
aiohttp_format: "%(asctime)s | %(levelname)s | %(name)s | %(message)s"
datefmt: "%Y-%m-%d %H:%M:%S"
host: localhost
port: 5921
lbrynet: http://localhost:5279

24
docker-compose.yml Normal file
View file

@ -0,0 +1,24 @@
version: "3.7"
services:
###########
## MySQL ##
###########
mysql:
image: mysql/mysql-server:5.7.27
restart: "no"
ports:
- "3306:3306"
environment:
- MYSQL_ALLOW_EMPTY_PASSWORD=true
- MYSQL_DATABASE=lbry
- MYSQL_USER=lbry
- MYSQL_PASSWORD=lbry
- MYSQL_LOG_CONSOLE=true
#############
## Adminer ##
#############
adminer:
image: adminer
restart: always
ports:
- 8080:8080

View file

@ -14,15 +14,17 @@ setup(
data_files=[('config', ['config/conf.json',])],
include_package_data=True,
install_requires=[
'pymysql',
'pyyaml',
'Faker>=1.0.7',
'asyncio>=3.4.3',
'aiohttp==3.5.4',
'aiojobs==0.2.2',
'asyncio',
'aiohttp',
'aiojobs',
'ecdsa>=0.13.3',
'cryptography==2.5',
'aiosqlite==0.10.0',
'PyNaCl>=1.3.0',
'requests',
'cython',
'peewee'
]
)

217
src/database/models.py Normal file
View file

@ -0,0 +1,217 @@
import json
import time
import logging
import math
import typing
from peewee import *
import nacl.hash
from src.server.validation import is_valid_base_comment
from src.misc import clean
class Channel(Model):
claim_id = CharField(column_name='ClaimId', primary_key=True, max_length=40)
name = CharField(column_name='Name', max_length=256)
class Meta:
table_name = 'CHANNEL'
class Comment(Model):
comment = CharField(column_name='Body', max_length=2000)
channel = ForeignKeyField(
backref='comments',
column_name='ChannelId',
field='claim_id',
model=Channel,
null=True
)
comment_id = CharField(column_name='CommentId', primary_key=True, max_length=64)
is_hidden = BooleanField(column_name='IsHidden', constraints=[SQL("DEFAULT 0")])
claim_id = CharField(max_length=40, column_name='LbryClaimId')
parent = ForeignKeyField(
column_name='ParentId',
field='comment_id',
model='self',
null=True,
backref='replies'
)
signature = CharField(max_length=128, column_name='Signature', null=True, unique=True)
signing_ts = TextField(column_name='SigningTs', null=True)
timestamp = IntegerField(column_name='Timestamp')
class Meta:
table_name = 'COMMENT'
indexes = (
(('channel', 'comment_id'), False),
(('claim_id', 'comment_id'), False),
)
FIELDS = {
'comment': Comment.comment,
'comment_id': Comment.comment_id,
'claim_id': Comment.claim_id,
'timestamp': Comment.timestamp,
'signature': Comment.signature,
'signing_ts': Comment.signing_ts,
'is_hidden': Comment.is_hidden,
'parent_id': Comment.parent.alias('parent_id'),
'channel_id': Channel.claim_id.alias('channel_id'),
'channel_name': Channel.name.alias('channel_name'),
'channel_url': ('lbry://' + Channel.name + '#' + Channel.claim_id).alias('channel_url')
}
def comment_list(claim_id: str = None, parent_id: str = None,
top_level: bool = False, exclude_mode: str = None,
page: int = 1, page_size: int = 50, expressions=None,
select_fields: list = None, exclude_fields: list = None) -> dict:
fields = FIELDS.keys()
if exclude_fields:
fields -= set(exclude_fields)
if select_fields:
fields &= set(select_fields)
attributes = [FIELDS[field] for field in fields]
query = Comment.select(*attributes)
# todo: allow this process to be more automated, so it can just be an expression
if claim_id:
query = query.where(Comment.claim_id == claim_id)
if top_level:
query = query.where(Comment.parent.is_null())
if parent_id:
query = query.where(Comment.ParentId == parent_id)
if exclude_mode:
show_hidden = exclude_mode.lower() == 'hidden'
query = query.where((Comment.is_hidden == show_hidden))
if expressions:
query = query.where(expressions)
total = query.count()
query = (query
.join(Channel, JOIN.LEFT_OUTER)
.order_by(Comment.timestamp.desc())
.paginate(page, page_size))
items = [clean(item) for item in query.dicts()]
# has_hidden_comments is deprecated
data = {
'page': page,
'page_size': page_size,
'total_pages': math.ceil(total / page_size),
'total_items': total,
'items': items,
'has_hidden_comments': exclude_mode is not None and exclude_mode == 'hidden',
}
return data
def get_comment(comment_id: str) -> dict:
try:
comment = comment_list(expressions=(Comment.comment_id == comment_id), page_size=1).get('items').pop()
except IndexError:
raise ValueError(f'Comment does not exist with id {comment_id}')
else:
return comment
def create_comment_id(comment: str, channel_id: str, timestamp: int):
# We convert the timestamp from seconds into minutes
# to prevent spammers from commenting the same BS everywhere.
nearest_minute = str(math.floor(timestamp / 60))
# don't use claim_id for the comment_id anymore so comments
# are not unique to just one claim
prehash = b':'.join([
comment.encode(),
channel_id.encode(),
nearest_minute.encode()
])
return nacl.hash.sha256(prehash).decode()
def create_comment(comment: str = None, claim_id: str = None,
parent_id: str = None, channel_id: str = None,
channel_name: str = None, signature: str = None,
signing_ts: str = None) -> dict:
if not is_valid_base_comment(
comment=comment,
claim_id=claim_id,
parent_id=parent_id,
channel_id=channel_id,
channel_name=channel_name,
signature=signature,
signing_ts=signing_ts
):
raise ValueError('Invalid Parameters given for comment')
channel, _ = Channel.get_or_create(name=channel_name, claim_id=channel_id)
if parent_id and not claim_id:
parent: Comment = Comment.get_by_id(parent_id)
claim_id = parent.claim_id
timestamp = int(time.time())
comment_id = create_comment_id(comment, channel_id, timestamp)
new_comment = Comment.create(
claim_id=claim_id,
comment_id=comment_id,
comment=comment,
parent=parent_id,
channel=channel,
signature=signature,
signing_ts=signing_ts,
timestamp=timestamp
)
return get_comment(new_comment.comment_id)
def delete_comment(comment_id: str) -> bool:
try:
comment: Comment = Comment.get_by_id(comment_id)
except DoesNotExist as e:
raise ValueError from e
else:
return 0 < comment.delete_instance(True, delete_nullable=True)
def edit_comment(comment_id: str, new_comment: str, new_sig: str, new_ts: str) -> bool:
try:
comment: Comment = Comment.get_by_id(comment_id)
except DoesNotExist as e:
raise ValueError from e
else:
comment.comment = new_comment
comment.signature = new_sig
comment.signing_ts = new_ts
# todo: add a 'last-modified' timestamp
comment.timestamp = int(time.time())
return comment.save() > 0
def set_hidden_flag(comment_ids: typing.List[str], hidden=True) -> bool:
# sets `is_hidden` flag for all `comment_ids` to the `hidden` param
update = (Comment
.update(is_hidden=hidden)
.where(Comment.comment_id.in_(comment_ids)))
return update.execute() > 0
if __name__ == '__main__':
logger = logging.getLogger('peewee')
logger.addHandler(logging.StreamHandler())
logger.setLevel(logging.DEBUG)
comments = comment_list(
page_size=20,
expressions=((Comment.timestamp < 1583272089) &
(Comment.claim_id ** '420%'))
)
print(json.dumps(comments, indent=4))

View file

@ -1,290 +0,0 @@
import atexit
import logging
import math
import sqlite3
import time
import typing
import nacl.hash
from src.database.schema import CREATE_TABLES_QUERY
logger = logging.getLogger(__name__)
SELECT_COMMENTS_ON_CLAIMS = """
SELECT comment, comment_id, claim_id, timestamp, is_hidden, parent_id,
channel_name, channel_id, channel_url, signature, signing_ts
FROM COMMENTS_ON_CLAIMS
"""
def clean(thing: dict) -> dict:
if 'is_hidden' in thing:
thing.update({'is_hidden': bool(thing['is_hidden'])})
return {k: v for k, v in thing.items() if v is not None}
def obtain_connection(filepath: str = None, row_factory: bool = True):
connection = sqlite3.connect(filepath)
if row_factory:
connection.row_factory = sqlite3.Row
return connection
def get_claim_comments(conn: sqlite3.Connection, claim_id: str, parent_id: str = None,
page: int = 1, page_size: int = 50, top_level=False):
with conn:
if top_level:
results = [clean(dict(row)) for row in conn.execute(
SELECT_COMMENTS_ON_CLAIMS + " WHERE claim_id = ? AND parent_id IS NULL LIMIT ? OFFSET ?",
(claim_id, page_size, page_size * (page - 1))
)]
count = conn.execute(
"SELECT COUNT(*) FROM COMMENTS_ON_CLAIMS WHERE claim_id = ? AND parent_id IS NULL",
(claim_id,)
)
elif parent_id is None:
results = [clean(dict(row)) for row in conn.execute(
SELECT_COMMENTS_ON_CLAIMS + "WHERE claim_id = ? LIMIT ? OFFSET ? ",
(claim_id, page_size, page_size * (page - 1))
)]
count = conn.execute(
"SELECT COUNT(*) FROM COMMENTS_ON_CLAIMS WHERE claim_id = ?",
(claim_id,)
)
else:
results = [clean(dict(row)) for row in conn.execute(
SELECT_COMMENTS_ON_CLAIMS + "WHERE claim_id = ? AND parent_id = ? LIMIT ? OFFSET ? ",
(claim_id, parent_id, page_size, page_size * (page - 1))
)]
count = conn.execute(
"SELECT COUNT(*) FROM COMMENTS_ON_CLAIMS WHERE claim_id = ? AND parent_id = ?",
(claim_id, parent_id)
)
count = tuple(count.fetchone())[0]
return {
'items': results,
'page': page,
'page_size': page_size,
'total_pages': math.ceil(count / page_size),
'total_items': count,
'has_hidden_comments': claim_has_hidden_comments(conn, claim_id)
}
def get_claim_hidden_comments(conn: sqlite3.Connection, claim_id: str, hidden=True, page=1, page_size=50):
with conn:
results = conn.execute(
SELECT_COMMENTS_ON_CLAIMS + "WHERE claim_id = ? AND is_hidden IS ? LIMIT ? OFFSET ?",
(claim_id, hidden, page_size, page_size * (page - 1))
)
count = conn.execute(
"SELECT COUNT(*) FROM COMMENTS_ON_CLAIMS WHERE claim_id = ? AND is_hidden IS ?", (claim_id, hidden)
)
results = [clean(dict(row)) for row in results.fetchall()]
count = tuple(count.fetchone())[0]
return {
'items': results,
'page': page,
'page_size': page_size,
'total_pages': math.ceil(count / page_size),
'total_items': count,
'has_hidden_comments': claim_has_hidden_comments(conn, claim_id)
}
def claim_has_hidden_comments(conn, claim_id):
with conn:
result = conn.execute(
"SELECT COUNT(DISTINCT is_hidden) FROM COMMENTS_ON_CLAIMS WHERE claim_id = ? AND is_hidden IS 1",
(claim_id,)
)
return bool(tuple(result.fetchone())[0])
def insert_comment(conn: sqlite3.Connection, claim_id: str, comment: str,
channel_id: str = None, signature: str = None, signing_ts: str = None, **extra) -> str:
timestamp = int(time.time())
prehash = b':'.join((claim_id.encode(), comment.encode(), str(timestamp).encode(),))
comment_id = nacl.hash.sha256(prehash).decode()
with conn:
curs = conn.execute(
"""
INSERT INTO COMMENT(CommentId, LbryClaimId, ChannelId, Body, ParentId,
Timestamp, Signature, SigningTs, IsHidden)
VALUES (:comment_id, :claim_id, :channel_id, :comment, NULL,
:timestamp, :signature, :signing_ts, 0) """,
{
'comment_id': comment_id,
'claim_id': claim_id,
'channel_id': channel_id,
'comment': comment,
'timestamp': timestamp,
'signature': signature,
'signing_ts': signing_ts
}
)
logging.info('attempted to insert comment with comment_id [%s] | %d rows affected', comment_id, curs.rowcount)
return comment_id
def insert_reply(conn: sqlite3.Connection, comment: str, parent_id: str,
channel_id: str = None, signature: str = None,
signing_ts: str = None, **extra) -> str:
timestamp = int(time.time())
prehash = b':'.join((parent_id.encode(), comment.encode(), str(timestamp).encode(),))
comment_id = nacl.hash.sha256(prehash).decode()
with conn:
curs = conn.execute(
"""
INSERT INTO COMMENT
(CommentId, LbryClaimId, ChannelId, Body, ParentId, Signature, Timestamp, SigningTs, IsHidden)
SELECT :comment_id, LbryClaimId, :channel_id, :comment, :parent_id, :signature, :timestamp, :signing_ts, 0
FROM COMMENT WHERE CommentId = :parent_id
""", {
'comment_id': comment_id,
'parent_id': parent_id,
'timestamp': timestamp,
'comment': comment,
'channel_id': channel_id,
'signature': signature,
'signing_ts': signing_ts
}
)
logging.info('attempted to insert reply with comment_id [%s] | %d rows affected', comment_id, curs.rowcount)
return comment_id
def get_comment_or_none(conn: sqlite3.Connection, comment_id: str) -> dict:
with conn:
curry = conn.execute(SELECT_COMMENTS_ON_CLAIMS + "WHERE comment_id = ?", (comment_id,))
thing = curry.fetchone()
return clean(dict(thing)) if thing else None
def get_comment_ids(conn: sqlite3.Connection, claim_id: str, parent_id: str = None, page=1, page_size=50):
""" Just return a list of the comment IDs that are associated with the given claim_id.
If get_all is specified then it returns all the IDs, otherwise only the IDs at that level.
if parent_id is left null then it only returns the top level comments.
For pagination the parameters are:
get_all XOR (page_size + page)
"""
with conn:
if parent_id is None:
curs = conn.execute("""
SELECT comment_id FROM COMMENTS_ON_CLAIMS
WHERE claim_id = ? AND parent_id IS NULL LIMIT ? OFFSET ?
""", (claim_id, page_size, page_size * abs(page - 1),)
)
else:
curs = conn.execute("""
SELECT comment_id FROM COMMENTS_ON_CLAIMS
WHERE claim_id = ? AND parent_id = ? LIMIT ? OFFSET ?
""", (claim_id, parent_id, page_size, page_size * abs(page - 1),)
)
return [tuple(row)[0] for row in curs.fetchall()]
def get_comments_by_id(conn, comment_ids: typing.Union[list, tuple]) -> typing.Union[list, None]:
""" Returns a list containing the comment data associated with each ID within the list"""
# format the input, under the assumption that the
placeholders = ', '.join('?' for _ in comment_ids)
with conn:
return [clean(dict(row)) for row in conn.execute(
SELECT_COMMENTS_ON_CLAIMS + f'WHERE comment_id IN ({placeholders})',
tuple(comment_ids)
)]
def delete_comment_by_id(conn: sqlite3.Connection, comment_id: str) -> bool:
with conn:
curs = conn.execute("DELETE FROM COMMENT WHERE CommentId = ?", (comment_id,))
return bool(curs.rowcount)
def insert_channel(conn: sqlite3.Connection, channel_name: str, channel_id: str):
with conn:
curs = conn.execute('INSERT INTO CHANNEL(ClaimId, Name) VALUES (?, ?)', (channel_id, channel_name))
return bool(curs.rowcount)
def get_channel_id_from_comment_id(conn: sqlite3.Connection, comment_id: str):
with conn:
channel = conn.execute(
"SELECT channel_id, channel_name FROM COMMENTS_ON_CLAIMS WHERE comment_id = ?", (comment_id,)
).fetchone()
return dict(channel) if channel else {}
def get_claim_ids_from_comment_ids(conn: sqlite3.Connection, comment_ids: list):
with conn:
cids = conn.execute(
f""" SELECT CommentId as comment_id, LbryClaimId AS claim_id FROM COMMENT
WHERE CommentId IN ({', '.join('?' for _ in comment_ids)}) """,
tuple(comment_ids)
)
return {row['comment_id']: row['claim_id'] for row in cids.fetchall()}
def hide_comments_by_id(conn: sqlite3.Connection, comment_ids: list) -> bool:
with conn:
curs = conn.cursor()
curs.executemany(
"UPDATE COMMENT SET IsHidden = 1 WHERE CommentId = ?",
[[c] for c in comment_ids]
)
return bool(curs.rowcount)
def edit_comment_by_id(conn: sqlite3.Connection, comment_id: str, comment: str,
signature: str, signing_ts: str) -> bool:
with conn:
curs = conn.execute(
"""
UPDATE COMMENT
SET Body = :comment, Signature = :signature, SigningTs = :signing_ts
WHERE CommentId = :comment_id
""",
{
'comment': comment,
'signature': signature,
'signing_ts': signing_ts,
'comment_id': comment_id
})
logger.info("updated comment with `comment_id`: %s", comment_id)
return bool(curs.rowcount)
class DatabaseWriter(object):
_writer = None
def __init__(self, db_file):
if not DatabaseWriter._writer:
self.conn = obtain_connection(db_file)
DatabaseWriter._writer = self
atexit.register(self.cleanup)
logging.info('Database writer has been created at %s', repr(self))
else:
logging.warning('Someone attempted to insantiate DatabaseWriter')
raise TypeError('Database Writer already exists!')
def cleanup(self):
logging.info('Cleaning up database writer')
self.conn.close()
DatabaseWriter._writer = None
@property
def connection(self):
return self.conn
def setup_database(db_path):
with sqlite3.connect(db_path) as conn:
conn.executescript(CREATE_TABLES_QUERY)
def backup_database(conn: sqlite3.Connection, back_fp):
with sqlite3.connect(back_fp) as back:
conn.backup(back)

View file

@ -7,7 +7,7 @@ from src.server.validation import is_valid_base_comment
from src.server.validation import is_valid_credential_input
from src.server.validation import validate_signature_from_claim
from src.server.validation import body_is_valid
from src.server.misc import get_claim_from_id
from src.misc import get_claim_from_id
from src.server.external import send_notifications
from src.server.external import send_notification
import src.database.queries as db
@ -84,13 +84,17 @@ async def hide_comments(app, pieces: list) -> list:
# TODO: Amortize this process
claims = {}
comments_to_hide = []
# go through a list of dict objects
for p in pieces:
# maps the comment_id from the piece to a claim_id
claim_id = comment_cids[p['comment_id']]
# resolve the claim from its id
if claim_id not in claims:
claim = await get_claim_from_id(app, claim_id)
if claim:
claims[claim_id] = claim
# get the claim's signing channel, then use it to validate the hidden comment
channel = claims[claim_id].get('signing_channel')
if validate_signature_from_claim(channel, p['signature'], p['signing_ts'], p['comment_id']):
comments_to_hide.append(p)

7
src/definitions.py Normal file
View file

@ -0,0 +1,7 @@
import os
SRC_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = os.path.dirname(SRC_DIR)
CONFIG_FILE = os.path.join(ROOT_DIR, 'config', 'conf.yml')
LOGGING_DIR = os.path.join(ROOT_DIR, 'logs')
DATABASE_DIR = os.path.join(ROOT_DIR, 'database')

View file

@ -1,13 +1,20 @@
import argparse
import json
import yaml
import logging
import logging.config
import os
import sys
from src.server.app import run_app
from src.settings import config
from src.definitions import LOGGING_DIR, CONFIG_FILE, DATABASE_DIR
def config_logging_from_settings(conf):
def setup_logging_from_config(conf: dict):
# set the logging directory here from the settings file
if not os.path.exists(LOGGING_DIR):
os.mkdir(LOGGING_DIR)
_config = {
"version": 1,
"disable_existing_loggers": False,
@ -32,7 +39,7 @@ def config_logging_from_settings(conf):
"level": "DEBUG",
"formatter": "standard",
"class": "logging.handlers.RotatingFileHandler",
"filename": conf['path']['debug_log'],
"filename": os.path.join(LOGGING_DIR, 'debug.log'),
"maxBytes": 10485760,
"backupCount": 5
},
@ -40,7 +47,7 @@ def config_logging_from_settings(conf):
"level": "ERROR",
"formatter": "standard",
"class": "logging.handlers.RotatingFileHandler",
"filename": conf['path']['error_log'],
"filename": os.path.join(LOGGING_DIR, 'error.log'),
"maxBytes": 10485760,
"backupCount": 5
},
@ -48,7 +55,7 @@ def config_logging_from_settings(conf):
"level": "NOTSET",
"formatter": "aiohttp",
"class": "logging.handlers.RotatingFileHandler",
"filename": conf['path']['server_log'],
"filename": os.path.join(LOGGING_DIR, 'server.log'),
"maxBytes": 10485760,
"backupCount": 5
}
@ -70,15 +77,42 @@ def config_logging_from_settings(conf):
logging.config.dictConfig(_config)
def get_config(filepath):
with open(filepath, 'r') as cfile:
config = yaml.load(cfile, Loader=yaml.Loader)
return config
def setup_db_from_config(config: dict):
mode = config['mode']
if config[mode]['database'] == 'sqlite':
if not os.path.exists(DATABASE_DIR):
os.mkdir(DATABASE_DIR)
config[mode]['db_file'] = os.path.join(
DATABASE_DIR, config[mode]['name']
)
def main(argv=None):
argv = argv or sys.argv[1:]
parser = argparse.ArgumentParser(description='LBRY Comment Server')
parser.add_argument('--port', type=int)
parser.add_argument('--config', type=str)
parser.add_argument('--mode', type=str)
args = parser.parse_args(argv)
config_logging_from_settings(config)
config = get_config(CONFIG_FILE) if not args.config else args.config
setup_logging_from_config(config)
if args.mode:
config['mode'] = args.mode
setup_db_from_config(config)
if args.port:
config['port'] = args.port
config_logging_from_settings(config)
run_app(config)

View file

@ -20,3 +20,7 @@ def clean_input_params(kwargs: dict):
kwargs[k] = v.strip()
if k in ID_LIST:
kwargs[k] = v.lower()
def clean(thing: dict) -> dict:
return {k: v for k, v in thing.items() if v is not None}

View file

@ -1,7 +1,6 @@
# cython: language_level=3
import asyncio
import logging
import pathlib
import signal
import time
@ -9,81 +8,67 @@ import aiojobs
import aiojobs.aiohttp
from aiohttp import web
from src.database.queries import obtain_connection, DatabaseWriter
from src.database.queries import setup_database, backup_database
from peewee import *
from src.server.handles import api_endpoint, get_api_endpoint
from src.database.models import Comment, Channel
MODELS = [Comment, Channel]
logger = logging.getLogger(__name__)
async def setup_db_schema(app):
if not pathlib.Path(app['db_path']).exists():
logger.info(f'Setting up schema in {app["db_path"]}')
setup_database(app['db_path'])
else:
logger.info(f'Database already exists in {app["db_path"]}, skipping setup')
def setup_database(app):
config = app['config']
mode = config['mode']
# switch between Database objects
if config[mode]['database'] == 'mysql':
app['db'] = MySQLDatabase(
database=config[mode]['name'],
user=config[mode]['user'],
host=config[mode]['host'],
password=config[mode]['password'],
port=config[mode]['port'],
)
elif config[mode]['database'] == 'sqlite':
app['db'] = SqliteDatabase(
config[mode]['file'],
pragmas=config[mode]['pragmas']
)
async def database_backup_routine(app):
try:
while True:
await asyncio.sleep(app['config']['backup_int'])
with app['reader'] as conn:
logger.debug('backing up database')
backup_database(conn, app['backup'])
except asyncio.CancelledError:
pass
# bind the Model list to the database
app['db'].bind(MODELS, bind_refs=False, bind_backrefs=False)
async def start_background_tasks(app):
# Reading the DB
app['reader'] = obtain_connection(app['db_path'], True)
app['waitful_backup'] = asyncio.create_task(database_backup_routine(app))
# Scheduler to prevent multiple threads from writing to DB simulataneously
app['comment_scheduler'] = await aiojobs.create_scheduler(limit=1, pending_limit=0)
app['db_writer'] = DatabaseWriter(app['db_path'])
app['writer'] = app['db_writer'].connection
app['db'].connect()
app['db'].create_tables(MODELS)
# for requesting to external and internal APIs
app['webhooks'] = await aiojobs.create_scheduler(pending_limit=0)
async def close_database_connections(app):
logger.info('Ending background backup loop')
app['waitful_backup'].cancel()
await app['waitful_backup']
app['reader'].close()
app['writer'].close()
app['db_writer'].cleanup()
app['db'].close()
async def close_schedulers(app):
logger.info('Closing comment_scheduler')
await app['comment_scheduler'].close()
logger.info('Closing scheduler for webhook requests')
await app['webhooks'].close()
class CommentDaemon:
def __init__(self, config, db_file=None, backup=None, **kwargs):
def __init__(self, config, **kwargs):
app = web.Application()
app['config'] = config
# configure the config
app['config'] = config
self.config = app['config']
self.config = config
self.host = config['host']
self.port = config['port']
# configure the db file
if db_file:
app['db_path'] = db_file
app['backup'] = backup
else:
app['db_path'] = config['path']['database']
app['backup'] = backup or (app['db_path'] + '.backup')
setup_database(app)
# configure the order of tasks to run during app lifetime
app.on_startup.append(setup_db_schema)
app.on_startup.append(start_background_tasks)
app.on_shutdown.append(close_schedulers)
app.on_cleanup.append(close_database_connections)
@ -105,20 +90,19 @@ class CommentDaemon:
await self.app_runner.setup()
self.app_site = web.TCPSite(
runner=self.app_runner,
host=host or self.config['host'],
port=port or self.config['port'],
host=host or self.host,
port=port or self.port,
)
await self.app_site.start()
logger.info(f'Comment Server is running on {self.config["host"]}:{self.config["port"]}')
logger.info(f'Comment Server is running on {self.host}:{self.port}')
async def stop(self):
await self.app_runner.shutdown()
await self.app_runner.cleanup()
def run_app(config, db_file=None):
comment_app = CommentDaemon(config=config, db_file=db_file, close_timeout=5.0)
def run_app(config):
comment_app = CommentDaemon(config=config)
loop = asyncio.get_event_loop()
def __exit():

View file

@ -1,16 +1,23 @@
import asyncio
import logging
import time
import typing
from aiohttp import web
from aiojobs.aiohttp import atomic
from peewee import DoesNotExist
import src.database.queries as db
from src.database.writes import abandon_comment, create_comment
from src.database.writes import hide_comments
from src.database.writes import edit_comment
from src.server.misc import clean_input_params
from src.server.validation import validate_signature_from_claim
from src.misc import clean_input_params, get_claim_from_id
from src.server.errors import make_error, report_error
from src.database.models import Comment, Channel
from src.database.models import get_comment
from src.database.models import comment_list
from src.database.models import create_comment
from src.database.models import edit_comment
from src.database.models import delete_comment
from src.database.models import set_hidden_flag
logger = logging.getLogger(__name__)
@ -20,51 +27,194 @@ def ping(*args):
return 'pong'
def handle_get_channel_from_comment_id(app, kwargs: dict):
return db.get_channel_id_from_comment_id(app['reader'], **kwargs)
def handle_get_channel_from_comment_id(app: web.Application, comment_id: str) -> dict:
comment = get_comment(comment_id)
return {
'channel_id': comment['channel_id'],
'channel_name': comment['channel_name']
}
def handle_get_comment_ids(app, kwargs):
return db.get_comment_ids(app['reader'], **kwargs)
def handle_get_comment_ids(
app: web.Application,
claim_id: str,
parent_id: str = None,
page: int = 1,
page_size: int = 50,
flattened=False
) -> dict:
results = comment_list(
claim_id=claim_id,
parent_id=parent_id,
top_level=(parent_id is None),
page=page,
page_size=page_size,
select_fields=['comment_id', 'parent_id']
)
if flattened:
results.update({
'items': [item['comment_id'] for item in results['items']],
'replies': [(item['comment_id'], item.get('parent_id'))
for item in results['items']]
})
return results
def handle_get_claim_comments(app, kwargs):
return db.get_claim_comments(app['reader'], **kwargs)
def handle_get_comments_by_id(
app: web.Application,
comment_ids: typing.Union[list, tuple]
) -> dict:
expression = Comment.comment_id.in_(comment_ids)
return comment_list(expressions=expression, page_size=len(comment_ids))
def handle_get_comments_by_id(app, kwargs):
return db.get_comments_by_id(app['reader'], **kwargs)
def handle_get_claim_comments(
app: web.Application,
claim_id: str,
parent_id: str = None,
page: int = 1,
page_size: int = 50,
top_level: bool = False
) -> dict:
return comment_list(
claim_id=claim_id,
parent_id=parent_id,
page=page,
page_size=page_size,
top_level=top_level
)
def handle_get_claim_hidden_comments(app, kwargs):
return db.get_claim_hidden_comments(app['reader'], **kwargs)
def handle_get_claim_hidden_comments(
app: web.Application,
claim_id: str,
hidden: bool,
page: int = 1,
page_size: int = 50,
) -> dict:
exclude = 'hidden' if hidden else 'visible'
return comment_list(
claim_id=claim_id,
exclude_mode=exclude,
page=page,
page_size=page_size
)
async def handle_abandon_comment(app, params):
return {'abandoned': await abandon_comment(app, **params)}
async def handle_abandon_comment(
app: web.Application,
comment_id: str,
signature: str,
signing_ts: str,
**kwargs,
) -> dict:
comment = get_comment(comment_id)
try:
channel = await get_claim_from_id(app, comment['channel_id'])
except DoesNotExist:
raise ValueError('Could not find a channel associated with the given comment')
else:
if not validate_signature_from_claim(channel, signature, signing_ts, comment_id):
raise ValueError('Abandon signature could not be validated')
with app['db'].atomic():
return {
'abandoned': delete_comment(comment_id)
}
async def handle_hide_comments(app, params):
return {'hidden': await hide_comments(app, **params)}
async def handle_hide_comments(app: web.Application, pieces: list, hide: bool = True) -> dict:
# let's get all the distinct claim_ids from the list of comment_ids
pieces_by_id = {p['comment_id']: p for p in pieces}
comment_ids = list(pieces_by_id.keys())
comments = (Comment
.select(Comment.comment_id, Comment.claim_id)
.where(Comment.comment_id.in_(comment_ids))
.tuples())
# resolve the claims and map them to their corresponding comment_ids
claims = {}
for comment_id, claim_id in comments:
try:
# try and resolve the claim, if fails then we mark it as null
# and remove the associated comment from the pieces
if claim_id not in claims:
claims[claim_id] = await get_claim_from_id(app, claim_id)
# try to get a public key to validate
if claims[claim_id] is None or 'signing_channel' not in claims[claim_id]:
raise ValueError(f'could not get signing channel from claim_id: {claim_id}')
# try to validate signature
else:
channel = claims[claim_id]['signing_channel']
piece = pieces_by_id[comment_id]
is_valid_signature = validate_signature_from_claim(
claim=channel,
signature=piece['signature'],
signing_ts=piece['signing_ts'],
data=piece['comment_id']
)
if not is_valid_signature:
raise ValueError(f'could not validate signature on comment_id: {comment_id}')
except ValueError:
# remove the piece from being hidden
pieces_by_id.pop(comment_id)
# remaining items in pieces_by_id have been able to successfully validate
with app['db'].atomic():
set_hidden_flag(list(pieces_by_id.keys()), hidden=hide)
query = Comment.select().where(Comment.comment_id.in_(comment_ids)).objects()
result = {
'hidden': [c.comment_id for c in query if c.is_hidden],
'visible': [c.comment_id for c in query if not c.is_hidden],
}
return result
async def handle_edit_comment(app, params):
if await edit_comment(app, **params):
return db.get_comment_or_none(app['reader'], params['comment_id'])
async def handle_edit_comment(app, comment: str = None, comment_id: str = None,
signature: str = None, signing_ts: str = None, **params) -> dict:
current = get_comment(comment_id)
channel_claim = await get_claim_from_id(app, current['channel_id'])
if not validate_signature_from_claim(channel_claim, signature, signing_ts, comment):
raise ValueError('Signature could not be validated')
with app['db'].atomic():
if not edit_comment(comment_id, comment, signature, signing_ts):
raise ValueError('Comment could not be edited')
return get_comment(comment_id)
# TODO: retrieve stake amounts for each channel & store in db
def handle_create_comment(app, comment: str = None, claim_id: str = None,
parent_id: str = None, channel_id: str = None, channel_name: str = None,
signature: str = None, signing_ts: str = None) -> dict:
with app['db'].atomic():
return create_comment(
comment=comment,
claim_id=claim_id,
parent_id=parent_id,
channel_id=channel_id,
channel_name=channel_name,
signature=signature,
signing_ts=signing_ts
)
METHODS = {
'ping': ping,
'get_claim_comments': handle_get_claim_comments,
'get_claim_hidden_comments': handle_get_claim_hidden_comments,
'get_claim_comments': handle_get_claim_comments, # this gets used
'get_claim_hidden_comments': handle_get_claim_hidden_comments, # this gets used
'get_comment_ids': handle_get_comment_ids,
'get_comments_by_id': handle_get_comments_by_id,
'get_channel_from_comment_id': handle_get_channel_from_comment_id,
'create_comment': create_comment,
'get_comments_by_id': handle_get_comments_by_id, # this gets used
'get_channel_from_comment_id': handle_get_channel_from_comment_id, # this gets used
'create_comment': handle_create_comment, # this gets used
'delete_comment': handle_abandon_comment,
'abandon_comment': handle_abandon_comment,
'hide_comments': handle_hide_comments,
'edit_comment': handle_edit_comment
'abandon_comment': handle_abandon_comment, # this gets used
'hide_comments': handle_hide_comments, # this gets used
'edit_comment': handle_edit_comment # this gets used
}
@ -78,17 +228,19 @@ async def process_json(app, body: dict) -> dict:
start = time.time()
try:
if asyncio.iscoroutinefunction(METHODS[method]):
result = await METHODS[method](app, params)
result = await METHODS[method](app, **params)
else:
result = METHODS[method](app, params)
response['result'] = result
result = METHODS[method](app, **params)
except Exception as err:
logger.exception(f'Got {type(err).__name__}:')
logger.exception(f'Got {type(err).__name__}:\n{err}')
if type(err) in (ValueError, TypeError): # param error, not too important
response['error'] = make_error('INVALID_PARAMS', err)
else:
response['error'] = make_error('INTERNAL', err)
await app['webhooks'].spawn(report_error(app, err, body))
await app['webhooks'].spawn(report_error(app, err, body))
else:
response['result'] = result
finally:
end = time.time()

View file

@ -51,11 +51,31 @@ def claim_id_is_valid(claim_id: str) -> bool:
# default to None so params can be treated as kwargs; param count becomes more manageable
def is_valid_base_comment(comment: str = None, claim_id: str = None, parent_id: str = None, **kwargs) -> bool:
return comment and body_is_valid(comment) and \
((claim_id and claim_id_is_valid(claim_id)) or # parentid is used in place of claimid in replies
(parent_id and comment_id_is_valid(parent_id))) \
and is_valid_credential_input(**kwargs)
def is_valid_base_comment(
comment: str = None,
claim_id: str = None,
parent_id: str = None,
strict: bool = False,
**kwargs,
) -> bool:
try:
assert comment and body_is_valid(comment)
# strict mode assumes that the parent_id might not exist
if strict:
assert claim_id and claim_id_is_valid(claim_id)
assert parent_id is None or comment_id_is_valid(parent_id)
# non-strict removes reference restrictions
else:
assert claim_id or parent_id
if claim_id:
assert claim_id_is_valid(claim_id)
else:
assert comment_id_is_valid(parent_id)
except AssertionError:
return False
else:
return is_valid_credential_input(**kwargs)
def is_valid_credential_input(channel_id: str = None, channel_name: str = None,

View file

@ -1,17 +0,0 @@
# cython: language_level=3
import json
import pathlib
root_dir = pathlib.Path(__file__).parent.parent
config_path = root_dir / 'config' / 'conf.json'
def get_config(filepath):
with open(filepath, 'r') as cfile:
conf = json.load(cfile)
for key, path in conf['path'].items():
conf['path'][key] = str(root_dir / path)
return conf
config = get_config(config_path)

View file

@ -1,18 +1,13 @@
import sqlite3
from random import randint
import faker
from faker.providers import internet
from faker.providers import lorem
from faker.providers import misc
from src.database.queries import get_comments_by_id
from src.database.queries import get_comment_ids
from src.database.queries import get_claim_comments
from src.database.queries import get_claim_hidden_comments
from src.database.writes import create_comment_or_error
from src.database.queries import hide_comments_by_id
from src.database.queries import delete_comment_by_id
from src.database.models import create_comment
from src.database.models import delete_comment
from src.database.models import comment_list, get_comment
from src.database.models import set_hidden_flag
from test.testcase import DatabaseTestCase
fake = faker.Faker()
@ -27,26 +22,25 @@ class TestDatabaseOperations(DatabaseTestCase):
self.claimId = '529357c3422c6046d3fec76be2358004ba22e340'
def test01NamedComments(self):
comment = create_comment_or_error(
conn=self.conn,
comment = create_comment(
claim_id=self.claimId,
comment='This is a named comment',
channel_name='@username',
channel_id='529357c3422c6046d3fec76be2358004ba22abcd',
signature=fake.uuid4(),
signature='22'*64,
signing_ts='aaa'
)
self.assertIsNotNone(comment)
self.assertNotIn('parent_in', comment)
previous_id = comment['comment_id']
reply = create_comment_or_error(
conn=self.conn,
reply = create_comment(
claim_id=self.claimId,
comment='This is a named response',
channel_name='@another_username',
channel_id='529357c3422c6046d3fec76be2358004ba224bcd',
parent_id=previous_id,
signature=fake.uuid4(),
signature='11'*64,
signing_ts='aaa'
)
self.assertIsNotNone(reply)
@ -54,34 +48,32 @@ class TestDatabaseOperations(DatabaseTestCase):
def test02AnonymousComments(self):
self.assertRaises(
sqlite3.IntegrityError,
create_comment_or_error,
conn=self.conn,
ValueError,
create_comment,
claim_id=self.claimId,
comment='This is an ANONYMOUS comment'
)
def test03SignedComments(self):
comment = create_comment_or_error(
conn=self.conn,
comment = create_comment(
claim_id=self.claimId,
comment='I like big butts and i cannot lie',
channel_name='@sirmixalot',
channel_id='529357c3422c6046d3fec76be2358005ba22abcd',
signature=fake.uuid4(),
signature='24'*64,
signing_ts='asdasd'
)
self.assertIsNotNone(comment)
self.assertIn('signing_ts', comment)
previous_id = comment['comment_id']
reply = create_comment_or_error(
conn=self.conn,
reply = create_comment(
claim_id=self.claimId,
comment='This is a LBRY verified response',
channel_name='@LBRY',
channel_id='529357c3422c6046d3fec76be2358001ba224bcd',
parent_id=previous_id,
signature=fake.uuid4(),
signature='12'*64,
signing_ts='sfdfdfds'
)
self.assertIsNotNone(reply)
@ -90,75 +82,109 @@ class TestDatabaseOperations(DatabaseTestCase):
def test04UsernameVariations(self):
self.assertRaises(
AssertionError,
callable=create_comment_or_error,
conn=self.conn,
ValueError,
create_comment,
claim_id=self.claimId,
channel_name='$#(@#$@#$',
channel_id='529357c3422c6046d3fec76be2358001ba224b23',
comment='this is an invalid username'
comment='this is an invalid username',
signature='1' * 128,
signing_ts='123'
)
valid_username = create_comment_or_error(
conn=self.conn,
valid_username = create_comment(
claim_id=self.claimId,
channel_name='@' + 'a' * 255,
channel_id='529357c3422c6046d3fec76be2358001ba224b23',
comment='this is a valid username'
comment='this is a valid username',
signature='1'*128,
signing_ts='123'
)
self.assertIsNotNone(valid_username)
self.assertRaises(AssertionError,
callable=create_comment_or_error,
conn=self.conn,
claim_id=self.claimId,
channel_name='@' + 'a' * 256,
channel_id='529357c3422c6046d3fec76be2358001ba224b23',
comment='this username is too long'
)
self.assertRaises(
AssertionError,
callable=create_comment_or_error,
conn=self.conn,
ValueError,
create_comment,
claim_id=self.claimId,
channel_name='@' + 'a' * 256,
channel_id='529357c3422c6046d3fec76be2358001ba224b23',
comment='this username is too long',
signature='2' * 128,
signing_ts='123'
)
self.assertRaises(
ValueError,
create_comment,
claim_id=self.claimId,
channel_name='',
channel_id='529357c3422c6046d3fec76be2358001ba224b23',
comment='this username should not default to ANONYMOUS'
comment='this username should not default to ANONYMOUS',
signature='3' * 128,
signing_ts='123'
)
self.assertRaises(
AssertionError,
callable=create_comment_or_error,
conn=self.conn,
ValueError,
create_comment,
claim_id=self.claimId,
channel_name='@',
channel_id='529357c3422c6046d3fec76be2358001ba224b23',
comment='this username is too short'
comment='this username is too short',
signature='3' * 128,
signing_ts='123'
)
def test05HideComments(self):
comm = create_comment_or_error(self.conn, 'Comment #1', self.claimId, '1'*40, '@Doge123', 'a'*128, '123')
comment = get_comments_by_id(self.conn, [comm['comment_id']]).pop()
comm = create_comment(
comment='Comment #1',
claim_id=self.claimId,
channel_id='1'*40,
channel_name='@Doge123',
signature='a'*128,
signing_ts='123'
)
comment = get_comment(comm['comment_id'])
self.assertFalse(comment['is_hidden'])
success = hide_comments_by_id(self.conn, [comm['comment_id']])
success = set_hidden_flag([comm['comment_id']])
self.assertTrue(success)
comment = get_comments_by_id(self.conn, [comm['comment_id']]).pop()
comment = get_comment(comm['comment_id'])
self.assertTrue(comment['is_hidden'])
success = hide_comments_by_id(self.conn, [comm['comment_id']])
success = set_hidden_flag([comm['comment_id']])
self.assertTrue(success)
comment = get_comments_by_id(self.conn, [comm['comment_id']]).pop()
comment = get_comment(comm['comment_id'])
self.assertTrue(comment['is_hidden'])
def test06DeleteComments(self):
comm = create_comment_or_error(self.conn, 'Comment #1', self.claimId, '1'*40, '@Doge123', 'a'*128, '123')
comments = get_claim_comments(self.conn, self.claimId)
match = list(filter(lambda x: comm['comment_id'] == x['comment_id'], comments['items']))
self.assertTrue(match)
deleted = delete_comment_by_id(self.conn, comm['comment_id'])
# make sure that the comment was created
comm = create_comment(
comment='Comment #1',
claim_id=self.claimId,
channel_id='1'*40,
channel_name='@Doge123',
signature='a'*128,
signing_ts='123'
)
comments = comment_list(self.claimId)
match = [x for x in comments['items'] if x['comment_id'] == comm['comment_id']]
self.assertTrue(len(match) > 0)
deleted = delete_comment(comm['comment_id'])
self.assertTrue(deleted)
comments = get_claim_comments(self.conn, self.claimId)
match = list(filter(lambda x: comm['comment_id'] == x['comment_id'], comments['items']))
# make sure that we can't find the comment here
comments = comment_list(self.claimId)
match = [x for x in comments['items'] if x['comment_id'] == comm['comment_id']]
self.assertFalse(match)
deleted = delete_comment_by_id(self.conn, comm['comment_id'])
self.assertFalse(deleted)
self.assertRaises(
ValueError,
delete_comment,
comment_id=comm['comment_id'],
)
class ListDatabaseTest(DatabaseTestCase):
@ -169,61 +195,75 @@ class ListDatabaseTest(DatabaseTestCase):
def testLists(self):
for claim_id in self.claim_ids:
with self.subTest(claim_id=claim_id):
comments = get_claim_comments(self.conn, claim_id)
comments = comment_list(claim_id)
self.assertIsNotNone(comments)
self.assertGreater(comments['page_size'], 0)
self.assertIn('has_hidden_comments', comments)
self.assertFalse(comments['has_hidden_comments'])
top_comments = get_claim_comments(self.conn, claim_id, top_level=True, page=1, page_size=50)
top_comments = comment_list(claim_id, top_level=True, page=1, page_size=50)
self.assertIsNotNone(top_comments)
self.assertEqual(top_comments['page_size'], 50)
self.assertEqual(top_comments['page'], 1)
self.assertGreaterEqual(top_comments['total_pages'], 0)
self.assertGreaterEqual(top_comments['total_items'], 0)
comment_ids = get_comment_ids(self.conn, claim_id, page_size=50, page=1)
comment_ids = comment_list(claim_id, page_size=50, page=1)
with self.subTest(comment_ids=comment_ids):
self.assertIsNotNone(comment_ids)
self.assertLessEqual(len(comment_ids), 50)
matching_comments = get_comments_by_id(self.conn, comment_ids)
matching_comments = (comment_ids)
self.assertIsNotNone(matching_comments)
self.assertEqual(len(matching_comments), len(comment_ids))
def testHiddenCommentLists(self):
claim_id = 'a'*40
comm1 = create_comment_or_error(self.conn, 'Comment #1', claim_id, '1'*40, '@Doge123', 'a'*128, '123')
comm2 = create_comment_or_error(self.conn, 'Comment #2', claim_id, '1'*40, '@Doge123', 'b'*128, '123')
comm3 = create_comment_or_error(self.conn, 'Comment #3', claim_id, '1'*40, '@Doge123', 'c'*128, '123')
comm1 = create_comment(
'Comment #1',
claim_id,
channel_id='1'*40,
channel_name='@Doge123',
signature='a'*128,
signing_ts='123'
)
comm2 = create_comment(
'Comment #2', claim_id,
channel_id='1'*40,
channel_name='@Doge123',
signature='b'*128,
signing_ts='123'
)
comm3 = create_comment(
'Comment #3', claim_id,
channel_id='1'*40,
channel_name='@Doge123',
signature='c'*128,
signing_ts='123'
)
comments = [comm1, comm2, comm3]
comment_list = get_claim_comments(self.conn, claim_id)
self.assertIn('items', comment_list)
self.assertIn('has_hidden_comments', comment_list)
self.assertEqual(len(comments), comment_list['total_items'])
self.assertIn('has_hidden_comments', comment_list)
self.assertFalse(comment_list['has_hidden_comments'])
hide_comments_by_id(self.conn, [comm2['comment_id']])
listed_comments = comment_list(claim_id)
self.assertEqual(len(comments), listed_comments['total_items'])
self.assertFalse(listed_comments['has_hidden_comments'])
default_comments = get_claim_hidden_comments(self.conn, claim_id)
self.assertIn('has_hidden_comments', default_comments)
set_hidden_flag([comm2['comment_id']])
hidden = comment_list(claim_id, exclude_mode='hidden')
hidden_comments = get_claim_hidden_comments(self.conn, claim_id, hidden=True)
self.assertIn('has_hidden_comments', hidden_comments)
self.assertEqual(default_comments, hidden_comments)
self.assertTrue(hidden['has_hidden_comments'])
self.assertGreater(len(hidden['items']), 0)
hidden_comment = hidden_comments['items'][0]
visible = comment_list(claim_id, exclude_mode='visible')
self.assertFalse(visible['has_hidden_comments'])
self.assertNotEqual(listed_comments['items'], visible['items'])
# make sure the hidden comment is the one we marked as hidden
hidden_comment = hidden['items'][0]
self.assertEqual(hidden_comment['comment_id'], comm2['comment_id'])
visible_comments = get_claim_hidden_comments(self.conn, claim_id, hidden=False)
self.assertIn('has_hidden_comments', visible_comments)
self.assertNotIn(hidden_comment, visible_comments['items'])
hidden_ids = [c['comment_id'] for c in hidden_comments['items']]
visible_ids = [c['comment_id'] for c in visible_comments['items']]
hidden_ids = [c['comment_id'] for c in hidden['items']]
visible_ids = [c['comment_id'] for c in visible['items']]
composite_ids = hidden_ids + visible_ids
listed_comments = comment_list(claim_id)
all_ids = [c['comment_id'] for c in listed_comments['items']]
composite_ids.sort()
comment_list = get_claim_comments(self.conn, claim_id)
all_ids = [c['comment_id'] for c in comment_list['items']]
all_ids.sort()
self.assertEqual(composite_ids, all_ids)

View file

@ -9,13 +9,18 @@ from faker.providers import internet
from faker.providers import lorem
from faker.providers import misc
from src.settings import config
from src.main import get_config, CONFIG_FILE
from src.server import app
from src.server.validation import is_valid_base_comment
from test.testcase import AsyncioTestCase
config = get_config(CONFIG_FILE)
config['mode'] = 'testing'
config['testing']['file'] = ':memory:'
if 'slack_webhook' in config:
config.pop('slack_webhook')
@ -71,10 +76,10 @@ def create_test_comments(values: iter, **default):
class ServerTest(AsyncioTestCase):
db_file = 'test.db'
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
config['mode'] = 'testing'
config['testing']['file'] = ':memory:'
self.host = 'localhost'
self.port = 5931
@ -85,11 +90,10 @@ class ServerTest(AsyncioTestCase):
@classmethod
def tearDownClass(cls) -> None:
print('exit reached')
os.remove(cls.db_file)
async def asyncSetUp(self):
await super().asyncSetUp()
self.server = app.CommentDaemon(config, db_file=self.db_file)
self.server = app.CommentDaemon(config)
await self.server.start(host=self.host, port=self.port)
self.addCleanup(self.server.stop)
@ -135,14 +139,16 @@ class ServerTest(AsyncioTestCase):
test_all = create_test_comments(replace.keys(), **{
k: None for k in replace.keys()
})
test_all.reverse()
for test in test_all:
with self.subTest(test=test):
nulls = 'null fields: ' + ', '.join(k for k, v in test.items() if not v)
with self.subTest(test=nulls):
message = await self.post_comment(**test)
self.assertTrue('result' in message or 'error' in message)
if 'error' in message:
self.assertFalse(is_valid_base_comment(**test))
self.assertFalse(is_valid_base_comment(**test, strict=True))
else:
self.assertTrue(is_valid_base_comment(**test))
self.assertTrue(is_valid_base_comment(**test, strict=True))
async def test04CreateAllReplies(self):
claim_id = '1d8a5cc39ca02e55782d619e67131c0a20843be8'
@ -220,7 +226,8 @@ class ListCommentsTest(AsyncioTestCase):
super().__init__(*args, **kwargs)
self.host = 'localhost'
self.port = 5931
self.db_file = 'list_test.db'
config['mode'] = 'testing'
config['testing']['file'] = ':memory:'
self.claim_id = '1d8a5cc39ca02e55782d619e67131c0a20843be8'
self.comment_ids = None
@ -231,10 +238,6 @@ class ListCommentsTest(AsyncioTestCase):
async def post_comment(self, **params):
return await jsonrpc_post(self.url, 'create_comment', **params)
def tearDown(self) -> None:
print('exit reached')
os.remove(self.db_file)
async def create_lots_of_comments(self, n=23):
self.comment_list = [{key: self.replace[key]() for key in self.replace.keys()} for _ in range(23)]
for comment in self.comment_list:
@ -244,7 +247,7 @@ class ListCommentsTest(AsyncioTestCase):
async def asyncSetUp(self):
await super().asyncSetUp()
self.server = app.CommentDaemon(config, db_file=self.db_file)
self.server = app.CommentDaemon(config)
await self.server.start(self.host, self.port)
self.addCleanup(self.server.stop)

View file

@ -1,12 +1,38 @@
import os
import pathlib
import unittest
from asyncio.runners import _cancel_all_tasks # type: ignore
from unittest.case import _Outcome
import asyncio
from asyncio.runners import _cancel_all_tasks # type: ignore
from peewee import *
from src.database.queries import obtain_connection, setup_database
from src.database.models import Channel, Comment
test_db = SqliteDatabase(':memory:')
MODELS = [Channel, Comment]
class DatabaseTestCase(unittest.TestCase):
def __init__(self, methodName='DatabaseTest'):
super().__init__(methodName)
def setUp(self) -> None:
super().setUp()
test_db.bind(MODELS, bind_refs=False, bind_backrefs=False)
test_db.connect()
test_db.create_tables(MODELS)
def tearDown(self) -> None:
# drop tables for next test
test_db.drop_tables(MODELS)
# close connection
test_db.close()
class AsyncioTestCase(unittest.TestCase):
@ -117,21 +143,3 @@ class AsyncioTestCase(unittest.TestCase):
self.loop.run_until_complete(maybe_coroutine)
class DatabaseTestCase(unittest.TestCase):
db_file = 'test.db'
def __init__(self, methodName='DatabaseTest'):
super().__init__(methodName)
if pathlib.Path(self.db_file).exists():
os.remove(self.db_file)
def setUp(self) -> None:
super().setUp()
setup_database(self.db_file)
self.conn = obtain_connection(self.db_file)
self.addCleanup(self.conn.close)
self.addCleanup(os.remove, self.db_file)
def tearDown(self) -> None:
self.conn.close()