瀏覽代碼

Support model config in LLMs (#1495)

Dev Khant 1 年之前
父節點
當前提交
40c9abe484

+ 0 - 0
mem0/configs/llms/__init__.py


+ 34 - 0
mem0/configs/llms/base.py

@@ -0,0 +1,34 @@
+from abc import ABC
+from typing import Optional
+
+class BaseLlmConfig(ABC):
+    """
+    Config for LLMs.
+    """
+
+    def __init__(
+        self,
+        model: Optional[str] = None,
+        temperature: float = 0,
+        max_tokens: int = 3000,
+        top_p: float = 1
+    ):
+        """
+        Initializes a configuration class instance for the LLM.
+
+        :param model: Controls the OpenAI model used, defaults to None
+        :type model: Optional[str], optional
+        :param temperature:  Controls the randomness of the model's output.
+        Higher values (closer to 1) make output more random, lower values make it more deterministic, defaults to 0
+        :type temperature: float, optional
+        :param max_tokens: Controls how many tokens are generated, defaults to 3000
+        :type max_tokens: int, optional
+        :param top_p: Controls the diversity of words. Higher values (closer to 1) make word selection more diverse,
+        defaults to 1
+        :type top_p: float, optional
+        """
+        
+        self.model = model
+        self.temperature = temperature
+        self.max_tokens = max_tokens
+        self.top_p = top_p

+ 11 - 6
mem0/llms/aws_bedrock.py

@@ -5,12 +5,16 @@ from typing import Dict, List, Optional, Any
 import boto3
 
 from mem0.llms.base import LLMBase
+from mem0.configs.llms.base import BaseLlmConfig
 
+class AWSBedrockLLM(LLMBase):    
+    def __init__(self, config: Optional[BaseLlmConfig] = None):
+        super().__init__(config)
 
-class AWSBedrockLLM(LLMBase):
-    def __init__(self, model="cohere.command-r-v1:0"):
+        if not self.config.model:
+            self.config.model="anthropic.claude-3-5-sonnet-20240620-v1:0"
         self.client = boto3.client("bedrock-runtime", region_name=os.environ.get("AWS_REGION"), aws_access_key_id=os.environ.get("AWS_ACCESS_KEY"), aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"))
-        self.model = model
+        self.model_kwargs = {"temperature": self.config.temperature, "max_tokens_to_sample": self.config.max_tokens, "top_p": self.config.top_p}
 
     def _format_messages(self, messages: List[Dict[str, str]]) -> str:
         """
@@ -171,19 +175,20 @@ class AWSBedrockLLM(LLMBase):
         if tools:
             # Use converse method when tools are provided
             messages = [{"role": "user", "content": [{"text": message["content"]} for message in messages]}]
+            inference_config = {"temperature": self.model_kwargs["temperature"], "maxTokens": self.model_kwargs["max_tokens_to_sample"], "topP": self.model_kwargs["top_p"]}
             tools_config = {"tools": self._convert_tool_format(tools)}
 
             response = self.client.converse(
-                modelId=self.model,
+                modelId=self.config.model,
                 messages=messages,
+                inferenceConfig=inference_config,
                 toolConfig=tools_config
             )
-            print("Tools response: ", response)
         else:
             # Use invoke_model method when no tools are provided
             prompt = self._format_messages(messages)
             provider = self.model.split(".")[0]
-            input_body = self._prepare_input(provider, self.model, prompt)
+            input_body = self._prepare_input(provider, self.config.model, prompt, **self.model_kwargs)
             body = json.dumps(input_body)
 
             response = self.client.invoke_model(

+ 14 - 0
mem0/llms/base.py

@@ -1,7 +1,21 @@
+from typing import Optional
 from abc import ABC, abstractmethod
 
+from mem0.configs.llms.base import BaseLlmConfig
+
 
 class LLMBase(ABC):
+    def __init__(self, config: Optional[BaseLlmConfig] = None):
+        """Initialize a base LLM class
+
+        :param config: LLM configuration option class, defaults to None
+        :type config: Optional[BaseLlmConfig], optional
+        """
+        if config is None:
+            self.config = BaseLlmConfig()
+        else:
+            self.config = config
+
     @abstractmethod
     def generate_response(self, messages):
         """

+ 1 - 1
mem0/llms/configs.py

@@ -8,7 +8,7 @@ class LlmConfig(BaseModel):
         description="Provider of the LLM (e.g., 'ollama', 'openai')", default="openai"
     )
     config: Optional[dict] = Field(
-        description="Configuration for the specific LLM", default=None
+        description="Configuration for the specific LLM", default={}
     )
 
     @field_validator("config")

+ 13 - 3
mem0/llms/groq.py

@@ -4,12 +4,16 @@ from typing import Dict, List, Optional
 from groq import Groq
 
 from mem0.llms.base import LLMBase
+from mem0.configs.llms.base import BaseLlmConfig
 
 
 class GroqLLM(LLMBase):
-    def __init__(self, model="llama3-70b-8192"):
+    def __init__(self, config: Optional[BaseLlmConfig] = None):
+        super().__init__(config)
+
+        if not self.config.model:
+            self.config.model="llama3-70b-8192"
         self.client = Groq()
-        self.model = model
 
     def _parse_response(self, response, tools):
         """
@@ -58,7 +62,13 @@ class GroqLLM(LLMBase):
         Returns:
             str: The generated response.
         """
-        params = {"model": self.model, "messages": messages}
+        params = {
+            "model": self.config.model, 
+            "messages": messages, 
+            "temperature": self.config.temperature, 
+            "max_tokens": self.config.max_tokens, 
+            "top_p": self.config.top_p
+        }
         if response_format:
             params["response_format"] = response_format
         if tools:

+ 15 - 5
mem0/llms/litellm.py

@@ -4,11 +4,15 @@ from typing import Dict, List, Optional
 import litellm
 
 from mem0.llms.base import LLMBase
+from mem0.configs.llms.base import BaseLlmConfig
 
 
 class LiteLLM(LLMBase):
-    def __init__(self, model="gpt-4o"):
-        self.model = model
+    def __init__(self, config: Optional[BaseLlmConfig] = None):
+        super().__init__(config)
+
+        if not self.config.model:
+            self.config.model="gpt-4o"
     
     def _parse_response(self, response, tools):
         """
@@ -57,10 +61,16 @@ class LiteLLM(LLMBase):
         Returns:
             str: The generated response.
         """
-        if not litellm.supports_function_calling(self.model):
-            raise ValueError(f"Model '{self.model}' in litellm does not support function calling.")
+        if not litellm.supports_function_calling(self.config.model):
+            raise ValueError(f"Model '{self.config.model}' in litellm does not support function calling.")
 
-        params = {"model": self.model, "messages": messages}
+        params = {
+            "model": self.config.model, 
+            "messages": messages, 
+            "temperature": self.config.temperature, 
+            "max_tokens": self.config.max_tokens, 
+            "top_p": self.config.top_p
+        }
         if response_format:
             params["response_format"] = response_format
         if tools:

+ 13 - 4
mem0/llms/openai.py

@@ -4,12 +4,15 @@ from typing import Dict, List, Optional
 from openai import OpenAI
 
 from mem0.llms.base import LLMBase
-
+from mem0.configs.llms.base import BaseLlmConfig
 
 class OpenAILLM(LLMBase):
-    def __init__(self, model="gpt-4o"):
+    def __init__(self, config: Optional[BaseLlmConfig] = None):
+        super().__init__(config)
+
+        if not self.config.model:
+            self.config.model="gpt-4o"
         self.client = OpenAI()
-        self.model = model
     
     def _parse_response(self, response, tools):
         """
@@ -58,7 +61,13 @@ class OpenAILLM(LLMBase):
         Returns:
             str: The generated response.
         """
-        params = {"model": self.model, "messages": messages}
+        params = {
+            "model": self.config.model, 
+            "messages": messages, 
+            "temperature": self.config.temperature, 
+            "max_tokens": self.config.max_tokens, 
+            "top_p": self.config.top_p
+        }
         if response_format:
             params["response_format"] = response_format
         if tools:

+ 13 - 4
mem0/llms/together.py

@@ -4,12 +4,15 @@ from typing import Dict, List, Optional
 from together import Together
 
 from mem0.llms.base import LLMBase
-
+from mem0.configs.llms.base import BaseLlmConfig
 
 class TogetherLLM(LLMBase):
-    def __init__(self, model="mistralai/Mixtral-8x7B-Instruct-v0.1"):
+    def __init__(self, config: Optional[BaseLlmConfig] = None):
+        super().__init__(config)
+
+        if not self.config.model:
+            self.config.model="mistralai/Mixtral-8x7B-Instruct-v0.1"
         self.client = Together()
-        self.model = model
     
     def _parse_response(self, response, tools):
         """
@@ -58,7 +61,13 @@ class TogetherLLM(LLMBase):
         Returns:
             str: The generated response.
         """
-        params = {"model": self.model, "messages": messages}
+        params = {
+            "model": self.config.model, 
+            "messages": messages, 
+            "temperature": self.config.temperature, 
+            "max_tokens": self.config.max_tokens, 
+            "top_p": self.config.top_p
+        }
         if response_format:
             params["response_format"] = response_format
         if tools:

+ 1 - 1
mem0/memory/main.py

@@ -82,7 +82,7 @@ class Memory(MemoryBase):
                 f"Unsupported vector store type: {self.config.vector_store_type}"
             )
 
