test_qdrant.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. import unittest
  2. import uuid
  3. from mock import patch
  4. from qdrant_client.http import models
  5. from qdrant_client.http.models import Batch
  6. from embedchain import App
  7. from embedchain.config import AppConfig
  8. from embedchain.config.vectordb.pinecone import PineconeDBConfig
  9. from embedchain.embedder.base import BaseEmbedder
  10. from embedchain.vectordb.qdrant import QdrantDB
  11. def mock_embedding_fn(texts: list[str]) -> list[list[float]]:
  12. """A mock embedding function."""
  13. return [[1, 2, 3], [4, 5, 6]]
  14. class TestQdrantDB(unittest.TestCase):
  15. TEST_UUIDS = ["abc", "def", "ghi"]
  16. def test_incorrect_config_throws_error(self):
  17. """Test the init method of the Qdrant class throws error for incorrect config"""
  18. with self.assertRaises(TypeError):
  19. QdrantDB(config=PineconeDBConfig())
  20. @patch("embedchain.vectordb.qdrant.QdrantClient")
  21. def test_initialize(self, qdrant_client_mock):
  22. # Set the embedder
  23. embedder = BaseEmbedder()
  24. embedder.set_vector_dimension(1526)
  25. embedder.set_embedding_fn(mock_embedding_fn)
  26. # Create a Qdrant instance
  27. db = QdrantDB()
  28. app_config = AppConfig(collect_metrics=False)
  29. App(config=app_config, db=db, embedding_model=embedder)
  30. self.assertEqual(db.collection_name, "embedchain-store-1526")
  31. self.assertEqual(db.client, qdrant_client_mock.return_value)
  32. qdrant_client_mock.return_value.get_collections.assert_called_once()
  33. @patch("embedchain.vectordb.qdrant.QdrantClient")
  34. def test_get(self, qdrant_client_mock):
  35. qdrant_client_mock.return_value.scroll.return_value = ([], None)
  36. # Set the embedder
  37. embedder = BaseEmbedder()
  38. embedder.set_vector_dimension(1526)
  39. embedder.set_embedding_fn(mock_embedding_fn)
  40. # Create a Qdrant instance
  41. db = QdrantDB()
  42. app_config = AppConfig(collect_metrics=False)
  43. App(config=app_config, db=db, embedding_model=embedder)
  44. resp = db.get(ids=[], where={})
  45. self.assertEqual(resp, {"ids": []})
  46. resp2 = db.get(ids=["123", "456"], where={"url": "https://ai.ai"})
  47. self.assertEqual(resp2, {"ids": []})
  48. @patch("embedchain.vectordb.qdrant.QdrantClient")
  49. @patch.object(uuid, "uuid4", side_effect=TEST_UUIDS)
  50. def test_add(self, uuid_mock, qdrant_client_mock):
  51. qdrant_client_mock.return_value.scroll.return_value = ([], None)
  52. # Set the embedder
  53. embedder = BaseEmbedder()
  54. embedder.set_vector_dimension(1526)
  55. embedder.set_embedding_fn(mock_embedding_fn)
  56. # Create a Qdrant instance
  57. db = QdrantDB()
  58. app_config = AppConfig(collect_metrics=False)
  59. App(config=app_config, db=db, embedding_model=embedder)
  60. embeddings = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
  61. documents = ["This is a test document.", "This is another test document."]
  62. metadatas = [{}, {}]
  63. ids = ["123", "456"]
  64. db.add(embeddings, documents, metadatas, ids)
  65. qdrant_client_mock.return_value.upsert.assert_called_once_with(
  66. collection_name="embedchain-store-1526",
  67. points=Batch(
  68. ids=["def", "ghi"],
  69. payloads=[
  70. {
  71. "identifier": "123",
  72. "text": "This is a test document.",
  73. "metadata": {"text": "This is a test document."},
  74. },
  75. {
  76. "identifier": "456",
  77. "text": "This is another test document.",
  78. "metadata": {"text": "This is another test document."},
  79. },
  80. ],
  81. vectors=embeddings,
  82. ),
  83. )
  84. @patch("embedchain.vectordb.qdrant.QdrantClient")
  85. def test_query(self, qdrant_client_mock):
  86. # Set the embedder
  87. embedder = BaseEmbedder()
  88. embedder.set_vector_dimension(1526)
  89. embedder.set_embedding_fn(mock_embedding_fn)
  90. # Create a Qdrant instance
  91. db = QdrantDB()
  92. app_config = AppConfig(collect_metrics=False)
  93. App(config=app_config, db=db, embedding_model=embedder)
  94. # Query for the document.
  95. db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"})
  96. qdrant_client_mock.return_value.search.assert_called_once_with(
  97. collection_name="embedchain-store-1526",
  98. query_filter=models.Filter(
  99. must=[
  100. models.FieldCondition(
  101. key="payload.metadata.doc_id",
  102. match=models.MatchValue(
  103. value="123",
  104. ),
  105. )
  106. ]
  107. ),
  108. query_vector=[1, 2, 3],
  109. limit=1,
  110. )
  111. @patch("embedchain.vectordb.qdrant.QdrantClient")
  112. def test_count(self, qdrant_client_mock):
  113. # Set the embedder
  114. embedder = BaseEmbedder()
  115. embedder.set_vector_dimension(1526)
  116. embedder.set_embedding_fn(mock_embedding_fn)
  117. # Create a Qdrant instance
  118. db = QdrantDB()
  119. app_config = AppConfig(collect_metrics=False)
  120. App(config=app_config, db=db, embedding_model=embedder)
  121. db.count()
  122. qdrant_client_mock.return_value.get_collection.assert_called_once_with(collection_name="embedchain-store-1526")
  123. @patch("embedchain.vectordb.qdrant.QdrantClient")
  124. def test_reset(self, qdrant_client_mock):
  125. # Set the embedder
  126. embedder = BaseEmbedder()
  127. embedder.set_vector_dimension(1526)
  128. embedder.set_embedding_fn(mock_embedding_fn)
  129. # Create a Qdrant instance
  130. db = QdrantDB()
  131. app_config = AppConfig(collect_metrics=False)
  132. App(config=app_config, db=db, embedding_model=embedder)
  133. db.reset()
  134. qdrant_client_mock.return_value.delete_collection.assert_called_once_with(
  135. collection_name="embedchain-store-1526"
  136. )
  137. if __name__ == "__main__":
  138. unittest.main()