Explorar o código

[Improvements] allow setting up the elasticsearch cloud instance (#997)

Co-authored-by: Deven Patel <deven298@yahoo.com>
Deven Patel hai 1 ano
pai
achega
0ea8ab228c

+ 10 - 5
docs/components/vector-databases.mdx

@@ -58,6 +58,12 @@ Install related dependencies using the following command:
 pip install --upgrade 'embedchain[elasticsearch]'
 ```
 
+<Note>
+You can configure the Elasticsearch connection by providing either `es_url` or `cloud_id`. If you are using the Elasticsearch Service on Elastic Cloud, you can find the `cloud_id` on the [Elastic Cloud dashboard](https://cloud.elastic.co/deployments).
+</Note>
+
+You can authorize the connection to Elasticsearch by providing either `basic_auth`, `api_key`, or `bearer_auth`.
+
 <CodeGroup>
 
 ```python main.py
@@ -72,11 +78,10 @@ vectordb:
   provider: elasticsearch
   config:
     collection_name: 'es-index'
-    es_url: http://localhost:9200
-    http_auth:
-      - admin
-      - admin
-    api_key: xxx
+    cloud_id: 'deployment-name:xxxx'
+    basic_auth:
+      - elastic
+      - <your_password>
     verify_certs: false
 ```
 </CodeGroup>

+ 1 - 1
embedchain/config/llm/base.py

@@ -57,7 +57,7 @@ class BaseLlmConfig(BaseConfig):
 
     def __init__(
         self,
-        number_documents: int = 1,
+        number_documents: int = 3,
         template: Optional[Template] = None,
         model: Optional[str] = None,
         temperature: float = 0,

+ 7 - 4
embedchain/config/vectordb/elasticsearch.py

@@ -12,6 +12,7 @@ class ElasticsearchDBConfig(BaseVectorDbConfig):
         collection_name: Optional[str] = None,
         dir: Optional[str] = None,
         es_url: Union[str, List[str]] = None,
+        cloud_id: Optional[str] = None,
         **ES_EXTRA_PARAMS: Dict[str, any],
     ):
         """
@@ -26,12 +27,15 @@ class ElasticsearchDBConfig(BaseVectorDbConfig):
         :param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
         :type ES_EXTRA_PARAMS: Dict[str, Any], optional
         """
+        if es_url and cloud_id:
+            raise ValueError("Only one of `es_url` and `cloud_id` can be set.")
         # self, es_url: Union[str, List[str]] = None, **ES_EXTRA_PARAMS: Dict[str, any]):
         self.ES_URL = es_url or os.environ.get("ELASTICSEARCH_URL")
-        if not self.ES_URL:
+        self.CLOUD_ID = cloud_id or os.environ.get("ELASTICSEARCH_CLOUD_ID")
+        if not self.ES_URL and not self.CLOUD_ID:
             raise AttributeError(
-                "Elasticsearch needs a URL attribute, "
-                "this can either be passed to `ElasticsearchDBConfig` or as `ELASTICSEARCH_URL` in `.env`"
+                "Elasticsearch needs a URL or CLOUD_ID attribute, "
+                "this can either be passed to `ElasticsearchDBConfig` or as `ELASTICSEARCH_URL` or `ELASTICSEARCH_CLOUD_ID` in `.env`"  # noqa: E501
             )
         self.ES_EXTRA_PARAMS = ES_EXTRA_PARAMS
         # Load API key from .env if it's not explicitly passed.
@@ -40,7 +44,6 @@ class ElasticsearchDBConfig(BaseVectorDbConfig):
             not self.ES_EXTRA_PARAMS.get("api_key")
             and not self.ES_EXTRA_PARAMS.get("basic_auth")
             and not self.ES_EXTRA_PARAMS.get("bearer_auth")
-            and not self.ES_EXTRA_PARAMS.get("http_auth")
         ):
             self.ES_EXTRA_PARAMS["api_key"] = os.environ.get("ELASTICSEARCH_API_KEY")
         super().__init__(collection_name=collection_name, dir=dir)

+ 31 - 11
embedchain/vectordb/elasticsearch.py

@@ -11,6 +11,7 @@ except ImportError:
 
 from embedchain.config import ElasticsearchDBConfig
 from embedchain.helpers.json_serializable import register_deserializable
+from embedchain.utils import chunks
 from embedchain.vectordb.base import BaseVectorDB
 
 
@@ -20,6 +21,8 @@ class ElasticsearchDB(BaseVectorDB):
     Elasticsearch as vector database
     """
 
+    BATCH_SIZE = 100
+
     def __init__(
         self,
         config: Optional[ElasticsearchDBConfig] = None,
@@ -43,7 +46,14 @@ class ElasticsearchDB(BaseVectorDB):
                     "Please make sure the type is right and that you are passing an instance."
                 )
             self.config = config or es_config
-        self.client = Elasticsearch(self.config.ES_URL, **self.config.ES_EXTRA_PARAMS)
+        if self.config.ES_URL:
+            self.client = Elasticsearch(self.config.ES_URL, **self.config.ES_EXTRA_PARAMS)
+        elif self.config.CLOUD_ID:
+            self.client = Elasticsearch(cloud_id=self.config.CLOUD_ID, **self.config.ES_EXTRA_PARAMS)
+        else:
+            raise ValueError(
+                "Something is wrong with your config. Please check again - `https://docs.embedchain.ai/components/vector-databases#elasticsearch`"  # noqa: E501
+            )
 
         # Call parent init here because embedder is needed
         super().__init__(config=self.config)
@@ -121,19 +131,29 @@ class ElasticsearchDB(BaseVectorDB):
         :type skip_embedding: bool
         """
 
-        docs = []
         if not skip_embedding:
             embeddings = self.embedder.embedding_fn(documents)
 
-        for id, text, metadata, embeddings in zip(ids, documents, metadatas, embeddings):
-            docs.append(
-                {
-                    "_index": self._get_index(),
-                    "_id": id,
-                    "_source": {"text": text, "metadata": metadata, "embeddings": embeddings},
-                }
-            )
-        bulk(self.client, docs)
+        for chunk in chunks(
+            list(zip(ids, documents, metadatas, embeddings)), self.BATCH_SIZE, desc="Inserting batches in elasticsearch"
+        ):  # noqa: E501
+            ids, docs, metadatas, embeddings = [], [], [], []
+            for id, text, metadata, embedding in chunk:
+                ids.append(id)
+                docs.append(text)
+                metadatas.append(metadata)
+                embeddings.append(embedding)
+
+            batch_docs = []
+            for id, text, metadata, embedding in zip(ids, docs, metadatas, embeddings):
+                batch_docs.append(
+                    {
+                        "_index": self._get_index(),
+                        "_id": id,
+                        "_source": {"text": text, "metadata": metadata, "embeddings": embedding},
+                    }
+                )
+            bulk(self.client, batch_docs, **kwargs)
         self.client.indices.refresh(index=self._get_index())
 
     def query(