Explorar o código

[Feature] Add support for custom streaming callback (#971)

Deshraj Yadav hai 1 ano
pai
achega
f6b80e01a1
Modificáronse 82 ficheiros con 162 adicións e 84 borrados
  1. 1 1
      embedchain/apps/app.py
  2. 2 2
      embedchain/bots/base.py
  3. 1 1
      embedchain/bots/discord.py
  4. 1 1
      embedchain/bots/poe.py
  5. 1 1
      embedchain/bots/slack.py
  6. 1 1
      embedchain/bots/whatsapp.py
  7. 1 1
      embedchain/chunkers/base_chunker.py
  8. 1 1
      embedchain/chunkers/common_chunker.py
  9. 1 1
      embedchain/chunkers/discourse.py
  10. 1 1
      embedchain/chunkers/docs_site.py
  11. 1 1
      embedchain/chunkers/docx_file.py
  12. 1 1
      embedchain/chunkers/gmail.py
  13. 1 1
      embedchain/chunkers/json.py
  14. 1 1
      embedchain/chunkers/mdx.py
  15. 1 1
      embedchain/chunkers/mysql.py
  16. 1 1
      embedchain/chunkers/notion.py
  17. 1 1
      embedchain/chunkers/pdf_file.py
  18. 1 1
      embedchain/chunkers/postgres.py
  19. 1 1
      embedchain/chunkers/qna_pair.py
  20. 1 1
      embedchain/chunkers/sitemap.py
  21. 1 1
      embedchain/chunkers/slack.py
  22. 1 1
      embedchain/chunkers/substack.py
  23. 1 1
      embedchain/chunkers/text.py
  24. 1 1
      embedchain/chunkers/unstructured_file.py
  25. 1 1
      embedchain/chunkers/web_page.py
  26. 1 1
      embedchain/chunkers/xml.py
  27. 1 1
      embedchain/chunkers/youtube_video.py
  28. 1 1
      embedchain/config/add_config.py
  29. 1 1
      embedchain/config/apps/app_config.py
  30. 1 1
      embedchain/config/apps/base_app_config.py
  31. 1 1
      embedchain/config/base_config.py
  32. 1 1
      embedchain/config/embedder/base.py
  33. 6 2
      embedchain/config/llm/base.py
  34. 1 1
      embedchain/config/pipeline_config.py
  35. 1 1
      embedchain/config/vectordb/chroma.py
  36. 1 1
      embedchain/config/vectordb/elasticsearch.py
  37. 1 1
      embedchain/config/vectordb/opensearch.py
  38. 1 1
      embedchain/config/vectordb/pinecone.py
  39. 1 1
      embedchain/config/vectordb/qdrant.py
  40. 1 1
      embedchain/config/vectordb/weaviate.py
  41. 1 1
      embedchain/config/vectordb/zilliz.py
  42. 1 1
      embedchain/data_formatter/data_formatter.py
  43. 1 1
      embedchain/embedchain.py
  44. 2 2
      embedchain/embedder/base.py
  45. 0 0
      embedchain/helpers/__init__.py
  46. 73 0
      embedchain/helpers/callbacks.py
  47. 0 0
      embedchain/helpers/json_serializable.py
  48. 1 1
      embedchain/llm/anthropic.py
  49. 1 1
      embedchain/llm/azure_openai.py
  50. 1 1
      embedchain/llm/base.py
  51. 1 1
      embedchain/llm/cohere.py
  52. 1 1
      embedchain/llm/gpt4all.py
  53. 1 1
      embedchain/llm/huggingface.py
  54. 1 1
      embedchain/llm/jina.py
  55. 1 1
      embedchain/llm/llama2.py
  56. 3 2
      embedchain/llm/openai.py
  57. 1 1
      embedchain/llm/vertex_ai.py
  58. 1 1
      embedchain/loaders/base_loader.py
  59. 1 1
      embedchain/loaders/docs_site_loader.py
  60. 1 1
      embedchain/loaders/docx_file.py
  61. 1 1
      embedchain/loaders/local_qna_pair.py
  62. 1 1
      embedchain/loaders/local_text.py
  63. 1 1
      embedchain/loaders/mdx.py
  64. 1 1
      embedchain/loaders/notion.py
  65. 1 1
      embedchain/loaders/pdf_file.py
  66. 1 1
      embedchain/loaders/sitemap.py
  67. 1 1
      embedchain/loaders/substack.py
  68. 1 1
      embedchain/loaders/unstructured_file.py
  69. 1 1
      embedchain/loaders/web_page.py
  70. 1 1
      embedchain/loaders/xml.py
  71. 1 1
      embedchain/loaders/youtube_video.py
  72. 1 1
      embedchain/memory/message.py
  73. 1 1
      embedchain/pipeline.py
  74. 1 1
      embedchain/vectordb/base.py
  75. 1 1
      embedchain/vectordb/chroma.py
  76. 1 1
      embedchain/vectordb/elasticsearch.py
  77. 1 1
      embedchain/vectordb/opensearch.py
  78. 1 1
      embedchain/vectordb/pinecone.py
  79. 1 1
      embedchain/vectordb/weaviate.py
  80. 1 1
      embedchain/vectordb/zilliz.py
  81. 1 1
      pyproject.toml
  82. 2 2
      tests/helper_classes/test_json_serializable.py

+ 1 - 1
embedchain/apps/app.py

@@ -10,7 +10,7 @@ from embedchain.embedchain import EmbedChain
 from embedchain.embedder.base import BaseEmbedder
 from embedchain.embedder.openai import OpenAIEmbedder
 from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.llm.base import BaseLlm
 from embedchain.llm.openai import OpenAILlm
 from embedchain.utils import validate_yaml_config

+ 2 - 2
embedchain/bots/base.py

@@ -3,8 +3,8 @@ from typing import Any
 from embedchain import Pipeline as App
 from embedchain.config import AddConfig, BaseLlmConfig, PipelineConfig
 from embedchain.embedder.openai import OpenAIEmbedder
-from embedchain.helper.json_serializable import (JSONSerializable,
-                                                 register_deserializable)
+from embedchain.helpers.json_serializable import (JSONSerializable,
+                                                  register_deserializable)
 from embedchain.llm.openai import OpenAILlm
 from embedchain.vectordb.chroma import ChromaDB
 

+ 1 - 1
embedchain/bots/discord.py

@@ -2,7 +2,7 @@ import argparse
 import logging
 import os
 
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 from .base import BaseBot
 

+ 1 - 1
embedchain/bots/poe.py

@@ -3,7 +3,7 @@ import logging
 import os
 from typing import List, Optional
 
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 from .base import BaseBot
 

+ 1 - 1
embedchain/bots/slack.py

@@ -5,7 +5,7 @@ import signal
 import sys
 
 from embedchain import App
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 from .base import BaseBot
 

+ 1 - 1
embedchain/bots/whatsapp.py

@@ -4,7 +4,7 @@ import logging
 import signal
 import sys
 
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 from .base import BaseBot
 

+ 1 - 1
embedchain/chunkers/base_chunker.py

@@ -1,6 +1,6 @@
 import hashlib
 
-from embedchain.helper.json_serializable import JSONSerializable
+from embedchain.helpers.json_serializable import JSONSerializable
 from embedchain.models.data_type import DataType
 
 

+ 1 - 1
embedchain/chunkers/common_chunker.py

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config.add_config import ChunkerConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 
 @register_deserializable

+ 1 - 1
embedchain/chunkers/discourse.py

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config.add_config import ChunkerConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 
 @register_deserializable

+ 1 - 1
embedchain/chunkers/docs_site.py

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config.add_config import ChunkerConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 
 @register_deserializable

+ 1 - 1
embedchain/chunkers/docx_file.py

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config.add_config import ChunkerConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 
 @register_deserializable

+ 1 - 1
embedchain/chunkers/gmail.py

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config.add_config import ChunkerConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 
 @register_deserializable

+ 1 - 1
embedchain/chunkers/json.py

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config.add_config import ChunkerConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 
 @register_deserializable

+ 1 - 1
embedchain/chunkers/mdx.py

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config.add_config import ChunkerConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 
 @register_deserializable

+ 1 - 1
embedchain/chunkers/mysql.py

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config.add_config import ChunkerConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 
 @register_deserializable

+ 1 - 1
embedchain/chunkers/notion.py

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config.add_config import ChunkerConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 
 @register_deserializable

+ 1 - 1
embedchain/chunkers/pdf_file.py

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config.add_config import ChunkerConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 
 @register_deserializable

+ 1 - 1
embedchain/chunkers/postgres.py

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config.add_config import ChunkerConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 
 @register_deserializable

+ 1 - 1
embedchain/chunkers/qna_pair.py

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config.add_config import ChunkerConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 
 @register_deserializable

+ 1 - 1
embedchain/chunkers/sitemap.py

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config.add_config import ChunkerConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 
 @register_deserializable

+ 1 - 1
embedchain/chunkers/slack.py

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config.add_config import ChunkerConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 
 @register_deserializable

+ 1 - 1
embedchain/chunkers/substack.py

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config.add_config import ChunkerConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 
 @register_deserializable

+ 1 - 1
embedchain/chunkers/text.py

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config.add_config import ChunkerConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 
 @register_deserializable

+ 1 - 1
embedchain/chunkers/unstructured_file.py

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config.add_config import ChunkerConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 
 @register_deserializable

+ 1 - 1
embedchain/chunkers/web_page.py

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config.add_config import ChunkerConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 
 @register_deserializable

+ 1 - 1
embedchain/chunkers/xml.py

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config.add_config import ChunkerConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 
 @register_deserializable

+ 1 - 1
embedchain/chunkers/youtube_video.py

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config.add_config import ChunkerConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 
 @register_deserializable

+ 1 - 1
embedchain/config/add_config.py

@@ -3,7 +3,7 @@ from importlib import import_module
 from typing import Callable, Optional
 
 from embedchain.config.base_config import BaseConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 
 @register_deserializable

+ 1 - 1
embedchain/config/apps/app_config.py

@@ -1,6 +1,6 @@
 from typing import Optional
 
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 from .base_app_config import BaseAppConfig
 

+ 1 - 1
embedchain/config/apps/base_app_config.py

@@ -2,7 +2,7 @@ import logging
 from typing import Optional
 
 from embedchain.config.base_config import BaseConfig
-from embedchain.helper.json_serializable import JSONSerializable
+from embedchain.helpers.json_serializable import JSONSerializable
 from embedchain.vectordb.base import BaseVectorDB
 
 

+ 1 - 1
embedchain/config/base_config.py

@@ -1,6 +1,6 @@
 from typing import Any, Dict
 
-from embedchain.helper.json_serializable import JSONSerializable
+from embedchain.helpers.json_serializable import JSONSerializable
 
 
 class BaseConfig(JSONSerializable):

+ 1 - 1
embedchain/config/embedder/base.py

@@ -1,6 +1,6 @@
 from typing import Optional
 
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 
 @register_deserializable

+ 6 - 2
embedchain/config/llm/base.py

@@ -1,9 +1,9 @@
 import re
 from string import Template
-from typing import Any, Dict, Optional
+from typing import Any, Dict, List, Optional
 
 from embedchain.config.base_config import BaseConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 DEFAULT_PROMPT = """
   Use the following pieces of context to answer the query at the end.
@@ -68,6 +68,7 @@ class BaseLlmConfig(BaseConfig):
         system_prompt: Optional[str] = None,
         where: Dict[str, Any] = None,
         query_type: Optional[str] = None,
+        callbacks: Optional[List] = None,
     ):
         """
         Initializes a configuration class instance for the LLM.
@@ -98,6 +99,8 @@ class BaseLlmConfig(BaseConfig):
         :type system_prompt: Optional[str], optional
         :param where: A dictionary of key-value pairs to filter the database results., defaults to None
         :type where: Dict[str, Any], optional
+        :param callbacks: Langchain callback functions to use, defaults to None
+        :type callbacks: Optional[List], optional
         :raises ValueError: If the template is not valid as template should
         contain $context and $query (and optionally $history)
         :raises ValueError: Stream is not boolean
@@ -113,6 +116,7 @@ class BaseLlmConfig(BaseConfig):
         self.deployment_name = deployment_name
         self.system_prompt = system_prompt
         self.query_type = query_type
+        self.callbacks = callbacks
 
         if type(template) is str:
             template = Template(template)

+ 1 - 1
embedchain/config/pipeline_config.py

@@ -1,6 +1,6 @@
 from typing import Optional
 
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 from .apps.base_app_config import BaseAppConfig
 

+ 1 - 1
embedchain/config/vectordb/chroma.py

@@ -1,7 +1,7 @@
 from typing import Optional
 
 from embedchain.config.vectordb.base import BaseVectorDbConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 
 @register_deserializable

+ 1 - 1
embedchain/config/vectordb/elasticsearch.py

@@ -2,7 +2,7 @@ import os
 from typing import Dict, List, Optional, Union
 
 from embedchain.config.vectordb.base import BaseVectorDbConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 
 @register_deserializable

+ 1 - 1
embedchain/config/vectordb/opensearch.py

@@ -1,7 +1,7 @@
 from typing import Dict, Optional, Tuple
 
 from embedchain.config.vectordb.base import BaseVectorDbConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 
 @register_deserializable

+ 1 - 1
embedchain/config/vectordb/pinecone.py

@@ -1,7 +1,7 @@
 from typing import Dict, Optional
 
 from embedchain.config.vectordb.base import BaseVectorDbConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 
 @register_deserializable

+ 1 - 1
embedchain/config/vectordb/qdrant.py

@@ -1,7 +1,7 @@
 from typing import Dict, Optional
 
 from embedchain.config.vectordb.base import BaseVectorDbConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 
 @register_deserializable

+ 1 - 1
embedchain/config/vectordb/weaviate.py

@@ -1,7 +1,7 @@
 from typing import Dict, Optional
 
 from embedchain.config.vectordb.base import BaseVectorDbConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 
 @register_deserializable

+ 1 - 1
embedchain/config/vectordb/zilliz.py

@@ -2,7 +2,7 @@ import os
 from typing import Optional
 
 from embedchain.config.vectordb.base import BaseVectorDbConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 
 
 @register_deserializable

+ 1 - 1
embedchain/data_formatter/data_formatter.py

@@ -4,7 +4,7 @@ from typing import Any, Dict
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config import AddConfig
 from embedchain.config.add_config import ChunkerConfig, LoaderConfig
-from embedchain.helper.json_serializable import JSONSerializable
+from embedchain.helpers.json_serializable import JSONSerializable
 from embedchain.loaders.base_loader import BaseLoader
 from embedchain.models.data_type import DataType
 

+ 1 - 1
embedchain/embedchain.py

@@ -13,7 +13,7 @@ from embedchain.config.apps.base_app_config import BaseAppConfig
 from embedchain.constants import SQLITE_PATH
 from embedchain.data_formatter import DataFormatter
 from embedchain.embedder.base import BaseEmbedder
-from embedchain.helper.json_serializable import JSONSerializable
+from embedchain.helpers.json_serializable import JSONSerializable
 from embedchain.llm.base import BaseLlm
 from embedchain.loaders.base_loader import BaseLoader
 from embedchain.models.data_type import (DataType, DirectDataType,

+ 2 - 2
embedchain/embedder/base.py

@@ -3,12 +3,12 @@ from typing import Any, Callable, Optional
 from embedchain.config.embedder.base import BaseEmbedderConfig
 
 try:
-    from chromadb.api.types import Embeddings, Embeddable, EmbeddingFunction
+    from chromadb.api.types import Embeddable, EmbeddingFunction, Embeddings
 except RuntimeError:
     from embedchain.utils import use_pysqlite3
 
     use_pysqlite3()
-    from chromadb.api.types import Embeddings, Embeddable, EmbeddingFunction
+    from chromadb.api.types import Embeddable, EmbeddingFunction, Embeddings
 
 
 class EmbeddingFunc(EmbeddingFunction):

+ 0 - 0
embedchain/helpers/__init__.py


+ 73 - 0
embedchain/helpers/callbacks.py

@@ -0,0 +1,73 @@
+import queue
+from typing import Any, Dict, List, Union
+
+from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
+from langchain.schema import LLMResult
+
+STOP_ITEM = "[END]"
+"""
+This is a special item that is used to signal the end of the stream.
+"""
+
+
+class StreamingStdOutCallbackHandlerYield(StreamingStdOutCallbackHandler):
+    """
+    This is a callback handler that yields the tokens as they are generated.
+    For a usage example, see the :func:`generate` function below.
+    """
+
+    q: queue.Queue
+    """
+    The queue to write the tokens to as they are generated.
+    """
+
+    def __init__(self, q: queue.Queue) -> None:
+        """
+        Initialize the callback handler.
+        q: The queue to write the tokens to as they are generated.
+        """
+        super().__init__()
+        self.q = q
+
+    def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
+        """Run when LLM starts running."""
+        with self.q.mutex:
+            self.q.queue.clear()
+
+    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
+        """Run on new LLM token. Only available when streaming is enabled."""
+        self.q.put(token)
+
+    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
+        """Run when LLM ends running."""
+        self.q.put(STOP_ITEM)
+
+    def on_llm_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
+        """Run when LLM errors."""
+        self.q.put("%s: %s" % (type(error).__name__, str(error)))
+        self.q.put(STOP_ITEM)
+
+
+def generate(rq: queue.Queue):
+    """
+    This is a generator that yields the items in the queue until it reaches the stop item.
+
+    Usage example:
+    ```
+    def askQuestion(callback_fn: StreamingStdOutCallbackHandlerYield):
+        llm = OpenAI(streaming=True, callbacks=[callback_fn])
+        return llm(prompt="Write a poem about a tree.")
+
+    @app.route("/", methods=["GET"])
+    def generate_output():
+        q = Queue()
+        callback_fn = StreamingStdOutCallbackHandlerYield(q)
+        threading.Thread(target=askQuestion, args=(callback_fn,)).start()
+        return Response(generate(q), mimetype="text/event-stream")
+    ```
+    """
+    while True:
+        result: str = rq.get()
+        if result == STOP_ITEM or result is None:
+            break
+        yield result

embedchain/helper/json_serializable.py → embedchain/helpers/json_serializable.py


+ 1 - 1
embedchain/llm/anthropic.py

@@ -3,7 +3,7 @@ import os
 from typing import Optional
 
 from embedchain.config import BaseLlmConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.llm.base import BaseLlm
 
 

+ 1 - 1
embedchain/llm/azure_openai.py

@@ -2,7 +2,7 @@ import logging
 from typing import Optional
 
 from embedchain.config import BaseLlmConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.llm.base import BaseLlm
 
 

+ 1 - 1
embedchain/llm/base.py

@@ -7,7 +7,7 @@ from embedchain.config import BaseLlmConfig
 from embedchain.config.llm.base import (DEFAULT_PROMPT,
                                         DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
                                         DOCS_SITE_PROMPT_TEMPLATE)
-from embedchain.helper.json_serializable import JSONSerializable
+from embedchain.helpers.json_serializable import JSONSerializable
 from embedchain.memory.base import ECChatMemory
 from embedchain.memory.message import ChatMessage
 

+ 1 - 1
embedchain/llm/cohere.py

@@ -5,7 +5,7 @@ from typing import Optional
 from langchain.llms import Cohere
 
 from embedchain.config import BaseLlmConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.llm.base import BaseLlm
 
 

+ 1 - 1
embedchain/llm/gpt4all.py

@@ -4,7 +4,7 @@ from langchain.callbacks.stdout import StdOutCallbackHandler
 from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 
 from embedchain.config import BaseLlmConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.llm.base import BaseLlm
 
 

+ 1 - 1
embedchain/llm/huggingface.py

@@ -5,7 +5,7 @@ from typing import Optional
 from langchain.llms import HuggingFaceHub
 
 from embedchain.config import BaseLlmConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.llm.base import BaseLlm
 
 

+ 1 - 1
embedchain/llm/jina.py

@@ -5,7 +5,7 @@ from langchain.chat_models import JinaChat
 from langchain.schema import HumanMessage, SystemMessage
 
 from embedchain.config import BaseLlmConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.llm.base import BaseLlm
 
 

+ 1 - 1
embedchain/llm/llama2.py

@@ -5,7 +5,7 @@ from typing import Optional
 from langchain.llms import Replicate
 
 from embedchain.config import BaseLlmConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.llm.base import BaseLlm
 
 

+ 3 - 2
embedchain/llm/openai.py

@@ -4,7 +4,7 @@ from langchain.chat_models import ChatOpenAI
 from langchain.schema import HumanMessage, SystemMessage
 
 from embedchain.config import BaseLlmConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.llm.base import BaseLlm
 
 
@@ -34,7 +34,8 @@ class OpenAILlm(BaseLlm):
             from langchain.callbacks.streaming_stdout import \
                 StreamingStdOutCallbackHandler
 
-            chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()])
+            callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()]
+            chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks)
         else:
             chat = ChatOpenAI(**kwargs)
         return chat(messages).content

