123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156 |
- import logging
- from typing import List, Optional
- from langchain.schema import BaseMessage
- from embedchain.config import ChatConfig, CustomAppConfig
- from embedchain.embedchain import EmbedChain
- from embedchain.models import Providers
- class CustomApp(EmbedChain):
- """
- The custom EmbedChain app.
- Has two functions: add and query.
- adds(data_type, url): adds the data from the given URL to the vector db.
- query(query): finds answer to the given query using vector database and LLM.
- dry_run(query): test your prompt without consuming tokens.
- """
- def __init__(self, config: CustomAppConfig = None):
- """
- :param config: Optional. `CustomAppConfig` instance to load as configuration.
- :raises ValueError: Config must be provided for custom app
- """
- if config is None:
- raise ValueError("Config must be provided for custom app")
- self.provider = config.provider
- if config.provider == Providers.GPT4ALL:
- from embedchain import OpenSourceApp
- # Because these models run locally, they should have an instance running when the custom app is created
- self.open_source_app = OpenSourceApp(config=config.open_source_app_config)
- super().__init__(config)
- def set_llm_model(self, provider: Providers):
- self.provider = provider
- if provider == Providers.GPT4ALL:
- raise ValueError(
- "GPT4ALL needs to be instantiated with the model known, please create a new app instance instead"
- )
- def get_llm_model_answer(self, prompt, config: ChatConfig):
- # TODO: Quitting the streaming response here for now.
- # Idea: https://gist.github.com/jvelezmagic/03ddf4c452d011aae36b2a0f73d72f68
- if config.stream:
- raise NotImplementedError(
- "Streaming responses have not been implemented for this model yet. Please disable."
- )
- try:
- if self.provider == Providers.OPENAI:
- return CustomApp._get_openai_answer(prompt, config)
- if self.provider == Providers.ANTHROPHIC:
- return CustomApp._get_athrophic_answer(prompt, config)
- if self.provider == Providers.VERTEX_AI:
- return CustomApp._get_vertex_answer(prompt, config)
- if self.provider == Providers.GPT4ALL:
- return self.open_source_app._get_gpt4all_answer(prompt, config)
- if self.provider == Providers.AZURE_OPENAI:
- return CustomApp._get_azure_openai_answer(prompt, config)
- except ImportError as e:
- raise ModuleNotFoundError(e.msg) from None
- @staticmethod
- def _get_openai_answer(prompt: str, config: ChatConfig) -> str:
- from langchain.chat_models import ChatOpenAI
- chat = ChatOpenAI(
- temperature=config.temperature,
- model=config.model or "gpt-3.5-turbo",
- max_tokens=config.max_tokens,
- streaming=config.stream,
- )
- if config.top_p and config.top_p != 1:
- logging.warning("Config option `top_p` is not supported by this model.")
- messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt)
- return chat(messages).content
- @staticmethod
- def _get_athrophic_answer(prompt: str, config: ChatConfig) -> str:
- from langchain.chat_models import ChatAnthropic
- chat = ChatAnthropic(temperature=config.temperature, model=config.model)
- if config.max_tokens and config.max_tokens != 1000:
- logging.warning("Config option `max_tokens` is not supported by this model.")
- messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt)
- return chat(messages).content
- @staticmethod
- def _get_vertex_answer(prompt: str, config: ChatConfig) -> str:
- from langchain.chat_models import ChatVertexAI
- chat = ChatVertexAI(temperature=config.temperature, model=config.model, max_output_tokens=config.max_tokens)
- if config.top_p and config.top_p != 1:
- logging.warning("Config option `top_p` is not supported by this model.")
- messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt)
- return chat(messages).content
- @staticmethod
- def _get_azure_openai_answer(prompt: str, config: ChatConfig) -> str:
- from langchain.chat_models import AzureChatOpenAI
- if not config.deployment_name:
- raise ValueError("Deployment name must be provided for Azure OpenAI")
- chat = AzureChatOpenAI(
- deployment_name=config.deployment_name,
- openai_api_version="2023-05-15",
- model_name=config.model or "gpt-3.5-turbo",
- temperature=config.temperature,
- max_tokens=config.max_tokens,
- streaming=config.stream,
- )
- if config.top_p and config.top_p != 1:
- logging.warning("Config option `top_p` is not supported by this model.")
- messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt)
- return chat(messages).content
- @staticmethod
- def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[BaseMessage]:
- from langchain.schema import HumanMessage, SystemMessage
- messages = []
- if system_prompt:
- messages.append(SystemMessage(content=system_prompt))
- messages.append(HumanMessage(content=prompt))
- return messages
- def _stream_llm_model_response(self, response):
- """
- This is a generator for streaming response from the OpenAI completions API
- """
- for line in response:
- chunk = line["choices"][0].get("delta", {}).get("content", "")
- yield chunk
|