Spaces:
Build error
Build error
File size: 5,810 Bytes
5488aaa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import os
import gradio as gr
import torch
import shutil
import requests
import subprocess
import soundfile as sf
from scipy.signal import resample
from moviepy.editor import VideoFileClip, AudioFileClip
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor, pipeline
# === Constants ===
TEMP_VIDEO = "temp_video.mp4"
RAW_AUDIO = "raw_audio_input"
CONVERTED_AUDIO = "converted_audio.wav"
MODEL_REPO = "ylacombe/accent-classifier"
# === load local model
# MODEL_DIR = "model"
# model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_DIR, local_files_only=True)
# feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_DIR)
# === Load models ===
model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_REPO)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_REPO)
whisper = pipeline("automatic-speech-recognition", model="openai/whisper-tiny")
LABELS = [model.config.id2label[i] for i in range(len(model.config.id2label))]
model.eval()
# === Helpers ===
def convert_to_wav(input_path, output_path=CONVERTED_AUDIO):
command = ["ffmpeg", "-y", "-i", input_path, output_path]
subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
return output_path
def extract_audio_from_video(video_path, output_path="extracted_audio.wav"):
clip = VideoFileClip(video_path)
if clip.audio is None:
raise ValueError("No audio stream found in video.")
clip.audio.write_audiofile(output_path)
return output_path
def download_video(url, filename=TEMP_VIDEO):
temp_download = "raw_download.mp4"
headers = {"User-Agent": "Mozilla/5.0"}
r = requests.get(url, headers=headers, stream=True, timeout=15)
r.raise_for_status()
if not r.headers.get("Content-Type", "").startswith("video/"):
raise RuntimeError(f"URL is not a video. Content-Type: {r.headers.get('Content-Type')}")
with open(temp_download, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
ffmpeg_cmd = [
"ffmpeg", "-y", "-i", temp_download,
"-c", "copy", "-movflags", "+faststart", filename
]
result = subprocess.run(ffmpeg_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if result.returncode != 0 or not os.path.exists(filename) or os.path.getsize(filename) == 0:
raise RuntimeError("FFmpeg failed to process the video.")
os.remove(temp_download)
return filename
def classify_accent(audio_path):
waveform, sr = sf.read(audio_path)
if len(waveform.shape) > 1:
waveform = waveform.mean(axis=1)
if sr != 16000:
num_samples = int(len(waveform) * 16000 / sr)
waveform = resample(waveform, num_samples)
sr = 16000
inputs = feature_extractor(waveform, sampling_rate=sr, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits[0]
probs = torch.nn.functional.softmax(logits, dim=-1)
top_idx = torch.argmax(probs).item()
top_label = LABELS[top_idx]
top_conf = round(probs[top_idx].item(), 4)
top5 = torch.topk(probs, k=5)
top5_labels = [LABELS[i] for i in top5.indices.tolist()]
top5_scores = [round(p, 4) for p in top5.values.tolist()]
top5_text = "\n".join([f"{label}: {score}" for label, score in zip(top5_labels, top5_scores)])
return top_label, top_conf, top5_text
def transcribe_audio(audio_path):
result = whisper(audio_path, return_timestamps=True)
return result.get("text", "").strip()
# === Main Handler ===
def process_input(audio_file, video_file, video_url):
try:
audio_path = None
if audio_file:
shutil.copy(audio_file, RAW_AUDIO)
audio_path = convert_to_wav(RAW_AUDIO)
elif video_file:
shutil.copy(video_file, TEMP_VIDEO)
extracted = extract_audio_from_video(TEMP_VIDEO, output_path="extracted_audio.wav")
audio_path = convert_to_wav(extracted)
elif video_url and video_url.strip():
if "loom.com" in video_url:
return "Loom links are not supported. Please upload the file or use a direct .mp4 URL.", None, None, None, None, None
downloaded = download_video(video_url)
extracted = extract_audio_from_video(downloaded, output_path="extracted_audio.wav")
audio_path = convert_to_wav(extracted)
else:
return "Please provide an audio file, a video file, or a direct video URL.", None, None, None, None, None
label, confidence, top5 = classify_accent(audio_path)
transcription = transcribe_audio(audio_path)
return f"Top prediction: {label}", confidence, label, audio_path, top5, transcription
except Exception as e:
return f"Error: {str(e)}", None, None, None, None, None
finally:
for f in [TEMP_VIDEO, RAW_AUDIO, CONVERTED_AUDIO, RAW_AUDIO + ".mp4"]:
if os.path.exists(f):
os.remove(f)
# === Gradio Interface ===
interface = gr.Interface(
fn=process_input,
inputs=[
gr.Audio(label="Upload MP3 or WAV", type="filepath"),
gr.File(label="Upload MP4 Video", type="filepath"),
gr.Textbox(label="Paste Direct .mp4 Video URL")
],
outputs=[
gr.Text(label="Prediction"),
gr.Number(label="Confidence Score"),
gr.Text(label="Accent"),
gr.Audio(label="Processed Audio", type="filepath"),
gr.Text(label="Top 5 Predictions"),
gr.Text(label="Transcription")
],
title="Accent Classifier + Transcriber",
description="Upload an audio or video file OR paste a direct video URL to classify the accent and transcribe the speech."
)
if __name__ == "__main__":
interface.launch()
|