Преглед на файлове

[Bugfix] fix cache session id in chat method (#1107)

Co-authored-by: Deven Patel <deven298@yahoo.com>
Deven Patel преди 1 година
родител
ревизия
ae2e9cb890
променени са 2 файла, в които са добавени 12 реда и са изтрити 2 реда
  1. 2 1
      embedchain/embedchain.py
  2. 10 1
      embedchain/llm/gpt4all.py

+ 2 - 1
embedchain/embedchain.py

@@ -612,11 +612,12 @@ class EmbedChain(JSONSerializable):
 
         if self.cache_config is not None:
             logging.info("Cache enabled. Checking cache...")
+            cache_id = f"{session_id}--{self.config.id}"
             answer = adapt(
                 llm_handler=self.llm.chat,
                 cache_data_convert=gptcache_data_convert,
                 update_cache_callback=gptcache_update_cache_callback,
-                session=get_gptcache_session(session_id=self.config.id),
+                session=get_gptcache_session(session_id=cache_id),
                 input_query=input_query,
                 contexts=contexts_data_for_llm_query,
                 config=config,

+ 10 - 1
embedchain/llm/gpt4all.py

@@ -1,3 +1,5 @@
+import os
+from pathlib import Path
 from typing import Iterable, Optional, Union
 
 from langchain.callbacks.stdout import StdOutCallbackHandler
@@ -29,7 +31,14 @@ class GPT4ALLLlm(BaseLlm):
                 "The GPT4All python package is not installed. Please install it with `pip install --upgrade embedchain[opensource]`"  # noqa E501
             ) from None
 
-        return LangchainGPT4All(model=model)
+        model_path = Path(model).expanduser()
+        if os.path.isabs(model_path):
+            if os.path.exists(model_path):
+                return LangchainGPT4All(model=str(model_path))
+            else:
+                raise ValueError(f"Model does not exist at {model_path=}")
+        else:
+            return LangchainGPT4All(model=model, allow_download=True)
 
     def _get_answer(self, prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
         if config.model and config.model != self.config.model: