Procházet zdrojové kódy

Add support for Hugging Face Inference Endpoint as LLM (#1143)

Madison Ebersole před 1 rokem
rodič
revize
62c0c52e31

+ 43 - 0
docs/components/llms.mdx

@@ -494,6 +494,49 @@ llm:
 ```
 </CodeGroup>
 
+### Custom Endpoints
+
+
+You can also use [Hugging Face Inference Endpoints](https://huggingface.co/docs/inference-endpoints/index#-inference-endpoints) to access custom endpoints. First, set the `HUGGINGFACE_ACCESS_TOKEN` as above.
+
+Then, load the app using the config yaml file:
+
+<CodeGroup>
+
+```python main.py
+import os
+from embedchain import App
+
+os.environ["HUGGINGFACE_ACCESS_TOKEN"] = "xxx"
+
+# load llm configuration from config.yaml file
+app = App.from_config(config_path="config.yaml")
+```
+
+```yaml config.yaml
+llm:
+  provider: huggingface 
+  config:
+    endpoint: https://api-inference.huggingface.co/models/gpt2 # replace with your personal endpoint
+```
+</CodeGroup>
+
+If your endpoint requires additional parameters, you can pass them in the `model_kwargs` field:
+
+```
+llm:
+  provider: huggingface 
+  config:
+    endpoint: <YOUR_ENDPOINT_URL_HERE>
+    model_kwargs:
+      max_new_tokens: 100
+      temperature: 0.5
+```
+
+Currently only supports `text-generation` and `text2text-generation` for now [[ref](https://api.python.langchain.com/en/latest/llms/langchain_community.llms.huggingface_endpoint.HuggingFaceEndpoint.html?highlight=huggingfaceendpoint#)].
+
+See langchain's [hugging face endpoint](https://python.langchain.com/docs/integrations/chat/huggingface#huggingfaceendpoint) for more information. 
+
 ## Llama2
 
 Llama2 is integrated through [Replicate](https://replicate.com/).  Set `REPLICATE_API_TOKEN` in environment variable which you can obtain from [their platform](https://replicate.com/account/api-tokens).

+ 10 - 1
embedchain/config/llm/base.py

@@ -72,6 +72,8 @@ class BaseLlmConfig(BaseConfig):
         query_type: Optional[str] = None,
         callbacks: Optional[List] = None,
         api_key: Optional[str] = None,
+        endpoint: Optional[str] = None,
+        model_kwargs: Optional[Dict[str, Any]] = {},
     ):
         """
         Initializes a configuration class instance for the LLM.
@@ -105,6 +107,12 @@ class BaseLlmConfig(BaseConfig):
         :type system_prompt: Optional[str], optional
         :param where: A dictionary of key-value pairs to filter the database results., defaults to None
         :type where: Dict[str, Any], optional
+        :param api_key: The api key of the custom endpoint, defaults to None
+        :type api_key: Optional[str], optional
+        :param endpoint: The api url of the custom endpoint, defaults to None
+        :type endpoint: Optional[str], optional
+        :param model_kwargs: A dictionary of key-value pairs to pass to the model, defaults to None
+        :type model_kwargs: Optional[Dict[str, Any]], optional
         :param callbacks: Langchain callback functions to use, defaults to None
         :type callbacks: Optional[List], optional
         :raises ValueError: If the template is not valid as template should
@@ -132,7 +140,8 @@ class BaseLlmConfig(BaseConfig):
         self.query_type = query_type
         self.callbacks = callbacks
         self.api_key = api_key
-
+        self.endpoint = endpoint
+        self.model_kwargs = model_kwargs
         if type(prompt) is str:
             prompt = Template(prompt)
 

+ 20 - 0
embedchain/llm/huggingface.py

@@ -3,6 +3,7 @@ import logging
 import os
 from typing import Optional
 
+from langchain.llms.huggingface_endpoint import HuggingFaceEndpoint
 from langchain.llms.huggingface_hub import HuggingFaceHub
 
 from embedchain.config import BaseLlmConfig
@@ -33,6 +34,15 @@ class HuggingFaceLlm(BaseLlm):
 
     @staticmethod
     def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
+        if config.model:
+            return HuggingFaceLlm._from_model(prompt=prompt, config=config)
+        elif config.endpoint:
+            return HuggingFaceLlm._from_endpoint(prompt=prompt, config=config)
+        else:
+            raise ValueError("Either `model` or `endpoint` must be set")
+
+    @staticmethod
+    def _from_model(prompt: str, config: BaseLlmConfig) -> str:
         model_kwargs = {
             "temperature": config.temperature or 0.1,
             "max_new_tokens": config.max_tokens,
@@ -52,3 +62,13 @@ class HuggingFaceLlm(BaseLlm):
         )
 
         return llm(prompt)
+
+    @staticmethod
+    def _from_endpoint(prompt: str, config: BaseLlmConfig) -> str:
+        llm = HuggingFaceEndpoint(
+            huggingfacehub_api_token=os.environ["HUGGINGFACE_ACCESS_TOKEN"],
+            endpoint_url=config.endpoint,
+            task="text-generation",
+            model_kwargs=config.model_kwargs,
+        )
+        return llm(prompt)

+ 1 - 0
embedchain/utils/misc.py

@@ -415,6 +415,7 @@ def validate_config(config_data):
                     Optional("where"): dict,
                     Optional("query_type"): str,
                     Optional("api_key"): str,
+                    Optional("endpoint"): str,
                 },
             },
             Optional("vectordb"): {

+ 19 - 0
tests/llm/test_huggingface.py

@@ -15,6 +15,14 @@ def huggingface_llm_config():
     os.environ.pop("HUGGINGFACE_ACCESS_TOKEN")
 
 
+@pytest.fixture
+def huggingface_endpoint_config():
+    os.environ["HUGGINGFACE_ACCESS_TOKEN"] = "test_access_token"
+    config = BaseLlmConfig(endpoint="https://api-inference.huggingface.co/models/gpt2", model_kwargs={"device": "cpu"})
+    yield config
+    os.environ.pop("HUGGINGFACE_ACCESS_TOKEN")
+
+
 def test_init_raises_value_error_without_api_key(mocker):
     mocker.patch.dict(os.environ, clear=True)
     with pytest.raises(ValueError):
@@ -61,3 +69,14 @@ def test_hugging_face_mock(huggingface_llm_config, mocker):
 
     assert answer == "Test answer"
     mock_llm_instance.assert_called_once_with("Test query")
+
+
+def test_custom_endpoint(huggingface_endpoint_config, mocker):
+    mock_llm_instance = mocker.Mock(return_value="Test answer")
+    mocker.patch("embedchain.llm.huggingface.HuggingFaceEndpoint", return_value=mock_llm_instance)
+
+    llm = HuggingFaceLlm(huggingface_endpoint_config)
+    answer = llm.get_llm_model_answer("Test query")
+
+    assert answer == "Test answer"
+    mock_llm_instance.assert_called_once_with("Test query")