浏览代码

[Feature] Add support for GPTCache (#1065)

Deven Patel 1 年之前
父节点
当前提交
04daa1b206
共有 10 个文件被更改,包括 157 次插入11 次删除
  1. 3 0
      .gitignore
  2. 25 1
      embedchain/app.py
  3. 40 0
      embedchain/cache.py
  4. 1 0
      embedchain/config/__init__.py
  5. 16 0
      embedchain/config/cache_config.py
  6. 36 6
      embedchain/embedchain.py
  7. 12 0
      embedchain/embedder/base.py
  8. 3 0
      embedchain/utils.py
  9. 19 3
      poetry.lock
  10. 2 1
      pyproject.toml

+ 3 - 0
.gitignore

@@ -175,3 +175,6 @@ notebooks/*.yaml
 .ipynb_checkpoints/
 .ipynb_checkpoints/
 
 
 !configs/*.yaml
 !configs/*.yaml
+
+# cache db
+*.db

+ 25 - 1
embedchain/app.py

@@ -9,8 +9,10 @@ from typing import Any, Dict, Optional
 import requests
 import requests
 import yaml
 import yaml
 
 
+from embedchain.cache import (Config, SearchDistanceEvaluation, cache,
+                              gptcache_data_manager, gptcache_pre_function)
 from embedchain.client import Client
 from embedchain.client import Client
-from embedchain.config import AppConfig, ChunkerConfig
+from embedchain.config import AppConfig, CacheConfig, ChunkerConfig
 from embedchain.constants import SQLITE_PATH
 from embedchain.constants import SQLITE_PATH
 from embedchain.embedchain import EmbedChain
 from embedchain.embedchain import EmbedChain
 from embedchain.embedder.base import BaseEmbedder
 from embedchain.embedder.base import BaseEmbedder
@@ -48,6 +50,7 @@ class App(EmbedChain):
         log_level=logging.WARN,
         log_level=logging.WARN,
         auto_deploy: bool = False,
         auto_deploy: bool = False,
         chunker: ChunkerConfig = None,
         chunker: ChunkerConfig = None,
+        cache_config: CacheConfig = None,
     ):
     ):
         """
         """
         Initialize a new `App` instance.
         Initialize a new `App` instance.
@@ -88,6 +91,7 @@ class App(EmbedChain):
         self.chunker = None
         self.chunker = None
         if chunker:
         if chunker:
             self.chunker = ChunkerConfig(**chunker)
             self.chunker = ChunkerConfig(**chunker)
+        self.cache_config = cache_config
 
 
         self.config = config or AppConfig()
         self.config = config or AppConfig()
         self.name = self.config.name
         self.name = self.config.name
@@ -109,6 +113,10 @@ class App(EmbedChain):
         self.llm = llm or OpenAILlm()
         self.llm = llm or OpenAILlm()
         self._init_db()
         self._init_db()
 
 
+        # If cache_config is provided, initializing the cache ...
+        if self.cache_config is not None:
+            self._init_cache()
+
         # Send anonymous telemetry
         # Send anonymous telemetry
         self._telemetry_props = {"class": self.__class__.__name__}
         self._telemetry_props = {"class": self.__class__.__name__}
         self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
         self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
@@ -147,6 +155,15 @@ class App(EmbedChain):
         self.db._initialize()
         self.db._initialize()
         self.db.set_collection_name(self.db.config.collection_name)
         self.db.set_collection_name(self.db.config.collection_name)
 
 
+    def _init_cache(self):
+        cache.init(
+            pre_embedding_func=gptcache_pre_function,
+            embedding_func=self.embedding_model.to_embeddings,
+            data_manager=gptcache_data_manager(vector_dimension=self.embedding_model.vector_dimension),
+            similarity_evaluation=SearchDistanceEvaluation(max_distance=1.0),
+            config=Config(similarity_threshold=self.cache_config.similarity_threshold),
+        )
+
     def _init_client(self):
     def _init_client(self):
         """
         """
         Initialize the client.
         Initialize the client.
@@ -399,6 +416,7 @@ class App(EmbedChain):
         embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
         embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
         llm_config_data = config_data.get("llm", {})
         llm_config_data = config_data.get("llm", {})
         chunker_config_data = config_data.get("chunker", {})
         chunker_config_data = config_data.get("chunker", {})
+        cache_config_data = config_data.get("cache", None)
 
 
         app_config = AppConfig(**app_config_data)
         app_config = AppConfig(**app_config_data)
 
 
@@ -416,6 +434,11 @@ class App(EmbedChain):
             embedding_model_provider, embedding_model_config_data.get("config", {})
             embedding_model_provider, embedding_model_config_data.get("config", {})
         )
         )
 
 
+        if cache_config_data is not None:
+            cache_config = CacheConfig(**cache_config_data)
+        else:
+            cache_config = None
+
         # Send anonymous telemetry
         # Send anonymous telemetry
         event_properties = {"init_type": "config_data"}
         event_properties = {"init_type": "config_data"}
         AnonymousTelemetry().capture(event_name="init", properties=event_properties)
         AnonymousTelemetry().capture(event_name="init", properties=event_properties)
@@ -428,4 +451,5 @@ class App(EmbedChain):
             config_data=config_data,
             config_data=config_data,
             auto_deploy=auto_deploy,
             auto_deploy=auto_deploy,
             chunker=chunker_config_data,
             chunker=chunker_config_data,
+            cache_config=cache_config,
         )
         )

+ 40 - 0
embedchain/cache.py

@@ -0,0 +1,40 @@
+import logging
+import os  # noqa: F401
+from typing import Any, Dict
+
+from gptcache import cache  # noqa: F401
+from gptcache.adapter.adapter import adapt  # noqa: F401
+from gptcache.config import Config  # noqa: F401
+from gptcache.manager import get_data_manager
+from gptcache.manager.scalar_data.base import Answer
+from gptcache.manager.scalar_data.base import DataType as CacheDataType
+from gptcache.session import Session
+from gptcache.similarity_evaluation.distance import \
+    SearchDistanceEvaluation  # noqa: F401
+
+
+def gptcache_pre_function(data: Dict[str, Any], **params: Dict[str, Any]):
+    return data["input_query"]
+
+
+def gptcache_data_manager(vector_dimension):
+    return get_data_manager(cache_base="sqlite", vector_base="chromadb", max_size=1000, eviction="LRU")
+
+
+def gptcache_data_convert(cache_data):
+    logging.info("[Cache] Cache hit, returning cache data...")
+    return cache_data
+
+
+def gptcache_update_cache_callback(llm_data, update_cache_func, *args, **kwargs):
+    logging.info("[Cache] Cache missed, updating cache...")
+    update_cache_func(Answer(llm_data, CacheDataType.STR))
+    return llm_data
+
+
+def _gptcache_session_hit_func(cur_session_id: str, cache_session_ids: list, cache_questions: list, cache_answer: str):
+    return cur_session_id in cache_session_ids
+
+
+def get_gptcache_session(session_id: str):
+    return Session(name=session_id, check_hit_func=_gptcache_session_hit_func)

+ 1 - 0
embedchain/config/__init__.py

@@ -3,6 +3,7 @@
 from .add_config import AddConfig, ChunkerConfig
 from .add_config import AddConfig, ChunkerConfig
 from .app_config import AppConfig
 from .app_config import AppConfig
 from .base_config import BaseConfig
 from .base_config import BaseConfig
+from .cache_config import CacheConfig
 from .embedder.base import BaseEmbedderConfig
 from .embedder.base import BaseEmbedderConfig
 from .embedder.base import BaseEmbedderConfig as EmbedderConfig
 from .embedder.base import BaseEmbedderConfig as EmbedderConfig
 from .llm.base import BaseLlmConfig
 from .llm.base import BaseLlmConfig

+ 16 - 0
embedchain/config/cache_config.py

@@ -0,0 +1,16 @@
+from typing import Optional
+
+from embedchain.config.base_config import BaseConfig
+from embedchain.helpers.json_serializable import register_deserializable
+
+
+@register_deserializable
+class CacheConfig(BaseConfig):
+    def __init__(
+        self,
+        similarity_threshold: Optional[float] = 0.5,
+    ):
+        if similarity_threshold < 0 or similarity_threshold > 1:
+            raise ValueError(f"similarity_threshold {similarity_threshold} should be between 0 and 1")
+
+        self.similarity_threshold = similarity_threshold

+ 36 - 6
embedchain/embedchain.py

@@ -7,6 +7,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union
 from dotenv import load_dotenv
 from dotenv import load_dotenv
 from langchain.docstore.document import Document
 from langchain.docstore.document import Document
 
 
+from embedchain.cache import (adapt, get_gptcache_session,
+                              gptcache_data_convert,
+                              gptcache_update_cache_callback)
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
 from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
 from embedchain.config.base_app_config import BaseAppConfig
 from embedchain.config.base_app_config import BaseAppConfig
@@ -52,6 +55,7 @@ class EmbedChain(JSONSerializable):
         """
         """
 
 
         self.config = config
         self.config = config
+        self.cache_config = None
         # Llm
         # Llm
         self.llm = llm
         self.llm = llm
         # Database has support for config assignment for backwards compatibility
         # Database has support for config assignment for backwards compatibility
@@ -546,9 +550,22 @@ class EmbedChain(JSONSerializable):
         else:
         else:
             contexts_data_for_llm_query = contexts
             contexts_data_for_llm_query = contexts
 
 
-        answer = self.llm.query(
-            input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
-        )
+        if self.cache_config is not None:
+            logging.info("Cache enabled. Checking cache...")
+            answer = adapt(
+                llm_handler=self.llm.query,
+                cache_data_convert=gptcache_data_convert,
+                update_cache_callback=gptcache_update_cache_callback,
+                session=get_gptcache_session(session_id=self.config.id),
+                input_query=input_query,
+                contexts=contexts_data_for_llm_query,
+                config=config,
+                dry_run=dry_run,
+            )
+        else:
+            answer = self.llm.query(
+                input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
+            )
 
 
         # Send anonymous telemetry
         # Send anonymous telemetry
         self.telemetry.capture(event_name="query", properties=self._telemetry_props)
         self.telemetry.capture(event_name="query", properties=self._telemetry_props)
@@ -599,9 +616,22 @@ class EmbedChain(JSONSerializable):
         else:
         else:
             contexts_data_for_llm_query = contexts
             contexts_data_for_llm_query = contexts
 
 
-        answer = self.llm.chat(
-            input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
-        )
+        if self.cache_config is not None:
+            logging.info("Cache enabled. Checking cache...")
+            answer = adapt(
+                llm_handler=self.llm.chat,
+                cache_data_convert=gptcache_data_convert,
+                update_cache_callback=gptcache_update_cache_callback,
+                session=get_gptcache_session(session_id=self.config.id),
+                input_query=input_query,
+                contexts=contexts_data_for_llm_query,
+                config=config,
+                dry_run=dry_run,
+            )
+        else:
+            answer = self.llm.chat(
+                input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
+            )
 
 
         # add conversation in memory
         # add conversation in memory
         self.llm.add_history(self.config.id, input_query, answer)
         self.llm.add_history(self.config.id, input_query, answer)

+ 12 - 0
embedchain/embedder/base.py

@@ -75,3 +75,15 @@ class BaseEmbedder:
         """
         """
 
 
         return EmbeddingFunc(embeddings.embed_documents)
         return EmbeddingFunc(embeddings.embed_documents)
+
+    def to_embeddings(self, data: str, **_):
+        """
+        Convert data to embeddings
+
+        :param data: data to convert to embeddings
+        :type data: str
+        :return: embeddings
+        :rtype: list[float]
+        """
+        embeddings = self.embedding_fn([data])
+        return embeddings[0]

+ 3 - 0
embedchain/utils.py

@@ -436,6 +436,9 @@ def validate_config(config_data):
                 Optional("length_function"): str,
                 Optional("length_function"): str,
                 Optional("min_chunk_size"): int,
                 Optional("min_chunk_size"): int,
             },
             },
+            Optional("cache"): {
+                Optional("similarity_threshold"): float,
+            },
         }
         }
     )
     )
 
 

+ 19 - 3
poetry.lock

@@ -1849,11 +1849,11 @@ files = [
 google-auth = ">=2.14.1,<3.0.dev0"
 google-auth = ">=2.14.1,<3.0.dev0"
 googleapis-common-protos = ">=1.56.2,<2.0.dev0"
 googleapis-common-protos = ">=1.56.2,<2.0.dev0"
 grpcio = [
 grpcio = [
-    {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
+    {version = ">=1.33.2,<2.0dev", optional = true, markers = "extra == \"grpc\""},
     {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
     {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
 ]
 ]
 grpcio-status = [
 grpcio-status = [
-    {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
+    {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "extra == \"grpc\""},
     {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
     {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
 ]
 ]
 protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0"
 protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0"
@@ -2176,6 +2176,22 @@ tqdm = "*"
 [package.extras]
 [package.extras]
 dev = ["black", "isort", "mkautodoc", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]", "pytest", "setuptools", "twine", "wheel"]
 dev = ["black", "isort", "mkautodoc", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]", "pytest", "setuptools", "twine", "wheel"]
 
 
+[[package]]
+name = "gptcache"
+version = "0.1.43"
+description = "GPTCache, a powerful caching library that can be used to speed up and lower the cost of chat applications that rely on the LLM service. GPTCache works as a memcache for AIGC applications, similar to how Redis works for traditional applications."
+optional = false
+python-versions = ">=3.8.1"
+files = [
+    {file = "gptcache-0.1.43-py3-none-any.whl", hash = "sha256:9c557ec9cc14428942a0ebf1c838520dc6d2be801d67bb6964807043fc2feaf5"},
+    {file = "gptcache-0.1.43.tar.gz", hash = "sha256:cebe7ec5e32a3347bf839e933a34e67c7fcae620deaa7cb8c6d7d276c8686f1a"},
+]
+
+[package.dependencies]
+cachetools = "*"
+numpy = "*"
+requests = "*"
+
 [[package]]
 [[package]]
 name = "greenlet"
 name = "greenlet"
 version = "3.0.0"
 version = "3.0.0"
@@ -6583,7 +6599,7 @@ files = [
 ]
 ]
 
 
 [package.dependencies]
 [package.dependencies]
-greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\" or extra == \"asyncio\""}
+greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\" or extra == \"asyncio\""}
 typing-extensions = ">=4.2.0"
 typing-extensions = ">=4.2.0"
 
 
 [package.extras]
 [package.extras]

+ 2 - 1
pyproject.toml

@@ -1,6 +1,6 @@
 [tool.poetry]
 [tool.poetry]
 name = "embedchain"
 name = "embedchain"
-version = "0.1.47"
+version = "0.1.48"
 description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
 description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
 authors = [
 authors = [
     "Taranjeet Singh <taranjeet@embedchain.ai>",
     "Taranjeet Singh <taranjeet@embedchain.ai>",
@@ -101,6 +101,7 @@ posthog = "^3.0.2"
 rich = "^13.7.0"
 rich = "^13.7.0"
 beautifulsoup4 = "^4.12.2"
 beautifulsoup4 = "^4.12.2"
 pypdf = "^3.11.0"
 pypdf = "^3.11.0"
+gptcache = "^0.1.43"
 tiktoken = { version = "^0.4.0", optional = true }
 tiktoken = { version = "^0.4.0", optional = true }
 youtube-transcript-api = { version = "^0.6.1", optional = true }
 youtube-transcript-api = { version = "^0.6.1", optional = true }
 pytube = { version = "^15.0.0", optional = true }
 pytube = { version = "^15.0.0", optional = true }