import itertools import json import logging import math import os import re from dataclasses import dataclass, field from typing import Dict, List, Optional, Sequence import torch import transformers from torch.utils.data import default_collate from transformers.trainer_pt_utils import LabelSmoother IGNORE_TOKEN_ID = LabelSmoother.ignore_index @dataclass class DataCollatorForSupervisedDataset(object): tokenizer: transformers.PreTrainedTokenizer def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: input_ids = [instance["input_ids"] for instance in instances] labels = [instance["labels"] for instance in instances] input_ids = torch.nn.utils.rnn.pad_sequence( input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id, ) labels = torch.nn.utils.rnn.pad_sequence( labels, batch_first=True, padding_value=IGNORE_TOKEN_ID, ) input_ids = input_ids[:, : self.tokenizer.model_max_length] labels = labels[:, : self.tokenizer.model_max_length] batch = dict( input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id), ) if "images" in instances[0]: images = [instance["images"] for instance in instances] batch["images"] = torch.cat(images, dim=0) if "doclm_images" in instances[0]: images = [instance["doclm_images"] for instance in instances] batch["doclm_images"] = torch.cat(images, dim=0) if "image_paths" in instances[0]: image_paths = [instance["image_paths"] for instance in instances] batch["image_paths"] = image_paths if "pixel_values" in instances[0]: pixel_values = torch.cat([instance["pixel_values"] for instance in instances]) batch["pixel_values"] = pixel_values if "image_flags" in instances[0]: image_flags = torch.cat([instance["image_flags"] for instance in instances]) batch["image_flags"] = image_flags return batch def collate_fn_deepspeed_old(batch): keys = list(set().union(*[set(x.keys()) for x in batch])) tmp_batch = [{} for _ in range(len(batch))] if "actual_seq_len" in batch[0]: actual_seq_len = [x["actual_seq_len"] for x in batch] else: actual_seq_len = None for k in keys: if "images" in k or k == "image_indices": for x, y in zip(tmp_batch, batch): if k in y: x[k] = y.pop(k) # print("x[image_indices]", x["image_indices"].size()) new_batch = default_collate(batch) for k in keys: if "images" in k or k == "image_indices": cat_dim = 0 if k != "image_indices" else 1 if k == "image_indices": cnt = 0 for sample in tmp_batch: if k in sample: sample[k][0] = cnt cnt += 1 new_batch[k] = torch.cat([x[k] for x in tmp_batch if k in x], dim=cat_dim) # print("new_batch[image_indices]", new_batch["image_indices"].size()) if actual_seq_len is not None: seq_len = actual_seq_len[0][-1] actual_seq_len = [elem + i * seq_len for i, elem in enumerate(actual_seq_len)] new_batch["actual_seq_len"] = torch.cat(actual_seq_len) return new_batch def collate_fn_deepspeed(batch): keys = list(set().union(*[set(x.keys()) for x in batch])) # print(f"{keys=}") tmp_batch = [{} for _ in range(len(batch))] if "actual_seq_len" in batch[0]: actual_seq_len = [x["actual_seq_len"] for x in batch] else: actual_seq_len = None if "images" in batch[0].keys(): for new_x, x in zip(tmp_batch, batch): new_x["images"] = x.pop("images") new_x["image_indices"] = x.pop("image_indices") if "audios" in batch[0].keys(): for new_x, x in zip(tmp_batch, batch): new_x["audios"] = x.pop("audios") new_x["audio_indices"] = x.pop("audio_indices") new_batch = default_collate(batch) if "images" in tmp_batch[0].keys(): new_batch["images"] = torch.cat([x["images"] for x in tmp_batch], dim=0) for sample_idx, sample in enumerate(tmp_batch): for j in range(len(sample["image_indices"])): sample["image_indices"][j][0, :, :] = sample_idx new_batch["image_indices"] = torch.cat([x["image_indices"] for x in tmp_batch], dim=1) if "audios" in tmp_batch[0].keys(): new_batch["audios"] = list(itertools.chain.from_iterable([x["audios"] for x in tmp_batch])) # print(f"{[x.size() for x in sample['audios']]}") for sample_idx, sample in enumerate(tmp_batch): for j in range(len(sample["audio_indices"])): sample["audio_indices"][j][0, :, :] = sample_idx new_batch["audio_indices"] = list( itertools.chain.from_iterable([x["audio_indices"] for x in tmp_batch]) ) # print(f"{[x.size() for x in sample['audio_indices']]}") if actual_seq_len is not None: seq_len = actual_seq_len[0][-1] actual_seq_len = [elem + i * seq_len for i, elem in enumerate(actual_seq_len)] new_batch["actual_seq_len"] = torch.cat(actual_seq_len) return new_batch