base_chunker.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import hashlib
  2. from langchain.text_splitter import RecursiveCharacterTextSplitter
  3. from embedchain.config.add_config import ChunkerConfig
  4. from embedchain.helper.json_serializable import JSONSerializable
  5. from embedchain.models.data_type import DataType
  6. class BaseChunker(JSONSerializable):
  7. def __init__(self, text_splitter):
  8. """Initialize the chunker."""
  9. if text_splitter is None:
  10. config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
  11. self.text_splitter = RecursiveCharacterTextSplitter(
  12. chunk_size=config.chunk_size,
  13. chunk_overlap=config.chunk_overlap,
  14. length_function=config.length_function,
  15. )
  16. else:
  17. self.text_splitter = text_splitter
  18. self.data_type = None
  19. def create_chunks(self, loader, src, app_id=None):
  20. """
  21. Loads data and chunks it.
  22. :param loader: The loader which's `load_data` method is used to create
  23. the raw data.
  24. :param src: The data to be handled by the loader. Can be a URL for
  25. remote sources or local content for local loaders.
  26. :param app_id: App id used to generate the doc_id.
  27. """
  28. documents = []
  29. chunk_ids = []
  30. idMap = {}
  31. data_result = loader.load_data(src)
  32. data_records = data_result["data"]
  33. doc_id = data_result["doc_id"]
  34. # Prefix app_id in the document id if app_id is not None to
  35. # distinguish between different documents stored in the same
  36. # elasticsearch or opensearch index
  37. doc_id = f"{app_id}--{doc_id}" if app_id is not None else doc_id
  38. metadatas = []
  39. for data in data_records:
  40. content = data["content"]
  41. meta_data = data["meta_data"]
  42. # add data type to meta data to allow query using data type
  43. meta_data["data_type"] = self.data_type.value
  44. meta_data["doc_id"] = doc_id
  45. url = meta_data["url"]
  46. chunks = self.get_chunks(content)
  47. for chunk in chunks:
  48. chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest()
  49. chunk_id = f"{app_id}--{chunk_id}" if app_id is not None else chunk_id
  50. if idMap.get(chunk_id) is None:
  51. idMap[chunk_id] = True
  52. chunk_ids.append(chunk_id)
  53. documents.append(chunk)
  54. metadatas.append(meta_data)
  55. return {
  56. "documents": documents,
  57. "ids": chunk_ids,
  58. "metadatas": metadatas,
  59. "doc_id": doc_id,
  60. }
  61. def get_chunks(self, content):
  62. """
  63. Returns chunks using text splitter instance.
  64. Override in child class if custom logic.
  65. """
  66. return self.text_splitter.split_text(content)
  67. def set_data_type(self, data_type: DataType):
  68. """
  69. set the data type of chunker
  70. """
  71. self.data_type = data_type
  72. # TODO: This should be done during initialization. This means it has to be done in the child classes.
  73. def get_word_count(self, documents):
  74. return sum([len(document.split(" ")) for document in documents])