Pārlūkot izejas kodu

Add HF endpoint in embedder (#1436)

Dev Khant 1 gadu atpakaļ
vecāks
revīzija
f6ddd5ffc5

+ 1 - 0
docs/api-reference/advanced/configuration.mdx

@@ -224,6 +224,7 @@ Alright, let's dive into what each key means in the yaml config above:
         - `model` (String): The specific model used for text embedding, 'text-embedding-ada-002'.
         - `vector_dimension` (Integer): The vector dimension of the embedding model. [Defaults](https://github.com/embedchain/embedchain/blob/main/embedchain/models/vector_dimensions.py)
         - `api_key` (String): The API key for the embedding model.
+        - `endpoint` (String): The endpoint for the HuggingFace embedding model.
         - `deployment_name` (String): The deployment name for the embedding model.
         - `title` (String): The title for the embedding model for Google Embedder.
         - `task_type` (String): The task type for the embedding model for Google Embedder.

+ 10 - 0
embedchain/config/embedder/base.py

@@ -10,6 +10,7 @@ class BaseEmbedderConfig:
         model: Optional[str] = None,
         deployment_name: Optional[str] = None,
         vector_dimension: Optional[int] = None,
+        endpoint: Optional[str] = None,
         api_key: Optional[str] = None,
         api_base: Optional[str] = None,
     ):
@@ -20,9 +21,18 @@ class BaseEmbedderConfig:
         :type model: Optional[str], optional
         :param deployment_name: deployment name for llm embedding model, defaults to None
         :type deployment_name: Optional[str], optional
+        :param vector_dimension: vector dimension of the embedding model, defaults to None
+        :type vector_dimension: Optional[int], optional
+        :param endpoint: endpoint for the embedding model, defaults to None
+        :type endpoint: Optional[str], optional
+        :param api_key: hugginface api key, defaults to None
+        :type api_key: Optional[str], optional
+        :param api_base: huggingface api base, defaults to None
+        :type api_base: Optional[str], optional
         """
         self.model = model
         self.deployment_name = deployment_name
         self.vector_dimension = vector_dimension
+        self.endpoint = endpoint
         self.api_key = api_key
         self.api_base = api_base

+ 15 - 1
embedchain/embedder/huggingface.py

@@ -1,6 +1,8 @@
+import os
 from typing import Optional
 
 from langchain_community.embeddings import HuggingFaceEmbeddings
+from langchain_community.embeddings.huggingface import HuggingFaceInferenceAPIEmbeddings
 
 from embedchain.config import BaseEmbedderConfig
 from embedchain.embedder.base import BaseEmbedder
@@ -11,7 +13,19 @@ class HuggingFaceEmbedder(BaseEmbedder):
     def __init__(self, config: Optional[BaseEmbedderConfig] = None):
         super().__init__(config=config)
 
-        embeddings = HuggingFaceEmbeddings(model_name=self.config.model)
+        if self.config.endpoint:
+            if not self.config.api_key and "HUGGINGFACE_ACCESS_TOKEN" not in os.environ:
+                raise ValueError(
+                    "Please set the HUGGINGFACE_ACCESS_TOKEN environment variable or pass API Key in the config."
+                )
+
+            embeddings = HuggingFaceInferenceAPIEmbeddings(
+                model_name=self.config.model,
+                api_url=self.config.endpoint,
+                api_key=self.config.api_key or os.getenv("HUGGINGFACE_ACCESS_TOKEN"),
+            )
+        else:
+            embeddings = HuggingFaceEmbeddings(model_name=self.config.model)
         embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
         self.set_embedding_fn(embedding_fn=embedding_fn)
 

+ 2 - 1
embedchain/utils/misc.py

@@ -441,7 +441,7 @@ def validate_config(config_data):
                     Optional("local"): bool,
                     Optional("base_url"): str,
                     Optional("default_headers"): dict,
-                    Optional("api_version"): Or(str, datetime.date)
+                    Optional("api_version"): Or(str, datetime.date),
                 },
             },
             Optional("vectordb"): {
@@ -473,6 +473,7 @@ def validate_config(config_data):
                     Optional("task_type"): str,
                     Optional("vector_dimension"): int,
                     Optional("base_url"): str,
+                    Optional("endpoint"): str,
                 },
             },
             Optional("embedding_model"): {