import torch import torch.nn as nn import torch.nn.functional as F from torchvision.models.resnet import ResNet, Bottleneck # from util import util # model_urls = { # 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', # } def copy_state_dict(state_dict, model, strip=None, replace=None): tgt_state = model.state_dict() copied_names = set() for name, param in state_dict.items(): if strip is not None and replace is None and name.startswith(strip): name = name[len(strip):] if strip is not None and replace is not None: name = name.replace(strip, replace) if name not in tgt_state: continue if isinstance(param, torch.nn.Parameter): param = param.data if param.size() != tgt_state[name].size(): print('mismatch:', name, param.size(), tgt_state[name].size()) continue tgt_state[name].copy_(param) copied_names.add(name) missing = set(tgt_state.keys()) - copied_names if len(missing) > 0: print("missing keys in state_dict:", missing) class ResNeXt50(nn.Module): def __init__(self, opt): super(ResNeXt50, self).__init__() self.model = ResNet(Bottleneck, [3, 4, 6, 3], groups=32, width_per_group=4) self.opt = opt # self.reduced_id_dim = opt.reduced_id_dim self.conv1x1 = nn.Conv2d(512 * Bottleneck.expansion, 512, kernel_size=1, padding=0) self.fc = nn.Linear(512 * Bottleneck.expansion, opt.data.num_classes) # self.fc_pre = nn.Sequential(nn.Linear(512 * Bottleneck.expansion, self.reduced_id_dim), nn.ReLU()) def load_pretrain(self, load_path): check_point = torch.load(load_path) copy_state_dict(check_point, self.model) def forward_feature(self, input): x = self.model.conv1(input) x = self.model.bn1(x) x = self.model.relu(x) x = self.model.maxpool(x) x = self.model.layer1(x) x = self.model.layer2(x) x = self.model.layer3(x) x = self.model.layer4(x) net = self.model.avgpool(x) net = torch.flatten(net, 1) x = self.conv1x1(x) # x = self.fc_pre(x) return net, x def forward(self, input): input_batch = input.view(-1, self.opt.model.output_nc, self.opt.data.img_size, self.opt.data.img_size) net, x = self.forward_feature(input_batch) net = net.view(-1, self.opt.num_inputs, 512 * Bottleneck.expansion) x = F.adaptive_avg_pool2d(x, (7, 7)) x = x.view(-1, self.opt.num_inputs, 512, 7, 7) net = torch.mean(net, 1) x = torch.mean(x, 1) cls_scores = self.fc(net) return [net, x], cls_scores # net is feature with dim all from channel; # x is feature with dim all from channel, but one more conv added and another 7*7 spatial size