test_poe.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import argparse
  2. import pytest
  3. from fastapi_poe.types import ProtocolMessage, QueryRequest
  4. from embedchain.bots.poe import PoeBot, start_command
  5. @pytest.fixture
  6. def poe_bot(mocker):
  7. bot = PoeBot()
  8. mocker.patch("fastapi_poe.run")
  9. return bot
  10. @pytest.mark.asyncio
  11. async def test_poe_bot_get_response(poe_bot, mocker):
  12. query = QueryRequest(
  13. version="test",
  14. type="query",
  15. query=[ProtocolMessage(role="system", content="Test content")],
  16. user_id="test_user_id",
  17. conversation_id="test_conversation_id",
  18. message_id="test_message_id",
  19. )
  20. mocker.patch.object(poe_bot.app.llm, "set_history")
  21. response_generator = poe_bot.get_response(query)
  22. await response_generator.__anext__()
  23. poe_bot.app.llm.set_history.assert_called_once()
  24. def test_poe_bot_handle_message(poe_bot, mocker):
  25. mocker.patch.object(poe_bot, "ask_bot", return_value="Answer from the bot")
  26. response_ask = poe_bot.handle_message("What is the answer?")
  27. assert response_ask == "Answer from the bot"
  28. # TODO: This test will fail because the add_data method is commented out.
  29. # mocker.patch.object(poe_bot, 'add_data', return_value="Added data from: some_data")
  30. # response_add = poe_bot.handle_message("/add some_data")
  31. # assert response_add == "Added data from: some_data"
  32. def test_start_command(mocker):
  33. mocker.patch("argparse.ArgumentParser.parse_args", return_value=argparse.Namespace(api_key="test_api_key"))
  34. mocker.patch("embedchain.bots.poe.run")
  35. start_command()