llama2.py 1.0 KB

1234567891011121314151617181920212223242526
  1. import os
  2. from typing import Optional
  3. from langchain.llms import Replicate
  4. from embedchain.config import BaseLlmConfig
  5. from embedchain.helper.json_serializable import register_deserializable
  6. from embedchain.llm.base import BaseLlm
  7. @register_deserializable
  8. class Llama2Llm(BaseLlm):
  9. def __init__(self, config: Optional[BaseLlmConfig] = None):
  10. if "REPLICATE_API_TOKEN" not in os.environ:
  11. raise ValueError("Please set the REPLICATE_API_TOKEN environment variable.")
  12. super().__init__(config=config)
  13. def get_llm_model_answer(self, prompt):
  14. # TODO: Move the model and other inputs into config
  15. if self.config.system_prompt:
  16. raise ValueError("Llama2App does not support `system_prompt`")
  17. llm = Replicate(
  18. model="a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5",
  19. input={"temperature": self.config.temperature or 0.75, "max_length": 500, "top_p": self.config.top_p},
  20. )
  21. return llm(prompt)