poe.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import argparse
  2. import logging
  3. import os
  4. from typing import Optional
  5. from embedchain.helpers.json_serializable import register_deserializable
  6. from .base import BaseBot
  7. try:
  8. from fastapi_poe import PoeBot, run
  9. except ModuleNotFoundError:
  10. raise ModuleNotFoundError(
  11. "The required dependencies for Poe are not installed." "Please install with `pip install fastapi-poe==0.0.16`"
  12. ) from None
  13. def start_command():
  14. parser = argparse.ArgumentParser(description="EmbedChain PoeBot command line interface")
  15. # parser.add_argument("--host", default="0.0.0.0", help="Host IP to bind")
  16. parser.add_argument("--port", default=8080, type=int, help="Port to bind")
  17. parser.add_argument("--api-key", type=str, help="Poe API key")
  18. # parser.add_argument(
  19. # "--history-length",
  20. # default=5,
  21. # type=int,
  22. # help="Set the max size of the chat history. Multiplies cost, but improves conversation awareness.",
  23. # )
  24. args = parser.parse_args()
  25. # FIXME: Arguments are automatically loaded by Poebot's ArgumentParser which causes it to fail.
  26. # the port argument here is also just for show, it actually works because poe has the same argument.
  27. run(PoeBot(), api_key=args.api_key or os.environ.get("POE_API_KEY"))
  28. @register_deserializable
  29. class PoeBot(BaseBot, PoeBot):
  30. def __init__(self):
  31. self.history_length = 5
  32. super().__init__()
  33. async def get_response(self, query):
  34. last_message = query.query[-1].content
  35. try:
  36. history = (
  37. [f"{m.role}: {m.content}" for m in query.query[-(self.history_length + 1) : -1]]
  38. if len(query.query) > 0
  39. else None
  40. )
  41. except Exception as e:
  42. logging.error(f"Error when processing the chat history. Message is being sent without history. Error: {e}")
  43. answer = self.handle_message(last_message, history)
  44. yield self.text_event(answer)
  45. def handle_message(self, message, history: Optional[list[str]] = None):
  46. if message.startswith("/add "):
  47. response = self.add_data(message)
  48. else:
  49. response = self.ask_bot(message, history)
  50. return response
  51. # def add_data(self, message):
  52. # data = message.split(" ")[-1]
  53. # try:
  54. # self.add(data)
  55. # response = f"Added data from: {data}"
  56. # except Exception:
  57. # logging.exception(f"Failed to add data {data}.")
  58. # response = "Some error occurred while adding data."
  59. # return response
  60. def ask_bot(self, message, history: list[str]):
  61. try:
  62. self.app.llm.set_history(history=history)
  63. response = self.query(message)
  64. except Exception:
  65. logging.exception(f"Failed to query {message}.")
  66. response = "An error occurred. Please try again!"
  67. return response
  68. def start(self):
  69. start_command()
  70. if __name__ == "__main__":
  71. start_command()