Pārlūkot izejas kodu

[Feature] Setup base for creating pipelines in embedchain (#834)

Deshraj Yadav 1 gadu atpakaļ
vecāks
revīzija
d18e533adf

+ 1 - 0
embedchain/__init__.py

@@ -3,4 +3,5 @@ import importlib.metadata
 __version__ = importlib.metadata.version(__package__ or __name__)
 
 from embedchain.apps.app import App  # noqa: F401
+from embedchain.pipeline import Pipeline  # noqa: F401
 from embedchain.vectordb.chroma import ChromaDB  # noqa: F401

+ 2 - 0
embedchain/config/__init__.py

@@ -2,10 +2,12 @@
 
 from .add_config import AddConfig, ChunkerConfig
 from .apps.app_config import AppConfig
+from .pipeline_config import PipelineConfig
 from .base_config import BaseConfig
 from .embedder.base import BaseEmbedderConfig
 from .embedder.base import BaseEmbedderConfig as EmbedderConfig
 from .llm.base import BaseLlmConfig
+from .pipeline_config import PipelineConfig
 from .vectordb.chroma import ChromaDbConfig
 from .vectordb.elasticsearch import ElasticsearchDBConfig
 from .vectordb.opensearch import OpenSearchDBConfig

+ 38 - 0
embedchain/config/pipeline_config.py

@@ -0,0 +1,38 @@
+from typing import Optional
+
+from embedchain.helper.json_serializable import register_deserializable
+
+from .apps.base_app_config import BaseAppConfig
+
+
+@register_deserializable
+class PipelineConfig(BaseAppConfig):
+    """
+    Config to initialize an embedchain custom `App` instance, with extra config options.
+    """
+
+    def __init__(
+        self,
+        log_level: str = "WARNING",
+        id: Optional[str] = None,
+        name: Optional[str] = None,
+        collect_metrics: Optional[bool] = False,
+    ):
+        """
+        Initializes a configuration class instance for an App. This is the simplest form of an embedchain app.
+        Most of the configuration is done in the `App` class itself.
+
+        :param log_level: Debug level ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], defaults to "WARNING"
+        :type log_level: str, optional
+        :param id: ID of the app. Document metadata will have this id., defaults to None
+        :type id: Optional[str], optional
+        :param collect_metrics: Send anonymous telemetry to improve embedchain, defaults to True
+        :type collect_metrics: Optional[bool], optional
+        :param collection_name: Default collection name. It's recommended to use app.db.set_collection_name() instead,
+        defaults to None
+        :type collection_name: Optional[str], optional
+        """
+        self._setup_logging(log_level)
+        self.id = id
+        self.name = name
+        self.collect_metrics = collect_metrics

+ 98 - 0
embedchain/pipeline.py

@@ -0,0 +1,98 @@
+import threading
+import uuid
+
+import yaml
+
+from embedchain.config import PipelineConfig
+from embedchain.embedchain import EmbedChain
+from embedchain.embedder.base import BaseEmbedder
+from embedchain.embedder.openai import OpenAIEmbedder
+from embedchain.factory import EmbedderFactory, VectorDBFactory
+from embedchain.helper.json_serializable import register_deserializable
+from embedchain.vectordb.base import BaseVectorDB
+from embedchain.vectordb.chroma import ChromaDB
+
+
+@register_deserializable
+class Pipeline(EmbedChain):
+    """
+    EmbedChain pipeline lets you create a LLM powered app for your unstructured
+    data by defining a pipeline with your chosen data source, embedding model,
+    and vector database.
+    """
+
+    def __init__(self, config: PipelineConfig = None, db: BaseVectorDB = None, embedding_model: BaseEmbedder = None):
+        """
+        Initialize a new `App` instance.
+
+        :param config: Configuration for the pipeline, defaults to None
+        :type config: PipelineConfig, optional
+        :param db: The database to use for storing and retrieving embeddings, defaults to None
+        :type db: BaseVectorDB, optional
+        :param embedding_model: The embedding model used to calculate embeddings, defaults to None
+        :type embedding_model: BaseEmbedder, optional
+        """
+        super().__init__()
+        self.config = config or PipelineConfig()
+        self.name = self.config.name
+        self.id = self.config.id or str(uuid.uuid4())
+
+        self.embedding_model = embedding_model or OpenAIEmbedder()
+        self.db = db or ChromaDB()
+        self._initialize_db()
+
+        self.user_asks = []  # legacy defaults
+
+        self.s_id = self.config.id or str(uuid.uuid4())
+        self.u_id = self._load_or_generate_user_id()
+
+        thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("pipeline_init",))
+        thread_telemetry.start()
+
+    def _initialize_db(self):
+        """
+        Initialize the database.
+        """
+        self.db._set_embedder(self.embedding_model)
+        self.db._initialize()
+        self.db.set_collection_name(self.name)
+
+    def search(self, query, num_documents=3):
+        """
+        Search for similar documents related to the query in the vector database.
+        """
+        where = {"app_id": self.id}
+        return self.db.query(
+            query,
+            n_results=num_documents,
+            where=where,
+            skip_embedding=False,
+        )
+
+    @classmethod
+    def from_config(cls, yaml_path: str):
+        """
+        Instantiate a Pipeline object from a YAML configuration file.
+
+        :param yaml_path: Path to the YAML configuration file.
+        :type yaml_path: str
+        :return: An instance of the Pipeline class.
+        :rtype: Pipeline
+        """
+        with open(yaml_path, "r") as file:
+            config_data = yaml.safe_load(file)
+
+        pipeline_config_data = config_data.get("pipeline", {})
+        db_config_data = config_data.get("vectordb", {})
+        embedding_model_config_data = config_data.get("embedding_model", {})
+
+        pipeline_config = PipelineConfig(**pipeline_config_data)
+
+        db_provider = db_config_data.get("provider", "chroma")
+        db = VectorDBFactory.create(db_provider, db_config_data.get("config", {}))
+
+        embedding_model_provider = embedding_model_config_data.get("provider", "openai")
+        embedding_model = EmbedderFactory.create(
+            embedding_model_provider, embedding_model_config_data.get("config", {})
+        )
+        return cls(config=pipeline_config, db=db, embedding_model=embedding_model)