Spaces:
Running
on
Zero
Running
on
Zero
import json | |
import logging | |
import math | |
import os | |
import pdb | |
import random | |
import re | |
import sys | |
import time | |
import traceback | |
from collections import defaultdict | |
from typing import Dict, List, Optional, Sequence | |
import numpy as np | |
import torch | |
import transformers | |
from transformers.trainer_pt_utils import LabelSmoother | |
from .dataset_base import BaseDataset | |
IGNORE_TOKEN_ID = LabelSmoother.ignore_index | |
class Qwen2Dataset(BaseDataset): | |
def __init__( | |
self, | |
*args, | |
**kwargs, | |
): | |
super().__init__( | |
*args, | |
**kwargs, | |
) | |
self.default_system_message = "You are a helpful AI assistant." | |
self.default_system_message = None | |
self.ret = defaultdict(dict) | |
self.is_cat = True | |
if self.cross_dataset_joint: | |
for i in range(2): | |
self.maybe_init_ret(f"default_{i}") | |
def maybe_init_ret(self, source, force=False): | |
if source not in self.ret or force: | |
self.ret[source] = {} | |
self.ret[source]["tokens"] = [] | |
self.ret[source]["labels"] = [] | |
self.ret[source]["actual_seq_len"] = [] | |
if self.create_position_ids: | |
self.ret[source]["position_ids"] = [] | |
if self.create_attention_mask: | |
self.ret[source]["attention_mask"] = [] | |
if self.create_attention_mask_2d: | |
self.ret[source]["attention_mask_2d"] = torch.tril( | |
torch.ones( | |
(1, self.max_padding_length, self.max_padding_length), dtype=torch.bool | |
) | |
) | |
return len(self.ret[source]["tokens"]) == 0 | |
def get_max_min_ret_length(self): | |
max_ret_lengh = 0 | |
min_ret_lengh = self.max_padding_length + 1 | |
max_ret_key = None | |
min_ret_key = None | |
for k, v in self.ret.items(): | |
cur_length = len(v["tokens"]) | |
if cur_length > max_ret_lengh: | |
max_ret_lengh = cur_length | |
max_ret_key = k | |
if cur_length < min_ret_lengh: | |
min_ret_lengh = cur_length | |
min_ret_key = k | |
return max_ret_lengh, max_ret_key, min_ret_lengh, min_ret_key | |
def add_ret(self, ret, source): | |
cur_length = len(ret["input_ids"]) | |
cur_image_length = len(ret["images"]) | |
cur_audio_length = len(ret["audios"]) | |
all_length = len(self.ret[source]["tokens"]) | |
if "images" in self.ret[source]: | |
all_image_length = len(self.ret[source]["images"]) | |
else: | |
all_image_length = 0 | |
if cur_image_length > 0: | |
if all_image_length > 0: | |
self.ret[source]["images"] = torch.cat( | |
[self.ret[source]["images"], ret["images"]], dim=0 | |
) | |
ret["image_indices"][1, :, :] += all_length | |
self.ret[source]["image_indices"] = torch.cat( | |
[self.ret[source]["image_indices"], ret["image_indices"]], dim=1 | |
) | |
else: | |
self.ret[source]["images"] = ret["images"] | |
self.ret[source]["image_indices"] = ret["image_indices"] | |
if "audios" in self.ret[source]: | |
all_audio_length = len(self.ret[source]["audios"]) | |
else: | |
all_audio_length = 0 | |
if cur_audio_length > 0: | |
if all_audio_length > 0: | |
# self.ret[source]["audios"] = torch.cat( | |
# [self.ret[source]["audios"], ret["audios"]], dim=0 | |
# ) | |
# ret["audio_indices"][1, :, :] += all_length | |
# self.ret[source]["audio_indices"] = torch.cat( | |
# [self.ret[source]["audio_indices"], ret["audio_indices"]], dim=1 | |
# ) | |
self.ret[source]["audios"].extend(ret["audios"]) | |
for audio_indice in ret["audio_indices"]: | |
audio_indice[1, :, :] += all_length | |
self.ret[source]["audio_indices"].extend(ret["audio_indices"]) | |
else: | |
self.ret[source]["audios"] = ret["audios"] | |
self.ret[source]["audio_indices"] = ret["audio_indices"] | |
# print(self.ret[source]["audios"]) | |
if self.create_attention_mask: | |
self.ret[source]["attention_mask"] += ret["attention_mask"] | |
if self.create_attention_mask_2d: | |
self.ret[source]["attention_mask_2d"][:, all_length:, :all_length] = 0 | |
if self.create_position_ids: | |
self.ret[source]["position_ids"] += list(range(cur_length)) | |
self.ret[source]["tokens"] += ret["input_ids"] | |
self.ret[source]["labels"] += ret["labels"] | |
self.ret[source]["actual_seq_len"] += [all_length + cur_length] | |
def process_ret(self, to_ret): | |
if "tokens" in to_ret and len(to_ret["tokens"]) > 0: | |
pass | |
else: | |
return to_ret | |
if self.create_position_ids: | |
if self.reset_position_ids: | |
pass | |
else: | |
to_ret["position_ids"] = list(range(len(to_ret["tokens"]))) | |
if self.create_attention_mask_2d: | |
if self.reset_attention_mask: | |
pass | |
else: | |
to_ret["attention_mask_2d"] = torch.tril( | |
torch.ones( | |
(1, self.max_padding_length, self.max_padding_length), dtype=torch.bool | |
) | |
) | |
if self.shift_token: | |
to_ret["tokens"] = to_ret["tokens"][:-1] | |
to_ret["labels"] = to_ret["labels"][1:] | |
to_ret["actual_seq_len"][-1] -= 1 | |
if self.create_position_ids: | |
to_ret["position_ids"] = to_ret["position_ids"][:-1] | |
if self.create_attention_mask: | |
to_ret["attention_mask"] = to_ret["attention_mask"][:-1] | |
if self.create_attention_mask_2d: | |
to_ret["attention_mask_2d"][:, :, -1] = 0 | |
to_ret["attention_mask_2d"][:, -1, :] = 0 | |
assert len(to_ret["tokens"]) == len( | |
to_ret["labels"] | |
), f"{len(to_ret['tokens'])} {len(to_ret['labels'])}" | |
if not self.variable_length and self.max_padding_length > len(to_ret["tokens"]): | |
to_ret["tokens"] += [self.tokenizer.pad_token_id] * ( | |
self.max_padding_length - len(to_ret["tokens"]) | |
) | |
to_ret["labels"] += [IGNORE_TOKEN_ID] * ( | |
self.max_padding_length - len(to_ret["labels"]) | |
) | |
to_ret["actual_seq_len"][-1] = self.max_padding_length | |
if self.create_position_ids: | |
# to_ret["position_ids"] += to_ret["position_ids"][-1:] * ( | |
# self.max_padding_length - len(to_ret["position_ids"]) | |
# ) | |
to_ret["position_ids"] += list( | |
range(to_ret["position_ids"][-1] + 1, self.max_padding_length) | |
) | |
if self.create_attention_mask: | |
to_ret["attention_mask"] += [0] * ( | |
self.max_padding_length - len(to_ret["attention_mask"]) | |
) | |
to_ret["tokens"] = to_ret["tokens"][: self.max_padding_length] | |
to_ret["labels"] = to_ret["labels"][: self.max_padding_length] | |
to_ret["actual_seq_len"][-1] = self.max_padding_length | |
if self.create_position_ids: | |
to_ret["position_ids"] = to_ret["position_ids"][: self.max_padding_length] | |
if self.create_attention_mask: | |
to_ret["attention_mask"] = to_ret["attention_mask"][: self.max_padding_length] | |
to_ret["tokens"] = torch.tensor(to_ret["tokens"], dtype=torch.int64) | |
to_ret["labels"] = torch.tensor(to_ret["labels"], dtype=torch.int64) | |
to_ret["actual_seq_len"] = torch.tensor(to_ret["actual_seq_len"], dtype=torch.int64) | |
if self.create_position_ids: | |
to_ret["position_ids"] = torch.tensor(to_ret["position_ids"], dtype=torch.int64) | |
if self.create_attention_mask: | |
to_ret["attention_mask"] = torch.tensor(to_ret["attention_mask"], dtype=torch.int64) | |
if self.create_attention_mask_2d: | |
attention_mask_2d = to_ret.pop("attention_mask_2d") | |
attention_mask_2d = attention_mask_2d.masked_fill( | |
(to_ret["attention_mask"] < 0.5).view(1, 1, self.max_padding_length), value=0 | |
) | |
attention_mask_2d = attention_mask_2d < 0.5 | |
to_ret["attention_mask"] = attention_mask_2d | |
if self.create_loss_mask: | |
loss_mask = torch.where(to_ret["labels"] == IGNORE_TOKEN_ID, 0, 1) | |
to_ret["loss_mask"] = loss_mask.to(torch.float32) | |
if not self.reset_position_ids and not self.reset_attention_mask: | |
to_ret.pop("actual_seq_len") | |
to_ret["input_ids"] = to_ret["tokens"] | |
# print("to_ret[tokens]", to_ret["tokens"]) | |
# print("to_ret[labels]", to_ret["labels"]) | |
return to_ret | |
def is_skip(self): | |
if self.processed_samples < self.skip_samples: | |
if self.processed_samples % 1e3 == 0: | |
print( | |
f"processed_samples {self.processed_samples} skip_samples {self.skip_samples}" | |
) | |
return True | |
def show_statistic(self): | |
log_interval = 10000 | |
if self.max_padding_length >= 2**17: | |
log_interval = 500 | |
if self.max_padding_length >= 2**20: | |
log_interval = 100 | |
if self.unjoint_samples % log_interval == 0: | |
print( | |
f"processed_samples {self.processed_samples} unjoint_samples {self.unjoint_samples} joint_samples {self.joint_samples} {[len(v['tokens']) for _, v in self.ret.items()]}", | |
flush=True, | |
) | |
return False | |
def __getitem__(self, index): | |
self.processor["audio"].load_model() | |
while True: | |
# if True: | |
try: | |
self.processed_samples += 1 | |
if self.is_skip(): | |
return {} | |
sample = self.raw_data[index] | |
if self.cross_dataset_joint: | |
is_empty = False | |
( | |
max_ret_lengh, | |
max_ret_key, | |
min_ret_lengh, | |
min_ret_key, | |
) = self.get_max_min_ret_length() | |
else: | |
source = sample["source"] | |
is_empty = self.maybe_init_ret(source) | |
max_ret_lengh = min_ret_lengh = len(self.ret[source]["tokens"]) | |
max_ret_key = min_ret_key = source | |
is_begin = is_empty or self.reset_position_ids or self.reset_attention_mask | |
ret = preprocess( | |
sample, | |
self.tokenizer, | |
self.image_token_length, | |
default_system_message=self.default_system_message, | |
processor=self.processor, | |
is_begin=is_begin, | |
max_num_frame=self.max_num_frame, | |
max_fps=self.max_fps, | |
) | |
if ret is None: | |
return {} | |
cur_length = len(ret["input_ids"]) | |
if cur_length > self.max_padding_length: | |
return {} | |
self.unjoint_samples += 1 | |
if not self.dataset_joint: | |
to_ret = self.ret.pop(max_ret_key) | |
self.maybe_init_ret(max_ret_key, force=True) | |
self.add_ret(ret, max_ret_key) | |
elif min_ret_lengh + cur_length > self.max_padding_length: | |
to_ret = self.ret.pop(max_ret_key) | |
self.joint_samples += 1 | |
self.maybe_init_ret(max_ret_key, force=True) | |
self.add_ret(ret, max_ret_key) | |
else: | |
to_ret = {} | |
self.add_ret(ret, min_ret_key) | |
to_ret = self.process_ret(to_ret) | |
self.show_statistic() | |
return to_ret | |
except Exception as error: | |
try: | |
with open(os.path.join(self.output_dir, "data_error.log"), "a") as f: | |
print("-" * 100, file=f) | |
print(traceback.format_exc(), file=f) | |
print(self.raw_data[index], file=f) | |
except Exception as error: | |
print(error) | |
return {} | |
def preprocess( | |
sample, | |
tokenizer: transformers.PreTrainedTokenizer, | |
image_token_length: int, | |
default_system_message: str = "You are a helpful assistant.", | |
processor=None, | |
is_begin: bool = True, | |
max_num_frame: int = 8, | |
max_fps: int = 1, | |
) -> Dict: | |
# <|im_start|>system | |
# You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|> | |
# <|im_start|>user | |
# Hello, how are you?<|im_end|> | |
# <|im_start|>assistantI'm doing great. How can I help you today?<|im_end|> | |
# <|im_start|>user | |
# I'd like to show off how chat templating works!<|im_end|> | |
from ..constants import ( | |
IMG_START_TOKEN, | |
IMG_END_TOKEN, | |
IMG_CONTEXT_TOKEN, | |
VID_START_TOKEN, | |
VID_END_TOKEN, | |
VID_CONTEXT_TOKEN, | |
PATCH_START_TOKEN, | |
PATCH_END_TOKEN, | |
PATCH_CONTEXT_TOKEN, | |
AUD_START_TOKEN, | |
AUD_END_TOKEN, | |
IMG_TAG_TOKEN, | |
VID_TAG_TOKEN, | |
AUD_TAG_TOKEN, | |
AUD_CONTEXT_TOKEN, | |
) | |
human_roles = ["user", "human"] | |
gpt_roles = ["assistant", "gpt"] | |
system_roles = ["system"] | |
IMG_CONTEXT_ID = tokenizer(IMG_CONTEXT_TOKEN, add_special_tokens=False).input_ids | |
IMG_START_ID = tokenizer(IMG_START_TOKEN, add_special_tokens=False).input_ids | |
IMG_END_ID = tokenizer(IMG_END_TOKEN, add_special_tokens=False).input_ids | |
VID_CONTEXT_ID = tokenizer(VID_CONTEXT_TOKEN, add_special_tokens=False).input_ids | |
VID_START_ID = tokenizer(VID_START_TOKEN, add_special_tokens=False).input_ids | |
VID_END_ID = tokenizer(VID_END_TOKEN, add_special_tokens=False).input_ids | |
PATCH_CONTEXT_ID = tokenizer(PATCH_CONTEXT_TOKEN, add_special_tokens=False).input_ids | |
PATCH_START_ID = tokenizer(PATCH_START_TOKEN, add_special_tokens=False).input_ids | |
PATCH_END_ID = tokenizer(PATCH_END_TOKEN, add_special_tokens=False).input_ids | |
AUD_CONTEXT_ID = tokenizer(AUD_CONTEXT_TOKEN, add_special_tokens=False).input_ids | |
AUD_START_ID = tokenizer(AUD_START_TOKEN, add_special_tokens=False).input_ids | |
AUD_END_ID = tokenizer(AUD_END_TOKEN, add_special_tokens=False).input_ids | |
IMG_TAG_ID = tokenizer(IMG_TAG_TOKEN, add_special_tokens=False).input_ids | |
VID_TAG_ID = tokenizer(VID_TAG_TOKEN, add_special_tokens=False).input_ids | |
AUD_TAG_ID = tokenizer(AUD_TAG_TOKEN, add_special_tokens=False).input_ids | |
assert len(IMG_CONTEXT_ID) == 1 | |
assert len(IMG_START_ID) == 1 | |
assert len(IMG_END_ID) == 1 | |
assert len(VID_CONTEXT_ID) == 1 | |
assert len(VID_START_ID) == 1 | |
assert len(VID_END_ID) == 1 | |
assert len(PATCH_CONTEXT_ID) == 1 | |
assert len(PATCH_START_ID) == 1 | |
assert len(PATCH_END_ID) == 1 | |
IMG_CONTEXT_ID = IMG_CONTEXT_ID[0] | |
IMG_START_ID = IMG_START_ID[0] | |
IMG_END_ID = IMG_END_ID[0] | |
VID_CONTEXT_ID = VID_CONTEXT_ID[0] | |
VID_START_ID = VID_START_ID[0] | |
VID_END_ID = VID_END_ID[0] | |
PATCH_CONTEXT_ID = PATCH_CONTEXT_ID[0] | |
PATCH_START_ID = PATCH_START_ID[0] | |
PATCH_END_ID = PATCH_END_ID[0] | |
AUD_CONTEXT_ID = AUD_CONTEXT_ID[0] | |
AUD_START_ID = AUD_START_ID[0] | |
AUD_END_ID = AUD_END_ID[0] | |
IMG_TAG_ID = IMG_TAG_ID[0] | |
VID_TAG_ID = VID_TAG_ID[0] | |
AUD_TAG_ID = AUD_TAG_ID[0] | |
BOS_ID = tokenizer.bos_token_id | |
EOS_ID = tokenizer.eos_token_id | |
IM_START = "<|im_start|>" | |
IM_END = "<|im_end|>" | |
USER = "user" | |
ASSISTANT = "assistant" | |
SYSTEM = "system" | |
nl_tokens = tokenizer("\n", add_special_tokens=False).input_ids | |
IM_START_IDS = tokenizer(IM_START, add_special_tokens=False).input_ids | |
IM_END_IDS = tokenizer(IM_END, add_special_tokens=False).input_ids | |
USER_IDS = tokenizer(USER, add_special_tokens=False).input_ids | |
ASSISTANT_IDS = tokenizer(ASSISTANT, add_special_tokens=False).input_ids | |
SYSTEM_IDS = tokenizer(SYSTEM, add_special_tokens=False).input_ids | |
input_ids, targets = [], [] | |
images = [] | |
image_indices = [] | |
audios = [] | |
audio_indices = [] | |
messages = [] | |
if "conversations" in sample: | |
messages = sample["conversations"] | |
if len(messages) == 0 and "messages" in sample: | |
messages = sample["messages"] | |
# ---------------------------------------------------------------- | |
# add text to TTS | |
if True: | |
add_text = None | |
# add_audio = None | |
for j, sentence in enumerate(messages): | |
content = sentence["content"] | |
role = sentence["role"] | |
if role == "user": | |
if "Convert the text to speech." in content: | |
add_text = content.replace("Convert the text to speech.\n", "") | |
add_text = add_text.strip() | |
# if "Convert the speech to text." in content: | |
# add_audio = sample["audios"][-1] | |
if role == "assistant" and add_text is not None: | |
sentence["content"] = add_text + content | |
# if role == "assistant" and add_audio is not None: | |
# sentence["content"] = content + "\n<audio>" | |
# sample["audios"].append(add_audio) | |
# ---------------------------------------------------------------- | |
# system | |
has_system = False | |
if is_begin: | |
if messages[0]["role"] == "system": | |
has_system = True | |
else: | |
has_system = False | |
if ( | |
not has_system | |
and default_system_message is not None | |
and len(default_system_message) > 0 | |
): | |
messages = [{"role": "system", "content": default_system_message}] + messages | |
has_system = True | |
# ---------------------------------------------------------------- | |
# audio | |
if has_audio(sample) and processor["audio"].is_discrete: | |
unused_audio_idxs = list(range(len(sample["audios"]))) | |
audio_tokens_list = [ | |
processor["audio"].process_audios(x, is_discrete=True) for x in sample["audios"] | |
] | |
audio_tokens_list = ["".join(f"<|audio_{i}|>" for i in x) for x in audio_tokens_list] | |
audio_idx = 0 | |
for j, sentence in enumerate(messages): | |
content = sentence["content"] | |
role = sentence["role"] | |
# whether apply discrete tokenize to this role | |
if processor["audio"].apply_to_role(role, is_discrete=True): | |
while AUD_TAG_TOKEN in content: | |
content = content.replace( | |
AUD_TAG_TOKEN, | |
f"{AUD_START_TOKEN}{audio_tokens_list[audio_idx]}{AUD_END_TOKEN}", | |
1, | |
) | |
unused_audio_idxs.remove(audio_idx) | |
audio_idx += 1 | |
else: | |
audio_idx += content.count(AUD_TAG_TOKEN) | |
sentence["content"] = content | |
# ---------------------------------------------------------------- | |
# text | |
for j, sentence in enumerate(messages): | |
role = sentence["role"] | |
content = sentence["content"] | |
if role in human_roles: | |
_input_id = ( | |
IM_START_IDS | |
+ USER_IDS | |
+ nl_tokens | |
+ tokenizer(content, add_special_tokens=False).input_ids | |
+ IM_END_IDS | |
+ nl_tokens | |
) | |
_target = [IGNORE_TOKEN_ID] * len(_input_id) | |
elif role in gpt_roles: | |
content_input_id = tokenizer(content, add_special_tokens=False).input_ids | |
if processor["audio"].audio_tokenizer is not None: | |
content_input_id = processor["audio"].text_audio_interval( | |
content_input_id, | |
AUD_START_ID, | |
AUD_END_ID, | |
) | |
_input_id = ( | |
IM_START_IDS + ASSISTANT_IDS + nl_tokens + content_input_id + IM_END_IDS + nl_tokens | |
) | |
_target = ( | |
[IGNORE_TOKEN_ID] * len(IM_START_IDS) | |
+ [IGNORE_TOKEN_ID] * len(ASSISTANT_IDS) | |
+ [IGNORE_TOKEN_ID] * len(nl_tokens) | |
+ content_input_id | |
+ IM_END_IDS | |
+ nl_tokens | |
) | |
elif role in system_roles: | |
_input_id = ( | |
IM_START_IDS | |
+ SYSTEM_IDS | |
+ nl_tokens | |
+ tokenizer(content, add_special_tokens=False).input_ids | |
+ IM_END_IDS | |
+ nl_tokens | |
) | |
_target = [IGNORE_TOKEN_ID] * len(_input_id) | |
else: | |
raise NotImplementedError | |
# print(f"_input_id {_input_id}") | |
input_ids += _input_id | |
targets += _target | |
# ---------------------------------------------------------------- | |
# image | |
if has_image(sample): | |
img_positions = [i for i, x in enumerate(input_ids) if x == IMG_TAG_ID] | |
assert len(img_positions) == len(sample["images"]), sample | |
new_input_ids = [] | |
new_targets = [] | |
st = 0 | |
for img_idx, img_pos in enumerate(img_positions): | |
image_patches, (best_width, best_height) = processor[ | |
"image" | |
].process_images_with_subpatch(sample["images"][img_idx]) | |
images.append(image_patches) | |
new_input_ids += input_ids[st:img_pos] | |
new_targets += targets[st:img_pos] | |
new_input_ids += [IMG_START_ID] | |
new_targets += [IGNORE_TOKEN_ID] | |
image_indice_b = torch.zeros( | |
1, image_token_length, dtype=torch.int64 | |
) # This will change in collate_fn | |
image_indice_s = ( | |
torch.arange(len(new_input_ids), len(new_input_ids) + image_token_length) | |
.unsqueeze(0) | |
.repeat(1, 1) | |
) | |
image_indice_b_s = torch.stack( | |
[image_indice_b, image_indice_s], dim=0 | |
) # 2, num_image, image_length | |
image_indices.append(image_indice_b_s) | |
new_input_ids += [IMG_CONTEXT_ID] * image_token_length | |
new_targets += [IGNORE_TOKEN_ID] * image_token_length | |
new_input_ids += [IMG_END_ID] | |
new_targets += [IGNORE_TOKEN_ID] | |
if len(image_patches) > 1: | |
for i in range(0, best_height, processor["image"].patch_size): | |
new_input_ids += nl_tokens | |
new_targets += [IGNORE_TOKEN_ID] * len(nl_tokens) | |
for j in range(0, best_width, processor["image"].patch_size): | |
new_input_ids += [PATCH_START_ID] | |
new_targets += [IGNORE_TOKEN_ID] | |
image_indice_b = torch.zeros( | |
1, image_token_length, dtype=torch.int64 | |
) # This will change in collate_fn | |
image_indice_s = ( | |
torch.arange( | |
len(new_input_ids), len(new_input_ids) + image_token_length | |
) | |
.unsqueeze(0) | |
.repeat(1, 1) | |
) | |
image_indice_b_s = torch.stack( | |
[image_indice_b, image_indice_s], dim=0 | |
) # 2, num_image, image_length | |
image_indices.append(image_indice_b_s) | |
new_input_ids += [PATCH_CONTEXT_ID] * image_token_length | |
new_targets += [IGNORE_TOKEN_ID] * image_token_length | |
new_input_ids += [PATCH_END_ID] | |
new_targets += [IGNORE_TOKEN_ID] | |
st = img_pos + 1 | |
new_input_ids += input_ids[st:] | |
new_targets += targets[st:] | |
input_ids = new_input_ids | |
targets = new_targets | |
# ---------------------------------------------------------------- | |
# video | |
if has_video(sample): | |
vid_positions = [i for i, x in enumerate(input_ids) if x == VID_TAG_ID] | |
assert len(vid_positions) == len(sample["videos"]), sample | |
new_input_ids = [] | |
new_targets = [] | |
st = 0 | |
for vid_idx, vid_pos in enumerate(vid_positions): | |
video_frames, _ = processor["image"].process_video( | |
sample["videos"][vid_idx], max_num_frame, max_fps | |
) | |
new_input_ids += input_ids[st:vid_pos] | |
new_targets += targets[st:vid_pos] | |
images.append(video_frames) | |
for _ in video_frames: | |
new_input_ids += [VID_START_ID] | |
new_targets += [IGNORE_TOKEN_ID] | |
image_indice_b = torch.zeros( | |
1, image_token_length, dtype=torch.int64 | |
) # This will change in collate_fn | |
image_indice_s = ( | |
torch.arange(len(new_input_ids), len(new_input_ids) + image_token_length) | |
.unsqueeze(0) | |
.repeat(1, 1) | |
) | |
image_indice_b_s = torch.stack( | |
[image_indice_b, image_indice_s], dim=0 | |
) # 2, num_image, image_length | |
image_indices.append(image_indice_b_s) | |
new_input_ids += [VID_CONTEXT_ID] * image_token_length | |
new_targets += [IGNORE_TOKEN_ID] * image_token_length | |
new_input_ids += [VID_END_ID] | |
new_targets += [IGNORE_TOKEN_ID] | |
st = vid_pos + 1 | |
new_input_ids += input_ids[st:] | |
new_targets += targets[st:] | |
input_ids = new_input_ids | |
targets = new_targets | |
# ---------------------------------------------------------------- | |
# audio | |
if has_audio(sample) and processor["audio"].is_contiguous: | |
aud_positions = [i for i, x in enumerate(input_ids) if x == AUD_TAG_ID] | |
# assert len(aud_positions) == len(sample["audios"]), sample | |
assert len(aud_positions) == len(unused_audio_idxs), sample | |
new_input_ids = [] | |
new_targets = [] | |
st = 0 | |
for aud_idx, aud_pos in enumerate(aud_positions): | |
aud_idx = unused_audio_idxs[aud_idx] | |
audio = processor["audio"].process_audios(sample["audios"][aud_idx], is_contiguous=True) | |
audios.append(audio) | |
audio_token_length = audio.size(0) + 4 | |
# audio_token_length = audio.size(0) | |
new_input_ids += input_ids[st:aud_pos] | |
new_targets += targets[st:aud_pos] | |
new_input_ids += [AUD_START_ID] | |
new_targets += [IGNORE_TOKEN_ID] | |
audio_indice_b = torch.zeros( | |
1, audio_token_length, dtype=torch.int64 | |
) # This will change in collate_fn | |
audio_indice_s = ( | |
torch.arange(len(new_input_ids), len(new_input_ids) + audio_token_length) | |
.unsqueeze(0) | |
.repeat(1, 1) | |
) | |
audio_indice_b_s = torch.stack( | |
[audio_indice_b, audio_indice_s], dim=0 | |
) # 2, num_image, image_length | |
audio_indices.append(audio_indice_b_s) | |
new_input_ids += [AUD_CONTEXT_ID] * audio_token_length | |
new_targets += [IGNORE_TOKEN_ID] * audio_token_length | |
new_input_ids += [AUD_END_ID] | |
new_targets += [IGNORE_TOKEN_ID] | |
st = aud_pos + 1 | |
new_input_ids += input_ids[st:] | |
new_targets += targets[st:] | |
input_ids = new_input_ids | |
targets = new_targets | |
if len(images) > 0: | |
images = torch.cat(images, dim=0) | |
if len(image_indices) > 0: | |
image_indices = torch.cat(image_indices, dim=1) | |
# if len(audios) > 0: | |
# audios = torch.cat(audios, dim=0) | |
# if len(audio_indices) > 0: | |
# audio_indices = torch.cat(audio_indices, dim=1) | |
attention_mask = [1] * len(input_ids) | |
# print("sample", sample, flush=True) | |
# print("input_ids", input_ids, flush=True) | |
# print("targets", targets[:100], flush=True) | |
# print("images", [xx.shape for x in images for xx in x], flush=True) | |
return dict( | |
input_ids=input_ids, | |
labels=targets, | |
attention_mask=attention_mask, | |
images=images, | |
image_indices=image_indices, | |
audios=audios, | |
audio_indices=audio_indices, | |
) | |
def has_video(sample): | |
# video | |
if ( | |
"videos" in sample | |
and isinstance(sample["videos"], list) | |
and None not in sample["videos"] | |
and len(sample["videos"]) | |
): | |
return True | |
return False | |
def has_image(sample): | |
# image | |
if ( | |
"images" in sample | |
and isinstance(sample["images"], list) | |
and None not in sample["images"] | |
and len(sample["images"]) | |
): | |
return True | |
return False | |
def has_audio(sample): | |
# audio | |
if ( | |
"audios" in sample | |
and isinstance(sample["audios"], list) | |
and None not in sample["audios"] | |
and len(sample["audios"]) | |
): | |
return True | |
return False | |