gpt4all.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import os
  2. from collections.abc import Iterable
  3. from pathlib import Path
  4. from typing import Optional, Union
  5. from langchain.callbacks.stdout import StdOutCallbackHandler
  6. from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
  7. from embedchain.config import BaseLlmConfig
  8. from embedchain.helpers.json_serializable import register_deserializable
  9. from embedchain.llm.base import BaseLlm
  10. @register_deserializable
  11. class GPT4ALLLlm(BaseLlm):
  12. def __init__(self, config: Optional[BaseLlmConfig] = None):
  13. super().__init__(config=config)
  14. if self.config.model is None:
  15. self.config.model = "orca-mini-3b-gguf2-q4_0.gguf"
  16. self.instance = GPT4ALLLlm._get_instance(self.config.model)
  17. self.instance.streaming = self.config.stream
  18. def get_llm_model_answer(self, prompt):
  19. return self._get_answer(prompt=prompt, config=self.config)
  20. @staticmethod
  21. def _get_instance(model):
  22. try:
  23. from langchain.llms.gpt4all import GPT4All as LangchainGPT4All
  24. except ModuleNotFoundError:
  25. raise ModuleNotFoundError(
  26. "The GPT4All python package is not installed. Please install it with `pip install --upgrade embedchain[opensource]`" # noqa E501
  27. ) from None
  28. model_path = Path(model).expanduser()
  29. if os.path.isabs(model_path):
  30. if os.path.exists(model_path):
  31. return LangchainGPT4All(model=str(model_path))
  32. else:
  33. raise ValueError(f"Model does not exist at {model_path=}")
  34. else:
  35. return LangchainGPT4All(model=model, allow_download=True)
  36. def _get_answer(self, prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
  37. if config.model and config.model != self.config.model:
  38. raise RuntimeError(
  39. "GPT4ALLLlm does not support switching models at runtime. Please create a new app instance."
  40. )
  41. messages = []
  42. if config.system_prompt:
  43. messages.append(config.system_prompt)
  44. messages.append(prompt)
  45. kwargs = {
  46. "temp": config.temperature,
  47. "max_tokens": config.max_tokens,
  48. }
  49. if config.top_p:
  50. kwargs["top_p"] = config.top_p
  51. callbacks = [StreamingStdOutCallbackHandler()] if config.stream else [StdOutCallbackHandler()]
  52. response = self.instance.generate(prompts=messages, callbacks=callbacks, **kwargs)
  53. answer = ""
  54. for generations in response.generations:
  55. answer += " ".join(map(lambda generation: generation.text, generations))
  56. return answer