test_embedchain.py 1.0 KB

1234567891011121314151617181920212223242526272829
  1. import os
  2. import unittest
  3. from unittest.mock import MagicMock, patch
  4. from embedchain import App
  5. class TestApp(unittest.TestCase):
  6. os.environ["OPENAI_API_KEY"] = "test_key"
  7. def setUp(self):
  8. self.app = App()
  9. @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
  10. def test_add(self):
  11. self.app.add("web_page", "https://example.com")
  12. self.assertEqual(self.app.user_asks, [["web_page", "https://example.com"]])
  13. @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
  14. def test_query(self):
  15. with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
  16. mock_retrieve.return_value = "Test context"
  17. with patch.object(self.app, "get_llm_model_answer") as mock_answer:
  18. mock_answer.return_value = "Test answer"
  19. answer = self.app.query("Test query")
  20. self.assertEqual(answer, "Test answer")
  21. mock_retrieve.assert_called_once_with("Test query")
  22. mock_answer.assert_called_once()