Преглед изворни кода

Add support for image dataset (#571)

Co-authored-by: Rupesh Bansal <rupeshbansal@Shankars-MacBook-Air.local>
Rupesh Bansal пре 1 година
родитељ
комит
d0af018b8d

+ 3 - 0
embedchain/chunkers/base_chunker.py

@@ -66,3 +66,6 @@ class BaseChunker(JSONSerializable):
         self.data_type = data_type
 
         # TODO: This should be done during initialization. This means it has to be done in the child classes.
+
+    def get_word_count(self, documents):
+        return sum([len(document.split(" ")) for document in documents])

+ 63 - 0
embedchain/chunkers/images.py

@@ -0,0 +1,63 @@
+import hashlib
+from typing import Optional
+
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+
+from embedchain.chunkers.base_chunker import BaseChunker
+from embedchain.config.add_config import ChunkerConfig
+
+
+class ImagesChunker(BaseChunker):
+    """Chunker for an Image."""
+
+    def __init__(self, config: Optional[ChunkerConfig] = None):
+        if config is None:
+            config = ChunkerConfig(chunk_size=300, chunk_overlap=0, length_function=len)
+        image_splitter = RecursiveCharacterTextSplitter(
+            chunk_size=config.chunk_size,
+            chunk_overlap=config.chunk_overlap,
+            length_function=config.length_function,
+        )
+        super().__init__(image_splitter)
+
+    def create_chunks(self, loader, src):
+        """
+        Loads the image(s), and creates their corresponding embedding. This creates one chunk for each image
+
+        :param loader: The loader whose `load_data` method is used to create
+        the raw data.
+        :param src: The data to be handled by the loader. Can be a URL for
+        remote sources or local content for local loaders.
+        """
+        documents = []
+        embeddings = []
+        ids = []
+        data_result = loader.load_data(src)
+        data_records = data_result["data"]
+        doc_id = data_result["doc_id"]
+        metadatas = []
+        for data in data_records:
+            meta_data = data["meta_data"]
+            # add data type to meta data to allow query using data type
+            meta_data["data_type"] = self.data_type.value
+            chunk_id = hashlib.sha256(meta_data["url"].encode()).hexdigest()
+            ids.append(chunk_id)
+            documents.append(data["content"])
+            embeddings.append(data["embedding"])
+            meta_data["doc_id"] = doc_id
+            metadatas.append(meta_data)
+
+        return {
+            "documents": documents,
+            "embeddings": embeddings,
+            "ids": ids,
+            "metadatas": metadatas,
+            "doc_id": doc_id,
+        }
+
+    def get_word_count(self, documents):
+        """
+        The number of chunks and the corresponding word count for an image is fixed to 1, as 1 embedding is created for
+        each image
+        """
+        return 1

+ 2 - 0
embedchain/config/llm/base_llm_config.py

@@ -67,6 +67,7 @@ class BaseLlmConfig(BaseConfig):
         deployment_name: Optional[str] = None,
         system_prompt: Optional[str] = None,
         where: Dict[str, Any] = None,
+        query_type: Optional[str] = None
     ):
         """
         Initializes a configuration class instance for the LLM.
@@ -112,6 +113,7 @@ class BaseLlmConfig(BaseConfig):
         self.top_p = top_p
         self.deployment_name = deployment_name
         self.system_prompt = system_prompt
+        self.query_type = query_type
 
         if self.validate_template(template):
             self.template = template

+ 4 - 1
embedchain/data_formatter/data_formatter.py

@@ -2,6 +2,7 @@ from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.chunkers.docs_site import DocsSiteChunker
 from embedchain.chunkers.docx_file import DocxFileChunker
 from embedchain.chunkers.mdx import MdxChunker
+from embedchain.chunkers.images import ImagesChunker
 from embedchain.chunkers.notion import NotionChunker
 from embedchain.chunkers.pdf_file import PdfFileChunker
 from embedchain.chunkers.qna_pair import QnaPairChunker
@@ -16,6 +17,7 @@ from embedchain.loaders.base_loader import BaseLoader
 from embedchain.loaders.csv import CsvLoader
 from embedchain.loaders.docs_site_loader import DocsSiteLoader
 from embedchain.loaders.docx_file import DocxFileLoader
+from embedchain.loaders.images import ImagesLoader
 from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
 from embedchain.loaders.local_text import LocalTextLoader
 from embedchain.loaders.mdx import MdxLoader
@@ -68,6 +70,7 @@ class DataFormatter(JSONSerializable):
             DataType.DOCS_SITE: DocsSiteLoader,
             DataType.CSV: CsvLoader,
             DataType.MDX: MdxLoader,
+            DataType.IMAGES: ImagesLoader,
         }
         lazy_loaders = {DataType.NOTION}
         if data_type in loaders:
@@ -102,11 +105,11 @@ class DataFormatter(JSONSerializable):
             DataType.QNA_PAIR: QnaPairChunker,
             DataType.TEXT: TextChunker,
             DataType.DOCX: DocxFileChunker,
-            DataType.WEB_PAGE: WebPageChunker,
             DataType.DOCS_SITE: DocsSiteChunker,
             DataType.NOTION: NotionChunker,
             DataType.CSV: TableChunker,
             DataType.MDX: MdxChunker,
+            DataType.IMAGES: ImagesChunker,
         }
         if data_type in chunker_classes:
             chunker_class: type = chunker_classes[data_type]

+ 14 - 4
embedchain/embedchain.py

@@ -212,7 +212,7 @@ class EmbedChain(JSONSerializable):
         # Send anonymous telemetry
         if self.config.collect_metrics:
             # it's quicker to check the variable twice than to count words when they won't be submitted.
-            word_count = sum([len(document.split(" ")) for document in documents])
+            word_count = data_formatter.chunker.get_word_count(documents)
 
             extra_metadata = {"data_type": data_type.value, "word_count": word_count, "chunks_count": new_chunks}
             thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("add", extra_metadata))
@@ -329,7 +329,6 @@ class EmbedChain(JSONSerializable):
 
         # Create chunks
         embeddings_data = chunker.create_chunks(loader, src)
-
         # spread chunking results
         documents = embeddings_data["documents"]
         metadatas = embeddings_data["metadatas"]
@@ -393,7 +392,8 @@ class EmbedChain(JSONSerializable):
         # Count before, to calculate a delta in the end.
         chunks_before_addition = self.db.count()
 
-        self.db.add(documents=documents, metadatas=metadatas, ids=ids)
+        self.db.add(embeddings=embeddings_data.get("embeddings", None), documents=documents, metadatas=metadatas,
+                    ids=ids, skip_embedding = (chunker.data_type == DataType.IMAGES))
         count_new_chunks = self.db.count() - chunks_before_addition
         print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
         return list(documents), metadatas, ids, count_new_chunks
@@ -434,10 +434,20 @@ class EmbedChain(JSONSerializable):
         if self.config.id is not None:
             where.update({"app_id": self.config.id})
 
+        # We cannot query the database with the input query in case of an image search. This is because we need
+        # to bring down both the image and text to the same dimension to be able to compare them.
+        db_query = input_query
+        if config.query_type == "Images":
+            # We import the clip processor here to make sure the package is not dependent on clip dependency even if the
+            # image dataset is not being used
+            from embedchain.models.clip_processor import ClipProcessor
+            db_query = ClipProcessor.get_text_features(query=input_query)
+
         contents = self.db.query(
-            input_query=input_query,
+            input_query=db_query,
             n_results=query_config.number_documents,
             where=where,
+            skip_embedding = (config.query_type == "Images")
         )
 
         return contents

+ 3 - 0
embedchain/llm/base.py

@@ -191,6 +191,9 @@ class BaseLlm(JSONSerializable):
                 prev_config = self.config.serialize()
                 self.config = config
 
+            if config is not None and config.query_type == "Images":
+                return contexts
+
             if self.is_docs_site_instance:
                 self.config.template = DOCS_SITE_PROMPT_TEMPLATE
                 self.config.number_documents = 5

+ 37 - 0
embedchain/loaders/images.py

@@ -0,0 +1,37 @@
+import os
+import logging
+import hashlib
+from embedchain.loaders.base_loader import BaseLoader
+
+
+class ImagesLoader(BaseLoader):
+
+    def load_data(self, image_url):
+        """
+        Loads images from the supplied directory/file and applies CLIP model transformation to represent these images
+        in vector form
+
+        :param image_url: The URL from which the images are to be loaded
+        """
+        # load model and image preprocessing
+        from embedchain.models.clip_processor import ClipProcessor
+        model, preprocess = ClipProcessor.load_model()
+        if os.path.isfile(image_url):
+            data = [ClipProcessor.get_image_features(image_url, model, preprocess)]
+        else:
+            data = []
+            for filename in os.listdir(image_url):
+                filepath = os.path.join(image_url, filename)
+                try:
+                    data.append(ClipProcessor.get_image_features(filepath, model, preprocess))
+                except Exception as e:
+                    # Log the file that was not loaded
+                    logging.exception("Failed to load the file {}. Exception {}".format(filepath, e))
+        # Get the metadata like Size, Last Modified and Last Created timestamps
+        image_path_metadata = [str(os.path.getsize(image_url)), str(os.path.getmtime(image_url)),
+                               str(os.path.getctime(image_url))]
+        doc_id = hashlib.sha256((" ".join(image_path_metadata) + image_url).encode()).hexdigest()
+        return {
+            "doc_id": doc_id,
+            "data": data,
+        }

+ 64 - 0
embedchain/models/clip_processor.py

@@ -0,0 +1,64 @@
+try:
+    import torch
+    import clip
+    from PIL import Image, UnidentifiedImageError
+except ImportError:
+    raise ImportError("Images requires extra dependencies. Install with `pip install embedchain[images]`") from None
+
+MODEL_NAME = "ViT-B/32"
+
+
+class ClipProcessor:
+    @staticmethod
+    def load_model():
+        """Load data from a director of images."""
+        device = "cuda" if torch.cuda.is_available() else "cpu"
+
+        # load model and image preprocessing
+        model, preprocess = clip.load(MODEL_NAME, device=device, jit=False)
+        return model, preprocess
+
+    @staticmethod
+    def get_image_features(image_url, model, preprocess):
+        """
+        Applies the CLIP model to evaluate the vector representation of the supplied image
+        """
+        device = "cuda" if torch.cuda.is_available() else "cpu"
+        try:
+            # load image
+            image = Image.open(image_url)
+        except FileNotFoundError:
+            raise FileNotFoundError("The supplied file does not exist`")
+        except UnidentifiedImageError:
+            raise UnidentifiedImageError("The supplied file is not an image`")
+
+        # pre-process image
+        processed_image = preprocess(image).unsqueeze(0).to(device)
+        with torch.no_grad():
+            image_features = model.encode_image(processed_image)
+            image_features /= image_features.norm(dim=-1, keepdim=True)
+
+        image_features = image_features.cpu().detach().numpy().tolist()[0]
+        meta_data = {
+            "url": image_url
+        }
+        return {
+            "content": image_url,
+            "embedding": image_features,
+            "meta_data": meta_data
+        }
+
+    @staticmethod
+    def get_text_features(query):
+        """
+        Applies the CLIP model to evaluate the vector representation of the supplied text
+        """
+        device = "cuda" if torch.cuda.is_available() else "cpu"
+
+        model, preprocess = ClipProcessor.load_model()
+        text = clip.tokenize(query).to(device)
+        with torch.no_grad():
+            text_features = model.encode_text(text)
+            text_features /= text_features.norm(dim=-1, keepdim=True)
+
+        return text_features.cpu().numpy().tolist()[0]

+ 2 - 0
embedchain/models/data_type.py

@@ -23,6 +23,7 @@ class IndirectDataType(Enum):
     NOTION = "notion"
     CSV = "csv"
     MDX = "mdx"
+    IMAGES = "images"
 
 
 class SpecialDataType(Enum):
@@ -45,3 +46,4 @@ class DataType(Enum):
     CSV = IndirectDataType.CSV.value
     MDX = IndirectDataType.MDX.value
     QNA_PAIR = SpecialDataType.QNA_PAIR.value
+    IMAGES = IndirectDataType.IMAGES.value

+ 24 - 12
embedchain/vectordb/chroma.py

@@ -115,7 +115,8 @@ class ChromaDB(BaseVectorDB):
     def get_advanced(self, where):
         return self.collection.get(where=where, limit=1)
 
-    def add(self, documents: List[str], metadatas: List[object], ids: List[str]) -> Any:
+    def add(self, embeddings: List[List[float]], documents: List[str], metadatas: List[object],
+            ids: List[str], skip_embedding: bool) -> Any:
         """
         Add vectors to chroma database
 
