garyzgao commited on
Commit
165f2cc
·
verified ·
1 Parent(s): 00cbce4

Uploading the file for model definition.

Browse files
Files changed (1) hide show
  1. model.py +134 -0
model.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import timm
5
+ from typing import Optional
6
+
7
+ class DecoderBlock(nn.Module):
8
+ def __init__(self, in_channels: int, skip_channels: int, out_channels: int):
9
+ super(DecoderBlock, self).__init__()
10
+ self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
11
+ self.conv1 = nn.Conv2d(out_channels + skip_channels, out_channels, kernel_size=3, padding=1)
12
+ self.bn1 = nn.BatchNorm2d(out_channels)
13
+ self.relu1 = nn.ReLU(inplace=True)
14
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
15
+ self.bn2 = nn.BatchNorm2d(out_channels)
16
+ self.relu2 = nn.ReLU(inplace=True)
17
+
18
+ def forward(self, x: torch.Tensor, skip: Optional[torch.Tensor]) -> torch.Tensor:
19
+ x = self.up(x)
20
+ if skip is not None:
21
+ if x.size() != skip.size():
22
+ x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=True)
23
+ x = torch.cat([x, skip], dim=1)
24
+ x = self.conv1(x)
25
+ x = self.bn1(x)
26
+ x = self.relu1(x)
27
+ x = self.conv2(x)
28
+ x = self.bn2(x)
29
+ x = self.relu2(x)
30
+ return x
31
+
32
+ class UNetTimmWithClassification(nn.Module):
33
+ def __init__(self,
34
+ encoder_name: str = 'resnet50',
35
+ encoder_weights: Optional[str] = 'imagenet',
36
+ num_classes_seg: int = 1,
37
+ num_classes_cls: int = 9
38
+ ):
39
+
40
+ super(UNetTimmWithClassification, self).__init__()
41
+
42
+ self.encoder = timm.create_model(
43
+ encoder_name,
44
+ pretrained=(encoder_weights == 'imagenet'),
45
+ features_only=True,
46
+ in_chans=3
47
+ )
48
+
49
+ encoder_channels = self.encoder.feature_info.channels()
50
+ decoder_channels = [2048, 1024, 512, 256]
51
+ # decoder_channels = [512, 256, 128, 64]
52
+
53
+ self.decoder4 = DecoderBlock(
54
+ in_channels=decoder_channels[0],
55
+ skip_channels=encoder_channels[3],
56
+ out_channels=decoder_channels[1]
57
+ )
58
+ self.decoder3 = DecoderBlock(
59
+ in_channels=decoder_channels[1],
60
+ skip_channels=encoder_channels[2],
61
+ out_channels=decoder_channels[2]
62
+ )
63
+ self.decoder2 = DecoderBlock(
64
+ in_channels=decoder_channels[2],
65
+ skip_channels=encoder_channels[1],
66
+ out_channels=decoder_channels[3]
67
+ )
68
+ self.decoder1 = DecoderBlock(
69
+ in_channels=decoder_channels[3],
70
+ skip_channels=encoder_channels[0],
71
+ out_channels=decoder_channels[3]
72
+ )
73
+ self.final_up = nn.ConvTranspose2d(
74
+ in_channels=decoder_channels[-1],
75
+ out_channels=32,
76
+ kernel_size=2,
77
+ stride=2
78
+ )
79
+ self.final_conv_seg = nn.Conv2d(
80
+ in_channels=32,
81
+ out_channels=num_classes_seg,
82
+ kernel_size=1
83
+ )
84
+
85
+ #Cls head
86
+ self.classification_head = nn.Sequential(
87
+ nn.AdaptiveAvgPool2d(1),
88
+ nn.Flatten(),
89
+ nn.Linear(2048, 512),
90
+ # nn.Linear(encoder_channels[-1], 512),
91
+ nn.Dropout(0.2),
92
+ nn.Linear(512, 512),
93
+ nn.ReLU(),
94
+ nn.Linear(512, num_classes_cls)
95
+ )
96
+
97
+ if num_classes_cls > 1:
98
+ self.classification_activation = nn.Softmax(dim=1)
99
+ elif num_classes_cls == 1:
100
+ self.classification_activation = nn.Sigmoid()
101
+ else:
102
+ self.classification_activation = None
103
+
104
+ if self.classification_activation is not None:
105
+ self.classification_head.add_module("activation", self.classification_activation)
106
+
107
+ #Xavier weight initialize
108
+ if encoder_weights == 'xavier':
109
+ self.apply(self.xavier_init_weights)
110
+
111
+ def xavier_init_weights(self, m):
112
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
113
+ nn.init.xavier_uniform_(m.weight)
114
+ if m.bias is not None:
115
+ nn.init.zeros_(m.bias)
116
+ elif isinstance(m, nn.Linear):
117
+ nn.init.xavier_uniform_(m.weight)
118
+ if m.bias is not None:
119
+ nn.init.zeros_(m.bias)
120
+ elif isinstance(m, nn.BatchNorm2d):
121
+ nn.init.ones_(m.weight)
122
+ nn.init.zeros_(m.bias)
123
+
124
+ def forward(self, x: torch.Tensor) -> tuple:
125
+ features = self.encoder(x)
126
+ C0, C1, C2, C3, C4 = features
127
+ cls = self.classification_head(C4)
128
+ D4 = self.decoder4(C4, C3)
129
+ D3 = self.decoder3(D4, C2)
130
+ D2 = self.decoder2(D3, C1)
131
+ D1 = self.decoder1(D2, C0)
132
+ x = self.final_up(D1)
133
+ seg = self.final_conv_seg(x)
134
+ return seg, cls