+ 1 - 1
embedchain/llm/vertex_ai.py

@@ -3,7 +3,7 @@ import logging
 from typing import Optional
 
 from embedchain.config import BaseLlmConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.llm.base import BaseLlm
 
 

+ 1 - 1
embedchain/loaders/base_loader.py

@@ -1,4 +1,4 @@
-from embedchain.helper.json_serializable import JSONSerializable
+from embedchain.helpers.json_serializable import JSONSerializable
 
 
 class BaseLoader(JSONSerializable):

+ 1 - 1
embedchain/loaders/docs_site_loader.py

@@ -12,7 +12,7 @@ except ImportError:
     ) from None
 
 
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.loaders.base_loader import BaseLoader
 
 

+ 1 - 1
embedchain/loaders/docx_file.py

@@ -6,7 +6,7 @@ except ImportError:
     raise ImportError(
         'Docx file requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'
     ) from None
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.loaders.base_loader import BaseLoader
 
 

+ 1 - 1
embedchain/loaders/local_qna_pair.py

@@ -1,6 +1,6 @@
 import hashlib
 
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.loaders.base_loader import BaseLoader
 
 

+ 1 - 1
embedchain/loaders/local_text.py

@@ -1,6 +1,6 @@
 import hashlib
 
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.loaders.base_loader import BaseLoader
 
 

