test_utils.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. import tempfile
  2. import unittest
  3. from unittest.mock import patch
  4. from embedchain.models.data_type import DataType
  5. from embedchain.utils import detect_datatype
  6. class TestApp(unittest.TestCase):
  7. """Test that the datatype detection is working, based on the input."""
  8. def test_detect_datatype_youtube(self):
  9. self.assertEqual(detect_datatype("https://www.youtube.com/watch?v=dQw4w9WgXcQ"), DataType.YOUTUBE_VIDEO)
  10. self.assertEqual(detect_datatype("https://m.youtube.com/watch?v=dQw4w9WgXcQ"), DataType.YOUTUBE_VIDEO)
  11. self.assertEqual(
  12. detect_datatype("https://www.youtube-nocookie.com/watch?v=dQw4w9WgXcQ"), DataType.YOUTUBE_VIDEO
  13. )
  14. self.assertEqual(detect_datatype("https://vid.plus/watch?v=dQw4w9WgXcQ"), DataType.YOUTUBE_VIDEO)
  15. self.assertEqual(detect_datatype("https://youtu.be/dQw4w9WgXcQ"), DataType.YOUTUBE_VIDEO)
  16. def test_detect_datatype_local_file(self):
  17. self.assertEqual(detect_datatype("file:///home/user/file.txt"), DataType.WEB_PAGE)
  18. def test_detect_datatype_pdf(self):
  19. self.assertEqual(detect_datatype("https://www.example.com/document.pdf"), DataType.PDF_FILE)
  20. def test_detect_datatype_local_pdf(self):
  21. self.assertEqual(detect_datatype("file:///home/user/document.pdf"), DataType.PDF_FILE)
  22. def test_detect_datatype_xml(self):
  23. self.assertEqual(detect_datatype("https://www.example.com/sitemap.xml"), DataType.SITEMAP)
  24. def test_detect_datatype_local_xml(self):
  25. self.assertEqual(detect_datatype("file:///home/user/sitemap.xml"), DataType.SITEMAP)
  26. def test_detect_datatype_docx(self):
  27. self.assertEqual(detect_datatype("https://www.example.com/document.docx"), DataType.DOCX)
  28. def test_detect_datatype_local_docx(self):
  29. self.assertEqual(detect_datatype("file:///home/user/document.docx"), DataType.DOCX)
  30. def test_detect_data_type_json(self):
  31. self.assertEqual(detect_datatype("https://www.example.com/data.json"), DataType.JSON)
  32. def test_detect_data_type_local_json(self):
  33. self.assertEqual(detect_datatype("file:///home/user/data.json"), DataType.JSON)
  34. @patch("os.path.isfile")
  35. def test_detect_datatype_regular_filesystem_docx(self, mock_isfile):
  36. with tempfile.NamedTemporaryFile(suffix=".docx", delete=True) as tmp:
  37. mock_isfile.return_value = True
  38. self.assertEqual(detect_datatype(tmp.name), DataType.DOCX)
  39. def test_detect_datatype_docs_site(self):
  40. self.assertEqual(detect_datatype("https://docs.example.com"), DataType.DOCS_SITE)
  41. def test_detect_datatype_docs_sitein_path(self):
  42. self.assertEqual(detect_datatype("https://www.example.com/docs/index.html"), DataType.DOCS_SITE)
  43. self.assertNotEqual(detect_datatype("file:///var/www/docs/index.html"), DataType.DOCS_SITE) # NOT equal
  44. def test_detect_datatype_web_page(self):
  45. self.assertEqual(detect_datatype("https://nav.al/agi"), DataType.WEB_PAGE)
  46. def test_detect_datatype_invalid_url(self):
  47. self.assertEqual(detect_datatype("not a url"), DataType.TEXT)
  48. def test_detect_datatype_qna_pair(self):
  49. self.assertEqual(
  50. detect_datatype(("Question?", "Answer. Content of the string is irrelevant.")), DataType.QNA_PAIR
  51. ) #
  52. def test_detect_datatype_qna_pair_types(self):
  53. """Test that a QnA pair needs to be a tuple of length two, and both items have to be strings."""
  54. with self.assertRaises(TypeError):
  55. self.assertNotEqual(
  56. detect_datatype(("How many planets are in our solar system?", 8)), DataType.QNA_PAIR
  57. ) # NOT equal
  58. def test_detect_datatype_text(self):
  59. self.assertEqual(detect_datatype("Just some text."), DataType.TEXT)
  60. def test_detect_datatype_non_string_error(self):
  61. """Test type error if the value passed is not a string, and not a valid non-string data_type"""
  62. with self.assertRaises(TypeError):
  63. detect_datatype(["foo", "bar"])
  64. @patch("os.path.isfile")
  65. def test_detect_datatype_regular_filesystem_file_txt(self, mock_isfile):
  66. with tempfile.NamedTemporaryFile(suffix=".txt", delete=True) as tmp:
  67. mock_isfile.return_value = True
  68. self.assertEqual(detect_datatype(tmp.name), DataType.TEXT_FILE)
  69. def test_detect_datatype_regular_filesystem_no_file(self):
  70. """Test that if a filepath is not actually an existing file, it is not handled as a file path."""
  71. self.assertEqual(detect_datatype("/var/not-an-existing-file.txt"), DataType.TEXT)
  72. def test_doc_examples_quickstart(self):
  73. """Test examples used in the documentation."""
  74. self.assertEqual(detect_datatype("https://en.wikipedia.org/wiki/Elon_Musk"), DataType.WEB_PAGE)
  75. self.assertEqual(detect_datatype("https://www.tesla.com/elon-musk"), DataType.WEB_PAGE)
  76. def test_doc_examples_introduction(self):
  77. """Test examples used in the documentation."""
  78. self.assertEqual(detect_datatype("https://www.youtube.com/watch?v=3qHkcs3kG44"), DataType.YOUTUBE_VIDEO)
  79. self.assertEqual(
  80. detect_datatype(
  81. "https://navalmanack.s3.amazonaws.com/Eric-Jorgenson_The-Almanack-of-Naval-Ravikant_Final.pdf"
  82. ),
  83. DataType.PDF_FILE,
  84. )
  85. self.assertEqual(detect_datatype("https://nav.al/feedback"), DataType.WEB_PAGE)
  86. def test_doc_examples_app_types(self):
  87. """Test examples used in the documentation."""
  88. self.assertEqual(detect_datatype("https://www.youtube.com/watch?v=Ff4fRgnuFgQ"), DataType.YOUTUBE_VIDEO)
  89. self.assertEqual(detect_datatype("https://en.wikipedia.org/wiki/Mark_Zuckerberg"), DataType.WEB_PAGE)
  90. def test_doc_examples_configuration(self):
  91. """Test examples used in the documentation."""
  92. import subprocess
  93. import sys
  94. subprocess.check_call([sys.executable, "-m", "pip", "install", "wikipedia"])
  95. import wikipedia
  96. page = wikipedia.page("Albert Einstein")
  97. # TODO: Add a wikipedia type, so wikipedia is a dependency and we don't need this slow test.
  98. # (timings: import: 1.4s, fetch wiki: 0.7s)
  99. self.assertEqual(detect_datatype(page.content), DataType.TEXT)
  100. if __name__ == "__main__":
  101. unittest.main()