Spaces:
Sleeping
Sleeping
File size: 10,490 Bytes
20239f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
import copy
import os
from pathlib import Path
import torch
from timm.models import create_model
from torchvision.models import get_model
from models import pdiscoformer_vit_bb, pdisconet_vit_bb, pdisconet_resnet_torchvision_bb
from models.individual_landmark_resnet import IndividualLandmarkResNet
from models.individual_landmark_convnext import IndividualLandmarkConvNext
from models.individual_landmark_vit import IndividualLandmarkViT
from utils import load_state_dict_pdisco
def load_model_arch(args, num_cls):
"""
Function to load the model
:param args: Arguments from the command line
:param num_cls: Number of classes in the dataset
:return:
"""
if 'resnet' in args.model_arch:
num_layers_split = [int(s) for s in args.model_arch if s.isdigit()]
num_layers = int(''.join(map(str, num_layers_split)))
if num_layers >= 100:
timm_model_arch = args.model_arch + ".a1h_in1k"
else:
timm_model_arch = args.model_arch + ".a1_in1k"
if "resnet" in args.model_arch and args.use_torchvision_resnet_model:
weights = "DEFAULT" if args.pretrained_start_weights else None
base_model = get_model(args.model_arch, weights=weights)
elif "resnet" in args.model_arch and not args.use_torchvision_resnet_model:
if args.eval_only:
base_model = create_model(
timm_model_arch,
pretrained=args.pretrained_start_weights,
num_classes=num_cls,
output_stride=args.output_stride,
)
else:
base_model = create_model(
timm_model_arch,
pretrained=args.pretrained_start_weights,
drop_path_rate=args.drop_path,
num_classes=num_cls,
output_stride=args.output_stride,
)
elif "convnext" in args.model_arch:
if args.eval_only:
base_model = create_model(
args.model_arch,
pretrained=args.pretrained_start_weights,
num_classes=num_cls,
output_stride=args.output_stride,
)
else:
base_model = create_model(
args.model_arch,
pretrained=args.pretrained_start_weights,
drop_path_rate=args.drop_path,
num_classes=num_cls,
output_stride=args.output_stride,
)
elif "vit" in args.model_arch:
if args.eval_only:
base_model = create_model(
args.model_arch,
pretrained=args.pretrained_start_weights,
img_size=args.image_size,
)
else:
base_model = create_model(
args.model_arch,
pretrained=args.pretrained_start_weights,
drop_path_rate=args.drop_path,
img_size=args.image_size,
)
vit_patch_size = base_model.patch_embed.proj.kernel_size[0]
if args.image_size % vit_patch_size != 0:
raise ValueError(f"Image size {args.image_size} must be divisible by patch size {vit_patch_size}")
else:
raise ValueError('Model not supported.')
return base_model
def init_pdisco_model(base_model, args, num_cls):
"""
Function to initialize the model
:param base_model: Base model
:param args: Arguments from the command line
:param num_cls: Number of classes in the dataset
:return:
"""
# Initialize the network
if 'convnext' in args.model_arch:
sl_channels = base_model.stages[-1].downsample[-1].in_channels
fl_channels = base_model.head.in_features
model = IndividualLandmarkConvNext(base_model, args.num_parts, num_classes=num_cls,
sl_channels=sl_channels, fl_channels=fl_channels,
part_dropout=args.part_dropout, modulation_type=args.modulation_type,
gumbel_softmax=args.gumbel_softmax,
gumbel_softmax_temperature=args.gumbel_softmax_temperature,
gumbel_softmax_hard=args.gumbel_softmax_hard,
modulation_orth=args.modulation_orth, classifier_type=args.classifier_type,
noise_variance=args.noise_variance)
elif 'resnet' in args.model_arch:
sl_channels = base_model.layer4[0].conv1.in_channels
fl_channels = base_model.fc.in_features
model = IndividualLandmarkResNet(base_model, args.num_parts, num_classes=num_cls,
sl_channels=sl_channels, fl_channels=fl_channels,
use_torchvision_model=args.use_torchvision_resnet_model,
part_dropout=args.part_dropout, modulation_type=args.modulation_type,
gumbel_softmax=args.gumbel_softmax,
gumbel_softmax_temperature=args.gumbel_softmax_temperature,
gumbel_softmax_hard=args.gumbel_softmax_hard,
modulation_orth=args.modulation_orth, classifier_type=args.classifier_type,
noise_variance=args.noise_variance)
elif 'vit' in args.model_arch:
model = IndividualLandmarkViT(base_model, num_landmarks=args.num_parts, num_classes=num_cls,
part_dropout=args.part_dropout,
modulation_type=args.modulation_type, gumbel_softmax=args.gumbel_softmax,
gumbel_softmax_temperature=args.gumbel_softmax_temperature,
gumbel_softmax_hard=args.gumbel_softmax_hard,
modulation_orth=args.modulation_orth, classifier_type=args.classifier_type,
noise_variance=args.noise_variance)
else:
raise ValueError('Model not supported.')
return model
def load_model_pdisco(args, num_cls):
"""
Function to load the model
:param args: Arguments from the command line
:param num_cls: Number of classes in the dataset
:return:
"""
base_model = load_model_arch(args, num_cls)
model = init_pdisco_model(base_model, args, num_cls)
return model
def pdiscoformer_vit(pretrained=True, model_dataset="cub", k=8, model_url="", img_size=224, num_cls=200):
"""
Function to load the PDiscoFormer model with ViT backbone
:param pretrained: Boolean flag to load the pretrained weights
:param model_dataset: Dataset for which the model is trained
:param k: Number of unsupervised landmarks the model is trained on
:param model_url: URL to load the model weights from
:param img_size: Image size
:param num_cls: Number of classes in the dataset
:return: PDiscoFormer model with ViT backbone
"""
model = pdiscoformer_vit_bb("vit_base_patch14_reg4_dinov2.lvd142m", num_cls=num_cls, k=k, img_size=img_size)
if pretrained:
hub_dir = torch.hub.get_dir()
model_dir = os.path.join(hub_dir, "pdiscoformer_checkpoints", f"pdiscoformer_{model_dataset}")
Path(model_dir).mkdir(parents=True, exist_ok=True)
url_path = model_url + str(k) + "_parts_snapshot_best.pt"
snapshot_data = torch.hub.load_state_dict_from_url(url_path, model_dir=model_dir, map_location='cpu')
if 'model_state' in snapshot_data:
_, state_dict = load_state_dict_pdisco(snapshot_data)
else:
state_dict = copy.deepcopy(snapshot_data)
model.load_state_dict(state_dict, strict=True)
return model
def pdisconet_vit(pretrained=True, model_dataset="nabirds", k=8, model_url="", img_size=224, num_cls=555):
"""
Function to load the PDiscoNet model with ViT backbone
:param pretrained: Boolean flag to load the pretrained weights
:param model_dataset: Dataset for which the model is trained
:param k: Number of unsupervised landmarks the model is trained on
:param model_url: URL to load the model weights from
:param img_size: Image size
:param num_cls: Number of classes in the dataset
:return: PDiscoNet model with ViT backbone
"""
model = pdisconet_vit_bb("vit_base_patch14_reg4_dinov2.lvd142m", num_cls=num_cls, k=k, img_size=img_size)
if pretrained:
hub_dir = torch.hub.get_dir()
model_dir = os.path.join(hub_dir, "pdiscoformer_checkpoints", f"pdisconet_{model_dataset}")
Path(model_dir).mkdir(parents=True, exist_ok=True)
url_path = model_url + str(k) + "_parts_snapshot_best.pt"
snapshot_data = torch.hub.load_state_dict_from_url(url_path, model_dir=model_dir, map_location='cpu')
if 'model_state' in snapshot_data:
_, state_dict = load_state_dict_pdisco(snapshot_data)
else:
state_dict = copy.deepcopy(snapshot_data)
model.load_state_dict(state_dict, strict=True)
return model
def pdisconet_resnet101(pretrained=True, model_dataset="nabirds", k=8, model_url="", num_cls=555):
"""
Function to load the PDiscoNet model with ResNet-101 backbone
:param pretrained: Boolean flag to load the pretrained weights
:param model_dataset: Dataset for which the model is trained
:param k: Number of unsupervised landmarks the model is trained on
:param model_url: URL to load the model weights from
:param num_cls: Number of classes in the dataset
:return: PDiscoNet model with ResNet-101 backbone
"""
model = pdisconet_resnet_torchvision_bb("resnet101", num_cls=num_cls, k=k)
if pretrained:
hub_dir = torch.hub.get_dir()
model_dir = os.path.join(hub_dir, "pdiscoformer_checkpoints", f"pdisconet_{model_dataset}")
Path(model_dir).mkdir(parents=True, exist_ok=True)
url_path = model_url + str(k) + "_parts_snapshot_best.pt"
snapshot_data = torch.hub.load_state_dict_from_url(url_path, model_dir=model_dir, map_location='cpu')
if 'model_state' in snapshot_data:
_, state_dict = load_state_dict_pdisco(snapshot_data)
else:
state_dict = copy.deepcopy(snapshot_data)
model.load_state_dict(state_dict, strict=True)
return model
|