images.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import hashlib
  2. import logging
  3. from typing import Optional
  4. from langchain.text_splitter import RecursiveCharacterTextSplitter
  5. from embedchain.chunkers.base_chunker import BaseChunker
  6. from embedchain.config.add_config import ChunkerConfig
  7. class ImagesChunker(BaseChunker):
  8. """Chunker for an Image."""
  9. def __init__(self, config: Optional[ChunkerConfig] = None):
  10. if config is None:
  11. config = ChunkerConfig(chunk_size=300, chunk_overlap=0, length_function=len)
  12. image_splitter = RecursiveCharacterTextSplitter(
  13. chunk_size=config.chunk_size,
  14. chunk_overlap=config.chunk_overlap,
  15. length_function=config.length_function,
  16. )
  17. super().__init__(image_splitter)
  18. def create_chunks(self, loader, src, app_id=None, config: Optional[ChunkerConfig] = None):
  19. """
  20. Loads the image(s), and creates their corresponding embedding. This creates one chunk for each image
  21. :param loader: The loader whose `load_data` method is used to create
  22. the raw data.
  23. :param src: The data to be handled by the loader. Can be a URL for
  24. remote sources or local content for local loaders.
  25. """
  26. documents = []
  27. embeddings = []
  28. ids = []
  29. min_chunk_size = config.min_chunk_size if config is not None else 0
  30. logging.info(f"[INFO] Skipping chunks smaller than {min_chunk_size} characters")
  31. data_result = loader.load_data(src)
  32. data_records = data_result["data"]
  33. doc_id = data_result["doc_id"]
  34. doc_id = f"{app_id}--{doc_id}" if app_id is not None else doc_id
  35. metadatas = []
  36. for data in data_records:
  37. meta_data = data["meta_data"]
  38. # add data type to meta data to allow query using data type
  39. meta_data["data_type"] = self.data_type.value
  40. chunk_id = hashlib.sha256(meta_data["url"].encode()).hexdigest()
  41. ids.append(chunk_id)
  42. documents.append(data["content"])
  43. embeddings.append(data["embedding"])
  44. meta_data["doc_id"] = doc_id
  45. metadatas.append(meta_data)
  46. return {
  47. "documents": documents,
  48. "embeddings": embeddings,
  49. "ids": ids,
  50. "metadatas": metadatas,
  51. "doc_id": doc_id,
  52. }
  53. def get_word_count(self, documents):
  54. """
  55. The number of chunks and the corresponding word count for an image is fixed to 1, as 1 embedding is created for
  56. each image
  57. """
  58. return 1