app.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. import ast
  2. import json
  3. import logging
  4. import os
  5. import sqlite3
  6. import uuid
  7. from typing import Any, Dict, Optional
  8. import requests
  9. import yaml
  10. from embedchain.cache import (
  11. Config,
  12. ExactMatchEvaluation,
  13. SearchDistanceEvaluation,
  14. cache,
  15. gptcache_data_manager,
  16. gptcache_pre_function,
  17. )
  18. from embedchain.client import Client
  19. from embedchain.config import AppConfig, CacheConfig, ChunkerConfig
  20. from embedchain.constants import SQLITE_PATH
  21. from embedchain.embedchain import EmbedChain
  22. from embedchain.embedder.base import BaseEmbedder
  23. from embedchain.embedder.openai import OpenAIEmbedder
  24. from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
  25. from embedchain.helpers.json_serializable import register_deserializable
  26. from embedchain.llm.base import BaseLlm
  27. from embedchain.llm.openai import OpenAILlm
  28. from embedchain.telemetry.posthog import AnonymousTelemetry
  29. from embedchain.utils.misc import validate_config
  30. from embedchain.vectordb.base import BaseVectorDB
  31. from embedchain.vectordb.chroma import ChromaDB
  32. # Set up the user directory if it doesn't exist already
  33. Client.setup_dir()
  34. @register_deserializable
  35. class App(EmbedChain):
  36. """
  37. EmbedChain App lets you create a LLM powered app for your unstructured
  38. data by defining your chosen data source, embedding model,
  39. and vector database.
  40. """
  41. def __init__(
  42. self,
  43. id: str = None,
  44. name: str = None,
  45. config: AppConfig = None,
  46. db: BaseVectorDB = None,
  47. embedding_model: BaseEmbedder = None,
  48. llm: BaseLlm = None,
  49. config_data: dict = None,
  50. log_level=logging.WARN,
  51. auto_deploy: bool = False,
  52. chunker: ChunkerConfig = None,
  53. cache_config: CacheConfig = None,
  54. ):
  55. """
  56. Initialize a new `App` instance.
  57. :param config: Configuration for the pipeline, defaults to None
  58. :type config: AppConfig, optional
  59. :param db: The database to use for storing and retrieving embeddings, defaults to None
  60. :type db: BaseVectorDB, optional
  61. :param embedding_model: The embedding model used to calculate embeddings, defaults to None
  62. :type embedding_model: BaseEmbedder, optional
  63. :param llm: The LLM model used to calculate embeddings, defaults to None
  64. :type llm: BaseLlm, optional
  65. :param config_data: Config dictionary, defaults to None
  66. :type config_data: dict, optional
  67. :param log_level: Log level to use, defaults to logging.WARN
  68. :type log_level: int, optional
  69. :param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
  70. :type auto_deploy: bool, optional
  71. :raises Exception: If an error occurs while creating the pipeline
  72. """
  73. if id and config_data:
  74. raise Exception("Cannot provide both id and config. Please provide only one of them.")
  75. if id and name:
  76. raise Exception("Cannot provide both id and name. Please provide only one of them.")
  77. if name and config:
  78. raise Exception("Cannot provide both name and config. Please provide only one of them.")
  79. logging.basicConfig(level=log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
  80. self.logger = logging.getLogger(__name__)
  81. self.auto_deploy = auto_deploy
  82. # Store the dict config as an attribute to be able to send it
  83. self.config_data = config_data if (config_data and validate_config(config_data)) else None
  84. self.client = None
  85. # pipeline_id from the backend
  86. self.id = None
  87. self.chunker = None
  88. if chunker:
  89. self.chunker = ChunkerConfig(**chunker)
  90. self.cache_config = cache_config
  91. self.config = config or AppConfig()
  92. self.name = self.config.name
  93. self.config.id = self.local_id = str(uuid.uuid4()) if self.config.id is None else self.config.id
  94. if id is not None:
  95. # Init client first since user is trying to fetch the pipeline
  96. # details from the platform
  97. self._init_client()
  98. pipeline_details = self._get_pipeline(id)
  99. self.config.id = self.local_id = pipeline_details["metadata"]["local_id"]
  100. self.id = id
  101. if name is not None:
  102. self.name = name
  103. self.embedding_model = embedding_model or OpenAIEmbedder()
  104. self.db = db or ChromaDB()
  105. self.llm = llm or OpenAILlm()
  106. self._init_db()
  107. # If cache_config is provided, initializing the cache ...
  108. if self.cache_config is not None:
  109. self._init_cache()
  110. # Send anonymous telemetry
  111. self._telemetry_props = {"class": self.__class__.__name__}
  112. self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
  113. # Establish a connection to the SQLite database
  114. self.connection = sqlite3.connect(SQLITE_PATH, check_same_thread=False)
  115. self.cursor = self.connection.cursor()
  116. # Create the 'data_sources' table if it doesn't exist
  117. self.cursor.execute(
  118. """
  119. CREATE TABLE IF NOT EXISTS data_sources (
  120. pipeline_id TEXT,
  121. hash TEXT,
  122. type TEXT,
  123. value TEXT,
  124. metadata TEXT,
  125. is_uploaded INTEGER DEFAULT 0,
  126. PRIMARY KEY (pipeline_id, hash)
  127. )
  128. """
  129. )
  130. self.connection.commit()
  131. # Send anonymous telemetry
  132. self.telemetry.capture(event_name="init", properties=self._telemetry_props)
  133. self.user_asks = []
  134. if self.auto_deploy:
  135. self.deploy()
  136. def _init_db(self):
  137. """
  138. Initialize the database.
  139. """
  140. self.db._set_embedder(self.embedding_model)
  141. self.db._initialize()
  142. self.db.set_collection_name(self.db.config.collection_name)
  143. def _init_cache(self):
  144. if self.cache_config.similarity_eval_config.strategy == "exact":
  145. similarity_eval_func = ExactMatchEvaluation()
  146. else:
  147. similarity_eval_func = SearchDistanceEvaluation(
  148. max_distance=self.cache_config.similarity_eval_config.max_distance,
  149. positive=self.cache_config.similarity_eval_config.positive,
  150. )
  151. cache.init(
  152. pre_embedding_func=gptcache_pre_function,
  153. embedding_func=self.embedding_model.to_embeddings,
  154. data_manager=gptcache_data_manager(vector_dimension=self.embedding_model.vector_dimension),
  155. similarity_evaluation=similarity_eval_func,
  156. config=Config(**self.cache_config.init_config.as_dict()),
  157. )
  158. def _init_client(self):
  159. """
  160. Initialize the client.
  161. """
  162. config = Client.load_config()
  163. if config.get("api_key"):
  164. self.client = Client()
  165. else:
  166. api_key = input(
  167. "🔑 Enter your Embedchain API key. You can find the API key at https://app.embedchain.ai/settings/keys/ \n" # noqa: E501
  168. )
  169. self.client = Client(api_key=api_key)
  170. def _get_pipeline(self, id):
  171. """
  172. Get existing pipeline
  173. """
  174. print("🛠️ Fetching pipeline details from the platform...")
  175. url = f"{self.client.host}/api/v1/pipelines/{id}/cli/"
  176. r = requests.get(
  177. url,
  178. headers={"Authorization": f"Token {self.client.api_key}"},
  179. )
  180. if r.status_code == 404:
  181. raise Exception(f"❌ Pipeline with id {id} not found!")
  182. print(
  183. f"🎉 Pipeline loaded successfully! Pipeline url: https://app.embedchain.ai/pipelines/{r.json()['id']}\n" # noqa: E501
  184. )
  185. return r.json()
  186. def _create_pipeline(self):
  187. """
  188. Create a pipeline on the platform.
  189. """
  190. print("🛠️ Creating pipeline on the platform...")
  191. # self.config_data is a dict. Pass it inside the key 'yaml_config' to the backend
  192. payload = {
  193. "yaml_config": json.dumps(self.config_data),
  194. "name": self.name,
  195. "local_id": self.local_id,
  196. }
  197. url = f"{self.client.host}/api/v1/pipelines/cli/create/"
  198. r = requests.post(
  199. url,
  200. json=payload,
  201. headers={"Authorization": f"Token {self.client.api_key}"},
  202. )
  203. if r.status_code not in [200, 201]:
  204. raise Exception(f"❌ Error occurred while creating pipeline. API response: {r.text}")
  205. if r.status_code == 200:
  206. print(
  207. f"🎉🎉🎉 Existing pipeline found! View your pipeline: https://app.embedchain.ai/pipelines/{r.json()['id']}\n" # noqa: E501
  208. ) # noqa: E501
  209. elif r.status_code == 201:
  210. print(
  211. f"🎉🎉🎉 Pipeline created successfully! View your pipeline: https://app.embedchain.ai/pipelines/{r.json()['id']}\n" # noqa: E501
  212. )
  213. return r.json()
  214. def _get_presigned_url(self, data_type, data_value):
  215. payload = {"data_type": data_type, "data_value": data_value}
  216. r = requests.post(
  217. f"{self.client.host}/api/v1/pipelines/{self.id}/cli/presigned_url/",
  218. json=payload,
  219. headers={"Authorization": f"Token {self.client.api_key}"},
  220. )
  221. r.raise_for_status()
  222. return r.json()
  223. def search(self, query, num_documents=3):
  224. """
  225. Search for similar documents related to the query in the vector database.
  226. """
  227. # Send anonymous telemetry
  228. self.telemetry.capture(event_name="search", properties=self._telemetry_props)
  229. # TODO: Search will call the endpoint rather than fetching the data from the db itself when deploy=True.
  230. if self.id is None:
  231. where = {"app_id": self.local_id}
  232. context = self.db.query(
  233. query,
  234. n_results=num_documents,
  235. where=where,
  236. citations=True,
  237. )
  238. result = []
  239. for c in context:
  240. result.append({"context": c[0], "metadata": c[1]})
  241. return result
  242. else:
  243. # Make API call to the backend to get the results
  244. NotImplementedError("Search is not implemented yet for the prod mode.")
  245. def _upload_file_to_presigned_url(self, presigned_url, file_path):
  246. try:
  247. with open(file_path, "rb") as file:
  248. response = requests.put(presigned_url, data=file)
  249. response.raise_for_status()
  250. return response.status_code == 200
  251. except Exception as e:
  252. self.logger.exception(f"Error occurred during file upload: {str(e)}")
  253. print("❌ Error occurred during file upload!")
  254. return False
  255. def _upload_data_to_pipeline(self, data_type, data_value, metadata=None):
  256. payload = {
  257. "data_type": data_type,
  258. "data_value": data_value,
  259. "metadata": metadata,
  260. }
  261. try:
  262. self._send_api_request(f"/api/v1/pipelines/{self.id}/cli/add/", payload)
  263. # print the local file path if user tries to upload a local file
  264. printed_value = metadata.get("file_path") if metadata.get("file_path") else data_value
  265. print(f"✅ Data of type: {data_type}, value: {printed_value} added successfully.")
  266. except Exception as e:
  267. print(f"❌ Error occurred during data upload for type {data_type}!. Error: {str(e)}")
  268. def _send_api_request(self, endpoint, payload):
  269. url = f"{self.client.host}{endpoint}"
  270. headers = {"Authorization": f"Token {self.client.api_key}"}
  271. response = requests.post(url, json=payload, headers=headers)
  272. response.raise_for_status()
  273. return response
  274. def _process_and_upload_data(self, data_hash, data_type, data_value):
  275. if os.path.isabs(data_value):
  276. presigned_url_data = self._get_presigned_url(data_type, data_value)
  277. presigned_url = presigned_url_data["presigned_url"]
  278. s3_key = presigned_url_data["s3_key"]
  279. if self._upload_file_to_presigned_url(presigned_url, file_path=data_value):
  280. metadata = {"file_path": data_value, "s3_key": s3_key}
  281. data_value = presigned_url
  282. else:
  283. self.logger.error(f"File upload failed for hash: {data_hash}")
  284. return False
  285. else:
  286. if data_type == "qna_pair":
  287. data_value = list(ast.literal_eval(data_value))
  288. metadata = {}
  289. try:
  290. self._upload_data_to_pipeline(data_type, data_value, metadata)
  291. self._mark_data_as_uploaded(data_hash)
  292. return True
  293. except Exception:
  294. print(f"❌ Error occurred during data upload for hash {data_hash}!")
  295. return False
  296. def _mark_data_as_uploaded(self, data_hash):
  297. self.cursor.execute(
  298. "UPDATE data_sources SET is_uploaded = 1 WHERE hash = ? AND pipeline_id = ?",
  299. (data_hash, self.local_id),
  300. )
  301. self.connection.commit()
  302. def get_data_sources(self):
  303. db_data = self.cursor.execute("SELECT * FROM data_sources WHERE pipeline_id = ?", (self.local_id,)).fetchall()
  304. data_sources = []
  305. for data in db_data:
  306. data_sources.append({"data_type": data[2], "data_value": data[3], "metadata": data[4]})
  307. return data_sources
  308. def deploy(self):
  309. if self.client is None:
  310. self._init_client()
  311. pipeline_data = self._create_pipeline()
  312. self.id = pipeline_data["id"]
  313. results = self.cursor.execute(
  314. "SELECT * FROM data_sources WHERE pipeline_id = ? AND is_uploaded = 0", (self.local_id,) # noqa:E501
  315. ).fetchall()
  316. if len(results) > 0:
  317. print("🛠️ Adding data to your pipeline...")
  318. for result in results:
  319. data_hash, data_type, data_value = result[1], result[2], result[3]
  320. self._process_and_upload_data(data_hash, data_type, data_value)
  321. # Send anonymous telemetry
  322. self.telemetry.capture(event_name="deploy", properties=self._telemetry_props)
  323. @classmethod
  324. def from_config(
  325. cls,
  326. config_path: Optional[str] = None,
  327. config: Optional[Dict[str, Any]] = None,
  328. auto_deploy: bool = False,
  329. yaml_path: Optional[str] = None,
  330. ):
  331. """
  332. Instantiate a Pipeline object from a configuration.
  333. :param config_path: Path to the YAML or JSON configuration file.
  334. :type config_path: Optional[str]
  335. :param config: A dictionary containing the configuration.
  336. :type config: Optional[Dict[str, Any]]
  337. :param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
  338. :type auto_deploy: bool, optional
  339. :param yaml_path: (Deprecated) Path to the YAML configuration file. Use config_path instead.
  340. :type yaml_path: Optional[str]
  341. :return: An instance of the Pipeline class.
  342. :rtype: Pipeline
  343. """
  344. # Backward compatibility for yaml_path
  345. if yaml_path and not config_path:
  346. config_path = yaml_path
  347. if config_path and config:
  348. raise ValueError("Please provide only one of config_path or config.")
  349. config_data = None
  350. if config_path:
  351. file_extension = os.path.splitext(config_path)[1]
  352. with open(config_path, "r") as file:
  353. if file_extension in [".yaml", ".yml"]:
  354. config_data = yaml.safe_load(file)
  355. elif file_extension == ".json":
  356. config_data = json.load(file)
  357. else:
  358. raise ValueError("config_path must be a path to a YAML or JSON file.")
  359. elif config and isinstance(config, dict):
  360. config_data = config
  361. else:
  362. logging.error(
  363. "Please provide either a config file path (YAML or JSON) or a config dictionary. Falling back to defaults because no config is provided.", # noqa: E501
  364. )
  365. config_data = {}
  366. try:
  367. validate_config(config_data)
  368. except Exception as e:
  369. raise Exception(f"Error occurred while validating the config. Error: {str(e)}")
  370. app_config_data = config_data.get("app", {}).get("config", {})
  371. db_config_data = config_data.get("vectordb", {})
  372. embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
  373. llm_config_data = config_data.get("llm", {})
  374. chunker_config_data = config_data.get("chunker", {})
  375. cache_config_data = config_data.get("cache", None)
  376. app_config = AppConfig(**app_config_data)
  377. db_provider = db_config_data.get("provider", "chroma")
  378. db = VectorDBFactory.create(db_provider, db_config_data.get("config", {}))
  379. if llm_config_data:
  380. llm_provider = llm_config_data.get("provider", "openai")
  381. llm = LlmFactory.create(llm_provider, llm_config_data.get("config", {}))
  382. else:
  383. llm = None
  384. embedding_model_provider = embedding_model_config_data.get("provider", "openai")
  385. embedding_model = EmbedderFactory.create(
  386. embedding_model_provider, embedding_model_config_data.get("config", {})
  387. )
  388. if cache_config_data is not None:
  389. cache_config = CacheConfig.from_config(cache_config_data)
  390. else:
  391. cache_config = None
  392. # Send anonymous telemetry
  393. event_properties = {"init_type": "config_data"}
  394. AnonymousTelemetry().capture(event_name="init", properties=event_properties)
  395. return cls(
  396. config=app_config,
  397. llm=llm,
  398. db=db,
  399. embedding_model=embedding_model,
  400. config_data=config_data,
  401. auto_deploy=auto_deploy,
  402. chunker=chunker_config_data,
  403. cache_config=cache_config,
  404. )