Browse Source

Lint and formatting fixes (#554)

Co-authored-by: cachho <admin@ch-webdev.com>
Co-authored-by: Taranjeet Singh <reachtotj@gmail.com>
Dev Khant 1 year ago
parent
commit
129242534d

+ 4 - 3
embedchain/bots/__init__.py

@@ -1,4 +1,5 @@
-from embedchain.bots.poe import PoeBot
-from embedchain.bots.whatsapp import WhatsAppBot
+from embedchain.bots.poe import PoeBot  # noqa: F401
+from embedchain.bots.whatsapp import WhatsAppBot  # noqa: F401
+
 # TODO: fix discord import
 # TODO: fix discord import
-# from embedchain.bots.discord import DiscordBot
+# from embedchain.bots.discord import DiscordBot

+ 12 - 5
embedchain/bots/whatsapp.py

@@ -1,4 +1,5 @@
 import argparse
 import argparse
+import importlib
 import logging
 import logging
 import signal
 import signal
 import sys
 import sys
@@ -11,8 +12,14 @@ from .base import BaseBot
 @register_deserializable
 @register_deserializable
 class WhatsAppBot(BaseBot):
 class WhatsAppBot(BaseBot):
     def __init__(self):
     def __init__(self):
-        from flask import Flask, request
-        from twilio.twiml.messaging_response import MessagingResponse
+        try:
+            self.flask = importlib.import_module("flask")
+            self.twilio = importlib.import_module("twilio")
+        except ModuleNotFoundError:
+            raise ModuleNotFoundError(
+                "The required dependencies for WhatsApp are not installed. "
+                'Please install with `pip install --upgrade "embedchain[whatsapp]"`'
+            ) from None
         super().__init__()
         super().__init__()
 
 
     def handle_message(self, message):
     def handle_message(self, message):
@@ -41,7 +48,7 @@ class WhatsAppBot(BaseBot):
         return response
         return response
 
 
     def start(self, host="0.0.0.0", port=5000, debug=True):
     def start(self, host="0.0.0.0", port=5000, debug=True):
-        app = Flask(__name__)
+        app = self.flask.Flask(__name__)
 
 
         def signal_handler(sig, frame):
         def signal_handler(sig, frame):
             logging.info("\nGracefully shutting down the WhatsAppBot...")
             logging.info("\nGracefully shutting down the WhatsAppBot...")
@@ -51,9 +58,9 @@ class WhatsAppBot(BaseBot):
 
 
         @app.route("/chat", methods=["POST"])
         @app.route("/chat", methods=["POST"])
         def chat():
         def chat():
-            incoming_message = request.values.get("Body", "").lower()
+            incoming_message = self.flask.request.values.get("Body", "").lower()
             response = self.handle_message(incoming_message)
             response = self.handle_message(incoming_message)
-            twilio_response = MessagingResponse()
+            twilio_response = self.twilio.twiml.messaging_response.MessagingResponse()
             twilio_response.message(response)
             twilio_response.message(response)
             return str(twilio_response)
             return str(twilio_response)
 
 

+ 1 - 1
embedchain/embedchain.py

@@ -82,7 +82,7 @@ class EmbedChain(JSONSerializable):
         # Send anonymous telemetry
         # Send anonymous telemetry
         self.s_id = self.config.id if self.config.id else str(uuid.uuid4())
         self.s_id = self.config.id if self.config.id else str(uuid.uuid4())
         self.u_id = self._load_or_generate_user_id()
         self.u_id = self._load_or_generate_user_id()
-        # NOTE: Uncomment the next two lines when running tests to see if any test fires a telemetry event. 
+        # NOTE: Uncomment the next two lines when running tests to see if any test fires a telemetry event.
         # if (self.config.collect_metrics):
         # if (self.config.collect_metrics):
         #     raise ConnectionRefusedError("Collection of metrics should not be allowed.")
         #     raise ConnectionRefusedError("Collection of metrics should not be allowed.")
         thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("init",))
         thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("init",))

+ 1 - 2
embedchain/llm/antrophic_llm.py

@@ -2,9 +2,8 @@ import logging
 from typing import Optional
 from typing import Optional
 
 
 from embedchain.config import BaseLlmConfig
 from embedchain.config import BaseLlmConfig
