callbacks.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. import queue
  2. from typing import Any, Union
  3. from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
  4. from langchain.schema import LLMResult
  5. STOP_ITEM = "[END]"
  6. """
  7. This is a special item that is used to signal the end of the stream.
  8. """
  9. class StreamingStdOutCallbackHandlerYield(StreamingStdOutCallbackHandler):
  10. """
  11. This is a callback handler that yields the tokens as they are generated.
  12. For a usage example, see the :func:`generate` function below.
  13. """
  14. q: queue.Queue
  15. """
  16. The queue to write the tokens to as they are generated.
  17. """
  18. def __init__(self, q: queue.Queue) -> None:
  19. """
  20. Initialize the callback handler.
  21. q: The queue to write the tokens to as they are generated.
  22. """
  23. super().__init__()
  24. self.q = q
  25. def on_llm_start(self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any) -> None:
  26. """Run when LLM starts running."""
  27. with self.q.mutex:
  28. self.q.queue.clear()
  29. def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
  30. """Run on new LLM token. Only available when streaming is enabled."""
  31. self.q.put(token)
  32. def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
  33. """Run when LLM ends running."""
  34. self.q.put(STOP_ITEM)
  35. def on_llm_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
  36. """Run when LLM errors."""
  37. self.q.put("%s: %s" % (type(error).__name__, str(error)))
  38. self.q.put(STOP_ITEM)
  39. def generate(rq: queue.Queue):
  40. """
  41. This is a generator that yields the items in the queue until it reaches the stop item.
  42. Usage example:
  43. ```
  44. def askQuestion(callback_fn: StreamingStdOutCallbackHandlerYield):
  45. llm = OpenAI(streaming=True, callbacks=[callback_fn])
  46. return llm.invoke(prompt="Write a poem about a tree.")
  47. @app.route("/", methods=["GET"])
  48. def generate_output():
  49. q = Queue()
  50. callback_fn = StreamingStdOutCallbackHandlerYield(q)
  51. threading.Thread(target=askQuestion, args=(callback_fn,)).start()
  52. return Response(generate(q), mimetype="text/event-stream")
  53. ```
  54. """
  55. while True:
  56. result: str = rq.get()
  57. if result == STOP_ITEM or result is None:
  58. break
  59. yield result