Selaa lähdekoodia

[Feature] OpenAI Function Calling (#1224)

UnMonsieur 1 vuosi sitten
vanhempi
commit
41bd258b93

+ 49 - 99
docs/components/llms.mdx

@@ -68,125 +68,75 @@ llm:
 </CodeGroup>
 
 ### Function Calling
-To enable [function calling](https://platform.openai.com/docs/guides/function-calling) in your application using embedchain and OpenAI, you need to pass functions into `OpenAILlm` class as an array of functions. Here are several ways in which you can achieve that:
+Embedchain supports OpenAI [Function calling](https://platform.openai.com/docs/guides/function-calling) with a single function. It accepts inputs in accordance with the [Langchain interface](https://python.langchain.com/docs/modules/model_io/chat/function_calling#legacy-args-functions-and-function_call).
 
-Examples:
-<Accordion title="Using Pydantic Models">
+<Accordion title="Pydantic Model">
   ```python
-import os
-from embedchain import App
-from embedchain.llm.openai import OpenAILlm
-import requests
-from pydantic import BaseModel, Field, ValidationError, field_validator
-
-os.environ["OPENAI_API_KEY"] = "sk-xxx"
-
-class QA(BaseModel):
-    """
-    A question and answer pair.
-    """
-
-    question: str = Field(
-        ..., description="The question.", example="What is a mountain?"
-    )
-    answer: str = Field(
-        ..., description="The answer.", example="A mountain is a hill."
-    )
-    person_who_is_asking: str = Field(
-        ..., description="The person who is asking the question.", example="John"
-    )
-
-    @field_validator("question")
-    def question_must_end_with_a_question_mark(cls, v):
-        """
-        Validate that the question ends with a question mark.
-        """
-        if not v.endswith("?"):
-            raise ValueError("question must end with a question mark")
-        return v
-
-    @field_validator("answer")
-    def answer_must_end_with_a_period(cls, v):
-        """
-        Validate that the answer ends with a period.
-        """
-        if not v.endswith("."):
-            raise ValueError("answer must end with a period")
-        return v
-
-llm = OpenAILlm(config=None,functions=[QA])
-app = App(llm=llm)
+  from pydantic import BaseModel
 
-result = app.query("Hey I am Sid. What is a mountain? A mountain is a hill.")
+  class multiply(BaseModel):
+      """Multiply two integers together."""
 
-print(result)
+      a: int = Field(..., description="First integer")
+      b: int = Field(..., description="Second integer")
   ```
-  </Accordion>
+</Accordion>
   
-  <Accordion title="Using OpenAI JSON schema">
-```python
-import os
-from embedchain import App
-from embedchain.llm.openai import OpenAILlm
-import requests
-from pydantic import BaseModel, Field, ValidationError, field_validator
-
-os.environ["OPENAI_API_KEY"] = "sk-xxx"
-
-json_schema = {
-    "name": "get_qa",
-    "description": "A question and answer pair and the user who is asking the question.",
-    "parameters": {
+<Accordion title="Python function">
+  ```python
+  def multiply(a: int, b: int) -> int:
+      """Multiply two integers together.
+
+      Args:
+          a: First integer
+          b: Second integer
+      """
+      return a * b
+  ```
+</Accordion>
+<Accordion title="OpenAI tool dictionary">
+  ```python
+  multiply = {
+    "type": "function",
+    "function": {
+      "name": "multiply",
+      "description": "Multiply two integers together.",
+      "parameters": {
         "type": "object",
         "properties": {
-            "question": {"type": "string", "description": "The question."},
-            "answer": {"type": "string", "description": "The answer."},
-            "person_who_is_asking": {
-                "type": "string",
-                "description": "The person who is asking the question.",
-            }
+          "a": {
+            "description": "First integer",
+            "type": "integer"
+          },
+          "b": {
+            "description": "Second integer",
+            "type": "integer"
+          }
         },
-        "required": ["question", "answer", "person_who_is_asking"],
-    },
-}
-
-llm = OpenAILlm(config=None,functions=[json_schema])
-app = App(llm=llm)
+        "required": [
+          "a",
+          "b"
+        ]
+      }
+    }
+  }
+  ```
+</Accordion>
 
-result = app.query("Hey I am Sid. What is a mountain? A mountain is a hill.")
+With any of the previous inputs, the OpenAI LLM can be queried to provide the appropriate arguments for the function.
 
-print(result)
-  ```
-  </Accordion>
-  <Accordion title="Using actual python functions">
-  ```python
+```python
 import os
 from embedchain import App
 from embedchain.llm.openai import OpenAILlm
-import requests
-from pydantic import BaseModel, Field, ValidationError, field_validator
 
 os.environ["OPENAI_API_KEY"] = "sk-xxx"
 
-def find_info_of_pokemon(pokemon: str):
-    """
-    Find the information of the given pokemon.
-    Args:
-        pokemon: The pokemon.
-    """
-    req = requests.get(f"https://pokeapi.co/api/v2/pokemon/{pokemon}")
-    if req.status_code == 404:
-        raise ValueError("pokemon not found")
-    return req.json()
-
-llm = OpenAILlm(config=None,functions=[find_info_of_pokemon])
+llm = OpenAILlm(tools=multiply)
 app = App(llm=llm)
 
-result = app.query("Tell me more about the pokemon pikachu.")
-
-print(result)
+result = app.query("What is the result of 125 multiplied by fifteen?")
 ```
-</Accordion>
 
 ## Google AI
 

+ 1 - 1
embedchain/embedder/huggingface.py

@@ -1,6 +1,6 @@
 from typing import Optional
 
-from langchain.embeddings import HuggingFaceEmbeddings
+from langchain_community.embeddings import HuggingFaceEmbeddings
 
 from embedchain.config import BaseEmbedderConfig
 from embedchain.embedder.base import BaseEmbedder

+ 1 - 1
embedchain/embedder/openai.py

@@ -2,7 +2,7 @@ import os
 from typing import Optional
 
 from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction
-from langchain.embeddings import AzureOpenAIEmbeddings
+from langchain_community.embeddings import AzureOpenAIEmbeddings
 
 from embedchain.config import BaseEmbedderConfig
 from embedchain.embedder.base import BaseEmbedder

+ 1 - 1
embedchain/embedder/vertexai.py

@@ -1,6 +1,6 @@
 from typing import Optional
 
-from langchain.embeddings import VertexAIEmbeddings
+from langchain_community.embeddings import VertexAIEmbeddings
 
 from embedchain.config import BaseEmbedderConfig
 from embedchain.embedder.base import BaseEmbedder

+ 1 - 1
embedchain/llm/anthropic.py

@@ -19,7 +19,7 @@ class AnthropicLlm(BaseLlm):
 
     @staticmethod
     def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
-        from langchain.chat_models import ChatAnthropic
+        from langchain_community.chat_models import ChatAnthropic
 
         chat = ChatAnthropic(
             anthropic_api_key=os.environ["ANTHROPIC_API_KEY"], temperature=config.temperature, model=config.model

+ 1 - 1
embedchain/llm/aws_bedrock.py

@@ -1,7 +1,7 @@
 import os
 from typing import Optional
 
-from langchain.llms import Bedrock
+from langchain_community.llms import Bedrock
 
 from embedchain.config import BaseLlmConfig
 from embedchain.helpers.json_serializable import register_deserializable

+ 1 - 1
embedchain/llm/azure_openai.py

@@ -16,7 +16,7 @@ class AzureOpenAILlm(BaseLlm):
 
     @staticmethod
     def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
-        from langchain.chat_models import AzureChatOpenAI
+        from langchain_community.chat_models import AzureChatOpenAI
 
         if not config.deployment_name:
             raise ValueError("Deployment name must be provided for Azure OpenAI")

+ 1 - 1
embedchain/llm/cohere.py

@@ -2,7 +2,7 @@ import importlib
 import os
 from typing import Optional
 
-from langchain.llms.cohere import Cohere
+from langchain_community.llms.cohere import Cohere
 
 from embedchain.config import BaseLlmConfig
 from embedchain.helpers.json_serializable import register_deserializable

+ 2 - 1
embedchain/llm/gpt4all.py

@@ -26,7 +26,8 @@ class GPT4ALLLlm(BaseLlm):
     @staticmethod
     def _get_instance(model):
         try:
-            from langchain.llms.gpt4all import GPT4All as LangchainGPT4All
+            from langchain_community.llms.gpt4all import \
+                GPT4All as LangchainGPT4All
         except ModuleNotFoundError:
             raise ModuleNotFoundError(
                 "The GPT4All python package is not installed. Please install it with `pip install --upgrade embedchain[opensource]`"  # noqa E501

+ 2 - 2
embedchain/llm/huggingface.py

@@ -3,8 +3,8 @@ import logging
 import os
 from typing import Optional
 
-from langchain.llms.huggingface_endpoint import HuggingFaceEndpoint
-from langchain.llms.huggingface_hub import HuggingFaceHub
+from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
+from langchain_community.llms.huggingface_hub import HuggingFaceHub
 
 from embedchain.config import BaseLlmConfig
 from embedchain.helpers.json_serializable import register_deserializable

+ 1 - 1
embedchain/llm/jina.py

@@ -1,8 +1,8 @@
 import os
 from typing import Optional
 
-from langchain.chat_models import JinaChat
 from langchain.schema import HumanMessage, SystemMessage
+from langchain_community.chat_models import JinaChat
 
 from embedchain.config import BaseLlmConfig
 from embedchain.helpers.json_serializable import register_deserializable

+ 1 - 1
embedchain/llm/llama2.py

@@ -2,7 +2,7 @@ import importlib
 import os
 from typing import Optional
 
-from langchain.llms.replicate import Replicate
+from langchain_community.llms.replicate import Replicate
 
 from embedchain.config import BaseLlmConfig
 from embedchain.helpers.json_serializable import register_deserializable

+ 1 - 1
embedchain/llm/ollama.py

@@ -4,7 +4,7 @@ from typing import Optional, Union
 from langchain.callbacks.manager import CallbackManager
 from langchain.callbacks.stdout import StdOutCallbackHandler
 from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
-from langchain.llms.ollama import Ollama
+from langchain_community.llms.ollama import Ollama
 
 from embedchain.config import BaseLlmConfig
 from embedchain.helpers.json_serializable import register_deserializable

+ 34 - 22
embedchain/llm/openai.py

@@ -1,9 +1,11 @@
 import json
 import os
-from typing import Any, Optional
+from typing import Any, Callable, Dict, Optional, Type, Union
 
-from langchain.chat_models import ChatOpenAI
-from langchain.schema import AIMessage, HumanMessage, SystemMessage
+from langchain.schema import BaseMessage, HumanMessage, SystemMessage
+from langchain_core.tools import BaseTool
+from langchain_openai import ChatOpenAI
+from pydantic import BaseModel
 
 from embedchain.config import BaseLlmConfig
 from embedchain.helpers.json_serializable import register_deserializable
@@ -12,8 +14,12 @@ from embedchain.llm.base import BaseLlm
 
 @register_deserializable
 class OpenAILlm(BaseLlm):
-    def __init__(self, config: Optional[BaseLlmConfig] = None, functions: Optional[dict[str, Any]] = None):
-        self.functions = functions
+    def __init__(
+        self,
+        config: Optional[BaseLlmConfig] = None,
+        tools: Optional[Union[Dict[str, Any], Type[BaseModel], Callable[..., Any], BaseTool]] = None,
+    ):
+        self.tools = tools
         super().__init__(config=config)
 
     def get_llm_model_answer(self, prompt) -> str:
@@ -38,21 +44,27 @@ class OpenAILlm(BaseLlm):
             from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 
             callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()]
-            llm = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key)
+            chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key)
         else:
-            llm = ChatOpenAI(**kwargs, api_key=api_key)
-
-        if self.functions is not None:
-            from langchain.chains.openai_functions import create_openai_fn_runnable
-            from langchain.prompts import ChatPromptTemplate
-
-            structured_prompt = ChatPromptTemplate.from_messages(messages)
-            runnable = create_openai_fn_runnable(functions=self.functions, prompt=structured_prompt, llm=llm)
-            fn_res = runnable.invoke(
-                {
-                    "input": prompt,
-                }
-            )
-            messages.append(AIMessage(content=json.dumps(fn_res)))
-
-        return llm(messages).content
+            chat = ChatOpenAI(**kwargs, api_key=api_key)
+        if self.tools:
+            return self._query_function_call(chat, self.tools, messages)
+
+        return chat.invoke(messages).content
+
+    def _query_function_call(
+        self,
+        chat: ChatOpenAI,
+        tools: Optional[Union[Dict[str, Any], Type[BaseModel], Callable[..., Any], BaseTool]],
+        messages: list[BaseMessage],
+    ) -> str:
+        from langchain.output_parsers.openai_tools import JsonOutputToolsParser
+        from langchain_core.utils.function_calling import \
+            convert_to_openai_tool
+
+        openai_tools = [convert_to_openai_tool(tools)]
+        chat = chat.bind(tools=openai_tools).pipe(JsonOutputToolsParser())
+        try:
+            return json.dumps(chat.invoke(messages)[0])
+        except IndexError:
+            return "Input could not be mapped to the function!"

+ 1 - 1
embedchain/llm/together.py

@@ -2,7 +2,7 @@ import importlib
 import os
 from typing import Optional
 
-from langchain.llms import Together
+from langchain_community.llms import Together
 
 from embedchain.config import BaseLlmConfig
 from embedchain.helpers.json_serializable import register_deserializable

+ 1 - 1
embedchain/llm/vertex_ai.py

@@ -24,7 +24,7 @@ class VertexAILlm(BaseLlm):
 
     @staticmethod
     def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
-        from langchain.chat_models import ChatVertexAI
+        from langchain_community.chat_models import ChatVertexAI
 
         chat = ChatVertexAI(temperature=config.temperature, model=config.model)
 

+ 1 - 1
embedchain/loaders/docx_file.py

@@ -1,7 +1,7 @@
 import hashlib
 
 try:
-    from langchain.document_loaders import Docx2txtLoader
+    from langchain_community.document_loaders import Docx2txtLoader
 except ImportError:
     raise ImportError(
         'Docx file requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'

+ 2 - 2
embedchain/loaders/google_drive.py

@@ -8,8 +8,8 @@ except ImportError:
         "Google Drive requires extra dependencies. Install with `pip install embedchain[googledrive]`"
     ) from None
 
-from langchain.document_loaders import GoogleDriveLoader as Loader
-from langchain.document_loaders import UnstructuredFileIOLoader
+from langchain_community.document_loaders import GoogleDriveLoader as Loader
+from langchain_community.document_loaders import UnstructuredFileIOLoader
 
 from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.loaders.base_loader import BaseLoader

+ 1 - 1
embedchain/loaders/pdf_file.py

@@ -1,7 +1,7 @@
 import hashlib
 
 try:
-    from langchain.document_loaders import PyPDFLoader
+    from langchain_community.document_loaders import PyPDFLoader
 except ImportError:
     raise ImportError(
         'PDF File requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'

+ 1 - 1
embedchain/loaders/rss_feed.py

@@ -28,7 +28,7 @@ class RSSFeedLoader(BaseLoader):
     @staticmethod
     def get_rss_content(url: str):
         try:
-            from langchain.document_loaders import \
+            from langchain_community.document_loaders import \
                 RSSFeedLoader as LangchainRSSFeedLoader
         except ImportError:
             raise ImportError(

+ 2 - 1
embedchain/loaders/unstructured_file.py

@@ -10,7 +10,8 @@ class UnstructuredLoader(BaseLoader):
     def load_data(self, url):
         """Load data from an Unstructured file."""
         try:
-            from langchain.document_loaders import UnstructuredFileLoader
+            from langchain_community.document_loaders import \
+                UnstructuredFileLoader
         except ImportError:
             raise ImportError(
                 'Unstructured file requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'  # noqa: E501

+ 1 - 1
embedchain/loaders/xml.py

@@ -1,7 +1,7 @@
 import hashlib
 
 try:
-    from langchain.document_loaders import UnstructuredXMLLoader
+    from langchain_community.document_loaders import UnstructuredXMLLoader
 except ImportError:
     raise ImportError(
         'XML file requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'

+ 1 - 1
embedchain/loaders/youtube_video.py

@@ -1,7 +1,7 @@
 import hashlib
 
 try:
-    from langchain.document_loaders import YoutubeLoader
+    from langchain_community.document_loaders import YoutubeLoader
 except ImportError:
     raise ImportError(
         'YouTube video requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'

+ 2 - 2
embedchain/vectordb/opensearch.py

@@ -12,8 +12,8 @@ except ImportError:
         "OpenSearch requires extra dependencies. Install with `pip install --upgrade embedchain[opensearch]`"
     ) from None
 
-from langchain.embeddings.openai import OpenAIEmbeddings
-from langchain.vectorstores import OpenSearchVectorSearch
+from langchain_community.embeddings.openai import OpenAIEmbeddings
+from langchain_community.vectorstores import OpenSearchVectorSearch
 
 from embedchain.config import OpenSearchDBConfig
 from embedchain.helpers.json_serializable import register_deserializable

Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 6 - 5
poetry.lock


+ 3 - 2
pyproject.toml

@@ -93,7 +93,7 @@ color = true
 [tool.poetry.dependencies]
 python = ">=3.9,<3.12"
 python-dotenv = "^1.0.0"
-langchain = "^0.0.336"
+langchain = "^0.1.4"
 requests = "^2.31.0"
 openai = ">=1.1.1"
 chromadb = "^0.4.17"
@@ -103,7 +103,7 @@ beautifulsoup4 = "^4.12.2"
 pypdf = "^3.11.0"
 gptcache = "^0.1.43"
 pysbd = "^0.3.4"
-tiktoken = { version = "^0.4.0", optional = true }
+tiktoken = { version = "^0.5.2", optional = true }
 youtube-transcript-api = { version = "^0.6.1", optional = true }
 pytube = { version = "^15.0.0", optional = true }
 duckduckgo-search = { version = "^3.8.5", optional = true }
@@ -151,6 +151,7 @@ google-auth-httplib2 = { version = "^0.2.0", optional = true }
 google-api-core = { version = "^2.15.0", optional = true }
 boto3 = { version = "^1.34.20", optional = true }
 langchain-mistralai = { version = "^0.0.3", optional = true }
+langchain-openai = "^0.0.5"
 
 [tool.poetry.group.dev.dependencies]
 black = "^23.3.0"

+ 2 - 2
tests/llm/test_anthrophic.py

@@ -24,7 +24,7 @@ def test_get_llm_model_answer(anthropic_llm):
 
 
 def test_get_answer(anthropic_llm):
-    with patch("langchain.chat_models.ChatAnthropic") as mock_chat:
+    with patch("langchain_community.chat_models.ChatAnthropic") as mock_chat:
         mock_chat_instance = mock_chat.return_value
         mock_chat_instance.return_value = MagicMock(content="Test Response")
 
@@ -53,7 +53,7 @@ def test_get_messages(anthropic_llm):
 
 
 def test_get_answer_max_tokens_is_provided(anthropic_llm, caplog):
-    with patch("langchain.chat_models.ChatAnthropic") as mock_chat:
+    with patch("langchain_community.chat_models.ChatAnthropic") as mock_chat:
         mock_chat_instance = mock_chat.return_value
         mock_chat_instance.return_value = MagicMock(content="Test Response")
 

+ 2 - 2
tests/llm/test_azure_openai.py

@@ -28,7 +28,7 @@ def test_get_llm_model_answer(azure_openai_llm):
 
 
 def test_get_answer(azure_openai_llm):
-    with patch("langchain.chat_models.AzureChatOpenAI") as mock_chat:
+    with patch("langchain_community.chat_models.AzureChatOpenAI") as mock_chat:
         mock_chat_instance = mock_chat.return_value
         mock_chat_instance.return_value = MagicMock(content="Test Response")
 
@@ -60,7 +60,7 @@ def test_get_messages(azure_openai_llm):
 
 
 def test_get_answer_top_p_is_provided(azure_openai_llm, caplog):
-    with patch("langchain.chat_models.AzureChatOpenAI") as mock_chat:
+    with patch("langchain_community.chat_models.AzureChatOpenAI") as mock_chat:
         mock_chat_instance = mock_chat.return_value
         mock_chat_instance.return_value = MagicMock(content="Test Response")
 

+ 1 - 1
tests/llm/test_gpt4all.py

@@ -1,5 +1,5 @@
 import pytest
-from langchain.llms.gpt4all import GPT4All as LangchainGPT4All
+from langchain_community.llms.gpt4all import GPT4All as LangchainGPT4All
 
 from embedchain.config import BaseLlmConfig
 from embedchain.llm.gpt4all import GPT4ALLLlm

+ 29 - 0
tests/llm/test_openai.py

@@ -74,3 +74,32 @@ def test_get_llm_model_answer_without_system_prompt(config, mocker):
         model_kwargs={"top_p": config.top_p},
         api_key=os.environ["OPENAI_API_KEY"],
     )
+
+
+@pytest.mark.parametrize(
+    "mock_return, expected",
+    [
+        ([{"test": "test"}], '{"test": "test"}'),
+        ([], "Input could not be mapped to the function!"),
+    ],
+)
+def test_get_llm_model_answer_with_tools(config, mocker, mock_return, expected):
+    mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
+    mocked_convert_to_openai_tool = mocker.patch("langchain_core.utils.function_calling.convert_to_openai_tool")
+    mocked_json_output_tools_parser = mocker.patch("langchain.output_parsers.openai_tools.JsonOutputToolsParser")
+    mocked_openai_chat.return_value.bind.return_value.pipe.return_value.invoke.return_value = mock_return
+
+    llm = OpenAILlm(config, tools={"test": "test"})
+    answer = llm.get_llm_model_answer("Test query")
+
+    mocked_openai_chat.assert_called_once_with(
+        model=config.model,
+        temperature=config.temperature,
+        max_tokens=config.max_tokens,
+        model_kwargs={"top_p": config.top_p},
+        api_key=os.environ["OPENAI_API_KEY"],
+    )
+    mocked_convert_to_openai_tool.assert_called_once_with({"test": "test"})
+    mocked_json_output_tools_parser.assert_called_once()
+
+    assert answer == expected

+ 2 - 2
tests/llm/test_vertex_ai.py

@@ -22,7 +22,7 @@ def test_get_llm_model_answer(vertexai_llm):
 
 
 def test_get_answer_with_warning(vertexai_llm, caplog):
-    with patch("langchain.chat_models.ChatVertexAI") as mock_chat:
+    with patch("langchain_community.chat_models.ChatVertexAI") as mock_chat:
         mock_chat_instance = mock_chat.return_value
         mock_chat_instance.return_value = MagicMock(content="Test Response")
 
@@ -39,7 +39,7 @@ def test_get_answer_with_warning(vertexai_llm, caplog):
 
 
 def test_get_answer_no_warning(vertexai_llm, caplog):
-    with patch("langchain.chat_models.ChatVertexAI") as mock_chat:
+    with patch("langchain_community.chat_models.ChatVertexAI") as mock_chat:
         mock_chat_instance = mock_chat.return_value
         mock_chat_instance.return_value = MagicMock(content="Test Response")
 

Kaikkia tiedostoja ei voida näyttää, sillä liian monta tiedostoa muuttui tässä diffissä