File size: 2,915 Bytes
8ed2f16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.resnet import ResNet, Bottleneck
# from util import util

# model_urls = {
#     'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
# }
def copy_state_dict(state_dict, model, strip=None, replace=None):
    tgt_state = model.state_dict()
    copied_names = set()
    for name, param in state_dict.items():
        if strip is not None and replace is None and name.startswith(strip):
            name = name[len(strip):]
        if strip is not None and replace is not None:
            name = name.replace(strip, replace)
        if name not in tgt_state:
            continue
        if isinstance(param, torch.nn.Parameter):
            param = param.data
        if param.size() != tgt_state[name].size():
            print('mismatch:', name, param.size(), tgt_state[name].size())
            continue
        tgt_state[name].copy_(param)
        copied_names.add(name)

    missing = set(tgt_state.keys()) - copied_names
    if len(missing) > 0:
        print("missing keys in state_dict:", missing)

class ResNeXt50(nn.Module):
    def __init__(self, opt):
        super(ResNeXt50, self).__init__()
        self.model = ResNet(Bottleneck, [3, 4, 6, 3], groups=32, width_per_group=4)
        self.opt = opt
        # self.reduced_id_dim = opt.reduced_id_dim
        self.conv1x1 = nn.Conv2d(512 * Bottleneck.expansion, 512, kernel_size=1, padding=0)
        self.fc = nn.Linear(512 * Bottleneck.expansion, opt.data.num_classes)
        # self.fc_pre = nn.Sequential(nn.Linear(512 * Bottleneck.expansion, self.reduced_id_dim), nn.ReLU())


    def load_pretrain(self, load_path):
        check_point = torch.load(load_path)
        copy_state_dict(check_point, self.model)

    def forward_feature(self, input):
        x = self.model.conv1(input)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)

        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)
        net = self.model.avgpool(x)
        net = torch.flatten(net, 1)
        x = self.conv1x1(x)
        # x = self.fc_pre(x)
        return net, x

    def forward(self, input):
        input_batch = input.view(-1, self.opt.model.output_nc, self.opt.data.img_size, self.opt.data.img_size)
        net, x = self.forward_feature(input_batch)
        net = net.view(-1, self.opt.num_inputs, 512 * Bottleneck.expansion)
        x = F.adaptive_avg_pool2d(x, (7, 7))
        x = x.view(-1, self.opt.num_inputs, 512, 7, 7)
        net = torch.mean(net, 1)
        x = torch.mean(x, 1)
        cls_scores = self.fc(net)

        return [net, x], cls_scores 
        # net is feature with dim all from channel; 
        # x is feature with dim all from channel, but one more conv added and another 7*7 spatial size