context_relevancy.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import concurrent.futures
  2. import os
  3. from string import Template
  4. from typing import Optional
  5. import numpy as np
  6. import pysbd
  7. from openai import OpenAI
  8. from tqdm import tqdm
  9. from embedchain.config.eval.base import ContextRelevanceConfig
  10. from embedchain.eval.base import BaseMetric
  11. from embedchain.utils.eval import EvalData, EvalMetric
  12. class ContextRelevance(BaseMetric):
  13. """
  14. Metric for evaluating the relevance of context in a dataset.
  15. """
  16. def __init__(self, config: Optional[ContextRelevanceConfig] = ContextRelevanceConfig()):
  17. super().__init__(name=EvalMetric.CONTEXT_RELEVANCY.value)
  18. self.config = config
  19. api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
  20. if not api_key:
  21. raise ValueError("API key not found. Set 'OPENAI_API_KEY' or pass it in the config.")
  22. self.client = OpenAI(api_key=api_key)
  23. self._sbd = pysbd.Segmenter(language=self.config.language, clean=False)
  24. def _sentence_segmenter(self, text: str) -> list[str]:
  25. """
  26. Segments the given text into sentences.
  27. """
  28. return self._sbd.segment(text)
  29. def _compute_score(self, data: EvalData) -> float:
  30. """
  31. Computes the context relevance score for a given data item.
  32. """
  33. original_context = "\n".join(data.contexts)
  34. prompt = Template(self.config.prompt).substitute(context=original_context, question=data.question)
  35. response = self.client.chat.completions.create(
  36. model=self.config.model, messages=[{"role": "user", "content": prompt}]
  37. )
  38. useful_context = response.choices[0].message.content.strip()
  39. useful_context_sentences = self._sentence_segmenter(useful_context)
  40. original_context_sentences = self._sentence_segmenter(original_context)
  41. if not original_context_sentences:
  42. return 0.0
  43. return len(useful_context_sentences) / len(original_context_sentences)
  44. def evaluate(self, dataset: list[EvalData]) -> float:
  45. """
  46. Evaluates the dataset and returns the average context relevance score.
  47. """
  48. scores = []
  49. with concurrent.futures.ThreadPoolExecutor() as executor:
  50. futures = [executor.submit(self._compute_score, data) for data in dataset]
  51. for future in tqdm(
  52. concurrent.futures.as_completed(futures), total=len(dataset), desc="Evaluating Context Relevancy"
  53. ):
  54. try:
  55. scores.append(future.result())
  56. except Exception as e:
  57. print(f"Error during evaluation: {e}")
  58. return np.mean(scores) if scores else 0.0