poe.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import argparse
  2. import logging
  3. import os
  4. from typing import List, Optional
  5. from fastapi_poe import PoeBot, run
  6. from embedchain.config import QueryConfig
  7. from embedchain.helper_classes.json_serializable import register_deserializable
  8. from .base import BaseBot
  9. @register_deserializable
  10. class EcPoeBot(BaseBot, PoeBot):
  11. def __init__(self):
  12. self.history_length = 5
  13. super().__init__()
  14. async def get_response(self, query):
  15. last_message = query.query[-1].content
  16. try:
  17. history = (
  18. [f"{m.role}: {m.content}" for m in query.query[-(self.history_length + 1) : -1]]
  19. if len(query.query) > 0
  20. else None
  21. )
  22. except Exception as e:
  23. logging.error(f"Error when processing the chat history. Message is being sent without history. Error: {e}")
  24. logging.warning(history)
  25. answer = self.handle_message(last_message, history)
  26. yield self.text_event(answer)
  27. def handle_message(self, message, history: Optional[List[str]] = None):
  28. if message.startswith("/add "):
  29. response = self.add_data(message)
  30. else:
  31. response = self.ask_bot(message, history)
  32. return response
  33. def add_data(self, message):
  34. data = message.split(" ")[-1]
  35. try:
  36. self.add(data)
  37. response = f"Added data from: {data}"
  38. except Exception:
  39. logging.exception(f"Failed to add data {data}.")
  40. response = "Some error occurred while adding data."
  41. return response
  42. def ask_bot(self, message, history: List[str]):
  43. try:
  44. config = QueryConfig(history=history)
  45. response = self.query(message, config)
  46. except Exception:
  47. logging.exception(f"Failed to query {message}.")
  48. response = "An error occurred. Please try again!"
  49. return response
  50. def start_command():
  51. parser = argparse.ArgumentParser(description="EmbedChain PoeBot command line interface")
  52. # parser.add_argument("--host", default="0.0.0.0", help="Host IP to bind")
  53. parser.add_argument("--port", default=8080, type=int, help="Port to bind")
  54. parser.add_argument("--api-key", type=str, help="Poe API key")
  55. # parser.add_argument(
  56. # "--history-length",
  57. # default=5,
  58. # type=int,
  59. # help="Set the max size of the chat history. Multiplies cost, but improves conversation awareness.",
  60. # )
  61. args = parser.parse_args()
  62. # FIXME: Arguments are automatically loaded by Poebot's ArgumentParser which causes it to fail.
  63. # the port argument here is also just for show, it actually works because poe has the same argument.
  64. run(EcPoeBot(), api_key=args.api_key or os.environ.get("POE_API_KEY"))
  65. if __name__ == "__main__":
  66. start_command()