Spaces:
Build error
Build error
File size: 5,899 Bytes
910e2ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import math
import sys
from typing import Iterable
import torch
import torch.nn as nn
import accelerate
from .utils import MetricLogger, SmoothedValue
def update_ema_for_dit(model, model_ema, accelerator, decay):
"""Apply exponential moving average update.
The weights are updated in-place as follow:
w_ema = w_ema * decay + (1 - decay) * w
Args:
model: active model that is being optimized
model_ema: running average model
decay: exponential decay parameter
"""
with torch.no_grad():
msd = accelerator.get_state_dict(model)
for k, ema_v in model_ema.state_dict().items():
if k in msd:
model_v = msd[k].detach().to(ema_v.device, dtype=ema_v.dtype)
ema_v.copy_(ema_v * decay + (1.0 - decay) * model_v)
def get_decay(optimization_step: int, ema_decay: float) -> float:
"""
Compute the decay factor for the exponential moving average.
"""
step = max(0, optimization_step - 1)
if step <= 0:
return 0.0
cur_decay_value = (1 + step) / (10 + step)
cur_decay_value = min(cur_decay_value, ema_decay)
cur_decay_value = max(cur_decay_value, 0.0)
return cur_decay_value
def train_one_epoch_with_fsdp(
runner,
model_ema: torch.nn.Module,
accelerator: accelerate.Accelerator,
model_dtype: str,
data_loader: Iterable,
optimizer: torch.optim.Optimizer,
lr_schedule_values,
device: torch.device,
epoch: int,
clip_grad: float = 1.0,
start_steps=None,
args=None,
print_freq=20,
iters_per_epoch=2000,
ema_decay=0.9999,
use_temporal_pyramid=True,
):
runner.dit.train()
metric_logger = MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('min_lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
train_loss = 0.0
print("Start training epoch {}, {} iters per inner epoch. Training dtype {}".format(epoch, iters_per_epoch, model_dtype))
for step in metric_logger.log_every(range(iters_per_epoch), print_freq, header):
if step >= iters_per_epoch:
break
if lr_schedule_values is not None:
for i, param_group in enumerate(optimizer.param_groups):
param_group["lr"] = lr_schedule_values[start_steps] * param_group.get("lr_scale", 1.0)
for _ in range(args.gradient_accumulation_steps):
with accelerator.accumulate(runner.dit):
# To fetch the data sample and Move the input to device
samples = next(data_loader)
video = samples['video'].to(accelerator.device)
text = samples['text']
identifier = samples['identifier']
# Perform the forward using the accerlate
loss, log_loss = runner(video, text, identifier,
use_temporal_pyramid=use_temporal_pyramid, accelerator=accelerator)
# Check if the loss is nan
loss_value = loss.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value), force=True)
sys.exit(1)
avg_loss = accelerator.gather(loss.repeat(args.batch_size)).mean()
train_loss += avg_loss.item() / args.gradient_accumulation_steps
accelerator.backward(loss)
# clip the gradient
if accelerator.sync_gradients:
params_to_clip = runner.dit.parameters()
grad_norm = accelerator.clip_grad_norm_(params_to_clip, clip_grad)
# To deal with the abnormal data point
if train_loss >= 2.0:
print(f"The ERROR data sample, finding extreme high loss {train_loss}, skip updating the parameters", force=True)
# zero out the gradient, do not update
optimizer.zero_grad()
train_loss = 0.001 # fix the loss for logging
else:
optimizer.step()
optimizer.zero_grad()
if accelerator.sync_gradients:
# Update every 100 steps
if model_ema is not None and start_steps % 100 == 0:
# cur_ema_decay = get_decay(start_steps, ema_decay)
cur_ema_decay = ema_decay
update_ema_for_dit(runner.dit, model_ema, accelerator, decay=cur_ema_decay)
start_steps += 1
# Report to tensorboard
accelerator.log({"train_loss": train_loss}, step=start_steps)
metric_logger.update(loss=train_loss)
train_loss = 0.0
min_lr = 10.
max_lr = 0.
for group in optimizer.param_groups:
min_lr = min(min_lr, group["lr"])
max_lr = max(max_lr, group["lr"])
metric_logger.update(lr=max_lr)
metric_logger.update(min_lr=min_lr)
weight_decay_value = None
for group in optimizer.param_groups:
if group["weight_decay"] > 0:
weight_decay_value = group["weight_decay"]
metric_logger.update(weight_decay=weight_decay_value)
metric_logger.update(grad_norm=grad_norm)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |