Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import os | |
| import torch | |
| from scipy.optimize import linear_sum_assignment | |
| from torch import Tensor | |
| import torch.nn.functional as F | |
| from mmdet.registry import MODELS | |
| from mmdet.structures import SampleList, TrackDataSample | |
| from seg.models.detectors import Mask2formerVideo | |
| from seg.models.utils import mask_pool | |
| BACKBONE_BATCH = 50 | |
| def video_split(total, tube_size, overlap=0): | |
| assert tube_size > overlap | |
| total -= overlap | |
| tube_size -= overlap | |
| if total % tube_size == 0: | |
| splits = total // tube_size | |
| else: | |
| splits = (total // tube_size) + 1 | |
| ind_list = [] | |
| for i in range(splits): | |
| ind_list.append((i + 1) * tube_size) | |
| diff = ind_list[-1] - total | |
| # currently only supports diff < splits | |
| if diff < splits: | |
| for i in range(diff): | |
| ind_list[splits - 1 - i] -= diff - i | |
| else: | |
| ind_list[splits - 1] -= diff | |
| assert ind_list[splits - 1] > 0 | |
| print("Warning: {} / {}".format(total, tube_size)) | |
| for idx in range(len(ind_list)): | |
| ind_list[idx] += overlap | |
| return ind_list | |
| def match_from_embeds(tgt_embds, cur_embds): | |
| cur_embds = cur_embds / cur_embds.norm(dim=-1, keepdim=True) | |
| tgt_embds = tgt_embds / tgt_embds.norm(dim=-1, keepdim=True) | |
| cos_sim = torch.bmm(cur_embds, tgt_embds.transpose(1, 2)) | |
| cost_embd = 1 - cos_sim | |
| C = 1.0 * cost_embd | |
| C = C.cpu() | |
| indices = [] | |
| for i in range(len(cur_embds)): | |
| indice = linear_sum_assignment(C[i].transpose(0, 1)) # target x current | |
| indice = indice[1] # permutation that makes current aligns to target | |
| indices.append(indice) | |
| return indices | |
| class Mask2formerVideoMinVIS(Mask2formerVideo): | |
| r"""Implementation of `Per-Pixel Classification is | |
| NOT All You Need for Semantic Segmentation | |
| <https://arxiv.org/pdf/2107.06278>`_.""" | |
| OVERLAPPING = None | |
| def __init__(self, | |
| *args, | |
| clip_size=6, | |
| clip_size_small=3, | |
| whole_clip_thr=0, | |
| small_clip_thr=12, | |
| overlap=0, | |
| **kwargs, | |
| ): | |
| super().__init__(*args, **kwargs) | |
| self.clip_size = clip_size | |
| self.clip_size_small = clip_size_small | |
| self.overlap = overlap | |
| self.whole_clip_thr = whole_clip_thr | |
| self.small_clip_thr = small_clip_thr | |
| def predict(self, | |
| batch_inputs: Tensor, | |
| batch_data_samples: SampleList, | |
| rescale: bool = True) -> SampleList: | |
| """Predict results from a batch of inputs and data samples with post- | |
| processing. | |
| Args: | |
| batch_inputs (Tensor): Inputs with shape (N, C, H, W). | |
| batch_data_samples (List[:obj:`DetDataSample`]): The Data | |
| Samples. It usually includes information such as | |
| `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. | |
| rescale (bool): Whether to rescale the results. | |
| Defaults to True. | |
| Returns: | |
| list[:obj:`DetDataSample`]: Detection results of the | |
| input images. Each DetDataSample usually contain | |
| 'pred_instances' and `pred_panoptic_seg`. And the | |
| ``pred_instances`` usually contains following keys. | |
| - scores (Tensor): Classification scores, has a shape | |
| (num_instance, ) | |
| - labels (Tensor): Labels of bboxes, has a shape | |
| (num_instances, ). | |
| - bboxes (Tensor): Has a shape (num_instances, 4), | |
| the last dimension 4 arrange as (x1, y1, x2, y2). | |
| - masks (Tensor): Has a shape (num_instances, H, W). | |
| And the ``pred_panoptic_seg`` contains the following key | |
| - sem_seg (Tensor): panoptic segmentation mask, has a | |
| shape (1, h, w). | |
| """ | |
| assert isinstance(batch_data_samples[0], TrackDataSample) | |
| bs, num_frames, three, h, w = batch_inputs.shape | |
| assert three == 3, "Only supporting images with 3 channels." | |
| if num_frames <= self.whole_clip_thr: | |
| return super().predict(batch_inputs, batch_data_samples, rescale) | |
| device = batch_inputs.device | |
| if num_frames > self.small_clip_thr: | |
| tube_inds = video_split(num_frames, self.clip_size, self.overlap) | |
| else: | |
| tube_inds = video_split(num_frames, self.clip_size_small, self.overlap) | |
| if num_frames > BACKBONE_BATCH: | |
| feat_bins = [[], [], [], []] | |
| num_clip = num_frames // BACKBONE_BATCH + 1 | |
| step_size = num_frames // num_clip + 1 | |
| for i in range(num_clip): | |
| start = i * step_size | |
| end = min(num_frames, (i + 1) * step_size) | |
| inputs = batch_inputs[:, start:end].reshape( | |
| (bs * (end - start), three, h, w)) | |
| _feats = self.extract_feat(inputs) | |
| assert len(_feats) == 4 | |
| for idx, item in enumerate(_feats): | |
| feat_bins[idx].append(item.to('cpu')) | |
| feats = [] | |
| for item in feat_bins: | |
| feat = torch.cat(item, dim=0) | |
| assert feat.size(0) == bs * num_frames, "{} vs {}".format(feat.size(0), bs * num_frames) | |
| feats.append(feat) | |
| else: | |
| x = batch_inputs.reshape((bs * num_frames, three, h, w)) | |
| feats = self.extract_feat(x) | |
| assert len(feats[0]) == bs * num_frames | |
| del batch_inputs | |
| ind_pre = 0 | |
| cls_list = [] | |
| mask_list = [] | |
| query_list = [] | |
| iou_list = [] | |
| flag = False | |
| for ind in tube_inds: | |
| tube_feats = [itm[ind_pre:ind].to(device=device) for itm in feats] | |
| tube_data_samples = [TrackDataSample(video_data_samples=itm[ind_pre:ind]) for itm in batch_data_samples] | |
| _mask_cls_results, _mask_pred_results, _query_feat, _iou_results = \ | |
| self.panoptic_head.predict(tube_feats, tube_data_samples, return_query=True) | |
| cls_list.append(_mask_cls_results) | |
| if not flag: | |
| mask_list.append(_mask_pred_results.cpu()) | |
| flag = True | |
| else: | |
| mask_list.append(_mask_pred_results[:, self.overlap:].cpu()) | |
| query_list.append(_query_feat.cpu()) | |
| iou_list.append(_iou_results) | |
| ind_pre = ind | |
| ind_pre -= self.overlap | |
| num_tubes = len(tube_inds) | |
| out_cls = [cls_list[0]] | |
| out_mask = [mask_list[0]] | |
| out_embed = [query_list[0]] | |
| ious = [iou_list[0]] | |
| for i in range(1, num_tubes): | |
| indices = match_from_embeds(out_embed[-1], query_list[i]) | |
| indices = indices[0] # since bs == 1 | |
| out_cls.append(cls_list[i][:, indices]) | |
| out_mask.append(mask_list[i][:, indices]) | |
| out_embed.append(query_list[i][:, indices]) | |
| ious.append(iou_list[i][:, indices]) | |
| del mask_list | |
| del out_embed | |
| mask_cls_results = sum(out_cls) / num_tubes | |
| mask_pred_results = torch.cat(out_mask, dim=2) | |
| iou_results = sum(ious) / num_tubes | |
| if self.OVERLAPPING is not None: | |
| assert len(self.OVERLAPPING) == self.num_classes | |
| mask_cls_results = self.open_voc_inference(feats, mask_cls_results, mask_pred_results) | |
| del feats | |
| mask_cls_results = mask_cls_results.to(device='cpu') | |
| iou_results = iou_results.to(device='cpu') | |
| id_assigner = [{} for _ in range(bs)] | |
| for frame_id in range(num_frames): | |
| results_list_img = self.panoptic_fusion_head.predict( | |
| mask_cls_results, | |
| mask_pred_results[:, :, frame_id], | |
| [batch_data_samples[idx][frame_id] for idx in range(bs)], | |
| iou_results=iou_results, | |
| rescale=rescale | |
| ) | |
| if frame_id == 0 and 'pro_results' in results_list_img[0]: | |
| for batch_id in range(bs): | |
| mask = results_list_img[batch_id]['pro_results'].to(dtype=torch.int32) | |
| mask_gt = torch.tensor(batch_data_samples[batch_id][frame_id].gt_instances.masks.masks, dtype=torch.int32) | |
| a, b = mask.flatten(1), mask_gt.flatten(1) | |
| intersection = torch.einsum('nc,mc->nm', a, b) | |
| union = (a[:, None] + b[None]).clamp(min=0, max=1).sum(-1) | |
| iou_cost = intersection / union | |
| a_indices, b_indices = linear_sum_assignment(-iou_cost.numpy()) | |
| for a_ind, b_ind in zip(a_indices, b_indices): | |
| id_assigner[batch_id][a_ind] = batch_data_samples[batch_id][frame_id].gt_instances.instances_ids[b_ind].item() | |
| if 'pro_results' in results_list_img[0]: | |
| h, w = results_list_img[batch_id]['pro_results'].shape[-2:] | |
| seg_map = torch.full((h, w), 0, dtype=torch.int32, device='cpu') | |
| for ind in id_assigner[batch_id]: | |
| seg_map[results_list_img[batch_id]['pro_results'][ind]] = id_assigner[batch_id][ind] | |
| results_list_img[batch_id]['pro_results'] = seg_map.cpu().numpy() | |
| _ = self.add_track_pred_to_datasample( | |
| [batch_data_samples[idx][frame_id] for idx in range(bs)], results_list_img | |
| ) | |
| results = batch_data_samples | |
| return results | |
| def open_voc_inference(self, feats, mask_cls_results, mask_pred_results): | |
| if len(mask_pred_results.shape) == 5: | |
| batch_size = mask_cls_results.shape[0] | |
| num_frames = mask_pred_results.shape[2] | |
| mask_pred_results = mask_pred_results.permute(0, 2, 1, 3, 4).flatten(0, 1) | |
| else: | |
| batch_size = mask_cls_results.shape[0] | |
| num_frames = 0 | |
| clip_feat = self.backbone.get_clip_feature(feats[-1]).to(device=mask_cls_results.device) | |
| clip_feat_mask = F.interpolate( | |
| mask_pred_results, | |
| size=clip_feat.shape[-2:], | |
| mode='bilinear', | |
| align_corners=False | |
| ).to(device=mask_cls_results.device) | |
| if num_frames > 0: | |
| clip_feat_mask = clip_feat_mask.unflatten(0, (batch_size, num_frames)).permute(0, 2, 1, 3, 4).flatten(2, 3) | |
| clip_feat = clip_feat.unflatten(0, (batch_size, num_frames)).permute(0, 2, 1, 3, 4).flatten(2, 3) | |
| instance_feat = mask_pool(clip_feat, clip_feat_mask) | |
| instance_feat = self.backbone.forward_feat(instance_feat) | |
| clip_logit = self.panoptic_head.forward_logit(instance_feat) | |
| clip_logit = clip_logit[..., :-1] | |
| query_logit = mask_cls_results[..., :-1] | |
| clip_logit = clip_logit.softmax(-1) | |
| query_logit = query_logit.softmax(-1) | |
| overlapping_mask = torch.tensor(self.OVERLAPPING, dtype=torch.float32, device=clip_logit.device) | |
| valid_masking = ((clip_feat_mask > 0).to(dtype=torch.float32).flatten(-2).sum(-1) > 0).to( | |
| torch.float32)[..., None] | |
| alpha = torch.ones_like(clip_logit) * self.alpha * valid_masking | |
| beta = torch.ones_like(clip_logit) * self.beta * valid_masking | |
| cls_logits_seen = ( | |
| (query_logit ** (1 - alpha) * clip_logit ** alpha).log() | |
| * overlapping_mask | |
| ) | |
| cls_logits_unseen = ( | |
| (query_logit ** (1 - beta) * clip_logit ** beta).log() | |
| * (1 - overlapping_mask) | |
| ) | |
| cls_results = cls_logits_seen + cls_logits_unseen | |
| is_void_prob = F.softmax(mask_cls_results, dim=-1)[..., -1:] | |
| mask_cls_results = torch.cat([ | |
| cls_results.softmax(-1) * (1.0 - is_void_prob), is_void_prob], dim=-1) | |
| mask_cls_results = torch.log(mask_cls_results + 1e-8) | |
| return mask_cls_results | |