test_google.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. import pytest
  2. from embedchain.config import BaseLlmConfig
  3. from embedchain.llm.google import GoogleLlm
  4. @pytest.fixture
  5. def google_llm_config():
  6. return BaseLlmConfig(model="gemini-pro", max_tokens=100, temperature=0.7, top_p=0.5, stream=False)
  7. def test_google_llm_init_missing_api_key(monkeypatch):
  8. monkeypatch.delenv("GOOGLE_API_KEY", raising=False)
  9. with pytest.raises(ValueError, match="Please set the GOOGLE_API_KEY environment variable."):
  10. GoogleLlm()
  11. def test_google_llm_init(monkeypatch):
  12. monkeypatch.setenv("GOOGLE_API_KEY", "fake_api_key")
  13. with monkeypatch.context() as m:
  14. m.setattr("importlib.import_module", lambda x: None)
  15. google_llm = GoogleLlm()
  16. assert google_llm is not None
  17. def test_google_llm_get_llm_model_answer_with_system_prompt(monkeypatch):
  18. monkeypatch.setenv("GOOGLE_API_KEY", "fake_api_key")
  19. monkeypatch.setattr("importlib.import_module", lambda x: None)
  20. google_llm = GoogleLlm(config=BaseLlmConfig(system_prompt="system prompt"))
  21. with pytest.raises(ValueError, match="GoogleLlm does not support `system_prompt`"):
  22. google_llm.get_llm_model_answer("test prompt")
  23. def test_google_llm_get_llm_model_answer(monkeypatch, google_llm_config):
  24. def mock_get_answer(prompt, config):
  25. return "Generated Text"
  26. monkeypatch.setenv("GOOGLE_API_KEY", "fake_api_key")
  27. monkeypatch.setattr(GoogleLlm, "_get_answer", mock_get_answer)
  28. google_llm = GoogleLlm(config=google_llm_config)
  29. result = google_llm.get_llm_model_answer("test prompt")
  30. assert result == "Generated Text"