|
@@ -11,6 +11,7 @@ except ImportError:
|
|
|
|
|
|
from embedchain.config import ElasticsearchDBConfig
|
|
|
from embedchain.helpers.json_serializable import register_deserializable
|
|
|
+from embedchain.utils import chunks
|
|
|
from embedchain.vectordb.base import BaseVectorDB
|
|
|
|
|
|
|
|
@@ -20,6 +21,8 @@ class ElasticsearchDB(BaseVectorDB):
|
|
|
Elasticsearch as vector database
|
|
|
"""
|
|
|
|
|
|
+ BATCH_SIZE = 100
|
|
|
+
|
|
|
def __init__(
|
|
|
self,
|
|
|
config: Optional[ElasticsearchDBConfig] = None,
|
|
@@ -43,7 +46,14 @@ class ElasticsearchDB(BaseVectorDB):
|
|
|
"Please make sure the type is right and that you are passing an instance."
|
|
|
)
|
|
|
self.config = config or es_config
|
|
|
- self.client = Elasticsearch(self.config.ES_URL, **self.config.ES_EXTRA_PARAMS)
|
|
|
+ if self.config.ES_URL:
|
|
|
+ self.client = Elasticsearch(self.config.ES_URL, **self.config.ES_EXTRA_PARAMS)
|
|
|
+ elif self.config.CLOUD_ID:
|
|
|
+ self.client = Elasticsearch(cloud_id=self.config.CLOUD_ID, **self.config.ES_EXTRA_PARAMS)
|
|
|
+ else:
|
|
|
+ raise ValueError(
|
|
|
+ "Something is wrong with your config. Please check again - `https://docs.embedchain.ai/components/vector-databases#elasticsearch`" # noqa: E501
|
|
|
+ )
|
|
|
|
|
|
# Call parent init here because embedder is needed
|
|
|
super().__init__(config=self.config)
|
|
@@ -121,19 +131,29 @@ class ElasticsearchDB(BaseVectorDB):
|
|
|
:type skip_embedding: bool
|
|
|
"""
|
|
|
|
|
|
- docs = []
|
|
|
if not skip_embedding:
|
|
|
embeddings = self.embedder.embedding_fn(documents)
|
|
|
|
|
|
- for id, text, metadata, embeddings in zip(ids, documents, metadatas, embeddings):
|
|
|
- docs.append(
|
|
|
- {
|
|
|
- "_index": self._get_index(),
|
|
|
- "_id": id,
|
|
|
- "_source": {"text": text, "metadata": metadata, "embeddings": embeddings},
|
|
|
- }
|
|
|
- )
|
|
|
- bulk(self.client, docs)
|
|
|
+ for chunk in chunks(
|
|
|
+ list(zip(ids, documents, metadatas, embeddings)), self.BATCH_SIZE, desc="Inserting batches in elasticsearch"
|
|
|
+ ): # noqa: E501
|
|
|
+ ids, docs, metadatas, embeddings = [], [], [], []
|
|
|
+ for id, text, metadata, embedding in chunk:
|
|
|
+ ids.append(id)
|
|
|
+ docs.append(text)
|
|
|
+ metadatas.append(metadata)
|
|
|
+ embeddings.append(embedding)
|
|
|
+
|
|
|
+ batch_docs = []
|
|
|
+ for id, text, metadata, embedding in zip(ids, docs, metadatas, embeddings):
|
|
|
+ batch_docs.append(
|
|
|
+ {
|
|
|
+ "_index": self._get_index(),
|
|
|
+ "_id": id,
|
|
|
+ "_source": {"text": text, "metadata": metadata, "embeddings": embedding},
|
|
|
+ }
|
|
|
+ )
|
|
|
+ bulk(self.client, batch_docs, **kwargs)
|
|
|
self.client.indices.refresh(index=self._get_index())
|
|
|
|
|
|
def query(
|