base.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import json
  2. import logging
  3. import sqlite3
  4. import uuid
  5. from typing import Any, Optional
  6. from embedchain.constants import SQLITE_PATH
  7. from embedchain.memory.message import ChatMessage
  8. from embedchain.memory.utils import merge_metadata_dict
  9. CHAT_MESSAGE_CREATE_TABLE_QUERY = """
  10. CREATE TABLE IF NOT EXISTS ec_chat_history (
  11. app_id TEXT,
  12. id TEXT,
  13. session_id TEXT,
  14. question TEXT,
  15. answer TEXT,
  16. metadata TEXT,
  17. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  18. PRIMARY KEY (id, app_id, session_id)
  19. )
  20. """
  21. class ChatHistory:
  22. def __init__(self) -> None:
  23. with sqlite3.connect(SQLITE_PATH, check_same_thread=False) as self.connection:
  24. self.cursor = self.connection.cursor()
  25. self.cursor.execute(CHAT_MESSAGE_CREATE_TABLE_QUERY)
  26. self.connection.commit()
  27. def add(self, app_id, session_id, chat_message: ChatMessage) -> Optional[str]:
  28. memory_id = str(uuid.uuid4())
  29. metadata_dict = merge_metadata_dict(chat_message.human_message.metadata, chat_message.ai_message.metadata)
  30. if metadata_dict:
  31. metadata = self._serialize_json(metadata_dict)
  32. ADD_CHAT_MESSAGE_QUERY = """
  33. INSERT INTO ec_chat_history (app_id, id, session_id, question, answer, metadata)
  34. VALUES (?, ?, ?, ?, ?, ?)
  35. """
  36. self.cursor.execute(
  37. ADD_CHAT_MESSAGE_QUERY,
  38. (
  39. app_id,
  40. memory_id,
  41. session_id,
  42. chat_message.human_message.content,
  43. chat_message.ai_message.content,
  44. metadata if metadata_dict else "{}",
  45. ),
  46. )
  47. self.connection.commit()
  48. logging.info(f"Added chat memory to db with id: {memory_id}")
  49. return memory_id
  50. def delete(self, app_id: str, session_id: Optional[str] = None):
  51. """
  52. Delete all chat history for a given app_id and session_id.
  53. This is useful for deleting chat history for a given user.
  54. :param app_id: The app_id to delete chat history for
  55. :param session_id: The session_id to delete chat history for
  56. :return: None
  57. """
  58. if session_id:
  59. DELETE_CHAT_HISTORY_QUERY = "DELETE FROM ec_chat_history WHERE app_id=? AND session_id=?"
  60. params = (app_id, session_id)
  61. else:
  62. DELETE_CHAT_HISTORY_QUERY = "DELETE FROM ec_chat_history WHERE app_id=?"
  63. params = (app_id,)
  64. self.cursor.execute(DELETE_CHAT_HISTORY_QUERY, params)
  65. self.connection.commit()
  66. def get(
  67. self, app_id, session_id: str = "default", num_rounds=10, fetch_all: bool = False, display_format=False
  68. ) -> list[ChatMessage]:
  69. """
  70. Get the chat history for a given app_id.
  71. param: app_id - The app_id to get chat history
  72. param: session_id (optional) - The session_id to get chat history. Defaults to "default"
  73. param: num_rounds (optional) - The number of rounds to get chat history. Defaults to 10
  74. param: fetch_all (optional) - Whether to fetch all chat history or not. Defaults to False
  75. param: display_format (optional) - Whether to return the chat history in display format. Defaults to False
  76. """
  77. base_query = """
  78. SELECT * FROM ec_chat_history
  79. WHERE app_id=?
  80. """
  81. if fetch_all:
  82. additional_query = "ORDER BY created_at ASC"
  83. params = (app_id,)
  84. else:
  85. additional_query = """
  86. AND session_id=?
  87. ORDER BY created_at ASC
  88. LIMIT ?
  89. """
  90. params = (app_id, session_id, num_rounds)
  91. QUERY = base_query + additional_query
  92. self.cursor.execute(
  93. QUERY,
  94. params,
  95. )
  96. results = self.cursor.fetchall()
  97. history = []
  98. for result in results:
  99. app_id, _, session_id, question, answer, metadata, timestamp = result
  100. metadata = self._deserialize_json(metadata=metadata)
  101. # Return list of dict if display_format is True
  102. if display_format:
  103. history.append(
  104. {
  105. "session_id": session_id,
  106. "human": question,
  107. "ai": answer,
  108. "metadata": metadata,
  109. "timestamp": timestamp,
  110. }
  111. )
  112. else:
  113. memory = ChatMessage()
  114. memory.add_user_message(question, metadata=metadata)
  115. memory.add_ai_message(answer, metadata=metadata)
  116. history.append(memory)
  117. return history
  118. def count(self, app_id: str, session_id: Optional[str] = None):
  119. """
  120. Count the number of chat messages for a given app_id and session_id.
  121. :param app_id: The app_id to count chat history for
  122. :param session_id: The session_id to count chat history for
  123. :return: The number of chat messages for a given app_id and session_id
  124. """
  125. if session_id:
  126. QUERY = "SELECT COUNT(*) FROM ec_chat_history WHERE app_id=? AND session_id=?"
  127. params = (app_id, session_id)
  128. else:
  129. QUERY = "SELECT COUNT(*) FROM ec_chat_history WHERE app_id=?"
  130. params = (app_id,)
  131. self.cursor.execute(QUERY, params)
  132. count = self.cursor.fetchone()[0]
  133. return count
  134. @staticmethod
  135. def _serialize_json(metadata: dict[str, Any]):
  136. return json.dumps(metadata)
  137. @staticmethod
  138. def _deserialize_json(metadata: str):
  139. return json.loads(metadata)
  140. def close_connection(self):
  141. self.connection.close()