test_qdrant.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. import unittest
  2. import uuid
  3. from mock import patch
  4. from qdrant_client.http import models
  5. from qdrant_client.http.models import Batch
  6. from embedchain import App
  7. from embedchain.config import AppConfig
  8. from embedchain.config.vectordb.pinecone import PineconeDBConfig
  9. from embedchain.embedder.base import BaseEmbedder
  10. from embedchain.vectordb.qdrant import QdrantDB
  11. def mock_embedding_fn(texts: list[str]) -> list[list[float]]:
  12. """A mock embedding function."""
  13. return [[1, 2, 3], [4, 5, 6]]
  14. class TestQdrantDB(unittest.TestCase):
  15. TEST_UUIDS = ["abc", "def", "ghi"]
  16. def test_incorrect_config_throws_error(self):
  17. """Test the init method of the Qdrant class throws error for incorrect config"""
  18. with self.assertRaises(TypeError):
  19. QdrantDB(config=PineconeDBConfig())
  20. @patch("embedchain.vectordb.qdrant.QdrantClient")
  21. def test_initialize(self, qdrant_client_mock):
  22. # Set the embedder
  23. embedder = BaseEmbedder()
  24. embedder.set_vector_dimension(1536)
  25. embedder.set_embedding_fn(mock_embedding_fn)
  26. # Create a Qdrant instance
  27. db = QdrantDB()
  28. app_config = AppConfig(collect_metrics=False)
  29. App(config=app_config, db=db, embedding_model=embedder)
  30. self.assertEqual(db.collection_name, "embedchain-store-1536")
  31. self.assertEqual(db.client, qdrant_client_mock.return_value)
  32. qdrant_client_mock.return_value.get_collections.assert_called_once()
  33. @patch("embedchain.vectordb.qdrant.QdrantClient")
  34. def test_get(self, qdrant_client_mock):
  35. qdrant_client_mock.return_value.scroll.return_value = ([], None)
  36. # Set the embedder
  37. embedder = BaseEmbedder()
  38. embedder.set_vector_dimension(1536)
  39. embedder.set_embedding_fn(mock_embedding_fn)
  40. # Create a Qdrant instance
  41. db = QdrantDB()
  42. app_config = AppConfig(collect_metrics=False)
  43. App(config=app_config, db=db, embedding_model=embedder)
  44. resp = db.get(ids=[], where={})
  45. self.assertEqual(resp, {"ids": [], "metadatas": []})
  46. resp2 = db.get(ids=["123", "456"], where={"url": "https://ai.ai"})
  47. self.assertEqual(resp2, {"ids": [], "metadatas": []})
  48. @patch("embedchain.vectordb.qdrant.QdrantClient")
  49. @patch.object(uuid, "uuid4", side_effect=TEST_UUIDS)
  50. def test_add(self, uuid_mock, qdrant_client_mock):
  51. qdrant_client_mock.return_value.scroll.return_value = ([], None)
  52. # Set the embedder
  53. embedder = BaseEmbedder()
  54. embedder.set_vector_dimension(1536)
  55. embedder.set_embedding_fn(mock_embedding_fn)
  56. # Create a Qdrant instance
  57. db = QdrantDB()
  58. app_config = AppConfig(collect_metrics=False)
  59. App(config=app_config, db=db, embedding_model=embedder)
  60. documents = ["This is a test document.", "This is another test document."]
  61. metadatas = [{}, {}]
  62. ids = ["123", "456"]
  63. db.add(documents, metadatas, ids)
  64. qdrant_client_mock.return_value.upsert.assert_called_once_with(
  65. collection_name="embedchain-store-1536",
  66. points=Batch(
  67. ids=["123", "456"],
  68. payloads=[
  69. {
  70. "identifier": "123",
  71. "text": "This is a test document.",
  72. "metadata": {"text": "This is a test document."},
  73. },
  74. {
  75. "identifier": "456",
  76. "text": "This is another test document.",
  77. "metadata": {"text": "This is another test document."},
  78. },
  79. ],
  80. vectors=[[1, 2, 3], [4, 5, 6]],
  81. ),
  82. )
  83. @patch("embedchain.vectordb.qdrant.QdrantClient")
  84. def test_query(self, qdrant_client_mock):
  85. # Set the embedder
  86. embedder = BaseEmbedder()
  87. embedder.set_vector_dimension(1536)
  88. embedder.set_embedding_fn(mock_embedding_fn)
  89. # Create a Qdrant instance
  90. db = QdrantDB()
  91. app_config = AppConfig(collect_metrics=False)
  92. App(config=app_config, db=db, embedding_model=embedder)
  93. # Query for the document.
  94. db.query(input_query="This is a test document.", n_results=1, where={"doc_id": "123"})
  95. qdrant_client_mock.return_value.search.assert_called_once_with(
  96. collection_name="embedchain-store-1536",
  97. query_filter=models.Filter(
  98. must=[
  99. models.FieldCondition(
  100. key="metadata.doc_id",
  101. match=models.MatchValue(
  102. value="123",
  103. ),
  104. )
  105. ]
  106. ),
  107. query_vector=[1, 2, 3],
  108. limit=1,
  109. )
  110. @patch("embedchain.vectordb.qdrant.QdrantClient")
  111. def test_count(self, qdrant_client_mock):
  112. # Set the embedder
  113. embedder = BaseEmbedder()
  114. embedder.set_vector_dimension(1536)
  115. embedder.set_embedding_fn(mock_embedding_fn)
  116. # Create a Qdrant instance
  117. db = QdrantDB()
  118. app_config = AppConfig(collect_metrics=False)
  119. App(config=app_config, db=db, embedding_model=embedder)
  120. db.count()
  121. qdrant_client_mock.return_value.get_collection.assert_called_once_with(collection_name="embedchain-store-1536")
  122. @patch("embedchain.vectordb.qdrant.QdrantClient")
  123. def test_reset(self, qdrant_client_mock):
  124. # Set the embedder
  125. embedder = BaseEmbedder()
  126. embedder.set_vector_dimension(1536)
  127. embedder.set_embedding_fn(mock_embedding_fn)
  128. # Create a Qdrant instance
  129. db = QdrantDB()
  130. app_config = AppConfig(collect_metrics=False)
  131. App(config=app_config, db=db, embedding_model=embedder)
  132. db.reset()
  133. qdrant_client_mock.return_value.delete_collection.assert_called_once_with(
  134. collection_name="embedchain-store-1536"
  135. )
  136. if __name__ == "__main__":
  137. unittest.main()