YoutubeTranscriptTool / transcription.py
maguid28's picture
Implemented transcription fallback using Whisper
4f48868
raw
history blame
3.25 kB
import os
import subprocess
import tempfile
import yt_dlp
import torch
from transformers import pipeline
from logging_config import logger, log_buffer
device = "cuda" if torch.cuda.is_available() else "cpu"
def convert_audio_to_wav(input_file: str, output_file: str) -> str:
logger.info(f"Converting {input_file} to WAV: {output_file}")
cmd = [
"ffmpeg",
"-y",
"-i", input_file,
"-ar", "16000", # sample rate
"-ac", "1", # mono
output_file
]
subprocess.run(cmd, check=True)
return output_file
def fallback_whisper_transcription(youtube_url: str):
# returns (transcript, logs).
try:
with tempfile.TemporaryDirectory() as tmpdir:
# Create temp dir
logger.info("")
logger.info(f"Created temporary directory: {tmpdir}")
logger.info("")
yield "", log_buffer.getvalue()
# Download best audio
logger.info("Downloading best audio via yt-dlp...")
logger.info("")
yield "", log_buffer.getvalue()
download_path = os.path.join(tmpdir, "audio.%(ext)s")
ydl_opts = {
'format': 'bestaudio/best',
'outtmpl': download_path,
'quiet': True,
'postprocessors': []
}
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
ydl.download([youtube_url])
logger.info("Audio downloaded. Locating the audio file in the temp folder...")
logger.info("")
yield "", log_buffer.getvalue()
# confirm audio file
downloaded_files = os.listdir(tmpdir)
if not downloaded_files:
raise RuntimeError("No audio file was downloaded via yt-dlp.")
audio_file_path = os.path.join(tmpdir, downloaded_files[0])
logger.info(f"Found audio file: {audio_file_path}")
logger.info("Video has downloaded!")
logger.info("")
yield "", log_buffer.getvalue()
# Convert to wav
wav_file_path = os.path.join(tmpdir, "audio.wav")
convert_audio_to_wav(audio_file_path, wav_file_path)
logger.info("Audio converted to WAV successfully.")
logger.info("")
yield "", log_buffer.getvalue()
# Run whisper
logger.info("Running Whisper ASR pipeline on the WAV file...")
logger.info("")
yield "", log_buffer.getvalue()
asr_pipeline = pipeline(
"automatic-speech-recognition",
model="openai/whisper-small",
return_timestamps=True,
device=device,
generate_kwargs={"task": "transcribe", "language": "<|en|>"}
)
result = asr_pipeline(inputs=wav_file_path)
transcription = result["text"]
logger.info("Whisper transcription completed successfully.")
logger.info("")
yield transcription, log_buffer.getvalue()
except Exception as e:
err_msg = f"Error in fallback transcription: {str(e)}"
logger.error(err_msg)
yield err_msg, log_buffer.getvalue()