فهرست منبع

Support for Audio Files (#1416)

Dev Khant 1 سال پیش
والد
کامیت
08b67b4a78

+ 1 - 1
Makefile

@@ -11,7 +11,7 @@ install:
 
 install_all:
 	poetry install --all-extras
-	poetry run pip install pinecone-text pinecone-client langchain-anthropic "unstructured[local-inference, all-docs]" ollama
+	poetry run pip install pinecone-text pinecone-client langchain-anthropic "unstructured[local-inference, all-docs]" ollama deepgram-sdk==3.2.7 
 
 install_es:
 	poetry install --extras elasticsearch

+ 25 - 0
docs/components/data-sources/audio.mdx

@@ -0,0 +1,25 @@
+---
+title: "🎤 Audio"
+---
+
+
+To use an audio as data source, just add `data_type` as `audio` and pass in the path of the audio (local or hosted).
+
+We use [Deepgram](https://developers.deepgram.com/docs/introduction) to transcribe the audiot to text, and then use the generated text as the data source.
+
+You would require an Deepgram API key which is available [here](https://console.deepgram.com/signup?jump=keys) to use this feature.
+
+### Without customization
+
+```python
+import os
+from embedchain import App
+
+os.environ["DEEPGRAM_API_KEY"] = "153xxx"
+
+app = App()
+app.add("introduction.wav", data_type="audio")
+response = app.query("What is my name and how old am I?")
+print(response)
+# Answer: Your name is Dave and you are 21 years old.
+```

+ 2 - 0
docs/components/data-sources/overview.mdx

@@ -9,6 +9,7 @@ Embedchain comes with built-in support for various data sources. We handle the c
   <Card title="CSV file" href="/components/data-sources/csv"></Card>
   <Card title="JSON file" href="/components/data-sources/json"></Card>
   <Card title="Text" href="/components/data-sources/text"></Card>
+  <Card title="Text File" href="/components/data-sources/text-file"></Card>
   <Card title="Directory" href="/components/data-sources/directory"></Card>
   <Card title="Web page" href="/components/data-sources/web-page"></Card>
   <Card title="Youtube Channel" href="/components/data-sources/youtube-channel"></Card>
@@ -33,6 +34,7 @@ Embedchain comes with built-in support for various data sources. We handle the c
   <Card title="Beehiiv" href="/components/data-sources/beehiiv"></Card>
   <Card title="Dropbox" href="/components/data-sources/dropbox"></Card>
   <Card title="Image" href="/components/data-sources/image"></Card>
+  <Card title="Audio" href="/components/data-sources/audio"></Card>
   <Card title="Custom" href="/components/data-sources/custom"></Card>
 </CardGroup>
 

+ 22 - 0
embedchain/chunkers/audio.py

@@ -0,0 +1,22 @@
+from typing import Optional
+
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+
+from embedchain.chunkers.base_chunker import BaseChunker
+from embedchain.config.add_config import ChunkerConfig
+from embedchain.helpers.json_serializable import register_deserializable
+
+
+@register_deserializable
+class AudioChunker(BaseChunker):
+    """Chunker for audio."""
+
+    def __init__(self, config: Optional[ChunkerConfig] = None):
+        if config is None:
+            config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
+        text_splitter = RecursiveCharacterTextSplitter(
+            chunk_size=config.chunk_size,
+            chunk_overlap=config.chunk_overlap,
+            length_function=config.length_function,
+        )
+        super().__init__(text_splitter)

+ 2 - 0
embedchain/data_formatter/data_formatter.py

@@ -81,6 +81,7 @@ class DataFormatter(JSONSerializable):
             DataType.DROPBOX: "embedchain.loaders.dropbox.DropboxLoader",
             DataType.TEXT_FILE: "embedchain.loaders.text_file.TextFileLoader",
             DataType.EXCEL_FILE: "embedchain.loaders.excel_file.ExcelFileLoader",
+            DataType.AUDIO: "embedchain.loaders.audio.AudioLoader",
         }
 
         if data_type == DataType.CUSTOM or loader is not None:
@@ -129,6 +130,7 @@ class DataFormatter(JSONSerializable):
             DataType.DROPBOX: "embedchain.chunkers.common_chunker.CommonChunker",
             DataType.TEXT_FILE: "embedchain.chunkers.common_chunker.CommonChunker",
             DataType.EXCEL_FILE: "embedchain.chunkers.excel_file.ExcelFileChunker",
+            DataType.AUDIO: "embedchain.chunkers.audio.AudioChunker",
         }
 
         if chunker is not None:

+ 51 - 0
embedchain/loaders/audio.py

@@ -0,0 +1,51 @@
+import os
+import hashlib
+import validators
+from embedchain.helpers.json_serializable import register_deserializable
+from embedchain.loaders.base_loader import BaseLoader
+
+try:
+    from deepgram import DeepgramClient, PrerecordedOptions
+except ImportError:
+    raise ImportError(
+        "Audio file requires extra dependencies. Install with `pip install deepgram-sdk==3.2.7`"
+    ) from None
+
+
+@register_deserializable
+class AudioLoader(BaseLoader):
+    def __init__(self):
+        if not os.environ.get("DEEPGRAM_API_KEY"):
+            raise ValueError("DEEPGRAM_API_KEY is not set")
+
+        DG_KEY = os.environ.get("DEEPGRAM_API_KEY")
+        self.client = DeepgramClient(DG_KEY)
+
+    def load_data(self, url: str):
+        """Load data from a audio file or URL."""
+
+        options = PrerecordedOptions(
+            model="nova-2",
+            smart_format=True,
+        )
+        if validators.url(url):
+            source = {"url": url}
+            response = self.client.listen.prerecorded.v("1").transcribe_url(source, options)
+        else:
+            with open(url, "rb") as audio:
+                source = {"buffer": audio}
+                response = self.client.listen.prerecorded.v("1").transcribe_file(source, options)
+        content = response["results"]["channels"][0]["alternatives"][0]["transcript"]
+
+        doc_id = hashlib.sha256((content + url).encode()).hexdigest()
+        metadata = {"url": url}
+
+        return {
+            "doc_id": doc_id,
+            "data": [
+                {
+                    "content": content,
+                    "meta_data": metadata,
+                }
+            ],
+        }

+ 2 - 0
embedchain/models/data_type.py

@@ -41,6 +41,7 @@ class IndirectDataType(Enum):
     DROPBOX = "dropbox"
     TEXT_FILE = "text_file"
     EXCEL_FILE = "excel_file"
+    AUDIO = "audio"
 
 
 class SpecialDataType(Enum):
@@ -81,3 +82,4 @@ class DataType(Enum):
     DROPBOX = IndirectDataType.DROPBOX.value
     TEXT_FILE = IndirectDataType.TEXT_FILE.value
     EXCEL_FILE = IndirectDataType.EXCEL_FILE.value
+    AUDIO = IndirectDataType.AUDIO.value

+ 6 - 0
embedchain/utils/misc.py

@@ -237,6 +237,12 @@ def detect_datatype(source: Any) -> DataType:
             logger.debug(f"Source of `{formatted_source}` detected as `docx`.")
             return DataType.DOCX
 
+        if url.path.endswith(
+            (".mp3", ".mp4", ".mp2", ".aac", ".wav", ".flac", ".pcm", ".m4a", ".ogg", ".opus", ".webm")
+        ):
+            logger.debug(f"Source of `{formatted_source}` detected as `audio`.")
+            return DataType.AUDIO
+
         if url.path.endswith(".yaml"):
             try:
                 response = requests.get(source)

+ 2 - 0
tests/chunkers/test_chunkers.py

@@ -19,6 +19,7 @@ from embedchain.chunkers.text import TextChunker
 from embedchain.chunkers.web_page import WebPageChunker
 from embedchain.chunkers.xml import XmlChunker
 from embedchain.chunkers.youtube_video import YoutubeVideoChunker
+from embedchain.chunkers.audio import AudioChunker
 from embedchain.config.add_config import ChunkerConfig
 
 chunker_config = ChunkerConfig(chunk_size=500, chunk_overlap=0, length_function=len)
@@ -45,6 +46,7 @@ chunker_common_config = {
     CommonChunker: {"chunk_size": 2000, "chunk_overlap": 0, "length_function": len},
     GoogleDriveChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
     ExcelFileChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
+    AudioChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
 }
 
 

+ 98 - 0
tests/loaders/test_audio.py

@@ -0,0 +1,98 @@
+import os
+import sys
+import hashlib
+import pytest
+from unittest.mock import mock_open, patch
+
+if sys.version_info > (3, 10):  # as `match` statement was introduced in python 3.10
+    from deepgram import PrerecordedOptions
+    from embedchain.loaders.audio import AudioLoader
+
+
+@pytest.fixture
+def setup_audio_loader(mocker):
+    mock_dropbox = mocker.patch("deepgram.DeepgramClient")
+    mock_dbx = mocker.MagicMock()
+    mock_dropbox.return_value = mock_dbx
+
+    os.environ["DEEPGRAM_API_KEY"] = "test_key"
+    loader = AudioLoader()
+    loader.client = mock_dbx
+
+    yield loader, mock_dbx
+
+    if "DEEPGRAM_API_KEY" in os.environ:
+        del os.environ["DEEPGRAM_API_KEY"]
+
+
+@pytest.mark.skipif(
+    sys.version_info < (3, 10), reason="Test skipped for Python 3.9 or lower"
+)  # as `match` statement was introduced in python 3.10
+def test_initialization(setup_audio_loader):
+    """Test initialization of AudioLoader."""
+    loader, _ = setup_audio_loader
+    assert loader is not None
+
+
+@pytest.mark.skipif(
+    sys.version_info < (3, 10), reason="Test skipped for Python 3.9 or lower"
+)  # as `match` statement was introduced in python 3.10
+def test_load_data_from_url(setup_audio_loader):
+    loader, mock_dbx = setup_audio_loader
+    url = "https://example.com/audio.mp3"
+    expected_content = "This is a test audio transcript."
+
+    mock_response = {"results": {"channels": [{"alternatives": [{"transcript": expected_content}]}]}}
+    mock_dbx.listen.prerecorded.v.return_value.transcribe_url.return_value = mock_response
+
+    result = loader.load_data(url)
+
+    doc_id = hashlib.sha256((expected_content + url).encode()).hexdigest()
+    expected_result = {
+        "doc_id": doc_id,
+        "data": [
+            {
+                "content": expected_content,
+                "meta_data": {"url": url},
+            }
+        ],
+    }
+
+    assert result == expected_result
+    mock_dbx.listen.prerecorded.v.assert_called_once_with("1")
+    mock_dbx.listen.prerecorded.v.return_value.transcribe_url.assert_called_once_with(
+        {"url": url}, PrerecordedOptions(model="nova-2", smart_format=True)
+    )
+
+
+@pytest.mark.skipif(
+    sys.version_info < (3, 10), reason="Test skipped for Python 3.9 or lower"
+)  # as `match` statement was introduced in python 3.10
+def test_load_data_from_file(setup_audio_loader):
+    loader, mock_dbx = setup_audio_loader
+    file_path = "local_audio.mp3"
+    expected_content = "This is a test audio transcript."
+
+    mock_response = {"results": {"channels": [{"alternatives": [{"transcript": expected_content}]}]}}
+    mock_dbx.listen.prerecorded.v.return_value.transcribe_file.return_value = mock_response
+
+    # Mock the file reading functionality
+    with patch("builtins.open", mock_open(read_data=b"some data")) as mock_file:
+        result = loader.load_data(file_path)
+
+    doc_id = hashlib.sha256((expected_content + file_path).encode()).hexdigest()
+    expected_result = {
+        "doc_id": doc_id,
+        "data": [
+            {
+                "content": expected_content,
+                "meta_data": {"url": file_path},
+            }
+        ],
+    }
+
+    assert result == expected_result
+    mock_dbx.listen.prerecorded.v.assert_called_once_with("1")
+    mock_dbx.listen.prerecorded.v.return_value.transcribe_file.assert_called_once_with(
+        {"buffer": mock_file.return_value}, PrerecordedOptions(model="nova-2", smart_format=True)
+    )