test_xml.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import tempfile
  2. import pytest
  3. from embedchain.loaders.xml import XmlLoader
  4. # Taken from https://github.com/langchain-ai/langchain/blob/master/libs/langchain/tests/integration_tests/examples/factbook.xml
  5. SAMPLE_XML = """<?xml version="1.0" encoding="UTF-8"?>
  6. <factbook>
  7. <country>
  8. <name>United States</name>
  9. <capital>Washington, DC</capital>
  10. <leader>Joe Biden</leader>
  11. <sport>Baseball</sport>
  12. </country>
  13. <country>
  14. <name>Canada</name>
  15. <capital>Ottawa</capital>
  16. <leader>Justin Trudeau</leader>
  17. <sport>Hockey</sport>
  18. </country>
  19. <country>
  20. <name>France</name>
  21. <capital>Paris</capital>
  22. <leader>Emmanuel Macron</leader>
  23. <sport>Soccer</sport>
  24. </country>
  25. <country>
  26. <name>Trinidad &amp; Tobado</name>
  27. <capital>Port of Spain</capital>
  28. <leader>Keith Rowley</leader>
  29. <sport>Track &amp; Field</sport>
  30. </country>
  31. </factbook>"""
  32. @pytest.mark.parametrize("xml", [SAMPLE_XML])
  33. def test_load_data(xml: str):
  34. """
  35. Test XML loader
  36. Tests that XML file is loaded, metadata is correct and content is correct
  37. """
  38. # Creating temporary XML file
  39. with tempfile.NamedTemporaryFile(mode="w+") as tmpfile:
  40. tmpfile.write(xml)
  41. tmpfile.seek(0)
  42. filename = tmpfile.name
  43. # Loading CSV using XmlLoader
  44. loader = XmlLoader()
  45. result = loader.load_data(filename)
  46. data = result["data"]
  47. # Assertions
  48. assert len(data) == 1
  49. assert "United States Washington, DC Joe Biden" in data[0]["content"]
  50. assert "Canada Ottawa Justin Trudeau" in data[0]["content"]
  51. assert "France Paris Emmanuel Macron" in data[0]["content"]
  52. assert "Trinidad & Tobado Port of Spain Keith Rowley" in data[0]["content"]
  53. assert data[0]["meta_data"]["url"] == filename