Spaces:
Running
Running
| from collections import OrderedDict | |
| import torch | |
| def update_ema( | |
| ema_model: torch.nn.Module, model: torch.nn.Module, optimizer=None, decay: float = 0.9999, sharded: bool = True | |
| ) -> None: | |
| """ | |
| Step the EMA model towards the current model. | |
| """ | |
| ema_params = OrderedDict(ema_model.named_parameters()) | |
| model_params = OrderedDict(model.named_parameters()) | |
| for name, param in model_params.items(): | |
| if name == "pos_embed": | |
| continue | |
| if param.requires_grad == False: | |
| continue | |
| if not sharded: | |
| param_data = param.data | |
| ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay) | |
| else: | |
| if param.data.dtype != torch.float32: | |
| param_id = id(param) | |
| master_param = optimizer._param_store.working_to_master_param[param_id] | |
| param_data = master_param.data | |
| else: | |
| param_data = param.data | |
| ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay) | |