tusker123's picture
Upload app.py
dcdec88 verified
import gradio as gr
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
import requests
import os
from moviepy.editor import VideoFileClip
import tempfile
import re
from urllib.parse import urlparse
from gradio import Progress
from pathlib import Path
import torch
import shutil # Import shutil for explicit temporary directory cleanup
import soundfile as sf # Import soundfile for explicit audio loading
# Load the audio classification model for English accents
pipe = pipeline("audio-classification", model="dima806/english_accents_classification")
# Load the language detection model
language_detector = pipeline("text-classification", model="alexneakameni/language_detection")
# Load a small ASR (Automatic Speech Recognition) model for transcribing audio clips
# This is used to get text from audio for language detection.
# Using 'openai/whisper-tiny.en' for a faster, English-focused transcription.
# Ensure to move model to GPU if available for faster inference.
device = 0 if torch.cuda.is_available() else -1
# Corrected ASR model ID to a valid Hugging Face model
asr_model_id = "openai/whisper-tiny.en"
asr_model = AutoModelForSpeechSeq2Seq.from_pretrained(asr_model_id)
asr_processor = AutoProcessor.from_pretrained(asr_model_id)
asr_pipe = pipeline(
"automatic-speech-recognition",
model=asr_model,
tokenizer=asr_processor.tokenizer,
feature_extractor=asr_processor.feature_extractor,
device=device
)
def is_valid_url(url):
"""
Checks if the given URL is valid and from allowed domains (MP4, Loom, or Google Drive).
Args:
url (str): The URL to validate.
Returns:
bool: True if the URL is valid and allowed, False otherwise.
"""
if not url:
return False
try:
result = urlparse(url)
if not all([result.scheme, result.netloc]):
return False
allowed_domains = [
'loom.com',
'cdn.loom.com',
'www.dropbox.com',
'dl.dropboxusercontent.com',
'drive.google.com' # Added Google Drive domain
]
# Check if the domain is in our allowed list
is_allowed_domain = any(domain in result.netloc.lower() for domain in allowed_domains)
# Check if the path part of the URL ends with .mp4
ends_with_mp4 = result.path.lower().endswith('.mp4')
if is_allowed_domain:
if ends_with_mp4:
return True
elif 'drive.google.com' in result.netloc.lower():
# Check for typical Google Drive patterns for shared files or download links
return '/file/d/' in result.path or '/uc' in result.path
elif any(domain in result.netloc.lower() for domain in ['loom.com', 'cdn.loom.com']):
return True # Allow Loom URLs even if they don't end in .mp4
elif ends_with_mp4:
# Allow direct .mp4 links from other domains if they end with .mp4
return True
return False
except Exception:
return False
def is_valid_file(file_obj):
"""
Checks if the uploaded file object represents a valid video file format.
Args:
file_obj (gr.File): The Gradio file object.
Returns:
bool: True if the file is a supported video format, False otherwise.
"""
if not file_obj:
return False
# Get the file extension from the uploaded file object's name
file_path = file_obj.name
# Check if the file extension is one of the supported video formats
return Path(file_path).suffix.lower() in ['.mp4', '.mov', '.avi', '.mkv']
def download_file(url, save_path, progress=Progress()):
"""
Downloads a video file from a given URL to a specified path.
Raises ValueError if the URL is invalid, ConnectionError if download fails.
Args:
url (str): The URL of the video to download.
save_path (str): The local path to save the downloaded video.
progress (gradio.Progress): Gradio progress tracker for UI updates.
"""
if not is_valid_url(url):
raise ValueError("Invalid URL. Only .mp4 files or Loom videos are accepted.")
response = requests.get(url, stream=True)
# Check if the download was successful (HTTP status code 200)
if response.status_code != 200:
raise ConnectionError(f"Failed to download video (HTTP {response.status_code})")
# Get the total size of the file for progress tracking
total_size = int(response.headers.get('content-length', 0))
downloaded = 0
# Write the downloaded content to the specified save path in chunks
with open(save_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk: # Filter out keep-alive new chunks
f.write(chunk)
downloaded += len(chunk)
if total_size > 0:
# Update progress bar based on downloaded percentage
progress(downloaded / total_size, desc="πŸ“₯ Downloading video...")
else:
# If total size is unknown, just show a general downloading message
progress(0, desc="πŸ“₯ Downloading video (size unknown)...")
def extract_audio_full(video_path, progress=Progress()):
"""
Extracts the full duration of audio from a video file and saves it as a WAV file.
Uses tempfile.NamedTemporaryFile to ensure the file persists for Gradio.
Args:
video_path (str): Path to the input video file.
progress (gradio.Progress): Gradio progress tracker for UI updates.
Returns:
str: The path to the extracted audio file.
"""
try:
progress(0, desc="πŸ”Š Extracting full audio for playback...")
video = VideoFileClip(video_path)
# Create a temporary WAV file that Gradio can manage
temp_audio_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
audio_path = temp_audio_file.name
temp_audio_file.close() # Close the file handle immediately so moviepy can write to it
audio_clip = video.audio
audio_clip.write_audiofile(audio_path, fps=16000, logger=None)
video.close()
audio_clip.close()
progress(1.0)
return audio_path
except Exception as e:
raise Exception(f"Full audio extraction failed: {str(e)}")
def extract_audio_clip(video_path, audio_path, duration, progress=Progress()):
"""
Extracts a specified duration of audio from a video file and saves it as a WAV file.
Args:
video_path (str): Path to the input video file.
audio_path (str): Path to save the extracted audio WAV file.
duration (int): The duration of audio to extract in seconds.
progress (gradio.Progress): Gradio progress tracker for UI updates.
Returns:
str: The path to the extracted audio file.
"""
try:
progress(0, desc=f"πŸ”Š Extracting {duration} seconds of audio for analysis...")
video = VideoFileClip(video_path)
# Ensure the subclip duration does not exceed the video's actual duration
clip_duration = min(duration, video.duration)
audio_clip = video.audio.subclip(0, clip_duration)
audio_clip.write_audiofile(audio_path, fps=16000, logger=None)
video.close()
audio_clip.close()
progress(1.0)
return audio_path
except Exception as e:
raise Exception(f"Audio clip extraction failed: {str(e)}")
def transcribe_audio(audio_path_clip, progress=Progress()):
"""
Transcribes a short audio clip to text using the ASR pipeline.
Args:
audio_path_clip (str): Path to the short audio clip.
Returns:
str: The transcribed text.
"""
try:
progress(0, desc="πŸ“ Transcribing audio for language detection...")
# Load audio using soundfile
audio_input, sampling_rate = sf.read(audio_path_clip)
# Ensure the audio is mono if the model expects it (Whisper typically does)
if audio_input.ndim > 1:
audio_input = audio_input.mean(axis=1) # Convert to mono
# Process audio with the ASR processor
# This handles resampling, padding, and feature extraction to match model requirements
inputs = asr_processor(audio_input, sampling_rate=sampling_rate, return_tensors="pt")
# Move inputs to the correct device
if device != -1:
inputs = {k: v.to(device) for k, v in inputs.items()}
# Generate transcription with the ASR model
with torch.no_grad():
# max_new_tokens can be adjusted based on expected transcription length
# For short clips (15s), 128 is usually more than enough
output_tokens = asr_model.generate(**inputs, max_new_tokens=128)
text = asr_processor.tokenizer.batch_decode(output_tokens, skip_special_tokens=True)[0]
progress(1.0)
return text
except Exception as e:
print(f"Transcription failed: {e}")
return "" # Return empty string on failure
def classify_audio(audio_path, progress=Progress()):
"""
Classifies the accent in an audio file using the pre-loaded Hugging Face pipeline.
Args:
audio_path (str): Path to the input audio file.
Returns:
list: A list of dictionaries containing accent labels and confidence scores.
"""
try:
progress(0, desc="πŸ” Analyzing accent - please be patient...")
result = pipe(audio_path)
progress(1.0) # Mark completion
return result
except Exception as e:
raise Exception(f"Classification failed: {str(e)}")
def process_video_unified(video_source, analysis_duration, progress=Progress()):
"""
Processes either a video URL or an uploaded video file to classify accent.
Includes language detection before accent classification.
Args:
video_source (str or gr.File): The input, either a URL string or a Gradio File object.
analysis_duration (int): The duration of audio to analyze for accent classification in seconds.
progress (gradio.Progress): Gradio progress tracker for UI updates.
Returns:
tuple: (language_status_html, html_output, audio_path, error_flag)
language_status_html (str): HTML string displaying language detection status.
html_output (str): HTML string displaying accent results or error.
audio_path (str or None): Path to extracted full audio if successful, else None.
error_flag (bool): True if an error occurred, False otherwise.
"""
temp_dir = None
full_audio_path = None # Initialize to None
try:
temp_dir = tempfile.mkdtemp() # Create temp dir for intermediate files (video, clipped audio)
video_path = os.path.join(temp_dir, "video.mp4")
# Determine if input is a URL string or an uploaded Gradio File object
if isinstance(video_source, str) and video_source.startswith(('http://', 'https://')):
if not is_valid_url(video_source):
raise ValueError("Invalid URL. Only .mp4 files or Loom videos are accepted.")
download_file(video_source, video_path, progress)
elif hasattr(video_source, 'name'):
if not is_valid_file(video_source):
raise ValueError("Invalid file format. Please upload a video file (MP4)")
with open(video_source.name, 'rb') as src_file:
with open(video_path, 'wb') as dest_file:
dest_file.write(src_file.read())
else:
raise ValueError("Unsupported input type. Please provide a video URL or upload a file.")
# Verify that the video file exists after download/upload
if not os.path.exists(video_path):
raise Exception("Video processing failed: Video file not found after download/upload.")
# Extract full audio for playback using tempfile.NamedTemporaryFile
full_audio_path = extract_audio_full(video_path, progress)
# Extract a short clip for transcription and language detection (e.g., first 15 seconds)
transcription_clip_duration = 15
audio_for_transcription_path = os.path.join(temp_dir, "audio_for_transcription.wav")
extract_audio_clip(video_path, audio_for_transcription_path, transcription_clip_duration, progress)
if not os.path.exists(full_audio_path):
raise Exception("Audio extraction failed: Full audio file not found.")
if not os.path.exists(audio_for_transcription_path):
raise Exception("Audio extraction failed: Clipped audio for transcription not found.")
# Transcribe the short audio clip
transcribed_text = transcribe_audio(audio_for_transcription_path, progress)
if not transcribed_text.strip():
language_status_html = "<p style='color: orange; font-weight: bold;'>⚠️ Could not transcribe audio for language detection. Please ensure audio is clear.</p>"
# If transcription fails, we can't detect language, so we'll proceed with accent classification
# but provide a warning. Or, you could choose to stop here. For now, let's proceed.
else:
# Perform language detection
lang_detection_result = language_detector(transcribed_text)
detected_language = lang_detection_result[0]['label']
lang_confidence = lang_detection_result[0]['score']
# Check if detected language is English or eng_Latn with a reasonable confidence
if (detected_language.lower() == 'english' or detected_language.lower() == 'eng_latn') and lang_confidence > 0.7: # Added 'eng_Latn' check
language_status_html = f"<p style='color: green; font-weight: bold;'>βœ… Verified English Language (Confidence: {lang_confidence*100:.2f}%)</p>"
else:
language_status_html = f"<p style='color: red; font-weight: bold;'>⚠️ Detected language: {detected_language.capitalize()} (Confidence: {lang_confidence*100:.2f}%). Please provide English audio for accent classification.</p>"
# If not English, return early with an error message and skip accent classification
return language_status_html, "", full_audio_path, True # Set error flag to True
# Extract audio clip for accent classification (based on analysis_duration slider)
audio_for_classification_path = os.path.join(temp_dir, "audio_for_classification.wav")
extract_audio_clip(video_path, audio_for_classification_path, analysis_duration, progress)
if not os.path.exists(audio_for_classification_path):
raise Exception("Audio extraction failed: Clipped audio for classification not found.")
# Classify the extracted audio for accent
result = classify_audio(audio_for_classification_path, progress)
if not result:
return language_status_html, "<p style='color: red; font-weight: bold;'>⚠️ No accent prediction returned</p>", full_audio_path, True
# Build results table for display
# Adjusted table width to 'fit-content' and individual column widths
table = """
<table style='width: fit-content; max-width: 100%; border-collapse: collapse; font-family: Arial, sans-serif; margin-top: 1em;'>
<thead>
<tr style='border-bottom: 2px solid #4CAF50; background-color: #f2f2f2;'>
<th style='text-align:left; padding: 8px; font-size: 1.1em; color: #333; width: auto; min-width: 50px;'>Rank</th>
<th style='text-align:left; padding: 8px; font-size: 1.1em; color: #333; width: auto; min-width: 100px;'>Accent</th>
<th style='text-align:left; padding: 8px; font-size: 1.1em; color: #333; width: auto; min-width: 180px;'>Confidence (%)</th>
<th style='text-align:left; padding: 8px; font-size: 1.1em; color: #333; width: auto; min-width: 80px;'>Score</th>
</tr>
</thead>
<tbody>
"""
for i, r in enumerate(result):
label = r['label'].capitalize()
score = r['score']
score_formatted_percent = f"{score * 100:.2f}%"
score_formatted_raw = f"{score:.4f}"
if i == 0:
row = f"""
<tr style='background-color:#d4edda; font-weight: bold; color: #155724;'>
<td style='padding: 8px; border-bottom: 1px solid #c3e6cb; width: auto; min-width: 50px;'>#{i+1}</td>
<td style='padding: 8px; border-bottom: 1px solid #c3e6cb; width: auto; min-width: 100px;'>{label}</td>
<td style='padding: 8px; border-bottom: 1px solid #c3e6cb; width: auto; min-width: 180px;'>
<div style='display: flex; align-items: center;'>
<span style='width: auto; display: inline-block;'>{score_formatted_percent}</span>
<progress value='{score * 100}' max='100' style='width: 100%; margin-left: 10px;'></progress>
</div>
</td>
<td style='padding: 8px; border-bottom: 1px solid #c3e6cb; width: auto; min-width: 80px;'>
<span style='width: auto; display: inline-block;'>{score_formatted_raw}</span>
</td>
</tr>
"""
else:
row = f"""
<tr style='color: #333;'>
<td style='padding: 8px; border-bottom: 1px solid #ddd; width: auto; min-width: 50px;'>#{i+1}</td>
<td style='padding: 8px; border-bottom: 1px solid #ddd; width: auto; min-width: 100px;'>{label}</td>
<td style='padding: 8px; border-bottom: 1px solid #ddd; width: auto; min-width: 180px;'>
<div style='display: flex; align-items: center;'>
<span style='width: auto; display: inline-block;'>{score_formatted_percent}</span>
<progress value='{score * 100}' max='100' style='width: 100%; margin-left: 10px;'></progress>
</div>
</td>
<td style='padding: 8px; border-bottom: 1px solid #ddd; width: auto; min-width: 80px;'>
<span style='display: inline-block;'>{score_formatted_raw}</span>
</td>
</tr>
"""
table += row
table += "</tbody></table>"
top_result = result[0]
html_output = f"""
<div style='font-family: Arial, sans-serif;'>
<h2 style='color: #2E7D32; margin-bottom: 0.5em;'>
🎀 Predicted Accent: <span style='font-weight:bold'>{top_result['label'].capitalize()}</span>
<span style='font-size: 0.8em; color: #555; font-weight: normal;'>
(Confidence: {top_result['score']*100:.2f}%)
</span>
</h2>
{table}
</div>
"""
# Return language status, accent results HTML, full audio path, and no error flag
return language_status_html, html_output, full_audio_path, False
except Exception as e:
# If any error occurs, return an error message and set the error flag
return "", f"<p style='color: red; font-weight: bold;'>⚠️ Error: {str(e)}</p>", None, True
finally:
# Explicitly clean up the temporary directory created for intermediate files.
# The full_audio_path is now managed by NamedTemporaryFile and Gradio.
if temp_dir and os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
# Define a custom Gradio theme for improved aesthetics
# This theme inherits from the default theme and overrides specific properties.
my_theme = gr.themes.Default().set(
# Background colors: A light grey for the primary background, white for inner blocks
background_fill_primary="#f0f2f5",
background_fill_secondary="#ffffff",
# Border for a cleaner look
border_color_primary="#e0e0e0",
# Button styling for a consistent look
# Changed primary button color to a darker, muted green
button_primary_background_fill="#4CAF50", # A standard green
button_primary_background_fill_hover="#66BB6A", # A slightly lighter green on hover
button_primary_text_color="#ffffff", # White text for primary buttons
# Changed secondary button color to a darker, muted green
button_secondary_background_fill="#4CAF50", # A standard green
button_secondary_background_fill_hover="#66BB6A", # A slightly lighter green on hover
button_secondary_text_color="#ffffff", # White text for secondary buttons
# Accent color for sliders and other accent elements
color_accent="#2196F3", # Blue for accent elements like sliders
color_accent_soft="#BBDEFB", # Lighter blue for soft accent elements
)
# Gradio app interface definition
with gr.Blocks(theme=my_theme) as app: # Apply the custom theme here
gr.Markdown("""
<div style='font-family: Arial, sans-serif;'>
<h1 style='color: #2E7D32;'>🎀 English Accent Classifier</h1>
<p>Analyze English accents from either:</p>
<ul>
<li>A video URL (MP4 or Loom videos)</li>
<li>Or upload a video file from your computer</li>
</ul>
<p>The accent analysis will be performed on the first <strong>60 seconds</strong> of audio by default, after language detection.</p>
<p>The analysis may take some time depending on the video size and your chosen analysis duration. Please be patient while we process your video.</p>
<p><strong>Supported file formats:</strong> MP4 </p>
<p style='font-size: 0.9em; color: #666;'>
<strong>Note:</strong> This application requires <a href='https://ffmpeg.org/download.html' target='_blank' style='color: #2E7D32;'>FFmpeg</a> to be installed on your system to process video and audio files.
</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
url_input = gr.Textbox(
label="πŸ”— Video URL (MP4 or Loom)",
placeholder="Paste URL here..."
)
video_input = gr.File(
label="πŸ“ Upload Video File",
file_types=["video"],
interactive=True
)
with gr.Column(scale=1):
analysis_duration = gr.Slider(
minimum=5,
maximum=120,
step=5,
value=60,
label="Accent Analysis Duration (seconds)",
info="Analyze the first N seconds of audio for accent classification."
)
with gr.Row():
submit_btn = gr.Button("Analyze Video", variant="primary")
clear_btn = gr.Button("Clear Input")
status_box = gr.Textbox(
label="Status",
placeholder="Waiting for video input...",
interactive=False,
visible=True
)
progress_bar = gr.Slider(
visible=False,
label="Processing Progress",
interactive=False
)
# Placing outputs in a new row to allow for better vertical stacking on smaller screens
# and horizontal arrangement on larger screens.
with gr.Row():
# Using gr.Column to contain the language status and audio player
with gr.Column(scale=1, min_width=300): # Added min_width for better control
language_status_html = gr.HTML(label="Language Detection Status", visible=True)
audio_player = gr.Audio(label="Extracted Audio (Full Duration)", visible=True)
# Using gr.Column for the main results table and error output
with gr.Column(scale=2, min_width=400): # Added min_width for better control
output_html = gr.HTML()
error_output = gr.HTML(visible=False)
def unified_processing_fn(video_url, video_file, analysis_duration, progress=Progress()):
video_source = video_url if video_url else video_file
yield (
gr.Textbox(value="⏳ Processing started - please be patient...", visible=True),
gr.Slider(visible=True, value=0),
gr.HTML(value="", visible=True), # Clear language status
gr.HTML(value="", visible=False), # Hide previous HTML output
gr.Audio(value=None, visible=True, label="Extracted Audio (Full Duration)"),
gr.HTML(value="", visible=False) # Hide previous error output
)
try:
lang_status, html, audio_path, error = process_video_unified(video_source, analysis_duration, progress)
if error:
yield (
gr.Textbox(value="❌ Processing failed", visible=True),
gr.Slider(visible=False),
gr.HTML(value=lang_status, visible=True),
gr.HTML(value="", visible=False),
gr.Audio(value=audio_path, visible=True, label="Extracted Audio (Full Duration)"),
gr.HTML(value=html, visible=True)
)
else:
yield (
gr.Textbox(value="βœ… Analysis complete!", visible=True),
gr.Slider(value=1.0, visible=False),
gr.HTML(value=lang_status, visible=True),
gr.HTML(value=html, visible=True),
gr.Audio(value=audio_path, visible=True, label="Extracted Audio (Full Duration)"),
gr.HTML(visible=False)
)
except Exception as e:
yield (
gr.Textbox(value="❌ An unexpected error occurred!", visible=True),
gr.Slider(visible=False),
gr.HTML(value="", visible=True),
gr.HTML(value="", visible=False),
gr.Audio(value=None, visible=True, label="Extracted Audio (Full Duration)"),
gr.HTML(value=f"<p style='color: red; font-weight: bold;'>⚠️ Unexpected Error: {str(e)}</p>", visible=True)
)
def clear_inputs():
return (
"", # url_input
None, # video_input
60, # analysis_duration (reset to default)
"Waiting for video input...", # status_box
gr.Slider(visible=False, value=0), # progress_bar (hidden and reset)
"", # language_status_html (clear)
"", # output_html (clear)
gr.Audio(visible=True, value=None, label="Extracted Audio (Full Duration)"),
"" # error_output (clear)
)
submit_btn.click(
fn=unified_processing_fn,
inputs=[url_input, video_input, analysis_duration],
outputs=[status_box, progress_bar, language_status_html, output_html, audio_player, error_output],
api_name="classify_video"
)
clear_btn.click(
fn=clear_inputs,
inputs=[],
outputs=[url_input, video_input, analysis_duration, status_box, progress_bar, language_status_html, output_html, audio_player, error_output],
)
if __name__ == "__main__":
app.launch(share=True)