test_csv.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import csv
  2. import os
  3. import pathlib
  4. import tempfile
  5. import pytest
  6. from embedchain.loaders.csv import CsvLoader
  7. @pytest.mark.parametrize("delimiter", [",", "\t", ";", "|"])
  8. def test_load_data(delimiter):
  9. """
  10. Test csv loader
  11. Tests that file is loaded, metadata is correct and content is correct
  12. """
  13. # Creating temporary CSV file
  14. with tempfile.NamedTemporaryFile(mode="w+", newline="", delete=False) as tmpfile:
  15. writer = csv.writer(tmpfile, delimiter=delimiter)
  16. writer.writerow(["Name", "Age", "Occupation"])
  17. writer.writerow(["Alice", "28", "Engineer"])
  18. writer.writerow(["Bob", "35", "Doctor"])
  19. writer.writerow(["Charlie", "22", "Student"])
  20. tmpfile.seek(0)
  21. filename = tmpfile.name
  22. # Loading CSV using CsvLoader
  23. loader = CsvLoader()
  24. result = loader.load_data(filename)
  25. data = result["data"]
  26. # Assertions
  27. assert len(data) == 3
  28. assert data[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer"
  29. assert data[0]["meta_data"]["url"] == filename
  30. assert data[0]["meta_data"]["row"] == 1
  31. assert data[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor"
  32. assert data[1]["meta_data"]["url"] == filename
  33. assert data[1]["meta_data"]["row"] == 2
  34. assert data[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student"
  35. assert data[2]["meta_data"]["url"] == filename
  36. assert data[2]["meta_data"]["row"] == 3
  37. # Cleaning up the temporary file
  38. os.unlink(filename)
  39. @pytest.mark.parametrize("delimiter", [",", "\t", ";", "|"])
  40. def test_load_data_with_file_uri(delimiter):
  41. """
  42. Test csv loader with file URI
  43. Tests that file is loaded, metadata is correct and content is correct
  44. """
  45. # Creating temporary CSV file
  46. with tempfile.NamedTemporaryFile(mode="w+", newline="", delete=False) as tmpfile:
  47. writer = csv.writer(tmpfile, delimiter=delimiter)
  48. writer.writerow(["Name", "Age", "Occupation"])
  49. writer.writerow(["Alice", "28", "Engineer"])
  50. writer.writerow(["Bob", "35", "Doctor"])
  51. writer.writerow(["Charlie", "22", "Student"])
  52. tmpfile.seek(0)
  53. filename = pathlib.Path(tmpfile.name).as_uri() # Convert path to file URI
  54. # Loading CSV using CsvLoader
  55. loader = CsvLoader()
  56. result = loader.load_data(filename)
  57. data = result["data"]
  58. # Assertions
  59. assert len(data) == 3
  60. assert data[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer"
  61. assert data[0]["meta_data"]["url"] == filename
  62. assert data[0]["meta_data"]["row"] == 1
  63. assert data[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor"
  64. assert data[1]["meta_data"]["url"] == filename
  65. assert data[1]["meta_data"]["row"] == 2
  66. assert data[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student"
  67. assert data[2]["meta_data"]["url"] == filename
  68. assert data[2]["meta_data"]["row"] == 3
  69. # Cleaning up the temporary file
  70. os.unlink(tmpfile.name)