123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596 |
- from typing import Any, Optional
- from embedchain.config.base_config import BaseConfig
- from embedchain.helpers.json_serializable import register_deserializable
- @register_deserializable
- class CacheSimilarityEvalConfig(BaseConfig):
- """
- This is the evaluator to compare two embeddings according to their distance computed in embedding retrieval stage.
- In the retrieval stage, `search_result` is the distance used for approximate nearest neighbor search and have been
- put into `cache_dict`. `max_distance` is used to bound this distance to make it between [0-`max_distance`].
- `positive` is used to indicate this distance is directly proportional to the similarity of two entities.
- If `positive` is set `False`, `max_distance` will be used to subtract this distance to get the final score.
- :param max_distance: the bound of maximum distance.
- :type max_distance: float
- :param positive: if the larger distance indicates more similar of two entities, It is True. Otherwise, it is False.
- :type positive: bool
- """
- def __init__(
- self,
- strategy: Optional[str] = "distance",
- max_distance: Optional[float] = 1.0,
- positive: Optional[bool] = False,
- ):
- self.strategy = strategy
- self.max_distance = max_distance
- self.positive = positive
- @staticmethod
- def from_config(config: Optional[dict[str, Any]]):
- if config is None:
- return CacheSimilarityEvalConfig()
- else:
- return CacheSimilarityEvalConfig(
- strategy=config.get("strategy", "distance"),
- max_distance=config.get("max_distance", 1.0),
- positive=config.get("positive", False),
- )
- @register_deserializable
- class CacheInitConfig(BaseConfig):
- """
- This is a cache init config. Used to initialize a cache.
- :param similarity_threshold: a threshold ranged from 0 to 1 to filter search results with similarity score higher \
- than the threshold. When it is 0, there is no hits. When it is 1, all search results will be returned as hits.
- :type similarity_threshold: float
- :param auto_flush: it will be automatically flushed every time xx pieces of data are added, default to 20
- :type auto_flush: int
- """
- def __init__(
- self,
- similarity_threshold: Optional[float] = 0.8,
- auto_flush: Optional[int] = 20,
- ):
- if similarity_threshold < 0 or similarity_threshold > 1:
- raise ValueError(f"similarity_threshold {similarity_threshold} should be between 0 and 1")
- self.similarity_threshold = similarity_threshold
- self.auto_flush = auto_flush
- @staticmethod
- def from_config(config: Optional[dict[str, Any]]):
- if config is None:
- return CacheInitConfig()
- else:
- return CacheInitConfig(
- similarity_threshold=config.get("similarity_threshold", 0.8),
- auto_flush=config.get("auto_flush", 20),
- )
- @register_deserializable
- class CacheConfig(BaseConfig):
- def __init__(
- self,
- similarity_eval_config: Optional[CacheSimilarityEvalConfig] = CacheSimilarityEvalConfig(),
- init_config: Optional[CacheInitConfig] = CacheInitConfig(),
- ):
- self.similarity_eval_config = similarity_eval_config
- self.init_config = init_config
- @staticmethod
- def from_config(config: Optional[dict[str, Any]]):
- if config is None:
- return CacheConfig()
- else:
- return CacheConfig(
- similarity_eval_config=CacheSimilarityEvalConfig.from_config(config.get("similarity_evaluation", {})),
- init_config=CacheInitConfig.from_config(config.get("init_config", {})),
- )
|