Spaces:
Running
on
Zero
Running
on
Zero
from typing import Dict, List | |
from PIL import Image | |
import random | |
from .utils import sample_video, read_image, adjust_bbox, filter_ocr_polygon | |
class VisionParser: | |
def __init__( | |
self, | |
n_frames=8, | |
max_n_frames=256, | |
is_training=True, | |
video_sampling_strategy={}, | |
): | |
self.n_frames = n_frames | |
self.max_n_frames = max_n_frames | |
self.is_training = is_training | |
self.video_sampling_strategy = video_sampling_strategy | |
# fmt: off | |
self.data_temp = { | |
"messages": [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": "Describe the image and the video."}, | |
# 支持的 image 格式: | |
{"type": "image", "image": {"image_file": "/path/to/image"}}, | |
{"type": "image", "image": {"video_file": "/path/to/video", "frame_indices": 0}}, | |
# 支持的 video 格式: | |
{"type": "video", "video": {"video_file": "/path/to/video"}}, | |
{"type": "video", "video": {"video_file": "/path/to/video", "frame_indices": [0, 1, 2]}}, | |
{"type": "video", "video": {"video_file": "/path/to/video", "start_frame": 0, "end_frame": 100}}, | |
{"type": "video", "video": {"video_file": "/path/to/video", "time_indices": [0, 1, 2]}}, | |
{"type": "video", "video": {"video_file": "/path/to/video", "start_time": 0, "end_time": 100}}, | |
{"type": "video", "video": {"image_file": ["/path/to/image"]}, "frame_indices": [0, 1, 2]}, | |
] | |
}, | |
{ | |
"role": "assistant", | |
"content": [ | |
{"type": "text","text": "xxx"} | |
] | |
} | |
], | |
"dataset": "LSMDC", | |
"task": "video/caption" | |
} | |
# fmt: on | |
def check_format(self, data_dict: Dict, image_processing_config: Dict): | |
if image_processing_config.get('do_crop', False) and image_processing_config.get('has_coordinates', False): | |
raise ValueError(f'do_crop and has_coordinates cannot be True at the same time!') | |
""" | |
1. 将 messages 中的 image/video 替换成相应的 PIL.Image/List[PIL.Image] | |
2. text 的特殊处理:调整 box;过滤面积太小的OCR | |
""" | |
def transform(self, data_dict: Dict, image_processing_config: Dict = None) -> Dict: | |
self.check_format(data_dict, image_processing_config) | |
self.set_n_frames(data_dict) | |
first_image = None # ugly! 需要调整box/过滤面积太小的OCR的数据只有图片任务 | |
for msg in data_dict['messages']: | |
if isinstance(msg['content'], dict): | |
msg['content'] = [msg['content']] | |
for content in msg['content']: | |
if content['type'] == 'image': | |
content['image'] = self.load_image_item(content['image']) | |
if first_image is None: | |
first_image = content['image'] | |
elif content['type'] == 'video': | |
video = self.load_video_item(content['video']) | |
content['video'] = video.pop('frames') | |
if video: | |
data_dict['extra_info']['frame_disturb_info'] = video.pop('video_info', {}) | |
elif content['type'] == 'text': | |
pass | |
else: | |
raise ValueError(f"content['type']={content['type']} MUST be one of ['image', 'video', 'text']") | |
for msg in data_dict['messages']: | |
for content in msg['content']: | |
if content['type'] == 'text': | |
self.postprocess_text(content, data_dict, image_processing_config, first_image) | |
return data_dict['messages'] | |
# set n_frames for each vision item. | |
def set_n_frames(self, data_dict): | |
if isinstance(self.n_frames, int): | |
n_frames = self.n_frames | |
else: | |
n_frames = random.choice(self.n_frames) | |
assert n_frames <= self.max_n_frames | |
curr_n_frames = 0 | |
has_dynamic = False | |
for msg in data_dict['messages']: | |
if isinstance(msg['content'], dict): | |
msg['content'] = [msg['content']] | |
for content in msg['content']: | |
if content['type'] == 'image': | |
curr_n_frames += 1 | |
elif content['type'] == 'video': | |
if 'frame_indices' in content['video']: | |
curr_n_frames += len(content['video']['frame_indices']) | |
content['video']['n_frames'] = len(content['video']['frame_indices']) | |
elif 'time_indices' in content['video']: | |
curr_n_frames += len(content['video']['time_indices']) | |
content['video']['n_frames'] = len(content['video']['time_indices']) | |
elif 'min_n_frames' in content['video']: | |
content['video']['min_n_frames'] = int(content['video']['min_n_frames']) | |
curr_n_frames += content['video']['min_n_frames'] | |
content['video']['n_frames'] = content['video']['min_n_frames'] | |
has_dynamic = True | |
elif 'fps' in content['video']: | |
content['video']['n_frames'] = self.max_n_frames | |
curr_n_frames += self.max_n_frames | |
has_dynamic = True | |
else: | |
content['video']['n_frames'] = 0 | |
has_dynamic = True | |
while curr_n_frames < n_frames and has_dynamic: | |
for msg in data_dict['messages']: | |
for content in msg['content']: | |
if content['type'] == 'video': | |
if 'frame_indices' in content['video']: | |
pass | |
elif 'time_indices' in content['video']: | |
pass | |
else: | |
if curr_n_frames < n_frames: | |
content['video']['n_frames'] += 1 | |
curr_n_frames += 1 | |
while curr_n_frames > self.max_n_frames and has_dynamic: | |
for msg in data_dict['messages']: | |
for content in msg['content']: | |
if content['type'] == 'video': | |
if 'frame_indices' in content['video']: | |
pass | |
elif 'time_indices' in content['video']: | |
pass | |
else: | |
if curr_n_frames > self.max_n_frames: | |
content['video']['n_frames'] -= 1 | |
curr_n_frames -= 1 | |
for msg in data_dict['messages']: | |
for content in msg['content']: | |
if content['type'] == 'video': | |
if 'frame_indices' in content['video']: | |
pass | |
elif 'time_indices' in content['video']: | |
pass | |
else: | |
n = self.video_sampling_strategy.get('force_frames_n_divisible', 1) | |
if n > 1 and content['video']['n_frames'] % n != 0: | |
content['video']['n_frames'] += n - content['video']['n_frames'] % n | |
def load_image_item(self, image_item) -> Image.Image: | |
""" | |
image_item: | |
{"image_file": {"lq": "/path/to/image"}} | |
{"video_file": {"lq": "/path/to/video"}, "frame_indices": 0} | |
""" | |
# check format | |
if ("image_file" not in image_item) and ("video_file" not in image_item): | |
raise KeyError(f"Key 'image_file' or 'video_file' not found in image_item") | |
if 'image_file' in image_item: | |
if not isinstance(image_item['image_file'], str): | |
raise ValueError(f"{image_item['image_file']} is not a str!") | |
if 'video_file' in image_item: | |
if not isinstance(image_item['frame_indices'], int): | |
raise ValueError(f"{image_item['frame_indices']} is not a int!") | |
if 'image_file' in image_item: | |
image = read_image(image_item['image_file']) | |
else: | |
frame_indices = [image_item['frame_indices']] | |
image = sample_video(image_item['video_file'], frame_indices = frame_indices)[0] | |
return image | |
def load_video_item(self, video_item) -> List[Image.Image]: | |
""" | |
video_item: | |
{"video_file": {"lq": "/path/to/video"}, "n_frames": 8} | |
{"video_file": {"lq": "/path/to/video"}, "frame_indices": [0, 1, 2], "n_frames": 3} | |
{"video_file": {"lq": "/path/to/video"}, "start_frame": 0, "end_frame": 100, "n_frames": 8} | |
{"video_file": {"lq": "/path/to/video"}, "time_indices": [0, 1, 2], "n_frames": 3} | |
{"video_file": {"lq": "/path/to/video"}, "start_time": 0, "end_time": 100, "n_frames": 8} | |
{"image_file": {"lq": ["/path/to/image"]}, "frame_indices": [0, 1, 2], "n_frames": 3} | |
""" | |
# check format | |
if ("image_file" not in video_item) and ("video_file" not in video_item): | |
raise KeyError(f"Key 'image_file' or 'video_file' not found in video_item") | |
video_path = video_item.get('video_file', video_item.get('image_file')) | |
n_frames = video_item.get('n_frames', None) | |
frame_indices = video_item.get('frame_indices', None) | |
start_frame = video_item.get('start_frame', None) | |
end_frame = video_item.get('end_frame', None) | |
time_indices = video_item.get('time_indices', None) | |
start_time = video_item.get('start_time', None) | |
end_time = video_item.get('end_time', None) | |
mask_boxes = video_item.get('mask_boxes', None) | |
fps = video_item.get('fps', None) | |
frames, frame_indices = sample_video( | |
video_path=video_path, | |
frame_indices=frame_indices, | |
start_frame=start_frame, | |
end_frame=end_frame, | |
n_frames=n_frames, | |
time_indices=time_indices, | |
start_time=start_time, | |
end_time=end_time, | |
sampling_fps=fps, | |
mask_boxes=mask_boxes, | |
is_training=self.is_training, | |
video_sampling_strategy=self.video_sampling_strategy, | |
return_frame_ids=True, | |
) | |
if self.video_sampling_strategy.get('use_multi_images_for_video', False): | |
new_frames = [] | |
for f in frames: | |
new_frames.extend([f, f]) | |
frames = new_frames | |
if isinstance(frame_indices, dict): | |
return { | |
'frames': frames, | |
'video_info': frame_indices | |
} | |
return {'frames': frames} | |
def postprocess_text(self, content, data_dict, image_processing_config, first_image): | |
if image_processing_config.get('has_coordinates') and image_processing_config.get('do_padding'): | |
content['text'] = adjust_bbox(content['text'], frame=first_image) | |
if data_dict.get('task') == 'image/OCR' and image_processing_config.get('has_coordinates'): | |
content['text'] = filter_ocr_polygon(content['text']) | |