|
@@ -1,16 +1,93 @@
|
|
-from typing import Optional
|
|
|
|
|
|
+from typing import Any, Dict, Optional
|
|
|
|
|
|
from embedchain.config.base_config import BaseConfig
|
|
from embedchain.config.base_config import BaseConfig
|
|
from embedchain.helpers.json_serializable import register_deserializable
|
|
from embedchain.helpers.json_serializable import register_deserializable
|
|
|
|
|
|
|
|
|
|
@register_deserializable
|
|
@register_deserializable
|
|
-class CacheConfig(BaseConfig):
|
|
|
|
|
|
+class CacheSimilarityEvalConfig(BaseConfig):
|
|
|
|
+ """
|
|
|
|
+ This is the evaluator to compare two embeddings according to their distance computed in embedding retrieval stage.
|
|
|
|
+ In the retrieval stage, `search_result` is the distance used for approximate nearest neighbor search and have been
|
|
|
|
+ put into `cache_dict`. `max_distance` is used to bound this distance to make it between [0-`max_distance`].
|
|
|
|
+ `positive` is used to indicate this distance is directly proportional to the similarity of two entites.
|
|
|
|
+ If `positive` is set `False`, `max_distance` will be used to substract this distance to get the final score.
|
|
|
|
+
|
|
|
|
+ :param max_distance: the bound of maximum distance.
|
|
|
|
+ :type max_distance: float
|
|
|
|
+ :param positive: if the larger distance indicates more similar of two entities, It is True. Otherwise it is False.
|
|
|
|
+ :type positive: bool
|
|
|
|
+ """
|
|
|
|
+
|
|
def __init__(
|
|
def __init__(
|
|
self,
|
|
self,
|
|
- similarity_threshold: Optional[float] = 0.5,
|
|
|
|
|
|
+ strategy: Optional[str] = "distance",
|
|
|
|
+ max_distance: Optional[float] = 1.0,
|
|
|
|
+ positive: Optional[bool] = False,
|
|
|
|
+ ):
|
|
|
|
+ self.strategy = strategy
|
|
|
|
+ self.max_distance = max_distance
|
|
|
|
+ self.positive = positive
|
|
|
|
+
|
|
|
|
+ def from_config(config: Optional[Dict[str, Any]]):
|
|
|
|
+ if config is None:
|
|
|
|
+ return CacheSimilarityEvalConfig()
|
|
|
|
+ else:
|
|
|
|
+ return CacheSimilarityEvalConfig(
|
|
|
|
+ strategy=config.get("strategy", "distance"),
|
|
|
|
+ max_distance=config.get("max_distance", 1.0),
|
|
|
|
+ positive=config.get("positive", False),
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@register_deserializable
|
|
|
|
+class CacheInitConfig(BaseConfig):
|
|
|
|
+ """
|
|
|
|
+ This is a cache init config. Used to initialize a cache.
|
|
|
|
+
|
|
|
|
+ :param similarity_threshold: a threshold ranged from 0 to 1 to filter search results with similarity score higher \
|
|
|
|
+ than the threshold. When it is 0, there is no hits. When it is 1, all search results will be returned as hits.
|
|
|
|
+ :type similarity_threshold: float
|
|
|
|
+ :param auto_flush: it will be automatically flushed every time xx pieces of data are added, default to 20
|
|
|
|
+ :type auto_flush: int
|
|
|
|
+ """
|
|
|
|
+
|
|
|
|
+ def __init__(
|
|
|
|
+ self,
|
|
|
|
+ similarity_threshold: Optional[float] = 0.8,
|
|
|
|
+ auto_flush: Optional[int] = 20,
|
|
):
|
|
):
|
|
if similarity_threshold < 0 or similarity_threshold > 1:
|
|
if similarity_threshold < 0 or similarity_threshold > 1:
|
|
raise ValueError(f"similarity_threshold {similarity_threshold} should be between 0 and 1")
|
|
raise ValueError(f"similarity_threshold {similarity_threshold} should be between 0 and 1")
|
|
|
|
|
|
self.similarity_threshold = similarity_threshold
|
|
self.similarity_threshold = similarity_threshold
|
|
|
|
+ self.auto_flush = auto_flush
|
|
|
|
+
|
|
|
|
+ def from_config(config: Optional[Dict[str, Any]]):
|
|
|
|
+ if config is None:
|
|
|
|
+ return CacheInitConfig()
|
|
|
|
+ else:
|
|
|
|
+ return CacheInitConfig(
|
|
|
|
+ similarity_threshold=config.get("similarity_threshold", 0.8),
|
|
|
|
+ auto_flush=config.get("auto_flush", 20),
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@register_deserializable
|
|
|
|
+class CacheConfig(BaseConfig):
|
|
|
|
+ def __init__(
|
|
|
|
+ self,
|
|
|
|
+ similarity_eval_config: Optional[CacheSimilarityEvalConfig] = CacheSimilarityEvalConfig(),
|
|
|
|
+ init_config: Optional[CacheInitConfig] = CacheInitConfig(),
|
|
|
|
+ ):
|
|
|
|
+ self.similarity_eval_config = similarity_eval_config
|
|
|
|
+ self.init_config = init_config
|
|
|
|
+
|
|
|
|
+ def from_config(config: Optional[Dict[str, Any]]):
|
|
|
|
+ if config is None:
|
|
|
|
+ return CacheConfig()
|
|
|
|
+ else:
|
|
|
|
+ return CacheConfig(
|
|
|
|
+ similarity_eval_config=CacheSimilarityEvalConfig.from_config(config.get("similarity_evaluation", {})),
|
|
|
|
+ init_config=CacheInitConfig.from_config(config.get("init_config", {})),
|
|
|
|
+ )
|