Sfoglia il codice sorgente

[bugfix] Fix issue when llm config is not defined (#763)

Deshraj Yadav 1 anno fa
parent
commit
87d0b5c76f

+ 1 - 1
.github/workflows/ci.yml

@@ -34,4 +34,4 @@ jobs:
           file: coverage.xml
         env:
           CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
-      
+

+ 3 - 0
Makefile

@@ -9,6 +9,9 @@ PROJECT_NAME := embedchain
 install:
 	poetry install
 
+install_all:
+	poetry install --all-extras
+
 install_es:
 	poetry install --extras elasticsearch
 

+ 1 - 1
embedchain/config/llm/base_llm_config.py

@@ -67,7 +67,7 @@ class BaseLlmConfig(BaseConfig):
         deployment_name: Optional[str] = None,
         system_prompt: Optional[str] = None,
         where: Dict[str, Any] = None,
-        query_type: Optional[str] = None
+        query_type: Optional[str] = None,
     ):
         """
         Initializes a configuration class instance for the LLM.

+ 1 - 1
embedchain/data_formatter/data_formatter.py

@@ -1,8 +1,8 @@
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.chunkers.docs_site import DocsSiteChunker
 from embedchain.chunkers.docx_file import DocxFileChunker
-from embedchain.chunkers.mdx import MdxChunker
 from embedchain.chunkers.images import ImagesChunker
+from embedchain.chunkers.mdx import MdxChunker
 from embedchain.chunkers.notion import NotionChunker
 from embedchain.chunkers.pdf_file import PdfFileChunker
 from embedchain.chunkers.qna_pair import QnaPairChunker

+ 10 - 4
embedchain/embedchain.py

@@ -392,8 +392,13 @@ class EmbedChain(JSONSerializable):
         # Count before, to calculate a delta in the end.
         chunks_before_addition = self.db.count()
 
-        self.db.add(embeddings=embeddings_data.get("embeddings", None), documents=documents, metadatas=metadatas,
-                    ids=ids, skip_embedding = (chunker.data_type == DataType.IMAGES))
+        self.db.add(
+            embeddings=embeddings_data.get("embeddings", None),
+            documents=documents,
+            metadatas=metadatas,
+            ids=ids,
+            skip_embedding=(chunker.data_type == DataType.IMAGES),
+        )
         count_new_chunks = self.db.count() - chunks_before_addition
         print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
         return list(documents), metadatas, ids, count_new_chunks
@@ -437,17 +442,18 @@ class EmbedChain(JSONSerializable):
         # We cannot query the database with the input query in case of an image search. This is because we need
         # to bring down both the image and text to the same dimension to be able to compare them.
         db_query = input_query
-        if config.query_type == "Images":
+        if hasattr(config, "query_type") and config.query_type == "Images":
             # We import the clip processor here to make sure the package is not dependent on clip dependency even if the
             # image dataset is not being used
             from embedchain.models.clip_processor import ClipProcessor
+
             db_query = ClipProcessor.get_text_features(query=input_query)
 
         contents = self.db.query(
             input_query=db_query,
             n_results=query_config.number_documents,
             where=where,
-            skip_embedding = (config.query_type == "Images")
+            skip_embedding=(hasattr(config, "query_type") and config.query_type == "Images"),
         )
 
         return contents

+ 1 - 1
embedchain/llm/gpt4all.py

@@ -22,7 +22,7 @@ class GPT4ALLLlm(BaseLlm):
             from gpt4all import GPT4All
         except ModuleNotFoundError:
             raise ModuleNotFoundError(
-                "The GPT4All python package is not installed. Please install it with `pip install --upgrade embedchain[opensource]`" # noqa E501
+                "The GPT4All python package is not installed. Please install it with `pip install --upgrade embedchain[opensource]`"  # noqa E501
             ) from None
 
         return GPT4All(model_name=model)

+ 9 - 5
embedchain/loaders/images.py

@@ -1,11 +1,11 @@
-import os
-import logging
 import hashlib
+import logging
+import os
+
 from embedchain.loaders.base_loader import BaseLoader
 
 
 class ImagesLoader(BaseLoader):
-
     def load_data(self, image_url):
         """
         Loads images from the supplied directory/file and applies CLIP model transformation to represent these images
