Browse Source

fix: Elasticsearch - use correct class attributes (#566)

cachho 1 year ago
parent
commit
b0d8711b65
2 changed files with 7 additions and 6 deletions
  1. 1 0
      embedchain/embedder/base_embedder.py
  2. 6 6
      embedchain/vectordb/elasticsearch_db.py

+ 1 - 0
embedchain/embedder/base_embedder.py

@@ -24,6 +24,7 @@ class BaseEmbedder:
             self.config = BaseEmbedderConfig()
             self.config = BaseEmbedderConfig()
         else:
         else:
             self.config = config
             self.config = config
+        self.vector_dimension: int
 
 
     def set_embedding_fn(self, embedding_fn: Callable[[list[str]], list[str]]):
     def set_embedding_fn(self, embedding_fn: Callable[[list[str]], list[str]]):
         if not hasattr(embedding_fn, "__call__"):
         if not hasattr(embedding_fn, "__call__"):

+ 6 - 6
embedchain/vectordb/elasticsearch_db.py

@@ -81,7 +81,7 @@ class ElasticsearchDB(BaseVectorDB):
         :param ids: ids of docs
         :param ids: ids of docs
         """
         """
         docs = []
         docs = []
-        embeddings = self.config.embedding_fn(documents)
+        embeddings = self.embedder.embedding_fn(documents)
         for id, text, metadata, embeddings in zip(ids, documents, metadatas, embeddings):
         for id, text, metadata, embeddings in zip(ids, documents, metadatas, embeddings):
             docs.append(
             docs.append(
                 {
                 {
@@ -101,7 +101,7 @@ class ElasticsearchDB(BaseVectorDB):
         :param n_results: no of similar documents to fetch from database
         :param n_results: no of similar documents to fetch from database
         :param where: Optional. to filter data
         :param where: Optional. to filter data
         """
         """
-        input_query_vector = self.config.embedding_fn(input_query)
+        input_query_vector = self.embedder.embedding_fn(input_query)
         query_vector = input_query_vector[0]
         query_vector = input_query_vector[0]
         query = {
         query = {
             "script_score": {
             "script_score": {
@@ -126,17 +126,17 @@ class ElasticsearchDB(BaseVectorDB):
 
 
     def count(self) -> int:
     def count(self) -> int:
         query = {"match_all": {}}
         query = {"match_all": {}}
-        response = self.client.count(index=self.es_index, query=query)
+        response = self.client.count(index=self._get_index(), query=query)
         doc_count = response["count"]
         doc_count = response["count"]
         return doc_count
         return doc_count
 
 
     def reset(self):
     def reset(self):
         # Delete all data from the database
         # Delete all data from the database
-        if self.client.indices.exists(index=self.es_index):
+        if self.client.indices.exists(index=self._get_index()):
             # delete index in Es
             # delete index in Es
-            self.client.indices.delete(index=self.es_index)
+            self.client.indices.delete(index=self._get_index())
 
 
     def _get_index(self):
     def _get_index(self):
         # NOTE: The method is preferred to an attribute, because if collection name changes,
         # NOTE: The method is preferred to an attribute, because if collection name changes,
         # it's always up-to-date.
         # it's always up-to-date.
-        return f"{self.config.collection_name}_{self.config.vector_dim}"
+        return f"{self.config.collection_name}_{self.embedder.vector_dimension}"