Tarsier2-7b / dataset /custom_data_parsers /standard_vision_parser.py
omni-research's picture
update to tarsier2-7b-0115
dcd4560
from typing import Dict, List
from PIL import Image
import random
from .utils import sample_video, read_image, adjust_bbox, filter_ocr_polygon
class VisionParser:
def __init__(
self,
n_frames=8,
max_n_frames=256,
is_training=True,
video_sampling_strategy={},
):
self.n_frames = n_frames
self.max_n_frames = max_n_frames
self.is_training = is_training
self.video_sampling_strategy = video_sampling_strategy
# fmt: off
self.data_temp = {
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "Describe the image and the video."},
# 支持的 image 格式:
{"type": "image", "image": {"image_file": "/path/to/image"}},
{"type": "image", "image": {"video_file": "/path/to/video", "frame_indices": 0}},
# 支持的 video 格式:
{"type": "video", "video": {"video_file": "/path/to/video"}},
{"type": "video", "video": {"video_file": "/path/to/video", "frame_indices": [0, 1, 2]}},
{"type": "video", "video": {"video_file": "/path/to/video", "start_frame": 0, "end_frame": 100}},
{"type": "video", "video": {"video_file": "/path/to/video", "time_indices": [0, 1, 2]}},
{"type": "video", "video": {"video_file": "/path/to/video", "start_time": 0, "end_time": 100}},
{"type": "video", "video": {"image_file": ["/path/to/image"]}, "frame_indices": [0, 1, 2]},
]
},
{
"role": "assistant",
"content": [
{"type": "text","text": "xxx"}
]
}
],
"dataset": "LSMDC",
"task": "video/caption"
}
# fmt: on
def check_format(self, data_dict: Dict, image_processing_config: Dict):
if image_processing_config.get('do_crop', False) and image_processing_config.get('has_coordinates', False):
raise ValueError(f'do_crop and has_coordinates cannot be True at the same time!')
"""
1. 将 messages 中的 image/video 替换成相应的 PIL.Image/List[PIL.Image]
2. text 的特殊处理:调整 box;过滤面积太小的OCR
"""
def transform(self, data_dict: Dict, image_processing_config: Dict = None) -> Dict:
self.check_format(data_dict, image_processing_config)
self.set_n_frames(data_dict)
first_image = None # ugly! 需要调整box/过滤面积太小的OCR的数据只有图片任务
for msg in data_dict['messages']:
if isinstance(msg['content'], dict):
msg['content'] = [msg['content']]
for content in msg['content']:
if content['type'] == 'image':
content['image'] = self.load_image_item(content['image'])
if first_image is None:
first_image = content['image']
elif content['type'] == 'video':
video = self.load_video_item(content['video'])
content['video'] = video.pop('frames')
if video:
data_dict['extra_info']['frame_disturb_info'] = video.pop('video_info', {})
elif content['type'] == 'text':
pass
else:
raise ValueError(f"content['type']={content['type']} MUST be one of ['image', 'video', 'text']")
for msg in data_dict['messages']:
for content in msg['content']:
if content['type'] == 'text':
self.postprocess_text(content, data_dict, image_processing_config, first_image)
return data_dict['messages']
# set n_frames for each vision item.
def set_n_frames(self, data_dict):
if isinstance(self.n_frames, int):
n_frames = self.n_frames
else:
n_frames = random.choice(self.n_frames)
assert n_frames <= self.max_n_frames
curr_n_frames = 0
has_dynamic = False
for msg in data_dict['messages']:
if isinstance(msg['content'], dict):
msg['content'] = [msg['content']]
for content in msg['content']:
if content['type'] == 'image':
curr_n_frames += 1
elif content['type'] == 'video':
if 'frame_indices' in content['video']:
curr_n_frames += len(content['video']['frame_indices'])
content['video']['n_frames'] = len(content['video']['frame_indices'])
elif 'time_indices' in content['video']:
curr_n_frames += len(content['video']['time_indices'])
content['video']['n_frames'] = len(content['video']['time_indices'])
elif 'min_n_frames' in content['video']:
content['video']['min_n_frames'] = int(content['video']['min_n_frames'])
curr_n_frames += content['video']['min_n_frames']
content['video']['n_frames'] = content['video']['min_n_frames']
has_dynamic = True
elif 'fps' in content['video']:
content['video']['n_frames'] = self.max_n_frames
curr_n_frames += self.max_n_frames
has_dynamic = True
else:
content['video']['n_frames'] = 0
has_dynamic = True
while curr_n_frames < n_frames and has_dynamic:
for msg in data_dict['messages']:
for content in msg['content']:
if content['type'] == 'video':
if 'frame_indices' in content['video']:
pass
elif 'time_indices' in content['video']:
pass
else:
if curr_n_frames < n_frames:
content['video']['n_frames'] += 1
curr_n_frames += 1
while curr_n_frames > self.max_n_frames and has_dynamic:
for msg in data_dict['messages']:
for content in msg['content']:
if content['type'] == 'video':
if 'frame_indices' in content['video']:
pass
elif 'time_indices' in content['video']:
pass
else:
if curr_n_frames > self.max_n_frames:
content['video']['n_frames'] -= 1
curr_n_frames -= 1
for msg in data_dict['messages']:
for content in msg['content']:
if content['type'] == 'video':
if 'frame_indices' in content['video']:
pass
elif 'time_indices' in content['video']:
pass
else:
n = self.video_sampling_strategy.get('force_frames_n_divisible', 1)
if n > 1 and content['video']['n_frames'] % n != 0:
content['video']['n_frames'] += n - content['video']['n_frames'] % n
def load_image_item(self, image_item) -> Image.Image:
"""
image_item:
{"image_file": {"lq": "/path/to/image"}}
{"video_file": {"lq": "/path/to/video"}, "frame_indices": 0}
"""
# check format
if ("image_file" not in image_item) and ("video_file" not in image_item):
raise KeyError(f"Key 'image_file' or 'video_file' not found in image_item")
if 'image_file' in image_item:
if not isinstance(image_item['image_file'], str):
raise ValueError(f"{image_item['image_file']} is not a str!")
if 'video_file' in image_item:
if not isinstance(image_item['frame_indices'], int):
raise ValueError(f"{image_item['frame_indices']} is not a int!")
if 'image_file' in image_item:
image = read_image(image_item['image_file'])
else:
frame_indices = [image_item['frame_indices']]
image = sample_video(image_item['video_file'], frame_indices = frame_indices)[0]
return image
def load_video_item(self, video_item) -> List[Image.Image]:
"""
video_item:
{"video_file": {"lq": "/path/to/video"}, "n_frames": 8}
{"video_file": {"lq": "/path/to/video"}, "frame_indices": [0, 1, 2], "n_frames": 3}
{"video_file": {"lq": "/path/to/video"}, "start_frame": 0, "end_frame": 100, "n_frames": 8}
{"video_file": {"lq": "/path/to/video"}, "time_indices": [0, 1, 2], "n_frames": 3}
{"video_file": {"lq": "/path/to/video"}, "start_time": 0, "end_time": 100, "n_frames": 8}
{"image_file": {"lq": ["/path/to/image"]}, "frame_indices": [0, 1, 2], "n_frames": 3}
"""
# check format
if ("image_file" not in video_item) and ("video_file" not in video_item):
raise KeyError(f"Key 'image_file' or 'video_file' not found in video_item")
video_path = video_item.get('video_file', video_item.get('image_file'))
n_frames = video_item.get('n_frames', None)
frame_indices = video_item.get('frame_indices', None)
start_frame = video_item.get('start_frame', None)
end_frame = video_item.get('end_frame', None)
time_indices = video_item.get('time_indices', None)
start_time = video_item.get('start_time', None)
end_time = video_item.get('end_time', None)
mask_boxes = video_item.get('mask_boxes', None)
fps = video_item.get('fps', None)
frames, frame_indices = sample_video(
video_path=video_path,
frame_indices=frame_indices,
start_frame=start_frame,
end_frame=end_frame,
n_frames=n_frames,
time_indices=time_indices,
start_time=start_time,
end_time=end_time,
sampling_fps=fps,
mask_boxes=mask_boxes,
is_training=self.is_training,
video_sampling_strategy=self.video_sampling_strategy,
return_frame_ids=True,
)
if self.video_sampling_strategy.get('use_multi_images_for_video', False):
new_frames = []
for f in frames:
new_frames.extend([f, f])
frames = new_frames
if isinstance(frame_indices, dict):
return {
'frames': frames,
'video_info': frame_indices
}
return {'frames': frames}
def postprocess_text(self, content, data_dict, image_processing_config, first_image):
if image_processing_config.get('has_coordinates') and image_processing_config.get('do_padding'):
content['text'] = adjust_bbox(content['text'], frame=first_image)
if data_dict.get('task') == 'image/OCR' and image_processing_config.get('has_coordinates'):
content['text'] = filter_ocr_polygon(content['text'])