+ 1 - 1
embedchain/loaders/mdx.py

@@ -1,6 +1,6 @@
 import hashlib
 
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.loaders.base_loader import BaseLoader
 
 

+ 1 - 1
embedchain/loaders/notion.py

@@ -10,7 +10,7 @@ except ImportError:
     ) from None
 
 
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.loaders.base_loader import BaseLoader
 from embedchain.utils import clean_string
 

+ 1 - 1
embedchain/loaders/pdf_file.py

@@ -6,7 +6,7 @@ except ImportError:
     raise ImportError(
         'PDF File requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'
     ) from None
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.loaders.base_loader import BaseLoader
 from embedchain.utils import clean_string
 

+ 1 - 1
embedchain/loaders/sitemap.py

@@ -13,7 +13,7 @@ except ImportError:
         'Sitemap requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'
     ) from None
 
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.loaders.base_loader import BaseLoader
 from embedchain.loaders.web_page import WebPageLoader
 from embedchain.utils import is_readable

+ 1 - 1
embedchain/loaders/substack.py

@@ -4,7 +4,7 @@ import time
 
 import requests
 
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.loaders.base_loader import BaseLoader
 from embedchain.utils import is_readable
 

+ 1 - 1
embedchain/loaders/unstructured_file.py

@@ -1,6 +1,6 @@
 import hashlib
 
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.loaders.base_loader import BaseLoader
 from embedchain.utils import clean_string
 

