Spaces:
Running
on
Zero
Running
on
Zero
import json | |
import logging | |
import math | |
import os | |
import pdb | |
import random | |
import re | |
import sys | |
import time | |
import traceback | |
from collections import defaultdict | |
from typing import Dict, List, Optional, Sequence | |
import numpy as np | |
import torch | |
import transformers | |
from transformers.trainer_pt_utils import LabelSmoother | |
from .dataset_base import BaseDataset | |
IGNORE_TOKEN_ID = LabelSmoother.ignore_index | |
class HunyuanDataset(BaseDataset): | |
def __init__( | |
self, | |
*args, | |
**kwargs, | |
): | |
super().__init__( | |
*args, | |
**kwargs, | |
) | |
self.default_system_message = "You are a helpful AI assistant." | |
# self.default_system_message = None | |
self.ret = defaultdict(dict) | |
self.is_cat = True | |
if self.cross_dataset_joint: | |
for i in range(2): | |
self.maybe_init_ret(f"default_{i}") | |
def maybe_init_ret(self, source, force=False): | |
if source not in self.ret or force: | |
self.ret[source] = {} | |
self.ret[source]["tokens"] = [] | |
self.ret[source]["labels"] = [] | |
self.ret[source]["actual_seq_len"] = [] | |
if self.create_position_ids: | |
self.ret[source]["position_ids"] = [] | |
if self.create_attention_mask: | |
self.ret[source]["attention_mask"] = [] | |
if self.create_attention_mask_2d: | |
self.ret[source]["attention_mask_2d"] = torch.tril( | |
torch.ones( | |
(1, self.max_padding_length, self.max_padding_length), dtype=torch.bool | |
) | |
) | |
return len(self.ret[source]["tokens"]) == 0 | |
def get_max_min_ret_length(self): | |
max_ret_lengh = 0 | |
min_ret_lengh = self.max_padding_length + 1 | |
max_ret_key = None | |
min_ret_key = None | |
for k, v in self.ret.items(): | |
cur_length = len(v["tokens"]) | |
if cur_length > max_ret_lengh: | |
max_ret_lengh = cur_length | |
max_ret_key = k | |
if cur_length < min_ret_lengh: | |
min_ret_lengh = cur_length | |
min_ret_key = k | |
return max_ret_lengh, max_ret_key, min_ret_lengh, min_ret_key | |
def add_ret(self, ret, source): | |
cur_length = len(ret["input_ids"]) | |
cur_image_length = len(ret["images"]) | |
all_length = len(self.ret[source]["tokens"]) | |
if "images" in self.ret[source]: | |
all_image_length = len(self.ret[source]["images"]) | |
else: | |
all_image_length = 0 | |
if cur_image_length > 0: | |
if all_image_length > 0: | |
self.ret[source]["images"] = torch.cat( | |
[self.ret[source]["images"], ret["images"]], dim=0 | |
) | |
ret["image_indices"][1, :, :] += all_length | |
self.ret[source]["image_indices"] = torch.cat( | |
[self.ret[source]["image_indices"], ret["image_indices"]], dim=1 | |
) | |
else: | |
self.ret[source]["images"] = ret["images"] | |
self.ret[source]["image_indices"] = ret["image_indices"] | |
if self.create_attention_mask: | |
self.ret[source]["attention_mask"] += ret["attention_mask"] | |
if self.create_attention_mask_2d: | |
self.ret[source]["attention_mask_2d"][:, all_length:, :all_length] = 0 | |
if self.create_position_ids: | |
self.ret[source]["position_ids"] += list(range(cur_length)) | |
self.ret[source]["tokens"] += ret["input_ids"] | |
self.ret[source]["labels"] += ret["labels"] | |
self.ret[source]["actual_seq_len"] += [all_length + cur_length] | |
def process_ret(self, to_ret): | |
if "tokens" in to_ret and len(to_ret["tokens"]) > 0: | |
pass | |
else: | |
return to_ret | |
if self.create_position_ids: | |
if self.reset_position_ids: | |
pass | |
else: | |
to_ret["position_ids"] = list(range(len(to_ret["tokens"]))) | |
if self.create_attention_mask_2d: | |
if self.reset_attention_mask: | |
pass | |
else: | |
to_ret["attention_mask_2d"] = torch.tril( | |
torch.ones( | |
(1, self.max_padding_length, self.max_padding_length), dtype=torch.bool | |
) | |
) | |
if self.shift_token: | |
to_ret["tokens"] = to_ret["tokens"][:-1] | |
to_ret["labels"] = to_ret["labels"][1:] | |
to_ret["actual_seq_len"][-1] -= 1 | |
if self.create_position_ids: | |
to_ret["position_ids"] = to_ret["position_ids"][:-1] | |
if self.create_attention_mask: | |
to_ret["attention_mask"] = to_ret["attention_mask"][:-1] | |
if self.create_attention_mask_2d: | |
to_ret["attention_mask_2d"][:, :, -1] = 0 | |
to_ret["attention_mask_2d"][:, -1, :] = 0 | |
assert len(to_ret["tokens"]) == len( | |
to_ret["labels"] | |
), f"{len(to_ret['tokens'])} {len(to_ret['labels'])}" | |
if not self.variable_length and self.max_padding_length > len(to_ret["tokens"]): | |
to_ret["tokens"] += [self.tokenizer.pad_token_id] * ( | |
self.max_padding_length - len(to_ret["tokens"]) | |
) | |
to_ret["labels"] += [IGNORE_TOKEN_ID] * ( | |
self.max_padding_length - len(to_ret["labels"]) | |
) | |
to_ret["actual_seq_len"][-1] = self.max_padding_length | |
if self.create_position_ids: | |
# to_ret["position_ids"] += to_ret["position_ids"][-1:] * ( | |
# self.max_padding_length - len(to_ret["position_ids"]) | |
# ) | |
to_ret["position_ids"] += list( | |
range(to_ret["position_ids"][-1] + 1, self.max_padding_length) | |
) | |
if self.create_attention_mask: | |
to_ret["attention_mask"] += [0] * ( | |
self.max_padding_length - len(to_ret["attention_mask"]) | |
) | |
to_ret["tokens"] = to_ret["tokens"][: self.max_padding_length] | |
to_ret["labels"] = to_ret["labels"][: self.max_padding_length] | |
to_ret["actual_seq_len"][-1] = self.max_padding_length | |
if self.create_position_ids: | |
to_ret["position_ids"] = to_ret["position_ids"][: self.max_padding_length] | |
if self.create_attention_mask: | |
to_ret["attention_mask"] = to_ret["attention_mask"][: self.max_padding_length] | |
to_ret["tokens"] = torch.tensor(to_ret["tokens"], dtype=torch.int64) | |
to_ret["labels"] = torch.tensor(to_ret["labels"], dtype=torch.int64) | |
to_ret["actual_seq_len"] = torch.tensor(to_ret["actual_seq_len"], dtype=torch.int64) | |
if self.create_position_ids: | |
to_ret["position_ids"] = torch.tensor(to_ret["position_ids"], dtype=torch.int64) | |
if self.create_attention_mask: | |
to_ret["attention_mask"] = torch.tensor(to_ret["attention_mask"], dtype=torch.int64) | |
if self.create_attention_mask_2d: | |
attention_mask_2d = to_ret.pop("attention_mask_2d") | |
attention_mask_2d = attention_mask_2d.masked_fill( | |
(to_ret["attention_mask"] < 0.5).view(1, 1, self.max_padding_length), value=0 | |
) | |
attention_mask_2d = attention_mask_2d < 0.5 | |
to_ret["attention_mask"] = attention_mask_2d | |
if self.create_loss_mask: | |
loss_mask = torch.where(to_ret["labels"] == IGNORE_TOKEN_ID, 0, 1) | |
to_ret["loss_mask"] = loss_mask.to(torch.float32) | |
if not self.reset_position_ids and not self.reset_attention_mask: | |
to_ret.pop("actual_seq_len") | |
to_ret["input_ids"] = to_ret["tokens"] | |
# print("to_ret[tokens]", to_ret["tokens"]) | |
# print("to_ret[labels]", to_ret["labels"]) | |
return to_ret | |
def is_skip(self): | |
if self.processed_samples < self.skip_samples: | |
if self.processed_samples % 1e3 == 0: | |
print( | |
f"processed_samples {self.processed_samples} skip_samples {self.skip_samples}" | |
) | |
return True | |
def show_statistic(self): | |
log_interval = 10000 | |
if self.max_padding_length >= 2**17: | |
log_interval = 500 | |
if self.max_padding_length >= 2**20: | |
log_interval = 100 | |
if self.unjoint_samples % log_interval == 0: | |
print( | |
f"processed_samples {self.processed_samples} unjoint_samples {self.unjoint_samples} joint_samples {self.joint_samples} {[len(v['tokens']) for _, v in self.ret.items()]}", | |
flush=True, | |
) | |
return False | |
def __getitem__(self, index): | |
self.processor["audio"].load_model() | |
while True: | |
# if True: | |
try: | |
self.processed_samples += 1 | |
if self.is_skip(): | |
return {} | |
sample = self.raw_data[index] | |
if self.cross_dataset_joint: | |
is_empty = False | |
( | |
max_ret_lengh, | |
max_ret_key, | |
min_ret_lengh, | |
min_ret_key, | |
) = self.get_max_min_ret_length() | |
else: | |
source = sample["source"] | |
is_empty = self.maybe_init_ret(source) | |
max_ret_lengh = min_ret_lengh = len(self.ret[source]["tokens"]) | |
max_ret_key = min_ret_key = source | |
is_begin = is_empty or self.reset_position_ids or self.reset_attention_mask | |
ret = preprocess( | |
sample, | |
self.tokenizer, | |
self.image_token_length, | |
default_system_message=self.default_system_message, | |
processor=self.processor, | |
is_begin=is_begin, | |
max_num_frame=self.max_num_frame, | |
max_fps=self.max_fps, | |
) | |
if ret is None: | |
return {} | |
cur_length = len(ret["input_ids"]) | |
if cur_length > self.max_padding_length: | |
return {} | |
self.unjoint_samples += 1 | |
if not self.dataset_joint: | |
to_ret = self.ret.pop(max_ret_key) | |
self.maybe_init_ret(max_ret_key, force=True) | |
self.add_ret(ret, max_ret_key) | |
elif min_ret_lengh + cur_length > self.max_padding_length: | |
to_ret = self.ret.pop(max_ret_key) | |
self.joint_samples += 1 | |
self.maybe_init_ret(max_ret_key, force=True) | |
self.add_ret(ret, max_ret_key) | |
else: | |
to_ret = {} | |
self.add_ret(ret, min_ret_key) | |
to_ret = self.process_ret(to_ret) | |
self.show_statistic() | |
return to_ret | |
except Exception as error: | |
try: | |
with open(os.path.join(self.output_dir, "data_error.log"), "a") as f: | |
print("-" * 100, file=f) | |
print(traceback.format_exc(), file=f) | |
print(self.raw_data[index], file=f) | |
except Exception as error: | |
print(error) | |
return {} | |
def preprocess( | |
sample, | |
tokenizer: transformers.PreTrainedTokenizer, | |
image_token_length: int, | |
default_system_message: str = "You are a helpful assistant.", | |
processor=None, | |
is_begin: bool = True, | |
max_num_frame: int = 8, | |
max_fps: int = 1, | |
) -> Dict: | |
from ..constants import ( | |
IMG_START_TOKEN, | |
IMG_END_TOKEN, | |
IMG_CONTEXT_TOKEN, | |
VID_START_TOKEN, | |
VID_END_TOKEN, | |
VID_CONTEXT_TOKEN, | |
PATCH_START_TOKEN, | |
PATCH_END_TOKEN, | |
PATCH_CONTEXT_TOKEN, | |
AUD_START_TOKEN, | |
AUD_END_TOKEN, | |
IMG_TAG_TOKEN, | |
VID_TAG_TOKEN, | |
AUD_TAG_TOKEN, | |
) | |
human_roles = ["user", "human"] | |
gpt_roles = ["assistant", "gpt"] | |
system_roles = ["system", "observation"] | |
IMG_CONTEXT_ID = tokenizer(IMG_CONTEXT_TOKEN, add_special_tokens=False).input_ids | |
IMG_START_ID = tokenizer(IMG_START_TOKEN, add_special_tokens=False).input_ids | |
IMG_END_ID = tokenizer(IMG_END_TOKEN, add_special_tokens=False).input_ids | |
VID_CONTEXT_ID = tokenizer(VID_CONTEXT_TOKEN, add_special_tokens=False).input_ids | |
VID_START_ID = tokenizer(VID_START_TOKEN, add_special_tokens=False).input_ids | |
VID_END_ID = tokenizer(VID_END_TOKEN, add_special_tokens=False).input_ids | |
PATCH_CONTEXT_ID = tokenizer(PATCH_CONTEXT_TOKEN, add_special_tokens=False).input_ids | |
PATCH_START_ID = tokenizer(PATCH_START_TOKEN, add_special_tokens=False).input_ids | |
PATCH_END_ID = tokenizer(PATCH_END_TOKEN, add_special_tokens=False).input_ids | |
AUD_START_ID = tokenizer(AUD_START_TOKEN, add_special_tokens=False).input_ids | |
AUD_END_ID = tokenizer(AUD_END_TOKEN, add_special_tokens=False).input_ids | |
IMG_TAG_ID = tokenizer(IMG_TAG_TOKEN, add_special_tokens=False).input_ids | |
VID_TAG_ID = tokenizer(VID_TAG_TOKEN, add_special_tokens=False).input_ids | |
AUD_TAG_ID = tokenizer(AUD_TAG_TOKEN, add_special_tokens=False).input_ids | |
assert len(IMG_CONTEXT_ID) == 1 | |
assert len(IMG_START_ID) == 1 | |
assert len(IMG_END_ID) == 1 | |
assert len(VID_CONTEXT_ID) == 1 | |
assert len(VID_START_ID) == 1 | |
assert len(VID_END_ID) == 1 | |
assert len(PATCH_CONTEXT_ID) == 1 | |
assert len(PATCH_START_ID) == 1 | |
assert len(PATCH_END_ID) == 1 | |
IMG_CONTEXT_ID = IMG_CONTEXT_ID[0] | |
IMG_START_ID = IMG_START_ID[0] | |
IMG_END_ID = IMG_END_ID[0] | |
VID_CONTEXT_ID = VID_CONTEXT_ID[0] | |
VID_START_ID = VID_START_ID[0] | |
VID_END_ID = VID_END_ID[0] | |
PATCH_CONTEXT_ID = PATCH_CONTEXT_ID[0] | |
PATCH_START_ID = PATCH_START_ID[0] | |
PATCH_END_ID = PATCH_END_ID[0] | |
AUD_START_ID = AUD_START_ID[0] | |
AUD_END_ID = AUD_END_ID[0] | |
IMG_TAG_ID = IMG_TAG_ID[0] | |
VID_TAG_ID = VID_TAG_ID[0] | |
AUD_TAG_ID = AUD_TAG_ID[0] | |
startoftext = "<|startoftext|>" | |
extra_4 = "<|extra_4|>" | |
extra_0 = "<|extra_0|>" | |
eos = "<|eos|>" | |
nl_tokens = tokenizer("\n", add_special_tokens=False).input_ids | |
startoftext_IDS = tokenizer(startoftext, add_special_tokens=False).input_ids | |
extra_4_IDS = tokenizer(extra_4, add_special_tokens=False).input_ids | |
extra_0_IDS = tokenizer(extra_0, add_special_tokens=False).input_ids | |
eos_IDS = tokenizer(eos, add_special_tokens=False).input_ids | |
input_ids, targets = [], [] | |
images = [] | |
image_indices = [] | |
messages = [] | |
if "conversations" in sample: | |
messages = sample["conversations"] | |
if len(messages) == 0 and "messages" in sample: | |
messages = sample["messages"] | |
# ---------------------------------------------------------------- | |
# system | |
has_system = False | |
if is_begin: | |
if messages[0]["role"] == "system": | |
has_system = True | |
else: | |
has_system = False | |
if ( | |
not has_system | |
and default_system_message is not None | |
and len(default_system_message) > 0 | |
): | |
messages = [{"role": "system", "content": default_system_message}] + messages | |
has_system = True | |
# ---------------------------------------------------------------- | |
# audio | |
if has_audio(sample): | |
audio_tokens_list = [processor["audio"].process_audios(x) for x in sample["audios"]] | |
audio_tokens_list = ["".join(f"<|audio_{i}|>" for i in x) for x in audio_tokens_list] | |
audio_idx = 0 | |
for j, sentence in enumerate(messages): | |
content = sentence["content"] | |
while AUD_TAG_TOKEN in content: | |
content = content.replace( | |
AUD_TAG_TOKEN, | |
f"{AUD_START_TOKEN}{audio_tokens_list[audio_idx]}{AUD_END_TOKEN}", | |
1, | |
) | |
audio_idx += 1 | |
sentence["content"] = content | |
audio_idx = 0 | |
for j, sentence in enumerate(messages): | |
content = sentence["content"] | |
while "<audio>" in content: | |
content = content.replace( | |
"<audio>", f"{AUD_START_TOKEN}{audio_tokens_list[audio_idx]}{AUD_END_TOKEN}", 1 | |
) | |
audio_idx += 1 | |
sentence["content"] = content | |
# ---------------------------------------------------------------- | |
# text | |
for j, sentence in enumerate(messages): | |
role = sentence["role"] | |
content = sentence["content"] | |
if role in human_roles: | |
# first user | |
if j == 1: | |
if has_system: | |
_input_id = tokenizer(content, add_special_tokens=False).input_ids + extra_0_IDS | |
else: | |
_input_id = ( | |
startoftext_IDS | |
+ startoftext_IDS | |
+ tokenizer(content, add_special_tokens=False).input_ids | |
+ extra_4_IDS | |
+ extra_0_IDS | |
) | |
else: | |
_input_id = ( | |
startoftext_IDS | |
+ tokenizer(content, add_special_tokens=False).input_ids | |
+ extra_0_IDS | |
) | |
_target = [IGNORE_TOKEN_ID] * len(_input_id) | |
elif role in gpt_roles: | |
# _input_id = tokenizer(content, add_special_tokens=False).input_ids + eos_IDS | |
# _target = tokenizer(content, add_special_tokens=False).input_ids + eos_IDS | |
_input_id = ( | |
text_audio_interval( | |
tokenizer(content, add_special_tokens=False).input_ids, AUD_START_ID, AUD_END_ID | |
) | |
+ eos_IDS | |
) | |
_target = _input_id | |
elif role in system_roles: | |
_input_id = ( | |
startoftext_IDS | |
+ tokenizer(content, add_special_tokens=False).input_ids | |
+ extra_4_IDS | |
) | |
_target = [IGNORE_TOKEN_ID] * len(_input_id) | |
else: | |
raise NotImplementedError | |
# print(f"_input_id {_input_id}") | |
input_ids += _input_id | |
targets += _target | |
# ---------------------------------------------------------------- | |
# image | |
if has_image(sample): | |
img_positions = [i for i, x in enumerate(input_ids) if x == IMG_TAG_ID] | |
assert len(img_positions) == len(sample["images"]), sample | |
new_input_ids = [] | |
new_targets = [] | |
st = 0 | |
for img_idx, img_pos in enumerate(img_positions): | |
image_patches, (best_width, best_height) = processor[ | |
"image" | |
].process_images_with_subpatch(sample["images"][img_idx]) | |
images.append(image_patches) | |
new_input_ids += input_ids[st:img_pos] | |
new_targets += targets[st:img_pos] | |
new_input_ids += [IMG_START_ID] | |
new_targets += [IGNORE_TOKEN_ID] | |
image_indice_b = torch.zeros( | |
1, image_token_length, dtype=torch.int64 | |
) # This will change in collate_fn | |
image_indice_s = ( | |
torch.arange(len(new_input_ids), len(new_input_ids) + image_token_length) | |
.unsqueeze(0) | |
.repeat(1, 1) | |
) | |
image_indice_b_s = torch.stack( | |
[image_indice_b, image_indice_s], dim=0 | |
) # 2, num_image, image_length | |
image_indices.append(image_indice_b_s) | |
new_input_ids += [IMG_CONTEXT_ID] * image_token_length | |
new_targets += [IGNORE_TOKEN_ID] * image_token_length | |
new_input_ids += [IMG_END_ID] | |
new_targets += [IGNORE_TOKEN_ID] | |
if len(image_patches) > 1: | |
for i in range(0, best_height, processor["image"].patch_size): | |
new_input_ids += nl_tokens | |
new_targets += [IGNORE_TOKEN_ID] * len(nl_tokens) | |
for j in range(0, best_width, processor["image"].patch_size): | |
new_input_ids += [PATCH_START_ID] | |
new_targets += [IGNORE_TOKEN_ID] | |
image_indice_b = torch.zeros( | |
1, image_token_length, dtype=torch.int64 | |
) # This will change in collate_fn | |
image_indice_s = ( | |
torch.arange( | |
len(new_input_ids), len(new_input_ids) + image_token_length | |
) | |
.unsqueeze(0) | |
.repeat(1, 1) | |
) | |
image_indice_b_s = torch.stack( | |
[image_indice_b, image_indice_s], dim=0 | |
) # 2, num_image, image_length | |
image_indices.append(image_indice_b_s) | |
new_input_ids += [PATCH_CONTEXT_ID] * image_token_length | |
new_targets += [IGNORE_TOKEN_ID] * image_token_length | |
new_input_ids += [PATCH_END_ID] | |
new_targets += [IGNORE_TOKEN_ID] | |
st = img_pos + 1 | |
new_input_ids += input_ids[st:] | |
new_targets += targets[st:] | |
input_ids = new_input_ids | |
targets = new_targets | |
# ---------------------------------------------------------------- | |
# video | |
if has_video(sample): | |
vid_positions = [i for i, x in enumerate(input_ids) if x == VID_TAG_ID] | |
assert len(vid_positions) == len(sample["videos"]), sample | |
new_input_ids = [] | |
new_targets = [] | |
st = 0 | |
for vid_idx, vid_pos in enumerate(vid_positions): | |
video_frames, _ = processor["image"].process_video( | |
sample["videos"][vid_idx], max_num_frame, max_fps | |
) | |
new_input_ids += input_ids[st:vid_pos] | |
new_targets += targets[st:vid_pos] | |
images.append(video_frames) | |
for _ in video_frames: | |
new_input_ids += [VID_START_ID] | |
new_targets += [IGNORE_TOKEN_ID] | |
image_indice_b = torch.zeros( | |
1, image_token_length, dtype=torch.int64 | |
) # This will change in collate_fn | |
image_indice_s = ( | |
torch.arange(len(new_input_ids), len(new_input_ids) + image_token_length) | |
.unsqueeze(0) | |
.repeat(1, 1) | |
) | |
image_indice_b_s = torch.stack( | |
[image_indice_b, image_indice_s], dim=0 | |
) # 2, num_image, image_length | |
image_indices.append(image_indice_b_s) | |
new_input_ids += [VID_CONTEXT_ID] * image_token_length | |
new_targets += [IGNORE_TOKEN_ID] * image_token_length | |
new_input_ids += [VID_END_ID] | |
new_targets += [IGNORE_TOKEN_ID] | |
st = vid_pos + 1 | |
new_input_ids += input_ids[st:] | |
new_targets += targets[st:] | |
input_ids = new_input_ids | |
targets = new_targets | |
# # ---------------------------------------------------------------- | |
# # audio | |
# if has_audio(sample): | |
# aud_positions = [i for i, x in enumerate(input_ids) if x == AUD_TAG_ID] | |
# assert len(aud_positions) == len(sample["audios"]), sample | |
# new_input_ids = [] | |
# new_targets = [] | |
# st = 0 | |
# for aud_idx, aud_pos in enumerate(aud_positions): | |
# audio_tokens = processor["audio"].process_audios( | |
# sample["audios"][aud_idx], | |
# ) | |
# new_input_ids += input_ids[st:aud_pos] | |
# new_targets += targets[st:aud_pos] | |
# new_input_ids += [AUD_START_ID] | |
# new_targets += [IGNORE_TOKEN_ID] | |
# new_input_ids += audio_tokens | |
# new_targets += [IGNORE_TOKEN_ID] * len(audio_tokens) | |
# new_input_ids += [AUD_END_ID] | |
# new_targets += [IGNORE_TOKEN_ID] | |
# st = aud_pos + 1 | |
# new_input_ids += input_ids[st:] | |
# new_targets += targets[st:] | |
# input_ids = new_input_ids | |
# targets = new_targets | |
if len(images) > 0: | |
images = torch.cat(images, dim=0) | |
if len(image_indices) > 0: | |
image_indices = torch.cat(image_indices, dim=1) | |
attention_mask = [1] * len(input_ids) | |
# print("sample", sample, flush=True) | |
# print("input_ids", input_ids, flush=True) | |
# print("targets", targets[:100], flush=True) | |
# print("images", [xx.shape for x in images for xx in x], flush=True) | |
return dict( | |
input_ids=input_ids, | |
labels=targets, | |
attention_mask=attention_mask, | |
images=images, | |
image_indices=image_indices, | |
) | |
def has_video(sample): | |
# video | |
if ( | |
"videos" in sample | |
and isinstance(sample["videos"], list) | |
and None not in sample["videos"] | |
and len(sample["videos"]) | |
): | |
return True | |
return False | |
def has_image(sample): | |
# image | |
if ( | |
"images" in sample | |
and isinstance(sample["images"], list) | |
and None not in sample["images"] | |
and len(sample["images"]) | |
): | |
return True | |
return False | |
def has_audio(sample): | |
# audio | |
if ( | |
"audios" in sample | |
and isinstance(sample["audios"], list) | |
and None not in sample["audios"] | |
and len(sample["audios"]) | |
): | |
return True | |
return False | |
def text_audio_interval(input_ids, AUD_START_ID, AUD_END_ID): | |
audio_num = 26 | |
text_num = 13 | |
st = [i for i, x in enumerate(input_ids) if x == AUD_START_ID] | |
ed = [i for i, x in enumerate(input_ids) if x == AUD_END_ID] | |
# only text | |
if len(st) == 0 and len(ed) == 0: | |
return input_ids | |
assert len(st) == 1 | |
assert len(ed) == 1 | |
st = st[0] | |
ed = ed[0] | |
assert st < ed | |
# only audio | |
if st == 0 and ed == len(input_ids) - 1: | |
return input_ids | |
audio_tokens = input_ids[st + 1 : ed] | |
text_tokens = input_ids[:st] + input_ids[ed + 1 :] | |
audio_tokens_chunks = [ | |
audio_tokens[i : i + audio_num] for i in range(0, len(audio_tokens), audio_num) | |
] | |
text_tokens_chunks = [ | |
text_tokens[i : i + text_num] for i in range(0, len(text_tokens), text_num) | |
] | |
chunk_num = min(len(audio_tokens_chunks), len(text_tokens_chunks)) | |
audio_tokens_chunks = audio_tokens_chunks[: chunk_num - 1] + [ | |
sum(audio_tokens_chunks[chunk_num - 1 :], []) | |
] | |
text_tokens_chunks = text_tokens_chunks[: chunk_num - 1] + [ | |
sum(text_tokens_chunks[chunk_num - 1 :], []) | |
] | |
interval_input_ids = [] | |
for text_tokens, audio_tokens in zip(text_tokens_chunks, audio_tokens_chunks): | |
interval_input_ids += text_tokens + [AUD_START_ID] + audio_tokens + [AUD_END_ID] | |
return interval_input_ids | |