test_elasticsearch_db.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. import os
  2. import unittest
  3. from unittest.mock import patch
  4. from embedchain import App
  5. from embedchain.config import AppConfig, ElasticsearchDBConfig
  6. from embedchain.embedder.gpt4all import GPT4AllEmbedder
  7. from embedchain.vectordb.elasticsearch import ElasticsearchDB
  8. class TestEsDB(unittest.TestCase):
  9. @patch("embedchain.vectordb.elasticsearch.Elasticsearch")
  10. def test_setUp(self, mock_client):
  11. self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
  12. self.vector_dim = 384
  13. app_config = AppConfig(collection_name=False, collect_metrics=False)
  14. self.app = App(config=app_config, db=self.db)
  15. # Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
  16. self.assertEqual(self.db.client, mock_client.return_value)
  17. @patch("embedchain.vectordb.elasticsearch.Elasticsearch")
  18. def test_query(self, mock_client):
  19. self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
  20. app_config = AppConfig(collection_name=False, collect_metrics=False)
  21. self.app = App(config=app_config, db=self.db, embedder=GPT4AllEmbedder())
  22. # Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
  23. self.assertEqual(self.db.client, mock_client.return_value)
  24. # Create some dummy data.
  25. embeddings = [[1, 2, 3], [4, 5, 6]]
  26. documents = ["This is a document.", "This is another document."]
  27. metadatas = [{"url": "url_1", "doc_id": "doc_id_1"}, {"url": "url_2", "doc_id": "doc_id_2"}]
  28. ids = ["doc_1", "doc_2"]
  29. # Add the data to the database.
  30. self.db.add(embeddings, documents, metadatas, ids, skip_embedding=False)
  31. search_response = {
  32. "hits": {
  33. "hits": [
  34. {
  35. "_source": {"text": "This is a document.", "metadata": {"url": "url_1", "doc_id": "doc_id_1"}},
  36. "_score": 0.9,
  37. },
  38. {
  39. "_source": {
  40. "text": "This is another document.",
  41. "metadata": {"url": "url_2", "doc_id": "doc_id_2"},
  42. },
  43. "_score": 0.8,
  44. },
  45. ]
  46. }
  47. }
  48. # Configure the mock client to return the mocked response.
  49. mock_client.return_value.search.return_value = search_response
  50. # Query the database for the documents that are most similar to the query "This is a document".
  51. query = ["This is a document"]
  52. results_without_citations = self.db.query(query, n_results=2, where={}, skip_embedding=False)
  53. expected_results_without_citations = ["This is a document.", "This is another document."]
  54. self.assertEqual(results_without_citations, expected_results_without_citations)
  55. results_with_citations = self.db.query(query, n_results=2, where={}, skip_embedding=False, citations=True)
  56. expected_results_with_citations = [
  57. ("This is a document.", "url_1", "doc_id_1"),
  58. ("This is another document.", "url_2", "doc_id_2"),
  59. ]
  60. self.assertEqual(results_with_citations, expected_results_with_citations)
  61. @patch("embedchain.vectordb.elasticsearch.Elasticsearch")
  62. def test_query_with_skip_embedding(self, mock_client):
  63. self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
  64. app_config = AppConfig(collection_name=False, collect_metrics=False)
  65. self.app = App(config=app_config, db=self.db)
  66. # Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
  67. self.assertEqual(self.db.client, mock_client.return_value)
  68. # Create some dummy data.
  69. embeddings = [[1, 2, 3], [4, 5, 6]]
  70. documents = ["This is a document.", "This is another document."]
  71. metadatas = [{"url": "url_1", "doc_id": "doc_id_1"}, {"url": "url_2", "doc_id": "doc_id_2"}]
  72. ids = ["doc_1", "doc_2"]
  73. # Add the data to the database.
  74. self.db.add(embeddings, documents, metadatas, ids, skip_embedding=True)
  75. search_response = {
  76. "hits": {
  77. "hits": [
  78. {
  79. "_source": {"text": "This is a document.", "metadata": {"url": "url_1", "doc_id": "doc_id_1"}},
  80. "_score": 0.9,
  81. },
  82. {
  83. "_source": {
  84. "text": "This is another document.",
  85. "metadata": {"url": "url_2", "doc_id": "doc_id_2"},
  86. },
  87. "_score": 0.8,
  88. },
  89. ]
  90. }
  91. }
  92. # Configure the mock client to return the mocked response.
  93. mock_client.return_value.search.return_value = search_response
  94. # Query the database for the documents that are most similar to the query "This is a document".
  95. query = ["This is a document"]
  96. results = self.db.query(query, n_results=2, where={}, skip_embedding=True)
  97. # Assert that the results are correct.
  98. self.assertEqual(results, ["This is a document.", "This is another document."])
  99. def test_init_without_url(self):
  100. # Make sure it's not loaded from env
  101. try:
  102. del os.environ["ELASTICSEARCH_URL"]
  103. except KeyError:
  104. pass
  105. # Test if an exception is raised when an invalid es_config is provided
  106. with self.assertRaises(AttributeError):
  107. ElasticsearchDB()
  108. def test_init_with_invalid_es_config(self):
  109. # Test if an exception is raised when an invalid es_config is provided
  110. with self.assertRaises(TypeError):
  111. ElasticsearchDB(es_config={"ES_URL": "some_url", "valid es_config": False})