test_answer_relevancy_metric.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. import numpy as np
  2. import pytest
  3. from embedchain.config.evaluation.base import AnswerRelevanceConfig
  4. from embedchain.evaluation.metrics import AnswerRelevance
  5. from embedchain.utils.evaluation import EvalData, EvalMetric
  6. @pytest.fixture
  7. def mock_data():
  8. return [
  9. EvalData(
  10. contexts=[
  11. "This is a test context 1.",
  12. ],
  13. question="This is a test question 1.",
  14. answer="This is a test answer 1.",
  15. ),
  16. EvalData(
  17. contexts=[
  18. "This is a test context 2-1.",
  19. "This is a test context 2-2.",
  20. ],
  21. question="This is a test question 2.",
  22. answer="This is a test answer 2.",
  23. ),
  24. ]
  25. @pytest.fixture
  26. def mock_answer_relevance_metric(monkeypatch):
  27. monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
  28. monkeypatch.setenv("OPENAI_API_BASE", "test_api_base")
  29. metric = AnswerRelevance()
  30. return metric
  31. def test_answer_relevance_init(monkeypatch):
  32. monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
  33. metric = AnswerRelevance()
  34. assert metric.name == EvalMetric.ANSWER_RELEVANCY.value
  35. assert metric.config.model == "gpt-4"
  36. assert metric.config.embedder == "text-embedding-ada-002"
  37. assert metric.config.api_key is None
  38. assert metric.config.num_gen_questions == 1
  39. monkeypatch.delenv("OPENAI_API_KEY")
  40. def test_answer_relevance_init_with_config():
  41. metric = AnswerRelevance(config=AnswerRelevanceConfig(api_key="test_api_key"))
  42. assert metric.name == EvalMetric.ANSWER_RELEVANCY.value
  43. assert metric.config.model == "gpt-4"
  44. assert metric.config.embedder == "text-embedding-ada-002"
  45. assert metric.config.api_key == "test_api_key"
  46. assert metric.config.num_gen_questions == 1
  47. def test_answer_relevance_init_without_api_key(monkeypatch):
  48. monkeypatch.delenv("OPENAI_API_KEY", raising=False)
  49. with pytest.raises(ValueError):
  50. AnswerRelevance()
  51. def test_generate_prompt(mock_answer_relevance_metric, mock_data):
  52. prompt = mock_answer_relevance_metric._generate_prompt(mock_data[0])
  53. assert "This is a test answer 1." in prompt
  54. prompt = mock_answer_relevance_metric._generate_prompt(mock_data[1])
  55. assert "This is a test answer 2." in prompt
  56. def test_generate_questions(mock_answer_relevance_metric, mock_data, monkeypatch):
  57. monkeypatch.setattr(
  58. mock_answer_relevance_metric.client.chat.completions,
  59. "create",
  60. lambda model, messages: type(
  61. "obj",
  62. (object,),
  63. {
  64. "choices": [
  65. type(
  66. "obj",
  67. (object,),
  68. {"message": type("obj", (object,), {"content": "This is a test question response.\n"})},
  69. )
  70. ]
  71. },
  72. )(),
  73. )
  74. prompt = mock_answer_relevance_metric._generate_prompt(mock_data[0])
  75. questions = mock_answer_relevance_metric._generate_questions(prompt)
  76. assert len(questions) == 1
  77. monkeypatch.setattr(
  78. mock_answer_relevance_metric.client.chat.completions,
  79. "create",
  80. lambda model, messages: type(
  81. "obj",
  82. (object,),
  83. {
  84. "choices": [
  85. type("obj", (object,), {"message": type("obj", (object,), {"content": "question 1?\nquestion2?"})})
  86. ]
  87. },
  88. )(),
  89. )
  90. prompt = mock_answer_relevance_metric._generate_prompt(mock_data[1])
  91. questions = mock_answer_relevance_metric._generate_questions(prompt)
  92. assert len(questions) == 2
  93. def test_generate_embedding(mock_answer_relevance_metric, mock_data, monkeypatch):
  94. monkeypatch.setattr(
  95. mock_answer_relevance_metric.client.embeddings,
  96. "create",
  97. lambda input, model: type("obj", (object,), {"data": [type("obj", (object,), {"embedding": [1, 2, 3]})]})(),
  98. )
  99. embedding = mock_answer_relevance_metric._generate_embedding("This is a test question.")
  100. assert len(embedding) == 3
  101. def test_compute_similarity(mock_answer_relevance_metric, mock_data):
  102. original = np.array([1, 2, 3])
  103. generated = np.array([[1, 2, 3], [1, 2, 3]])
  104. similarity = mock_answer_relevance_metric._compute_similarity(original, generated)
  105. assert len(similarity) == 2
  106. assert similarity[0] == 1.0
  107. assert similarity[1] == 1.0
  108. def test_compute_score(mock_answer_relevance_metric, mock_data, monkeypatch):
  109. monkeypatch.setattr(
  110. mock_answer_relevance_metric.client.chat.completions,
  111. "create",
  112. lambda model, messages: type(
  113. "obj",
  114. (object,),
  115. {
  116. "choices": [
  117. type(
  118. "obj",
  119. (object,),
  120. {"message": type("obj", (object,), {"content": "This is a test question response.\n"})},
  121. )
  122. ]
  123. },
  124. )(),
  125. )
  126. monkeypatch.setattr(
  127. mock_answer_relevance_metric.client.embeddings,
  128. "create",
  129. lambda input, model: type("obj", (object,), {"data": [type("obj", (object,), {"embedding": [1, 2, 3]})]})(),
  130. )
  131. score = mock_answer_relevance_metric._compute_score(mock_data[0])
  132. assert score == 1.0
  133. monkeypatch.setattr(
  134. mock_answer_relevance_metric.client.chat.completions,
  135. "create",
  136. lambda model, messages: type(
  137. "obj",
  138. (object,),
  139. {
  140. "choices": [
  141. type("obj", (object,), {"message": type("obj", (object,), {"content": "question 1?\nquestion2?"})})
  142. ]
  143. },
  144. )(),
  145. )
  146. monkeypatch.setattr(
  147. mock_answer_relevance_metric.client.embeddings,
  148. "create",
  149. lambda input, model: type("obj", (object,), {"data": [type("obj", (object,), {"embedding": [1, 2, 3]})]})(),
  150. )
  151. score = mock_answer_relevance_metric._compute_score(mock_data[1])
  152. assert score == 1.0
  153. def test_evaluate(mock_answer_relevance_metric, mock_data, monkeypatch):
  154. monkeypatch.setattr(
  155. mock_answer_relevance_metric.client.chat.completions,
  156. "create",
  157. lambda model, messages: type(
  158. "obj",
  159. (object,),
  160. {
  161. "choices": [
  162. type(
  163. "obj",
  164. (object,),
  165. {"message": type("obj", (object,), {"content": "This is a test question response.\n"})},
  166. )
  167. ]
  168. },
  169. )(),
  170. )
  171. monkeypatch.setattr(
  172. mock_answer_relevance_metric.client.embeddings,
  173. "create",
  174. lambda input, model: type("obj", (object,), {"data": [type("obj", (object,), {"embedding": [1, 2, 3]})]})(),
  175. )
  176. score = mock_answer_relevance_metric.evaluate(mock_data)
  177. assert score == 1.0
  178. monkeypatch.setattr(
  179. mock_answer_relevance_metric.client.chat.completions,
  180. "create",
  181. lambda model, messages: type(
  182. "obj",
  183. (object,),
  184. {
  185. "choices": [
  186. type("obj", (object,), {"message": type("obj", (object,), {"content": "question 1?\nquestion2?"})})
  187. ]
  188. },
  189. )(),
  190. )
  191. monkeypatch.setattr(
  192. mock_answer_relevance_metric.client.embeddings,
  193. "create",
  194. lambda input, model: type("obj", (object,), {"data": [type("obj", (object,), {"embedding": [1, 2, 3]})]})(),
  195. )
  196. score = mock_answer_relevance_metric.evaluate(mock_data)
  197. assert score == 1.0