12345678910111213141516171819202122232425262728293031323334353637383940 |
- from typing import Iterable, Optional, Union
- from langchain.callbacks.manager import CallbackManager
- from langchain.callbacks.stdout import StdOutCallbackHandler
- from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
- from langchain_community.llms import VLLM as BaseVLLM
- from embedchain.config import BaseLlmConfig
- from embedchain.helpers.json_serializable import register_deserializable
- from embedchain.llm.base import BaseLlm
- @register_deserializable
- class VLLM(BaseLlm):
- def __init__(self, config: Optional[BaseLlmConfig] = None):
- super().__init__(config=config)
- if self.config.model is None:
- self.config.model = "mosaicml/mpt-7b"
- def get_llm_model_answer(self, prompt):
- return self._get_answer(prompt=prompt, config=self.config)
- @staticmethod
- def _get_answer(prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
- callback_manager = [StreamingStdOutCallbackHandler()] if config.stream else [StdOutCallbackHandler()]
- # Prepare the arguments for BaseVLLM
- llm_args = {
- "model": config.model,
- "temperature": config.temperature,
- "top_p": config.top_p,
- "callback_manager": CallbackManager(callback_manager),
- }
- # Add model_kwargs if they are not None
- if config.model_kwargs is not None:
- llm_args.update(config.model_kwargs)
- llm = BaseVLLM(**llm_args)
- return llm(prompt)
|