1234567891011121314151617181920212223242526272829303132333435363738394041424344454647 |
- from typing import Iterable, Optional, Union
- from embedchain.config import BaseLlmConfig
- from embedchain.llm.base_llm import BaseLlm
- from embedchain.helper_classes.json_serializable import register_deserializable
- @register_deserializable
- class GPT4ALLLlm(BaseLlm):
- def __init__(self, config: Optional[BaseLlmConfig] = None):
- super().__init__(config=config)
- if self.config.model is None:
- self.config.model = "orca-mini-3b.ggmlv3.q4_0.bin"
- self.instance = GPT4ALLLlm._get_instance(self.config.model)
- def get_llm_model_answer(self, prompt):
- return self._get_gpt4all_answer(prompt=prompt, config=self.config)
- @staticmethod
- def _get_instance(model):
- try:
- from gpt4all import GPT4All
- except ModuleNotFoundError:
- raise ModuleNotFoundError(
- "The GPT4All python package is not installed. Please install it with `pip install embedchain[opensource]`" # noqa E501
- ) from None
- return GPT4All(model_name=model)
- def _get_gpt4all_answer(self, prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
- if config.model and config.model != self.config.model:
- raise RuntimeError(
- "OpenSourceApp does not support switching models at runtime. Please create a new app instance."
- )
- if config.system_prompt:
- raise ValueError("OpenSourceApp does not support `system_prompt`")
- response = self.instance.generate(
- prompt=prompt,
- streaming=config.stream,
- top_p=config.top_p,
- max_tokens=config.max_tokens,
- temp=config.temperature,
- )
- return response
|