assistants.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. import logging
  2. import os
  3. import re
  4. import tempfile
  5. import time
  6. import uuid
  7. from pathlib import Path
  8. from typing import cast
  9. from openai import OpenAI
  10. from openai.types.beta.threads import MessageContentText, ThreadMessage
  11. from embedchain import Client, Pipeline
  12. from embedchain.config import AddConfig
  13. from embedchain.data_formatter import DataFormatter
  14. from embedchain.models.data_type import DataType
  15. from embedchain.telemetry.posthog import AnonymousTelemetry
  16. from embedchain.utils.misc import detect_datatype
  17. # Set up the user directory if it doesn't exist already
  18. Client.setup()
  19. class OpenAIAssistant:
  20. def __init__(
  21. self,
  22. name=None,
  23. instructions=None,
  24. tools=None,
  25. thread_id=None,
  26. model="gpt-4-1106-preview",
  27. data_sources=None,
  28. assistant_id=None,
  29. log_level=logging.INFO,
  30. collect_metrics=True,
  31. ):
  32. self.name = name or "OpenAI Assistant"
  33. self.instructions = instructions
  34. self.tools = tools or [{"type": "retrieval"}]
  35. self.model = model
  36. self.data_sources = data_sources or []
  37. self.log_level = log_level
  38. self._client = OpenAI()
  39. self._initialize_assistant(assistant_id)
  40. self.thread_id = thread_id or self._create_thread()
  41. self._telemetry_props = {"class": self.__class__.__name__}
  42. self.telemetry = AnonymousTelemetry(enabled=collect_metrics)
  43. self.telemetry.capture(event_name="init", properties=self._telemetry_props)
  44. def add(self, source, data_type=None):
  45. file_path = self._prepare_source_path(source, data_type)
  46. self._add_file_to_assistant(file_path)
  47. event_props = {
  48. **self._telemetry_props,
  49. "data_type": data_type or detect_datatype(source),
  50. }
  51. self.telemetry.capture(event_name="add", properties=event_props)
  52. logging.info("Data successfully added to the assistant.")
  53. def chat(self, message):
  54. self._send_message(message)
  55. self.telemetry.capture(event_name="chat", properties=self._telemetry_props)
  56. return self._get_latest_response()
  57. def delete_thread(self):
  58. self._client.beta.threads.delete(self.thread_id)
  59. self.thread_id = self._create_thread()
  60. # Internal methods
  61. def _initialize_assistant(self, assistant_id):
  62. file_ids = self._generate_file_ids(self.data_sources)
  63. self.assistant = (
  64. self._client.beta.assistants.retrieve(assistant_id)
  65. if assistant_id
  66. else self._client.beta.assistants.create(
  67. name=self.name, model=self.model, file_ids=file_ids, instructions=self.instructions, tools=self.tools
  68. )
  69. )
  70. def _create_thread(self):
  71. thread = self._client.beta.threads.create()
  72. return thread.id
  73. def _prepare_source_path(self, source, data_type=None):
  74. if Path(source).is_file():
  75. return source
  76. data_type = data_type or detect_datatype(source)
  77. formatter = DataFormatter(data_type=DataType(data_type), config=AddConfig())
  78. data = formatter.loader.load_data(source)["data"]
  79. return self._save_temp_data(data=data[0]["content"].encode(), source=source)
  80. def _add_file_to_assistant(self, file_path):
  81. file_obj = self._client.files.create(file=open(file_path, "rb"), purpose="assistants")
  82. self._client.beta.assistants.files.create(assistant_id=self.assistant.id, file_id=file_obj.id)
  83. def _generate_file_ids(self, data_sources):
  84. return [
  85. self._add_file_to_assistant(self._prepare_source_path(ds["source"], ds.get("data_type")))
  86. for ds in data_sources
  87. ]
  88. def _send_message(self, message):
  89. self._client.beta.threads.messages.create(thread_id=self.thread_id, role="user", content=message)
  90. self._wait_for_completion()
  91. def _wait_for_completion(self):
  92. run = self._client.beta.threads.runs.create(
  93. thread_id=self.thread_id,
  94. assistant_id=self.assistant.id,
  95. instructions=self.instructions,
  96. )
  97. run_id = run.id
  98. run_status = run.status
  99. while run_status in ["queued", "in_progress", "requires_action"]:
  100. time.sleep(0.1) # Sleep before making the next API call to avoid hitting rate limits
  101. run = self._client.beta.threads.runs.retrieve(thread_id=self.thread_id, run_id=run_id)
  102. run_status = run.status
  103. if run_status == "failed":
  104. raise ValueError(f"Thread run failed with the following error: {run.last_error}")
  105. def _get_latest_response(self):
  106. history = self._get_history()
  107. return self._format_message(history[0]) if history else None
  108. def _get_history(self):
  109. messages = self._client.beta.threads.messages.list(thread_id=self.thread_id, order="desc")
  110. return list(messages)
  111. @staticmethod
  112. def _format_message(thread_message):
  113. thread_message = cast(ThreadMessage, thread_message)
  114. content = [c.text.value for c in thread_message.content if isinstance(c, MessageContentText)]
  115. return " ".join(content)
  116. @staticmethod
  117. def _save_temp_data(data, source):
  118. special_chars_pattern = r'[\\/:*?"<>|&=% ]+'
  119. sanitized_source = re.sub(special_chars_pattern, "_", source)[:256]
  120. temp_dir = tempfile.mkdtemp()
  121. file_path = os.path.join(temp_dir, sanitized_source)
  122. with open(file_path, "wb") as file:
  123. file.write(data)
  124. return file_path
  125. class AIAssistant:
  126. def __init__(
  127. self,
  128. name=None,
  129. instructions=None,
  130. yaml_path=None,
  131. assistant_id=None,
  132. thread_id=None,
  133. data_sources=None,
  134. log_level=logging.INFO,
  135. collect_metrics=True,
  136. ):
  137. self.name = name or "AI Assistant"
  138. self.data_sources = data_sources or []
  139. self.log_level = log_level
  140. self.instructions = instructions
  141. self.assistant_id = assistant_id or str(uuid.uuid4())
  142. self.thread_id = thread_id or str(uuid.uuid4())
  143. self.pipeline = Pipeline.from_config(config_path=yaml_path) if yaml_path else Pipeline()
  144. self.pipeline.local_id = self.pipeline.config.id = self.thread_id
  145. if self.instructions:
  146. self.pipeline.system_prompt = self.instructions
  147. print(
  148. f"🎉 Created AI Assistant with name: {self.name}, assistant_id: {self.assistant_id}, thread_id: {self.thread_id}" # noqa: E501
  149. )
  150. # telemetry related properties
  151. self._telemetry_props = {"class": self.__class__.__name__}
  152. self.telemetry = AnonymousTelemetry(enabled=collect_metrics)
  153. self.telemetry.capture(event_name="init", properties=self._telemetry_props)
  154. if self.data_sources:
  155. for data_source in self.data_sources:
  156. metadata = {"assistant_id": self.assistant_id, "thread_id": "global_knowledge"}
  157. self.pipeline.add(data_source["source"], data_source.get("data_type"), metadata=metadata)
  158. def add(self, source, data_type=None):
  159. metadata = {"assistant_id": self.assistant_id, "thread_id": self.thread_id}
  160. self.pipeline.add(source, data_type=data_type, metadata=metadata)
  161. event_props = {
  162. **self._telemetry_props,
  163. "data_type": data_type or detect_datatype(source),
  164. }
  165. self.telemetry.capture(event_name="add", properties=event_props)
  166. def chat(self, query):
  167. where = {
  168. "$and": [
  169. {"assistant_id": {"$eq": self.assistant_id}},
  170. {"thread_id": {"$in": [self.thread_id, "global_knowledge"]}},
  171. ]
  172. }
  173. return self.pipeline.chat(query, where=where)
  174. def delete(self):
  175. self.pipeline.reset()