瀏覽代碼

refactor: Use src instead of url as argument value (#111)

cachho 2 年之前
父節點
當前提交
51adc5c886
共有 2 個文件被更改,包括 17 次插入9 次删除
  1. 11 3
      embedchain/chunkers/base_chunker.py
  2. 6 6
      embedchain/embedchain.py

+ 11 - 3
embedchain/chunkers/base_chunker.py

@@ -5,16 +5,24 @@ class BaseChunker:
     def __init__(self, text_splitter):
         self.text_splitter = text_splitter
 
-    def create_chunks(self, loader, url):
+    def create_chunks(self, loader, src):
+        """
+        Loads data and chunks it.
+
+        :param loader: The loader which's `load_data` method is used to create the raw data.
+        :param src: The data to be handled by the loader. Can be a URL for remote sources or local content for local loaders. 
+        """
         documents = []
         ids = []
-        datas = loader.load_data(url)
+        datas = loader.load_data(src)
         metadatas = []
         for data in datas:
             content = data["content"]
             meta_data = data["meta_data"]
-            chunks = self.text_splitter.split_text(content)
             url = meta_data["url"]
+
+            chunks = self.text_splitter.split_text(content)
+
             for chunk in chunks:
                 chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest()
                 ids.append(chunk_id)

+ 6 - 6
embedchain/embedchain.py

@@ -121,22 +121,22 @@ class EmbedChain:
         self.user_asks.append([data_type, content])
         self.load_and_embed(loader, chunker, content)
 
-    def load_and_embed(self, loader, chunker, url):
+    def load_and_embed(self, loader, chunker, src):
         """
         Loads the data from the given URL, chunks it, and adds it to the database.
 
         :param loader: The loader to use to load the data.
         :param chunker: The chunker to use to chunk the data.
-        :param url: The URL where the data is located.
+        :param src: The data to be handled by the loader. Can be a URL for remote sources or local content for local loaders.
         """
-        embeddings_data = chunker.create_chunks(loader, url)
+        embeddings_data = chunker.create_chunks(loader, src)
         documents = embeddings_data["documents"]
         metadatas = embeddings_data["metadatas"]
         ids = embeddings_data["ids"]
         # get existing ids, and discard doc if any common id exist.
         existing_docs = self.collection.get(
             ids=ids,
-            # where={"url": url}
+            # where={"url": src}
         )
         existing_ids = set(existing_docs["ids"])
 
@@ -145,7 +145,7 @@ class EmbedChain:
             data_dict = {id: value for id, value in data_dict.items() if id not in existing_ids}
 
             if not data_dict:
-                print(f"All data from {url} already exists in the database.")
+                print(f"All data from {src} already exists in the database.")
                 return
 
             ids = list(data_dict.keys())
@@ -156,7 +156,7 @@ class EmbedChain:
             metadatas=list(metadatas),
             ids=ids
         )
-        print(f"Successfully saved {url}. Total chunks count: {self.collection.count()}")
+        print(f"Successfully saved {src}. Total chunks count: {self.collection.count()}")
 
     def _format_result(self, results):
         return [