CustomApp.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import logging
  2. from typing import List
  3. from langchain.schema import BaseMessage
  4. from embedchain.config import ChatConfig, CustomAppConfig
  5. from embedchain.embedchain import EmbedChain
  6. from embedchain.models import Providers
  7. class CustomApp(EmbedChain):
  8. """
  9. The custom EmbedChain app.
  10. Has two functions: add and query.
  11. adds(data_type, url): adds the data from the given URL to the vector db.
  12. query(query): finds answer to the given query using vector database and LLM.
  13. dry_run(query): test your prompt without consuming tokens.
  14. """
  15. def __init__(self, config: CustomAppConfig = None):
  16. """
  17. :param config: Optional. `CustomAppConfig` instance to load as configuration.
  18. :raises ValueError: Config must be provided for custom app
  19. """
  20. if config is None:
  21. raise ValueError("Config must be provided for custom app")
  22. self.provider = config.provider
  23. if config.provider == Providers.GPT4ALL:
  24. from embedchain import OpenSourceApp
  25. # Because these models run locally, they should have an instance running when the custom app is created
  26. self.open_source_app = OpenSourceApp(config=config.open_source_app_config)
  27. super().__init__(config)
  28. def set_llm_model(self, provider: Providers):
  29. self.provider = provider
  30. if provider == Providers.GPT4ALL:
  31. raise ValueError(
  32. "GPT4ALL needs to be instantiated with the model known, please create a new app instance instead"
  33. )
  34. def get_llm_model_answer(self, prompt, config: ChatConfig):
  35. # TODO: Quitting the streaming response here for now.
  36. # Idea: https://gist.github.com/jvelezmagic/03ddf4c452d011aae36b2a0f73d72f68
  37. if config.stream:
  38. raise NotImplementedError(
  39. "Streaming responses have not been implemented for this model yet. Please disable."
  40. )
  41. try:
  42. if self.provider == Providers.OPENAI:
  43. return CustomApp._get_openai_answer(prompt, config)
  44. if self.provider == Providers.ANTHROPHIC:
  45. return CustomApp._get_athrophic_answer(prompt, config)
  46. if self.provider == Providers.VERTEX_AI:
  47. return CustomApp._get_vertex_answer(prompt, config)
  48. if self.provider == Providers.GPT4ALL:
  49. return self.open_source_app._get_gpt4all_answer(prompt, config)
  50. if self.provider == Providers.AZURE_OPENAI:
  51. return CustomApp._get_azure_openai_answer(prompt, config)
  52. except ImportError as e:
  53. raise ImportError(e.msg) from None
  54. @staticmethod
  55. def _get_openai_answer(prompt: str, config: ChatConfig) -> str:
  56. from langchain.chat_models import ChatOpenAI
  57. chat = ChatOpenAI(
  58. temperature=config.temperature,
  59. model=config.model or "gpt-3.5-turbo",
  60. max_tokens=config.max_tokens,
  61. streaming=config.stream,
  62. )
  63. if config.top_p and config.top_p != 1:
  64. logging.warning("Config option `top_p` is not supported by this model.")
  65. messages = CustomApp._get_messages(prompt)
  66. return chat(messages).content
  67. @staticmethod
  68. def _get_athrophic_answer(prompt: str, config: ChatConfig) -> str:
  69. from langchain.chat_models import ChatAnthropic
  70. chat = ChatAnthropic(temperature=config.temperature, model=config.model)
  71. if config.max_tokens and config.max_tokens != 1000:
  72. logging.warning("Config option `max_tokens` is not supported by this model.")
  73. messages = CustomApp._get_messages(prompt)
  74. return chat(messages).content
  75. @staticmethod
  76. def _get_vertex_answer(prompt: str, config: ChatConfig) -> str:
  77. from langchain.chat_models import ChatVertexAI
  78. chat = ChatVertexAI(temperature=config.temperature, model=config.model, max_output_tokens=config.max_tokens)
  79. if config.top_p and config.top_p != 1:
  80. logging.warning("Config option `top_p` is not supported by this model.")
  81. messages = CustomApp._get_messages(prompt)
  82. return chat(messages).content
  83. @staticmethod
  84. def _get_azure_openai_answer(prompt: str, config: ChatConfig) -> str:
  85. from langchain.chat_models import AzureChatOpenAI
  86. if not config.deployment_name:
  87. raise ValueError("Deployment name must be provided for Azure OpenAI")
  88. chat = AzureChatOpenAI(
  89. deployment_name=config.deployment_name,
  90. openai_api_version="2023-05-15",
  91. model_name=config.model or "gpt-3.5-turbo",
  92. temperature=config.temperature,
  93. max_tokens=config.max_tokens,
  94. streaming=config.stream,
  95. )
  96. if config.top_p and config.top_p != 1:
  97. logging.warning("Config option `top_p` is not supported by this model.")
  98. messages = CustomApp._get_messages(prompt)
  99. return chat(messages).content
  100. @staticmethod
  101. def _get_messages(prompt: str) -> List[BaseMessage]:
  102. from langchain.schema import HumanMessage, SystemMessage
  103. return [SystemMessage(content="You are a helpful assistant."), HumanMessage(content=prompt)]
  104. def _stream_llm_model_response(self, response):
  105. """
  106. This is a generator for streaming response from the OpenAI completions API
  107. """
  108. for line in response:
  109. chunk = line["choices"][0].get("delta", {}).get("content", "")
  110. yield chunk