Spaces:
Build error
Build error
File size: 17,200 Bytes
28c256d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from collections import defaultdict
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule
from torch import Tensor
from torch.nn.modules.utils import _pair
from mmdet.models.losses import accuracy
from mmdet.models.task_modules import SamplingResult
from mmdet.models.task_modules.tracking import embed_similarity
from mmdet.registry import MODELS
@MODELS.register_module()
class RoIEmbedHead(BaseModule):
"""The roi embed head.
This module is used in multi-object tracking methods, such as MaskTrack
R-CNN.
Args:
num_convs (int): The number of convoluational layers to embed roi
features. Defaults to 0.
num_fcs (int): The number of fully connection layers to embed roi
features. Defaults to 0.
roi_feat_size (int|tuple(int)): The spatial size of roi features.
Defaults to 7.
in_channels (int): The input channel of roi features. Defaults to 256.
conv_out_channels (int): The output channel of roi features after
forwarding convoluational layers. Defaults to 256.
with_avg_pool (bool): Whether use average pooling before passing roi
features into fully connection layers. Defaults to False.
fc_out_channels (int): The output channel of roi features after
forwarding fully connection layers. Defaults to 1024.
conv_cfg (dict): Config dict for convolution layer. Defaults to None,
which means using conv2d.
norm_cfg (dict): Config dict for normalization layer. Defaults to None.
loss_match (dict): The loss function. Defaults to
dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
init_cfg (dict): Configuration of initialization. Defaults to None.
"""
def __init__(self,
num_convs: int = 0,
num_fcs: int = 0,
roi_feat_size: int = 7,
in_channels: int = 256,
conv_out_channels: int = 256,
with_avg_pool: bool = False,
fc_out_channels: int = 1024,
conv_cfg: Optional[dict] = None,
norm_cfg: Optional[dict] = None,
loss_match: dict = dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
init_cfg: Optional[dict] = None,
**kwargs):
super(RoIEmbedHead, self).__init__(init_cfg=init_cfg)
self.num_convs = num_convs
self.num_fcs = num_fcs
self.roi_feat_size = _pair(roi_feat_size)
self.roi_feat_area = self.roi_feat_size[0] * self.roi_feat_size[1]
self.in_channels = in_channels
self.conv_out_channels = conv_out_channels
self.with_avg_pool = with_avg_pool
self.fc_out_channels = fc_out_channels
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.loss_match = MODELS.build(loss_match)
self.fp16_enabled = False
if self.with_avg_pool:
self.avg_pool = nn.AvgPool2d(self.roi_feat_size)
# add convs and fcs
self.convs, self.fcs, self.last_layer_dim = self._add_conv_fc_branch(
self.num_convs, self.num_fcs, self.in_channels)
self.relu = nn.ReLU(inplace=True)
def _add_conv_fc_branch(
self, num_branch_convs: int, num_branch_fcs: int,
in_channels: int) -> Tuple[nn.ModuleList, nn.ModuleList, int]:
"""Add shared or separable branch.
convs -> avg pool (optional) -> fcs
"""
last_layer_dim = in_channels
# add branch specific conv layers
branch_convs = nn.ModuleList()
if num_branch_convs > 0:
for i in range(num_branch_convs):
conv_in_channels = (
last_layer_dim if i == 0 else self.conv_out_channels)
branch_convs.append(
ConvModule(
conv_in_channels,
self.conv_out_channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
last_layer_dim = self.conv_out_channels
# add branch specific fc layers
branch_fcs = nn.ModuleList()
if num_branch_fcs > 0:
if not self.with_avg_pool:
last_layer_dim *= self.roi_feat_area
for i in range(num_branch_fcs):
fc_in_channels = (
last_layer_dim if i == 0 else self.fc_out_channels)
branch_fcs.append(
nn.Linear(fc_in_channels, self.fc_out_channels))
last_layer_dim = self.fc_out_channels
return branch_convs, branch_fcs, last_layer_dim
@property
def custom_activation(self):
return getattr(self.loss_match, 'custom_activation', False)
def extract_feat(self, x: Tensor,
num_x_per_img: List[int]) -> Tuple[Tensor]:
"""Extract feature from the input `x`, and split the output to a list.
Args:
x (Tensor): of shape [N, C, H, W]. N is the number of proposals.
num_x_per_img (list[int]): The `x` contains proposals of
multi-images. `num_x_per_img` denotes the number of proposals
for each image.
Returns:
list[Tensor]: Each Tensor denotes the embed features belonging to
an image in a batch.
"""
if self.num_convs > 0:
for conv in self.convs:
x = conv(x)
if self.num_fcs > 0:
if self.with_avg_pool:
x = self.avg_pool(x)
x = x.flatten(1)
for fc in self.fcs:
x = self.relu(fc(x))
else:
x = x.flatten(1)
x_split = torch.split(x, num_x_per_img, dim=0)
return x_split
def forward(
self, x: Tensor, ref_x: Tensor, num_x_per_img: List[int],
num_x_per_ref_img: List[int]
) -> Tuple[Tuple[Tensor], Tuple[Tensor]]:
"""Computing the similarity scores between `x` and `ref_x`.
Args:
x (Tensor): of shape [N, C, H, W]. N is the number of key frame
proposals.
ref_x (Tensor): of shape [M, C, H, W]. M is the number of reference
frame proposals.
num_x_per_img (list[int]): The `x` contains proposals of
multi-images. `num_x_per_img` denotes the number of proposals
for each key image.
num_x_per_ref_img (list[int]): The `ref_x` contains proposals of
multi-images. `num_x_per_ref_img` denotes the number of
proposals for each reference image.
Returns:
tuple[tuple[Tensor], tuple[Tensor]]: Each tuple of tensor denotes
the embed features belonging to an image in a batch.
"""
x_split = self.extract_feat(x, num_x_per_img)
ref_x_split = self.extract_feat(ref_x, num_x_per_ref_img)
return x_split, ref_x_split
def get_targets(self, sampling_results: List[SamplingResult],
gt_instance_ids: List[Tensor],
ref_gt_instance_ids: List[Tensor]) -> Tuple[List, List]:
"""Calculate the ground truth for all samples in a batch according to
the sampling_results.
Args:
sampling_results (List[obj:SamplingResult]): Assign results of
all images in a batch after sampling.
gt_instance_ids (list[Tensor]): The instance ids of gt_bboxes of
all images in a batch, each tensor has shape (num_gt, ).
ref_gt_instance_ids (list[Tensor]): The instance ids of gt_bboxes
of all reference images in a batch, each tensor has shape
(num_gt, ).
Returns:
Tuple[list[Tensor]]: Ground truth for proposals in a batch.
Containing the following list of Tensors:
- track_id_targets (list[Tensor]): The instance ids of
Gt_labels for all proposals in a batch, each tensor in list
has shape (num_proposals,).
- track_id_weights (list[Tensor]): Labels_weights for
all proposals in a batch, each tensor in list has
shape (num_proposals,).
"""
track_id_targets = []
track_id_weights = []
for res, gt_instance_id, ref_gt_instance_id in zip(
sampling_results, gt_instance_ids, ref_gt_instance_ids):
pos_instance_ids = gt_instance_id[res.pos_assigned_gt_inds]
pos_match_id = gt_instance_id.new_zeros(len(pos_instance_ids))
for i, id in enumerate(pos_instance_ids):
if id in ref_gt_instance_id:
pos_match_id[i] = ref_gt_instance_id.tolist().index(id) + 1
track_id_target = gt_instance_id.new_zeros(
len(res.bboxes), dtype=torch.int64)
track_id_target[:len(res.pos_bboxes)] = pos_match_id
track_id_weight = res.bboxes.new_zeros(len(res.bboxes))
track_id_weight[:len(res.pos_bboxes)] = 1.0
track_id_targets.append(track_id_target)
track_id_weights.append(track_id_weight)
return track_id_targets, track_id_weights
def loss(
self,
bbox_feats: Tensor,
ref_bbox_feats: Tensor,
num_bbox_per_img: int,
num_bbox_per_ref_img: int,
sampling_results: List[SamplingResult],
gt_instance_ids: List[Tensor],
ref_gt_instance_ids: List[Tensor],
reduction_override: Optional[str] = None,
) -> dict:
"""Calculate the loss in a batch.
Args:
bbox_feats (Tensor): of shape [N, C, H, W]. N is the number of
bboxes.
ref_bbox_feats (Tensor): of shape [M, C, H, W]. M is the number of
reference bboxes.
num_bbox_per_img (list[int]): The `bbox_feats` contains proposals
of multi-images. `num_bbox_per_img` denotes the number of
proposals for each key image.
num_bbox_per_ref_img (list[int]): The `ref_bbox_feats` contains
proposals of multi-images. `num_bbox_per_ref_img` denotes the
number of proposals for each reference image.
sampling_results (List[obj:SamplingResult]): Assign results of
all images in a batch after sampling.
gt_instance_ids (list[Tensor]): The instance ids of gt_bboxes of
all images in a batch, each tensor has shape (num_gt, ).
ref_gt_instance_ids (list[Tensor]): The instance ids of gt_bboxes
of all reference images in a batch, each tensor has shape
(num_gt, ).
reduction_override (str, optional): The method used to reduce the
loss. Options are "none", "mean" and "sum".
Returns:
dict[str, Tensor]: a dictionary of loss components.
"""
x_split, ref_x_split = self(bbox_feats, ref_bbox_feats,
num_bbox_per_img, num_bbox_per_ref_img)
losses = self.loss_by_feat(x_split, ref_x_split, sampling_results,
gt_instance_ids, ref_gt_instance_ids,
reduction_override)
return losses
def loss_by_feat(self,
x_split: Tuple[Tensor],
ref_x_split: Tuple[Tensor],
sampling_results: List[SamplingResult],
gt_instance_ids: List[Tensor],
ref_gt_instance_ids: List[Tensor],
reduction_override: Optional[str] = None) -> dict:
"""Calculate losses.
Args:
x_split (Tensor): The embed features belonging to key image.
ref_x_split (Tensor): The embed features belonging to ref image.
sampling_results (List[obj:SamplingResult]): Assign results of
all images in a batch after sampling.
gt_instance_ids (list[Tensor]): The instance ids of gt_bboxes of
all images in a batch, each tensor has shape (num_gt, ).
ref_gt_instance_ids (list[Tensor]): The instance ids of gt_bboxes
of all reference images in a batch, each tensor has shape
(num_gt, ).
reduction_override (str, optional): The method used to reduce the
loss. Options are "none", "mean" and "sum".
Returns:
dict[str, Tensor]: a dictionary of loss components.
"""
track_id_targets, track_id_weights = self.get_targets(
sampling_results, gt_instance_ids, ref_gt_instance_ids)
assert isinstance(track_id_targets, list)
assert isinstance(track_id_weights, list)
assert len(track_id_weights) == len(track_id_targets)
losses = defaultdict(list)
similarity_logits = []
for one_x, one_ref_x in zip(x_split, ref_x_split):
similarity_logit = embed_similarity(
one_x, one_ref_x, method='dot_product')
dummy = similarity_logit.new_zeros(one_x.shape[0], 1)
similarity_logit = torch.cat((dummy, similarity_logit), dim=1)
similarity_logits.append(similarity_logit)
assert isinstance(similarity_logits, list)
assert len(similarity_logits) == len(track_id_targets)
for similarity_logit, track_id_target, track_id_weight in zip(
similarity_logits, track_id_targets, track_id_weights):
avg_factor = max(torch.sum(track_id_target > 0).float().item(), 1.)
if similarity_logit.numel() > 0:
loss_match = self.loss_match(
similarity_logit,
track_id_target,
track_id_weight,
avg_factor=avg_factor,
reduction_override=reduction_override)
if isinstance(loss_match, dict):
for key, value in loss_match.items():
losses[key].append(value)
else:
losses['loss_match'].append(loss_match)
valid_index = track_id_weight > 0
valid_similarity_logit = similarity_logit[valid_index]
valid_track_id_target = track_id_target[valid_index]
if self.custom_activation:
match_accuracy = self.loss_match.get_accuracy(
valid_similarity_logit, valid_track_id_target)
for key, value in match_accuracy.items():
losses[key].append(value)
else:
losses['match_accuracy'].append(
accuracy(valid_similarity_logit,
valid_track_id_target))
for key, value in losses.items():
losses[key] = sum(losses[key]) / len(similarity_logits)
return losses
def predict(self, roi_feats: Tensor,
prev_roi_feats: Tensor) -> List[Tensor]:
"""Perform forward propagation of the tracking head and predict
tracking results on the features of the upstream network.
Args:
roi_feats (Tensor): Feature map of current images rois.
prev_roi_feats (Tensor): Feature map of previous images rois.
Returns:
list[Tensor]: The predicted similarity_logits of each pair of key
image and reference image.
"""
x_split, ref_x_split = self(roi_feats, prev_roi_feats,
[roi_feats.shape[0]],
[prev_roi_feats.shape[0]])
similarity_logits = self.predict_by_feat(x_split, ref_x_split)
return similarity_logits
def predict_by_feat(self, x_split: Tuple[Tensor],
ref_x_split: Tuple[Tensor]) -> List[Tensor]:
"""Get similarity_logits.
Args:
x_split (Tensor): The embed features belonging to key image.
ref_x_split (Tensor): The embed features belonging to ref image.
Returns:
list[Tensor]: The predicted similarity_logits of each pair of key
image and reference image.
"""
similarity_logits = []
for one_x, one_ref_x in zip(x_split, ref_x_split):
similarity_logit = embed_similarity(
one_x, one_ref_x, method='dot_product')
dummy = similarity_logit.new_zeros(one_x.shape[0], 1)
similarity_logit = torch.cat((dummy, similarity_logit), dim=1)
similarity_logits.append(similarity_logit)
return similarity_logits
|