Kaynağa Gözat

fix: use openai llm via langchain (#670)

Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
Taranjeet Singh 1 yıl önce
ebeveyn
işleme
0f16c72762
3 değiştirilmiş dosya ile 42 ekleme ve 51 silme
  1. 2 1
      embedchain/embedchain.py
  2. 27 25
      embedchain/llm/openai.py
  3. 13 25
      tests/llm/test_query.py

+ 2 - 1
embedchain/embedchain.py

@@ -21,7 +21,8 @@ from embedchain.embedder.base import BaseEmbedder
 from embedchain.helper.json_serializable import JSONSerializable
 from embedchain.llm.base import BaseLlm
 from embedchain.loaders.base_loader import BaseLoader
-from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
+from embedchain.models.data_type import (DataType, DirectDataType,
+                                         IndirectDataType, SpecialDataType)
 from embedchain.utils import detect_datatype
 from embedchain.vectordb.base import BaseVectorDB
 

+ 27 - 25
embedchain/llm/openai.py

@@ -1,6 +1,7 @@
 from typing import Optional
 
-import openai
+from langchain.chat_models import ChatOpenAI
+from langchain.schema import HumanMessage, SystemMessage
 
 from embedchain.config import BaseLlmConfig
 from embedchain.helper.json_serializable import register_deserializable
@@ -12,31 +13,32 @@ class OpenAILlm(BaseLlm):
     def __init__(self, config: Optional[BaseLlmConfig] = None):
         super().__init__(config=config)
 
-    # NOTE: This class does not use langchain. One reason is that `top_p` is not supported.
-
     def get_llm_model_answer(self, prompt):
-        messages = []
-        if self.config.system_prompt:
-            messages.append({"role": "system", "content": self.config.system_prompt})
-        messages.append({"role": "user", "content": prompt})
-        response = openai.ChatCompletion.create(
-            model=self.config.model or "gpt-3.5-turbo-0613",
-            messages=messages,
-            temperature=self.config.temperature,
-            max_tokens=self.config.max_tokens,
-            top_p=self.config.top_p,
-            stream=self.config.stream,
-        )
+        response = OpenAILlm._get_answer(prompt, self.config)
 
         if self.config.stream:
-            return self._stream_llm_model_response(response)
+            return response
+        else:
+            return response.content
+
+    def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
+        messages = []
+        if config.system_prompt:
+            messages.append(SystemMessage(content=config.system_prompt))
+        messages.append(HumanMessage(content=prompt))
+        kwargs = {
+            "model": config.model or "gpt-3.5-turbo-0613",
+            "temperature": config.temperature,
+            "max_tokens": config.max_tokens,
+            "model_kwargs": {},
+        }
+        if config.top_p:
+            kwargs["model_kwargs"]["top_p"] = config.top_p
+        if config.stream:
+            from langchain.callbacks.streaming_stdout import \
+                StreamingStdOutCallbackHandler
+
+            chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()])
         else:
-            return response["choices"][0]["message"]["content"]
-
-    def _stream_llm_model_response(self, response):
-        """
-        This is a generator for streaming response from the OpenAI completions API
-        """
-        for line in response:
-            chunk = line["choices"][0].get("delta", {}).get("content", "")
-            yield chunk
+            chat = ChatOpenAI(**kwargs)
+        return chat(messages)

+ 13 - 25
tests/llm/test_query.py

@@ -46,41 +46,29 @@ class TestApp(unittest.TestCase):
         self.assertEqual(input_query_arg, "Test query")
         mock_answer.assert_called_once()
 
-    @patch("openai.ChatCompletion.create")
-    def test_query_config_app_passing(self, mock_create):
-        mock_create.return_value = {"choices": [{"message": {"content": "response"}}]}  # Mock response
+    @patch("embedchain.llm.openai.OpenAILlm._get_answer")
+    def test_query_config_app_passing(self, mock_get_answer):
+        mock_get_answer.return_value = MagicMock()
+        mock_get_answer.return_value.content = "Test answer"
 
         config = AppConfig(collect_metrics=False)
         chat_config = BaseLlmConfig(system_prompt="Test system prompt")
         app = App(config=config, llm_config=chat_config)
+        answer = app.llm.get_llm_model_answer("Test query")
 
-        app.llm.get_llm_model_answer("Test query")
-
-        # Test system_prompt: Check that the 'create' method was called with the correct 'messages' argument
-        messages_arg = mock_create.call_args.kwargs["messages"]
-        self.assertTrue(messages_arg[0].get("role"), "system")
-        self.assertEqual(messages_arg[0].get("content"), "Test system prompt")
-        self.assertTrue(messages_arg[1].get("role"), "user")
-        self.assertEqual(messages_arg[1].get("content"), "Test query")
-
-        # TODO: Add tests for other config variables
-
-    @patch("openai.ChatCompletion.create")
-    def test_app_passing(self, mock_create):
-        mock_create.return_value = {"choices": [{"message": {"content": "response"}}]}  # Mock response
+        self.assertEqual(app.llm.config.system_prompt, "Test system prompt")
+        self.assertEqual(answer, "Test answer")
 
+    @patch("embedchain.llm.openai.OpenAILlm._get_answer")
+    def test_app_passing(self, mock_get_answer):
+        mock_get_answer.return_value = MagicMock()
+        mock_get_answer.return_value.content = "Test answer"
         config = AppConfig(collect_metrics=False)
         chat_config = BaseLlmConfig()
         app = App(config=config, llm_config=chat_config, system_prompt="Test system prompt")
-
+        answer = app.llm.get_llm_model_answer("Test query")
         self.assertEqual(app.llm.config.system_prompt, "Test system prompt")
-
-        app.llm.get_llm_model_answer("Test query")
-
-        # Test system_prompt: Check that the 'create' method was called with the correct 'messages' argument
-        messages_arg = mock_create.call_args.kwargs["messages"]
-        self.assertTrue(messages_arg[0].get("role"), "system")
-        self.assertEqual(messages_arg[0].get("content"), "Test system prompt")
+        self.assertEqual(answer, "Test answer")
 
     @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
     def test_query_with_where_in_params(self):