Spaces:
Runtime error
Runtime error
File size: 7,449 Bytes
6a62ffb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
import torch
import torch.distributed as dist
from fairseq.dataclass.configs import FairseqBMUFConfig
from fairseq.dataclass.utils import gen_parser_from_dataclass
from fairseq.optim.fairseq_optimizer import FairseqOptimizer
class FairseqBMUF(FairseqOptimizer):
"""
Implements incremental block distributed data parallelism similar to
https://ieeexplore.ieee.org/document/7472805
Paper title: Scalable training of deep learning machines by incremental
block training with intra-block parallel optimization and blockwise
model-update filtering
"""
def __init__(self, cfg: FairseqBMUFConfig, optimizer):
super().__init__(cfg)
self._optimizer = optimizer
self._num_updates = 0
self.sync_iter = cfg.global_sync_iter
self.block_momentum = cfg.block_momentum
self.block_lr = cfg.block_lr
self._reset_local_data()
self.warmup_iteration = cfg.warmup_iterations
self.use_nbm = cfg.use_nbm
self.initial_state = self._optimizer.state_dict()
self.average_sync = self.cfg.average_sync
self.world_size = self.cfg.distributed_world_size
@staticmethod
def add_args(parser):
"""Add optimizer-specific arguments to the parser."""
gen_parser_from_dataclass(parser, FairseqBMUFConfig())
@property
def optimizer(self):
return self._optimizer.optimizer
@property
def optimizer_config(self):
return self._optimizer.optimizer_config
def get_lr(self):
return self._optimizer.get_lr()
def set_lr(self, lr):
self._optimizer.set_lr(lr)
def state_dict(self):
return self._optimizer.state_dict()
def load_state_dict(self, state_dict, optimizer_overrides=None):
self._optimizer.load_state_dict(state_dict, optimizer_overrides)
self.initial_state = self._optimizer.state_dict()
def multiply_grads(self, c):
"""Multiplies grads by a constant *c*."""
self._optimizer.multiply_grads(c)
def clip_grad_norm(self, max_norm, aggregate_norm_fn=None):
"""Clips gradient norm."""
return self._optimizer.clip_grad_norm(max_norm, aggregate_norm_fn)
def average_params(self):
self._optimizer.average_params()
def _block_sync(self):
if self.world_size <= 1:
return
# Update the global model using local models from all GPUs
# (Step-1) Calculate grad between previously synced model and
# currrent local model
if self.block_momentum != 0:
self._calc_grad()
# (Step-2) Average gradient from all GPUs
self._avg_grad_from_all_gpus()
# (Step-3) Calculate global momentum and update the global model
if self.block_momentum != 0:
self._update_global_model()
# (Step-4) Average local optimizer params
if self.average_sync:
self.average_params()
def _is_warmup_end(self):
# Check whether train iterations is equal to warmup iter
if self.get_num_updates() == self.warmup_iteration:
return True
return False
def _is_bmuf_iter(self):
# Check whether train iterations is equal to bmuf sync iter
if (self.get_num_updates() > self.warmup_iteration) and (
self.get_num_updates() % self.sync_iter == 0
):
return True
return False
def _warmup_sync(self, root_rank=0):
if self.world_size <= 1:
return
# Broadcast the local model to all gpus
for param in self.params:
dist.broadcast(param.data, src=root_rank)
# Update local optimizer state
if self.average_sync:
self._optimizer.average_params()
else:
self._optimizer.load_state_dict(self.initial_state)
self._reset_local_data()
def step(self, closure=None):
"""Performs a single optimization step."""
self._optimizer.step(closure)
self.set_num_updates(self.get_num_updates() + 1)
if self._is_warmup_end():
self._warmup_sync()
elif self._is_bmuf_iter():
self._block_sync()
def zero_grad(self):
"""Clears the gradients of all optimized parameters."""
self._optimizer.zero_grad()
def get_num_updates(self):
"""Get the number of parameters updates."""
return self._num_updates
def set_num_updates(self, num_updates):
"""Set the number of parameters updates."""
self._num_updates = num_updates
@torch.no_grad()
def _reset_local_data(self):
# (Step-0) Initialize global momentum parameters and store global copy on each gpu
self.global_params = [torch.zeros_like(p.data) for p in self.params]
self.smoothed_grads = [p.data.new_zeros(p.data.size()) for p in self.params]
self.grads = [p.data.new_zeros(p.data.size()) for p in self.params]
# saving the global model locally for calculating gradient during bmuf sync
for param, global_param in zip(self.params, self.global_params):
global_param.copy_(param.data)
@torch.no_grad()
def _calc_grad(self):
# global_params is basically the global copy from the previously finished
# synchronisation. param.data is local parameter after block_sync_freq
# for the local gpu. so grad is difference between previously synced
# model and currrent local model.
for index, (param, global_param) in enumerate(
zip(self.params, self.global_params)
):
self.grads[index] = global_param - param.data
def _avg_grad_from_all_gpus(self):
for index, param in enumerate(self.params):
sync_para = param.data if self.block_momentum == 0 else self.grads[index]
sync_para /= float(dist.get_world_size())
dist.all_reduce(sync_para, op=dist.ReduceOp.SUM)
@torch.no_grad()
def _update_global_model(self):
for index, (param, global_param, smoothed_grad, grad) in enumerate(
zip(
self.params,
self.global_params,
self.smoothed_grads,
# all gpus would share the same value of smoothed_grad, since it is
# always computed on synchronized gradients.
self.grads,
)
):
# global_param is basically last syncrhornized parameter. though
# smoothed_grad is local, all processes will have same value of
# smoothed_grad and hence param is globally synchronized copy.
# smoothed_grad(t) = BM * smoothed_grad(t-1) + BM_lr * grad(t)
smoothed_grad = self.block_momentum * smoothed_grad + self.block_lr * grad
param.data.copy_(global_param - smoothed_grad)
# A Nesterov momentum here is to do a partial weight update before
# calculating the gradient
if self.use_nbm:
param.data.copy_(param.data - self.block_momentum * smoothed_grad)
# backup for the next synchronization.
self.smoothed_grads[index] = smoothed_grad
global_param.copy_(param.data)
|