File size: 3,074 Bytes
6065472 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
from pathlib import Path
import argparse
from functools import partial
import gradio as gr
import torch
from torchaudio.functional import resample
import utils.train_util as train_util
def load_model(cfg,
ckpt_path,
device):
model = train_util.init_model_from_config(cfg["model"])
ckpt = torch.load(ckpt_path, "cpu")
train_util.load_pretrained_model(model, ckpt)
model.eval()
model = model.to(device)
tokenizer = train_util.init_obj_from_dict(cfg["tokenizer"])
if not tokenizer.loaded:
tokenizer.load_state_dict(ckpt["tokenizer"])
model.set_index(tokenizer.bos, tokenizer.eos, tokenizer.pad)
return model, tokenizer
def infer(file, device, model, tokenizer, target_sr):
sr, wav = file
wav = torch.as_tensor(wav)
if wav.dtype == torch.short:
wav = wav / 2 ** 15
elif wav.dtype == torch.int:
wav = wav / 2 ** 31
if wav.ndim > 1:
wav = wav.mean(1)
wav = resample(wav, sr, target_sr)
wav_len = len(wav)
wav = wav.float().unsqueeze(0).to(device)
input_dict = {
"mode": "inference",
"wav": wav,
"wav_len": [wav_len],
"specaug": False,
"sample_method": "beam",
"beam_size": 3,
}
with torch.no_grad():
output_dict = model(input_dict)
seq = output_dict["seq"].cpu().numpy()
cap = tokenizer.decode(seq)[0]
return cap
# def input_toggle(input_type):
# if input_type == "file":
# return gr.update(visible=True), gr.update(visible=False)
# elif input_type == "mic":
# return gr.update(visible=False), gr.update(visible=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--share", action="store_true", default=False)
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
exp_dir = Path("./checkpoints/audiocaps")
cfg = train_util.load_config(exp_dir / "config.yaml")
target_sr = cfg["target_sr"]
model, tokenizer = load_model(cfg, exp_dir / "ckpt.pth", device)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
# radio = gr.Radio(
# ["file", "mic"],
# value="file",
# label="Select input type"
# )
file = gr.Audio(label="Input", visible=True)
# mic = gr.Microphone(label="Input", visible=False)
# radio.change(fn=input_toggle, inputs=radio, outputs=[file, mic])
btn = gr.Button("Run")
with gr.Column():
output = gr.Textbox(label="Output")
btn.click(
fn=partial(infer,
device=device,
model=model,
tokenizer=tokenizer,
target_sr=target_sr),
inputs=[file,],
outputs=output
)
demo.launch(share=args.share)
|