test_groundedness_metric.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import numpy as np
  2. import pytest
  3. from embedchain.config.evaluation.base import GroundednessConfig
  4. from embedchain.evaluation.metrics import Groundedness
  5. from embedchain.utils.evaluation import EvalData, EvalMetric
  6. @pytest.fixture
  7. def mock_data():
  8. return [
  9. EvalData(
  10. contexts=[
  11. "This is a test context 1.",
  12. ],
  13. question="This is a test question 1.",
  14. answer="This is a test answer 1.",
  15. ),
  16. EvalData(
  17. contexts=[
  18. "This is a test context 2-1.",
  19. "This is a test context 2-2.",
  20. ],
  21. question="This is a test question 2.",
  22. answer="This is a test answer 2.",
  23. ),
  24. ]
  25. @pytest.fixture
  26. def mock_groundedness_metric(monkeypatch):
  27. monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
  28. metric = Groundedness()
  29. return metric
  30. def test_groundedness_init(monkeypatch):
  31. monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
  32. metric = Groundedness()
  33. assert metric.name == EvalMetric.GROUNDEDNESS.value
  34. assert metric.config.model == "gpt-4"
  35. assert metric.config.api_key is None
  36. monkeypatch.delenv("OPENAI_API_KEY")
  37. def test_groundedness_init_with_config():
  38. metric = Groundedness(config=GroundednessConfig(api_key="test_api_key"))
  39. assert metric.name == EvalMetric.GROUNDEDNESS.value
  40. assert metric.config.model == "gpt-4"
  41. assert metric.config.api_key == "test_api_key"
  42. def test_groundedness_init_without_api_key(monkeypatch):
  43. monkeypatch.delenv("OPENAI_API_KEY", raising=False)
  44. with pytest.raises(ValueError):
  45. Groundedness()
  46. def test_generate_answer_claim_prompt(mock_groundedness_metric, mock_data):
  47. prompt = mock_groundedness_metric._generate_answer_claim_prompt(data=mock_data[0])
  48. assert "This is a test question 1." in prompt
  49. assert "This is a test answer 1." in prompt
  50. def test_get_claim_statements(mock_groundedness_metric, mock_data, monkeypatch):
  51. monkeypatch.setattr(
  52. mock_groundedness_metric.client.chat.completions,
  53. "create",
  54. lambda *args, **kwargs: type(
  55. "obj",
  56. (object,),
  57. {
  58. "choices": [
  59. type(
  60. "obj",
  61. (object,),
  62. {
  63. "message": type(
  64. "obj",
  65. (object,),
  66. {
  67. "content": """This is a test answer 1.
  68. This is a test answer 2.
  69. This is a test answer 3."""
  70. },
  71. )
  72. },
  73. )
  74. ]
  75. },
  76. )(),
  77. )
  78. prompt = mock_groundedness_metric._generate_answer_claim_prompt(data=mock_data[0])
  79. claim_statements = mock_groundedness_metric._get_claim_statements(prompt=prompt)
  80. assert len(claim_statements) == 3
  81. assert "This is a test answer 1." in claim_statements
  82. def test_generate_claim_inference_prompt(mock_groundedness_metric, mock_data):
  83. prompt = mock_groundedness_metric._generate_answer_claim_prompt(data=mock_data[0])
  84. claim_statements = [
  85. "This is a test claim 1.",
  86. "This is a test claim 2.",
  87. ]
  88. prompt = mock_groundedness_metric._generate_claim_inference_prompt(
  89. data=mock_data[0], claim_statements=claim_statements
  90. )
  91. assert "This is a test context 1." in prompt
  92. assert "This is a test claim 1." in prompt
  93. def test_get_claim_verdict_scores(mock_groundedness_metric, mock_data, monkeypatch):
  94. monkeypatch.setattr(
  95. mock_groundedness_metric.client.chat.completions,
  96. "create",
  97. lambda *args, **kwargs: type(
  98. "obj",
  99. (object,),
  100. {"choices": [type("obj", (object,), {"message": type("obj", (object,), {"content": "1\n0\n-1"})})]},
  101. )(),
  102. )
  103. prompt = mock_groundedness_metric._generate_answer_claim_prompt(data=mock_data[0])
  104. claim_statements = mock_groundedness_metric._get_claim_statements(prompt=prompt)
  105. prompt = mock_groundedness_metric._generate_claim_inference_prompt(
  106. data=mock_data[0], claim_statements=claim_statements
  107. )
  108. claim_verdict_scores = mock_groundedness_metric._get_claim_verdict_scores(prompt=prompt)
  109. assert len(claim_verdict_scores) == 3
  110. assert claim_verdict_scores[0] == 1
  111. assert claim_verdict_scores[1] == 0
  112. def test_compute_score(mock_groundedness_metric, mock_data, monkeypatch):
  113. monkeypatch.setattr(
  114. mock_groundedness_metric,
  115. "_get_claim_statements",
  116. lambda *args, **kwargs: np.array(
  117. [
  118. "This is a test claim 1.",
  119. "This is a test claim 2.",
  120. ]
  121. ),
  122. )
  123. monkeypatch.setattr(mock_groundedness_metric, "_get_claim_verdict_scores", lambda *args, **kwargs: np.array([1, 0]))
  124. score = mock_groundedness_metric._compute_score(data=mock_data[0])
  125. assert score == 0.5
  126. def test_evaluate(mock_groundedness_metric, mock_data, monkeypatch):
  127. monkeypatch.setattr(mock_groundedness_metric, "_compute_score", lambda *args, **kwargs: 0.5)
  128. score = mock_groundedness_metric.evaluate(dataset=mock_data)
  129. assert score == 0.5