slack.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import argparse
  2. import logging
  3. import os
  4. import signal
  5. import sys
  6. from embedchain import App
  7. from embedchain.helpers.json_serializable import register_deserializable
  8. from .base import BaseBot
  9. try:
  10. from flask import Flask, request
  11. from slack_sdk import WebClient
  12. except ModuleNotFoundError:
  13. raise ModuleNotFoundError(
  14. "The required dependencies for Slack are not installed."
  15. 'Please install with `pip install --upgrade "embedchain[slack]"`'
  16. ) from None
  17. logger = logging.getLogger(__name__)
  18. SLACK_BOT_TOKEN = os.environ.get("SLACK_BOT_TOKEN")
  19. @register_deserializable
  20. class SlackBot(BaseBot):
  21. def __init__(self):
  22. self.client = WebClient(token=SLACK_BOT_TOKEN)
  23. self.chat_bot = App()
  24. self.recent_message = {"ts": 0, "channel": ""}
  25. super().__init__()
  26. def handle_message(self, event_data):
  27. message = event_data.get("event")
  28. if message and "text" in message and message.get("subtype") != "bot_message":
  29. text: str = message["text"]
  30. if float(message.get("ts")) > float(self.recent_message["ts"]):
  31. self.recent_message["ts"] = message["ts"]
  32. self.recent_message["channel"] = message["channel"]
  33. if text.startswith("query"):
  34. _, question = text.split(" ", 1)
  35. try:
  36. response = self.chat_bot.chat(question)
  37. self.send_slack_message(message["channel"], response)
  38. logger.info("Query answered successfully!")
  39. except Exception as e:
  40. self.send_slack_message(message["channel"], "An error occurred. Please try again!")
  41. logger.error("Error occurred during 'query' command:", e)
  42. elif text.startswith("add"):
  43. _, data_type, url_or_text = text.split(" ", 2)
  44. if url_or_text.startswith("<") and url_or_text.endswith(">"):
  45. url_or_text = url_or_text[1:-1]
  46. try:
  47. self.chat_bot.add(url_or_text, data_type)
  48. self.send_slack_message(message["channel"], f"Added {data_type} : {url_or_text}")
  49. except ValueError as e:
  50. self.send_slack_message(message["channel"], f"Error: {str(e)}")
  51. logger.error("Error occurred during 'add' command:", e)
  52. except Exception as e:
  53. self.send_slack_message(message["channel"], f"Failed to add {data_type} : {url_or_text}")
  54. logger.error("Error occurred during 'add' command:", e)
  55. def send_slack_message(self, channel, message):
  56. response = self.client.chat_postMessage(channel=channel, text=message)
  57. return response
  58. def start(self, host="0.0.0.0", port=5000, debug=True):
  59. app = Flask(__name__)
  60. def signal_handler(sig, frame):
  61. logger.info("\nGracefully shutting down the SlackBot...")
  62. sys.exit(0)
  63. signal.signal(signal.SIGINT, signal_handler)
  64. @app.route("/", methods=["POST"])
  65. def chat():
  66. # Check if the request is a verification request
  67. if request.json.get("challenge"):
  68. return str(request.json.get("challenge"))
  69. response = self.handle_message(request.json)
  70. return str(response)
  71. app.run(host=host, port=port, debug=debug)
  72. def start_command():
  73. parser = argparse.ArgumentParser(description="EmbedChain SlackBot command line interface")
  74. parser.add_argument("--host", default="0.0.0.0", help="Host IP to bind")
  75. parser.add_argument("--port", default=5000, type=int, help="Port to bind")
  76. args = parser.parse_args()
  77. slack_bot = SlackBot()
  78. slack_bot.start(host=args.host, port=args.port)
  79. if __name__ == "__main__":
  80. start_command()