test_qdrant.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. import unittest
  2. import uuid
  3. from mock import patch
  4. from qdrant_client.http import models
  5. from qdrant_client.http.models import Batch
  6. from embedchain import App
  7. from embedchain.config import AppConfig
  8. from embedchain.config.vectordb.pinecone import PineconeDBConfig
  9. from embedchain.embedder.base import BaseEmbedder
  10. from embedchain.vectordb.qdrant import QdrantDB
  11. class TestQdrantDB(unittest.TestCase):
  12. TEST_UUIDS = ["abc", "def", "ghi"]
  13. def test_incorrect_config_throws_error(self):
  14. """Test the init method of the Qdrant class throws error for incorrect config"""
  15. with self.assertRaises(TypeError):
  16. QdrantDB(config=PineconeDBConfig())
  17. @patch("embedchain.vectordb.qdrant.QdrantClient")
  18. def test_initialize(self, qdrant_client_mock):
  19. # Set the embedder
  20. embedder = BaseEmbedder()
  21. embedder.set_vector_dimension(1526)
  22. # Create a Qdrant instance
  23. db = QdrantDB()
  24. app_config = AppConfig(collect_metrics=False)
  25. App(config=app_config, db=db, embedder=embedder)
  26. self.assertEqual(db.collection_name, "embedchain-store-1526")
  27. self.assertEqual(db.client, qdrant_client_mock.return_value)
  28. qdrant_client_mock.return_value.get_collections.assert_called_once()
  29. @patch("embedchain.vectordb.qdrant.QdrantClient")
  30. def test_get(self, qdrant_client_mock):
  31. qdrant_client_mock.return_value.scroll.return_value = ([], None)
  32. # Set the embedder
  33. embedder = BaseEmbedder()
  34. embedder.set_vector_dimension(1526)
  35. # Create a Qdrant instance
  36. db = QdrantDB()
  37. app_config = AppConfig(collect_metrics=False)
  38. App(config=app_config, db=db, embedder=embedder)
  39. resp = db.get(ids=[], where={})
  40. self.assertEqual(resp, {"ids": []})
  41. resp2 = db.get(ids=["123", "456"], where={"url": "https://ai.ai"})
  42. self.assertEqual(resp2, {"ids": []})
  43. @patch("embedchain.vectordb.qdrant.QdrantClient")
  44. @patch.object(uuid, "uuid4", side_effect=TEST_UUIDS)
  45. def test_add(self, uuid_mock, qdrant_client_mock):
  46. qdrant_client_mock.return_value.scroll.return_value = ([], None)
  47. # Set the embedder
  48. embedder = BaseEmbedder()
  49. embedder.set_vector_dimension(1526)
  50. # Create a Qdrant instance
  51. db = QdrantDB()
  52. app_config = AppConfig(collect_metrics=False)
  53. App(config=app_config, db=db, embedder=embedder)
  54. embeddings = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
  55. documents = ["This is a test document.", "This is another test document."]
  56. metadatas = [{}, {}]
  57. ids = ["123", "456"]
  58. skip_embedding = True
  59. db.add(embeddings, documents, metadatas, ids, skip_embedding)
  60. qdrant_client_mock.return_value.upsert.assert_called_once_with(
  61. collection_name="embedchain-store-1526",
  62. points=Batch(
  63. ids=["def", "ghi"],
  64. payloads=[
  65. {
  66. "identifier": "123",
  67. "text": "This is a test document.",
  68. "metadata": {"text": "This is a test document."},
  69. },
  70. {
  71. "identifier": "456",
  72. "text": "This is another test document.",
  73. "metadata": {"text": "This is another test document."},
  74. },
  75. ],
  76. vectors=embeddings,
  77. ),
  78. )
  79. @patch("embedchain.vectordb.qdrant.QdrantClient")
  80. def test_query(self, qdrant_client_mock):
  81. # Set the embedder
  82. embedder = BaseEmbedder()
  83. embedder.set_vector_dimension(1526)
  84. # Create a Qdrant instance
  85. db = QdrantDB()
  86. app_config = AppConfig(collect_metrics=False)
  87. App(config=app_config, db=db, embedder=embedder)
  88. # Query for the document.
  89. db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"}, skip_embedding=True)
  90. qdrant_client_mock.return_value.search.assert_called_once_with(
  91. collection_name="embedchain-store-1526",
  92. query_filter=models.Filter(
  93. must=[
  94. models.FieldCondition(
  95. key="payload.metadata.doc_id",
  96. match=models.MatchValue(
  97. value="123",
  98. ),
  99. )
  100. ]
  101. ),
  102. query_vector=["This is a test document."],
  103. limit=1,
  104. )
  105. @patch("embedchain.vectordb.qdrant.QdrantClient")
  106. def test_count(self, qdrant_client_mock):
  107. # Set the embedder
  108. embedder = BaseEmbedder()
  109. embedder.set_vector_dimension(1526)
  110. # Create a Qdrant instance
  111. db = QdrantDB()
  112. app_config = AppConfig(collect_metrics=False)
  113. App(config=app_config, db=db, embedder=embedder)
  114. db.count()
  115. qdrant_client_mock.return_value.get_collection.assert_called_once_with(collection_name="embedchain-store-1526")
  116. @patch("embedchain.vectordb.qdrant.QdrantClient")
  117. def test_reset(self, qdrant_client_mock):
  118. # Set the embedder
  119. embedder = BaseEmbedder()
  120. embedder.set_vector_dimension(1526)
  121. # Create a Qdrant instance
  122. db = QdrantDB()
  123. app_config = AppConfig(collect_metrics=False)
  124. App(config=app_config, db=db, embedder=embedder)
  125. db.reset()
  126. qdrant_client_mock.return_value.delete_collection.assert_called_once_with(
  127. collection_name="embedchain-store-1526"
  128. )
  129. if __name__ == "__main__":
  130. unittest.main()