@@ -126,7 +127,10 @@ class ChromaDB(BaseVectorDB):
         :param ids: ids
         :type ids: List[str]
         """
-        self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
+        if skip_embedding:
+            self.collection.add(embeddings=embeddings, documents=documents, metadatas=metadatas, ids=ids)
+        else:
+            self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
 
     def _format_result(self, results: QueryResult) -> list[tuple[Document, float]]:
         """
@@ -146,7 +150,7 @@ class ChromaDB(BaseVectorDB):
             )
         ]
 
-    def query(self, input_query: List[str], n_results: int, where: Dict[str, Any]) -> List[str]:
+    def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
         """
         Query contents from vector data base based on vector similarity
 
@@ -161,19 +165,27 @@ class ChromaDB(BaseVectorDB):
         :rtype: List[str]
         """
         try:
-            result = self.collection.query(
-                query_texts=[
-                    input_query,
-                ],
-                n_results=n_results,
-                where=where,
-            )
+            if skip_embedding:
+                result = self.collection.query(
+                    query_embeddings=[
+                        input_query,
+                    ],
+                    n_results=n_results,
+                    where=where,
+                )
+            else:
+                result = self.collection.query(
+                    query_texts=[
+                        input_query,
+                    ],
+                    n_results=n_results,
+                    where=where,
+                )
         except InvalidDimensionException as e:
             raise InvalidDimensionException(
                 e.message()
-                + ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database."  # noqa E501
+                + ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database." # noqa E501
             ) from None
