base_chunker.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import hashlib
  2. import logging
  3. from typing import Optional
  4. from embedchain.config.add_config import ChunkerConfig
  5. from embedchain.helpers.json_serializable import JSONSerializable
  6. from embedchain.models.data_type import DataType
  7. logger = logging.getLogger(__name__)
  8. class BaseChunker(JSONSerializable):
  9. def __init__(self, text_splitter):
  10. """Initialize the chunker."""
  11. self.text_splitter = text_splitter
  12. self.data_type = None
  13. def create_chunks(self, loader, src, app_id=None, config: Optional[ChunkerConfig] = None):
  14. """
  15. Loads data and chunks it.
  16. :param loader: The loader whose `load_data` method is used to create
  17. the raw data.
  18. :param src: The data to be handled by the loader. Can be a URL for
  19. remote sources or local content for local loaders.
  20. :param app_id: App id used to generate the doc_id.
  21. """
  22. documents = []
  23. chunk_ids = []
  24. id_map = {}
  25. min_chunk_size = config.min_chunk_size if config is not None else 1
  26. logger.info(f"Skipping chunks smaller than {min_chunk_size} characters")
  27. data_result = loader.load_data(src)
  28. data_records = data_result["data"]
  29. doc_id = data_result["doc_id"]
  30. # Prefix app_id in the document id if app_id is not None to
  31. # distinguish between different documents stored in the same
  32. # elasticsearch or opensearch index
  33. doc_id = f"{app_id}--{doc_id}" if app_id is not None else doc_id
  34. metadatas = []
  35. for data in data_records:
  36. content = data["content"]
  37. metadata = data["meta_data"]
  38. # add data type to meta data to allow query using data type
  39. metadata["data_type"] = self.data_type.value
  40. metadata["doc_id"] = doc_id
  41. # TODO: Currently defaulting to the src as the url. This is done intentianally since some
  42. # of the data types like 'gmail' loader doesn't have the url in the meta data.
  43. url = metadata.get("url", src)
  44. chunks = self.get_chunks(content)
  45. for chunk in chunks:
  46. chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest()
  47. chunk_id = f"{app_id}--{chunk_id}" if app_id is not None else chunk_id
  48. if id_map.get(chunk_id) is None and len(chunk) >= min_chunk_size:
  49. id_map[chunk_id] = True
  50. chunk_ids.append(chunk_id)
  51. documents.append(chunk)
  52. metadatas.append(metadata)
  53. return {
  54. "documents": documents,
  55. "ids": chunk_ids,
  56. "metadatas": metadatas,
  57. "doc_id": doc_id,
  58. }
  59. def get_chunks(self, content):
  60. """
  61. Returns chunks using text splitter instance.
  62. Override in child class if custom logic.
  63. """
  64. return self.text_splitter.split_text(content)
  65. def set_data_type(self, data_type: DataType):
  66. """
  67. set the data type of chunker
  68. """
  69. self.data_type = data_type
  70. # TODO: This should be done during initialization. This means it has to be done in the child classes.
  71. @staticmethod
  72. def get_word_count(documents) -> int:
  73. return sum(len(document.split(" ")) for document in documents)