base.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. import json
  2. import logging
  3. import sqlite3
  4. import uuid
  5. from typing import Any, Optional
  6. from embedchain.constants import SQLITE_PATH
  7. from embedchain.memory.message import ChatMessage
  8. from embedchain.memory.utils import merge_metadata_dict
  9. CHAT_MESSAGE_CREATE_TABLE_QUERY = """
  10. CREATE TABLE IF NOT EXISTS ec_chat_history (
  11. app_id TEXT,
  12. id TEXT,
  13. session_id TEXT,
  14. question TEXT,
  15. answer TEXT,
  16. metadata TEXT,
  17. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  18. PRIMARY KEY (id, app_id, session_id)
  19. )
  20. """
  21. class ChatHistory:
  22. def __init__(self) -> None:
  23. with sqlite3.connect(SQLITE_PATH, check_same_thread=False) as self.connection:
  24. self.cursor = self.connection.cursor()
  25. self.cursor.execute(CHAT_MESSAGE_CREATE_TABLE_QUERY)
  26. self.connection.commit()
  27. def add(self, app_id, session_id, chat_message: ChatMessage) -> Optional[str]:
  28. memory_id = str(uuid.uuid4())
  29. metadata_dict = merge_metadata_dict(chat_message.human_message.metadata, chat_message.ai_message.metadata)
  30. if metadata_dict:
  31. metadata = self._serialize_json(metadata_dict)
  32. ADD_CHAT_MESSAGE_QUERY = """
  33. INSERT INTO ec_chat_history (app_id, id, session_id, question, answer, metadata)
  34. VALUES (?, ?, ?, ?, ?, ?)
  35. """
  36. self.cursor.execute(
  37. ADD_CHAT_MESSAGE_QUERY,
  38. (
  39. app_id,
  40. memory_id,
  41. session_id,
  42. chat_message.human_message.content,
  43. chat_message.ai_message.content,
  44. metadata if metadata_dict else "{}",
  45. ),
  46. )
  47. self.connection.commit()
  48. logging.info(f"Added chat memory to db with id: {memory_id}")
  49. return memory_id
  50. def delete(self, app_id: str, session_id: Optional[str] = None):
  51. """
  52. Delete all chat history for a given app_id and session_id.
  53. This is useful for deleting chat history for a given user.
  54. :param app_id: The app_id to delete chat history for
  55. :param session_id: The session_id to delete chat history for
  56. :return: None
  57. """
  58. if session_id:
  59. DELETE_CHAT_HISTORY_QUERY = "DELETE FROM ec_chat_history WHERE app_id=? AND session_id=?"
  60. params = (app_id, session_id)
  61. else:
  62. DELETE_CHAT_HISTORY_QUERY = "DELETE FROM ec_chat_history WHERE app_id=?"
  63. params = (app_id,)
  64. self.cursor.execute(DELETE_CHAT_HISTORY_QUERY, params)
  65. self.connection.commit()
  66. def get(self, app_id, session_id, num_rounds=10, display_format=False) -> list[ChatMessage]:
  67. """
  68. Get the most recent num_rounds rounds of conversations
  69. between human and AI, for a given app_id.
  70. """
  71. QUERY = """
  72. SELECT * FROM ec_chat_history
  73. WHERE app_id=? AND session_id=?
  74. ORDER BY created_at DESC
  75. LIMIT ?
  76. """
  77. self.cursor.execute(
  78. QUERY,
  79. (app_id, session_id, num_rounds),
  80. )
  81. results = self.cursor.fetchall()
  82. history = []
  83. for result in results:
  84. app_id, _, session_id, question, answer, metadata, timestamp = result
  85. metadata = self._deserialize_json(metadata=metadata)
  86. # Return list of dict if display_format is True
  87. if display_format:
  88. history.append({"human": question, "ai": answer, "metadata": metadata, "timestamp": timestamp})
  89. else:
  90. memory = ChatMessage()
  91. memory.add_user_message(question, metadata=metadata)
  92. memory.add_ai_message(answer, metadata=metadata)
  93. history.append(memory)
  94. return history
  95. def count(self, app_id: str, session_id: Optional[str] = None):
  96. """
  97. Count the number of chat messages for a given app_id and session_id.
  98. :param app_id: The app_id to count chat history for
  99. :param session_id: The session_id to count chat history for
  100. :return: The number of chat messages for a given app_id and session_id
  101. """
  102. if session_id:
  103. QUERY = "SELECT COUNT(*) FROM ec_chat_history WHERE app_id=? AND session_id=?"
  104. params = (app_id, session_id)
  105. else:
  106. QUERY = "SELECT COUNT(*) FROM ec_chat_history WHERE app_id=?"
  107. params = (app_id,)
  108. self.cursor.execute(QUERY, params)
  109. count = self.cursor.fetchone()[0]
  110. return count
  111. @staticmethod
  112. def _serialize_json(metadata: dict[str, Any]):
  113. return json.dumps(metadata)
  114. @staticmethod
  115. def _deserialize_json(metadata: str):
  116. return json.loads(metadata)
  117. def close_connection(self):
  118. self.connection.close()