import copy import json import logging import os import pdb import re import time import traceback from collections import defaultdict from typing import Dict, List, Optional, Sequence import numpy as np import torch import transformers from transformers.trainer_pt_utils import LabelSmoother from .dataset_base import BaseDataset IGNORE_TOKEN_ID = LabelSmoother.ignore_index class MistralDataset(BaseDataset): def __init__( self, *args, **kwargs, ): super().__init__( *args, **kwargs, ) self.system_message = "You are a helpful AI assistant." self.system_message = None self.ret = defaultdict(dict) self.is_cat = True def maybe_init_ret(self, source, force=False): if source not in self.ret or force: self.ret[source] = {} self.ret[source]["tokens"] = [] self.ret[source]["labels"] = [] self.ret[source]["actual_seq_len"] = [] if self.create_position_ids: self.ret[source]["position_ids"] = [] if self.create_attention_mask: self.ret[source]["attention_mask"] = [] if self.create_attention_mask_2d: self.ret[source]["attention_mask_2d"] = torch.tril( torch.ones( (1, self.max_padding_length, self.max_padding_length), dtype=torch.bool ) ) return len(self.ret[source]["tokens"]) == 0 def __getitem__(self, index): while True: try: # if True: self.processed_samples += 1 if self.processed_samples < self.skip_samples: if self.processed_samples % 1e3 == 0: print( f"processed_samples {self.processed_samples} skip_samples {self.skip_samples}" ) return {} source = self.raw_data[index]["source"] is_empty = self.maybe_init_ret(source) is_begin = is_empty or self.reset_position_ids or self.reset_attention_mask ret = preprocess( self.raw_data[index], self.tokenizer, self.image_token_length, system_message=self.system_message, image_processor=self.processor["image"], is_begin=is_begin, max_num_frame=self.max_num_frame, max_fps=self.max_fps, ) if ret is None: return {} if len(ret["input_ids"]) > self.max_padding_length: return {} if len(ret["images"]) > self.max_num_frame: return {} self.unjoint_samples += 1 if self.max_padding_length > 2**17: log_interval = 1e3 elif self.max_padding_length > 2**15: log_interval = 1e4 else: log_interval = 1e4 if self.unjoint_samples % log_interval == 0: print( f"processed_samples {self.processed_samples} unjoint_samples {self.unjoint_samples}" ) all_length = len(self.ret[source]["tokens"]) cur_length = len(ret["input_ids"]) # print("self.ret", self.ret.keys()) # print("source", source) if all_length + cur_length > self.max_padding_length or ( "images" in self.ret[source] and len(self.ret[source]["images"]) + len(ret["images"]) > self.max_num_image ): # if "tokens" in self.ret[source] and len(self.ret[source]["tokens"]) > 0: to_ret = copy.deepcopy(self.ret[source]) self.maybe_init_ret(source, force=True) else: to_ret = {} all_length = len(self.ret[source]["tokens"]) cur_length = len(ret["input_ids"]) if "images" in ret and len(ret["images"]) > 0: if "images" in self.ret[source]: self.ret[source]["images"] = torch.cat( [self.ret[source]["images"], ret["images"]], dim=0 ) ret["image_indices"][1, :, :] += all_length self.ret[source]["image_indices"] = torch.cat( [self.ret[source]["image_indices"], ret["image_indices"]], dim=1 ) else: self.ret[source]["images"] = ret["images"] self.ret[source]["image_indices"] = ret["image_indices"] if self.create_attention_mask: self.ret[source]["attention_mask"] += ret["attention_mask"] if self.create_attention_mask_2d: self.ret[source]["attention_mask_2d"][:, all_length:, :all_length] = 0 if self.create_position_ids: self.ret[source]["position_ids"] += list(range(cur_length)) self.ret[source]["tokens"] += ret["input_ids"] self.ret[source]["labels"] += ret["labels"] self.ret[source]["actual_seq_len"] += [all_length + cur_length] if "tokens" in to_ret: if self.create_position_ids: if self.reset_position_ids: pass else: to_ret["position_ids"] = list(range(len(to_ret["tokens"]))) if self.create_attention_mask_2d: if sefl.reset_attention_mask: pass else: to_ret["attention_mask_2d"] = torch.tril( torch.ones( (1, self.max_padding_length, self.max_padding_length), dtype=torch.bool, ) ) if self.shift_token: to_ret["tokens"] = to_ret["tokens"][:-1] to_ret["labels"] = to_ret["labels"][1:] to_ret["actual_seq_len"][-1] -= 1 if self.create_position_ids: to_ret["position_ids"] = to_ret["position_ids"][:-1] if self.create_attention_mask: to_ret["attention_mask"] = to_ret["attention_mask"][:-1] if self.create_attention_mask_2d: to_ret["attention_mask_2d"][:, :, -1] = 0 to_ret["attention_mask_2d"][:, -1, :] = 0 assert len(to_ret["tokens"]) == len( to_ret["labels"] ), f"{len(to_ret['tokens'])} {len(to_ret['labels'])}" if not self.variable_length and self.max_padding_length > len(to_ret["tokens"]): to_ret["tokens"] += [self.tokenizer.pad_token_id] * ( self.max_padding_length - len(to_ret["tokens"]) ) to_ret["labels"] += [IGNORE_TOKEN_ID] * ( self.max_padding_length - len(to_ret["labels"]) ) to_ret["actual_seq_len"][-1] = self.max_padding_length if self.create_position_ids: to_ret["position_ids"] += to_ret["position_ids"][-1:] * ( self.max_padding_length - len(to_ret["position_ids"]) ) if self.create_attention_mask: to_ret["attention_mask"] += [0] * ( self.max_padding_length - len(to_ret["attention_mask"]) ) to_ret["tokens"] = to_ret["tokens"][: self.max_padding_length] to_ret["labels"] = to_ret["labels"][: self.max_padding_length] to_ret["actual_seq_len"][-1] = self.max_padding_length if self.create_position_ids: to_ret["position_ids"] = to_ret["position_ids"][: self.max_padding_length] if self.create_attention_mask: to_ret["attention_mask"] = to_ret["attention_mask"][ : self.max_padding_length ] to_ret["tokens"] = torch.tensor(to_ret["tokens"], dtype=torch.int64) to_ret["labels"] = torch.tensor(to_ret["labels"], dtype=torch.int64) to_ret["actual_seq_len"] = torch.tensor( to_ret["actual_seq_len"], dtype=torch.int64 ) if self.create_position_ids: to_ret["position_ids"] = torch.tensor( to_ret["position_ids"], dtype=torch.int64 ) if self.create_attention_mask: to_ret["attention_mask"] = torch.tensor( to_ret["attention_mask"], dtype=torch.int64 ) if self.create_attention_mask_2d: attention_mask_2d = to_ret.pop("attention_mask_2d") attention_mask_2d = attention_mask_2d.masked_fill( (to_ret["attention_mask"] < 0.5).view(1, 1, self.max_padding_length), value=0, ) attention_mask_2d = attention_mask_2d < 0.5 to_ret["attention_mask"] = attention_mask_2d if self.create_loss_mask: loss_mask = torch.where(to_ret["labels"] == IGNORE_TOKEN_ID, 0, 1) to_ret["loss_mask"] = loss_mask.to(torch.float32) if not self.reset_position_ids and not self.reset_attention_mask: to_ret.pop("actual_seq_len") # print("to_ret[tokens]", to_ret["tokens"]) # print("to_ret[labels]", to_ret["labels"]) return to_ret except Exception as error: with open(os.path.join(self.output_dir, "dataset_error.log"), "a") as f: print(error, file=f) print([self.raw_data[index]], file=f) if index == 0: index += 1 else: index -= 1 def preprocess( source, tokenizer: transformers.PreTrainedTokenizer, image_token_length: int, system_message: str = "You are a helpful assistant.", image_processor=None, is_begin: bool = True, max_num_frame: int = 8, max_fps: int = 1, ) -> Dict: # [INST]Hello, how are you?[/INST]I'm doing great. How can I help you today?[INST]I'd like to show off how chat templating works![/INST] from ..constants import ( IMG_START_TOKEN, IMG_END_TOKEN, IMG_CONTEXT_TOKEN, VID_START_TOKEN, VID_END_TOKEN, VID_CONTEXT_TOKEN, PATCH_START_TOKEN, PATCH_END_TOKEN, PATCH_CONTEXT_TOKEN, ) image_tag = "" video_tag = "