import os import gradio as gr import torch import shutil import requests import subprocess import soundfile as sf from scipy.signal import resample from moviepy.editor import VideoFileClip, AudioFileClip from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor, pipeline # === Constants === TEMP_VIDEO = "temp_video.mp4" RAW_AUDIO = "raw_audio_input" CONVERTED_AUDIO = "converted_audio.wav" MODEL_REPO = "ylacombe/accent-classifier" # === load local model MODEL_DIR = "model" model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_DIR, local_files_only=True) feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_DIR) # === Load models === # model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_REPO) # feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_REPO) whisper = pipeline("automatic-speech-recognition", model="openai/whisper-tiny") LABELS = [model.config.id2label[i] for i in range(len(model.config.id2label))] model.eval() # === Helpers === def convert_to_wav(input_path, output_path=CONVERTED_AUDIO): command = ["ffmpeg", "-y", "-i", input_path, output_path] subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) return output_path def extract_audio_from_video(video_path, output_path="extracted_audio.wav"): clip = VideoFileClip(video_path) if clip.audio is None: raise ValueError("No audio stream found in video.") clip.audio.write_audiofile(output_path) return output_path def download_video(url, filename=TEMP_VIDEO): temp_download = "raw_download.mp4" headers = {"User-Agent": "Mozilla/5.0"} r = requests.get(url, headers=headers, stream=True, timeout=15) r.raise_for_status() if not r.headers.get("Content-Type", "").startswith("video/"): raise RuntimeError(f"URL is not a video. Content-Type: {r.headers.get('Content-Type')}") with open(temp_download, 'wb') as f: for chunk in r.iter_content(chunk_size=8192): f.write(chunk) ffmpeg_cmd = [ "ffmpeg", "-y", "-i", temp_download, "-c", "copy", "-movflags", "+faststart", filename ] result = subprocess.run(ffmpeg_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) if result.returncode != 0 or not os.path.exists(filename) or os.path.getsize(filename) == 0: raise RuntimeError("FFmpeg failed to process the video.") os.remove(temp_download) return filename def classify_accent(audio_path): waveform, sr = sf.read(audio_path) if len(waveform.shape) > 1: waveform = waveform.mean(axis=1) if sr != 16000: num_samples = int(len(waveform) * 16000 / sr) waveform = resample(waveform, num_samples) sr = 16000 inputs = feature_extractor(waveform, sampling_rate=sr, return_tensors="pt", padding=True) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits[0] probs = torch.nn.functional.softmax(logits, dim=-1) top_idx = torch.argmax(probs).item() top_label = LABELS[top_idx] top_conf = round(probs[top_idx].item(), 4) top5 = torch.topk(probs, k=5) top5_labels = [LABELS[i] for i in top5.indices.tolist()] top5_scores = [round(p, 4) for p in top5.values.tolist()] top5_text = "\n".join([f"{label}: {score}" for label, score in zip(top5_labels, top5_scores)]) return top_label, top_conf, top5_text def transcribe_audio(audio_path): result = whisper(audio_path, return_timestamps=True) return result.get("text", "").strip() # === Main Handler === def process_input(audio_file, video_file, video_url): try: audio_path = None if audio_file: shutil.copy(audio_file, RAW_AUDIO) audio_path = convert_to_wav(RAW_AUDIO) elif video_file: shutil.copy(video_file, TEMP_VIDEO) extracted = extract_audio_from_video(TEMP_VIDEO, output_path="extracted_audio.wav") audio_path = convert_to_wav(extracted) elif video_url and video_url.strip(): if "loom.com" in video_url: return "Loom links are not supported. Please upload the file or use a direct .mp4 URL.", None, None, None, None, None downloaded = download_video(video_url) extracted = extract_audio_from_video(downloaded, output_path="extracted_audio.wav") audio_path = convert_to_wav(extracted) else: return "Please provide an audio file, a video file, or a direct video URL.", None, None, None, None, None label, confidence, top5 = classify_accent(audio_path) transcription = transcribe_audio(audio_path) return f"Top prediction: {label}", confidence, label, audio_path, top5, transcription except Exception as e: return f"Error: {str(e)}", None, None, None, None, None finally: for f in [TEMP_VIDEO, RAW_AUDIO, CONVERTED_AUDIO, RAW_AUDIO + ".mp4"]: if os.path.exists(f): os.remove(f) # === Gradio Interface === interface = gr.Interface( fn=process_input, inputs=[ gr.Audio(label="Upload MP3 or WAV", type="filepath"), gr.File(label="Upload MP4 Video", type="filepath"), gr.Textbox(label="Paste Direct .mp4 Video URL") ], outputs=[ gr.Text(label="Prediction"), gr.Number(label="Confidence Score"), gr.Text(label="Accent"), gr.Audio(label="Processed Audio", type="filepath"), gr.Text(label="Top 5 Predictions"), gr.Text(label="Transcription") ], title="Accent Classifier + Transcriber", description="Upload an audio or video file OR paste a direct video URL to classify the accent and transcribe the speech." ) if __name__ == "__main__": interface.launch()