from typing import Dict, List import random import re from PIL import Image from .utils import sample_video, read_image class MultiImagesParser: def __init__( self, n_frames=8, is_training=True, ): self.n_frames = n_frames self.is_training = is_training # fmt: off self.data_temp = { "text": [ [{ "prompt": "Describe the image in short.", "response": "A rollerblader rides high in a full pipe while others watch" }], [{ "prompt": "Describe the image in short.", "response": "A woman in winter clothes is on the sidewalk with a phone." }] ], "image": [ { "image_file": "/mnt/bn/videonaslq/images/flickr30k/images/3371533654.jpg" }, { "image_file": "/mnt/bn/videonaslq/images/coco/train2014/COCO_train2014_000000177950.jpg" }, { "video_file": "/mnt/bn/llmdatalq/jiangnan/video_generation/webvid_10M_download/20230609/videos/011851_011900/1047443473.mp4", "frame_indices": [0, 85, 171, 256, 342, 427, 513, 598] } ], "dataset": "coco", "task": "multi_images", "image_processing_config": {}, } # fmt: on def check_format(self, data_dict: Dict, image_processing_config: Dict): assert data_dict['dataset'] in ['coco', 'sharegpt4v_cap100k', 'sharegpt4v_mix665k', 'webvid', 'movie'], data_dict # 目前多图数据应该没有包含坐标的数据吧 if image_processing_config.get('has_coordinates', False): raise ValueError(f'do_crop and has_coordinates cannot be True at the same time in MultiImagesParser!') # 检查是否能匹配到坐标 texts = data_dict['text'] for text in texts: match = re.search(r'\[(\d+(\.\d+)?,\s*)+\d+(\.\d+)?\]', text['prompt'] + text['response']) if match: print(f'[Warning] 疑似检测到包含坐标的数据:{data_dict}') def transform(self, data_dict: Dict, image_processing_config: Dict = None) -> Dict: self.check_format(data_dict, image_processing_config) # shuffle texts = data_dict['text'] images = data_dict['image'] images = self.load_images(images) idxs = list(range(len(texts))) random.shuffle(idxs) texts = [texts[i] for i in idxs] images = [images[i] for i in idxs] # sample n_frames if isinstance(self.n_frames, int): n_frames = random.choice(list(range(1, self.n_frames + 1))) else: n_frames = random.choice(self.n_frames) texts = texts[: n_frames] images = images[: n_frames] dataset = data_dict['dataset'] if dataset in ['coco', 'sharegpt4v_cap100k', 'webvid', 'movie']: prompt, response = self.transform_for_caption_task(texts, dataset, images) else: prompt, response = self.transform_for_qa_task(texts, dataset, images) messages = [ { "role": "user", "content": [ *[{"type": "image", "image": img} for img in images], {"type": "text", "text": prompt}, ] }, { "role": "assistant", "content": [ {"type": "text", "text": response} ] } ] return messages def transform_for_caption_task(self, texts, dataset, images): idx = random.choice(list(range(len(texts)))) if dataset == 'coco': if len(texts) == 1: prompt = 'Describe the image in short.' else: prompt = f'Describe the images starting from frame {idx + 1} in short in order.' elif dataset == 'sharegpt4v_cap100k': if len(texts) == 1: prompt = 'Describe the image in detail.' else: prompt = f'Describe the images starting from frame {idx + 1} in detail in order.' else: if len(texts) == 1: prompt = 'Describe the image.' else: prompt = f'Describe the images starting from frame {idx + 1} in order.' response = '' for i, text in enumerate(texts): if i < idx: continue if not isinstance(text, dict): text = random.choice(text) resp = text['response'] response += f'{resp}\n' return prompt, response def transform_for_qa_task(self, texts, dataset, images): prompt, response = '', '' for i, text in enumerate(texts): if not isinstance(text, dict): text = random.choice(text) if len(texts) > 1: prompt += f'Question for frame {i+1}:\n' + text['prompt'] + '\n' response += f'Answer to question of frame {i+1}:\n' + text['response'] + '\n' else: prompt += text['prompt'] + '\n' response += text['response'] + '\n' return prompt, response def load_images(self, image_items: List[Dict]) -> List[Image.Image]: """ image_items: List[Dict]. each item like: {"video_file": "path/to/video", "frame_indices": [1]} or {"image_file": "path/to/image"} """ if image_items is None: raise ValueError(f'image_items is None!') if isinstance(image_items, dict): image_items = [image_items] images = [] for image_item in image_items: if 'video_file' in image_item: file_key = 'video_file' elif 'image_file' in image_item: file_key = 'image_file' else: raise KeyError(f'video_file or image_file not in {image_item}') file_path = image_item[file_key] if file_key == 'video_file': frame_indices = image_item.get('frame_indices', None) if frame_indices is None: raise ValueError(f'read 0 frame: {image_item}') if isinstance(frame_indices, int): frame_indices = [frame_indices] frames = sample_video(file_path, frame_indices = frame_indices) images.extend(frames) else: if isinstance(file_path, str): file_path = [file_path] images.extend([read_image(f) for f in file_path]) return images if __name__ == '__main__': # python3 -m xenon_generation.data.custom_data_parsers.multi_images_parser from tqdm import tqdm from tools.rw_utils import read_jsonlines lines = read_jsonlines('/mnt/bn/videonaslq/VideoCaption/datasets_1009/sharegpt4v_cap100k/part_36.jsonl') lines = lines[:10] parser = MultiImagesParser(n_frames=8) for i, l in tqdm(enumerate(lines)): l_image_processing_config = l.get('image_processing_config', {}) messages = parser.transform(l, l_image_processing_config) print(messages)