huggingface.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import importlib
  2. import os
  3. from typing import Optional
  4. from langchain.llms import HuggingFaceHub
  5. from embedchain.config import BaseLlmConfig
  6. from embedchain.helper.json_serializable import register_deserializable
  7. from embedchain.llm.base import BaseLlm
  8. @register_deserializable
  9. class HuggingFaceLlm(BaseLlm):
  10. def __init__(self, config: Optional[BaseLlmConfig] = None):
  11. if "HUGGINGFACE_ACCESS_TOKEN" not in os.environ:
  12. raise ValueError("Please set the HUGGINGFACE_ACCESS_TOKEN environment variable.")
  13. try:
  14. importlib.import_module("huggingface_hub")
  15. except ModuleNotFoundError:
  16. raise ModuleNotFoundError(
  17. "The required dependencies for HuggingFaceHub are not installed."
  18. 'Please install with `pip install --upgrade "embedchain[huggingface_hub]"`'
  19. ) from None
  20. super().__init__(config=config)
  21. def get_llm_model_answer(self, prompt):
  22. if self.config.system_prompt:
  23. raise ValueError("HuggingFaceLlm does not support `system_prompt`")
  24. return HuggingFaceLlm._get_answer(prompt=prompt, config=self.config)
  25. @staticmethod
  26. def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
  27. model_kwargs = {
  28. "temperature": config.temperature or 0.1,
  29. "max_new_tokens": config.max_tokens,
  30. }
  31. if config.top_p > 0.0 and config.top_p < 1.0:
  32. model_kwargs["top_p"] = config.top_p
  33. else:
  34. raise ValueError("`top_p` must be > 0.0 and < 1.0")
  35. llm = HuggingFaceHub(
  36. huggingfacehub_api_token=os.environ["HUGGINGFACE_ACCESS_TOKEN"],
  37. repo_id=config.model or "google/flan-t5-xxl",
  38. model_kwargs=model_kwargs,
  39. )
  40. return llm(prompt)