test_audio.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import hashlib
  2. import os
  3. import sys
  4. from unittest.mock import mock_open, patch
  5. import pytest
  6. if sys.version_info > (3, 10): # as `match` statement was introduced in python 3.10
  7. from deepgram import PrerecordedOptions
  8. from embedchain.loaders.audio import AudioLoader
  9. @pytest.fixture
  10. def setup_audio_loader(mocker):
  11. mock_dropbox = mocker.patch("deepgram.DeepgramClient")
  12. mock_dbx = mocker.MagicMock()
  13. mock_dropbox.return_value = mock_dbx
  14. os.environ["DEEPGRAM_API_KEY"] = "test_key"
  15. loader = AudioLoader()
  16. loader.client = mock_dbx
  17. yield loader, mock_dbx
  18. if "DEEPGRAM_API_KEY" in os.environ:
  19. del os.environ["DEEPGRAM_API_KEY"]
  20. @pytest.mark.skipif(
  21. sys.version_info < (3, 10), reason="Test skipped for Python 3.9 or lower"
  22. ) # as `match` statement was introduced in python 3.10
  23. def test_initialization(setup_audio_loader):
  24. """Test initialization of AudioLoader."""
  25. loader, _ = setup_audio_loader
  26. assert loader is not None
  27. @pytest.mark.skipif(
  28. sys.version_info < (3, 10), reason="Test skipped for Python 3.9 or lower"
  29. ) # as `match` statement was introduced in python 3.10
  30. def test_load_data_from_url(setup_audio_loader):
  31. loader, mock_dbx = setup_audio_loader
  32. url = "https://example.com/audio.mp3"
  33. expected_content = "This is a test audio transcript."
  34. mock_response = {"results": {"channels": [{"alternatives": [{"transcript": expected_content}]}]}}
  35. mock_dbx.listen.prerecorded.v.return_value.transcribe_url.return_value = mock_response
  36. result = loader.load_data(url)
  37. doc_id = hashlib.sha256((expected_content + url).encode()).hexdigest()
  38. expected_result = {
  39. "doc_id": doc_id,
  40. "data": [
  41. {
  42. "content": expected_content,
  43. "meta_data": {"url": url},
  44. }
  45. ],
  46. }
  47. assert result == expected_result
  48. mock_dbx.listen.prerecorded.v.assert_called_once_with("1")
  49. mock_dbx.listen.prerecorded.v.return_value.transcribe_url.assert_called_once_with(
  50. {"url": url}, PrerecordedOptions(model="nova-2", smart_format=True)
  51. )
  52. @pytest.mark.skipif(
  53. sys.version_info < (3, 10), reason="Test skipped for Python 3.9 or lower"
  54. ) # as `match` statement was introduced in python 3.10
  55. def test_load_data_from_file(setup_audio_loader):
  56. loader, mock_dbx = setup_audio_loader
  57. file_path = "local_audio.mp3"
  58. expected_content = "This is a test audio transcript."
  59. mock_response = {"results": {"channels": [{"alternatives": [{"transcript": expected_content}]}]}}
  60. mock_dbx.listen.prerecorded.v.return_value.transcribe_file.return_value = mock_response
  61. # Mock the file reading functionality
  62. with patch("builtins.open", mock_open(read_data=b"some data")) as mock_file:
  63. result = loader.load_data(file_path)
  64. doc_id = hashlib.sha256((expected_content + file_path).encode()).hexdigest()
  65. expected_result = {
  66. "doc_id": doc_id,
  67. "data": [
  68. {
  69. "content": expected_content,
  70. "meta_data": {"url": file_path},
  71. }
  72. ],
  73. }
  74. assert result == expected_result
  75. mock_dbx.listen.prerecorded.v.assert_called_once_with("1")
  76. mock_dbx.listen.prerecorded.v.return_value.transcribe_file.assert_called_once_with(
  77. {"buffer": mock_file.return_value}, PrerecordedOptions(model="nova-2", smart_format=True)
  78. )