CustomApp.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. import logging
  2. from typing import List, Optional
  3. from langchain.schema import BaseMessage
  4. from embedchain.config import ChatConfig, CustomAppConfig
  5. from embedchain.embedchain import EmbedChain
  6. from embedchain.helper_classes.json_serializable import register_deserializable
  7. from embedchain.models import Providers
  8. @register_deserializable
  9. class CustomApp(EmbedChain):
  10. """
  11. The custom EmbedChain app.
  12. Has two functions: add and query.
  13. adds(data_type, url): adds the data from the given URL to the vector db.
  14. query(query): finds answer to the given query using vector database and LLM.
  15. dry_run(query): test your prompt without consuming tokens.
  16. """
  17. def __init__(self, config: CustomAppConfig = None, system_prompt: Optional[str] = None):
  18. """
  19. :param config: Optional. `CustomAppConfig` instance to load as configuration.
  20. :raises ValueError: Config must be provided for custom app
  21. :param system_prompt: Optional. System prompt string.
  22. """
  23. if config is None:
  24. raise ValueError("Config must be provided for custom app")
  25. self.provider = config.provider
  26. if config.provider == Providers.GPT4ALL:
  27. from embedchain import OpenSourceApp
  28. # Because these models run locally, they should have an instance running when the custom app is created
  29. self.open_source_app = OpenSourceApp(config=config.open_source_app_config)
  30. super().__init__(config, system_prompt)
  31. def set_llm_model(self, provider: Providers):
  32. self.provider = provider
  33. if provider == Providers.GPT4ALL:
  34. raise ValueError(
  35. "GPT4ALL needs to be instantiated with the model known, please create a new app instance instead"
  36. )
  37. def get_llm_model_answer(self, prompt, config: ChatConfig):
  38. # TODO: Quitting the streaming response here for now.
  39. # Idea: https://gist.github.com/jvelezmagic/03ddf4c452d011aae36b2a0f73d72f68
  40. if config.stream:
  41. raise NotImplementedError(
  42. "Streaming responses have not been implemented for this model yet. Please disable."
  43. )
  44. if config.system_prompt is None and self.system_prompt is not None:
  45. config.system_prompt = self.system_prompt
  46. try:
  47. if self.provider == Providers.OPENAI:
  48. return CustomApp._get_openai_answer(prompt, config)
  49. if self.provider == Providers.ANTHROPHIC:
  50. return CustomApp._get_athrophic_answer(prompt, config)
  51. if self.provider == Providers.VERTEX_AI:
  52. return CustomApp._get_vertex_answer(prompt, config)
  53. if self.provider == Providers.GPT4ALL:
  54. return self.open_source_app._get_gpt4all_answer(prompt, config)
  55. if self.provider == Providers.AZURE_OPENAI:
  56. return CustomApp._get_azure_openai_answer(prompt, config)
  57. except ImportError as e:
  58. raise ModuleNotFoundError(e.msg) from None
  59. @staticmethod
  60. def _get_openai_answer(prompt: str, config: ChatConfig) -> str:
  61. from langchain.chat_models import ChatOpenAI
  62. chat = ChatOpenAI(
  63. temperature=config.temperature,
  64. model=config.model or "gpt-3.5-turbo",
  65. max_tokens=config.max_tokens,
  66. streaming=config.stream,
  67. )
  68. if config.top_p and config.top_p != 1:
  69. logging.warning("Config option `top_p` is not supported by this model.")
  70. messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt)
  71. return chat(messages).content
  72. @staticmethod
  73. def _get_athrophic_answer(prompt: str, config: ChatConfig) -> str:
  74. from langchain.chat_models import ChatAnthropic
  75. chat = ChatAnthropic(temperature=config.temperature, model=config.model)
  76. if config.max_tokens and config.max_tokens != 1000:
  77. logging.warning("Config option `max_tokens` is not supported by this model.")
  78. messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt)
  79. return chat(messages).content
  80. @staticmethod
  81. def _get_vertex_answer(prompt: str, config: ChatConfig) -> str:
  82. from langchain.chat_models import ChatVertexAI
  83. chat = ChatVertexAI(temperature=config.temperature, model=config.model, max_output_tokens=config.max_tokens)
  84. if config.top_p and config.top_p != 1:
  85. logging.warning("Config option `top_p` is not supported by this model.")
  86. messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt)
  87. return chat(messages).content
  88. @staticmethod
  89. def _get_azure_openai_answer(prompt: str, config: ChatConfig) -> str:
  90. from langchain.chat_models import AzureChatOpenAI
  91. if not config.deployment_name:
  92. raise ValueError("Deployment name must be provided for Azure OpenAI")
  93. chat = AzureChatOpenAI(
  94. deployment_name=config.deployment_name,
  95. openai_api_version="2023-05-15",
  96. model_name=config.model or "gpt-3.5-turbo",
  97. temperature=config.temperature,
  98. max_tokens=config.max_tokens,
  99. streaming=config.stream,
  100. )
  101. if config.top_p and config.top_p != 1:
  102. logging.warning("Config option `top_p` is not supported by this model.")
  103. messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt)
  104. return chat(messages).content
  105. @staticmethod
  106. def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[BaseMessage]:
  107. from langchain.schema import HumanMessage, SystemMessage
  108. messages = []
  109. if system_prompt:
  110. messages.append(SystemMessage(content=system_prompt))
  111. messages.append(HumanMessage(content=prompt))
  112. return messages
  113. def _stream_llm_model_response(self, response):
  114. """
  115. This is a generator for streaming response from the OpenAI completions API
  116. """
  117. for line in response:
  118. chunk = line["choices"][0].get("delta", {}).get("content", "")
  119. yield chunk