Pārlūkot izejas kodu

[Feature] Add support for Google Gemini (#1009)

Deven Patel 1 gadu atpakaļ
vecāks
revīzija
151746beec

+ 8 - 0
configs/google.yaml

@@ -0,0 +1,8 @@
+llm:
+  provider: google
+  config:
+    model: gemini-pro
+    max_tokens: 1000
+    temperature: 0.9
+    top_p: 1.0
+    stream: false

+ 36 - 0
docs/components/llms.mdx

@@ -8,6 +8,7 @@ Embedchain comes with built-in support for various popular large language models
 
 <CardGroup cols={4}>
   <Card title="OpenAI" href="#openai"></Card>
+  <Card title="Google AI" href="#google-ai"></Card>
   <Card title="Azure OpenAI" href="#azure-openai"></Card>
   <Card title="Anthropic" href="#anthropic"></Card>
   <Card title="Cohere" href="#cohere"></Card>
@@ -62,6 +63,41 @@ llm:
 
 </CodeGroup>
 
+## Google AI
+
+To use Google AI model, you have to set the `GOOGLE_API_KEY` environment variable. You can obtain the Google API key from the [Google Maker Suite](https://makersuite.google.com/app/apikey)
+
+<CodeGroup>
+```python main.py
+import os
+from embedchain import Pipeline as App
+
+os.environ["OPENAI_API_KEY"] = "sk-xxxx"
+os.environ["GOOGLE_API_KEY"] = "xxx"
+
+app = App.from_config(config_path="config.yaml")
+
+app.add("https://www.forbes.com/profile/elon-musk")
+
+response = app.query("What is the net worth of Elon Musk?")
+if app.llm.config.stream: # if stream is enabled, response is a generator
+    for chunk in response:
+        print(chunk)
+else:
+    print(response)
+```
+
+```yaml config.yaml
+llm:
+  provider: google
+  config:
+    model: gemini-pro
+    max_tokens: 1000
+    temperature: 0.5
+    top_p: 1
+    stream: false
+```
+</CodeGroup>
 
 ## Azure OpenAI
 

+ 1 - 0
embedchain/factory.py

@@ -18,6 +18,7 @@ class LlmFactory:
         "llama2": "embedchain.llm.llama2.Llama2Llm",
         "openai": "embedchain.llm.openai.OpenAILlm",
         "vertexai": "embedchain.llm.vertex_ai.VertexAILlm",
+        "google": "embedchain.llm.google.GoogleLlm",
     }
     provider_to_config_class = {
         "embedchain": "embedchain.config.llm.base.BaseLlmConfig",

+ 0 - 1
embedchain/llm/base.py

@@ -217,7 +217,6 @@ class BaseLlm(JSONSerializable):
                 return prompt
 
             answer = self.get_answer_from_llm(prompt)
-
             if isinstance(answer, str):
                 logging.info(f"Answer: {answer}")
                 return answer

+ 64 - 0
embedchain/llm/google.py

@@ -0,0 +1,64 @@
+import importlib
+import logging
+import os
+from typing import Optional
+
+import google.generativeai as genai
+
+from embedchain.config import BaseLlmConfig
+from embedchain.helpers.json_serializable import register_deserializable
+from embedchain.llm.base import BaseLlm
+
+
+@register_deserializable
+class GoogleLlm(BaseLlm):
+    def __init__(self, config: Optional[BaseLlmConfig] = None):
+        if "GOOGLE_API_KEY" not in os.environ:
+            raise ValueError("Please set the GOOGLE_API_KEY environment variable.")
+
+        try:
+            importlib.import_module("google.generativeai")
+        except ModuleNotFoundError:
+            raise ModuleNotFoundError(
+                "The required dependencies for GoogleLlm are not installed."
+                'Please install with `pip install --upgrade "embedchain[google]"`'
+            ) from None
+
+        super().__init__(config)
+        genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
+
+    def get_llm_model_answer(self, prompt):
+        if self.config.system_prompt:
+            raise ValueError("GoogleLlm does not support `system_prompt`")
+        return GoogleLlm._get_answer(prompt, self.config)
+
+    @staticmethod
+    def _get_answer(prompt: str, config: BaseLlmConfig):
+        model_name = config.model or "gemini-pro"
+        logging.info(f"Using Google LLM model: {model_name}")
+        model = genai.GenerativeModel(model_name=model_name)
+
+        generation_config_params = {
+            "candidate_count": 1,
+            "max_output_tokens": config.max_tokens,
+            "temperature": config.temperature or 0.5,
+        }
+
+        if config.top_p >= 0.0 and config.top_p <= 1.0:
+            generation_config_params["top_p"] = config.top_p
+        else:
+            raise ValueError("`top_p` must be > 0.0 and < 1.0")
+
+        generation_config = genai.types.GenerationConfig(**generation_config_params)
+
+        response = model.generate_content(
+            prompt,
+            generation_config=generation_config,
+            stream=config.stream,
+        )
+
+        if config.stream:
+            for chunk in response:
+                yield chunk.text
+        else:
+            return response.text

+ 1 - 0
embedchain/utils.py

@@ -387,6 +387,7 @@ def validate_config(config_data):
                     "jina",
                     "llama2",
                     "vertexai",
+                    "google",
                 ),
                 Optional("config"): {
                     Optional("model"): str,

+ 51 - 12
poetry.lock

@@ -1,4 +1,4 @@
-# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
+# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand.
 
 [[package]]
 name = "aiofiles"
@@ -1762,6 +1762,22 @@ gitdb = ">=4.0.1,<5"
 [package.extras]
 test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest", "pytest-cov", "pytest-instafail", "pytest-subtests", "pytest-sugar"]
 
+[[package]]
+name = "google-ai-generativelanguage"
+version = "0.4.0"
+description = "Google Ai Generativelanguage API client library"
+optional = true
+python-versions = ">=3.7"
+files = [
+    {file = "google-ai-generativelanguage-0.4.0.tar.gz", hash = "sha256:c8199066c08f74c4e91290778329bb9f357ba1ea5d6f82de2bc0d10552bf4f8c"},
+    {file = "google_ai_generativelanguage-0.4.0-py3-none-any.whl", hash = "sha256:e4c425376c1ee26c78acbc49a24f735f90ebfa81bf1a06495fae509a2433232c"},
+]
+
+[package.dependencies]
+google-api-core = {version = ">=1.34.0,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]}
+proto-plus = ">=1.22.3,<2.0.0dev"
+protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev"
+
 [[package]]
 name = "google-api-core"
 version = "2.12.0"
@@ -1777,12 +1793,12 @@ files = [
 google-auth = ">=2.14.1,<3.0.dev0"
 googleapis-common-protos = ">=1.56.2,<2.0.dev0"
 grpcio = [
+    {version = ">=1.33.2,<2.0dev", optional = true, markers = "extra == \"grpc\""},
     {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
-    {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
 ]
 grpcio-status = [
+    {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "extra == \"grpc\""},
     {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
-    {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
 ]
 protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0"
 requests = ">=2.18.0,<3.0.0.dev0"
@@ -1870,8 +1886,8 @@ google-api-core = {version = ">=1.31.5,<2.0.dev0 || >2.3.0,<3.0.0dev", extras =
 google-cloud-core = ">=1.6.0,<3.0.0dev"
 google-resumable-media = ">=0.6.0,<3.0dev"
 grpcio = [
-    {version = ">=1.49.1,<2.0dev", markers = "python_version >= \"3.11\""},
     {version = ">=1.47.0,<2.0dev", markers = "python_version < \"3.11\""},
+    {version = ">=1.49.1,<2.0dev", markers = "python_version >= \"3.11\""},
 ]
 packaging = ">=20.0.0"
 proto-plus = ">=1.15.0,<2.0.0dev"
@@ -1922,8 +1938,8 @@ files = [
 google-api-core = {version = ">=1.34.0,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]}
 grpc-google-iam-v1 = ">=0.12.4,<1.0.0dev"
 proto-plus = [
-    {version = ">=1.22.2,<2.0.0dev", markers = "python_version >= \"3.11\""},
     {version = ">=1.22.0,<2.0.0dev", markers = "python_version < \"3.11\""},
+    {version = ">=1.22.2,<2.0.0dev", markers = "python_version >= \"3.11\""},
 ]
 protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev"
 
@@ -2029,6 +2045,26 @@ files = [
 [package.extras]
 testing = ["pytest"]
 
+[[package]]
+name = "google-generativeai"
+version = "0.3.1"
+description = "Google Generative AI High level API client library and tools."
+optional = true
+python-versions = ">=3.9"
+files = [
+    {file = "google_generativeai-0.3.1-py3-none-any.whl", hash = "sha256:800ec6041ca537b897d7ba654f4125651c64b38506f2bfce3b464370e3333a1b"},
+]
+
+[package.dependencies]
+google-ai-generativelanguage = "0.4.0"
+google-api-core = "*"
+google-auth = "*"
+protobuf = "*"
+tqdm = "*"
+
+[package.extras]
+dev = ["Pillow", "absl-py", "black", "ipython", "nose2", "pandas", "pytype", "pyyaml"]
+
 [[package]]
 name = "google-resumable-media"
 version = "2.6.0"
@@ -4004,11 +4040,13 @@ files = [
 
 [package.dependencies]
 numpy = [
+    {version = ">=1.21.0", markers = "python_version <= \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""},
+    {version = ">=1.21.2", markers = "python_version >= \"3.10\""},
+    {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\""},
+    {version = ">=1.19.3", markers = "python_version >= \"3.6\" and platform_system == \"Linux\" and platform_machine == \"aarch64\" or python_version >= \"3.9\""},
+    {version = ">=1.17.0", markers = "python_version >= \"3.7\""},
+    {version = ">=1.17.3", markers = "python_version >= \"3.8\""},
     {version = ">=1.23.5", markers = "python_version >= \"3.11\""},
-    {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""},
-    {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
-    {version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""},
-    {version = ">=1.21.0", markers = "python_version == \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""},
 ]
 
 [[package]]
@@ -4202,8 +4240,8 @@ files = [
 
 [package.dependencies]
 numpy = [
-    {version = ">=1.23.2", markers = "python_version == \"3.11\""},
     {version = ">=1.22.4", markers = "python_version < \"3.11\""},
+    {version = ">=1.23.2", markers = "python_version == \"3.11\""},
 ]
 python-dateutil = ">=2.8.2"
 pytz = ">=2020.1"
@@ -6344,7 +6382,7 @@ files = [
 ]
 
 [package.dependencies]
-greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\" or extra == \"asyncio\""}
+greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\" or extra == \"asyncio\""}
 typing-extensions = ">=4.2.0"
 
 [package.extras]
@@ -7820,6 +7858,7 @@ discord = ["discord"]
 elasticsearch = ["elasticsearch"]
 github = ["PyGithub", "gitpython"]
 gmail = ["llama-hub", "requests"]
+google = ["google-generativeai"]
 huggingface-hub = ["huggingface_hub"]
 images = ["ftfy", "pillow", "regex", "torch", "torchvision"]
 json = ["llama-hub"]
@@ -7843,4 +7882,4 @@ youtube = ["youtube-transcript-api", "yt_dlp"]
 [metadata]
 lock-version = "2.0"
 python-versions = ">=3.9,<3.12"
-content-hash = "776ae7f49adab8a5dc98f6fe7c2887d2e700fd2d7c447383ea81ef05a463c8f3"
+content-hash = "846cca158ccd7a2ecc3d0a08218d273846aad15e0ac5c19ddeb1b5de00fa9a3f"

+ 7 - 1
pyproject.toml

@@ -143,6 +143,7 @@ PyGithub = { version = "^1.59.1", optional = true }
 feedparser = { version = "^6.0.10", optional = true }
 newspaper3k = { version = "^0.2.8", optional = true }
 listparser = { version = "^0.19", optional = true }
+google-generativeai = { version = "^0.3.0", optional = true }
 
 [tool.poetry.group.dev.dependencies]
 black = "^23.3.0"
@@ -204,7 +205,12 @@ youtube = [
     "yt_dlp",
     "youtube-transcript-api",
 ]
-rss_feed = ["feedparser", "listparser", "newspaper3k"]
+rss_feed = [
+    "feedparser",
+    "listparser",
+    "newspaper3k"
+]
+google = ["google-generativeai"]
 
 [tool.poetry.group.docs.dependencies]
 

+ 43 - 0
tests/llm/test_google.py

@@ -0,0 +1,43 @@
+import pytest
+
+from embedchain.config import BaseLlmConfig
+from embedchain.llm.google import GoogleLlm
+
+
+@pytest.fixture
+def google_llm_config():
+    return BaseLlmConfig(model="gemini-pro", max_tokens=100, temperature=0.7, top_p=0.5, stream=False)
+
+
+def test_google_llm_init_missing_api_key(monkeypatch):
+    monkeypatch.delenv("GOOGLE_API_KEY", raising=False)
+    with pytest.raises(ValueError, match="Please set the GOOGLE_API_KEY environment variable."):
+        GoogleLlm()
+
+
+def test_google_llm_init(monkeypatch):
+    monkeypatch.setenv("GOOGLE_API_KEY", "fake_api_key")
+    with monkeypatch.context() as m:
+        m.setattr("importlib.import_module", lambda x: None)
+        google_llm = GoogleLlm()
+    assert google_llm is not None
+
+
+def test_google_llm_get_llm_model_answer_with_system_prompt(monkeypatch):
+    monkeypatch.setenv("GOOGLE_API_KEY", "fake_api_key")
+    monkeypatch.setattr("importlib.import_module", lambda x: None)
+    google_llm = GoogleLlm(config=BaseLlmConfig(system_prompt="system prompt"))
+    with pytest.raises(ValueError, match="GoogleLlm does not support `system_prompt`"):
+        google_llm.get_llm_model_answer("test prompt")
+
+
+def test_google_llm_get_llm_model_answer(monkeypatch, google_llm_config):
+    def mock_get_answer(prompt, config):
+        return "Generated Text"
+
+    monkeypatch.setenv("GOOGLE_API_KEY", "fake_api_key")
+    monkeypatch.setattr(GoogleLlm, "_get_answer", mock_get_answer)
+    google_llm = GoogleLlm(config=google_llm_config)
+    result = google_llm.get_llm_model_answer("test prompt")
+
+    assert result == "Generated Text"