Spaces:
Running
on
L4
Running
on
L4
File size: 5,090 Bytes
a344f64 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import argparse
import functools
import os
import random
from tqdm import tqdm
import sys
sys.path.append('../')
import yaml
import time
import numpy as np
import torch
from data.data import get_audiotext_dataloader
@torch.no_grad()
def validation_losses(model, data_config, clap_config, tokenizer, batch_size, autocast, cast_dtype, device_id, verbose=True):
model.eval()
@torch.no_grad()
def get_val_loss(validloader):
loss_sum = 0.0
for idx, batch in tqdm(enumerate(validloader)):
audio_clips = batch["audio_clips"].to(device_id, dtype=cast_dtype, non_blocking=True)
audio_embed_mask = batch["audio_embed_mask"].to(device_id, dtype=cast_dtype, non_blocking=True)
input_ids = batch["input_ids"].to(device_id, dtype=cast_dtype, non_blocking=True)
attention_mask = batch["attention_mask"].to(device_id, dtype=cast_dtype, non_blocking=True)
labels = input_ids.clone()
labels[labels == tokenizer.pad_token_id] = -100
labels[:, :1] = -100
labels[labels == tokenizer.encode("<audio>")[-1]] = -100
sep_locations = labels == tokenizer.sep_token_id
eoc_locations = labels == endofchunk_token_id
for i in range(labels.shape[0]):
shouldmask = True
for j in range(labels.shape[1]):
if shouldmask and (labels[i][j] != tokenizer.eos_token_id):
masked_value = -100
else:
masked_value = labels[i][j]
if labels[i][j] == tokenizer.sep_token_id:
shouldmask = False
elif labels[i][j] == endofchunk_token_id:
shouldmask = True
labels[i][j] = masked_value
if labels[i][-1] not in [-100, tokenizer.eos_token_id, tokenizer.pad_token_id, endofchunk_token_id]:
for j in range(labels.shape[1]-1, -1, -1):
if labels[i][j] not in [-100, tokenizer.eos_token_id, endofchunk_token_id]:
labels[i][j] = -100
else:
break
labels = labels.to(device_id)
with autocast():
output = model(
audio_x=audio_clips,
audio_x_mask=audio_embed_mask,
lang_x=input_ids,
attention_mask=attention_mask,
labels=labels
)
valid_loss_no_multiplier = output.loss.item()
loss_sum += valid_loss_no_multiplier
return loss_sum / ((idx+1) * batch_size)
media_token_id = tokenizer("<audio>", add_special_tokens=False)["input_ids"][-1]
assert media_token_id == tokenizer.encode("<audio>")[-1]
endofchunk_token_id = tokenizer("<|endofchunk|>", add_special_tokens=False)["input_ids"][-1]
valid_losses = {}
all_valid_AudioTextDataInfo = get_audiotext_dataloader(data_config, clap_config, tokenizer, batch_size, split='val')
for valid_dataset_name in all_valid_AudioTextDataInfo:
if verbose:
print('computing validation loss on {}'.format(valid_dataset_name))
validloader = all_valid_AudioTextDataInfo[valid_dataset_name].dataloader
valid_losses[valid_dataset_name] = get_val_loss(validloader)
if verbose:
print('validation loss on {} is {:.3f}'.format(valid_dataset_name, valid_losses[valid_dataset_name]))
model.train()
return valid_losses
if __name__ == "__main__":
from src.factory import create_model_and_transforms
from train_utils import Dict2Class, get_autocast, get_cast_dtype
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, default='../configs/config.yaml', help='yaml config path')
parsed_args = parser.parse_args()
config = yaml.load(open(parsed_args.config), Loader=yaml.FullLoader)
data_config = config['data_config']
model_config = config['model_config']
clap_config = config['clap_config']
args = Dict2Class(config['train_config'])
os.environ["TOKENIZERS_PARALLELISM"] = "false" # disable the tokenizer parallelism warning
model, tokenizer = create_model_and_transforms(
**model_config,
clap_config=clap_config,
use_local_files=args.offline,
gradient_checkpointing=args.gradient_checkpointing,
freeze_lm_embeddings=args.freeze_lm_embeddings,
)
device_id = 0
model = model.to(device_id)
autocast = get_autocast(
args.precision, cache_enabled=(not args.fsdp)
) # if fsdp, disable cache to save memory
cast_dtype = get_cast_dtype(args.precision)
valid_losses = validation_losses(
model,
data_config,
clap_config,
tokenizer,
args.batch_size,
autocast,
cast_dtype,
device_id,
verbose=True
)
print(valid_losses) |