123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298 |
- import logging
- from typing import Any, Dict, Generator, List, Optional
- from langchain.schema import BaseMessage as LCBaseMessage
- from embedchain.config import BaseLlmConfig
- from embedchain.config.llm.base import (DEFAULT_PROMPT,
- DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
- DOCS_SITE_PROMPT_TEMPLATE)
- from embedchain.helpers.json_serializable import JSONSerializable
- from embedchain.memory.base import ChatHistory
- from embedchain.memory.message import ChatMessage
- class BaseLlm(JSONSerializable):
- def __init__(self, config: Optional[BaseLlmConfig] = None):
- """Initialize a base LLM class
- :param config: LLM configuration option class, defaults to None
- :type config: Optional[BaseLlmConfig], optional
- """
- if config is None:
- self.config = BaseLlmConfig()
- else:
- self.config = config
- self.memory = ChatHistory()
- self.is_docs_site_instance = False
- self.online = False
- self.history: Any = None
- def get_llm_model_answer(self):
- """
- Usually implemented by child class
- """
- raise NotImplementedError
- def set_history(self, history: Any):
- """
- Provide your own history.
- Especially interesting for the query method, which does not internally manage conversation history.
- :param history: History to set
- :type history: Any
- """
- self.history = history
- def update_history(self, app_id: str, session_id: str = "default"):
- """Update class history attribute with history in memory (for chat method)"""
- chat_history = self.memory.get(app_id=app_id, session_id=session_id, num_rounds=10)
- self.set_history([str(history) for history in chat_history])
- def add_history(
- self,
- app_id: str,
- question: str,
- answer: str,
- metadata: Optional[Dict[str, Any]] = None,
- session_id: str = "default",
- ):
- chat_message = ChatMessage()
- chat_message.add_user_message(question, metadata=metadata)
- chat_message.add_ai_message(answer, metadata=metadata)
- self.memory.add(app_id=app_id, chat_message=chat_message, session_id=session_id)
- self.update_history(app_id=app_id, session_id=session_id)
- def generate_prompt(self, input_query: str, contexts: List[str], **kwargs: Dict[str, Any]) -> str:
- """
- Generates a prompt based on the given query and context, ready to be
- passed to an LLM
- :param input_query: The query to use.
- :type input_query: str
- :param contexts: List of similar documents to the query used as context.
- :type contexts: List[str]
- :return: The prompt
- :rtype: str
- """
- context_string = (" | ").join(contexts)
- web_search_result = kwargs.get("web_search_result", "")
- if web_search_result:
- context_string = self._append_search_and_context(context_string, web_search_result)
- prompt_contains_history = self.config._validate_prompt_history(self.config.prompt)
- if prompt_contains_history:
- # Prompt contains history
- # If there is no history yet, we insert `- no history -`
- prompt = self.config.prompt.substitute(
- context=context_string, query=input_query, history=self.history or "- no history -"
- )
- elif self.history and not prompt_contains_history:
- # History is present, but not included in the prompt.
- # check if it's the default prompt without history
- if (
- not self.config._validate_prompt_history(self.config.prompt)
- and self.config.prompt.template == DEFAULT_PROMPT
- ):
- # swap in the template with history
- prompt = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE.substitute(
- context=context_string, query=input_query, history=self.history
- )
- else:
- # If we can't swap in the default, we still proceed but tell users that the history is ignored.
- logging.warning(
- "Your bot contains a history, but prompt does not include `$history` key. History is ignored."
- )
- prompt = self.config.prompt.substitute(context=context_string, query=input_query)
- else:
- # basic use case, no history.
- prompt = self.config.prompt.substitute(context=context_string, query=input_query)
- return prompt
- def _append_search_and_context(self, context: str, web_search_result: str) -> str:
- """Append web search context to existing context
- :param context: Existing context
- :type context: str
- :param web_search_result: Web search result
- :type web_search_result: str
- :return: Concatenated web search result
- :rtype: str
- """
- return f"{context}\nWeb Search Result: {web_search_result}"
- def get_answer_from_llm(self, prompt: str):
- """
- Gets an answer based on the given query and context by passing it
- to an LLM.
- :param prompt: Gets an answer based on the given query and context by passing it to an LLM.
- :type prompt: str
- :return: The answer.
- :rtype: _type_
- """
- return self.get_llm_model_answer(prompt)
- def access_search_and_get_results(self, input_query: str):
- """
- Search the internet for additional context
- :param input_query: search query
- :type input_query: str
- :return: Search results
- :rtype: Unknown
- """
- try:
- from langchain.tools import DuckDuckGoSearchRun
- except ImportError:
- raise ImportError(
- 'Searching requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'
- ) from None
- search = DuckDuckGoSearchRun()
- logging.info(f"Access search to get answers for {input_query}")
- return search.run(input_query)
- def _stream_response(self, answer: Any) -> Generator[Any, Any, None]:
- """Generator to be used as streaming response
- :param answer: Answer chunk from llm
- :type answer: Any
- :yield: Answer chunk from llm
- :rtype: Generator[Any, Any, None]
- """
- streamed_answer = ""
- for chunk in answer:
- streamed_answer = streamed_answer + chunk
- yield chunk
- logging.info(f"Answer: {streamed_answer}")
- def query(self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False):
- """
- Queries the vector database based on the given input query.
- Gets relevant doc based on the query and then passes it to an
- LLM as context to get the answer.
- :param input_query: The query to use.
- :type input_query: str
- :param contexts: Embeddings retrieved from the database to be used as context.
- :type contexts: List[str]
- :param config: The `BaseLlmConfig` instance to use as configuration options. This is used for one method call.
- To persistently use a config, declare it during app init., defaults to None
- :type config: Optional[BaseLlmConfig], optional
- :param dry_run: A dry run does everything except send the resulting prompt to
- the LLM. The purpose is to test the prompt, not the response., defaults to False
- :type dry_run: bool, optional
- :return: The answer to the query or the dry run result
- :rtype: str
- """
- try:
- if config:
- # A config instance passed to this method will only be applied temporarily, for one call.
- # So we will save the previous config and restore it at the end of the execution.
- # For this we use the serializer.
- prev_config = self.config.serialize()
- self.config = config
- if config is not None and config.query_type == "Images":
- return contexts
- if self.is_docs_site_instance:
- self.config.prompt = DOCS_SITE_PROMPT_TEMPLATE
- self.config.number_documents = 5
- k = {}
- if self.online:
- k["web_search_result"] = self.access_search_and_get_results(input_query)
- prompt = self.generate_prompt(input_query, contexts, **k)
- logging.info(f"Prompt: {prompt}")
- if dry_run:
- return prompt
- answer = self.get_answer_from_llm(prompt)
- if isinstance(answer, str):
- logging.info(f"Answer: {answer}")
- return answer
- else:
- return self._stream_response(answer)
- finally:
- if config:
- # Restore previous config
- self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
- def chat(
- self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False, session_id: str = None
- ):
- """
- Queries the vector database on the given input query.
- Gets relevant doc based on the query and then passes it to an
- LLM as context to get the answer.
- Maintains the whole conversation in memory.
- :param input_query: The query to use.
- :type input_query: str
- :param contexts: Embeddings retrieved from the database to be used as context.
- :type contexts: List[str]
- :param config: The `BaseLlmConfig` instance to use as configuration options. This is used for one method call.
- To persistently use a config, declare it during app init., defaults to None
- :type config: Optional[BaseLlmConfig], optional
- :param dry_run: A dry run does everything except send the resulting prompt to
- the LLM. The purpose is to test the prompt, not the response., defaults to False
- :type dry_run: bool, optional
- :param session_id: Session ID to use for the conversation, defaults to None
- :type session_id: str, optional
- :return: The answer to the query or the dry run result
- :rtype: str
- """
- try:
- if config:
- # A config instance passed to this method will only be applied temporarily, for one call.
- # So we will save the previous config and restore it at the end of the execution.
- # For this we use the serializer.
- prev_config = self.config.serialize()
- self.config = config
- if self.is_docs_site_instance:
- self.config.prompt = DOCS_SITE_PROMPT_TEMPLATE
- self.config.number_documents = 5
- k = {}
- if self.online:
- k["web_search_result"] = self.access_search_and_get_results(input_query)
- prompt = self.generate_prompt(input_query, contexts, **k)
- logging.info(f"Prompt: {prompt}")
- if dry_run:
- return prompt
- answer = self.get_answer_from_llm(prompt)
- if isinstance(answer, str):
- logging.info(f"Answer: {answer}")
- return answer
- else:
- # this is a streamed response and needs to be handled differently.
- return self._stream_response(answer)
- finally:
- if config:
- # Restore previous config
- self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
- @staticmethod
- def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[LCBaseMessage]:
- """
- Construct a list of langchain messages
- :param prompt: User prompt
- :type prompt: str
- :param system_prompt: System prompt, defaults to None
- :type system_prompt: Optional[str], optional
- :return: List of messages
- :rtype: List[BaseMessage]
- """
- from langchain.schema import HumanMessage, SystemMessage
- messages = []
- if system_prompt:
- messages.append(SystemMessage(content=system_prompt))
- messages.append(HumanMessage(content=prompt))
- return messages
|