|
@@ -53,7 +53,7 @@ class ChatHistory:
|
|
|
logging.info(f"Added chat memory to db with id: {memory_id}")
|
|
|
return memory_id
|
|
|
|
|
|
- def delete(self, app_id: str, session_id: str):
|
|
|
+ def delete(self, app_id: str, session_id: Optional[str] = None):
|
|
|
"""
|
|
|
Delete all chat history for a given app_id and session_id.
|
|
|
This is useful for deleting chat history for a given user.
|
|
@@ -63,8 +63,14 @@ class ChatHistory:
|
|
|
|
|
|
:return: None
|
|
|
"""
|
|
|
- DELETE_CHAT_HISTORY_QUERY = "DELETE FROM ec_chat_history WHERE app_id=? AND session_id=?"
|
|
|
- self.cursor.execute(DELETE_CHAT_HISTORY_QUERY, (app_id, session_id))
|
|
|
+ if session_id:
|
|
|
+ DELETE_CHAT_HISTORY_QUERY = "DELETE FROM ec_chat_history WHERE app_id=? AND session_id=?"
|
|
|
+ params = (app_id, session_id)
|
|
|
+ else:
|
|
|
+ DELETE_CHAT_HISTORY_QUERY = "DELETE FROM ec_chat_history WHERE app_id=?"
|
|
|
+ params = (app_id,)
|
|
|
+
|
|
|
+ self.cursor.execute(DELETE_CHAT_HISTORY_QUERY, params)
|
|
|
self.connection.commit()
|
|
|
|
|
|
def get(self, app_id, session_id, num_rounds=10, display_format=False) -> list[ChatMessage]:
|
|
@@ -99,7 +105,7 @@ class ChatHistory:
|
|
|
history.append(memory)
|
|
|
return history
|
|
|
|
|
|
- def count(self, app_id: str, session_id: str):
|
|
|
+ def count(self, app_id: str, session_id: Optional[str] = None):
|
|
|
"""
|
|
|
Count the number of chat messages for a given app_id and session_id.
|
|
|
|
|
@@ -108,8 +114,14 @@ class ChatHistory:
|
|
|
|
|
|
:return: The number of chat messages for a given app_id and session_id
|
|
|
"""
|
|
|
- QUERY = "SELECT COUNT(*) FROM ec_chat_history WHERE app_id=? AND session_id=?"
|
|
|
- self.cursor.execute(QUERY, (app_id, session_id))
|
|
|
+ if session_id:
|
|
|
+ QUERY = "SELECT COUNT(*) FROM ec_chat_history WHERE app_id=? AND session_id=?"
|
|
|
+ params = (app_id, session_id)
|
|
|
+ else:
|
|
|
+ QUERY = "SELECT COUNT(*) FROM ec_chat_history WHERE app_id=?"
|
|
|
+ params = (app_id,)
|
|
|
+
|
|
|
+ self.cursor.execute(QUERY, params)
|
|
|
count = self.cursor.fetchone()[0]
|
|
|
return count
|
|
|
|