Tarsier2-7b / dataset /custom_data_parsers /multi_images_parser.py
omni-research's picture
update to tarsier2-7b-0115
dcd4560
from typing import Dict, List
import random
import re
from PIL import Image
from .utils import sample_video, read_image
class MultiImagesParser:
def __init__(
self,
n_frames=8,
is_training=True,
):
self.n_frames = n_frames
self.is_training = is_training
# fmt: off
self.data_temp = {
"text": [
[{
"prompt": "Describe the image in short.",
"response": "A rollerblader rides high in a full pipe while others watch"
}],
[{
"prompt": "Describe the image in short.",
"response": "A woman in winter clothes is on the sidewalk with a phone."
}]
],
"image": [
{
"image_file": "/mnt/bn/videonaslq/images/flickr30k/images/3371533654.jpg"
},
{
"image_file": "/mnt/bn/videonaslq/images/coco/train2014/COCO_train2014_000000177950.jpg"
},
{
"video_file": "/mnt/bn/llmdatalq/jiangnan/video_generation/webvid_10M_download/20230609/videos/011851_011900/1047443473.mp4",
"frame_indices": [0, 85, 171, 256, 342, 427, 513, 598]
}
],
"dataset": "coco",
"task": "multi_images",
"image_processing_config": {},
}
# fmt: on
def check_format(self, data_dict: Dict, image_processing_config: Dict):
assert data_dict['dataset'] in ['coco', 'sharegpt4v_cap100k', 'sharegpt4v_mix665k', 'webvid', 'movie'], data_dict
# 目前多图数据应该没有包含坐标的数据吧
if image_processing_config.get('has_coordinates', False):
raise ValueError(f'do_crop and has_coordinates cannot be True at the same time in MultiImagesParser!')
# 检查是否能匹配到坐标
texts = data_dict['text']
for text in texts:
match = re.search(r'\[(\d+(\.\d+)?,\s*)+\d+(\.\d+)?\]', text['prompt'] + text['response'])
if match:
print(f'[Warning] 疑似检测到包含坐标的数据:{data_dict}')
def transform(self, data_dict: Dict, image_processing_config: Dict = None) -> Dict:
self.check_format(data_dict, image_processing_config)
# shuffle
texts = data_dict['text']
images = data_dict['image']
images = self.load_images(images)
idxs = list(range(len(texts)))
random.shuffle(idxs)
texts = [texts[i] for i in idxs]
images = [images[i] for i in idxs]
# sample n_frames
if isinstance(self.n_frames, int):
n_frames = random.choice(list(range(1, self.n_frames + 1)))
else:
n_frames = random.choice(self.n_frames)
texts = texts[: n_frames]
images = images[: n_frames]
dataset = data_dict['dataset']
if dataset in ['coco', 'sharegpt4v_cap100k', 'webvid', 'movie']:
prompt, response = self.transform_for_caption_task(texts, dataset, images)
else:
prompt, response = self.transform_for_qa_task(texts, dataset, images)
messages = [
{
"role": "user",
"content": [
*[{"type": "image", "image": img} for img in images],
{"type": "text", "text": prompt},
]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": response}
]
}
]
return messages
def transform_for_caption_task(self, texts, dataset, images):
idx = random.choice(list(range(len(texts))))
if dataset == 'coco':
if len(texts) == 1:
prompt = 'Describe the image in short.'
else:
prompt = f'Describe the images starting from frame {idx + 1} in short in order.'
elif dataset == 'sharegpt4v_cap100k':
if len(texts) == 1:
prompt = 'Describe the image in detail.'
else:
prompt = f'Describe the images starting from frame {idx + 1} in detail in order.'
else:
if len(texts) == 1:
prompt = 'Describe the image.'
else:
prompt = f'Describe the images starting from frame {idx + 1} in order.'
response = ''
for i, text in enumerate(texts):
if i < idx:
continue
if not isinstance(text, dict):
text = random.choice(text)
resp = text['response']
response += f'{resp}\n'
return prompt, response
def transform_for_qa_task(self, texts, dataset, images):
prompt, response = '', ''
for i, text in enumerate(texts):
if not isinstance(text, dict):
text = random.choice(text)
if len(texts) > 1:
prompt += f'Question for frame {i+1}:\n' + text['prompt'] + '\n'
response += f'Answer to question of frame {i+1}:\n' + text['response'] + '\n'
else:
prompt += text['prompt'] + '\n'
response += text['response'] + '\n'
return prompt, response
def load_images(self, image_items: List[Dict]) -> List[Image.Image]:
"""
image_items: List[Dict]. each item like:
{"video_file": "path/to/video", "frame_indices": [1]}
or
{"image_file": "path/to/image"}
"""
if image_items is None:
raise ValueError(f'image_items is None!')
if isinstance(image_items, dict):
image_items = [image_items]
images = []
for image_item in image_items:
if 'video_file' in image_item:
file_key = 'video_file'
elif 'image_file' in image_item:
file_key = 'image_file'
else:
raise KeyError(f'video_file or image_file not in {image_item}')
file_path = image_item[file_key]
if file_key == 'video_file':
frame_indices = image_item.get('frame_indices', None)
if frame_indices is None:
raise ValueError(f'read 0 frame: {image_item}')
if isinstance(frame_indices, int):
frame_indices = [frame_indices]
frames = sample_video(file_path, frame_indices = frame_indices)
images.extend(frames)
else:
if isinstance(file_path, str):
file_path = [file_path]
images.extend([read_image(f) for f in file_path])
return images
if __name__ == '__main__':
# python3 -m xenon_generation.data.custom_data_parsers.multi_images_parser
from tqdm import tqdm
from tools.rw_utils import read_jsonlines
lines = read_jsonlines('/mnt/bn/videonaslq/VideoCaption/datasets_1009/sharegpt4v_cap100k/part_36.jsonl')
lines = lines[:10]
parser = MultiImagesParser(n_frames=8)
for i, l in tqdm(enumerate(lines)):
l_image_processing_config = l.get('image_processing_config', {})
messages = parser.transform(l, l_image_processing_config)
print(messages)