Spaces:
Paused
Paused
| import math | |
| import os | |
| import signal | |
| import sys | |
| import time | |
| from typing import List, Optional, Tuple, Union | |
| import torch | |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | |
| # from memory_profiler import profile | |
| import infinity.utils.dist as dist | |
| from infinity.utils import misc | |
| class NullCtx: | |
| def __enter__(self): | |
| pass | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| pass | |
| def handle_timeout(signum, frame): | |
| raise TimeoutError('took too long') | |
| def per_param_clip_grad_norm_(parameters, thresh: float, stable=False, fp=None) -> (float, float): | |
| skipped, max_grad = [], 0 | |
| for pi, p in enumerate(parameters): | |
| if p.grad is not None: | |
| g = p.grad.data.norm(2).item() + 1e-7 | |
| max_grad = max(max_grad, g) | |
| clip_coef = thresh / g | |
| if clip_coef < 1: | |
| if stable and clip_coef < 0.2: | |
| skipped.append(clip_coef) | |
| p.grad.data.mul_(0) # todo NOTE: inf.mul_(0)==nan will shrink the scale ratio, but inf.zero_()==0 won't | |
| else: | |
| p.grad.data.mul_(clip_coef) | |
| # if fp is not None: fp.write(f'[per_param_clip_grad_norm_:47] finished.\n'); fp.flush() | |
| return 0 if len(skipped) == 0 else math.log10(max(min(skipped), 1e-7)), max_grad | |
| class AmpOptimizer: | |
| def __init__( | |
| self, | |
| model_name_3letters: str, mixed_precision: int, | |
| optimizer: torch.optim.Optimizer, model_maybe_fsdp: Union[torch.nn.Module, FSDP], | |
| r_accu: float, grad_clip: float, zero: int, | |
| ): | |
| self.enable_amp = mixed_precision > 0 | |
| self.zero = zero | |
| if self.enable_amp: | |
| self.using_fp16_rather_bf16 = mixed_precision != 2 | |
| self.max_sc = float(mixed_precision if mixed_precision > 128 else 32768) | |
| # todo: on both V100 and A100, torch.get_autocast_gpu_dtype() returns fp16, not bf16. | |
| self.amp_ctx = torch.autocast('cuda', enabled=True, dtype=torch.float16 if self.using_fp16_rather_bf16 else torch.bfloat16, cache_enabled=self.zero == 0) # todo: cache_enabled=False | |
| if self.using_fp16_rather_bf16: | |
| self.scaler = torch.cuda.amp.GradScaler(init_scale=2. ** 11, growth_interval=1000) | |
| else: | |
| self.scaler = None | |
| else: | |
| self.using_fp16_rather_bf16 = True | |
| self.amp_ctx = NullCtx() | |
| self.scaler = None | |
| t = torch.zeros(dist.get_world_size()) | |
| t[dist.get_rank()] = float(self.enable_amp) | |
| dist.allreduce(t) | |
| assert round(t.sum().item()) in {0, dist.get_world_size()}, f'enable_amp: {t}' | |
| t = torch.zeros(dist.get_world_size()) | |
| t[dist.get_rank()] = float(self.using_fp16_rather_bf16) | |
| dist.allreduce(t) | |
| assert round(t.sum().item()) in {0, dist.get_world_size()}, f'using_fp16_rather_bf16: {t}' | |
| self.model_name_3letters = model_name_3letters | |
| self.optimizer, self.model_maybe_fsdp = optimizer, model_maybe_fsdp | |
| self.r_accu = r_accu | |
| self.paras = self.names = ... # todo: solve EMA-related codes | |
| self.grad_clip, self.grad_clip_we = grad_clip, 0 # todo: disable wclip | |
| if self.grad_clip > 100: | |
| self.grad_clip %= 100 | |
| self.per_param = True | |
| else: | |
| self.per_param = False | |
| self.per_param = False # todo: disable wclip | |
| self.early_clipping = grad_clip > 0 and not hasattr(optimizer, 'global_grad_norm') | |
| self.late_clipping = grad_clip > 0 and hasattr(optimizer, 'global_grad_norm') # deepspeed's optimizer | |
| self.fp = None | |
| self.last_orig_norm: torch.Tensor = torch.tensor(0.1) | |
| def log_param(self, ep: int): | |
| if self.zero == 0: | |
| for name, values in get_param_for_log(self.model_name_3letters, self.model_maybe_fsdp.named_parameters()).items(): | |
| values: List[float] | |
| if len(values) == 1: # e.g., cls token will only have one value | |
| values.append(values[0]) | |
| else: | |
| ... | |
| # todo: log params | |
| # @profile(precision=4, stream=open('amp_sc.log', 'w+')) | |
| def backward_clip_step( | |
| self, ep: int, it: int, g_it: int, stepping: bool, logging_params: bool, loss: torch.Tensor, clip_decay_ratio=1, stable=False, | |
| ) -> Tuple[torch.Tensor, Optional[float]]: | |
| # backward | |
| loss = loss.mul(self.r_accu) # r_accu == 1.0 / n_gradient_accumulation | |
| orig_norm = scaler_sc = None | |
| # if self.fp is not None: | |
| # if g_it % 20 == 0: self.fp.seek(0); self.fp.truncate(0) | |
| if self.scaler is not None: | |
| self.scaler.scale(loss).backward(retain_graph=False, create_graph=False) # retain_graph=retain_graph, create_graph=create_graph | |
| else: | |
| loss.backward(retain_graph=False, create_graph=False) | |
| # if self.fp is not None: self.fp.write(f'[backward_clip_step:131] [it{it}, g_it{g_it}] after backward\n'); self.fp.flush() | |
| # clip gradients then step optimizer | |
| if stepping: | |
| if self.scaler is not None: self.scaler.unscale_(self.optimizer) # now the gradient can be correctly got | |
| # if self.fp is not None: self.fp.write(f'[backward_clip_step:137] [it{it}, g_it{g_it}] after scaler.unscale_\n'); self.fp.flush() | |
| skipped, orig_norm = 0, self.last_orig_norm | |
| # try: | |
| if self.fp is not None: | |
| if g_it % 10 == 0: self.fp.seek(0); self.fp.truncate(0) | |
| self.fp.write(f'<ep{ep} it{it} {g_it}>\n'); self.fp.flush() | |
| if self.early_clipping: | |
| c = self.grad_clip * clip_decay_ratio | |
| if self.zero: | |
| orig_norm: Optional[torch.Tensor] = self.model_maybe_fsdp.clip_grad_norm_(c) | |
| else: | |
| orig_norm: Optional[torch.Tensor] = torch.nn.utils.clip_grad_norm_(self.model_maybe_fsdp.parameters(), c) | |
| # if self.fp is not None: self.fp.write(f'[backward_clip_step:175] [it{it}, g_it{g_it}] before opt step\n'); self.fp.flush() | |
| if self.scaler is not None: | |
| self.scaler: torch.cuda.amp.GradScaler | |
| if self.zero: | |
| # synchronize found_inf_per_device before calling step, so that even if only some ranks found inf on their sharded params, all other ranks will know | |
| # otherwise, when saving FSDP optimizer state, it will cause AssertionError saying "Different ranks have different values for step." | |
| for optimizer_state in self.scaler._per_optimizer_states.values(): | |
| for t in optimizer_state['found_inf_per_device'].values(): | |
| dist.allreduce(t) # ideally, each rank only has one single t; so no need to use async allreduce | |
| self.scaler.step(self.optimizer) | |
| scaler_sc: Optional[float] = self.scaler.get_scale() | |
| if scaler_sc > self.max_sc: # fp16 will overflow when >65536, so multiply 32768 could be dangerous | |
| # print(f'[fp16 scaling] too large loss scale {scaler_sc}! (clip to {self.max_sc:g})') | |
| self.scaler.update(new_scale=self.max_sc) | |
| else: | |
| self.scaler.update() | |
| try: | |
| scaler_sc = float(math.log2(scaler_sc)) | |
| except Exception as e: | |
| print(f'[scaler_sc = {scaler_sc}]\n' * 15, flush=True) | |
| time.sleep(1) | |
| print(f'[scaler_sc = {scaler_sc}]\n' * 15, flush=True) | |
| raise e | |
| else: | |
| self.optimizer.step() | |
| if self.late_clipping: | |
| orig_norm: Optional[torch.Tensor] = self.optimizer.global_grad_norm | |
| self.last_orig_norm = orig_norm | |
| # no zero_grad calling here, gonna log those gradients! | |
| return orig_norm, scaler_sc | |
| def state_dict(self): | |
| return { | |
| 'optimizer': self.optimizer.state_dict() | |
| } if self.scaler is None else { | |
| 'scaler': self.scaler.state_dict(), | |
| 'optimizer': self.optimizer.state_dict() | |
| } | |
| def load_state_dict(self, state, strict=True): | |
| if self.scaler is not None: | |
| try: self.scaler.load_state_dict(state['scaler']) | |
| except Exception as e: print(f'[fp16 load_state_dict err] {e}') | |
| self.optimizer.load_state_dict(state['optimizer']) | |