浏览代码

Improve tests (#780)

Sidharth Mohanty 1 年之前
父节点
当前提交
b91d922600

+ 20 - 15
embedchain/loaders/web_page.py

@@ -15,7 +15,25 @@ class WebPageLoader(BaseLoader):
         """Load data from a web page."""
         response = requests.get(url)
         data = response.content
-        soup = BeautifulSoup(data, "html.parser")
+        content = self._get_clean_content(data, url)
+
+        meta_data = {
+            "url": url,
+        }
+
+        doc_id = hashlib.sha256((content + url).encode()).hexdigest()
+        return {
+            "doc_id": doc_id,
+            "data": [
+                {
+                    "content": content,
+                    "meta_data": meta_data,
+                }
+            ],
+        }
+
+    def _get_clean_content(self, html, url) -> str:
+        soup = BeautifulSoup(html, "html.parser")
         original_size = len(str(soup.get_text()))
 
         tags_to_exclude = [
@@ -61,17 +79,4 @@ class WebPageLoader(BaseLoader):
                 f"[{url}] Cleaned page size: {cleaned_size} characters, down from {original_size} (shrunk: {original_size-cleaned_size} chars, {round((1-(cleaned_size/original_size)) * 100, 2)}%)"  # noqa:E501
             )
 
-        meta_data = {
-            "url": url,
-        }
-        content = content
-        doc_id = hashlib.sha256((content + url).encode()).hexdigest()
-        return {
-            "doc_id": doc_id,
-            "data": [
-                {
-                    "content": content,
-                    "meta_data": meta_data,
-                }
-            ],
-        }
+        return content

+ 84 - 0
tests/chunkers/test_base_chunker.py

@@ -0,0 +1,84 @@
+import hashlib
+import pytest
+from unittest.mock import MagicMock
+from embedchain.chunkers.base_chunker import BaseChunker
+from embedchain.models.data_type import DataType
+
+
+@pytest.fixture
+def text_splitter_mock():
+    return MagicMock()
+
+
+@pytest.fixture
+def loader_mock():
+    return MagicMock()
+
+
+@pytest.fixture
+def app_id():
+    return "test_app"
+
+
+@pytest.fixture
+def data_type():
+    return DataType.TEXT
+
+
+@pytest.fixture
+def chunker(text_splitter_mock, data_type):
+    text_splitter = text_splitter_mock
+    chunker = BaseChunker(text_splitter)
+    chunker.set_data_type(data_type)
+    return chunker
+
+
+def test_create_chunks(chunker, text_splitter_mock, loader_mock, app_id, data_type):
+    text_splitter_mock.split_text.return_value = ["Chunk 1", "Chunk 2"]
+    loader_mock.load_data.return_value = {
+        "data": [{"content": "Content 1", "meta_data": {"url": "URL 1"}}],
+        "doc_id": "DocID",
+    }
+
+    result = chunker.create_chunks(loader_mock, "test_src", app_id)
+    expected_ids = [
+        hashlib.sha256(("Chunk 1" + "URL 1").encode()).hexdigest(),
+        hashlib.sha256(("Chunk 2" + "URL 1").encode()).hexdigest(),
+    ]
+
+    assert result["documents"] == ["Chunk 1", "Chunk 2"]
+    assert result["ids"] == expected_ids
+    assert result["metadatas"] == [
+        {
+            "url": "URL 1",
+            "data_type": data_type.value,
+            "doc_id": f"{app_id}--DocID",
+        },
+        {
+            "url": "URL 1",
+            "data_type": data_type.value,
+            "doc_id": f"{app_id}--DocID",
+        },
+    ]
+    assert result["doc_id"] == f"{app_id}--DocID"
+
+
+def test_get_chunks(chunker, text_splitter_mock):
+    text_splitter_mock.split_text.return_value = ["Chunk 1", "Chunk 2"]
+
+    content = "This is a test content."
+    result = chunker.get_chunks(content)
+
+    assert len(result) == 2
+    assert result == ["Chunk 1", "Chunk 2"]
+
+
+def test_set_data_type(chunker):
+    chunker.set_data_type(DataType.MDX)
+    assert chunker.data_type == DataType.MDX
+
+
+def test_get_word_count(chunker):
+    documents = ["This is a test.", "Another test."]
+    result = chunker.get_word_count(documents)
+    assert result == 6

+ 46 - 0
tests/chunkers/test_chunkers.py

@@ -0,0 +1,46 @@
+from embedchain.chunkers.docs_site import DocsSiteChunker
+from embedchain.chunkers.docx_file import DocxFileChunker
+from embedchain.chunkers.mdx import MdxChunker
+from embedchain.chunkers.notion import NotionChunker
+from embedchain.chunkers.pdf_file import PdfFileChunker
+from embedchain.chunkers.qna_pair import QnaPairChunker
+from embedchain.chunkers.sitemap import SitemapChunker
+from embedchain.chunkers.table import TableChunker
+from embedchain.chunkers.text import TextChunker
+from embedchain.chunkers.web_page import WebPageChunker
+from embedchain.chunkers.xml import XmlChunker
+from embedchain.chunkers.youtube_video import YoutubeVideoChunker
+from embedchain.config.add_config import ChunkerConfig
+
+chunker_config = ChunkerConfig(chunk_size=500, chunk_overlap=0, length_function=len)
+
+chunker_common_config = {
+    DocsSiteChunker: {"chunk_size": 500, "chunk_overlap": 50, "length_function": len},
+    DocxFileChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
+    PdfFileChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
+    TextChunker: {"chunk_size": 300, "chunk_overlap": 0, "length_function": len},
+    MdxChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
+    NotionChunker: {"chunk_size": 300, "chunk_overlap": 0, "length_function": len},
+    QnaPairChunker: {"chunk_size": 300, "chunk_overlap": 0, "length_function": len},
+    TableChunker: {"chunk_size": 300, "chunk_overlap": 0, "length_function": len},
+    SitemapChunker: {"chunk_size": 500, "chunk_overlap": 0, "length_function": len},
+    WebPageChunker: {"chunk_size": 500, "chunk_overlap": 0, "length_function": len},
+    XmlChunker: {"chunk_size": 500, "chunk_overlap": 50, "length_function": len},
+    YoutubeVideoChunker: {"chunk_size": 2000, "chunk_overlap": 0, "length_function": len},
+}
+
+
+def test_default_config_values():
+    for chunker_class, config in chunker_common_config.items():
+        chunker = chunker_class()
+        assert chunker.text_splitter._chunk_size == config["chunk_size"]
+        assert chunker.text_splitter._chunk_overlap == config["chunk_overlap"]
+        assert chunker.text_splitter._length_function == config["length_function"]
+
+
+def test_custom_config_values():
+    for chunker_class, _ in chunker_common_config.items():
+        chunker = chunker_class(config=chunker_config)
+        assert chunker.text_splitter._chunk_size == 500
+        assert chunker.text_splitter._chunk_overlap == 0
+        assert chunker.text_splitter._length_function == len

+ 27 - 0
tests/loaders/test_csv.py

@@ -2,6 +2,7 @@ import csv
 import os
 import pathlib
 import tempfile
+from unittest.mock import MagicMock, patch
 
 import pytest
 
@@ -84,3 +85,29 @@ def test_load_data_with_file_uri(delimiter):
 
         # Cleaning up the temporary file
         os.unlink(tmpfile.name)
+
+
+@pytest.mark.parametrize("content", ["ftp://example.com", "sftp://example.com", "mailto://example.com"])
+def test_get_file_content(content):
+    with pytest.raises(ValueError):
+        loader = CsvLoader()
+        loader._get_file_content(content)
+
+
+@pytest.mark.parametrize("content", ["http://example.com", "https://example.com"])
+def test_get_file_content_http(content):
+    """
+    Test _get_file_content method of CsvLoader for http and https URLs
+    """
+
+    with patch("requests.get") as mock_get:
+        mock_response = MagicMock()
+        mock_response.text = "Name,Age,Occupation\nAlice,28,Engineer\nBob,35,Doctor\nCharlie,22,Student"
+        mock_get.return_value = mock_response
+
+        loader = CsvLoader()
+        file_content = loader._get_file_content(content)
+
+        mock_get.assert_called_once_with(content)
+        mock_response.raise_for_status.assert_called_once()
+        assert file_content.read() == mock_response.text

+ 128 - 0
tests/loaders/test_docs_site.py

@@ -0,0 +1,128 @@
+import hashlib
+import pytest
+from unittest.mock import Mock, patch
+from requests import Response
+from embedchain.loaders.docs_site_loader import DocsSiteLoader
+
+
+@pytest.fixture
+def mock_requests_get():
+    with patch("requests.get") as mock_get:
+        yield mock_get
+
+
+@pytest.fixture
+def docs_site_loader():
+    return DocsSiteLoader()
+
+
+def test_get_child_links_recursive(mock_requests_get, docs_site_loader):
+    mock_response = Mock()
+    mock_response.status_code = 200
+    mock_response.text = """
+        <html>
+            <a href="/page1">Page 1</a>
+            <a href="/page2">Page 2</a>
+        </html>
+    """
+    mock_requests_get.return_value = mock_response
+
+    docs_site_loader._get_child_links_recursive("https://example.com")
+
+    assert len(docs_site_loader.visited_links) == 2
+    assert "https://example.com/page1" in docs_site_loader.visited_links
+    assert "https://example.com/page2" in docs_site_loader.visited_links
+
+
+def test_get_child_links_recursive_status_not_200(mock_requests_get, docs_site_loader):
+    mock_response = Mock()
+    mock_response.status_code = 404
+    mock_requests_get.return_value = mock_response
+
+    docs_site_loader._get_child_links_recursive("https://example.com")
+
+    assert len(docs_site_loader.visited_links) == 0
+
+
+def test_get_all_urls(mock_requests_get, docs_site_loader):
+    mock_response = Mock()
+    mock_response.status_code = 200
+    mock_response.text = """
+        <html>
+            <a href="/page1">Page 1</a>
+            <a href="/page2">Page 2</a>
+            <a href="https://example.com/external">External</a>
+        </html>
+    """
+    mock_requests_get.return_value = mock_response
+
+    all_urls = docs_site_loader._get_all_urls("https://example.com")
+
+    assert len(all_urls) == 3
+    assert "https://example.com/page1" in all_urls
+    assert "https://example.com/page2" in all_urls
+    assert "https://example.com/external" in all_urls
+
+
+def test_load_data_from_url(mock_requests_get, docs_site_loader):
+    mock_response = Mock()
+    mock_response.status_code = 200
+    mock_response.content = """
+        <html>
+            <nav>
+                <h1>Navigation</h1>
+            </nav>
+            <article class="bd-article">
+                <p>Article Content</p>
+            </article>
+        </html>
+    """.encode()
+    mock_requests_get.return_value = mock_response
+
+    data = docs_site_loader._load_data_from_url("https://example.com/page1")
+
+    assert len(data) == 1
+    assert data[0]["content"] == "Article Content"
+    assert data[0]["meta_data"]["url"] == "https://example.com/page1"
+
+
+def test_load_data_from_url_status_not_200(mock_requests_get, docs_site_loader):
+    mock_response = Mock()
+    mock_response.status_code = 404
+    mock_requests_get.return_value = mock_response
+
+    data = docs_site_loader._load_data_from_url("https://example.com/page1")
+
+    assert data == []
+    assert len(data) == 0
+
+
+def test_load_data(mock_requests_get, docs_site_loader):
+    mock_response = Response()
+    mock_response.status_code = 200
+    mock_response._content = """
+        <html>
+            <a href="/page1">Page 1</a>
+            <a href="/page2">Page 2</a>
+        """.encode()
+    mock_requests_get.return_value = mock_response
+
+    url = "https://example.com"
+    data = docs_site_loader.load_data(url)
+    expected_doc_id = hashlib.sha256((" ".join(docs_site_loader.visited_links) + url).encode()).hexdigest()
+
+    assert len(data["data"]) == 2
+    assert data["doc_id"] == expected_doc_id
+
+
+def test_if_response_status_not_200(mock_requests_get, docs_site_loader):
+    mock_response = Response()
+    mock_response.status_code = 404
+    mock_requests_get.return_value = mock_response
+
+    url = "https://example.com"
+    data = docs_site_loader.load_data(url)
+    expected_doc_id = hashlib.sha256((" ".join(docs_site_loader.visited_links) + url).encode()).hexdigest()
+
+    assert len(data["data"]) == 0
+    assert data["doc_id"] == expected_doc_id

+ 37 - 0
tests/loaders/test_docx_file.py

@@ -0,0 +1,37 @@
+import hashlib
+import pytest
+from unittest.mock import MagicMock, patch
+from embedchain.loaders.docx_file import DocxFileLoader
+
+
+@pytest.fixture
+def mock_docx2txt_loader():
+    with patch("embedchain.loaders.docx_file.Docx2txtLoader") as mock_loader:
+        yield mock_loader
+
+
+@pytest.fixture
+def docx_file_loader():
+    return DocxFileLoader()
+
+
+def test_load_data(mock_docx2txt_loader, docx_file_loader):
+    mock_url = "mock_docx_file.docx"
+
+    mock_loader = MagicMock()
+    mock_loader.load.return_value = [MagicMock(page_content="Sample Docx Content", metadata={"url": "local"})]
+
+    mock_docx2txt_loader.return_value = mock_loader
+
+    result = docx_file_loader.load_data(mock_url)
+
+    assert "doc_id" in result
+    assert "data" in result
+
+    expected_content = "Sample Docx Content"
+    assert result["data"][0]["content"] == expected_content
+
+    assert result["data"][0]["meta_data"]["url"] == "local"
+
+    expected_doc_id = hashlib.sha256((expected_content + mock_url).encode()).hexdigest()
+    assert result["doc_id"] == expected_doc_id

+ 30 - 0
tests/loaders/test_local_qna_pair.py

@@ -0,0 +1,30 @@
+import hashlib
+import pytest
+from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
+
+
+@pytest.fixture
+def qna_pair_loader():
+    return LocalQnaPairLoader()
+
+
+def test_load_data(qna_pair_loader):
+    question = "What is the capital of France?"
+    answer = "The capital of France is Paris."
+
+    content = (question, answer)
+    result = qna_pair_loader.load_data(content)
+
+    assert "doc_id" in result
+    assert "data" in result
+    url = "local"
+
+    expected_content = f"Q: {question}\nA: {answer}"
+    assert result["data"][0]["content"] == expected_content
+
+    assert result["data"][0]["meta_data"]["url"] == url
+
+    assert result["data"][0]["meta_data"]["question"] == question
+
+    expected_doc_id = hashlib.sha256((expected_content + url).encode()).hexdigest()
+    assert result["doc_id"] == expected_doc_id

+ 25 - 0
tests/loaders/test_local_text.py

@@ -0,0 +1,25 @@
+import hashlib
+import pytest
+from embedchain.loaders.local_text import LocalTextLoader
+
+
+@pytest.fixture
+def text_loader():
+    return LocalTextLoader()
+
+
+def test_load_data(text_loader):
+    mock_content = "This is a sample text content."
+
+    result = text_loader.load_data(mock_content)
+
+    assert "doc_id" in result
+    assert "data" in result
+
+    url = "local"
+    assert result["data"][0]["content"] == mock_content
+
+    assert result["data"][0]["meta_data"]["url"] == url
+
+    expected_doc_id = hashlib.sha256((mock_content + url).encode()).hexdigest()
+    assert result["doc_id"] == expected_doc_id

+ 28 - 0
tests/loaders/test_mdx.py

@@ -0,0 +1,28 @@
+import hashlib
+import pytest
+from unittest.mock import patch, mock_open
+from embedchain.loaders.mdx import MdxLoader
+
+
+@pytest.fixture
+def mdx_loader():
+    return MdxLoader()
+
+
+def test_load_data(mdx_loader):
+    mock_content = "Sample MDX Content"
+
+    # Mock open function to simulate file reading
+    with patch("builtins.open", mock_open(read_data=mock_content)):
+        url = "mock_file.mdx"
+        result = mdx_loader.load_data(url)
+
+        assert "doc_id" in result
+        assert "data" in result
+
+        assert result["data"][0]["content"] == mock_content
+
+        assert result["data"][0]["meta_data"]["url"] == url
+
+        expected_doc_id = hashlib.sha256((mock_content + url).encode()).hexdigest()
+        assert result["doc_id"] == expected_doc_id

+ 34 - 0
tests/loaders/test_notion.py

@@ -0,0 +1,34 @@
+import hashlib
+import os
+import pytest
+from unittest.mock import Mock, patch
+from embedchain.loaders.notion import NotionLoader
+
+
+@pytest.fixture
+def notion_loader():
+    with patch.dict(os.environ, {"NOTION_INTEGRATION_TOKEN": "test_notion_token"}):
+        yield NotionLoader()
+
+
+def test_load_data(notion_loader):
+    source = "https://www.notion.so/Test-Page-1234567890abcdef1234567890abcdef"
+    mock_text = "This is a test page."
+    expected_doc_id = hashlib.sha256((mock_text + source).encode()).hexdigest()
+    expected_data = [
+        {
+            "content": mock_text,
+            "meta_data": {"url": "notion-12345678-90ab-cdef-1234-567890abcdef"},  # formatted_id
+        }
+    ]
+
+    mock_page = Mock()
+    mock_page.text = mock_text
+    mock_documents = [mock_page]
+
+    with patch("embedchain.loaders.notion.NotionPageReader") as mock_reader:
+        mock_reader.return_value.load_data.return_value = mock_documents
+        result = notion_loader.load_data(source)
+
+    assert result["doc_id"] == expected_doc_id
+    assert result["data"] == expected_data

+ 115 - 0
tests/loaders/test_web_page.py

@@ -0,0 +1,115 @@
+import hashlib
+import pytest
+from unittest.mock import Mock, patch
+from embedchain.loaders.web_page import WebPageLoader
+
+
+@pytest.fixture
+def web_page_loader():
+    return WebPageLoader()
+
+
+def test_load_data(web_page_loader):
+    page_url = "https://example.com/page"
+    mock_response = Mock()
+    mock_response.status_code = 200
+    mock_response.content = """
+        <html>
+            <head>
+                <title>Test Page</title>
+            </head>
+            <body>
+                <div id="content">
+                    <p>This is some test content.</p>
+                </div>
+            </body>
+        </html>
+    """
+    with patch("embedchain.loaders.web_page.requests.get", return_value=mock_response):
+        result = web_page_loader.load_data(page_url)
+
+    content = web_page_loader._get_clean_content(mock_response.content, page_url)
+    expected_doc_id = hashlib.sha256((content + page_url).encode()).hexdigest()
+    assert result["doc_id"] == expected_doc_id
+
+    expected_data = [
+        {
+            "content": content,
+            "meta_data": {
+                "url": page_url,
+            },
+        }
+    ]
+
+    assert result["data"] == expected_data
+
+
+def test_get_clean_content_excludes_unnecessary_info(web_page_loader):
+    mock_html = """
+        <html>
+        <head>
+            <title>Sample HTML</title>
+            <style>
+                /* Stylesheet to be excluded */
+                .elementor-location-header {
+                    background-color: #f0f0f0;
+                }
+            </style>
+        </head>
+        <body>
+            <header id="header">Header Content</header>
+            <nav class="nav">Nav Content</nav>
+            <aside>Aside Content</aside>
+            <form>Form Content</form>
+            <main>Main Content</main>
+            <footer class="footer">Footer Content</footer>
+            <script>Some Script</script>
+            <noscript>NoScript Content</noscript>
+            <svg>SVG Content</svg>
+            <canvas>Canvas Content</canvas>
+            
+            <div id="sidebar">Sidebar Content</div>
+            <div id="main-navigation">Main Navigation Content</div>
+            <div id="menu-main-menu">Menu Main Menu Content</div>
+            
+            <div class="header-sidebar-wrapper">Header Sidebar Wrapper Content</div>
+            <div class="blog-sidebar-wrapper">Blog Sidebar Wrapper Content</div>
+            <div class="related-posts">Related Posts Content</div>
+        </body>
+        </html>
+    """
+
+    tags_to_exclude = [
+        "nav",
+        "aside",
+        "form",
+        "header",
+        "noscript",
+        "svg",
+        "canvas",
+        "footer",
+        "script",
+        "style",
+    ]
+    ids_to_exclude = ["sidebar", "main-navigation", "menu-main-menu"]
+    classes_to_exclude = [
+        "elementor-location-header",
+        "navbar-header",
+        "nav",
+        "header-sidebar-wrapper",
+        "blog-sidebar-wrapper",
+        "related-posts",
+    ]
+
+    content = web_page_loader._get_clean_content(mock_html, "https://example.com/page")
+
+    for tag in tags_to_exclude:
+        assert tag not in content
+
+    for id in ids_to_exclude:
+        assert id not in content
+
+    for class_name in classes_to_exclude:
+        assert class_name not in content
+
+    assert len(content) > 0

+ 47 - 0
tests/loaders/test_youtube_video.py

@@ -0,0 +1,47 @@
+import hashlib
+import pytest
+from unittest.mock import MagicMock, Mock, patch
+from embedchain.loaders.youtube_video import YoutubeVideoLoader
+
+
+@pytest.fixture
+def youtube_video_loader():
+    return YoutubeVideoLoader()
+
+
+def test_load_data(youtube_video_loader):
+    video_url = "https://www.youtube.com/watch?v=VIDEO_ID"
+    mock_loader = Mock()
+    mock_page_content = "This is a YouTube video content."
+    mock_loader.load.return_value = [
+        MagicMock(
+            page_content=mock_page_content,
+            metadata={"url": video_url, "title": "Test Video"},
+        )
+    ]
+
+    with patch("embedchain.loaders.youtube_video.YoutubeLoader.from_youtube_url", return_value=mock_loader):
+        result = youtube_video_loader.load_data(video_url)
+
+    expected_doc_id = hashlib.sha256((mock_page_content + video_url).encode()).hexdigest()
+
+    assert result["doc_id"] == expected_doc_id
+
+    expected_data = [
+        {
+            "content": "This is a YouTube video content.",
+            "meta_data": {"url": video_url, "title": "Test Video"},
+        }
+    ]
+
+    assert result["data"] == expected_data
+
+
+def test_load_data_with_empty_doc(youtube_video_loader):
+    video_url = "https://www.youtube.com/watch?v=VIDEO_ID"
+    mock_loader = Mock()
+    mock_loader.load.return_value = []
+
+    with patch("embedchain.loaders.youtube_video.YoutubeLoader.from_youtube_url", return_value=mock_loader):
+        with pytest.raises(ValueError):
+            youtube_video_loader.load_data(video_url)