pipeline.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. import ast
  2. import json
  3. import logging
  4. import os
  5. import sqlite3
  6. import uuid
  7. import requests
  8. import yaml
  9. from embedchain import Client
  10. from embedchain.config import PipelineConfig, ChunkerConfig
  11. from embedchain.embedchain import CONFIG_DIR, EmbedChain
  12. from embedchain.embedder.base import BaseEmbedder
  13. from embedchain.embedder.openai import OpenAIEmbedder
  14. from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
  15. from embedchain.helper.json_serializable import register_deserializable
  16. from embedchain.llm.base import BaseLlm
  17. from embedchain.llm.openai import OpenAILlm
  18. from embedchain.telemetry.posthog import AnonymousTelemetry
  19. from embedchain.utils import validate_yaml_config
  20. from embedchain.vectordb.base import BaseVectorDB
  21. from embedchain.vectordb.chroma import ChromaDB
  22. SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db")
  23. @register_deserializable
  24. class Pipeline(EmbedChain):
  25. """
  26. EmbedChain pipeline lets you create a LLM powered app for your unstructured
  27. data by defining a pipeline with your chosen data source, embedding model,
  28. and vector database.
  29. """
  30. def __init__(
  31. self,
  32. id: str = None,
  33. name: str = None,
  34. config: PipelineConfig = None,
  35. db: BaseVectorDB = None,
  36. embedding_model: BaseEmbedder = None,
  37. llm: BaseLlm = None,
  38. yaml_path: str = None,
  39. log_level=logging.INFO,
  40. auto_deploy: bool = False,
  41. chunker: ChunkerConfig = None,
  42. ):
  43. """
  44. Initialize a new `App` instance.
  45. :param config: Configuration for the pipeline, defaults to None
  46. :type config: PipelineConfig, optional
  47. :param db: The database to use for storing and retrieving embeddings, defaults to None
  48. :type db: BaseVectorDB, optional
  49. :param embedding_model: The embedding model used to calculate embeddings, defaults to None
  50. :type embedding_model: BaseEmbedder, optional
  51. :param llm: The LLM model used to calculate embeddings, defaults to None
  52. :type llm: BaseLlm, optional
  53. :param yaml_path: Path to the YAML configuration file, defaults to None
  54. :type yaml_path: str, optional
  55. :param log_level: Log level to use, defaults to logging.INFO
  56. :type log_level: int, optional
  57. :param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
  58. :type auto_deploy: bool, optional
  59. :raises Exception: If an error occurs while creating the pipeline
  60. """
  61. if id and yaml_path:
  62. raise Exception("Cannot provide both id and config. Please provide only one of them.")
  63. if id and name:
  64. raise Exception("Cannot provide both id and name. Please provide only one of them.")
  65. if name and config:
  66. raise Exception("Cannot provide both name and config. Please provide only one of them.")
  67. logging.basicConfig(level=log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
  68. self.logger = logging.getLogger(__name__)
  69. self.auto_deploy = auto_deploy
  70. # Store the yaml config as an attribute to be able to send it
  71. self.yaml_config = None
  72. self.client = None
  73. # pipeline_id from the backend
  74. self.id = None
  75. self.chunker = None
  76. if chunker:
  77. self.chunker = ChunkerConfig(**chunker)
  78. self.config = config or PipelineConfig()
  79. self.name = self.config.name
  80. self.config.id = self.local_id = str(uuid.uuid4()) if self.config.id is None else self.config.id
  81. if yaml_path:
  82. with open(yaml_path, "r") as file:
  83. config_data = yaml.safe_load(file)
  84. self.yaml_config = config_data
  85. if id is not None:
  86. # Init client first since user is trying to fetch the pipeline
  87. # details from the platform
  88. self._init_client()
  89. pipeline_details = self._get_pipeline(id)
  90. self.config.id = self.local_id = pipeline_details["metadata"]["local_id"]
  91. self.id = id
  92. if name is not None:
  93. self.name = name
  94. self.embedding_model = embedding_model or OpenAIEmbedder()
  95. self.db = db or ChromaDB()
  96. self.llm = llm or OpenAILlm()
  97. self._init_db()
  98. # Send anonymous telemetry
  99. self._telemetry_props = {"class": self.__class__.__name__}
  100. self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
  101. # Establish a connection to the SQLite database
  102. self.connection = sqlite3.connect(SQLITE_PATH, check_same_thread=False)
  103. self.cursor = self.connection.cursor()
  104. # Create the 'data_sources' table if it doesn't exist
  105. self.cursor.execute(
  106. """
  107. CREATE TABLE IF NOT EXISTS data_sources (
  108. pipeline_id TEXT,
  109. hash TEXT,
  110. type TEXT,
  111. value TEXT,
  112. metadata TEXT,
  113. is_uploaded INTEGER DEFAULT 0,
  114. PRIMARY KEY (pipeline_id, hash)
  115. )
  116. """
  117. )
  118. self.connection.commit()
  119. # Send anonymous telemetry
  120. self.telemetry.capture(event_name="init", properties=self._telemetry_props)
  121. self.user_asks = []
  122. if self.auto_deploy:
  123. self.deploy()
  124. def _init_db(self):
  125. """
  126. Initialize the database.
  127. """
  128. self.db._set_embedder(self.embedding_model)
  129. self.db._initialize()
  130. self.db.set_collection_name(self.db.config.collection_name)
  131. def _init_client(self):
  132. """
  133. Initialize the client.
  134. """
  135. config = Client.load_config()
  136. if config.get("api_key"):
  137. self.client = Client()
  138. else:
  139. api_key = input(
  140. "🔑 Enter your Embedchain API key. You can find the API key at https://app.embedchain.ai/settings/keys/ \n" # noqa: E501
  141. )
  142. self.client = Client(api_key=api_key)
  143. def _get_pipeline(self, id):
  144. """
  145. Get existing pipeline
  146. """
  147. print("🛠️ Fetching pipeline details from the platform...")
  148. url = f"{self.client.host}/api/v1/pipelines/{id}/cli/"
  149. r = requests.get(
  150. url,
  151. headers={"Authorization": f"Token {self.client.api_key}"},
  152. )
  153. if r.status_code == 404:
  154. raise Exception(f"❌ Pipeline with id {id} not found!")
  155. print(
  156. f"🎉 Pipeline loaded successfully! Pipeline url: https://app.embedchain.ai/pipelines/{r.json()['id']}\n" # noqa: E501
  157. )
  158. return r.json()
  159. def _create_pipeline(self):
  160. """
  161. Create a pipeline on the platform.
  162. """
  163. print("🛠️ Creating pipeline on the platform...")
  164. # self.yaml_config is a dict. Pass it inside the key 'yaml_config' to the backend
  165. payload = {
  166. "yaml_config": json.dumps(self.yaml_config),
  167. "name": self.name,
  168. "local_id": self.local_id,
  169. }
  170. url = f"{self.client.host}/api/v1/pipelines/cli/create/"
  171. r = requests.post(
  172. url,
  173. json=payload,
  174. headers={"Authorization": f"Token {self.client.api_key}"},
  175. )
  176. if r.status_code not in [200, 201]:
  177. raise Exception(f"❌ Error occurred while creating pipeline. API response: {r.text}")
  178. if r.status_code == 200:
  179. print(
  180. f"🎉🎉🎉 Existing pipeline found! View your pipeline: https://app.embedchain.ai/pipelines/{r.json()['id']}\n" # noqa: E501
  181. ) # noqa: E501
  182. elif r.status_code == 201:
  183. print(
  184. f"🎉🎉🎉 Pipeline created successfully! View your pipeline: https://app.embedchain.ai/pipelines/{r.json()['id']}\n" # noqa: E501
  185. )
  186. return r.json()
  187. def _get_presigned_url(self, data_type, data_value):
  188. payload = {"data_type": data_type, "data_value": data_value}
  189. r = requests.post(
  190. f"{self.client.host}/api/v1/pipelines/{self.id}/cli/presigned_url/",
  191. json=payload,
  192. headers={"Authorization": f"Token {self.client.api_key}"},
  193. )
  194. r.raise_for_status()
  195. return r.json()
  196. def search(self, query, num_documents=3):
  197. """
  198. Search for similar documents related to the query in the vector database.
  199. """
  200. # Send anonymous telemetry
  201. self.telemetry.capture(event_name="search", properties=self._telemetry_props)
  202. # TODO: Search will call the endpoint rather than fetching the data from the db itself when deploy=True.
  203. if self.id is None:
  204. where = {"app_id": self.local_id}
  205. context = self.db.query(
  206. query,
  207. n_results=num_documents,
  208. where=where,
  209. skip_embedding=False,
  210. citations=True,
  211. )
  212. result = []
  213. for c in context:
  214. result.append(
  215. {
  216. "context": c[0],
  217. "source": c[1],
  218. "document_id": c[2],
  219. }
  220. )
  221. return result
  222. else:
  223. # Make API call to the backend to get the results
  224. NotImplementedError("Search is not implemented yet for the prod mode.")
  225. def _upload_file_to_presigned_url(self, presigned_url, file_path):
  226. try:
  227. with open(file_path, "rb") as file:
  228. response = requests.put(presigned_url, data=file)
  229. response.raise_for_status()
  230. return response.status_code == 200
  231. except Exception as e:
  232. self.logger.exception(f"Error occurred during file upload: {str(e)}")
  233. print("❌ Error occurred during file upload!")
  234. return False
  235. def _upload_data_to_pipeline(self, data_type, data_value, metadata=None):
  236. payload = {
  237. "data_type": data_type,
  238. "data_value": data_value,
  239. "metadata": metadata,
  240. }
  241. try:
  242. self._send_api_request(f"/api/v1/pipelines/{self.id}/cli/add/", payload)
  243. # print the local file path if user tries to upload a local file
  244. printed_value = metadata.get("file_path") if metadata.get("file_path") else data_value
  245. print(f"✅ Data of type: {data_type}, value: {printed_value} added successfully.")
  246. except Exception as e:
  247. print(f"❌ Error occurred during data upload for type {data_type}!. Error: {str(e)}")
  248. def _send_api_request(self, endpoint, payload):
  249. url = f"{self.client.host}{endpoint}"
  250. headers = {"Authorization": f"Token {self.client.api_key}"}
  251. response = requests.post(url, json=payload, headers=headers)
  252. response.raise_for_status()
  253. return response
  254. def _process_and_upload_data(self, data_hash, data_type, data_value):
  255. if os.path.isabs(data_value):
  256. presigned_url_data = self._get_presigned_url(data_type, data_value)
  257. presigned_url = presigned_url_data["presigned_url"]
  258. s3_key = presigned_url_data["s3_key"]
  259. if self._upload_file_to_presigned_url(presigned_url, file_path=data_value):
  260. metadata = {"file_path": data_value, "s3_key": s3_key}
  261. data_value = presigned_url
  262. else:
  263. self.logger.error(f"File upload failed for hash: {data_hash}")
  264. return False
  265. else:
  266. if data_type == "qna_pair":
  267. data_value = list(ast.literal_eval(data_value))
  268. metadata = {}
  269. try:
  270. self._upload_data_to_pipeline(data_type, data_value, metadata)
  271. self._mark_data_as_uploaded(data_hash)
  272. return True
  273. except Exception:
  274. print(f"❌ Error occurred during data upload for hash {data_hash}!")
  275. return False
  276. def _mark_data_as_uploaded(self, data_hash):
  277. self.cursor.execute(
  278. "UPDATE data_sources SET is_uploaded = 1 WHERE hash = ? AND pipeline_id = ?",
  279. (data_hash, self.local_id),
  280. )
  281. self.connection.commit()
  282. def get_data_sources(self):
  283. db_data = self.cursor.execute("SELECT * FROM data_sources WHERE pipeline_id = ?", (self.local_id,)).fetchall()
  284. data_sources = []
  285. for data in db_data:
  286. data_sources.append({"data_type": data[2], "data_value": data[3], "metadata": data[4]})
  287. return data_sources
  288. def deploy(self):
  289. if self.client is None:
  290. self._init_client()
  291. pipeline_data = self._create_pipeline()
  292. self.id = pipeline_data["id"]
  293. results = self.cursor.execute(
  294. "SELECT * FROM data_sources WHERE pipeline_id = ? AND is_uploaded = 0", (self.local_id,) # noqa:E501
  295. ).fetchall()
  296. if len(results) > 0:
  297. print("🛠️ Adding data to your pipeline...")
  298. for result in results:
  299. data_hash, data_type, data_value = result[1], result[2], result[3]
  300. self._process_and_upload_data(data_hash, data_type, data_value)
  301. # Send anonymous telemetry
  302. self.telemetry.capture(event_name="deploy", properties=self._telemetry_props)
  303. @classmethod
  304. def from_config(cls, yaml_path: str, auto_deploy: bool = False):
  305. """
  306. Instantiate a Pipeline object from a YAML configuration file.
  307. :param yaml_path: Path to the YAML configuration file.
  308. :type yaml_path: str
  309. :param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
  310. :type auto_deploy: bool, optional
  311. :return: An instance of the Pipeline class.
  312. :rtype: Pipeline
  313. """
  314. with open(yaml_path, "r") as file:
  315. config_data = yaml.safe_load(file)
  316. try:
  317. validate_yaml_config(config_data)
  318. except Exception as e:
  319. raise Exception(f"❌ Error occurred while validating the YAML config. Error: {str(e)}")
  320. pipeline_config_data = config_data.get("app", {}).get("config", {})
  321. db_config_data = config_data.get("vectordb", {})
  322. embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
  323. llm_config_data = config_data.get("llm", {})
  324. chunker_config_data = config_data.get("chunker", {})
  325. pipeline_config = PipelineConfig(**pipeline_config_data)
  326. db_provider = db_config_data.get("provider", "chroma")
  327. db = VectorDBFactory.create(db_provider, db_config_data.get("config", {}))
  328. if llm_config_data:
  329. llm_provider = llm_config_data.get("provider", "openai")
  330. llm = LlmFactory.create(llm_provider, llm_config_data.get("config", {}))
  331. else:
  332. llm = None
  333. embedding_model_provider = embedding_model_config_data.get("provider", "openai")
  334. embedding_model = EmbedderFactory.create(
  335. embedding_model_provider, embedding_model_config_data.get("config", {})
  336. )
  337. # Send anonymous telemetry
  338. event_properties = {"init_type": "yaml_config"}
  339. AnonymousTelemetry().capture(event_name="init", properties=event_properties)
  340. return cls(
  341. config=pipeline_config,
  342. llm=llm,
  343. db=db,
  344. embedding_model=embedding_model,
  345. yaml_path=yaml_path,
  346. auto_deploy=auto_deploy,
  347. chunker=chunker_config_data,
  348. )