test_qdrant.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. import unittest
  2. import uuid
  3. from mock import patch
  4. from qdrant_client.http import models
  5. from qdrant_client.http.models import Batch
  6. from embedchain import App
  7. from embedchain.config import AppConfig
  8. from embedchain.config.vector_db.pinecone import PineconeDBConfig
  9. from embedchain.embedder.base import BaseEmbedder
  10. from embedchain.vectordb.qdrant import QdrantDB
  11. def mock_embedding_fn(texts: list[str]) -> list[list[float]]:
  12. """A mock embedding function."""
  13. return [[1, 2, 3], [4, 5, 6]]
  14. class TestQdrantDB(unittest.TestCase):
  15. TEST_UUIDS = ["abc", "def", "ghi"]
  16. def test_incorrect_config_throws_error(self):
  17. """Test the init method of the Qdrant class throws error for incorrect config"""
  18. with self.assertRaises(TypeError):
  19. QdrantDB(config=PineconeDBConfig())
  20. @patch("embedchain.vectordb.qdrant.QdrantClient")
  21. def test_initialize(self, qdrant_client_mock):
  22. # Set the embedder
  23. embedder = BaseEmbedder()
  24. embedder.set_vector_dimension(1536)
  25. embedder.set_embedding_fn(mock_embedding_fn)
  26. # Create a Qdrant instance
  27. db = QdrantDB()
  28. app_config = AppConfig(collect_metrics=False)
  29. App(config=app_config, db=db, embedding_model=embedder)
  30. self.assertEqual(db.collection_name, "embedchain-store-1536")
  31. self.assertEqual(db.client, qdrant_client_mock.return_value)
  32. qdrant_client_mock.return_value.get_collections.assert_called_once()
  33. @patch("embedchain.vectordb.qdrant.QdrantClient")
  34. def test_get(self, qdrant_client_mock):
  35. qdrant_client_mock.return_value.scroll.return_value = ([], None)
  36. # Set the embedder
  37. embedder = BaseEmbedder()
  38. embedder.set_vector_dimension(1536)
  39. embedder.set_embedding_fn(mock_embedding_fn)
  40. # Create a Qdrant instance
  41. db = QdrantDB()
  42. app_config = AppConfig(collect_metrics=False)
  43. App(config=app_config, db=db, embedding_model=embedder)
  44. resp = db.get(ids=[], where={})
  45. self.assertEqual(resp, {"ids": [], "metadatas": []})
  46. resp2 = db.get(ids=["123", "456"], where={"url": "https://ai.ai"})
  47. self.assertEqual(resp2, {"ids": [], "metadatas": []})
  48. @patch("embedchain.vectordb.qdrant.QdrantClient")
  49. @patch.object(uuid, "uuid4", side_effect=TEST_UUIDS)
  50. def test_add(self, uuid_mock, qdrant_client_mock):
  51. qdrant_client_mock.return_value.scroll.return_value = ([], None)
  52. # Set the embedder
  53. embedder = BaseEmbedder()
  54. embedder.set_vector_dimension(1536)
  55. embedder.set_embedding_fn(mock_embedding_fn)
  56. # Create a Qdrant instance
  57. db = QdrantDB()
  58. app_config = AppConfig(collect_metrics=False)
  59. App(config=app_config, db=db, embedding_model=embedder)
  60. documents = ["This is a test document.", "This is another test document."]
  61. metadatas = [{}, {}]
  62. ids = ["123", "456"]
  63. db.add(documents, metadatas, ids)
  64. qdrant_client_mock.return_value.upsert.assert_called_once_with(
  65. collection_name="embedchain-store-1536",
  66. points=Batch(
  67. ids=["123", "456"],
  68. payloads=[
  69. {
  70. "identifier": "123",
  71. "text": "This is a test document.",
  72. "metadata": {"text": "This is a test document."},
  73. },
  74. {
  75. "identifier": "456",
  76. "text": "This is another test document.",
  77. "metadata": {"text": "This is another test document."},
  78. },
  79. ],
  80. vectors=[[1, 2, 3], [4, 5, 6]],
  81. ),
  82. )
  83. @patch("embedchain.vectordb.qdrant.QdrantClient")
  84. def test_query(self, qdrant_client_mock):
  85. # Set the embedder
  86. embedder = BaseEmbedder()
  87. embedder.set_vector_dimension(1536)
  88. embedder.set_embedding_fn(mock_embedding_fn)
  89. # Create a Qdrant instance
  90. db = QdrantDB()
  91. app_config = AppConfig(collect_metrics=False)
  92. App(config=app_config, db=db, embedding_model=embedder)
  93. # Query for the document.
  94. db.query(input_query="This is a test document.", n_results=1, where={"doc_id": "123"})
  95. qdrant_client_mock.return_value.search.assert_called_once_with(
  96. collection_name="embedchain-store-1536",
  97. query_filter=models.Filter(
  98. must=[
  99. models.FieldCondition(
  100. key="metadata.doc_id",
  101. match=models.MatchValue(
  102. value="123",
  103. ),
  104. )
  105. ]
  106. ),
  107. query_vector=[1, 2, 3],
  108. limit=1,
  109. )
  110. @patch("embedchain.vectordb.qdrant.QdrantClient")
  111. def test_count(self, qdrant_client_mock):
  112. # Set the embedder
  113. embedder = BaseEmbedder()
  114. embedder.set_vector_dimension(1536)
  115. embedder.set_embedding_fn(mock_embedding_fn)
  116. # Create a Qdrant instance
  117. db = QdrantDB()
  118. app_config = AppConfig(collect_metrics=False)
  119. App(config=app_config, db=db, embedding_model=embedder)
  120. db.count()
  121. qdrant_client_mock.return_value.get_collection.assert_called_once_with(collection_name="embedchain-store-1536")
  122. @patch("embedchain.vectordb.qdrant.QdrantClient")
  123. def test_reset(self, qdrant_client_mock):
  124. # Set the embedder
  125. embedder = BaseEmbedder()
  126. embedder.set_vector_dimension(1536)
  127. embedder.set_embedding_fn(mock_embedding_fn)
  128. # Create a Qdrant instance
  129. db = QdrantDB()
  130. app_config = AppConfig(collect_metrics=False)
  131. App(config=app_config, db=db, embedding_model=embedder)
  132. db.reset()
  133. qdrant_client_mock.return_value.delete_collection.assert_called_once_with(
  134. collection_name="embedchain-store-1536"
  135. )
  136. if __name__ == "__main__":
  137. unittest.main()