|
@@ -4,6 +4,7 @@ import logging
|
|
|
import os
|
|
|
import sqlite3
|
|
|
import uuid
|
|
|
+from typing import Any, Dict, Optional
|
|
|
|
|
|
import requests
|
|
|
import yaml
|
|
@@ -19,7 +20,7 @@ from embedchain.helpers.json_serializable import register_deserializable
|
|
|
from embedchain.llm.base import BaseLlm
|
|
|
from embedchain.llm.openai import OpenAILlm
|
|
|
from embedchain.telemetry.posthog import AnonymousTelemetry
|
|
|
-from embedchain.utils import validate_yaml_config
|
|
|
+from embedchain.utils import validate_config
|
|
|
from embedchain.vectordb.base import BaseVectorDB
|
|
|
from embedchain.vectordb.chroma import ChromaDB
|
|
|
|
|
@@ -43,7 +44,7 @@ class Pipeline(EmbedChain):
|
|
|
db: BaseVectorDB = None,
|
|
|
embedding_model: BaseEmbedder = None,
|
|
|
llm: BaseLlm = None,
|
|
|
- yaml_path: str = None,
|
|
|
+ config_data: dict = None,
|
|
|
log_level=logging.WARN,
|
|
|
auto_deploy: bool = False,
|
|
|
chunker: ChunkerConfig = None,
|
|
@@ -59,15 +60,15 @@ class Pipeline(EmbedChain):
|
|
|
:type embedding_model: BaseEmbedder, optional
|
|
|
:param llm: The LLM model used to calculate embeddings, defaults to None
|
|
|
:type llm: BaseLlm, optional
|
|
|
- :param yaml_path: Path to the YAML configuration file, defaults to None
|
|
|
- :type yaml_path: str, optional
|
|
|
+ :param config_data: Config dictionary, defaults to None
|
|
|
+ :type config_data: dict, optional
|
|
|
:param log_level: Log level to use, defaults to logging.WARN
|
|
|
:type log_level: int, optional
|
|
|
:param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
|
|
|
:type auto_deploy: bool, optional
|
|
|
:raises Exception: If an error occurs while creating the pipeline
|
|
|
"""
|
|
|
- if id and yaml_path:
|
|
|
+ if id and config_data:
|
|
|
raise Exception("Cannot provide both id and config. Please provide only one of them.")
|
|
|
|
|
|
if id and name:
|
|
@@ -79,8 +80,8 @@ class Pipeline(EmbedChain):
|
|
|
logging.basicConfig(level=log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
|
|
self.logger = logging.getLogger(__name__)
|
|
|
self.auto_deploy = auto_deploy
|
|
|
- # Store the yaml config as an attribute to be able to send it
|
|
|
- self.yaml_config = None
|
|
|
+ # Store the dict config as an attribute to be able to send it
|
|
|
+ self.config_data = config_data if (config_data and validate_config(config_data)) else None
|
|
|
self.client = None
|
|
|
# pipeline_id from the backend
|
|
|
self.id = None
|
|
@@ -92,11 +93,6 @@ class Pipeline(EmbedChain):
|
|
|
self.name = self.config.name
|
|
|
self.config.id = self.local_id = str(uuid.uuid4()) if self.config.id is None else self.config.id
|
|
|
|
|
|
- if yaml_path:
|
|
|
- with open(yaml_path, "r") as file:
|
|
|
- config_data = yaml.safe_load(file)
|
|
|
- self.yaml_config = config_data
|
|
|
-
|
|
|
if id is not None:
|
|
|
# Init client first since user is trying to fetch the pipeline
|
|
|
# details from the platform
|
|
@@ -187,9 +183,9 @@ class Pipeline(EmbedChain):
|
|
|
Create a pipeline on the platform.
|
|
|
"""
|
|
|
print("🛠️ Creating pipeline on the platform...")
|
|
|
- # self.yaml_config is a dict. Pass it inside the key 'yaml_config' to the backend
|
|
|
+ # self.config_data is a dict. Pass it inside the key 'yaml_config' to the backend
|
|
|
payload = {
|
|
|
- "yaml_config": json.dumps(self.yaml_config),
|
|
|
+ "yaml_config": json.dumps(self.config_data),
|
|
|
"name": self.name,
|
|
|
"local_id": self.local_id,
|
|
|
}
|
|
@@ -346,24 +342,57 @@ class Pipeline(EmbedChain):
|
|
|
self.telemetry.capture(event_name="deploy", properties=self._telemetry_props)
|
|
|
|
|
|
@classmethod
|
|
|
- def from_config(cls, yaml_path: str, auto_deploy: bool = False):
|
|
|
+ def from_config(
|
|
|
+ cls,
|
|
|
+ config_path: Optional[str] = None,
|
|
|
+ config: Optional[Dict[str, Any]] = None,
|
|
|
+ auto_deploy: bool = False,
|
|
|
+ yaml_path: Optional[str] = None,
|
|
|
+ ):
|
|
|
"""
|
|
|
- Instantiate a Pipeline object from a YAML configuration file.
|
|
|
+ Instantiate a Pipeline object from a configuration.
|
|
|
|
|
|
- :param yaml_path: Path to the YAML configuration file.
|
|
|
- :type yaml_path: str
|
|
|
+ :param config_path: Path to the YAML or JSON configuration file.
|
|
|
+ :type config_path: Optional[str]
|
|
|
+ :param config: A dictionary containing the configuration.
|
|
|
+ :type config: Optional[Dict[str, Any]]
|
|
|
:param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
|
|
|
:type auto_deploy: bool, optional
|
|
|
+ :param yaml_path: (Deprecated) Path to the YAML configuration file. Use config_path instead.
|
|
|
+ :type yaml_path: Optional[str]
|
|
|
:return: An instance of the Pipeline class.
|
|
|
:rtype: Pipeline
|
|
|
"""
|
|
|
- with open(yaml_path, "r") as file:
|
|
|
- config_data = yaml.safe_load(file)
|
|
|
+ # Backward compatibility for yaml_path
|
|
|
+ if yaml_path and not config_path:
|
|
|
+ config_path = yaml_path
|
|
|
+
|
|
|
+ if config_path and config:
|
|
|
+ raise ValueError("Please provide only one of config_path or config.")
|
|
|
+
|
|
|
+ config_data = None
|
|
|
+
|
|
|
+ if config_path:
|
|
|
+ file_extension = os.path.splitext(config_path)[1]
|
|
|
+ with open(config_path, "r") as file:
|
|
|
+ if file_extension in [".yaml", ".yml"]:
|
|
|
+ config_data = yaml.safe_load(file)
|
|
|
+ elif file_extension == ".json":
|
|
|
+ config_data = json.load(file)
|
|
|
+ else:
|
|
|
+ raise ValueError("config_path must be a path to a YAML or JSON file.")
|
|
|
+ elif config and isinstance(config, dict):
|
|
|
+ config_data = config
|
|
|
+ else:
|
|
|
+ logging.error(
|
|
|
+ "Please provide either a config file path (YAML or JSON) or a config dictionary. Falling back to defaults because no config is provided.", # noqa: E501
|
|
|
+ )
|
|
|
+ config_data = {}
|
|
|
|
|
|
try:
|
|
|
- validate_yaml_config(config_data)
|
|
|
+ validate_config(config_data)
|
|
|
except Exception as e:
|
|
|
- raise Exception(f"❌ Error occurred while validating the YAML config. Error: {str(e)}")
|
|
|
+ raise Exception(f"Error occurred while validating the config. Error: {str(e)}")
|
|
|
|
|
|
pipeline_config_data = config_data.get("app", {}).get("config", {})
|
|
|
db_config_data = config_data.get("vectordb", {})
|
|
@@ -388,7 +417,7 @@ class Pipeline(EmbedChain):
|
|
|
)
|
|
|
|
|
|
# Send anonymous telemetry
|
|
|
- event_properties = {"init_type": "yaml_config"}
|
|
|
+ event_properties = {"init_type": "config_data"}
|
|
|
AnonymousTelemetry().capture(event_name="init", properties=event_properties)
|
|
|
|
|
|
return cls(
|
|
@@ -396,7 +425,7 @@ class Pipeline(EmbedChain):
|
|
|
llm=llm,
|
|
|
db=db,
|
|
|
embedding_model=embedding_model,
|
|
|
- yaml_path=yaml_path,
|
|
|
+ config_data=config_data,
|
|
|
auto_deploy=auto_deploy,
|
|
|
chunker=chunker_config_data,
|
|
|
)
|