Browse Source

[Improvement] Set a default app id if not provided in the app configuration (#1300)

Deshraj Yadav 1 year ago
parent
commit
faacfeb891
4 changed files with 42 additions and 41 deletions
  1. 23 17
      docs/get-started/quickstart.mdx
  2. 6 11
      embedchain/app.py
  3. 1 1
      pyproject.toml
  4. 12 12
      tests/vectordb/test_qdrant.py

+ 23 - 17
docs/get-started/quickstart.mdx

@@ -31,41 +31,47 @@ This section gives a quickstart example of using Mistral as the Open source LLM
 We are using Mistral hosted at Hugging Face, so will you need a Hugging Face token to run this example. Its *free* and you can create one [here](https://huggingface.co/docs/hub/security-tokens).
 
 <CodeGroup>
-```python quickstart.py
+```python huggingface_demo.py
 import os
-# replace this with your HF key
+# Replace this with your HF token
 os.environ["HUGGINGFACE_ACCESS_TOKEN"] = "hf_xxxx"
 
 from embedchain import App
-app = App.from_config("mistral.yaml")
+
+config = {
+  'llm': {
+    'provider': 'huggingface',
+    'config': {
+      'model': 'mistralai/Mistral-7B-Instruct-v0.2',
+      'top_p': 0.5
+    }
+  },
+  'embedder': {
+    'provider': 'huggingface',
+    'config': {
+      'model': 'sentence-transformers/all-mpnet-base-v2'
+    }
+  }
+}
+app = App.from_config(config=config)
 app.add("https://www.forbes.com/profile/elon-musk")
 app.add("https://en.wikipedia.org/wiki/Elon_Musk")
 app.query("What is the net worth of Elon Musk today?")
 # Answer: The net worth of Elon Musk today is $258.7 billion.
 ```
-```yaml mistral.yaml
-llm:
-  provider: huggingface
-  config:
-    model: 'mistralai/Mistral-7B-Instruct-v0.2'
-    top_p: 0.5
-embedder:
-  provider: huggingface
-  config:
-    model: 'sentence-transformers/all-mpnet-base-v2'
-```
 </CodeGroup>
 
 ## Paid Models
 
 In this section, we will use both LLM and embedding model from OpenAI.
 
-```python quickstart.py
+```python openai_demo.py
 import os
-# replace this with your OpenAI key
+from embedchain import App
+
+# Replace this with your OpenAI key
 os.environ["OPENAI_API_KEY"] = "sk-xxxx"
 
-from embedchain import App
 app = App()
 app.add("https://www.forbes.com/profile/elon-musk")
 app.add("https://en.wikipedia.org/wiki/Elon_Musk")

+ 6 - 11
embedchain/app.py

@@ -3,21 +3,15 @@ import concurrent.futures
 import json
 import logging
 import os
-import uuid
 from typing import Any, Optional, Union
 
 import requests
 import yaml
 from tqdm import tqdm
 
-from embedchain.cache import (
-    Config,
-    ExactMatchEvaluation,
-    SearchDistanceEvaluation,
-    cache,
-    gptcache_data_manager,
-    gptcache_pre_function,
-)
+from embedchain.cache import (Config, ExactMatchEvaluation,
+                              SearchDistanceEvaluation, cache,
+                              gptcache_data_manager, gptcache_pre_function)
 from embedchain.client import Client
 from embedchain.config import AppConfig, CacheConfig, ChunkerConfig
 from embedchain.core.db.database import get_session, init_db, setup_engine
@@ -26,7 +20,8 @@ from embedchain.embedchain import EmbedChain
 from embedchain.embedder.base import BaseEmbedder
 from embedchain.embedder.openai import OpenAIEmbedder
 from embedchain.evaluation.base import BaseMetric
-from embedchain.evaluation.metrics import AnswerRelevance, ContextRelevance, Groundedness
+from embedchain.evaluation.metrics import (AnswerRelevance, ContextRelevance,
+                                           Groundedness)
 from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
 from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.llm.base import BaseLlm
@@ -106,7 +101,7 @@ class App(EmbedChain):
 
         self.config = config or AppConfig()
         self.name = self.config.name
-        self.config.id = self.local_id = str(uuid.uuid4()) if self.config.id is None else self.config.id
+        self.config.id = self.local_id = "default-app-id" if self.config.id is None else self.config.id
 
         if id is not None:
             # Init client first since user is trying to fetch the pipeline

+ 1 - 1
pyproject.toml

@@ -1,6 +1,6 @@
 [tool.poetry]
 name = "embedchain"
-version = "0.1.91"
+version = "0.1.92"
 description = "Simplest open source retrieval(RAG) framework"
 authors = [
     "Taranjeet Singh <taranjeet@embedchain.ai>",

+ 12 - 12
tests/vectordb/test_qdrant.py

@@ -29,7 +29,7 @@ class TestQdrantDB(unittest.TestCase):
     def test_initialize(self, qdrant_client_mock):
         # Set the embedder
         embedder = BaseEmbedder()
-        embedder.set_vector_dimension(1526)
+        embedder.set_vector_dimension(1536)
         embedder.set_embedding_fn(mock_embedding_fn)
 
         # Create a Qdrant instance
@@ -37,7 +37,7 @@ class TestQdrantDB(unittest.TestCase):
         app_config = AppConfig(collect_metrics=False)
         App(config=app_config, db=db, embedding_model=embedder)
 
-        self.assertEqual(db.collection_name, "embedchain-store-1526")
+        self.assertEqual(db.collection_name, "embedchain-store-1536")
         self.assertEqual(db.client, qdrant_client_mock.return_value)
         qdrant_client_mock.return_value.get_collections.assert_called_once()
 
@@ -47,7 +47,7 @@ class TestQdrantDB(unittest.TestCase):
 
         # Set the embedder
         embedder = BaseEmbedder()
-        embedder.set_vector_dimension(1526)
+        embedder.set_vector_dimension(1536)
         embedder.set_embedding_fn(mock_embedding_fn)
 
         # Create a Qdrant instance
@@ -67,7 +67,7 @@ class TestQdrantDB(unittest.TestCase):
 
         # Set the embedder
         embedder = BaseEmbedder()
-        embedder.set_vector_dimension(1526)
+        embedder.set_vector_dimension(1536)
         embedder.set_embedding_fn(mock_embedding_fn)
 
         # Create a Qdrant instance
@@ -80,9 +80,9 @@ class TestQdrantDB(unittest.TestCase):
         ids = ["123", "456"]
         db.add(documents, metadatas, ids)
         qdrant_client_mock.return_value.upsert.assert_called_once_with(
-            collection_name="embedchain-store-1526",
+            collection_name="embedchain-store-1536",
             points=Batch(
-                ids=["def", "ghi"],
+                ids=["abc", "def"],
                 payloads=[
                     {
                         "identifier": "123",
@@ -103,7 +103,7 @@ class TestQdrantDB(unittest.TestCase):
     def test_query(self, qdrant_client_mock):
         # Set the embedder
         embedder = BaseEmbedder()
-        embedder.set_vector_dimension(1526)
+        embedder.set_vector_dimension(1536)
         embedder.set_embedding_fn(mock_embedding_fn)
 
         # Create a Qdrant instance
@@ -115,7 +115,7 @@ class TestQdrantDB(unittest.TestCase):
         db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"})
 
         qdrant_client_mock.return_value.search.assert_called_once_with(
-            collection_name="embedchain-store-1526",
+            collection_name="embedchain-store-1536",
             query_filter=models.Filter(
                 must=[
                     models.FieldCondition(
@@ -134,7 +134,7 @@ class TestQdrantDB(unittest.TestCase):
     def test_count(self, qdrant_client_mock):
         # Set the embedder
         embedder = BaseEmbedder()
-        embedder.set_vector_dimension(1526)
+        embedder.set_vector_dimension(1536)
         embedder.set_embedding_fn(mock_embedding_fn)
 
         # Create a Qdrant instance
@@ -143,13 +143,13 @@ class TestQdrantDB(unittest.TestCase):
         App(config=app_config, db=db, embedding_model=embedder)
 
         db.count()
-        qdrant_client_mock.return_value.get_collection.assert_called_once_with(collection_name="embedchain-store-1526")
+        qdrant_client_mock.return_value.get_collection.assert_called_once_with(collection_name="embedchain-store-1536")
 
     @patch("embedchain.vectordb.qdrant.QdrantClient")
     def test_reset(self, qdrant_client_mock):
         # Set the embedder
         embedder = BaseEmbedder()
-        embedder.set_vector_dimension(1526)
+        embedder.set_vector_dimension(1536)
         embedder.set_embedding_fn(mock_embedding_fn)
 
         # Create a Qdrant instance
@@ -159,7 +159,7 @@ class TestQdrantDB(unittest.TestCase):
 
         db.reset()
         qdrant_client_mock.return_value.delete_collection.assert_called_once_with(
-            collection_name="embedchain-store-1526"
+            collection_name="embedchain-store-1536"
         )