Sfoglia il codice sorgente

[Updates] Update GPTCache configuration/docs (#1098)

Co-authored-by: Deven Patel <deven298@yahoo.com>
Deven Patel 1 anno fa
parent
commit
295cd3fac6

+ 41 - 3
docs/api-reference/advanced/configuration.mdx

@@ -56,6 +56,14 @@ chunker:
   chunk_overlap: 100
   length_function: 'len'
   min_chunk_size: 0
+
+cache:
+  similarity_evaluation:
+    strategy: distance
+    max_distance: 1.0
+  config:
+    similarity_threshold: 0.8
+    auto_flush: 50
 ```
 
 ```json config.json
@@ -98,7 +106,17 @@ chunker:
     "chunk_overlap": 100,
     "length_function": "len",
     "min_chunk_size": 0
-  }
+  },
+  "cache": {
+    "similarity_evaluation": {
+        "strategy": "distance",
+        "max_distance": 1.0,
+    },
+    "config": {
+        "similarity_threshold": 0.8,
+        "auto_flush": 50,
+    },
+  },
 }
 ```
 
@@ -148,7 +166,17 @@ config = {
         'chunk_overlap': 100,
         'length_function': 'len',
         'min_chunk_size': 0
-    }
+    },
+    'cache': {
+      'similarity_evaluation': {
+          'strategy': 'distance',
+          'max_distance': 1.0,
+      },
+      'config': {
+          'similarity_threshold': 0.8,
+          'auto_flush': 50,
+      },
+    },
 }
 ```
 </CodeGroup>
@@ -192,7 +220,17 @@ Alright, let's dive into what each key means in the yaml config above:
     - `chunk_overlap` (Integer): The amount of overlap between each chunk of text.
     - `length_function` (String): The function used to calculate the length of each chunk of text. In this case, it's set to 'len'. You can also use any function import directly as a string here.
     - `min_chunk_size` (Integer): The minimum size of each chunk of text that is sent to the language model. Must be less than `chunk_size`, and greater than `chunk_overlap`.
-
+6. `cache` Section: (Optional)
+    - `similarity_evaluation` (Optional): The config for similarity evaluation strategy. If not provided, the default `distance` based similarity evaluation strategy is used.
+      - `strategy` (String): The strategy to use for similarity evaluation. Currently, only `distance` and `exact` based similarity evaluation is supported. Defaults to `distance`.
+      - `max_distance` (Float): The bound of maximum distance. Defaults to `1.0`.
+      - `positive` (Boolean): If the larger distance indicates more similar of two entities, set it `True`, otherwise `False`. Defaults to `False`.
+    - `config` (Optional): The config for initializing the cache. If not provided, sensible default values are used as mentioned below.
+      - `similarity_threshold` (Float): The threshold for similarity evaluation. Defaults to `0.8`.
+      - `auto_flush` (Integer): The number of queries after which the cache is flushed. Defaults to `20`.
+    <Note>
+    If you provide a cache section, the app will automatically configure and use a cache to store the results of the language model. This is useful if you want to speed up the response time and save inference cost of your app.
+    </Note>
 If you have questions about the configuration above, please feel free to reach out to us using one of the following methods:
 
 <Snippet file="get-help.mdx" />

+ 13 - 4
embedchain/app.py

@@ -9,7 +9,8 @@ from typing import Any, Dict, Optional
 import requests
 import yaml
 
-from embedchain.cache import (Config, SearchDistanceEvaluation, cache,
+from embedchain.cache import (Config, ExactMatchEvaluation,
+                              SearchDistanceEvaluation, cache,
                               gptcache_data_manager, gptcache_pre_function)
 from embedchain.client import Client
 from embedchain.config import AppConfig, CacheConfig, ChunkerConfig
@@ -156,12 +157,20 @@ class App(EmbedChain):
         self.db.set_collection_name(self.db.config.collection_name)
 
     def _init_cache(self):
+        if self.cache_config.similarity_eval_config.strategy == "exact":
+            similarity_eval_func = ExactMatchEvaluation()
+        else:
+            similarity_eval_func = SearchDistanceEvaluation(
+                max_distance=self.cache_config.similarity_eval_config.max_distance,
+                positive=self.cache_config.similarity_eval_config.positive,
+            )
+
         cache.init(
             pre_embedding_func=gptcache_pre_function,
             embedding_func=self.embedding_model.to_embeddings,
             data_manager=gptcache_data_manager(vector_dimension=self.embedding_model.vector_dimension),
-            similarity_evaluation=SearchDistanceEvaluation(max_distance=1.0),
-            config=Config(similarity_threshold=self.cache_config.similarity_threshold),
+            similarity_evaluation=similarity_eval_func,
+            config=Config(**self.cache_config.init_config.as_dict()),
         )
 
     def _init_client(self):
@@ -428,7 +437,7 @@ class App(EmbedChain):
         )
 
         if cache_config_data is not None:
-            cache_config = CacheConfig(**cache_config_data)
+            cache_config = CacheConfig.from_config(cache_config_data)
         else:
             cache_config = None
 

+ 2 - 0
embedchain/cache.py

