llama2.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. import importlib
  2. import os
  3. from typing import Optional
  4. from langchain_community.llms.replicate import Replicate
  5. from embedchain.config import BaseLlmConfig
  6. from embedchain.helpers.json_serializable import register_deserializable
  7. from embedchain.llm.base import BaseLlm
  8. @register_deserializable
  9. class Llama2Llm(BaseLlm):
  10. def __init__(self, config: Optional[BaseLlmConfig] = None):
  11. try:
  12. importlib.import_module("replicate")
  13. except ModuleNotFoundError:
  14. raise ModuleNotFoundError(
  15. "The required dependencies for Llama2 are not installed."
  16. 'Please install with `pip install --upgrade "embedchain[llama2]"`'
  17. ) from None
  18. # Set default config values specific to this llm
  19. if not config:
  20. config = BaseLlmConfig()
  21. # Add variables to this block that have a default value in the parent class
  22. config.max_tokens = 500
  23. config.temperature = 0.75
  24. # Add variables that are `none` by default to this block.
  25. if not config.model:
  26. config.model = (
  27. "a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5"
  28. )
  29. super().__init__(config=config)
  30. if not self.config.api_key and "REPLICATE_API_TOKEN" not in os.environ:
  31. raise ValueError("Please set the REPLICATE_API_TOKEN environment variable or pass it in the config.")
  32. def get_llm_model_answer(self, prompt):
  33. # TODO: Move the model and other inputs into config
  34. if self.config.system_prompt:
  35. raise ValueError("Llama2 does not support `system_prompt`")
  36. api_key = self.config.api_key or os.getenv("REPLICATE_API_TOKEN")
  37. llm = Replicate(
  38. model=self.config.model,
  39. replicate_api_token=api_key,
  40. input={
  41. "temperature": self.config.temperature,
  42. "max_length": self.config.max_tokens,
  43. "top_p": self.config.top_p,
  44. },
  45. )
  46. return llm.invoke(prompt)