google.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import importlib
  2. import logging
  3. import os
  4. from collections.abc import Generator
  5. from typing import Any, Optional, Union
  6. import google.generativeai as genai
  7. from embedchain.config import BaseLlmConfig
  8. from embedchain.helpers.json_serializable import register_deserializable
  9. from embedchain.llm.base import BaseLlm
  10. logger = logging.getLogger(__name__)
  11. @register_deserializable
  12. class GoogleLlm(BaseLlm):
  13. def __init__(self, config: Optional[BaseLlmConfig] = None):
  14. try:
  15. importlib.import_module("google.generativeai")
  16. except ModuleNotFoundError:
  17. raise ModuleNotFoundError(
  18. "The required dependencies for GoogleLlm are not installed."
  19. 'Please install with `pip install --upgrade "embedchain[google]"`'
  20. ) from None
  21. super().__init__(config)
  22. if not self.config.api_key and "GOOGLE_API_KEY" not in os.environ:
  23. raise ValueError("Please set the GOOGLE_API_KEY environment variable or pass it in the config.")
  24. api_key = self.config.api_key or os.getenv("GOOGLE_API_KEY")
  25. genai.configure(api_key=api_key)
  26. def get_llm_model_answer(self, prompt):
  27. if self.config.system_prompt:
  28. raise ValueError("GoogleLlm does not support `system_prompt`")
  29. response = self._get_answer(prompt)
  30. return response
  31. def _get_answer(self, prompt: str) -> Union[str, Generator[Any, Any, None]]:
  32. model_name = self.config.model or "gemini-pro"
  33. logger.info(f"Using Google LLM model: {model_name}")
  34. model = genai.GenerativeModel(model_name=model_name)
  35. generation_config_params = {
  36. "candidate_count": 1,
  37. "max_output_tokens": self.config.max_tokens,
  38. "temperature": self.config.temperature or 0.5,
  39. }
  40. if 0.0 <= self.config.top_p <= 1.0:
  41. generation_config_params["top_p"] = self.config.top_p
  42. else:
  43. raise ValueError("`top_p` must be > 0.0 and < 1.0")
  44. generation_config = genai.types.GenerationConfig(**generation_config_params)
  45. response = model.generate_content(
  46. prompt,
  47. generation_config=generation_config,
  48. stream=self.config.stream,
  49. )
  50. if self.config.stream:
  51. # TODO: Implement streaming
  52. response.resolve()
  53. return response.text
  54. else:
  55. return response.text