CustomApp.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. import logging
  2. from typing import List, Optional
  3. from langchain.schema import BaseMessage
  4. from embedchain.config import ChatConfig, CustomAppConfig
  5. from embedchain.embedchain import EmbedChain
  6. from embedchain.models import Providers
  7. class CustomApp(EmbedChain):
  8. """
  9. The custom EmbedChain app.
  10. Has two functions: add and query.
  11. adds(data_type, url): adds the data from the given URL to the vector db.
  12. query(query): finds answer to the given query using vector database and LLM.
  13. dry_run(query): test your prompt without consuming tokens.
  14. """
  15. def __init__(self, config: CustomAppConfig = None, system_prompt: Optional[str] = None):
  16. """
  17. :param config: Optional. `CustomAppConfig` instance to load as configuration.
  18. :raises ValueError: Config must be provided for custom app
  19. :param system_prompt: Optional. System prompt string.
  20. """
  21. if config is None:
  22. raise ValueError("Config must be provided for custom app")
  23. self.provider = config.provider
  24. if config.provider == Providers.GPT4ALL:
  25. from embedchain import OpenSourceApp
  26. # Because these models run locally, they should have an instance running when the custom app is created
  27. self.open_source_app = OpenSourceApp(config=config.open_source_app_config)
  28. super().__init__(config, system_prompt)
  29. def set_llm_model(self, provider: Providers):
  30. self.provider = provider
  31. if provider == Providers.GPT4ALL:
  32. raise ValueError(
  33. "GPT4ALL needs to be instantiated with the model known, please create a new app instance instead"
  34. )
  35. def get_llm_model_answer(self, prompt, config: ChatConfig):
  36. # TODO: Quitting the streaming response here for now.
  37. # Idea: https://gist.github.com/jvelezmagic/03ddf4c452d011aae36b2a0f73d72f68
  38. if config.stream:
  39. raise NotImplementedError(
  40. "Streaming responses have not been implemented for this model yet. Please disable."
  41. )
  42. if config.system_prompt is None and self.system_prompt is not None:
  43. config.system_prompt = self.system_prompt
  44. try:
  45. if self.provider == Providers.OPENAI:
  46. return CustomApp._get_openai_answer(prompt, config)
  47. if self.provider == Providers.ANTHROPHIC:
  48. return CustomApp._get_athrophic_answer(prompt, config)
  49. if self.provider == Providers.VERTEX_AI:
  50. return CustomApp._get_vertex_answer(prompt, config)
  51. if self.provider == Providers.GPT4ALL:
  52. return self.open_source_app._get_gpt4all_answer(prompt, config)
  53. if self.provider == Providers.AZURE_OPENAI:
  54. return CustomApp._get_azure_openai_answer(prompt, config)
  55. except ImportError as e:
  56. raise ModuleNotFoundError(e.msg) from None
  57. @staticmethod
  58. def _get_openai_answer(prompt: str, config: ChatConfig) -> str:
  59. from langchain.chat_models import ChatOpenAI
  60. chat = ChatOpenAI(
  61. temperature=config.temperature,
  62. model=config.model or "gpt-3.5-turbo",
  63. max_tokens=config.max_tokens,
  64. streaming=config.stream,
  65. )
  66. if config.top_p and config.top_p != 1:
  67. logging.warning("Config option `top_p` is not supported by this model.")
  68. messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt)
  69. return chat(messages).content
  70. @staticmethod
  71. def _get_athrophic_answer(prompt: str, config: ChatConfig) -> str:
  72. from langchain.chat_models import ChatAnthropic
  73. chat = ChatAnthropic(temperature=config.temperature, model=config.model)
  74. if config.max_tokens and config.max_tokens != 1000:
  75. logging.warning("Config option `max_tokens` is not supported by this model.")
  76. messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt)
  77. return chat(messages).content
  78. @staticmethod
  79. def _get_vertex_answer(prompt: str, config: ChatConfig) -> str:
  80. from langchain.chat_models import ChatVertexAI
  81. chat = ChatVertexAI(temperature=config.temperature, model=config.model, max_output_tokens=config.max_tokens)
  82. if config.top_p and config.top_p != 1:
  83. logging.warning("Config option `top_p` is not supported by this model.")
  84. messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt)
  85. return chat(messages).content
  86. @staticmethod
  87. def _get_azure_openai_answer(prompt: str, config: ChatConfig) -> str:
  88. from langchain.chat_models import AzureChatOpenAI
  89. if not config.deployment_name:
  90. raise ValueError("Deployment name must be provided for Azure OpenAI")
  91. chat = AzureChatOpenAI(
  92. deployment_name=config.deployment_name,
  93. openai_api_version="2023-05-15",
  94. model_name=config.model or "gpt-3.5-turbo",
  95. temperature=config.temperature,
  96. max_tokens=config.max_tokens,
  97. streaming=config.stream,
  98. )
  99. if config.top_p and config.top_p != 1:
  100. logging.warning("Config option `top_p` is not supported by this model.")
  101. messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt)
  102. return chat(messages).content
  103. @staticmethod
  104. def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[BaseMessage]:
  105. from langchain.schema import HumanMessage, SystemMessage
  106. messages = []
  107. if system_prompt:
  108. messages.append(SystemMessage(content=system_prompt))
  109. messages.append(HumanMessage(content=prompt))
  110. return messages
  111. def _stream_llm_model_response(self, response):
  112. """
  113. This is a generator for streaming response from the OpenAI completions API
  114. """
  115. for line in response:
  116. chunk = line["choices"][0].get("delta", {}).get("content", "")
  117. yield chunk