Spaces:
Sleeping
Sleeping
File size: 2,204 Bytes
ab3f8fd 47bcf45 ab3f8fd 47bcf45 ab3f8fd 47bcf45 ab3f8fd 47bcf45 ab3f8fd 47bcf45 ab3f8fd 47bcf45 ab3f8fd 47bcf45 ab3f8fd 47bcf45 ab3f8fd 47bcf45 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
"""
Audio Captioning Model
This script implements an audio captioning model based on the Effb2-Trm architecture.
It uses a pre-trained model to generate captions for audio inputs.
The original implementation is based on:
https://github.com/wsntxxn/Effb2-Trm-AudioCaps-Captioning/
"""
from functools import partial
import gradio as gr
import spaces
import torch
from torchaudio.functional import resample
from transformers import AutoModel, PreTrainedTokenizerFast
from hf_wrapper import Effb2TrmConfig, Effb2TrmCaptioningModel
# Load the configuration
config = Effb2TrmConfig.from_pretrained("config.json")
# Load the model
model = Effb2TrmCaptioningModel(config)
# Load the state dict from the local pytorch_model.bin file
state_dict = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(state_dict)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Move the model to the appropriate device
model = model.to(device)
tokenizer = PreTrainedTokenizerFast.from_pretrained(
"wsntxxn/audiocaps-simple-tokenizer"
)
target_sr = model.config.sample_rate
@spaces.GPU
def infer(input_audio):
sr, wav = input_audio
wav = torch.as_tensor(wav)
if wav.dtype == torch.short:
wav = wav / 2 ** 15
elif wav.dtype == torch.int:
wav = wav / 2 ** 31
if wav.ndim > 1:
wav = wav.mean(1)
wav = resample(wav, sr, target_sr)
wav_len = len(wav)
wav = wav.float().unsqueeze(0)
with torch.no_grad():
word_idx = model(
audio=wav,
audio_length=[wav_len]
)[0]
cap = tokenizer.decode(word_idx, skip_special_tokens=True)
return cap
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown("# Lightweight Audio Captioning")
with gr.Row():
gr.Markdown("""
Audio Captioning Demo
""")
with gr.Row():
with gr.Column():
file = gr.Audio(label="Input", visible=True)
btn = gr.Button("Run")
with gr.Column():
output = gr.Textbox(label="Output")
btn.click(
fn=partial(infer),
inputs=[file,],
outputs=output
)
demo.launch()
|