pipeline.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import threading
  2. import uuid
  3. import yaml
  4. from embedchain.config import PipelineConfig
  5. from embedchain.embedchain import EmbedChain
  6. from embedchain.embedder.base import BaseEmbedder
  7. from embedchain.embedder.openai import OpenAIEmbedder
  8. from embedchain.factory import EmbedderFactory, VectorDBFactory
  9. from embedchain.helper.json_serializable import register_deserializable
  10. from embedchain.vectordb.base import BaseVectorDB
  11. from embedchain.vectordb.chroma import ChromaDB
  12. @register_deserializable
  13. class Pipeline(EmbedChain):
  14. """
  15. EmbedChain pipeline lets you create a LLM powered app for your unstructured
  16. data by defining a pipeline with your chosen data source, embedding model,
  17. and vector database.
  18. """
  19. def __init__(self, config: PipelineConfig = None, db: BaseVectorDB = None, embedding_model: BaseEmbedder = None):
  20. """
  21. Initialize a new `App` instance.
  22. :param config: Configuration for the pipeline, defaults to None
  23. :type config: PipelineConfig, optional
  24. :param db: The database to use for storing and retrieving embeddings, defaults to None
  25. :type db: BaseVectorDB, optional
  26. :param embedding_model: The embedding model used to calculate embeddings, defaults to None
  27. :type embedding_model: BaseEmbedder, optional
  28. """
  29. super().__init__()
  30. self.config = config or PipelineConfig()
  31. self.name = self.config.name
  32. self.id = self.config.id or str(uuid.uuid4())
  33. self.embedding_model = embedding_model or OpenAIEmbedder()
  34. self.db = db or ChromaDB()
  35. self._initialize_db()
  36. self.user_asks = [] # legacy defaults
  37. self.s_id = self.config.id or str(uuid.uuid4())
  38. self.u_id = self._load_or_generate_user_id()
  39. thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("pipeline_init",))
  40. thread_telemetry.start()
  41. def _initialize_db(self):
  42. """
  43. Initialize the database.
  44. """
  45. self.db._set_embedder(self.embedding_model)
  46. self.db._initialize()
  47. self.db.set_collection_name(self.name)
  48. def search(self, query, num_documents=3):
  49. """
  50. Search for similar documents related to the query in the vector database.
  51. """
  52. where = {"app_id": self.id}
  53. return self.db.query(
  54. query,
  55. n_results=num_documents,
  56. where=where,
  57. skip_embedding=False,
  58. )
  59. @classmethod
  60. def from_config(cls, yaml_path: str):
  61. """
  62. Instantiate a Pipeline object from a YAML configuration file.
  63. :param yaml_path: Path to the YAML configuration file.
  64. :type yaml_path: str
  65. :return: An instance of the Pipeline class.
  66. :rtype: Pipeline
  67. """
  68. with open(yaml_path, "r") as file:
  69. config_data = yaml.safe_load(file)
  70. pipeline_config_data = config_data.get("pipeline", {})
  71. db_config_data = config_data.get("vectordb", {})
  72. embedding_model_config_data = config_data.get("embedding_model", {})
  73. pipeline_config = PipelineConfig(**pipeline_config_data)
  74. db_provider = db_config_data.get("provider", "chroma")
  75. db = VectorDBFactory.create(db_provider, db_config_data.get("config", {}))
  76. embedding_model_provider = embedding_model_config_data.get("provider", "openai")
  77. embedding_model = EmbedderFactory.create(
  78. embedding_model_provider, embedding_model_config_data.get("config", {})
  79. )
  80. return cls(config=pipeline_config, db=db, embedding_model=embedding_model)