utils.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. import logging
  2. import os
  3. import re
  4. import string
  5. from typing import Any
  6. from embedchain.models.data_type import DataType
  7. def clean_string(text):
  8. """
  9. This function takes in a string and performs a series of text cleaning operations.
  10. Args:
  11. text (str): The text to be cleaned. This is expected to be a string.
  12. Returns:
  13. cleaned_text (str): The cleaned text after all the cleaning operations
  14. have been performed.
  15. """
  16. # Replacement of newline characters:
  17. text = text.replace("\n", " ")
  18. # Stripping and reducing multiple spaces to single:
  19. cleaned_text = re.sub(r"\s+", " ", text.strip())
  20. # Removing backslashes:
  21. cleaned_text = cleaned_text.replace("\\", "")
  22. # Replacing hash characters:
  23. cleaned_text = cleaned_text.replace("#", " ")
  24. # Eliminating consecutive non-alphanumeric characters:
  25. # This regex identifies consecutive non-alphanumeric characters (i.e., not
  26. # a word character [a-zA-Z0-9_] and not a whitespace) in the string
  27. # and replaces each group of such characters with a single occurrence of
  28. # that character.
  29. # For example, "!!! hello !!!" would become "! hello !".
  30. cleaned_text = re.sub(r"([^\w\s])\1*", r"\1", cleaned_text)
  31. return cleaned_text
  32. def is_readable(s):
  33. """
  34. Heuristic to determine if a string is "readable" (mostly contains printable characters and forms meaningful words)
  35. :param s: string
  36. :return: True if the string is more than 95% printable.
  37. """
  38. try:
  39. printable_ratio = sum(c in string.printable for c in s) / len(s)
  40. except ZeroDivisionError:
  41. logging.warning("Empty string processed as unreadable")
  42. printable_ratio = 0
  43. return printable_ratio > 0.95 # 95% of characters are printable
  44. def use_pysqlite3():
  45. """
  46. Swap std-lib sqlite3 with pysqlite3.
  47. """
  48. import platform
  49. import sqlite3
  50. if platform.system() == "Linux" and sqlite3.sqlite_version_info < (3, 35, 0):
  51. try:
  52. # According to the Chroma team, this patch only works on Linux
  53. import datetime
  54. import subprocess
  55. import sys
  56. subprocess.check_call(
  57. [sys.executable, "-m", "pip", "install", "pysqlite3-binary", "--quiet", "--disable-pip-version-check"]
  58. )
  59. __import__("pysqlite3")
  60. sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
  61. # Let the user know what happened.
  62. current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]
  63. print(
  64. f"{current_time} [embedchain] [INFO]",
  65. "Swapped std-lib sqlite3 with pysqlite3 for ChromaDb compatibility.",
  66. f"Your original version was {sqlite3.sqlite_version}.",
  67. )
  68. except Exception as e:
  69. # Escape all exceptions
  70. current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]
  71. print(
  72. f"{current_time} [embedchain] [ERROR]",
  73. "Failed to swap std-lib sqlite3 with pysqlite3 for ChromaDb compatibility.",
  74. "Error:",
  75. e,
  76. )
  77. def format_source(source: str, limit: int = 20) -> str:
  78. """
  79. Format a string to only take the first x and last x letters.
  80. This makes it easier to display a URL, keeping familiarity while ensuring a consistent length.
  81. If the string is too short, it is not sliced.
  82. """
  83. if len(source) > 2 * limit:
  84. return source[:limit] + "..." + source[-limit:]
  85. return source
  86. def detect_datatype(source: Any) -> DataType:
  87. """
  88. Automatically detect the datatype of the given source.
  89. :param source: the source to base the detection on
  90. :return: data_type string
  91. """
  92. from urllib.parse import urlparse
  93. try:
  94. if not isinstance(source, str):
  95. raise ValueError("Source is not a string and thus cannot be a URL.")
  96. url = urlparse(source)
  97. # Check if both scheme and netloc are present. Local file system URIs are acceptable too.
  98. if not all([url.scheme, url.netloc]) and url.scheme != "file":
  99. raise ValueError("Not a valid URL.")
  100. except ValueError:
  101. url = False
  102. formatted_source = format_source(str(source), 30)
  103. if url:
  104. from langchain.document_loaders.youtube import ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
  105. if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
  106. logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
  107. return DataType.YOUTUBE_VIDEO
  108. if url.netloc in {"notion.so", "notion.site"}:
  109. logging.debug(f"Source of `{formatted_source}` detected as `notion`.")
  110. return DataType.NOTION
  111. if url.path.endswith(".pdf"):
  112. logging.debug(f"Source of `{formatted_source}` detected as `pdf_file`.")
  113. return DataType.PDF_FILE
  114. if url.path.endswith(".xml"):
  115. logging.debug(f"Source of `{formatted_source}` detected as `sitemap`.")
  116. return DataType.SITEMAP
  117. if url.path.endswith(".csv"):
  118. logging.debug(f"Source of `{formatted_source}` detected as `csv`.")
  119. return DataType.CSV
  120. if url.path.endswith(".docx"):
  121. logging.debug(f"Source of `{formatted_source}` detected as `docx`.")
  122. return DataType.DOCX
  123. if "docs" in url.netloc or ("docs" in url.path and url.scheme != "file"):
  124. # `docs_site` detection via path is not accepted for local filesystem URIs,
  125. # because that would mean all paths that contain `docs` are now doc sites, which is too aggressive.
  126. logging.debug(f"Source of `{formatted_source}` detected as `docs_site`.")
  127. return DataType.DOCS_SITE
  128. # If none of the above conditions are met, it's a general web page
  129. logging.debug(f"Source of `{formatted_source}` detected as `web_page`.")
  130. return DataType.WEB_PAGE
  131. elif not isinstance(source, str):
  132. # For datatypes where source is not a string.
  133. if isinstance(source, tuple) and len(source) == 2 and isinstance(source[0], str) and isinstance(source[1], str):
  134. logging.debug(f"Source of `{formatted_source}` detected as `qna_pair`.")
  135. return DataType.QNA_PAIR
  136. # Raise an error if it isn't a string and also not a valid non-string type (one of the previous).
  137. # We could stringify it, but it is better to raise an error and let the user decide how they want to do that.
  138. raise TypeError(
  139. "Source is not a string and a valid non-string type could not be detected. If you want to embed it, please stringify it, for instance by using `str(source)` or `(', ').join(source)`." # noqa: E501
  140. )
  141. elif os.path.isfile(source):
  142. # For datatypes that support conventional file references.
  143. # Note: checking for string is not necessary anymore.
  144. if source.endswith(".docx"):
  145. logging.debug(f"Source of `{formatted_source}` detected as `docx`.")
  146. return DataType.DOCX
  147. if source.endswith(".csv"):
  148. logging.debug(f"Source of `{formatted_source}` detected as `csv`.")
  149. return DataType.CSV
  150. if source.endswith(".xml"):
  151. logging.debug(f"Source of `{formatted_source}` detected as `xml`.")
  152. return DataType.XML
  153. # If the source is a valid file, that's not detectable as a type, an error is raised.
  154. # It does not fallback to text.
  155. raise ValueError(
  156. "Source points to a valid file, but based on the filename, no `data_type` can be detected. Please be aware, that not all data_types allow conventional file references, some require the use of the `file URI scheme`. Please refer to the embedchain documentation (https://docs.embedchain.ai/advanced/data_types#remote-data-types)." # noqa: E501
  157. )
  158. else:
  159. # Source is not a URL.
  160. # Use text as final fallback.
  161. logging.debug(f"Source of `{formatted_source}` detected as `text`.")
  162. return DataType.TEXT