Browse Source

OpenAI function calling support (#1011)

Sidharth Mohanty 1 year ago
parent
commit
cd2c40a9c4
3 changed files with 147 additions and 13 deletions
  1. 121 1
      docs/components/llms.mdx
  2. 21 7
      embedchain/llm/openai.py
  3. 5 5
      tests/llm/test_openai.py

+ 121 - 1
docs/components/llms.mdx

@@ -60,9 +60,129 @@ llm:
     top_p: 1
     stream: false
 ```
-
 </CodeGroup>
 
+### Function Calling
+To enable [function calling](https://platform.openai.com/docs/guides/function-calling) in your application using embedchain and OpenAI, you need to pass functions into `OpenAILlm` class as an array of functions. Here are several ways in which you can achieve that:
+
+Examples:
+<Accordion title="Using Pydantic Models">
+  ```python
+import os
+from embedchain import Pipeline as App
+from embedchain.llm.openai import OpenAILlm
+import requests
+from pydantic import BaseModel, Field, ValidationError, field_validator
+
+os.environ["OPENAI_API_KEY"] = "sk-xxx"
+
+class QA(BaseModel):
+    """
+    A question and answer pair.
+    """
+
+    question: str = Field(
+        ..., description="The question.", example="What is a mountain?"
+    )
+    answer: str = Field(
+        ..., description="The answer.", example="A mountain is a hill."
+    )
+    person_who_is_asking: str = Field(
+        ..., description="The person who is asking the question.", example="John"
+    )
+
+    @field_validator("question")
+    def question_must_end_with_a_question_mark(cls, v):
+        """
+        Validate that the question ends with a question mark.
+        """
+        if not v.endswith("?"):
+            raise ValueError("question must end with a question mark")
+        return v
+
+    @field_validator("answer")
+    def answer_must_end_with_a_period(cls, v):
+        """
+        Validate that the answer ends with a period.
+        """
+        if not v.endswith("."):
+            raise ValueError("answer must end with a period")
+        return v
+
+llm = OpenAILlm(config=None,functions=[QA])
+app = App(llm=llm)
+
+result = app.query("Hey I am Sid. What is a mountain? A mountain is a hill.")
+
+print(result)
+  ```
+  </Accordion>
+  
+  <Accordion title="Using OpenAI JSON schema">
+```python
+import os
+from embedchain import Pipeline as App
+from embedchain.llm.openai import OpenAILlm
+import requests
+from pydantic import BaseModel, Field, ValidationError, field_validator
+
+os.environ["OPENAI_API_KEY"] = "sk-xxx"
+
+json_schema = {
+    "name": "get_qa",
+    "description": "A question and answer pair and the user who is asking the question.",
+    "parameters": {
+        "type": "object",
+        "properties": {
+            "question": {"type": "string", "description": "The question."},
+            "answer": {"type": "string", "description": "The answer."},
+            "person_who_is_asking": {
+                "type": "string",
+                "description": "The person who is asking the question.",
+            }
+        },
+        "required": ["question", "answer", "person_who_is_asking"],
+    },
+}
+
+llm = OpenAILlm(config=None,functions=[json_schema])
+app = App(llm=llm)
+
+result = app.query("Hey I am Sid. What is a mountain? A mountain is a hill.")
+
+print(result)
+  ```
+  </Accordion>
+  <Accordion title="Using actual python functions">
+  ```python
+import os
+from embedchain import Pipeline as App
+from embedchain.llm.openai import OpenAILlm
+import requests
+from pydantic import BaseModel, Field, ValidationError, field_validator
+
+os.environ["OPENAI_API_KEY"] = "sk-xxx"
+
+def find_info_of_pokemon(pokemon: str):
+    """
+    Find the information of the given pokemon.
+    Args:
+        pokemon: The pokemon.
+    """
+    req = requests.get(f"https://pokeapi.co/api/v2/pokemon/{pokemon}")
+    if req.status_code == 404:
+        raise ValueError("pokemon not found")
+    return req.json()
+
+llm = OpenAILlm(config=None,functions=[find_info_of_pokemon])
+app = App(llm=llm)
+
+result = app.query("Tell me more about the pokemon pikachu.")
+
+print(result)
+```
+</Accordion>
+
 ## Google AI
 
 To use Google AI model, you have to set the `GOOGLE_API_KEY` environment variable. You can obtain the Google API key from the [Google Maker Suite](https://makersuite.google.com/app/apikey)

+ 21 - 7
embedchain/llm/openai.py

@@ -1,7 +1,8 @@
-from typing import Optional
+import json
+from typing import Any, Dict, Optional
 
 from langchain.chat_models import ChatOpenAI
-from langchain.schema import HumanMessage, SystemMessage
+from langchain.schema import AIMessage, HumanMessage, SystemMessage
 
 from embedchain.config import BaseLlmConfig
 from embedchain.helpers.json_serializable import register_deserializable
@@ -10,14 +11,15 @@ from embedchain.llm.base import BaseLlm
 
 @register_deserializable
 class OpenAILlm(BaseLlm):
-    def __init__(self, config: Optional[BaseLlmConfig] = None):
+    def __init__(self, config: Optional[BaseLlmConfig] = None, functions: Optional[Dict[str, Any]] = None):
+        self.functions = functions
         super().__init__(config=config)
 
     def get_llm_model_answer(self, prompt) -> str:
-        response = OpenAILlm._get_answer(prompt, self.config)
+        response = self._get_answer(prompt, self.config)
         return response
 
-    def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
+    def _get_answer(self, prompt: str, config: BaseLlmConfig) -> str:
         messages = []
         if config.system_prompt:
             messages.append(SystemMessage(content=config.system_prompt))
@@ -31,11 +33,23 @@ class OpenAILlm(BaseLlm):
         if config.top_p:
             kwargs["model_kwargs"]["top_p"] = config.top_p
         if config.stream:
-            from langchain.callbacks.streaming_stdout import \
-                StreamingStdOutCallbackHandler
+            from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 
             callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()]
             chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks)
         else:
             chat = ChatOpenAI(**kwargs)
+        if self.functions is not None:
+            from langchain.chains.openai_functions import create_openai_fn_runnable
+            from langchain.prompts import ChatPromptTemplate
+
+            structured_prompt = ChatPromptTemplate.from_messages(messages)
+            runnable = create_openai_fn_runnable(functions=self.functions, prompt=structured_prompt, llm=chat)
+            fn_res = runnable.invoke(
+                {
+                    "input": prompt,
+                }
+            )
+            messages.append(AIMessage(content=json.dumps(fn_res)))
+
         return chat(messages).content

+ 5 - 5
tests/llm/test_openai.py

@@ -50,24 +50,24 @@ def test_get_llm_model_answer_empty_prompt(config, mocker):
 
 def test_get_llm_model_answer_with_streaming(config, mocker):
     config.stream = True
-    mocked_jinachat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
+    mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
 
     llm = OpenAILlm(config)
     llm.get_llm_model_answer("Test query")
 
-    mocked_jinachat.assert_called_once()
-    callbacks = [callback[1]["callbacks"] for callback in mocked_jinachat.call_args_list]
+    mocked_openai_chat.assert_called_once()
+    callbacks = [callback[1]["callbacks"] for callback in mocked_openai_chat.call_args_list]
     assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)
 
 
 def test_get_llm_model_answer_without_system_prompt(config, mocker):
     config.system_prompt = None
-    mocked_jinachat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
+    mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
 
     llm = OpenAILlm(config)
     llm.get_llm_model_answer("Test query")
 
-    mocked_jinachat.assert_called_once_with(
+    mocked_openai_chat.assert_called_once_with(
         model=config.model,
         temperature=config.temperature,
         max_tokens=config.max_tokens,