google.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import importlib
  2. import logging
  3. import os
  4. from collections.abc import Generator
  5. from typing import Any, Optional, Union
  6. import google.generativeai as genai
  7. from embedchain.config import BaseLlmConfig
  8. from embedchain.helpers.json_serializable import register_deserializable
  9. from embedchain.llm.base import BaseLlm
  10. @register_deserializable
  11. class GoogleLlm(BaseLlm):
  12. def __init__(self, config: Optional[BaseLlmConfig] = None):
  13. if "GOOGLE_API_KEY" not in os.environ:
  14. raise ValueError("Please set the GOOGLE_API_KEY environment variable.")
  15. try:
  16. importlib.import_module("google.generativeai")
  17. except ModuleNotFoundError:
  18. raise ModuleNotFoundError(
  19. "The required dependencies for GoogleLlm are not installed."
  20. 'Please install with `pip install --upgrade "embedchain[google]"`'
  21. ) from None
  22. super().__init__(config)
  23. genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
  24. def get_llm_model_answer(self, prompt):
  25. if self.config.system_prompt:
  26. raise ValueError("GoogleLlm does not support `system_prompt`")
  27. response = self._get_answer(prompt)
  28. return response
  29. def _get_answer(self, prompt: str) -> Union[str, Generator[Any, Any, None]]:
  30. model_name = self.config.model or "gemini-pro"
  31. logging.info(f"Using Google LLM model: {model_name}")
  32. model = genai.GenerativeModel(model_name=model_name)
  33. generation_config_params = {
  34. "candidate_count": 1,
  35. "max_output_tokens": self.config.max_tokens,
  36. "temperature": self.config.temperature or 0.5,
  37. }
  38. if 0.0 <= self.config.top_p <= 1.0:
  39. generation_config_params["top_p"] = self.config.top_p
  40. else:
  41. raise ValueError("`top_p` must be > 0.0 and < 1.0")
  42. generation_config = genai.types.GenerationConfig(**generation_config_params)
  43. response = model.generate_content(
  44. prompt,
  45. generation_config=generation_config,
  46. stream=self.config.stream,
  47. )
  48. if self.config.stream:
  49. # TODO: Implement streaming
  50. response.resolve()
  51. return response.text
  52. else:
  53. return response.text