Spaces:
Build error
Build error
File size: 3,651 Bytes
f36e52e 61f17b6 8dbd60e f36e52e 8dbd60e f36e52e 36abecf 3e497df 8dbd60e 0fd738e 3e497df 8dbd60e 0fd738e 3e497df 8dbd60e 0fd738e 3e497df 8dbd60e 36abecf 8dbd60e f36e52e 6c36e37 f36e52e 8dbd60e f36e52e 6c36e37 f36e52e 6c36e37 f36e52e 36abecf 8dbd60e 43f1b5e 8dbd60e 43f1b5e 8dbd60e 43f1b5e f36e52e 36abecf f36e52e 6c36e37 8dbd60e 6c36e37 f36e52e 8dbd60e f36e52e 8dbd60e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import torch
import whisper
import torchaudio as ta
import gradio as gr
from model_utils import get_processor, get_model, get_whisper_model_small, get_device
from config import SAMPLING_RATE, CHUNK_LENGTH_S
import spaces
@spaces.GPU
def load_and_resample_audio(audio):
if isinstance(audio, str): # If audio is a file path
waveform, sample_rate = ta.load(audio)
else: # If audio is already loaded (sample_rate, waveform)
sample_rate, waveform = audio
waveform = torch.tensor(waveform).float()
if sample_rate != SAMPLING_RATE:
waveform = ta.functional.resample(waveform, sample_rate, SAMPLING_RATE)
# Ensure the audio is in the correct shape (mono)
if waveform.dim() > 1 and waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
elif waveform.dim() == 1:
waveform = waveform.unsqueeze(0)
return waveform, SAMPLING_RATE
@spaces.GPU
def detect_language(waveform):
whisper_model = get_whisper_model_small()
# Use Whisper's preprocessing
audio_tensor = whisper.pad_or_trim(waveform.squeeze())
mel = whisper.log_mel_spectrogram(audio_tensor).to(whisper_model.device)
# Detect language
_, probs = whisper_model.detect_language(mel)
detected_lang = max(probs, key=probs.get)
print(f"Audio shape: {audio_tensor.shape}")
print(f"Mel spectrogram shape: {mel.shape}")
print(f"Detected language: {detected_lang}")
print("Language probabilities:", probs)
return detected_lang
@spaces.GPU
def process_long_audio(waveform, sample_rate, task="transcribe", language=None):
input_length = waveform.shape[1]
chunk_length = int(CHUNK_LENGTH_S * sample_rate)
chunks = [waveform[:, i:i + chunk_length] for i in range(0, input_length, chunk_length)]
processor = get_processor()
model = get_model()
device = get_device()
results = []
for chunk in chunks:
input_features = processor(chunk.squeeze(), sampling_rate=sample_rate, return_tensors="pt").input_features.to(
device)
with torch.no_grad():
if task == "translate":
forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task="translate")
generated_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids)
else:
generated_ids = model.generate(input_features)
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
results.extend(transcription)
# Clear GPU cache
torch.cuda.empty_cache()
return " ".join(results)
@spaces.GPU
def process_audio(audio):
if audio is None:
return "No file uploaded", "", ""
waveform, sample_rate = load_and_resample_audio(audio)
detected_lang = detect_language(waveform)
transcription = process_long_audio(waveform, sample_rate, task="transcribe")
translation = process_long_audio(waveform, sample_rate, task="translate", language=detected_lang)
return detected_lang, transcription, translation
# Gradio interface
iface = gr.Interface(
fn=process_audio,
inputs=gr.Audio(),
outputs=[
gr.Textbox(label="Detected Language"),
gr.Textbox(label="Transcription", lines=5),
gr.Textbox(label="Translation", lines=5)
],
title="Audio Transcription and Translation",
description="Upload an audio file to detect its language, transcribe, and translate it.",
allow_flagging="never",
css=".output-textbox { font-family: 'Noto Sans Devanagari', sans-serif; font-size: 18px; }"
)
if __name__ == "__main__":
iface.launch() |