VITA-Audio / vita_audio /data /dataset_qwen2.py
shenyunhang's picture
-a
82f2cfa
import json
import logging
import math
import os
import pdb
import random
import re
import sys
import time
import traceback
from collections import defaultdict
from typing import Dict, List, Optional, Sequence
import numpy as np
import torch
import transformers
from transformers.trainer_pt_utils import LabelSmoother
from .dataset_base import BaseDataset
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
class Qwen2Dataset(BaseDataset):
def __init__(
self,
*args,
**kwargs,
):
super().__init__(
*args,
**kwargs,
)
self.default_system_message = "You are a helpful AI assistant."
self.default_system_message = None
self.ret = defaultdict(dict)
self.is_cat = True
if self.cross_dataset_joint:
for i in range(2):
self.maybe_init_ret(f"default_{i}")
def maybe_init_ret(self, source, force=False):
if source not in self.ret or force:
self.ret[source] = {}
self.ret[source]["tokens"] = []
self.ret[source]["labels"] = []
self.ret[source]["actual_seq_len"] = []
if self.create_position_ids:
self.ret[source]["position_ids"] = []
if self.create_attention_mask:
self.ret[source]["attention_mask"] = []
if self.create_attention_mask_2d:
self.ret[source]["attention_mask_2d"] = torch.tril(
torch.ones(
(1, self.max_padding_length, self.max_padding_length), dtype=torch.bool
)
)
return len(self.ret[source]["tokens"]) == 0
def get_max_min_ret_length(self):
max_ret_lengh = 0
min_ret_lengh = self.max_padding_length + 1
max_ret_key = None
min_ret_key = None
for k, v in self.ret.items():
cur_length = len(v["tokens"])
if cur_length > max_ret_lengh:
max_ret_lengh = cur_length
max_ret_key = k
if cur_length < min_ret_lengh:
min_ret_lengh = cur_length
min_ret_key = k
return max_ret_lengh, max_ret_key, min_ret_lengh, min_ret_key
def add_ret(self, ret, source):
cur_length = len(ret["input_ids"])
cur_image_length = len(ret["images"])
cur_audio_length = len(ret["audios"])
all_length = len(self.ret[source]["tokens"])
if "images" in self.ret[source]:
all_image_length = len(self.ret[source]["images"])
else:
all_image_length = 0
if cur_image_length > 0:
if all_image_length > 0:
self.ret[source]["images"] = torch.cat(
[self.ret[source]["images"], ret["images"]], dim=0
)
ret["image_indices"][1, :, :] += all_length
self.ret[source]["image_indices"] = torch.cat(
[self.ret[source]["image_indices"], ret["image_indices"]], dim=1
)
else:
self.ret[source]["images"] = ret["images"]
self.ret[source]["image_indices"] = ret["image_indices"]
if "audios" in self.ret[source]:
all_audio_length = len(self.ret[source]["audios"])
else:
all_audio_length = 0
if cur_audio_length > 0:
if all_audio_length > 0:
# self.ret[source]["audios"] = torch.cat(
# [self.ret[source]["audios"], ret["audios"]], dim=0
# )
# ret["audio_indices"][1, :, :] += all_length
# self.ret[source]["audio_indices"] = torch.cat(
# [self.ret[source]["audio_indices"], ret["audio_indices"]], dim=1
# )
self.ret[source]["audios"].extend(ret["audios"])
for audio_indice in ret["audio_indices"]:
audio_indice[1, :, :] += all_length
self.ret[source]["audio_indices"].extend(ret["audio_indices"])
else:
self.ret[source]["audios"] = ret["audios"]
self.ret[source]["audio_indices"] = ret["audio_indices"]
# print(self.ret[source]["audios"])
if self.create_attention_mask:
self.ret[source]["attention_mask"] += ret["attention_mask"]
if self.create_attention_mask_2d:
self.ret[source]["attention_mask_2d"][:, all_length:, :all_length] = 0
if self.create_position_ids:
self.ret[source]["position_ids"] += list(range(cur_length))
self.ret[source]["tokens"] += ret["input_ids"]
self.ret[source]["labels"] += ret["labels"]
self.ret[source]["actual_seq_len"] += [all_length + cur_length]
def process_ret(self, to_ret):
if "tokens" in to_ret and len(to_ret["tokens"]) > 0:
pass
else:
return to_ret
if self.create_position_ids:
if self.reset_position_ids:
pass
else:
to_ret["position_ids"] = list(range(len(to_ret["tokens"])))
if self.create_attention_mask_2d:
if self.reset_attention_mask:
pass
else:
to_ret["attention_mask_2d"] = torch.tril(
torch.ones(
(1, self.max_padding_length, self.max_padding_length), dtype=torch.bool
)
)
if self.shift_token:
to_ret["tokens"] = to_ret["tokens"][:-1]
to_ret["labels"] = to_ret["labels"][1:]
to_ret["actual_seq_len"][-1] -= 1
if self.create_position_ids:
to_ret["position_ids"] = to_ret["position_ids"][:-1]
if self.create_attention_mask:
to_ret["attention_mask"] = to_ret["attention_mask"][:-1]
if self.create_attention_mask_2d:
to_ret["attention_mask_2d"][:, :, -1] = 0
to_ret["attention_mask_2d"][:, -1, :] = 0
assert len(to_ret["tokens"]) == len(
to_ret["labels"]
), f"{len(to_ret['tokens'])} {len(to_ret['labels'])}"
if not self.variable_length and self.max_padding_length > len(to_ret["tokens"]):
to_ret["tokens"] += [self.tokenizer.pad_token_id] * (
self.max_padding_length - len(to_ret["tokens"])
)
to_ret["labels"] += [IGNORE_TOKEN_ID] * (
self.max_padding_length - len(to_ret["labels"])
)
to_ret["actual_seq_len"][-1] = self.max_padding_length
if self.create_position_ids:
# to_ret["position_ids"] += to_ret["position_ids"][-1:] * (
# self.max_padding_length - len(to_ret["position_ids"])
# )
to_ret["position_ids"] += list(
range(to_ret["position_ids"][-1] + 1, self.max_padding_length)
)
if self.create_attention_mask:
to_ret["attention_mask"] += [0] * (
self.max_padding_length - len(to_ret["attention_mask"])
)
to_ret["tokens"] = to_ret["tokens"][: self.max_padding_length]
to_ret["labels"] = to_ret["labels"][: self.max_padding_length]
to_ret["actual_seq_len"][-1] = self.max_padding_length
if self.create_position_ids:
to_ret["position_ids"] = to_ret["position_ids"][: self.max_padding_length]
if self.create_attention_mask:
to_ret["attention_mask"] = to_ret["attention_mask"][: self.max_padding_length]
to_ret["tokens"] = torch.tensor(to_ret["tokens"], dtype=torch.int64)
to_ret["labels"] = torch.tensor(to_ret["labels"], dtype=torch.int64)
to_ret["actual_seq_len"] = torch.tensor(to_ret["actual_seq_len"], dtype=torch.int64)
if self.create_position_ids:
to_ret["position_ids"] = torch.tensor(to_ret["position_ids"], dtype=torch.int64)
if self.create_attention_mask:
to_ret["attention_mask"] = torch.tensor(to_ret["attention_mask"], dtype=torch.int64)
if self.create_attention_mask_2d:
attention_mask_2d = to_ret.pop("attention_mask_2d")
attention_mask_2d = attention_mask_2d.masked_fill(
(to_ret["attention_mask"] < 0.5).view(1, 1, self.max_padding_length), value=0
)
attention_mask_2d = attention_mask_2d < 0.5
to_ret["attention_mask"] = attention_mask_2d
if self.create_loss_mask:
loss_mask = torch.where(to_ret["labels"] == IGNORE_TOKEN_ID, 0, 1)
to_ret["loss_mask"] = loss_mask.to(torch.float32)
if not self.reset_position_ids and not self.reset_attention_mask:
to_ret.pop("actual_seq_len")
to_ret["input_ids"] = to_ret["tokens"]
# print("to_ret[tokens]", to_ret["tokens"])
# print("to_ret[labels]", to_ret["labels"])
return to_ret
def is_skip(self):
if self.processed_samples < self.skip_samples:
if self.processed_samples % 1e3 == 0:
print(
f"processed_samples {self.processed_samples} skip_samples {self.skip_samples}"
)
return True
def show_statistic(self):
log_interval = 10000
if self.max_padding_length >= 2**17:
log_interval = 500
if self.max_padding_length >= 2**20:
log_interval = 100
if self.unjoint_samples % log_interval == 0:
print(
f"processed_samples {self.processed_samples} unjoint_samples {self.unjoint_samples} joint_samples {self.joint_samples} {[len(v['tokens']) for _, v in self.ret.items()]}",
flush=True,
)
return False
def __getitem__(self, index):
self.processor["audio"].load_model()
while True:
# if True:
try:
self.processed_samples += 1
if self.is_skip():
return {}
sample = self.raw_data[index]
if self.cross_dataset_joint:
is_empty = False
(
max_ret_lengh,
max_ret_key,
min_ret_lengh,
min_ret_key,
) = self.get_max_min_ret_length()
else:
source = sample["source"]
is_empty = self.maybe_init_ret(source)
max_ret_lengh = min_ret_lengh = len(self.ret[source]["tokens"])
max_ret_key = min_ret_key = source
is_begin = is_empty or self.reset_position_ids or self.reset_attention_mask
ret = preprocess(
sample,
self.tokenizer,
self.image_token_length,
default_system_message=self.default_system_message,
processor=self.processor,
is_begin=is_begin,
max_num_frame=self.max_num_frame,
max_fps=self.max_fps,
)
if ret is None:
return {}
cur_length = len(ret["input_ids"])
if cur_length > self.max_padding_length:
return {}
self.unjoint_samples += 1
if not self.dataset_joint:
to_ret = self.ret.pop(max_ret_key)
self.maybe_init_ret(max_ret_key, force=True)
self.add_ret(ret, max_ret_key)
elif min_ret_lengh + cur_length > self.max_padding_length:
to_ret = self.ret.pop(max_ret_key)
self.joint_samples += 1
self.maybe_init_ret(max_ret_key, force=True)
self.add_ret(ret, max_ret_key)
else:
to_ret = {}
self.add_ret(ret, min_ret_key)
to_ret = self.process_ret(to_ret)
self.show_statistic()
return to_ret
except Exception as error:
try:
with open(os.path.join(self.output_dir, "data_error.log"), "a") as f:
print("-" * 100, file=f)
print(traceback.format_exc(), file=f)
print(self.raw_data[index], file=f)
except Exception as error:
print(error)
return {}
def preprocess(
sample,
tokenizer: transformers.PreTrainedTokenizer,
image_token_length: int,
default_system_message: str = "You are a helpful assistant.",
processor=None,
is_begin: bool = True,
max_num_frame: int = 8,
max_fps: int = 1,
) -> Dict:
# <|im_start|>system
# You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
# <|im_start|>user
# Hello, how are you?<|im_end|>
# <|im_start|>assistantI'm doing great. How can I help you today?<|im_end|>
# <|im_start|>user
# I'd like to show off how chat templating works!<|im_end|>
from ..constants import (
IMG_START_TOKEN,
IMG_END_TOKEN,
IMG_CONTEXT_TOKEN,
VID_START_TOKEN,
VID_END_TOKEN,
VID_CONTEXT_TOKEN,
PATCH_START_TOKEN,
PATCH_END_TOKEN,
PATCH_CONTEXT_TOKEN,
AUD_START_TOKEN,
AUD_END_TOKEN,
IMG_TAG_TOKEN,
VID_TAG_TOKEN,
AUD_TAG_TOKEN,
AUD_CONTEXT_TOKEN,
)
human_roles = ["user", "human"]
gpt_roles = ["assistant", "gpt"]
system_roles = ["system"]
IMG_CONTEXT_ID = tokenizer(IMG_CONTEXT_TOKEN, add_special_tokens=False).input_ids
IMG_START_ID = tokenizer(IMG_START_TOKEN, add_special_tokens=False).input_ids
IMG_END_ID = tokenizer(IMG_END_TOKEN, add_special_tokens=False).input_ids
VID_CONTEXT_ID = tokenizer(VID_CONTEXT_TOKEN, add_special_tokens=False).input_ids
VID_START_ID = tokenizer(VID_START_TOKEN, add_special_tokens=False).input_ids
VID_END_ID = tokenizer(VID_END_TOKEN, add_special_tokens=False).input_ids
PATCH_CONTEXT_ID = tokenizer(PATCH_CONTEXT_TOKEN, add_special_tokens=False).input_ids
PATCH_START_ID = tokenizer(PATCH_START_TOKEN, add_special_tokens=False).input_ids
PATCH_END_ID = tokenizer(PATCH_END_TOKEN, add_special_tokens=False).input_ids
AUD_CONTEXT_ID = tokenizer(AUD_CONTEXT_TOKEN, add_special_tokens=False).input_ids
AUD_START_ID = tokenizer(AUD_START_TOKEN, add_special_tokens=False).input_ids
AUD_END_ID = tokenizer(AUD_END_TOKEN, add_special_tokens=False).input_ids
IMG_TAG_ID = tokenizer(IMG_TAG_TOKEN, add_special_tokens=False).input_ids
VID_TAG_ID = tokenizer(VID_TAG_TOKEN, add_special_tokens=False).input_ids
AUD_TAG_ID = tokenizer(AUD_TAG_TOKEN, add_special_tokens=False).input_ids
assert len(IMG_CONTEXT_ID) == 1
assert len(IMG_START_ID) == 1
assert len(IMG_END_ID) == 1
assert len(VID_CONTEXT_ID) == 1
assert len(VID_START_ID) == 1
assert len(VID_END_ID) == 1
assert len(PATCH_CONTEXT_ID) == 1
assert len(PATCH_START_ID) == 1
assert len(PATCH_END_ID) == 1
IMG_CONTEXT_ID = IMG_CONTEXT_ID[0]
IMG_START_ID = IMG_START_ID[0]
IMG_END_ID = IMG_END_ID[0]
VID_CONTEXT_ID = VID_CONTEXT_ID[0]
VID_START_ID = VID_START_ID[0]
VID_END_ID = VID_END_ID[0]
PATCH_CONTEXT_ID = PATCH_CONTEXT_ID[0]
PATCH_START_ID = PATCH_START_ID[0]
PATCH_END_ID = PATCH_END_ID[0]
AUD_CONTEXT_ID = AUD_CONTEXT_ID[0]
AUD_START_ID = AUD_START_ID[0]
AUD_END_ID = AUD_END_ID[0]
IMG_TAG_ID = IMG_TAG_ID[0]
VID_TAG_ID = VID_TAG_ID[0]
AUD_TAG_ID = AUD_TAG_ID[0]
BOS_ID = tokenizer.bos_token_id
EOS_ID = tokenizer.eos_token_id
IM_START = "<|im_start|>"
IM_END = "<|im_end|>"
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
nl_tokens = tokenizer("\n", add_special_tokens=False).input_ids
IM_START_IDS = tokenizer(IM_START, add_special_tokens=False).input_ids
IM_END_IDS = tokenizer(IM_END, add_special_tokens=False).input_ids
USER_IDS = tokenizer(USER, add_special_tokens=False).input_ids
ASSISTANT_IDS = tokenizer(ASSISTANT, add_special_tokens=False).input_ids
SYSTEM_IDS = tokenizer(SYSTEM, add_special_tokens=False).input_ids
input_ids, targets = [], []
images = []
image_indices = []
audios = []
audio_indices = []
messages = []
if "conversations" in sample:
messages = sample["conversations"]
if len(messages) == 0 and "messages" in sample:
messages = sample["messages"]
# ----------------------------------------------------------------
# add text to TTS
if True:
add_text = None
# add_audio = None
for j, sentence in enumerate(messages):
content = sentence["content"]
role = sentence["role"]
if role == "user":
if "Convert the text to speech." in content:
add_text = content.replace("Convert the text to speech.\n", "")
add_text = add_text.strip()
# if "Convert the speech to text." in content:
# add_audio = sample["audios"][-1]
if role == "assistant" and add_text is not None:
sentence["content"] = add_text + content
# if role == "assistant" and add_audio is not None:
# sentence["content"] = content + "\n<audio>"
# sample["audios"].append(add_audio)
# ----------------------------------------------------------------
# system
has_system = False
if is_begin:
if messages[0]["role"] == "system":
has_system = True
else:
has_system = False
if (
not has_system
and default_system_message is not None
and len(default_system_message) > 0
):
messages = [{"role": "system", "content": default_system_message}] + messages
has_system = True
# ----------------------------------------------------------------
# audio
if has_audio(sample) and processor["audio"].is_discrete:
unused_audio_idxs = list(range(len(sample["audios"])))
audio_tokens_list = [
processor["audio"].process_audios(x, is_discrete=True) for x in sample["audios"]
]
audio_tokens_list = ["".join(f"<|audio_{i}|>" for i in x) for x in audio_tokens_list]
audio_idx = 0
for j, sentence in enumerate(messages):
content = sentence["content"]
role = sentence["role"]
# whether apply discrete tokenize to this role
if processor["audio"].apply_to_role(role, is_discrete=True):
while AUD_TAG_TOKEN in content:
content = content.replace(
AUD_TAG_TOKEN,
f"{AUD_START_TOKEN}{audio_tokens_list[audio_idx]}{AUD_END_TOKEN}",
1,
)
unused_audio_idxs.remove(audio_idx)
audio_idx += 1
else:
audio_idx += content.count(AUD_TAG_TOKEN)
sentence["content"] = content
# ----------------------------------------------------------------
# text
for j, sentence in enumerate(messages):
role = sentence["role"]
content = sentence["content"]
if role in human_roles:
_input_id = (
IM_START_IDS
+ USER_IDS
+ nl_tokens
+ tokenizer(content, add_special_tokens=False).input_ids
+ IM_END_IDS
+ nl_tokens
)
_target = [IGNORE_TOKEN_ID] * len(_input_id)
elif role in gpt_roles:
content_input_id = tokenizer(content, add_special_tokens=False).input_ids
if processor["audio"].audio_tokenizer is not None:
content_input_id = processor["audio"].text_audio_interval(
content_input_id,
AUD_START_ID,
AUD_END_ID,
)
_input_id = (
IM_START_IDS + ASSISTANT_IDS + nl_tokens + content_input_id + IM_END_IDS + nl_tokens
)
_target = (
[IGNORE_TOKEN_ID] * len(IM_START_IDS)
+ [IGNORE_TOKEN_ID] * len(ASSISTANT_IDS)
+ [IGNORE_TOKEN_ID] * len(nl_tokens)
+ content_input_id
+ IM_END_IDS
+ nl_tokens
)
elif role in system_roles:
_input_id = (
IM_START_IDS
+ SYSTEM_IDS
+ nl_tokens
+ tokenizer(content, add_special_tokens=False).input_ids
+ IM_END_IDS
+ nl_tokens
)
_target = [IGNORE_TOKEN_ID] * len(_input_id)
else:
raise NotImplementedError
# print(f"_input_id {_input_id}")
input_ids += _input_id
targets += _target
# ----------------------------------------------------------------
# image
if has_image(sample):
img_positions = [i for i, x in enumerate(input_ids) if x == IMG_TAG_ID]
assert len(img_positions) == len(sample["images"]), sample
new_input_ids = []
new_targets = []
st = 0
for img_idx, img_pos in enumerate(img_positions):
image_patches, (best_width, best_height) = processor[
"image"
].process_images_with_subpatch(sample["images"][img_idx])
images.append(image_patches)
new_input_ids += input_ids[st:img_pos]
new_targets += targets[st:img_pos]
new_input_ids += [IMG_START_ID]
new_targets += [IGNORE_TOKEN_ID]
image_indice_b = torch.zeros(
1, image_token_length, dtype=torch.int64
) # This will change in collate_fn
image_indice_s = (
torch.arange(len(new_input_ids), len(new_input_ids) + image_token_length)
.unsqueeze(0)
.repeat(1, 1)
)
image_indice_b_s = torch.stack(
[image_indice_b, image_indice_s], dim=0
) # 2, num_image, image_length
image_indices.append(image_indice_b_s)
new_input_ids += [IMG_CONTEXT_ID] * image_token_length
new_targets += [IGNORE_TOKEN_ID] * image_token_length
new_input_ids += [IMG_END_ID]
new_targets += [IGNORE_TOKEN_ID]
if len(image_patches) > 1:
for i in range(0, best_height, processor["image"].patch_size):
new_input_ids += nl_tokens
new_targets += [IGNORE_TOKEN_ID] * len(nl_tokens)
for j in range(0, best_width, processor["image"].patch_size):
new_input_ids += [PATCH_START_ID]
new_targets += [IGNORE_TOKEN_ID]
image_indice_b = torch.zeros(
1, image_token_length, dtype=torch.int64
) # This will change in collate_fn
image_indice_s = (
torch.arange(
len(new_input_ids), len(new_input_ids) + image_token_length
)
.unsqueeze(0)
.repeat(1, 1)
)
image_indice_b_s = torch.stack(
[image_indice_b, image_indice_s], dim=0
) # 2, num_image, image_length
image_indices.append(image_indice_b_s)
new_input_ids += [PATCH_CONTEXT_ID] * image_token_length
new_targets += [IGNORE_TOKEN_ID] * image_token_length
new_input_ids += [PATCH_END_ID]
new_targets += [IGNORE_TOKEN_ID]
st = img_pos + 1
new_input_ids += input_ids[st:]
new_targets += targets[st:]
input_ids = new_input_ids
targets = new_targets
# ----------------------------------------------------------------
# video
if has_video(sample):
vid_positions = [i for i, x in enumerate(input_ids) if x == VID_TAG_ID]
assert len(vid_positions) == len(sample["videos"]), sample
new_input_ids = []
new_targets = []
st = 0
for vid_idx, vid_pos in enumerate(vid_positions):
video_frames, _ = processor["image"].process_video(
sample["videos"][vid_idx], max_num_frame, max_fps
)
new_input_ids += input_ids[st:vid_pos]
new_targets += targets[st:vid_pos]
images.append(video_frames)
for _ in video_frames:
new_input_ids += [VID_START_ID]
new_targets += [IGNORE_TOKEN_ID]
image_indice_b = torch.zeros(
1, image_token_length, dtype=torch.int64
) # This will change in collate_fn
image_indice_s = (
torch.arange(len(new_input_ids), len(new_input_ids) + image_token_length)
.unsqueeze(0)
.repeat(1, 1)
)
image_indice_b_s = torch.stack(
[image_indice_b, image_indice_s], dim=0
) # 2, num_image, image_length
image_indices.append(image_indice_b_s)
new_input_ids += [VID_CONTEXT_ID] * image_token_length
new_targets += [IGNORE_TOKEN_ID] * image_token_length
new_input_ids += [VID_END_ID]
new_targets += [IGNORE_TOKEN_ID]
st = vid_pos + 1
new_input_ids += input_ids[st:]
new_targets += targets[st:]
input_ids = new_input_ids
targets = new_targets
# ----------------------------------------------------------------
# audio
if has_audio(sample) and processor["audio"].is_contiguous:
aud_positions = [i for i, x in enumerate(input_ids) if x == AUD_TAG_ID]
# assert len(aud_positions) == len(sample["audios"]), sample
assert len(aud_positions) == len(unused_audio_idxs), sample
new_input_ids = []
new_targets = []
st = 0
for aud_idx, aud_pos in enumerate(aud_positions):
aud_idx = unused_audio_idxs[aud_idx]
audio = processor["audio"].process_audios(sample["audios"][aud_idx], is_contiguous=True)
audios.append(audio)
audio_token_length = audio.size(0) + 4
# audio_token_length = audio.size(0)
new_input_ids += input_ids[st:aud_pos]
new_targets += targets[st:aud_pos]
new_input_ids += [AUD_START_ID]
new_targets += [IGNORE_TOKEN_ID]
audio_indice_b = torch.zeros(
1, audio_token_length, dtype=torch.int64
) # This will change in collate_fn
audio_indice_s = (
torch.arange(len(new_input_ids), len(new_input_ids) + audio_token_length)
.unsqueeze(0)
.repeat(1, 1)
)
audio_indice_b_s = torch.stack(
[audio_indice_b, audio_indice_s], dim=0
) # 2, num_image, image_length
audio_indices.append(audio_indice_b_s)
new_input_ids += [AUD_CONTEXT_ID] * audio_token_length
new_targets += [IGNORE_TOKEN_ID] * audio_token_length
new_input_ids += [AUD_END_ID]
new_targets += [IGNORE_TOKEN_ID]
st = aud_pos + 1
new_input_ids += input_ids[st:]
new_targets += targets[st:]
input_ids = new_input_ids
targets = new_targets
if len(images) > 0:
images = torch.cat(images, dim=0)
if len(image_indices) > 0:
image_indices = torch.cat(image_indices, dim=1)
# if len(audios) > 0:
# audios = torch.cat(audios, dim=0)
# if len(audio_indices) > 0:
# audio_indices = torch.cat(audio_indices, dim=1)
attention_mask = [1] * len(input_ids)
# print("sample", sample, flush=True)
# print("input_ids", input_ids, flush=True)
# print("targets", targets[:100], flush=True)
# print("images", [xx.shape for x in images for xx in x], flush=True)
return dict(
input_ids=input_ids,
labels=targets,
attention_mask=attention_mask,
images=images,
image_indices=image_indices,
audios=audios,
audio_indices=audio_indices,
)
def has_video(sample):
# video
if (
"videos" in sample
and isinstance(sample["videos"], list)
and None not in sample["videos"]
and len(sample["videos"])
):
return True
return False
def has_image(sample):
# image
if (
"images" in sample
and isinstance(sample["images"], list)
and None not in sample["images"]
and len(sample["images"])
):
return True
return False
def has_audio(sample):
# audio
if (
"audios" in sample
and isinstance(sample["audios"], list)
and None not in sample["audios"]
and len(sample["audios"])
):
return True
return False