Selaa lähdekoodia

Adding model_kwargs for huggingface embedders. (#1450)

Pranav Puranik 1 vuosi sitten
vanhempi
commit
7c24601d0f

+ 1 - 0
docs/api-reference/advanced/configuration.mdx

@@ -229,6 +229,7 @@ Alright, let's dive into what each key means in the yaml config above:
         - `deployment_name` (String): The deployment name for the embedding model.
         - `title` (String): The title for the embedding model for Google Embedder.
         - `task_type` (String): The task type for the embedding model for Google Embedder.
+        - `model_kwargs` (Dict): Used to pass extra arguments to embedders.
 5. `chunker` Section:
     - `chunk_size` (Integer): The size of each chunk of text that is sent to the language model.
     - `chunk_overlap` (Integer): The amount of overlap between each chunk of text.

+ 2 - 0
docs/components/embedding-models.mdx

@@ -192,6 +192,8 @@ embedder:
   provider: huggingface
   config:
     model: 'sentence-transformers/all-mpnet-base-v2'
+    model_kwargs:
+        trust_remote_code: True # Only use if you trust your embedder
 ```
 
 </CodeGroup>

+ 5 - 1
embedchain/config/embedder/base.py

@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Any, Dict, Optional
 
 from embedchain.helpers.json_serializable import register_deserializable
 
@@ -13,6 +13,7 @@ class BaseEmbedderConfig:
         endpoint: Optional[str] = None,
         api_key: Optional[str] = None,
         api_base: Optional[str] = None,
+        model_kwargs: Optional[Dict[str, Any]] = None,
     ):
         """
         Initialize a new instance of an embedder config class.
@@ -29,6 +30,8 @@ class BaseEmbedderConfig:
         :type api_key: Optional[str], optional
         :param api_base: huggingface api base, defaults to None
         :type api_base: Optional[str], optional
+        :param model_kwargs: key-value arguments for the embedding model, defaults a dict inside init.
+        :type model_kwargs: Optional[Dict[str, Any]], defaults a dict inside init.
         """
         self.model = model
         self.deployment_name = deployment_name
@@ -36,3 +39,4 @@ class BaseEmbedderConfig:
         self.endpoint = endpoint
         self.api_key = api_key
         self.api_base = api_base
+        self.model_kwargs = model_kwargs or {}

+ 2 - 1
embedchain/embedder/huggingface.py

@@ -31,7 +31,8 @@ class HuggingFaceEmbedder(BaseEmbedder):
                 huggingfacehub_api_token=self.config.api_key or os.getenv("HUGGINGFACE_ACCESS_TOKEN"),
             )
         else:
-            embeddings = HuggingFaceEmbeddings(model_name=self.config.model)
+            embeddings = HuggingFaceEmbeddings(model_name=self.config.model, model_kwargs=self.config.model_kwargs)
+
         embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
         self.set_embedding_fn(embedding_fn=embedding_fn)
 

+ 1 - 0
embedchain/utils/misc.py

@@ -474,6 +474,7 @@ def validate_config(config_data):
                     Optional("vector_dimension"): int,
                     Optional("base_url"): str,
                     Optional("endpoint"): str,
+                    Optional("model_kwargs"): dict,
                 },
             },
             Optional("embedding_model"): {

+ 18 - 0
tests/embedder/test_huggingface_embedder.py

@@ -0,0 +1,18 @@
+
+from unittest.mock import patch
+from embedchain.config import BaseEmbedderConfig
+from embedchain.embedder.huggingface import HuggingFaceEmbedder
+
+
+def test_huggingface_embedder_with_model(monkeypatch):
+    config = BaseEmbedderConfig(model="test-model", model_kwargs={"param": "value"})
+    with patch('embedchain.embedder.huggingface.HuggingFaceEmbeddings') as mock_embeddings:
+        embedder = HuggingFaceEmbedder(config=config)
+        assert embedder.config.model == "test-model"
+        assert embedder.config.model_kwargs == {"param": "value"}
+        mock_embeddings.assert_called_once_with(
+            model_name="test-model",
+            model_kwargs={"param": "value"}
+        )
+
+