|
@@ -9,19 +9,24 @@ import requests
|
|
import yaml
|
|
import yaml
|
|
from tqdm import tqdm
|
|
from tqdm import tqdm
|
|
|
|
|
|
-from embedchain.cache import (Config, ExactMatchEvaluation,
|
|
|
|
- SearchDistanceEvaluation, cache,
|
|
|
|
- gptcache_data_manager, gptcache_pre_function)
|
|
|
|
|
|
+from mem0 import Mem0
|
|
|
|
+from embedchain.cache import (
|
|
|
|
+ Config,
|
|
|
|
+ ExactMatchEvaluation,
|
|
|
|
+ SearchDistanceEvaluation,
|
|
|
|
+ cache,
|
|
|
|
+ gptcache_data_manager,
|
|
|
|
+ gptcache_pre_function,
|
|
|
|
+)
|
|
from embedchain.client import Client
|
|
from embedchain.client import Client
|
|
-from embedchain.config import AppConfig, CacheConfig, ChunkerConfig
|
|
|
|
|
|
+from embedchain.config import AppConfig, CacheConfig, ChunkerConfig, Mem0Config
|
|
from embedchain.core.db.database import get_session, init_db, setup_engine
|
|
from embedchain.core.db.database import get_session, init_db, setup_engine
|
|
from embedchain.core.db.models import DataSource
|
|
from embedchain.core.db.models import DataSource
|
|
from embedchain.embedchain import EmbedChain
|
|
from embedchain.embedchain import EmbedChain
|
|
from embedchain.embedder.base import BaseEmbedder
|
|
from embedchain.embedder.base import BaseEmbedder
|
|
from embedchain.embedder.openai import OpenAIEmbedder
|
|
from embedchain.embedder.openai import OpenAIEmbedder
|
|
from embedchain.evaluation.base import BaseMetric
|
|
from embedchain.evaluation.base import BaseMetric
|
|
-from embedchain.evaluation.metrics import (AnswerRelevance, ContextRelevance,
|
|
|
|
- Groundedness)
|
|
|
|
|
|
+from embedchain.evaluation.metrics import AnswerRelevance, ContextRelevance, Groundedness
|
|
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
|
|
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
|
|
from embedchain.helpers.json_serializable import register_deserializable
|
|
from embedchain.helpers.json_serializable import register_deserializable
|
|
from embedchain.llm.base import BaseLlm
|
|
from embedchain.llm.base import BaseLlm
|
|
@@ -55,6 +60,7 @@ class App(EmbedChain):
|
|
auto_deploy: bool = False,
|
|
auto_deploy: bool = False,
|
|
chunker: ChunkerConfig = None,
|
|
chunker: ChunkerConfig = None,
|
|
cache_config: CacheConfig = None,
|
|
cache_config: CacheConfig = None,
|
|
|
|
+ memory_config: Mem0Config = None,
|
|
log_level: int = logging.WARN,
|
|
log_level: int = logging.WARN,
|
|
):
|
|
):
|
|
"""
|
|
"""
|
|
@@ -95,6 +101,7 @@ class App(EmbedChain):
|
|
self.id = None
|
|
self.id = None
|
|
self.chunker = ChunkerConfig(**chunker) if chunker else None
|
|
self.chunker = ChunkerConfig(**chunker) if chunker else None
|
|
self.cache_config = cache_config
|
|
self.cache_config = cache_config
|
|
|
|
+ self.memory_config = memory_config
|
|
|
|
|
|
self.config = config or AppConfig()
|
|
self.config = config or AppConfig()
|
|
self.name = self.config.name
|
|
self.name = self.config.name
|
|
@@ -123,6 +130,11 @@ class App(EmbedChain):
|
|
if self.cache_config is not None:
|
|
if self.cache_config is not None:
|
|
self._init_cache()
|
|
self._init_cache()
|
|
|
|
|
|
|
|
+ # If memory_config is provided, initializing the memory ...
|
|
|
|
+ self.mem0_client = None
|
|
|
|
+ if self.memory_config is not None:
|
|
|
|
+ self.mem0_client = Mem0(api_key=self.memory_config.api_key)
|
|
|
|
+
|
|
# Send anonymous telemetry
|
|
# Send anonymous telemetry
|
|
self._telemetry_props = {"class": self.__class__.__name__}
|
|
self._telemetry_props = {"class": self.__class__.__name__}
|
|
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
|
|
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
|
|
@@ -365,11 +377,13 @@ class App(EmbedChain):
|
|
app_config_data = config_data.get("app", {}).get("config", {})
|
|
app_config_data = config_data.get("app", {}).get("config", {})
|
|
vector_db_config_data = config_data.get("vectordb", {})
|
|
vector_db_config_data = config_data.get("vectordb", {})
|
|
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
|
|
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
|
|
|
|
+ memory_config_data = config_data.get("memory", {})
|
|
llm_config_data = config_data.get("llm", {})
|
|
llm_config_data = config_data.get("llm", {})
|
|
chunker_config_data = config_data.get("chunker", {})
|
|
chunker_config_data = config_data.get("chunker", {})
|
|
cache_config_data = config_data.get("cache", None)
|
|
cache_config_data = config_data.get("cache", None)
|
|
|
|
|
|
app_config = AppConfig(**app_config_data)
|
|
app_config = AppConfig(**app_config_data)
|
|
|
|
+ memory_config = Mem0Config(**memory_config_data) if memory_config_data else None
|
|
|
|
|
|
vector_db_provider = vector_db_config_data.get("provider", "chroma")
|
|
vector_db_provider = vector_db_config_data.get("provider", "chroma")
|
|
vector_db = VectorDBFactory.create(vector_db_provider, vector_db_config_data.get("config", {}))
|
|
vector_db = VectorDBFactory.create(vector_db_provider, vector_db_config_data.get("config", {}))
|
|
@@ -403,6 +417,7 @@ class App(EmbedChain):
|
|
auto_deploy=auto_deploy,
|
|
auto_deploy=auto_deploy,
|
|
chunker=chunker_config_data,
|
|
chunker=chunker_config_data,
|
|
cache_config=cache_config,
|
|
cache_config=cache_config,
|
|
|
|
+ memory_config=memory_config,
|
|
)
|
|
)
|
|
|
|
|
|
def _eval(self, dataset: list[EvalData], metric: Union[BaseMetric, str]):
|
|
def _eval(self, dataset: list[EvalData], metric: Union[BaseMetric, str]):
|