123456789101112131415161718192021222324252627282930313233 |
- import os
- import unittest
- from unittest.mock import patch
- from embedchain.config import BaseLlmConfig
- from embedchain.llm.cohere import CohereLlm
- class TestCohereLlm(unittest.TestCase):
- def setUp(self):
- os.environ["COHERE_API_KEY"] = "test_api_key"
- self.config = BaseLlmConfig(model="gptd-instruct-tft", max_tokens=50, temperature=0.7, top_p=0.8)
- def test_init_raises_value_error_without_api_key(self):
- os.environ.pop("COHERE_API_KEY")
- with self.assertRaises(ValueError):
- CohereLlm()
- def test_get_llm_model_answer_raises_value_error_for_system_prompt(self):
- llm = CohereLlm(self.config)
- llm.config.system_prompt = "system_prompt"
- with self.assertRaises(ValueError):
- llm.get_llm_model_answer("prompt")
- @patch("embedchain.llm.cohere.CohereLlm._get_answer")
- def test_get_llm_model_answer(self, mock_get_answer):
- mock_get_answer.return_value = "Test answer"
- llm = CohereLlm(self.config)
- answer = llm.get_llm_model_answer("Test query")
- self.assertEqual(answer, "Test answer")
- mock_get_answer.assert_called_once()
|