@@ -15,6 +15,7 @@ class ImagesLoader(BaseLoader):
         """
         # load model and image preprocessing
         from embedchain.models.clip_processor import ClipProcessor
+
         model, preprocess = ClipProcessor.load_model()
         if os.path.isfile(image_url):
             data = [ClipProcessor.get_image_features(image_url, model, preprocess)]
@@ -28,8 +29,11 @@ class ImagesLoader(BaseLoader):
                     # Log the file that was not loaded
                     logging.exception("Failed to load the file {}. Exception {}".format(filepath, e))
         # Get the metadata like Size, Last Modified and Last Created timestamps
-        image_path_metadata = [str(os.path.getsize(image_url)), str(os.path.getmtime(image_url)),
-                               str(os.path.getctime(image_url))]
+        image_path_metadata = [
+            str(os.path.getsize(image_url)),
+            str(os.path.getmtime(image_url)),
+            str(os.path.getctime(image_url)),
+        ]
         doc_id = hashlib.sha256((" ".join(image_path_metadata) + image_url).encode()).hexdigest()
         return {
             "doc_id": doc_id,

+ 3 - 9
embedchain/models/clip_processor.py

@@ -1,6 +1,6 @@
 try:
-    import torch
     import clip
+    import torch
     from PIL import Image, UnidentifiedImageError
 except ImportError:
     raise ImportError("Images requires extra dependencies. Install with `pip install embedchain[images]`") from None
@@ -39,14 +39,8 @@ class ClipProcessor:
             image_features /= image_features.norm(dim=-1, keepdim=True)
 
         image_features = image_features.cpu().detach().numpy().tolist()[0]
-        meta_data = {
-            "url": image_url
-        }
-        return {
-            "content": image_url,
-            "embedding": image_features,
-            "meta_data": meta_data
-        }
+        meta_data = {"url": image_url}
+        return {"content": image_url, "embedding": image_features, "meta_data": meta_data}
 
     @staticmethod
     def get_text_features(query):

+ 9 - 3
embedchain/vectordb/chroma.py

@@ -115,8 +115,14 @@ class ChromaDB(BaseVectorDB):
     def get_advanced(self, where):
         return self.collection.get(where=where, limit=1)
 
-    def add(self, embeddings: List[List[float]], documents: List[str], metadatas: List[object],
-            ids: List[str], skip_embedding: bool) -> Any:
+    def add(
+        self,
+        embeddings: List[List[float]],
+        documents: List[str],
+        metadatas: List[object],
+        ids: List[str],
+        skip_embedding: bool,
+    ) -> Any:
         """
         Add vectors to chroma database
 
@@ -184,7 +190,7 @@ class ChromaDB(BaseVectorDB):
         except InvalidDimensionException as e:
             raise InvalidDimensionException(
                 e.message()
-                + ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database." # noqa E501
+                + ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database."  # noqa E501
             ) from None
         results_formatted = self._format_result(result)
         contents = [result[0].page_content for result in results_formatted]

+ 8 - 2
embedchain/vectordb/elasticsearch.py

@@ -100,8 +100,14 @@ class ElasticsearchDB(BaseVectorDB):
         ids = [doc["_id"] for doc in docs]
         return {"ids": set(ids)}
 
