google.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import importlib
  2. import logging
  3. import os
  4. from collections.abc import Generator
  5. from typing import Any, Optional, Union
  6. import google.generativeai as genai
  7. from embedchain.config import BaseLlmConfig
  8. from embedchain.helpers.json_serializable import register_deserializable
  9. from embedchain.llm.base import BaseLlm
  10. logger = logging.getLogger(__name__)
  11. @register_deserializable
  12. class GoogleLlm(BaseLlm):
  13. def __init__(self, config: Optional[BaseLlmConfig] = None):
  14. if "GOOGLE_API_KEY" not in os.environ:
  15. raise ValueError("Please set the GOOGLE_API_KEY environment variable.")
  16. try:
  17. importlib.import_module("google.generativeai")
  18. except ModuleNotFoundError:
  19. raise ModuleNotFoundError(
  20. "The required dependencies for GoogleLlm are not installed."
  21. 'Please install with `pip install --upgrade "embedchain[google]"`'
  22. ) from None
  23. super().__init__(config)
  24. genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
  25. def get_llm_model_answer(self, prompt):
  26. if self.config.system_prompt:
  27. raise ValueError("GoogleLlm does not support `system_prompt`")
  28. response = self._get_answer(prompt)
  29. return response
  30. def _get_answer(self, prompt: str) -> Union[str, Generator[Any, Any, None]]:
  31. model_name = self.config.model or "gemini-pro"
  32. logger.info(f"Using Google LLM model: {model_name}")
  33. model = genai.GenerativeModel(model_name=model_name)
  34. generation_config_params = {
  35. "candidate_count": 1,
  36. "max_output_tokens": self.config.max_tokens,
  37. "temperature": self.config.temperature or 0.5,
  38. }
  39. if 0.0 <= self.config.top_p <= 1.0:
  40. generation_config_params["top_p"] = self.config.top_p
  41. else:
  42. raise ValueError("`top_p` must be > 0.0 and < 1.0")
  43. generation_config = genai.types.GenerationConfig(**generation_config_params)
  44. response = model.generate_content(
  45. prompt,
  46. generation_config=generation_config,
  47. stream=self.config.stream,
  48. )
  49. if self.config.stream:
  50. # TODO: Implement streaming
  51. response.resolve()
  52. return response.text
  53. else:
  54. return response.text