Spaces:
Build error
Build error
import torch | |
import whisper | |
import numpy as np | |
import torchaudio as ta | |
import gradio as gr | |
from model_utils import get_processor, get_model, get_whisper_model_small, get_device | |
from config import SAMPLING_RATE, CHUNK_LENGTH_S | |
import subprocess | |
def resample_with_ffmpeg(input_file, output_file, target_sr=16000): | |
command = [ | |
'ffmpeg', '-i', input_file, '-ar', str(target_sr), output_file | |
] | |
subprocess.run(command, check=True) | |
def detect_language(audio): | |
whisper_model = get_whisper_model_small() | |
# Save the input audio to a temporary file | |
ta.save("input_audio.wav", torch.tensor(audio[1]).unsqueeze(0), audio[0]) | |
# Resample if necessary using ffmpeg | |
if audio[0] != SAMPLING_RATE: | |
resample_with_ffmpeg("input_audio.wav", "resampled_audio.wav", target_sr=SAMPLING_RATE) | |
audio_tensor, _ = ta.load("resampled_audio.wav") | |
else: | |
audio_tensor = torch.tensor(audio[1]).float() | |
# Ensure the audio is in the correct shape (mono) | |
if audio_tensor.dim() == 2: | |
audio_tensor = audio_tensor.mean(dim=0) | |
# Use Whisper's preprocessing | |
audio_tensor = whisper.pad_or_trim(audio_tensor) | |
print(f"Audio length after pad/trim: {audio_tensor.shape[-1] / SAMPLING_RATE} seconds") | |
mel = whisper.log_mel_spectrogram(audio_tensor).to(whisper_model.device) | |
# Detect language | |
_, probs = whisper_model.detect_language(mel) | |
detected_lang = max(probs, key=probs.get) | |
print(f"Audio shape: {audio_tensor.shape}") | |
print(f"Mel spectrogram shape: {mel.shape}") | |
print(f"Detected language: {detected_lang}") | |
print("Language probabilities:", probs) | |
return detected_lang | |
def process_long_audio(audio, task="transcribe", language=None): | |
# ... (rest of the function remains the same) | |
def process_audio(audio): | |
if audio is None: | |
return "No file uploaded", "", "" | |
detected_lang = detect_language(audio) | |
transcription = process_long_audio(audio, task="transcribe") | |
translation = process_long_audio(audio, task="translate", language=detected_lang) | |
return detected_lang, transcription, translation | |
# Gradio interface | |
iface = gr.Interface( | |
fn=process_audio, | |
inputs=gr.Audio(), | |
outputs=[ | |
gr.Textbox(label="Detected Language"), | |
gr.Textbox(label="Transcription", lines=5), | |
gr.Textbox(label="Translation", lines=5) | |
], | |
title="Audio Transcription and Translation", | |
description="Upload an audio file to detect its language, transcribe, and translate it.", | |
allow_flagging="never", | |
css=".output-textbox { font-family: 'Noto Sans Devanagari', sans-serif; font-size: 18px; }" | |
) |