utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. import logging
  2. import os
  3. import re
  4. import string
  5. from typing import Any
  6. from embedchain.models.data_type import DataType
  7. def clean_string(text):
  8. """
  9. This function takes in a string and performs a series of text cleaning operations.
  10. Args:
  11. text (str): The text to be cleaned. This is expected to be a string.
  12. Returns:
  13. cleaned_text (str): The cleaned text after all the cleaning operations
  14. have been performed.
  15. """
  16. # Replacement of newline characters:
  17. text = text.replace("\n", " ")
  18. # Stripping and reducing multiple spaces to single:
  19. cleaned_text = re.sub(r"\s+", " ", text.strip())
  20. # Removing backslashes:
  21. cleaned_text = cleaned_text.replace("\\", "")
  22. # Replacing hash characters:
  23. cleaned_text = cleaned_text.replace("#", " ")
  24. # Eliminating consecutive non-alphanumeric characters:
  25. # This regex identifies consecutive non-alphanumeric characters (i.e., not
  26. # a word character [a-zA-Z0-9_] and not a whitespace) in the string
  27. # and replaces each group of such characters with a single occurrence of
  28. # that character.
  29. # For example, "!!! hello !!!" would become "! hello !".
  30. cleaned_text = re.sub(r"([^\w\s])\1*", r"\1", cleaned_text)
  31. return cleaned_text
  32. def is_readable(s):
  33. """
  34. Heuristic to determine if a string is "readable" (mostly contains printable characters and forms meaningful words)
  35. :param s: string
  36. :return: True if the string is more than 95% printable.
  37. """
  38. try:
  39. printable_ratio = sum(c in string.printable for c in s) / len(s)
  40. except ZeroDivisionError:
  41. logging.warning("Empty string processed as unreadable")
  42. printable_ratio = 0
  43. return printable_ratio > 0.95 # 95% of characters are printable
  44. def use_pysqlite3():
  45. """
  46. Swap std-lib sqlite3 with pysqlite3.
  47. """
  48. import platform
  49. import sqlite3
  50. if platform.system() == "Linux" and sqlite3.sqlite_version_info < (3, 35, 0):
  51. try:
  52. # According to the Chroma team, this patch only works on Linux
  53. import datetime
  54. import subprocess
  55. import sys
  56. subprocess.check_call(
  57. [sys.executable, "-m", "pip", "install", "pysqlite3-binary", "--quiet", "--disable-pip-version-check"]
  58. )
  59. __import__("pysqlite3")
  60. sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
  61. # Let the user know what happened.
  62. current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]
  63. print(
  64. f"{current_time} [embedchain] [INFO]",
  65. "Swapped std-lib sqlite3 with pysqlite3 for ChromaDb compatibility.",
  66. f"Your original version was {sqlite3.sqlite_version}.",
  67. )
  68. except Exception as e:
  69. # Escape all exceptions
  70. current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]
  71. print(
  72. f"{current_time} [embedchain] [ERROR]",
  73. "Failed to swap std-lib sqlite3 with pysqlite3 for ChromaDb compatibility.",
  74. "Error:",
  75. e,
  76. )
  77. def format_source(source: str, limit: int = 20) -> str:
  78. """
  79. Format a string to only take the first x and last x letters.
  80. This makes it easier to display a URL, keeping familiarity while ensuring a consistent length.
  81. If the string is too short, it is not sliced.
  82. """
  83. if len(source) > 2 * limit:
  84. return source[:limit] + "..." + source[-limit:]
  85. return source
  86. def detect_datatype(source: Any) -> DataType:
  87. """
  88. Automatically detect the datatype of the given source.
  89. :param source: the source to base the detection on
  90. :return: data_type string
  91. """
  92. from urllib.parse import urlparse
  93. import requests
  94. import yaml
  95. def is_openapi_yaml(yaml_content):
  96. # currently the following two fields are required in openapi spec yaml config
  97. return "openapi" in yaml_content and "info" in yaml_content
  98. try:
  99. if not isinstance(source, str):
  100. raise ValueError("Source is not a string and thus cannot be a URL.")
  101. url = urlparse(source)
  102. # Check if both scheme and netloc are present. Local file system URIs are acceptable too.
  103. if not all([url.scheme, url.netloc]) and url.scheme != "file":
  104. raise ValueError("Not a valid URL.")
  105. except ValueError:
  106. url = False
  107. formatted_source = format_source(str(source), 30)
  108. if url:
  109. from langchain.document_loaders.youtube import \
  110. ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
  111. if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
  112. logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
  113. return DataType.YOUTUBE_VIDEO
  114. if url.netloc in {"notion.so", "notion.site"}:
  115. logging.debug(f"Source of `{formatted_source}` detected as `notion`.")
  116. return DataType.NOTION
  117. if url.path.endswith(".pdf"):
  118. logging.debug(f"Source of `{formatted_source}` detected as `pdf_file`.")
  119. return DataType.PDF_FILE
  120. if url.path.endswith(".xml"):
  121. logging.debug(f"Source of `{formatted_source}` detected as `sitemap`.")
  122. return DataType.SITEMAP
  123. if url.path.endswith(".csv"):
  124. logging.debug(f"Source of `{formatted_source}` detected as `csv`.")
  125. return DataType.CSV
  126. if url.path.endswith(".docx"):
  127. logging.debug(f"Source of `{formatted_source}` detected as `docx`.")
  128. return DataType.DOCX
  129. if url.path.endswith(".yaml"):
  130. try:
  131. response = requests.get(source)
  132. response.raise_for_status()
  133. try:
  134. yaml_content = yaml.safe_load(response.text)
  135. except yaml.YAMLError as exc:
  136. logging.error(f"Error parsing YAML: {exc}")
  137. raise TypeError(f"Not a valid data type. Error loading YAML: {exc}")
  138. if is_openapi_yaml(yaml_content):
  139. logging.debug(f"Source of `{formatted_source}` detected as `openapi`.")
  140. return DataType.OPENAPI
  141. else:
  142. logging.error(
  143. f"Source of `{formatted_source}` does not contain all the required \
  144. fields of OpenAPI yaml. Check 'https://spec.openapis.org/oas/v3.1.0'"
  145. )
  146. raise TypeError(
  147. "Not a valid data type. Check 'https://spec.openapis.org/oas/v3.1.0', \
  148. make sure you have all the required fields in YAML config data"
  149. )
  150. except requests.exceptions.RequestException as e:
  151. logging.error(f"Error fetching URL {formatted_source}: {e}")
  152. if url.path.endswith(".json"):
  153. logging.debug(f"Source of `{formatted_source}` detected as `json_file`.")
  154. return DataType.JSON
  155. if "docs" in url.netloc or ("docs" in url.path and url.scheme != "file"):
  156. # `docs_site` detection via path is not accepted for local filesystem URIs,
  157. # because that would mean all paths that contain `docs` are now doc sites, which is too aggressive.
  158. logging.debug(f"Source of `{formatted_source}` detected as `docs_site`.")
  159. return DataType.DOCS_SITE
  160. # If none of the above conditions are met, it's a general web page
  161. logging.debug(f"Source of `{formatted_source}` detected as `web_page`.")
  162. return DataType.WEB_PAGE
  163. elif not isinstance(source, str):
  164. # For datatypes where source is not a string.
  165. if isinstance(source, tuple) and len(source) == 2 and isinstance(source[0], str) and isinstance(source[1], str):
  166. logging.debug(f"Source of `{formatted_source}` detected as `qna_pair`.")
  167. return DataType.QNA_PAIR
  168. # Raise an error if it isn't a string and also not a valid non-string type (one of the previous).
  169. # We could stringify it, but it is better to raise an error and let the user decide how they want to do that.
  170. raise TypeError(
  171. "Source is not a string and a valid non-string type could not be detected. If you want to embed it, please stringify it, for instance by using `str(source)` or `(', ').join(source)`." # noqa: E501
  172. )
  173. elif os.path.isfile(source):
  174. # For datatypes that support conventional file references.
  175. # Note: checking for string is not necessary anymore.
  176. if source.endswith(".docx"):
  177. logging.debug(f"Source of `{formatted_source}` detected as `docx`.")
  178. return DataType.DOCX
  179. if source.endswith(".csv"):
  180. logging.debug(f"Source of `{formatted_source}` detected as `csv`.")
  181. return DataType.CSV
  182. if source.endswith(".xml"):
  183. logging.debug(f"Source of `{formatted_source}` detected as `xml`.")
  184. return DataType.XML
  185. if source.endswith(".yaml"):
  186. with open(source, "r") as file:
  187. yaml_content = yaml.safe_load(file)
  188. if is_openapi_yaml(yaml_content):
  189. logging.debug(f"Source of `{formatted_source}` detected as `openapi`.")
  190. return DataType.OPENAPI
  191. else:
  192. logging.error(
  193. f"Source of `{formatted_source}` does not contain all the required \
  194. fields of OpenAPI yaml. Check 'https://spec.openapis.org/oas/v3.1.0'"
  195. )
  196. raise ValueError(
  197. "Invalid YAML data. Check 'https://spec.openapis.org/oas/v3.1.0', \
  198. make sure to add all the required params"
  199. )
  200. if source.endswith(".json"):
  201. logging.debug(f"Source of `{formatted_source}` detected as `json`.")
  202. return DataType.JSON
  203. # If the source is a valid file, that's not detectable as a type, an error is raised.
  204. # It does not fallback to text.
  205. raise ValueError(
  206. "Source points to a valid file, but based on the filename, no `data_type` can be detected. Please be aware, that not all data_types allow conventional file references, some require the use of the `file URI scheme`. Please refer to the embedchain documentation (https://docs.embedchain.ai/advanced/data_types#remote-data-types)." # noqa: E501
  207. )
  208. else:
  209. # Source is not a URL.
  210. # TODO: check if source is gmail query
  211. # Use text as final fallback.
  212. logging.debug(f"Source of `{formatted_source}` detected as `text`.")
  213. return DataType.TEXT