File size: 1,991 Bytes
c865888
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from torch import nn
import torch
from deepspeed.runtime.zero import GatheredParameters
import deepspeed
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus

def _z3_params_to_fetch(param_list):
    return [
        p for p in param_list
        if hasattr(p, 'ds_id') and p.ds_status == ZeroParamStatus.NOT_AVAILABLE
    ]


def moving_average(model, model_ema, beta=0.9999, device=None, zero_stage=3):
    zero_stage_3 = (zero_stage == 3)
    with torch.no_grad():
        for param, param_ema in zip(model.parameters(),
                                    model_ema.parameters()):
            # TODO: use prefiltering for efficiency
            params_to_fetch = _z3_params_to_fetch([param, param_ema
                                                   ]) if zero_stage_3 else []
            should_gather_param = len(params_to_fetch) > 0
            with deepspeed.zero.GatheredParameters(
                    params_to_fetch, enabled=should_gather_param):
                data = param.data
                if device is not None:
                    data = data.to(device)
                #print('real model',data.shape, data)
                #print('ema model',param_ema.shape, param_ema.data)
                param_ema.data.copy_(torch.lerp(data, param_ema.data, beta))
                #print('after ema copy',param_ema.shape, param_ema.data)


def clone_zero_model(src_model, dst_model, zero_stage=0):
    zero_stage_3 = (zero_stage == 3)
    with torch.no_grad():
        for src_param, dst_param in zip(src_model.parameters(), dst_model.parameters()):
            # TODO: use prefiltering for efficiency
            params_to_fetch = _z3_params_to_fetch([src_param, dst_param
                                                   ]) if zero_stage_3 else []
            should_gather_param = len(params_to_fetch) > 0
            with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=should_gather_param):
                dst_param.data.copy_(src_param.data)