|
@@ -1,5 +1,8 @@
|
|
|
from typing import Iterable, Optional, Union
|
|
|
|
|
|
+from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
|
|
+from langchain.callbacks.stdout import StdOutCallbackHandler
|
|
|
+
|
|
|
from embedchain.config import BaseLlmConfig
|
|
|
from embedchain.helper.json_serializable import register_deserializable
|
|
|
from embedchain.llm.base import BaseLlm
|
|
@@ -12,6 +15,7 @@ class GPT4ALLLlm(BaseLlm):
|
|
|
if self.config.model is None:
|
|
|
self.config.model = "orca-mini-3b.ggmlv3.q4_0.bin"
|
|
|
self.instance = GPT4ALLLlm._get_instance(self.config.model)
|
|
|
+ self.instance.streaming = config.stream
|
|
|
|
|
|
def get_llm_model_answer(self, prompt):
|
|
|
return self._get_answer(prompt=prompt, config=self.config)
|
|
@@ -19,13 +23,13 @@ class GPT4ALLLlm(BaseLlm):
|
|
|
@staticmethod
|
|
|
def _get_instance(model):
|
|
|
try:
|
|
|
- from gpt4all import GPT4All
|
|
|
+ from langchain.llms.gpt4all import GPT4All as LangchainGPT4All
|
|
|
except ModuleNotFoundError:
|
|
|
raise ModuleNotFoundError(
|
|
|
"The GPT4All python package is not installed. Please install it with `pip install --upgrade embedchain[opensource]`" # noqa E501
|
|
|
) from None
|
|
|
|
|
|
- return GPT4All(model_name=model)
|
|
|
+ return LangchainGPT4All(model=model, allow_download=True)
|
|
|
|
|
|
def _get_answer(self, prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
|
|
|
if config.model and config.model != self.config.model:
|
|
@@ -33,14 +37,25 @@ class GPT4ALLLlm(BaseLlm):
|
|
|
"GPT4ALLLlm does not support switching models at runtime. Please create a new app instance."
|
|
|
)
|
|
|
|
|
|
+ messages = []
|
|
|
if config.system_prompt:
|
|
|
- raise ValueError("GPT4ALLLlm does not support `system_prompt`")
|
|
|
-
|
|
|
- response = self.instance.generate(
|
|
|
- prompt=prompt,
|
|
|
- streaming=config.stream,
|
|
|
- top_p=config.top_p,
|
|
|
- max_tokens=config.max_tokens,
|
|
|
- temp=config.temperature,
|
|
|
- )
|
|
|
- return response
|
|
|
+ messages.append(config.system_prompt)
|
|
|
+ messages.append(prompt)
|
|
|
+ kwargs = {
|
|
|
+ "temp": config.temperature,
|
|
|
+ "max_tokens": config.max_tokens,
|
|
|
+ }
|
|
|
+ if config.top_p:
|
|
|
+ kwargs["top_p"] = config.top_p
|
|
|
+
|
|
|
+ callbacks = None
|
|
|
+ if config.stream:
|
|
|
+ callbacks = [StreamingStdOutCallbackHandler()]
|
|
|
+ else:
|
|
|
+ callbacks =[StdOutCallbackHandler()]
|
|
|
+
|
|
|
+ response = self.instance.generate(prompts=messages, callbacks=callbacks, **kwargs)
|
|
|
+ answer = ""
|
|
|
+ for generations in response.generations:
|
|
|
+ answer += " ".join(map(lambda generation: generation.text, generations))
|
|
|
+ return answer
|