Spaces:
Running
on
Zero
Running
on
Zero
"""Datamodule for Llava Pretraining and Finetuning""" | |
import os | |
import re | |
from PIL import Image | |
import numpy as np | |
import re | |
import tempfile | |
from typing import Dict, List, Union, Tuple | |
import traceback | |
import json | |
import torch | |
import torch.nn.functional as F | |
from transformers import DataCollatorForSeq2Seq | |
from tools.rw_utils import read_jsonlines | |
from torch.utils.data import Dataset, DataLoader | |
np_str_obj_array_pattern = re.compile(r"[SaUO]") | |
default_collate_err_msg_format = ( | |
"default_collate: batch must contain tensors, numpy arrays, numbers, " | |
"dicts or lists; found {}" | |
) | |
from .custom_data_parsers.standard_vision_parser import VisionParser | |
from .custom_data_parsers.object_tracking_parser import ObjectTrackingParser | |
from .custom_data_parsers.multi_images_parser import MultiImagesParser | |
from .custom_data_parsers.video_permutation_parser import VideoPermutationParser | |
from .custom_data_parsers.utils_visualize import visualize_image_bbox | |
from .tarsier_processor import TarsierProcessor | |
from tools.rw_utils import NumpyArrayEncoder | |
from .utils import DictToObject | |
import os | |
HF_TOKEN = os.environ.get('HF_TOKEN', '') | |
class TarsierDataProcessor: | |
def __init__( | |
self, | |
processor: TarsierProcessor, | |
n_frames: Union[int, list], | |
max_n_frames=256, | |
max_pixels=int(1280 * 720 // 2), | |
min_pixels=0, | |
max_seq_len=None, | |
is_training=True, # 会影响:1. 训练和测试时采帧不同;2. 测试时忽略 response。 | |
print_data_error=True, | |
do_image_padding=False, | |
do_image_crop=False, | |
do_image_resize=True, | |
video_sampling_strategy={}, | |
prompt='', | |
train_task='sft', | |
**kwargs | |
): | |
self.kwargs = kwargs | |
self.processor = processor | |
self.pad_collator = DataCollatorForSeq2Seq(processor.tokenizer, padding='longest') | |
self.processor.max_seq_len = self.tokenizer.model_max_length if max_seq_len is None else max_seq_len | |
self.n_frames = n_frames | |
self.max_n_frames = max_n_frames | |
self.max_pixels = max_pixels | |
self.min_pixels = min_pixels | |
self.is_training = is_training | |
self.print_data_error = print_data_error | |
self.do_image_padding = do_image_padding | |
self.do_image_crop = do_image_crop | |
self.do_image_resize = do_image_resize | |
self.video_sampling_strategy = video_sampling_strategy | |
self.prompt = prompt | |
self.train_task = train_task | |
self.object_tracking_parser = ObjectTrackingParser( | |
n_frames=self.n_frames, | |
max_objects=4, | |
is_training=self.is_training, | |
) | |
self.multi_images_parser = MultiImagesParser( | |
n_frames=self.n_frames, | |
is_training=self.is_training, | |
) | |
self.video_permutation_parser = VideoPermutationParser( | |
n_frames=self.n_frames, | |
is_training=self.is_training, | |
video_sampling_strategy=self.video_sampling_strategy, | |
) | |
self.vision_parser = VisionParser( | |
n_frames=self.n_frames, | |
max_n_frames=self.max_n_frames, | |
is_training=self.is_training, | |
video_sampling_strategy=self.video_sampling_strategy | |
) | |
def select_parser(self, data_dict): | |
if data_dict.get('task', None) == 'video/object_tracking': | |
return self.object_tracking_parser | |
elif data_dict.get('task', None) == 'multi_images': | |
return self.multi_images_parser | |
elif data_dict.get('dataset', None) == 'video_permutation': | |
return self.video_permutation_parser | |
else: | |
return self.vision_parser | |
def parse_image_processing_config(self, data_dict): | |
image_processing_config=data_dict.get('image_processing_config', {}) | |
do_padding = image_processing_config.get('do_padding', self.do_image_padding) | |
do_crop = image_processing_config.get('do_crop', self.do_image_crop) | |
do_resize = image_processing_config.get('do_resize', self.do_image_resize) | |
max_pixels = image_processing_config.get('max_pixels', self.max_pixels) | |
min_pixels = image_processing_config.get('min_pixels', self.min_pixels) | |
assert min_pixels <= max_pixels | |
image_processing_config['do_padding'] = do_padding | |
image_processing_config['do_crop'] = do_crop | |
image_processing_config['do_resize'] = do_resize | |
image_processing_config['max_pixels'] = max_pixels | |
image_processing_config['min_pixels'] = min_pixels | |
return image_processing_config | |
def _transform(self, raw_data_dict: Dict) -> Dict: | |
data_dict = json.loads(json.dumps(raw_data_dict, cls=NumpyArrayEncoder)) | |
del raw_data_dict | |
if self.prompt: | |
for msg in data_dict['messages']: | |
if msg['role'] == 'user': | |
for content in msg['content']: | |
if content['type'] == 'text': | |
content['text'] = self.prompt | |
data_dict_copy = json.loads(json.dumps(data_dict, cls=NumpyArrayEncoder)) | |
image_processing_config = self.parse_image_processing_config(data_dict) | |
parser = self.select_parser(data_dict) | |
messages = parser.transform(data_dict, image_processing_config) | |
data_dict_copy['extra_info'] = data_dict.pop('extra_info', {}) | |
# visualize_image_bbox(data_dict, image_processing_config, self.processor) | |
outputs = self.processor(messages, image_processing_config, is_training=self.is_training) | |
# if not self.is_training: | |
outputs['raw_data_dict'] = data_dict_copy | |
return [outputs] | |
def _split_chosen_rejected(self, data_dict: Dict): | |
chosen_data_dict = data_dict | |
rejected_data_dict = json.loads(json.dumps(data_dict, cls=NumpyArrayEncoder)) | |
for msg in chosen_data_dict['messages']: | |
if msg['role'] == 'assistant': | |
for content in msg['content']: | |
if content['type'] == 'text': | |
content['text'] = content['chosen'] | |
for msg in rejected_data_dict['messages']: | |
if msg['role'] == 'assistant': | |
for content in msg['content']: | |
if content['type'] == 'text': | |
content['text'] = content['rejected'] | |
return chosen_data_dict, rejected_data_dict | |
def transform(self, data_dict: Dict) -> Dict: | |
try: | |
if self.train_task == 'dpo': | |
chosen_data_dict, rejected_data_dict = self._split_chosen_rejected(data_dict) | |
return self._transform(chosen_data_dict) + self._transform(rejected_data_dict) | |
return self._transform(data_dict) | |
except Exception as e: | |
if self.print_data_error: | |
print(traceback.format_exc()) | |
print(f'Error occurs when processing: \n{data_dict}') | |
return [] | |
def batch_transform(self, batch_data: List[Dict]) -> Dict: | |
model_inputs = {} | |
# if not self.is_training: | |
raw_data_dict = [d.pop('raw_data_dict') for d in batch_data] | |
model_inputs['raw_data_dict'] = raw_data_dict | |
batch_pixel_values = [d.pop('pixel_values') for d in batch_data if 'pixel_values' in d] | |
batch_image_grid_thw = [d.pop('image_grid_thw') for d in batch_data if 'image_grid_thw' in d] | |
if len(batch_pixel_values) == 0: | |
vision_placeholder = self.get_vision_placeholder() | |
batch_pixel_values = [vision_placeholder.get('pixel_values')] | |
batch_image_grid_thw = [vision_placeholder.get('image_grid_thw')] if 'image_grid_thw' in vision_placeholder else [] | |
model_inputs['pixel_values'] = torch.cat(batch_pixel_values, dim=0) | |
if len(batch_image_grid_thw) > 0: | |
model_inputs['image_grid_thw'] = torch.cat(batch_image_grid_thw, dim=0) | |
batch_num_images = [d.pop('num_images') for d in batch_data] | |
model_inputs['num_images'] = torch.tensor(batch_num_images) | |
model_inputs.update(self.pad_collator(batch_data)) | |
return model_inputs | |
def __call__(self, batch_data: Union[Dict, List[Dict]]) -> Dict: | |
if isinstance(batch_data, dict): | |
batch_data = [batch_data] | |
batch = [self.transform(d)[0] for d in batch_data] | |
return self.batch_transform(batch) | |
def get_vision_placeholder(self): | |
messages = [{"role": "user", "content": [{"type": "image", "image": Image.new(mode='RGB', size=(336, 336))}]}] | |
image_processing_config = self.parse_image_processing_config({}) | |
return self.processor(messages, image_processing_config) | |
def get_text_placeholder(self): | |
messages = [ | |
{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}, | |
{"role": "assistant", "content": [{"type": "text", "text": "Thank you very much"}]}, | |
] | |
image_processing_config = self.parse_image_processing_config({}) | |
return self.processor(messages, image_processing_config) | |
def init_processor(processor: Union[TarsierProcessor, str]=None, config: Dict=None): | |
config = DictToObject(config) if isinstance(config, dict) else config | |
if isinstance(processor, str): | |
sub_processor = TarsierProcessor.from_pretrained( | |
processor, | |
padding_side='left', | |
trust_remote_code=True, | |
token=HF_TOKEN, | |
) | |
else: | |
sub_processor = processor | |
processor = TarsierDataProcessor( | |
processor=sub_processor, | |
n_frames=config.n_frames, | |
max_n_frames=config.max_n_frames, | |
max_pixels=config.max_pixels, | |
min_pixels=config.min_pixels, | |
max_seq_len=config.max_seq_len, | |
is_training=config.is_training, | |
print_data_error=config.print_data_error, | |
do_image_padding=config.do_image_padding, | |
do_image_crop=config.do_image_crop, | |
do_image_resize=config.do_image_resize, | |
video_sampling_strategy=config.video_sampling_strategy, | |
prompt=config.prompt, | |
train_task=config.train_task | |
) | |
return processor | |
class TarsierDataset(Dataset): | |
def __init__(self, ann_path="", anns=None, config: Dict=None, processor: Union[TarsierDataProcessor, TarsierProcessor, str]=None): | |
self.config = DictToObject(config) if isinstance(config, dict) else config | |
if not isinstance(processor, TarsierDataProcessor): | |
self.processor = init_processor(processor, config) | |
else: | |
self.processor = processor | |
if anns is None: | |
self.anns = [] | |
if isinstance(ann_path, str): | |
ann_path = [ann_path] | |
for path in ann_path: | |
self.anns.extend(read_jsonlines(path)) | |
else: | |
self.anns = anns | |
def __len__(self): | |
return len(self.anns) | |
def __getitem__(self, index): | |
if index < 0 or index >= len(self.anns): | |
raise IndexError("Index out of range") | |
try: | |
ann = self.anns[index] | |
model_inputs = self.processor(ann) | |
except Exception as e: | |
print(f"Load data error: {e}") | |
return ann, None | |
return ann, model_inputs | |