llama2.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import importlib
  2. import os
  3. from typing import Optional
  4. from langchain.llms import Replicate
  5. from embedchain.config import BaseLlmConfig
  6. from embedchain.helper.json_serializable import register_deserializable
  7. from embedchain.llm.base import BaseLlm
  8. try:
  9. importlib.import_module("replicate")
  10. except ModuleNotFoundError:
  11. raise ModuleNotFoundError(
  12. "The required dependencies for Llama2 are not installed."
  13. 'Please install with `pip install --upgrade "embedchain[llama2]"`'
  14. ) from None
  15. @register_deserializable
  16. class Llama2Llm(BaseLlm):
  17. def __init__(self, config: Optional[BaseLlmConfig] = None):
  18. if "REPLICATE_API_TOKEN" not in os.environ:
  19. raise ValueError("Please set the REPLICATE_API_TOKEN environment variable.")
  20. # Set default config values specific to this llm
  21. if not config:
  22. config = BaseLlmConfig()
  23. # Add variables to this block that have a default value in the parent class
  24. config.max_tokens = 500
  25. config.temperature = 0.75
  26. # Add variables that are `none` by default to this block.
  27. if not config.model:
  28. config.model = (
  29. "a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5"
  30. )
  31. super().__init__(config=config)
  32. def get_llm_model_answer(self, prompt):
  33. # TODO: Move the model and other inputs into config
  34. if self.config.system_prompt:
  35. raise ValueError("Llama2App does not support `system_prompt`")
  36. llm = Replicate(
  37. model=self.config.model,
  38. input={
  39. "temperature": self.config.temperature,
  40. "max_length": self.config.max_tokens,
  41. "top_p": self.config.top_p,
  42. },
  43. )
  44. return llm(prompt)