test_qdrant.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. import unittest
  2. import uuid
  3. import pytest
  4. from mock import patch
  5. from qdrant_client.http import models
  6. from qdrant_client.http.models import Batch
  7. from embedchain import App
  8. from embedchain.config import AppConfig
  9. from embedchain.config.vectordb.pinecone import PineconeDBConfig
  10. from embedchain.embedder.base import BaseEmbedder
  11. from embedchain.vectordb.qdrant import QdrantDB
  12. def mock_embedding_fn(texts: list[str]) -> list[list[float]]:
  13. """A mock embedding function."""
  14. return [[1, 2, 3], [4, 5, 6]]
  15. class TestQdrantDB(unittest.TestCase):
  16. TEST_UUIDS = ["abc", "def", "ghi"]
  17. def test_incorrect_config_throws_error(self):
  18. """Test the init method of the Qdrant class throws error for incorrect config"""
  19. with self.assertRaises(TypeError):
  20. QdrantDB(config=PineconeDBConfig())
  21. @patch("embedchain.vectordb.qdrant.QdrantClient")
  22. def test_initialize(self, qdrant_client_mock):
  23. # Set the embedder
  24. embedder = BaseEmbedder()
  25. embedder.set_vector_dimension(1536)
  26. embedder.set_embedding_fn(mock_embedding_fn)
  27. # Create a Qdrant instance
  28. db = QdrantDB()
  29. app_config = AppConfig(collect_metrics=False)
  30. App(config=app_config, db=db, embedding_model=embedder)
  31. self.assertEqual(db.collection_name, "embedchain-store-1536")
  32. self.assertEqual(db.client, qdrant_client_mock.return_value)
  33. qdrant_client_mock.return_value.get_collections.assert_called_once()
  34. @patch("embedchain.vectordb.qdrant.QdrantClient")
  35. def test_get(self, qdrant_client_mock):
  36. qdrant_client_mock.return_value.scroll.return_value = ([], None)
  37. # Set the embedder
  38. embedder = BaseEmbedder()
  39. embedder.set_vector_dimension(1536)
  40. embedder.set_embedding_fn(mock_embedding_fn)
  41. # Create a Qdrant instance
  42. db = QdrantDB()
  43. app_config = AppConfig(collect_metrics=False)
  44. App(config=app_config, db=db, embedding_model=embedder)
  45. resp = db.get(ids=[], where={})
  46. self.assertEqual(resp, {"ids": [], "metadatas": []})
  47. resp2 = db.get(ids=["123", "456"], where={"url": "https://ai.ai"})
  48. self.assertEqual(resp2, {"ids": [], "metadatas": []})
  49. @pytest.mark.skip(reason="Investigate the issue with the test case.")
  50. @patch("embedchain.vectordb.qdrant.QdrantClient")
  51. @patch.object(uuid, "uuid4", side_effect=TEST_UUIDS)
  52. def test_add(self, uuid_mock, qdrant_client_mock):
  53. qdrant_client_mock.return_value.scroll.return_value = ([], None)
  54. # Set the embedder
  55. embedder = BaseEmbedder()
  56. embedder.set_vector_dimension(1536)
  57. embedder.set_embedding_fn(mock_embedding_fn)
  58. # Create a Qdrant instance
  59. db = QdrantDB()
  60. app_config = AppConfig(collect_metrics=False)
  61. App(config=app_config, db=db, embedding_model=embedder)
  62. documents = ["This is a test document.", "This is another test document."]
  63. metadatas = [{}, {}]
  64. ids = ["123", "456"]
  65. db.add(documents, metadatas, ids)
  66. qdrant_client_mock.return_value.upsert.assert_called_once_with(
  67. collection_name="embedchain-store-1536",
  68. points=Batch(
  69. ids=["abc", "def"],
  70. payloads=[
  71. {
  72. "identifier": "123",
  73. "text": "This is a test document.",
  74. "metadata": {"text": "This is a test document."},
  75. },
  76. {
  77. "identifier": "456",
  78. "text": "This is another test document.",
  79. "metadata": {"text": "This is another test document."},
  80. },
  81. ],
  82. vectors=[[1, 2, 3], [4, 5, 6]],
  83. ),
  84. )
  85. @patch("embedchain.vectordb.qdrant.QdrantClient")
  86. def test_query(self, qdrant_client_mock):
  87. # Set the embedder
  88. embedder = BaseEmbedder()
  89. embedder.set_vector_dimension(1536)
  90. embedder.set_embedding_fn(mock_embedding_fn)
  91. # Create a Qdrant instance
  92. db = QdrantDB()
  93. app_config = AppConfig(collect_metrics=False)
  94. App(config=app_config, db=db, embedding_model=embedder)
  95. # Query for the document.
  96. db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"})
  97. qdrant_client_mock.return_value.search.assert_called_once_with(
  98. collection_name="embedchain-store-1536",
  99. query_filter=models.Filter(
  100. must=[
  101. models.FieldCondition(
  102. key="metadata.doc_id",
  103. match=models.MatchValue(
  104. value="123",
  105. ),
  106. )
  107. ]
  108. ),
  109. query_vector=[1, 2, 3],
  110. limit=1,
  111. )
  112. @patch("embedchain.vectordb.qdrant.QdrantClient")
  113. def test_count(self, qdrant_client_mock):
  114. # Set the embedder
  115. embedder = BaseEmbedder()
  116. embedder.set_vector_dimension(1536)
  117. embedder.set_embedding_fn(mock_embedding_fn)
  118. # Create a Qdrant instance
  119. db = QdrantDB()
  120. app_config = AppConfig(collect_metrics=False)
  121. App(config=app_config, db=db, embedding_model=embedder)
  122. db.count()
  123. qdrant_client_mock.return_value.get_collection.assert_called_once_with(collection_name="embedchain-store-1536")
  124. @patch("embedchain.vectordb.qdrant.QdrantClient")
  125. def test_reset(self, qdrant_client_mock):
  126. # Set the embedder
  127. embedder = BaseEmbedder()
  128. embedder.set_vector_dimension(1536)
  129. embedder.set_embedding_fn(mock_embedding_fn)
  130. # Create a Qdrant instance
  131. db = QdrantDB()
  132. app_config = AppConfig(collect_metrics=False)
  133. App(config=app_config, db=db, embedding_model=embedder)
  134. db.reset()
  135. qdrant_client_mock.return_value.delete_collection.assert_called_once_with(
  136. collection_name="embedchain-store-1536"
  137. )
  138. if __name__ == "__main__":
  139. unittest.main()