-
         results_formatted = self._format_result(result)
         contents = [result[0].page_content for result in results_formatted]
         return contents

+ 15 - 8
embedchain/vectordb/elasticsearch.py

@@ -1,5 +1,5 @@
 import logging
-from typing import Dict, List, Optional, Set
+from typing import Any, Dict, List, Optional, Set
 
 try:
     from elasticsearch import Elasticsearch
@@ -100,9 +100,10 @@ class ElasticsearchDB(BaseVectorDB):
         ids = [doc["_id"] for doc in docs]
         return {"ids": set(ids)}
 
-    def add(self, documents: List[str], metadatas: List[object], ids: List[str]):
-        """add data in vector database
-
+    def add(self, embeddings: List[List[float]], documents: List[str], metadatas: List[object],
+            ids: List[str], skip_embedding: bool) -> Any:
+        """
+        add data in vector database
         :param documents: list of texts to add
         :type documents: List[str]
         :param metadatas: list of metadata associated with docs
@@ -112,7 +113,9 @@ class ElasticsearchDB(BaseVectorDB):
         """
 
         docs = []
-        embeddings = self.embedder.embedding_fn(documents)
+        if not skip_embedding:
+            embeddings = self.embedder.embedding_fn(documents)
+
         for id, text, metadata, embeddings in zip(ids, documents, metadatas, embeddings):
             docs.append(
                 {
@@ -124,7 +127,7 @@ class ElasticsearchDB(BaseVectorDB):
         bulk(self.client, docs)
         self.client.indices.refresh(index=self._get_index())
 
-    def query(self, input_query: List[str], n_results: int, where: Dict[str, any]) -> List[str]:
+    def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
         """
         query contents from vector data base based on vector similarity
 