-from embedchain.llm.base_llm import BaseLlm
-
 from embedchain.helper_classes.json_serializable import register_deserializable
 from embedchain.helper_classes.json_serializable import register_deserializable
+from embedchain.llm.base_llm import BaseLlm
 
 
 
 
 @register_deserializable
 @register_deserializable

+ 1 - 2
embedchain/llm/azure_openai_llm.py

@@ -2,9 +2,8 @@ import logging
 from typing import Optional
 from typing import Optional
 
 
 from embedchain.config import BaseLlmConfig
 from embedchain.config import BaseLlmConfig
-from embedchain.llm.base_llm import BaseLlm
-
 from embedchain.helper_classes.json_serializable import register_deserializable
 from embedchain.helper_classes.json_serializable import register_deserializable
+from embedchain.llm.base_llm import BaseLlm
 
 
 
 
 @register_deserializable
 @register_deserializable

+ 1 - 1
embedchain/llm/base_llm.py

@@ -3,12 +3,12 @@ from typing import List, Optional
 
 
 from langchain.memory import ConversationBufferMemory
 from langchain.memory import ConversationBufferMemory
 from langchain.schema import BaseMessage
 from langchain.schema import BaseMessage
-from embedchain.helper_classes.json_serializable import JSONSerializable
 
 
 from embedchain.config import BaseLlmConfig
 from embedchain.config import BaseLlmConfig
 from embedchain.config.llm.base_llm_config import (
 from embedchain.config.llm.base_llm_config import (
     DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
     DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
     DOCS_SITE_PROMPT_TEMPLATE)
     DOCS_SITE_PROMPT_TEMPLATE)
+from embedchain.helper_classes.json_serializable import JSONSerializable
 
 
 
 
 class BaseLlm(JSONSerializable):
 class BaseLlm(JSONSerializable):

+ 1 - 2
embedchain/llm/gpt4all_llm.py

@@ -1,9 +1,8 @@
 from typing import Iterable, Optional, Union
 from typing import Iterable, Optional, Union
 
 
 from embedchain.config import BaseLlmConfig
 from embedchain.config import BaseLlmConfig
-from embedchain.llm.base_llm import BaseLlm
-
 from embedchain.helper_classes.json_serializable import register_deserializable
 from embedchain.helper_classes.json_serializable import register_deserializable
+from embedchain.llm.base_llm import BaseLlm
 
 
 
 
 @register_deserializable
 @register_deserializable

+ 1 - 2
embedchain/llm/llama2_llm.py

@@ -4,9 +4,8 @@ from typing import Optional
 from langchain.llms import Replicate
 from langchain.llms import Replicate
 
 
 from embedchain.config import BaseLlmConfig
 from embedchain.config import BaseLlmConfig
-from embedchain.llm.base_llm import BaseLlm
-
 from embedchain.helper_classes.json_serializable import register_deserializable
 from embedchain.helper_classes.json_serializable import register_deserializable
+from embedchain.llm.base_llm import BaseLlm
 
 
 
 
 @register_deserializable
 @register_deserializable

+ 1 - 2
embedchain/llm/openai_llm.py

@@ -3,9 +3,8 @@ from typing import Optional
 import openai
 import openai
 
 
 from embedchain.config import BaseLlmConfig
 from embedchain.config import BaseLlmConfig
-from embedchain.llm.base_llm import BaseLlm
-
 from embedchain.helper_classes.json_serializable import register_deserializable
 from embedchain.helper_classes.json_serializable import register_deserializable
+from embedchain.llm.base_llm import BaseLlm
 
 
 
 
 @register_deserializable
 @register_deserializable

+ 1 - 2
embedchain/llm/vertex_ai_llm.py

@@ -2,9 +2,8 @@ import logging
 from typing import Optional
 from typing import Optional
 
 
 from embedchain.config import BaseLlmConfig
 from embedchain.config import BaseLlmConfig
-from embedchain.llm.base_llm import BaseLlm
-
 from embedchain.helper_classes.json_serializable import register_deserializable
 from embedchain.helper_classes.json_serializable import register_deserializable
