John6666's picture
Upload 351 files
e84842d verified
raw
history blame
15.4 kB
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
Based on https://github.com/facebookresearch/TimeSformer
"""
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# Copyright 2020 Ross Wightman
# Modified model creation / weight loading / state_dict helpers
import logging, warnings
import os
import math
from collections import OrderedDict
import torch
import torch.utils.model_zoo as model_zoo
import torch.nn.functional as F
def load_state_dict(checkpoint_path, use_ema=False):
if checkpoint_path and os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location="cpu")
state_dict_key = "state_dict"
if isinstance(checkpoint, dict):
if use_ema and "state_dict_ema" in checkpoint:
state_dict_key = "state_dict_ema"
if state_dict_key and state_dict_key in checkpoint:
new_state_dict = OrderedDict()
for k, v in checkpoint[state_dict_key].items():
# strip `module.` prefix
name = k[7:] if k.startswith("module") else k
new_state_dict[name] = v
state_dict = new_state_dict
elif "model_state" in checkpoint:
state_dict_key = "model_state"
new_state_dict = OrderedDict()
for k, v in checkpoint[state_dict_key].items():
# strip `model.` prefix
name = k[6:] if k.startswith("model") else k
new_state_dict[name] = v
state_dict = new_state_dict
else:
state_dict = checkpoint
logging.info(
"Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path)
)
return state_dict
else:
logging.error("No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError()
def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
state_dict = load_state_dict(checkpoint_path, use_ema)
model.load_state_dict(state_dict, strict=strict)
# def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
# resume_epoch = None
# if os.path.isfile(checkpoint_path):
# checkpoint = torch.load(checkpoint_path, map_location='cpu')
# if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
# if log_info:
# _logger.info('Restoring model state from checkpoint...')
# new_state_dict = OrderedDict()
# for k, v in checkpoint['state_dict'].items():
# name = k[7:] if k.startswith('module') else k
# new_state_dict[name] = v
# model.load_state_dict(new_state_dict)
# if optimizer is not None and 'optimizer' in checkpoint:
# if log_info:
# _logger.info('Restoring optimizer state from checkpoint...')
# optimizer.load_state_dict(checkpoint['optimizer'])
# if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
# if log_info:
# _logger.info('Restoring AMP loss scaler state from checkpoint...')
# loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
# if 'epoch' in checkpoint:
# resume_epoch = checkpoint['epoch']
# if 'version' in checkpoint and checkpoint['version'] > 1:
# resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
# if log_info:
# _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
# else:
# model.load_state_dict(checkpoint)
# if log_info:
# _logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
# return resume_epoch
# else:
# _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
# raise FileNotFoundError()
def load_pretrained(
model,
cfg=None,
num_classes=1000,
in_chans=3,
filter_fn=None,
img_size=224,
num_frames=8,
num_patches=196,
attention_type="divided_space_time",
pretrained_model="",
strict=True,
):
if cfg is None:
cfg = getattr(model, "default_cfg")
if cfg is None or "url" not in cfg or not cfg["url"]:
logging.warning("Pretrained model URL is invalid, using random initialization.")
return
if len(pretrained_model) == 0:
if cfg is None:
logging.info(f"loading from default config {model.default_cfg}.")
state_dict = model_zoo.load_url(cfg["url"], progress=False, map_location="cpu")
else:
try:
state_dict = load_state_dict(pretrained_model)["model"]
except:
state_dict = load_state_dict(pretrained_model)
if filter_fn is not None:
state_dict = filter_fn(state_dict)
if in_chans == 1:
conv1_name = cfg["first_conv"]
logging.info(
"Converting first conv (%s) pretrained weights from 3 to 1 channel"
% conv1_name
)
conv1_weight = state_dict[conv1_name + ".weight"]
conv1_type = conv1_weight.dtype
conv1_weight = conv1_weight.float()
O, I, J, K = conv1_weight.shape
if I > 3:
assert conv1_weight.shape[1] % 3 == 0
# For models with space2depth stems
conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K)
conv1_weight = conv1_weight.sum(dim=2, keepdim=False)
else:
conv1_weight = conv1_weight.sum(dim=1, keepdim=True)
conv1_weight = conv1_weight.to(conv1_type)
state_dict[conv1_name + ".weight"] = conv1_weight
elif in_chans != 3:
conv1_name = cfg["first_conv"]
conv1_weight = state_dict[conv1_name + ".weight"]
conv1_type = conv1_weight.dtype
conv1_weight = conv1_weight.float()
O, I, J, K = conv1_weight.shape
if I != 3:
logging.warning(
"Deleting first conv (%s) from pretrained weights." % conv1_name
)
del state_dict[conv1_name + ".weight"]
strict = False
else:
logging.info(
"Repeating first conv (%s) weights in channel dim." % conv1_name
)
repeat = int(math.ceil(in_chans / 3))
conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
conv1_weight *= 3 / float(in_chans)
conv1_weight = conv1_weight.to(conv1_type)
state_dict[conv1_name + ".weight"] = conv1_weight
classifier_name = cfg["classifier"]
if num_classes == 1000 and cfg["num_classes"] == 1001:
# special case for imagenet trained models with extra background class in pretrained weights
classifier_weight = state_dict[classifier_name + ".weight"]
state_dict[classifier_name + ".weight"] = classifier_weight[1:]
classifier_bias = state_dict[classifier_name + ".bias"]
state_dict[classifier_name + ".bias"] = classifier_bias[1:]
elif num_classes != state_dict[classifier_name + ".weight"].size(0):
# print('Removing the last fully connected layer due to dimensions mismatch ('+str(num_classes)+ ' != '+str(state_dict[classifier_name + '.weight'].size(0))+').', flush=True)
# completely discard fully connected for all other differences between pretrained and created model
del state_dict[classifier_name + ".weight"]
del state_dict[classifier_name + ".bias"]
strict = False
## Resizing the positional embeddings in case they don't match
logging.info(
f"Resizing spatial position embedding from {state_dict['pos_embed'].size(1)} to {num_patches + 1}"
)
if num_patches + 1 != state_dict["pos_embed"].size(1):
pos_embed = state_dict["pos_embed"]
cls_pos_embed = pos_embed[0, 0, :].unsqueeze(0).unsqueeze(1)
other_pos_embed = pos_embed[0, 1:, :].unsqueeze(0).transpose(1, 2)
new_pos_embed = F.interpolate(
other_pos_embed, size=(num_patches), mode="nearest"
)
new_pos_embed = new_pos_embed.transpose(1, 2)
new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1)
state_dict["pos_embed"] = new_pos_embed
## Resizing time embeddings in case they don't match
if "time_embed" in state_dict and num_frames != state_dict["time_embed"].size(1):
logging.info(
f"Resizing temporal position embedding from {state_dict['time_embed'].size(1)} to {num_frames}"
)
time_embed = state_dict["time_embed"].transpose(1, 2)
new_time_embed = F.interpolate(time_embed, size=(num_frames), mode="nearest")
state_dict["time_embed"] = new_time_embed.transpose(1, 2)
## Initializing temporal attention
if attention_type == "divided_space_time":
new_state_dict = state_dict.copy()
for key in state_dict:
if "blocks" in key and "attn" in key:
new_key = key.replace("attn", "temporal_attn")
if not new_key in state_dict:
new_state_dict[new_key] = state_dict[key]
else:
new_state_dict[new_key] = state_dict[new_key]
if "blocks" in key and "norm1" in key:
new_key = key.replace("norm1", "temporal_norm1")
if not new_key in state_dict:
new_state_dict[new_key] = state_dict[key]
else:
new_state_dict[new_key] = state_dict[new_key]
state_dict = new_state_dict
## Loading the weights
model.load_state_dict(state_dict, strict=False)
def load_pretrained_imagenet(
model,
pretrained_model,
cfg=None,
ignore_classifier=True,
num_frames=8,
num_patches=196,
**kwargs,
):
import timm
logging.info(f"Loading vit_base_patch16_224 checkpoints.")
loaded_state_dict = timm.models.vision_transformer.vit_base_patch16_224(
pretrained=True
).state_dict()
del loaded_state_dict["head.weight"]
del loaded_state_dict["head.bias"]
## Initializing temporal attention
new_state_dict = loaded_state_dict.copy()
for key in loaded_state_dict:
if "blocks" in key and "attn" in key:
new_key = key.replace("attn", "temporal_attn")
if not new_key in loaded_state_dict:
new_state_dict[new_key] = loaded_state_dict[key]
else:
new_state_dict[new_key] = loaded_state_dict[new_key]
if "blocks" in key and "norm1" in key:
new_key = key.replace("norm1", "temporal_norm1")
if not new_key in loaded_state_dict:
new_state_dict[new_key] = loaded_state_dict[key]
else:
new_state_dict[new_key] = loaded_state_dict[new_key]
loaded_state_dict = new_state_dict
loaded_keys = loaded_state_dict.keys()
model_keys = model.state_dict().keys()
load_not_in_model = [k for k in loaded_keys if k not in model_keys]
model_not_in_load = [k for k in model_keys if k not in loaded_keys]
toload = dict()
mismatched_shape_keys = []
for k in model_keys:
if k in loaded_keys:
if model.state_dict()[k].shape != loaded_state_dict[k].shape:
mismatched_shape_keys.append(k)
else:
toload[k] = loaded_state_dict[k]
logging.info("Keys in loaded but not in model:")
logging.info(f"In total {len(load_not_in_model)}, {sorted(load_not_in_model)}")
logging.info("Keys in model but not in loaded:")
logging.info(f"In total {len(model_not_in_load)}, {sorted(model_not_in_load)}")
logging.info("Keys in model and loaded, but shape mismatched:")
logging.info(
f"In total {len(mismatched_shape_keys)}, {sorted(mismatched_shape_keys)}"
)
model.load_state_dict(toload, strict=False)
def load_pretrained_kinetics(
model,
pretrained_model,
cfg=None,
ignore_classifier=True,
num_frames=8,
num_patches=196,
**kwargs,
):
if cfg is None:
cfg = getattr(model, "default_cfg")
if cfg is None or "url" not in cfg or not cfg["url"]:
logging.warning("Pretrained model URL is invalid, using random initialization.")
return
assert (
len(pretrained_model) > 0
), "Path to pre-trained Kinetics weights not provided."
state_dict = load_state_dict(pretrained_model)
classifier_name = cfg["classifier"]
if ignore_classifier:
classifier_weight_key = classifier_name + ".weight"
classifier_bias_key = classifier_name + ".bias"
state_dict[classifier_weight_key] = model.state_dict()[classifier_weight_key]
state_dict[classifier_bias_key] = model.state_dict()[classifier_bias_key]
else:
raise NotImplementedError(
"[dxli] Not supporting loading Kinetics-pretrained ckpt with classifier."
)
## Resizing the positional embeddings in case they don't match
if num_patches + 1 != state_dict["pos_embed"].size(1):
new_pos_embed = resize_spatial_embedding(state_dict, "pos_embed", num_patches)
state_dict["pos_embed"] = new_pos_embed
## Resizing time embeddings in case they don't match
if "time_embed" in state_dict and num_frames != state_dict["time_embed"].size(1):
state_dict["time_embed"] = resize_temporal_embedding(
state_dict, "time_embed", num_frames
)
## Loading the weights
try:
model.load_state_dict(state_dict, strict=True)
logging.info("Succeeded in loading Kinetics pre-trained weights.")
except:
logging.error("Error in loading Kinetics pre-trained weights.")
def resize_spatial_embedding(state_dict, key, num_patches):
logging.info(
f"Resizing spatial position embedding from {state_dict[key].size(1)} to {num_patches + 1}"
)
pos_embed = state_dict[key]
cls_pos_embed = pos_embed[0, 0, :].unsqueeze(0).unsqueeze(1)
other_pos_embed = pos_embed[0, 1:, :].unsqueeze(0).transpose(1, 2)
new_pos_embed = F.interpolate(other_pos_embed, size=(num_patches), mode="nearest")
new_pos_embed = new_pos_embed.transpose(1, 2)
new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1)
return new_pos_embed
def resize_temporal_embedding(state_dict, key, num_frames):
logging.info(
f"Resizing temporal position embedding from {state_dict[key].size(1)} to {num_frames}"
)
time_embed = state_dict[key].transpose(1, 2)
new_time_embed = F.interpolate(time_embed, size=(num_frames), mode="nearest")
return new_time_embed.transpose(1, 2)
def detach_variable(inputs):
if isinstance(inputs, tuple):
out = []
for inp in inputs:
x = inp.detach()
x.requires_grad = inp.requires_grad
out.append(x)
return tuple(out)
else:
raise RuntimeError(
"Only tuple of tensors is supported. Got Unsupported input type: ",
type(inputs).__name__,
)
def check_backward_validity(inputs):
if not any(inp.requires_grad for inp in inputs):
warnings.warn(
"None of the inputs have requires_grad=True. Gradients will be None"
)