Spaces:
Build error
Build error
File size: 14,229 Bytes
28c256d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Optional, Tuple
try:
import lap
except ImportError:
lap = None
import numpy as np
import torch
from mmengine.structures import InstanceData
from mmdet.registry import MODELS, TASK_UTILS
from mmdet.structures import DetDataSample
from mmdet.structures.bbox import (bbox_cxcyah_to_xyxy, bbox_overlaps,
bbox_xyxy_to_cxcyah)
from .base_tracker import BaseTracker
@MODELS.register_module()
class ByteTracker(BaseTracker):
"""Tracker for ByteTrack.
Args:
motion (dict): Configuration of motion. Defaults to None.
obj_score_thrs (dict): Detection score threshold for matching objects.
- high (float): Threshold of the first matching. Defaults to 0.6.
- low (float): Threshold of the second matching. Defaults to 0.1.
init_track_thr (float): Detection score threshold for initializing a
new tracklet. Defaults to 0.7.
weight_iou_with_det_scores (bool): Whether using detection scores to
weight IOU which is used for matching. Defaults to True.
match_iou_thrs (dict): IOU distance threshold for matching between two
frames.
- high (float): Threshold of the first matching. Defaults to 0.1.
- low (float): Threshold of the second matching. Defaults to 0.5.
- tentative (float): Threshold of the matching for tentative
tracklets. Defaults to 0.3.
num_tentatives (int, optional): Number of continuous frames to confirm
a track. Defaults to 3.
"""
def __init__(self,
motion: Optional[dict] = None,
obj_score_thrs: dict = dict(high=0.6, low=0.1),
init_track_thr: float = 0.7,
weight_iou_with_det_scores: bool = True,
match_iou_thrs: dict = dict(high=0.1, low=0.5, tentative=0.3),
num_tentatives: int = 3,
**kwargs):
super().__init__(**kwargs)
if lap is None:
raise RuntimeError('lap is not installed,\
please install it by: pip install lap')
if motion is not None:
self.motion = TASK_UTILS.build(motion)
self.obj_score_thrs = obj_score_thrs
self.init_track_thr = init_track_thr
self.weight_iou_with_det_scores = weight_iou_with_det_scores
self.match_iou_thrs = match_iou_thrs
self.num_tentatives = num_tentatives
@property
def confirmed_ids(self) -> List:
"""Confirmed ids in the tracker."""
ids = [id for id, track in self.tracks.items() if not track.tentative]
return ids
@property
def unconfirmed_ids(self) -> List:
"""Unconfirmed ids in the tracker."""
ids = [id for id, track in self.tracks.items() if track.tentative]
return ids
def init_track(self, id: int, obj: Tuple[torch.Tensor]) -> None:
"""Initialize a track."""
super().init_track(id, obj)
if self.tracks[id].frame_ids[-1] == 0:
self.tracks[id].tentative = False
else:
self.tracks[id].tentative = True
bbox = bbox_xyxy_to_cxcyah(self.tracks[id].bboxes[-1]) # size = (1, 4)
assert bbox.ndim == 2 and bbox.shape[0] == 1
bbox = bbox.squeeze(0).cpu().numpy()
self.tracks[id].mean, self.tracks[id].covariance = self.kf.initiate(
bbox)
def update_track(self, id: int, obj: Tuple[torch.Tensor]) -> None:
"""Update a track."""
super().update_track(id, obj)
if self.tracks[id].tentative:
if len(self.tracks[id]['bboxes']) >= self.num_tentatives:
self.tracks[id].tentative = False
bbox = bbox_xyxy_to_cxcyah(self.tracks[id].bboxes[-1]) # size = (1, 4)
assert bbox.ndim == 2 and bbox.shape[0] == 1
bbox = bbox.squeeze(0).cpu().numpy()
track_label = self.tracks[id]['labels'][-1]
label_idx = self.memo_items.index('labels')
obj_label = obj[label_idx]
assert obj_label == track_label
self.tracks[id].mean, self.tracks[id].covariance = self.kf.update(
self.tracks[id].mean, self.tracks[id].covariance, bbox)
def pop_invalid_tracks(self, frame_id: int) -> None:
"""Pop out invalid tracks."""
invalid_ids = []
for k, v in self.tracks.items():
# case1: disappeared frames >= self.num_frames_retrain
case1 = frame_id - v['frame_ids'][-1] >= self.num_frames_retain
# case2: tentative tracks but not matched in this frame
case2 = v.tentative and v['frame_ids'][-1] != frame_id
if case1 or case2:
invalid_ids.append(k)
for invalid_id in invalid_ids:
self.tracks.pop(invalid_id)
def assign_ids(
self,
ids: List[int],
det_bboxes: torch.Tensor,
det_labels: torch.Tensor,
det_scores: torch.Tensor,
weight_iou_with_det_scores: Optional[bool] = False,
match_iou_thr: Optional[float] = 0.5
) -> Tuple[np.ndarray, np.ndarray]:
"""Assign ids.
Args:
ids (list[int]): Tracking ids.
det_bboxes (Tensor): of shape (N, 4)
det_labels (Tensor): of shape (N,)
det_scores (Tensor): of shape (N,)
weight_iou_with_det_scores (bool, optional): Whether using
detection scores to weight IOU which is used for matching.
Defaults to False.
match_iou_thr (float, optional): Matching threshold.
Defaults to 0.5.
Returns:
tuple(np.ndarray, np.ndarray): The assigning ids.
"""
# get track_bboxes
track_bboxes = np.zeros((0, 4))
for id in ids:
track_bboxes = np.concatenate(
(track_bboxes, self.tracks[id].mean[:4][None]), axis=0)
track_bboxes = torch.from_numpy(track_bboxes).to(det_bboxes)
track_bboxes = bbox_cxcyah_to_xyxy(track_bboxes)
# compute distance
ious = bbox_overlaps(track_bboxes, det_bboxes)
if weight_iou_with_det_scores:
ious *= det_scores
# support multi-class association
track_labels = torch.tensor([
self.tracks[id]['labels'][-1] for id in ids
]).to(det_bboxes.device)
cate_match = det_labels[None, :] == track_labels[:, None]
# to avoid det and track of different categories are matched
cate_cost = (1 - cate_match.int()) * 1e6
dists = (1 - ious + cate_cost).cpu().numpy()
# bipartite match
if dists.size > 0:
cost, row, col = lap.lapjv(
dists, extend_cost=True, cost_limit=1 - match_iou_thr)
else:
row = np.zeros(len(ids)).astype(np.int32) - 1
col = np.zeros(len(det_bboxes)).astype(np.int32) - 1
return row, col
def track(self, data_sample: DetDataSample, **kwargs) -> InstanceData:
"""Tracking forward function.
Args:
data_sample (:obj:`DetDataSample`): The data sample.
It includes information such as `pred_instances`.
Returns:
:obj:`InstanceData`: Tracking results of the input images.
Each InstanceData usually contains ``bboxes``, ``labels``,
``scores`` and ``instances_id``.
"""
metainfo = data_sample.metainfo
bboxes = data_sample.pred_instances.bboxes
labels = data_sample.pred_instances.labels
scores = data_sample.pred_instances.scores
frame_id = metainfo.get('frame_id', -1)
if frame_id == 0:
self.reset()
if not hasattr(self, 'kf'):
self.kf = self.motion
if self.empty or bboxes.size(0) == 0:
valid_inds = scores > self.init_track_thr
scores = scores[valid_inds]
bboxes = bboxes[valid_inds]
labels = labels[valid_inds]
num_new_tracks = bboxes.size(0)
ids = torch.arange(self.num_tracks,
self.num_tracks + num_new_tracks).to(labels)
self.num_tracks += num_new_tracks
else:
# 0. init
ids = torch.full((bboxes.size(0), ),
-1,
dtype=labels.dtype,
device=labels.device)
# get the detection bboxes for the first association
first_det_inds = scores > self.obj_score_thrs['high']
first_det_bboxes = bboxes[first_det_inds]
first_det_labels = labels[first_det_inds]
first_det_scores = scores[first_det_inds]
first_det_ids = ids[first_det_inds]
# get the detection bboxes for the second association
second_det_inds = (~first_det_inds) & (
scores > self.obj_score_thrs['low'])
second_det_bboxes = bboxes[second_det_inds]
second_det_labels = labels[second_det_inds]
second_det_scores = scores[second_det_inds]
second_det_ids = ids[second_det_inds]
# 1. use Kalman Filter to predict current location
for id in self.confirmed_ids:
# track is lost in previous frame
if self.tracks[id].frame_ids[-1] != frame_id - 1:
self.tracks[id].mean[7] = 0
(self.tracks[id].mean,
self.tracks[id].covariance) = self.kf.predict(
self.tracks[id].mean, self.tracks[id].covariance)
# 2. first match
first_match_track_inds, first_match_det_inds = self.assign_ids(
self.confirmed_ids, first_det_bboxes, first_det_labels,
first_det_scores, self.weight_iou_with_det_scores,
self.match_iou_thrs['high'])
# '-1' mean a detection box is not matched with tracklets in
# previous frame
valid = first_match_det_inds > -1
first_det_ids[valid] = torch.tensor(
self.confirmed_ids)[first_match_det_inds[valid]].to(labels)
first_match_det_bboxes = first_det_bboxes[valid]
first_match_det_labels = first_det_labels[valid]
first_match_det_scores = first_det_scores[valid]
first_match_det_ids = first_det_ids[valid]
assert (first_match_det_ids > -1).all()
first_unmatch_det_bboxes = first_det_bboxes[~valid]
first_unmatch_det_labels = first_det_labels[~valid]
first_unmatch_det_scores = first_det_scores[~valid]
first_unmatch_det_ids = first_det_ids[~valid]
assert (first_unmatch_det_ids == -1).all()
# 3. use unmatched detection bboxes from the first match to match
# the unconfirmed tracks
(tentative_match_track_inds,
tentative_match_det_inds) = self.assign_ids(
self.unconfirmed_ids, first_unmatch_det_bboxes,
first_unmatch_det_labels, first_unmatch_det_scores,
self.weight_iou_with_det_scores,
self.match_iou_thrs['tentative'])
valid = tentative_match_det_inds > -1
first_unmatch_det_ids[valid] = torch.tensor(self.unconfirmed_ids)[
tentative_match_det_inds[valid]].to(labels)
# 4. second match for unmatched tracks from the first match
first_unmatch_track_ids = []
for i, id in enumerate(self.confirmed_ids):
# tracklet is not matched in the first match
case_1 = first_match_track_inds[i] == -1
# tracklet is not lost in the previous frame
case_2 = self.tracks[id].frame_ids[-1] == frame_id - 1
if case_1 and case_2:
first_unmatch_track_ids.append(id)
second_match_track_inds, second_match_det_inds = self.assign_ids(
first_unmatch_track_ids, second_det_bboxes, second_det_labels,
second_det_scores, False, self.match_iou_thrs['low'])
valid = second_match_det_inds > -1
second_det_ids[valid] = torch.tensor(first_unmatch_track_ids)[
second_match_det_inds[valid]].to(ids)
# 5. gather all matched detection bboxes from step 2-4
# we only keep matched detection bboxes in second match, which
# means the id != -1
valid = second_det_ids > -1
bboxes = torch.cat(
(first_match_det_bboxes, first_unmatch_det_bboxes), dim=0)
bboxes = torch.cat((bboxes, second_det_bboxes[valid]), dim=0)
labels = torch.cat(
(first_match_det_labels, first_unmatch_det_labels), dim=0)
labels = torch.cat((labels, second_det_labels[valid]), dim=0)
scores = torch.cat(
(first_match_det_scores, first_unmatch_det_scores), dim=0)
scores = torch.cat((scores, second_det_scores[valid]), dim=0)
ids = torch.cat((first_match_det_ids, first_unmatch_det_ids),
dim=0)
ids = torch.cat((ids, second_det_ids[valid]), dim=0)
# 6. assign new ids
new_track_inds = ids == -1
ids[new_track_inds] = torch.arange(
self.num_tracks,
self.num_tracks + new_track_inds.sum()).to(labels)
self.num_tracks += new_track_inds.sum()
self.update(
ids=ids,
bboxes=bboxes,
scores=scores,
labels=labels,
frame_ids=frame_id)
# update pred_track_instances
pred_track_instances = InstanceData()
pred_track_instances.bboxes = bboxes
pred_track_instances.labels = labels
pred_track_instances.scores = scores
pred_track_instances.instances_id = ids
return pred_track_instances
|