@@ -11,6 +11,8 @@ from gptcache.manager.scalar_data.base import DataType as CacheDataType
 from gptcache.session import Session
 from gptcache.similarity_evaluation.distance import \
     SearchDistanceEvaluation  # noqa: F401
+from gptcache.similarity_evaluation.exact_match import \
+    ExactMatchEvaluation  # noqa: F401
 
 
 def gptcache_pre_function(data: Dict[str, Any], **params: Dict[str, Any]):

+ 80 - 3
embedchain/config/cache_config.py

@@ -1,16 +1,93 @@
-from typing import Optional
+from typing import Any, Dict, Optional
 
 from embedchain.config.base_config import BaseConfig
 from embedchain.helpers.json_serializable import register_deserializable
 
 
 @register_deserializable
-class CacheConfig(BaseConfig):
+class CacheSimilarityEvalConfig(BaseConfig):
+    """
+    This is the evaluator to compare two embeddings according to their distance computed in embedding retrieval stage.
+    In the retrieval stage, `search_result` is the distance used for approximate nearest neighbor search and have been
+    put into `cache_dict`. `max_distance` is used to bound this distance to make it between [0-`max_distance`].
+    `positive` is used to indicate this distance is directly proportional to the similarity of two entites.
+    If `positive` is set `False`, `max_distance` will be used to substract this distance to get the final score.
+
+    :param max_distance: the bound of maximum distance.
+    :type max_distance: float
+    :param positive: if the larger distance indicates more similar of two entities, It is True. Otherwise it is False.
+    :type positive: bool
+    """
+
     def __init__(
         self,
-        similarity_threshold: Optional[float] = 0.5,
+        strategy: Optional[str] = "distance",
+        max_distance: Optional[float] = 1.0,
+        positive: Optional[bool] = False,
+    ):
+        self.strategy = strategy
+        self.max_distance = max_distance
+        self.positive = positive
+
+    def from_config(config: Optional[Dict[str, Any]]):
+        if config is None:
+            return CacheSimilarityEvalConfig()
+        else:
+            return CacheSimilarityEvalConfig(
+                strategy=config.get("strategy", "distance"),
+                max_distance=config.get("max_distance", 1.0),
+                positive=config.get("positive", False),
+            )
+
+
+@register_deserializable
+class CacheInitConfig(BaseConfig):
+    """
+    This is a cache init config. Used to initialize a cache.
+
+    :param similarity_threshold: a threshold ranged from 0 to 1 to filter search results with similarity score higher \
+     than the threshold. When it is 0, there is no hits. When it is 1, all search results will be returned as hits.
+    :type similarity_threshold: float
+    :param auto_flush: it will be automatically flushed every time xx pieces of data are added, default to 20
+    :type auto_flush: int
+    """
+
+    def __init__(
+        self,
+        similarity_threshold: Optional[float] = 0.8,
+        auto_flush: Optional[int] = 20,
     ):
         if similarity_threshold < 0 or similarity_threshold > 1:
             raise ValueError(f"similarity_threshold {similarity_threshold} should be between 0 and 1")
 
         self.similarity_threshold = similarity_threshold
+        self.auto_flush = auto_flush
+
+    def from_config(config: Optional[Dict[str, Any]]):
+        if config is None:
+            return CacheInitConfig()
+        else:
+            return CacheInitConfig(
+                similarity_threshold=config.get("similarity_threshold", 0.8),
+                auto_flush=config.get("auto_flush", 20),
+            )
+
+
+@register_deserializable
+class CacheConfig(BaseConfig):
+    def __init__(
+        self,
+        similarity_eval_config: Optional[CacheSimilarityEvalConfig] = CacheSimilarityEvalConfig(),
+        init_config: Optional[CacheInitConfig] = CacheInitConfig(),
+    ):
+        self.similarity_eval_config = similarity_eval_config
+        self.init_config = init_config
+
+    def from_config(config: Optional[Dict[str, Any]]):
+        if config is None:
+            return CacheConfig()
+        else:
+            return CacheConfig(
+                similarity_eval_config=CacheSimilarityEvalConfig.from_config(config.get("similarity_evaluation", {})),
+                init_config=CacheInitConfig.from_config(config.get("init_config", {})),
+            )

+ 1 - 0
embedchain/llm/openai.py

@@ -42,6 +42,7 @@ class OpenAILlm(BaseLlm):
             chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key)
         else:
             chat = ChatOpenAI(**kwargs, api_key=api_key)
+
         if self.functions is not None:
             from langchain.chains.openai_functions import \
                 create_openai_fn_runnable

+ 9 - 1
embedchain/utils.py

@@ -441,7 +441,15 @@ def validate_config(config_data):
                 Optional("min_chunk_size"): int,
             },
             Optional("cache"): {
-                Optional("similarity_threshold"): float,
+                Optional("similarity_evaluation"): {
+                    Optional("strategy"): Or("distance", "exact"),
+                    Optional("max_distance"): float,
+                    Optional("positive"): bool,
+                },
+                Optional("config"): {
+                    Optional("similarity_threshold"): float,
+                    Optional("auto_flush"): int,
+                },
             },
         }
     )