import copy
import json
import logging
import os
import pdb
import re
import time
import traceback
from collections import defaultdict
from typing import Dict, List, Optional, Sequence
import numpy as np
import torch
import transformers
from transformers.trainer_pt_utils import LabelSmoother
from .dataset_base import BaseDataset
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
class MistralDataset(BaseDataset):
def __init__(
self,
*args,
**kwargs,
):
super().__init__(
*args,
**kwargs,
)
self.system_message = "You are a helpful AI assistant."
self.system_message = None
self.ret = defaultdict(dict)
self.is_cat = True
def maybe_init_ret(self, source, force=False):
if source not in self.ret or force:
self.ret[source] = {}
self.ret[source]["tokens"] = []
self.ret[source]["labels"] = []
self.ret[source]["actual_seq_len"] = []
if self.create_position_ids:
self.ret[source]["position_ids"] = []
if self.create_attention_mask:
self.ret[source]["attention_mask"] = []
if self.create_attention_mask_2d:
self.ret[source]["attention_mask_2d"] = torch.tril(
torch.ones(
(1, self.max_padding_length, self.max_padding_length), dtype=torch.bool
)
)
return len(self.ret[source]["tokens"]) == 0
def __getitem__(self, index):
while True:
try:
# if True:
self.processed_samples += 1
if self.processed_samples < self.skip_samples:
if self.processed_samples % 1e3 == 0:
print(
f"processed_samples {self.processed_samples} skip_samples {self.skip_samples}"
)
return {}
source = self.raw_data[index]["source"]
is_empty = self.maybe_init_ret(source)
is_begin = is_empty or self.reset_position_ids or self.reset_attention_mask
ret = preprocess(
self.raw_data[index],
self.tokenizer,
self.image_token_length,
system_message=self.system_message,
image_processor=self.processor["image"],
is_begin=is_begin,
max_num_frame=self.max_num_frame,
max_fps=self.max_fps,
)
if ret is None:
return {}
if len(ret["input_ids"]) > self.max_padding_length:
return {}
if len(ret["images"]) > self.max_num_frame:
return {}
self.unjoint_samples += 1
if self.max_padding_length > 2**17:
log_interval = 1e3
elif self.max_padding_length > 2**15:
log_interval = 1e4
else:
log_interval = 1e4
if self.unjoint_samples % log_interval == 0:
print(
f"processed_samples {self.processed_samples} unjoint_samples {self.unjoint_samples}"
)
all_length = len(self.ret[source]["tokens"])
cur_length = len(ret["input_ids"])
# print("self.ret", self.ret.keys())
# print("source", source)
if all_length + cur_length > self.max_padding_length or (
"images" in self.ret[source]
and len(self.ret[source]["images"]) + len(ret["images"]) > self.max_num_image
):
# if "tokens" in self.ret[source] and len(self.ret[source]["tokens"]) > 0:
to_ret = copy.deepcopy(self.ret[source])
self.maybe_init_ret(source, force=True)
else:
to_ret = {}
all_length = len(self.ret[source]["tokens"])
cur_length = len(ret["input_ids"])
if "images" in ret and len(ret["images"]) > 0:
if "images" in self.ret[source]:
self.ret[source]["images"] = torch.cat(
[self.ret[source]["images"], ret["images"]], dim=0
)
ret["image_indices"][1, :, :] += all_length
self.ret[source]["image_indices"] = torch.cat(
[self.ret[source]["image_indices"], ret["image_indices"]], dim=1
)
else:
self.ret[source]["images"] = ret["images"]
self.ret[source]["image_indices"] = ret["image_indices"]
if self.create_attention_mask:
self.ret[source]["attention_mask"] += ret["attention_mask"]
if self.create_attention_mask_2d:
self.ret[source]["attention_mask_2d"][:, all_length:, :all_length] = 0
if self.create_position_ids:
self.ret[source]["position_ids"] += list(range(cur_length))
self.ret[source]["tokens"] += ret["input_ids"]
self.ret[source]["labels"] += ret["labels"]
self.ret[source]["actual_seq_len"] += [all_length + cur_length]
if "tokens" in to_ret:
if self.create_position_ids:
if self.reset_position_ids:
pass
else:
to_ret["position_ids"] = list(range(len(to_ret["tokens"])))
if self.create_attention_mask_2d:
if sefl.reset_attention_mask:
pass
else:
to_ret["attention_mask_2d"] = torch.tril(
torch.ones(
(1, self.max_padding_length, self.max_padding_length),
dtype=torch.bool,
)
)
if self.shift_token:
to_ret["tokens"] = to_ret["tokens"][:-1]
to_ret["labels"] = to_ret["labels"][1:]
to_ret["actual_seq_len"][-1] -= 1
if self.create_position_ids:
to_ret["position_ids"] = to_ret["position_ids"][:-1]
if self.create_attention_mask:
to_ret["attention_mask"] = to_ret["attention_mask"][:-1]
if self.create_attention_mask_2d:
to_ret["attention_mask_2d"][:, :, -1] = 0
to_ret["attention_mask_2d"][:, -1, :] = 0
assert len(to_ret["tokens"]) == len(
to_ret["labels"]
), f"{len(to_ret['tokens'])} {len(to_ret['labels'])}"
if not self.variable_length and self.max_padding_length > len(to_ret["tokens"]):
to_ret["tokens"] += [self.tokenizer.pad_token_id] * (
self.max_padding_length - len(to_ret["tokens"])
)
to_ret["labels"] += [IGNORE_TOKEN_ID] * (
self.max_padding_length - len(to_ret["labels"])
)
to_ret["actual_seq_len"][-1] = self.max_padding_length
if self.create_position_ids:
to_ret["position_ids"] += to_ret["position_ids"][-1:] * (
self.max_padding_length - len(to_ret["position_ids"])
)
if self.create_attention_mask:
to_ret["attention_mask"] += [0] * (
self.max_padding_length - len(to_ret["attention_mask"])
)
to_ret["tokens"] = to_ret["tokens"][: self.max_padding_length]
to_ret["labels"] = to_ret["labels"][: self.max_padding_length]
to_ret["actual_seq_len"][-1] = self.max_padding_length
if self.create_position_ids:
to_ret["position_ids"] = to_ret["position_ids"][: self.max_padding_length]
if self.create_attention_mask:
to_ret["attention_mask"] = to_ret["attention_mask"][
: self.max_padding_length
]
to_ret["tokens"] = torch.tensor(to_ret["tokens"], dtype=torch.int64)
to_ret["labels"] = torch.tensor(to_ret["labels"], dtype=torch.int64)
to_ret["actual_seq_len"] = torch.tensor(
to_ret["actual_seq_len"], dtype=torch.int64
)
if self.create_position_ids:
to_ret["position_ids"] = torch.tensor(
to_ret["position_ids"], dtype=torch.int64
)
if self.create_attention_mask:
to_ret["attention_mask"] = torch.tensor(
to_ret["attention_mask"], dtype=torch.int64
)
if self.create_attention_mask_2d:
attention_mask_2d = to_ret.pop("attention_mask_2d")
attention_mask_2d = attention_mask_2d.masked_fill(
(to_ret["attention_mask"] < 0.5).view(1, 1, self.max_padding_length),
value=0,
)
attention_mask_2d = attention_mask_2d < 0.5
to_ret["attention_mask"] = attention_mask_2d
if self.create_loss_mask:
loss_mask = torch.where(to_ret["labels"] == IGNORE_TOKEN_ID, 0, 1)
to_ret["loss_mask"] = loss_mask.to(torch.float32)
if not self.reset_position_ids and not self.reset_attention_mask:
to_ret.pop("actual_seq_len")
# print("to_ret[tokens]", to_ret["tokens"])
# print("to_ret[labels]", to_ret["labels"])
return to_ret
except Exception as error:
with open(os.path.join(self.output_dir, "dataset_error.log"), "a") as f:
print(error, file=f)
print([self.raw_data[index]], file=f)
if index == 0:
index += 1
else:
index -= 1
def preprocess(
source,
tokenizer: transformers.PreTrainedTokenizer,
image_token_length: int,
system_message: str = "You are a helpful assistant.",
image_processor=None,
is_begin: bool = True,
max_num_frame: int = 8,
max_fps: int = 1,
) -> Dict:
# [INST]Hello, how are you?[/INST]I'm doing great. How can I help you today?[INST]I'd like to show off how chat templating works![/INST]
from ..constants import (
IMG_START_TOKEN,
IMG_END_TOKEN,
IMG_CONTEXT_TOKEN,
VID_START_TOKEN,
VID_END_TOKEN,
VID_CONTEXT_TOKEN,
PATCH_START_TOKEN,
PATCH_END_TOKEN,
PATCH_CONTEXT_TOKEN,
)
image_tag = ""
video_tag = "