groundedness.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import concurrent.futures
  2. import logging
  3. import os
  4. from string import Template
  5. from typing import Optional
  6. import numpy as np
  7. from openai import OpenAI
  8. from tqdm import tqdm
  9. from embedchain.config.evaluation.base import GroundednessConfig
  10. from embedchain.evaluation.base import BaseMetric
  11. from embedchain.utils.evaluation import EvalData, EvalMetric
  12. logger = logging.getLogger(__name__)
  13. class Groundedness(BaseMetric):
  14. """
  15. Metric for groundedness of answer from the given contexts.
  16. """
  17. def __init__(self, config: Optional[GroundednessConfig] = None):
  18. super().__init__(name=EvalMetric.GROUNDEDNESS.value)
  19. self.config = config or GroundednessConfig()
  20. api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
  21. if not api_key:
  22. raise ValueError("Please set the OPENAI_API_KEY environment variable or pass the `api_key` in config.")
  23. self.client = OpenAI(api_key=api_key)
  24. def _generate_answer_claim_prompt(self, data: EvalData) -> str:
  25. """
  26. Generate the prompt for the given data.
  27. """
  28. prompt = Template(self.config.answer_claims_prompt).substitute(question=data.question, answer=data.answer)
  29. return prompt
  30. def _get_claim_statements(self, prompt: str) -> np.ndarray:
  31. """
  32. Get claim statements from the answer.
  33. """
  34. response = self.client.chat.completions.create(
  35. model=self.config.model,
  36. messages=[{"role": "user", "content": f"{prompt}"}],
  37. )
  38. result = response.choices[0].message.content.strip()
  39. claim_statements = np.array([statement for statement in result.split("\n") if statement])
  40. return claim_statements
  41. def _generate_claim_inference_prompt(self, data: EvalData, claim_statements: list[str]) -> str:
  42. """
  43. Generate the claim inference prompt for the given data and claim statements.
  44. """
  45. prompt = Template(self.config.claims_inference_prompt).substitute(
  46. context="\n".join(data.contexts), claim_statements="\n".join(claim_statements)
  47. )
  48. return prompt
  49. def _get_claim_verdict_scores(self, prompt: str) -> np.ndarray:
  50. """
  51. Get verdicts for claim statements.
  52. """
  53. response = self.client.chat.completions.create(
  54. model=self.config.model,
  55. messages=[{"role": "user", "content": f"{prompt}"}],
  56. )
  57. result = response.choices[0].message.content.strip()
  58. claim_verdicts = result.split("\n")
  59. verdict_score_map = {"1": 1, "0": 0, "-1": np.nan}
  60. verdict_scores = np.array([verdict_score_map[verdict] for verdict in claim_verdicts])
  61. return verdict_scores
  62. def _compute_score(self, data: EvalData) -> float:
  63. """
  64. Compute the groundedness score for a single data point.
  65. """
  66. answer_claims_prompt = self._generate_answer_claim_prompt(data)
  67. claim_statements = self._get_claim_statements(answer_claims_prompt)
  68. claim_inference_prompt = self._generate_claim_inference_prompt(data, claim_statements)
  69. verdict_scores = self._get_claim_verdict_scores(claim_inference_prompt)
  70. return np.sum(verdict_scores) / claim_statements.size
  71. def evaluate(self, dataset: list[EvalData]):
  72. """
  73. Evaluate the dataset and returns the average groundedness score.
  74. """
  75. results = []
  76. with concurrent.futures.ThreadPoolExecutor() as executor:
  77. future_to_data = {executor.submit(self._compute_score, data): data for data in dataset}
  78. for future in tqdm(
  79. concurrent.futures.as_completed(future_to_data),
  80. total=len(future_to_data),
  81. desc="Evaluating Groundedness",
  82. ):
  83. data = future_to_data[future]
  84. try:
  85. score = future.result()
  86. results.append(score)
  87. except Exception as e:
  88. logger.error(f"Error while evaluating groundedness for data point {data}: {e}")
  89. return np.mean(results) if results else 0.0