poe.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import argparse
  2. import logging
  3. import os
  4. from typing import List, Optional
  5. from fastapi_poe import PoeBot, run
  6. from embedchain.config import QueryConfig
  7. from .base import BaseBot
  8. class EcPoeBot(BaseBot, PoeBot):
  9. def __init__(self):
  10. self.history_length = 5
  11. super().__init__()
  12. async def get_response(self, query):
  13. last_message = query.query[-1].content
  14. try:
  15. history = (
  16. [f"{m.role}: {m.content}" for m in query.query[-(self.history_length + 1) : -1]]
  17. if len(query.query) > 0
  18. else None
  19. )
  20. except Exception as e:
  21. logging.error(f"Error when processing the chat history. Message is being sent without history. Error: {e}")
  22. logging.warning(history)
  23. answer = self.handle_message(last_message, history)
  24. yield self.text_event(answer)
  25. def handle_message(self, message, history: Optional[List[str]] = None):
  26. if message.startswith("/add "):
  27. response = self.add_data(message)
  28. else:
  29. response = self.ask_bot(message, history)
  30. return response
  31. def add_data(self, message):
  32. data = message.split(" ")[-1]
  33. try:
  34. self.add(data)
  35. response = f"Added data from: {data}"
  36. except Exception:
  37. logging.exception(f"Failed to add data {data}.")
  38. response = "Some error occurred while adding data."
  39. return response
  40. def ask_bot(self, message, history: List[str]):
  41. try:
  42. config = QueryConfig(history=history)
  43. response = self.query(message, config)
  44. except Exception:
  45. logging.exception(f"Failed to query {message}.")
  46. response = "An error occurred. Please try again!"
  47. return response
  48. def start_command():
  49. parser = argparse.ArgumentParser(description="EmbedChain PoeBot command line interface")
  50. # parser.add_argument("--host", default="0.0.0.0", help="Host IP to bind")
  51. parser.add_argument("--port", default=8080, type=int, help="Port to bind")
  52. parser.add_argument("--api-key", type=str, help="Poe API key")
  53. # parser.add_argument(
  54. # "--history-length",
  55. # default=5,
  56. # type=int,
  57. # help="Set the max size of the chat history. Multiplies cost, but improves conversation awareness.",
  58. # )
  59. args = parser.parse_args()
  60. # FIXME: Arguments are automatically loaded by Poebot's ArgumentParser which causes it to fail.
  61. # the port argument here is also just for show, it actually works because poe has the same argument.
  62. run(EcPoeBot(), api_key=args.api_key or os.environ.get("POE_API_KEY"))
  63. if __name__ == "__main__":
  64. start_command()