Procházet zdrojové kódy

[Bug fix] Fix issues related to creating pipelines (#850)

Deshraj Yadav před 1 rokem
rodič
revize
413ccb83e6
3 změnil soubory, kde provedl 81 přidání a 13 odebrání
  1. 2 1
      .gitignore
  2. 26 0
      configs/pipeline.yaml
  3. 53 12
      embedchain/pipeline.py

+ 2 - 1
.gitignore

@@ -173,4 +173,5 @@ test-db
 .DS_Store
 
 notebooks/*.yaml
-.ipynb_checkpoints/
+.ipynb_checkpoints/
+!configs/*.yaml

+ 26 - 0
configs/pipeline.yaml

@@ -0,0 +1,26 @@
+pipeline:
+  config:
+    name: Example pipeline
+    id: pipeline-1  # Make sure that id is different every time you create a new pipeline
+
+vectordb:
+  provider: chroma
+  config:
+    collection_name: pipeline-1
+    dir: db
+    allow_reset: true
+
+llm:
+  provider: gpt4all
+  config:
+    model: 'orca-mini-3b.ggmlv3.q4_0.bin'
+    temperature: 0.5
+    max_tokens: 1000
+    top_p: 1
+    stream: false
+
+embedding_model:
+  provider: gpt4all
+  config:
+    model: 'all-MiniLM-L6-v2'
+    deployment_name: null

+ 53 - 12
embedchain/pipeline.py

@@ -14,7 +14,7 @@ from embedchain.config import PipelineConfig
 from embedchain.embedchain import CONFIG_DIR, EmbedChain
 from embedchain.embedder.base import BaseEmbedder
 from embedchain.embedder.openai import OpenAIEmbedder
-from embedchain.factory import EmbedderFactory, VectorDBFactory
+from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
 from embedchain.helper.json_serializable import register_deserializable
 from embedchain.llm.base import BaseLlm
 from embedchain.vectordb.base import BaseVectorDB
@@ -39,6 +39,7 @@ class Pipeline(EmbedChain):
         llm: BaseLlm = None,
         yaml_path: str = None,
         log_level=logging.INFO,
+        auto_deploy: bool = False,
     ):
         """
         Initialize a new `App` instance.
@@ -49,12 +50,26 @@ class Pipeline(EmbedChain):
         :type db: BaseVectorDB, optional
         :param embedding_model: The embedding model used to calculate embeddings, defaults to None
         :type embedding_model: BaseEmbedder, optional
+        :param llm: The LLM model used to calculate embeddings, defaults to None
+        :type llm: BaseLlm, optional
+        :param yaml_path: Path to the YAML configuration file, defaults to None
+        :type yaml_path: str, optional
+        :param log_level: Log level to use, defaults to logging.INFO
+        :type log_level: int, optional
+        :param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
+        :type auto_deploy: bool, optional
+        :raises Exception: If an error occurs while creating the pipeline
         """
         logging.basicConfig(level=log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
         self.logger = logging.getLogger(__name__)
+
+        self.auto_deploy = auto_deploy
+
         # Store the yaml config as an attribute to be able to send it
         self.yaml_config = None
         self.client = None
+        # pipeline_id from the backend
+        self.id = None
         if yaml_path:
             with open(yaml_path, "r") as file:
                 config_data = yaml.safe_load(file)
@@ -84,7 +99,7 @@ class Pipeline(EmbedChain):
                 hash TEXT,
                 type TEXT,
                 value TEXT,
-                metadata TEXT
+                metadata TEXT,
                 is_uploaded INTEGER DEFAULT 0,
                 PRIMARY KEY (pipeline_id, hash)
             )
@@ -93,6 +108,8 @@ class Pipeline(EmbedChain):
         self.connection.commit()
 
         self.user_asks = []  # legacy defaults
+        if self.auto_deploy:
+            self.deploy()
 
     def _init_db(self):
         """
@@ -110,14 +127,16 @@ class Pipeline(EmbedChain):
         if config.get("api_key"):
             self.client = Client()
         else:
-            api_key = input("Enter API key. You can find the API key at https://app.embedchain.ai/settings/keys/ \n")
+            api_key = input(
+                "Enter Embedchain API key. You can find the API key at https://app.embedchain.ai/settings/keys/ \n"
+            )
             self.client = Client(api_key=api_key)
 
     def _create_pipeline(self):
         """
         Create a pipeline on the platform.
         """
-        print("Creating pipeline on the platform...")
+        print("🛠️ Creating pipeline on the platform...")
         # self.yaml_config is a dict. Pass it inside the key 'yaml_config' to the backend
         payload = {
             "yaml_config": json.dumps(self.yaml_config),
@@ -133,7 +152,9 @@ class Pipeline(EmbedChain):
         if r.status_code not in [200, 201]:
             raise Exception(f"Error occurred while creating pipeline. Response from API: {r.text}")
 
-        print(f"Pipeline created. link: https://app.embedchain.ai/pipelines/{r.json()['id']}")
+        print(
+            f"🎉🎉🎉 Pipeline created successfully! View your pipeline: https://app.embedchain.ai/pipelines/{r.json()['id']}\n"  # noqa: E501
+        )
         return r.json()
 
     def _get_presigned_url(self, data_type, data_value):
@@ -151,7 +172,7 @@ class Pipeline(EmbedChain):
         Search for similar documents related to the query in the vector database.
         """
         # TODO: Search will call the endpoint rather than fetching the data from the db itself when deploy=True.
-        if self.deploy is False:
+        if self.id is None:
             where = {"app_id": self.local_id}
             return self.db.query(
                 query,
@@ -171,6 +192,7 @@ class Pipeline(EmbedChain):
                 return response.status_code == 200
         except Exception as e:
             self.logger.exception(f"Error occurred during file upload: {str(e)}")
+            print("❌ Error occurred during file upload!")
             return False
 
     def _upload_data_to_pipeline(self, data_type, data_value, metadata=None):
@@ -179,7 +201,14 @@ class Pipeline(EmbedChain):
             "data_value": data_value,
             "metadata": metadata,
         }
-        return self._send_api_request(f"/api/v1/pipelines/{self.id}/cli/add/", payload)
+        try:
+            self._send_api_request(f"/api/v1/pipelines/{self.id}/cli/add/", payload)
+            # print the local file path if user tries to upload a local file
+            printed_value = metadata.get("file_path") if metadata.get("file_path") else data_value
+            print(f"✅ Data of type: {data_type}, value: {printed_value} added successfully.")
+        except Exception as e:
+            self.logger.error(f"Error occurred during data upload: {str(e)}")
+            print(f"❌ Error occurred during data upload for type {data_type}!")
 
     def _send_api_request(self, endpoint, payload):
         url = f"{self.client.host}{endpoint}"
@@ -194,8 +223,8 @@ class Pipeline(EmbedChain):
             presigned_url = presigned_url_data["presigned_url"]
             s3_key = presigned_url_data["s3_key"]
             if self._upload_file_to_presigned_url(presigned_url, file_path=data_value):
-                data_value = presigned_url
                 metadata = {"file_path": data_value, "s3_key": s3_key}
+                data_value = presigned_url
             else:
                 self.logger.error(f"File upload failed for hash: {data_hash}")
                 return False
@@ -207,10 +236,10 @@ class Pipeline(EmbedChain):
         try:
             self._upload_data_to_pipeline(data_type, data_value, metadata)
             self._mark_data_as_uploaded(data_hash)
-            self.logger.info(f"Data of type {data_type} uploaded successfully.")
             return True
         except Exception as e:
             self.logger.error(f"Error occurred during data upload: {str(e)}")
+            print(f"❌ Error occurred during data upload for hash {data_hash}!")
             return False
 
     def _mark_data_as_uploaded(self, data_hash):
@@ -232,22 +261,25 @@ class Pipeline(EmbedChain):
                 "SELECT * FROM data_sources WHERE pipeline_id = ? AND is_uploaded = 0", (self.local_id,)
             ).fetchall()
 
+            if len(results) > 0:
+                print("🛠️ Adding data to your pipeline...")
             for result in results:
                 data_hash, data_type, data_value = result[0], result[2], result[3]
-                if self._process_and_upload_data(data_hash, data_type, data_value):
-                    self.logger.info(f"Data with hash {data_hash} uploaded successfully.")
+                self._process_and_upload_data(data_hash, data_type, data_value)
 
         except Exception as e:
             self.logger.exception(f"Error occurred during deployment: {str(e)}")
             raise HTTPException(status_code=500, detail="Error occurred during deployment.")
 
     @classmethod
-    def from_config(cls, yaml_path: str):
+    def from_config(cls, yaml_path: str, auto_deploy: bool = False):
         """
         Instantiate a Pipeline object from a YAML configuration file.
 
         :param yaml_path: Path to the YAML configuration file.
         :type yaml_path: str
+        :param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
+        :type auto_deploy: bool, optional
         :return: An instance of the Pipeline class.
         :rtype: Pipeline
         """
@@ -257,21 +289,30 @@ class Pipeline(EmbedChain):
         pipeline_config_data = config_data.get("pipeline", {}).get("config", {})
         db_config_data = config_data.get("vectordb", {})
         embedding_model_config_data = config_data.get("embedding_model", {})
+        llm_config_data = config_data.get("llm", {})
 
         pipeline_config = PipelineConfig(**pipeline_config_data)
 
         db_provider = db_config_data.get("provider", "chroma")
         db = VectorDBFactory.create(db_provider, db_config_data.get("config", {}))
 
+        if llm_config_data:
+            llm_provider = llm_config_data.get("provider", "openai")
+            llm = LlmFactory.create(llm_provider, llm_config_data.get("config", {}))
+        else:
+            llm = None
+
         embedding_model_provider = embedding_model_config_data.get("provider", "openai")
         embedding_model = EmbedderFactory.create(
             embedding_model_provider, embedding_model_config_data.get("config", {})
         )
         return cls(
             config=pipeline_config,
+            llm=llm,
             db=db,
             embedding_model=embedding_model,
             yaml_path=yaml_path,
+            auto_deploy=auto_deploy,
         )
 
     def start(self, host="0.0.0.0", port=8000):