test_elasticsearch_db.py 3.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import os
  2. import unittest
  3. from unittest.mock import patch
  4. from embedchain import App
  5. from embedchain.config import AppConfig, ElasticsearchDBConfig
  6. from embedchain.embedder.gpt4all import GPT4AllEmbedder
  7. from embedchain.vectordb.elasticsearch import ElasticsearchDB
  8. class TestEsDB(unittest.TestCase):
  9. @patch("embedchain.vectordb.elasticsearch.Elasticsearch")
  10. def test_setUp(self, mock_client):
  11. self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
  12. self.vector_dim = 384
  13. app_config = AppConfig(collect_metrics=False)
  14. self.app = App(config=app_config, db=self.db)
  15. # Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
  16. self.assertEqual(self.db.client, mock_client.return_value)
  17. @patch("embedchain.vectordb.elasticsearch.Elasticsearch")
  18. def test_query(self, mock_client):
  19. self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
  20. app_config = AppConfig(collect_metrics=False)
  21. self.app = App(config=app_config, db=self.db, embedding_model=GPT4AllEmbedder())
  22. # Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
  23. self.assertEqual(self.db.client, mock_client.return_value)
  24. # Create some dummy data
  25. documents = ["This is a document.", "This is another document."]
  26. metadatas = [{"url": "url_1", "doc_id": "doc_id_1"}, {"url": "url_2", "doc_id": "doc_id_2"}]
  27. ids = ["doc_1", "doc_2"]
  28. # Add the data to the database.
  29. self.db.add(documents, metadatas, ids)
  30. search_response = {
  31. "hits": {
  32. "hits": [
  33. {
  34. "_source": {"text": "This is a document.", "metadata": {"url": "url_1", "doc_id": "doc_id_1"}},
  35. "_score": 0.9,
  36. },
  37. {
  38. "_source": {
  39. "text": "This is another document.",
  40. "metadata": {"url": "url_2", "doc_id": "doc_id_2"},
  41. },
  42. "_score": 0.8,
  43. },
  44. ]
  45. }
  46. }
  47. # Configure the mock client to return the mocked response.
  48. mock_client.return_value.search.return_value = search_response
  49. # Query the database for the documents that are most similar to the query "This is a document".
  50. query = "This is a document"
  51. results_without_citations = self.db.query(query, n_results=2, where={})
  52. expected_results_without_citations = ["This is a document.", "This is another document."]
  53. self.assertEqual(results_without_citations, expected_results_without_citations)
  54. results_with_citations = self.db.query(query, n_results=2, where={}, citations=True)
  55. expected_results_with_citations = [
  56. ("This is a document.", {"url": "url_1", "doc_id": "doc_id_1", "score": 0.9}),
  57. ("This is another document.", {"url": "url_2", "doc_id": "doc_id_2", "score": 0.8}),
  58. ]
  59. self.assertEqual(results_with_citations, expected_results_with_citations)
  60. def test_init_without_url(self):
  61. # Make sure it's not loaded from env
  62. try:
  63. del os.environ["ELASTICSEARCH_URL"]
  64. except KeyError:
  65. pass
  66. # Test if an exception is raised when an invalid es_config is provided
  67. with self.assertRaises(AttributeError):
  68. ElasticsearchDB()
  69. def test_init_with_invalid_es_config(self):
  70. # Test if an exception is raised when an invalid es_config is provided
  71. with self.assertRaises(TypeError):
  72. ElasticsearchDB(es_config={"ES_URL": "some_url", "valid es_config": False})