-        self.llm = LlmFactory.create(self.config.llm.provider)
+        self.llm = LlmFactory.create(self.config.llm.provider, self.config.llm.config)
         self.db = SQLiteManager(self.config.history_db_path)
         self.collection_name = self.config.collection_name
         self.vector_store.create_col(

+ 6 - 3
mem0/utils/factory.py

@@ -1,5 +1,7 @@
 import importlib
 
+from mem0.configs.llms.base import BaseLlmConfig
+
 
 def load_class(class_type):
     module_path, class_name = class_type.rsplit(".", 1)
@@ -18,11 +20,12 @@ class LlmFactory:
     }
 
     @classmethod
-    def create(cls, provider_name):
+    def create(cls, provider_name, config):
         class_type = cls.provider_to_class.get(provider_name)
         if class_type:
-            llm_instance = load_class(class_type)()
-            return llm_instance
+            llm_instance = load_class(class_type)
+            base_config = BaseLlmConfig(**config)
+            return llm_instance(base_config)
         else:
             raise ValueError(f"Unsupported Llm provider: {provider_name}")
         

+ 12 - 3
tests/llms/test_groq.py

@@ -1,6 +1,7 @@
 import pytest
 from unittest.mock import Mock, patch
 from mem0.llms.groq import GroqLLM
