OpenSourceApp.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import logging
  2. from typing import Iterable, Union, Optional
  3. from embedchain.config import ChatConfig, OpenSourceAppConfig
  4. from embedchain.embedchain import EmbedChain
  5. gpt4all_model = None
  6. class OpenSourceApp(EmbedChain):
  7. """
  8. The OpenSource app.
  9. Same as App, but uses an open source embedding model and LLM.
  10. Has two function: add and query.
  11. adds(data_type, url): adds the data from the given URL to the vector db.
  12. query(query): finds answer to the given query using vector database and LLM.
  13. """
  14. def __init__(self, config: OpenSourceAppConfig = None, system_prompt: Optional[str] = None):
  15. """
  16. :param config: OpenSourceAppConfig instance to load as configuration. Optional.
  17. `ef` defaults to open source.
  18. :param system_prompt: System prompt string. Optional.
  19. """
  20. logging.info("Loading open source embedding model. This may take some time...") # noqa:E501
  21. if not config:
  22. config = OpenSourceAppConfig()
  23. if not config.model:
  24. raise ValueError("OpenSourceApp needs a model to be instantiated. Maybe you passed the wrong config type?")
  25. self.instance = OpenSourceApp._get_instance(config.model)
  26. logging.info("Successfully loaded open source embedding model.")
  27. super().__init__(config, system_prompt)
  28. def get_llm_model_answer(self, prompt, config: ChatConfig):
  29. return self._get_gpt4all_answer(prompt=prompt, config=config)
  30. @staticmethod
  31. def _get_instance(model):
  32. try:
  33. from gpt4all import GPT4All
  34. except ModuleNotFoundError:
  35. raise ModuleNotFoundError(
  36. "The GPT4All python package is not installed. Please install it with `pip install embedchain[opensource]`" # noqa E501
  37. ) from None
  38. return GPT4All(model)
  39. def _get_gpt4all_answer(self, prompt: str, config: ChatConfig) -> Union[str, Iterable]:
  40. if config.model and config.model != self.config.model:
  41. raise RuntimeError(
  42. "OpenSourceApp does not support switching models at runtime. Please create a new app instance."
  43. )
  44. if self.system_prompt or config.system_prompt:
  45. raise ValueError("OpenSourceApp does not support `system_prompt`")
  46. response = self.instance.generate(
  47. prompt=prompt,
  48. streaming=config.stream,
  49. top_p=config.top_p,
  50. max_tokens=config.max_tokens,
  51. temp=config.temperature,
  52. )
  53. return response