clarifai.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import logging
  2. import os
  3. from typing import Optional
  4. from embedchain.config import BaseLlmConfig
  5. from embedchain.helpers.json_serializable import register_deserializable
  6. from embedchain.llm.base import BaseLlm
  7. @register_deserializable
  8. class ClarifaiLlm(BaseLlm):
  9. def __init__(self, config: Optional[BaseLlmConfig] = None):
  10. super().__init__(config=config)
  11. if not self.config.api_key and "CLARIFAI_PAT" not in os.environ:
  12. raise ValueError("Please set the CLARIFAI_PAT environment variable.")
  13. def get_llm_model_answer(self, prompt):
  14. return self._get_answer(prompt=prompt, config=self.config)
  15. @staticmethod
  16. def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
  17. try:
  18. from clarifai.client.model import Model
  19. except ModuleNotFoundError:
  20. raise ModuleNotFoundError(
  21. "The required dependencies for Clarifai are not installed."
  22. "Please install with `pip install clarifai==10.0.1`"
  23. ) from None
  24. model_name = config.model
  25. logging.info(f"Using clarifai LLM model: {model_name}")
  26. api_key = config.api_key or os.getenv("CLARIFAI_PAT")
  27. model = Model(url=model_name, pat=api_key)
  28. params = config.model_kwargs
  29. try:
  30. (params := {}) if config.model_kwargs is None else config.model_kwargs
  31. predict_response = model.predict_by_bytes(
  32. bytes(prompt, "utf-8"),
  33. input_type="text",
  34. inference_params=params,
  35. )
  36. text = predict_response.outputs[0].data.text.raw
  37. return text
  38. except Exception as e:
  39. logging.error(f"Predict failed, exception: {e}")