+ 1 - 1
embedchain/loaders/web_page.py

@@ -10,7 +10,7 @@ except ImportError:
         'Webpage requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'
     ) from None
 
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.loaders.base_loader import BaseLoader
 from embedchain.utils import clean_string
 

+ 1 - 1
embedchain/loaders/xml.py

@@ -6,7 +6,7 @@ except ImportError:
     raise ImportError(
         'XML file requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'
     ) from None
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.loaders.base_loader import BaseLoader
 from embedchain.utils import clean_string
 

+ 1 - 1
embedchain/loaders/youtube_video.py

@@ -6,7 +6,7 @@ except ImportError:
     raise ImportError(
         'YouTube video requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'
     ) from None
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.loaders.base_loader import BaseLoader
 from embedchain.utils import clean_string
 

+ 1 - 1
embedchain/memory/message.py

@@ -1,7 +1,7 @@
 import logging
 from typing import Any, Dict, Optional
 
-from embedchain.helper.json_serializable import JSONSerializable
+from embedchain.helpers.json_serializable import JSONSerializable
 
 
 class BaseMessage(JSONSerializable):

+ 1 - 1
embedchain/pipeline.py

@@ -15,7 +15,7 @@ from embedchain.embedchain import EmbedChain
 from embedchain.embedder.base import BaseEmbedder
 from embedchain.embedder.openai import OpenAIEmbedder
 from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.llm.base import BaseLlm
 from embedchain.llm.openai import OpenAILlm
 from embedchain.telemetry.posthog import AnonymousTelemetry

