base.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import json
  2. import logging
  3. import uuid
  4. from typing import Any, Optional
  5. from embedchain.core.db.database import get_session
  6. from embedchain.core.db.models import ChatHistory as ChatHistoryModel
  7. from embedchain.memory.message import ChatMessage
  8. from embedchain.memory.utils import merge_metadata_dict
  9. logger = logging.getLogger(__name__)
  10. class ChatHistory:
  11. def __init__(self) -> None:
  12. self.db_session = get_session()
  13. def add(self, app_id, session_id, chat_message: ChatMessage) -> Optional[str]:
  14. memory_id = str(uuid.uuid4())
  15. metadata_dict = merge_metadata_dict(chat_message.human_message.metadata, chat_message.ai_message.metadata)
  16. if metadata_dict:
  17. metadata = self._serialize_json(metadata_dict)
  18. self.db_session.add(
  19. ChatHistoryModel(
  20. app_id=app_id,
  21. id=memory_id,
  22. session_id=session_id,
  23. question=chat_message.human_message.content,
  24. answer=chat_message.ai_message.content,
  25. metadata=metadata if metadata_dict else "{}",
  26. )
  27. )
  28. try:
  29. self.db_session.commit()
  30. except Exception as e:
  31. logger.error(f"Error adding chat memory to db: {e}")
  32. self.db_session.rollback()
  33. return None
  34. logger.info(f"Added chat memory to db with id: {memory_id}")
  35. return memory_id
  36. def delete(self, app_id: str, session_id: Optional[str] = None):
  37. """
  38. Delete all chat history for a given app_id and session_id.
  39. This is useful for deleting chat history for a given user.
  40. :param app_id: The app_id to delete chat history for
  41. :param session_id: The session_id to delete chat history for
  42. :return: None
  43. """
  44. params = {"app_id": app_id}
  45. if session_id:
  46. params["session_id"] = session_id
  47. self.db_session.query(ChatHistoryModel).filter_by(**params).delete()
  48. try:
  49. self.db_session.commit()
  50. except Exception as e:
  51. logger.error(f"Error deleting chat history: {e}")
  52. self.db_session.rollback()
  53. def get(
  54. self, app_id, session_id: str = "default", num_rounds=10, fetch_all: bool = False, display_format=False
  55. ) -> list[ChatMessage]:
  56. """
  57. Get the chat history for a given app_id.
  58. param: app_id - The app_id to get chat history
  59. param: session_id (optional) - The session_id to get chat history. Defaults to "default"
  60. param: num_rounds (optional) - The number of rounds to get chat history. Defaults to 10
  61. param: fetch_all (optional) - Whether to fetch all chat history or not. Defaults to False
  62. param: display_format (optional) - Whether to return the chat history in display format. Defaults to False
  63. """
  64. params = {"app_id": app_id}
  65. if not fetch_all:
  66. params["session_id"] = session_id
  67. results = (
  68. self.db_session.query(ChatHistoryModel).filter_by(**params).order_by(ChatHistoryModel.created_at.asc())
  69. )
  70. results = results.limit(num_rounds) if not fetch_all else results
  71. history = []
  72. for result in results:
  73. metadata = self._deserialize_json(metadata=result.meta_data or "{}")
  74. # Return list of dict if display_format is True
  75. if display_format:
  76. history.append(
  77. {
  78. "session_id": result.session_id,
  79. "human": result.question,
  80. "ai": result.answer,
  81. "metadata": result.meta_data,
  82. "timestamp": result.created_at,
  83. }
  84. )
  85. else:
  86. memory = ChatMessage()
  87. memory.add_user_message(result.question, metadata=metadata)
  88. memory.add_ai_message(result.answer, metadata=metadata)
  89. history.append(memory)
  90. return history
  91. def count(self, app_id: str, session_id: Optional[str] = None):
  92. """
  93. Count the number of chat messages for a given app_id and session_id.
  94. :param app_id: The app_id to count chat history for
  95. :param session_id: The session_id to count chat history for
  96. :return: The number of chat messages for a given app_id and session_id
  97. """
  98. # Rewrite the logic below with sqlalchemy
  99. params = {"app_id": app_id}
  100. if session_id:
  101. params["session_id"] = session_id
  102. return self.db_session.query(ChatHistoryModel).filter_by(**params).count()
  103. @staticmethod
  104. def _serialize_json(metadata: dict[str, Any]):
  105. return json.dumps(metadata)
  106. @staticmethod
  107. def _deserialize_json(metadata: str):
  108. return json.loads(metadata)
  109. def close_connection(self):
  110. self.connection.close()