Spaces:
No application file
No application file
# -*- coding: utf-8 -*- | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
import torch | |
from torch import nn | |
from detectron2.structures import ImageList | |
from ..backbone import build_backbone | |
from ..postprocessing import detector_postprocess, sem_seg_postprocess | |
from ..proposal_generator import build_proposal_generator | |
from ..roi_heads import build_roi_heads | |
from .build import META_ARCH_REGISTRY | |
from .semantic_seg import build_sem_seg_head | |
__all__ = ["PanopticFPN"] | |
class PanopticFPN(nn.Module): | |
""" | |
Implement the paper :paper:`PanopticFPN`. | |
""" | |
def __init__(self, cfg): | |
super().__init__() | |
self.instance_loss_weight = cfg.MODEL.PANOPTIC_FPN.INSTANCE_LOSS_WEIGHT | |
# options when combining instance & semantic outputs | |
self.combine_on = cfg.MODEL.PANOPTIC_FPN.COMBINE.ENABLED | |
self.combine_overlap_threshold = cfg.MODEL.PANOPTIC_FPN.COMBINE.OVERLAP_THRESH | |
self.combine_stuff_area_limit = cfg.MODEL.PANOPTIC_FPN.COMBINE.STUFF_AREA_LIMIT | |
self.combine_instances_confidence_threshold = ( | |
cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH | |
) | |
self.backbone = build_backbone(cfg) | |
self.proposal_generator = build_proposal_generator(cfg, self.backbone.output_shape()) | |
self.roi_heads = build_roi_heads(cfg, self.backbone.output_shape()) | |
self.sem_seg_head = build_sem_seg_head(cfg, self.backbone.output_shape()) | |
self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(-1, 1, 1)) | |
self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(-1, 1, 1)) | |
def device(self): | |
return self.pixel_mean.device | |
def forward(self, batched_inputs): | |
""" | |
Args: | |
batched_inputs: a list, batched outputs of :class:`DatasetMapper`. | |
Each item in the list contains the inputs for one image. | |
For now, each item in the list is a dict that contains: | |
* "image": Tensor, image in (C, H, W) format. | |
* "instances": Instances | |
* "sem_seg": semantic segmentation ground truth. | |
* Other information that's included in the original dicts, such as: | |
"height", "width" (int): the output resolution of the model, used in inference. | |
See :meth:`postprocess` for details. | |
Returns: | |
list[dict]: | |
each dict is the results for one image. The dict contains the following keys: | |
* "instances": see :meth:`GeneralizedRCNN.forward` for its format. | |
* "sem_seg": see :meth:`SemanticSegmentor.forward` for its format. | |
* "panoptic_seg": available when `PANOPTIC_FPN.COMBINE.ENABLED`. | |
See the return value of | |
:func:`combine_semantic_and_instance_outputs` for its format. | |
""" | |
images = [x["image"].to(self.device) for x in batched_inputs] | |
images = [(x - self.pixel_mean) / self.pixel_std for x in images] | |
images = ImageList.from_tensors(images, self.backbone.size_divisibility) | |
features = self.backbone(images.tensor) | |
if "proposals" in batched_inputs[0]: | |
proposals = [x["proposals"].to(self.device) for x in batched_inputs] | |
proposal_losses = {} | |
if "sem_seg" in batched_inputs[0]: | |
gt_sem_seg = [x["sem_seg"].to(self.device) for x in batched_inputs] | |
gt_sem_seg = ImageList.from_tensors( | |
gt_sem_seg, self.backbone.size_divisibility, self.sem_seg_head.ignore_value | |
).tensor | |
else: | |
gt_sem_seg = None | |
sem_seg_results, sem_seg_losses = self.sem_seg_head(features, gt_sem_seg) | |
if "instances" in batched_inputs[0]: | |
gt_instances = [x["instances"].to(self.device) for x in batched_inputs] | |
else: | |
gt_instances = None | |
if self.proposal_generator: | |
proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) | |
detector_results, detector_losses = self.roi_heads( | |
images, features, proposals, gt_instances | |
) | |
if self.training: | |
losses = {} | |
losses.update(sem_seg_losses) | |
losses.update({k: v * self.instance_loss_weight for k, v in detector_losses.items()}) | |
losses.update(proposal_losses) | |
return losses | |
processed_results = [] | |
for sem_seg_result, detector_result, input_per_image, image_size in zip( | |
sem_seg_results, detector_results, batched_inputs, images.image_sizes | |
): | |
height = input_per_image.get("height", image_size[0]) | |
width = input_per_image.get("width", image_size[1]) | |
sem_seg_r = sem_seg_postprocess(sem_seg_result, image_size, height, width) | |
detector_r = detector_postprocess(detector_result, height, width) | |
processed_results.append({"sem_seg": sem_seg_r, "instances": detector_r}) | |
if self.combine_on: | |
panoptic_r = combine_semantic_and_instance_outputs( | |
detector_r, | |
sem_seg_r.argmax(dim=0), | |
self.combine_overlap_threshold, | |
self.combine_stuff_area_limit, | |
self.combine_instances_confidence_threshold, | |
) | |
processed_results[-1]["panoptic_seg"] = panoptic_r | |
return processed_results | |
def combine_semantic_and_instance_outputs( | |
instance_results, | |
semantic_results, | |
overlap_threshold, | |
stuff_area_limit, | |
instances_confidence_threshold, | |
): | |
""" | |
Implement a simple combining logic following | |
"combine_semantic_and_instance_predictions.py" in panopticapi | |
to produce panoptic segmentation outputs. | |
Args: | |
instance_results: output of :func:`detector_postprocess`. | |
semantic_results: an (H, W) tensor, each is the contiguous semantic | |
category id | |
Returns: | |
panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment. | |
segments_info (list[dict]): Describe each segment in `panoptic_seg`. | |
Each dict contains keys "id", "category_id", "isthing". | |
""" | |
panoptic_seg = torch.zeros_like(semantic_results, dtype=torch.int32) | |
# sort instance outputs by scores | |
sorted_inds = torch.argsort(-instance_results.scores) | |
current_segment_id = 0 | |
segments_info = [] | |
instance_masks = instance_results.pred_masks.to(dtype=torch.bool, device=panoptic_seg.device) | |
# Add instances one-by-one, check for overlaps with existing ones | |
for inst_id in sorted_inds: | |
score = instance_results.scores[inst_id].item() | |
if score < instances_confidence_threshold: | |
break | |
mask = instance_masks[inst_id] # H,W | |
mask_area = mask.sum().item() | |
if mask_area == 0: | |
continue | |
intersect = (mask > 0) & (panoptic_seg > 0) | |
intersect_area = intersect.sum().item() | |
if intersect_area * 1.0 / mask_area > overlap_threshold: | |
continue | |
if intersect_area > 0: | |
mask = mask & (panoptic_seg == 0) | |
current_segment_id += 1 | |
panoptic_seg[mask] = current_segment_id | |
segments_info.append( | |
{ | |
"id": current_segment_id, | |
"isthing": True, | |
"score": score, | |
"category_id": instance_results.pred_classes[inst_id].item(), | |
"instance_id": inst_id.item(), | |
} | |
) | |
# Add semantic results to remaining empty areas | |
semantic_labels = torch.unique(semantic_results).cpu().tolist() | |
for semantic_label in semantic_labels: | |
if semantic_label == 0: # 0 is a special "thing" class | |
continue | |
mask = (semantic_results == semantic_label) & (panoptic_seg == 0) | |
mask_area = mask.sum().item() | |
if mask_area < stuff_area_limit: | |
continue | |
current_segment_id += 1 | |
panoptic_seg[mask] = current_segment_id | |
segments_info.append( | |
{ | |
"id": current_segment_id, | |
"isthing": False, | |
"category_id": semantic_label, | |
"area": mask_area, | |
} | |
) | |
return panoptic_seg, segments_info | |