+ 1 - 1
embedchain/vectordb/base.py

@@ -1,6 +1,6 @@
 from embedchain.config.vectordb.base import BaseVectorDbConfig
 from embedchain.embedder.base import BaseEmbedder
-from embedchain.helper.json_serializable import JSONSerializable
+from embedchain.helpers.json_serializable import JSONSerializable
 
 
 class BaseVectorDB(JSONSerializable):

+ 1 - 1
embedchain/vectordb/chroma.py

@@ -6,7 +6,7 @@ from langchain.docstore.document import Document
 from tqdm import tqdm
 
 from embedchain.config import ChromaDbConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.vectordb.base import BaseVectorDB
 
 try:

+ 1 - 1
embedchain/vectordb/elasticsearch.py

@@ -10,7 +10,7 @@ except ImportError:
     ) from None
 
 from embedchain.config import ElasticsearchDBConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.vectordb.base import BaseVectorDB
 
 

+ 1 - 1
embedchain/vectordb/opensearch.py

@@ -16,7 +16,7 @@ from langchain.embeddings.openai import OpenAIEmbeddings
 from langchain.vectorstores import OpenSearchVectorSearch
 
 from embedchain.config import OpenSearchDBConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.vectordb.base import BaseVectorDB
 
 

+ 1 - 1
embedchain/vectordb/pinecone.py

