from functools import partial from torch import optim as optim import os import torch import torch.distributed as dist import numpy as np from scipy import interpolate def build_optimizer(config, model, is_pretrain=False, logger=None): """ Build optimizer, set weight decay of normalization to 0 by default. AdamW only. """ logger.info('>>>>>>>>>> Build Optimizer') skip = {} skip_keywords = {} if hasattr(model, 'no_weight_decay'): skip = model.no_weight_decay() if hasattr(model, 'no_weight_decay_keywords'): skip_keywords = model.no_weight_decay_keywords() if is_pretrain: parameters = get_pretrain_param_groups(model, skip, skip_keywords) else: depths = config.MODEL.SWIN.DEPTHS if config.MODEL.TYPE == 'swin' \ else config.MODEL.SWINV2.DEPTHS num_layers = sum(depths) get_layer_func = partial(get_swin_layer, num_layers=num_layers + 2, depths=depths) scales = list(config.TRAIN.LAYER_DECAY ** i for i in reversed(range(num_layers + 2))) parameters = get_finetune_param_groups(model, config.TRAIN.BASE_LR, config.TRAIN.WEIGHT_DECAY, get_layer_func, scales, skip, skip_keywords) optimizer = None optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) logger.info(optimizer) return optimizer def set_weight_decay(model, skip_list=(), skip_keywords=()): """ Args: model (_type_): _description_ skip_list (tuple, optional): _description_. Defaults to (). skip_keywords (tuple, optional): _description_. Defaults to (). Returns: _type_: _description_ """ has_decay = [] no_decay = [] for name, param in model.named_parameters(): if not param.requires_grad: continue # frozen weights if len(param.shape) == 1 or name.endswith(".bias") \ or (name in skip_list) or \ check_keywords_in_name(name, skip_keywords): no_decay.append(param) else: has_decay.append(param) return [{'params': has_decay}, {'params': no_decay, 'weight_decay': 0.}] def check_keywords_in_name(name, keywords=()): isin = False for keyword in keywords: if keyword in name: isin = True return isin def get_pretrain_param_groups(model, skip_list=(), skip_keywords=()): has_decay = [] no_decay = [] has_decay_name = [] no_decay_name = [] for name, param in model.named_parameters(): if not param.requires_grad: continue if len(param.shape) == 1 or name.endswith(".bias") or \ (name in skip_list) or \ check_keywords_in_name(name, skip_keywords): no_decay.append(param) no_decay_name.append(name) else: has_decay.append(param) has_decay_name.append(name) return [{'params': has_decay}, {'params': no_decay, 'weight_decay': 0.}] def get_swin_layer(name, num_layers, depths): if name in ("mask_token"): return 0 elif name.startswith("patch_embed"): return 0 elif name.startswith("layers"): layer_id = int(name.split('.')[1]) block_id = name.split('.')[3] if block_id == 'reduction' or block_id == 'norm': return sum(depths[:layer_id + 1]) layer_id = sum(depths[:layer_id]) + int(block_id) return layer_id + 1 else: return num_layers - 1 def get_finetune_param_groups(model, lr, weight_decay, get_layer_func, scales, skip_list=(), skip_keywords=()): parameter_group_names = {} parameter_group_vars = {} for name, param in model.named_parameters(): if not param.requires_grad: continue if len(param.shape) == 1 or name.endswith(".bias") \ or (name in skip_list) or \ check_keywords_in_name(name, skip_keywords): group_name = "no_decay" this_weight_decay = 0. else: group_name = "decay" this_weight_decay = weight_decay if get_layer_func is not None: layer_id = get_layer_func(name) group_name = "layer_%d_%s" % (layer_id, group_name) else: layer_id = None if group_name not in parameter_group_names: if scales is not None: scale = scales[layer_id] else: scale = 1. parameter_group_names[group_name] = { "group_name": group_name, "weight_decay": this_weight_decay, "params": [], "lr": lr * scale, "lr_scale": scale, } parameter_group_vars[group_name] = { "group_name": group_name, "weight_decay": this_weight_decay, "params": [], "lr": lr * scale, "lr_scale": scale } parameter_group_vars[group_name]["params"].append(param) parameter_group_names[group_name]["params"].append(name) return list(parameter_group_vars.values()) def load_checkpoint(config, model, optimizer, lr_scheduler, scaler, logger): logger.info(f">>>>>>>>>> Resuming from {config.MODEL.RESUME} ..........") if config.MODEL.RESUME.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url( config.MODEL.RESUME, map_location='cpu', check_hash=True) else: checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') # re-map keys due to name change (only for loading provided models) rpe_mlp_keys = [k for k in checkpoint['model'].keys() if "rpe_mlp" in k] for k in rpe_mlp_keys: checkpoint['model'][k.replace( 'rpe_mlp', 'cpb_mlp')] = checkpoint['model'].pop(k) msg = model.load_state_dict(checkpoint['model'], strict=False) logger.info(msg) max_accuracy = 0.0 if not config.EVAL_MODE and 'optimizer' in checkpoint \ and 'lr_scheduler' in checkpoint \ and 'scaler' in checkpoint \ and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) scaler.load_state_dict(checkpoint['scaler']) config.defrost() config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 config.freeze() logger.info( f"=> loaded successfully '{config.MODEL.RESUME}' " + f"(epoch {checkpoint['epoch']})") if 'max_accuracy' in checkpoint: max_accuracy = checkpoint['max_accuracy'] else: max_accuracy = 0.0 del checkpoint torch.cuda.empty_cache() return max_accuracy def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, scaler, logger): save_state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'scaler': scaler.state_dict(), 'max_accuracy': max_accuracy, 'epoch': epoch, 'config': config} save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') logger.info(f"{save_path} saving......") torch.save(save_state, save_path) logger.info(f"{save_path} saved !!!") def get_grad_norm(parameters, norm_type=2): if isinstance(parameters, torch.Tensor): parameters = [parameters] parameters = list(filter(lambda p: p.grad is not None, parameters)) norm_type = float(norm_type) total_norm = 0 for p in parameters: param_norm = p.grad.data.norm(norm_type) total_norm += param_norm.item() ** norm_type total_norm = total_norm ** (1. / norm_type) return total_norm def auto_resume_helper(output_dir, logger): checkpoints = os.listdir(output_dir) checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')] logger.info(f"All checkpoints founded in {output_dir}: {checkpoints}") if len(checkpoints) > 0: latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime) logger.info(f"The latest checkpoint founded: {latest_checkpoint}") resume_file = latest_checkpoint else: resume_file = None return resume_file def reduce_tensor(tensor): rt = tensor.clone() dist.all_reduce(rt, op=dist.ReduceOp.SUM) rt /= dist.get_world_size() return rt def load_pretrained(config, model, logger): logger.info( f">>>>>>>>>> Fine-tuned from {config.MODEL.PRETRAINED} ..........") checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu') checkpoint_model = checkpoint['model'] if any([True if 'encoder.' in k else False for k in checkpoint_model.keys()]): checkpoint_model = {k.replace( 'encoder.', ''): v for k, v in checkpoint_model.items() if k.startswith('encoder.')} logger.info('Detect pre-trained model, remove [encoder.] prefix.') else: logger.info( 'Detect non-pre-trained model, pass without doing anything.') if config.MODEL.TYPE in ['swin', 'swinv2']: logger.info( ">>>>>>>>>> Remapping pre-trained keys for SWIN ..........") checkpoint = remap_pretrained_keys_swin( model, checkpoint_model, logger) else: raise NotImplementedError msg = model.load_state_dict(checkpoint_model, strict=False) logger.info(msg) del checkpoint torch.cuda.empty_cache() logger.info(f">>>>>>>>>> loaded successfully '{config.MODEL.PRETRAINED}'") def remap_pretrained_keys_swin(model, checkpoint_model, logger): state_dict = model.state_dict() # Geometric interpolation when pre-trained patch size mismatch # with fine-tuned patch size all_keys = list(checkpoint_model.keys()) for key in all_keys: if "relative_position_bias_table" in key: logger.info(f"Key: {key}") rel_position_bias_table_pretrained = checkpoint_model[key] rel_position_bias_table_current = state_dict[key] L1, nH1 = rel_position_bias_table_pretrained.size() L2, nH2 = rel_position_bias_table_current.size() if nH1 != nH2: logger.info(f"Error in loading {key}, passing......") else: if L1 != L2: logger.info( f"{key}: Interpolate " + "relative_position_bias_table using geo.") src_size = int(L1 ** 0.5) dst_size = int(L2 ** 0.5) def geometric_progression(a, r, n): return a * (1.0 - r ** n) / (1.0 - r) left, right = 1.01, 1.5 while right - left > 1e-6: q = (left + right) / 2.0 gp = geometric_progression(1, q, src_size // 2) if gp > dst_size // 2: right = q else: left = q # if q > 1.090307: # q = 1.090307 dis = [] cur = 1 for i in range(src_size // 2): dis.append(cur) cur += q ** (i + 1) r_ids = [-_ for _ in reversed(dis)] x = r_ids + [0] + dis y = r_ids + [0] + dis t = dst_size // 2.0 dx = np.arange(-t, t + 0.1, 1.0) dy = np.arange(-t, t + 0.1, 1.0) logger.info("Original positions = %s" % str(x)) logger.info("Target positions = %s" % str(dx)) all_rel_pos_bias = [] for i in range(nH1): z = rel_position_bias_table_pretrained[:, i].view( src_size, src_size).float().numpy() f_cubic = interpolate.interp2d(x, y, z, kind='cubic') all_rel_pos_bias_host = \ torch.Tensor(f_cubic(dx, dy) ).contiguous().view(-1, 1) all_rel_pos_bias.append( all_rel_pos_bias_host.to( rel_position_bias_table_pretrained.device)) new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) checkpoint_model[key] = new_rel_pos_bias # delete relative_position_index since we always re-init it relative_position_index_keys = [ k for k in checkpoint_model.keys() if "relative_position_index" in k] for k in relative_position_index_keys: del checkpoint_model[k] # delete relative_coords_table since we always re-init it relative_coords_table_keys = [ k for k in checkpoint_model.keys() if "relative_coords_table" in k] for k in relative_coords_table_keys: del checkpoint_model[k] # delete attn_mask since we always re-init it attn_mask_keys = [k for k in checkpoint_model.keys() if "attn_mask" in k] for k in attn_mask_keys: del checkpoint_model[k] return checkpoint_model def remap_pretrained_keys_vit(model, checkpoint_model, logger): # Duplicate shared rel_pos_bias to each layer if getattr(model, 'use_rel_pos_bias', False) and \ "rel_pos_bias.relative_position_bias_table" in checkpoint_model: logger.info( "Expand the shared relative position " + "embedding to each transformer block.") num_layers = model.get_num_layers() rel_pos_bias = \ checkpoint_model["rel_pos_bias.relative_position_bias_table"] for i in range(num_layers): checkpoint_model["blocks.%d.attn.relative_position_bias_table" % i] = rel_pos_bias.clone() checkpoint_model.pop("rel_pos_bias.relative_position_bias_table") # Geometric interpolation when pre-trained patch # size mismatch with fine-tuned patch size all_keys = list(checkpoint_model.keys()) for key in all_keys: if "relative_position_index" in key: checkpoint_model.pop(key) if "relative_position_bias_table" in key: rel_pos_bias = checkpoint_model[key] src_num_pos, num_attn_heads = rel_pos_bias.size() dst_num_pos, _ = model.state_dict()[key].size() dst_patch_shape = model.patch_embed.patch_shape if dst_patch_shape[0] != dst_patch_shape[1]: raise NotImplementedError() num_extra_tokens = dst_num_pos - \ (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1) src_size = int((src_num_pos - num_extra_tokens) ** 0.5) dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5) if src_size != dst_size: logger.info("Position interpolate for " + "%s from %dx%d to %dx%d" % ( key, src_size, src_size, dst_size, dst_size)) extra_tokens = rel_pos_bias[-num_extra_tokens:, :] rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] def geometric_progression(a, r, n): return a * (1.0 - r ** n) / (1.0 - r) left, right = 1.01, 1.5 while right - left > 1e-6: q = (left + right) / 2.0 gp = geometric_progression(1, q, src_size // 2) if gp > dst_size // 2: right = q else: left = q # if q > 1.090307: # q = 1.090307 dis = [] cur = 1 for i in range(src_size // 2): dis.append(cur) cur += q ** (i + 1) r_ids = [-_ for _ in reversed(dis)] x = r_ids + [0] + dis y = r_ids + [0] + dis t = dst_size // 2.0 dx = np.arange(-t, t + 0.1, 1.0) dy = np.arange(-t, t + 0.1, 1.0) logger.info("Original positions = %s" % str(x)) logger.info("Target positions = %s" % str(dx)) all_rel_pos_bias = [] for i in range(num_attn_heads): z = rel_pos_bias[:, i].view( src_size, src_size).float().numpy() f = interpolate.interp2d(x, y, z, kind='cubic') all_rel_pos_bias_host = \ torch.Tensor(f(dx, dy)).contiguous().view(-1, 1) all_rel_pos_bias.append( all_rel_pos_bias_host.to(rel_pos_bias.device)) rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) new_rel_pos_bias = torch.cat( (rel_pos_bias, extra_tokens), dim=0) checkpoint_model[key] = new_rel_pos_bias return checkpoint_model