瀏覽代碼

[Feature] Add support for deploying local pipelines to Embedchain platform (#847)

Deshraj Yadav 1 年之前
父節點
當前提交
3979480532

+ 1 - 0
embedchain/__init__.py

@@ -3,5 +3,6 @@ import importlib.metadata
 __version__ = importlib.metadata.version(__package__ or __name__)
 
 from embedchain.apps.app import App  # noqa: F401
+from embedchain.client import Client  # noqa: F401
 from embedchain.pipeline import Pipeline  # noqa: F401
 from embedchain.vectordb.chroma import ChromaDB  # noqa: F401

+ 102 - 0
embedchain/client.py

@@ -0,0 +1,102 @@
+import json
+import logging
+import os
+import uuid
+
+import requests
+
+from embedchain.embedchain import CONFIG_DIR, CONFIG_FILE
+
+
+class Client:
+    def __init__(self, api_key=None, host="https://apiv2.embedchain.ai"):
+        self.config_data = self.load_config()
+        self.host = host
+
+        if api_key:
+            if self.check(api_key):
+                self.api_key = api_key
+                self.save()
+            else:
+                raise ValueError(
+                    "Invalid API key provided. You can find your API key on https://app.embedchain.ai/settings/keys."
+                )
+        else:
+            if "api_key" in self.config_data:
+                self.api_key = self.config_data["api_key"]
+                logging.info("API key loaded successfully!")
+            else:
+                raise ValueError(
+                    "You are not logged in. Please obtain an API key from https://app.embedchain.ai/settings/keys/"
+                )
+
+    @classmethod
+    def setup_dir(self):
+        """
+        Loads the user id from the config file if it exists, otherwise generates a new
+        one and saves it to the config file.
+
+        :return: user id
+        :rtype: str
+        """
+        if not os.path.exists(CONFIG_DIR):
+            os.makedirs(CONFIG_DIR)
+
+        if os.path.exists(CONFIG_FILE):
+            with open(CONFIG_FILE, "r") as f:
+                data = json.load(f)
+                if "user_id" in data:
+                    return data["user_id"]
+
+        u_id = str(uuid.uuid4())
+        with open(CONFIG_FILE, "w") as f:
+            json.dump({"user_id": u_id}, f)
+
+    @classmethod
+    def load_config(cls):
+        if not os.path.exists(CONFIG_FILE):
+            cls.setup_dir()
+
+        with open(CONFIG_FILE, "r") as config_file:
+            return json.load(config_file)
+
+    def save(self):
+        self.config_data["api_key"] = self.api_key
+        with open(CONFIG_FILE, "w") as config_file:
+            json.dump(self.config_data, config_file, indent=4)
+
+        logging.info("API key saved successfully!")
+
+    def clear(self):
+        if "api_key" in self.config_data:
+            del self.config_data["api_key"]
+            with open(CONFIG_FILE, "w") as config_file:
+                json.dump(self.config_data, config_file, indent=4)
+            self.api_key = None
+            logging.info("API key deleted successfully!")
+        else:
+            logging.warning("API key not found in the configuration file.")
+
+    def update(self, api_key):
+        if self.check(api_key):
+            self.api_key = api_key
+            self.save()
+            logging.info("API key updated successfully!")
+        else:
+            logging.warning("Invalid API key provided. API key not updated.")
+
+    def check(self, api_key):
+        validation_url = f"{self.host}/api/v1/accounts/api_keys/validate/"
+        response = requests.post(validation_url, headers={"Authorization": f"Token {api_key}"})
+        if response.status_code == 200:
+            return True
+        else:
+            logging.warning(f"Response from API: {response.text}")
+            logging.warning("Invalid API key. Unable to validate.")
+            return False
+
+    def get(self):
+        return self.api_key
+
+    def __str__(self):
+        return self.api_key

+ 0 - 1
embedchain/config/__init__.py

@@ -2,7 +2,6 @@
 
 from .add_config import AddConfig, ChunkerConfig
 from .apps.app_config import AppConfig
-from .pipeline_config import PipelineConfig
 from .base_config import BaseConfig
 from .embedder.base import BaseEmbedderConfig
 from .embedder.base import BaseEmbedderConfig as EmbedderConfig

+ 56 - 12
embedchain/embedchain.py

@@ -3,6 +3,7 @@ import importlib.metadata
 import json
 import logging
 import os
+import sqlite3
 import threading
 import uuid
 from pathlib import Path
@@ -32,6 +33,7 @@ ABS_PATH = os.getcwd()
 HOME_DIR = str(Path.home())
 CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain")
 CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
+SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db")
 
 
 class EmbedChain(JSONSerializable):
@@ -89,6 +91,27 @@ class EmbedChain(JSONSerializable):
         # Send anonymous telemetry
         self.s_id = self.config.id if self.config.id else str(uuid.uuid4())
         self.u_id = self._load_or_generate_user_id()
+
+        # Establish a connection to the SQLite database
+        self.connection = sqlite3.connect(SQLITE_PATH)
+        self.cursor = self.connection.cursor()
+
+        # Create the 'data_sources' table if it doesn't exist
+        self.cursor.execute(
+            """
+            CREATE TABLE IF NOT EXISTS data_sources (
+                pipeline_id TEXT,
+                hash TEXT,
+                type TEXT,
+                value TEXT,
+                metadata TEXT,
+                is_uploaded INTEGER DEFAULT 0,
+                PRIMARY KEY (pipeline_id, hash)
+            )
+        """
+        )
+        self.connection.commit()
+
         # NOTE: Uncomment the next two lines when running tests to see if any test fires a telemetry event.
         # if (self.config.collect_metrics):
         #     raise ConnectionRefusedError("Collection of metrics should not be allowed.")
@@ -163,7 +186,7 @@ class EmbedChain(JSONSerializable):
         :raises ValueError: Invalid data type
         :param dry_run: Optional. A dry run displays the chunks to ensure that the loader and chunker work as intended.
         deafaults to False
-        :return: source_id, a md5-hash of the source, in hexadecimal representation.
+        :return: source_hash, a md5-hash of the source, in hexadecimal representation.
         :rtype: str
         """
         if config is None:
@@ -192,18 +215,40 @@ class EmbedChain(JSONSerializable):
         if not data_type:
             data_type = detect_datatype(source)
 
-        # `source_id` is the hash of the source argument
+        # `source_hash` is the md5 hash of the source argument
         hash_object = hashlib.md5(str(source).encode("utf-8"))
-        source_id = hash_object.hexdigest()
+        source_hash = hash_object.hexdigest()
+
+        # Check if the data hash already exists, if so, skip the addition
+        self.cursor.execute(
+            "SELECT 1 FROM data_sources WHERE hash = ? AND pipeline_id = ?", (source_hash, self.config.id)
+        )
+        existing_data = self.cursor.fetchone()
+
+        if existing_data:
+            print(f"Data with hash {source_hash} already exists. Skipping addition.")
+            return source_hash
 
         data_formatter = DataFormatter(data_type, config)
         self.user_asks.append([source, data_type.value, metadata])
         documents, metadatas, _ids, new_chunks = self.load_and_embed(
-            data_formatter.loader, data_formatter.chunker, source, metadata, source_id, dry_run
+            data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run
         )
         if data_type in {DataType.DOCS_SITE}:
             self.is_docs_site_instance = True
 
+        # Insert the data into the 'data' table
+        self.cursor.execute(
+            """
+            INSERT INTO data_sources (hash, pipeline_id, type, value, metadata)
+            VALUES (?, ?, ?, ?, ?)
+        """,
+            (source_hash, self.config.id, data_type.value, str(source), json.dumps(metadata)),
+        )
+
+        # Commit the transaction
+        self.connection.commit()
+
         if dry_run:
             data_chunks_info = {"chunks": documents, "metadata": metadatas, "count": len(documents), "type": data_type}
             logging.debug(f"Dry run info : {data_chunks_info}")
@@ -218,7 +263,7 @@ class EmbedChain(JSONSerializable):
             thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("add", extra_metadata))
             thread_telemetry.start()
 
-        return source_id
+        return source_hash
 
     def add_local(
         self,
@@ -245,7 +290,7 @@ class EmbedChain(JSONSerializable):
         :param config: The `AddConfig` instance to use as configuration options., defaults to None
         :type config: Optional[AddConfig], optional
         :raises ValueError: Invalid data type
-        :return: source_id, a md5-hash of the source, in hexadecimal representation.
+        :return: source_hash, a md5-hash of the source, in hexadecimal representation.
         :rtype: str
         """
         logging.warning(
@@ -313,7 +358,7 @@ class EmbedChain(JSONSerializable):
         chunker: BaseChunker,
         src: Any,
         metadata: Optional[Dict[str, Any]] = None,
-        source_id: Optional[str] = None,
+        source_hash: Optional[str] = None,
         dry_run=False,
     ):
         """
@@ -324,7 +369,7 @@ class EmbedChain(JSONSerializable):
         :param src: The data to be handled by the loader. Can be a URL for
         remote sources or local content for local loaders.
         :param metadata: Optional. Metadata associated with the data source.
-        :param source_id: Hexadecimal hash of the source.
+        :param source_hash: Hexadecimal hash of the source.
         :param dry_run: Optional. A dry run returns chunks and doesn't update DB.
         :type dry_run: bool, defaults to False
         :return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
@@ -382,7 +427,7 @@ class EmbedChain(JSONSerializable):
                 m["app_id"] = self.config.id
 
             # Add hashed source
-            m["hash"] = source_id
+            m["hash"] = source_hash
 
             # Note: Metadata is the function argument
             if metadata:
@@ -558,15 +603,14 @@ class EmbedChain(JSONSerializable):
         """
         Resets the database. Deletes all embeddings irreversibly.
         `App` does not have to be reinitialized after using this method.
-
-        DEPRECATED IN FAVOR OF `db.reset()`
         """
         # Send anonymous telemetry
         thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("reset",))
         thread_telemetry.start()
 
-        logging.warning("DEPRECATION WARNING: Please use `app.db.reset()` instead of `App.reset()`.")
         self.db.reset()
+        self.cursor.execute("DELETE FROM data_sources WHERE pipeline_id = ?", (self.config.id,))
+        self.connection.commit()
 
     @retry(stop=stop_after_attempt(3), wait=wait_fixed(1))
     def _send_telemetry_event(self, method: str, extra_metadata: Optional[dict] = None):

+ 228 - 21
embedchain/pipeline.py

@@ -1,17 +1,27 @@
-import threading
+import ast
+import json
+import logging
+import os
+import sqlite3
 import uuid
 
+import requests
 import yaml
+from fastapi import FastAPI, HTTPException
 
+from embedchain import Client
 from embedchain.config import PipelineConfig
-from embedchain.embedchain import EmbedChain
+from embedchain.embedchain import CONFIG_DIR, EmbedChain
 from embedchain.embedder.base import BaseEmbedder
 from embedchain.embedder.openai import OpenAIEmbedder
 from embedchain.factory import EmbedderFactory, VectorDBFactory
 from embedchain.helper.json_serializable import register_deserializable
+from embedchain.llm.base import BaseLlm
 from embedchain.vectordb.base import BaseVectorDB
 from embedchain.vectordb.chroma import ChromaDB
 
+SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db")
+
 
 @register_deserializable
 class Pipeline(EmbedChain):
@@ -21,7 +31,15 @@ class Pipeline(EmbedChain):
     and vector database.
     """
 
-    def __init__(self, config: PipelineConfig = None, db: BaseVectorDB = None, embedding_model: BaseEmbedder = None):
+    def __init__(
+        self,
+        config: PipelineConfig = None,
+        db: BaseVectorDB = None,
+        embedding_model: BaseEmbedder = None,
+        llm: BaseLlm = None,
+        yaml_path: str = None,
+        log_level=logging.INFO,
+    ):
         """
         Initialize a new `App` instance.
 
@@ -32,42 +50,196 @@ class Pipeline(EmbedChain):
         :param embedding_model: The embedding model used to calculate embeddings, defaults to None
         :type embedding_model: BaseEmbedder, optional
         """
-        super().__init__()
+        logging.basicConfig(level=log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
+        self.logger = logging.getLogger(__name__)
+        # Store the yaml config as an attribute to be able to send it
+        self.yaml_config = None
+        self.client = None
+        if yaml_path:
+            with open(yaml_path, "r") as file:
+                config_data = yaml.safe_load(file)
+                self.yaml_config = config_data
+
         self.config = config or PipelineConfig()
         self.name = self.config.name
-        self.id = self.config.id or str(uuid.uuid4())
+        self.local_id = self.config.id or str(uuid.uuid4())
 
         self.embedding_model = embedding_model or OpenAIEmbedder()
         self.db = db or ChromaDB()
-        self._initialize_db()
+        self.llm = llm or None
+        self._init_db()
 
-        self.user_asks = []  # legacy defaults
-
-        self.s_id = self.config.id or str(uuid.uuid4())
+        # setup user id and directory
         self.u_id = self._load_or_generate_user_id()
 
-        thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("pipeline_init",))
-        thread_telemetry.start()
+        # Establish a connection to the SQLite database
+        self.connection = sqlite3.connect(SQLITE_PATH)
+        self.cursor = self.connection.cursor()
+
+        # Create the 'data_sources' table if it doesn't exist
+        self.cursor.execute(
+            """
+            CREATE TABLE IF NOT EXISTS data_sources (
+                pipeline_id TEXT,
+                hash TEXT,
+                type TEXT,
+                value TEXT,
+                metadata TEXT
+                is_uploaded INTEGER DEFAULT 0,
+                PRIMARY KEY (pipeline_id, hash)
+            )
+        """
+        )
+        self.connection.commit()
+
+        self.user_asks = []  # legacy defaults
 
-    def _initialize_db(self):
+    def _init_db(self):
         """
         Initialize the database.
         """
         self.db._set_embedder(self.embedding_model)
         self.db._initialize()
-        self.db.set_collection_name(self.name)
+        self.db.set_collection_name(self.db.config.collection_name)
+
+    def _init_client(self):
+        """
+        Initialize the client.
+        """
+        config = Client.load_config()
+        if config.get("api_key"):
+            self.client = Client()
+        else:
+            api_key = input("Enter API key. You can find the API key at https://app.embedchain.ai/settings/keys/ \n")
+            self.client = Client(api_key=api_key)
+
+    def _create_pipeline(self):
+        """
+        Create a pipeline on the platform.
+        """
+        print("Creating pipeline on the platform...")
+        # self.yaml_config is a dict. Pass it inside the key 'yaml_config' to the backend
+        payload = {
+            "yaml_config": json.dumps(self.yaml_config),
+            "name": self.name,
+            "local_id": self.local_id,
+        }
+        url = f"{self.client.host}/api/v1/pipelines/cli/create/"
+        r = requests.post(
+            url,
+            json=payload,
+            headers={"Authorization": f"Token {self.client.api_key}"},
+        )
+        if r.status_code not in [200, 201]:
+            raise Exception(f"Error occurred while creating pipeline. Response from API: {r.text}")
+
+        print(f"Pipeline created. link: https://app.embedchain.ai/pipelines/{r.json()['id']}")
+        return r.json()
+
+    def _get_presigned_url(self, data_type, data_value):
+        payload = {"data_type": data_type, "data_value": data_value}
+        r = requests.post(
+            f"{self.client.host}/api/v1/pipelines/{self.id}/cli/presigned_url/",
+            json=payload,
+            headers={"Authorization": f"Token {self.client.api_key}"},
+        )
+        r.raise_for_status()
+        return r.json()
 
     def search(self, query, num_documents=3):
         """
         Search for similar documents related to the query in the vector database.
         """
-        where = {"app_id": self.id}
-        return self.db.query(
-            query,
-            n_results=num_documents,
-            where=where,
-            skip_embedding=False,
+        # TODO: Search will call the endpoint rather than fetching the data from the db itself when deploy=True.
+        if self.deploy is False:
+            where = {"app_id": self.local_id}
+            return self.db.query(
+                query,
+                n_results=num_documents,
+                where=where,
+                skip_embedding=False,
+            )
+        else:
+            # Make API call to the backend to get the results
+            NotImplementedError("Search is not implemented yet for the prod mode.")
+
+    def _upload_file_to_presigned_url(self, presigned_url, file_path):
+        try:
+            with open(file_path, "rb") as file:
+                response = requests.put(presigned_url, data=file)
+                response.raise_for_status()
+                return response.status_code == 200
+        except Exception as e:
+            self.logger.exception(f"Error occurred during file upload: {str(e)}")
+            return False
+
+    def _upload_data_to_pipeline(self, data_type, data_value, metadata=None):
+        payload = {
+            "data_type": data_type,
+            "data_value": data_value,
+            "metadata": metadata,
+        }
+        return self._send_api_request(f"/api/v1/pipelines/{self.id}/cli/add/", payload)
+
+    def _send_api_request(self, endpoint, payload):
+        url = f"{self.client.host}{endpoint}"
+        headers = {"Authorization": f"Token {self.client.api_key}"}
+        response = requests.post(url, json=payload, headers=headers)
+        response.raise_for_status()
+        return response
+
+    def _process_and_upload_data(self, data_hash, data_type, data_value):
+        if os.path.isabs(data_value):
+            presigned_url_data = self._get_presigned_url(data_type, data_value)
+            presigned_url = presigned_url_data["presigned_url"]
+            s3_key = presigned_url_data["s3_key"]
+            if self._upload_file_to_presigned_url(presigned_url, file_path=data_value):
+                data_value = presigned_url
+                metadata = {"file_path": data_value, "s3_key": s3_key}
+            else:
+                self.logger.error(f"File upload failed for hash: {data_hash}")
+                return False
+        else:
+            if data_type == "qna_pair":
+                data_value = list(ast.literal_eval(data_value))
+            metadata = {}
+
+        try:
+            self._upload_data_to_pipeline(data_type, data_value, metadata)
+            self._mark_data_as_uploaded(data_hash)
+            self.logger.info(f"Data of type {data_type} uploaded successfully.")
+            return True
+        except Exception as e:
+            self.logger.error(f"Error occurred during data upload: {str(e)}")
+            return False
+
+    def _mark_data_as_uploaded(self, data_hash):
+        self.cursor.execute(
+            "UPDATE data_sources SET is_uploaded = 1 WHERE hash = ? AND pipeline_id = ? AND is_uploaded = 0",
+            (data_hash, self.local_id),
         )
+        self.connection.commit()
+
+    def deploy(self):
+        try:
+            if self.client is None:
+                self._init_client()
+
+            pipeline_data = self._create_pipeline()
+            self.id = pipeline_data["id"]
+
+            results = self.cursor.execute(
+                "SELECT * FROM data_sources WHERE pipeline_id = ? AND is_uploaded = 0", (self.local_id,)
+            ).fetchall()
+
+            for result in results:
+                data_hash, data_type, data_value = result[0], result[2], result[3]
+                if self._process_and_upload_data(data_hash, data_type, data_value):
+                    self.logger.info(f"Data with hash {data_hash} uploaded successfully.")
+
+        except Exception as e:
+            self.logger.exception(f"Error occurred during deployment: {str(e)}")
+            raise HTTPException(status_code=500, detail="Error occurred during deployment.")
 
     @classmethod
     def from_config(cls, yaml_path: str):
@@ -82,7 +254,7 @@ class Pipeline(EmbedChain):
         with open(yaml_path, "r") as file:
             config_data = yaml.safe_load(file)
 
-        pipeline_config_data = config_data.get("pipeline", {})
+        pipeline_config_data = config_data.get("pipeline", {}).get("config", {})
         db_config_data = config_data.get("vectordb", {})
         embedding_model_config_data = config_data.get("embedding_model", {})
 
@@ -95,4 +267,39 @@ class Pipeline(EmbedChain):
         embedding_model = EmbedderFactory.create(
             embedding_model_provider, embedding_model_config_data.get("config", {})
         )
-        return cls(config=pipeline_config, db=db, embedding_model=embedding_model)
+        return cls(
+            config=pipeline_config,
+            db=db,
+            embedding_model=embedding_model,
+            yaml_path=yaml_path,
+        )
+
+    def start(self, host="0.0.0.0", port=8000):
+        app = FastAPI()
+
+        @app.post("/add")
+        async def add_document(data_value: str, data_type: str = None):
+            """
+            Add a document to the pipeline.
+            """
+            try:
+                document = {"data_value": data_value, "data_type": data_type}
+                self.add(document)
+                return {"message": "Document added successfully"}
+            except Exception as e:
+                raise HTTPException(status_code=500, detail=str(e))
+
+        @app.post("/query")
+        async def query_documents(query: str, num_documents: int = 3):
+            """
+            Query for similar documents in the pipeline.
+            """
+            try:
+                results = self.search(query, num_documents)
+                return results
+            except Exception as e:
+                raise HTTPException(status_code=500, detail=str(e))
+
+        import uvicorn
+
+        uvicorn.run(app, host=host, port=port)

+ 16 - 0
tests/conftest.py

@@ -0,0 +1,16 @@
+import os
+
+import pytest
+
+
+def clean_db():
+    db_path = os.path.expanduser("~/.embedchain/embedchain.db")
+    if os.path.exists(db_path):
+        os.remove(db_path)
+
+
+@pytest.fixture
+def setup():
+    clean_db()
+    yield
+    clean_db()

+ 8 - 7
tests/embedchain/test_add.py

@@ -16,19 +16,20 @@ def app(mocker):
 
 
 def test_add(app):
-    app.add("https://example.com", metadata={"meta": "meta-data"})
-    assert app.user_asks == [["https://example.com", "web_page", {"meta": "meta-data"}]]
+    app.add("https://example.com", metadata={"foo": "bar"})
+    assert app.user_asks == [["https://example.com", "web_page", {"foo": "bar"}]]
 
 
-def test_add_sitemap(app):
-    app.add("https://www.google.com/sitemap.xml", metadata={"meta": "meta-data"})
-    assert app.user_asks == [["https://www.google.com/sitemap.xml", "sitemap", {"meta": "meta-data"}]]
+# TODO: Make this test faster by generating a sitemap locally rather than using a remote one
+# def test_add_sitemap(app):
+#     app.add("https://www.google.com/sitemap.xml", metadata={"foo": "bar"})
+#     assert app.user_asks == [["https://www.google.com/sitemap.xml", "sitemap", {"foo": "bar"}]]
 
 
 def test_add_forced_type(app):
     data_type = "text"
-    app.add("https://example.com", data_type=data_type, metadata={"meta": "meta-data"})
-    assert app.user_asks == [["https://example.com", data_type, {"meta": "meta-data"}]]
+    app.add("https://example.com", data_type=data_type, metadata={"foo": "bar"})
+    assert app.user_asks == [["https://example.com", data_type, {"foo": "bar"}]]
 
 
 def test_dry_run(app):

+ 53 - 0
tests/test_client.py

@@ -0,0 +1,53 @@
+import pytest
+
+from embedchain import Client
+
+
+class TestClient:
+    @pytest.fixture
+    def mock_requests_post(self, mocker):
+        return mocker.patch("embedchain.client.requests.post")
+
+    def test_valid_api_key(self, mock_requests_post):
+        mock_requests_post.return_value.status_code = 200
+        client = Client(api_key="valid_api_key")
+        assert client.check("valid_api_key") is True
+
+    def test_invalid_api_key(self, mock_requests_post):
+        mock_requests_post.return_value.status_code = 401
+        with pytest.raises(ValueError):
+            Client(api_key="invalid_api_key")
+
+    def test_update_valid_api_key(self, mock_requests_post):
+        mock_requests_post.return_value.status_code = 200
+        client = Client(api_key="valid_api_key")
+        client.update("new_valid_api_key")
+        assert client.get() == "new_valid_api_key"
+
+    def test_clear_api_key(self, mock_requests_post):
+        mock_requests_post.return_value.status_code = 200
+        client = Client(api_key="valid_api_key")
+        client.clear()
+        assert client.get() is None
+
+    def test_save_api_key(self, mock_requests_post):
+        mock_requests_post.return_value.status_code = 200
+        api_key_to_save = "valid_api_key"
+        client = Client(api_key=api_key_to_save)
+        client.save()
+        assert client.get() == api_key_to_save
+
+    def test_load_api_key_from_config(self, mocker):
+        mocker.patch("embedchain.Client.load_config", return_value={"api_key": "test_api_key"})
+        client = Client()
+        assert client.get() == "test_api_key"
+
+    def test_load_invalid_api_key_from_config(self, mocker):
+        mocker.patch("embedchain.Client.load_config", return_value={})
+        with pytest.raises(ValueError):
+            Client()
+
+    def test_load_missing_api_key_from_config(self, mocker):
+        mocker.patch("embedchain.Client.load_config", return_value={})
+        with pytest.raises(ValueError):
+            Client()

+ 2 - 1
tests/vectordb/test_chroma_db.py

@@ -1,9 +1,10 @@
 import os
 import shutil
-import pytest
 from unittest.mock import patch
 
+import pytest
 from chromadb.config import Settings
+
 from embedchain import App
 from embedchain.config import AppConfig, ChromaDbConfig
 from embedchain.vectordb.chroma import ChromaDB