소스 검색

[Tests] add tests for evaluation metrics (#1174)

Co-authored-by: Deven Patel <deven298@yahoo.com>
Deven Patel 1 년 전
부모
커밋
2784bae772

+ 1 - 1
embedchain/evaluation/metrics/groundedness.py

@@ -21,7 +21,7 @@ class Groundedness(BaseMetric):
     def __init__(self, config: Optional[GroundednessConfig] = None):
         super().__init__(name=EvalMetric.GROUNDEDNESS.value)
         self.config = config or GroundednessConfig()
-        api_key = self.config.api_key or os.environ["OPENAI_API_KEY"]
+        api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
         if not api_key:
             raise ValueError("Please set the OPENAI_API_KEY environment variable or pass the `api_key` in config.")
         self.client = OpenAI(api_key=api_key)

+ 223 - 0
tests/evaluation/test_answer_relevancy_metric.py

@@ -0,0 +1,223 @@
+import numpy as np
+import pytest
+
+from embedchain.config.evaluation.base import AnswerRelevanceConfig
+from embedchain.evaluation.metrics import AnswerRelevance
+from embedchain.utils.evaluation import EvalData, EvalMetric
+
+
+@pytest.fixture
+def mock_data():
+    return [
+        EvalData(
+            contexts=[
+                "This is a test context 1.",
+            ],
+            question="This is a test question 1.",
+            answer="This is a test answer 1.",
+        ),
+        EvalData(
+            contexts=[
+                "This is a test context 2-1.",
+                "This is a test context 2-2.",
+            ],
+            question="This is a test question 2.",
+            answer="This is a test answer 2.",
+        ),
+    ]
+
+
+@pytest.fixture
+def mock_answer_relevance_metric(monkeypatch):
+    monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
+    metric = AnswerRelevance()
+    return metric
+
+
+def test_answer_relevance_init(monkeypatch):
+    monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
+    metric = AnswerRelevance()
+    assert metric.name == EvalMetric.ANSWER_RELEVANCY.value
+    assert metric.config.model == "gpt-4"
+    assert metric.config.embedder == "text-embedding-ada-002"
+    assert metric.config.api_key is None
+    assert metric.config.num_gen_questions == 1
+    monkeypatch.delenv("OPENAI_API_KEY")
+
+
+def test_answer_relevance_init_with_config():
+    metric = AnswerRelevance(config=AnswerRelevanceConfig(api_key="test_api_key"))
+    assert metric.name == EvalMetric.ANSWER_RELEVANCY.value
+    assert metric.config.model == "gpt-4"
+    assert metric.config.embedder == "text-embedding-ada-002"
+    assert metric.config.api_key == "test_api_key"
+    assert metric.config.num_gen_questions == 1
+
+
+def test_answer_relevance_init_without_api_key(monkeypatch):
+    monkeypatch.delenv("OPENAI_API_KEY", raising=False)
+    with pytest.raises(ValueError):
+        AnswerRelevance()
+
+
+def test_generate_prompt(mock_answer_relevance_metric, mock_data):
+    prompt = mock_answer_relevance_metric._generate_prompt(mock_data[0])
+    assert "This is a test answer 1." in prompt
+
+    prompt = mock_answer_relevance_metric._generate_prompt(mock_data[1])
+    assert "This is a test answer 2." in prompt
+
+
+def test_generate_questions(mock_answer_relevance_metric, mock_data, monkeypatch):
+    monkeypatch.setattr(
+        mock_answer_relevance_metric.client.chat.completions,
+        "create",
+        lambda model, messages: type(
+            "obj",
+            (object,),
+            {
+                "choices": [
+                    type(
+                        "obj",
+                        (object,),
+                        {"message": type("obj", (object,), {"content": "This is a test question response.\n"})},
+                    )
+                ]
+            },
+        )(),
+    )
+    prompt = mock_answer_relevance_metric._generate_prompt(mock_data[0])
+    questions = mock_answer_relevance_metric._generate_questions(prompt)
+    assert len(questions) == 1
+
+    monkeypatch.setattr(
+        mock_answer_relevance_metric.client.chat.completions,
+        "create",
+        lambda model, messages: type(
+            "obj",
+            (object,),
+            {
+                "choices": [
+                    type("obj", (object,), {"message": type("obj", (object,), {"content": "question 1?\nquestion2?"})})
+                ]
+            },
+        )(),
+    )
+    prompt = mock_answer_relevance_metric._generate_prompt(mock_data[1])
+    questions = mock_answer_relevance_metric._generate_questions(prompt)
+    assert len(questions) == 2
+
+
+def test_generate_embedding(mock_answer_relevance_metric, mock_data, monkeypatch):
+    monkeypatch.setattr(
+        mock_answer_relevance_metric.client.embeddings,
+        "create",
+        lambda input, model: type("obj", (object,), {"data": [type("obj", (object,), {"embedding": [1, 2, 3]})]})(),
+    )
+    embedding = mock_answer_relevance_metric._generate_embedding("This is a test question.")
+    assert len(embedding) == 3
+
+
+def test_compute_similarity(mock_answer_relevance_metric, mock_data):
+    original = np.array([1, 2, 3])
+    generated = np.array([[1, 2, 3], [1, 2, 3]])
+    similarity = mock_answer_relevance_metric._compute_similarity(original, generated)
+    assert len(similarity) == 2
+    assert similarity[0] == 1.0
+    assert similarity[1] == 1.0
+
+
+def test_compute_score(mock_answer_relevance_metric, mock_data, monkeypatch):
+    monkeypatch.setattr(
+        mock_answer_relevance_metric.client.chat.completions,
+        "create",
+        lambda model, messages: type(
+            "obj",
+            (object,),
+            {
+                "choices": [
+                    type(
+                        "obj",
+                        (object,),
+                        {"message": type("obj", (object,), {"content": "This is a test question response.\n"})},
+                    )
+                ]
+            },
+        )(),
+    )
+    monkeypatch.setattr(
+        mock_answer_relevance_metric.client.embeddings,
+        "create",
+        lambda input, model: type("obj", (object,), {"data": [type("obj", (object,), {"embedding": [1, 2, 3]})]})(),
+    )
+    score = mock_answer_relevance_metric._compute_score(mock_data[0])
+    assert score == 1.0
+
+    monkeypatch.setattr(
+        mock_answer_relevance_metric.client.chat.completions,
+        "create",
+        lambda model, messages: type(
+            "obj",
+            (object,),
+            {
+                "choices": [
+                    type("obj", (object,), {"message": type("obj", (object,), {"content": "question 1?\nquestion2?"})})
+                ]
+            },
+        )(),
+    )
+    monkeypatch.setattr(
+        mock_answer_relevance_metric.client.embeddings,
+        "create",
+        lambda input, model: type("obj", (object,), {"data": [type("obj", (object,), {"embedding": [1, 2, 3]})]})(),
+    )
+    score = mock_answer_relevance_metric._compute_score(mock_data[1])
+    assert score == 1.0
+
+
+def test_evaluate(mock_answer_relevance_metric, mock_data, monkeypatch):
+    monkeypatch.setattr(
+        mock_answer_relevance_metric.client.chat.completions,
+        "create",
+        lambda model, messages: type(
+            "obj",
+            (object,),
+            {
+                "choices": [
+                    type(
+                        "obj",
+                        (object,),
+                        {"message": type("obj", (object,), {"content": "This is a test question response.\n"})},
+                    )
+                ]
+            },
+        )(),
+    )
+    monkeypatch.setattr(
+        mock_answer_relevance_metric.client.embeddings,
+        "create",
+        lambda input, model: type("obj", (object,), {"data": [type("obj", (object,), {"embedding": [1, 2, 3]})]})(),
+    )
+    score = mock_answer_relevance_metric.evaluate(mock_data)
+    assert score == 1.0
+
+    monkeypatch.setattr(
+        mock_answer_relevance_metric.client.chat.completions,
+        "create",
+        lambda model, messages: type(
+            "obj",
+            (object,),
+            {
+                "choices": [
+                    type("obj", (object,), {"message": type("obj", (object,), {"content": "question 1?\nquestion2?"})})
+                ]
+            },
+        )(),
+    )
+    monkeypatch.setattr(
+        mock_answer_relevance_metric.client.embeddings,
+        "create",
+        lambda input, model: type("obj", (object,), {"data": [type("obj", (object,), {"embedding": [1, 2, 3]})]})(),
+    )
+    score = mock_answer_relevance_metric.evaluate(mock_data)
+    assert score == 1.0

+ 100 - 0
tests/evaluation/test_context_relevancy_metric.py

@@ -0,0 +1,100 @@
+import pytest
+
+from embedchain.config.evaluation.base import ContextRelevanceConfig
+from embedchain.evaluation.metrics import ContextRelevance
+from embedchain.utils.evaluation import EvalData, EvalMetric
+
+
+@pytest.fixture
+def mock_data():
+    return [
+        EvalData(
+            contexts=[
+                "This is a test context 1.",
+            ],
+            question="This is a test question 1.",
+            answer="This is a test answer 1.",
+        ),
+        EvalData(
+            contexts=[
+                "This is a test context 2-1.",
+                "This is a test context 2-2.",
+            ],
+            question="This is a test question 2.",
+            answer="This is a test answer 2.",
+        ),
+    ]
+
+
+@pytest.fixture
+def mock_context_relevance_metric(monkeypatch):
+    monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
+    metric = ContextRelevance()
+    return metric
+
+
+def test_context_relevance_init(monkeypatch):
+    monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
+    metric = ContextRelevance()
+    assert metric.name == EvalMetric.CONTEXT_RELEVANCY.value
+    assert metric.config.model == "gpt-4"
+    assert metric.config.api_key is None
+    assert metric.config.language == "en"
+    monkeypatch.delenv("OPENAI_API_KEY")
+
+
+def test_context_relevance_init_with_config():
+    metric = ContextRelevance(config=ContextRelevanceConfig(api_key="test_api_key"))
+    assert metric.name == EvalMetric.CONTEXT_RELEVANCY.value
+    assert metric.config.model == "gpt-4"
+    assert metric.config.api_key == "test_api_key"
+    assert metric.config.language == "en"
+
+
+def test_context_relevance_init_without_api_key(monkeypatch):
+    monkeypatch.delenv("OPENAI_API_KEY", raising=False)
+    with pytest.raises(ValueError):
+        ContextRelevance()
+
+
+def test_sentence_segmenter(mock_context_relevance_metric):
+    text = "This is a test sentence. This is another sentence."
+    assert mock_context_relevance_metric._sentence_segmenter(text) == [
+        "This is a test sentence. ",
+        "This is another sentence.",
+    ]
+
+
+def test_compute_score(mock_context_relevance_metric, mock_data, monkeypatch):
+    monkeypatch.setattr(
+        mock_context_relevance_metric.client.chat.completions,
+        "create",
+        lambda model, messages: type(
+            "obj",
+            (object,),
+            {
+                "choices": [
+                    type("obj", (object,), {"message": type("obj", (object,), {"content": "This is a test reponse."})})
+                ]
+            },
+        )(),
+    )
+    assert mock_context_relevance_metric._compute_score(mock_data[0]) == 1.0
+    assert mock_context_relevance_metric._compute_score(mock_data[1]) == 0.5
+
+
+def test_evaluate(mock_context_relevance_metric, mock_data, monkeypatch):
+    monkeypatch.setattr(
+        mock_context_relevance_metric.client.chat.completions,
+        "create",
+        lambda model, messages: type(
+            "obj",
+            (object,),
+            {
+                "choices": [
+                    type("obj", (object,), {"message": type("obj", (object,), {"content": "This is a test reponse."})})
+                ]
+            },
+        )(),
+    )
+    assert mock_context_relevance_metric.evaluate(mock_data) == 0.75

+ 152 - 0
tests/evaluation/test_groundedness_metric.py

@@ -0,0 +1,152 @@
+import numpy as np
+import pytest
+
+from embedchain.config.evaluation.base import GroundednessConfig
+from embedchain.evaluation.metrics import Groundedness
+from embedchain.utils.evaluation import EvalData, EvalMetric
+
+
+@pytest.fixture
+def mock_data():
+    return [
+        EvalData(
+            contexts=[
+                "This is a test context 1.",
+            ],
+            question="This is a test question 1.",
+            answer="This is a test answer 1.",
+        ),
+        EvalData(
+            contexts=[
+                "This is a test context 2-1.",
+                "This is a test context 2-2.",
+            ],
+            question="This is a test question 2.",
+            answer="This is a test answer 2.",
+        ),
+    ]
+
+
+@pytest.fixture
+def mock_groundedness_metric(monkeypatch):
+    monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
+    metric = Groundedness()
+    return metric
+
+
+def test_groundedness_init(monkeypatch):
+    monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
+    metric = Groundedness()
+    assert metric.name == EvalMetric.GROUNDEDNESS.value
+    assert metric.config.model == "gpt-4"
+    assert metric.config.api_key is None
+    monkeypatch.delenv("OPENAI_API_KEY")
+
+
+def test_groundedness_init_with_config():
+    metric = Groundedness(config=GroundednessConfig(api_key="test_api_key"))
+    assert metric.name == EvalMetric.GROUNDEDNESS.value
+    assert metric.config.model == "gpt-4"
+    assert metric.config.api_key == "test_api_key"
+
+
+def test_groundedness_init_without_api_key(monkeypatch):
+    monkeypatch.delenv("OPENAI_API_KEY", raising=False)
+    with pytest.raises(ValueError):
+        Groundedness()
+
+
+def test_generate_answer_claim_prompt(mock_groundedness_metric, mock_data):
+    prompt = mock_groundedness_metric._generate_answer_claim_prompt(data=mock_data[0])
+    assert "This is a test question 1." in prompt
+    assert "This is a test answer 1." in prompt
+
+
+def test_get_claim_statements(mock_groundedness_metric, mock_data, monkeypatch):
+    monkeypatch.setattr(
+        mock_groundedness_metric.client.chat.completions,
+        "create",
+        lambda *args, **kwargs: type(
+            "obj",
+            (object,),
+            {
+                "choices": [
+                    type(
+                        "obj",
+                        (object,),
+                        {
+                            "message": type(
+                                "obj",
+                                (object,),
+                                {
+                                    "content": """This is a test answer 1.
+                                                                                        This is a test answer 2.
+                                                                                        This is a test answer 3."""
+                                },
+                            )
+                        },
+                    )
+                ]
+            },
+        )(),
+    )
+    prompt = mock_groundedness_metric._generate_answer_claim_prompt(data=mock_data[0])
+    claim_statements = mock_groundedness_metric._get_claim_statements(prompt=prompt)
+    assert len(claim_statements) == 3
+    assert "This is a test answer 1." in claim_statements
+
+
+def test_generate_claim_inference_prompt(mock_groundedness_metric, mock_data):
+    prompt = mock_groundedness_metric._generate_answer_claim_prompt(data=mock_data[0])
+    claim_statements = [
+        "This is a test claim 1.",
+        "This is a test claim 2.",
+    ]
+    prompt = mock_groundedness_metric._generate_claim_inference_prompt(
+        data=mock_data[0], claim_statements=claim_statements
+    )
+    assert "This is a test context 1." in prompt
+    assert "This is a test claim 1." in prompt
+
+
+def test_get_claim_verdict_scores(mock_groundedness_metric, mock_data, monkeypatch):
+    monkeypatch.setattr(
+        mock_groundedness_metric.client.chat.completions,
+        "create",
+        lambda *args, **kwargs: type(
+            "obj",
+            (object,),
+            {"choices": [type("obj", (object,), {"message": type("obj", (object,), {"content": "1\n0\n-1"})})]},
+        )(),
+    )
+    prompt = mock_groundedness_metric._generate_answer_claim_prompt(data=mock_data[0])
+    claim_statements = mock_groundedness_metric._get_claim_statements(prompt=prompt)
+    prompt = mock_groundedness_metric._generate_claim_inference_prompt(
+        data=mock_data[0], claim_statements=claim_statements
+    )
+    claim_verdict_scores = mock_groundedness_metric._get_claim_verdict_scores(prompt=prompt)
+    assert len(claim_verdict_scores) == 3
+    assert claim_verdict_scores[0] == 1
+    assert claim_verdict_scores[1] == 0
+
+
+def test_compute_score(mock_groundedness_metric, mock_data, monkeypatch):
+    monkeypatch.setattr(
+        mock_groundedness_metric,
+        "_get_claim_statements",
+        lambda *args, **kwargs: np.array(
+            [
+                "This is a test claim 1.",
+                "This is a test claim 2.",
+            ]
+        ),
+    )
+    monkeypatch.setattr(mock_groundedness_metric, "_get_claim_verdict_scores", lambda *args, **kwargs: np.array([1, 0]))
+    score = mock_groundedness_metric._compute_score(data=mock_data[0])
+    assert score == 0.5
+
+
+def test_evaluate(mock_groundedness_metric, mock_data, monkeypatch):
+    monkeypatch.setattr(mock_groundedness_metric, "_compute_score", lambda *args, **kwargs: 0.5)
+    score = mock_groundedness_metric.evaluate(dataset=mock_data)
+    assert score == 0.5