+from embedchain.llm.base_llm import BaseLlm
 
 
 
 
 @register_deserializable
 @register_deserializable

+ 5 - 6
tests/llm/test_chat.py

@@ -1,7 +1,6 @@
-
 import os
 import os
 import unittest
 import unittest
-from unittest.mock import patch, MagicMock
+from unittest.mock import MagicMock, patch
 
 
 from embedchain import App
 from embedchain import App
 from embedchain.config import AppConfig, BaseLlmConfig
 from embedchain.config import AppConfig, BaseLlmConfig
@@ -88,8 +87,8 @@ class TestApp(unittest.TestCase):
 
 
         self.assertEqual(answer, "Test answer")
         self.assertEqual(answer, "Test answer")
         _args, kwargs = mock_retrieve.call_args
         _args, kwargs = mock_retrieve.call_args
-        self.assertEqual(kwargs.get('input_query'), "Test query")
-        self.assertEqual(kwargs.get('where'), {"attribute": "value"})
+        self.assertEqual(kwargs.get("input_query"), "Test query")
+        self.assertEqual(kwargs.get("where"), {"attribute": "value"})
         mock_answer.assert_called_once()
         mock_answer.assert_called_once()
 
 
     @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
     @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
@@ -120,6 +119,6 @@ class TestApp(unittest.TestCase):
 
 
         self.assertEqual(answer, "Test answer")
         self.assertEqual(answer, "Test answer")
         _args, kwargs = mock_database_query.call_args
         _args, kwargs = mock_database_query.call_args
-        self.assertEqual(kwargs.get('input_query'), "Test query")
-        self.assertEqual(kwargs.get('where'), {"attribute": "value"})
+        self.assertEqual(kwargs.get("input_query"), "Test query")
+        self.assertEqual(kwargs.get("where"), {"attribute": "value"})
         mock_answer.assert_called_once()
         mock_answer.assert_called_once()

+ 4 - 4
tests/llm/test_query.py

@@ -109,8 +109,8 @@ class TestApp(unittest.TestCase):
 
 
         self.assertEqual(answer, "Test answer")
         self.assertEqual(answer, "Test answer")
         _args, kwargs = mock_retrieve.call_args
         _args, kwargs = mock_retrieve.call_args
-        self.assertEqual(kwargs.get('input_query'), "Test query")
-        self.assertEqual(kwargs.get('where'), {"attribute": "value"})
+        self.assertEqual(kwargs.get("input_query"), "Test query")
+        self.assertEqual(kwargs.get("where"), {"attribute": "value"})
         mock_answer.assert_called_once()
         mock_answer.assert_called_once()
 
 
     @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
     @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
@@ -142,6 +142,6 @@ class TestApp(unittest.TestCase):
 
 
         self.assertEqual(answer, "Test answer")
         self.assertEqual(answer, "Test answer")
         _args, kwargs = mock_database_query.call_args
         _args, kwargs = mock_database_query.call_args
-        self.assertEqual(kwargs.get('input_query'), "Test query")
-        self.assertEqual(kwargs.get('where'), {"attribute": "value"})
+        self.assertEqual(kwargs.get("input_query"), "Test query")
+        self.assertEqual(kwargs.get("where"), {"attribute": "value"})
         mock_answer.assert_called_once()
         mock_answer.assert_called_once()

+ 0 - 2
tests/vectordb/test_chroma_db.py

@@ -7,7 +7,6 @@ from chromadb.config import Settings
 
 
 from embedchain import App
 from embedchain import App
 from embedchain.config import AppConfig, ChromaDbConfig
 from embedchain.config import AppConfig, ChromaDbConfig
-from embedchain.models import EmbeddingFunctions, Providers
 from embedchain.vectordb.chroma_db import ChromaDB
 from embedchain.vectordb.chroma_db import ChromaDB
 
 
 
 
@@ -86,7 +85,6 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
         """
         """
         Test if the `App` instance is initialized without a config that does not contain default hosts and ports.
         Test if the `App` instance is initialized without a config that does not contain default hosts and ports.
         """
         """
-        config = AppConfig(log_level="DEBUG")
 
 
         _app = App(config=AppConfig(collect_metrics=False))
         _app = App(config=AppConfig(collect_metrics=False))