misc.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538
  1. import datetime
  2. import itertools
  3. import json
  4. import logging
  5. import os
  6. import re
  7. import string
  8. from typing import Any
  9. from schema import Optional, Or, Schema
  10. from tqdm import tqdm
  11. from embedchain.models.data_type import DataType
  12. logger = logging.getLogger(__name__)
  13. def parse_content(content, type):
  14. implemented = ["html.parser", "lxml", "lxml-xml", "xml", "html5lib"]
  15. if type not in implemented:
  16. raise ValueError(f"Parser type {type} not implemented. Please choose one of {implemented}")
  17. from bs4 import BeautifulSoup
  18. soup = BeautifulSoup(content, type)
  19. original_size = len(str(soup.get_text()))
  20. tags_to_exclude = [
  21. "nav",
  22. "aside",
  23. "form",
  24. "header",
  25. "noscript",
  26. "svg",
  27. "canvas",
  28. "footer",
  29. "script",
  30. "style",
  31. ]
  32. for tag in soup(tags_to_exclude):
  33. tag.decompose()
  34. ids_to_exclude = ["sidebar", "main-navigation", "menu-main-menu"]
  35. for id in ids_to_exclude:
  36. tags = soup.find_all(id=id)
  37. for tag in tags:
  38. tag.decompose()
  39. classes_to_exclude = [
  40. "elementor-location-header",
  41. "navbar-header",
  42. "nav",
  43. "header-sidebar-wrapper",
  44. "blog-sidebar-wrapper",
  45. "related-posts",
  46. ]
  47. for class_name in classes_to_exclude:
  48. tags = soup.find_all(class_=class_name)
  49. for tag in tags:
  50. tag.decompose()
  51. content = soup.get_text()
  52. content = clean_string(content)
  53. cleaned_size = len(content)
  54. if original_size != 0:
  55. logger.info(
  56. f"Cleaned page size: {cleaned_size} characters, down from {original_size} (shrunk: {original_size-cleaned_size} chars, {round((1-(cleaned_size/original_size)) * 100, 2)}%)" # noqa:E501
  57. )
  58. return content
  59. def clean_string(text):
  60. """
  61. This function takes in a string and performs a series of text cleaning operations.
  62. Args:
  63. text (str): The text to be cleaned. This is expected to be a string.
  64. Returns:
  65. cleaned_text (str): The cleaned text after all the cleaning operations
  66. have been performed.
  67. """
  68. # Stripping and reducing multiple spaces to single:
  69. cleaned_text = re.sub(r"\s+", " ", text.strip())
  70. # Removing backslashes:
  71. cleaned_text = cleaned_text.replace("\\", "")
  72. # Replacing hash characters:
  73. cleaned_text = cleaned_text.replace("#", " ")
  74. # Eliminating consecutive non-alphanumeric characters:
  75. # This regex identifies consecutive non-alphanumeric characters (i.e., not
  76. # a word character [a-zA-Z0-9_] and not a whitespace) in the string
  77. # and replaces each group of such characters with a single occurrence of
  78. # that character.
  79. # For example, "!!! hello !!!" would become "! hello !".
  80. cleaned_text = re.sub(r"([^\w\s])\1*", r"\1", cleaned_text)
  81. return cleaned_text
  82. def is_readable(s):
  83. """
  84. Heuristic to determine if a string is "readable" (mostly contains printable characters and forms meaningful words)
  85. :param s: string
  86. :return: True if the string is more than 95% printable.
  87. """
  88. len_s = len(s)
  89. if len_s == 0:
  90. return False
  91. printable_chars = set(string.printable)
  92. printable_ratio = sum(c in printable_chars for c in s) / len_s
  93. return printable_ratio > 0.95 # 95% of characters are printable
  94. def use_pysqlite3():
  95. """
  96. Swap std-lib sqlite3 with pysqlite3.
  97. """
  98. import platform
  99. import sqlite3
  100. if platform.system() == "Linux" and sqlite3.sqlite_version_info < (3, 35, 0):
  101. try:
  102. # According to the Chroma team, this patch only works on Linux
  103. import datetime
  104. import subprocess
  105. import sys
  106. subprocess.check_call(
  107. [sys.executable, "-m", "pip", "install", "pysqlite3-binary", "--quiet", "--disable-pip-version-check"]
  108. )
  109. __import__("pysqlite3")
  110. sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
  111. # Let the user know what happened.
  112. current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]
  113. print(
  114. f"{current_time} [embedchain] [INFO]",
  115. "Swapped std-lib sqlite3 with pysqlite3 for ChromaDb compatibility.",
  116. f"Your original version was {sqlite3.sqlite_version}.",
  117. )
  118. except Exception as e:
  119. # Escape all exceptions
  120. current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]
  121. print(
  122. f"{current_time} [embedchain] [ERROR]",
  123. "Failed to swap std-lib sqlite3 with pysqlite3 for ChromaDb compatibility.",
  124. "Error:",
  125. e,
  126. )
  127. def format_source(source: str, limit: int = 20) -> str:
  128. """
  129. Format a string to only take the first x and last x letters.
  130. This makes it easier to display a URL, keeping familiarity while ensuring a consistent length.
  131. If the string is too short, it is not sliced.
  132. """
  133. if len(source) > 2 * limit:
  134. return source[:limit] + "..." + source[-limit:]
  135. return source
  136. def detect_datatype(source: Any) -> DataType:
  137. """
  138. Automatically detect the datatype of the given source.
  139. :param source: the source to base the detection on
  140. :return: data_type string
  141. """
  142. from urllib.parse import urlparse
  143. import requests
  144. import yaml
  145. def is_openapi_yaml(yaml_content):
  146. # currently the following two fields are required in openapi spec yaml config
  147. return "openapi" in yaml_content and "info" in yaml_content
  148. def is_google_drive_folder(url):
  149. # checks if url is a Google Drive folder url against a regex
  150. regex = r"^drive\.google\.com\/drive\/(?:u\/\d+\/)folders\/([a-zA-Z0-9_-]+)$"
  151. return re.match(regex, url)
  152. try:
  153. if not isinstance(source, str):
  154. raise ValueError("Source is not a string and thus cannot be a URL.")
  155. url = urlparse(source)
  156. # Check if both scheme and netloc are present. Local file system URIs are acceptable too.
  157. if not all([url.scheme, url.netloc]) and url.scheme != "file":
  158. raise ValueError("Not a valid URL.")
  159. except ValueError:
  160. url = False
  161. formatted_source = format_source(str(source), 30)
  162. if url:
  163. YOUTUBE_ALLOWED_NETLOCKS = {
  164. "www.youtube.com",
  165. "m.youtube.com",
  166. "youtu.be",
  167. "youtube.com",
  168. "vid.plus",
  169. "www.youtube-nocookie.com",
  170. }
  171. if url.netloc in YOUTUBE_ALLOWED_NETLOCKS:
  172. logger.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
  173. return DataType.YOUTUBE_VIDEO
  174. if url.netloc in {"notion.so", "notion.site"}:
  175. logger.debug(f"Source of `{formatted_source}` detected as `notion`.")
  176. return DataType.NOTION
  177. if url.path.endswith(".pdf"):
  178. logger.debug(f"Source of `{formatted_source}` detected as `pdf_file`.")
  179. return DataType.PDF_FILE
  180. if url.path.endswith(".xml"):
  181. logger.debug(f"Source of `{formatted_source}` detected as `sitemap`.")
  182. return DataType.SITEMAP
  183. if url.path.endswith(".csv"):
  184. logger.debug(f"Source of `{formatted_source}` detected as `csv`.")
  185. return DataType.CSV
  186. if url.path.endswith(".mdx") or url.path.endswith(".md"):
  187. logger.debug(f"Source of `{formatted_source}` detected as `mdx`.")
  188. return DataType.MDX
  189. if url.path.endswith(".docx"):
  190. logger.debug(f"Source of `{formatted_source}` detected as `docx`.")
  191. return DataType.DOCX
  192. if url.path.endswith(
  193. (".mp3", ".mp4", ".mp2", ".aac", ".wav", ".flac", ".pcm", ".m4a", ".ogg", ".opus", ".webm")
  194. ):
  195. logger.debug(f"Source of `{formatted_source}` detected as `audio`.")
  196. return DataType.AUDIO
  197. if url.path.endswith(".yaml"):
  198. try:
  199. response = requests.get(source)
  200. response.raise_for_status()
  201. try:
  202. yaml_content = yaml.safe_load(response.text)
  203. except yaml.YAMLError as exc:
  204. logger.error(f"Error parsing YAML: {exc}")
  205. raise TypeError(f"Not a valid data type. Error loading YAML: {exc}")
  206. if is_openapi_yaml(yaml_content):
  207. logger.debug(f"Source of `{formatted_source}` detected as `openapi`.")
  208. return DataType.OPENAPI
  209. else:
  210. logger.error(
  211. f"Source of `{formatted_source}` does not contain all the required \
  212. fields of OpenAPI yaml. Check 'https://spec.openapis.org/oas/v3.1.0'"
  213. )
  214. raise TypeError(
  215. "Not a valid data type. Check 'https://spec.openapis.org/oas/v3.1.0', \
  216. make sure you have all the required fields in YAML config data"
  217. )
  218. except requests.exceptions.RequestException as e:
  219. logger.error(f"Error fetching URL {formatted_source}: {e}")
  220. if url.path.endswith(".json"):
  221. logger.debug(f"Source of `{formatted_source}` detected as `json_file`.")
  222. return DataType.JSON
  223. if "docs" in url.netloc or ("docs" in url.path and url.scheme != "file"):
  224. # `docs_site` detection via path is not accepted for local filesystem URIs,
  225. # because that would mean all paths that contain `docs` are now doc sites, which is too aggressive.
  226. logger.debug(f"Source of `{formatted_source}` detected as `docs_site`.")
  227. return DataType.DOCS_SITE
  228. if "github.com" in url.netloc:
  229. logger.debug(f"Source of `{formatted_source}` detected as `github`.")
  230. return DataType.GITHUB
  231. if is_google_drive_folder(url.netloc + url.path):
  232. logger.debug(f"Source of `{formatted_source}` detected as `google drive folder`.")
  233. return DataType.GOOGLE_DRIVE_FOLDER
  234. # If none of the above conditions are met, it's a general web page
  235. logger.debug(f"Source of `{formatted_source}` detected as `web_page`.")
  236. return DataType.WEB_PAGE
  237. elif not isinstance(source, str):
  238. # For datatypes where source is not a string.
  239. if isinstance(source, tuple) and len(source) == 2 and isinstance(source[0], str) and isinstance(source[1], str):
  240. logger.debug(f"Source of `{formatted_source}` detected as `qna_pair`.")
  241. return DataType.QNA_PAIR
  242. # Raise an error if it isn't a string and also not a valid non-string type (one of the previous).
  243. # We could stringify it, but it is better to raise an error and let the user decide how they want to do that.
  244. raise TypeError(
  245. "Source is not a string and a valid non-string type could not be detected. If you want to embed it, please stringify it, for instance by using `str(source)` or `(', ').join(source)`." # noqa: E501
  246. )
  247. elif os.path.isfile(source):
  248. # For datatypes that support conventional file references.
  249. # Note: checking for string is not necessary anymore.
  250. if source.endswith(".docx"):
  251. logger.debug(f"Source of `{formatted_source}` detected as `docx`.")
  252. return DataType.DOCX
  253. if source.endswith(".csv"):
  254. logger.debug(f"Source of `{formatted_source}` detected as `csv`.")
  255. return DataType.CSV
  256. if source.endswith(".xml"):
  257. logger.debug(f"Source of `{formatted_source}` detected as `xml`.")
  258. return DataType.XML
  259. if source.endswith(".mdx") or source.endswith(".md"):
  260. logger.debug(f"Source of `{formatted_source}` detected as `mdx`.")
  261. return DataType.MDX
  262. if source.endswith(".txt"):
  263. logger.debug(f"Source of `{formatted_source}` detected as `text`.")
  264. return DataType.TEXT_FILE
  265. if source.endswith(".pdf"):
  266. logger.debug(f"Source of `{formatted_source}` detected as `pdf_file`.")
  267. return DataType.PDF_FILE
  268. if source.endswith(".yaml"):
  269. with open(source, "r") as file:
  270. yaml_content = yaml.safe_load(file)
  271. if is_openapi_yaml(yaml_content):
  272. logger.debug(f"Source of `{formatted_source}` detected as `openapi`.")
  273. return DataType.OPENAPI
  274. else:
  275. logger.error(
  276. f"Source of `{formatted_source}` does not contain all the required \
  277. fields of OpenAPI yaml. Check 'https://spec.openapis.org/oas/v3.1.0'"
  278. )
  279. raise ValueError(
  280. "Invalid YAML data. Check 'https://spec.openapis.org/oas/v3.1.0', \
  281. make sure to add all the required params"
  282. )
  283. if source.endswith(".json"):
  284. logger.debug(f"Source of `{formatted_source}` detected as `json`.")
  285. return DataType.JSON
  286. if os.path.exists(source) and is_readable(open(source).read()):
  287. logger.debug(f"Source of `{formatted_source}` detected as `text_file`.")
  288. return DataType.TEXT_FILE
  289. # If the source is a valid file, that's not detectable as a type, an error is raised.
  290. # It does not fall back to text.
  291. raise ValueError(
  292. "Source points to a valid file, but based on the filename, no `data_type` can be detected. Please be aware, that not all data_types allow conventional file references, some require the use of the `file URI scheme`. Please refer to the embedchain documentation (https://docs.embedchain.ai/advanced/data_types#remote-data-types)." # noqa: E501
  293. )
  294. else:
  295. # Source is not a URL.
  296. # TODO: check if source is gmail query
  297. # check if the source is valid json string
  298. if is_valid_json_string(source):
  299. logger.debug(f"Source of `{formatted_source}` detected as `json`.")
  300. return DataType.JSON
  301. # Use text as final fallback.
  302. logger.debug(f"Source of `{formatted_source}` detected as `text`.")
  303. return DataType.TEXT
  304. # check if the source is valid json string
  305. def is_valid_json_string(source: str):
  306. try:
  307. _ = json.loads(source)
  308. return True
  309. except json.JSONDecodeError:
  310. return False
  311. def validate_config(config_data):
  312. schema = Schema(
  313. {
  314. Optional("app"): {
  315. Optional("config"): {
  316. Optional("id"): str,
  317. Optional("name"): str,
  318. Optional("log_level"): Or("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"),
  319. Optional("collect_metrics"): bool,
  320. Optional("collection_name"): str,
  321. }
  322. },
  323. Optional("llm"): {
  324. Optional("provider"): Or(
  325. "openai",
  326. "azure_openai",
  327. "anthropic",
  328. "huggingface",
  329. "cohere",
  330. "together",
  331. "gpt4all",
  332. "ollama",
  333. "jina",
  334. "llama2",
  335. "vertexai",
  336. "google",
  337. "aws_bedrock",
  338. "mistralai",
  339. "clarifai",
  340. "vllm",
  341. "groq",
  342. "nvidia",
  343. ),
  344. Optional("config"): {
  345. Optional("model"): str,
  346. Optional("model_name"): str,
  347. Optional("number_documents"): int,
  348. Optional("temperature"): float,
  349. Optional("max_tokens"): int,
  350. Optional("top_p"): Or(float, int),
  351. Optional("stream"): bool,
  352. Optional("online"): bool,
  353. Optional("template"): str,
  354. Optional("prompt"): str,
  355. Optional("system_prompt"): str,
  356. Optional("deployment_name"): str,
  357. Optional("where"): dict,
  358. Optional("query_type"): str,
  359. Optional("api_key"): str,
  360. Optional("base_url"): str,
  361. Optional("endpoint"): str,
  362. Optional("model_kwargs"): dict,
  363. Optional("local"): bool,
  364. Optional("base_url"): str,
  365. Optional("default_headers"): dict,
  366. Optional("api_version"): Or(str, datetime.date),
  367. Optional("http_client_proxies"): Or(str, dict),
  368. Optional("http_async_client_proxies"): Or(str, dict),
  369. },
  370. },
  371. Optional("vectordb"): {
  372. Optional("provider"): Or(
  373. "chroma", "elasticsearch", "opensearch", "lancedb", "pinecone", "qdrant", "weaviate", "zilliz"
  374. ),
  375. Optional("config"): object, # TODO: add particular config schema for each provider
  376. },
  377. Optional("embedder"): {
  378. Optional("provider"): Or(
  379. "openai",
  380. "gpt4all",
  381. "huggingface",
  382. "vertexai",
  383. "azure_openai",
  384. "google",
  385. "mistralai",
  386. "clarifai",
  387. "nvidia",
  388. "ollama",
  389. "cohere",
  390. ),
  391. Optional("config"): {
  392. Optional("model"): Optional(str),
  393. Optional("deployment_name"): Optional(str),
  394. Optional("api_key"): str,
  395. Optional("api_base"): str,
  396. Optional("title"): str,
  397. Optional("task_type"): str,
  398. Optional("vector_dimension"): int,
  399. Optional("base_url"): str,
  400. Optional("endpoint"): str,
  401. Optional("model_kwargs"): dict,
  402. },
  403. },
  404. Optional("embedding_model"): {
  405. Optional("provider"): Or(
  406. "openai",
  407. "gpt4all",
  408. "huggingface",
  409. "vertexai",
  410. "azure_openai",
  411. "google",
  412. "mistralai",
  413. "clarifai",
  414. "nvidia",
  415. "ollama",
  416. ),
  417. Optional("config"): {
  418. Optional("model"): str,
  419. Optional("deployment_name"): str,
  420. Optional("api_key"): str,
  421. Optional("title"): str,
  422. Optional("task_type"): str,
  423. Optional("vector_dimension"): int,
  424. Optional("base_url"): str,
  425. },
  426. },
  427. Optional("chunker"): {
  428. Optional("chunk_size"): int,
  429. Optional("chunk_overlap"): int,
  430. Optional("length_function"): str,
  431. Optional("min_chunk_size"): int,
  432. },
  433. Optional("cache"): {
  434. Optional("similarity_evaluation"): {
  435. Optional("strategy"): Or("distance", "exact"),
  436. Optional("max_distance"): float,
  437. Optional("positive"): bool,
  438. },
  439. Optional("config"): {
  440. Optional("similarity_threshold"): float,
  441. Optional("auto_flush"): int,
  442. },
  443. },
  444. }
  445. )
  446. return schema.validate(config_data)
  447. def chunks(iterable, batch_size=100, desc="Processing chunks"):
  448. """A helper function to break an iterable into chunks of size batch_size."""
  449. it = iter(iterable)
  450. total_size = len(iterable)
  451. with tqdm(total=total_size, desc=desc, unit="batch") as pbar:
  452. chunk = tuple(itertools.islice(it, batch_size))
  453. while chunk:
  454. yield chunk
  455. pbar.update(len(chunk))
  456. chunk = tuple(itertools.islice(it, batch_size))