Przeglądaj źródła

[Feature] Add support for AIAssistant (#938)

Deshraj Yadav 1 rok temu
rodzic
commit
1364975396
3 zmienionych plików z 66 dodań i 2 usunięć
  1. 1 1
      embedchain/embedchain.py
  2. 64 0
      embedchain/store/assistants.py
  3. 1 1
      pyproject.toml

+ 1 - 1
embedchain/embedchain.py

@@ -482,7 +482,7 @@ class EmbedChain(JSONSerializable):
             where = {}
             if query_config is not None and query_config.where is not None:
                 where = query_config.where
-            
+
             if self.config.id is not None:
                 where.update({"app_id": self.config.id})
 

+ 64 - 0
embedchain/store/assistants.py

@@ -3,12 +3,14 @@ import os
 import re
 import tempfile
 import time
+import uuid
 from pathlib import Path
 from typing import cast
 
 from openai import OpenAI
 from openai.types.beta.threads import MessageContentText, ThreadMessage
 
+from embedchain import Pipeline
 from embedchain.config import AddConfig
 from embedchain.data_formatter import DataFormatter
 from embedchain.models.data_type import DataType
@@ -138,3 +140,65 @@ class OpenAIAssistant:
         with open(file_path, "wb") as file:
             file.write(data)
         return file_path
+
+
+class AIAssistant:
+    def __init__(
+        self,
+        name=None,
+        instructions=None,
+        yaml_path=None,
+        assistant_id=None,
+        thread_id=None,
+        data_sources=None,
+        log_level=logging.WARN,
+        collect_metrics=True,
+    ):
+        logging.basicConfig(level=log_level)
+
+        self.name = name or "AI Assistant"
+        self.data_sources = data_sources or []
+        self.log_level = log_level
+        self.instructions = instructions
+        self.assistant_id = assistant_id or str(uuid.uuid4())
+        self.thread_id = thread_id or str(uuid.uuid4())
+        self.pipeline = Pipeline.from_config(yaml_path=yaml_path) if yaml_path else Pipeline()
+        self.pipeline.local_id = self.pipeline.config.id = self.thread_id
+
+        if self.instructions:
+            self.pipeline.system_prompt = self.instructions
+
+        print(
+            f"🎉 Created AI Assistant with name: {self.name}, assistant_id: {self.assistant_id}, thread_id: {self.thread_id}"  # noqa: E501
+        )
+
+        # telemetry related properties
+        self._telemetry_props = {"class": self.__class__.__name__}
+        self.telemetry = AnonymousTelemetry(enabled=collect_metrics)
+        self.telemetry.capture(event_name="init", properties=self._telemetry_props)
+
+        if self.data_sources:
+            for data_source in self.data_sources:
+                metadata = {"assistant_id": self.assistant_id, "thread_id": "global_knowledge"}
+                self.pipeline.add(data_source["source"], data_source.get("data_type"), metadata=metadata)
+
+    def add(self, source, data_type=None):
+        metadata = {"assistant_id": self.assistant_id, "thread_id": self.thread_id}
+        self.pipeline.add(source, data_type=data_type, metadata=metadata)
+        event_props = {
+            **self._telemetry_props,
+            "data_type": data_type or detect_datatype(source),
+        }
+        self.telemetry.capture(event_name="add", properties=event_props)
+
+    def chat(self, query):
+        where = {
+            "$and": [
+                {"assistant_id": {"$eq": self.assistant_id}},
+                {"thread_id": {"$in": [self.thread_id, "global_knowledge"]}},
+            ]
+        }
+        return self.pipeline.chat(query, where=where)
+
+    def delete(self):
+        self.pipeline.reset()

+ 1 - 1
pyproject.toml

@@ -1,6 +1,6 @@
 [tool.poetry]
 name = "embedchain"
-version = "0.1.6"
+version = "0.1.7"
 description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
 authors = [
     "Taranjeet Singh <taranjeet@embedchain.ai>",