Browse Source

Chunkers: Refactor each chunker & add base class

Adds a base chunker from which any chunker can inherit.
Existing chunkers are refactored to inherit from this base
chunker.
Taranjeet Singh 2 years ago
parent
commit
4329caa17c

+ 27 - 0
embedchain/chunkers/base_chunker.py

@@ -0,0 +1,27 @@
+import hashlib
+
+
+class BaseChunker:
+    def __init__(self, text_splitter):
+        self.text_splitter = text_splitter
+
+    def create_chunks(self, loader, url):
+        documents = []
+        ids = []
+        datas = loader.load_data(url)
+        metadatas = []
+        for data in datas:
+            content = data["content"]
+            meta_data = data["meta_data"]
+            chunks = self.text_splitter.split_text(content)
+            url = meta_data["url"]
+            for chunk in chunks:
+                chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest()
+                ids.append(chunk_id)
+                documents.append(chunk)
+                metadatas.append(meta_data)
+        return {
+            "documents": documents,
+            "ids": ids,
+            "metadatas": metadatas,
+        }

+ 5 - 25
embedchain/chunkers/pdf_file.py

@@ -1,4 +1,4 @@
-import hashlib
+from embedchain.chunkers.base_chunker import BaseChunker
 
 from langchain.text_splitter import RecursiveCharacterTextSplitter
 
@@ -9,28 +9,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = {
     "length_function": len,
 }
 
-TEXT_SPLITTER = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS)
 
-
-class PdfFileChunker:
-
-    def create_chunks(self, loader, url):
-        documents = []
-        ids = []
-        datas = loader.load_data(url)
-        metadatas = []
-        for data in datas:
-            content = data["content"]
-            meta_data = data["meta_data"]
-            chunks = TEXT_SPLITTER.split_text(content)
-            url = meta_data["url"]
-            for chunk in chunks:
-                chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest()
-                ids.append(chunk_id)
-                documents.append(chunk)
-                metadatas.append(meta_data)
-        return {
-            "documents": documents,
-            "ids": ids,
-            "metadatas": metadatas,
-        }
+class PdfFileChunker(BaseChunker):
+    def __init__(self):
+        text_splitter = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS)
+        super().__init__(text_splitter)

+ 5 - 25
embedchain/chunkers/website.py

@@ -1,4 +1,4 @@
-import hashlib
+from embedchain.chunkers.base_chunker import BaseChunker
 
 from langchain.text_splitter import RecursiveCharacterTextSplitter
 
@@ -9,28 +9,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = {
     "length_function": len,
 }
 
-TEXT_SPLITTER = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS)
 
-
-class WebsiteChunker:
-
-    def create_chunks(self, loader, url):
-        documents = []
-        ids = []
-        datas = loader.load_data(url)
-        metadatas = []
-        for data in datas:
-            content = data["content"]
-            meta_data = data["meta_data"]
-            chunks = TEXT_SPLITTER.split_text(content)
-            url = meta_data["url"]
-            for chunk in chunks:
-                chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest()
-                ids.append(chunk_id)
-                documents.append(chunk)
-                metadatas.append(meta_data)
-        return {
-            "documents": documents,
-            "ids": ids,
-            "metadatas": metadatas,
-        }
+class WebsiteChunker(BaseChunker):
+    def __init__(self):
+        text_splitter = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS)
+        super().__init__(text_splitter)

+ 5 - 25
embedchain/chunkers/youtube_video.py

@@ -1,4 +1,4 @@
-import hashlib
+from embedchain.chunkers.base_chunker import BaseChunker
 
 from langchain.text_splitter import RecursiveCharacterTextSplitter
 
@@ -9,28 +9,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = {
     "length_function": len,
 }
 
-TEXT_SPLITTER = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS)
 
-
-class YoutubeVideoChunker:
-
-    def create_chunks(self, loader, url):
-        documents = []
-        ids = []
-        datas = loader.load_data(url)
-        metadatas = []
-        for data in datas:
-            content = data["content"]
-            meta_data = data["meta_data"]
-            chunks = TEXT_SPLITTER.split_text(content)
-            url = meta_data["url"]
-            for chunk in chunks:
-                chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest()
-                ids.append(chunk_id)
-                documents.append(chunk)
-                metadatas.append(meta_data)
-        return {
-            "documents": documents,
-            "ids": ids,
-            "metadatas": metadatas,
-        }
+class YoutubeVideoChunker(BaseChunker):
+    def __init__(self):
+        text_splitter = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS)
+        super().__init__(text_splitter)