Llama2App.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. import os
  2. from typing import Optional
  3. from langchain.llms import Replicate
  4. from embedchain.config import AppConfig, ChatConfig
  5. from embedchain.embedchain import EmbedChain
  6. class Llama2App(EmbedChain):
  7. """
  8. The EmbedChain Llama2App class.
  9. Has two functions: add and query.
  10. adds(data_type, url): adds the data from the given URL to the vector db.
  11. query(query): finds answer to the given query using vector database and LLM.
  12. """
  13. def __init__(self, config: AppConfig = None, system_prompt: Optional[str] = None):
  14. """
  15. :param config: AppConfig instance to load as configuration. Optional.
  16. :param system_prompt: System prompt string. Optional.
  17. """
  18. if "REPLICATE_API_TOKEN" not in os.environ:
  19. raise ValueError("Please set the REPLICATE_API_TOKEN environment variable.")
  20. if config is None:
  21. config = AppConfig()
  22. super().__init__(config, system_prompt)
  23. def get_llm_model_answer(self, prompt, config: ChatConfig = None):
  24. # TODO: Move the model and other inputs into config
  25. if self.system_prompt or config.system_prompt:
  26. raise ValueError("Llama2App does not support `system_prompt`")
  27. llm = Replicate(
  28. model="a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5",
  29. input={"temperature": 0.75, "max_length": 500, "top_p": 1},
  30. )
  31. return llm(prompt)