Kaynağa Gözat

Integrate Mem0 (#1462)

Co-authored-by: Deshraj Yadav <deshraj@gatech.edu>
Dev Khant 1 yıl önce
ebeveyn
işleme
bbe56107fb

+ 3 - 0
docs/api-reference/advanced/configuration.mdx

@@ -249,6 +249,9 @@ Alright, let's dive into what each key means in the yaml config above:
     - `config` (Optional): The config for initializing the cache. If not provided, sensible default values are used as mentioned below.
       - `similarity_threshold` (Float): The threshold for similarity evaluation. Defaults to `0.8`.
       - `auto_flush` (Integer): The number of queries after which the cache is flushed. Defaults to `20`.
+7. `memory` Section: (Optional)
+    - `api_key` (String): The API key of mem0.
+    - `top_k` (Integer): The number of top-k results to return. Defaults to `10`.
     <Note>
     If you provide a cache section, the app will automatically configure and use a cache to store the results of the language model. This is useful if you want to speed up the response time and save inference cost of your app.
     </Note>

+ 25 - 0
docs/api-reference/app/chat.mdx

@@ -144,3 +144,28 @@ app.add("https://www.forbes.com/profile/elon-musk")
 query_config = BaseLlmConfig(number_documents=5)
 app.chat("What is the net worth of Elon Musk?", config=query_config)
 ```
+
+### With Mem0 to store chat history
+
+Mem0 is a cutting-edge long-term memory for LLMs to enable personalization for the GenAI stack. It enables LLMs to remember past interactions and provide more personalized responses.
+
+Follow these steps to use Mem0  to enable memory for personalization in your apps:
+- Install the [`mem0`](https://docs.mem0.ai/) package using `pip install memzero`. 
+- Get the api_key from [Mem0 Platform](https://app.mem0.ai/).
+- Provide api_key in config under `memory`, refer [Configurations](docs/api-reference/advanced/configuration.mdx).
+
+```python with mem0
+from embedchain import App
+
+config = {
+  "memory": {
+    "api_key": "m0-xxx",
+    "top_k": 5
+  }
+}
+
+app = App.from_config(config=config)
+app.add("https://www.forbes.com/profile/elon-musk")
+
+app.chat("What is the net worth of Elon Musk?")
+```

+ 21 - 6
embedchain/app.py

@@ -9,19 +9,24 @@ import requests
 import yaml
 from tqdm import tqdm
 
-from embedchain.cache import (Config, ExactMatchEvaluation,
-                              SearchDistanceEvaluation, cache,
-                              gptcache_data_manager, gptcache_pre_function)
+from mem0 import Mem0
+from embedchain.cache import (
+    Config,
+    ExactMatchEvaluation,
+    SearchDistanceEvaluation,
+    cache,
+    gptcache_data_manager,
+    gptcache_pre_function,
+)
 from embedchain.client import Client
-from embedchain.config import AppConfig, CacheConfig, ChunkerConfig
+from embedchain.config import AppConfig, CacheConfig, ChunkerConfig, Mem0Config
 from embedchain.core.db.database import get_session, init_db, setup_engine
 from embedchain.core.db.models import DataSource
 from embedchain.embedchain import EmbedChain
 from embedchain.embedder.base import BaseEmbedder
 from embedchain.embedder.openai import OpenAIEmbedder
 from embedchain.evaluation.base import BaseMetric
-from embedchain.evaluation.metrics import (AnswerRelevance, ContextRelevance,
-                                           Groundedness)
+from embedchain.evaluation.metrics import AnswerRelevance, ContextRelevance, Groundedness
 from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
 from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.llm.base import BaseLlm
@@ -55,6 +60,7 @@ class App(EmbedChain):
         auto_deploy: bool = False,
         chunker: ChunkerConfig = None,
         cache_config: CacheConfig = None,
+        memory_config: Mem0Config = None,
         log_level: int = logging.WARN,
     ):
         """
@@ -95,6 +101,7 @@ class App(EmbedChain):
         self.id = None
         self.chunker = ChunkerConfig(**chunker) if chunker else None
         self.cache_config = cache_config
+        self.memory_config = memory_config
 
         self.config = config or AppConfig()
         self.name = self.config.name
@@ -123,6 +130,11 @@ class App(EmbedChain):
         if self.cache_config is not None:
             self._init_cache()
 
+        # If memory_config is provided, initializing the memory ...
+        self.mem0_client = None
+        if self.memory_config is not None:
+            self.mem0_client = Mem0(api_key=self.memory_config.api_key)
+
         # Send anonymous telemetry
         self._telemetry_props = {"class": self.__class__.__name__}
         self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
@@ -365,11 +377,13 @@ class App(EmbedChain):
         app_config_data = config_data.get("app", {}).get("config", {})
         vector_db_config_data = config_data.get("vectordb", {})
         embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
+        memory_config_data = config_data.get("memory", {})
         llm_config_data = config_data.get("llm", {})
         chunker_config_data = config_data.get("chunker", {})
         cache_config_data = config_data.get("cache", None)
 
         app_config = AppConfig(**app_config_data)
+        memory_config = Mem0Config(**memory_config_data) if memory_config_data else None
 
         vector_db_provider = vector_db_config_data.get("provider", "chroma")
         vector_db = VectorDBFactory.create(vector_db_provider, vector_db_config_data.get("config", {}))
@@ -403,6 +417,7 @@ class App(EmbedChain):
             auto_deploy=auto_deploy,
             chunker=chunker_config_data,
             cache_config=cache_config,
+            memory_config=memory_config,
         )
 
     def _eval(self, dataset: list[EvalData], metric: Union[BaseMetric, str]):

+ 1 - 0
embedchain/config/__init__.py

@@ -12,3 +12,4 @@ from .vectordb.chroma import ChromaDbConfig
 from .vectordb.elasticsearch import ElasticsearchDBConfig
 from .vectordb.opensearch import OpenSearchDBConfig
 from .vectordb.zilliz import ZillizDBConfig
+from .mem0_config import Mem0Config

+ 30 - 0
embedchain/config/llm/base.py

@@ -50,6 +50,35 @@ Query: $query
 Answer:
 """  # noqa:E501
 
+DEFAULT_PROMPT_WITH_MEM0_MEMORY = """
+You are a Q&A expert system. Your responses must always be rooted in the context provided for each query. You are also provided with the conversation history and memories with the user. Make sure to use relevant context from conversation history and memories as needed.
+
+Here are some guidelines to follow:
+
+1. Refrain from explicitly mentioning the context provided in your response.
+2. Take into consideration the conversation history and memories provided.
+3. The context should silently guide your answers without being directly acknowledged.
+4. Do not use phrases such as 'According to the context provided', 'Based on the context, ...' etc.
+
+Context information:
+----------------------
+$context
+----------------------
+
+Conversation history:
+----------------------
+$history
+----------------------
+
+Memories/Preferences:
+----------------------
+$memories
+----------------------
+
+Query: $query
+Answer:
+"""  # noqa:E501
+
 DOCS_SITE_DEFAULT_PROMPT = """
 You are an expert AI assistant for developer support product. Your responses must always be rooted in the context provided for each query. Wherever possible, give complete code snippet. Dont make up any code snippet on your own.
 
@@ -70,6 +99,7 @@ Answer:
 
 DEFAULT_PROMPT_TEMPLATE = Template(DEFAULT_PROMPT)
 DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE = Template(DEFAULT_PROMPT_WITH_HISTORY)
+DEFAULT_PROMPT_WITH_MEM0_MEMORY_TEMPLATE = Template(DEFAULT_PROMPT_WITH_MEM0_MEMORY)
 DOCS_SITE_PROMPT_TEMPLATE = Template(DOCS_SITE_DEFAULT_PROMPT)
 query_re = re.compile(r"\$\{*query\}*")
 context_re = re.compile(r"\$\{*context\}*")

+ 21 - 0
embedchain/config/mem0_config.py

@@ -0,0 +1,21 @@
+from typing import Any, Optional
+
+from embedchain.config.base_config import BaseConfig
+from embedchain.helpers.json_serializable import register_deserializable
+
+
+@register_deserializable
+class Mem0Config(BaseConfig):
+    def __init__(self, api_key: str, top_k: Optional[int] = 10):
+        self.api_key = api_key
+        self.top_k = top_k
+
+    @staticmethod
+    def from_config(config: Optional[dict[str, Any]]):
+        if config is None:
+            return Mem0Config()
+        else:
+            return Mem0Config(
+                api_key=config.get("api_key", ""),
+                init_config=config.get("top_k", 10),
+            )

+ 24 - 2
embedchain/embedchain.py

@@ -52,6 +52,8 @@ class EmbedChain(JSONSerializable):
         """
         self.config = config
         self.cache_config = None
+        self.memory_config = None
+        self.mem0_client = None
         # Llm
         self.llm = llm
         # Database has support for config assignment for backwards compatibility
@@ -595,6 +597,12 @@ class EmbedChain(JSONSerializable):
         else:
             contexts_data_for_llm_query = contexts
 
+        memories = None
+        if self.mem0_client:
+            memories = self.mem0_client.search(
+                query=input_query, agent_id=self.config.id, session_id=session_id, limit=self.memory_config.top_k
+            )
+
         # Update the history beforehand so that we can handle multiple chat sessions in the same python session
         self.llm.update_history(app_id=self.config.id, session_id=session_id)
 
@@ -615,13 +623,27 @@ class EmbedChain(JSONSerializable):
             logger.debug("Cache disabled. Running chat without cache.")
             if self.llm.config.token_usage:
                 answer, token_info = self.llm.query(
-                    input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
+                    input_query=input_query,
+                    contexts=contexts_data_for_llm_query,
+                    config=config,
+                    dry_run=dry_run,
+                    memories=memories,
                 )
             else:
                 answer = self.llm.query(
-                    input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
+                    input_query=input_query,
+                    contexts=contexts_data_for_llm_query,
+                    config=config,
+                    dry_run=dry_run,
+                    memories=memories,
                 )
 
+        # Add to Mem0 memory if enabled
+        # TODO: Might need to prepend with some text like: 
+        # "Remember user preferences from following user query: {input_query}"
+        if self.mem0_client:
+            self.mem0_client.add(data=input_query, agent_id=self.config.id, session_id=session_id)
+
         # add conversation in memory
         self.llm.add_history(self.config.id, input_query, answer, session_id=session_id)
 

+ 32 - 8
embedchain/llm/base.py

@@ -5,9 +5,12 @@ from typing import Any, Optional
 from langchain.schema import BaseMessage as LCBaseMessage
 
 from embedchain.config import BaseLlmConfig
-from embedchain.config.llm.base import (DEFAULT_PROMPT,
-                                        DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
-                                        DOCS_SITE_PROMPT_TEMPLATE)
+from embedchain.config.llm.base import (
+    DEFAULT_PROMPT,
+    DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
+    DEFAULT_PROMPT_WITH_MEM0_MEMORY_TEMPLATE,
+    DOCS_SITE_PROMPT_TEMPLATE,
+)
 from embedchain.helpers.json_serializable import JSONSerializable
 from embedchain.memory.base import ChatHistory
 from embedchain.memory.message import ChatMessage
@@ -74,6 +77,16 @@ class BaseLlm(JSONSerializable):
         """
         return "\n".join(self.history)
 
+    def _format_memories(self, memories: list[dict]) -> str:
+        """Format memories to be used in prompt
+
+        :param memories: Memories to format
+        :type memories: list[dict]
+        :return: Formatted memories
+        :rtype: str
+        """
+        return "\n".join([memory["text"] for memory in memories])
+
     def generate_prompt(self, input_query: str, contexts: list[str], **kwargs: dict[str, Any]) -> str:
         """
         Generates a prompt based on the given query and context, ready to be
@@ -88,6 +101,7 @@ class BaseLlm(JSONSerializable):
         """
         context_string = " | ".join(contexts)
         web_search_result = kwargs.get("web_search_result", "")
+        memories = kwargs.get("memories", None)
         if web_search_result:
             context_string = self._append_search_and_context(context_string, web_search_result)
 
@@ -103,10 +117,19 @@ class BaseLlm(JSONSerializable):
                 not self.config._validate_prompt_history(self.config.prompt)
                 and self.config.prompt.template == DEFAULT_PROMPT
             ):
-                # swap in the template with history
-                prompt = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE.substitute(
-                    context=context_string, query=input_query, history=self._format_history()
-                )
+                if memories:
+                    # swap in the template with Mem0 memory template
+                    prompt = DEFAULT_PROMPT_WITH_MEM0_MEMORY_TEMPLATE.substitute(
+                        context=context_string,
+                        query=input_query,
+                        history=self._format_history(),
+                        memories=self._format_memories(memories),
+                    )
+                else:
+                    # swap in the template with history
+                    prompt = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE.substitute(
+                        context=context_string, query=input_query, history=self._format_history()
+                    )
             else:
                 # If we can't swap in the default, we still proceed but tell users that the history is ignored.
                 logger.warning(
@@ -180,7 +203,7 @@ class BaseLlm(JSONSerializable):
         if token_info:
             logger.info(f"Token Info: {token_info}")
 
-    def query(self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False):
+    def query(self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False, memories=None):
         """
         Queries the vector database based on the given input query.
         Gets relevant doc based on the query and then passes it to an
@@ -216,6 +239,7 @@ class BaseLlm(JSONSerializable):
             k = {}
             if self.config.online:
                 k["web_search_result"] = self.access_search_and_get_results(input_query)
+            k["memories"] = memories
             prompt = self.generate_prompt(input_query, contexts, **k)
             logger.info(f"Prompt: {prompt}")
             if dry_run:

+ 4 - 0
embedchain/utils/misc.py

@@ -520,6 +520,10 @@ def validate_config(config_data):
                     Optional("auto_flush"): int,
                 },
             },
+            Optional("memory"): {
+                "api_key": str,
+                Optional("top_k"): int,
+            },
         }
     )
 

+ 33 - 18
poetry.lock

@@ -385,17 +385,17 @@ files = [
 
 [[package]]
 name = "boto3"
-version = "1.34.139"
+version = "1.34.140"
 description = "The AWS SDK for Python"
 optional = false
 python-versions = ">=3.8"
 files = [
-    {file = "boto3-1.34.139-py3-none-any.whl", hash = "sha256:98b2a12bcb30e679fa9f60fc74145a39db5ec2ca7b7c763f42896e3bd9b3a38d"},
-    {file = "boto3-1.34.139.tar.gz", hash = "sha256:32b99f0d76ec81fdca287ace2c9744a2eb8b92cb62bf4d26d52a4f516b63a6bf"},
+    {file = "boto3-1.34.140-py3-none-any.whl", hash = "sha256:23ca8d8f7a30c3bbd989808056b5fc5d68ff5121c02c722c6167b6b1bb7f8726"},
+    {file = "boto3-1.34.140.tar.gz", hash = "sha256:578bbd5e356005719b6b610d03edff7ea1b0824d078afe62d3fb8bea72f83a87"},
 ]
 
 [package.dependencies]
-botocore = ">=1.34.139,<1.35.0"
+botocore = ">=1.34.140,<1.35.0"
 jmespath = ">=0.7.1,<2.0.0"
 s3transfer = ">=0.10.0,<0.11.0"
 
@@ -404,13 +404,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
 
 [[package]]
 name = "botocore"
-version = "1.34.139"
+version = "1.34.140"
 description = "Low-level, data-driven core of boto 3."
 optional = false
 python-versions = ">=3.8"
 files = [
-    {file = "botocore-1.34.139-py3-none-any.whl", hash = "sha256:dd1e085d4caa2a4c1b7d83e3bc51416111c8238a35d498e9d3b04f3b63b086ba"},
-    {file = "botocore-1.34.139.tar.gz", hash = "sha256:df023d8cf8999d574214dad4645cb90f9d2ccd1494f6ee2b57b1ab7522f6be77"},
+    {file = "botocore-1.34.140-py3-none-any.whl", hash = "sha256:43940d3a67d946ba3301631ba4078476a75f1015d4fb0fb0272d0b754b2cf9de"},
+    {file = "botocore-1.34.140.tar.gz", hash = "sha256:86302b2226c743b9eec7915a4c6cfaffd338ae03989cd9ee181078ef39d1ab39"},
 ]
 
 [package.dependencies]
@@ -882,13 +882,13 @@ all = ["pycocotools (==2.0.6)"]
 
 [[package]]
 name = "clarifai-grpc"
-version = "10.5.4"
+version = "10.6.1"
 description = "Clarifai gRPC API Client"
 optional = true
 python-versions = ">=3.8"
 files = [
-    {file = "clarifai_grpc-10.5.4-py3-none-any.whl", hash = "sha256:ae4c4d8985fdd2bf326cec27ee834571e44d0e989fb12686dd681f9b553ae218"},
-    {file = "clarifai_grpc-10.5.4.tar.gz", hash = "sha256:c67ce0dde186e8bab0d42a9923d28ddb4a05017b826c8e52ac7a86ec6df5f12a"},
+    {file = "clarifai_grpc-10.6.1-py3-none-any.whl", hash = "sha256:7f07c262f46042995b11af10cdd552718c4487e955db1b3f1253fcb0c2ab1ce1"},
+    {file = "clarifai_grpc-10.6.1.tar.gz", hash = "sha256:f692e3d6a051a1228ca371c3a9dc705cc9a61334eecc454d056f7af0b6f4dbad"},
 ]
 
 [package.dependencies]
@@ -1280,18 +1280,17 @@ stone = ">=2"
 
 [[package]]
 name = "duckduckgo-search"
-version = "6.1.8"
+version = "6.1.9"
 description = "Search for words, documents, images, news, maps and text translation using the DuckDuckGo.com search engine."
 optional = true
 python-versions = ">=3.8"
 files = [
-    {file = "duckduckgo_search-6.1.8-py3-none-any.whl", hash = "sha256:fb67f6ae8df4f291462010018342aeaaa4f259b54667dc48de37c31d8ecab027"},
-    {file = "duckduckgo_search-6.1.8.tar.gz", hash = "sha256:e38fa695f598b0b2bd779fffde1fef2eeff1d6a3f218772e50f8b4f381f63279"},
+    {file = "duckduckgo_search-6.1.9-py3-none-any.whl", hash = "sha256:a208babf87b971290b1afed9908bc5ab6ac6c1738b90b48ad613267f7630cb77"},
+    {file = "duckduckgo_search-6.1.9.tar.gz", hash = "sha256:0d7d746e003d6b3bcd0d0dc11927c9a69b6fa271f3b3f65df6f01ea4d9d2689d"},
 ]
 
 [package.dependencies]
 click = ">=8.1.7"
-orjson = ">=3.10.6"
 pyreqwest-impersonate = ">=0.4.9"
 
 [package.extras]
@@ -3337,6 +3336,22 @@ files = [
     {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
 ]
 
+[[package]]
+name = "memzero"
+version = "0.0.7"
+description = "Long-term memory for AI Agents"
+optional = false
+python-versions = "<4.0,>=3.9"
+files = [
+    {file = "memzero-0.0.7-py3-none-any.whl", hash = "sha256:65f6da88d46263dbc05621fcd01bd09616d0e7f082d55ed9899dc2152491ffd2"},
+    {file = "memzero-0.0.7.tar.gz", hash = "sha256:0c1f413d8ee0ade955fe9f8b8f5aff2cf58bc94869537aca62139db3d9f50725"},
+]
+
+[package.dependencies]
+httpx = ">=0.27.0,<0.28.0"
+posthog = ">=3.5.0,<4.0.0"
+pydantic = ">=2.7.3,<3.0.0"
+
 [[package]]
 name = "milvus-lite"
 version = "2.4.8"
@@ -6603,13 +6618,13 @@ files = [
 
 [[package]]
 name = "tenacity"
-version = "8.4.2"
+version = "8.5.0"
 description = "Retry code until it succeeds"
 optional = false
 python-versions = ">=3.8"
 files = [
-    {file = "tenacity-8.4.2-py3-none-any.whl", hash = "sha256:9e6f7cf7da729125c7437222f8a522279751cdfbe6b67bfe64f75d3a348661b2"},
-    {file = "tenacity-8.4.2.tar.gz", hash = "sha256:cd80a53a79336edba8489e767f729e4f391c896956b57140b5d7511a64bbd3ef"},
+    {file = "tenacity-8.5.0-py3-none-any.whl", hash = "sha256:b594c2a5945830c267ce6b79a166228323ed52718f30302c1359836112346687"},
+    {file = "tenacity-8.5.0.tar.gz", hash = "sha256:8bc6c0c8a09b31e6cad13c47afbed1a567518250a9a171418582ed8d9c20ca78"},
 ]
 
 [package.extras]
@@ -7914,4 +7929,4 @@ youtube = ["youtube-transcript-api", "yt_dlp"]
 [metadata]
 lock-version = "2.0"
 python-versions = ">=3.9,<=3.13"
-content-hash = "afc88f00bafd2b76a954c758a0556cba6d3854e98c444bc5e720319bf472caa8"
+content-hash = "22f5fb8700344234abb1d98a097a55c35162d2475010f3c0c3a97e37dc72c545"

+ 1 - 0
pyproject.toml

@@ -103,6 +103,7 @@ beautifulsoup4 = "^4.12.2"
 pypdf = "^4.0.1"
 gptcache = "^0.1.43"
 pysbd = "^0.3.4"
+memzero = "^0.0.7"
 tiktoken = { version = "^0.7.0", optional = true }
 youtube-transcript-api = { version = "^0.6.1", optional = true }
 pytube = { version = "^15.0.0", optional = true }