Browse Source

Add dry_run to add() (#545)

Dev Khant 1 year ago
parent
commit
7c39d9f0c1

+ 1 - 1
docs/advanced/interface_types.mdx

@@ -36,7 +36,7 @@ print(naval_chat_bot.chat("what did the author say about happiness?"))
 
 #### Dry Run
 
-Dry Run is an option in the `query` and `chat` methods that allows the user to not send their constructed prompt to the LLM, to save money. It's used for [testing](/advanced/testing#dry-run).
+Dry Run is an option in the `add`, `query` and `chat` methods that allows the user to displays the data chunks and their constructed prompt which is not send to the LLM, to save money. It's used for [testing](/advanced/testing#dry-run).
 
 
 ### Stream Response

+ 16 - 5
docs/advanced/testing.mdx

@@ -6,11 +6,9 @@ title: '🧪 Testing'
 
 ### Dry Run
 
-Before you consume valueable tokens, you should make sure that the embedding you have done works and that it's receiving the correct document from the database.
+Before you consume valueable tokens, you should make sure that data chunks are properly created and the embedding you have done works and that it's receiving the correct document from the database.
 
-For this you can use the `dry_run` option in your `query` or `chat` method.
-
-Following the example above, add this to your script:
+- For `query` or `chat` method, you can add this to your script:
 
 ```python
 print(naval_chat_bot.query('Can you tell me who Naval Ravikant is?', dry_run=True))
@@ -26,4 +24,17 @@ A: Naval Ravikant is an Indian-American entrepreneur and investor.
 
 _The embedding is confirmed to work as expected. It returns the right document, even if the question is asked slightly different. No prompt tokens have been consumed._
 
-**The dry run will still consume tokens to embed your query, but it is only ~1/15 of the prompt.**
+The dry run will still consume tokens to embed your query, but it is only **~1/15 of the prompt.**
+
+
+- For `add` method, you can add this to your script:
+
+```python
+print(naval_chat_bot.add('https://navalmanack.s3.amazonaws.com/Eric-Jorgenson_The-Almanack-of-Naval-Ravikant_Final.pdf', dry_run=True))
+
+'''
+{'chunks': ['THE ALMANACK OF NAVAL RAVIKANT', 'GETTING RICH IS NOT JUST ABOUT LUCK;', 'HAPPINESS IS NOT JUST A TRAIT WE ARE'], 'metadata': [{'source': 'C:\\Users\\Dev\\AppData\\Local\\Temp\\tmp3g5mjoiz\\tmp.pdf', 'page': 0, 'url': 'https://navalmanack.s3.amazonaws.com/Eric-Jorgenson_The-Almanack-of-Naval-Ravikant_Final.pdf', 'data_type': 'pdf_file'}, {'source': 'C:\\Users\\Dev\\AppData\\Local\\Temp\\tmp3g5mjoiz\\tmp.pdf', 'page': 2, 'url': 'https://navalmanack.s3.amazonaws.com/Eric-Jorgenson_The-Almanack-of-Naval-Ravikant_Final.pdf', 'data_type': 'pdf_file'}, {'source': 'C:\\Users\\Dev\\AppData\\Local\\Temp\\tmp3g5mjoiz\\tmp.pdf', 'page': 2, 'url': 'https://navalmanack.s3.amazonaws.com/Eric-Jorgenson_The-Almanack-of-Naval-Ravikant_Final.pdf', 'data_type': 'pdf_file'}], 'count': 7358, 'type': <DataType.PDF_FILE: 'pdf_file'>}
+
+# less items to show for readability
+'''
+```

+ 16 - 2
embedchain/embedchain.py

@@ -125,6 +125,7 @@ class EmbedChain(JSONSerializable):
         data_type: Optional[DataType] = None,
         metadata: Optional[Dict[str, Any]] = None,
         config: Optional[AddConfig] = None,
+        dry_run=False,
     ):
         """
         Adds the data from the given URL to the vector db.
@@ -141,6 +142,8 @@ class EmbedChain(JSONSerializable):
         :param config: The `AddConfig` instance to use as configuration options., defaults to None
         :type config: Optional[AddConfig], optional
         :raises ValueError: Invalid data type
+        :param dry_run: Optional. A dry run displays the chunks to ensure that the loader and chunker work as intended.
+        deafaults to False
         :return: source_id, a md5-hash of the source, in hexadecimal representation.
         :rtype: str
         """
@@ -176,12 +179,17 @@ class EmbedChain(JSONSerializable):
 
         data_formatter = DataFormatter(data_type, config)
         self.user_asks.append([source, data_type.value, metadata])
-        documents, _metadatas, _ids, new_chunks = self.load_and_embed(
-            data_formatter.loader, data_formatter.chunker, source, metadata, source_id
+        documents, metadatas, _ids, new_chunks = self.load_and_embed(
+            data_formatter.loader, data_formatter.chunker, source, metadata, source_id, dry_run
         )
         if data_type in {DataType.DOCS_SITE}:
             self.is_docs_site_instance = True
 
+        if dry_run:
+            data_chunks_info = {"chunks": documents, "metadata": metadatas, "count": len(documents), "type": data_type}
+            logging.debug(f"Dry run info : {data_chunks_info}")
+            return data_chunks_info
+
         # Send anonymous telemetry
         if self.config.collect_metrics:
             # it's quicker to check the variable twice than to count words when they won't be submitted.
@@ -233,6 +241,7 @@ class EmbedChain(JSONSerializable):
         src: Any,
         metadata: Optional[Dict[str, Any]] = None,
         source_id: Optional[str] = None,
+        dry_run = False
     ) -> Tuple[List[str], Dict[str, Any], List[str], int]:
         """The loader to use to load the data.
 
@@ -247,6 +256,8 @@ class EmbedChain(JSONSerializable):
         :type metadata: Dict[str, Any], optional
         :param source_id: Hexadecimal hash of the source., defaults to None
         :type source_id: str, optional
+        :param dry_run: Optional. A dry run returns chunks and doesn't update DB.
+        :type dry_run: bool, defaults to False
         :return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
         :rtype: Tuple[List[str], Dict[str, Any], List[str], int]
         """
@@ -277,6 +288,9 @@ class EmbedChain(JSONSerializable):
             ids = list(data_dict.keys())
             documents, metadatas = zip(*data_dict.values())
 
+        if dry_run:
+            return list(documents), metadatas, ids, 0
+
         # Loop though all metadatas and add extras.
         new_metadatas = []
         for m in metadatas:

+ 27 - 1
tests/embedchain/test_add.py

@@ -3,7 +3,8 @@ import unittest
 from unittest.mock import MagicMock, patch
 
 from embedchain import App
-from embedchain.config import AppConfig
+from embedchain.config import AppConfig, AddConfig, ChunkerConfig
+from embedchain.models.data_type import DataType
 
 
 class TestApp(unittest.TestCase):
@@ -34,3 +35,28 @@ class TestApp(unittest.TestCase):
         data_type = "text"
         self.app.add("https://example.com", data_type=data_type, metadata={"meta": "meta-data"})
         self.assertEqual(self.app.user_asks, [["https://example.com", data_type, {"meta": "meta-data"}]])
+
+    @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
+    def test_dry_run(self):
+        """
+        Test that if dry_run == True then data chunks are returned.
+        """
+
+        chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0)
+        # We can't test with lorem ipsum because chunks are deduped, so would be recurring characters.
+        text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"""
+
+        result = self.app.add(source=text, config=AddConfig(chunker=chunker_config), dry_run=True)
+
+        chunks = result["chunks"]
+        metadata = result["metadata"]
+        count = result["count"]
+        data_type = result["type"]
+
+        self.assertEqual(len(chunks), len(text))
+        self.assertEqual(count, len(text))
+        self.assertEqual(data_type, DataType.TEXT)
+        for item in metadata:
+            self.assertIsInstance(item, dict)
+            self.assertIn(item["url"], "local")
+            self.assertIn(item["data_type"], "text")