File size: 3,074 Bytes
6065472
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from pathlib import Path
import argparse
from functools import partial
import gradio as gr
import torch
from torchaudio.functional import resample

import utils.train_util as train_util


def load_model(cfg,
               ckpt_path,
               device):
    model = train_util.init_model_from_config(cfg["model"])
    ckpt = torch.load(ckpt_path, "cpu")
    train_util.load_pretrained_model(model, ckpt)
    model.eval()
    model = model.to(device)
    tokenizer = train_util.init_obj_from_dict(cfg["tokenizer"])
    if not tokenizer.loaded:
        tokenizer.load_state_dict(ckpt["tokenizer"])
    model.set_index(tokenizer.bos, tokenizer.eos, tokenizer.pad)
    return model, tokenizer


def infer(file, device, model, tokenizer, target_sr):
    sr, wav = file
    wav = torch.as_tensor(wav)
    if wav.dtype == torch.short:
        wav = wav / 2 ** 15
    elif wav.dtype == torch.int:
        wav = wav / 2 ** 31
    if wav.ndim > 1:
        wav = wav.mean(1)
    wav = resample(wav, sr, target_sr)
    wav_len = len(wav)
    wav = wav.float().unsqueeze(0).to(device)
    input_dict = {
        "mode": "inference",
        "wav": wav,
        "wav_len": [wav_len],
        "specaug": False,
        "sample_method": "beam",
        "beam_size": 3,
    }
    with torch.no_grad():
        output_dict = model(input_dict)
        seq = output_dict["seq"].cpu().numpy()
        cap = tokenizer.decode(seq)[0]
    return cap

# def input_toggle(input_type):
#     if input_type == "file":
#         return gr.update(visible=True), gr.update(visible=False)
#     elif input_type == "mic":
#         return gr.update(visible=False), gr.update(visible=True)


if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--share", action="store_true", default=False)

    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    exp_dir = Path("./checkpoints/audiocaps")
    cfg = train_util.load_config(exp_dir / "config.yaml")
    target_sr = cfg["target_sr"]
    model, tokenizer = load_model(cfg, exp_dir / "ckpt.pth", device)

    with gr.Blocks() as demo:
        with gr.Row():
            with gr.Column():
                # radio = gr.Radio(
                #     ["file", "mic"],
                #     value="file",
                #     label="Select input type"
                # )
                file = gr.Audio(label="Input", visible=True)
                # mic = gr.Microphone(label="Input", visible=False)
                # radio.change(fn=input_toggle, inputs=radio, outputs=[file, mic])
                btn = gr.Button("Run")
            with gr.Column():
                output = gr.Textbox(label="Output")
            btn.click(
                fn=partial(infer,
                           device=device,
                           model=model,
                           tokenizer=tokenizer,
                           target_sr=target_sr),
                inputs=[file,],
                outputs=output
            )
        
        demo.launch(share=args.share)