Browse Source

[Feature] Add support for AIAssistant (#938)

Deshraj Yadav 1 year ago
parent
commit
1364975396
3 changed files with 66 additions and 2 deletions
  1. 1 1
      embedchain/embedchain.py
  2. 64 0
      embedchain/store/assistants.py
  3. 1 1
      pyproject.toml

+ 1 - 1
embedchain/embedchain.py

@@ -482,7 +482,7 @@ class EmbedChain(JSONSerializable):
             where = {}
             where = {}
             if query_config is not None and query_config.where is not None:
             if query_config is not None and query_config.where is not None:
                 where = query_config.where
                 where = query_config.where
-            
+
             if self.config.id is not None:
             if self.config.id is not None:
                 where.update({"app_id": self.config.id})
                 where.update({"app_id": self.config.id})
 
 

+ 64 - 0
embedchain/store/assistants.py

@@ -3,12 +3,14 @@ import os
 import re
 import re
 import tempfile
 import tempfile
 import time
 import time
+import uuid
 from pathlib import Path
 from pathlib import Path
 from typing import cast
 from typing import cast
 
 
 from openai import OpenAI
 from openai import OpenAI
 from openai.types.beta.threads import MessageContentText, ThreadMessage
 from openai.types.beta.threads import MessageContentText, ThreadMessage
 
 
+from embedchain import Pipeline
 from embedchain.config import AddConfig
 from embedchain.config import AddConfig
 from embedchain.data_formatter import DataFormatter
 from embedchain.data_formatter import DataFormatter
 from embedchain.models.data_type import DataType
 from embedchain.models.data_type import DataType
@@ -138,3 +140,65 @@ class OpenAIAssistant:
         with open(file_path, "wb") as file:
         with open(file_path, "wb") as file:
             file.write(data)
             file.write(data)
         return file_path
         return file_path
+
+
+class AIAssistant:
+    def __init__(
+        self,
+        name=None,
+        instructions=None,
+        yaml_path=None,
+        assistant_id=None,
+        thread_id=None,
+        data_sources=None,
+        log_level=logging.WARN,
+        collect_metrics=True,
+    ):
+        logging.basicConfig(level=log_level)
+
+        self.name = name or "AI Assistant"
+        self.data_sources = data_sources or []
+        self.log_level = log_level
+        self.instructions = instructions
+        self.assistant_id = assistant_id or str(uuid.uuid4())
+        self.thread_id = thread_id or str(uuid.uuid4())
+        self.pipeline = Pipeline.from_config(yaml_path=yaml_path) if yaml_path else Pipeline()
+        self.pipeline.local_id = self.pipeline.config.id = self.thread_id
+
+        if self.instructions:
+            self.pipeline.system_prompt = self.instructions
+
+        print(
+            f"🎉 Created AI Assistant with name: {self.name}, assistant_id: {self.assistant_id}, thread_id: {self.thread_id}"  # noqa: E501
+        )
+
+        # telemetry related properties
+        self._telemetry_props = {"class": self.__class__.__name__}
+        self.telemetry = AnonymousTelemetry(enabled=collect_metrics)
+        self.telemetry.capture(event_name="init", properties=self._telemetry_props)
+
+        if self.data_sources:
+            for data_source in self.data_sources:
+                metadata = {"assistant_id": self.assistant_id, "thread_id": "global_knowledge"}
+                self.pipeline.add(data_source["source"], data_source.get("data_type"), metadata=metadata)
+
+    def add(self, source, data_type=None):
+        metadata = {"assistant_id": self.assistant_id, "thread_id": self.thread_id}
+        self.pipeline.add(source, data_type=data_type, metadata=metadata)
+        event_props = {
+            **self._telemetry_props,
+            "data_type": data_type or detect_datatype(source),
+        }
+        self.telemetry.capture(event_name="add", properties=event_props)
+
+    def chat(self, query):
+        where = {
+            "$and": [
+                {"assistant_id": {"$eq": self.assistant_id}},
+                {"thread_id": {"$in": [self.thread_id, "global_knowledge"]}},
+            ]
+        }
+        return self.pipeline.chat(query, where=where)
+
+    def delete(self):
+        self.pipeline.reset()

+ 1 - 1
pyproject.toml

@@ -1,6 +1,6 @@
 [tool.poetry]
 [tool.poetry]
 name = "embedchain"
 name = "embedchain"
-version = "0.1.6"
+version = "0.1.7"
 description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
 description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
 authors = [
 authors = [
     "Taranjeet Singh <taranjeet@embedchain.ai>",
     "Taranjeet Singh <taranjeet@embedchain.ai>",