test_answer_relevancy_metric.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. import numpy as np
  2. import pytest
  3. from embedchain.config.evaluation.base import AnswerRelevanceConfig
  4. from embedchain.evaluation.metrics import AnswerRelevance
  5. from embedchain.utils.evaluation import EvalData, EvalMetric
  6. @pytest.fixture
  7. def mock_data():
  8. return [
  9. EvalData(
  10. contexts=[
  11. "This is a test context 1.",
  12. ],
  13. question="This is a test question 1.",
  14. answer="This is a test answer 1.",
  15. ),
  16. EvalData(
  17. contexts=[
  18. "This is a test context 2-1.",
  19. "This is a test context 2-2.",
  20. ],
  21. question="This is a test question 2.",
  22. answer="This is a test answer 2.",
  23. ),
  24. ]
  25. @pytest.fixture
  26. def mock_answer_relevance_metric(monkeypatch):
  27. monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
  28. metric = AnswerRelevance()
  29. return metric
  30. def test_answer_relevance_init(monkeypatch):
  31. monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
  32. metric = AnswerRelevance()
  33. assert metric.name == EvalMetric.ANSWER_RELEVANCY.value
  34. assert metric.config.model == "gpt-4"
  35. assert metric.config.embedder == "text-embedding-ada-002"
  36. assert metric.config.api_key is None
  37. assert metric.config.num_gen_questions == 1
  38. monkeypatch.delenv("OPENAI_API_KEY")
  39. def test_answer_relevance_init_with_config():
  40. metric = AnswerRelevance(config=AnswerRelevanceConfig(api_key="test_api_key"))
  41. assert metric.name == EvalMetric.ANSWER_RELEVANCY.value
  42. assert metric.config.model == "gpt-4"
  43. assert metric.config.embedder == "text-embedding-ada-002"
  44. assert metric.config.api_key == "test_api_key"
  45. assert metric.config.num_gen_questions == 1
  46. def test_answer_relevance_init_without_api_key(monkeypatch):
  47. monkeypatch.delenv("OPENAI_API_KEY", raising=False)
  48. with pytest.raises(ValueError):
  49. AnswerRelevance()
  50. def test_generate_prompt(mock_answer_relevance_metric, mock_data):
  51. prompt = mock_answer_relevance_metric._generate_prompt(mock_data[0])
  52. assert "This is a test answer 1." in prompt
  53. prompt = mock_answer_relevance_metric._generate_prompt(mock_data[1])
  54. assert "This is a test answer 2." in prompt
  55. def test_generate_questions(mock_answer_relevance_metric, mock_data, monkeypatch):
  56. monkeypatch.setattr(
  57. mock_answer_relevance_metric.client.chat.completions,
  58. "create",
  59. lambda model, messages: type(
  60. "obj",
  61. (object,),
  62. {
  63. "choices": [
  64. type(
  65. "obj",
  66. (object,),
  67. {"message": type("obj", (object,), {"content": "This is a test question response.\n"})},
  68. )
  69. ]
  70. },
  71. )(),
  72. )
  73. prompt = mock_answer_relevance_metric._generate_prompt(mock_data[0])
  74. questions = mock_answer_relevance_metric._generate_questions(prompt)
  75. assert len(questions) == 1
  76. monkeypatch.setattr(
  77. mock_answer_relevance_metric.client.chat.completions,
  78. "create",
  79. lambda model, messages: type(
  80. "obj",
  81. (object,),
  82. {
  83. "choices": [
  84. type("obj", (object,), {"message": type("obj", (object,), {"content": "question 1?\nquestion2?"})})
  85. ]
  86. },
  87. )(),
  88. )
  89. prompt = mock_answer_relevance_metric._generate_prompt(mock_data[1])
  90. questions = mock_answer_relevance_metric._generate_questions(prompt)
  91. assert len(questions) == 2
  92. def test_generate_embedding(mock_answer_relevance_metric, mock_data, monkeypatch):
  93. monkeypatch.setattr(
  94. mock_answer_relevance_metric.client.embeddings,
  95. "create",
  96. lambda input, model: type("obj", (object,), {"data": [type("obj", (object,), {"embedding": [1, 2, 3]})]})(),
  97. )
  98. embedding = mock_answer_relevance_metric._generate_embedding("This is a test question.")
  99. assert len(embedding) == 3
  100. def test_compute_similarity(mock_answer_relevance_metric, mock_data):
  101. original = np.array([1, 2, 3])
  102. generated = np.array([[1, 2, 3], [1, 2, 3]])
  103. similarity = mock_answer_relevance_metric._compute_similarity(original, generated)
  104. assert len(similarity) == 2
  105. assert similarity[0] == 1.0
  106. assert similarity[1] == 1.0
  107. def test_compute_score(mock_answer_relevance_metric, mock_data, monkeypatch):
  108. monkeypatch.setattr(
  109. mock_answer_relevance_metric.client.chat.completions,
  110. "create",
  111. lambda model, messages: type(
  112. "obj",
  113. (object,),
  114. {
  115. "choices": [
  116. type(
  117. "obj",
  118. (object,),
  119. {"message": type("obj", (object,), {"content": "This is a test question response.\n"})},
  120. )
  121. ]
  122. },
  123. )(),
  124. )
  125. monkeypatch.setattr(
  126. mock_answer_relevance_metric.client.embeddings,
  127. "create",
  128. lambda input, model: type("obj", (object,), {"data": [type("obj", (object,), {"embedding": [1, 2, 3]})]})(),
  129. )
  130. score = mock_answer_relevance_metric._compute_score(mock_data[0])
  131. assert score == 1.0
  132. monkeypatch.setattr(
  133. mock_answer_relevance_metric.client.chat.completions,
  134. "create",
  135. lambda model, messages: type(
  136. "obj",
  137. (object,),
  138. {
  139. "choices": [
  140. type("obj", (object,), {"message": type("obj", (object,), {"content": "question 1?\nquestion2?"})})
  141. ]
  142. },
  143. )(),
  144. )
  145. monkeypatch.setattr(
  146. mock_answer_relevance_metric.client.embeddings,
  147. "create",
  148. lambda input, model: type("obj", (object,), {"data": [type("obj", (object,), {"embedding": [1, 2, 3]})]})(),
  149. )
  150. score = mock_answer_relevance_metric._compute_score(mock_data[1])
  151. assert score == 1.0
  152. def test_evaluate(mock_answer_relevance_metric, mock_data, monkeypatch):
  153. monkeypatch.setattr(
  154. mock_answer_relevance_metric.client.chat.completions,
  155. "create",
  156. lambda model, messages: type(
  157. "obj",
  158. (object,),
  159. {
  160. "choices": [
  161. type(
  162. "obj",
  163. (object,),
  164. {"message": type("obj", (object,), {"content": "This is a test question response.\n"})},
  165. )
  166. ]
  167. },
  168. )(),
  169. )
  170. monkeypatch.setattr(
  171. mock_answer_relevance_metric.client.embeddings,
  172. "create",
  173. lambda input, model: type("obj", (object,), {"data": [type("obj", (object,), {"embedding": [1, 2, 3]})]})(),
  174. )
  175. score = mock_answer_relevance_metric.evaluate(mock_data)
  176. assert score == 1.0
  177. monkeypatch.setattr(
  178. mock_answer_relevance_metric.client.chat.completions,
  179. "create",
  180. lambda model, messages: type(
  181. "obj",
  182. (object,),
  183. {
  184. "choices": [
  185. type("obj", (object,), {"message": type("obj", (object,), {"content": "question 1?\nquestion2?"})})
  186. ]
  187. },
  188. )(),
  189. )
  190. monkeypatch.setattr(
  191. mock_answer_relevance_metric.client.embeddings,
  192. "create",
  193. lambda input, model: type("obj", (object,), {"data": [type("obj", (object,), {"embedding": [1, 2, 3]})]})(),
  194. )
  195. score = mock_answer_relevance_metric.evaluate(mock_data)
  196. assert score == 1.0