""" Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause Based on https://github.com/facebookresearch/TimeSformer """ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # Copyright 2020 Ross Wightman # Modified model creation / weight loading / state_dict helpers import logging, warnings import os import math from collections import OrderedDict import torch import torch.utils.model_zoo as model_zoo import torch.nn.functional as F def load_state_dict(checkpoint_path, use_ema=False): if checkpoint_path and os.path.isfile(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location="cpu") state_dict_key = "state_dict" if isinstance(checkpoint, dict): if use_ema and "state_dict_ema" in checkpoint: state_dict_key = "state_dict_ema" if state_dict_key and state_dict_key in checkpoint: new_state_dict = OrderedDict() for k, v in checkpoint[state_dict_key].items(): # strip `module.` prefix name = k[7:] if k.startswith("module") else k new_state_dict[name] = v state_dict = new_state_dict elif "model_state" in checkpoint: state_dict_key = "model_state" new_state_dict = OrderedDict() for k, v in checkpoint[state_dict_key].items(): # strip `model.` prefix name = k[6:] if k.startswith("model") else k new_state_dict[name] = v state_dict = new_state_dict else: state_dict = checkpoint logging.info( "Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path) ) return state_dict else: logging.error("No checkpoint found at '{}'".format(checkpoint_path)) raise FileNotFoundError() def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True): state_dict = load_state_dict(checkpoint_path, use_ema) model.load_state_dict(state_dict, strict=strict) # def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True): # resume_epoch = None # if os.path.isfile(checkpoint_path): # checkpoint = torch.load(checkpoint_path, map_location='cpu') # if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: # if log_info: # _logger.info('Restoring model state from checkpoint...') # new_state_dict = OrderedDict() # for k, v in checkpoint['state_dict'].items(): # name = k[7:] if k.startswith('module') else k # new_state_dict[name] = v # model.load_state_dict(new_state_dict) # if optimizer is not None and 'optimizer' in checkpoint: # if log_info: # _logger.info('Restoring optimizer state from checkpoint...') # optimizer.load_state_dict(checkpoint['optimizer']) # if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint: # if log_info: # _logger.info('Restoring AMP loss scaler state from checkpoint...') # loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key]) # if 'epoch' in checkpoint: # resume_epoch = checkpoint['epoch'] # if 'version' in checkpoint and checkpoint['version'] > 1: # resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save # if log_info: # _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) # else: # model.load_state_dict(checkpoint) # if log_info: # _logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) # return resume_epoch # else: # _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) # raise FileNotFoundError() def load_pretrained( model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, img_size=224, num_frames=8, num_patches=196, attention_type="divided_space_time", pretrained_model="", strict=True, ): if cfg is None: cfg = getattr(model, "default_cfg") if cfg is None or "url" not in cfg or not cfg["url"]: logging.warning("Pretrained model URL is invalid, using random initialization.") return if len(pretrained_model) == 0: if cfg is None: logging.info(f"loading from default config {model.default_cfg}.") state_dict = model_zoo.load_url(cfg["url"], progress=False, map_location="cpu") else: try: state_dict = load_state_dict(pretrained_model)["model"] except: state_dict = load_state_dict(pretrained_model) if filter_fn is not None: state_dict = filter_fn(state_dict) if in_chans == 1: conv1_name = cfg["first_conv"] logging.info( "Converting first conv (%s) pretrained weights from 3 to 1 channel" % conv1_name ) conv1_weight = state_dict[conv1_name + ".weight"] conv1_type = conv1_weight.dtype conv1_weight = conv1_weight.float() O, I, J, K = conv1_weight.shape if I > 3: assert conv1_weight.shape[1] % 3 == 0 # For models with space2depth stems conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K) conv1_weight = conv1_weight.sum(dim=2, keepdim=False) else: conv1_weight = conv1_weight.sum(dim=1, keepdim=True) conv1_weight = conv1_weight.to(conv1_type) state_dict[conv1_name + ".weight"] = conv1_weight elif in_chans != 3: conv1_name = cfg["first_conv"] conv1_weight = state_dict[conv1_name + ".weight"] conv1_type = conv1_weight.dtype conv1_weight = conv1_weight.float() O, I, J, K = conv1_weight.shape if I != 3: logging.warning( "Deleting first conv (%s) from pretrained weights." % conv1_name ) del state_dict[conv1_name + ".weight"] strict = False else: logging.info( "Repeating first conv (%s) weights in channel dim." % conv1_name ) repeat = int(math.ceil(in_chans / 3)) conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] conv1_weight *= 3 / float(in_chans) conv1_weight = conv1_weight.to(conv1_type) state_dict[conv1_name + ".weight"] = conv1_weight classifier_name = cfg["classifier"] if num_classes == 1000 and cfg["num_classes"] == 1001: # special case for imagenet trained models with extra background class in pretrained weights classifier_weight = state_dict[classifier_name + ".weight"] state_dict[classifier_name + ".weight"] = classifier_weight[1:] classifier_bias = state_dict[classifier_name + ".bias"] state_dict[classifier_name + ".bias"] = classifier_bias[1:] elif num_classes != state_dict[classifier_name + ".weight"].size(0): # print('Removing the last fully connected layer due to dimensions mismatch ('+str(num_classes)+ ' != '+str(state_dict[classifier_name + '.weight'].size(0))+').', flush=True) # completely discard fully connected for all other differences between pretrained and created model del state_dict[classifier_name + ".weight"] del state_dict[classifier_name + ".bias"] strict = False ## Resizing the positional embeddings in case they don't match logging.info( f"Resizing spatial position embedding from {state_dict['pos_embed'].size(1)} to {num_patches + 1}" ) if num_patches + 1 != state_dict["pos_embed"].size(1): pos_embed = state_dict["pos_embed"] cls_pos_embed = pos_embed[0, 0, :].unsqueeze(0).unsqueeze(1) other_pos_embed = pos_embed[0, 1:, :].unsqueeze(0).transpose(1, 2) new_pos_embed = F.interpolate( other_pos_embed, size=(num_patches), mode="nearest" ) new_pos_embed = new_pos_embed.transpose(1, 2) new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1) state_dict["pos_embed"] = new_pos_embed ## Resizing time embeddings in case they don't match if "time_embed" in state_dict and num_frames != state_dict["time_embed"].size(1): logging.info( f"Resizing temporal position embedding from {state_dict['time_embed'].size(1)} to {num_frames}" ) time_embed = state_dict["time_embed"].transpose(1, 2) new_time_embed = F.interpolate(time_embed, size=(num_frames), mode="nearest") state_dict["time_embed"] = new_time_embed.transpose(1, 2) ## Initializing temporal attention if attention_type == "divided_space_time": new_state_dict = state_dict.copy() for key in state_dict: if "blocks" in key and "attn" in key: new_key = key.replace("attn", "temporal_attn") if not new_key in state_dict: new_state_dict[new_key] = state_dict[key] else: new_state_dict[new_key] = state_dict[new_key] if "blocks" in key and "norm1" in key: new_key = key.replace("norm1", "temporal_norm1") if not new_key in state_dict: new_state_dict[new_key] = state_dict[key] else: new_state_dict[new_key] = state_dict[new_key] state_dict = new_state_dict ## Loading the weights model.load_state_dict(state_dict, strict=False) def load_pretrained_imagenet( model, pretrained_model, cfg=None, ignore_classifier=True, num_frames=8, num_patches=196, **kwargs, ): import timm logging.info(f"Loading vit_base_patch16_224 checkpoints.") loaded_state_dict = timm.models.vision_transformer.vit_base_patch16_224( pretrained=True ).state_dict() del loaded_state_dict["head.weight"] del loaded_state_dict["head.bias"] ## Initializing temporal attention new_state_dict = loaded_state_dict.copy() for key in loaded_state_dict: if "blocks" in key and "attn" in key: new_key = key.replace("attn", "temporal_attn") if not new_key in loaded_state_dict: new_state_dict[new_key] = loaded_state_dict[key] else: new_state_dict[new_key] = loaded_state_dict[new_key] if "blocks" in key and "norm1" in key: new_key = key.replace("norm1", "temporal_norm1") if not new_key in loaded_state_dict: new_state_dict[new_key] = loaded_state_dict[key] else: new_state_dict[new_key] = loaded_state_dict[new_key] loaded_state_dict = new_state_dict loaded_keys = loaded_state_dict.keys() model_keys = model.state_dict().keys() load_not_in_model = [k for k in loaded_keys if k not in model_keys] model_not_in_load = [k for k in model_keys if k not in loaded_keys] toload = dict() mismatched_shape_keys = [] for k in model_keys: if k in loaded_keys: if model.state_dict()[k].shape != loaded_state_dict[k].shape: mismatched_shape_keys.append(k) else: toload[k] = loaded_state_dict[k] logging.info("Keys in loaded but not in model:") logging.info(f"In total {len(load_not_in_model)}, {sorted(load_not_in_model)}") logging.info("Keys in model but not in loaded:") logging.info(f"In total {len(model_not_in_load)}, {sorted(model_not_in_load)}") logging.info("Keys in model and loaded, but shape mismatched:") logging.info( f"In total {len(mismatched_shape_keys)}, {sorted(mismatched_shape_keys)}" ) model.load_state_dict(toload, strict=False) def load_pretrained_kinetics( model, pretrained_model, cfg=None, ignore_classifier=True, num_frames=8, num_patches=196, **kwargs, ): if cfg is None: cfg = getattr(model, "default_cfg") if cfg is None or "url" not in cfg or not cfg["url"]: logging.warning("Pretrained model URL is invalid, using random initialization.") return assert ( len(pretrained_model) > 0 ), "Path to pre-trained Kinetics weights not provided." state_dict = load_state_dict(pretrained_model) classifier_name = cfg["classifier"] if ignore_classifier: classifier_weight_key = classifier_name + ".weight" classifier_bias_key = classifier_name + ".bias" state_dict[classifier_weight_key] = model.state_dict()[classifier_weight_key] state_dict[classifier_bias_key] = model.state_dict()[classifier_bias_key] else: raise NotImplementedError( "[dxli] Not supporting loading Kinetics-pretrained ckpt with classifier." ) ## Resizing the positional embeddings in case they don't match if num_patches + 1 != state_dict["pos_embed"].size(1): new_pos_embed = resize_spatial_embedding(state_dict, "pos_embed", num_patches) state_dict["pos_embed"] = new_pos_embed ## Resizing time embeddings in case they don't match if "time_embed" in state_dict and num_frames != state_dict["time_embed"].size(1): state_dict["time_embed"] = resize_temporal_embedding( state_dict, "time_embed", num_frames ) ## Loading the weights try: model.load_state_dict(state_dict, strict=True) logging.info("Succeeded in loading Kinetics pre-trained weights.") except: logging.error("Error in loading Kinetics pre-trained weights.") def resize_spatial_embedding(state_dict, key, num_patches): logging.info( f"Resizing spatial position embedding from {state_dict[key].size(1)} to {num_patches + 1}" ) pos_embed = state_dict[key] cls_pos_embed = pos_embed[0, 0, :].unsqueeze(0).unsqueeze(1) other_pos_embed = pos_embed[0, 1:, :].unsqueeze(0).transpose(1, 2) new_pos_embed = F.interpolate(other_pos_embed, size=(num_patches), mode="nearest") new_pos_embed = new_pos_embed.transpose(1, 2) new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1) return new_pos_embed def resize_temporal_embedding(state_dict, key, num_frames): logging.info( f"Resizing temporal position embedding from {state_dict[key].size(1)} to {num_frames}" ) time_embed = state_dict[key].transpose(1, 2) new_time_embed = F.interpolate(time_embed, size=(num_frames), mode="nearest") return new_time_embed.transpose(1, 2) def detach_variable(inputs): if isinstance(inputs, tuple): out = [] for inp in inputs: x = inp.detach() x.requires_grad = inp.requires_grad out.append(x) return tuple(out) else: raise RuntimeError( "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__, ) def check_backward_validity(inputs): if not any(inp.requires_grad for inp in inputs): warnings.warn( "None of the inputs have requires_grad=True. Gradients will be None" )