Create new file
Browse files- model_video.py +297 -0
model_video.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import init
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.optim import Adam
|
6 |
+
import numpy
|
7 |
+
from einops import rearrange
|
8 |
+
import time
|
9 |
+
from transformer import Transformer
|
10 |
+
from Intra_MLP import index_points,knn_l2
|
11 |
+
|
12 |
+
# vgg choice
|
13 |
+
base = {'vgg': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']}
|
14 |
+
|
15 |
+
# vgg16
|
16 |
+
def vgg(cfg, i=3, batch_norm=True):
|
17 |
+
layers = []
|
18 |
+
in_channels = i
|
19 |
+
for v in cfg:
|
20 |
+
if v == 'M':
|
21 |
+
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
|
22 |
+
else:
|
23 |
+
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
|
24 |
+
if batch_norm:
|
25 |
+
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
|
26 |
+
else:
|
27 |
+
layers += [conv2d, nn.ReLU(inplace=True)]
|
28 |
+
in_channels = v
|
29 |
+
return layers
|
30 |
+
|
31 |
+
|
32 |
+
def hsp(in_channel, out_channel):
|
33 |
+
layers = nn.Sequential(nn.Conv2d(in_channel, out_channel, 1, 1),
|
34 |
+
nn.ReLU())
|
35 |
+
return layers
|
36 |
+
|
37 |
+
def cls_modulation_branch(in_channel, hiden_channel):
|
38 |
+
layers = nn.Sequential(nn.Linear(in_channel, hiden_channel),
|
39 |
+
nn.ReLU())
|
40 |
+
return layers
|
41 |
+
|
42 |
+
def cls_branch(hiden_channel, class_num):
|
43 |
+
layers = nn.Sequential(nn.Linear(hiden_channel, class_num),
|
44 |
+
nn.Sigmoid())
|
45 |
+
return layers
|
46 |
+
|
47 |
+
def intra():
|
48 |
+
layers = []
|
49 |
+
layers += [nn.Conv2d(512, 512, 1, 1)]
|
50 |
+
layers += [nn.Sigmoid()]
|
51 |
+
return layers
|
52 |
+
|
53 |
+
def concat_r():
|
54 |
+
layers = []
|
55 |
+
layers += [nn.Conv2d(512, 512, 1, 1)]
|
56 |
+
layers += [nn.ReLU()]
|
57 |
+
layers += [nn.Conv2d(512, 512, 3, 1, 1)]
|
58 |
+
layers += [nn.ReLU()]
|
59 |
+
layers += [nn.ConvTranspose2d(512, 512, 4, 2, 1)]
|
60 |
+
return layers
|
61 |
+
|
62 |
+
def concat_1():
|
63 |
+
layers = []
|
64 |
+
layers += [nn.Conv2d(512, 512, 1, 1)]
|
65 |
+
layers += [nn.ReLU()]
|
66 |
+
layers += [nn.Conv2d(512, 512, 3, 1, 1)]
|
67 |
+
layers += [nn.ReLU()]
|
68 |
+
return layers
|
69 |
+
|
70 |
+
def mask_branch():
|
71 |
+
layers = []
|
72 |
+
layers += [nn.Conv2d(512, 2, 3, 1, 1)]
|
73 |
+
layers += [nn.ConvTranspose2d(2, 2, 8, 4, 2)]
|
74 |
+
layers += [nn.Softmax2d()]
|
75 |
+
return layers
|
76 |
+
|
77 |
+
def incr_channel():
|
78 |
+
layers = []
|
79 |
+
layers += [nn.Conv2d(128, 512, 3, 1, 1)]
|
80 |
+
layers += [nn.Conv2d(256, 512, 3, 1, 1)]
|
81 |
+
layers += [nn.Conv2d(512, 512, 3, 1, 1)]
|
82 |
+
layers += [nn.Conv2d(512, 512, 3, 1, 1)]
|
83 |
+
return layers
|
84 |
+
|
85 |
+
def incr_channel2():
|
86 |
+
layers = []
|
87 |
+
layers += [nn.Conv2d(512, 512, 3, 1, 1)]
|
88 |
+
layers += [nn.Conv2d(512, 512, 3, 1, 1)]
|
89 |
+
layers += [nn.Conv2d(512, 512, 3, 1, 1)]
|
90 |
+
layers += [nn.Conv2d(512, 512, 3, 1, 1)]
|
91 |
+
layers += [nn.ReLU()]
|
92 |
+
return layers
|
93 |
+
|
94 |
+
def norm(x, dim):
|
95 |
+
squared_norm = (x ** 2).sum(dim=dim, keepdim=True)
|
96 |
+
normed = x / torch.sqrt(squared_norm)
|
97 |
+
return normed
|
98 |
+
|
99 |
+
def fuse_hsp(x, p,group_size=5):
|
100 |
+
|
101 |
+
t = torch.zeros(group_size, x.size(1))
|
102 |
+
for i in range(x.size(0)):
|
103 |
+
tmp = x[i, :]
|
104 |
+
if i == 0:
|
105 |
+
nx = tmp.expand_as(t)
|
106 |
+
else:
|
107 |
+
nx = torch.cat(([nx, tmp.expand_as(t)]), dim=0)
|
108 |
+
nx = nx.view(x.size(0)*group_size, x.size(1), 1, 1)
|
109 |
+
y = nx.expand_as(p)
|
110 |
+
return y
|
111 |
+
|
112 |
+
|
113 |
+
class Model(nn.Module):
|
114 |
+
def __init__(self, device, base, incr_channel, incr_channel2, hsp1, hsp2, cls_m, cls, concat_r, concat_1, mask_branch, intra,demo_mode=False):
|
115 |
+
super(Model, self).__init__()
|
116 |
+
self.base = nn.ModuleList(base)
|
117 |
+
self.sp1 = hsp1
|
118 |
+
self.sp2 = hsp2
|
119 |
+
self.cls_m = cls_m
|
120 |
+
self.cls = cls
|
121 |
+
self.incr_channel1 = nn.ModuleList(incr_channel)
|
122 |
+
self.incr_channel2 = nn.ModuleList(incr_channel2)
|
123 |
+
self.concat4 = nn.ModuleList(concat_r)
|
124 |
+
self.concat3 = nn.ModuleList(concat_r)
|
125 |
+
self.concat2 = nn.ModuleList(concat_r)
|
126 |
+
self.concat1 = nn.ModuleList(concat_1)
|
127 |
+
self.mask = nn.ModuleList(mask_branch)
|
128 |
+
self.extract = [13, 23, 33, 43]
|
129 |
+
self.device = device
|
130 |
+
self.group_size = 5
|
131 |
+
self.intra = nn.ModuleList(intra)
|
132 |
+
self.transformer_1=Transformer(512,4,4,782,group=self.group_size)
|
133 |
+
self.transformer_2=Transformer(512,4,4,782,group=self.group_size)
|
134 |
+
self.demo_mode=demo_mode
|
135 |
+
|
136 |
+
def forward(self, x):
|
137 |
+
# backbone, p is the pool2, 3, 4, 5
|
138 |
+
p = list()
|
139 |
+
for k in range(len(self.base)):
|
140 |
+
x = self.base[k](x)
|
141 |
+
if k in self.extract:
|
142 |
+
p.append(x)
|
143 |
+
|
144 |
+
|
145 |
+
# increase the channel
|
146 |
+
newp = list()
|
147 |
+
newp_T=list()
|
148 |
+
for k in range(len(p)):
|
149 |
+
np = self.incr_channel1[k](p[k])
|
150 |
+
np = self.incr_channel2[k](np)
|
151 |
+
newp.append(self.incr_channel2[4](np))
|
152 |
+
if k==3:
|
153 |
+
tmp_newp_T3=self.transformer_1(newp[k])
|
154 |
+
newp_T.append(tmp_newp_T3)
|
155 |
+
if k==2:
|
156 |
+
newp_T.append(self.transformer_2(newp[k]))
|
157 |
+
if k<2:
|
158 |
+
newp_T.append(None)
|
159 |
+
|
160 |
+
|
161 |
+
# intra-MLP
|
162 |
+
point = newp[3].view(newp[3].size(0), newp[3].size(1), -1)
|
163 |
+
point = point.permute(0,2,1)
|
164 |
+
|
165 |
+
idx = knn_l2(self.device, point, 4, 1)
|
166 |
+
feat=idx
|
167 |
+
new_point = index_points(self.device, point,idx)
|
168 |
+
|
169 |
+
group_point = new_point.permute(0, 3, 2, 1)
|
170 |
+
group_point = self.intra[0](group_point)
|
171 |
+
group_point = torch.max(group_point, 2)[0] # [B, D', S]
|
172 |
+
|
173 |
+
intra_mask = group_point.view(group_point.size(0), group_point.size(1), 7, 7)
|
174 |
+
intra_mask = intra_mask + newp[3]
|
175 |
+
|
176 |
+
spa_mask = self.intra[1](intra_mask)
|
177 |
+
|
178 |
+
|
179 |
+
x = newp[3]
|
180 |
+
x = self.sp1(x)
|
181 |
+
x = x.view(-1, x.size(1), x.size(2) * x.size(3))
|
182 |
+
x = torch.bmm(x, x.transpose(1, 2))
|
183 |
+
x = x.view(-1, x.size(1) * x.size(2))
|
184 |
+
x = x.view(x.size(0) // self.group_size, x.size(1), -1, 1)
|
185 |
+
x = self.sp2(x)
|
186 |
+
x = x.view(-1, x.size(1), x.size(2) * x.size(3))
|
187 |
+
x = torch.bmm(x, x.transpose(1, 2))
|
188 |
+
x = x.view(-1, x.size(1) * x.size(2))
|
189 |
+
|
190 |
+
#cls pred
|
191 |
+
cls_modulated_vector = self.cls_m(x)
|
192 |
+
cls_pred = self.cls(cls_modulated_vector)
|
193 |
+
|
194 |
+
#semantic and spatial modulator
|
195 |
+
g1 = fuse_hsp(cls_modulated_vector, newp[0],self.group_size)
|
196 |
+
g2 = fuse_hsp(cls_modulated_vector, newp[1],self.group_size)
|
197 |
+
g3 = fuse_hsp(cls_modulated_vector, newp[2],self.group_size)
|
198 |
+
g4 = fuse_hsp(cls_modulated_vector, newp[3],self.group_size)
|
199 |
+
|
200 |
+
spa_1 = F.interpolate(spa_mask, size=[g1.size(2), g1.size(3)], mode='bilinear')
|
201 |
+
spa_1 = spa_1.expand_as(g1)
|
202 |
+
spa_2 = F.interpolate(spa_mask, size=[g2.size(2), g2.size(3)], mode='bilinear')
|
203 |
+
spa_2 = spa_2.expand_as(g2)
|
204 |
+
spa_3 = F.interpolate(spa_mask, size=[g3.size(2), g3.size(3)], mode='bilinear')
|
205 |
+
spa_3 = spa_3.expand_as(g3)
|
206 |
+
spa_4 = F.interpolate(spa_mask, size=[g4.size(2), g4.size(3)], mode='bilinear')
|
207 |
+
spa_4 = spa_4.expand_as(g4)
|
208 |
+
|
209 |
+
y4 = newp_T[3] * g4 + spa_4
|
210 |
+
for k in range(len(self.concat4)):
|
211 |
+
y4 = self.concat4[k](y4)
|
212 |
+
|
213 |
+
y3 = newp_T[2] * g3 + spa_3
|
214 |
+
|
215 |
+
for k in range(len(self.concat3)):
|
216 |
+
y3 = self.concat3[k](y3)
|
217 |
+
if k == 1:
|
218 |
+
y3 = y3 + y4
|
219 |
+
|
220 |
+
y2 = newp[1] * g2 + spa_2
|
221 |
+
|
222 |
+
#print(y2.shape)
|
223 |
+
|
224 |
+
for k in range(len(self.concat2)):
|
225 |
+
y2 = self.concat2[k](y2)
|
226 |
+
if k == 1:
|
227 |
+
y2 = y2 + y3
|
228 |
+
y1 = newp[0] * g1 + spa_1
|
229 |
+
|
230 |
+
for k in range(len(self.concat1)):
|
231 |
+
y1 = self.concat1[k](y1)
|
232 |
+
if k == 1:
|
233 |
+
y1 = y1 + y2
|
234 |
+
y = y1
|
235 |
+
if self.demo_mode:
|
236 |
+
tmp=F.interpolate(y1, size=[14,14], mode='bilinear')
|
237 |
+
tmp=tmp.permute(0,2,3,1).contiguous().reshape(tmp.shape[0]*tmp.shape[2]*tmp.shape[3],tmp.shape[1])
|
238 |
+
tmp=tmp/torch.norm(tmp,p=2,dim=1).unsqueeze(1)
|
239 |
+
feat2=([email protected]())
|
240 |
+
feat=F.interpolate(y, size=[14,14], mode='bilinear')
|
241 |
+
|
242 |
+
# decoder
|
243 |
+
for k in range(len(self.mask)):
|
244 |
+
|
245 |
+
y = self.mask[k](y)
|
246 |
+
mask_pred = y[:, 0, :, :]
|
247 |
+
if self.demo_mode:
|
248 |
+
return cls_pred, mask_pred,feat,feat2
|
249 |
+
else:
|
250 |
+
return cls_pred, mask_pred
|
251 |
+
|
252 |
+
|
253 |
+
|
254 |
+
# build the whole network
|
255 |
+
def build_model(device,demo_mode=False):
|
256 |
+
return Model(device,
|
257 |
+
vgg(base['vgg']),
|
258 |
+
incr_channel(),
|
259 |
+
incr_channel2(),
|
260 |
+
hsp(512, 64),
|
261 |
+
hsp(64**2, 32),
|
262 |
+
cls_modulation_branch(32**2, 512),
|
263 |
+
cls_branch(512, 78),
|
264 |
+
concat_r(),
|
265 |
+
concat_1(),
|
266 |
+
mask_branch(),
|
267 |
+
intra(),demo_mode)
|
268 |
+
|
269 |
+
# weight init
|
270 |
+
def xavier(param):
|
271 |
+
init.xavier_uniform_(param)
|
272 |
+
|
273 |
+
def weights_init(m):
|
274 |
+
if isinstance(m, nn.Conv2d):
|
275 |
+
xavier(m.weight.data)
|
276 |
+
elif isinstance(m, nn.BatchNorm2d):
|
277 |
+
init.constant_(m.weight, 1)
|
278 |
+
init.constant_(m.bias, 0)
|
279 |
+
|
280 |
+
'''import os
|
281 |
+
os.environ['CUDA_VISIBLE_DEVICES']='6'
|
282 |
+
gpu_id='cuda:0'
|
283 |
+
device = torch.device(gpu_id)
|
284 |
+
nt=build_model(device).to(device)
|
285 |
+
it=2
|
286 |
+
bs=1
|
287 |
+
gs=5
|
288 |
+
sum=0
|
289 |
+
with torch.no_grad():
|
290 |
+
for i in range(it):
|
291 |
+
A=torch.rand(bs*gs,3,448,256).cuda()
|
292 |
+
A=A*2-1
|
293 |
+
start=time.time()
|
294 |
+
nt(A)
|
295 |
+
sum+=time.time()-start
|
296 |
+
print(sum/bs/gs/it)'''
|
297 |
+
|