Jelajahi Sumber

Allow ollama llm to take custom callback for handling streaming (#1376)

Aditya Veer Parmar 1 tahun lalu
induk
melakukan
dc0d8e0932
2 mengubah file dengan 23 tambahan dan 2 penghapusan
  1. 5 2
      embedchain/llm/ollama.py
  2. 18 0
      tests/llm/test_ollama.py

+ 5 - 2
embedchain/llm/ollama.py

@@ -33,14 +33,17 @@ class OllamaLlm(BaseLlm):
 
     @staticmethod
     def _get_answer(prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
-        callback_manager = [StreamingStdOutCallbackHandler()] if config.stream else [StdOutCallbackHandler()]
+        if config.stream:
+            callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()]
+        else:
+            callbacks = [StdOutCallbackHandler()]
 
         llm = Ollama(
             model=config.model,
             system=config.system_prompt,
             temperature=config.temperature,
             top_p=config.top_p,
-            callback_manager=CallbackManager(callback_manager),
+            callback_manager=CallbackManager(callbacks),
             base_url=config.base_url,
         )
 

+ 18 - 0
tests/llm/test_ollama.py

@@ -2,6 +2,7 @@ import pytest
 
 from embedchain.config import BaseLlmConfig
 from embedchain.llm.ollama import OllamaLlm
+from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 
 
 @pytest.fixture
@@ -31,3 +32,20 @@ def test_get_answer_mocked_ollama(ollama_llm_config, mocker):
     answer = llm.get_llm_model_answer(prompt)
 
     assert answer == "Mocked answer"
+
+
+def test_get_llm_model_answer_with_streaming(ollama_llm_config, mocker):
+    ollama_llm_config.stream = True
+    ollama_llm_config.callbacks = [StreamingStdOutCallbackHandler()]
+    mocked_ollama_chat = mocker.patch("embedchain.llm.ollama.OllamaLlm._get_answer", return_value="Test answer")
+
+    llm = OllamaLlm(ollama_llm_config)
+    llm.get_llm_model_answer("Test query")
+
+    mocked_ollama_chat.assert_called_once()
+    call_args = mocked_ollama_chat.call_args
+    config_arg = call_args[1]["config"]
+    callbacks = config_arg.callbacks
+
+    assert len(callbacks) == 1
+    assert isinstance(callbacks[0], StreamingStdOutCallbackHandler)