google.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import importlib
  2. import logging
  3. import os
  4. from typing import Any, Generator, Optional, Union
  5. import google.generativeai as genai
  6. from embedchain.config import BaseLlmConfig
  7. from embedchain.helpers.json_serializable import register_deserializable
  8. from embedchain.llm.base import BaseLlm
  9. @register_deserializable
  10. class GoogleLlm(BaseLlm):
  11. def __init__(self, config: Optional[BaseLlmConfig] = None):
  12. if "GOOGLE_API_KEY" not in os.environ:
  13. raise ValueError("Please set the GOOGLE_API_KEY environment variable.")
  14. try:
  15. importlib.import_module("google.generativeai")
  16. except ModuleNotFoundError:
  17. raise ModuleNotFoundError(
  18. "The required dependencies for GoogleLlm are not installed."
  19. 'Please install with `pip install --upgrade "embedchain[google]"`'
  20. ) from None
  21. super().__init__(config)
  22. genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
  23. def get_llm_model_answer(self, prompt):
  24. if self.config.system_prompt:
  25. raise ValueError("GoogleLlm does not support `system_prompt`")
  26. response = self._get_answer(prompt)
  27. return response
  28. def _get_answer(self, prompt: str) -> Union[str, Generator[Any, Any, None]]:
  29. model_name = self.config.model or "gemini-pro"
  30. logging.info(f"Using Google LLM model: {model_name}")
  31. model = genai.GenerativeModel(model_name=model_name)
  32. generation_config_params = {
  33. "candidate_count": 1,
  34. "max_output_tokens": self.config.max_tokens,
  35. "temperature": self.config.temperature or 0.5,
  36. }
  37. if self.config.top_p >= 0.0 and self.config.top_p <= 1.0:
  38. generation_config_params["top_p"] = self.config.top_p
  39. else:
  40. raise ValueError("`top_p` must be > 0.0 and < 1.0")
  41. generation_config = genai.types.GenerationConfig(**generation_config_params)
  42. response = model.generate_content(
  43. prompt,
  44. generation_config=generation_config,
  45. stream=self.config.stream,
  46. )
  47. if self.config.stream:
  48. # TODO: Implement streaming
  49. response.resolve()
  50. return response.text
  51. else:
  52. return response.text