csv.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import csv
  2. import hashlib
  3. from io import StringIO
  4. from urllib.parse import urlparse
  5. import requests
  6. from embedchain.loaders.base_loader import BaseLoader
  7. class CsvLoader(BaseLoader):
  8. @staticmethod
  9. def _detect_delimiter(first_line):
  10. delimiters = [",", "\t", ";", "|"]
  11. counts = {delimiter: first_line.count(delimiter) for delimiter in delimiters}
  12. return max(counts, key=counts.get)
  13. @staticmethod
  14. def _get_file_content(content):
  15. url = urlparse(content)
  16. if all([url.scheme, url.netloc]) and url.scheme not in ["file", "http", "https"]:
  17. raise ValueError("Not a valid URL.")
  18. if url.scheme in ["http", "https"]:
  19. response = requests.get(content)
  20. response.raise_for_status()
  21. return StringIO(response.text)
  22. elif url.scheme == "file":
  23. path = url.path
  24. return open(path, newline="", encoding="utf-8") # Open the file using the path from the URI
  25. else:
  26. return open(content, newline="", encoding="utf-8") # Treat content as a regular file path
  27. @staticmethod
  28. def load_data(content):
  29. """Load a csv file with headers. Each line is a document"""
  30. result = []
  31. lines = []
  32. with CsvLoader._get_file_content(content) as file:
  33. first_line = file.readline()
  34. delimiter = CsvLoader._detect_delimiter(first_line)
  35. file.seek(0) # Reset the file pointer to the start
  36. reader = csv.DictReader(file, delimiter=delimiter)
  37. for i, row in enumerate(reader):
  38. line = ", ".join([f"{field}: {value}" for field, value in row.items()])
  39. lines.append(line)
  40. result.append({"content": line, "meta_data": {"url": content, "row": i + 1}})
  41. doc_id = hashlib.sha256((content + " ".join(lines)).encode()).hexdigest()
  42. return {"doc_id": doc_id, "data": result}