accent-classifier / local.py
usamaijaz-ai's picture
initial commit
5488aaa
import os
import gradio as gr
import torch
import shutil
import requests
import subprocess
import soundfile as sf
from scipy.signal import resample
from moviepy.editor import VideoFileClip, AudioFileClip
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor, pipeline
# === Constants ===
TEMP_VIDEO = "temp_video.mp4"
RAW_AUDIO = "raw_audio_input"
CONVERTED_AUDIO = "converted_audio.wav"
MODEL_REPO = "ylacombe/accent-classifier"
# === load local model
MODEL_DIR = "model"
model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_DIR, local_files_only=True)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_DIR)
# === Load models ===
# model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_REPO)
# feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_REPO)
whisper = pipeline("automatic-speech-recognition", model="openai/whisper-tiny")
LABELS = [model.config.id2label[i] for i in range(len(model.config.id2label))]
model.eval()
# === Helpers ===
def convert_to_wav(input_path, output_path=CONVERTED_AUDIO):
command = ["ffmpeg", "-y", "-i", input_path, output_path]
subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
return output_path
def extract_audio_from_video(video_path, output_path="extracted_audio.wav"):
clip = VideoFileClip(video_path)
if clip.audio is None:
raise ValueError("No audio stream found in video.")
clip.audio.write_audiofile(output_path)
return output_path
def download_video(url, filename=TEMP_VIDEO):
temp_download = "raw_download.mp4"
headers = {"User-Agent": "Mozilla/5.0"}
r = requests.get(url, headers=headers, stream=True, timeout=15)
r.raise_for_status()
if not r.headers.get("Content-Type", "").startswith("video/"):
raise RuntimeError(f"URL is not a video. Content-Type: {r.headers.get('Content-Type')}")
with open(temp_download, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
ffmpeg_cmd = [
"ffmpeg", "-y", "-i", temp_download,
"-c", "copy", "-movflags", "+faststart", filename
]
result = subprocess.run(ffmpeg_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if result.returncode != 0 or not os.path.exists(filename) or os.path.getsize(filename) == 0:
raise RuntimeError("FFmpeg failed to process the video.")
os.remove(temp_download)
return filename
def classify_accent(audio_path):
waveform, sr = sf.read(audio_path)
if len(waveform.shape) > 1:
waveform = waveform.mean(axis=1)
if sr != 16000:
num_samples = int(len(waveform) * 16000 / sr)
waveform = resample(waveform, num_samples)
sr = 16000
inputs = feature_extractor(waveform, sampling_rate=sr, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits[0]
probs = torch.nn.functional.softmax(logits, dim=-1)
top_idx = torch.argmax(probs).item()
top_label = LABELS[top_idx]
top_conf = round(probs[top_idx].item(), 4)
top5 = torch.topk(probs, k=5)
top5_labels = [LABELS[i] for i in top5.indices.tolist()]
top5_scores = [round(p, 4) for p in top5.values.tolist()]
top5_text = "\n".join([f"{label}: {score}" for label, score in zip(top5_labels, top5_scores)])
return top_label, top_conf, top5_text
def transcribe_audio(audio_path):
result = whisper(audio_path, return_timestamps=True)
return result.get("text", "").strip()
# === Main Handler ===
def process_input(audio_file, video_file, video_url):
try:
audio_path = None
if audio_file:
shutil.copy(audio_file, RAW_AUDIO)
audio_path = convert_to_wav(RAW_AUDIO)
elif video_file:
shutil.copy(video_file, TEMP_VIDEO)
extracted = extract_audio_from_video(TEMP_VIDEO, output_path="extracted_audio.wav")
audio_path = convert_to_wav(extracted)
elif video_url and video_url.strip():
if "loom.com" in video_url:
return "Loom links are not supported. Please upload the file or use a direct .mp4 URL.", None, None, None, None, None
downloaded = download_video(video_url)
extracted = extract_audio_from_video(downloaded, output_path="extracted_audio.wav")
audio_path = convert_to_wav(extracted)
else:
return "Please provide an audio file, a video file, or a direct video URL.", None, None, None, None, None
label, confidence, top5 = classify_accent(audio_path)
transcription = transcribe_audio(audio_path)
return f"Top prediction: {label}", confidence, label, audio_path, top5, transcription
except Exception as e:
return f"Error: {str(e)}", None, None, None, None, None
finally:
for f in [TEMP_VIDEO, RAW_AUDIO, CONVERTED_AUDIO, RAW_AUDIO + ".mp4"]:
if os.path.exists(f):
os.remove(f)
# === Gradio Interface ===
interface = gr.Interface(
fn=process_input,
inputs=[
gr.Audio(label="Upload MP3 or WAV", type="filepath"),
gr.File(label="Upload MP4 Video", type="filepath"),
gr.Textbox(label="Paste Direct .mp4 Video URL")
],
outputs=[
gr.Text(label="Prediction"),
gr.Number(label="Confidence Score"),
gr.Text(label="Accent"),
gr.Audio(label="Processed Audio", type="filepath"),
gr.Text(label="Top 5 Predictions"),
gr.Text(label="Transcription")
],
title="Accent Classifier + Transcriber",
description="Upload an audio or video file OR paste a direct video URL to classify the accent and transcribe the speech."
)
if __name__ == "__main__":
interface.launch()