Selaa lähdekoodia

Add support for loading api_key from config or env variable (#1421)

Dev Khant 1 vuosi sitten
vanhempi
commit
2855f1635b

+ 4 - 5
embedchain/llm/anthropic.py

@@ -17,18 +17,17 @@ logger = logging.getLogger(__name__)
 @register_deserializable
 class AnthropicLlm(BaseLlm):
     def __init__(self, config: Optional[BaseLlmConfig] = None):
-        if "ANTHROPIC_API_KEY" not in os.environ:
-            raise ValueError("Please set the ANTHROPIC_API_KEY environment variable.")
         super().__init__(config=config)
+        if not self.config.api_key and "ANTHROPIC_API_KEY" not in os.environ:
+            raise ValueError("Please set the ANTHROPIC_API_KEY environment variable or pass it in the config.")
 
     def get_llm_model_answer(self, prompt):
         return AnthropicLlm._get_answer(prompt=prompt, config=self.config)
 
     @staticmethod
     def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
-        chat = ChatAnthropic(
-            anthropic_api_key=os.environ["ANTHROPIC_API_KEY"], temperature=config.temperature, model_name=config.model
-        )
+        api_key = config.api_key or os.getenv("ANTHROPIC_API_KEY")
+        chat = ChatAnthropic(anthropic_api_key=api_key, temperature=config.temperature, model_name=config.model)
 
         if config.max_tokens and config.max_tokens != 1000:
             logger.warning("Config option `max_tokens` is not supported by this model.")

+ 4 - 4
embedchain/llm/cohere.py

@@ -12,9 +12,6 @@ from embedchain.llm.base import BaseLlm
 @register_deserializable
 class CohereLlm(BaseLlm):
     def __init__(self, config: Optional[BaseLlmConfig] = None):
-        if "COHERE_API_KEY" not in os.environ:
-            raise ValueError("Please set the COHERE_API_KEY environment variable.")
-
         try:
             importlib.import_module("cohere")
         except ModuleNotFoundError:
@@ -24,6 +21,8 @@ class CohereLlm(BaseLlm):
             ) from None
 
         super().__init__(config=config)
+        if not self.config.api_key and "COHERE_API_KEY" not in os.environ:
+            raise ValueError("Please set the COHERE_API_KEY environment variable or pass it in the config.")
 
     def get_llm_model_answer(self, prompt):
         if self.config.system_prompt:
@@ -32,8 +31,9 @@ class CohereLlm(BaseLlm):
 
     @staticmethod
     def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
+        api_key = config.api_key or os.getenv("COHERE_API_KEY")
         llm = Cohere(
-            cohere_api_key=os.environ["COHERE_API_KEY"],
+            cohere_api_key=api_key,
             model=config.model,
             max_tokens=config.max_tokens,
             temperature=config.temperature,

+ 5 - 4
embedchain/llm/google.py

@@ -16,9 +16,6 @@ logger = logging.getLogger(__name__)
 @register_deserializable
 class GoogleLlm(BaseLlm):
     def __init__(self, config: Optional[BaseLlmConfig] = None):
-        if "GOOGLE_API_KEY" not in os.environ:
-            raise ValueError("Please set the GOOGLE_API_KEY environment variable.")
-
         try:
             importlib.import_module("google.generativeai")
         except ModuleNotFoundError:
@@ -28,7 +25,11 @@ class GoogleLlm(BaseLlm):
             ) from None
 
         super().__init__(config)
-        genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
+        if not self.config.api_key and "GOOGLE_API_KEY" not in os.environ:
+            raise ValueError("Please set the GOOGLE_API_KEY environment variable or pass it in the config.")
+
+        api_key = self.config.api_key or os.getenv("GOOGLE_API_KEY")
+        genai.configure(api_key=api_key)
 
     def get_llm_model_answer(self, prompt):
         if self.config.system_prompt:

+ 2 - 0
embedchain/llm/groq.py

@@ -19,6 +19,8 @@ from embedchain.llm.base import BaseLlm
 class GroqLlm(BaseLlm):
     def __init__(self, config: Optional[BaseLlmConfig] = None):
         super().__init__(config=config)
