answer_relevancy.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import concurrent.futures
  2. import logging
  3. import os
  4. from string import Template
  5. from typing import Optional
  6. import numpy as np
  7. from openai import OpenAI
  8. from tqdm import tqdm
  9. from embedchain.config.eval.base import AnswerRelevanceConfig
  10. from embedchain.eval.base import BaseMetric
  11. from embedchain.utils.eval import EvalData, EvalMetric
  12. class AnswerRelevance(BaseMetric):
  13. """
  14. Metric for evaluating the relevance of answers.
  15. """
  16. def __init__(self, config: Optional[AnswerRelevanceConfig] = AnswerRelevanceConfig()):
  17. super().__init__(name=EvalMetric.ANSWER_RELEVANCY.value)
  18. self.config = config
  19. api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
  20. if not api_key:
  21. raise ValueError("API key not found. Set 'OPENAI_API_KEY' or pass it in the config.")
  22. self.client = OpenAI(api_key=api_key)
  23. def _generate_prompt(self, data: EvalData) -> str:
  24. """
  25. Generates a prompt based on the provided data.
  26. """
  27. return Template(self.config.prompt).substitute(
  28. num_gen_questions=self.config.num_gen_questions, answer=data.answer
  29. )
  30. def _generate_questions(self, prompt: str) -> list[str]:
  31. """
  32. Generates questions from the prompt.
  33. """
  34. response = self.client.chat.completions.create(
  35. model=self.config.model,
  36. messages=[{"role": "user", "content": prompt}],
  37. )
  38. return response.choices[0].message.content.strip().split("\n")
  39. def _generate_embedding(self, question: str) -> np.ndarray:
  40. """
  41. Generates the embedding for a question.
  42. """
  43. response = self.client.embeddings.create(
  44. input=question,
  45. model=self.config.embedder,
  46. )
  47. return np.array(response.data[0].embedding)
  48. def _compute_similarity(self, original: np.ndarray, generated: np.ndarray) -> float:
  49. """
  50. Computes the cosine similarity between two embeddings.
  51. """
  52. original = original.reshape(1, -1)
  53. norm = np.linalg.norm(original) * np.linalg.norm(generated, axis=1)
  54. return np.dot(generated, original.T).flatten() / norm
  55. def _compute_score(self, data: EvalData) -> float:
  56. """
  57. Computes the relevance score for a given data item.
  58. """
  59. prompt = self._generate_prompt(data)
  60. generated_questions = self._generate_questions(prompt)
  61. original_embedding = self._generate_embedding(data.question)
  62. generated_embeddings = np.array([self._generate_embedding(q) for q in generated_questions])
  63. similarities = self._compute_similarity(original_embedding, generated_embeddings)
  64. return np.mean(similarities)
  65. def evaluate(self, dataset: list[EvalData]) -> float:
  66. """
  67. Evaluates the dataset and returns the average answer relevance score.
  68. """
  69. results = []
  70. with concurrent.futures.ThreadPoolExecutor() as executor:
  71. future_to_data = {executor.submit(self._compute_score, data): data for data in dataset}
  72. for future in tqdm(
  73. concurrent.futures.as_completed(future_to_data), total=len(dataset), desc="Evaluating Answer Relevancy"
  74. ):
  75. data = future_to_data[future]
  76. try:
  77. results.append(future.result())
  78. except Exception as e:
  79. logging.error(f"Error evaluating answer relevancy for {data}: {e}")
  80. return np.mean(results) if results else 0.0