audio-flamingo-2 / train /train_utils.py
root
initial commit
a344f64
raw
history blame
12.1 kB
import time
import os
from tqdm import tqdm
import sys
from copy import deepcopy
from contextlib import suppress
import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import (
FullStateDictConfig,
StateDictType,
)
from torch.distributed.fsdp.api import FullOptimStateDictConfig
from einops import rearrange
class Dict2Class:
def __init__(self, data_dict):
for key, value in data_dict.items():
setattr(self, key, value)
class SysLogger(object):
def __init__(self, filename="../log/log.log"):
self.terminal = sys.stdout
self.log = open(filename, "a")
def write(self, message):
self.terminal.write(message+'\n')
self.log.write(message)
def get_cast_dtype(precision: str):
cast_dtype = None
if precision == "bf16":
cast_dtype = torch.bfloat16
elif precision == "fp16":
cast_dtype = torch.float16
return cast_dtype
def get_mp_policy_dtype(precision: str):
if "bfloat16" in precision or "bf16" in precision:
return torch.bfloat16
elif precision == "fp16":
return torch.float16
else:
return torch.float32
def get_autocast(precision, cache_enabled=True):
if precision == "amp":
return torch.cuda.amp.autocast(cache_enabled=cache_enabled)
elif precision == "amp_bfloat16" or precision == "amp_bf16":
return lambda: torch.cuda.amp.autocast(
dtype=torch.bfloat16, cache_enabled=cache_enabled
)
else:
return suppress
def train_one_epoch(
args,
model,
epoch,
trainloader,
tokenizer,
optimizer,
lr_scheduler,
device_id,
tb
):
# setup loaders
num_batches_per_epoch = len(trainloader)
total_training_steps = num_batches_per_epoch * args.num_epochs
print('num_batches_per_epoch={}, total_training_steps={}'.format(num_batches_per_epoch, total_training_steps))
autocast = get_autocast(
args.precision, cache_enabled=(not args.fsdp)
) # if fsdp, disable cache to save memory
cast_dtype = get_cast_dtype(args.precision)
# setup model
media_token_id = tokenizer("<audio>", add_special_tokens=False)["input_ids"][-1]
assert media_token_id == tokenizer.encode("<audio>")[-1]
endofchunk_token_id = tokenizer("<|endofchunk|>", add_special_tokens=False)["input_ids"][-1]
model.train()
# setup logging
step_time_m = AverageMeter()
data_time_m = AverageMeter()
end = time.time()
# loop through dataloader
for num_steps, batch in tqdm(
enumerate(trainloader),
disable=args.rank != 0,
total=total_training_steps,
initial=(epoch * num_batches_per_epoch)
):
data_time_m.update(time.time() - end)
global_step = num_steps + epoch * num_batches_per_epoch
#### FORWARD PASS ####
audio_clips = batch["audio_clips"].to(device_id, dtype=cast_dtype, non_blocking=True) # (B, N_WINDOWS, WINDOW_LENGTH)
audio_embed_mask = batch["audio_embed_mask"].to(device_id, dtype=cast_dtype, non_blocking=True) # (B, N_WINDOWS)
input_ids = batch["input_ids"].to(device_id, dtype=torch.long, non_blocking=True) # (B, N_TOKENS)
attention_mask = batch["attention_mask"].to(device_id, dtype=cast_dtype, non_blocking=True) # (B, N_TOKENS)
# set up labels; language model is expected to handle shifting
labels = input_ids.clone()
labels[labels == tokenizer.pad_token_id] = -100
labels[:, :1] = -100
labels[labels == tokenizer.encode("<audio>")[-1]] = -100
# mask all prompts except for between <SEP> and <|endofchunk|>
sep_locations = labels == tokenizer.sep_token_id
eoc_locations = labels == endofchunk_token_id
if not all(sep_locations.sum(dim=1) == eoc_locations.sum(dim=1)):
print("Warning: <SEP>-<EoC> pairing mismatch at step {} due to max_token limit.".format(num_steps))
for i in range(labels.shape[0]):
shouldmask = True
for j in range(labels.shape[1]):
if shouldmask and (labels[i][j] != tokenizer.eos_token_id):
masked_value = -100
else:
masked_value = labels[i][j]
if labels[i][j] == tokenizer.sep_token_id:
shouldmask = False
elif labels[i][j] == endofchunk_token_id:
shouldmask = True
labels[i][j] = masked_value
if labels[i][-1] not in [-100, tokenizer.eos_token_id, tokenizer.pad_token_id, endofchunk_token_id]:
for j in range(labels.shape[1]-1, -1, -1):
if labels[i][j] not in [-100, tokenizer.eos_token_id, endofchunk_token_id]:
labels[i][j] = -100
else:
break
labels = labels.to(device_id)
# gradient accumulation w/ fsdp cpu offloading requires a no_sync context manager
with autocast():
output = model(
audio_x=audio_clips,
audio_x_mask=audio_embed_mask,
lang_x=input_ids,
attention_mask=attention_mask,
labels=labels
)
loss = output.loss
divided_loss = loss / args.gradient_accumulation_steps
train_loss = divided_loss * args.loss_multiplier
train_loss.backward()
if (not args.freeze_lm_embeddings) and (
not args.fsdp or args.fsdp_use_orig_params
):
# Mask gradients for input embeddings s.t. we only update the added tokens <audio> and <|endofchunk|>
if args.fsdp:
embed_grad = model.lang_encoder.get_input_embeddings().weight.grad
else:
embed_grad = (
model.module.lang_encoder.get_input_embeddings().weight.grad
)
zero_mask = torch.zeros_like(embed_grad)
zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id])
zero_mask[endofchunk_token_id] = torch.ones_like(
zero_mask[endofchunk_token_id]
)
if args.fsdp:
model.lang_encoder.get_input_embeddings().weight.grad = (
embed_grad * zero_mask
)
else:
model.module.lang_encoder.get_input_embeddings().weight.grad = (
embed_grad * zero_mask
)
# clip gradient norm
if args.fsdp:
"""
The way we clip gradients with FSDP is different than the non-FSDP case,
because during FSDP, gradient norms are computed over certain submodules,
rather than the entire model.
At least for OPT-125M, this didn't seem to make a difference in performance.
"""
model.clip_grad_norm_(1.0)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# step optimizer and log
if (((num_steps + 1) % args.gradient_accumulation_steps) == 0) or (
num_steps == num_batches_per_epoch - 1
):
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# step time and reset end outside of rank 0
step_time_m.update(time.time() - end)
end = time.time()
# rank 0 logging
if args.rank == 0:
samples_per_second = (
args.gradient_accumulation_steps
* args.batch_size
* args.world_size
/ step_time_m.val
)
samples_per_second_per_gpu = (
args.gradient_accumulation_steps
* args.batch_size
/ step_time_m.val
)
log_dict = {
"data_time": data_time_m.avg,
"step_time": step_time_m.avg,
"samples_per_second": samples_per_second,
"samples_per_second_per_gpu": samples_per_second_per_gpu,
"lr": optimizer.param_groups[0]["lr"],
"loss": loss.item()
}
if ((num_steps + 1) % args.logging_steps == 0):
for key in log_dict:
tb.add_scalar("Train/{}".format(key), log_dict[key], global_step)
step_time_m.reset()
data_time_m.reset()
# Log loss to console
if ((num_steps + 1) % args.logging_steps == 0):
print(
f"Step {num_steps+1}/{num_batches_per_epoch} of epoch {epoch+1}/{args.num_epochs} complete. Loss: {loss.item():.3f}\n"
)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def filter_state_dict_to_trainable(model, state_dict):
"""
Remove non-trainable parameters from model state dict.
Exception: Embeddings will not be removed, even if frozen.
This is because we need the new <audio> <|endofchunk|> tokens to
be consistent across initializations.
"""
for (
name,
p,
) in model.named_parameters(): # won't work for fsdp + use_orig_params=False
if "fsdp" in name:
continue
if "embed" in name or isinstance(p, torch.nn.Embedding):
continue
if not p.requires_grad:
name = name.replace("._checkpoint_wrapped_module", "")
if name in state_dict:
del state_dict[name]
else:
print(f"WARNING: filtering but {name} not in state_dict")
# also remove the keys in state_dict generated from
# lang_encoder.old_decoder_blocks and lang_encoder.gated_cross_attn_layers
# because these are already saved in lang_encoder.model...
to_delete = [
n
for n in state_dict.keys()
if ("lang_encoder.old_decoder_blocks" in n)
or ("lang_encoder.gated_cross_attn_layers" in n)
or ("vision_encoder" in n)
]
for name in to_delete:
del state_dict[name]
return state_dict
def save_checkpoint(model, optimizer, lr_scheduler, epoch, args):
"""
Save training checkpoint with model, optimizer, and lr_scheduler state.
"""
if args.fsdp:
FSDP.set_state_dict_type(
model,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(rank0_only=True, offload_to_cpu=True),
FullOptimStateDictConfig(rank0_only=True),
)
model_state = model.state_dict()
optim_state = FSDP.optim_state_dict(model, optimizer, group=args.my_group)
else:
model_state = model.state_dict()
optim_state = optimizer.state_dict()
if args.rank == 0:
if not (args.fsdp and not args.fsdp_use_orig_params):
model_state = filter_state_dict_to_trainable(model, model_state)
checkpoint_dir = os.path.join(args.expdir, args.run_name)
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
checkpoint_dict = {
"epoch": epoch,
"model_state_dict": model_state,
"optimizer_state_dict": optim_state,
"lr_scheduler_state_dict": lr_scheduler.state_dict(),
}
print(f"Saving checkpoint to {checkpoint_dir}/checkpoint_{epoch}.pt")
torch.save(checkpoint_dict, f"{checkpoint_dir}/checkpoint_{epoch}.pt")
if args.delete_previous_checkpoint:
if epoch > 0 and (epoch-1) % 5 != 0:
os.remove(f"{checkpoint_dir}/checkpoint_{epoch-1}.pt")