base.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. import json
  2. import logging
  3. import uuid
  4. from typing import Any, Optional
  5. from embedchain.core.db.database import get_session
  6. from embedchain.core.db.models import ChatHistory as ChatHistoryModel
  7. from embedchain.memory.message import ChatMessage
  8. from embedchain.memory.utils import merge_metadata_dict
  9. class ChatHistory:
  10. def __init__(self) -> None:
  11. self.db_session = get_session()
  12. def add(self, app_id, session_id, chat_message: ChatMessage) -> Optional[str]:
  13. memory_id = str(uuid.uuid4())
  14. metadata_dict = merge_metadata_dict(chat_message.human_message.metadata, chat_message.ai_message.metadata)
  15. if metadata_dict:
  16. metadata = self._serialize_json(metadata_dict)
  17. self.db_session.add(
  18. ChatHistoryModel(
  19. app_id=app_id,
  20. id=memory_id,
  21. session_id=session_id,
  22. question=chat_message.human_message.content,
  23. answer=chat_message.ai_message.content,
  24. metadata=metadata if metadata_dict else "{}",
  25. )
  26. )
  27. try:
  28. self.db_session.commit()
  29. except Exception as e:
  30. logging.error(f"Error adding chat memory to db: {e}")
  31. self.db_session.rollback()
  32. return None
  33. logging.info(f"Added chat memory to db with id: {memory_id}")
  34. return memory_id
  35. def delete(self, app_id: str, session_id: Optional[str] = None):
  36. """
  37. Delete all chat history for a given app_id and session_id.
  38. This is useful for deleting chat history for a given user.
  39. :param app_id: The app_id to delete chat history for
  40. :param session_id: The session_id to delete chat history for
  41. :return: None
  42. """
  43. params = {"app_id": app_id}
  44. if session_id:
  45. params["session_id"] = session_id
  46. self.db_session.query(ChatHistoryModel).filter_by(**params).delete()
  47. try:
  48. self.db_session.commit()
  49. except Exception as e:
  50. logging.error(f"Error deleting chat history: {e}")
  51. self.db_session.rollback()
  52. def get(
  53. self, app_id, session_id: str = "default", num_rounds=10, fetch_all: bool = False, display_format=False
  54. ) -> list[ChatMessage]:
  55. """
  56. Get the chat history for a given app_id.
  57. param: app_id - The app_id to get chat history
  58. param: session_id (optional) - The session_id to get chat history. Defaults to "default"
  59. param: num_rounds (optional) - The number of rounds to get chat history. Defaults to 10
  60. param: fetch_all (optional) - Whether to fetch all chat history or not. Defaults to False
  61. param: display_format (optional) - Whether to return the chat history in display format. Defaults to False
  62. """
  63. params = {"app_id": app_id}
  64. if not fetch_all:
  65. params["session_id"] = session_id
  66. results = (
  67. self.db_session.query(ChatHistoryModel).filter_by(**params).order_by(ChatHistoryModel.created_at.asc())
  68. )
  69. results = results.limit(num_rounds) if not fetch_all else results
  70. history = []
  71. for result in results:
  72. metadata = self._deserialize_json(metadata=result.meta_data or "{}")
  73. # Return list of dict if display_format is True
  74. if display_format:
  75. history.append(
  76. {
  77. "session_id": result.session_id,
  78. "human": result.question,
  79. "ai": result.answer,
  80. "metadata": result.meta_data,
  81. "timestamp": result.created_at,
  82. }
  83. )
  84. else:
  85. memory = ChatMessage()
  86. memory.add_user_message(result.question, metadata=metadata)
  87. memory.add_ai_message(result.answer, metadata=metadata)
  88. history.append(memory)
  89. return history
  90. def count(self, app_id: str, session_id: Optional[str] = None):
  91. """
  92. Count the number of chat messages for a given app_id and session_id.
  93. :param app_id: The app_id to count chat history for
  94. :param session_id: The session_id to count chat history for
  95. :return: The number of chat messages for a given app_id and session_id
  96. """
  97. # Rewrite the logic below with sqlalchemy
  98. params = {"app_id": app_id}
  99. if session_id:
  100. params["session_id"] = session_id
  101. return self.db_session.query(ChatHistoryModel).filter_by(**params).count()
  102. @staticmethod
  103. def _serialize_json(metadata: dict[str, Any]):
  104. return json.dumps(metadata)
  105. @staticmethod
  106. def _deserialize_json(metadata: str):
  107. return json.loads(metadata)
  108. def close_connection(self):
  109. self.connection.close()