ynhe commited on
Commit
6980b93
·
verified ·
1 Parent(s): 732976d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +130 -12
README.md CHANGED
@@ -104,16 +104,128 @@ pip install flash-attn --no-build-isolation
104
  ```
105
  Then you could use our model:
106
  ```python
 
 
 
 
 
 
107
  from transformers import AutoModel, AutoTokenizer
108
 
 
109
  # model setting
110
  model_path = 'OpenGVLab/InternVideo2_5_Chat_8B'
111
 
112
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
113
  model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().cuda()
114
- image_processor = model.get_vision_tower().image_processor
115
 
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  # evaluation setting
118
  max_num_frames = 512
119
  generation_config = dict(
@@ -123,20 +235,26 @@ generation_config = dict(
123
  top_p=0.1,
124
  num_beams=1
125
  )
126
-
127
  video_path = "your_video.mp4"
 
128
 
129
- # single-turn conversation
130
- question1 = "Describe this video in detail."
131
- output1, chat_history = model.chat(video_path=video_path, tokenizer=tokenizer, user_prompt=question1, return_history=True, max_num_frames=max_num_frames, generation_config=generation_config)
132
-
133
- print(output1)
134
-
135
- # multi-turn conversation
136
- question2 = "How many people appear in the video?"
137
- output2, chat_history = model.chat(video_path=video_path, tokenizer=tokenizer, user_prompt=question2, chat_history=chat_history, return_history=True, max_num_frames=max_num_frames, generation_config=generation_config)
138
 
139
- print(output2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  ```
141
 
142
  ## ✏️ Citation
 
104
  ```
105
  Then you could use our model:
106
  ```python
107
+ import numpy as np
108
+ import torch
109
+ import torchvision.transforms as T
110
+ from decord import VideoReader, cpu
111
+ from PIL import Image
112
+ from torchvision.transforms.functional import InterpolationMode
113
  from transformers import AutoModel, AutoTokenizer
114
 
115
+
116
  # model setting
117
  model_path = 'OpenGVLab/InternVideo2_5_Chat_8B'
118
 
119
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
120
  model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().cuda()
 
121
 
122
 
123
+ def build_transform(input_size):
124
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
125
+ transform = T.Compose([T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=MEAN, std=STD)])
126
+ return transform
127
+
128
+
129
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
130
+ best_ratio_diff = float("inf")
131
+ best_ratio = (1, 1)
132
+ area = width * height
133
+ for ratio in target_ratios:
134
+ target_aspect_ratio = ratio[0] / ratio[1]
135
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
136
+ if ratio_diff < best_ratio_diff:
137
+ best_ratio_diff = ratio_diff
138
+ best_ratio = ratio
139
+ elif ratio_diff == best_ratio_diff:
140
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
141
+ best_ratio = ratio
142
+ return best_ratio
143
+
144
+
145
+ def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
146
+ orig_width, orig_height = image.size
147
+ aspect_ratio = orig_width / orig_height
148
+
149
+ # calculate the existing image aspect ratio
150
+ target_ratios = set((i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num)
151
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
152
+
153
+ # find the closest aspect ratio to the target
154
+ target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)
155
+
156
+ # calculate the target width and height
157
+ target_width = image_size * target_aspect_ratio[0]
158
+ target_height = image_size * target_aspect_ratio[1]
159
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
160
+
161
+ # resize the image
162
+ resized_img = image.resize((target_width, target_height))
163
+ processed_images = []
164
+ for i in range(blocks):
165
+ box = ((i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size, ((i % (target_width // image_size)) + 1) * image_size, ((i // (target_width // image_size)) + 1) * image_size)
166
+ # split the image
167
+ split_img = resized_img.crop(box)
168
+ processed_images.append(split_img)
169
+ assert len(processed_images) == blocks
170
+ if use_thumbnail and len(processed_images) != 1:
171
+ thumbnail_img = image.resize((image_size, image_size))
172
+ processed_images.append(thumbnail_img)
173
+ return processed_images
174
+
175
+
176
+ def load_image(image, input_size=448, max_num=6):
177
+ transform = build_transform(input_size=input_size)
178
+ images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
179
+ pixel_values = [transform(image) for image in images]
180
+ pixel_values = torch.stack(pixel_values)
181
+ return pixel_values
182
+
183
+
184
+ def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
185
+ if bound:
186
+ start, end = bound[0], bound[1]
187
+ else:
188
+ start, end = -100000, 100000
189
+ start_idx = max(first_idx, round(start * fps))
190
+ end_idx = min(round(end * fps), max_frame)
191
+ seg_size = float(end_idx - start_idx) / num_segments
192
+ frame_indices = np.array([int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) for idx in range(num_segments)])
193
+ return frame_indices
194
+
195
+ def get_num_frames_by_duration(duration):
196
+ local_num_frames = 4
197
+ num_segments = int(duration // local_num_frames)
198
+ if num_segments == 0:
199
+ num_frames = local_num_frames
200
+ else:
201
+ num_frames = local_num_frames * num_segments
202
+
203
+ num_frames = min(512, num_frames)
204
+ num_frames = max(128, num_frames)
205
+
206
+ return num_frames
207
+
208
+ def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32, get_frame_by_duration = False):
209
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
210
+ max_frame = len(vr) - 1
211
+ fps = float(vr.get_avg_fps())
212
+
213
+ pixel_values_list, num_patches_list = [], []
214
+ transform = build_transform(input_size=input_size)
215
+ if get_frame_by_duration:
216
+ duration = max_frame / fps
217
+ num_segments = get_num_frames_by_duration(duration)
218
+ frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)
219
+ for frame_index in frame_indices:
220
+ img = Image.fromarray(vr[frame_index].asnumpy()).convert("RGB")
221
+ img = dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num)
222
+ pixel_values = [transform(tile) for tile in img]
223
+ pixel_values = torch.stack(pixel_values)
224
+ num_patches_list.append(pixel_values.shape[0])
225
+ pixel_values_list.append(pixel_values)
226
+ pixel_values = torch.cat(pixel_values_list)
227
+ return pixel_values, num_patches_list
228
+
229
  # evaluation setting
230
  max_num_frames = 512
231
  generation_config = dict(
 
235
  top_p=0.1,
236
  num_beams=1
237
  )
 
238
  video_path = "your_video.mp4"
239
+ num_segments=128
240
 
 
 
 
 
 
 
 
 
 
241
 
242
+ with torch.no_grad():
243
+
244
+ pixel_values, num_patches_list = load_video(video_path, num_segments=num_segments, max_num=1, get_frame_by_duration=False)
245
+ pixel_values = pixel_values.to(torch.bfloat16).to(model.device)
246
+ video_prefix = "".join([f"Frame{i+1}: <image>\n" for i in range(len(num_patches_list))])
247
+ # single-turn conversation
248
+ question1 = "Describe this video in detail."
249
+ question = video_prefix + question1
250
+ output1, chat_history = model.chat(tokenizer, pixel_values, question, generation_config, num_patches_list=num_patches_list, history=None, return_history=True)
251
+ print(output1)
252
+
253
+ # multi-turn conversation
254
+ question2 = "How many people appear in the video?"
255
+ output2, chat_history = model.chat(tokenizer, pixel_values, question, generation_config, num_patches_list=num_patches_list, history=chat_history, return_history=True)
256
+
257
+ print(output2)
258
  ```
259
 
260
  ## ✏️ Citation