+from mem0.configs.llms.base import BaseLlmConfig
 
 @pytest.fixture
 def mock_groq_client():
@@ -11,7 +12,8 @@ def mock_groq_client():
 
 
 def test_generate_response_without_tools(mock_groq_client):
-    llm = GroqLLM()
+    config = BaseLlmConfig(model="llama3-70b-8192", temperature=0.7, max_tokens=100, top_p=1.0)
+    llm = GroqLLM(config)
     messages = [
         {"role": "system", "content": "You are a helpful assistant."},
         {"role": "user", "content": "Hello, how are you?"}
@@ -25,13 +27,17 @@ def test_generate_response_without_tools(mock_groq_client):
 
     mock_groq_client.chat.completions.create.assert_called_once_with(
         model="llama3-70b-8192",
-        messages=messages
+        messages=messages,
+        temperature=0.7,
+        max_tokens=100,
+        top_p=1.0
     )
     assert response == "I'm doing well, thank you for asking!"
 
 
 def test_generate_response_with_tools(mock_groq_client):
-    llm = GroqLLM()
+    config = BaseLlmConfig(model="llama3-70b-8192", temperature=0.7, max_tokens=100, top_p=1.0)
+    llm = GroqLLM(config)
     messages = [
         {"role": "system", "content": "You are a helpful assistant."},
         {"role": "user", "content": "Add a new memory: Today is a sunny day."}
@@ -70,6 +76,9 @@ def test_generate_response_with_tools(mock_groq_client):
     mock_groq_client.chat.completions.create.assert_called_once_with(
         model="llama3-70b-8192",
         messages=messages,
+        temperature=0.7,
+        max_tokens=100,
+        top_p=1.0,
         tools=tools,
         tool_choice="auto"
     )

+ 14 - 4
tests/llms/test_litellm.py

@@ -2,6 +2,7 @@ import pytest
 from unittest.mock import Mock, patch
 
 from mem0.llms import litellm
+from mem0.configs.llms.base import BaseLlmConfig
 
 @pytest.fixture
 def mock_litellm():
@@ -9,7 +10,8 @@ def mock_litellm():
         yield mock_litellm
 
 def test_generate_response_with_unsupported_model(mock_litellm):
-    llm = litellm.LiteLLM(model="unsupported-model")
+    config = BaseLlmConfig(model="unsupported-model", temperature=0.7, max_tokens=100, top_p=1)
+    llm = litellm.LiteLLM(config)
     messages = [{"role": "user", "content": "Hello"}]
     
     mock_litellm.supports_function_calling.return_value = False
@@ -19,7 +21,8 @@ def test_generate_response_with_unsupported_model(mock_litellm):
 
 
 def test_generate_response_without_tools(mock_litellm):
-    llm = litellm.LiteLLM()
+    config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1)
+    llm = litellm.LiteLLM(config)
     messages = [
         {"role": "system", "content": "You are a helpful assistant."},
         {"role": "user", "content": "Hello, how are you?"}
@@ -34,13 +37,17 @@ def test_generate_response_without_tools(mock_litellm):
 
     mock_litellm.completion.assert_called_once_with(
         model="gpt-4o",
-        messages=messages
+        messages=messages,
+        temperature=0.7,
+        max_tokens=100,
+        top_p=1.0
     )
     assert response == "I'm doing well, thank you for asking!"
 
 
 def test_generate_response_with_tools(mock_litellm):
-    llm = litellm.LiteLLM()
+    config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1)
+    llm = litellm.LiteLLM(config)
     messages = [
         {"role": "system", "content": "You are a helpful assistant."},
         {"role": "user", "content": "Add a new memory: Today is a sunny day."}
@@ -80,6 +87,9 @@ def test_generate_response_with_tools(mock_litellm):
     mock_litellm.completion.assert_called_once_with(
         model="gpt-4o",
         messages=messages,
+        temperature=0.7,
+        max_tokens=100,
+        top_p=1,
         tools=tools,
         tool_choice="auto"
     )

+ 13 - 4
tests/llms/test_openai.py

@@ -1,6 +1,7 @@
 import pytest
 from unittest.mock import Mock, patch
 from mem0.llms.openai import OpenAILLM
+from mem0.configs.llms.base import BaseLlmConfig
 
 @pytest.fixture
 def mock_openai_client():
@@ -11,7 +12,8 @@ def mock_openai_client():
 
 
 def test_generate_response_without_tools(mock_openai_client):
-    llm = OpenAILLM()
+    config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0)
+    llm = OpenAILLM(config)
     messages = [
         {"role": "system", "content": "You are a helpful assistant."},
         {"role": "user", "content": "Hello, how are you?"}
@@ -25,13 +27,17 @@ def test_generate_response_without_tools(mock_openai_client):
 
     mock_openai_client.chat.completions.create.assert_called_once_with(
         model="gpt-4o",
-        messages=messages
+        messages=messages,
+        temperature=0.7,
+        max_tokens=100,
+        top_p=1.0
     )
     assert response == "I'm doing well, thank you for asking!"
-    
+
 
 def test_generate_response_with_tools(mock_openai_client):
-    llm = OpenAILLM()
+    config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0)
+    llm = OpenAILLM(config)
     messages = [
         {"role": "system", "content": "You are a helpful assistant."},
         {"role": "user", "content": "Add a new memory: Today is a sunny day."}
@@ -70,6 +76,9 @@ def test_generate_response_with_tools(mock_openai_client):
     mock_openai_client.chat.completions.create.assert_called_once_with(
         model="gpt-4o",
         messages=messages,
+        temperature=0.7,
+        max_tokens=100,
+        top_p=1.0,
         tools=tools,
         tool_choice="auto"
     )

+ 12 - 3
tests/llms/test_together.py

@@ -1,6 +1,7 @@
 import pytest
 from unittest.mock import Mock, patch
 from mem0.llms.together import TogetherLLM
+from mem0.configs.llms.base import BaseLlmConfig
 
 @pytest.fixture
 def mock_together_client():
@@ -11,7 +12,8 @@ def mock_together_client():
 
 
 def test_generate_response_without_tools(mock_together_client):
-    llm = TogetherLLM()
+    config = BaseLlmConfig(model="mistralai/Mixtral-8x7B-Instruct-v0.1", temperature=0.7, max_tokens=100, top_p=1.0)
+    llm = TogetherLLM(config)
     messages = [
         {"role": "system", "content": "You are a helpful assistant."},
         {"role": "user", "content": "Hello, how are you?"}
@@ -25,13 +27,17 @@ def test_generate_response_without_tools(mock_together_client):
 
     mock_together_client.chat.completions.create.assert_called_once_with(
         model="mistralai/Mixtral-8x7B-Instruct-v0.1",
-        messages=messages
+        messages=messages,
+        temperature=0.7,
+        max_tokens=100,
+        top_p=1.0
     )
     assert response == "I'm doing well, thank you for asking!"
 
 
 def test_generate_response_with_tools(mock_together_client):
-    llm = TogetherLLM()
+    config = BaseLlmConfig(model="mistralai/Mixtral-8x7B-Instruct-v0.1", temperature=0.7, max_tokens=100, top_p=1.0)
+    llm = TogetherLLM(config)
     messages = [
         {"role": "system", "content": "You are a helpful assistant."},
         {"role": "user", "content": "Add a new memory: Today is a sunny day."}
@@ -70,6 +76,9 @@ def test_generate_response_with_tools(mock_together_client):
     mock_together_client.chat.completions.create.assert_called_once_with(
         model="mistralai/Mixtral-8x7B-Instruct-v0.1",
         messages=messages,
+        temperature=0.7,
+        max_tokens=100,
+        top_p=1.0,
         tools=tools,
         tool_choice="auto"
     )