from typing import Dict, List from PIL import Image import random from .utils import sample_video, read_image, adjust_bbox, filter_ocr_polygon class VisionParser: def __init__( self, n_frames=8, max_n_frames=256, is_training=True, video_sampling_strategy={}, ): self.n_frames = n_frames self.max_n_frames = max_n_frames self.is_training = is_training self.video_sampling_strategy = video_sampling_strategy # fmt: off self.data_temp = { "messages": [ { "role": "user", "content": [ {"type": "text", "text": "Describe the image and the video."}, # 支持的 image 格式: {"type": "image", "image": {"image_file": "/path/to/image"}}, {"type": "image", "image": {"video_file": "/path/to/video", "frame_indices": 0}}, # 支持的 video 格式: {"type": "video", "video": {"video_file": "/path/to/video"}}, {"type": "video", "video": {"video_file": "/path/to/video", "frame_indices": [0, 1, 2]}}, {"type": "video", "video": {"video_file": "/path/to/video", "start_frame": 0, "end_frame": 100}}, {"type": "video", "video": {"video_file": "/path/to/video", "time_indices": [0, 1, 2]}}, {"type": "video", "video": {"video_file": "/path/to/video", "start_time": 0, "end_time": 100}}, {"type": "video", "video": {"image_file": ["/path/to/image"]}, "frame_indices": [0, 1, 2]}, ] }, { "role": "assistant", "content": [ {"type": "text","text": "xxx"} ] } ], "dataset": "LSMDC", "task": "video/caption" } # fmt: on def check_format(self, data_dict: Dict, image_processing_config: Dict): if image_processing_config.get('do_crop', False) and image_processing_config.get('has_coordinates', False): raise ValueError(f'do_crop and has_coordinates cannot be True at the same time!') """ 1. 将 messages 中的 image/video 替换成相应的 PIL.Image/List[PIL.Image] 2. text 的特殊处理:调整 box;过滤面积太小的OCR """ def transform(self, data_dict: Dict, image_processing_config: Dict = None) -> Dict: self.check_format(data_dict, image_processing_config) self.set_n_frames(data_dict) first_image = None # ugly! 需要调整box/过滤面积太小的OCR的数据只有图片任务 for msg in data_dict['messages']: if isinstance(msg['content'], dict): msg['content'] = [msg['content']] for content in msg['content']: if content['type'] == 'image': content['image'] = self.load_image_item(content['image']) if first_image is None: first_image = content['image'] elif content['type'] == 'video': video = self.load_video_item(content['video']) content['video'] = video.pop('frames') if video: data_dict['extra_info']['frame_disturb_info'] = video.pop('video_info', {}) elif content['type'] == 'text': pass else: raise ValueError(f"content['type']={content['type']} MUST be one of ['image', 'video', 'text']") for msg in data_dict['messages']: for content in msg['content']: if content['type'] == 'text': self.postprocess_text(content, data_dict, image_processing_config, first_image) return data_dict['messages'] # set n_frames for each vision item. def set_n_frames(self, data_dict): if isinstance(self.n_frames, int): n_frames = self.n_frames else: n_frames = random.choice(self.n_frames) assert n_frames <= self.max_n_frames curr_n_frames = 0 has_dynamic = False for msg in data_dict['messages']: if isinstance(msg['content'], dict): msg['content'] = [msg['content']] for content in msg['content']: if content['type'] == 'image': curr_n_frames += 1 elif content['type'] == 'video': if 'frame_indices' in content['video']: curr_n_frames += len(content['video']['frame_indices']) content['video']['n_frames'] = len(content['video']['frame_indices']) elif 'time_indices' in content['video']: curr_n_frames += len(content['video']['time_indices']) content['video']['n_frames'] = len(content['video']['time_indices']) elif 'min_n_frames' in content['video']: content['video']['min_n_frames'] = int(content['video']['min_n_frames']) curr_n_frames += content['video']['min_n_frames'] content['video']['n_frames'] = content['video']['min_n_frames'] has_dynamic = True elif 'fps' in content['video']: content['video']['n_frames'] = self.max_n_frames curr_n_frames += self.max_n_frames has_dynamic = True else: content['video']['n_frames'] = 0 has_dynamic = True while curr_n_frames < n_frames and has_dynamic: for msg in data_dict['messages']: for content in msg['content']: if content['type'] == 'video': if 'frame_indices' in content['video']: pass elif 'time_indices' in content['video']: pass else: if curr_n_frames < n_frames: content['video']['n_frames'] += 1 curr_n_frames += 1 while curr_n_frames > self.max_n_frames and has_dynamic: for msg in data_dict['messages']: for content in msg['content']: if content['type'] == 'video': if 'frame_indices' in content['video']: pass elif 'time_indices' in content['video']: pass else: if curr_n_frames > self.max_n_frames: content['video']['n_frames'] -= 1 curr_n_frames -= 1 for msg in data_dict['messages']: for content in msg['content']: if content['type'] == 'video': if 'frame_indices' in content['video']: pass elif 'time_indices' in content['video']: pass else: n = self.video_sampling_strategy.get('force_frames_n_divisible', 1) if n > 1 and content['video']['n_frames'] % n != 0: content['video']['n_frames'] += n - content['video']['n_frames'] % n def load_image_item(self, image_item) -> Image.Image: """ image_item: {"image_file": {"lq": "/path/to/image"}} {"video_file": {"lq": "/path/to/video"}, "frame_indices": 0} """ # check format if ("image_file" not in image_item) and ("video_file" not in image_item): raise KeyError(f"Key 'image_file' or 'video_file' not found in image_item") if 'image_file' in image_item: if not isinstance(image_item['image_file'], str): raise ValueError(f"{image_item['image_file']} is not a str!") if 'video_file' in image_item: if not isinstance(image_item['frame_indices'], int): raise ValueError(f"{image_item['frame_indices']} is not a int!") if 'image_file' in image_item: image = read_image(image_item['image_file']) else: frame_indices = [image_item['frame_indices']] image = sample_video(image_item['video_file'], frame_indices = frame_indices)[0] return image def load_video_item(self, video_item) -> List[Image.Image]: """ video_item: {"video_file": {"lq": "/path/to/video"}, "n_frames": 8} {"video_file": {"lq": "/path/to/video"}, "frame_indices": [0, 1, 2], "n_frames": 3} {"video_file": {"lq": "/path/to/video"}, "start_frame": 0, "end_frame": 100, "n_frames": 8} {"video_file": {"lq": "/path/to/video"}, "time_indices": [0, 1, 2], "n_frames": 3} {"video_file": {"lq": "/path/to/video"}, "start_time": 0, "end_time": 100, "n_frames": 8} {"image_file": {"lq": ["/path/to/image"]}, "frame_indices": [0, 1, 2], "n_frames": 3} """ # check format if ("image_file" not in video_item) and ("video_file" not in video_item): raise KeyError(f"Key 'image_file' or 'video_file' not found in video_item") video_path = video_item.get('video_file', video_item.get('image_file')) n_frames = video_item.get('n_frames', None) frame_indices = video_item.get('frame_indices', None) start_frame = video_item.get('start_frame', None) end_frame = video_item.get('end_frame', None) time_indices = video_item.get('time_indices', None) start_time = video_item.get('start_time', None) end_time = video_item.get('end_time', None) mask_boxes = video_item.get('mask_boxes', None) fps = video_item.get('fps', None) frames, frame_indices = sample_video( video_path=video_path, frame_indices=frame_indices, start_frame=start_frame, end_frame=end_frame, n_frames=n_frames, time_indices=time_indices, start_time=start_time, end_time=end_time, sampling_fps=fps, mask_boxes=mask_boxes, is_training=self.is_training, video_sampling_strategy=self.video_sampling_strategy, return_frame_ids=True, ) if self.video_sampling_strategy.get('use_multi_images_for_video', False): new_frames = [] for f in frames: new_frames.extend([f, f]) frames = new_frames if isinstance(frame_indices, dict): return { 'frames': frames, 'video_info': frame_indices } return {'frames': frames} def postprocess_text(self, content, data_dict, image_processing_config, first_image): if image_processing_config.get('has_coordinates') and image_processing_config.get('do_padding'): content['text'] = adjust_bbox(content['text'], frame=first_image) if data_dict.get('task') == 'image/OCR' and image_processing_config.get('has_coordinates'): content['text'] = filter_ocr_polygon(content['text'])