@@ -137,8 +140,12 @@ class ElasticsearchDB(BaseVectorDB):
         :return: Database contents that are the result of the query
         :rtype: List[str]
         """
-        input_query_vector = self.embedder.embedding_fn(input_query)
-        query_vector = input_query_vector[0]
+        if skip_embedding:
+            query_vector = input_query
+        else:
+            input_query_vector = self.embedder.embedding_fn(input_query)
+            query_vector = input_query_vector[0]
+
         query = {
             "script_score": {
                 "query": {"bool": {"must": [{"exists": {"field": "text"}}]}},


+ 4 - 2
pyproject.toml

@@ -106,8 +106,9 @@ fastapi-poe = { version = "0.0.16", optional = true }
 discord = { version = "^2.3.2", optional = true }
 slack-sdk = { version = "3.21.3", optional = true }
 docx2txt = "^0.8"
-
-
+clip = {git = "https://github.com/openai/CLIP.git#a1d0717", optional = true}
+ftfy = { version = "6.1.1", optional = true }
+regex = { version = "2023.8.8", optional = true }
 
 [tool.poetry.group.dev.dependencies]
 black = "^23.3.0"
@@ -130,6 +131,7 @@ poe = ["fastapi-poe"]
 discord = ["discord"]
 slack = ["slack-sdk", "flask"]
 whatsapp = ["twilio", "flask"]
+images = ["torch", "ftfy", "regex", "clip"]
 
 [tool.poetry.group.docs.dependencies]
 

+ 72 - 0
tests/chunkers/test_image_chunker.py

@@ -0,0 +1,72 @@
+import unittest
+
+from embedchain.chunkers.images import ImagesChunker
+from embedchain.config import ChunkerConfig
+from embedchain.models.data_type import DataType
+
+
+class TestImageChunker(unittest.TestCase):
+    def test_chunks(self):
+        """
+        Test the chunks generated by TextChunker.
+        # TODO: Not a very precise test.
+        """
+        chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len)
+        chunker = ImagesChunker(config=chunker_config)
+        # Data type must be set manually in the test
+        chunker.set_data_type(DataType.IMAGES)
+
+        image_path = "./tmp/image.jpeg"
+        result = chunker.create_chunks(MockLoader(), image_path)
+
+        expected_chunks = {'doc_id': '123',
+                           'documents': [image_path],
+                           'embeddings': ['embedding'],
+                           'ids': ['140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe'],
+                           'metadatas': [{'data_type': 'images', 'doc_id': '123', 'url': 'none'}]}
+        self.assertEqual(expected_chunks, result)
+
+    def test_chunks_with_default_config(self):
+        """
+        Test the chunks generated by ImageChunker with default config.
+        """
+        chunker = ImagesChunker()
+        # Data type must be set manually in the test
+        chunker.set_data_type(DataType.IMAGES)
+
+        image_path = "./tmp/image.jpeg"
+        result = chunker.create_chunks(MockLoader(), image_path)
+
+        expected_chunks = {'doc_id': '123',
+                           'documents': [image_path],
+                           'embeddings': ['embedding'],
+                           'ids': ['140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe'],
+                           'metadatas': [{'data_type': 'images', 'doc_id': '123', 'url': 'none'}]}
+        self.assertEqual(expected_chunks, result)
+
+    def test_word_count(self):
+        chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len)
+        chunker = ImagesChunker(config=chunker_config)
+        chunker.set_data_type(DataType.IMAGES)
+
+        document = [["ab cd", "ef gh"], ["ij kl", "mn op"]]
+        result = chunker.get_word_count(document)
+        self.assertEqual(result, 1)
+
+
+class MockLoader:
+    def load_data(self, src):
+        """
+        Mock loader that returns a list of data dictionaries.
+        Adjust this method to return different data for testing.
+        """
+        return {
+            "doc_id": "123",
+            "data": [
+                {
+                    "content": src,
+                    "embedding": "embedding",
+                    "meta_data": {"url": "none"},
+                }
+            ],
+        }

+ 9 - 0
tests/chunkers/test_text.py

@@ -62,6 +62,15 @@ class TestTextChunker(unittest.TestCase):
 
         self.assertEqual(len(documents), len(text))
 
+    def test_word_count(self):
+        chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len)
+        chunker = TextChunker(config=chunker_config)
+        chunker.set_data_type(DataType.TEXT)
+
+        document = ["ab cd", "ef gh"]
+        result = chunker.get_word_count(document)
+        self.assertEqual(result, 4)
+
 
 class MockLoader:
     def load_data(self, src):

BIN
tests/models/image.jpg


+ 55 - 0
tests/models/test_clip_processor.py

@@ -0,0 +1,55 @@
+import tempfile
+import unittest
+import os
+import urllib
+from PIL import Image
+from embedchain.models.clip_processor import ClipProcessor
+
+
+class ClipProcessorTest(unittest.TestCase):
+
+    def test_load_model(self):
+        # Test that the `load_model()` method loads the CLIP model and image preprocessing correctly.
+        model, preprocess = ClipProcessor.load_model()
+
+        # Assert that the model is not None.
+        self.assertIsNotNone(model)
+
+        # Assert that the preprocess is not None.
+        self.assertIsNotNone(preprocess)
+
+    def test_get_image_features(self):
+        # Clone the image to a temporary folder.
+        with tempfile.TemporaryDirectory() as tmp_dir:
+            urllib.request.urlretrieve(
+                'https://upload.wikimedia.org/wikipedia/en/a/a9/Example.jpg',
+                "image.jpg")
+
+            image = Image.open("image.jpg")
+            image.save(os.path.join(tmp_dir, "image.jpg"))
+
+            # Get the image features.
+            model, preprocess = ClipProcessor.load_model()
+            ClipProcessor.get_image_features(os.path.join(tmp_dir, "image.jpg"), model, preprocess)
+
+            # Delete the temporary file.
+            os.remove(os.path.join(tmp_dir, "image.jpg"))
+
+            # Assert that the test passes.
+            self.assertTrue(True)
+
+    def test_get_text_features(self):
+        # Test that the `get_text_features()` method returns a list containing the text embedding.
+        query = "This is a text query."
+        model, preprocess = ClipProcessor.load_model()
+
+        text_features = ClipProcessor.get_text_features(query)
+
+        # Assert that the text embedding is not None.
+        self.assertIsNotNone(text_features)
+
+        # Assert that the text embedding is a list of floats.
+        self.assertIsInstance(text_features, list)
+
+        # Assert that the text embedding has the correct length.
+        self.assertEqual(len(text_features), 512)

+ 28 - 0
tests/vectordb/test_chroma_db.py

@@ -186,6 +186,34 @@ class TestChromaDbCollection(unittest.TestCase):
         # Should still be 1, not 2.
         self.assertEqual(app.db.count(), 1)
 
+    def test_add_with_skip_embedding(self):
+        """
+        Test that changes to one collection do not affect the other collection
+        """
+        # Start with a clean app
+        self.app_with_settings.reset()
+        # app = App(config=AppConfig(collect_metrics=False), db=db)
+
+        # Collection should be empty when created
+        self.assertEqual(self.app_with_settings.db.count(), 0)
+
+        self.app_with_settings.db.add(embeddings=[[0, 0, 0]], documents=["document"], metadatas=[{"value": "somevalue"}], ids=["id"], skip_embedding=True)
+        # After adding, should contain one item
+        self.assertEqual(self.app_with_settings.db.count(), 1)
+
+        # Validate if the get utility of the database is working as expected
+        data = self.app_with_settings.db.get(["id"], limit=1)
+        expected_value = {'documents': ['document'],
+                          'embeddings': None,
+                          'ids': ['id'],
+                          'metadatas': [{'value': 'somevalue'}]}
+        self.assertEqual(data, expected_value)
+
+        # Validate if the query utility of the database is working as expected
+        data = self.app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True)
+        expected_value = ['document']
+        self.assertEqual(data, expected_value)
+
     def test_collections_are_persistent(self):
         """
         Test that a collection can be picked up later.

