LancXu vor 1 Jahr
Ursprung
Commit
9a5894d25a
7 geänderte Dateien mit 220 neuen und 5 gelöschten Zeilen
  1. 7 0
      Dockerfile
  2. 2 2
      mem0/embeddings/configs.py
  3. 33 0
      mem0/embeddings/zhipu.py
  4. 2 2
      mem0/llms/configs.py
  5. 78 0
      mem0/llms/zhipu.py
  6. 3 1
      mem0/utils/factory.py
  7. 95 0
      tests/llms/test_zhipu.py

+ 7 - 0
Dockerfile

@@ -0,0 +1,7 @@
+FROM python:3.10.12
+
+ADD . /workspace
+
+WORKDIR /workspace
+
+

+ 2 - 2
mem0/embeddings/configs.py

@@ -6,7 +6,7 @@ from pydantic import BaseModel, Field, field_validator
 class EmbedderConfig(BaseModel):
     provider: str = Field(
         description="Provider of the embedding model (e.g., 'ollama', 'openai')",
-        default="openai",
+        default="zhipu",
     )
     config: Optional[dict] = Field(
         description="Configuration for the specific embedding model", default=None
@@ -15,7 +15,7 @@ class EmbedderConfig(BaseModel):
     @field_validator("config")
     def validate_config(cls, v, values):
         provider = values.data.get("provider")
-        if provider in ["openai", "ollama"]:
+        if provider in ["openai", "ollama", "zhipu"]:
             return v
         else:
             raise ValueError(f"Unsupported embedding provider: {provider}")

+ 33 - 0
mem0/embeddings/zhipu.py

@@ -0,0 +1,33 @@
+# -*- coding: utf-8 -*-
+# @Author: privacy
+# @Date:   2024-07-22 14:49:33
+# @Last Modified by:   privacy
+# @Last Modified time: 2024-07-22 15:24:53
+import os
+from zhipuai import ZhipuAI
+
+from mem0.embeddings.base import EmbeddingBase
+
+
+class ZhipuEmbedding(EmbeddingBase):
+    def __init__(self, model="embedding-2"):
+        self.client = ZhipuAI(api_key=os.environ.get("ZHIPU_API_KEY"))
+        self.model = model
+        self.dims = 1024
+
+    def embed(self, text):
+        """
+        Get the embedding for the given text using OpenAI.
+
+        Args:
+            text (str): The text to embed.
+
+        Returns:
+            list: The embedding vector.
+        """
+        text = text.replace("\n", " ")
+        return (
+            self.client.embeddings.create(input=[text], model=self.model)
+            .data[0]
+            .embedding
+        )

+ 2 - 2
mem0/llms/configs.py

@@ -5,7 +5,7 @@ from pydantic import BaseModel, Field, field_validator
 
 class LlmConfig(BaseModel):
     provider: str = Field(
-        description="Provider of the LLM (e.g., 'ollama', 'openai')", default="openai"
+        description="Provider of the LLM (e.g., 'ollama', 'openai')", default="zhipu"
     )
     config: Optional[dict] = Field(
         description="Configuration for the specific LLM", default={}
@@ -14,7 +14,7 @@ class LlmConfig(BaseModel):
     @field_validator("config")
     def validate_config(cls, v, values):
         provider = values.data.get("provider")
-        if provider in ("openai", "ollama", "groq", "together", "aws_bedrock", "litellm"):
+        if provider in ("openai", "ollama", "groq", "together", "aws_bedrock", "litellm", "zhipu"):
             return v
         else:
             raise ValueError(f"Unsupported LLM provider: {provider}")

+ 78 - 0
mem0/llms/zhipu.py

@@ -0,0 +1,78 @@
+# -*- coding: utf-8 -*-
+# @Author: privacy
+# @Date:   2024-07-22 13:34:17
+# @Last Modified by:   privacy
+# @Last Modified time: 2024-07-22 14:10:07
+import os
+import json
+from typing import Dict, List, Optional
+
+try:
+    from zhipuai import ZhipuAI
+except ImportError:
+    raise ImportError("Together requires extra dependencies. Install with `pip install zhipuai`") from None
+
+from mem0.llms.base import LLMBase
+from mem0.configs.llms.base import BaseLlmConfig
+
+
+class ZhipuLLM(LLMBase):
+    def __init__(self, config: Optional[BaseLlmConfig] = None):
+        super().__init__(config)
+
+        if not self.config.model:
+            self.config.model="glm-4"
+
+        self.client = ZhipuAI(api_key=os.environ.get("ZHIPU_API_KEY"))
+
+    def _parse_response(self, response, tools):
+        """
+        Process the response based on whether tools are used or not.
+
+        Args:
+            response: The raw response from API.
+            tools: The list of tools provided in the request.
+
+        Returns:
+            str or dict: The processed response.
+        """
+        if tools:
+            processed_response = {
+                "content": response.choices[0].message.content,
+                "tool_calls": []
+            }
+            
+            if response.choices[0].message.tool_calls:
+                for tool_call in response.choices[0].message.tool_calls:
+                    processed_response["tool_calls"].append({
+                        "name": tool_call.function.name,
+                        "arguments": json.loads(tool_call.function.arguments)
+                    })
+            
+            return processed_response
+        else:
+            return response.choices[0].message.content
+
+    def generate_response(
+        self,
+        messages: List[Dict[str, str]],
+        response_format=None,
+        tools: Optional[List[Dict]] = None,
+        tool_choice: str = "auto",
+    ):
+        params = {
+            "model": self.config.model, 
+            "messages": messages, 
+            "temperature": self.config.temperature, 
+            "max_tokens": self.config.max_tokens, 
+            "top_p": self.config.top_p
+        }
+
+        if response_format:
+            params["response_format"] = response_format
+        if tools:
+            params["tools"] = tools
+            params["tool_choice"] = tool_choice
+
+        response = self.client.chat.completions.create(**params)
+        return self._parse_response(response, tools)

+ 3 - 1
mem0/utils/factory.py

@@ -17,6 +17,7 @@ class LlmFactory:
         "together": "mem0.llms.together.TogetherLLM",
         "aws_bedrock": "mem0.llms.aws_bedrock.AWSBedrockLLM",
         "litellm": "mem0.llms.litellm.LiteLLM",
+        "zhipu": "mem0.llms.zhipu.ZhipuLLM",
     }
 
     @classmethod
@@ -33,7 +34,8 @@ class EmbedderFactory:
     provider_to_class = {
         "openai": "mem0.embeddings.openai.OpenAIEmbedding",
         "ollama": "mem0.embeddings.ollama.OllamaEmbedding",
-        "huggingface": "mem0.embeddings.huggingface.HuggingFaceEmbedding"
+        "huggingface": "mem0.embeddings.huggingface.HuggingFaceEmbedding",
+        "zhipu": "mem0.embeddings.zhipu.ZhipuEmbedding",
     }
 
     @classmethod

+ 95 - 0
tests/llms/test_zhipu.py

@@ -0,0 +1,95 @@
+# -*- coding: utf-8 -*-
+# @Author: privacy
+# @Date:   2024-07-22 13:54:29
+# @Last Modified by:   privacy
+# @Last Modified time: 2024-07-22 13:57:05
+import pytest
+from unittest.mock import Mock, patch
+from mem0.llms.zhipu import ZhipuLLM
+from mem0.configs.llms.base import BaseLlmConfig
+
+@pytest.fixture
+def mock_groq_client():
+    with patch('mem0.llms.zhipu.ZhipuLLM') as mock_groq:
+        mock_client = Mock()
+        mock_groq.return_value = mock_client
+        yield mock_client
+
+
+def test_generate_response_without_tools(mock_groq_client):
+    config = BaseLlmConfig(model="glm-4", temperature=0.7, max_tokens=100, top_p=1.0)
+    llm = ZhipuLLM(config)
+    messages = [
+        {"role": "system", "content": "You are a helpful assistant."},
+        {"role": "user", "content": "Hello, how are you?"}
+    ]
+    
+    mock_response = Mock()
+    mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))]
+    mock_groq_client.chat.completions.create.return_value = mock_response
+
+    response = llm.generate_response(messages)
+
+    mock_groq_client.chat.completions.create.assert_called_once_with(
+        model="glm-4",
+        messages=messages,
+        temperature=0.7,
+        max_tokens=100,
+        top_p=1.0
+    )
+    assert response == "I'm doing well, thank you for asking!"
+
+
+def test_generate_response_with_tools(mock_groq_client):
+    config = BaseLlmConfig(model="glm-4", temperature=0.7, max_tokens=100, top_p=1.0)
+    llm = ZhipuLLM(config)
+    messages = [
+        {"role": "system", "content": "You are a helpful assistant."},
+        {"role": "user", "content": "Add a new memory: Today is a sunny day."}
+    ]
+    tools = [
+        {
+            "type": "function",
+            "function": {
+                "name": "add_memory",
+                "description": "Add a memory",
+                "parameters": {
+                    "type": "object",
+                    "properties": {
+                        "data": {"type": "string", "description": "Data to add to memory"}
+                    },
+                    "required": ["data"],
+                },
+            },
+        }
+    ]
+    
+    mock_response = Mock()
+    mock_message = Mock()
+    mock_message.content = "I've added the memory for you."
+    
+    mock_tool_call = Mock()
+    mock_tool_call.function.name = "add_memory"
+    mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}'
+    
+    mock_message.tool_calls = [mock_tool_call]
+    mock_response.choices = [Mock(message=mock_message)]
+    mock_groq_client.chat.completions.create.return_value = mock_response
+
+    response = llm.generate_response(messages, tools=tools)
+
+    mock_groq_client.chat.completions.create.assert_called_once_with(
+        model="glm-4",
+        messages=messages,
+        temperature=0.7,
+        max_tokens=100,
+        top_p=1.0,
+        tools=tools,
+        tool_choice="auto"
+    )
+    
+    assert response["content"] == "I've added the memory for you."
+    assert len(response["tool_calls"]) == 1
+    assert response["tool_calls"][0]["name"] == "add_memory"
+    assert response["tool_calls"][0]["arguments"] == {'data': 'Today is a sunny day.'}
+