gpt4all.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import os
  2. from collections.abc import Iterable
  3. from pathlib import Path
  4. from typing import Optional, Union
  5. from langchain.callbacks.stdout import StdOutCallbackHandler
  6. from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
  7. from embedchain.config import BaseLlmConfig
  8. from embedchain.helpers.json_serializable import register_deserializable
  9. from embedchain.llm.base import BaseLlm
  10. @register_deserializable
  11. class GPT4ALLLlm(BaseLlm):
  12. def __init__(self, config: Optional[BaseLlmConfig] = None):
  13. super().__init__(config=config)
  14. if self.config.model is None:
  15. self.config.model = "orca-mini-3b-gguf2-q4_0.gguf"
  16. self.instance = GPT4ALLLlm._get_instance(self.config.model)
  17. self.instance.streaming = self.config.stream
  18. def get_llm_model_answer(self, prompt):
  19. return self._get_answer(prompt=prompt, config=self.config)
  20. @staticmethod
  21. def _get_instance(model):
  22. try:
  23. from langchain_community.llms.gpt4all import \
  24. GPT4All as LangchainGPT4All
  25. except ModuleNotFoundError:
  26. raise ModuleNotFoundError(
  27. "The GPT4All python package is not installed. Please install it with `pip install --upgrade embedchain[opensource]`" # noqa E501
  28. ) from None
  29. model_path = Path(model).expanduser()
  30. if os.path.isabs(model_path):
  31. if os.path.exists(model_path):
  32. return LangchainGPT4All(model=str(model_path))
  33. else:
  34. raise ValueError(f"Model does not exist at {model_path=}")
  35. else:
  36. return LangchainGPT4All(model=model, allow_download=True)
  37. def _get_answer(self, prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
  38. if config.model and config.model != self.config.model:
  39. raise RuntimeError(
  40. "GPT4ALLLlm does not support switching models at runtime. Please create a new app instance."
  41. )
  42. messages = []
  43. if config.system_prompt:
  44. messages.append(config.system_prompt)
  45. messages.append(prompt)
  46. kwargs = {
  47. "temp": config.temperature,
  48. "max_tokens": config.max_tokens,
  49. }
  50. if config.top_p:
  51. kwargs["top_p"] = config.top_p
  52. callbacks = [StreamingStdOutCallbackHandler()] if config.stream else [StdOutCallbackHandler()]
  53. response = self.instance.generate(prompts=messages, callbacks=callbacks, **kwargs)
  54. answer = ""
  55. for generations in response.generations:
  56. answer += " ".join(map(lambda generation: generation.text, generations))
  57. return answer