Browse Source

[Feature] Add support for GPTCache (#1065)

Deven Patel 1 year ago
parent
commit
04daa1b206

+ 3 - 0
.gitignore

@@ -175,3 +175,6 @@ notebooks/*.yaml
 .ipynb_checkpoints/
 
 !configs/*.yaml
+
+# cache db
+*.db

+ 25 - 1
embedchain/app.py

@@ -9,8 +9,10 @@ from typing import Any, Dict, Optional
 import requests
 import yaml
 
+from embedchain.cache import (Config, SearchDistanceEvaluation, cache,
+                              gptcache_data_manager, gptcache_pre_function)
 from embedchain.client import Client
-from embedchain.config import AppConfig, ChunkerConfig
+from embedchain.config import AppConfig, CacheConfig, ChunkerConfig
 from embedchain.constants import SQLITE_PATH
 from embedchain.embedchain import EmbedChain
 from embedchain.embedder.base import BaseEmbedder
@@ -48,6 +50,7 @@ class App(EmbedChain):
         log_level=logging.WARN,
         auto_deploy: bool = False,
         chunker: ChunkerConfig = None,
+        cache_config: CacheConfig = None,
     ):
         """
         Initialize a new `App` instance.
@@ -88,6 +91,7 @@ class App(EmbedChain):
         self.chunker = None
         if chunker:
             self.chunker = ChunkerConfig(**chunker)
+        self.cache_config = cache_config
 
         self.config = config or AppConfig()
         self.name = self.config.name
@@ -109,6 +113,10 @@ class App(EmbedChain):
         self.llm = llm or OpenAILlm()
         self._init_db()
 
+        # If cache_config is provided, initializing the cache ...
+        if self.cache_config is not None:
+            self._init_cache()
+
         # Send anonymous telemetry
         self._telemetry_props = {"class": self.__class__.__name__}
         self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
@@ -147,6 +155,15 @@ class App(EmbedChain):
         self.db._initialize()
         self.db.set_collection_name(self.db.config.collection_name)
 
+    def _init_cache(self):
+        cache.init(
+            pre_embedding_func=gptcache_pre_function,
+            embedding_func=self.embedding_model.to_embeddings,
+            data_manager=gptcache_data_manager(vector_dimension=self.embedding_model.vector_dimension),
+            similarity_evaluation=SearchDistanceEvaluation(max_distance=1.0),
+            config=Config(similarity_threshold=self.cache_config.similarity_threshold),
+        )
+
     def _init_client(self):
         """
         Initialize the client.
@@ -399,6 +416,7 @@ class App(EmbedChain):
         embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
         llm_config_data = config_data.get("llm", {})
         chunker_config_data = config_data.get("chunker", {})
+        cache_config_data = config_data.get("cache", None)
 
         app_config = AppConfig(**app_config_data)
 
@@ -416,6 +434,11 @@ class App(EmbedChain):
             embedding_model_provider, embedding_model_config_data.get("config", {})
         )
 
+        if cache_config_data is not None:
+            cache_config = CacheConfig(**cache_config_data)
+        else:
+            cache_config = None
+
         # Send anonymous telemetry
         event_properties = {"init_type": "config_data"}
         AnonymousTelemetry().capture(event_name="init", properties=event_properties)
@@ -428,4 +451,5 @@ class App(EmbedChain):
             config_data=config_data,
             auto_deploy=auto_deploy,
             chunker=chunker_config_data,
+            cache_config=cache_config,
         )

+ 40 - 0
embedchain/cache.py

@@ -0,0 +1,40 @@
+import logging
+import os  # noqa: F401
+from typing import Any, Dict
+
+from gptcache import cache  # noqa: F401
+from gptcache.adapter.adapter import adapt  # noqa: F401
+from gptcache.config import Config  # noqa: F401
+from gptcache.manager import get_data_manager
+from gptcache.manager.scalar_data.base import Answer
+from gptcache.manager.scalar_data.base import DataType as CacheDataType
+from gptcache.session import Session
+from gptcache.similarity_evaluation.distance import \
+    SearchDistanceEvaluation  # noqa: F401
+
+
+def gptcache_pre_function(data: Dict[str, Any], **params: Dict[str, Any]):
+    return data["input_query"]
+
+
+def gptcache_data_manager(vector_dimension):
+    return get_data_manager(cache_base="sqlite", vector_base="chromadb", max_size=1000, eviction="LRU")
+
+
+def gptcache_data_convert(cache_data):
+    logging.info("[Cache] Cache hit, returning cache data...")
+    return cache_data
+
+
+def gptcache_update_cache_callback(llm_data, update_cache_func, *args, **kwargs):
+    logging.info("[Cache] Cache missed, updating cache...")
+    update_cache_func(Answer(llm_data, CacheDataType.STR))
+    return llm_data
+
+
+def _gptcache_session_hit_func(cur_session_id: str, cache_session_ids: list, cache_questions: list, cache_answer: str):
+    return cur_session_id in cache_session_ids
+
+
+def get_gptcache_session(session_id: str):
+    return Session(name=session_id, check_hit_func=_gptcache_session_hit_func)

+ 1 - 0
embedchain/config/__init__.py

@@ -3,6 +3,7 @@
 from .add_config import AddConfig, ChunkerConfig
 from .app_config import AppConfig
 from .base_config import BaseConfig
+from .cache_config import CacheConfig
 from .embedder.base import BaseEmbedderConfig
 from .embedder.base import BaseEmbedderConfig as EmbedderConfig
 from .llm.base import BaseLlmConfig

+ 16 - 0
embedchain/config/cache_config.py

@@ -0,0 +1,16 @@
+from typing import Optional
+
+from embedchain.config.base_config import BaseConfig
+from embedchain.helpers.json_serializable import register_deserializable
+
+
+@register_deserializable
+class CacheConfig(BaseConfig):
+    def __init__(
+        self,
+        similarity_threshold: Optional[float] = 0.5,
+    ):
+        if similarity_threshold < 0 or similarity_threshold > 1:
+            raise ValueError(f"similarity_threshold {similarity_threshold} should be between 0 and 1")
+
+        self.similarity_threshold = similarity_threshold

+ 36 - 6
embedchain/embedchain.py

@@ -7,6 +7,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union
 from dotenv import load_dotenv
 from langchain.docstore.document import Document
 
+from embedchain.cache import (adapt, get_gptcache_session,
+                              gptcache_data_convert,
+                              gptcache_update_cache_callback)
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
 from embedchain.config.base_app_config import BaseAppConfig
@@ -52,6 +55,7 @@ class EmbedChain(JSONSerializable):
         """
 
         self.config = config
+        self.cache_config = None
         # Llm
         self.llm = llm
         # Database has support for config assignment for backwards compatibility
@@ -546,9 +550,22 @@ class EmbedChain(JSONSerializable):
         else:
             contexts_data_for_llm_query = contexts
 
-        answer = self.llm.query(
-            input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
-        )
+        if self.cache_config is not None:
+            logging.info("Cache enabled. Checking cache...")
+            answer = adapt(
+                llm_handler=self.llm.query,
+                cache_data_convert=gptcache_data_convert,
+                update_cache_callback=gptcache_update_cache_callback,
+                session=get_gptcache_session(session_id=self.config.id),
+                input_query=input_query,
+                contexts=contexts_data_for_llm_query,
+                config=config,
+                dry_run=dry_run,
+            )
+        else:
+            answer = self.llm.query(
+                input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
+            )
 
         # Send anonymous telemetry
         self.telemetry.capture(event_name="query", properties=self._telemetry_props)
@@ -599,9 +616,22 @@ class EmbedChain(JSONSerializable):
         else:
             contexts_data_for_llm_query = contexts
 
-        answer = self.llm.chat(
-            input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
-        )
+        if self.cache_config is not None:
+            logging.info("Cache enabled. Checking cache...")
+            answer = adapt(
+                llm_handler=self.llm.chat,
+                cache_data_convert=gptcache_data_convert,
+                update_cache_callback=gptcache_update_cache_callback,
+                session=get_gptcache_session(session_id=self.config.id),
+                input_query=input_query,
+                contexts=contexts_data_for_llm_query,
+                config=config,
+                dry_run=dry_run,
+            )
+        else:
+            answer = self.llm.chat(
+                input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
+            )
 
         # add conversation in memory
         self.llm.add_history(self.config.id, input_query, answer)

+ 12 - 0
embedchain/embedder/base.py

@@ -75,3 +75,15 @@ class BaseEmbedder:
         """
 
         return EmbeddingFunc(embeddings.embed_documents)
+
+    def to_embeddings(self, data: str, **_):
+        """
+        Convert data to embeddings
+
+        :param data: data to convert to embeddings
+        :type data: str
+        :return: embeddings
+        :rtype: list[float]
+        """
+        embeddings = self.embedding_fn([data])
+        return embeddings[0]

+ 3 - 0
embedchain/utils.py

@@ -436,6 +436,9 @@ def validate_config(config_data):
                 Optional("length_function"): str,
                 Optional("min_chunk_size"): int,
             },
+            Optional("cache"): {
+                Optional("similarity_threshold"): float,
+            },
         }
     )
 

+ 19 - 3
poetry.lock

@@ -1849,11 +1849,11 @@ files = [
 google-auth = ">=2.14.1,<3.0.dev0"
 googleapis-common-protos = ">=1.56.2,<2.0.dev0"
 grpcio = [
-    {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
+    {version = ">=1.33.2,<2.0dev", optional = true, markers = "extra == \"grpc\""},
     {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
 ]
 grpcio-status = [
-    {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
+    {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "extra == \"grpc\""},
     {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
 ]
 protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0"
@@ -2176,6 +2176,22 @@ tqdm = "*"
 [package.extras]
 dev = ["black", "isort", "mkautodoc", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]", "pytest", "setuptools", "twine", "wheel"]
 
+[[package]]
+name = "gptcache"
+version = "0.1.43"
+description = "GPTCache, a powerful caching library that can be used to speed up and lower the cost of chat applications that rely on the LLM service. GPTCache works as a memcache for AIGC applications, similar to how Redis works for traditional applications."
+optional = false
+python-versions = ">=3.8.1"
+files = [
+    {file = "gptcache-0.1.43-py3-none-any.whl", hash = "sha256:9c557ec9cc14428942a0ebf1c838520dc6d2be801d67bb6964807043fc2feaf5"},
+    {file = "gptcache-0.1.43.tar.gz", hash = "sha256:cebe7ec5e32a3347bf839e933a34e67c7fcae620deaa7cb8c6d7d276c8686f1a"},
+]
+
+[package.dependencies]
+cachetools = "*"
+numpy = "*"
+requests = "*"
+
 [[package]]
 name = "greenlet"
 version = "3.0.0"
@@ -6583,7 +6599,7 @@ files = [
 ]
 
 [package.dependencies]
-greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\" or extra == \"asyncio\""}
+greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\" or extra == \"asyncio\""}
 typing-extensions = ">=4.2.0"
 
 [package.extras]

+ 2 - 1
pyproject.toml

@@ -1,6 +1,6 @@
 [tool.poetry]
 name = "embedchain"
-version = "0.1.47"
+version = "0.1.48"
 description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
 authors = [
     "Taranjeet Singh <taranjeet@embedchain.ai>",
@@ -101,6 +101,7 @@ posthog = "^3.0.2"
 rich = "^13.7.0"
 beautifulsoup4 = "^4.12.2"
 pypdf = "^3.11.0"
+gptcache = "^0.1.43"
 tiktoken = { version = "^0.4.0", optional = true }
 youtube-transcript-api = { version = "^0.6.1", optional = true }
 pytube = { version = "^15.0.0", optional = true }