llama2.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. import os
  2. from typing import Optional
  3. from langchain.llms import Replicate
  4. from embedchain.config import BaseLlmConfig
  5. from embedchain.helper.json_serializable import register_deserializable
  6. from embedchain.llm.base import BaseLlm
  7. @register_deserializable
  8. class Llama2Llm(BaseLlm):
  9. def __init__(self, config: Optional[BaseLlmConfig] = None):
  10. if "REPLICATE_API_TOKEN" not in os.environ:
  11. raise ValueError("Please set the REPLICATE_API_TOKEN environment variable.")
  12. # Set default config values specific to this llm
  13. if not config:
  14. config = BaseLlmConfig()
  15. # Add variables to this block that have a default value in the parent class
  16. config.max_tokens = 500
  17. config.temperature = 0.75
  18. # Add variables that are `none` by default to this block.
  19. if not config.model:
  20. config.model = (
  21. "a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5"
  22. )
  23. super().__init__(config=config)
  24. def get_llm_model_answer(self, prompt):
  25. # TODO: Move the model and other inputs into config
  26. if self.config.system_prompt:
  27. raise ValueError("Llama2App does not support `system_prompt`")
  28. llm = Replicate(
  29. model=self.config.model,
  30. input={
  31. "temperature": self.config.temperature,
  32. "max_length": self.config.max_tokens,
  33. "top_p": self.config.top_p,
  34. },
  35. )
  36. return llm(prompt)