瀏覽代碼

[feat]: add support for sending anonymous user_id in telemetry (#491)

Deshraj Yadav 2 年之前
父節點
當前提交
70df373807
共有 2 個文件被更改,包括 34 次插入18 次删除
  1. 29 1
      embedchain/embedchain.py
  2. 5 17
      tests/embedchain/test_chat.py

+ 29 - 1
embedchain/embedchain.py

@@ -1,9 +1,11 @@
 import hashlib
 import hashlib
 import importlib.metadata
 import importlib.metadata
+import json
 import logging
 import logging
 import os
 import os
 import threading
 import threading
 import uuid
 import uuid
+from pathlib import Path
 from typing import Dict, Optional
 from typing import Dict, Optional
 
 
 import requests
 import requests
@@ -25,6 +27,9 @@ load_dotenv()
 
 
 ABS_PATH = os.getcwd()
 ABS_PATH = os.getcwd()
 DB_DIR = os.path.join(ABS_PATH, "db")
 DB_DIR = os.path.join(ABS_PATH, "db")
+HOME_DIR = str(Path.home())
+CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain")
+CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
 
 
 
 
 class EmbedChain:
 class EmbedChain:
@@ -46,9 +51,30 @@ class EmbedChain:
 
 
         # Send anonymous telemetry
         # Send anonymous telemetry
         self.s_id = self.config.id if self.config.id else str(uuid.uuid4())
         self.s_id = self.config.id if self.config.id else str(uuid.uuid4())
+        self.u_id = self._load_or_generate_user_id()
         thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("init",))
         thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("init",))
         thread_telemetry.start()
         thread_telemetry.start()
 
 
+    def _load_or_generate_user_id(self):
+        """
+        Loads the user id from the config file if it exists, otherwise generates a new
+        one and saves it to the config file.
+        """
+        if not os.path.exists(CONFIG_DIR):
+            os.makedirs(CONFIG_DIR)
+
+        if os.path.exists(CONFIG_FILE):
+            with open(CONFIG_FILE, "r") as f:
+                data = json.load(f)
+                if "user_id" in data:
+                    return data["user_id"]
+
+        u_id = str(uuid.uuid4())
+        with open(CONFIG_FILE, "w") as f:
+            json.dump({"user_id": u_id}, f)
+
+        return u_id
+
     def add(
     def add(
         self,
         self,
         source,
         source,
@@ -443,9 +469,11 @@ class EmbedChain:
                 "version": importlib.metadata.version(__package__ or __name__),
                 "version": importlib.metadata.version(__package__ or __name__),
                 "method": method,
                 "method": method,
                 "language": "py",
                 "language": "py",
+                "u_id": self.u_id,
             }
             }
             if extra_metadata:
             if extra_metadata:
                 metadata.update(extra_metadata)
                 metadata.update(extra_metadata)
 
 
             response = requests.post(url, json={"metadata": metadata})
             response = requests.post(url, json={"metadata": metadata})
-            response.raise_for_status()
+            if response.status_code != 200:
+                logging.warning(f"Telemetry event failed with status code {response.status_code}")

+ 5 - 17
tests/embedchain/test_chat.py

@@ -7,15 +7,13 @@ from embedchain.config import AppConfig
 
 
 
 
 class TestApp(unittest.TestCase):
 class TestApp(unittest.TestCase):
-    os.environ["OPENAI_API_KEY"] = "test_key"
-
     def setUp(self):
     def setUp(self):
+        os.environ["OPENAI_API_KEY"] = "test_key"
         self.app = App(config=AppConfig(collect_metrics=False))
         self.app = App(config=AppConfig(collect_metrics=False))
 
 
-    @patch("embedchain.embedchain.memory", autospec=True)
     @patch.object(App, "retrieve_from_database", return_value=["Test context"])
     @patch.object(App, "retrieve_from_database", return_value=["Test context"])
     @patch.object(App, "get_answer_from_llm", return_value="Test answer")
     @patch.object(App, "get_answer_from_llm", return_value="Test answer")
-    def test_chat_with_memory(self, mock_answer, mock_retrieve, mock_memory):
+    def test_chat_with_memory(self, mock_get_answer, mock_retrieve):
         """
         """
         This test checks the functionality of the 'chat' method in the App class with respect to the chat history
         This test checks the functionality of the 'chat' method in the App class with respect to the chat history
         memory.
         memory.
@@ -23,27 +21,17 @@ class TestApp(unittest.TestCase):
         The second call is expected to use the chat history from the first call.
         The second call is expected to use the chat history from the first call.
 
 
         Key assumptions tested:
         Key assumptions tested:
-        - After the first call, 'memory.chat_memory.add_user_message' and 'memory.chat_memory.add_ai_message' are
             called with correct arguments, adding the correct chat history.
             called with correct arguments, adding the correct chat history.
+        - After the first call, 'memory.chat_memory.add_user_message' and 'memory.chat_memory.add_ai_message' are
         - During the second call, the 'chat' method uses the chat history from the first call.
         - During the second call, the 'chat' method uses the chat history from the first call.
 
 
         The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database', 'get_answer_from_llm' and
         The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database', 'get_answer_from_llm' and
         'memory' methods.
         'memory' methods.
         """
         """
-        mock_memory.load_memory_variables.return_value = {"history": []}
         app = App()
         app = App()
-
-        # First call to chat
         first_answer = app.chat("Test query 1")
         first_answer = app.chat("Test query 1")
         self.assertEqual(first_answer, "Test answer")
         self.assertEqual(first_answer, "Test answer")
-        mock_memory.chat_memory.add_user_message.assert_called_once_with("Test query 1")
-        mock_memory.chat_memory.add_ai_message.assert_called_once_with("Test answer")
-
-        mock_memory.chat_memory.add_user_message.reset_mock()
-        mock_memory.chat_memory.add_ai_message.reset_mock()
-
-        # Second call to chat
+        self.assertEqual(len(app.memory.chat_memory.messages), 2)
         second_answer = app.chat("Test query 2")
         second_answer = app.chat("Test query 2")
         self.assertEqual(second_answer, "Test answer")
         self.assertEqual(second_answer, "Test answer")
-        mock_memory.chat_memory.add_user_message.assert_called_once_with("Test query 2")
-        mock_memory.chat_memory.add_ai_message.assert_called_once_with("Test answer")
+        self.assertEqual(len(app.memory.chat_memory.messages), 4)