+        if not self.config.api_key and "GROQ_API_KEY" not in os.environ:
+            raise ValueError("Please set the GROQ_API_KEY environment variable or pass it in the config.")
 
     def get_llm_model_answer(self, prompt) -> str:
         response = self._get_answer(prompt, self.config)

+ 6 - 5
embedchain/llm/huggingface.py

@@ -17,9 +17,6 @@ logger = logging.getLogger(__name__)
 @register_deserializable
 class HuggingFaceLlm(BaseLlm):
     def __init__(self, config: Optional[BaseLlmConfig] = None):
-        if "HUGGINGFACE_ACCESS_TOKEN" not in os.environ:
-            raise ValueError("Please set the HUGGINGFACE_ACCESS_TOKEN environment variable.")
-
         try:
             importlib.import_module("huggingface_hub")
         except ModuleNotFoundError:
@@ -29,6 +26,8 @@ class HuggingFaceLlm(BaseLlm):
             ) from None
 
         super().__init__(config=config)
+        if not self.config.api_key and "HUGGINGFACE_ACCESS_TOKEN" not in os.environ:
+            raise ValueError("Please set the HUGGINGFACE_ACCESS_TOKEN environment variable or pass it in the config.")
 
     def get_llm_model_answer(self, prompt):
         if self.config.system_prompt:
@@ -60,9 +59,10 @@ class HuggingFaceLlm(BaseLlm):
             raise ValueError("`top_p` must be > 0.0 and < 1.0")
 
         model = config.model
+        api_key = config.api_key or os.getenv("HUGGINGFACE_ACCESS_TOKEN")
         logger.info(f"Using HuggingFaceHub with model {model}")
         llm = HuggingFaceHub(
-            huggingfacehub_api_token=os.environ["HUGGINGFACE_ACCESS_TOKEN"],
+            huggingfacehub_api_token=api_key,
             repo_id=model,
             model_kwargs=model_kwargs,
         )
@@ -70,8 +70,9 @@ class HuggingFaceLlm(BaseLlm):
 
     @staticmethod
     def _from_endpoint(prompt: str, config: BaseLlmConfig) -> str:
+        api_key = config.api_key or os.getenv("HUGGINGFACE_ACCESS_TOKEN")
         llm = HuggingFaceEndpoint(
-            huggingfacehub_api_token=os.environ["HUGGINGFACE_ACCESS_TOKEN"],
+            huggingfacehub_api_token=api_key,
             endpoint_url=config.endpoint,
             task="text-generation",
             model_kwargs=config.model_kwargs,

+ 4 - 4
embedchain/llm/jina.py

@@ -12,9 +12,9 @@ from embedchain.llm.base import BaseLlm
 @register_deserializable
 class JinaLlm(BaseLlm):
     def __init__(self, config: Optional[BaseLlmConfig] = None):
-        if "JINACHAT_API_KEY" not in os.environ:
-            raise ValueError("Please set the JINACHAT_API_KEY environment variable.")
         super().__init__(config=config)
+        if not self.config.api_key and "JINACHAT_API_KEY" not in os.environ:
+            raise ValueError("Please set the JINACHAT_API_KEY environment variable or pass it in the config.")
 
     def get_llm_model_answer(self, prompt):
         response = JinaLlm._get_answer(prompt, self.config)
@@ -29,13 +29,13 @@ class JinaLlm(BaseLlm):
         kwargs = {
             "temperature": config.temperature,
             "max_tokens": config.max_tokens,
+            "jinachat_api_key": config.api_key or os.environ["JINACHAT_API_KEY"],
             "model_kwargs": {},
         }
         if config.top_p:
             kwargs["model_kwargs"]["top_p"] = config.top_p
         if config.stream:
-            from langchain.callbacks.streaming_stdout import \
-                StreamingStdOutCallbackHandler
+            from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 
             chat = JinaChat(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()])
         else:

+ 4 - 2
embedchain/llm/llama2.py

@@ -19,8 +19,6 @@ class Llama2Llm(BaseLlm):
                 "The required dependencies for Llama2 are not installed."
                 'Please install with `pip install --upgrade "embedchain[llama2]"`'
             ) from None
