cache_config.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. from typing import Any, Optional
  2. from embedchain.config.base_config import BaseConfig
  3. from embedchain.helpers.json_serializable import register_deserializable
  4. @register_deserializable
  5. class CacheSimilarityEvalConfig(BaseConfig):
  6. """
  7. This is the evaluator to compare two embeddings according to their distance computed in embedding retrieval stage.
  8. In the retrieval stage, `search_result` is the distance used for approximate nearest neighbor search and have been
  9. put into `cache_dict`. `max_distance` is used to bound this distance to make it between [0-`max_distance`].
  10. `positive` is used to indicate this distance is directly proportional to the similarity of two entities.
  11. If `positive` is set `False`, `max_distance` will be used to subtract this distance to get the final score.
  12. :param max_distance: the bound of maximum distance.
  13. :type max_distance: float
  14. :param positive: if the larger distance indicates more similar of two entities, It is True. Otherwise, it is False.
  15. :type positive: bool
  16. """
  17. def __init__(
  18. self,
  19. strategy: Optional[str] = "distance",
  20. max_distance: Optional[float] = 1.0,
  21. positive: Optional[bool] = False,
  22. ):
  23. self.strategy = strategy
  24. self.max_distance = max_distance
  25. self.positive = positive
  26. @staticmethod
  27. def from_config(config: Optional[dict[str, Any]]):
  28. if config is None:
  29. return CacheSimilarityEvalConfig()
  30. else:
  31. return CacheSimilarityEvalConfig(
  32. strategy=config.get("strategy", "distance"),
  33. max_distance=config.get("max_distance", 1.0),
  34. positive=config.get("positive", False),
  35. )
  36. @register_deserializable
  37. class CacheInitConfig(BaseConfig):
  38. """
  39. This is a cache init config. Used to initialize a cache.
  40. :param similarity_threshold: a threshold ranged from 0 to 1 to filter search results with similarity score higher \
  41. than the threshold. When it is 0, there is no hits. When it is 1, all search results will be returned as hits.
  42. :type similarity_threshold: float
  43. :param auto_flush: it will be automatically flushed every time xx pieces of data are added, default to 20
  44. :type auto_flush: int
  45. """
  46. def __init__(
  47. self,
  48. similarity_threshold: Optional[float] = 0.8,
  49. auto_flush: Optional[int] = 20,
  50. ):
  51. if similarity_threshold < 0 or similarity_threshold > 1:
  52. raise ValueError(f"similarity_threshold {similarity_threshold} should be between 0 and 1")
  53. self.similarity_threshold = similarity_threshold
  54. self.auto_flush = auto_flush
  55. @staticmethod
  56. def from_config(config: Optional[dict[str, Any]]):
  57. if config is None:
  58. return CacheInitConfig()
  59. else:
  60. return CacheInitConfig(
  61. similarity_threshold=config.get("similarity_threshold", 0.8),
  62. auto_flush=config.get("auto_flush", 20),
  63. )
  64. @register_deserializable
  65. class CacheConfig(BaseConfig):
  66. def __init__(
  67. self,
  68. similarity_eval_config: Optional[CacheSimilarityEvalConfig] = CacheSimilarityEvalConfig(),
  69. init_config: Optional[CacheInitConfig] = CacheInitConfig(),
  70. ):
  71. self.similarity_eval_config = similarity_eval_config
  72. self.init_config = init_config
  73. @staticmethod
  74. def from_config(config: Optional[dict[str, Any]]):
  75. if config is None:
  76. return CacheConfig()
  77. else:
  78. return CacheConfig(
  79. similarity_eval_config=CacheSimilarityEvalConfig.from_config(config.get("similarity_evaluation", {})),
  80. init_config=CacheInitConfig.from_config(config.get("init_config", {})),
  81. )