Spaces:
Running
on
Zero
Running
on
Zero
from typing import Dict, List | |
import random | |
import re | |
from PIL import Image | |
from .utils import sample_video, read_image | |
class MultiImagesParser: | |
def __init__( | |
self, | |
n_frames=8, | |
is_training=True, | |
): | |
self.n_frames = n_frames | |
self.is_training = is_training | |
# fmt: off | |
self.data_temp = { | |
"text": [ | |
[{ | |
"prompt": "Describe the image in short.", | |
"response": "A rollerblader rides high in a full pipe while others watch" | |
}], | |
[{ | |
"prompt": "Describe the image in short.", | |
"response": "A woman in winter clothes is on the sidewalk with a phone." | |
}] | |
], | |
"image": [ | |
{ | |
"image_file": "/mnt/bn/videonaslq/images/flickr30k/images/3371533654.jpg" | |
}, | |
{ | |
"image_file": "/mnt/bn/videonaslq/images/coco/train2014/COCO_train2014_000000177950.jpg" | |
}, | |
{ | |
"video_file": "/mnt/bn/llmdatalq/jiangnan/video_generation/webvid_10M_download/20230609/videos/011851_011900/1047443473.mp4", | |
"frame_indices": [0, 85, 171, 256, 342, 427, 513, 598] | |
} | |
], | |
"dataset": "coco", | |
"task": "multi_images", | |
"image_processing_config": {}, | |
} | |
# fmt: on | |
def check_format(self, data_dict: Dict, image_processing_config: Dict): | |
assert data_dict['dataset'] in ['coco', 'sharegpt4v_cap100k', 'sharegpt4v_mix665k', 'webvid', 'movie'], data_dict | |
# 目前多图数据应该没有包含坐标的数据吧 | |
if image_processing_config.get('has_coordinates', False): | |
raise ValueError(f'do_crop and has_coordinates cannot be True at the same time in MultiImagesParser!') | |
# 检查是否能匹配到坐标 | |
texts = data_dict['text'] | |
for text in texts: | |
match = re.search(r'\[(\d+(\.\d+)?,\s*)+\d+(\.\d+)?\]', text['prompt'] + text['response']) | |
if match: | |
print(f'[Warning] 疑似检测到包含坐标的数据:{data_dict}') | |
def transform(self, data_dict: Dict, image_processing_config: Dict = None) -> Dict: | |
self.check_format(data_dict, image_processing_config) | |
# shuffle | |
texts = data_dict['text'] | |
images = data_dict['image'] | |
images = self.load_images(images) | |
idxs = list(range(len(texts))) | |
random.shuffle(idxs) | |
texts = [texts[i] for i in idxs] | |
images = [images[i] for i in idxs] | |
# sample n_frames | |
if isinstance(self.n_frames, int): | |
n_frames = random.choice(list(range(1, self.n_frames + 1))) | |
else: | |
n_frames = random.choice(self.n_frames) | |
texts = texts[: n_frames] | |
images = images[: n_frames] | |
dataset = data_dict['dataset'] | |
if dataset in ['coco', 'sharegpt4v_cap100k', 'webvid', 'movie']: | |
prompt, response = self.transform_for_caption_task(texts, dataset, images) | |
else: | |
prompt, response = self.transform_for_qa_task(texts, dataset, images) | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
*[{"type": "image", "image": img} for img in images], | |
{"type": "text", "text": prompt}, | |
] | |
}, | |
{ | |
"role": "assistant", | |
"content": [ | |
{"type": "text", "text": response} | |
] | |
} | |
] | |
return messages | |
def transform_for_caption_task(self, texts, dataset, images): | |
idx = random.choice(list(range(len(texts)))) | |
if dataset == 'coco': | |
if len(texts) == 1: | |
prompt = 'Describe the image in short.' | |
else: | |
prompt = f'Describe the images starting from frame {idx + 1} in short in order.' | |
elif dataset == 'sharegpt4v_cap100k': | |
if len(texts) == 1: | |
prompt = 'Describe the image in detail.' | |
else: | |
prompt = f'Describe the images starting from frame {idx + 1} in detail in order.' | |
else: | |
if len(texts) == 1: | |
prompt = 'Describe the image.' | |
else: | |
prompt = f'Describe the images starting from frame {idx + 1} in order.' | |
response = '' | |
for i, text in enumerate(texts): | |
if i < idx: | |
continue | |
if not isinstance(text, dict): | |
text = random.choice(text) | |
resp = text['response'] | |
response += f'{resp}\n' | |
return prompt, response | |
def transform_for_qa_task(self, texts, dataset, images): | |
prompt, response = '', '' | |
for i, text in enumerate(texts): | |
if not isinstance(text, dict): | |
text = random.choice(text) | |
if len(texts) > 1: | |
prompt += f'Question for frame {i+1}:\n' + text['prompt'] + '\n' | |
response += f'Answer to question of frame {i+1}:\n' + text['response'] + '\n' | |
else: | |
prompt += text['prompt'] + '\n' | |
response += text['response'] + '\n' | |
return prompt, response | |
def load_images(self, image_items: List[Dict]) -> List[Image.Image]: | |
""" | |
image_items: List[Dict]. each item like: | |
{"video_file": "path/to/video", "frame_indices": [1]} | |
or | |
{"image_file": "path/to/image"} | |
""" | |
if image_items is None: | |
raise ValueError(f'image_items is None!') | |
if isinstance(image_items, dict): | |
image_items = [image_items] | |
images = [] | |
for image_item in image_items: | |
if 'video_file' in image_item: | |
file_key = 'video_file' | |
elif 'image_file' in image_item: | |
file_key = 'image_file' | |
else: | |
raise KeyError(f'video_file or image_file not in {image_item}') | |
file_path = image_item[file_key] | |
if file_key == 'video_file': | |
frame_indices = image_item.get('frame_indices', None) | |
if frame_indices is None: | |
raise ValueError(f'read 0 frame: {image_item}') | |
if isinstance(frame_indices, int): | |
frame_indices = [frame_indices] | |
frames = sample_video(file_path, frame_indices = frame_indices) | |
images.extend(frames) | |
else: | |
if isinstance(file_path, str): | |
file_path = [file_path] | |
images.extend([read_image(f) for f in file_path]) | |
return images | |
if __name__ == '__main__': | |
# python3 -m xenon_generation.data.custom_data_parsers.multi_images_parser | |
from tqdm import tqdm | |
from tools.rw_utils import read_jsonlines | |
lines = read_jsonlines('/mnt/bn/videonaslq/VideoCaption/datasets_1009/sharegpt4v_cap100k/part_36.jsonl') | |
lines = lines[:10] | |
parser = MultiImagesParser(n_frames=8) | |
for i, l in tqdm(enumerate(lines)): | |
l_image_processing_config = l.get('image_processing_config', {}) | |
messages = parser.transform(l, l_image_processing_config) | |
print(messages) |