cheng-hust commited on
Commit
7ac818a
·
verified ·
1 Parent(s): d21a218

Upload rtdetr_decoder.py

Browse files
Files changed (1) hide show
  1. src/zoo/rtdetr/rtdetr_decoder.py +664 -0
src/zoo/rtdetr/rtdetr_decoder.py ADDED
@@ -0,0 +1,664 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """by lyuwenyu
2
+ """
3
+
4
+ import math
5
+ import copy
6
+ from collections import OrderedDict
7
+ from typing import Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.nn.init as init
13
+ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
14
+ from torch.nn.parameter import Parameter
15
+ import torch.linalg
16
+
17
+ from .denoising import get_contrastive_denoising_training_group
18
+ from .utils import deformable_attention_core_func, get_activation, inverse_sigmoid
19
+ from .utils import bias_init_with_prob
20
+ from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
21
+
22
+ from src.core import register
23
+
24
+ import numpy as np
25
+
26
+ import scipy.linalg as sl
27
+
28
+ __all__ = ['RTDETRTransformer']
29
+
30
+
31
+
32
+ class MLP(nn.Module):
33
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers, act='relu'):
34
+ super().__init__()
35
+ self.num_layers = num_layers
36
+ h = [hidden_dim] * (num_layers - 1)
37
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
38
+ self.act = nn.Identity() if act is None else get_activation(act)
39
+
40
+ def forward(self, x):
41
+ for i, layer in enumerate(self.layers):
42
+ x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x)
43
+ return x
44
+
45
+
46
+ class CoPE(nn.Module):
47
+ def __init__(self,npos_max,head_dim):
48
+ super(CoPE, self).__init__()
49
+ self.npos_max = npos_max #?
50
+ self.pos_emb = nn.parameter.Parameter(torch.zeros(1,head_dim,npos_max))
51
+
52
+ def forward(self,query,attn_logits):
53
+ #compute positions
54
+ gates = torch.sigmoid(attn_logits) #sig(qk)
55
+ pos = gates.flip(-1).cumsum(dim=-1).flip(-1)
56
+ pos = pos.clamp(max=self.npos_max-1)
57
+ #interpolate from integer positions
58
+ pos_ceil = pos.ceil().long()
59
+ pos_floor = pos.floor().long()
60
+ logits_int = torch.matmul(query,self.pos_emb)
61
+ logits_ceil = logits_int.gather(-1,pos_ceil)
62
+ logits_floor = logits_int.gather(-1,pos_floor)
63
+ w = pos-pos_floor
64
+ return logits_ceil*w+logits_floor*(1-w)
65
+
66
+
67
+
68
+
69
+ class MSDeformableAttention(nn.Module):
70
+ def __init__(self, embed_dim=256, num_heads=8, num_levels=4, num_points=4,):
71
+ """
72
+ Multi-Scale Deformable Attention Module
73
+ """
74
+ super(MSDeformableAttention, self).__init__()
75
+ self.embed_dim = embed_dim
76
+ self.num_heads = num_heads
77
+ self.num_levels = num_levels
78
+ self.num_points = num_points
79
+ self.total_points = num_heads * num_levels * num_points
80
+
81
+ self.head_dim = embed_dim // num_heads
82
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
83
+
84
+ self.sampling_offsets = nn.Linear(embed_dim, self.total_points * 2,)
85
+ self.attention_weights = nn.Linear(embed_dim, self.total_points)
86
+ self.value_proj = nn.Linear(embed_dim, embed_dim)
87
+ self.output_proj = nn.Linear(embed_dim, embed_dim)
88
+
89
+ self.ms_deformable_attn_core = deformable_attention_core_func
90
+
91
+ self._reset_parameters()
92
+
93
+
94
+ def _reset_parameters(self):
95
+ # sampling_offsets
96
+ init.constant_(self.sampling_offsets.weight, 0)
97
+ thetas = torch.arange(self.num_heads, dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
98
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
99
+ grid_init = grid_init / grid_init.abs().max(-1, keepdim=True).values
100
+ grid_init = grid_init.reshape(self.num_heads, 1, 1, 2).tile([1, self.num_levels, self.num_points, 1])
101
+ scaling = torch.arange(1, self.num_points + 1, dtype=torch.float32).reshape(1, 1, -1, 1)
102
+ grid_init *= scaling
103
+ self.sampling_offsets.bias.data[...] = grid_init.flatten()
104
+
105
+ # attention_weights
106
+ init.constant_(self.attention_weights.weight, 0)
107
+ init.constant_(self.attention_weights.bias, 0)
108
+
109
+ # proj
110
+ init.xavier_uniform_(self.value_proj.weight)
111
+ init.constant_(self.value_proj.bias, 0)
112
+ init.xavier_uniform_(self.output_proj.weight)
113
+ init.constant_(self.output_proj.bias, 0)
114
+
115
+
116
+ def forward(self,
117
+ query,
118
+ reference_points,
119
+ value,
120
+ value_spatial_shapes,
121
+ value_mask=None):
122
+ """
123
+ Args:
124
+ query (Tensor): [bs, query_length, C]
125
+ reference_points (Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0),
126
+ bottom-right (1, 1), including padding area
127
+ value (Tensor): [bs, value_length, C]
128
+ value_spatial_shapes (List): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
129
+ value_level_start_index (List): [n_levels], [0, H_0*W_0, H_0*W_0+H_1*W_1, ...]
130
+ value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements
131
+
132
+ Returns:
133
+ output (Tensor): [bs, Length_{query}, C]
134
+ """
135
+ bs, Len_q = query.shape[:2]
136
+ Len_v = value.shape[1]
137
+
138
+ value = self.value_proj(value)
139
+ if value_mask is not None:
140
+ value_mask = value_mask.astype(value.dtype).unsqueeze(-1)
141
+ value *= value_mask
142
+ value = value.reshape(bs, Len_v, self.num_heads, self.head_dim)
143
+
144
+ sampling_offsets = self.sampling_offsets(query).reshape(
145
+ bs, Len_q, self.num_heads, self.num_levels, self.num_points, 2)
146
+ attention_weights = self.attention_weights(query).reshape(
147
+ bs, Len_q, self.num_heads, self.num_levels * self.num_points)
148
+ attention_weights = F.softmax(attention_weights, dim=-1).reshape(
149
+ bs, Len_q, self.num_heads, self.num_levels, self.num_points)
150
+
151
+ if reference_points.shape[-1] == 2:
152
+ offset_normalizer = torch.tensor(value_spatial_shapes)
153
+ offset_normalizer = offset_normalizer.flip([1]).reshape(
154
+ 1, 1, 1, self.num_levels, 1, 2)
155
+ sampling_locations = reference_points.reshape(
156
+ bs, Len_q, 1, self.num_levels, 1, 2
157
+ ) + sampling_offsets / offset_normalizer
158
+ elif reference_points.shape[-1] == 4:
159
+ sampling_locations = (
160
+ reference_points[:, :, None, :, None, :2] + sampling_offsets /
161
+ self.num_points * reference_points[:, :, None, :, None, 2:] * 0.5)
162
+ else:
163
+ raise ValueError(
164
+ "Last dim of reference_points must be 2 or 4, but get {} instead.".
165
+ format(reference_points.shape[-1]))
166
+
167
+ output = self.ms_deformable_attn_core(value, value_spatial_shapes, sampling_locations, attention_weights)
168
+
169
+ output = self.output_proj(output)
170
+
171
+ return output
172
+
173
+
174
+ class TransformerDecoderLayer(nn.Module):
175
+ def __init__(self,
176
+ d_model=256,
177
+ n_head=8,
178
+ dim_feedforward=1024,
179
+ dropout=0.,
180
+ activation="relu",
181
+ n_levels=4,
182
+ n_points=4,
183
+ cope='none',):
184
+ super(TransformerDecoderLayer, self).__init__()
185
+
186
+ # self attention
187
+ self.self_attn = nn.MultiheadAttention(d_model, n_head, dropout=dropout, batch_first=True)
188
+ self.dropout1 = nn.Dropout(dropout)
189
+ self.norm1 = nn.LayerNorm(d_model)
190
+
191
+ # cross attention
192
+ self.cross_attn = MSDeformableAttention(d_model, n_head, n_levels, n_points)
193
+ self.dropout2 = nn.Dropout(dropout)
194
+ self.norm2 = nn.LayerNorm(d_model)
195
+
196
+ # ffn
197
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
198
+ self.activation = getattr(F, activation)
199
+ self.dropout3 = nn.Dropout(dropout)
200
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
201
+ self.dropout4 = nn.Dropout(dropout)
202
+ self.norm3 = nn.LayerNorm(d_model)
203
+
204
+ if cope == '24':
205
+ self.cope = CoPE(24,d_model)
206
+ elif cope == '12':
207
+ self.cope = CoPE(12,d_model)
208
+ else:
209
+ self.cope = None
210
+
211
+ # self._reset_parameters()
212
+
213
+ # def _reset_parameters(self):
214
+ # linear_init_(self.linear1)
215
+ # linear_init_(self.linear2)
216
+ # xavier_uniform_(self.linear1.weight)
217
+ # xavier_uniform_(self.linear2.weight)
218
+
219
+ def with_pos_embed(self, tensor, pos):
220
+ return tensor if pos is None else tensor + pos
221
+
222
+ def forward_ffn(self, tgt):
223
+ return self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
224
+
225
+ def forward(self,
226
+ tgt,
227
+ reference_points,
228
+ memory,
229
+ memory_spatial_shapes,
230
+ memory_level_start_index,
231
+ attn_mask=None,
232
+ memory_mask=None,
233
+ query_pos_embed=None):
234
+ # self attention
235
+ #print(query_pos_embed.shape)
236
+ #qk = torch.bmm (tgt ,tgt.transpose(-1 ,-2))
237
+ # mask = torch.tril(torch.ones_like(qk),diagonal=0)
238
+ # mask = torch.log(mask)
239
+ # query_pos_embed = self.cope(tgt,qk) #position_embedding
240
+
241
+
242
+ # n_tgt = tgt.cpu().detach().numpy()
243
+
244
+ #itgt = tgt.new_tensor(np.array([sl.pinv(i) for i in n_tgt])) #inv_tgt
245
+
246
+ # with torch.no_grad():
247
+ # try:
248
+ # itgt = torch.linalg.pinv(tgt)
249
+ # except:
250
+ # print('wrong!!')
251
+ # itgt = torch.pinverse(tgt.detach().cpu()).cuda()
252
+ # print('qk:',qk.shape)
253
+ # print('tgt:',tgt.shape)
254
+ # print(([email protected](-1,-2)).shape)
255
+ # print('ik:',itgt.shape)
256
+
257
+ # print(torch.round(itgt@tgt))
258
+ # print([email protected](-1,-2))
259
+
260
+ # k = tgt
261
+ # q = tgt + ([email protected](-1,-2))
262
+
263
+ # print((q@(k.transpose(-1,-2))-query_pos_embed))
264
+
265
+ # if attn_mask is not None:
266
+ # attn_mask = torch.where(
267
+ # attn_mask.to(torch.bool),
268
+ # torch.zeros_like(attn_mask),
269
+ # torch.full_like(attn_mask, float('-inf'), dtype=tgt.dtype))
270
+
271
+ if self.cope == None:
272
+ q = k = self.with_pos_embed(tgt, query_pos_embed)
273
+ else:
274
+ qk = torch.bmm (tgt ,tgt.transpose(-1 ,-2))
275
+ query_pos_embed = self.cope(tgt,qk)
276
+ with torch.no_grad():
277
+ try:
278
+ itgt = torch.linalg.pinv(tgt)
279
+ except:
280
+ print('wrong!!')
281
+ itgt = torch.pinverse(tgt.detach().cpu()).cuda()
282
+ k = tgt
283
+ q = tgt + ([email protected](-1,-2))
284
+
285
+
286
+ tgt2, _ = self.self_attn(q, k, value=tgt, attn_mask=attn_mask)
287
+ tgt = tgt + self.dropout1(tgt2)
288
+ tgt = self.norm1(tgt)
289
+
290
+ # cross attention
291
+ if self.cope:
292
+ tgt2 = self.cross_attn(\
293
+ self.with_pos_embed(tgt, [email protected](-1,-2)),#([email protected](-1,-2))), #self.with_pos_embed(tgt, query_pos_embed),
294
+ reference_points,
295
+ memory,
296
+ memory_spatial_shapes,
297
+ memory_mask)
298
+ else:
299
+ tgt2 = self.cross_attn(\
300
+ self.with_pos_embed(tgt, query_pos_embed),
301
+ reference_points,
302
+ memory,
303
+ memory_spatial_shapes,
304
+ memory_mask)
305
+
306
+ tgt = tgt + self.dropout2(tgt2)
307
+ tgt = self.norm2(tgt)
308
+
309
+ # ffn
310
+ tgt2 = self.forward_ffn(tgt)
311
+ tgt = tgt + self.dropout4(tgt2)
312
+ tgt = self.norm3(tgt)
313
+
314
+ return tgt
315
+
316
+
317
+ class TransformerDecoder(nn.Module):
318
+ def __init__(self, hidden_dim, decoder_layer, num_layers, eval_idx=-1):
319
+ super(TransformerDecoder, self).__init__()
320
+ self.layers = nn.ModuleList([copy.deepcopy(decoder_layer) for _ in range(num_layers)])
321
+ self.hidden_dim = hidden_dim
322
+ self.num_layers = num_layers
323
+ self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx
324
+
325
+ def forward(self,
326
+ tgt,
327
+ ref_points_unact,
328
+ memory,
329
+ memory_spatial_shapes,
330
+ memory_level_start_index,
331
+ bbox_head,
332
+ score_head,
333
+ query_pos_head,
334
+ attn_mask=None,
335
+ memory_mask=None):
336
+ output = tgt
337
+ dec_out_bboxes = []
338
+ dec_out_logits = []
339
+ ref_points_detach = F.sigmoid(ref_points_unact)
340
+
341
+ for i, layer in enumerate(self.layers):
342
+ ref_points_input = ref_points_detach.unsqueeze(2)
343
+ query_pos_embed = query_pos_head(ref_points_detach)
344
+
345
+ output = layer(output, ref_points_input, memory,
346
+ memory_spatial_shapes, memory_level_start_index,
347
+ attn_mask, memory_mask, query_pos_embed)
348
+
349
+ inter_ref_bbox = F.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points_detach))
350
+
351
+ if self.training:
352
+ dec_out_logits.append(score_head[i](output))
353
+ if i == 0:
354
+ dec_out_bboxes.append(inter_ref_bbox)
355
+ else:
356
+ dec_out_bboxes.append(F.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points)))
357
+
358
+ elif i == self.eval_idx:
359
+ dec_out_logits.append(score_head[i](output))
360
+ dec_out_bboxes.append(inter_ref_bbox)
361
+ break
362
+
363
+ ref_points = inter_ref_bbox
364
+ ref_points_detach = inter_ref_bbox.detach(
365
+ ) if self.training else inter_ref_bbox
366
+
367
+ return torch.stack(dec_out_bboxes), torch.stack(dec_out_logits)
368
+
369
+
370
+ @register
371
+ class RTDETRTransformer(nn.Module):
372
+ __share__ = ['num_classes']
373
+ def __init__(self,
374
+ num_classes=80,
375
+ hidden_dim=256,
376
+ num_queries=300,
377
+ position_embed_type='sine',
378
+ feat_channels=[512, 1024, 2048],
379
+ feat_strides=[8, 16, 32],
380
+ num_levels=3,
381
+ num_decoder_points=4,
382
+ nhead=8,
383
+ num_decoder_layers=6,
384
+ dim_feedforward=1024,
385
+ dropout=0.,
386
+ activation="relu",
387
+ num_denoising=100,
388
+ label_noise_ratio=0.5,
389
+ box_noise_scale=1.0,
390
+ learnt_init_query=False,
391
+ eval_spatial_size=None,
392
+ eval_idx=-1,
393
+ eps=1e-2,
394
+ aux_loss=True,
395
+ cope='None',):
396
+
397
+ super(RTDETRTransformer, self).__init__()
398
+ assert position_embed_type in ['sine', 'learned'], \
399
+ f'ValueError: position_embed_type not supported {position_embed_type}!'
400
+ assert len(feat_channels) <= num_levels
401
+ assert len(feat_strides) == len(feat_channels)
402
+ for _ in range(num_levels - len(feat_strides)):
403
+ feat_strides.append(feat_strides[-1] * 2)
404
+
405
+ self.hidden_dim = hidden_dim
406
+ self.nhead = nhead
407
+ self.feat_strides = feat_strides
408
+ self.num_levels = num_levels
409
+ self.num_classes = num_classes
410
+ self.num_queries = num_queries
411
+ self.eps = eps
412
+ self.num_decoder_layers = num_decoder_layers
413
+ self.eval_spatial_size = eval_spatial_size
414
+ self.aux_loss = aux_loss
415
+
416
+ # backbone feature projection
417
+ self._build_input_proj_layer(feat_channels)
418
+
419
+ # Transformer module
420
+ decoder_layer = TransformerDecoderLayer(hidden_dim, nhead, dim_feedforward, dropout, activation, num_levels, num_decoder_points,cope)
421
+ self.decoder = TransformerDecoder(hidden_dim, decoder_layer, num_decoder_layers, eval_idx)
422
+
423
+ self.num_denoising = num_denoising
424
+ self.label_noise_ratio = label_noise_ratio
425
+ self.box_noise_scale = box_noise_scale
426
+ # denoising part
427
+ if num_denoising > 0:
428
+ # self.denoising_class_embed = nn.Embedding(num_classes, hidden_dim, padding_idx=num_classes-1) # TODO for load paddle weights
429
+ self.denoising_class_embed = nn.Embedding(num_classes+1, hidden_dim, padding_idx=num_classes)
430
+
431
+ # decoder embedding
432
+ self.learnt_init_query = learnt_init_query
433
+ if learnt_init_query:
434
+ self.tgt_embed = nn.Embedding(num_queries, hidden_dim)
435
+ self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, num_layers=2)
436
+
437
+ # encoder head
438
+ self.enc_output = nn.Sequential(
439
+ nn.Linear(hidden_dim, hidden_dim),
440
+ nn.LayerNorm(hidden_dim,)
441
+ )
442
+ self.enc_score_head = nn.Linear(hidden_dim, num_classes)
443
+ self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, num_layers=3)
444
+
445
+ # decoder head
446
+ self.dec_score_head = nn.ModuleList([
447
+ nn.Linear(hidden_dim, num_classes)
448
+ for _ in range(num_decoder_layers)
449
+ ])
450
+ self.dec_bbox_head = nn.ModuleList([
451
+ MLP(hidden_dim, hidden_dim, 4, num_layers=3)
452
+ for _ in range(num_decoder_layers)
453
+ ])
454
+
455
+ # init encoder output anchors and valid_mask
456
+ if self.eval_spatial_size:
457
+ self.anchors, self.valid_mask = self._generate_anchors()
458
+
459
+ self._reset_parameters()
460
+
461
+ def _reset_parameters(self):
462
+ bias = bias_init_with_prob(0.01)
463
+
464
+ init.constant_(self.enc_score_head.bias, bias)
465
+ init.constant_(self.enc_bbox_head.layers[-1].weight, 0)
466
+ init.constant_(self.enc_bbox_head.layers[-1].bias, 0)
467
+
468
+ for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head):
469
+ init.constant_(cls_.bias, bias)
470
+ init.constant_(reg_.layers[-1].weight, 0)
471
+ init.constant_(reg_.layers[-1].bias, 0)
472
+
473
+ # linear_init_(self.enc_output[0])
474
+ init.xavier_uniform_(self.enc_output[0].weight)
475
+ if self.learnt_init_query:
476
+ init.xavier_uniform_(self.tgt_embed.weight)
477
+ init.xavier_uniform_(self.query_pos_head.layers[0].weight)
478
+ init.xavier_uniform_(self.query_pos_head.layers[1].weight)
479
+
480
+
481
+ def _build_input_proj_layer(self, feat_channels):
482
+ self.input_proj = nn.ModuleList()
483
+ for in_channels in feat_channels:
484
+ self.input_proj.append(
485
+ nn.Sequential(OrderedDict([
486
+ ('conv', nn.Conv2d(in_channels, self.hidden_dim, 1, bias=False)),
487
+ ('norm', nn.BatchNorm2d(self.hidden_dim,))])
488
+ )
489
+ )
490
+
491
+ in_channels = feat_channels[-1]
492
+
493
+ for _ in range(self.num_levels - len(feat_channels)):
494
+ self.input_proj.append(
495
+ nn.Sequential(OrderedDict([
496
+ ('conv', nn.Conv2d(in_channels, self.hidden_dim, 3, 2, padding=1, bias=False)),
497
+ ('norm', nn.BatchNorm2d(self.hidden_dim))])
498
+ )
499
+ )
500
+ in_channels = self.hidden_dim
501
+
502
+ def _get_encoder_input(self, feats):
503
+ # get projection features
504
+ proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)]
505
+ if self.num_levels > len(proj_feats):
506
+ len_srcs = len(proj_feats)
507
+ for i in range(len_srcs, self.num_levels):
508
+ if i == len_srcs:
509
+ proj_feats.append(self.input_proj[i](feats[-1]))
510
+ else:
511
+ proj_feats.append(self.input_proj[i](proj_feats[-1]))
512
+
513
+ # get encoder inputs
514
+ feat_flatten = []
515
+ spatial_shapes = []
516
+ level_start_index = [0, ]
517
+ for i, feat in enumerate(proj_feats):
518
+ _, _, h, w = feat.shape
519
+ # [b, c, h, w] -> [b, h*w, c]
520
+ feat_flatten.append(feat.flatten(2).permute(0, 2, 1))
521
+ # [num_levels, 2]
522
+ spatial_shapes.append([h, w])
523
+ # [l], start index of each level
524
+ level_start_index.append(h * w + level_start_index[-1])
525
+
526
+ # [b, l, c]
527
+ feat_flatten = torch.concat(feat_flatten, 1)
528
+ level_start_index.pop()
529
+ return (feat_flatten, spatial_shapes, level_start_index)
530
+
531
+ def _generate_anchors(self,
532
+ spatial_shapes=None,
533
+ grid_size=0.05,
534
+ dtype=torch.float32,
535
+ device='cpu'):
536
+ if spatial_shapes is None:
537
+ spatial_shapes = [[int(self.eval_spatial_size[0] / s), int(self.eval_spatial_size[1] / s)]
538
+ for s in self.feat_strides
539
+ ]
540
+ anchors = []
541
+ for lvl, (h, w) in enumerate(spatial_shapes):
542
+ grid_y, grid_x = torch.meshgrid(\
543
+ torch.arange(end=h, dtype=dtype), \
544
+ torch.arange(end=w, dtype=dtype), indexing='ij')
545
+ grid_xy = torch.stack([grid_x, grid_y], -1)
546
+ valid_WH = torch.tensor([w, h]).to(dtype)
547
+ grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH
548
+ wh = torch.ones_like(grid_xy) * grid_size * (2.0 ** lvl)
549
+ anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, h * w, 4))
550
+
551
+ anchors = torch.concat(anchors, 1).to(device)
552
+ valid_mask = ((anchors > self.eps) * (anchors < 1 - self.eps)).all(-1, keepdim=True)
553
+ anchors = torch.log(anchors / (1 - anchors))
554
+ # anchors = torch.where(valid_mask, anchors, float('inf'))
555
+ # anchors[valid_mask] = torch.inf # valid_mask [1, 8400, 1]
556
+ anchors = torch.where(valid_mask, anchors, torch.inf)
557
+
558
+ return anchors, valid_mask
559
+
560
+
561
+ def _get_decoder_input(self,
562
+ memory,
563
+ spatial_shapes,
564
+ denoising_class=None,
565
+ denoising_bbox_unact=None):
566
+ bs, _, _ = memory.shape
567
+ # prepare input for decoder
568
+ if self.training or self.eval_spatial_size is None:
569
+ anchors, valid_mask = self._generate_anchors(spatial_shapes, device=memory.device)
570
+ else:
571
+ anchors, valid_mask = self.anchors.to(memory.device), self.valid_mask.to(memory.device)
572
+
573
+ # memory = torch.where(valid_mask, memory, 0)
574
+ memory = valid_mask.to(memory.dtype) * memory # TODO fix type error for onnx export
575
+
576
+ output_memory = self.enc_output(memory)
577
+
578
+ enc_outputs_class = self.enc_score_head(output_memory)
579
+ enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors
580
+
581
+ _, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.num_queries, dim=1)
582
+
583
+ reference_points_unact = enc_outputs_coord_unact.gather(dim=1, \
584
+ index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_coord_unact.shape[-1]))
585
+
586
+ enc_topk_bboxes = F.sigmoid(reference_points_unact)
587
+ if denoising_bbox_unact is not None:
588
+ reference_points_unact = torch.concat(
589
+ [denoising_bbox_unact, reference_points_unact], 1)
590
+
591
+ enc_topk_logits = enc_outputs_class.gather(dim=1, \
592
+ index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1]))
593
+
594
+ # extract region features
595
+ if self.learnt_init_query:
596
+ target = self.tgt_embed.weight.unsqueeze(0).tile([bs, 1, 1])
597
+ else:
598
+ target = output_memory.gather(dim=1, \
599
+ index=topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1]))
600
+ target = target.detach()
601
+
602
+ if denoising_class is not None:
603
+ target = torch.concat([denoising_class, target], 1)
604
+
605
+ return target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits
606
+
607
+
608
+ def forward(self, feats, targets=None):
609
+
610
+ # input projection and embedding
611
+ (memory, spatial_shapes, level_start_index) = self._get_encoder_input(feats)
612
+
613
+ # prepare denoising training
614
+ if self.training and self.num_denoising > 0:
615
+ denoising_class, denoising_bbox_unact, attn_mask, dn_meta = \
616
+ get_contrastive_denoising_training_group(targets, \
617
+ self.num_classes,
618
+ self.num_queries,
619
+ self.denoising_class_embed,
620
+ num_denoising=self.num_denoising,
621
+ label_noise_ratio=self.label_noise_ratio,
622
+ box_noise_scale=self.box_noise_scale, )
623
+ else:
624
+ denoising_class, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None
625
+
626
+ target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = \
627
+ self._get_decoder_input(memory, spatial_shapes, denoising_class, denoising_bbox_unact)
628
+
629
+ # decoder
630
+ out_bboxes, out_logits = self.decoder(
631
+ target,
632
+ init_ref_points_unact,
633
+ memory,
634
+ spatial_shapes,
635
+ level_start_index,
636
+ self.dec_bbox_head,
637
+ self.dec_score_head,
638
+ self.query_pos_head,
639
+ attn_mask=attn_mask)
640
+
641
+ if self.training and dn_meta is not None:
642
+ dn_out_bboxes, out_bboxes = torch.split(out_bboxes, dn_meta['dn_num_split'], dim=2)
643
+ dn_out_logits, out_logits = torch.split(out_logits, dn_meta['dn_num_split'], dim=2)
644
+
645
+ out = {'pred_logits': out_logits[-1], 'pred_boxes': out_bboxes[-1]}
646
+
647
+ if self.training and self.aux_loss:
648
+ out['aux_outputs'] = self._set_aux_loss(out_logits[:-1], out_bboxes[:-1])
649
+ out['aux_outputs'].extend(self._set_aux_loss([enc_topk_logits], [enc_topk_bboxes]))
650
+
651
+ if self.training and dn_meta is not None:
652
+ out['dn_aux_outputs'] = self._set_aux_loss(dn_out_logits, dn_out_bboxes)
653
+ out['dn_meta'] = dn_meta
654
+
655
+ return out
656
+
657
+
658
+ @torch.jit.unused
659
+ def _set_aux_loss(self, outputs_class, outputs_coord):
660
+ # this is a workaround to make torchscript happy, as torchscript
661
+ # doesn't support dictionary with non-homogeneous values, such
662
+ # as a dict having both a Tensor and a list.
663
+ return [{'pred_logits': a, 'pred_boxes': b}
664
+ for a, b in zip(outputs_class, outputs_coord)]