diff --git a/lbry_comment_server/database.py b/lbry_comment_server/database.py index 57b7a6e..68c965a 100644 --- a/lbry_comment_server/database.py +++ b/lbry_comment_server/database.py @@ -149,7 +149,7 @@ async def _insert_channel_async(db_file: str, channel_name: str, channel_id: str async def _insert_comment_async(db_file: str, claim_id: str = None, comment: str = None, - channel_id: str = None, signature: str = None, parent_id: str = None) -> str: + channel_id: str = None, signature: str = None, parent_id: str = None) -> str: timestamp = time.time_ns() comment_prehash = ':'.join((claim_id, comment, str(timestamp),)) comment_prehash = bytes(comment_prehash.encode('utf-8')) @@ -186,7 +186,7 @@ async def create_comment_async(db_file: str, comment: str, claim_id: str, **kwar comment_id = await _insert_comment_async( db_file=db_file, comment=comment, claim_id=claim_id, channel_id=channel_id, **kwargs ) - async with await aiosqlite.connect(db_file) as db: + async with aiosqlite.connect(db_file) as db: db.row_factory = aiosqlite.Row curs = await db.execute( 'SELECT * FROM COMMENTS_ON_CLAIMS WHERE comment_id = ?', (comment_id,) diff --git a/tests/database_test.py b/tests/database_test.py index 39a4379..ff7ea55 100644 --- a/tests/database_test.py +++ b/tests/database_test.py @@ -1,13 +1,15 @@ -import unittest - +from random import randint +import faker from faker.providers import internet from faker.providers import lorem from faker.providers import misc -import lbry_comment_server.database as db -import faker -from random import randint +import lbry_comment_server.database as db +import schema.db_helpers as schema +from lbry_comment_server.settings import config +from tests.testcase import DatabaseTestCase, AsyncioTestCase + fake = faker.Faker() fake.add_provider(internet) fake.add_provider(lorem) @@ -28,9 +30,6 @@ class TestCommentCreation(DatabaseTestCase): channel_id='529357c3422c6046d3fec76be2358004ba22abcd', ) self.assertIsNotNone(comment) - self.assertIn('comment', comment) - self.assertIn('comment_id', comment) - self.assertIn('parent_id', comment) self.assertIsNone(comment['parent_id']) previous_id = comment['comment_id'] reply = db.create_comment( @@ -42,9 +41,6 @@ class TestCommentCreation(DatabaseTestCase): parent_id=previous_id ) self.assertIsNotNone(reply) - self.assertIn('comment', reply) - self.assertIn('comment_id', reply) - self.assertIn('parent_id', reply) self.assertEqual(reply['parent_id'], comment['comment_id']) self.assertEqual(reply['claim_id'], comment['claim_id']) @@ -55,9 +51,6 @@ class TestCommentCreation(DatabaseTestCase): comment='This is an ANONYMOUS comment' ) self.assertIsNotNone(comment) - self.assertIn('comment', comment) - self.assertIn('comment_id', comment) - self.assertIn('parent_id', comment) self.assertIsNone(comment['parent_id']) previous_id = comment['comment_id'] reply = db.create_comment( @@ -67,9 +60,6 @@ class TestCommentCreation(DatabaseTestCase): parent_id=previous_id ) self.assertIsNotNone(reply) - self.assertIn('comment', reply) - self.assertIn('comment_id', reply) - self.assertIn('parent_id', reply) self.assertEqual(reply['parent_id'], comment['comment_id']) self.assertEqual(reply['claim_id'], comment['claim_id']) @@ -83,9 +73,6 @@ class TestCommentCreation(DatabaseTestCase): signature='siggy' ) self.assertIsNotNone(comment) - self.assertIn('comment', comment) - self.assertIn('comment_id', comment) - self.assertIn('parent_id', comment) self.assertIsNone(comment['parent_id']) previous_id = comment['comment_id'] reply = db.create_comment( @@ -98,9 +85,6 @@ class TestCommentCreation(DatabaseTestCase): signature='Cursive Font Goes Here' ) self.assertIsNotNone(reply) - self.assertIn('comment', reply) - self.assertIn('comment_id', reply) - self.assertIn('parent_id', reply) self.assertEqual(reply['parent_id'], comment['comment_id']) self.assertEqual(reply['claim_id'], comment['claim_id']) @@ -185,10 +169,8 @@ class TestCommentCreation(DatabaseTestCase): total += len(comments) self.assertEqual(total, success) self.assertGreater(total, 0) - success, total = 0, 0 for reply in generate_replies(top_comments): db.create_comment(self.conn, **reply) - self.assertEqual(success, total) for claim_id in claim_ids: comments_ids = db.get_comment_ids(self.conn, claim_id) with self.subTest(comments_ids=comments_ids): @@ -231,8 +213,152 @@ class ListDatabaseTest(DatabaseTestCase): self.assertEqual(len(matching_comments), len(comment_ids)) -class AsyncDatabaseTestCase(unittest.TestCase): - async def asyncSetup +class AsyncWriteTest(AsyncioTestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.db_path = config['PATH']['TEST'] + self.claimId = '529357c3422c6046d3fec76be2358004ba22e340' + + async def asyncSetUp(self): + await super().asyncSetUp() + schema.setup_database(self.db_path) + + async def asyncTearDown(self): + await super().asyncTearDown() + schema.teardown_database(self.db_path) + + async def test01NamedComments(self): + comment = await db.create_comment_async( + self.db_path, + claim_id=self.claimId, + comment='This is a named comment', + channel_name='@username', + channel_id='529357c3422c6046d3fec76be2358004ba22abcd', + ) + self.assertIsNotNone(comment) + self.assertIsNone(comment['parent_id']) + previous_id = comment['comment_id'] + reply = await db.create_comment_async( + self.db_path, + claim_id=self.claimId, + comment='This is a named response', + channel_name='@another_username', + channel_id='529357c3422c6046d3fec76be2358004ba224bcd', + parent_id=previous_id + ) + self.assertIsNotNone(reply) + self.assertEqual(reply['parent_id'], comment['comment_id']) + self.assertEqual(reply['claim_id'], comment['claim_id']) + + async def test02AnonymousComments(self): + comment = await db.create_comment_async( + self.db_path, + claim_id=self.claimId, + comment='This is an ANONYMOUS comment' + ) + self.assertIsNotNone(comment) + self.assertIsNone(comment['parent_id']) + previous_id = comment['comment_id'] + reply = await db.create_comment_async( + self.db_path, + claim_id=self.claimId, + comment='This is an unnamed response', + parent_id=previous_id + ) + self.assertIsNotNone(reply) + self.assertEqual(reply['parent_id'], comment['comment_id']) + self.assertEqual(reply['claim_id'], comment['claim_id']) + + async def test03SignedComments(self): + comment = await db.create_comment_async( + self.db_path, + claim_id=self.claimId, + comment='I like big butts and i cannot lie', + channel_name='@sirmixalot', + channel_id='529357c3422c6046d3fec76be2358005ba22abcd', + signature='siggy' + ) + self.assertIsNotNone(comment) + self.assertIsNone(comment['parent_id']) + previous_id = comment['comment_id'] + reply = await db.create_comment_async( + self.db_path, + claim_id=self.claimId, + comment='This is a LBRY verified response', + channel_name='@LBRY', + channel_id='529357c3422c6046d3fec76be2358001ba224bcd', + parent_id=previous_id, + signature='Cursive Font Goes Here' + ) + self.assertIsNotNone(reply) + self.assertEqual(reply['parent_id'], comment['comment_id']) + self.assertEqual(reply['claim_id'], comment['claim_id']) + + async def test04UsernameVariations(self): + invalid_comment = await db.create_comment_async( + self.db_path, + claim_id=self.claimId, + channel_name='$#(@#$@#$', + channel_id='529357c3422c6046d3fec76be2358001ba224b23', + comment='this is an invalid username' + ) + self.assertIsNone(invalid_comment) + valid_username = await db.create_comment_async( + self.db_path, + claim_id=self.claimId, + channel_name='@' + 'a'*255, + channel_id='529357c3422c6046d3fec76be2358001ba224b23', + comment='this is a valid username' + ) + self.assertIsNotNone(valid_username) + + lengthy_username = await db.create_comment_async( + self.db_path, + claim_id=self.claimId, + channel_name='@' + 'a'*256, + channel_id='529357c3422c6046d3fec76be2358001ba224b23', + comment='this username is too long' + ) + self.assertIsNone(lengthy_username) + comment = await db.create_comment_async( + self.db_path, + claim_id=self.claimId, + channel_name='', + channel_id='529357c3422c6046d3fec76be2358001ba224b23', + comment='this username should not default to ANONYMOUS' + ) + self.assertIsNone(comment) + short_username = await db.create_comment_async( + self.db_path, + claim_id=self.claimId, + channel_name='@', + channel_id='529357c3422c6046d3fec76be2358001ba224b23', + comment='this username is too short' + ) + self.assertIsNone(short_username) + + async def test06GenerateAndListComments(self): + top_comments, claim_ids = generate_top_comments() + total, success = 0, 0 + for _, comments in top_comments.items(): + for i, comment in enumerate(comments): + result = await db.create_comment_async(self.db_path, **comment) + if result: + success += 1 + comments[i] = result + del comment + total += len(comments) + self.assertEqual(total, success) + self.assertGreater(total, 0) + success, total = 0, 0 + for reply in generate_replies(top_comments): + inserted_reply = await db.create_comment_async(self.db_path, **reply) + if inserted_reply: + success += 1 + total += 1 + + self.assertEqual(success, total) + self.assertGreater(success, 0) def generate_replies(top_comments): diff --git a/tests/testcase.py b/tests/testcase.py index 16c913f..26f34ed 100644 --- a/tests/testcase.py +++ b/tests/testcase.py @@ -7,6 +7,7 @@ import lbry_comment_server.database as db from lbry_comment_server import config import schema.db_helpers as schema + class AsyncioTestCase(unittest.TestCase): # Implementation inspired by discussion: # https://bugs.python.org/issue32972 @@ -115,7 +116,6 @@ class AsyncioTestCase(unittest.TestCase): self.loop.run_until_complete(maybe_coroutine) - class DatabaseTestCase(unittest.TestCase): def setUp(self) -> None: super().setUp() @@ -123,21 +123,7 @@ class DatabaseTestCase(unittest.TestCase): self.conn = db.obtain_connection(config['PATH']['TEST']) def tearDown(self) -> None: - curs = self.conn.execute('SELECT * FROM COMMENT') - results = {'COMMENT': [dict(r) for r in curs.fetchall()]} - curs = self.conn.execute('SELECT * FROM CHANNEL') - results['CHANNEL'] = [dict(r) for r in curs.fetchall()] - curs = self.conn.execute('SELECT * FROM COMMENTS_ON_CLAIMS') - results['COMMENTS_ON_CLAIMS'] = [dict(r) for r in curs.fetchall()] - curs = self.conn.execute('SELECT * FROM COMMENT_REPLIES') - results['COMMENT_REPLIES'] = [dict(r) for r in curs.fetchall()] - # print(json.dumps(results, indent=4)) - with self.conn: - self.conn.executescript(""" - DROP TABLE IF EXISTS COMMENT; - DROP TABLE IF EXISTS CHANNEL; - DROP VIEW IF EXISTS COMMENTS_ON_CLAIMS; - DROP VIEW IF EXISTS COMMENT_REPLIES; - """) self.conn.close() + schema.teardown_database(config['PATH']['TEST']) +