test_context_relevancy_metric.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import pytest
  2. from embedchain.config.evaluation.base import ContextRelevanceConfig
  3. from embedchain.evaluation.metrics import ContextRelevance
  4. from embedchain.utils.evaluation import EvalData, EvalMetric
  5. @pytest.fixture
  6. def mock_data():
  7. return [
  8. EvalData(
  9. contexts=[
  10. "This is a test context 1.",
  11. ],
  12. question="This is a test question 1.",
  13. answer="This is a test answer 1.",
  14. ),
  15. EvalData(
  16. contexts=[
  17. "This is a test context 2-1.",
  18. "This is a test context 2-2.",
  19. ],
  20. question="This is a test question 2.",
  21. answer="This is a test answer 2.",
  22. ),
  23. ]
  24. @pytest.fixture
  25. def mock_context_relevance_metric(monkeypatch):
  26. monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
  27. metric = ContextRelevance()
  28. return metric
  29. def test_context_relevance_init(monkeypatch):
  30. monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
  31. metric = ContextRelevance()
  32. assert metric.name == EvalMetric.CONTEXT_RELEVANCY.value
  33. assert metric.config.model == "gpt-4"
  34. assert metric.config.api_key is None
  35. assert metric.config.language == "en"
  36. monkeypatch.delenv("OPENAI_API_KEY")
  37. def test_context_relevance_init_with_config():
  38. metric = ContextRelevance(config=ContextRelevanceConfig(api_key="test_api_key"))
  39. assert metric.name == EvalMetric.CONTEXT_RELEVANCY.value
  40. assert metric.config.model == "gpt-4"
  41. assert metric.config.api_key == "test_api_key"
  42. assert metric.config.language == "en"
  43. def test_context_relevance_init_without_api_key(monkeypatch):
  44. monkeypatch.delenv("OPENAI_API_KEY", raising=False)
  45. with pytest.raises(ValueError):
  46. ContextRelevance()
  47. def test_sentence_segmenter(mock_context_relevance_metric):
  48. text = "This is a test sentence. This is another sentence."
  49. assert mock_context_relevance_metric._sentence_segmenter(text) == [
  50. "This is a test sentence. ",
  51. "This is another sentence.",
  52. ]
  53. def test_compute_score(mock_context_relevance_metric, mock_data, monkeypatch):
  54. monkeypatch.setattr(
  55. mock_context_relevance_metric.client.chat.completions,
  56. "create",
  57. lambda model, messages: type(
  58. "obj",
  59. (object,),
  60. {
  61. "choices": [
  62. type("obj", (object,), {"message": type("obj", (object,), {"content": "This is a test reponse."})})
  63. ]
  64. },
  65. )(),
  66. )
  67. assert mock_context_relevance_metric._compute_score(mock_data[0]) == 1.0
  68. assert mock_context_relevance_metric._compute_score(mock_data[1]) == 0.5
  69. def test_evaluate(mock_context_relevance_metric, mock_data, monkeypatch):
  70. monkeypatch.setattr(
  71. mock_context_relevance_metric.client.chat.completions,
  72. "create",
  73. lambda model, messages: type(
  74. "obj",
  75. (object,),
  76. {
  77. "choices": [
  78. type("obj", (object,), {"message": type("obj", (object,), {"content": "This is a test reponse."})})
  79. ]
  80. },
  81. )(),
  82. )
  83. assert mock_context_relevance_metric.evaluate(mock_data) == 0.75