Преглед изворни кода

feat: csv loader (#470)

Co-authored-by: Taranjeet Singh <reachtotj@gmail.com>
cachho пре 1 година
родитељ
комит
bd595f84e8

+ 9 - 0
docs/advanced/data_types.mdx

@@ -73,6 +73,15 @@ app.add('https://example.com/content/intro.docx', data_type="docx")
 app.add('content/intro.docx', data_type="docx")
 ```
 
+### CSV file
+
+To add any csv file, use the data_type as `csv`. `csv` allows remote urls and conventional file paths. Headers are included for each line, so if you have an `age` column, `18` will be added as `age: 18`. Eg:
+
+```python
+app.add('https://example.com/content/sheet.csv', data_type="csv")
+app.add('content/sheet.csv', data_type="csv")
+```
+
 ### Code documentation website loader
 
 To add any code documentation website as a loader, use the data_type as `docs_site`. Eg:

+ 20 - 0
embedchain/chunkers/table.py

@@ -0,0 +1,20 @@
+from typing import Optional
+
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+
+from embedchain.chunkers.base_chunker import BaseChunker
+from embedchain.config.AddConfig import ChunkerConfig
+
+
+class TableChunker(BaseChunker):
+    """Chunker for tables, for instance csv, google sheets or databases."""
+
+    def __init__(self, config: Optional[ChunkerConfig] = None):
+        if config is None:
+            config = ChunkerConfig(chunk_size=300, chunk_overlap=0, length_function=len)
+        text_splitter = RecursiveCharacterTextSplitter(
+            chunk_size=config.chunk_size,
+            chunk_overlap=config.chunk_overlap,
+            length_function=config.length_function,
+        )
+        super().__init__(text_splitter)

+ 4 - 0
embedchain/data_formatter/data_formatter.py

@@ -3,11 +3,13 @@ from embedchain.chunkers.docx_file import DocxFileChunker
 from embedchain.chunkers.notion import NotionChunker
 from embedchain.chunkers.pdf_file import PdfFileChunker
 from embedchain.chunkers.qna_pair import QnaPairChunker
+from embedchain.chunkers.table import TableChunker
 from embedchain.chunkers.text import TextChunker
 from embedchain.chunkers.web_page import WebPageChunker
 from embedchain.chunkers.youtube_video import YoutubeVideoChunker
 from embedchain.config import AddConfig
 from embedchain.helper_classes.json_serializable import JSONSerializable
+from embedchain.loaders.csv import CsvLoader
 from embedchain.loaders.docs_site_loader import DocsSiteLoader
 from embedchain.loaders.docx_file import DocxFileLoader
 from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
@@ -47,6 +49,7 @@ class DataFormatter(JSONSerializable):
             DataType.DOCX: DocxFileLoader,
             DataType.SITEMAP: SitemapLoader,
             DataType.DOCS_SITE: DocsSiteLoader,
+            DataType.CSV: CsvLoader,
         }
         lazy_loaders = {DataType.NOTION}
         if data_type in loaders:
@@ -81,6 +84,7 @@ class DataFormatter(JSONSerializable):
             DataType.WEB_PAGE: WebPageChunker,
             DataType.DOCS_SITE: DocsSiteChunker,
             DataType.NOTION: NotionChunker,
+            DataType.CSV: TableChunker,
         }
         if data_type in chunker_classes:
             chunker_class = chunker_classes[data_type]

+ 46 - 0
embedchain/loaders/csv.py

@@ -0,0 +1,46 @@
+import csv
+from io import StringIO
+from urllib.parse import urlparse
+
+import requests
+
+from embedchain.loaders.base_loader import BaseLoader
+
+
+class CsvLoader(BaseLoader):
+    @staticmethod
+    def _detect_delimiter(first_line):
+        delimiters = [",", "\t", ";", "|"]
+        counts = {delimiter: first_line.count(delimiter) for delimiter in delimiters}
+        return max(counts, key=counts.get)
+
+    @staticmethod
+    def _get_file_content(content):
+        url = urlparse(content)
+        if all([url.scheme, url.netloc]) and url.scheme not in ["file", "http", "https"]:
+            raise ValueError("Not a valid URL.")
+
+        if url.scheme in ["http", "https"]:
+            response = requests.get(content)
+            response.raise_for_status()
+            return StringIO(response.text)
+        elif url.scheme == "file":
+            path = url.path
+            return open(path, newline="")  # Open the file using the path from the URI
+        else:
+            return open(content, newline="")  # Treat content as a regular file path
+
+    @staticmethod
+    def load_data(content):
+        """Load a csv file with headers. Each line is a document"""
+        result = []
+
+        with CsvLoader._get_file_content(content) as file:
+            first_line = file.readline()
+            delimiter = CsvLoader._detect_delimiter(first_line)
+            file.seek(0)  # Reset the file pointer to the start
+            reader = csv.DictReader(file, delimiter=delimiter)
+            for i, row in enumerate(reader):
+                line = ", ".join([f"{field}: {value}" for field, value in row.items()])
+                result.append({"content": line, "meta_data": {"url": content, "row": i + 1}})
+        return result

+ 1 - 0
embedchain/models/data_type.py

@@ -11,3 +11,4 @@ class DataType(Enum):
     TEXT = "text"
     QNA_PAIR = "qna_pair"
     NOTION = "notion"
+    CSV = "csv"

+ 8 - 0
embedchain/utils.py

@@ -147,6 +147,10 @@ def detect_datatype(source: Any) -> DataType:
             logging.debug(f"Source of `{formatted_source}` detected as `sitemap`.")
             return DataType.SITEMAP
 
+        if url.path.endswith(".csv"):
+            logging.debug(f"Source of `{formatted_source}` detected as `csv`.")
+            return DataType.CSV
+
         if url.path.endswith(".docx"):
             logging.debug(f"Source of `{formatted_source}` detected as `docx`.")
             return DataType.DOCX
@@ -182,6 +186,10 @@ def detect_datatype(source: Any) -> DataType:
             logging.debug(f"Source of `{formatted_source}` detected as `docx`.")
             return DataType.DOCX
 
+        if source.endswith(".csv"):
+            logging.debug(f"Source of `{formatted_source}` detected as `csv`.")
+            return DataType.CSV
+
         # If the source is a valid file, that's not detectable as a type, an error is raised.
         # It does not fallback to text.
         raise ValueError(

+ 84 - 0
tests/loaders/test_csv.py

@@ -0,0 +1,84 @@
+import csv
+import os
+import pathlib
+import tempfile
+
+import pytest
+
+from embedchain.loaders.csv import CsvLoader
+
+
+@pytest.mark.parametrize("delimiter", [",", "\t", ";", "|"])
+def test_load_data(delimiter):
+    """
+    Test csv loader
+
+    Tests that file is loaded, metadata is correct and content is correct
+    """
+    # Creating temporary CSV file
+    with tempfile.NamedTemporaryFile(mode="w+", newline="", delete=False) as tmpfile:
+        writer = csv.writer(tmpfile, delimiter=delimiter)
+        writer.writerow(["Name", "Age", "Occupation"])
+        writer.writerow(["Alice", "28", "Engineer"])
+        writer.writerow(["Bob", "35", "Doctor"])
+        writer.writerow(["Charlie", "22", "Student"])
+
+        tmpfile.seek(0)
+        filename = tmpfile.name
+
+        # Loading CSV using CsvLoader
+        loader = CsvLoader()
+        result = loader.load_data(filename)
+
+        # Assertions
+        assert len(result) == 3
+        assert result[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer"
+        assert result[0]["meta_data"]["url"] == filename
+        assert result[0]["meta_data"]["row"] == 1
+        assert result[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor"
+        assert result[1]["meta_data"]["url"] == filename
+        assert result[1]["meta_data"]["row"] == 2
+        assert result[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student"
+        assert result[2]["meta_data"]["url"] == filename
+        assert result[2]["meta_data"]["row"] == 3
+
+        # Cleaning up the temporary file
+        os.unlink(filename)
+
+
+@pytest.mark.parametrize("delimiter", [",", "\t", ";", "|"])
+def test_load_data_with_file_uri(delimiter):
+    """
+    Test csv loader with file URI
+
+    Tests that file is loaded, metadata is correct and content is correct
+    """
+    # Creating temporary CSV file
+    with tempfile.NamedTemporaryFile(mode="w+", newline="", delete=False) as tmpfile:
+        writer = csv.writer(tmpfile, delimiter=delimiter)
+        writer.writerow(["Name", "Age", "Occupation"])
+        writer.writerow(["Alice", "28", "Engineer"])
+        writer.writerow(["Bob", "35", "Doctor"])
+        writer.writerow(["Charlie", "22", "Student"])
+
+        tmpfile.seek(0)
+        filename = pathlib.Path(tmpfile.name).as_uri()  # Convert path to file URI
+
+        # Loading CSV using CsvLoader
+        loader = CsvLoader()
+        result = loader.load_data(filename)
+
+        # Assertions
+        assert len(result) == 3
+        assert result[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer"
+        assert result[0]["meta_data"]["url"] == filename
+        assert result[0]["meta_data"]["row"] == 1
+        assert result[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor"
+        assert result[1]["meta_data"]["url"] == filename
+        assert result[1]["meta_data"]["row"] == 2
+        assert result[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student"
+        assert result[2]["meta_data"]["url"] == filename
+        assert result[2]["meta_data"]["row"] == 3
+
+        # Cleaning up the temporary file
+        os.unlink(tmpfile.name)