Quellcode durchsuchen

Added language detection for non-english youtube videos (#1362)

Ananto Joyoadikusumo vor 1 Jahr
Ursprung
Commit
4800e0344c
3 geänderte Dateien mit 19 neuen und 18 gelöschten Zeilen
  1. 2 1
      embedchain/llm/openai.py
  2. 16 11
      embedchain/loaders/youtube_video.py
  3. 1 6
      tests/loaders/test_youtube_video.py

+ 2 - 1
embedchain/llm/openai.py

@@ -69,7 +69,8 @@ class OpenAILlm(BaseLlm):
         messages: list[BaseMessage],
     ) -> str:
         from langchain.output_parsers.openai_tools import JsonOutputToolsParser
-        from langchain_core.utils.function_calling import convert_to_openai_tool
+        from langchain_core.utils.function_calling import \
+            convert_to_openai_tool
 
         openai_tools = [convert_to_openai_tool(tools)]
         chat = chat.bind(tools=openai_tools).pipe(JsonOutputToolsParser())

+ 16 - 11
embedchain/loaders/youtube_video.py

@@ -8,6 +8,7 @@ except ImportError:
     raise ImportError('YouTube video requires extra dependencies. Install with `pip install youtube-transcript-api "`')
 try:
     from langchain_community.document_loaders import YoutubeLoader
+    from langchain_community.document_loaders.youtube import _parse_video_id
 except ImportError:
     raise ImportError(
         'YouTube video requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'
@@ -21,7 +22,20 @@ from embedchain.utils.misc import clean_string
 class YoutubeVideoLoader(BaseLoader):
     def load_data(self, url):
         """Load data from a Youtube video."""
-        loader = YoutubeLoader.from_youtube_url(url, add_video_info=True)
+        video_id = _parse_video_id(url)
+
+        languages = ["en"]
+        try:
+            # Fetching transcript data
+            languages = [transcript.language_code for transcript in YouTubeTranscriptApi.list_transcripts(video_id)]
+            transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=languages)
+            # convert transcript to json to avoid unicode symboles
+            transcript = json.dumps(transcript, ensure_ascii=True)
+        except Exception:
+            logging.exception(f"Failed to fetch transcript for video {url}")
+            transcript = "Unavailable"
+
+        loader = YoutubeLoader.from_youtube_url(url, add_video_info=True, language=languages)
         doc = loader.load()
         output = []
         if not len(doc):
@@ -30,16 +44,7 @@ class YoutubeVideoLoader(BaseLoader):
         content = clean_string(content)
         metadata = doc[0].metadata
         metadata["url"] = url
-
-        video_id = url.split("v=")[1].split("&")[0]
-        try:
-            # Fetching transcript data
-            transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=["en"])
-            # convert transcript to json to avoid unicode symboles
-            metadata["transcript"] = json.dumps(transcript, ensure_ascii=True)
-        except Exception:
-            logging.exception(f"Failed to fetch transcript for video {url}")
-            metadata["transcript"] = "Unavailable"
+        metadata["transcript"] = transcript
 
         output.append(
             {

+ 1 - 6
tests/loaders/test_youtube_video.py

@@ -1,5 +1,4 @@
 import hashlib
-import json
 from unittest.mock import MagicMock, Mock, patch
 
 import pytest
@@ -37,11 +36,7 @@ def test_load_data(youtube_video_loader):
     expected_data = [
         {
             "content": "This is a YouTube video content.",
-            "meta_data": {
-                "url": video_url,
-                "title": "Test Video",
-                "transcript": json.dumps(mock_transcript, ensure_ascii=True),
-            },
+            "meta_data": {"url": video_url, "title": "Test Video", "transcript": "Unavailable"},
         }
     ]