+ 99 - 4
tests/vectordb/test_elasticsearch_db.py

@@ -1,14 +1,109 @@
 import os
 import unittest
+from unittest.mock import patch
 
-from embedchain.config import ElasticsearchDBConfig
+from embedchain import App
+from embedchain.config import AppConfig, ElasticsearchDBConfig
 from embedchain.vectordb.elasticsearch import ElasticsearchDB
-
+from embedchain.embedder.gpt4all import GPT4AllEmbedder
 
 class TestEsDB(unittest.TestCase):
-    def setUp(self):
-        self.es_config = ElasticsearchDBConfig(es_url="http://mock-url.net")
+
+    @patch("embedchain.vectordb.elasticsearch.Elasticsearch")
+    def test_setUp(self, mock_client):
+        self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
         self.vector_dim = 384
+        app_config = AppConfig(collection_name=False, collect_metrics=False)
+        self.app = App(config=app_config, db=self.db)
+
+        # Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
+        self.assertEqual(self.db.client, mock_client.return_value)
+
+    @patch("embedchain.vectordb.elasticsearch.Elasticsearch")
+    def test_query(self, mock_client):
+        self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
+        app_config = AppConfig(collection_name=False, collect_metrics=False)
+        self.app = App(config=app_config, db=self.db, embedder=GPT4AllEmbedder())
+
+        # Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
+        self.assertEqual(self.db.client, mock_client.return_value)
+
+        # Create some dummy data.
+        embeddings = [[1, 2, 3], [4, 5, 6]]
+        documents = ["This is a document.", "This is another document."]
+        metadatas = [{}, {}]
+        ids = ["doc_1", "doc_2"]
+
+        # Add the data to the database.
+        self.db.add(embeddings, documents, metadatas, ids, skip_embedding=False)
+
+        search_response = {"hits":
+            {"hits":
+                [
+                    {
+                        "_source": {"text": "This is a document."},
+                        "_score": 0.9
+                    },
+                    {
+                        "_source": {"text": "This is another document."},
+                        "_score": 0.8
+                    }
+                ]
+            }
+        }
+
+        # Configure the mock client to return the mocked response.
+        mock_client.return_value.search.return_value = search_response
+
+        # Query the database for the documents that are most similar to the query "This is a document".
+        query = ["This is a document"]
+        results = self.db.query(query, n_results=2, where={}, skip_embedding=False)
+
+        # Assert that the results are correct.
+        self.assertEqual(results, ["This is a document.", "This is another document."])
+
+    @patch("embedchain.vectordb.elasticsearch.Elasticsearch")
+    def test_query_with_skip_embedding(self, mock_client):
+        self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
+        app_config = AppConfig(collection_name=False, collect_metrics=False)
+        self.app = App(config=app_config, db=self.db)
+
+        # Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
+        self.assertEqual(self.db.client, mock_client.return_value)
+
+        # Create some dummy data.
+        embeddings = [[1, 2, 3], [4, 5, 6]]
+        documents = ["This is a document.", "This is another document."]
+        metadatas = [{}, {}]
+        ids = ["doc_1", "doc_2"]
+
+        # Add the data to the database.
+        self.db.add(embeddings, documents, metadatas, ids, skip_embedding=True)
+
+        search_response = {"hits":
+            {"hits":
+                [
+                    {
+                        "_source": {"text": "This is a document."},
+                        "_score": 0.9
+                    },
+                    {
+                        "_source": {"text": "This is another document."},
+                        "_score": 0.8
+                    }
+                ]
+            }
+        }
+
+        # Configure the mock client to return the mocked response.
+        mock_client.return_value.search.return_value = search_response
+
+        # Query the database for the documents that are most similar to the query "This is a document".
+        query = ["This is a document"]
+        results = self.db.query(query, n_results=2, where={}, skip_embedding=True)
+
+        # Assert that the results are correct.
+        self.assertEqual(results, ["This is a document.", "This is another document."])
 
     def test_init_without_url(self):
         # Make sure it's not loaded from env