llama2.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import importlib
  2. import os
  3. from typing import Optional
  4. from langchain_community.llms.replicate import Replicate
  5. from embedchain.config import BaseLlmConfig
  6. from embedchain.helpers.json_serializable import register_deserializable
  7. from embedchain.llm.base import BaseLlm
  8. @register_deserializable
  9. class Llama2Llm(BaseLlm):
  10. def __init__(self, config: Optional[BaseLlmConfig] = None):
  11. try:
  12. importlib.import_module("replicate")
  13. except ModuleNotFoundError:
  14. raise ModuleNotFoundError(
  15. "The required dependencies for Llama2 are not installed."
  16. 'Please install with `pip install --upgrade "embedchain[llama2]"`'
  17. ) from None
  18. if "REPLICATE_API_TOKEN" not in os.environ:
  19. raise ValueError("Please set the REPLICATE_API_TOKEN environment variable.")
  20. # Set default config values specific to this llm
  21. if not config:
  22. config = BaseLlmConfig()
  23. # Add variables to this block that have a default value in the parent class
  24. config.max_tokens = 500
  25. config.temperature = 0.75
  26. # Add variables that are `none` by default to this block.
  27. if not config.model:
  28. config.model = (
  29. "a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5"
  30. )
  31. super().__init__(config=config)
  32. def get_llm_model_answer(self, prompt):
  33. # TODO: Move the model and other inputs into config
  34. if self.config.system_prompt:
  35. raise ValueError("Llama2 does not support `system_prompt`")
  36. llm = Replicate(
  37. model=self.config.model,
  38. input={
  39. "temperature": self.config.temperature,
  40. "max_length": self.config.max_tokens,
  41. "top_p": self.config.top_p,
  42. },
  43. )
  44. return llm.invoke(prompt)