|
@@ -3,12 +3,14 @@ import os
|
|
|
import re
|
|
|
import tempfile
|
|
|
import time
|
|
|
+import uuid
|
|
|
from pathlib import Path
|
|
|
from typing import cast
|
|
|
|
|
|
from openai import OpenAI
|
|
|
from openai.types.beta.threads import MessageContentText, ThreadMessage
|
|
|
|
|
|
+from embedchain import Pipeline
|
|
|
from embedchain.config import AddConfig
|
|
|
from embedchain.data_formatter import DataFormatter
|
|
|
from embedchain.models.data_type import DataType
|
|
@@ -138,3 +140,65 @@ class OpenAIAssistant:
|
|
|
with open(file_path, "wb") as file:
|
|
|
file.write(data)
|
|
|
return file_path
|
|
|
+
|
|
|
+
|
|
|
+class AIAssistant:
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ name=None,
|
|
|
+ instructions=None,
|
|
|
+ yaml_path=None,
|
|
|
+ assistant_id=None,
|
|
|
+ thread_id=None,
|
|
|
+ data_sources=None,
|
|
|
+ log_level=logging.WARN,
|
|
|
+ collect_metrics=True,
|
|
|
+ ):
|
|
|
+ logging.basicConfig(level=log_level)
|
|
|
+
|
|
|
+ self.name = name or "AI Assistant"
|
|
|
+ self.data_sources = data_sources or []
|
|
|
+ self.log_level = log_level
|
|
|
+ self.instructions = instructions
|
|
|
+ self.assistant_id = assistant_id or str(uuid.uuid4())
|
|
|
+ self.thread_id = thread_id or str(uuid.uuid4())
|
|
|
+ self.pipeline = Pipeline.from_config(yaml_path=yaml_path) if yaml_path else Pipeline()
|
|
|
+ self.pipeline.local_id = self.pipeline.config.id = self.thread_id
|
|
|
+
|
|
|
+ if self.instructions:
|
|
|
+ self.pipeline.system_prompt = self.instructions
|
|
|
+
|
|
|
+ print(
|
|
|
+ f"🎉 Created AI Assistant with name: {self.name}, assistant_id: {self.assistant_id}, thread_id: {self.thread_id}" # noqa: E501
|
|
|
+ )
|
|
|
+
|
|
|
+ # telemetry related properties
|
|
|
+ self._telemetry_props = {"class": self.__class__.__name__}
|
|
|
+ self.telemetry = AnonymousTelemetry(enabled=collect_metrics)
|
|
|
+ self.telemetry.capture(event_name="init", properties=self._telemetry_props)
|
|
|
+
|
|
|
+ if self.data_sources:
|
|
|
+ for data_source in self.data_sources:
|
|
|
+ metadata = {"assistant_id": self.assistant_id, "thread_id": "global_knowledge"}
|
|
|
+ self.pipeline.add(data_source["source"], data_source.get("data_type"), metadata=metadata)
|
|
|
+
|
|
|
+ def add(self, source, data_type=None):
|
|
|
+ metadata = {"assistant_id": self.assistant_id, "thread_id": self.thread_id}
|
|
|
+ self.pipeline.add(source, data_type=data_type, metadata=metadata)
|
|
|
+ event_props = {
|
|
|
+ **self._telemetry_props,
|
|
|
+ "data_type": data_type or detect_datatype(source),
|
|
|
+ }
|
|
|
+ self.telemetry.capture(event_name="add", properties=event_props)
|
|
|
+
|
|
|
+ def chat(self, query):
|
|
|
+ where = {
|
|
|
+ "$and": [
|
|
|
+ {"assistant_id": {"$eq": self.assistant_id}},
|
|
|
+ {"thread_id": {"$in": [self.thread_id, "global_knowledge"]}},
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ return self.pipeline.chat(query, where=where)
|
|
|
+
|
|
|
+ def delete(self):
|
|
|
+ self.pipeline.reset()
|