File size: 2,204 Bytes
ab3f8fd
 
 
 
 
 
 
 
 
 
 
47bcf45
 
 
 
 
 
ab3f8fd
47bcf45
ab3f8fd
 
47bcf45
ab3f8fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47bcf45
 
ab3f8fd
 
47bcf45
 
 
 
 
 
 
ab3f8fd
47bcf45
 
 
ab3f8fd
47bcf45
 
 
ab3f8fd
47bcf45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab3f8fd
47bcf45
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""
Audio Captioning Model

This script implements an audio captioning model based on the Effb2-Trm architecture.
It uses a pre-trained model to generate captions for audio inputs.

The original implementation is based on:
https://github.com/wsntxxn/Effb2-Trm-AudioCaps-Captioning/

"""

from functools import partial
import gradio as gr
import spaces
import torch
from torchaudio.functional import resample
from transformers import AutoModel, PreTrainedTokenizerFast
from hf_wrapper import Effb2TrmConfig, Effb2TrmCaptioningModel

# Load the configuration
config = Effb2TrmConfig.from_pretrained("config.json")

# Load the model
model = Effb2TrmCaptioningModel(config)

# Load the state dict from the local pytorch_model.bin file
state_dict = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(state_dict)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move the model to the appropriate device
model = model.to(device)
tokenizer = PreTrainedTokenizerFast.from_pretrained(
    "wsntxxn/audiocaps-simple-tokenizer"
)
target_sr = model.config.sample_rate

@spaces.GPU
def infer(input_audio):
    sr, wav = input_audio
    wav = torch.as_tensor(wav)
    if wav.dtype == torch.short:
        wav = wav / 2 ** 15
    elif wav.dtype == torch.int:
        wav = wav / 2 ** 31
    if wav.ndim > 1:
        wav = wav.mean(1)
    wav = resample(wav, sr, target_sr)
    wav_len = len(wav)
    wav = wav.float().unsqueeze(0)
    with torch.no_grad():
        word_idx = model(
            audio=wav,
            audio_length=[wav_len]
        )[0]
        cap = tokenizer.decode(word_idx, skip_special_tokens=True)
    return cap


with gr.Blocks() as demo:
    with gr.Row():
        gr.Markdown("# Lightweight Audio Captioning")

    with gr.Row():
        gr.Markdown("""
            Audio Captioning Demo
        """)
    with gr.Row():
        with gr.Column():
            file = gr.Audio(label="Input", visible=True)
            btn = gr.Button("Run")
        with gr.Column():
            output = gr.Textbox(label="Output")
        btn.click(
            fn=partial(infer),
            inputs=[file,],
            outputs=output
        )
    demo.launch()