Fixes input validation for comment creation

This commit is contained in:
Oleg Silkin 2019-05-19 17:15:54 -04:00
parent dcbf4be04c
commit 436608d69b

View file

@ -21,7 +21,6 @@ def validate_input(**kwargs):
class DatabaseConnection: class DatabaseConnection:
def obtain_connection(self): def obtain_connection(self):
self.connection = sqlite3.connect(self.filepath) self.connection = sqlite3.connect(self.filepath)
if self.row_factory: if self.row_factory:
@ -66,7 +65,7 @@ class DatabaseConnection:
) )
return [dict(row) for row in curs.fetchall()] return [dict(row) for row in curs.fetchall()]
def _insert_channel(self, channel_name, channel_id, *args, **kwargs): def _insert_channel(self, channel_name, channel_id):
with self.connection: with self.connection:
self.connection.execute( self.connection.execute(
'INSERT INTO CHANNEL(ClaimId, Name) VALUES (?, ?)', 'INSERT INTO CHANNEL(ClaimId, Name) VALUES (?, ?)',
@ -91,37 +90,33 @@ class DatabaseConnection:
) )
return comment_id return comment_id
def create_comment(self, comment: str, claim_id: str, channel_name: str = None, def create_comment(self, comment: str, claim_id: str, **kwargs) -> typing.Union[dict, None]:
channel_id: str = None, **kwargs) -> dict: channel_id = kwargs.pop('channel_id', '')
thing = None channel_name = kwargs.pop('channel_name', '')
try: if channel_id or channel_name:
validate_input( try:
comment=comment, validate_input(
claim_id=claim_id, comment=comment,
channel_id=channel_id, claim_id=claim_id,
channel_name=channel_name, channel_id=channel_id,
) channel_name=channel_name,
if channel_id and channel_name: )
self._insert_channel(channel_name, channel_id) self._insert_channel(channel_name, channel_id)
else: except AssertionError:
channel_id = anonymous['channel_id'] return None
comcast_id = self._insert_comment( else:
comment=comment, channel_id = anonymous['channel_id']
claim_id=claim_id, comcast_id = self._insert_comment(
channel_id=channel_id, comment=comment,
**kwargs claim_id=claim_id,
) channel_id=channel_id,
curry = self.connection.execute( **kwargs
'SELECT * FROM COMMENTS_ON_CLAIMS WHERE comment_id = ?', (comcast_id,) )
) curry = self.connection.execute(
thing = curry.fetchone() 'SELECT * FROM COMMENTS_ON_CLAIMS WHERE comment_id = ?', (comcast_id,)
except AssertionError as e: )
print(e) thing = curry.fetchone()
finally: return dict(thing) if thing else None
return dict(thing) if thing else None
def get_comment_ids(self, claim_id: str, parent_id: str = None, page=1, page_size=50): def get_comment_ids(self, claim_id: str, parent_id: str = None, page=1, page_size=50):
""" Just return a list of the comment IDs that are associated with the given claim_id. """ Just return a list of the comment IDs that are associated with the given claim_id.
@ -143,14 +138,17 @@ class DatabaseConnection:
WHERE claim_id LIKE ? AND parent_id LIKE ? LIMIT ? OFFSET ? WHERE claim_id LIKE ? AND parent_id LIKE ? LIMIT ? OFFSET ?
""", (claim_id, parent_id, page_size, page_size * abs(page - 1),) """, (claim_id, parent_id, page_size, page_size * abs(page - 1),)
) )
return [tuple(row) for row in curs.fetchall()] return [tuple(row)[0] for row in curs.fetchall()]
def get_comments_by_id(self, comment_ids: list) -> typing.Union[list, None]: def get_comments_by_id(self, comment_ids: list) -> typing.Union[list, None]:
""" Returns a list containing the comment data associated with each ID within the list""" """ Returns a list containing the comment data associated with each ID within the list"""
if len(comment_ids) == 0:
return None
placeholders = ', '.join('?' for _ in comment_ids) placeholders = ', '.join('?' for _ in comment_ids)
curs = self.connection.execute( curs = self.connection.execute(
f'SELECT * FROM COMMENTS_ON_CLAIMS WHERE comment_id IN ({placeholders})', f'SELECT * FROM COMMENTS_ON_CLAIMS WHERE comment_id IN ({placeholders})',
comment_ids tuple(comment_ids)
) )
return [dict(row) for row in curs.fetchall()] return [dict(row) for row in curs.fetchall()]