Преглед изворни кода

[Feature] add google ai embedder (#1019)

Co-authored-by: Deven Patel <deven298@yahoo.com>
Deven Patel пре 1 година
родитељ
комит
c0b5e93967

+ 5 - 0
configs/google.yaml

@@ -6,3 +6,8 @@ llm:
     temperature: 0.9
     top_p: 1.0
     stream: false
+
+embedder:
+  provider: google
+  config:
+    model: models/embedding-001

+ 29 - 0
docs/components/embedding-models.mdx

@@ -8,6 +8,7 @@ Embedchain supports several embedding models from the following providers:
 
 <CardGroup cols={4}>
   <Card title="OpenAI" href="#openai"></Card>
+  <Card title="GoogleAI" href="#google-ai"></Card>
   <Card title="Azure OpenAI" href="#azure-openai"></Card>
   <Card title="GPT4All" href="#gpt4all"></Card>
   <Card title="Hugging Face" href="#hugging-face"></Card>
@@ -44,6 +45,34 @@ embedder:
 
 </CodeGroup>
 
+## Google AI
+
+To use Google AI embedding function, you have to set the `GOOGLE_API_KEY` environment variable. You can obtain the Google API key from the [Google Maker Suite](https://makersuite.google.com/app/apikey)
+
+<CodeGroup>
+```python main.py
+import os
+from embedchain import Pipeline as App
+
+os.environ["GOOGLE_API_KEY"] = "xxx"
+
+app = App.from_config(config_path="config.yaml")
+```
+
+```yaml config.yaml
+embedder:
+  provider: google
+  config:
+    model: 'models/embedding-001'
+    task_type: "retrieval_document"
+    title: "Embeddings for Embedchain"
+```
+</CodeGroup>
+<br/>
+<Note>
+For more details regarding the Google AI embedding model, please refer to the [Google AI documentation](https://ai.google.dev/tutorials/python_quickstart#use_embeddings).
+</Note>
+
 ## Azure OpenAI
 
 To use Azure OpenAI embedding model, you have to set some of the azure openai related environment variables as given in the code block below:

+ 7 - 1
docs/components/llms.mdx

@@ -72,7 +72,6 @@ To use Google AI model, you have to set the `GOOGLE_API_KEY` environment variabl
 import os
 from embedchain import Pipeline as App
 
-os.environ["OPENAI_API_KEY"] = "sk-xxxx"
 os.environ["GOOGLE_API_KEY"] = "xxx"
 
 app = App.from_config(config_path="config.yaml")
@@ -96,6 +95,13 @@ llm:
     temperature: 0.5
     top_p: 1
     stream: false
+
+embedder:
+  provider: google
+  config:
+    model: 'models/embedding-001'
+    task_type: "retrieval_document"
+    title: "Embeddings for Embedchain"
 ```
 </CodeGroup>
 

+ 18 - 0
embedchain/config/embedder/google.py

@@ -0,0 +1,18 @@
+from typing import Optional
+
+from embedchain.config.embedder.base import BaseEmbedderConfig
+from embedchain.helpers.json_serializable import register_deserializable
+
+
+@register_deserializable
+class GoogleAIEmbedderConfig(BaseEmbedderConfig):
+    def __init__(
+        self,
+        model: Optional[str] = None,
+        deployment_name: Optional[str] = None,
+        task_type: Optional[str] = None,
+        title: Optional[str] = None,
+    ):
+        super().__init__(model, deployment_name)
+        self.task_type = task_type or "retrieval_document"
+        self.title = title or "Embeddings for Embedchain"

+ 31 - 0
embedchain/embedder/google.py

@@ -0,0 +1,31 @@
+from typing import Optional
+
+import google.generativeai as genai
+from chromadb import EmbeddingFunction, Embeddings
+
+from embedchain.config.embedder.google import GoogleAIEmbedderConfig
+from embedchain.embedder.base import BaseEmbedder
+from embedchain.models import VectorDimensions
+
+
+class GoogleAIEmbeddingFunction(EmbeddingFunction):
+    def __init__(self, config: Optional[GoogleAIEmbedderConfig] = None) -> None:
+        super().__init__()
+        self.config = config or GoogleAIEmbedderConfig()
+
+    def __call__(self, input: str) -> Embeddings:
+        model = self.config.model
+        title = self.config.title
+        task_type = self.config.task_type
+        embeddings = genai.embed_content(model=model, content=input, task_type=task_type, title=title)
+        return embeddings["embedding"]
+
+
+class GoogleAIEmbedder(BaseEmbedder):
+    def __init__(self, config: Optional[GoogleAIEmbedderConfig] = None):
+        super().__init__(config)
+        embedding_fn = GoogleAIEmbeddingFunction(config=config)
+        self.set_embedding_fn(embedding_fn=embedding_fn)
+
+        vector_dimension = VectorDimensions.GOOGLE_AI.value
+        self.set_vector_dimension(vector_dimension=vector_dimension)

+ 2 - 0
embedchain/factory.py

@@ -47,11 +47,13 @@ class EmbedderFactory:
         "huggingface": "embedchain.embedder.huggingface.HuggingFaceEmbedder",
         "openai": "embedchain.embedder.openai.OpenAIEmbedder",
         "vertexai": "embedchain.embedder.vertexai.VertexAIEmbedder",
+        "google": "embedchain.embedder.google.GoogleAIEmbedder",
     }
     provider_to_config_class = {
         "azure_openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
         "openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
         "gpt4all": "embedchain.config.embedder.base.BaseEmbedderConfig",
+        "google": "embedchain.config.embedder.google.GoogleAIEmbedderConfig",
     }
 
     @classmethod

+ 3 - 19
embedchain/llm/base.py

@@ -146,21 +146,7 @@ class BaseLlm(JSONSerializable):
         logging.info(f"Access search to get answers for {input_query}")
         return search.run(input_query)
 
-    def _stream_query_response(self, answer: Any) -> Generator[Any, Any, None]:
-        """Generator to be used as streaming response
-
-        :param answer: Answer chunk from llm
-        :type answer: Any
-        :yield: Answer chunk from llm
-        :rtype: Generator[Any, Any, None]
-        """
-        streamed_answer = ""
-        for chunk in answer:
-            streamed_answer = streamed_answer + chunk
-            yield chunk
-        logging.info(f"Answer: {streamed_answer}")
-
-    def _stream_chat_response(self, answer: Any) -> Generator[Any, Any, None]:
+    def _stream_response(self, answer: Any) -> Generator[Any, Any, None]:
         """Generator to be used as streaming response
 
         :param answer: Answer chunk from llm
@@ -220,7 +206,7 @@ class BaseLlm(JSONSerializable):
                 logging.info(f"Answer: {answer}")
                 return answer
             else:
-                return self._stream_query_response(answer)
+                return self._stream_response(answer)
         finally:
             if config:
                 # Restore previous config
@@ -269,14 +255,12 @@ class BaseLlm(JSONSerializable):
                 return prompt
 
             answer = self.get_answer_from_llm(prompt)
-
             if isinstance(answer, str):
                 logging.info(f"Answer: {answer}")
-
                 return answer
             else:
                 # this is a streamed response and needs to be handled differently.
-                return self._stream_chat_response(answer)
+                return self._stream_response(answer)
         finally:
             if config:
                 # Restore previous config

+ 14 - 14
embedchain/llm/google.py

@@ -1,7 +1,7 @@
 import importlib
 import logging
 import os
-from typing import Optional
+from typing import Any, Generator, Optional, Union
 
 import google.generativeai as genai
 
@@ -30,22 +30,22 @@ class GoogleLlm(BaseLlm):
     def get_llm_model_answer(self, prompt):
         if self.config.system_prompt:
             raise ValueError("GoogleLlm does not support `system_prompt`")
-        return GoogleLlm._get_answer(prompt, self.config)
+        response = self._get_answer(prompt)
+        return response
 
-    @staticmethod
-    def _get_answer(prompt: str, config: BaseLlmConfig):
-        model_name = config.model or "gemini-pro"
+    def _get_answer(self, prompt: str) -> Union[str, Generator[Any, Any, None]]:
+        model_name = self.config.model or "gemini-pro"
         logging.info(f"Using Google LLM model: {model_name}")
         model = genai.GenerativeModel(model_name=model_name)
 
         generation_config_params = {
             "candidate_count": 1,
-            "max_output_tokens": config.max_tokens,
-            "temperature": config.temperature or 0.5,
+            "max_output_tokens": self.config.max_tokens,
+            "temperature": self.config.temperature or 0.5,
         }
 
-        if config.top_p >= 0.0 and config.top_p <= 1.0:
-            generation_config_params["top_p"] = config.top_p
+        if self.config.top_p >= 0.0 and self.config.top_p <= 1.0:
+            generation_config_params["top_p"] = self.config.top_p
         else:
             raise ValueError("`top_p` must be > 0.0 and < 1.0")
 
@@ -54,11 +54,11 @@ class GoogleLlm(BaseLlm):
         response = model.generate_content(
             prompt,
             generation_config=generation_config,
-            stream=config.stream,
+            stream=self.config.stream,
         )
-
-        if config.stream:
-            for chunk in response:
-                yield chunk.text
+        if self.config.stream:
+            # TODO: Implement streaming
+            response.resolve()
+            return response.text
         else:
             return response.text

+ 1 - 0
embedchain/models/vector_dimensions.py

@@ -7,3 +7,4 @@ class VectorDimensions(Enum):
     OPENAI = 1536
     VERTEX_AI = 768
     HUGGING_FACE = 384
+    GOOGLE_AI = 768

+ 2 - 2
embedchain/utils.py

@@ -411,14 +411,14 @@ def validate_config(config_data):
                 Optional("config"): object,  # TODO: add particular config schema for each provider
             },
             Optional("embedder"): {
-                Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai", "azure_openai"),
+                Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai", "azure_openai", "google"),
                 Optional("config"): {
                     Optional("model"): Optional(str),
                     Optional("deployment_name"): Optional(str),
                 },
             },
             Optional("embedding_model"): {
-                Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai", "azure_openai"),
+                Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai", "azure_openai", "google"),
                 Optional("config"): {
                     Optional("model"): str,
                     Optional("deployment_name"): str,

+ 1 - 1
pyproject.toml

@@ -1,6 +1,6 @@
 [tool.poetry]
 name = "embedchain"
-version = "0.1.33"
+version = "0.1.34"
 description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
 authors = [
     "Taranjeet Singh <taranjeet@embedchain.ai>",

+ 2 - 8
tests/llm/test_base_llm.py

@@ -38,15 +38,9 @@ def test_is_get_llm_model_answer_implemented():
     assert llm.get_llm_model_answer() == "Implemented"
 
 
-def test_stream_query_response(base_llm):
+def test_stream_response(base_llm):
     answer = ["Chunk1", "Chunk2", "Chunk3"]
-    result = list(base_llm._stream_query_response(answer))
-    assert result == answer
-
-
-def test_stream_chat_response(base_llm):
-    answer = ["Chunk1", "Chunk2", "Chunk3"]
-    result = list(base_llm._stream_chat_response(answer))
+    result = list(base_llm._stream_response(answer))
     assert result == answer