@@ -9,7 +9,7 @@ except ImportError:
     ) from None
 
 from embedchain.config.vectordb.pinecone import PineconeDBConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.vectordb.base import BaseVectorDB
 
 

+ 1 - 1
embedchain/vectordb/weaviate.py

@@ -10,7 +10,7 @@ except ImportError:
     ) from None
 
 from embedchain.config.vectordb.weaviate import WeaviateDBConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.vectordb.base import BaseVectorDB
 
 

+ 1 - 1
embedchain/vectordb/zilliz.py

@@ -2,7 +2,7 @@ import logging
 from typing import Dict, List, Optional, Tuple, Union
 
 from embedchain.config import ZillizDBConfig
-from embedchain.helper.json_serializable import register_deserializable
+from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.vectordb.base import BaseVectorDB
 
 try:

+ 1 - 1
pyproject.toml

@@ -1,6 +1,6 @@
 [tool.poetry]
 name = "embedchain"
-version = "0.1.18"
+version = "0.1.19"
 description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
 authors = [
     "Taranjeet Singh <taranjeet@embedchain.ai>",

+ 2 - 2
tests/helper_classes/test_json_serializable.py

@@ -4,8 +4,8 @@ from string import Template
 
 from embedchain import App
 from embedchain.config import AppConfig, BaseLlmConfig
-from embedchain.helper.json_serializable import (JSONSerializable,
-                                                 register_deserializable)
+from embedchain.helpers.json_serializable import (JSONSerializable,
+                                                  register_deserializable)
 
 
 class TestJsonSerializable(unittest.TestCase):