-    def add(self, embeddings: List[List[float]], documents: List[str], metadatas: List[object],
-            ids: List[str], skip_embedding: bool) -> Any:
+    def add(
+        self,
+        embeddings: List[List[float]],
+        documents: List[str],
+        metadatas: List[object],
+        ids: List[str],
+        skip_embedding: bool,
+    ) -> Any:
         """
         add data in vector database
         :param documents: list of texts to add

+ 4 - 2
pyproject.toml

@@ -94,7 +94,7 @@ pytube = "^15.0.0"
 duckduckgo-search = "^3.8.5"
 llama-hub = { version = "^0.0.29", optional = true }
 sentence-transformers = { version = "^2.2.2", optional = true }
-torch = { version = ">=2.0.0, !=2.0.1", optional = true }
+torch = { version = "2.0.0", optional = true }
 # Torch 2.0.1 is not compatible with poetry (https://github.com/pytorch/pytorch/issues/100974)
 gpt4all = { version = "1.0.8", optional = true }
 # 1.0.9 is not working for some users (https://github.com/nomic-ai/gpt4all/issues/1394)
@@ -107,6 +107,8 @@ discord = { version = "^2.3.2", optional = true }
 slack-sdk = { version = "3.21.3", optional = true }
 docx2txt = "^0.8"
 clip = {git = "https://github.com/openai/CLIP.git#a1d0717", optional = true}
+pillow = { version = "10.0.1", optional = true }
+torchvision = { version = ">=0.15.1, !=0.15.2", optional = true }
 ftfy = { version = "6.1.1", optional = true }
 regex = { version = "2023.8.8", optional = true }
 
@@ -131,7 +133,7 @@ poe = ["fastapi-poe"]
 discord = ["discord"]
 slack = ["slack-sdk", "flask"]
 whatsapp = ["twilio", "flask"]
-images = ["torch", "ftfy", "regex", "clip"]
+images = ["torch", "ftfy", "regex", "clip", "pillow", "torchvision"]
 
 [tool.poetry.group.docs.dependencies]
 

+ 14 - 10
tests/chunkers/test_image_chunker.py

@@ -19,11 +19,13 @@ class TestImageChunker(unittest.TestCase):
         image_path = "./tmp/image.jpeg"
         result = chunker.create_chunks(MockLoader(), image_path)
 
-        expected_chunks = {'doc_id': '123',
-                           'documents': [image_path],
-                           'embeddings': ['embedding'],
-                           'ids': ['140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe'],
-                           'metadatas': [{'data_type': 'images', 'doc_id': '123', 'url': 'none'}]}
+        expected_chunks = {
+            "doc_id": "123",
+            "documents": [image_path],
+            "embeddings": ["embedding"],
+            "ids": ["140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe"],
+            "metadatas": [{"data_type": "images", "doc_id": "123", "url": "none"}],
+        }
         self.assertEqual(expected_chunks, result)
 
     def test_chunks_with_default_config(self):
@@ -37,11 +39,13 @@ class TestImageChunker(unittest.TestCase):
         image_path = "./tmp/image.jpeg"
         result = chunker.create_chunks(MockLoader(), image_path)
 
-        expected_chunks = {'doc_id': '123',
-                           'documents': [image_path],
-                           'embeddings': ['embedding'],
-                           'ids': ['140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe'],
-                           'metadatas': [{'data_type': 'images', 'doc_id': '123', 'url': 'none'}]}
+        expected_chunks = {
+            "doc_id": "123",
+            "documents": [image_path],
+            "embeddings": ["embedding"],
+            "ids": ["140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe"],
+            "metadatas": [{"data_type": "images", "doc_id": "123", "url": "none"}],
+        }
         self.assertEqual(expected_chunks, result)
 
     def test_word_count(self):

+ 10 - 19
tests/models/test_clip_processor.py

@@ -1,29 +1,23 @@
-import tempfile
-import unittest
 import os
+import tempfile
 import urllib
+
 from PIL import Image
-from embedchain.models.clip_processor import ClipProcessor
 
+from embedchain.models.clip_processor import ClipProcessor
 
-class ClipProcessorTest(unittest.TestCase):
 
+class TestClipProcessor:
     def test_load_model(self):
         # Test that the `load_model()` method loads the CLIP model and image preprocessing correctly.
         model, preprocess = ClipProcessor.load_model()
-
-        # Assert that the model is not None.
-        self.assertIsNotNone(model)
-
-        # Assert that the preprocess is not None.
-        self.assertIsNotNone(preprocess)
+        assert model is not None
+        assert preprocess is not None
 
     def test_get_image_features(self):
         # Clone the image to a temporary folder.
         with tempfile.TemporaryDirectory() as tmp_dir:
-            urllib.request.urlretrieve(
-                'https://upload.wikimedia.org/wikipedia/en/a/a9/Example.jpg',
-                "image.jpg")
+            urllib.request.urlretrieve("https://upload.wikimedia.org/wikipedia/en/a/a9/Example.jpg", "image.jpg")
 
             image = Image.open("image.jpg")
             image.save(os.path.join(tmp_dir, "image.jpg"))
@@ -35,9 +29,6 @@ class ClipProcessorTest(unittest.TestCase):
             # Delete the temporary file.
             os.remove(os.path.join(tmp_dir, "image.jpg"))
 
-            # Assert that the test passes.
-            self.assertTrue(True)
-
     def test_get_text_features(self):
         # Test that the `get_text_features()` method returns a list containing the text embedding.
         query = "This is a text query."
@@ -46,10 +37,10 @@ class ClipProcessorTest(unittest.TestCase):
         text_features = ClipProcessor.get_text_features(query)
 
         # Assert that the text embedding is not None.
-        self.assertIsNotNone(text_features)
+        assert text_features is not None
 
         # Assert that the text embedding is a list of floats.
-        self.assertIsInstance(text_features, list)
+        assert isinstance(text_features, list)
 
         # Assert that the text embedding has the correct length.
-        self.assertEqual(len(text_features), 512)
+        assert len(text_features) == 512

+ 14 - 6
tests/vectordb/test_chroma_db.py

@@ -197,21 +197,29 @@ class TestChromaDbCollection(unittest.TestCase):
         # Collection should be empty when created
         self.assertEqual(self.app_with_settings.db.count(), 0)
 
-        self.app_with_settings.db.add(embeddings=[[0, 0, 0]], documents=["document"], metadatas=[{"value": "somevalue"}], ids=["id"], skip_embedding=True)
+        self.app_with_settings.db.add(
+            embeddings=[[0, 0, 0]],
+            documents=["document"],
+            metadatas=[{"value": "somevalue"}],
+            ids=["id"],
+            skip_embedding=True,
+        )
         # After adding, should contain one item
         self.assertEqual(self.app_with_settings.db.count(), 1)
 
         # Validate if the get utility of the database is working as expected
         data = self.app_with_settings.db.get(["id"], limit=1)
-        expected_value = {'documents': ['document'],
-                          'embeddings': None,
-                          'ids': ['id'],
-                          'metadatas': [{'value': 'somevalue'}]}
+        expected_value = {
+            "documents": ["document"],
+            "embeddings": None,
+            "ids": ["id"],
+            "metadatas": [{"value": "somevalue"}],
+        }
         self.assertEqual(data, expected_value)
 
         # Validate if the query utility of the database is working as expected
         data = self.app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True)
-        expected_value = ['document']
+        expected_value = ["document"]
         self.assertEqual(data, expected_value)
 
     def test_collections_are_persistent(self):

+ 12 - 24
tests/vectordb/test_elasticsearch_db.py

@@ -4,11 +4,11 @@ from unittest.mock import patch
 
 from embedchain import App
 from embedchain.config import AppConfig, ElasticsearchDBConfig
-from embedchain.vectordb.elasticsearch import ElasticsearchDB
 from embedchain.embedder.gpt4all import GPT4AllEmbedder
+from embedchain.vectordb.elasticsearch import ElasticsearchDB
 
-class TestEsDB(unittest.TestCase):
 
+class TestEsDB(unittest.TestCase):
     @patch("embedchain.vectordb.elasticsearch.Elasticsearch")
     def test_setUp(self, mock_client):
         self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
@@ -37,17 +37,11 @@ class TestEsDB(unittest.TestCase):
         # Add the data to the database.
         self.db.add(embeddings, documents, metadatas, ids, skip_embedding=False)
 
-        search_response = {"hits":
-            {"hits":
-                [
-                    {
-                        "_source": {"text": "This is a document."},
-                        "_score": 0.9
-                    },
-                    {
-                        "_source": {"text": "This is another document."},
-                        "_score": 0.8
-                    }
+        search_response = {
+            "hits": {
+                "hits": [
+                    {"_source": {"text": "This is a document."}, "_score": 0.9},
+                    {"_source": {"text": "This is another document."}, "_score": 0.8},
                 ]
             }
         }
@@ -80,17 +74,11 @@ class TestEsDB(unittest.TestCase):
         # Add the data to the database.
         self.db.add(embeddings, documents, metadatas, ids, skip_embedding=True)
 
-        search_response = {"hits":
-            {"hits":
-                [
-                    {
-                        "_source": {"text": "This is a document."},
-                        "_score": 0.9
-                    },
-                    {
-                        "_source": {"text": "This is another document."},
-                        "_score": 0.8
-                    }
+        search_response = {
+            "hits": {
+                "hits": [
+                    {"_source": {"text": "This is a document."}, "_score": 0.9},
+                    {"_source": {"text": "This is another document."}, "_score": 0.8},
                 ]
             }
         }