app.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521
  1. import ast
  2. import concurrent.futures
  3. import json
  4. import logging
  5. import os
  6. from typing import Any, Optional, Union
  7. import requests
  8. import yaml
  9. from tqdm import tqdm
  10. from mem0 import Mem0
  11. from embedchain.cache import (
  12. Config,
  13. ExactMatchEvaluation,
  14. SearchDistanceEvaluation,
  15. cache,
  16. gptcache_data_manager,
  17. gptcache_pre_function,
  18. )
  19. from embedchain.client import Client
  20. from embedchain.config import AppConfig, CacheConfig, ChunkerConfig, Mem0Config
  21. from embedchain.core.db.database import get_session, init_db, setup_engine
  22. from embedchain.core.db.models import DataSource
  23. from embedchain.embedchain import EmbedChain
  24. from embedchain.embedder.base import BaseEmbedder
  25. from embedchain.embedder.openai import OpenAIEmbedder
  26. from embedchain.evaluation.base import BaseMetric
  27. from embedchain.evaluation.metrics import AnswerRelevance, ContextRelevance, Groundedness
  28. from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
  29. from embedchain.helpers.json_serializable import register_deserializable
  30. from embedchain.llm.base import BaseLlm
  31. from embedchain.llm.openai import OpenAILlm
  32. from embedchain.telemetry.posthog import AnonymousTelemetry
  33. from embedchain.utils.evaluation import EvalData, EvalMetric
  34. from embedchain.utils.misc import validate_config
  35. from embedchain.vectordb.base import BaseVectorDB
  36. from embedchain.vectordb.chroma import ChromaDB
  37. logger = logging.getLogger(__name__)
  38. @register_deserializable
  39. class App(EmbedChain):
  40. """
  41. EmbedChain App lets you create a LLM powered app for your unstructured
  42. data by defining your chosen data source, embedding model,
  43. and vector database.
  44. """
  45. def __init__(
  46. self,
  47. id: str = None,
  48. name: str = None,
  49. config: AppConfig = None,
  50. db: BaseVectorDB = None,
  51. embedding_model: BaseEmbedder = None,
  52. llm: BaseLlm = None,
  53. config_data: dict = None,
  54. auto_deploy: bool = False,
  55. chunker: ChunkerConfig = None,
  56. cache_config: CacheConfig = None,
  57. memory_config: Mem0Config = None,
  58. log_level: int = logging.WARN,
  59. ):
  60. """
  61. Initialize a new `App` instance.
  62. :param config: Configuration for the pipeline, defaults to None
  63. :type config: AppConfig, optional
  64. :param db: The database to use for storing and retrieving embeddings, defaults to None
  65. :type db: BaseVectorDB, optional
  66. :param embedding_model: The embedding model used to calculate embeddings, defaults to None
  67. :type embedding_model: BaseEmbedder, optional
  68. :param llm: The LLM model used to calculate embeddings, defaults to None
  69. :type llm: BaseLlm, optional
  70. :param config_data: Config dictionary, defaults to None
  71. :type config_data: dict, optional
  72. :param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
  73. :type auto_deploy: bool, optional
  74. :raises Exception: If an error occurs while creating the pipeline
  75. """
  76. if id and config_data:
  77. raise Exception("Cannot provide both id and config. Please provide only one of them.")
  78. if id and name:
  79. raise Exception("Cannot provide both id and name. Please provide only one of them.")
  80. if name and config:
  81. raise Exception("Cannot provide both name and config. Please provide only one of them.")
  82. # Initialize the metadata db for the app
  83. setup_engine(database_uri=os.environ.get("EMBEDCHAIN_DB_URI"))
  84. init_db()
  85. self.auto_deploy = auto_deploy
  86. # Store the dict config as an attribute to be able to send it
  87. self.config_data = config_data if (config_data and validate_config(config_data)) else None
  88. self.client = None
  89. # pipeline_id from the backend
  90. self.id = None
  91. self.chunker = ChunkerConfig(**chunker) if chunker else None
  92. self.cache_config = cache_config
  93. self.memory_config = memory_config
  94. self.config = config or AppConfig()
  95. self.name = self.config.name
  96. self.config.id = self.local_id = "default-app-id" if self.config.id is None else self.config.id
  97. if id is not None:
  98. # Init client first since user is trying to fetch the pipeline
  99. # details from the platform
  100. self._init_client()
  101. pipeline_details = self._get_pipeline(id)
  102. self.config.id = self.local_id = pipeline_details["metadata"]["local_id"]
  103. self.id = id
  104. if name is not None:
  105. self.name = name
  106. self.embedding_model = embedding_model or OpenAIEmbedder()
  107. self.db = db or ChromaDB()
  108. self.llm = llm or OpenAILlm()
  109. self._init_db()
  110. # Session for the metadata db
  111. self.db_session = get_session()
  112. # If cache_config is provided, initializing the cache ...
  113. if self.cache_config is not None:
  114. self._init_cache()
  115. # If memory_config is provided, initializing the memory ...
  116. self.mem0_client = None
  117. if self.memory_config is not None:
  118. self.mem0_client = Mem0(api_key=self.memory_config.api_key)
  119. # Send anonymous telemetry
  120. self._telemetry_props = {"class": self.__class__.__name__}
  121. self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
  122. self.telemetry.capture(event_name="init", properties=self._telemetry_props)
  123. self.user_asks = []
  124. if self.auto_deploy:
  125. self.deploy()
  126. def _init_db(self):
  127. """
  128. Initialize the database.
  129. """
  130. self.db._set_embedder(self.embedding_model)
  131. self.db._initialize()
  132. self.db.set_collection_name(self.db.config.collection_name)
  133. def _init_cache(self):
  134. if self.cache_config.similarity_eval_config.strategy == "exact":
  135. similarity_eval_func = ExactMatchEvaluation()
  136. else:
  137. similarity_eval_func = SearchDistanceEvaluation(
  138. max_distance=self.cache_config.similarity_eval_config.max_distance,
  139. positive=self.cache_config.similarity_eval_config.positive,
  140. )
  141. cache.init(
  142. pre_embedding_func=gptcache_pre_function,
  143. embedding_func=self.embedding_model.to_embeddings,
  144. data_manager=gptcache_data_manager(vector_dimension=self.embedding_model.vector_dimension),
  145. similarity_evaluation=similarity_eval_func,
  146. config=Config(**self.cache_config.init_config.as_dict()),
  147. )
  148. def _init_client(self):
  149. """
  150. Initialize the client.
  151. """
  152. config = Client.load_config()
  153. if config.get("api_key"):
  154. self.client = Client()
  155. else:
  156. api_key = input(
  157. "🔑 Enter your Embedchain API key. You can find the API key at https://app.embedchain.ai/settings/keys/ \n" # noqa: E501
  158. )
  159. self.client = Client(api_key=api_key)
  160. def _get_pipeline(self, id):
  161. """
  162. Get existing pipeline
  163. """
  164. print("🛠️ Fetching pipeline details from the platform...")
  165. url = f"{self.client.host}/api/v1/pipelines/{id}/cli/"
  166. r = requests.get(
  167. url,
  168. headers={"Authorization": f"Token {self.client.api_key}"},
  169. )
  170. if r.status_code == 404:
  171. raise Exception(f"❌ Pipeline with id {id} not found!")
  172. print(
  173. f"🎉 Pipeline loaded successfully! Pipeline url: https://app.embedchain.ai/pipelines/{r.json()['id']}\n" # noqa: E501
  174. )
  175. return r.json()
  176. def _create_pipeline(self):
  177. """
  178. Create a pipeline on the platform.
  179. """
  180. print("🛠️ Creating pipeline on the platform...")
  181. # self.config_data is a dict. Pass it inside the key 'yaml_config' to the backend
  182. payload = {
  183. "yaml_config": json.dumps(self.config_data),
  184. "name": self.name,
  185. "local_id": self.local_id,
  186. }
  187. url = f"{self.client.host}/api/v1/pipelines/cli/create/"
  188. r = requests.post(
  189. url,
  190. json=payload,
  191. headers={"Authorization": f"Token {self.client.api_key}"},
  192. )
  193. if r.status_code not in [200, 201]:
  194. raise Exception(f"❌ Error occurred while creating pipeline. API response: {r.text}")
  195. if r.status_code == 200:
  196. print(
  197. f"🎉🎉🎉 Existing pipeline found! View your pipeline: https://app.embedchain.ai/pipelines/{r.json()['id']}\n" # noqa: E501
  198. ) # noqa: E501
  199. elif r.status_code == 201:
  200. print(
  201. f"🎉🎉🎉 Pipeline created successfully! View your pipeline: https://app.embedchain.ai/pipelines/{r.json()['id']}\n" # noqa: E501
  202. )
  203. return r.json()
  204. def _get_presigned_url(self, data_type, data_value):
  205. payload = {"data_type": data_type, "data_value": data_value}
  206. r = requests.post(
  207. f"{self.client.host}/api/v1/pipelines/{self.id}/cli/presigned_url/",
  208. json=payload,
  209. headers={"Authorization": f"Token {self.client.api_key}"},
  210. )
  211. r.raise_for_status()
  212. return r.json()
  213. def _upload_file_to_presigned_url(self, presigned_url, file_path):
  214. try:
  215. with open(file_path, "rb") as file:
  216. response = requests.put(presigned_url, data=file)
  217. response.raise_for_status()
  218. return response.status_code == 200
  219. except Exception as e:
  220. logger.exception(f"Error occurred during file upload: {str(e)}")
  221. print("❌ Error occurred during file upload!")
  222. return False
  223. def _upload_data_to_pipeline(self, data_type, data_value, metadata=None):
  224. payload = {
  225. "data_type": data_type,
  226. "data_value": data_value,
  227. "metadata": metadata,
  228. }
  229. try:
  230. self._send_api_request(f"/api/v1/pipelines/{self.id}/cli/add/", payload)
  231. # print the local file path if user tries to upload a local file
  232. printed_value = metadata.get("file_path") if metadata.get("file_path") else data_value
  233. print(f"✅ Data of type: {data_type}, value: {printed_value} added successfully.")
  234. except Exception as e:
  235. print(f"❌ Error occurred during data upload for type {data_type}!. Error: {str(e)}")
  236. def _send_api_request(self, endpoint, payload):
  237. url = f"{self.client.host}{endpoint}"
  238. headers = {"Authorization": f"Token {self.client.api_key}"}
  239. response = requests.post(url, json=payload, headers=headers)
  240. response.raise_for_status()
  241. return response
  242. def _process_and_upload_data(self, data_hash, data_type, data_value):
  243. if os.path.isabs(data_value):
  244. presigned_url_data = self._get_presigned_url(data_type, data_value)
  245. presigned_url = presigned_url_data["presigned_url"]
  246. s3_key = presigned_url_data["s3_key"]
  247. if self._upload_file_to_presigned_url(presigned_url, file_path=data_value):
  248. metadata = {"file_path": data_value, "s3_key": s3_key}
  249. data_value = presigned_url
  250. else:
  251. logger.error(f"File upload failed for hash: {data_hash}")
  252. return False
  253. else:
  254. if data_type == "qna_pair":
  255. data_value = list(ast.literal_eval(data_value))
  256. metadata = {}
  257. try:
  258. self._upload_data_to_pipeline(data_type, data_value, metadata)
  259. self._mark_data_as_uploaded(data_hash)
  260. return True
  261. except Exception:
  262. print(f"❌ Error occurred during data upload for hash {data_hash}!")
  263. return False
  264. def _mark_data_as_uploaded(self, data_hash):
  265. self.db_session.query(DataSource).filter_by(hash=data_hash, app_id=self.local_id).update({"is_uploaded": 1})
  266. def get_data_sources(self):
  267. data_sources = self.db_session.query(DataSource).filter_by(app_id=self.local_id).all()
  268. results = []
  269. for row in data_sources:
  270. results.append({"data_type": row.type, "data_value": row.value, "metadata": row.meta_data})
  271. return results
  272. def deploy(self):
  273. if self.client is None:
  274. self._init_client()
  275. pipeline_data = self._create_pipeline()
  276. self.id = pipeline_data["id"]
  277. results = self.db_session.query(DataSource).filter_by(app_id=self.local_id, is_uploaded=0).all()
  278. if len(results) > 0:
  279. print("🛠️ Adding data to your pipeline...")
  280. for result in results:
  281. data_hash, data_type, data_value = result.hash, result.data_type, result.data_value
  282. self._process_and_upload_data(data_hash, data_type, data_value)
  283. # Send anonymous telemetry
  284. self.telemetry.capture(event_name="deploy", properties=self._telemetry_props)
  285. @classmethod
  286. def from_config(
  287. cls,
  288. config_path: Optional[str] = None,
  289. config: Optional[dict[str, Any]] = None,
  290. auto_deploy: bool = False,
  291. yaml_path: Optional[str] = None,
  292. ):
  293. """
  294. Instantiate a App object from a configuration.
  295. :param config_path: Path to the YAML or JSON configuration file.
  296. :type config_path: Optional[str]
  297. :param config: A dictionary containing the configuration.
  298. :type config: Optional[dict[str, Any]]
  299. :param auto_deploy: Whether to deploy the app automatically, defaults to False
  300. :type auto_deploy: bool, optional
  301. :param yaml_path: (Deprecated) Path to the YAML configuration file. Use config_path instead.
  302. :type yaml_path: Optional[str]
  303. :return: An instance of the App class.
  304. :rtype: App
  305. """
  306. # Backward compatibility for yaml_path
  307. if yaml_path and not config_path:
  308. config_path = yaml_path
  309. if config_path and config:
  310. raise ValueError("Please provide only one of config_path or config.")
  311. config_data = None
  312. if config_path:
  313. file_extension = os.path.splitext(config_path)[1]
  314. with open(config_path, "r", encoding="UTF-8") as file:
  315. if file_extension in [".yaml", ".yml"]:
  316. config_data = yaml.safe_load(file)
  317. elif file_extension == ".json":
  318. config_data = json.load(file)
  319. else:
  320. raise ValueError("config_path must be a path to a YAML or JSON file.")
  321. elif config and isinstance(config, dict):
  322. config_data = config
  323. else:
  324. logger.error(
  325. "Please provide either a config file path (YAML or JSON) or a config dictionary. Falling back to defaults because no config is provided.", # noqa: E501
  326. )
  327. config_data = {}
  328. # Validate the config
  329. validate_config(config_data)
  330. app_config_data = config_data.get("app", {}).get("config", {})
  331. vector_db_config_data = config_data.get("vectordb", {})
  332. embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
  333. memory_config_data = config_data.get("memory", {})
  334. llm_config_data = config_data.get("llm", {})
  335. chunker_config_data = config_data.get("chunker", {})
  336. cache_config_data = config_data.get("cache", None)
  337. app_config = AppConfig(**app_config_data)
  338. memory_config = Mem0Config(**memory_config_data) if memory_config_data else None
  339. vector_db_provider = vector_db_config_data.get("provider", "chroma")
  340. vector_db = VectorDBFactory.create(vector_db_provider, vector_db_config_data.get("config", {}))
  341. if llm_config_data:
  342. # Initialize the metadata db for the app here since llmfactory needs it for initialization of
  343. # the llm memory
  344. setup_engine(database_uri=os.environ.get("EMBEDCHAIN_DB_URI"))
  345. init_db()
  346. llm_provider = llm_config_data.get("provider", "openai")
  347. llm = LlmFactory.create(llm_provider, llm_config_data.get("config", {}))
  348. else:
  349. llm = None
  350. embedding_model_provider = embedding_model_config_data.get("provider", "openai")
  351. embedding_model = EmbedderFactory.create(
  352. embedding_model_provider, embedding_model_config_data.get("config", {})
  353. )
  354. if cache_config_data is not None:
  355. cache_config = CacheConfig.from_config(cache_config_data)
  356. else:
  357. cache_config = None
  358. return cls(
  359. config=app_config,
  360. llm=llm,
  361. db=vector_db,
  362. embedding_model=embedding_model,
  363. config_data=config_data,
  364. auto_deploy=auto_deploy,
  365. chunker=chunker_config_data,
  366. cache_config=cache_config,
  367. memory_config=memory_config,
  368. )
  369. def _eval(self, dataset: list[EvalData], metric: Union[BaseMetric, str]):
  370. """
  371. Evaluate the app on a dataset for a given metric.
  372. """
  373. metric_str = metric.name if isinstance(metric, BaseMetric) else metric
  374. eval_class_map = {
  375. EvalMetric.CONTEXT_RELEVANCY.value: ContextRelevance,
  376. EvalMetric.ANSWER_RELEVANCY.value: AnswerRelevance,
  377. EvalMetric.GROUNDEDNESS.value: Groundedness,
  378. }
  379. if metric_str in eval_class_map:
  380. return eval_class_map[metric_str]().evaluate(dataset)
  381. # Handle the case for custom metrics
  382. if isinstance(metric, BaseMetric):
  383. return metric.evaluate(dataset)
  384. else:
  385. raise ValueError(f"Invalid metric: {metric}")
  386. def evaluate(
  387. self,
  388. questions: Union[str, list[str]],
  389. metrics: Optional[list[Union[BaseMetric, str]]] = None,
  390. num_workers: int = 4,
  391. ):
  392. """
  393. Evaluate the app on a question.
  394. param: questions: A question or a list of questions to evaluate.
  395. type: questions: Union[str, list[str]]
  396. param: metrics: A list of metrics to evaluate. Defaults to all metrics.
  397. type: metrics: Optional[list[Union[BaseMetric, str]]]
  398. param: num_workers: Number of workers to use for parallel processing.
  399. type: num_workers: int
  400. return: A dictionary containing the evaluation results.
  401. rtype: dict
  402. """
  403. if "OPENAI_API_KEY" not in os.environ:
  404. raise ValueError("Please set the OPENAI_API_KEY environment variable with permission to use `gpt4` model.")
  405. queries, answers, contexts = [], [], []
  406. if isinstance(questions, list):
  407. with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
  408. future_to_data = {executor.submit(self.query, q, citations=True): q for q in questions}
  409. for future in tqdm(
  410. concurrent.futures.as_completed(future_to_data),
  411. total=len(future_to_data),
  412. desc="Getting answer and contexts for questions",
  413. ):
  414. question = future_to_data[future]
  415. queries.append(question)
  416. answer, context = future.result()
  417. answers.append(answer)
  418. contexts.append(list(map(lambda x: x[0], context)))
  419. else:
  420. answer, context = self.query(questions, citations=True)
  421. queries = [questions]
  422. answers = [answer]
  423. contexts = [list(map(lambda x: x[0], context))]
  424. metrics = metrics or [
  425. EvalMetric.CONTEXT_RELEVANCY.value,
  426. EvalMetric.ANSWER_RELEVANCY.value,
  427. EvalMetric.GROUNDEDNESS.value,
  428. ]
  429. logger.info(f"Collecting data from {len(queries)} questions for evaluation...")
  430. dataset = []
  431. for q, a, c in zip(queries, answers, contexts):
  432. dataset.append(EvalData(question=q, answer=a, contexts=c))
  433. logger.info(f"Evaluating {len(dataset)} data points...")
  434. result = {}
  435. with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
  436. future_to_metric = {executor.submit(self._eval, dataset, metric): metric for metric in metrics}
  437. for future in tqdm(
  438. concurrent.futures.as_completed(future_to_metric),
  439. total=len(future_to_metric),
  440. desc="Evaluating metrics",
  441. ):
  442. metric = future_to_metric[future]
  443. if isinstance(metric, BaseMetric):
  444. result[metric.name] = future.result()
  445. else:
  446. result[metric] = future.result()
  447. if self.config.collect_metrics:
  448. telemetry_props = self._telemetry_props
  449. metrics_names = []
  450. for metric in metrics:
  451. if isinstance(metric, BaseMetric):
  452. metrics_names.append(metric.name)
  453. else:
  454. metrics_names.append(metric)
  455. telemetry_props["metrics"] = metrics_names
  456. self.telemetry.capture(event_name="evaluate", properties=telemetry_props)
  457. return result