VITA-Audio / vita_audio /data /dataset_mistral.py
shenyunhang's picture
-a
82f2cfa
import copy
import json
import logging
import os
import pdb
import re
import time
import traceback
from collections import defaultdict
from typing import Dict, List, Optional, Sequence
import numpy as np
import torch
import transformers
from transformers.trainer_pt_utils import LabelSmoother
from .dataset_base import BaseDataset
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
class MistralDataset(BaseDataset):
def __init__(
self,
*args,
**kwargs,
):
super().__init__(
*args,
**kwargs,
)
self.system_message = "You are a helpful AI assistant."
self.system_message = None
self.ret = defaultdict(dict)
self.is_cat = True
def maybe_init_ret(self, source, force=False):
if source not in self.ret or force:
self.ret[source] = {}
self.ret[source]["tokens"] = []
self.ret[source]["labels"] = []
self.ret[source]["actual_seq_len"] = []
if self.create_position_ids:
self.ret[source]["position_ids"] = []
if self.create_attention_mask:
self.ret[source]["attention_mask"] = []
if self.create_attention_mask_2d:
self.ret[source]["attention_mask_2d"] = torch.tril(
torch.ones(
(1, self.max_padding_length, self.max_padding_length), dtype=torch.bool
)
)
return len(self.ret[source]["tokens"]) == 0
def __getitem__(self, index):
while True:
try:
# if True:
self.processed_samples += 1
if self.processed_samples < self.skip_samples:
if self.processed_samples % 1e3 == 0:
print(
f"processed_samples {self.processed_samples} skip_samples {self.skip_samples}"
)
return {}
source = self.raw_data[index]["source"]
is_empty = self.maybe_init_ret(source)
is_begin = is_empty or self.reset_position_ids or self.reset_attention_mask
ret = preprocess(
self.raw_data[index],
self.tokenizer,
self.image_token_length,
system_message=self.system_message,
image_processor=self.processor["image"],
is_begin=is_begin,
max_num_frame=self.max_num_frame,
max_fps=self.max_fps,
)
if ret is None:
return {}
if len(ret["input_ids"]) > self.max_padding_length:
return {}
if len(ret["images"]) > self.max_num_frame:
return {}
self.unjoint_samples += 1
if self.max_padding_length > 2**17:
log_interval = 1e3
elif self.max_padding_length > 2**15:
log_interval = 1e4
else:
log_interval = 1e4
if self.unjoint_samples % log_interval == 0:
print(
f"processed_samples {self.processed_samples} unjoint_samples {self.unjoint_samples}"
)
all_length = len(self.ret[source]["tokens"])
cur_length = len(ret["input_ids"])
# print("self.ret", self.ret.keys())
# print("source", source)
if all_length + cur_length > self.max_padding_length or (
"images" in self.ret[source]
and len(self.ret[source]["images"]) + len(ret["images"]) > self.max_num_image
):
# if "tokens" in self.ret[source] and len(self.ret[source]["tokens"]) > 0:
to_ret = copy.deepcopy(self.ret[source])
self.maybe_init_ret(source, force=True)
else:
to_ret = {}
all_length = len(self.ret[source]["tokens"])
cur_length = len(ret["input_ids"])
if "images" in ret and len(ret["images"]) > 0:
if "images" in self.ret[source]:
self.ret[source]["images"] = torch.cat(
[self.ret[source]["images"], ret["images"]], dim=0
)
ret["image_indices"][1, :, :] += all_length
self.ret[source]["image_indices"] = torch.cat(
[self.ret[source]["image_indices"], ret["image_indices"]], dim=1
)
else:
self.ret[source]["images"] = ret["images"]
self.ret[source]["image_indices"] = ret["image_indices"]
if self.create_attention_mask:
self.ret[source]["attention_mask"] += ret["attention_mask"]
if self.create_attention_mask_2d:
self.ret[source]["attention_mask_2d"][:, all_length:, :all_length] = 0
if self.create_position_ids:
self.ret[source]["position_ids"] += list(range(cur_length))
self.ret[source]["tokens"] += ret["input_ids"]
self.ret[source]["labels"] += ret["labels"]
self.ret[source]["actual_seq_len"] += [all_length + cur_length]
if "tokens" in to_ret:
if self.create_position_ids:
if self.reset_position_ids:
pass
else:
to_ret["position_ids"] = list(range(len(to_ret["tokens"])))
if self.create_attention_mask_2d:
if sefl.reset_attention_mask:
pass
else:
to_ret["attention_mask_2d"] = torch.tril(
torch.ones(
(1, self.max_padding_length, self.max_padding_length),
dtype=torch.bool,
)
)
if self.shift_token:
to_ret["tokens"] = to_ret["tokens"][:-1]
to_ret["labels"] = to_ret["labels"][1:]
to_ret["actual_seq_len"][-1] -= 1
if self.create_position_ids:
to_ret["position_ids"] = to_ret["position_ids"][:-1]
if self.create_attention_mask:
to_ret["attention_mask"] = to_ret["attention_mask"][:-1]
if self.create_attention_mask_2d:
to_ret["attention_mask_2d"][:, :, -1] = 0
to_ret["attention_mask_2d"][:, -1, :] = 0
assert len(to_ret["tokens"]) == len(
to_ret["labels"]
), f"{len(to_ret['tokens'])} {len(to_ret['labels'])}"
if not self.variable_length and self.max_padding_length > len(to_ret["tokens"]):
to_ret["tokens"] += [self.tokenizer.pad_token_id] * (
self.max_padding_length - len(to_ret["tokens"])
)
to_ret["labels"] += [IGNORE_TOKEN_ID] * (
self.max_padding_length - len(to_ret["labels"])
)
to_ret["actual_seq_len"][-1] = self.max_padding_length
if self.create_position_ids:
to_ret["position_ids"] += to_ret["position_ids"][-1:] * (
self.max_padding_length - len(to_ret["position_ids"])
)
if self.create_attention_mask:
to_ret["attention_mask"] += [0] * (
self.max_padding_length - len(to_ret["attention_mask"])
)
to_ret["tokens"] = to_ret["tokens"][: self.max_padding_length]
to_ret["labels"] = to_ret["labels"][: self.max_padding_length]
to_ret["actual_seq_len"][-1] = self.max_padding_length
if self.create_position_ids:
to_ret["position_ids"] = to_ret["position_ids"][: self.max_padding_length]
if self.create_attention_mask:
to_ret["attention_mask"] = to_ret["attention_mask"][
: self.max_padding_length
]
to_ret["tokens"] = torch.tensor(to_ret["tokens"], dtype=torch.int64)
to_ret["labels"] = torch.tensor(to_ret["labels"], dtype=torch.int64)
to_ret["actual_seq_len"] = torch.tensor(
to_ret["actual_seq_len"], dtype=torch.int64
)
if self.create_position_ids:
to_ret["position_ids"] = torch.tensor(
to_ret["position_ids"], dtype=torch.int64
)
if self.create_attention_mask:
to_ret["attention_mask"] = torch.tensor(
to_ret["attention_mask"], dtype=torch.int64
)
if self.create_attention_mask_2d:
attention_mask_2d = to_ret.pop("attention_mask_2d")
attention_mask_2d = attention_mask_2d.masked_fill(
(to_ret["attention_mask"] < 0.5).view(1, 1, self.max_padding_length),
value=0,
)
attention_mask_2d = attention_mask_2d < 0.5
to_ret["attention_mask"] = attention_mask_2d
if self.create_loss_mask:
loss_mask = torch.where(to_ret["labels"] == IGNORE_TOKEN_ID, 0, 1)
to_ret["loss_mask"] = loss_mask.to(torch.float32)
if not self.reset_position_ids and not self.reset_attention_mask:
to_ret.pop("actual_seq_len")
# print("to_ret[tokens]", to_ret["tokens"])
# print("to_ret[labels]", to_ret["labels"])
return to_ret
except Exception as error:
with open(os.path.join(self.output_dir, "dataset_error.log"), "a") as f:
print(error, file=f)
print([self.raw_data[index]], file=f)
if index == 0:
index += 1
else:
index -= 1
def preprocess(
source,
tokenizer: transformers.PreTrainedTokenizer,
image_token_length: int,
system_message: str = "You are a helpful assistant.",
image_processor=None,
is_begin: bool = True,
max_num_frame: int = 8,
max_fps: int = 1,
) -> Dict:
# <s>[INST]Hello, how are you?[/INST]I'm doing great. How can I help you today?</s>[INST]I'd like to show off how chat templating works![/INST]
from ..constants import (
IMG_START_TOKEN,
IMG_END_TOKEN,
IMG_CONTEXT_TOKEN,
VID_START_TOKEN,
VID_END_TOKEN,
VID_CONTEXT_TOKEN,
PATCH_START_TOKEN,
PATCH_END_TOKEN,
PATCH_CONTEXT_TOKEN,
)
image_tag = "<image>"
video_tag = "<video>"
human_roles = ["user", "human"]
gpt_roles = ["assistant", "gpt"]
system_roles = [
"system",
]
IMG_CONTEXT_ID = tokenizer(IMG_CONTEXT_TOKEN, add_special_tokens=False).input_ids
IMG_START_ID = tokenizer(IMG_START_TOKEN, add_special_tokens=False).input_ids
IMG_END_ID = tokenizer(IMG_END_TOKEN, add_special_tokens=False).input_ids
VID_CONTEXT_ID = tokenizer(VID_CONTEXT_TOKEN, add_special_tokens=False).input_ids
VID_START_ID = tokenizer(VID_START_TOKEN, add_special_tokens=False).input_ids
VID_END_ID = tokenizer(VID_END_TOKEN, add_special_tokens=False).input_ids
PATCH_CONTEXT_ID = tokenizer(PATCH_CONTEXT_TOKEN, add_special_tokens=False).input_ids
PATCH_START_ID = tokenizer(PATCH_START_TOKEN, add_special_tokens=False).input_ids
PATCH_END_ID = tokenizer(PATCH_END_TOKEN, add_special_tokens=False).input_ids
assert len(IMG_CONTEXT_ID) == 1
assert len(IMG_START_ID) == 1
assert len(IMG_END_ID) == 1
assert len(VID_CONTEXT_ID) == 1
assert len(VID_START_ID) == 1
assert len(VID_END_ID) == 1
assert len(PATCH_CONTEXT_ID) == 1
assert len(PATCH_START_ID) == 1
assert len(PATCH_END_ID) == 1
IMG_CONTEXT_ID = IMG_CONTEXT_ID[0]
IMG_START_ID = IMG_START_ID[0]
IMG_END_ID = IMG_END_ID[0]
VID_CONTEXT_ID = VID_CONTEXT_ID[0]
VID_START_ID = VID_START_ID[0]
VID_END_ID = VID_END_ID[0]
PATCH_CONTEXT_ID = PATCH_CONTEXT_ID[0]
PATCH_START_ID = PATCH_START_ID[0]
PATCH_END_ID = PATCH_END_ID[0]
BOS_ID = tokenizer.bos_token_id
EOS_ID = tokenizer.eos_token_id
B_INST = "[INST]"
E_INST = "[/INST]"
B_SYS = "<<SYS>>\n"
E_SYS = "\n<</SYS>>\n\n"
nl_tokens = tokenizer("\n", add_special_tokens=False).input_ids
B_SYS_IDS = tokenizer(B_SYS, add_special_tokens=False).input_ids
E_SYS_IDS = tokenizer(E_SYS, add_special_tokens=False).input_ids
B_INST_IDS = tokenizer(B_INST, add_special_tokens=False).input_ids
E_INST_IDS = tokenizer(E_INST, add_special_tokens=False).input_ids
input_ids, targets = [], []
images = []
image_indices = []
conversations = source["conversations"]
if is_begin:
input_ids += [BOS_ID]
targets += [IGNORE_TOKEN_ID]
if conversations[0]["role"] == "system":
custom_system = True
else:
custom_system = False
if not custom_system and system_message is not None and len(system_message) > 0:
for jj in range(0, len(conversations)):
if conversations[jj]["role"] in human_roles:
conversations[jj]["content"] = (
system_message + "\n\n" + conversations[jj]["content"]
)
# print(f"input_ids {input_ids}")
# print(f"targets {targets}")
for j, sentence in enumerate(conversations):
role = sentence["role"]
content = sentence["content"]
# ----------------------------------------------------------------
# image
content = content.replace(image_tag, IMG_CONTEXT_TOKEN)
if IMG_START_TOKEN in content:
source["images"] = []
bos_pos = [m.start() for m in re.finditer(IMG_START_TOKEN, value)]
eos_pos = [m.start() for m in re.finditer(IMG_END_TOKEN, value)]
# print(bos_pos, eos_pos)
assert len(bos_pos) == len(eos_pos)
new_content = ""
st = 0
for a, b in zip(bos_pos, eos_pos):
img_path = content[a + len(IMG_START_TOKEN) : b]
new_value += content[st:a] + IMG_CONTEXT_TOKEN
st = b + len(IMG_END_TOKEN)
source["images"].append(img_path)
new_value += content[st:]
content = new_value
# ----------------------------------------------------------------
# video
content = content.replace(video_tag, VID_CONTEXT_TOKEN)
if VID_START_TOKEN in content:
source["videos"] = []
bos_pos = [m.start() for m in re.finditer(VID_START_TOKEN, value)]
eos_pos = [m.start() for m in re.finditer(VID_END_TOKEN, value)]
# print(bos_pos, eos_pos)
assert len(bos_pos) == len(eos_pos)
new_content = ""
st = 0
for a, b in zip(bos_pos, eos_pos):
vid_path = content[a + len(VID_START_TOKEN) : b]
new_value += content[st:a] + VID_CONTEXT_TOKEN
st = b + len(VID_END_TOKEN)
source["videos"].append(vid_path)
new_value += content[st:]
content = new_value
# ----------------------------------------------------------------
# text
if role in human_roles:
_input_id = (
B_INST_IDS + tokenizer(content, add_special_tokens=False).input_ids + E_INST_IDS
)
_target = [IGNORE_TOKEN_ID] * len(_input_id)
elif role in gpt_roles:
_input_id = tokenizer(content, add_special_tokens=False).input_ids + [EOS_ID]
_target = tokenizer(content, add_special_tokens=False).input_ids + [EOS_ID]
if "type" in sentence:
if sentence["type"] == "wrong answer":
_target = [IGNORE_TOKEN_ID] * (len(_input_id) - 1) + [EOS_ID]
elif role in system_roles:
for jj in range(j + 1, len(conversations)):
if conversations[jj]["role"] in human_roles:
conversations[jj]["content"] = content + "\n\n" + conversations[jj]["content"]
_input_id = []
_target = []
else:
raise NotImplementedError
input_ids += _input_id
targets += _target
# ----------------------------------------------------------------
# image
if (
"images" in source
and isinstance(source["images"], list)
and None not in source["images"]
and len(source["images"])
):
img_positions = [i for i, x in enumerate(input_ids) if x == IMG_CONTEXT_ID]
assert len(img_positions) == len(source["images"]), source
new_input_ids = []
new_targets = []
st = 0
for img_idx, img_pos in enumerate(img_positions):
image_patches, (best_width, best_height) = image_processor.process_images_with_subpatch(
source["images"][img_idx]
)
images.append(image_patches)
new_input_ids += input_ids[st:img_pos]
new_targets += targets[st:img_pos]
new_input_ids += [IMG_START_ID]
new_targets += [IGNORE_TOKEN_ID]
image_indice_b = torch.zeros(
1, image_token_length, dtype=torch.int64
) # This will change in collate_fn
image_indice_s = (
torch.arange(len(new_input_ids), len(new_input_ids) + image_token_length)
.unsqueeze(0)
.repeat(1, 1)
)
image_indice_b_s = torch.stack(
[image_indice_b, image_indice_s], dim=0
) # 2, num_image, image_length
image_indices.append(image_indice_b_s)
new_input_ids += [IMG_CONTEXT_ID] * image_token_length
new_targets += [IGNORE_TOKEN_ID] * image_token_length
new_input_ids += [IMG_END_ID]
new_targets += [IGNORE_TOKEN_ID]
if len(image_patches) > 1:
for i in range(0, best_height, image_processor.patch_size):
new_input_ids += nl_tokens
new_targets += [IGNORE_TOKEN_ID] * len(nl_tokens)
for j in range(0, best_width, image_processor.patch_size):
new_input_ids += [PATCH_START_ID]
new_targets += [IGNORE_TOKEN_ID]
image_indice_b = torch.zeros(
1, image_token_length, dtype=torch.int64
) # This will change in collate_fn
image_indice_s = (
torch.arange(
len(new_input_ids), len(new_input_ids) + image_token_length
)
.unsqueeze(0)
.repeat(1, 1)
)
image_indice_b_s = torch.stack(
[image_indice_b, image_indice_s], dim=0
) # 2, num_image, image_length
image_indices.append(image_indice_b_s)
new_input_ids += [PATCH_CONTEXT_ID] * image_token_length
new_targets += [IGNORE_TOKEN_ID] * image_token_length
new_input_ids += [PATCH_END_ID]
new_targets += [IGNORE_TOKEN_ID]
st = img_pos + 1
new_input_ids += input_ids[st:]
new_targets += targets[st:]
input_ids = new_input_ids
targets = new_targets
# ----------------------------------------------------------------
# video
if (
"videos" in source
and isinstance(source["videos"], list)
and None not in source["videos"]
and len(source["videos"])
):
vid_positions = [i for i, x in enumerate(input_ids) if x == VID_CONTEXT_ID]
assert len(vid_positions) == len(source["videos"]), source
new_input_ids = []
new_targets = []
st = 0
for vid_idx, vid_pos in enumerate(vid_positions):
vid_path = source["videos"][vid_idx]
new_input_ids += input_ids[st:vid_pos]
new_targets += targets[st:vid_pos]
video_frames, _ = image_processor.process_video(vid_path, max_num_frame, max_fps)
images.append(video_frames)
for _ in video_frames:
new_input_ids += [VID_START_ID]
new_targets += [IGNORE_TOKEN_ID]
image_indice_b = torch.zeros(
1, image_token_length, dtype=torch.int64
) # This will change in collate_fn
image_indice_s = (
torch.arange(len(new_input_ids), len(new_input_ids) + image_token_length)
.unsqueeze(0)
.repeat(1, 1)
)
image_indice_b_s = torch.stack(
[image_indice_b, image_indice_s], dim=0
) # 2, num_image, image_length
image_indices.append(image_indice_b_s)
new_input_ids += [VID_CONTEXT_ID] * image_token_length
new_targets += [IGNORE_TOKEN_ID] * image_token_length
new_input_ids += [VID_END_ID]
new_targets += [IGNORE_TOKEN_ID]
st = vid_pos + 1
new_input_ids += input_ids[st:]
new_targets += targets[st:]
input_ids = new_input_ids
targets = new_targets
if len(images) > 0:
images = torch.cat(images, dim=0)
if len(image_indices) > 0:
image_indices = torch.cat(image_indices, dim=1)
attention_mask = [1] * len(input_ids)
# print("source", source, flush=True)
# print("input_ids", input_ids[:100], flush=True)
# print("targets", targets[:100], flush=True)
# print("images", [xx.shape for x in images for xx in x], flush=True)
return dict(
input_ids=input_ids,
labels=targets,
attention_mask=attention_mask,
images=images,
image_indices=image_indices,
)