utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. import json
  2. import logging
  3. import os
  4. import re
  5. import string
  6. from typing import Any
  7. from embedchain.models.data_type import DataType
  8. def clean_string(text):
  9. """
  10. This function takes in a string and performs a series of text cleaning operations.
  11. Args:
  12. text (str): The text to be cleaned. This is expected to be a string.
  13. Returns:
  14. cleaned_text (str): The cleaned text after all the cleaning operations
  15. have been performed.
  16. """
  17. # Replacement of newline characters:
  18. text = text.replace("\n", " ")
  19. # Stripping and reducing multiple spaces to single:
  20. cleaned_text = re.sub(r"\s+", " ", text.strip())
  21. # Removing backslashes:
  22. cleaned_text = cleaned_text.replace("\\", "")
  23. # Replacing hash characters:
  24. cleaned_text = cleaned_text.replace("#", " ")
  25. # Eliminating consecutive non-alphanumeric characters:
  26. # This regex identifies consecutive non-alphanumeric characters (i.e., not
  27. # a word character [a-zA-Z0-9_] and not a whitespace) in the string
  28. # and replaces each group of such characters with a single occurrence of
  29. # that character.
  30. # For example, "!!! hello !!!" would become "! hello !".
  31. cleaned_text = re.sub(r"([^\w\s])\1*", r"\1", cleaned_text)
  32. return cleaned_text
  33. def is_readable(s):
  34. """
  35. Heuristic to determine if a string is "readable" (mostly contains printable characters and forms meaningful words)
  36. :param s: string
  37. :return: True if the string is more than 95% printable.
  38. """
  39. try:
  40. printable_ratio = sum(c in string.printable for c in s) / len(s)
  41. except ZeroDivisionError:
  42. logging.warning("Empty string processed as unreadable")
  43. printable_ratio = 0
  44. return printable_ratio > 0.95 # 95% of characters are printable
  45. def use_pysqlite3():
  46. """
  47. Swap std-lib sqlite3 with pysqlite3.
  48. """
  49. import platform
  50. import sqlite3
  51. if platform.system() == "Linux" and sqlite3.sqlite_version_info < (3, 35, 0):
  52. try:
  53. # According to the Chroma team, this patch only works on Linux
  54. import datetime
  55. import subprocess
  56. import sys
  57. subprocess.check_call(
  58. [sys.executable, "-m", "pip", "install", "pysqlite3-binary", "--quiet", "--disable-pip-version-check"]
  59. )
  60. __import__("pysqlite3")
  61. sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
  62. # Let the user know what happened.
  63. current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]
  64. print(
  65. f"{current_time} [embedchain] [INFO]",
  66. "Swapped std-lib sqlite3 with pysqlite3 for ChromaDb compatibility.",
  67. f"Your original version was {sqlite3.sqlite_version}.",
  68. )
  69. except Exception as e:
  70. # Escape all exceptions
  71. current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]
  72. print(
  73. f"{current_time} [embedchain] [ERROR]",
  74. "Failed to swap std-lib sqlite3 with pysqlite3 for ChromaDb compatibility.",
  75. "Error:",
  76. e,
  77. )
  78. def format_source(source: str, limit: int = 20) -> str:
  79. """
  80. Format a string to only take the first x and last x letters.
  81. This makes it easier to display a URL, keeping familiarity while ensuring a consistent length.
  82. If the string is too short, it is not sliced.
  83. """
  84. if len(source) > 2 * limit:
  85. return source[:limit] + "..." + source[-limit:]
  86. return source
  87. def detect_datatype(source: Any) -> DataType:
  88. """
  89. Automatically detect the datatype of the given source.
  90. :param source: the source to base the detection on
  91. :return: data_type string
  92. """
  93. from urllib.parse import urlparse
  94. import requests
  95. import yaml
  96. def is_openapi_yaml(yaml_content):
  97. # currently the following two fields are required in openapi spec yaml config
  98. return "openapi" in yaml_content and "info" in yaml_content
  99. try:
  100. if not isinstance(source, str):
  101. raise ValueError("Source is not a string and thus cannot be a URL.")
  102. url = urlparse(source)
  103. # Check if both scheme and netloc are present. Local file system URIs are acceptable too.
  104. if not all([url.scheme, url.netloc]) and url.scheme != "file":
  105. raise ValueError("Not a valid URL.")
  106. except ValueError:
  107. url = False
  108. formatted_source = format_source(str(source), 30)
  109. if url:
  110. from langchain.document_loaders.youtube import \
  111. ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
  112. if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
  113. logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
  114. return DataType.YOUTUBE_VIDEO
  115. if url.netloc in {"notion.so", "notion.site"}:
  116. logging.debug(f"Source of `{formatted_source}` detected as `notion`.")
  117. return DataType.NOTION
  118. if url.path.endswith(".pdf"):
  119. logging.debug(f"Source of `{formatted_source}` detected as `pdf_file`.")
  120. return DataType.PDF_FILE
  121. if url.path.endswith(".xml"):
  122. logging.debug(f"Source of `{formatted_source}` detected as `sitemap`.")
  123. return DataType.SITEMAP
  124. if url.path.endswith(".csv"):
  125. logging.debug(f"Source of `{formatted_source}` detected as `csv`.")
  126. return DataType.CSV
  127. if url.path.endswith(".docx"):
  128. logging.debug(f"Source of `{formatted_source}` detected as `docx`.")
  129. return DataType.DOCX
  130. if url.path.endswith(".yaml"):
  131. try:
  132. response = requests.get(source)
  133. response.raise_for_status()
  134. try:
  135. yaml_content = yaml.safe_load(response.text)
  136. except yaml.YAMLError as exc:
  137. logging.error(f"Error parsing YAML: {exc}")
  138. raise TypeError(f"Not a valid data type. Error loading YAML: {exc}")
  139. if is_openapi_yaml(yaml_content):
  140. logging.debug(f"Source of `{formatted_source}` detected as `openapi`.")
  141. return DataType.OPENAPI
  142. else:
  143. logging.error(
  144. f"Source of `{formatted_source}` does not contain all the required \
  145. fields of OpenAPI yaml. Check 'https://spec.openapis.org/oas/v3.1.0'"
  146. )
  147. raise TypeError(
  148. "Not a valid data type. Check 'https://spec.openapis.org/oas/v3.1.0', \
  149. make sure you have all the required fields in YAML config data"
  150. )
  151. except requests.exceptions.RequestException as e:
  152. logging.error(f"Error fetching URL {formatted_source}: {e}")
  153. if url.path.endswith(".json"):
  154. logging.debug(f"Source of `{formatted_source}` detected as `json_file`.")
  155. return DataType.JSON
  156. if "docs" in url.netloc or ("docs" in url.path and url.scheme != "file"):
  157. # `docs_site` detection via path is not accepted for local filesystem URIs,
  158. # because that would mean all paths that contain `docs` are now doc sites, which is too aggressive.
  159. logging.debug(f"Source of `{formatted_source}` detected as `docs_site`.")
  160. return DataType.DOCS_SITE
  161. # If none of the above conditions are met, it's a general web page
  162. logging.debug(f"Source of `{formatted_source}` detected as `web_page`.")
  163. return DataType.WEB_PAGE
  164. elif not isinstance(source, str):
  165. # For datatypes where source is not a string.
  166. if isinstance(source, tuple) and len(source) == 2 and isinstance(source[0], str) and isinstance(source[1], str):
  167. logging.debug(f"Source of `{formatted_source}` detected as `qna_pair`.")
  168. return DataType.QNA_PAIR
  169. # Raise an error if it isn't a string and also not a valid non-string type (one of the previous).
  170. # We could stringify it, but it is better to raise an error and let the user decide how they want to do that.
  171. raise TypeError(
  172. "Source is not a string and a valid non-string type could not be detected. If you want to embed it, please stringify it, for instance by using `str(source)` or `(', ').join(source)`." # noqa: E501
  173. )
  174. elif os.path.isfile(source):
  175. # For datatypes that support conventional file references.
  176. # Note: checking for string is not necessary anymore.
  177. if source.endswith(".docx"):
  178. logging.debug(f"Source of `{formatted_source}` detected as `docx`.")
  179. return DataType.DOCX
  180. if source.endswith(".csv"):
  181. logging.debug(f"Source of `{formatted_source}` detected as `csv`.")
  182. return DataType.CSV
  183. if source.endswith(".xml"):
  184. logging.debug(f"Source of `{formatted_source}` detected as `xml`.")
  185. return DataType.XML
  186. if source.endswith(".yaml"):
  187. with open(source, "r") as file:
  188. yaml_content = yaml.safe_load(file)
  189. if is_openapi_yaml(yaml_content):
  190. logging.debug(f"Source of `{formatted_source}` detected as `openapi`.")
  191. return DataType.OPENAPI
  192. else:
  193. logging.error(
  194. f"Source of `{formatted_source}` does not contain all the required \
  195. fields of OpenAPI yaml. Check 'https://spec.openapis.org/oas/v3.1.0'"
  196. )
  197. raise ValueError(
  198. "Invalid YAML data. Check 'https://spec.openapis.org/oas/v3.1.0', \
  199. make sure to add all the required params"
  200. )
  201. if source.endswith(".json"):
  202. logging.debug(f"Source of `{formatted_source}` detected as `json`.")
  203. return DataType.JSON
  204. # If the source is a valid file, that's not detectable as a type, an error is raised.
  205. # It does not fallback to text.
  206. raise ValueError(
  207. "Source points to a valid file, but based on the filename, no `data_type` can be detected. Please be aware, that not all data_types allow conventional file references, some require the use of the `file URI scheme`. Please refer to the embedchain documentation (https://docs.embedchain.ai/advanced/data_types#remote-data-types)." # noqa: E501
  208. )
  209. else:
  210. # Source is not a URL.
  211. # TODO: check if source is gmail query
  212. # check if the source is valid json string
  213. if is_valid_json_string(source):
  214. logging.debug(f"Source of `{formatted_source}` detected as `json`.")
  215. return DataType.JSON
  216. # Use text as final fallback.
  217. logging.debug(f"Source of `{formatted_source}` detected as `text`.")
  218. return DataType.TEXT
  219. # check if the source is valid json string
  220. def is_valid_json_string(source: str):
  221. try:
  222. _ = json.loads(source)
  223. return True
  224. except json.JSONDecodeError:
  225. logging.error(
  226. "Insert valid string format of JSON. \
  227. Check the docs to see the supported formats - `https://docs.embedchain.ai/data-sources/json`"
  228. )
  229. return False