mshukor commited on
Commit
0173bf0
·
1 Parent(s): 33f1db4
Files changed (1) hide show
  1. app.py +16 -115
app.py CHANGED
@@ -50,20 +50,17 @@ device_type = 'cuda' if use_cuda else 'cpu'
50
  ## Load model
51
 
52
  ### Captioning
53
- config = 'configs/image/ePALM_caption.yaml'
54
- # config = yaml.load(open(config, 'r'), Loader=yaml.Loader)
55
  config = yaml.load(open(config, 'r'))
56
 
57
  text_model = 'facebook/opt-2.7b'
58
  vision_model_name = 'vit_base_patch16_224'
59
 
60
- # text_model = 'facebook/opt-6.7b'
61
- # vision_model_name = 'vit_large_patch16_224'
62
 
63
  start_layer_idx = 19
64
  end_layer_idx = 31
65
  low_cpu = True
66
- model = ePALM(opt_model_name=text_model,
67
  vision_model_name=vision_model_name,
68
  use_vis_prefix=True,
69
  start_layer_idx=start_layer_idx,
@@ -73,62 +70,20 @@ model = ePALM(opt_model_name=text_model,
73
  low_cpu=low_cpu
74
  )
75
  print("Model Built")
76
- model.to(device)
77
 
78
  checkpoint_path = 'checkpoints/float32/ePALM_caption/checkpoint_best.pth'
79
- # checkpoint_path = '/data/mshukor/logs/eplam/models/accelerate/ePALM_pt_L_acc_caption/checkpoint_best.pth'
80
  checkpoint = torch.load(checkpoint_path, map_location='cpu')
81
  state_dict = checkpoint['model']
82
- msg = model.load_state_dict(state_dict,strict=False)
83
-
84
- model.bfloat16()
85
-
86
- # ###### VQA
87
- # config = 'configs/image/ePALM_vqa.yaml'
88
- # config = yaml.load(open(config, 'r'))
89
-
90
- # start_layer_idx = 19
91
- # end_layer_idx = 31
92
- # low_cpu = True
93
- # model_vqa = ePALM(opt_model_name=text_model,
94
- # vision_model_name=vision_model_name,
95
- # use_vis_prefix=True,
96
- # start_layer_idx=start_layer_idx,
97
- # end_layer_idx=end_layer_idx,
98
- # return_hidden_state_vision=True,
99
- # config=config,
100
- # low_cpu=low_cpu
101
- # )
102
- # print("Model Built")
103
- # model_vqa.to(device)
104
-
105
-
106
- checkpoint_path = 'checkpoints/float32/ePALM_vqa/checkpoint_best.pth'
107
- checkpoint = torch.load(checkpoint_path, map_location='cpu')
108
- state_dict_vqa = checkpoint['model']
109
- # msg = model_vqa.load_state_dict(state_dict,strict=False)
110
-
111
 
112
- # model_vqa.bfloat16()
113
 
114
 
115
 
116
- # Video Captioning
117
- checkpoint_path = 'checkpoints/float32/ePALM_video_caption_msrvtt/checkpoint_best.pth'
118
- # checkpoint_path = '/data/mshukor/logs/eplam/models/accelerate/ePALM_pt_L_acc_caption/checkpoint_best.pth'
119
- checkpoint = torch.load(checkpoint_path, map_location='cpu')
120
- state_dict_video_caption = checkpoint['model']
121
-
122
- # Video QA
123
- checkpoint_path = 'checkpoints/float32/ePALM_video_qa_msrvtt/checkpoint_best.pth'
124
- # checkpoint_path = '/data/mshukor/logs/eplam/models/accelerate/ePALM_pt_L_acc_caption/checkpoint_best.pth'
125
- checkpoint = torch.load(checkpoint_path, map_location='cpu')
126
- state_dict_video_qa = checkpoint['model']
127
-
128
 
129
  # Audio Captioning
130
  checkpoint_path = 'checkpoints/float32/ePALM_audio_caption/checkpoint_best.pth'
131
- # checkpoint_path = '/data/mshukor/logs/eplam/models/accelerate/ePALM_pt_L_acc_caption/checkpoint_best.pth'
132
  checkpoint = torch.load(checkpoint_path, map_location='cpu')
133
  state_dict_audio_caption = checkpoint['model']
134
 
@@ -146,33 +101,8 @@ special_tokens_dict = {'additional_special_tokens': [special_answer_token]}
146
  tokenizer.add_special_tokens(special_tokens_dict)
147
 
148
 
149
- image_size = 224
150
- normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
151
-
152
- transform = transforms.Compose([
153
- transforms.Resize((image_size,image_size),interpolation=Image.BICUBIC),
154
- transforms.ToTensor(),
155
- normalize,
156
- ])
157
 
158
- type_transform = transforms.Lambda(lambda x: x.float().div(255.0))
159
- test_transform = transforms.Compose([
160
- transforms.Resize((image_size,image_size),interpolation=Image.BICUBIC),
161
- type_transform,
162
- normalize,
163
- ])
164
- from dataset.video_utils import VIDEO_READER_FUNCS
165
- video_reader = VIDEO_READER_FUNCS['decord']
166
-
167
- def read_video(path, num_frames=16):
168
-
169
 
170
- frames, frame_indices, video_duration = video_reader(
171
- path, num_frames, 'rand', max_num_frames=-1
172
- )
173
- video = test_transform(frames)
174
-
175
- return video
176
 
177
  def read_audio(path):
178
 
@@ -237,37 +167,18 @@ max_length=30
237
 
238
 
239
 
240
- def inference(image, audio, video, task_type, instruction):
241
 
242
- if task_type == 'Image Captioning':
243
- text = ['']
244
- text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
245
- elif task_type == 'Video Captioning':
246
- text = ['']
247
- text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
248
- model = model.load_state_dict(state_dict_video_caption,strict=False)
249
- elif task_type == 'Audio Captioning':
250
  text = ['']
251
  text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
252
- model = model.load_state_dict(state_dict_audio_caption,strict=False)
253
- elif task_type == 'Visual Question Answering':
254
- question = instruction+'?'+special_answer_token
255
- text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
256
- model = model.load_state_dict(state_dict_vqa,strict=False)
257
- elif task_type == 'Visual Question Answering':
258
- question = instruction+'?'+special_answer_token
259
- text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
260
- model = model.load_state_dict(state_dict_video_qa,strict=False)
261
  else:
262
  raise NotImplemented
263
 
264
- if "Video" in task_type:
265
- image = read_video(image)
266
- elif "Audio" in task_type:
267
- image = read_audio(image)
268
- else:
269
- image = transform(image)
270
- image = image.to(device,non_blocking=True).unsqueeze(0)
271
 
272
 
273
 
@@ -290,25 +201,15 @@ def inference(image, audio, video, task_type, instruction):
290
  return response
291
 
292
 
293
- inputs = [gr.inputs.Image(type='pil'), gr.Audio(source="upload", type="filepath"), gr.Video(source="upload", type="filepath"), gr.inputs.Radio(choices=['Image Captioning', 'Video Captioning', 'Audio Captioning', "Visual Question Answering", "Visual Grounding", "General", "General Video"], type="value", default="Image Captioning", label="Task"), gr.inputs.Textbox(lines=1, label="Instruction")]
294
  outputs = ['text']
295
  examples = [
296
- ['examples/images/soccer.jpg', None, None, 'Image Captioning', None],
297
- ['examples/images/ski.jpg', None, None, 'Visual Question Answering', 'what does the woman wearing black do?'],
298
- ['examples/images/banana.jpg', None, None, 'Image Captioning', None],
299
- ['examples/images/skateboard.jpg', None, None, 'Visual Question Answering', 'what is on top of the skateboard?'],
300
- ['examples/images/baseball.jpg', None, None, 'Image Captioning', None],
301
- [None, None, 'examples/videos/video7014.mp4', 'Video Captioning', None],
302
- [None, None, 'examples/videos/video7017.mp4', 'Video Captioning', None],
303
- [None, None, 'examples/videos/video7019.mp4', 'Video Captioning', None],
304
- [None, None, 'examples/videos/video7021.mp4', 'Video Captioning', None],
305
- [None, None, 'examples/videos/video7021.mp4', 'Video Captioning', None],
306
- [None, 'examples/audios/6cS0FsUM-cQ.wav', None, 'Audio Captioning', None],
307
- [None, 'examples/audios/AJtNitYMa1I.wav', None, 'Audio Captioning', None],
308
  ]
309
 
310
- title = "eP-ALM"
311
- description = "Gradio Demo for eP-ALM: "
312
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2303.11403' target='_blank'>Paper</a> | <a href='https://github.com/mshukor/eP-ALM' target='_blank'>Github Repo</a></p>"
313
 
314
  io = gr.Interface(fn=inference, inputs=inputs, outputs=outputs,
 
50
  ## Load model
51
 
52
  ### Captioning
53
+ config = 'configs/audio/ePALM_audio_caption.yaml'
 
54
  config = yaml.load(open(config, 'r'))
55
 
56
  text_model = 'facebook/opt-2.7b'
57
  vision_model_name = 'vit_base_patch16_224'
58
 
 
 
59
 
60
  start_layer_idx = 19
61
  end_layer_idx = 31
62
  low_cpu = True
63
+ MODEL = ePALM(opt_model_name=text_model,
64
  vision_model_name=vision_model_name,
65
  use_vis_prefix=True,
66
  start_layer_idx=start_layer_idx,
 
70
  low_cpu=low_cpu
71
  )
72
  print("Model Built")
73
+ MODEL.to(device)
74
 
75
  checkpoint_path = 'checkpoints/float32/ePALM_caption/checkpoint_best.pth'
 
76
  checkpoint = torch.load(checkpoint_path, map_location='cpu')
77
  state_dict = checkpoint['model']
78
+ msg = MODEL.load_state_dict(state_dict,strict=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
+ MODEL.bfloat16()
81
 
82
 
83
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  # Audio Captioning
86
  checkpoint_path = 'checkpoints/float32/ePALM_audio_caption/checkpoint_best.pth'
 
87
  checkpoint = torch.load(checkpoint_path, map_location='cpu')
88
  state_dict_audio_caption = checkpoint['model']
89
 
 
101
  tokenizer.add_special_tokens(special_tokens_dict)
102
 
103
 
 
 
 
 
 
 
 
 
104
 
 
 
 
 
 
 
 
 
 
 
 
105
 
 
 
 
 
 
 
106
 
107
  def read_audio(path):
108
 
 
167
 
168
 
169
 
170
+ def inference(image, task_type):
171
 
172
+ if task_type == 'Audio Captioning':
 
 
 
 
 
 
 
173
  text = ['']
174
  text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
175
+ model = MODEL
 
 
 
 
 
 
 
 
176
  else:
177
  raise NotImplemented
178
 
179
+
180
+ image = read_audio(image)
181
+
 
 
 
 
182
 
183
 
184
 
 
201
  return response
202
 
203
 
204
+ inputs = [gr.Audio(source="upload", type="filepath"), gr.inputs.Radio(choices=['Audio Captioning'], type="value", default="Image Captioning", label="Task")]
205
  outputs = ['text']
206
  examples = [
207
+ ['examples/audios/6cS0FsUM-cQ.wav', 'Audio Captioning', None],
208
+ ['examples/audios/AJtNitYMa1I.wav', 'Audio Captioning', None],
 
 
 
 
 
 
 
 
 
 
209
  ]
210
 
211
+ title = "eP-ALM for Audio-Text tasks"
212
+ description = "Gradio Demo for eP-ALM. For this demo, we use 2.7B OPT. As the model runs on CPUs and float16 mixed precision is not supported on CPUs, the generation can take up to 2 mins."
213
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2303.11403' target='_blank'>Paper</a> | <a href='https://github.com/mshukor/eP-ALM' target='_blank'>Github Repo</a></p>"
214
 
215
  io = gr.Interface(fn=inference, inputs=inputs, outputs=outputs,