Llama2App.py 1.2 KB

123456789101112131415161718192021222324252627282930313233343536
  1. import os
  2. from langchain.llms import Replicate
  3. from embedchain.config import AppConfig
  4. from embedchain.embedchain import EmbedChain
  5. class Llama2App(EmbedChain):
  6. """
  7. The EmbedChain Llama2App class.
  8. Has two functions: add and query.
  9. adds(data_type, url): adds the data from the given URL to the vector db.
  10. query(query): finds answer to the given query using vector database and LLM.
  11. """
  12. def __init__(self, config: AppConfig = None):
  13. """
  14. :param config: AppConfig instance to load as configuration. Optional.
  15. """
  16. if "REPLICATE_API_TOKEN" not in os.environ:
  17. raise ValueError("Please set the REPLICATE_API_TOKEN environment variable.")
  18. if config is None:
  19. config = AppConfig()
  20. super().__init__(config)
  21. def get_llm_model_answer(self, prompt, config: AppConfig = None):
  22. # TODO: Move the model and other inputs into config
  23. llm = Replicate(
  24. model="a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5",
  25. input={"temperature": 0.75, "max_length": 500, "top_p": 1},
  26. )
  27. return llm(prompt)