소스 검색

[Improvement] Add support for gpt4all through langchain (#838)

Deven Patel 1 년 전
부모
커밋
0f8a2e624a
8개의 변경된 파일36개의 추가작업 그리고 29개의 파일을 삭제
  1. 1 3
      configs/gpt4all.yaml
  2. 1 2
      configs/opensource.yaml
  3. 0 2
      docs/components/embedding-models.mdx
  4. 0 2
      docs/components/llms.mdx
  5. 5 6
      embedchain/embedder/gpt4all.py
  6. 27 12
      embedchain/llm/gpt4all.py
  7. 1 1
      pyproject.toml
  8. 1 1
      tests/apps/test_apps.py

+ 1 - 3
configs/gpt4all.yaml

@@ -1,7 +1,7 @@
 llm:
   provider: gpt4all
-  model: 'orca-mini-3b.ggmlv3.q4_0.bin'
   config:
+    model: 'orca-mini-3b.ggmlv3.q4_0.bin'
     temperature: 0.5
     max_tokens: 1000
     top_p: 1
@@ -9,5 +9,3 @@ llm:
 
 embedder:
   provider: gpt4all
-  config:
-    model: 'all-MiniLM-L6-v2'

+ 1 - 2
configs/opensource.yaml

@@ -6,8 +6,8 @@ app:
 
 llm:
   provider: gpt4all
-  model: 'orca-mini-3b.ggmlv3.q4_0.bin'
   config:
+    model: 'orca-mini-3b.ggmlv3.q4_0.bin'
     temperature: 0.5
     max_tokens: 1000
     top_p: 1
@@ -23,5 +23,4 @@ vectordb:
 embedder:
   provider: gpt4all
   config:
-    model: 'all-MiniLM-L6-v2'
     deployment_name: null

+ 0 - 2
docs/components/embedding-models.mdx

@@ -108,8 +108,6 @@ llm:
 
 embedder:
   provider: gpt4all
-  config:
-    model: 'all-MiniLM-L6-v2'
 ```
 
 </CodeGroup>

+ 0 - 2
docs/components/llms.mdx

@@ -198,8 +198,6 @@ llm:
 
 embedder:
   provider: gpt4all
-  config:
-    model: 'all-MiniLM-L6-v2'
 ```
 </CodeGroup>
 

+ 5 - 6
embedchain/embedder/gpt4all.py

@@ -1,7 +1,5 @@
 from typing import Optional
 
-from chromadb.utils import embedding_functions
-
 from embedchain.config import BaseEmbedderConfig
 from embedchain.embedder.base import BaseEmbedder
 from embedchain.models import VectorDimensions
@@ -9,12 +7,13 @@ from embedchain.models import VectorDimensions
 
 class GPT4AllEmbedder(BaseEmbedder):
     def __init__(self, config: Optional[BaseEmbedderConfig] = None):
-        # Note: We could use langchains GPT4ALL embedding, but it's not available in all versions.
         super().__init__(config=config)
-        if self.config.model is None:
-            self.config.model = "all-MiniLM-L6-v2"
 
-        embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=self.config.model)
+        from langchain.embeddings import \
+            GPT4AllEmbeddings as LangchainGPT4AllEmbeddings
+
+        embeddings = LangchainGPT4AllEmbeddings()
+        embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
         self.set_embedding_fn(embedding_fn=embedding_fn)
 
         vector_dimension = VectorDimensions.GPT4ALL.value

+ 27 - 12
embedchain/llm/gpt4all.py

@@ -1,5 +1,8 @@
 from typing import Iterable, Optional, Union
 
+from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
+from langchain.callbacks.stdout import StdOutCallbackHandler
+
 from embedchain.config import BaseLlmConfig
 from embedchain.helper.json_serializable import register_deserializable
 from embedchain.llm.base import BaseLlm
@@ -12,6 +15,7 @@ class GPT4ALLLlm(BaseLlm):
         if self.config.model is None:
             self.config.model = "orca-mini-3b.ggmlv3.q4_0.bin"
         self.instance = GPT4ALLLlm._get_instance(self.config.model)
+        self.instance.streaming = config.stream
 
     def get_llm_model_answer(self, prompt):
         return self._get_answer(prompt=prompt, config=self.config)
@@ -19,13 +23,13 @@ class GPT4ALLLlm(BaseLlm):
     @staticmethod
     def _get_instance(model):
         try:
-            from gpt4all import GPT4All
+            from langchain.llms.gpt4all import GPT4All as LangchainGPT4All
         except ModuleNotFoundError:
             raise ModuleNotFoundError(
                 "The GPT4All python package is not installed. Please install it with `pip install --upgrade embedchain[opensource]`"  # noqa E501
             ) from None
 
-        return GPT4All(model_name=model)
+        return LangchainGPT4All(model=model, allow_download=True)
 
     def _get_answer(self, prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
         if config.model and config.model != self.config.model:
@@ -33,14 +37,25 @@ class GPT4ALLLlm(BaseLlm):
                 "GPT4ALLLlm does not support switching models at runtime. Please create a new app instance."
             )
 
+        messages = []
         if config.system_prompt:
-            raise ValueError("GPT4ALLLlm does not support `system_prompt`")
-
-        response = self.instance.generate(
-            prompt=prompt,
-            streaming=config.stream,
-            top_p=config.top_p,
-            max_tokens=config.max_tokens,
-            temp=config.temperature,
-        )
-        return response
+            messages.append(config.system_prompt)
+        messages.append(prompt)
+        kwargs = {
+            "temp": config.temperature,
+            "max_tokens": config.max_tokens,
+        }
+        if config.top_p:
+            kwargs["top_p"] = config.top_p
+
+        callbacks = None
+        if config.stream:
+            callbacks = [StreamingStdOutCallbackHandler()]
+        else:
+            callbacks =[StdOutCallbackHandler()]
+
+        response = self.instance.generate(prompts=messages, callbacks=callbacks, **kwargs)
+        answer = ""
+        for generations in response.generations:
+            answer += " ".join(map(lambda generation: generation.text, generations))
+        return answer

+ 1 - 1
pyproject.toml

@@ -143,7 +143,7 @@ pytest-asyncio = "^0.21.1"
 [tool.poetry.extras]
 streamlit = ["streamlit"]
 community = ["llama-hub"]
-opensource = ["sentence-transformers", "torch", "gpt4all"]
+opensource = ["sentence-transformers", "torch", "gpt4all", "langchain"]
 elasticsearch = ["elasticsearch"]
 opensearch = ["opensearch-py"]
 poe = ["fastapi-poe"]

+ 1 - 1
tests/apps/test_apps.py

@@ -135,6 +135,7 @@ class TestAppFromConfig:
 
         # Validate the LLM config values
         llm_config = config_data["llm"]["config"]
+        assert app.llm.config.model == llm_config["model"]
         assert app.llm.config.temperature == llm_config["temperature"]
         assert app.llm.config.max_tokens == llm_config["max_tokens"]
         assert app.llm.config.top_p == llm_config["top_p"]
@@ -148,5 +149,4 @@ class TestAppFromConfig:
 
         # Validate the Embedder config values
         embedder_config = config_data["embedder"]["config"]
-        assert app.embedder.config.model == embedder_config["model"]
         assert app.embedder.config.deployment_name == embedder_config["deployment_name"]