test_base_chunker.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import hashlib
  2. from unittest.mock import MagicMock
  3. import pytest
  4. from embedchain.chunkers.base_chunker import BaseChunker
  5. from embedchain.models.data_type import DataType
  6. @pytest.fixture
  7. def text_splitter_mock():
  8. return MagicMock()
  9. @pytest.fixture
  10. def loader_mock():
  11. return MagicMock()
  12. @pytest.fixture
  13. def app_id():
  14. return "test_app"
  15. @pytest.fixture
  16. def data_type():
  17. return DataType.TEXT
  18. @pytest.fixture
  19. def chunker(text_splitter_mock, data_type):
  20. text_splitter = text_splitter_mock
  21. chunker = BaseChunker(text_splitter)
  22. chunker.set_data_type(data_type)
  23. return chunker
  24. def test_create_chunks(chunker, text_splitter_mock, loader_mock, app_id, data_type):
  25. text_splitter_mock.split_text.return_value = ["Chunk 1", "Chunk 2"]
  26. loader_mock.load_data.return_value = {
  27. "data": [{"content": "Content 1", "meta_data": {"url": "URL 1"}}],
  28. "doc_id": "DocID",
  29. }
  30. result = chunker.create_chunks(loader_mock, "test_src", app_id)
  31. expected_ids = [
  32. f"{app_id}--" + hashlib.sha256(("Chunk 1" + "URL 1").encode()).hexdigest(),
  33. f"{app_id}--" + hashlib.sha256(("Chunk 2" + "URL 1").encode()).hexdigest(),
  34. ]
  35. assert result["documents"] == ["Chunk 1", "Chunk 2"]
  36. assert result["ids"] == expected_ids
  37. assert result["metadatas"] == [
  38. {
  39. "url": "URL 1",
  40. "data_type": data_type.value,
  41. "doc_id": f"{app_id}--DocID",
  42. },
  43. {
  44. "url": "URL 1",
  45. "data_type": data_type.value,
  46. "doc_id": f"{app_id}--DocID",
  47. },
  48. ]
  49. assert result["doc_id"] == f"{app_id}--DocID"
  50. def test_get_chunks(chunker, text_splitter_mock):
  51. text_splitter_mock.split_text.return_value = ["Chunk 1", "Chunk 2"]
  52. content = "This is a test content."
  53. result = chunker.get_chunks(content)
  54. assert len(result) == 2
  55. assert result == ["Chunk 1", "Chunk 2"]
  56. def test_set_data_type(chunker):
  57. chunker.set_data_type(DataType.MDX)
  58. assert chunker.data_type == DataType.MDX
  59. def test_get_word_count(chunker):
  60. documents = ["This is a test.", "Another test."]
  61. result = chunker.get_word_count(documents)
  62. assert result == 6