assistants.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import logging
  2. import os
  3. import re
  4. import tempfile
  5. import time
  6. from pathlib import Path
  7. from typing import cast
  8. from openai import OpenAI
  9. from openai.types.beta.threads import MessageContentText, ThreadMessage
  10. from embedchain.config import AddConfig
  11. from embedchain.data_formatter import DataFormatter
  12. from embedchain.models.data_type import DataType
  13. from embedchain.utils import detect_datatype
  14. logging.basicConfig(level=logging.WARN)
  15. class OpenAIAssistant:
  16. def __init__(
  17. self,
  18. name=None,
  19. instructions=None,
  20. tools=None,
  21. thread_id=None,
  22. model="gpt-4-1106-preview",
  23. data_sources=None,
  24. assistant_id=None,
  25. log_level=logging.WARN,
  26. ):
  27. self.name = name or "OpenAI Assistant"
  28. self.instructions = instructions
  29. self.tools = tools or [{"type": "retrieval"}]
  30. self.model = model
  31. self.data_sources = data_sources or []
  32. self.log_level = log_level
  33. self._client = OpenAI()
  34. self._initialize_assistant(assistant_id)
  35. self.thread_id = thread_id or self._create_thread()
  36. def add(self, source, data_type=None):
  37. file_path = self._prepare_source_path(source, data_type)
  38. self._add_file_to_assistant(file_path)
  39. logging.info("Data successfully added to the assistant.")
  40. def chat(self, message):
  41. self._send_message(message)
  42. return self._get_latest_response()
  43. def delete_thread(self):
  44. self._client.beta.threads.delete(self.thread_id)
  45. self.thread_id = self._create_thread()
  46. # Internal methods
  47. def _initialize_assistant(self, assistant_id):
  48. file_ids = self._generate_file_ids(self.data_sources)
  49. self.assistant = (
  50. self._client.beta.assistants.retrieve(assistant_id)
  51. if assistant_id
  52. else self._client.beta.assistants.create(
  53. name=self.name, model=self.model, file_ids=file_ids, instructions=self.instructions, tools=self.tools
  54. )
  55. )
  56. def _create_thread(self):
  57. thread = self._client.beta.threads.create()
  58. return thread.id
  59. def _prepare_source_path(self, source, data_type=None):
  60. if Path(source).is_file():
  61. return source
  62. data_type = data_type or detect_datatype(source)
  63. formatter = DataFormatter(data_type=DataType(data_type), config=AddConfig(), kwargs={})
  64. data = formatter.loader.load_data(source)["data"]
  65. return self._save_temp_data(data=data[0]["content"].encode(), source=source)
  66. def _add_file_to_assistant(self, file_path):
  67. file_obj = self._client.files.create(file=open(file_path, "rb"), purpose="assistants")
  68. self._client.beta.assistants.files.create(assistant_id=self.assistant.id, file_id=file_obj.id)
  69. def _generate_file_ids(self, data_sources):
  70. return [
  71. self._add_file_to_assistant(self._prepare_source_path(ds["source"], ds.get("data_type")))
  72. for ds in data_sources
  73. ]
  74. def _send_message(self, message):
  75. self._client.beta.threads.messages.create(thread_id=self.thread_id, role="user", content=message)
  76. self._wait_for_completion()
  77. def _wait_for_completion(self):
  78. run = self._client.beta.threads.runs.create(
  79. thread_id=self.thread_id,
  80. assistant_id=self.assistant.id,
  81. instructions=self.instructions,
  82. )
  83. run_id = run.id
  84. run_status = run.status
  85. while run_status in ["queued", "in_progress", "requires_action"]:
  86. time.sleep(0.1) # Sleep before making the next API call to avoid hitting rate limits
  87. run = self._client.beta.threads.runs.retrieve(thread_id=self.thread_id, run_id=run_id)
  88. run_status = run.status
  89. if run_status == "failed":
  90. raise ValueError(f"Thread run failed with the following error: {run.last_error}")
  91. def _get_latest_response(self):
  92. history = self._get_history()
  93. return self._format_message(history[0]) if history else None
  94. def _get_history(self):
  95. messages = self._client.beta.threads.messages.list(thread_id=self.thread_id, order="desc")
  96. return list(messages)
  97. def _format_message(self, thread_message):
  98. thread_message = cast(ThreadMessage, thread_message)
  99. content = [c.text.value for c in thread_message.content if isinstance(c, MessageContentText)]
  100. return " ".join(content)
  101. def _save_temp_data(self, data, source):
  102. special_chars_pattern = r'[\\/:*?"<>|&=% ]+'
  103. sanitized_source = re.sub(special_chars_pattern, "_", source)[:256]
  104. temp_dir = tempfile.mkdtemp()
  105. file_path = os.path.join(temp_dir, sanitized_source)
  106. with open(file_path, "wb") as file:
  107. file.write(data)
  108. return file_path