-        if "REPLICATE_API_TOKEN" not in os.environ:
-            raise ValueError("Please set the REPLICATE_API_TOKEN environment variable.")
 
         # Set default config values specific to this llm
         if not config:
@@ -35,13 +33,17 @@ class Llama2Llm(BaseLlm):
             )
 
         super().__init__(config=config)
+        if not self.config.api_key and "REPLICATE_API_TOKEN" not in os.environ:
+            raise ValueError("Please set the REPLICATE_API_TOKEN environment variable or pass it in the config.")
 
     def get_llm_model_answer(self, prompt):
         # TODO: Move the model and other inputs into config
         if self.config.system_prompt:
             raise ValueError("Llama2 does not support `system_prompt`")
+        api_key = self.config.api_key or os.getenv("REPLICATE_API_TOKEN")
         llm = Replicate(
             model=self.config.model,
+            replicate_api_token=api_key,
             input={
                 "temperature": self.config.temperature,
                 "max_length": self.config.max_tokens,

+ 3 - 4
embedchain/llm/nvidia.py

@@ -21,10 +21,9 @@ from embedchain.llm.base import BaseLlm
 @register_deserializable
 class NvidiaLlm(BaseLlm):
     def __init__(self, config: Optional[BaseLlmConfig] = None):
-        if "NVIDIA_API_KEY" not in os.environ:
-            raise ValueError("NVIDIA_API_KEY environment variable must be set")
-
         super().__init__(config=config)
+        if not self.config.api_key and "NVIDIA_API_KEY" not in os.environ:
+            raise ValueError("Please set the NVIDIA_API_KEY environment variable or pass it in the config.")
 
     def get_llm_model_answer(self, prompt):
         return self._get_answer(prompt=prompt, config=self.config)
@@ -34,7 +33,7 @@ class NvidiaLlm(BaseLlm):
         callback_manager = [StreamingStdOutCallbackHandler()] if config.stream else [StdOutCallbackHandler()]
         model_kwargs = config.model_kwargs or {}
         labels = model_kwargs.get("labels", None)
-        params = {"model": config.model}
+        params = {"model": config.model, "nvidia_api_key": config.api_key or os.getenv("NVIDIA_API_KEY")}
         if config.system_prompt:
             params["system_prompt"] = config.system_prompt
         if config.temperature:

+ 4 - 4
embedchain/llm/together.py

@@ -12,9 +12,6 @@ from embedchain.llm.base import BaseLlm
 @register_deserializable
 class TogetherLlm(BaseLlm):
     def __init__(self, config: Optional[BaseLlmConfig] = None):
-        if "TOGETHER_API_KEY" not in os.environ:
-            raise ValueError("Please set the TOGETHER_API_KEY environment variable.")
-
         try:
             importlib.import_module("together")
         except ModuleNotFoundError:
@@ -24,6 +21,8 @@ class TogetherLlm(BaseLlm):
             ) from None
 
         super().__init__(config=config)
+        if not self.config.api_key and "TOGETHER_API_KEY" not in os.environ:
+            raise ValueError("Please set the TOGETHER_API_KEY environment variable or pass it in the config.")
 
     def get_llm_model_answer(self, prompt):
         if self.config.system_prompt:
@@ -32,8 +31,9 @@ class TogetherLlm(BaseLlm):
 
     @staticmethod
     def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
+        api_key = config.api_key or os.getenv("TOGETHER_API_KEY")
         llm = Together(
-            together_api_key=os.environ["TOGETHER_API_KEY"],
+            together_api_key=api_key,
             model=config.model,
             max_tokens=config.max_tokens,
             temperature=config.temperature,

+ 1 - 0
tests/llm/test_jina.py

@@ -74,5 +74,6 @@ def test_get_llm_model_answer_without_system_prompt(config, mocker):
     mocked_jinachat.assert_called_once_with(
         temperature=config.temperature,
         max_tokens=config.max_tokens,
+        jinachat_api_key=os.environ["JINACHAT_API_KEY"],
         model_kwargs={"top_p": config.top_p},
     )