|
@@ -9,8 +9,10 @@ from typing import Any, Dict, Optional
|
|
|
import requests
|
|
|
import yaml
|
|
|
|
|
|
+from embedchain.cache import (Config, SearchDistanceEvaluation, cache,
|
|
|
+ gptcache_data_manager, gptcache_pre_function)
|
|
|
from embedchain.client import Client
|
|
|
-from embedchain.config import AppConfig, ChunkerConfig
|
|
|
+from embedchain.config import AppConfig, CacheConfig, ChunkerConfig
|
|
|
from embedchain.constants import SQLITE_PATH
|
|
|
from embedchain.embedchain import EmbedChain
|
|
|
from embedchain.embedder.base import BaseEmbedder
|
|
@@ -48,6 +50,7 @@ class App(EmbedChain):
|
|
|
log_level=logging.WARN,
|
|
|
auto_deploy: bool = False,
|
|
|
chunker: ChunkerConfig = None,
|
|
|
+ cache_config: CacheConfig = None,
|
|
|
):
|
|
|
"""
|
|
|
Initialize a new `App` instance.
|
|
@@ -88,6 +91,7 @@ class App(EmbedChain):
|
|
|
self.chunker = None
|
|
|
if chunker:
|
|
|
self.chunker = ChunkerConfig(**chunker)
|
|
|
+ self.cache_config = cache_config
|
|
|
|
|
|
self.config = config or AppConfig()
|
|
|
self.name = self.config.name
|
|
@@ -109,6 +113,10 @@ class App(EmbedChain):
|
|
|
self.llm = llm or OpenAILlm()
|
|
|
self._init_db()
|
|
|
|
|
|
+ # If cache_config is provided, initializing the cache ...
|
|
|
+ if self.cache_config is not None:
|
|
|
+ self._init_cache()
|
|
|
+
|
|
|
# Send anonymous telemetry
|
|
|
self._telemetry_props = {"class": self.__class__.__name__}
|
|
|
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
|
|
@@ -147,6 +155,15 @@ class App(EmbedChain):
|
|
|
self.db._initialize()
|
|
|
self.db.set_collection_name(self.db.config.collection_name)
|
|
|
|
|
|
+ def _init_cache(self):
|
|
|
+ cache.init(
|
|
|
+ pre_embedding_func=gptcache_pre_function,
|
|
|
+ embedding_func=self.embedding_model.to_embeddings,
|
|
|
+ data_manager=gptcache_data_manager(vector_dimension=self.embedding_model.vector_dimension),
|
|
|
+ similarity_evaluation=SearchDistanceEvaluation(max_distance=1.0),
|
|
|
+ config=Config(similarity_threshold=self.cache_config.similarity_threshold),
|
|
|
+ )
|
|
|
+
|
|
|
def _init_client(self):
|
|
|
"""
|
|
|
Initialize the client.
|
|
@@ -399,6 +416,7 @@ class App(EmbedChain):
|
|
|
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
|
|
|
llm_config_data = config_data.get("llm", {})
|
|
|
chunker_config_data = config_data.get("chunker", {})
|
|
|
+ cache_config_data = config_data.get("cache", None)
|
|
|
|
|
|
app_config = AppConfig(**app_config_data)
|
|
|
|
|
@@ -416,6 +434,11 @@ class App(EmbedChain):
|
|
|
embedding_model_provider, embedding_model_config_data.get("config", {})
|
|
|
)
|
|
|
|
|
|
+ if cache_config_data is not None:
|
|
|
+ cache_config = CacheConfig(**cache_config_data)
|
|
|
+ else:
|
|
|
+ cache_config = None
|
|
|
+
|
|
|
# Send anonymous telemetry
|
|
|
event_properties = {"init_type": "config_data"}
|
|
|
AnonymousTelemetry().capture(event_name="init", properties=event_properties)
|
|
@@ -428,4 +451,5 @@ class App(EmbedChain):
|
|
|
config_data=config_data,
|
|
|
auto_deploy=auto_deploy,
|
|
|
chunker=chunker_config_data,
|
|
|
+ cache_config=cache_config,
|
|
|
)
|