Browse Source

[Feature/Improvements] Delete data sources from metadata db when using `app.delete()` (#1286)

Deshraj Yadav 1 year ago
parent
commit
752f638cfc

+ 2 - 0
embedchain/config/vectordb/pinecone.py

@@ -16,6 +16,7 @@ class PineconeDBConfig(BaseVectorDbConfig):
         pod_config: Optional[dict[str, any]] = None,
         serverless_config: Optional[dict[str, any]] = None,
         hybrid_search: bool = False,
+        bm25_encoder: any = None,
         **extra_params: dict[str, any],
     ):
         self.metric = metric
@@ -24,6 +25,7 @@ class PineconeDBConfig(BaseVectorDbConfig):
         self.vector_dimension = vector_dimension
         self.extra_params = extra_params
         self.hybrid_search = hybrid_search
+        self.bm25_encoder = bm25_encoder
         if pod_config is None and serverless_config is None:
             # If no config is provided, use the default pod spec config
             pod_environment = os.environ.get("PINECONE_ENV", "gcp-starter")

+ 15 - 4
embedchain/embedchain.py

@@ -6,17 +6,20 @@ from typing import Any, Optional, Union
 from dotenv import load_dotenv
 from langchain.docstore.document import Document
 
-from embedchain.cache import adapt, get_gptcache_session, gptcache_data_convert, gptcache_update_cache_callback
+from embedchain.cache import (adapt, get_gptcache_session,
+                              gptcache_data_convert,
+                              gptcache_update_cache_callback)
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
 from embedchain.config.base_app_config import BaseAppConfig
-from embedchain.core.db.models import DataSource
+from embedchain.core.db.models import ChatHistory, DataSource
 from embedchain.data_formatter import DataFormatter
 from embedchain.embedder.base import BaseEmbedder
 from embedchain.helpers.json_serializable import JSONSerializable
 from embedchain.llm.base import BaseLlm
 from embedchain.loaders.base_loader import BaseLoader
-from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
+from embedchain.models.data_type import (DataType, DirectDataType,
+                                         IndirectDataType, SpecialDataType)
 from embedchain.utils.misc import detect_datatype, is_valid_json_string
 from embedchain.vectordb.base import BaseVectorDB
 
@@ -642,9 +645,10 @@ class EmbedChain(JSONSerializable):
         """
         try:
             self.db_session.query(DataSource).filter_by(app_id=self.config.id).delete()
+            self.db_session.query(ChatHistory).filter_by(app_id=self.config.id).delete()
             self.db_session.commit()
         except Exception as e:
-            logging.error(f"Error deleting chat history: {e}")
+            logging.error(f"Error deleting data sources: {e}")
             self.db_session.rollback()
             return None
         self.db.reset()
@@ -682,6 +686,13 @@ class EmbedChain(JSONSerializable):
         :param source_hash: The hash of the source.
         :type source_hash: str
         """
+        try:
+            self.db_session.query(DataSource).filter_by(hash=source_id, app_id=self.config.id).delete()
+            self.db_session.commit()
+        except Exception as e:
+            logging.error(f"Error deleting data sources: {e}")
+            self.db_session.rollback()
+            return None
         self.db.delete(where={"hash": source_id})
         logging.info(f"Successfully deleted {source_id}")
         # Send anonymous telemetry

+ 1 - 2
embedchain/vectordb/pinecone.py

@@ -49,9 +49,8 @@ class PineconeDB(BaseVectorDB):
         # Setup BM25Encoder if sparse vectors are to be used
         self.bm25_encoder = None
         if self.config.hybrid_search:
-            # TODO: Add support for fitting BM25Encoder on any corpus
             logging.info("Initializing BM25Encoder for sparse vectors..")
-            self.bm25_encoder = BM25Encoder.default()
+            self.bm25_encoder = self.config.bm25_encoder if self.config.bm25_encoder else BM25Encoder.default()
 
         # Call parent init here because embedder is needed
         super().__init__(config=self.config)

+ 1 - 1
pyproject.toml

@@ -1,6 +1,6 @@
 [tool.poetry]
 name = "embedchain"
-version = "0.1.86"
+version = "0.1.87"
 description = "Simplest open source retrieval(RAG) framework"
 authors = [
     "Taranjeet Singh <taranjeet@embedchain.ai>",