poe.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import argparse
  2. import logging
  3. import os
  4. from typing import List, Optional
  5. from fastapi_poe import PoeBot, run
  6. from embedchain.config import QueryConfig
  7. from embedchain.helper_classes.json_serializable import register_deserializable
  8. from .base import BaseBot
  9. def start_command():
  10. parser = argparse.ArgumentParser(description="EmbedChain PoeBot command line interface")
  11. # parser.add_argument("--host", default="0.0.0.0", help="Host IP to bind")
  12. parser.add_argument("--port", default=8080, type=int, help="Port to bind")
  13. parser.add_argument("--api-key", type=str, help="Poe API key")
  14. # parser.add_argument(
  15. # "--history-length",
  16. # default=5,
  17. # type=int,
  18. # help="Set the max size of the chat history. Multiplies cost, but improves conversation awareness.",
  19. # )
  20. args = parser.parse_args()
  21. # FIXME: Arguments are automatically loaded by Poebot's ArgumentParser which causes it to fail.
  22. # the port argument here is also just for show, it actually works because poe has the same argument.
  23. run(PoeBot(), api_key=args.api_key or os.environ.get("POE_API_KEY"))
  24. @register_deserializable
  25. class PoeBot(BaseBot, PoeBot):
  26. def __init__(self):
  27. self.history_length = 5
  28. super().__init__()
  29. async def get_response(self, query):
  30. last_message = query.query[-1].content
  31. try:
  32. history = (
  33. [f"{m.role}: {m.content}" for m in query.query[-(self.history_length + 1) : -1]]
  34. if len(query.query) > 0
  35. else None
  36. )
  37. except Exception as e:
  38. logging.error(f"Error when processing the chat history. Message is being sent without history. Error: {e}")
  39. logging.warning(history)
  40. answer = self.handle_message(last_message, history)
  41. yield self.text_event(answer)
  42. def handle_message(self, message, history: Optional[List[str]] = None):
  43. if message.startswith("/add "):
  44. response = self.add_data(message)
  45. else:
  46. response = self.ask_bot(message, history)
  47. return response
  48. # def add_data(self, message):
  49. # data = message.split(" ")[-1]
  50. # try:
  51. # self.add(data)
  52. # response = f"Added data from: {data}"
  53. # except Exception:
  54. # logging.exception(f"Failed to add data {data}.")
  55. # response = "Some error occurred while adding data."
  56. # return response
  57. def ask_bot(self, message, history: List[str]):
  58. try:
  59. config = QueryConfig(history=history)
  60. response = self.query(message, config)
  61. except Exception:
  62. logging.exception(f"Failed to query {message}.")
  63. response = "An error occurred. Please try again!"
  64. return response
  65. def start(self):
  66. start_command()
  67. if __name__ == "__main__":
  68. start_command()