Spaces:
Running
Running
import gradio as gr | |
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor | |
import requests | |
import os | |
from moviepy.editor import VideoFileClip | |
import tempfile | |
import re | |
from urllib.parse import urlparse | |
from gradio import Progress | |
from pathlib import Path | |
import torch | |
import shutil # Import shutil for explicit temporary directory cleanup | |
import soundfile as sf # Import soundfile for explicit audio loading | |
# Load the audio classification model for English accents | |
pipe = pipeline("audio-classification", model="dima806/english_accents_classification") | |
# Load the language detection model | |
language_detector = pipeline("text-classification", model="alexneakameni/language_detection") | |
# Load a small ASR (Automatic Speech Recognition) model for transcribing audio clips | |
# This is used to get text from audio for language detection. | |
# Using 'openai/whisper-tiny.en' for a faster, English-focused transcription. | |
# Ensure to move model to GPU if available for faster inference. | |
device = 0 if torch.cuda.is_available() else -1 | |
# Corrected ASR model ID to a valid Hugging Face model | |
asr_model_id = "openai/whisper-tiny.en" | |
asr_model = AutoModelForSpeechSeq2Seq.from_pretrained(asr_model_id) | |
asr_processor = AutoProcessor.from_pretrained(asr_model_id) | |
asr_pipe = pipeline( | |
"automatic-speech-recognition", | |
model=asr_model, | |
tokenizer=asr_processor.tokenizer, | |
feature_extractor=asr_processor.feature_extractor, | |
device=device | |
) | |
def is_valid_url(url): | |
""" | |
Checks if the given URL is valid and from allowed domains (MP4, Loom, or Google Drive). | |
Args: | |
url (str): The URL to validate. | |
Returns: | |
bool: True if the URL is valid and allowed, False otherwise. | |
""" | |
if not url: | |
return False | |
try: | |
result = urlparse(url) | |
if not all([result.scheme, result.netloc]): | |
return False | |
allowed_domains = [ | |
'loom.com', | |
'cdn.loom.com', | |
'www.dropbox.com', | |
'dl.dropboxusercontent.com', | |
'drive.google.com' # Added Google Drive domain | |
] | |
# Check if the domain is in our allowed list | |
is_allowed_domain = any(domain in result.netloc.lower() for domain in allowed_domains) | |
# Check if the path part of the URL ends with .mp4 | |
ends_with_mp4 = result.path.lower().endswith('.mp4') | |
if is_allowed_domain: | |
if ends_with_mp4: | |
return True | |
elif 'drive.google.com' in result.netloc.lower(): | |
# Check for typical Google Drive patterns for shared files or download links | |
return '/file/d/' in result.path or '/uc' in result.path | |
elif any(domain in result.netloc.lower() for domain in ['loom.com', 'cdn.loom.com']): | |
return True # Allow Loom URLs even if they don't end in .mp4 | |
elif ends_with_mp4: | |
# Allow direct .mp4 links from other domains if they end with .mp4 | |
return True | |
return False | |
except Exception: | |
return False | |
def is_valid_file(file_obj): | |
""" | |
Checks if the uploaded file object represents a valid video file format. | |
Args: | |
file_obj (gr.File): The Gradio file object. | |
Returns: | |
bool: True if the file is a supported video format, False otherwise. | |
""" | |
if not file_obj: | |
return False | |
# Get the file extension from the uploaded file object's name | |
file_path = file_obj.name | |
# Check if the file extension is one of the supported video formats | |
return Path(file_path).suffix.lower() in ['.mp4', '.mov', '.avi', '.mkv'] | |
def download_file(url, save_path, progress=Progress()): | |
""" | |
Downloads a video file from a given URL to a specified path. | |
Raises ValueError if the URL is invalid, ConnectionError if download fails. | |
Args: | |
url (str): The URL of the video to download. | |
save_path (str): The local path to save the downloaded video. | |
progress (gradio.Progress): Gradio progress tracker for UI updates. | |
""" | |
if not is_valid_url(url): | |
raise ValueError("Invalid URL. Only .mp4 files or Loom videos are accepted.") | |
response = requests.get(url, stream=True) | |
# Check if the download was successful (HTTP status code 200) | |
if response.status_code != 200: | |
raise ConnectionError(f"Failed to download video (HTTP {response.status_code})") | |
# Get the total size of the file for progress tracking | |
total_size = int(response.headers.get('content-length', 0)) | |
downloaded = 0 | |
# Write the downloaded content to the specified save path in chunks | |
with open(save_path, 'wb') as f: | |
for chunk in response.iter_content(chunk_size=8192): | |
if chunk: # Filter out keep-alive new chunks | |
f.write(chunk) | |
downloaded += len(chunk) | |
if total_size > 0: | |
# Update progress bar based on downloaded percentage | |
progress(downloaded / total_size, desc="π₯ Downloading video...") | |
else: | |
# If total size is unknown, just show a general downloading message | |
progress(0, desc="π₯ Downloading video (size unknown)...") | |
def extract_audio_full(video_path, progress=Progress()): | |
""" | |
Extracts the full duration of audio from a video file and saves it as a WAV file. | |
Uses tempfile.NamedTemporaryFile to ensure the file persists for Gradio. | |
Args: | |
video_path (str): Path to the input video file. | |
progress (gradio.Progress): Gradio progress tracker for UI updates. | |
Returns: | |
str: The path to the extracted audio file. | |
""" | |
try: | |
progress(0, desc="π Extracting full audio for playback...") | |
video = VideoFileClip(video_path) | |
# Create a temporary WAV file that Gradio can manage | |
temp_audio_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) | |
audio_path = temp_audio_file.name | |
temp_audio_file.close() # Close the file handle immediately so moviepy can write to it | |
audio_clip = video.audio | |
audio_clip.write_audiofile(audio_path, fps=16000, logger=None) | |
video.close() | |
audio_clip.close() | |
progress(1.0) | |
return audio_path | |
except Exception as e: | |
raise Exception(f"Full audio extraction failed: {str(e)}") | |
def extract_audio_clip(video_path, audio_path, duration, progress=Progress()): | |
""" | |
Extracts a specified duration of audio from a video file and saves it as a WAV file. | |
Args: | |
video_path (str): Path to the input video file. | |
audio_path (str): Path to save the extracted audio WAV file. | |
duration (int): The duration of audio to extract in seconds. | |
progress (gradio.Progress): Gradio progress tracker for UI updates. | |
Returns: | |
str: The path to the extracted audio file. | |
""" | |
try: | |
progress(0, desc=f"π Extracting {duration} seconds of audio for analysis...") | |
video = VideoFileClip(video_path) | |
# Ensure the subclip duration does not exceed the video's actual duration | |
clip_duration = min(duration, video.duration) | |
audio_clip = video.audio.subclip(0, clip_duration) | |
audio_clip.write_audiofile(audio_path, fps=16000, logger=None) | |
video.close() | |
audio_clip.close() | |
progress(1.0) | |
return audio_path | |
except Exception as e: | |
raise Exception(f"Audio clip extraction failed: {str(e)}") | |
def transcribe_audio(audio_path_clip, progress=Progress()): | |
""" | |
Transcribes a short audio clip to text using the ASR pipeline. | |
Args: | |
audio_path_clip (str): Path to the short audio clip. | |
Returns: | |
str: The transcribed text. | |
""" | |
try: | |
progress(0, desc="π Transcribing audio for language detection...") | |
# Load audio using soundfile | |
audio_input, sampling_rate = sf.read(audio_path_clip) | |
# Ensure the audio is mono if the model expects it (Whisper typically does) | |
if audio_input.ndim > 1: | |
audio_input = audio_input.mean(axis=1) # Convert to mono | |
# Process audio with the ASR processor | |
# This handles resampling, padding, and feature extraction to match model requirements | |
inputs = asr_processor(audio_input, sampling_rate=sampling_rate, return_tensors="pt") | |
# Move inputs to the correct device | |
if device != -1: | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
# Generate transcription with the ASR model | |
with torch.no_grad(): | |
# max_new_tokens can be adjusted based on expected transcription length | |
# For short clips (15s), 128 is usually more than enough | |
output_tokens = asr_model.generate(**inputs, max_new_tokens=128) | |
text = asr_processor.tokenizer.batch_decode(output_tokens, skip_special_tokens=True)[0] | |
progress(1.0) | |
return text | |
except Exception as e: | |
print(f"Transcription failed: {e}") | |
return "" # Return empty string on failure | |
def classify_audio(audio_path, progress=Progress()): | |
""" | |
Classifies the accent in an audio file using the pre-loaded Hugging Face pipeline. | |
Args: | |
audio_path (str): Path to the input audio file. | |
Returns: | |
list: A list of dictionaries containing accent labels and confidence scores. | |
""" | |
try: | |
progress(0, desc="π Analyzing accent - please be patient...") | |
result = pipe(audio_path) | |
progress(1.0) # Mark completion | |
return result | |
except Exception as e: | |
raise Exception(f"Classification failed: {str(e)}") | |
def process_video_unified(video_source, analysis_duration, progress=Progress()): | |
""" | |
Processes either a video URL or an uploaded video file to classify accent. | |
Includes language detection before accent classification. | |
Args: | |
video_source (str or gr.File): The input, either a URL string or a Gradio File object. | |
analysis_duration (int): The duration of audio to analyze for accent classification in seconds. | |
progress (gradio.Progress): Gradio progress tracker for UI updates. | |
Returns: | |
tuple: (language_status_html, html_output, audio_path, error_flag) | |
language_status_html (str): HTML string displaying language detection status. | |
html_output (str): HTML string displaying accent results or error. | |
audio_path (str or None): Path to extracted full audio if successful, else None. | |
error_flag (bool): True if an error occurred, False otherwise. | |
""" | |
temp_dir = None | |
full_audio_path = None # Initialize to None | |
try: | |
temp_dir = tempfile.mkdtemp() # Create temp dir for intermediate files (video, clipped audio) | |
video_path = os.path.join(temp_dir, "video.mp4") | |
# Determine if input is a URL string or an uploaded Gradio File object | |
if isinstance(video_source, str) and video_source.startswith(('http://', 'https://')): | |
if not is_valid_url(video_source): | |
raise ValueError("Invalid URL. Only .mp4 files or Loom videos are accepted.") | |
download_file(video_source, video_path, progress) | |
elif hasattr(video_source, 'name'): | |
if not is_valid_file(video_source): | |
raise ValueError("Invalid file format. Please upload a video file (MP4)") | |
with open(video_source.name, 'rb') as src_file: | |
with open(video_path, 'wb') as dest_file: | |
dest_file.write(src_file.read()) | |
else: | |
raise ValueError("Unsupported input type. Please provide a video URL or upload a file.") | |
# Verify that the video file exists after download/upload | |
if not os.path.exists(video_path): | |
raise Exception("Video processing failed: Video file not found after download/upload.") | |
# Extract full audio for playback using tempfile.NamedTemporaryFile | |
full_audio_path = extract_audio_full(video_path, progress) | |
# Extract a short clip for transcription and language detection (e.g., first 15 seconds) | |
transcription_clip_duration = 15 | |
audio_for_transcription_path = os.path.join(temp_dir, "audio_for_transcription.wav") | |
extract_audio_clip(video_path, audio_for_transcription_path, transcription_clip_duration, progress) | |
if not os.path.exists(full_audio_path): | |
raise Exception("Audio extraction failed: Full audio file not found.") | |
if not os.path.exists(audio_for_transcription_path): | |
raise Exception("Audio extraction failed: Clipped audio for transcription not found.") | |
# Transcribe the short audio clip | |
transcribed_text = transcribe_audio(audio_for_transcription_path, progress) | |
if not transcribed_text.strip(): | |
language_status_html = "<p style='color: orange; font-weight: bold;'>β οΈ Could not transcribe audio for language detection. Please ensure audio is clear.</p>" | |
# If transcription fails, we can't detect language, so we'll proceed with accent classification | |
# but provide a warning. Or, you could choose to stop here. For now, let's proceed. | |
else: | |
# Perform language detection | |
lang_detection_result = language_detector(transcribed_text) | |
detected_language = lang_detection_result[0]['label'] | |
lang_confidence = lang_detection_result[0]['score'] | |
# Check if detected language is English or eng_Latn with a reasonable confidence | |
if (detected_language.lower() == 'english' or detected_language.lower() == 'eng_latn') and lang_confidence > 0.7: # Added 'eng_Latn' check | |
language_status_html = f"<p style='color: green; font-weight: bold;'>β Verified English Language (Confidence: {lang_confidence*100:.2f}%)</p>" | |
else: | |
language_status_html = f"<p style='color: red; font-weight: bold;'>β οΈ Detected language: {detected_language.capitalize()} (Confidence: {lang_confidence*100:.2f}%). Please provide English audio for accent classification.</p>" | |
# If not English, return early with an error message and skip accent classification | |
return language_status_html, "", full_audio_path, True # Set error flag to True | |
# Extract audio clip for accent classification (based on analysis_duration slider) | |
audio_for_classification_path = os.path.join(temp_dir, "audio_for_classification.wav") | |
extract_audio_clip(video_path, audio_for_classification_path, analysis_duration, progress) | |
if not os.path.exists(audio_for_classification_path): | |
raise Exception("Audio extraction failed: Clipped audio for classification not found.") | |
# Classify the extracted audio for accent | |
result = classify_audio(audio_for_classification_path, progress) | |
if not result: | |
return language_status_html, "<p style='color: red; font-weight: bold;'>β οΈ No accent prediction returned</p>", full_audio_path, True | |
# Build results table for display | |
# Adjusted table width to 'fit-content' and individual column widths | |
table = """ | |
<table style='width: fit-content; max-width: 100%; border-collapse: collapse; font-family: Arial, sans-serif; margin-top: 1em;'> | |
<thead> | |
<tr style='border-bottom: 2px solid #4CAF50; background-color: #f2f2f2;'> | |
<th style='text-align:left; padding: 8px; font-size: 1.1em; color: #333; width: auto; min-width: 50px;'>Rank</th> | |
<th style='text-align:left; padding: 8px; font-size: 1.1em; color: #333; width: auto; min-width: 100px;'>Accent</th> | |
<th style='text-align:left; padding: 8px; font-size: 1.1em; color: #333; width: auto; min-width: 180px;'>Confidence (%)</th> | |
<th style='text-align:left; padding: 8px; font-size: 1.1em; color: #333; width: auto; min-width: 80px;'>Score</th> | |
</tr> | |
</thead> | |
<tbody> | |
""" | |
for i, r in enumerate(result): | |
label = r['label'].capitalize() | |
score = r['score'] | |
score_formatted_percent = f"{score * 100:.2f}%" | |
score_formatted_raw = f"{score:.4f}" | |
if i == 0: | |
row = f""" | |
<tr style='background-color:#d4edda; font-weight: bold; color: #155724;'> | |
<td style='padding: 8px; border-bottom: 1px solid #c3e6cb; width: auto; min-width: 50px;'>#{i+1}</td> | |
<td style='padding: 8px; border-bottom: 1px solid #c3e6cb; width: auto; min-width: 100px;'>{label}</td> | |
<td style='padding: 8px; border-bottom: 1px solid #c3e6cb; width: auto; min-width: 180px;'> | |
<div style='display: flex; align-items: center;'> | |
<span style='width: auto; display: inline-block;'>{score_formatted_percent}</span> | |
<progress value='{score * 100}' max='100' style='width: 100%; margin-left: 10px;'></progress> | |
</div> | |
</td> | |
<td style='padding: 8px; border-bottom: 1px solid #c3e6cb; width: auto; min-width: 80px;'> | |
<span style='width: auto; display: inline-block;'>{score_formatted_raw}</span> | |
</td> | |
</tr> | |
""" | |
else: | |
row = f""" | |
<tr style='color: #333;'> | |
<td style='padding: 8px; border-bottom: 1px solid #ddd; width: auto; min-width: 50px;'>#{i+1}</td> | |
<td style='padding: 8px; border-bottom: 1px solid #ddd; width: auto; min-width: 100px;'>{label}</td> | |
<td style='padding: 8px; border-bottom: 1px solid #ddd; width: auto; min-width: 180px;'> | |
<div style='display: flex; align-items: center;'> | |
<span style='width: auto; display: inline-block;'>{score_formatted_percent}</span> | |
<progress value='{score * 100}' max='100' style='width: 100%; margin-left: 10px;'></progress> | |
</div> | |
</td> | |
<td style='padding: 8px; border-bottom: 1px solid #ddd; width: auto; min-width: 80px;'> | |
<span style='display: inline-block;'>{score_formatted_raw}</span> | |
</td> | |
</tr> | |
""" | |
table += row | |
table += "</tbody></table>" | |
top_result = result[0] | |
html_output = f""" | |
<div style='font-family: Arial, sans-serif;'> | |
<h2 style='color: #2E7D32; margin-bottom: 0.5em;'> | |
π€ Predicted Accent: <span style='font-weight:bold'>{top_result['label'].capitalize()}</span> | |
<span style='font-size: 0.8em; color: #555; font-weight: normal;'> | |
(Confidence: {top_result['score']*100:.2f}%) | |
</span> | |
</h2> | |
{table} | |
</div> | |
""" | |
# Return language status, accent results HTML, full audio path, and no error flag | |
return language_status_html, html_output, full_audio_path, False | |
except Exception as e: | |
# If any error occurs, return an error message and set the error flag | |
return "", f"<p style='color: red; font-weight: bold;'>β οΈ Error: {str(e)}</p>", None, True | |
finally: | |
# Explicitly clean up the temporary directory created for intermediate files. | |
# The full_audio_path is now managed by NamedTemporaryFile and Gradio. | |
if temp_dir and os.path.exists(temp_dir): | |
shutil.rmtree(temp_dir) | |
# Define a custom Gradio theme for improved aesthetics | |
# This theme inherits from the default theme and overrides specific properties. | |
my_theme = gr.themes.Default().set( | |
# Background colors: A light grey for the primary background, white for inner blocks | |
background_fill_primary="#f0f2f5", | |
background_fill_secondary="#ffffff", | |
# Border for a cleaner look | |
border_color_primary="#e0e0e0", | |
# Button styling for a consistent look | |
# Changed primary button color to a darker, muted green | |
button_primary_background_fill="#4CAF50", # A standard green | |
button_primary_background_fill_hover="#66BB6A", # A slightly lighter green on hover | |
button_primary_text_color="#ffffff", # White text for primary buttons | |
# Changed secondary button color to a darker, muted green | |
button_secondary_background_fill="#4CAF50", # A standard green | |
button_secondary_background_fill_hover="#66BB6A", # A slightly lighter green on hover | |
button_secondary_text_color="#ffffff", # White text for secondary buttons | |
# Accent color for sliders and other accent elements | |
color_accent="#2196F3", # Blue for accent elements like sliders | |
color_accent_soft="#BBDEFB", # Lighter blue for soft accent elements | |
) | |
# Gradio app interface definition | |
with gr.Blocks(theme=my_theme) as app: # Apply the custom theme here | |
gr.Markdown(""" | |
<div style='font-family: Arial, sans-serif;'> | |
<h1 style='color: #2E7D32;'>π€ English Accent Classifier</h1> | |
<p>Analyze English accents from either:</p> | |
<ul> | |
<li>A video URL (MP4 or Loom videos)</li> | |
<li>Or upload a video file from your computer</li> | |
</ul> | |
<p>The accent analysis will be performed on the first <strong>60 seconds</strong> of audio by default, after language detection.</p> | |
<p>The analysis may take some time depending on the video size and your chosen analysis duration. Please be patient while we process your video.</p> | |
<p><strong>Supported file formats:</strong> MP4 </p> | |
<p style='font-size: 0.9em; color: #666;'> | |
<strong>Note:</strong> This application requires <a href='https://ffmpeg.org/download.html' target='_blank' style='color: #2E7D32;'>FFmpeg</a> to be installed on your system to process video and audio files. | |
</p> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
url_input = gr.Textbox( | |
label="π Video URL (MP4 or Loom)", | |
placeholder="Paste URL here..." | |
) | |
video_input = gr.File( | |
label="π Upload Video File", | |
file_types=["video"], | |
interactive=True | |
) | |
with gr.Column(scale=1): | |
analysis_duration = gr.Slider( | |
minimum=5, | |
maximum=120, | |
step=5, | |
value=60, | |
label="Accent Analysis Duration (seconds)", | |
info="Analyze the first N seconds of audio for accent classification." | |
) | |
with gr.Row(): | |
submit_btn = gr.Button("Analyze Video", variant="primary") | |
clear_btn = gr.Button("Clear Input") | |
status_box = gr.Textbox( | |
label="Status", | |
placeholder="Waiting for video input...", | |
interactive=False, | |
visible=True | |
) | |
progress_bar = gr.Slider( | |
visible=False, | |
label="Processing Progress", | |
interactive=False | |
) | |
# Placing outputs in a new row to allow for better vertical stacking on smaller screens | |
# and horizontal arrangement on larger screens. | |
with gr.Row(): | |
# Using gr.Column to contain the language status and audio player | |
with gr.Column(scale=1, min_width=300): # Added min_width for better control | |
language_status_html = gr.HTML(label="Language Detection Status", visible=True) | |
audio_player = gr.Audio(label="Extracted Audio (Full Duration)", visible=True) | |
# Using gr.Column for the main results table and error output | |
with gr.Column(scale=2, min_width=400): # Added min_width for better control | |
output_html = gr.HTML() | |
error_output = gr.HTML(visible=False) | |
def unified_processing_fn(video_url, video_file, analysis_duration, progress=Progress()): | |
video_source = video_url if video_url else video_file | |
yield ( | |
gr.Textbox(value="β³ Processing started - please be patient...", visible=True), | |
gr.Slider(visible=True, value=0), | |
gr.HTML(value="", visible=True), # Clear language status | |
gr.HTML(value="", visible=False), # Hide previous HTML output | |
gr.Audio(value=None, visible=True, label="Extracted Audio (Full Duration)"), | |
gr.HTML(value="", visible=False) # Hide previous error output | |
) | |
try: | |
lang_status, html, audio_path, error = process_video_unified(video_source, analysis_duration, progress) | |
if error: | |
yield ( | |
gr.Textbox(value="β Processing failed", visible=True), | |
gr.Slider(visible=False), | |
gr.HTML(value=lang_status, visible=True), | |
gr.HTML(value="", visible=False), | |
gr.Audio(value=audio_path, visible=True, label="Extracted Audio (Full Duration)"), | |
gr.HTML(value=html, visible=True) | |
) | |
else: | |
yield ( | |
gr.Textbox(value="β Analysis complete!", visible=True), | |
gr.Slider(value=1.0, visible=False), | |
gr.HTML(value=lang_status, visible=True), | |
gr.HTML(value=html, visible=True), | |
gr.Audio(value=audio_path, visible=True, label="Extracted Audio (Full Duration)"), | |
gr.HTML(visible=False) | |
) | |
except Exception as e: | |
yield ( | |
gr.Textbox(value="β An unexpected error occurred!", visible=True), | |
gr.Slider(visible=False), | |
gr.HTML(value="", visible=True), | |
gr.HTML(value="", visible=False), | |
gr.Audio(value=None, visible=True, label="Extracted Audio (Full Duration)"), | |
gr.HTML(value=f"<p style='color: red; font-weight: bold;'>β οΈ Unexpected Error: {str(e)}</p>", visible=True) | |
) | |
def clear_inputs(): | |
return ( | |
"", # url_input | |
None, # video_input | |
60, # analysis_duration (reset to default) | |
"Waiting for video input...", # status_box | |
gr.Slider(visible=False, value=0), # progress_bar (hidden and reset) | |
"", # language_status_html (clear) | |
"", # output_html (clear) | |
gr.Audio(visible=True, value=None, label="Extracted Audio (Full Duration)"), | |
"" # error_output (clear) | |
) | |
submit_btn.click( | |
fn=unified_processing_fn, | |
inputs=[url_input, video_input, analysis_duration], | |
outputs=[status_box, progress_bar, language_status_html, output_html, audio_player, error_output], | |
api_name="classify_video" | |
) | |
clear_btn.click( | |
fn=clear_inputs, | |
inputs=[], | |
outputs=[url_input, video_input, analysis_duration, status_box, progress_bar, language_status_html, output_html, audio_player, error_output], | |
) | |
if __name__ == "__main__": | |
app.launch(share=True) | |