Spaces:
Build error
Build error
import os | |
os.system('cd TimeSformer;' | |
'pip install .; cd ..') | |
os.system('ls -l') | |
os.system('pwd') | |
import os, sys | |
sys.path.append("/home/user/app/TimeSformer/") | |
import timesformer | |
import torch | |
from torchvision import transforms | |
from transformers import AutoTokenizer | |
from PIL import Image | |
import json | |
import os | |
from torchvision import transforms | |
from models.epalm import ePALM | |
import os | |
from transformers import AutoTokenizer | |
# import ruamel_yaml as yaml | |
from ruamel.yaml import YAML | |
import torch | |
import gradio as gr | |
import torchaudio | |
yaml=YAML(typ='safe') | |
use_cuda = torch.cuda.is_available() | |
device = torch.device('cuda') if use_cuda else torch.device('cpu') | |
device_type = 'cuda' if use_cuda else 'cpu' | |
## Load model | |
### Captioning | |
config = 'configs/audio/ePALM_audio_caption.yaml' | |
config = yaml.load(open(config, 'r')) | |
text_model = 'facebook/opt-2.7b' | |
vision_model_name = 'vit_base_patch16_224' | |
start_layer_idx = 19 | |
end_layer_idx = 31 | |
low_cpu = True | |
MODEL = ePALM(opt_model_name=text_model, | |
vision_model_name=vision_model_name, | |
use_vis_prefix=True, | |
start_layer_idx=start_layer_idx, | |
end_layer_idx=end_layer_idx, | |
return_hidden_state_vision=True, | |
config=config, | |
low_cpu=low_cpu | |
) | |
print("Model Built") | |
MODEL.to(device) | |
checkpoint_path = 'checkpoints/float32/ePALM_caption/checkpoint_best.pth' | |
checkpoint = torch.load(checkpoint_path, map_location='cpu') | |
state_dict = checkpoint['model'] | |
msg = MODEL.load_state_dict(state_dict,strict=False) | |
MODEL.bfloat16() | |
# Audio Captioning | |
checkpoint_path = 'checkpoints/float32/ePALM_audio_caption/checkpoint_best.pth' | |
checkpoint = torch.load(checkpoint_path, map_location='cpu') | |
state_dict_audio_caption = checkpoint['model'] | |
## Load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(text_model, use_fast=False) | |
eos_token = tokenizer.eos_token | |
pad_token = tokenizer.pad_token | |
special_answer_token = '</a>' | |
special_tokens_dict = {'additional_special_tokens': [special_answer_token]} | |
tokenizer.add_special_tokens(special_tokens_dict) | |
def read_audio(path): | |
melbins = 128 | |
target_length = 1024 | |
skip_norm = False | |
norm_mean = -4.2677393 | |
norm_std = 4.5689974 | |
waveform, sr = torchaudio.load(path) | |
waveform = waveform - waveform.mean() | |
# audio | |
fbank = torchaudio.compliance.kaldi.fbank(waveform, htk_compat=True, sample_frequency=sr, use_energy=False, | |
window_type='hanning', num_mel_bins=melbins, dither=0.0, | |
frame_shift=10) | |
n_frames = fbank.shape[0] | |
p = target_length - n_frames | |
# cut and pad | |
if p > 0: | |
m = torch.nn.ZeroPad2d((0, 0, 0, p)) | |
fbank = m(fbank) | |
elif p < 0: | |
fbank = fbank[0:target_length, :] | |
# SpecAug, not do for eval set | |
fbank = torch.transpose(fbank, 0, 1) | |
# this is just to satisfy new torchaudio version, which only accept [1, freq, time] | |
fbank = fbank.unsqueeze(0) | |
# squeeze it back, it is just a trick to satisfy new torchaudio version | |
fbank = fbank.squeeze(0) | |
fbank = torch.transpose(fbank, 0, 1) | |
# normalize the input for both training and test | |
if not skip_norm: | |
fbank = (fbank - norm_mean) / (norm_std * 2) | |
# skip normalization the input if you are trying to get the normalization stats. | |
else: | |
pass | |
audio = fbank | |
return audio | |
do_sample=False | |
num_beams=3 | |
max_length=30 | |
def inference(image, task_type): | |
if task_type == 'Audio Captioning': | |
text = [''] | |
text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device) | |
model = MODEL | |
else: | |
raise NotImplemented | |
image = read_audio(image) | |
with torch.autocast(device_type=device_type, dtype=torch.bfloat16, enabled=True): | |
out = model(image=image, text=text_input, mode='generate', return_dict=True, max_length=max_length, | |
do_sample=do_sample, num_beams=num_beams) | |
if 'Captioning' in task_type: | |
for i, o in enumerate(out): | |
res = tokenizer.decode(o) | |
response = res.split('</s>')[1].replace(pad_token, '').replace('</s>', '').replace(eos_token, '') # skip_special_tokens=True | |
else: | |
for o in out: | |
o_list = o.tolist() | |
response = tokenizer.decode(o_list).split(special_answer_token)[1].replace(pad_token, '').replace('</s>', '').replace(eos_token, '') # skip_special_tokens=True | |
return response | |
inputs = [gr.Audio(source="upload", type="filepath"), gr.inputs.Radio(choices=['Audio Captioning'], type="value", default="Image Captioning", label="Task")] | |
outputs = ['text'] | |
examples = [ | |
['examples/audios/6cS0FsUM-cQ.wav', 'Audio Captioning', None], | |
['examples/audios/AJtNitYMa1I.wav', 'Audio Captioning', None], | |
] | |
title = "eP-ALM for Audio-Text tasks" | |
description = "Gradio Demo for eP-ALM. For this demo, we use 2.7B OPT. As the model runs on CPUs and float16 mixed precision is not supported on CPUs, the generation can take up to 2 mins." | |
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2303.11403' target='_blank'>Paper</a> | <a href='https://github.com/mshukor/eP-ALM' target='_blank'>Github Repo</a></p>" | |
io = gr.Interface(fn=inference, inputs=inputs, outputs=outputs, | |
title=title, description=description, article=article, examples=examples, cache_examples=False) | |
io.launch() |