test_base_chunker.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import hashlib
  2. from unittest.mock import MagicMock
  3. import pytest
  4. from embedchain.chunkers.base_chunker import BaseChunker
  5. from embedchain.config.add_config import ChunkerConfig
  6. from embedchain.models.data_type import DataType
  7. @pytest.fixture
  8. def text_splitter_mock():
  9. return MagicMock()
  10. @pytest.fixture
  11. def loader_mock():
  12. return MagicMock()
  13. @pytest.fixture
  14. def app_id():
  15. return "test_app"
  16. @pytest.fixture
  17. def data_type():
  18. return DataType.TEXT
  19. @pytest.fixture
  20. def chunker(text_splitter_mock, data_type):
  21. text_splitter = text_splitter_mock
  22. chunker = BaseChunker(text_splitter)
  23. chunker.set_data_type(data_type)
  24. return chunker
  25. def test_create_chunks_with_config(chunker, text_splitter_mock, loader_mock, app_id, data_type):
  26. text_splitter_mock.split_text.return_value = ["Chunk 1", "long chunk"]
  27. loader_mock.load_data.return_value = {
  28. "data": [{"content": "Content 1", "meta_data": {"url": "URL 1"}}],
  29. "doc_id": "DocID",
  30. }
  31. config = ChunkerConfig(chunk_size=50, chunk_overlap=0, length_function=len, min_chunk_size=10)
  32. result = chunker.create_chunks(loader_mock, "test_src", app_id, config)
  33. assert result["documents"] == ["long chunk"]
  34. def test_create_chunks(chunker, text_splitter_mock, loader_mock, app_id, data_type):
  35. text_splitter_mock.split_text.return_value = ["Chunk 1", "Chunk 2"]
  36. loader_mock.load_data.return_value = {
  37. "data": [{"content": "Content 1", "meta_data": {"url": "URL 1"}}],
  38. "doc_id": "DocID",
  39. }
  40. result = chunker.create_chunks(loader_mock, "test_src", app_id)
  41. expected_ids = [
  42. f"{app_id}--" + hashlib.sha256(("Chunk 1" + "URL 1").encode()).hexdigest(),
  43. f"{app_id}--" + hashlib.sha256(("Chunk 2" + "URL 1").encode()).hexdigest(),
  44. ]
  45. assert result["documents"] == ["Chunk 1", "Chunk 2"]
  46. assert result["ids"] == expected_ids
  47. assert result["metadatas"] == [
  48. {
  49. "url": "URL 1",
  50. "data_type": data_type.value,
  51. "doc_id": f"{app_id}--DocID",
  52. },
  53. {
  54. "url": "URL 1",
  55. "data_type": data_type.value,
  56. "doc_id": f"{app_id}--DocID",
  57. },
  58. ]
  59. assert result["doc_id"] == f"{app_id}--DocID"
  60. def test_get_chunks(chunker, text_splitter_mock):
  61. text_splitter_mock.split_text.return_value = ["Chunk 1", "Chunk 2"]
  62. content = "This is a test content."
  63. result = chunker.get_chunks(content)
  64. assert len(result) == 2
  65. assert result == ["Chunk 1", "Chunk 2"]
  66. def test_set_data_type(chunker):
  67. chunker.set_data_type(DataType.MDX)
  68. assert chunker.data_type == DataType.MDX
  69. def test_get_word_count(chunker):
  70. documents = ["This is a test.", "Another test."]
  71. result = chunker.get_word_count(documents)
  72. assert result == 6