cheng-hust commited on
Commit
d21a218
·
verified ·
1 Parent(s): dfb8d8b

Delete src/zoo/rtdetr/rtdetr_decoder.py

Browse files
Files changed (1) hide show
  1. src/zoo/rtdetr/rtdetr_decoder.py +0 -627
src/zoo/rtdetr/rtdetr_decoder.py DELETED
@@ -1,627 +0,0 @@
1
- """by lyuwenyu
2
- """
3
-
4
- import math
5
- import copy
6
- from collections import OrderedDict
7
- from typing import Optional, Tuple
8
-
9
- import torch
10
- import torch.nn as nn
11
- import torch.nn.functional as F
12
- import torch.nn.init as init
13
- from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
14
- from torch.nn.parameter import Parameter
15
-
16
- from .denoising import get_contrastive_denoising_training_group
17
- from .utils import deformable_attention_core_func, get_activation, inverse_sigmoid
18
- from .utils import bias_init_with_prob
19
- from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
20
-
21
- from src.core import register
22
-
23
- import numpy as np
24
-
25
- import scipy.linalg as sl
26
-
27
- __all__ = ['RTDETRTransformer']
28
-
29
-
30
-
31
- class MLP(nn.Module):
32
- def __init__(self, input_dim, hidden_dim, output_dim, num_layers, act='relu'):
33
- super().__init__()
34
- self.num_layers = num_layers
35
- h = [hidden_dim] * (num_layers - 1)
36
- self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
37
- self.act = nn.Identity() if act is None else get_activation(act)
38
-
39
- def forward(self, x):
40
- for i, layer in enumerate(self.layers):
41
- x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x)
42
- return x
43
-
44
-
45
- class CoPE(nn.Module):
46
- def __init__(self,npos_max,head_dim):
47
- super(CoPE, self).__init__()
48
- self.npos_max = npos_max #?
49
- self.pos_emb = nn.parameter.Parameter(torch.zeros(1,head_dim,npos_max))
50
-
51
- def forward(self,query,attn_logits):
52
- #compute positions
53
- gates = torch.sigmoid(attn_logits) #sig(qk)
54
- pos = gates.flip(-1).cumsum(dim=-1).flip(-1)
55
- pos = pos.clamp(max=self.npos_max-1)
56
- #interpolate from integer positions
57
- pos_ceil = pos.ceil().long()
58
- pos_floor = pos.floor().long()
59
- logits_int = torch.matmul(query,self.pos_emb)
60
- logits_ceil = logits_int.gather(-1,pos_ceil)
61
- logits_floor = logits_int.gather(-1,pos_floor)
62
- w = pos-pos_floor
63
- return logits_ceil*w+logits_floor*(1-w)
64
-
65
-
66
-
67
-
68
- class MSDeformableAttention(nn.Module):
69
- def __init__(self, embed_dim=256, num_heads=8, num_levels=4, num_points=4,):
70
- """
71
- Multi-Scale Deformable Attention Module
72
- """
73
- super(MSDeformableAttention, self).__init__()
74
- self.embed_dim = embed_dim
75
- self.num_heads = num_heads
76
- self.num_levels = num_levels
77
- self.num_points = num_points
78
- self.total_points = num_heads * num_levels * num_points
79
-
80
- self.head_dim = embed_dim // num_heads
81
- assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
82
-
83
- self.sampling_offsets = nn.Linear(embed_dim, self.total_points * 2,)
84
- self.attention_weights = nn.Linear(embed_dim, self.total_points)
85
- self.value_proj = nn.Linear(embed_dim, embed_dim)
86
- self.output_proj = nn.Linear(embed_dim, embed_dim)
87
-
88
- self.ms_deformable_attn_core = deformable_attention_core_func
89
-
90
- self._reset_parameters()
91
-
92
-
93
- def _reset_parameters(self):
94
- # sampling_offsets
95
- init.constant_(self.sampling_offsets.weight, 0)
96
- thetas = torch.arange(self.num_heads, dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
97
- grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
98
- grid_init = grid_init / grid_init.abs().max(-1, keepdim=True).values
99
- grid_init = grid_init.reshape(self.num_heads, 1, 1, 2).tile([1, self.num_levels, self.num_points, 1])
100
- scaling = torch.arange(1, self.num_points + 1, dtype=torch.float32).reshape(1, 1, -1, 1)
101
- grid_init *= scaling
102
- self.sampling_offsets.bias.data[...] = grid_init.flatten()
103
-
104
- # attention_weights
105
- init.constant_(self.attention_weights.weight, 0)
106
- init.constant_(self.attention_weights.bias, 0)
107
-
108
- # proj
109
- init.xavier_uniform_(self.value_proj.weight)
110
- init.constant_(self.value_proj.bias, 0)
111
- init.xavier_uniform_(self.output_proj.weight)
112
- init.constant_(self.output_proj.bias, 0)
113
-
114
-
115
- def forward(self,
116
- query,
117
- reference_points,
118
- value,
119
- value_spatial_shapes,
120
- value_mask=None):
121
- """
122
- Args:
123
- query (Tensor): [bs, query_length, C]
124
- reference_points (Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0),
125
- bottom-right (1, 1), including padding area
126
- value (Tensor): [bs, value_length, C]
127
- value_spatial_shapes (List): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
128
- value_level_start_index (List): [n_levels], [0, H_0*W_0, H_0*W_0+H_1*W_1, ...]
129
- value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements
130
-
131
- Returns:
132
- output (Tensor): [bs, Length_{query}, C]
133
- """
134
- bs, Len_q = query.shape[:2]
135
- Len_v = value.shape[1]
136
-
137
- value = self.value_proj(value)
138
- if value_mask is not None:
139
- value_mask = value_mask.astype(value.dtype).unsqueeze(-1)
140
- value *= value_mask
141
- value = value.reshape(bs, Len_v, self.num_heads, self.head_dim)
142
-
143
- sampling_offsets = self.sampling_offsets(query).reshape(
144
- bs, Len_q, self.num_heads, self.num_levels, self.num_points, 2)
145
- attention_weights = self.attention_weights(query).reshape(
146
- bs, Len_q, self.num_heads, self.num_levels * self.num_points)
147
- attention_weights = F.softmax(attention_weights, dim=-1).reshape(
148
- bs, Len_q, self.num_heads, self.num_levels, self.num_points)
149
-
150
- if reference_points.shape[-1] == 2:
151
- offset_normalizer = torch.tensor(value_spatial_shapes)
152
- offset_normalizer = offset_normalizer.flip([1]).reshape(
153
- 1, 1, 1, self.num_levels, 1, 2)
154
- sampling_locations = reference_points.reshape(
155
- bs, Len_q, 1, self.num_levels, 1, 2
156
- ) + sampling_offsets / offset_normalizer
157
- elif reference_points.shape[-1] == 4:
158
- sampling_locations = (
159
- reference_points[:, :, None, :, None, :2] + sampling_offsets /
160
- self.num_points * reference_points[:, :, None, :, None, 2:] * 0.5)
161
- else:
162
- raise ValueError(
163
- "Last dim of reference_points must be 2 or 4, but get {} instead.".
164
- format(reference_points.shape[-1]))
165
-
166
- output = self.ms_deformable_attn_core(value, value_spatial_shapes, sampling_locations, attention_weights)
167
-
168
- output = self.output_proj(output)
169
-
170
- return output
171
-
172
-
173
- class TransformerDecoderLayer(nn.Module):
174
- def __init__(self,
175
- d_model=256,
176
- n_head=8,
177
- dim_feedforward=1024,
178
- dropout=0.,
179
- activation="relu",
180
- n_levels=4,
181
- n_points=4,):
182
- super(TransformerDecoderLayer, self).__init__()
183
-
184
- # self attention
185
- self.self_attn = nn.MultiheadAttention(d_model, n_head, dropout=dropout, batch_first=True)
186
- self.dropout1 = nn.Dropout(dropout)
187
- self.norm1 = nn.LayerNorm(d_model)
188
-
189
- # cross attention
190
- self.cross_attn = MSDeformableAttention(d_model, n_head, n_levels, n_points)
191
- self.dropout2 = nn.Dropout(dropout)
192
- self.norm2 = nn.LayerNorm(d_model)
193
-
194
- # ffn
195
- self.linear1 = nn.Linear(d_model, dim_feedforward)
196
- self.activation = getattr(F, activation)
197
- self.dropout3 = nn.Dropout(dropout)
198
- self.linear2 = nn.Linear(dim_feedforward, d_model)
199
- self.dropout4 = nn.Dropout(dropout)
200
- self.norm3 = nn.LayerNorm(d_model)
201
-
202
- self.cope = CoPE(12,d_model)
203
-
204
- # self._reset_parameters()
205
-
206
- # def _reset_parameters(self):
207
- # linear_init_(self.linear1)
208
- # linear_init_(self.linear2)
209
- # xavier_uniform_(self.linear1.weight)
210
- # xavier_uniform_(self.linear2.weight)
211
-
212
- def with_pos_embed(self, tensor, pos):
213
- return tensor if pos is None else tensor + pos
214
-
215
- def forward_ffn(self, tgt):
216
- return self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
217
-
218
- def forward(self,
219
- tgt,
220
- reference_points,
221
- memory,
222
- memory_spatial_shapes,
223
- memory_level_start_index,
224
- attn_mask=None,
225
- memory_mask=None,
226
- query_pos_embed=None):
227
- # self attention
228
- #print(query_pos_embed.shape)
229
- qk = torch.bmm (tgt ,tgt.transpose(-1 ,-2))
230
- mask = torch.tril(torch.ones_like(qk),diagonal=0)
231
- mask = torch.log(mask)
232
- query_pos_embed = self.cope(tgt,qk+mask) #position_embedding
233
-
234
-
235
- n_tgt = tgt.cpu().detach().numpy()
236
-
237
- itgt = tgt.new_tensor(np.array([sl.pinv(i) for i in n_tgt])) #inv_tgt
238
-
239
- # print('qk:',qk.shape)
240
- # print('tgt:',tgt.shape)
241
- # print(([email protected](-1,-2)).shape)
242
- # print('ik:',itgt.shape)
243
-
244
- # print(torch.round(itgt@tgt))
245
- # print([email protected](-1,-2))
246
-
247
- k = tgt
248
- q = tgt + ([email protected](-1,-2))
249
-
250
- # print((q@(k.transpose(-1,-2))-query_pos_embed))
251
-
252
- # if attn_mask is not None:
253
- # attn_mask = torch.where(
254
- # attn_mask.to(torch.bool),
255
- # torch.zeros_like(attn_mask),
256
- # torch.full_like(attn_mask, float('-inf'), dtype=tgt.dtype))
257
-
258
- # q = k = self.with_pos_embed(tgt, query_pos_embed)
259
- tgt2, _ = self.self_attn(q, k, value=tgt, attn_mask=attn_mask)
260
- tgt = tgt + self.dropout1(tgt2)
261
- tgt = self.norm1(tgt)
262
-
263
- # cross attention
264
- tgt2 = self.cross_attn(\
265
- self.with_pos_embed(tgt, ([email protected](-1,-2))), #self.with_pos_embed(tgt, query_pos_embed),
266
- reference_points,
267
- memory,
268
- memory_spatial_shapes,
269
- memory_mask)
270
- tgt = tgt + self.dropout2(tgt2)
271
- tgt = self.norm2(tgt)
272
-
273
- # ffn
274
- tgt2 = self.forward_ffn(tgt)
275
- tgt = tgt + self.dropout4(tgt2)
276
- tgt = self.norm3(tgt)
277
-
278
- return tgt
279
-
280
-
281
- class TransformerDecoder(nn.Module):
282
- def __init__(self, hidden_dim, decoder_layer, num_layers, eval_idx=-1):
283
- super(TransformerDecoder, self).__init__()
284
- self.layers = nn.ModuleList([copy.deepcopy(decoder_layer) for _ in range(num_layers)])
285
- self.hidden_dim = hidden_dim
286
- self.num_layers = num_layers
287
- self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx
288
-
289
- def forward(self,
290
- tgt,
291
- ref_points_unact,
292
- memory,
293
- memory_spatial_shapes,
294
- memory_level_start_index,
295
- bbox_head,
296
- score_head,
297
- query_pos_head,
298
- attn_mask=None,
299
- memory_mask=None):
300
- output = tgt
301
- dec_out_bboxes = []
302
- dec_out_logits = []
303
- ref_points_detach = F.sigmoid(ref_points_unact)
304
-
305
- for i, layer in enumerate(self.layers):
306
- ref_points_input = ref_points_detach.unsqueeze(2)
307
- query_pos_embed = query_pos_head(ref_points_detach)
308
-
309
- output = layer(output, ref_points_input, memory,
310
- memory_spatial_shapes, memory_level_start_index,
311
- attn_mask, memory_mask, query_pos_embed)
312
-
313
- inter_ref_bbox = F.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points_detach))
314
-
315
- if self.training:
316
- dec_out_logits.append(score_head[i](output))
317
- if i == 0:
318
- dec_out_bboxes.append(inter_ref_bbox)
319
- else:
320
- dec_out_bboxes.append(F.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points)))
321
-
322
- elif i == self.eval_idx:
323
- dec_out_logits.append(score_head[i](output))
324
- dec_out_bboxes.append(inter_ref_bbox)
325
- break
326
-
327
- ref_points = inter_ref_bbox
328
- ref_points_detach = inter_ref_bbox.detach(
329
- ) if self.training else inter_ref_bbox
330
-
331
- return torch.stack(dec_out_bboxes), torch.stack(dec_out_logits)
332
-
333
-
334
- @register
335
- class RTDETRTransformer(nn.Module):
336
- __share__ = ['num_classes']
337
- def __init__(self,
338
- num_classes=80,
339
- hidden_dim=256,
340
- num_queries=300,
341
- position_embed_type='sine',
342
- feat_channels=[512, 1024, 2048],
343
- feat_strides=[8, 16, 32],
344
- num_levels=3,
345
- num_decoder_points=4,
346
- nhead=8,
347
- num_decoder_layers=6,
348
- dim_feedforward=1024,
349
- dropout=0.,
350
- activation="relu",
351
- num_denoising=100,
352
- label_noise_ratio=0.5,
353
- box_noise_scale=1.0,
354
- learnt_init_query=False,
355
- eval_spatial_size=None,
356
- eval_idx=-1,
357
- eps=1e-2,
358
- aux_loss=True):
359
-
360
- super(RTDETRTransformer, self).__init__()
361
- assert position_embed_type in ['sine', 'learned'], \
362
- f'ValueError: position_embed_type not supported {position_embed_type}!'
363
- assert len(feat_channels) <= num_levels
364
- assert len(feat_strides) == len(feat_channels)
365
- for _ in range(num_levels - len(feat_strides)):
366
- feat_strides.append(feat_strides[-1] * 2)
367
-
368
- self.hidden_dim = hidden_dim
369
- self.nhead = nhead
370
- self.feat_strides = feat_strides
371
- self.num_levels = num_levels
372
- self.num_classes = num_classes
373
- self.num_queries = num_queries
374
- self.eps = eps
375
- self.num_decoder_layers = num_decoder_layers
376
- self.eval_spatial_size = eval_spatial_size
377
- self.aux_loss = aux_loss
378
-
379
- # backbone feature projection
380
- self._build_input_proj_layer(feat_channels)
381
-
382
- # Transformer module
383
- decoder_layer = TransformerDecoderLayer(hidden_dim, nhead, dim_feedforward, dropout, activation, num_levels, num_decoder_points)
384
- self.decoder = TransformerDecoder(hidden_dim, decoder_layer, num_decoder_layers, eval_idx)
385
-
386
- self.num_denoising = num_denoising
387
- self.label_noise_ratio = label_noise_ratio
388
- self.box_noise_scale = box_noise_scale
389
- # denoising part
390
- if num_denoising > 0:
391
- # self.denoising_class_embed = nn.Embedding(num_classes, hidden_dim, padding_idx=num_classes-1) # TODO for load paddle weights
392
- self.denoising_class_embed = nn.Embedding(num_classes+1, hidden_dim, padding_idx=num_classes)
393
-
394
- # decoder embedding
395
- self.learnt_init_query = learnt_init_query
396
- if learnt_init_query:
397
- self.tgt_embed = nn.Embedding(num_queries, hidden_dim)
398
- self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, num_layers=2)
399
-
400
- # encoder head
401
- self.enc_output = nn.Sequential(
402
- nn.Linear(hidden_dim, hidden_dim),
403
- nn.LayerNorm(hidden_dim,)
404
- )
405
- self.enc_score_head = nn.Linear(hidden_dim, num_classes)
406
- self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, num_layers=3)
407
-
408
- # decoder head
409
- self.dec_score_head = nn.ModuleList([
410
- nn.Linear(hidden_dim, num_classes)
411
- for _ in range(num_decoder_layers)
412
- ])
413
- self.dec_bbox_head = nn.ModuleList([
414
- MLP(hidden_dim, hidden_dim, 4, num_layers=3)
415
- for _ in range(num_decoder_layers)
416
- ])
417
-
418
- # init encoder output anchors and valid_mask
419
- if self.eval_spatial_size:
420
- self.anchors, self.valid_mask = self._generate_anchors()
421
-
422
- self._reset_parameters()
423
-
424
- def _reset_parameters(self):
425
- bias = bias_init_with_prob(0.01)
426
-
427
- init.constant_(self.enc_score_head.bias, bias)
428
- init.constant_(self.enc_bbox_head.layers[-1].weight, 0)
429
- init.constant_(self.enc_bbox_head.layers[-1].bias, 0)
430
-
431
- for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head):
432
- init.constant_(cls_.bias, bias)
433
- init.constant_(reg_.layers[-1].weight, 0)
434
- init.constant_(reg_.layers[-1].bias, 0)
435
-
436
- # linear_init_(self.enc_output[0])
437
- init.xavier_uniform_(self.enc_output[0].weight)
438
- if self.learnt_init_query:
439
- init.xavier_uniform_(self.tgt_embed.weight)
440
- init.xavier_uniform_(self.query_pos_head.layers[0].weight)
441
- init.xavier_uniform_(self.query_pos_head.layers[1].weight)
442
-
443
-
444
- def _build_input_proj_layer(self, feat_channels):
445
- self.input_proj = nn.ModuleList()
446
- for in_channels in feat_channels:
447
- self.input_proj.append(
448
- nn.Sequential(OrderedDict([
449
- ('conv', nn.Conv2d(in_channels, self.hidden_dim, 1, bias=False)),
450
- ('norm', nn.BatchNorm2d(self.hidden_dim,))])
451
- )
452
- )
453
-
454
- in_channels = feat_channels[-1]
455
-
456
- for _ in range(self.num_levels - len(feat_channels)):
457
- self.input_proj.append(
458
- nn.Sequential(OrderedDict([
459
- ('conv', nn.Conv2d(in_channels, self.hidden_dim, 3, 2, padding=1, bias=False)),
460
- ('norm', nn.BatchNorm2d(self.hidden_dim))])
461
- )
462
- )
463
- in_channels = self.hidden_dim
464
-
465
- def _get_encoder_input(self, feats):
466
- # get projection features
467
- proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)]
468
- if self.num_levels > len(proj_feats):
469
- len_srcs = len(proj_feats)
470
- for i in range(len_srcs, self.num_levels):
471
- if i == len_srcs:
472
- proj_feats.append(self.input_proj[i](feats[-1]))
473
- else:
474
- proj_feats.append(self.input_proj[i](proj_feats[-1]))
475
-
476
- # get encoder inputs
477
- feat_flatten = []
478
- spatial_shapes = []
479
- level_start_index = [0, ]
480
- for i, feat in enumerate(proj_feats):
481
- _, _, h, w = feat.shape
482
- # [b, c, h, w] -> [b, h*w, c]
483
- feat_flatten.append(feat.flatten(2).permute(0, 2, 1))
484
- # [num_levels, 2]
485
- spatial_shapes.append([h, w])
486
- # [l], start index of each level
487
- level_start_index.append(h * w + level_start_index[-1])
488
-
489
- # [b, l, c]
490
- feat_flatten = torch.concat(feat_flatten, 1)
491
- level_start_index.pop()
492
- return (feat_flatten, spatial_shapes, level_start_index)
493
-
494
- def _generate_anchors(self,
495
- spatial_shapes=None,
496
- grid_size=0.05,
497
- dtype=torch.float32,
498
- device='cpu'):
499
- if spatial_shapes is None:
500
- spatial_shapes = [[int(self.eval_spatial_size[0] / s), int(self.eval_spatial_size[1] / s)]
501
- for s in self.feat_strides
502
- ]
503
- anchors = []
504
- for lvl, (h, w) in enumerate(spatial_shapes):
505
- grid_y, grid_x = torch.meshgrid(\
506
- torch.arange(end=h, dtype=dtype), \
507
- torch.arange(end=w, dtype=dtype), indexing='ij')
508
- grid_xy = torch.stack([grid_x, grid_y], -1)
509
- valid_WH = torch.tensor([w, h]).to(dtype)
510
- grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH
511
- wh = torch.ones_like(grid_xy) * grid_size * (2.0 ** lvl)
512
- anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, h * w, 4))
513
-
514
- anchors = torch.concat(anchors, 1).to(device)
515
- valid_mask = ((anchors > self.eps) * (anchors < 1 - self.eps)).all(-1, keepdim=True)
516
- anchors = torch.log(anchors / (1 - anchors))
517
- # anchors = torch.where(valid_mask, anchors, float('inf'))
518
- # anchors[valid_mask] = torch.inf # valid_mask [1, 8400, 1]
519
- anchors = torch.where(valid_mask, anchors, torch.inf)
520
-
521
- return anchors, valid_mask
522
-
523
-
524
- def _get_decoder_input(self,
525
- memory,
526
- spatial_shapes,
527
- denoising_class=None,
528
- denoising_bbox_unact=None):
529
- bs, _, _ = memory.shape
530
- # prepare input for decoder
531
- if self.training or self.eval_spatial_size is None:
532
- anchors, valid_mask = self._generate_anchors(spatial_shapes, device=memory.device)
533
- else:
534
- anchors, valid_mask = self.anchors.to(memory.device), self.valid_mask.to(memory.device)
535
-
536
- # memory = torch.where(valid_mask, memory, 0)
537
- memory = valid_mask.to(memory.dtype) * memory # TODO fix type error for onnx export
538
-
539
- output_memory = self.enc_output(memory)
540
-
541
- enc_outputs_class = self.enc_score_head(output_memory)
542
- enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors
543
-
544
- _, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.num_queries, dim=1)
545
-
546
- reference_points_unact = enc_outputs_coord_unact.gather(dim=1, \
547
- index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_coord_unact.shape[-1]))
548
-
549
- enc_topk_bboxes = F.sigmoid(reference_points_unact)
550
- if denoising_bbox_unact is not None:
551
- reference_points_unact = torch.concat(
552
- [denoising_bbox_unact, reference_points_unact], 1)
553
-
554
- enc_topk_logits = enc_outputs_class.gather(dim=1, \
555
- index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1]))
556
-
557
- # extract region features
558
- if self.learnt_init_query:
559
- target = self.tgt_embed.weight.unsqueeze(0).tile([bs, 1, 1])
560
- else:
561
- target = output_memory.gather(dim=1, \
562
- index=topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1]))
563
- target = target.detach()
564
-
565
- if denoising_class is not None:
566
- target = torch.concat([denoising_class, target], 1)
567
-
568
- return target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits
569
-
570
-
571
- def forward(self, feats, targets=None):
572
-
573
- # input projection and embedding
574
- (memory, spatial_shapes, level_start_index) = self._get_encoder_input(feats)
575
-
576
- # prepare denoising training
577
- if self.training and self.num_denoising > 0:
578
- denoising_class, denoising_bbox_unact, attn_mask, dn_meta = \
579
- get_contrastive_denoising_training_group(targets, \
580
- self.num_classes,
581
- self.num_queries,
582
- self.denoising_class_embed,
583
- num_denoising=self.num_denoising,
584
- label_noise_ratio=self.label_noise_ratio,
585
- box_noise_scale=self.box_noise_scale, )
586
- else:
587
- denoising_class, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None
588
-
589
- target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = \
590
- self._get_decoder_input(memory, spatial_shapes, denoising_class, denoising_bbox_unact)
591
-
592
- # decoder
593
- out_bboxes, out_logits = self.decoder(
594
- target,
595
- init_ref_points_unact,
596
- memory,
597
- spatial_shapes,
598
- level_start_index,
599
- self.dec_bbox_head,
600
- self.dec_score_head,
601
- self.query_pos_head,
602
- attn_mask=attn_mask)
603
-
604
- if self.training and dn_meta is not None:
605
- dn_out_bboxes, out_bboxes = torch.split(out_bboxes, dn_meta['dn_num_split'], dim=2)
606
- dn_out_logits, out_logits = torch.split(out_logits, dn_meta['dn_num_split'], dim=2)
607
-
608
- out = {'pred_logits': out_logits[-1], 'pred_boxes': out_bboxes[-1]}
609
-
610
- if self.training and self.aux_loss:
611
- out['aux_outputs'] = self._set_aux_loss(out_logits[:-1], out_bboxes[:-1])
612
- out['aux_outputs'].extend(self._set_aux_loss([enc_topk_logits], [enc_topk_bboxes]))
613
-
614
- if self.training and dn_meta is not None:
615
- out['dn_aux_outputs'] = self._set_aux_loss(dn_out_logits, dn_out_bboxes)
616
- out['dn_meta'] = dn_meta
617
-
618
- return out
619
-
620
-
621
- @torch.jit.unused
622
- def _set_aux_loss(self, outputs_class, outputs_coord):
623
- # this is a workaround to make torchscript happy, as torchscript
624
- # doesn't support dictionary with non-homogeneous values, such
625
- # as a dict having both a Tensor and a list.
626
- return [{'pred_logits': a, 'pred_boxes': b}
627
- for a, b in zip(outputs_class, outputs_coord)]