Spaces:
Running
on
L4
Running
on
L4
import time | |
import os | |
from tqdm import tqdm | |
import sys | |
from copy import deepcopy | |
from contextlib import suppress | |
import torch | |
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | |
from torch.distributed.fsdp import ( | |
FullStateDictConfig, | |
StateDictType, | |
) | |
from torch.distributed.fsdp.api import FullOptimStateDictConfig | |
from einops import rearrange | |
class Dict2Class: | |
def __init__(self, data_dict): | |
for key, value in data_dict.items(): | |
setattr(self, key, value) | |
class SysLogger(object): | |
def __init__(self, filename="../log/log.log"): | |
self.terminal = sys.stdout | |
self.log = open(filename, "a") | |
def write(self, message): | |
self.terminal.write(message+'\n') | |
self.log.write(message) | |
def get_cast_dtype(precision: str): | |
cast_dtype = None | |
if precision == "bf16": | |
cast_dtype = torch.bfloat16 | |
elif precision == "fp16": | |
cast_dtype = torch.float16 | |
return cast_dtype | |
def get_mp_policy_dtype(precision: str): | |
if "bfloat16" in precision or "bf16" in precision: | |
return torch.bfloat16 | |
elif precision == "fp16": | |
return torch.float16 | |
else: | |
return torch.float32 | |
def get_autocast(precision, cache_enabled=True): | |
if precision == "amp": | |
return torch.cuda.amp.autocast(cache_enabled=cache_enabled) | |
elif precision == "amp_bfloat16" or precision == "amp_bf16": | |
return lambda: torch.cuda.amp.autocast( | |
dtype=torch.bfloat16, cache_enabled=cache_enabled | |
) | |
else: | |
return suppress | |
def train_one_epoch( | |
args, | |
model, | |
epoch, | |
trainloader, | |
tokenizer, | |
optimizer, | |
lr_scheduler, | |
device_id, | |
tb | |
): | |
# setup loaders | |
num_batches_per_epoch = len(trainloader) | |
total_training_steps = num_batches_per_epoch * args.num_epochs | |
print('num_batches_per_epoch={}, total_training_steps={}'.format(num_batches_per_epoch, total_training_steps)) | |
autocast = get_autocast( | |
args.precision, cache_enabled=(not args.fsdp) | |
) # if fsdp, disable cache to save memory | |
cast_dtype = get_cast_dtype(args.precision) | |
# setup model | |
media_token_id = tokenizer("<audio>", add_special_tokens=False)["input_ids"][-1] | |
assert media_token_id == tokenizer.encode("<audio>")[-1] | |
endofchunk_token_id = tokenizer("<|endofchunk|>", add_special_tokens=False)["input_ids"][-1] | |
model.train() | |
# setup logging | |
step_time_m = AverageMeter() | |
data_time_m = AverageMeter() | |
end = time.time() | |
# loop through dataloader | |
for num_steps, batch in tqdm( | |
enumerate(trainloader), | |
disable=args.rank != 0, | |
total=total_training_steps, | |
initial=(epoch * num_batches_per_epoch) | |
): | |
data_time_m.update(time.time() - end) | |
global_step = num_steps + epoch * num_batches_per_epoch | |
#### FORWARD PASS #### | |
audio_clips = batch["audio_clips"].to(device_id, dtype=cast_dtype, non_blocking=True) # (B, N_WINDOWS, WINDOW_LENGTH) | |
audio_embed_mask = batch["audio_embed_mask"].to(device_id, dtype=cast_dtype, non_blocking=True) # (B, N_WINDOWS) | |
input_ids = batch["input_ids"].to(device_id, dtype=torch.long, non_blocking=True) # (B, N_TOKENS) | |
attention_mask = batch["attention_mask"].to(device_id, dtype=cast_dtype, non_blocking=True) # (B, N_TOKENS) | |
# set up labels; language model is expected to handle shifting | |
labels = input_ids.clone() | |
labels[labels == tokenizer.pad_token_id] = -100 | |
labels[:, :1] = -100 | |
labels[labels == tokenizer.encode("<audio>")[-1]] = -100 | |
# mask all prompts except for between <SEP> and <|endofchunk|> | |
sep_locations = labels == tokenizer.sep_token_id | |
eoc_locations = labels == endofchunk_token_id | |
if not all(sep_locations.sum(dim=1) == eoc_locations.sum(dim=1)): | |
print("Warning: <SEP>-<EoC> pairing mismatch at step {} due to max_token limit.".format(num_steps)) | |
for i in range(labels.shape[0]): | |
shouldmask = True | |
for j in range(labels.shape[1]): | |
if shouldmask and (labels[i][j] != tokenizer.eos_token_id): | |
masked_value = -100 | |
else: | |
masked_value = labels[i][j] | |
if labels[i][j] == tokenizer.sep_token_id: | |
shouldmask = False | |
elif labels[i][j] == endofchunk_token_id: | |
shouldmask = True | |
labels[i][j] = masked_value | |
if labels[i][-1] not in [-100, tokenizer.eos_token_id, tokenizer.pad_token_id, endofchunk_token_id]: | |
for j in range(labels.shape[1]-1, -1, -1): | |
if labels[i][j] not in [-100, tokenizer.eos_token_id, endofchunk_token_id]: | |
labels[i][j] = -100 | |
else: | |
break | |
labels = labels.to(device_id) | |
# gradient accumulation w/ fsdp cpu offloading requires a no_sync context manager | |
with autocast(): | |
output = model( | |
audio_x=audio_clips, | |
audio_x_mask=audio_embed_mask, | |
lang_x=input_ids, | |
attention_mask=attention_mask, | |
labels=labels | |
) | |
loss = output.loss | |
divided_loss = loss / args.gradient_accumulation_steps | |
train_loss = divided_loss * args.loss_multiplier | |
train_loss.backward() | |
if (not args.freeze_lm_embeddings) and ( | |
not args.fsdp or args.fsdp_use_orig_params | |
): | |
# Mask gradients for input embeddings s.t. we only update the added tokens <audio> and <|endofchunk|> | |
if args.fsdp: | |
embed_grad = model.lang_encoder.get_input_embeddings().weight.grad | |
else: | |
embed_grad = ( | |
model.module.lang_encoder.get_input_embeddings().weight.grad | |
) | |
zero_mask = torch.zeros_like(embed_grad) | |
zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id]) | |
zero_mask[endofchunk_token_id] = torch.ones_like( | |
zero_mask[endofchunk_token_id] | |
) | |
if args.fsdp: | |
model.lang_encoder.get_input_embeddings().weight.grad = ( | |
embed_grad * zero_mask | |
) | |
else: | |
model.module.lang_encoder.get_input_embeddings().weight.grad = ( | |
embed_grad * zero_mask | |
) | |
# clip gradient norm | |
if args.fsdp: | |
""" | |
The way we clip gradients with FSDP is different than the non-FSDP case, | |
because during FSDP, gradient norms are computed over certain submodules, | |
rather than the entire model. | |
At least for OPT-125M, this didn't seem to make a difference in performance. | |
""" | |
model.clip_grad_norm_(1.0) | |
else: | |
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
# step optimizer and log | |
if (((num_steps + 1) % args.gradient_accumulation_steps) == 0) or ( | |
num_steps == num_batches_per_epoch - 1 | |
): | |
optimizer.step() | |
lr_scheduler.step() | |
optimizer.zero_grad(set_to_none=True) | |
# step time and reset end outside of rank 0 | |
step_time_m.update(time.time() - end) | |
end = time.time() | |
# rank 0 logging | |
if args.rank == 0: | |
samples_per_second = ( | |
args.gradient_accumulation_steps | |
* args.batch_size | |
* args.world_size | |
/ step_time_m.val | |
) | |
samples_per_second_per_gpu = ( | |
args.gradient_accumulation_steps | |
* args.batch_size | |
/ step_time_m.val | |
) | |
log_dict = { | |
"data_time": data_time_m.avg, | |
"step_time": step_time_m.avg, | |
"samples_per_second": samples_per_second, | |
"samples_per_second_per_gpu": samples_per_second_per_gpu, | |
"lr": optimizer.param_groups[0]["lr"], | |
"loss": loss.item() | |
} | |
if ((num_steps + 1) % args.logging_steps == 0): | |
for key in log_dict: | |
tb.add_scalar("Train/{}".format(key), log_dict[key], global_step) | |
step_time_m.reset() | |
data_time_m.reset() | |
# Log loss to console | |
if ((num_steps + 1) % args.logging_steps == 0): | |
print( | |
f"Step {num_steps+1}/{num_batches_per_epoch} of epoch {epoch+1}/{args.num_epochs} complete. Loss: {loss.item():.3f}\n" | |
) | |
class AverageMeter(object): | |
"""Computes and stores the average and current value""" | |
def __init__(self): | |
self.reset() | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |
def filter_state_dict_to_trainable(model, state_dict): | |
""" | |
Remove non-trainable parameters from model state dict. | |
Exception: Embeddings will not be removed, even if frozen. | |
This is because we need the new <audio> <|endofchunk|> tokens to | |
be consistent across initializations. | |
""" | |
for ( | |
name, | |
p, | |
) in model.named_parameters(): # won't work for fsdp + use_orig_params=False | |
if "fsdp" in name: | |
continue | |
if "embed" in name or isinstance(p, torch.nn.Embedding): | |
continue | |
if not p.requires_grad: | |
name = name.replace("._checkpoint_wrapped_module", "") | |
if name in state_dict: | |
del state_dict[name] | |
else: | |
print(f"WARNING: filtering but {name} not in state_dict") | |
# also remove the keys in state_dict generated from | |
# lang_encoder.old_decoder_blocks and lang_encoder.gated_cross_attn_layers | |
# because these are already saved in lang_encoder.model... | |
to_delete = [ | |
n | |
for n in state_dict.keys() | |
if ("lang_encoder.old_decoder_blocks" in n) | |
or ("lang_encoder.gated_cross_attn_layers" in n) | |
or ("vision_encoder" in n) | |
] | |
for name in to_delete: | |
del state_dict[name] | |
return state_dict | |
def save_checkpoint(model, optimizer, lr_scheduler, epoch, args): | |
""" | |
Save training checkpoint with model, optimizer, and lr_scheduler state. | |
""" | |
if args.fsdp: | |
FSDP.set_state_dict_type( | |
model, | |
StateDictType.FULL_STATE_DICT, | |
FullStateDictConfig(rank0_only=True, offload_to_cpu=True), | |
FullOptimStateDictConfig(rank0_only=True), | |
) | |
model_state = model.state_dict() | |
optim_state = FSDP.optim_state_dict(model, optimizer, group=args.my_group) | |
else: | |
model_state = model.state_dict() | |
optim_state = optimizer.state_dict() | |
if args.rank == 0: | |
if not (args.fsdp and not args.fsdp_use_orig_params): | |
model_state = filter_state_dict_to_trainable(model, model_state) | |
checkpoint_dir = os.path.join(args.expdir, args.run_name) | |
if not os.path.exists(checkpoint_dir): | |
os.makedirs(checkpoint_dir) | |
checkpoint_dict = { | |
"epoch": epoch, | |
"model_state_dict": model_state, | |
"optimizer_state_dict": optim_state, | |
"lr_scheduler_state_dict": lr_scheduler.state_dict(), | |
} | |
print(f"Saving checkpoint to {checkpoint_dir}/checkpoint_{epoch}.pt") | |
torch.save(checkpoint_dict, f"{checkpoint_dir}/checkpoint_{epoch}.pt") | |
if args.delete_previous_checkpoint: | |
if epoch > 0 and (epoch-1) % 5 != 0: | |
os.remove(f"{checkpoint_dir}/checkpoint_{epoch-1}.pt") | |