groundedness.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import concurrent.futures
  2. import logging
  3. import os
  4. from string import Template
  5. from typing import Optional
  6. import numpy as np
  7. from openai import OpenAI
  8. from tqdm import tqdm
  9. from embedchain.config.eval.base import GroundednessConfig
  10. from embedchain.eval.base import BaseMetric
  11. from embedchain.utils.eval import EvalData, EvalMetric
  12. class Groundedness(BaseMetric):
  13. """
  14. Metric for groundedness (aka faithfulness) of answer from the given contexts.
  15. """
  16. def __init__(self, config: Optional[GroundednessConfig] = None):
  17. super().__init__(name=EvalMetric.GROUNDEDNESS.value)
  18. self.config = config or GroundednessConfig()
  19. api_key = self.config.api_key or os.environ["OPENAI_API_KEY"]
  20. if not api_key:
  21. raise ValueError("Please set the OPENAI_API_KEY environment variable or pass the `api_key` in config.")
  22. self.client = OpenAI(api_key=api_key)
  23. def _generate_answer_claim_prompt(self, data: EvalData) -> str:
  24. """
  25. Generate the prompt for the given data.
  26. """
  27. prompt = Template(self.config.answer_claims_prompt).substitute(question=data.question, answer=data.answer)
  28. return prompt
  29. def _get_claim_statements(self, prompt: str) -> np.ndarray:
  30. """
  31. Get claim statements from the answer.
  32. """
  33. response = self.client.chat.completions.create(
  34. model=self.config.model,
  35. messages=[{"role": "user", "content": f"{prompt}"}],
  36. )
  37. result = response.choices[0].message.content.strip()
  38. claim_statements = np.array([statement for statement in result.split("\n") if statement])
  39. return claim_statements
  40. def _generate_claim_inference_prompt(self, data: EvalData, claim_statements: list[str]) -> str:
  41. """
  42. Generate the claim inference prompt for the given data and claim statements.
  43. """
  44. prompt = Template(self.config.claims_inference_prompt).substitute(
  45. context="\n".join(data.contexts), claim_statements="\n".join(claim_statements)
  46. )
  47. return prompt
  48. def _get_claim_verdict_scores(self, prompt: str) -> np.ndarray:
  49. """
  50. Get verdicts for claim statements.
  51. """
  52. response = self.client.chat.completions.create(
  53. model=self.config.model,
  54. messages=[{"role": "user", "content": f"{prompt}"}],
  55. )
  56. result = response.choices[0].message.content.strip()
  57. claim_verdicts = result.split("\n")
  58. verdict_score_map = {"1": 1, "0": 0, "-1": np.nan}
  59. verdict_scores = np.array([verdict_score_map[verdict] for verdict in claim_verdicts])
  60. return verdict_scores
  61. def _compute_score(self, data: EvalData) -> float:
  62. """
  63. Compute the groundedness score (aka faithfulness) for a single data point.
  64. """
  65. answer_claims_prompt = self._generate_answer_claim_prompt(data)
  66. claim_statements = self._get_claim_statements(answer_claims_prompt)
  67. claim_inference_prompt = self._generate_claim_inference_prompt(data, claim_statements)
  68. verdict_scores = self._get_claim_verdict_scores(claim_inference_prompt)
  69. return np.sum(verdict_scores) / claim_statements.size
  70. def evaluate(self, dataset: list[EvalData]):
  71. """
  72. Evaluate the dataset and returns the average groundedness score.
  73. """
  74. results = []
  75. with concurrent.futures.ThreadPoolExecutor() as executor:
  76. future_to_data = {executor.submit(self._compute_score, data): data for data in dataset}
  77. for future in tqdm(
  78. concurrent.futures.as_completed(future_to_data),
  79. total=len(future_to_data),
  80. desc="Evaluating groundedness (aka faithfulness)",
  81. ):
  82. data = future_to_data[future]
  83. try:
  84. score = future.result()
  85. results.append(score)
  86. except Exception as e:
  87. logging.error(f"Error while evaluating groundedness for data point {data}: {e}")
  88. return np.mean(results) if results else 0.0