Llama2App.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. import os
  2. from langchain.llms import Replicate
  3. from embedchain.config import AppConfig, ChatConfig
  4. from embedchain.embedchain import EmbedChain
  5. class Llama2App(EmbedChain):
  6. """
  7. The EmbedChain Llama2App class.
  8. Has two functions: add and query.
  9. adds(data_type, url): adds the data from the given URL to the vector db.
  10. query(query): finds answer to the given query using vector database and LLM.
  11. """
  12. def __init__(self, config: AppConfig = None):
  13. """
  14. :param config: AppConfig instance to load as configuration. Optional.
  15. """
  16. if "REPLICATE_API_TOKEN" not in os.environ:
  17. raise ValueError("Please set the REPLICATE_API_TOKEN environment variable.")
  18. if config is None:
  19. config = AppConfig()
  20. super().__init__(config)
  21. def get_llm_model_answer(self, prompt, config: ChatConfig = None):
  22. # TODO: Move the model and other inputs into config
  23. if config.system_prompt:
  24. raise ValueError("Llama2App does not support `system_prompt`")
  25. llm = Replicate(
  26. model="a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5",
  27. input={"temperature": 0.75, "max_length": 500, "top_p": 1},
  28. )
  29. return llm(prompt)