diff --git a/server/database.py b/server/database.py index 1d258de..1de26de 100644 --- a/server/database.py +++ b/server/database.py @@ -21,7 +21,6 @@ def validate_input(**kwargs): class DatabaseConnection: - def obtain_connection(self): self.connection = sqlite3.connect(self.filepath) if self.row_factory: @@ -66,7 +65,7 @@ class DatabaseConnection: ) return [dict(row) for row in curs.fetchall()] - def _insert_channel(self, channel_name, channel_id, *args, **kwargs): + def _insert_channel(self, channel_name, channel_id): with self.connection: self.connection.execute( 'INSERT INTO CHANNEL(ClaimId, Name) VALUES (?, ?)', @@ -91,37 +90,33 @@ class DatabaseConnection: ) return comment_id - def create_comment(self, comment: str, claim_id: str, channel_name: str = None, - channel_id: str = None, **kwargs) -> dict: - thing = None - try: - validate_input( - comment=comment, - claim_id=claim_id, - channel_id=channel_id, - channel_name=channel_name, - ) - if channel_id and channel_name: + def create_comment(self, comment: str, claim_id: str, **kwargs) -> typing.Union[dict, None]: + channel_id = kwargs.pop('channel_id', '') + channel_name = kwargs.pop('channel_name', '') + if channel_id or channel_name: + try: + validate_input( + comment=comment, + claim_id=claim_id, + channel_id=channel_id, + channel_name=channel_name, + ) self._insert_channel(channel_name, channel_id) - else: - channel_id = anonymous['channel_id'] - comcast_id = self._insert_comment( - comment=comment, - claim_id=claim_id, - channel_id=channel_id, - **kwargs - ) - curry = self.connection.execute( - 'SELECT * FROM COMMENTS_ON_CLAIMS WHERE comment_id = ?', (comcast_id,) - ) - thing = curry.fetchone() - except AssertionError as e: - print(e) - finally: - return dict(thing) if thing else None - - - + except AssertionError: + return None + else: + channel_id = anonymous['channel_id'] + comcast_id = self._insert_comment( + comment=comment, + claim_id=claim_id, + channel_id=channel_id, + **kwargs + ) + curry = self.connection.execute( + 'SELECT * FROM COMMENTS_ON_CLAIMS WHERE comment_id = ?', (comcast_id,) + ) + thing = curry.fetchone() + return dict(thing) if thing else None def get_comment_ids(self, claim_id: str, parent_id: str = None, page=1, page_size=50): """ Just return a list of the comment IDs that are associated with the given claim_id. @@ -143,14 +138,17 @@ class DatabaseConnection: WHERE claim_id LIKE ? AND parent_id LIKE ? LIMIT ? OFFSET ? """, (claim_id, parent_id, page_size, page_size * abs(page - 1),) ) - return [tuple(row) for row in curs.fetchall()] + return [tuple(row)[0] for row in curs.fetchall()] def get_comments_by_id(self, comment_ids: list) -> typing.Union[list, None]: """ Returns a list containing the comment data associated with each ID within the list""" + if len(comment_ids) == 0: + return None placeholders = ', '.join('?' for _ in comment_ids) curs = self.connection.execute( f'SELECT * FROM COMMENTS_ON_CLAIMS WHERE comment_id IN ({placeholders})', - comment_ids + tuple(comment_ids) + ) return [dict(row) for row in curs.fetchall()]