test_utils.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import tempfile
  2. import unittest
  3. from unittest.mock import patch
  4. from embedchain.models.data_type import DataType
  5. from embedchain.utils import detect_datatype
  6. class TestApp(unittest.TestCase):
  7. """Test that the datatype detection is working, based on the input."""
  8. def test_detect_datatype_youtube(self):
  9. self.assertEqual(detect_datatype("https://www.youtube.com/watch?v=dQw4w9WgXcQ"), DataType.YOUTUBE_VIDEO)
  10. self.assertEqual(detect_datatype("https://m.youtube.com/watch?v=dQw4w9WgXcQ"), DataType.YOUTUBE_VIDEO)
  11. self.assertEqual(
  12. detect_datatype("https://www.youtube-nocookie.com/watch?v=dQw4w9WgXcQ"), DataType.YOUTUBE_VIDEO
  13. )
  14. self.assertEqual(detect_datatype("https://vid.plus/watch?v=dQw4w9WgXcQ"), DataType.YOUTUBE_VIDEO)
  15. self.assertEqual(detect_datatype("https://youtu.be/dQw4w9WgXcQ"), DataType.YOUTUBE_VIDEO)
  16. def test_detect_datatype_local_file(self):
  17. self.assertEqual(detect_datatype("file:///home/user/file.txt"), DataType.WEB_PAGE)
  18. def test_detect_datatype_pdf(self):
  19. self.assertEqual(detect_datatype("https://www.example.com/document.pdf"), DataType.PDF_FILE)
  20. def test_detect_datatype_local_pdf(self):
  21. self.assertEqual(detect_datatype("file:///home/user/document.pdf"), DataType.PDF_FILE)
  22. def test_detect_datatype_xml(self):
  23. self.assertEqual(detect_datatype("https://www.example.com/sitemap.xml"), DataType.SITEMAP)
  24. def test_detect_datatype_local_xml(self):
  25. self.assertEqual(detect_datatype("file:///home/user/sitemap.xml"), DataType.SITEMAP)
  26. def test_detect_datatype_docx(self):
  27. self.assertEqual(detect_datatype("https://www.example.com/document.docx"), DataType.DOCX)
  28. def test_detect_datatype_local_docx(self):
  29. self.assertEqual(detect_datatype("file:///home/user/document.docx"), DataType.DOCX)
  30. @patch("os.path.isfile")
  31. def test_detect_datatype_regular_filesystem_docx(self, mock_isfile):
  32. with tempfile.NamedTemporaryFile(suffix=".docx", delete=True) as tmp:
  33. mock_isfile.return_value = True
  34. self.assertEqual(detect_datatype(tmp.name), DataType.DOCX)
  35. def test_detect_datatype_docs_site(self):
  36. self.assertEqual(detect_datatype("https://docs.example.com"), DataType.DOCS_SITE)
  37. def test_detect_datatype_docs_sitein_path(self):
  38. self.assertEqual(detect_datatype("https://www.example.com/docs/index.html"), DataType.DOCS_SITE)
  39. self.assertNotEqual(detect_datatype("file:///var/www/docs/index.html"), DataType.DOCS_SITE) # NOT equal
  40. def test_detect_datatype_web_page(self):
  41. self.assertEqual(detect_datatype("https://nav.al/agi"), DataType.WEB_PAGE)
  42. def test_detect_datatype_invalid_url(self):
  43. self.assertEqual(detect_datatype("not a url"), DataType.TEXT)
  44. def test_detect_datatype_qna_pair(self):
  45. self.assertEqual(
  46. detect_datatype(("Question?", "Answer. Content of the string is irrelevant.")), DataType.QNA_PAIR
  47. ) #
  48. def test_detect_datatype_qna_pair_types(self):
  49. """Test that a QnA pair needs to be a tuple of length two, and both items have to be strings."""
  50. with self.assertRaises(TypeError):
  51. self.assertNotEqual(
  52. detect_datatype(("How many planets are in our solar system?", 8)), DataType.QNA_PAIR
  53. ) # NOT equal
  54. def test_detect_datatype_text(self):
  55. self.assertEqual(detect_datatype("Just some text."), DataType.TEXT)
  56. def test_detect_datatype_non_string_error(self):
  57. """Test type error if the value passed is not a string, and not a valid non-string data_type"""
  58. with self.assertRaises(TypeError):
  59. detect_datatype(["foo", "bar"])
  60. @patch("os.path.isfile")
  61. def test_detect_datatype_regular_filesystem_file_not_detected(self, mock_isfile):
  62. """Test error if a valid file is referenced, but it isn't a valid data_type"""
  63. with tempfile.NamedTemporaryFile(suffix=".txt", delete=True) as tmp:
  64. mock_isfile.return_value = True
  65. with self.assertRaises(ValueError):
  66. detect_datatype(tmp.name)
  67. def test_detect_datatype_regular_filesystem_no_file(self):
  68. """Test that if a filepath is not actually an existing file, it is not handled as a file path."""
  69. self.assertEqual(detect_datatype("/var/not-an-existing-file.txt"), DataType.TEXT)
  70. def test_doc_examples_quickstart(self):
  71. """Test examples used in the documentation."""
  72. self.assertEqual(detect_datatype("https://en.wikipedia.org/wiki/Elon_Musk"), DataType.WEB_PAGE)
  73. self.assertEqual(detect_datatype("https://www.tesla.com/elon-musk"), DataType.WEB_PAGE)
  74. def test_doc_examples_introduction(self):
  75. """Test examples used in the documentation."""
  76. self.assertEqual(detect_datatype("https://www.youtube.com/watch?v=3qHkcs3kG44"), DataType.YOUTUBE_VIDEO)
  77. self.assertEqual(
  78. detect_datatype(
  79. "https://navalmanack.s3.amazonaws.com/Eric-Jorgenson_The-Almanack-of-Naval-Ravikant_Final.pdf"
  80. ),
  81. DataType.PDF_FILE,
  82. )
  83. self.assertEqual(detect_datatype("https://nav.al/feedback"), DataType.WEB_PAGE)
  84. def test_doc_examples_app_types(self):
  85. """Test examples used in the documentation."""
  86. self.assertEqual(detect_datatype("https://www.youtube.com/watch?v=Ff4fRgnuFgQ"), DataType.YOUTUBE_VIDEO)
  87. self.assertEqual(detect_datatype("https://en.wikipedia.org/wiki/Mark_Zuckerberg"), DataType.WEB_PAGE)
  88. def test_doc_examples_configuration(self):
  89. """Test examples used in the documentation."""
  90. import subprocess
  91. import sys
  92. subprocess.check_call([sys.executable, "-m", "pip", "install", "wikipedia"])
  93. import wikipedia
  94. page = wikipedia.page("Albert Einstein")
  95. # TODO: Add a wikipedia type, so wikipedia is a dependency and we don't need this slow test.
  96. # (timings: import: 1.4s, fetch wiki: 0.7s)
  97. self.assertEqual(detect_datatype(page.content), DataType.TEXT)
  98. if __name__ == "__main__":
  99. unittest.main()