Spaces:
Build error
Build error
import os | |
import gradio as gr | |
import torch | |
import shutil | |
import requests | |
import subprocess | |
import soundfile as sf | |
from scipy.signal import resample | |
from moviepy.editor import VideoFileClip, AudioFileClip | |
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor, pipeline | |
# === Constants === | |
TEMP_VIDEO = "temp_video.mp4" | |
RAW_AUDIO = "raw_audio_input" | |
CONVERTED_AUDIO = "converted_audio.wav" | |
MODEL_REPO = "ylacombe/accent-classifier" | |
# === load local model | |
MODEL_DIR = "model" | |
model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_DIR, local_files_only=True) | |
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_DIR) | |
# === Load models === | |
# model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_REPO) | |
# feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_REPO) | |
whisper = pipeline("automatic-speech-recognition", model="openai/whisper-tiny") | |
LABELS = [model.config.id2label[i] for i in range(len(model.config.id2label))] | |
model.eval() | |
# === Helpers === | |
def convert_to_wav(input_path, output_path=CONVERTED_AUDIO): | |
command = ["ffmpeg", "-y", "-i", input_path, output_path] | |
subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
return output_path | |
def extract_audio_from_video(video_path, output_path="extracted_audio.wav"): | |
clip = VideoFileClip(video_path) | |
if clip.audio is None: | |
raise ValueError("No audio stream found in video.") | |
clip.audio.write_audiofile(output_path) | |
return output_path | |
def download_video(url, filename=TEMP_VIDEO): | |
temp_download = "raw_download.mp4" | |
headers = {"User-Agent": "Mozilla/5.0"} | |
r = requests.get(url, headers=headers, stream=True, timeout=15) | |
r.raise_for_status() | |
if not r.headers.get("Content-Type", "").startswith("video/"): | |
raise RuntimeError(f"URL is not a video. Content-Type: {r.headers.get('Content-Type')}") | |
with open(temp_download, 'wb') as f: | |
for chunk in r.iter_content(chunk_size=8192): | |
f.write(chunk) | |
ffmpeg_cmd = [ | |
"ffmpeg", "-y", "-i", temp_download, | |
"-c", "copy", "-movflags", "+faststart", filename | |
] | |
result = subprocess.run(ffmpeg_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
if result.returncode != 0 or not os.path.exists(filename) or os.path.getsize(filename) == 0: | |
raise RuntimeError("FFmpeg failed to process the video.") | |
os.remove(temp_download) | |
return filename | |
def classify_accent(audio_path): | |
waveform, sr = sf.read(audio_path) | |
if len(waveform.shape) > 1: | |
waveform = waveform.mean(axis=1) | |
if sr != 16000: | |
num_samples = int(len(waveform) * 16000 / sr) | |
waveform = resample(waveform, num_samples) | |
sr = 16000 | |
inputs = feature_extractor(waveform, sampling_rate=sr, return_tensors="pt", padding=True) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits[0] | |
probs = torch.nn.functional.softmax(logits, dim=-1) | |
top_idx = torch.argmax(probs).item() | |
top_label = LABELS[top_idx] | |
top_conf = round(probs[top_idx].item(), 4) | |
top5 = torch.topk(probs, k=5) | |
top5_labels = [LABELS[i] for i in top5.indices.tolist()] | |
top5_scores = [round(p, 4) for p in top5.values.tolist()] | |
top5_text = "\n".join([f"{label}: {score}" for label, score in zip(top5_labels, top5_scores)]) | |
return top_label, top_conf, top5_text | |
def transcribe_audio(audio_path): | |
result = whisper(audio_path, return_timestamps=True) | |
return result.get("text", "").strip() | |
# === Main Handler === | |
def process_input(audio_file, video_file, video_url): | |
try: | |
audio_path = None | |
if audio_file: | |
shutil.copy(audio_file, RAW_AUDIO) | |
audio_path = convert_to_wav(RAW_AUDIO) | |
elif video_file: | |
shutil.copy(video_file, TEMP_VIDEO) | |
extracted = extract_audio_from_video(TEMP_VIDEO, output_path="extracted_audio.wav") | |
audio_path = convert_to_wav(extracted) | |
elif video_url and video_url.strip(): | |
if "loom.com" in video_url: | |
return "Loom links are not supported. Please upload the file or use a direct .mp4 URL.", None, None, None, None, None | |
downloaded = download_video(video_url) | |
extracted = extract_audio_from_video(downloaded, output_path="extracted_audio.wav") | |
audio_path = convert_to_wav(extracted) | |
else: | |
return "Please provide an audio file, a video file, or a direct video URL.", None, None, None, None, None | |
label, confidence, top5 = classify_accent(audio_path) | |
transcription = transcribe_audio(audio_path) | |
return f"Top prediction: {label}", confidence, label, audio_path, top5, transcription | |
except Exception as e: | |
return f"Error: {str(e)}", None, None, None, None, None | |
finally: | |
for f in [TEMP_VIDEO, RAW_AUDIO, CONVERTED_AUDIO, RAW_AUDIO + ".mp4"]: | |
if os.path.exists(f): | |
os.remove(f) | |
# === Gradio Interface === | |
interface = gr.Interface( | |
fn=process_input, | |
inputs=[ | |
gr.Audio(label="Upload MP3 or WAV", type="filepath"), | |
gr.File(label="Upload MP4 Video", type="filepath"), | |
gr.Textbox(label="Paste Direct .mp4 Video URL") | |
], | |
outputs=[ | |
gr.Text(label="Prediction"), | |
gr.Number(label="Confidence Score"), | |
gr.Text(label="Accent"), | |
gr.Audio(label="Processed Audio", type="filepath"), | |
gr.Text(label="Top 5 Predictions"), | |
gr.Text(label="Transcription") | |
], | |
title="Accent Classifier + Transcriber", | |
description="Upload an audio or video file OR paste a direct video URL to classify the accent and transcribe the speech." | |
) | |
if __name__ == "__main__": | |
interface.launch() | |