File size: 5,899 Bytes
910e2ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import math
import sys
from typing import Iterable

import torch
import torch.nn as nn
import accelerate
from .utils import MetricLogger, SmoothedValue


def update_ema_for_dit(model, model_ema, accelerator, decay):
    """Apply exponential moving average update.



    The  weights are updated in-place as follow:

    w_ema = w_ema * decay + (1 - decay) * w

    Args:

        model: active model that is being optimized

        model_ema: running average model

        decay: exponential decay parameter

    """
    with torch.no_grad():
        msd = accelerator.get_state_dict(model)
        for k, ema_v in model_ema.state_dict().items():
            if k in msd:
                model_v = msd[k].detach().to(ema_v.device, dtype=ema_v.dtype)
                ema_v.copy_(ema_v * decay + (1.0 - decay) * model_v)


def get_decay(optimization_step: int, ema_decay: float) -> float:
    """

    Compute the decay factor for the exponential moving average.

    """
    step = max(0, optimization_step - 1)

    if step <= 0:
        return 0.0

    cur_decay_value = (1 + step) / (10 + step)
    cur_decay_value = min(cur_decay_value, ema_decay)
    cur_decay_value = max(cur_decay_value, 0.0)

    return cur_decay_value


def train_one_epoch_with_fsdp(

    runner,

    model_ema: torch.nn.Module,

    accelerator: accelerate.Accelerator,

    model_dtype: str,

    data_loader: Iterable, 

    optimizer: torch.optim.Optimizer,

    lr_schedule_values,

    device: torch.device, 

    epoch: int, 

    clip_grad: float = 1.0,

    start_steps=None,

    args=None,

    print_freq=20,

    iters_per_epoch=2000,

    ema_decay=0.9999,

    use_temporal_pyramid=True,

):
    runner.dit.train()
    metric_logger = MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter('min_lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    train_loss = 0.0

    print("Start training epoch {}, {} iters per inner epoch. Training dtype {}".format(epoch, iters_per_epoch, model_dtype))

    for step in metric_logger.log_every(range(iters_per_epoch), print_freq, header):
        if step >= iters_per_epoch:
            break

        if lr_schedule_values is not None:
            for i, param_group in enumerate(optimizer.param_groups):
                param_group["lr"] = lr_schedule_values[start_steps] * param_group.get("lr_scale", 1.0)

        for _ in range(args.gradient_accumulation_steps):

            with accelerator.accumulate(runner.dit):
                # To fetch the data sample and Move the input to device
                samples = next(data_loader)
                video =  samples['video'].to(accelerator.device)
                text = samples['text']
                identifier = samples['identifier']

                # Perform the forward using the accerlate
                loss, log_loss = runner(video, text, identifier, 
                    use_temporal_pyramid=use_temporal_pyramid, accelerator=accelerator)

                # Check if the loss is nan
                loss_value = loss.item()
                if not math.isfinite(loss_value):
                    print("Loss is {}, stopping training".format(loss_value), force=True)
                    sys.exit(1)

                avg_loss = accelerator.gather(loss.repeat(args.batch_size)).mean()

                train_loss += avg_loss.item() / args.gradient_accumulation_steps

                accelerator.backward(loss)

                # clip the gradient
                if accelerator.sync_gradients:
                    params_to_clip = runner.dit.parameters()
                    grad_norm = accelerator.clip_grad_norm_(params_to_clip, clip_grad)
                
                # To deal with the abnormal data point
                if train_loss >= 2.0:
                    print(f"The ERROR data sample, finding extreme high loss {train_loss}, skip updating the parameters", force=True)
                    # zero out the gradient, do not update
                    optimizer.zero_grad()
                    train_loss = 0.001    # fix the loss for logging
                else:
                    optimizer.step()
                    optimizer.zero_grad()

            if accelerator.sync_gradients:
                # Update every 100 steps
                if model_ema is not None and start_steps % 100 == 0:
                    # cur_ema_decay = get_decay(start_steps, ema_decay)
                    cur_ema_decay = ema_decay
                    update_ema_for_dit(runner.dit, model_ema, accelerator, decay=cur_ema_decay)

                start_steps += 1

                # Report to tensorboard
                accelerator.log({"train_loss": train_loss}, step=start_steps)
                metric_logger.update(loss=train_loss)

                train_loss = 0.0

                min_lr = 10.
                max_lr = 0.
                for group in optimizer.param_groups:
                    min_lr = min(min_lr, group["lr"])
                    max_lr = max(max_lr, group["lr"])

                metric_logger.update(lr=max_lr)
                metric_logger.update(min_lr=min_lr)
                weight_decay_value = None
                for group in optimizer.param_groups:
                    if group["weight_decay"] > 0:
                        weight_decay_value = group["weight_decay"]
                metric_logger.update(weight_decay=weight_decay_value)
                metric_logger.update(grad_norm=grad_norm)

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}