gpt4all.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. from typing import Iterable, Optional, Union
  2. from embedchain.config import BaseLlmConfig
  3. from embedchain.helper.json_serializable import register_deserializable
  4. from embedchain.llm.base import BaseLlm
  5. @register_deserializable
  6. class GPT4ALLLlm(BaseLlm):
  7. def __init__(self, config: Optional[BaseLlmConfig] = None):
  8. super().__init__(config=config)
  9. if self.config.model is None:
  10. self.config.model = "orca-mini-3b.ggmlv3.q4_0.bin"
  11. self.instance = GPT4ALLLlm._get_instance(self.config.model)
  12. def get_llm_model_answer(self, prompt):
  13. return self._get_answer(prompt=prompt, config=self.config)
  14. @staticmethod
  15. def _get_instance(model):
  16. try:
  17. from gpt4all import GPT4All
  18. except ModuleNotFoundError:
  19. raise ModuleNotFoundError(
  20. "The GPT4All python package is not installed. Please install it with `pip install --upgrade embedchain[opensource]`" # noqa E501
  21. ) from None
  22. return GPT4All(model_name=model)
  23. def _get_answer(self, prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
  24. if config.model and config.model != self.config.model:
  25. raise RuntimeError(
  26. "OpenSourceApp does not support switching models at runtime. Please create a new app instance."
  27. )
  28. if config.system_prompt:
  29. raise ValueError("OpenSourceApp does not support `system_prompt`")
  30. response = self.instance.generate(
  31. prompt=prompt,
  32. streaming=config.stream,
  33. top_p=config.top_p,
  34. max_tokens=config.max_tokens,
  35. temp=config.temperature,
  36. )
  37. return response