# Copyright (c) Facebook, Inc. and its affiliates. import itertools from enum import Enum from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type, Union import torch from detectron2.config import CfgNode from .lr_scheduler import WarmupCosineLR, WarmupMultiStepLR _GradientClipperInput = Union[torch.Tensor, Iterable[torch.Tensor]] _GradientClipper = Callable[[_GradientClipperInput], None] class GradientClipType(Enum): VALUE = "value" NORM = "norm" def _create_gradient_clipper(cfg: CfgNode) -> _GradientClipper: """ Creates gradient clipping closure to clip by value or by norm, according to the provided config. """ cfg = cfg.clone() def clip_grad_norm(p: _GradientClipperInput): torch.nn.utils.clip_grad_norm_(p, cfg.CLIP_VALUE, cfg.NORM_TYPE) def clip_grad_value(p: _GradientClipperInput): torch.nn.utils.clip_grad_value_(p, cfg.CLIP_VALUE) _GRADIENT_CLIP_TYPE_TO_CLIPPER = { GradientClipType.VALUE: clip_grad_value, GradientClipType.NORM: clip_grad_norm, } return _GRADIENT_CLIP_TYPE_TO_CLIPPER[GradientClipType(cfg.CLIP_TYPE)] def _generate_optimizer_class_with_gradient_clipping( optimizer: Type[torch.optim.Optimizer], *, per_param_clipper: Optional[_GradientClipper] = None, global_clipper: Optional[_GradientClipper] = None ) -> Type[torch.optim.Optimizer]: """ Dynamically creates a new type that inherits the type of a given instance and overrides the `step` method to add gradient clipping """ assert ( per_param_clipper is None or global_clipper is None ), "Not allowed to use both per-parameter clipping and global clipping" def optimizer_wgc_step(self, closure=None): if per_param_clipper is not None: for group in self.param_groups: for p in group["params"]: per_param_clipper(p) else: # global clipper for future use with detr # (https://github.com/facebookresearch/detr/pull/287) all_params = itertools.chain(*[g["params"] for g in self.param_groups]) global_clipper(all_params) super(type(self), self).step(closure) OptimizerWithGradientClip = type( optimizer.__name__ + "WithGradientClip", (optimizer,), {"step": optimizer_wgc_step}, ) return OptimizerWithGradientClip def maybe_add_gradient_clipping( cfg: CfgNode, optimizer: Type[torch.optim.Optimizer] ) -> Type[torch.optim.Optimizer]: """ If gradient clipping is enabled through config options, wraps the existing optimizer type to become a new dynamically created class OptimizerWithGradientClip that inherits the given optimizer and overrides the `step` method to include gradient clipping. Args: cfg: CfgNode, configuration options optimizer: type. A subclass of torch.optim.Optimizer Return: type: either the input `optimizer` (if gradient clipping is disabled), or a subclass of it with gradient clipping included in the `step` method. """ if not cfg.SOLVER.CLIP_GRADIENTS.ENABLED: return optimizer if isinstance(optimizer, torch.optim.Optimizer): optimizer_type = type(optimizer) else: assert issubclass(optimizer, torch.optim.Optimizer), optimizer optimizer_type = optimizer grad_clipper = _create_gradient_clipper(cfg.SOLVER.CLIP_GRADIENTS) OptimizerWithGradientClip = _generate_optimizer_class_with_gradient_clipping( optimizer_type, per_param_clipper=grad_clipper ) if isinstance(optimizer, torch.optim.Optimizer): optimizer.__class__ = OptimizerWithGradientClip # a bit hacky, not recommended return optimizer else: return OptimizerWithGradientClip def build_optimizer(cfg: CfgNode, model: torch.nn.Module) -> torch.optim.Optimizer: """ Build an optimizer from config. """ params = get_default_optimizer_params( model, base_lr=cfg.SOLVER.BASE_LR, weight_decay=cfg.SOLVER.WEIGHT_DECAY, weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM, bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR, weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS, ) return maybe_add_gradient_clipping(cfg, torch.optim.SGD)( params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM, nesterov=cfg.SOLVER.NESTEROV ) def get_default_optimizer_params( model: torch.nn.Module, base_lr, weight_decay, weight_decay_norm, bias_lr_factor=1.0, weight_decay_bias=None, overrides: Optional[Dict[str, Dict[str, float]]] = None, ): """ Get default param list for optimizer Args: overrides (dict: str -> (dict: str -> float)): if not `None`, provides values for optimizer hyperparameters (LR, weight decay) for module parameters with a given name; e.g. {"embedding": {"lr": 0.01, "weight_decay": 0.1}} will set the LR and weight decay values for all module parameters named `embedding` (default: None) """ if weight_decay_bias is None: weight_decay_bias = weight_decay norm_module_types = ( torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm, # NaiveSyncBatchNorm inherits from BatchNorm2d torch.nn.GroupNorm, torch.nn.InstanceNorm1d, torch.nn.InstanceNorm2d, torch.nn.InstanceNorm3d, torch.nn.LayerNorm, torch.nn.LocalResponseNorm, ) params: List[Dict[str, Any]] = [] memo: Set[torch.nn.parameter.Parameter] = set() for module in model.modules(): for module_param_name, value in module.named_parameters(recurse=False): if not value.requires_grad: continue # Avoid duplicating parameters if value in memo: continue memo.add(value) schedule_params = { "lr": base_lr, "weight_decay": weight_decay, } if isinstance(module, norm_module_types): schedule_params["weight_decay"] = weight_decay_norm elif module_param_name == "bias": # NOTE: unlike Detectron v1, we now default BIAS_LR_FACTOR to 1.0 # and WEIGHT_DECAY_BIAS to WEIGHT_DECAY so that bias optimizer # hyperparameters are by default exactly the same as for regular # weights. schedule_params["lr"] = base_lr * bias_lr_factor schedule_params["weight_decay"] = weight_decay_bias if overrides is not None and module_param_name in overrides: schedule_params.update(overrides[module_param_name]) params += [ { "params": [value], "lr": schedule_params["lr"], "weight_decay": schedule_params["weight_decay"], } ] return params def build_lr_scheduler( cfg: CfgNode, optimizer: torch.optim.Optimizer ) -> torch.optim.lr_scheduler._LRScheduler: """ Build a LR scheduler from config. """ name = cfg.SOLVER.LR_SCHEDULER_NAME if name == "WarmupMultiStepLR": return WarmupMultiStepLR( optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, warmup_factor=cfg.SOLVER.WARMUP_FACTOR, warmup_iters=cfg.SOLVER.WARMUP_ITERS, warmup_method=cfg.SOLVER.WARMUP_METHOD, ) elif name == "WarmupCosineLR": return WarmupCosineLR( optimizer, cfg.SOLVER.MAX_ITER, warmup_factor=cfg.SOLVER.WARMUP_FACTOR, warmup_iters=cfg.SOLVER.WARMUP_ITERS, warmup_method=cfg.SOLVER.WARMUP_METHOD, ) else: raise ValueError("Unknown LR scheduler: {}".format(name))