Spaces:
Sleeping
Sleeping
""" | |
Copyright (c) 2022, salesforce.com, inc. | |
All rights reserved. | |
SPDX-License-Identifier: BSD-3-Clause | |
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
Based on https://github.com/facebookresearch/TimeSformer | |
""" | |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
# Copyright 2020 Ross Wightman | |
# Modified model creation / weight loading / state_dict helpers | |
import logging, warnings | |
import os | |
import math | |
from collections import OrderedDict | |
import torch | |
import torch.utils.model_zoo as model_zoo | |
import torch.nn.functional as F | |
def load_state_dict(checkpoint_path, use_ema=False): | |
if checkpoint_path and os.path.isfile(checkpoint_path): | |
checkpoint = torch.load(checkpoint_path, map_location="cpu") | |
state_dict_key = "state_dict" | |
if isinstance(checkpoint, dict): | |
if use_ema and "state_dict_ema" in checkpoint: | |
state_dict_key = "state_dict_ema" | |
if state_dict_key and state_dict_key in checkpoint: | |
new_state_dict = OrderedDict() | |
for k, v in checkpoint[state_dict_key].items(): | |
# strip `module.` prefix | |
name = k[7:] if k.startswith("module") else k | |
new_state_dict[name] = v | |
state_dict = new_state_dict | |
elif "model_state" in checkpoint: | |
state_dict_key = "model_state" | |
new_state_dict = OrderedDict() | |
for k, v in checkpoint[state_dict_key].items(): | |
# strip `model.` prefix | |
name = k[6:] if k.startswith("model") else k | |
new_state_dict[name] = v | |
state_dict = new_state_dict | |
else: | |
state_dict = checkpoint | |
logging.info( | |
"Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path) | |
) | |
return state_dict | |
else: | |
logging.error("No checkpoint found at '{}'".format(checkpoint_path)) | |
raise FileNotFoundError() | |
def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True): | |
state_dict = load_state_dict(checkpoint_path, use_ema) | |
model.load_state_dict(state_dict, strict=strict) | |
# def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True): | |
# resume_epoch = None | |
# if os.path.isfile(checkpoint_path): | |
# checkpoint = torch.load(checkpoint_path, map_location='cpu') | |
# if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: | |
# if log_info: | |
# _logger.info('Restoring model state from checkpoint...') | |
# new_state_dict = OrderedDict() | |
# for k, v in checkpoint['state_dict'].items(): | |
# name = k[7:] if k.startswith('module') else k | |
# new_state_dict[name] = v | |
# model.load_state_dict(new_state_dict) | |
# if optimizer is not None and 'optimizer' in checkpoint: | |
# if log_info: | |
# _logger.info('Restoring optimizer state from checkpoint...') | |
# optimizer.load_state_dict(checkpoint['optimizer']) | |
# if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint: | |
# if log_info: | |
# _logger.info('Restoring AMP loss scaler state from checkpoint...') | |
# loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key]) | |
# if 'epoch' in checkpoint: | |
# resume_epoch = checkpoint['epoch'] | |
# if 'version' in checkpoint and checkpoint['version'] > 1: | |
# resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save | |
# if log_info: | |
# _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) | |
# else: | |
# model.load_state_dict(checkpoint) | |
# if log_info: | |
# _logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) | |
# return resume_epoch | |
# else: | |
# _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) | |
# raise FileNotFoundError() | |
def load_pretrained( | |
model, | |
cfg=None, | |
num_classes=1000, | |
in_chans=3, | |
filter_fn=None, | |
img_size=224, | |
num_frames=8, | |
num_patches=196, | |
attention_type="divided_space_time", | |
pretrained_model="", | |
strict=True, | |
): | |
if cfg is None: | |
cfg = getattr(model, "default_cfg") | |
if cfg is None or "url" not in cfg or not cfg["url"]: | |
logging.warning("Pretrained model URL is invalid, using random initialization.") | |
return | |
if len(pretrained_model) == 0: | |
if cfg is None: | |
logging.info(f"loading from default config {model.default_cfg}.") | |
state_dict = model_zoo.load_url(cfg["url"], progress=False, map_location="cpu") | |
else: | |
try: | |
state_dict = load_state_dict(pretrained_model)["model"] | |
except: | |
state_dict = load_state_dict(pretrained_model) | |
if filter_fn is not None: | |
state_dict = filter_fn(state_dict) | |
if in_chans == 1: | |
conv1_name = cfg["first_conv"] | |
logging.info( | |
"Converting first conv (%s) pretrained weights from 3 to 1 channel" | |
% conv1_name | |
) | |
conv1_weight = state_dict[conv1_name + ".weight"] | |
conv1_type = conv1_weight.dtype | |
conv1_weight = conv1_weight.float() | |
O, I, J, K = conv1_weight.shape | |
if I > 3: | |
assert conv1_weight.shape[1] % 3 == 0 | |
# For models with space2depth stems | |
conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K) | |
conv1_weight = conv1_weight.sum(dim=2, keepdim=False) | |
else: | |
conv1_weight = conv1_weight.sum(dim=1, keepdim=True) | |
conv1_weight = conv1_weight.to(conv1_type) | |
state_dict[conv1_name + ".weight"] = conv1_weight | |
elif in_chans != 3: | |
conv1_name = cfg["first_conv"] | |
conv1_weight = state_dict[conv1_name + ".weight"] | |
conv1_type = conv1_weight.dtype | |
conv1_weight = conv1_weight.float() | |
O, I, J, K = conv1_weight.shape | |
if I != 3: | |
logging.warning( | |
"Deleting first conv (%s) from pretrained weights." % conv1_name | |
) | |
del state_dict[conv1_name + ".weight"] | |
strict = False | |
else: | |
logging.info( | |
"Repeating first conv (%s) weights in channel dim." % conv1_name | |
) | |
repeat = int(math.ceil(in_chans / 3)) | |
conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] | |
conv1_weight *= 3 / float(in_chans) | |
conv1_weight = conv1_weight.to(conv1_type) | |
state_dict[conv1_name + ".weight"] = conv1_weight | |
classifier_name = cfg["classifier"] | |
if num_classes == 1000 and cfg["num_classes"] == 1001: | |
# special case for imagenet trained models with extra background class in pretrained weights | |
classifier_weight = state_dict[classifier_name + ".weight"] | |
state_dict[classifier_name + ".weight"] = classifier_weight[1:] | |
classifier_bias = state_dict[classifier_name + ".bias"] | |
state_dict[classifier_name + ".bias"] = classifier_bias[1:] | |
elif num_classes != state_dict[classifier_name + ".weight"].size(0): | |
# print('Removing the last fully connected layer due to dimensions mismatch ('+str(num_classes)+ ' != '+str(state_dict[classifier_name + '.weight'].size(0))+').', flush=True) | |
# completely discard fully connected for all other differences between pretrained and created model | |
del state_dict[classifier_name + ".weight"] | |
del state_dict[classifier_name + ".bias"] | |
strict = False | |
## Resizing the positional embeddings in case they don't match | |
logging.info( | |
f"Resizing spatial position embedding from {state_dict['pos_embed'].size(1)} to {num_patches + 1}" | |
) | |
if num_patches + 1 != state_dict["pos_embed"].size(1): | |
pos_embed = state_dict["pos_embed"] | |
cls_pos_embed = pos_embed[0, 0, :].unsqueeze(0).unsqueeze(1) | |
other_pos_embed = pos_embed[0, 1:, :].unsqueeze(0).transpose(1, 2) | |
new_pos_embed = F.interpolate( | |
other_pos_embed, size=(num_patches), mode="nearest" | |
) | |
new_pos_embed = new_pos_embed.transpose(1, 2) | |
new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1) | |
state_dict["pos_embed"] = new_pos_embed | |
## Resizing time embeddings in case they don't match | |
if "time_embed" in state_dict and num_frames != state_dict["time_embed"].size(1): | |
logging.info( | |
f"Resizing temporal position embedding from {state_dict['time_embed'].size(1)} to {num_frames}" | |
) | |
time_embed = state_dict["time_embed"].transpose(1, 2) | |
new_time_embed = F.interpolate(time_embed, size=(num_frames), mode="nearest") | |
state_dict["time_embed"] = new_time_embed.transpose(1, 2) | |
## Initializing temporal attention | |
if attention_type == "divided_space_time": | |
new_state_dict = state_dict.copy() | |
for key in state_dict: | |
if "blocks" in key and "attn" in key: | |
new_key = key.replace("attn", "temporal_attn") | |
if not new_key in state_dict: | |
new_state_dict[new_key] = state_dict[key] | |
else: | |
new_state_dict[new_key] = state_dict[new_key] | |
if "blocks" in key and "norm1" in key: | |
new_key = key.replace("norm1", "temporal_norm1") | |
if not new_key in state_dict: | |
new_state_dict[new_key] = state_dict[key] | |
else: | |
new_state_dict[new_key] = state_dict[new_key] | |
state_dict = new_state_dict | |
## Loading the weights | |
model.load_state_dict(state_dict, strict=False) | |
def load_pretrained_imagenet( | |
model, | |
pretrained_model, | |
cfg=None, | |
ignore_classifier=True, | |
num_frames=8, | |
num_patches=196, | |
**kwargs, | |
): | |
import timm | |
logging.info(f"Loading vit_base_patch16_224 checkpoints.") | |
loaded_state_dict = timm.models.vision_transformer.vit_base_patch16_224( | |
pretrained=True | |
).state_dict() | |
del loaded_state_dict["head.weight"] | |
del loaded_state_dict["head.bias"] | |
## Initializing temporal attention | |
new_state_dict = loaded_state_dict.copy() | |
for key in loaded_state_dict: | |
if "blocks" in key and "attn" in key: | |
new_key = key.replace("attn", "temporal_attn") | |
if not new_key in loaded_state_dict: | |
new_state_dict[new_key] = loaded_state_dict[key] | |
else: | |
new_state_dict[new_key] = loaded_state_dict[new_key] | |
if "blocks" in key and "norm1" in key: | |
new_key = key.replace("norm1", "temporal_norm1") | |
if not new_key in loaded_state_dict: | |
new_state_dict[new_key] = loaded_state_dict[key] | |
else: | |
new_state_dict[new_key] = loaded_state_dict[new_key] | |
loaded_state_dict = new_state_dict | |
loaded_keys = loaded_state_dict.keys() | |
model_keys = model.state_dict().keys() | |
load_not_in_model = [k for k in loaded_keys if k not in model_keys] | |
model_not_in_load = [k for k in model_keys if k not in loaded_keys] | |
toload = dict() | |
mismatched_shape_keys = [] | |
for k in model_keys: | |
if k in loaded_keys: | |
if model.state_dict()[k].shape != loaded_state_dict[k].shape: | |
mismatched_shape_keys.append(k) | |
else: | |
toload[k] = loaded_state_dict[k] | |
logging.info("Keys in loaded but not in model:") | |
logging.info(f"In total {len(load_not_in_model)}, {sorted(load_not_in_model)}") | |
logging.info("Keys in model but not in loaded:") | |
logging.info(f"In total {len(model_not_in_load)}, {sorted(model_not_in_load)}") | |
logging.info("Keys in model and loaded, but shape mismatched:") | |
logging.info( | |
f"In total {len(mismatched_shape_keys)}, {sorted(mismatched_shape_keys)}" | |
) | |
model.load_state_dict(toload, strict=False) | |
def load_pretrained_kinetics( | |
model, | |
pretrained_model, | |
cfg=None, | |
ignore_classifier=True, | |
num_frames=8, | |
num_patches=196, | |
**kwargs, | |
): | |
if cfg is None: | |
cfg = getattr(model, "default_cfg") | |
if cfg is None or "url" not in cfg or not cfg["url"]: | |
logging.warning("Pretrained model URL is invalid, using random initialization.") | |
return | |
assert ( | |
len(pretrained_model) > 0 | |
), "Path to pre-trained Kinetics weights not provided." | |
state_dict = load_state_dict(pretrained_model) | |
classifier_name = cfg["classifier"] | |
if ignore_classifier: | |
classifier_weight_key = classifier_name + ".weight" | |
classifier_bias_key = classifier_name + ".bias" | |
state_dict[classifier_weight_key] = model.state_dict()[classifier_weight_key] | |
state_dict[classifier_bias_key] = model.state_dict()[classifier_bias_key] | |
else: | |
raise NotImplementedError( | |
"[dxli] Not supporting loading Kinetics-pretrained ckpt with classifier." | |
) | |
## Resizing the positional embeddings in case they don't match | |
if num_patches + 1 != state_dict["pos_embed"].size(1): | |
new_pos_embed = resize_spatial_embedding(state_dict, "pos_embed", num_patches) | |
state_dict["pos_embed"] = new_pos_embed | |
## Resizing time embeddings in case they don't match | |
if "time_embed" in state_dict and num_frames != state_dict["time_embed"].size(1): | |
state_dict["time_embed"] = resize_temporal_embedding( | |
state_dict, "time_embed", num_frames | |
) | |
## Loading the weights | |
try: | |
model.load_state_dict(state_dict, strict=True) | |
logging.info("Succeeded in loading Kinetics pre-trained weights.") | |
except: | |
logging.error("Error in loading Kinetics pre-trained weights.") | |
def resize_spatial_embedding(state_dict, key, num_patches): | |
logging.info( | |
f"Resizing spatial position embedding from {state_dict[key].size(1)} to {num_patches + 1}" | |
) | |
pos_embed = state_dict[key] | |
cls_pos_embed = pos_embed[0, 0, :].unsqueeze(0).unsqueeze(1) | |
other_pos_embed = pos_embed[0, 1:, :].unsqueeze(0).transpose(1, 2) | |
new_pos_embed = F.interpolate(other_pos_embed, size=(num_patches), mode="nearest") | |
new_pos_embed = new_pos_embed.transpose(1, 2) | |
new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1) | |
return new_pos_embed | |
def resize_temporal_embedding(state_dict, key, num_frames): | |
logging.info( | |
f"Resizing temporal position embedding from {state_dict[key].size(1)} to {num_frames}" | |
) | |
time_embed = state_dict[key].transpose(1, 2) | |
new_time_embed = F.interpolate(time_embed, size=(num_frames), mode="nearest") | |
return new_time_embed.transpose(1, 2) | |
def detach_variable(inputs): | |
if isinstance(inputs, tuple): | |
out = [] | |
for inp in inputs: | |
x = inp.detach() | |
x.requires_grad = inp.requires_grad | |
out.append(x) | |
return tuple(out) | |
else: | |
raise RuntimeError( | |
"Only tuple of tensors is supported. Got Unsupported input type: ", | |
type(inputs).__name__, | |
) | |
def check_backward_validity(inputs): | |
if not any(inp.requires_grad for inp in inputs): | |
warnings.warn( | |
"None of the inputs have requires_grad=True. Gradients will be None" | |
) | |