pipeline.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. import ast
  2. import json
  3. import logging
  4. import os
  5. import sqlite3
  6. import uuid
  7. import requests
  8. import yaml
  9. from fastapi import FastAPI, HTTPException
  10. from embedchain import Client
  11. from embedchain.config import PipelineConfig
  12. from embedchain.embedchain import CONFIG_DIR, EmbedChain
  13. from embedchain.embedder.base import BaseEmbedder
  14. from embedchain.embedder.openai import OpenAIEmbedder
  15. from embedchain.factory import EmbedderFactory, VectorDBFactory
  16. from embedchain.helper.json_serializable import register_deserializable
  17. from embedchain.llm.base import BaseLlm
  18. from embedchain.vectordb.base import BaseVectorDB
  19. from embedchain.vectordb.chroma import ChromaDB
  20. SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db")
  21. @register_deserializable
  22. class Pipeline(EmbedChain):
  23. """
  24. EmbedChain pipeline lets you create a LLM powered app for your unstructured
  25. data by defining a pipeline with your chosen data source, embedding model,
  26. and vector database.
  27. """
  28. def __init__(
  29. self,
  30. config: PipelineConfig = None,
  31. db: BaseVectorDB = None,
  32. embedding_model: BaseEmbedder = None,
  33. llm: BaseLlm = None,
  34. yaml_path: str = None,
  35. log_level=logging.INFO,
  36. ):
  37. """
  38. Initialize a new `App` instance.
  39. :param config: Configuration for the pipeline, defaults to None
  40. :type config: PipelineConfig, optional
  41. :param db: The database to use for storing and retrieving embeddings, defaults to None
  42. :type db: BaseVectorDB, optional
  43. :param embedding_model: The embedding model used to calculate embeddings, defaults to None
  44. :type embedding_model: BaseEmbedder, optional
  45. """
  46. logging.basicConfig(level=log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
  47. self.logger = logging.getLogger(__name__)
  48. # Store the yaml config as an attribute to be able to send it
  49. self.yaml_config = None
  50. self.client = None
  51. if yaml_path:
  52. with open(yaml_path, "r") as file:
  53. config_data = yaml.safe_load(file)
  54. self.yaml_config = config_data
  55. self.config = config or PipelineConfig()
  56. self.name = self.config.name
  57. self.local_id = self.config.id or str(uuid.uuid4())
  58. self.embedding_model = embedding_model or OpenAIEmbedder()
  59. self.db = db or ChromaDB()
  60. self.llm = llm or None
  61. self._init_db()
  62. # setup user id and directory
  63. self.u_id = self._load_or_generate_user_id()
  64. # Establish a connection to the SQLite database
  65. self.connection = sqlite3.connect(SQLITE_PATH)
  66. self.cursor = self.connection.cursor()
  67. # Create the 'data_sources' table if it doesn't exist
  68. self.cursor.execute(
  69. """
  70. CREATE TABLE IF NOT EXISTS data_sources (
  71. pipeline_id TEXT,
  72. hash TEXT,
  73. type TEXT,
  74. value TEXT,
  75. metadata TEXT
  76. is_uploaded INTEGER DEFAULT 0,
  77. PRIMARY KEY (pipeline_id, hash)
  78. )
  79. """
  80. )
  81. self.connection.commit()
  82. self.user_asks = [] # legacy defaults
  83. def _init_db(self):
  84. """
  85. Initialize the database.
  86. """
  87. self.db._set_embedder(self.embedding_model)
  88. self.db._initialize()
  89. self.db.set_collection_name(self.db.config.collection_name)
  90. def _init_client(self):
  91. """
  92. Initialize the client.
  93. """
  94. config = Client.load_config()
  95. if config.get("api_key"):
  96. self.client = Client()
  97. else:
  98. api_key = input("Enter API key. You can find the API key at https://app.embedchain.ai/settings/keys/ \n")
  99. self.client = Client(api_key=api_key)
  100. def _create_pipeline(self):
  101. """
  102. Create a pipeline on the platform.
  103. """
  104. print("Creating pipeline on the platform...")
  105. # self.yaml_config is a dict. Pass it inside the key 'yaml_config' to the backend
  106. payload = {
  107. "yaml_config": json.dumps(self.yaml_config),
  108. "name": self.name,
  109. "local_id": self.local_id,
  110. }
  111. url = f"{self.client.host}/api/v1/pipelines/cli/create/"
  112. r = requests.post(
  113. url,
  114. json=payload,
  115. headers={"Authorization": f"Token {self.client.api_key}"},
  116. )
  117. if r.status_code not in [200, 201]:
  118. raise Exception(f"Error occurred while creating pipeline. Response from API: {r.text}")
  119. print(f"Pipeline created. link: https://app.embedchain.ai/pipelines/{r.json()['id']}")
  120. return r.json()
  121. def _get_presigned_url(self, data_type, data_value):
  122. payload = {"data_type": data_type, "data_value": data_value}
  123. r = requests.post(
  124. f"{self.client.host}/api/v1/pipelines/{self.id}/cli/presigned_url/",
  125. json=payload,
  126. headers={"Authorization": f"Token {self.client.api_key}"},
  127. )
  128. r.raise_for_status()
  129. return r.json()
  130. def search(self, query, num_documents=3):
  131. """
  132. Search for similar documents related to the query in the vector database.
  133. """
  134. # TODO: Search will call the endpoint rather than fetching the data from the db itself when deploy=True.
  135. if self.deploy is False:
  136. where = {"app_id": self.local_id}
  137. return self.db.query(
  138. query,
  139. n_results=num_documents,
  140. where=where,
  141. skip_embedding=False,
  142. )
  143. else:
  144. # Make API call to the backend to get the results
  145. NotImplementedError("Search is not implemented yet for the prod mode.")
  146. def _upload_file_to_presigned_url(self, presigned_url, file_path):
  147. try:
  148. with open(file_path, "rb") as file:
  149. response = requests.put(presigned_url, data=file)
  150. response.raise_for_status()
  151. return response.status_code == 200
  152. except Exception as e:
  153. self.logger.exception(f"Error occurred during file upload: {str(e)}")
  154. return False
  155. def _upload_data_to_pipeline(self, data_type, data_value, metadata=None):
  156. payload = {
  157. "data_type": data_type,
  158. "data_value": data_value,
  159. "metadata": metadata,
  160. }
  161. return self._send_api_request(f"/api/v1/pipelines/{self.id}/cli/add/", payload)
  162. def _send_api_request(self, endpoint, payload):
  163. url = f"{self.client.host}{endpoint}"
  164. headers = {"Authorization": f"Token {self.client.api_key}"}
  165. response = requests.post(url, json=payload, headers=headers)
  166. response.raise_for_status()
  167. return response
  168. def _process_and_upload_data(self, data_hash, data_type, data_value):
  169. if os.path.isabs(data_value):
  170. presigned_url_data = self._get_presigned_url(data_type, data_value)
  171. presigned_url = presigned_url_data["presigned_url"]
  172. s3_key = presigned_url_data["s3_key"]
  173. if self._upload_file_to_presigned_url(presigned_url, file_path=data_value):
  174. data_value = presigned_url
  175. metadata = {"file_path": data_value, "s3_key": s3_key}
  176. else:
  177. self.logger.error(f"File upload failed for hash: {data_hash}")
  178. return False
  179. else:
  180. if data_type == "qna_pair":
  181. data_value = list(ast.literal_eval(data_value))
  182. metadata = {}
  183. try:
  184. self._upload_data_to_pipeline(data_type, data_value, metadata)
  185. self._mark_data_as_uploaded(data_hash)
  186. self.logger.info(f"Data of type {data_type} uploaded successfully.")
  187. return True
  188. except Exception as e:
  189. self.logger.error(f"Error occurred during data upload: {str(e)}")
  190. return False
  191. def _mark_data_as_uploaded(self, data_hash):
  192. self.cursor.execute(
  193. "UPDATE data_sources SET is_uploaded = 1 WHERE hash = ? AND pipeline_id = ? AND is_uploaded = 0",
  194. (data_hash, self.local_id),
  195. )
  196. self.connection.commit()
  197. def deploy(self):
  198. try:
  199. if self.client is None:
  200. self._init_client()
  201. pipeline_data = self._create_pipeline()
  202. self.id = pipeline_data["id"]
  203. results = self.cursor.execute(
  204. "SELECT * FROM data_sources WHERE pipeline_id = ? AND is_uploaded = 0", (self.local_id,)
  205. ).fetchall()
  206. for result in results:
  207. data_hash, data_type, data_value = result[0], result[2], result[3]
  208. if self._process_and_upload_data(data_hash, data_type, data_value):
  209. self.logger.info(f"Data with hash {data_hash} uploaded successfully.")
  210. except Exception as e:
  211. self.logger.exception(f"Error occurred during deployment: {str(e)}")
  212. raise HTTPException(status_code=500, detail="Error occurred during deployment.")
  213. @classmethod
  214. def from_config(cls, yaml_path: str):
  215. """
  216. Instantiate a Pipeline object from a YAML configuration file.
  217. :param yaml_path: Path to the YAML configuration file.
  218. :type yaml_path: str
  219. :return: An instance of the Pipeline class.
  220. :rtype: Pipeline
  221. """
  222. with open(yaml_path, "r") as file:
  223. config_data = yaml.safe_load(file)
  224. pipeline_config_data = config_data.get("pipeline", {}).get("config", {})
  225. db_config_data = config_data.get("vectordb", {})
  226. embedding_model_config_data = config_data.get("embedding_model", {})
  227. pipeline_config = PipelineConfig(**pipeline_config_data)
  228. db_provider = db_config_data.get("provider", "chroma")
  229. db = VectorDBFactory.create(db_provider, db_config_data.get("config", {}))
  230. embedding_model_provider = embedding_model_config_data.get("provider", "openai")
  231. embedding_model = EmbedderFactory.create(
  232. embedding_model_provider, embedding_model_config_data.get("config", {})
  233. )
  234. return cls(
  235. config=pipeline_config,
  236. db=db,
  237. embedding_model=embedding_model,
  238. yaml_path=yaml_path,
  239. )
  240. def start(self, host="0.0.0.0", port=8000):
  241. app = FastAPI()
  242. @app.post("/add")
  243. async def add_document(data_value: str, data_type: str = None):
  244. """
  245. Add a document to the pipeline.
  246. """
  247. try:
  248. document = {"data_value": data_value, "data_type": data_type}
  249. self.add(document)
  250. return {"message": "Document added successfully"}
  251. except Exception as e:
  252. raise HTTPException(status_code=500, detail=str(e))
  253. @app.post("/query")
  254. async def query_documents(query: str, num_documents: int = 3):
  255. """
  256. Query for similar documents in the pipeline.
  257. """
  258. try:
  259. results = self.search(query, num_documents)
  260. return results
  261. except Exception as e:
  262. raise HTTPException(status_code=500, detail=str(e))
  263. import uvicorn
  264. uvicorn.run(app, host=host, port=port)