Sidharth Mohanty преди 1 година
родител
ревизия
e226a89637
променени са 3 файла, в които са добавени 83 реда и са изтрити 0 реда
  1. 1 0
      docs/advanced/app_types.mdx
  2. 42 0
      embedchain/llm/jina.py
  3. 40 0
      tests/llm/test_jina.py

+ 1 - 0
docs/advanced/app_types.mdx

@@ -27,6 +27,7 @@ The following LLM providers are supported by Embedchain:
 - GPT4ALL
 - AZURE_OPENAI
 - LLAMA2
+- JINA
 - COHERE
 
 You can choose one by importing it from `embedchain.llm`. E.g.:

+ 42 - 0
embedchain/llm/jina.py

@@ -0,0 +1,42 @@
+import os
+from typing import Optional
+
+from langchain.chat_models import JinaChat
+from langchain.schema import HumanMessage, SystemMessage
+
+from embedchain.config import BaseLlmConfig
+from embedchain.helper.json_serializable import register_deserializable
+from embedchain.llm.base import BaseLlm
+
+
+@register_deserializable
+class JinaLlm(BaseLlm):
+    def __init__(self, config: Optional[BaseLlmConfig] = None):
+        if "JINACHAT_API_KEY" not in os.environ:
+            raise ValueError("Please set the JINACHAT_API_KEY environment variable.")
+        super().__init__(config=config)
+
+    def get_llm_model_answer(self, prompt):
+        response = JinaLlm._get_answer(prompt, self.config)
+        return response
+
+    @staticmethod
+    def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
+        messages = []
+        if config.system_prompt:
+            messages.append(SystemMessage(content=config.system_prompt))
+        messages.append(HumanMessage(content=prompt))
+        kwargs = {
+            "temperature": config.temperature,
+            "max_tokens": config.max_tokens,
+            "model_kwargs": {},
+        }
+        if config.top_p:
+            kwargs["model_kwargs"]["top_p"] = config.top_p
+        if config.stream:
+            from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
+
+            chat = JinaChat(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()])
+        else:
+            chat = JinaChat(**kwargs)
+        return chat(messages).content

+ 40 - 0
tests/llm/test_jina.py

@@ -0,0 +1,40 @@
+import os
+import unittest
+from unittest.mock import patch
+
+from embedchain.config import BaseLlmConfig
+from embedchain.llm.jina import JinaLlm
+
+
+class TestJinaLlm(unittest.TestCase):
+    def setUp(self):
+        os.environ["JINACHAT_API_KEY"] = "test_api_key"
+        self.config = BaseLlmConfig(
+            temperature=0.7, max_tokens=50, top_p=0.8, stream=False, system_prompt="System prompt"
+        )
+
+    def test_init_raises_value_error_without_api_key(self):
+        os.environ.pop("JINACHAT_API_KEY")
+        with self.assertRaises(ValueError):
+            JinaLlm()
+
+    @patch("embedchain.llm.jina.JinaLlm._get_answer")
+    def test_get_llm_model_answer(self, mock_get_answer):
+        mock_get_answer.return_value = "Test answer"
+
+        llm = JinaLlm(self.config)
+        answer = llm.get_llm_model_answer("Test query")
+
+        self.assertEqual(answer, "Test answer")
+        mock_get_answer.assert_called_once()
+
+    @patch("embedchain.llm.jina.JinaLlm._get_answer")
+    def test_get_llm_model_answer_with_system_prompt(self, mock_get_answer):
+        self.config.system_prompt = "Custom system prompt"
+        mock_get_answer.return_value = "Test answer"
+
+        llm = JinaLlm(self.config)
+        answer = llm.get_llm_model_answer("Test query")
+
+        self.assertEqual(answer, "Test answer")
+        mock_get_answer.assert_called_once()