|
@@ -1,3 +1,5 @@
|
|
|
+import os
|
|
|
+from pathlib import Path
|
|
|
from typing import Iterable, Optional, Union
|
|
|
|
|
|
from langchain.callbacks.stdout import StdOutCallbackHandler
|
|
@@ -29,7 +31,14 @@ class GPT4ALLLlm(BaseLlm):
|
|
|
"The GPT4All python package is not installed. Please install it with `pip install --upgrade embedchain[opensource]`" # noqa E501
|
|
|
) from None
|
|
|
|
|
|
- return LangchainGPT4All(model=model)
|
|
|
+ model_path = Path(model).expanduser()
|
|
|
+ if os.path.isabs(model_path):
|
|
|
+ if os.path.exists(model_path):
|
|
|
+ return LangchainGPT4All(model=str(model_path))
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Model does not exist at {model_path=}")
|
|
|
+ else:
|
|
|
+ return LangchainGPT4All(model=model, allow_download=True)
|
|
|
|
|
|
def _get_answer(self, prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
|
|
|
if config.model and config.model != self.config.model:
|