Selaa lähdekoodia

Change HF embedding library (#1440)

Dev Khant 1 vuosi sitten
vanhempi
commit
5070a1d83e
2 muutettua tiedostoa jossa 12 lisäystä ja 6 poistoa
  1. 1 1
      Makefile
  2. 11 5
      embedchain/embedder/huggingface.py

+ 1 - 1
Makefile

@@ -11,7 +11,7 @@ install:
 
 install_all:
 	poetry install --all-extras
-	poetry run pip install pinecone-text pinecone-client langchain-anthropic "unstructured[local-inference, all-docs]" ollama deepgram-sdk==3.2.7 
+	poetry run pip install pinecone-text pinecone-client langchain-anthropic "unstructured[local-inference, all-docs]" ollama deepgram-sdk==3.2.7 langchain-huggingface
 
 install_es:
 	poetry install --extras elasticsearch

+ 11 - 5
embedchain/embedder/huggingface.py

@@ -2,7 +2,14 @@ import os
 from typing import Optional
 
 from langchain_community.embeddings import HuggingFaceEmbeddings
-from langchain_community.embeddings.huggingface import HuggingFaceInferenceAPIEmbeddings
+
+try:
+    from langchain_huggingface import HuggingFaceEndpointEmbeddings
+except ModuleNotFoundError:
+    raise ModuleNotFoundError(
+        "The required dependencies for HuggingFaceHub are not installed."
+        "Please install with `pip install langchain_huggingface`"
+    ) from None
 
 from embedchain.config import BaseEmbedderConfig
 from embedchain.embedder.base import BaseEmbedder
@@ -19,10 +26,9 @@ class HuggingFaceEmbedder(BaseEmbedder):
                     "Please set the HUGGINGFACE_ACCESS_TOKEN environment variable or pass API Key in the config."
                 )
 
-            embeddings = HuggingFaceInferenceAPIEmbeddings(
-                model_name=self.config.model,
-                api_url=self.config.endpoint,
-                api_key=self.config.api_key or os.getenv("HUGGINGFACE_ACCESS_TOKEN"),
+            embeddings = HuggingFaceEndpointEmbeddings(
+                model=self.config.endpoint,
+                huggingfacehub_api_token=self.config.api_key or os.getenv("HUGGINGFACE_ACCESS_TOKEN"),
             )
         else:
             embeddings = HuggingFaceEmbeddings(model_name=self.config.model)