test_base.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import os
  2. import pytest
  3. from embedchain.config import AddConfig, BaseLlmConfig
  4. from embedchain.bots.base import BaseBot
  5. from unittest.mock import patch
  6. @pytest.fixture
  7. def base_bot():
  8. os.environ["OPENAI_API_KEY"] = "test_api_key" # needed by App
  9. return BaseBot()
  10. def test_add(base_bot):
  11. data = "Test data"
  12. config = AddConfig()
  13. with patch.object(base_bot.app, "add") as mock_add:
  14. base_bot.add(data, config)
  15. mock_add.assert_called_with(data, config=config)
  16. def test_query(base_bot):
  17. query = "Test query"
  18. config = BaseLlmConfig()
  19. with patch.object(base_bot.app, "query") as mock_query:
  20. mock_query.return_value = "Query result"
  21. result = base_bot.query(query, config)
  22. assert isinstance(result, str)
  23. assert result == "Query result"
  24. def test_start():
  25. class TestBot(BaseBot):
  26. def start(self):
  27. return "Bot started"
  28. bot = TestBot()
  29. result = bot.start()
  30. assert result == "Bot started"
  31. def test_start_not_implemented():
  32. bot = BaseBot()
  33. with pytest.raises(NotImplementedError):
  34. bot.start()