OpenSourceApp.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import logging
  2. from typing import Iterable, Optional, Union
  3. from embedchain.config import ChatConfig, OpenSourceAppConfig
  4. from embedchain.embedchain import EmbedChain
  5. from embedchain.helper_classes.json_serializable import register_deserializable
  6. gpt4all_model = None
  7. @register_deserializable
  8. class OpenSourceApp(EmbedChain):
  9. """
  10. The OpenSource app.
  11. Same as App, but uses an open source embedding model and LLM.
  12. Has two function: add and query.
  13. adds(data_type, url): adds the data from the given URL to the vector db.
  14. query(query): finds answer to the given query using vector database and LLM.
  15. """
  16. def __init__(self, config: OpenSourceAppConfig = None, system_prompt: Optional[str] = None):
  17. """
  18. :param config: OpenSourceAppConfig instance to load as configuration. Optional.
  19. `ef` defaults to open source.
  20. :param system_prompt: System prompt string. Optional.
  21. """
  22. logging.info("Loading open source embedding model. This may take some time...") # noqa:E501
  23. if not config:
  24. config = OpenSourceAppConfig()
  25. if not config.model:
  26. raise ValueError("OpenSourceApp needs a model to be instantiated. Maybe you passed the wrong config type?")
  27. self.instance = OpenSourceApp._get_instance(config.model)
  28. logging.info("Successfully loaded open source embedding model.")
  29. super().__init__(config, system_prompt)
  30. def get_llm_model_answer(self, prompt, config: ChatConfig):
  31. return self._get_gpt4all_answer(prompt=prompt, config=config)
  32. @staticmethod
  33. def _get_instance(model):
  34. try:
  35. from gpt4all import GPT4All
  36. except ModuleNotFoundError:
  37. raise ModuleNotFoundError(
  38. "The GPT4All python package is not installed. Please install it with `pip install embedchain[opensource]`" # noqa E501
  39. ) from None
  40. return GPT4All(model)
  41. def _get_gpt4all_answer(self, prompt: str, config: ChatConfig) -> Union[str, Iterable]:
  42. if config.model and config.model != self.config.model:
  43. raise RuntimeError(
  44. "OpenSourceApp does not support switching models at runtime. Please create a new app instance."
  45. )
  46. if self.system_prompt or config.system_prompt:
  47. raise ValueError("OpenSourceApp does not support `system_prompt`")
  48. response = self.instance.generate(
  49. prompt=prompt,
  50. streaming=config.stream,
  51. top_p=config.top_p,
  52. max_tokens=config.max_tokens,
  53. temp=config.temperature,
  54. )
  55. return response