Browse Source

[Feature] Setup base for creating pipelines in embedchain (#834)

Deshraj Yadav 1 year ago
parent
commit
d18e533adf

+ 1 - 0
embedchain/__init__.py

@@ -3,4 +3,5 @@ import importlib.metadata
 __version__ = importlib.metadata.version(__package__ or __name__)
 __version__ = importlib.metadata.version(__package__ or __name__)
 
 
 from embedchain.apps.app import App  # noqa: F401
 from embedchain.apps.app import App  # noqa: F401
+from embedchain.pipeline import Pipeline  # noqa: F401
 from embedchain.vectordb.chroma import ChromaDB  # noqa: F401
 from embedchain.vectordb.chroma import ChromaDB  # noqa: F401

+ 2 - 0
embedchain/config/__init__.py

@@ -2,10 +2,12 @@
 
 
 from .add_config import AddConfig, ChunkerConfig
 from .add_config import AddConfig, ChunkerConfig
 from .apps.app_config import AppConfig
 from .apps.app_config import AppConfig
+from .pipeline_config import PipelineConfig
 from .base_config import BaseConfig
 from .base_config import BaseConfig
 from .embedder.base import BaseEmbedderConfig
 from .embedder.base import BaseEmbedderConfig
 from .embedder.base import BaseEmbedderConfig as EmbedderConfig
 from .embedder.base import BaseEmbedderConfig as EmbedderConfig
 from .llm.base import BaseLlmConfig
 from .llm.base import BaseLlmConfig
+from .pipeline_config import PipelineConfig
 from .vectordb.chroma import ChromaDbConfig
 from .vectordb.chroma import ChromaDbConfig
 from .vectordb.elasticsearch import ElasticsearchDBConfig
 from .vectordb.elasticsearch import ElasticsearchDBConfig
 from .vectordb.opensearch import OpenSearchDBConfig
 from .vectordb.opensearch import OpenSearchDBConfig

+ 38 - 0
embedchain/config/pipeline_config.py

@@ -0,0 +1,38 @@
+from typing import Optional
+
+from embedchain.helper.json_serializable import register_deserializable
+
+from .apps.base_app_config import BaseAppConfig
+
+
+@register_deserializable
+class PipelineConfig(BaseAppConfig):
+    """
+    Config to initialize an embedchain custom `App` instance, with extra config options.
+    """
+
+    def __init__(
+        self,
+        log_level: str = "WARNING",
+        id: Optional[str] = None,
+        name: Optional[str] = None,
+        collect_metrics: Optional[bool] = False,
+    ):
+        """
+        Initializes a configuration class instance for an App. This is the simplest form of an embedchain app.
+        Most of the configuration is done in the `App` class itself.
+
+        :param log_level: Debug level ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], defaults to "WARNING"
+        :type log_level: str, optional
+        :param id: ID of the app. Document metadata will have this id., defaults to None
+        :type id: Optional[str], optional
+        :param collect_metrics: Send anonymous telemetry to improve embedchain, defaults to True
+        :type collect_metrics: Optional[bool], optional
+        :param collection_name: Default collection name. It's recommended to use app.db.set_collection_name() instead,
+        defaults to None
+        :type collection_name: Optional[str], optional
+        """
+        self._setup_logging(log_level)
+        self.id = id
+        self.name = name
+        self.collect_metrics = collect_metrics

+ 98 - 0
embedchain/pipeline.py

@@ -0,0 +1,98 @@
+import threading
+import uuid
+
+import yaml
+
+from embedchain.config import PipelineConfig
+from embedchain.embedchain import EmbedChain
+from embedchain.embedder.base import BaseEmbedder
+from embedchain.embedder.openai import OpenAIEmbedder
+from embedchain.factory import EmbedderFactory, VectorDBFactory
+from embedchain.helper.json_serializable import register_deserializable
+from embedchain.vectordb.base import BaseVectorDB
+from embedchain.vectordb.chroma import ChromaDB
+
+
+@register_deserializable
+class Pipeline(EmbedChain):
+    """
+    EmbedChain pipeline lets you create a LLM powered app for your unstructured
+    data by defining a pipeline with your chosen data source, embedding model,
+    and vector database.
+    """
+
+    def __init__(self, config: PipelineConfig = None, db: BaseVectorDB = None, embedding_model: BaseEmbedder = None):
+        """
+        Initialize a new `App` instance.
+
+        :param config: Configuration for the pipeline, defaults to None
+        :type config: PipelineConfig, optional
+        :param db: The database to use for storing and retrieving embeddings, defaults to None
+        :type db: BaseVectorDB, optional
+        :param embedding_model: The embedding model used to calculate embeddings, defaults to None
+        :type embedding_model: BaseEmbedder, optional
+        """
+        super().__init__()
+        self.config = config or PipelineConfig()
+        self.name = self.config.name
+        self.id = self.config.id or str(uuid.uuid4())
+
+        self.embedding_model = embedding_model or OpenAIEmbedder()
+        self.db = db or ChromaDB()
+        self._initialize_db()
+
+        self.user_asks = []  # legacy defaults
+
+        self.s_id = self.config.id or str(uuid.uuid4())
+        self.u_id = self._load_or_generate_user_id()
+
+        thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("pipeline_init",))
+        thread_telemetry.start()
+
+    def _initialize_db(self):
+        """
+        Initialize the database.
+        """
+        self.db._set_embedder(self.embedding_model)
+        self.db._initialize()
+        self.db.set_collection_name(self.name)
+
+    def search(self, query, num_documents=3):
+        """
+        Search for similar documents related to the query in the vector database.
+        """
+        where = {"app_id": self.id}
+        return self.db.query(
+            query,
+            n_results=num_documents,
+            where=where,
+            skip_embedding=False,
+        )
+
+    @classmethod
+    def from_config(cls, yaml_path: str):
+        """
+        Instantiate a Pipeline object from a YAML configuration file.
+
+        :param yaml_path: Path to the YAML configuration file.
+        :type yaml_path: str
+        :return: An instance of the Pipeline class.
+        :rtype: Pipeline
+        """
+        with open(yaml_path, "r") as file:
+            config_data = yaml.safe_load(file)
+
+        pipeline_config_data = config_data.get("pipeline", {})
+        db_config_data = config_data.get("vectordb", {})
+        embedding_model_config_data = config_data.get("embedding_model", {})
+
+        pipeline_config = PipelineConfig(**pipeline_config_data)
+
+        db_provider = db_config_data.get("provider", "chroma")
+        db = VectorDBFactory.create(db_provider, db_config_data.get("config", {}))
+
+        embedding_model_provider = embedding_model_config_data.get("provider", "openai")
+        embedding_model = EmbedderFactory.create(
+            embedding_model_provider, embedding_model_config_data.get("config", {})
+        )
+        return cls(config=pipeline_config, db=db, embedding_model=embedding_model)