utils.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. import logging
  2. import os
  3. import re
  4. import string
  5. from typing import Any
  6. from embedchain.models.data_type import DataType
  7. def clean_string(text):
  8. """
  9. This function takes in a string and performs a series of text cleaning operations.
  10. Args:
  11. text (str): The text to be cleaned. This is expected to be a string.
  12. Returns:
  13. cleaned_text (str): The cleaned text after all the cleaning operations
  14. have been performed.
  15. """
  16. # Replacement of newline characters:
  17. text = text.replace("\n", " ")
  18. # Stripping and reducing multiple spaces to single:
  19. cleaned_text = re.sub(r"\s+", " ", text.strip())
  20. # Removing backslashes:
  21. cleaned_text = cleaned_text.replace("\\", "")
  22. # Replacing hash characters:
  23. cleaned_text = cleaned_text.replace("#", " ")
  24. # Eliminating consecutive non-alphanumeric characters:
  25. # This regex identifies consecutive non-alphanumeric characters (i.e., not
  26. # a word character [a-zA-Z0-9_] and not a whitespace) in the string
  27. # and replaces each group of such characters with a single occurrence of
  28. # that character.
  29. # For example, "!!! hello !!!" would become "! hello !".
  30. cleaned_text = re.sub(r"([^\w\s])\1*", r"\1", cleaned_text)
  31. return cleaned_text
  32. def is_readable(s):
  33. """
  34. Heuristic to determine if a string is "readable" (mostly contains printable characters and forms meaningful words)
  35. :param s: string
  36. :return: True if the string is more than 95% printable.
  37. """
  38. try:
  39. printable_ratio = sum(c in string.printable for c in s) / len(s)
  40. except ZeroDivisionError:
  41. logging.warning("Empty string processed as unreadable")
  42. printable_ratio = 0
  43. return printable_ratio > 0.95 # 95% of characters are printable
  44. def use_pysqlite3():
  45. """
  46. Swap std-lib sqlite3 with pysqlite3.
  47. """
  48. import platform
  49. import sqlite3
  50. if platform.system() == "Linux" and sqlite3.sqlite_version_info < (3, 35, 0):
  51. try:
  52. # According to the Chroma team, this patch only works on Linux
  53. import datetime
  54. import subprocess
  55. import sys
  56. subprocess.check_call(
  57. [sys.executable, "-m", "pip", "install", "pysqlite3-binary", "--quiet", "--disable-pip-version-check"]
  58. )
  59. __import__("pysqlite3")
  60. sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
  61. # Let the user know what happened.
  62. current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]
  63. print(
  64. f"{current_time} [embedchain] [INFO]",
  65. "Swapped std-lib sqlite3 with pysqlite3 for ChromaDb compatibility.",
  66. f"Your original version was {sqlite3.sqlite_version}.",
  67. )
  68. except Exception as e:
  69. # Escape all exceptions
  70. current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]
  71. print(
  72. f"{current_time} [embedchain] [ERROR]",
  73. "Failed to swap std-lib sqlite3 with pysqlite3 for ChromaDb compatibility.",
  74. "Error:",
  75. e,
  76. )
  77. def format_source(source: str, limit: int = 20) -> str:
  78. """
  79. Format a string to only take the first x and last x letters.
  80. This makes it easier to display a URL, keeping familiarity while ensuring a consistent length.
  81. If the string is too short, it is not sliced.
  82. """
  83. if len(source) > 2 * limit:
  84. return source[:limit] + "..." + source[-limit:]
  85. return source
  86. def detect_datatype(source: Any) -> DataType:
  87. """
  88. Automatically detect the datatype of the given source.
  89. :param source: the source to base the detection on
  90. :return: data_type string
  91. """
  92. from urllib.parse import urlparse
  93. try:
  94. if not isinstance(source, str):
  95. raise ValueError("Source is not a string and thus cannot be a URL.")
  96. url = urlparse(source)
  97. # Check if both scheme and netloc are present. Local file system URIs are acceptable too.
  98. if not all([url.scheme, url.netloc]) and url.scheme != "file":
  99. raise ValueError("Not a valid URL.")
  100. except ValueError:
  101. url = False
  102. formatted_source = format_source(str(source), 30)
  103. if url:
  104. from langchain.document_loaders.youtube import \
  105. ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
  106. if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
  107. logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
  108. return DataType.YOUTUBE_VIDEO
  109. if url.netloc in {"notion.so", "notion.site"}:
  110. logging.debug(f"Source of `{formatted_source}` detected as `notion`.")
  111. return DataType.NOTION
  112. if url.path.endswith(".pdf"):
  113. logging.debug(f"Source of `{formatted_source}` detected as `pdf_file`.")
  114. return DataType.PDF_FILE
  115. if url.path.endswith(".xml"):
  116. logging.debug(f"Source of `{formatted_source}` detected as `sitemap`.")
  117. return DataType.SITEMAP
  118. if url.path.endswith(".docx"):
  119. logging.debug(f"Source of `{formatted_source}` detected as `docx`.")
  120. return DataType.DOCX
  121. if "docs" in url.netloc or ("docs" in url.path and url.scheme != "file"):
  122. # `docs_site` detection via path is not accepted for local filesystem URIs,
  123. # because that would mean all paths that contain `docs` are now doc sites, which is too aggressive.
  124. logging.debug(f"Source of `{formatted_source}` detected as `docs_site`.")
  125. return DataType.DOCS_SITE
  126. # If none of the above conditions are met, it's a general web page
  127. logging.debug(f"Source of `{formatted_source}` detected as `web_page`.")
  128. return DataType.WEB_PAGE
  129. elif not isinstance(source, str):
  130. # For datatypes where source is not a string.
  131. if isinstance(source, tuple) and len(source) == 2 and isinstance(source[0], str) and isinstance(source[1], str):
  132. logging.debug(f"Source of `{formatted_source}` detected as `qna_pair`.")
  133. return DataType.QNA_PAIR
  134. # Raise an error if it isn't a string and also not a valid non-string type (one of the previous).
  135. # We could stringify it, but it is better to raise an error and let the user decide how they want to do that.
  136. raise TypeError(
  137. "Source is not a string and a valid non-string type could not be detected. If you want to embed it, please stringify it, for instance by using `str(source)` or `(', ').join(source)`." # noqa: E501
  138. )
  139. elif os.path.isfile(source):
  140. # For datatypes that support conventional file references.
  141. # Note: checking for string is not necessary anymore.
  142. if source.endswith(".docx"):
  143. logging.debug(f"Source of `{formatted_source}` detected as `docx`.")
  144. return DataType.DOCX
  145. # If the source is a valid file, that's not detectable as a type, an error is raised.
  146. # It does not fallback to text.
  147. raise ValueError(
  148. "Source points to a valid file, but based on the filename, no `data_type` can be detected. Please be aware, that not all data_types allow conventional file references, some require the use of the `file URI scheme`. Please refer to the embedchain documentation (https://docs.embedchain.ai/advanced/data_types#remote-data-types)." # noqa: E501
  149. )
  150. else:
  151. # Source is not a URL.
  152. # Use text as final fallback.
  153. logging.debug(f"Source of `{formatted_source}` detected as `text`.")
  154. return DataType.TEXT