File size: 3,651 Bytes
f36e52e
 
61f17b6
8dbd60e
f36e52e
 
8dbd60e
f36e52e
36abecf
3e497df
8dbd60e
 
 
 
 
 
0fd738e
3e497df
8dbd60e
0fd738e
3e497df
 
 
8dbd60e
 
0fd738e
3e497df
 
8dbd60e
36abecf
8dbd60e
f36e52e
6c36e37
f36e52e
8dbd60e
f36e52e
6c36e37
f36e52e
 
 
6c36e37
f36e52e
 
 
 
 
 
 
36abecf
 
8dbd60e
 
 
43f1b5e
8dbd60e
43f1b5e
 
 
 
 
 
 
8dbd60e
 
43f1b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f36e52e
36abecf
f36e52e
 
 
6c36e37
8dbd60e
 
 
 
 
6c36e37
f36e52e
 
8dbd60e
f36e52e
 
 
 
 
 
 
 
 
 
 
 
 
8dbd60e
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
import whisper
import torchaudio as ta
import gradio as gr
from model_utils import get_processor, get_model, get_whisper_model_small, get_device
from config import SAMPLING_RATE, CHUNK_LENGTH_S
import spaces


@spaces.GPU
def load_and_resample_audio(audio):
    if isinstance(audio, str):  # If audio is a file path
        waveform, sample_rate = ta.load(audio)
    else:  # If audio is already loaded (sample_rate, waveform)
        sample_rate, waveform = audio
        waveform = torch.tensor(waveform).float()

    if sample_rate != SAMPLING_RATE:
        waveform = ta.functional.resample(waveform, sample_rate, SAMPLING_RATE)

    # Ensure the audio is in the correct shape (mono)
    if waveform.dim() > 1 and waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)
    elif waveform.dim() == 1:
        waveform = waveform.unsqueeze(0)

    return waveform, SAMPLING_RATE


@spaces.GPU
def detect_language(waveform):
    whisper_model = get_whisper_model_small()

    # Use Whisper's preprocessing
    audio_tensor = whisper.pad_or_trim(waveform.squeeze())
    mel = whisper.log_mel_spectrogram(audio_tensor).to(whisper_model.device)

    # Detect language
    _, probs = whisper_model.detect_language(mel)
    detected_lang = max(probs, key=probs.get)

    print(f"Audio shape: {audio_tensor.shape}")
    print(f"Mel spectrogram shape: {mel.shape}")
    print(f"Detected language: {detected_lang}")
    print("Language probabilities:", probs)

    return detected_lang


@spaces.GPU
def process_long_audio(waveform, sample_rate, task="transcribe", language=None):
    input_length = waveform.shape[1]
    chunk_length = int(CHUNK_LENGTH_S * sample_rate)

    chunks = [waveform[:, i:i + chunk_length] for i in range(0, input_length, chunk_length)]

    processor = get_processor()
    model = get_model()
    device = get_device()

    results = []
    for chunk in chunks:
        input_features = processor(chunk.squeeze(), sampling_rate=sample_rate, return_tensors="pt").input_features.to(
            device)

        with torch.no_grad():
            if task == "translate":
                forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task="translate")
                generated_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids)
            else:
                generated_ids = model.generate(input_features)

        transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
        results.extend(transcription)

        # Clear GPU cache
        torch.cuda.empty_cache()

    return " ".join(results)


@spaces.GPU
def process_audio(audio):
    if audio is None:
        return "No file uploaded", "", ""

    waveform, sample_rate = load_and_resample_audio(audio)

    detected_lang = detect_language(waveform)
    transcription = process_long_audio(waveform, sample_rate, task="transcribe")
    translation = process_long_audio(waveform, sample_rate, task="translate", language=detected_lang)

    return detected_lang, transcription, translation


# Gradio interface
iface = gr.Interface(
    fn=process_audio,
    inputs=gr.Audio(),
    outputs=[
        gr.Textbox(label="Detected Language"),
        gr.Textbox(label="Transcription", lines=5),
        gr.Textbox(label="Translation", lines=5)
    ],
    title="Audio Transcription and Translation",
    description="Upload an audio file to detect its language, transcribe, and translate it.",
    allow_flagging="never",
    css=".output-textbox { font-family: 'Noto Sans Devanagari', sans-serif; font-size: 18px; }"
)

if __name__ == "__main__":
    iface.launch()