Moves functions around into more logical locations
This commit is contained in:
parent
a1b1fa2b1f
commit
02cf92720e
6 changed files with 111 additions and 111 deletions
|
@ -8,11 +8,10 @@ import asyncio
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
import schema.db_helpers
|
import schema.db_helpers
|
||||||
from src.database import obtain_connection
|
from src.database import obtain_connection, DatabaseWriter
|
||||||
from src.handles import api_endpoint
|
from src.handles import api_endpoint
|
||||||
from src.handles import create_comment_scheduler
|
from src.handles import create_comment_scheduler
|
||||||
from src.settings import config_path, get_config
|
from src.settings import config_path, get_config
|
||||||
from src.writes import DatabaseWriter
|
|
||||||
|
|
||||||
config = get_config(config_path)
|
config = get_config(config_path)
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
|
import atexit
|
||||||
import logging
|
import logging
|
||||||
import re
|
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import time
|
import time
|
||||||
import typing
|
import typing
|
||||||
|
@ -82,23 +82,7 @@ def get_claim_comments(conn: sqlite3.Connection, claim_id: str, parent_id: str =
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def validate_channel(channel_id: str, channel_name: str):
|
def insert_channel(conn: sqlite3.Connection, channel_name: str, channel_id: str):
|
||||||
assert channel_id and channel_name
|
|
||||||
assert type(channel_id) is str and type(channel_name) is str
|
|
||||||
assert re.fullmatch(
|
|
||||||
'^@(?:(?![\x00-\x08\x0b\x0c\x0e-\x1f\x23-\x26'
|
|
||||||
'\x2f\x3a\x3d\x3f-\x40\uFFFE-\U0000FFFF]).){1,255}$',
|
|
||||||
channel_name
|
|
||||||
)
|
|
||||||
assert re.fullmatch('[a-z0-9]{40}', channel_id)
|
|
||||||
|
|
||||||
|
|
||||||
def validate_input(comment: str, claim_id: str, **kwargs):
|
|
||||||
assert 0 < len(comment) <= 2000
|
|
||||||
assert re.fullmatch('[a-z0-9]{40}', claim_id)
|
|
||||||
|
|
||||||
|
|
||||||
def _insert_channel(conn: sqlite3.Connection, channel_name: str, channel_id: str):
|
|
||||||
with conn:
|
with conn:
|
||||||
conn.execute(
|
conn.execute(
|
||||||
'INSERT INTO CHANNEL(ClaimId, Name) VALUES (?, ?)',
|
'INSERT INTO CHANNEL(ClaimId, Name) VALUES (?, ?)',
|
||||||
|
@ -106,18 +90,9 @@ def _insert_channel(conn: sqlite3.Connection, channel_name: str, channel_id: str
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def insert_channel_or_error(conn: sqlite3.Connection, channel_name: str, channel_id: str):
|
def insert_comment(conn: sqlite3.Connection, claim_id: str = None, comment: str = None,
|
||||||
try:
|
channel_id: str = None, signature: str = None, signing_ts: str = None,
|
||||||
validate_channel(channel_id, channel_name)
|
parent_id: str = None) -> str:
|
||||||
_insert_channel(conn, channel_name, channel_id)
|
|
||||||
except AssertionError as ae:
|
|
||||||
logger.exception('Invalid channel values given: %s', ae)
|
|
||||||
raise ValueError('Received invalid values for channel_id or channel_name')
|
|
||||||
|
|
||||||
|
|
||||||
def _insert_comment(conn: sqlite3.Connection, claim_id: str = None, comment: str = None,
|
|
||||||
channel_id: str = None, signature: str = None, signing_ts: str = None,
|
|
||||||
parent_id: str = None) -> str:
|
|
||||||
timestamp = int(time.time())
|
timestamp = int(time.time())
|
||||||
prehash = ':'.join((claim_id, comment, str(timestamp),))
|
prehash = ':'.join((claim_id, comment, str(timestamp),))
|
||||||
prehash = bytes(prehash.encode('utf-8'))
|
prehash = bytes(prehash.encode('utf-8'))
|
||||||
|
@ -147,25 +122,6 @@ def get_comment_or_none(conn: sqlite3.Connection, comment_id: str) -> dict:
|
||||||
return clean(dict(thing)) if thing else None
|
return clean(dict(thing)) if thing else None
|
||||||
|
|
||||||
|
|
||||||
def validate_signature(*args, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def create_comment(conn: sqlite3.Connection, comment: str, claim_id: str, channel_id: str = None,
|
|
||||||
channel_name: str = None, signature: str = None, signing_ts: str = None, parent_id: str = None):
|
|
||||||
if channel_id or channel_name or signature or signing_ts:
|
|
||||||
validate_signature(signature, signing_ts, comment, channel_name, channel_id)
|
|
||||||
insert_channel_or_error(conn, channel_name, channel_id)
|
|
||||||
try:
|
|
||||||
comment_id = _insert_comment(
|
|
||||||
conn=conn, comment=comment, claim_id=claim_id, channel_id=channel_id,
|
|
||||||
signature=signature, parent_id=parent_id
|
|
||||||
)
|
|
||||||
return get_comment_or_none(conn, comment_id)
|
|
||||||
except sqlite3.IntegrityError as ie:
|
|
||||||
logger.exception(ie)
|
|
||||||
|
|
||||||
|
|
||||||
def get_comment_ids(conn: sqlite3.Connection, claim_id: str, parent_id: str = None, page=1, page_size=50):
|
def get_comment_ids(conn: sqlite3.Connection, claim_id: str, parent_id: str = None, page=1, page_size=50):
|
||||||
""" Just return a list of the comment IDs that are associated with the given claim_id.
|
""" Just return a list of the comment IDs that are associated with the given claim_id.
|
||||||
If get_all is specified then it returns all the IDs, otherwise only the IDs at that level.
|
If get_all is specified then it returns all the IDs, otherwise only the IDs at that level.
|
||||||
|
@ -199,3 +155,26 @@ def get_comments_by_id(conn, comment_ids: list) -> typing.Union[list, None]:
|
||||||
f'SELECT * FROM COMMENTS_ON_CLAIMS WHERE comment_id IN ({placeholders})',
|
f'SELECT * FROM COMMENTS_ON_CLAIMS WHERE comment_id IN ({placeholders})',
|
||||||
tuple(comment_ids)
|
tuple(comment_ids)
|
||||||
)]
|
)]
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseWriter(object):
|
||||||
|
_writer = None
|
||||||
|
|
||||||
|
def __init__(self, db_file):
|
||||||
|
if not DatabaseWriter._writer:
|
||||||
|
self.conn = obtain_connection(db_file)
|
||||||
|
DatabaseWriter._writer = self
|
||||||
|
atexit.register(self.cleanup)
|
||||||
|
logging.info('Database writer has been created at %s', repr(self))
|
||||||
|
else:
|
||||||
|
logging.warning('Someone attempted to insantiate DatabaseWriter')
|
||||||
|
raise TypeError('Database Writer already exists!')
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
logging.info('Cleaning up database writer')
|
||||||
|
DatabaseWriter._writer = None
|
||||||
|
self.conn.close()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def connection(self):
|
||||||
|
return self.conn
|
|
@ -8,11 +8,11 @@ from aiohttp import web
|
||||||
from aiojobs.aiohttp import atomic
|
from aiojobs.aiohttp import atomic
|
||||||
from asyncio import coroutine
|
from asyncio import coroutine
|
||||||
|
|
||||||
from src.database import create_comment
|
from src.database import DatabaseWriter
|
||||||
from src.database import get_claim_comments
|
from src.database import get_claim_comments
|
||||||
from src.database import get_comments_by_id, get_comment_ids
|
from src.database import get_comments_by_id, get_comment_ids
|
||||||
from src.database import obtain_connection
|
from src.database import obtain_connection
|
||||||
from src.writes import DatabaseWriter
|
from src.writes import create_comment
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
22
src/misc.py
Normal file
22
src/misc.py
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
def validate_channel(channel_id: str, channel_name: str):
|
||||||
|
assert channel_id and channel_name
|
||||||
|
assert type(channel_id) is str and type(channel_name) is str
|
||||||
|
assert re.fullmatch(
|
||||||
|
'^@(?:(?![\x00-\x08\x0b\x0c\x0e-\x1f\x23-\x26'
|
||||||
|
'\x2f\x3a\x3d\x3f-\x40\uFFFE-\U0000FFFF]).){1,255}$',
|
||||||
|
channel_name
|
||||||
|
)
|
||||||
|
assert re.fullmatch('[a-z0-9]{40}', channel_id)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_input(comment: str, claim_id: str, **kwargs):
|
||||||
|
assert 0 < len(comment) <= 2000
|
||||||
|
assert re.fullmatch('[a-z0-9]{40}', claim_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def validate_signature(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
|
@ -1,30 +1,32 @@
|
||||||
import atexit
|
|
||||||
import logging
|
import logging
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
from src.database import obtain_connection
|
from src.database import get_comment_or_none
|
||||||
|
from src.database import insert_comment, insert_channel
|
||||||
|
from src.misc import validate_channel, validate_signature
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# DatabaseWriter should be instantiated on startup
|
def create_comment(conn: sqlite3.Connection, comment: str, claim_id: str, channel_id: str = None,
|
||||||
class DatabaseWriter(object):
|
channel_name: str = None, signature: str = None, signing_ts: str = None, parent_id: str = None):
|
||||||
_writer = None
|
if channel_id or channel_name or signature or signing_ts:
|
||||||
|
validate_signature(signature, signing_ts, comment, channel_name, channel_id)
|
||||||
|
insert_channel_or_error(conn, channel_name, channel_id)
|
||||||
|
try:
|
||||||
|
comment_id = insert_comment(
|
||||||
|
conn=conn, comment=comment, claim_id=claim_id, channel_id=channel_id,
|
||||||
|
signature=signature, parent_id=parent_id
|
||||||
|
)
|
||||||
|
return get_comment_or_none(conn, comment_id)
|
||||||
|
except sqlite3.IntegrityError as ie:
|
||||||
|
logger.exception(ie)
|
||||||
|
|
||||||
def __init__(self, db_file):
|
|
||||||
if not DatabaseWriter._writer:
|
|
||||||
self.conn = obtain_connection(db_file)
|
|
||||||
DatabaseWriter._writer = self
|
|
||||||
atexit.register(self.cleanup)
|
|
||||||
logging.info('Database writer has been created at %s', repr(self))
|
|
||||||
else:
|
|
||||||
logging.warning('Someone attempted to insantiate DatabaseWriter')
|
|
||||||
raise TypeError('Database Writer already exists!')
|
|
||||||
|
|
||||||
def cleanup(self):
|
def insert_channel_or_error(conn: sqlite3.Connection, channel_name: str, channel_id: str):
|
||||||
logging.info('Cleaning up database writer')
|
try:
|
||||||
DatabaseWriter._writer = None
|
validate_channel(channel_id, channel_name)
|
||||||
self.conn.close()
|
insert_channel(conn, channel_name, channel_id)
|
||||||
|
except AssertionError as ae:
|
||||||
@property
|
logger.exception('Invalid channel values given: %s', ae)
|
||||||
def connection(self):
|
raise ValueError('Received invalid values for channel_id or channel_name')
|
||||||
return self.conn
|
|
||||||
|
|
|
@ -5,11 +5,10 @@ from faker.providers import internet
|
||||||
from faker.providers import lorem
|
from faker.providers import lorem
|
||||||
from faker.providers import misc
|
from faker.providers import misc
|
||||||
|
|
||||||
from src.database import get_comments_by_id, create_comment, get_comment_ids, \
|
from src.database import get_comments_by_id, get_comment_ids, \
|
||||||
get_claim_comments
|
get_claim_comments
|
||||||
from schema.db_helpers import setup_database, teardown_database
|
from src.writes import create_comment
|
||||||
from src.settings import config
|
from tests.testcase import DatabaseTestCase
|
||||||
from tests.testcase import DatabaseTestCase, AsyncioTestCase
|
|
||||||
|
|
||||||
fake = faker.Faker()
|
fake = faker.Faker()
|
||||||
fake.add_provider(internet)
|
fake.add_provider(internet)
|
||||||
|
@ -29,9 +28,11 @@ class TestCommentCreation(DatabaseTestCase):
|
||||||
comment='This is a named comment',
|
comment='This is a named comment',
|
||||||
channel_name='@username',
|
channel_name='@username',
|
||||||
channel_id='529357c3422c6046d3fec76be2358004ba22abcd',
|
channel_id='529357c3422c6046d3fec76be2358004ba22abcd',
|
||||||
|
signature=fake.uuid4(),
|
||||||
|
signing_ts='aaa'
|
||||||
)
|
)
|
||||||
self.assertIsNotNone(comment)
|
self.assertIsNotNone(comment)
|
||||||
self.assertIsNone(comment['parent_id'])
|
self.assertNotIn('parent_in', comment)
|
||||||
previous_id = comment['comment_id']
|
previous_id = comment['comment_id']
|
||||||
reply = create_comment(
|
reply = create_comment(
|
||||||
conn=self.conn,
|
conn=self.conn,
|
||||||
|
@ -39,11 +40,12 @@ class TestCommentCreation(DatabaseTestCase):
|
||||||
comment='This is a named response',
|
comment='This is a named response',
|
||||||
channel_name='@another_username',
|
channel_name='@another_username',
|
||||||
channel_id='529357c3422c6046d3fec76be2358004ba224bcd',
|
channel_id='529357c3422c6046d3fec76be2358004ba224bcd',
|
||||||
parent_id=previous_id
|
parent_id=previous_id,
|
||||||
|
signature=fake.uuid4(),
|
||||||
|
signing_ts='aaa'
|
||||||
)
|
)
|
||||||
self.assertIsNotNone(reply)
|
self.assertIsNotNone(reply)
|
||||||
self.assertEqual(reply['parent_id'], comment['comment_id'])
|
self.assertEqual(reply['parent_id'], comment['comment_id'])
|
||||||
self.assertEqual(reply['claim_id'], comment['claim_id'])
|
|
||||||
|
|
||||||
def test02AnonymousComments(self):
|
def test02AnonymousComments(self):
|
||||||
comment = create_comment(
|
comment = create_comment(
|
||||||
|
@ -52,7 +54,6 @@ class TestCommentCreation(DatabaseTestCase):
|
||||||
comment='This is an ANONYMOUS comment'
|
comment='This is an ANONYMOUS comment'
|
||||||
)
|
)
|
||||||
self.assertIsNotNone(comment)
|
self.assertIsNotNone(comment)
|
||||||
self.assertIsNone(comment['parent_id'])
|
|
||||||
previous_id = comment['comment_id']
|
previous_id = comment['comment_id']
|
||||||
reply = create_comment(
|
reply = create_comment(
|
||||||
conn=self.conn,
|
conn=self.conn,
|
||||||
|
@ -62,7 +63,6 @@ class TestCommentCreation(DatabaseTestCase):
|
||||||
)
|
)
|
||||||
self.assertIsNotNone(reply)
|
self.assertIsNotNone(reply)
|
||||||
self.assertEqual(reply['parent_id'], comment['comment_id'])
|
self.assertEqual(reply['parent_id'], comment['comment_id'])
|
||||||
self.assertEqual(reply['claim_id'], comment['claim_id'])
|
|
||||||
|
|
||||||
def test03SignedComments(self):
|
def test03SignedComments(self):
|
||||||
comment = create_comment(
|
comment = create_comment(
|
||||||
|
@ -71,10 +71,10 @@ class TestCommentCreation(DatabaseTestCase):
|
||||||
comment='I like big butts and i cannot lie',
|
comment='I like big butts and i cannot lie',
|
||||||
channel_name='@sirmixalot',
|
channel_name='@sirmixalot',
|
||||||
channel_id='529357c3422c6046d3fec76be2358005ba22abcd',
|
channel_id='529357c3422c6046d3fec76be2358005ba22abcd',
|
||||||
signature='siggy'
|
signature=fake.uuid4(),
|
||||||
|
signing_ts='asdasd'
|
||||||
)
|
)
|
||||||
self.assertIsNotNone(comment)
|
self.assertIsNotNone(comment)
|
||||||
self.assertIsNone(comment['parent_id'])
|
|
||||||
previous_id = comment['comment_id']
|
previous_id = comment['comment_id']
|
||||||
reply = create_comment(
|
reply = create_comment(
|
||||||
conn=self.conn,
|
conn=self.conn,
|
||||||
|
@ -83,11 +83,11 @@ class TestCommentCreation(DatabaseTestCase):
|
||||||
channel_name='@LBRY',
|
channel_name='@LBRY',
|
||||||
channel_id='529357c3422c6046d3fec76be2358001ba224bcd',
|
channel_id='529357c3422c6046d3fec76be2358001ba224bcd',
|
||||||
parent_id=previous_id,
|
parent_id=previous_id,
|
||||||
signature='Cursive Font Goes Here'
|
signature=fake.uuid4(),
|
||||||
|
signing_ts='sfdfdfds'
|
||||||
)
|
)
|
||||||
self.assertIsNotNone(reply)
|
self.assertIsNotNone(reply)
|
||||||
self.assertEqual(reply['parent_id'], comment['comment_id'])
|
self.assertEqual(reply['parent_id'], comment['comment_id'])
|
||||||
self.assertEqual(reply['claim_id'], comment['claim_id'])
|
|
||||||
|
|
||||||
def test04UsernameVariations(self):
|
def test04UsernameVariations(self):
|
||||||
invalid_comment = create_comment(
|
invalid_comment = create_comment(
|
||||||
|
@ -187,16 +187,10 @@ class ListDatabaseTest(DatabaseTestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
super().setUp()
|
super().setUp()
|
||||||
top_coms, self.claim_ids = generate_top_comments(5, 75)
|
top_coms, self.claim_ids = generate_top_comments(5, 75)
|
||||||
self.top_comments = {
|
|
||||||
commie_id: [create_comment(self.conn, **commie) for commie in commie_list]
|
|
||||||
for commie_id, commie_list in top_coms.items()
|
|
||||||
}
|
|
||||||
self.replies = [
|
|
||||||
create_comment(self.conn, **reply)
|
|
||||||
for reply in generate_replies(self.top_comments)
|
|
||||||
]
|
|
||||||
|
|
||||||
def testLists(self):
|
def testLists(self):
|
||||||
|
self.skipTest('Populating a database each time is not a good way to test listing')
|
||||||
|
|
||||||
for claim_id in self.claim_ids:
|
for claim_id in self.claim_ids:
|
||||||
with self.subTest(claim_id=claim_id):
|
with self.subTest(claim_id=claim_id):
|
||||||
comments = get_claim_comments(self.conn, claim_id)
|
comments = get_claim_comments(self.conn, claim_id)
|
||||||
|
@ -217,6 +211,22 @@ class ListDatabaseTest(DatabaseTestCase):
|
||||||
self.assertEqual(len(matching_comments), len(comment_ids))
|
self.assertEqual(len(matching_comments), len(comment_ids))
|
||||||
|
|
||||||
|
|
||||||
|
def generate_top_comments(ncid=15, ncomm=100, minchar=50, maxchar=500):
|
||||||
|
claim_ids = [fake.sha1() for _ in range(ncid)]
|
||||||
|
top_comments = {
|
||||||
|
cid: [{
|
||||||
|
'claim_id': cid,
|
||||||
|
'comment': ''.join(fake.text(max_nb_chars=randint(minchar, maxchar))),
|
||||||
|
'channel_name': '@' + fake.user_name(),
|
||||||
|
'channel_id': fake.sha1(),
|
||||||
|
'signature': fake.uuid4(),
|
||||||
|
'signing_ts': fake.uuid4()
|
||||||
|
} for _ in range(ncomm)]
|
||||||
|
for cid in claim_ids
|
||||||
|
}
|
||||||
|
return top_comments, claim_ids
|
||||||
|
|
||||||
|
|
||||||
def generate_replies(top_comments):
|
def generate_replies(top_comments):
|
||||||
return [{
|
return [{
|
||||||
'claim_id': comment['claim_id'],
|
'claim_id': comment['claim_id'],
|
||||||
|
@ -224,7 +234,8 @@ def generate_replies(top_comments):
|
||||||
'comment': ' '.join(fake.text(max_nb_chars=randint(50, 500))),
|
'comment': ' '.join(fake.text(max_nb_chars=randint(50, 500))),
|
||||||
'channel_name': '@' + fake.user_name(),
|
'channel_name': '@' + fake.user_name(),
|
||||||
'channel_id': fake.sha1(),
|
'channel_id': fake.sha1(),
|
||||||
'signature': fake.uuid4()
|
'signature': fake.uuid4(),
|
||||||
|
'signing_ts': fake.uuid4()
|
||||||
}
|
}
|
||||||
for claim, comments in top_comments.items()
|
for claim, comments in top_comments.items()
|
||||||
for i, comment in enumerate(comments)
|
for i, comment in enumerate(comments)
|
||||||
|
@ -247,19 +258,6 @@ def generate_replies_random(top_comments):
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def generate_top_comments(ncid=15, ncomm=100, minchar=50, maxchar=500):
|
|
||||||
claim_ids = [fake.sha1() for _ in range(ncid)]
|
|
||||||
top_comments = {
|
|
||||||
cid: [{
|
|
||||||
'claim_id': cid,
|
|
||||||
'comment': ''.join(fake.text(max_nb_chars=randint(minchar, maxchar))),
|
|
||||||
'channel_name': '@' + fake.user_name(),
|
|
||||||
'channel_id': fake.sha1(),
|
|
||||||
'signature': fake.uuid4()
|
|
||||||
} for _ in range(ncomm)]
|
|
||||||
for cid in claim_ids
|
|
||||||
}
|
|
||||||
return top_comments, claim_ids
|
|
||||||
|
|
||||||
|
|
||||||
def generate_top_comments_random():
|
def generate_top_comments_random():
|
||||||
|
|
Loading…
Reference in a new issue