Spaces:
Build error
Build error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| import copy | |
| import logging | |
| import numpy as np | |
| from typing import Dict, List, Optional, Tuple | |
| import torch | |
| from torch import nn | |
| import json | |
| from detectron2.utils.events import get_event_storage | |
| from detectron2.config import configurable | |
| from detectron2.structures import ImageList, Instances, Boxes | |
| import detectron2.utils.comm as comm | |
| from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY | |
| from detectron2.modeling.meta_arch.rcnn import GeneralizedRCNN | |
| from detectron2.modeling.postprocessing import detector_postprocess | |
| from detectron2.utils.visualizer import Visualizer, _create_text_labels | |
| from detectron2.data.detection_utils import convert_image_to_rgb | |
| from torch.cuda.amp import autocast | |
| from ..text.text_encoder import build_text_encoder | |
| from ..utils import load_class_freq, get_fed_loss_inds | |
| class CustomRCNN(GeneralizedRCNN): | |
| ''' | |
| Add image labels | |
| ''' | |
| def __init__( | |
| self, | |
| with_image_labels = False, | |
| dataset_loss_weight = [], | |
| fp16 = False, | |
| sync_caption_batch = False, | |
| roi_head_name = '', | |
| cap_batch_ratio = 4, | |
| with_caption = False, | |
| dynamic_classifier = False, | |
| **kwargs): | |
| """ | |
| """ | |
| self.with_image_labels = with_image_labels | |
| self.dataset_loss_weight = dataset_loss_weight | |
| self.fp16 = fp16 | |
| self.with_caption = with_caption | |
| self.sync_caption_batch = sync_caption_batch | |
| self.roi_head_name = roi_head_name | |
| self.cap_batch_ratio = cap_batch_ratio | |
| self.dynamic_classifier = dynamic_classifier | |
| self.return_proposal = False | |
| if self.dynamic_classifier: | |
| self.freq_weight = kwargs.pop('freq_weight') | |
| self.num_classes = kwargs.pop('num_classes') | |
| self.num_sample_cats = kwargs.pop('num_sample_cats') | |
| super().__init__(**kwargs) | |
| assert self.proposal_generator is not None | |
| if self.with_caption: | |
| assert not self.dynamic_classifier | |
| self.text_encoder = build_text_encoder(pretrain=True) | |
| for v in self.text_encoder.parameters(): | |
| v.requires_grad = False | |
| def from_config(cls, cfg): | |
| ret = super().from_config(cfg) | |
| ret.update({ | |
| 'with_image_labels': cfg.WITH_IMAGE_LABELS, | |
| 'dataset_loss_weight': cfg.MODEL.DATASET_LOSS_WEIGHT, | |
| 'fp16': cfg.FP16, | |
| 'with_caption': cfg.MODEL.WITH_CAPTION, | |
| 'sync_caption_batch': cfg.MODEL.SYNC_CAPTION_BATCH, | |
| 'dynamic_classifier': cfg.MODEL.DYNAMIC_CLASSIFIER, | |
| 'roi_head_name': cfg.MODEL.ROI_HEADS.NAME, | |
| 'cap_batch_ratio': cfg.MODEL.CAP_BATCH_RATIO, | |
| }) | |
| if ret['dynamic_classifier']: | |
| ret['freq_weight'] = load_class_freq( | |
| cfg.MODEL.ROI_BOX_HEAD.CAT_FREQ_PATH, | |
| cfg.MODEL.ROI_BOX_HEAD.FED_LOSS_FREQ_WEIGHT) | |
| ret['num_classes'] = cfg.MODEL.ROI_HEADS.NUM_CLASSES | |
| ret['num_sample_cats'] = cfg.MODEL.NUM_SAMPLE_CATS | |
| return ret | |
| def inference( | |
| self, | |
| batched_inputs: Tuple[Dict[str, torch.Tensor]], | |
| detected_instances: Optional[List[Instances]] = None, | |
| do_postprocess: bool = True, | |
| ): | |
| assert not self.training | |
| assert detected_instances is None | |
| images = self.preprocess_image(batched_inputs) | |
| features = self.backbone(images.tensor) | |
| proposals, _ = self.proposal_generator(images, features, None) | |
| results, _ = self.roi_heads(images, features, proposals) | |
| if do_postprocess: | |
| assert not torch.jit.is_scripting(), \ | |
| "Scripting is not supported for postprocess." | |
| return CustomRCNN._postprocess( | |
| results, batched_inputs, images.image_sizes) | |
| else: | |
| return results | |
| def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]): | |
| """ | |
| Add ann_type | |
| Ignore proposal loss when training with image labels | |
| """ | |
| if not self.training: | |
| return self.inference(batched_inputs) | |
| images = self.preprocess_image(batched_inputs) | |
| ann_type = 'box' | |
| gt_instances = [x["instances"].to(self.device) for x in batched_inputs] | |
| if self.with_image_labels: | |
| for inst, x in zip(gt_instances, batched_inputs): | |
| inst._ann_type = x['ann_type'] | |
| inst._pos_category_ids = x['pos_category_ids'] | |
| ann_types = [x['ann_type'] for x in batched_inputs] | |
| assert len(set(ann_types)) == 1 | |
| ann_type = ann_types[0] | |
| if ann_type in ['prop', 'proptag']: | |
| for t in gt_instances: | |
| t.gt_classes *= 0 | |
| if self.fp16: # TODO (zhouxy): improve | |
| with autocast(): | |
| features = self.backbone(images.tensor.half()) | |
| features = {k: v.float() for k, v in features.items()} | |
| else: | |
| features = self.backbone(images.tensor) | |
| cls_features, cls_inds, caption_features = None, None, None | |
| if self.with_caption and 'caption' in ann_type: | |
| inds = [torch.randint(len(x['captions']), (1,))[0].item() \ | |
| for x in batched_inputs] | |
| caps = [x['captions'][ind] for ind, x in zip(inds, batched_inputs)] | |
| caption_features = self.text_encoder(caps).float() | |
| if self.sync_caption_batch: | |
| caption_features = self._sync_caption_features( | |
| caption_features, ann_type, len(batched_inputs)) | |
| if self.dynamic_classifier and ann_type != 'caption': | |
| cls_inds = self._sample_cls_inds(gt_instances, ann_type) # inds, inv_inds | |
| ind_with_bg = cls_inds[0].tolist() + [-1] | |
| cls_features = self.roi_heads.box_predictor[ | |
| 0].cls_score.zs_weight[:, ind_with_bg].permute(1, 0).contiguous() | |
| classifier_info = cls_features, cls_inds, caption_features | |
| proposals, proposal_losses = self.proposal_generator( | |
| images, features, gt_instances) | |
| if self.roi_head_name in ['StandardROIHeads', 'CascadeROIHeads']: | |
| proposals, detector_losses = self.roi_heads( | |
| images, features, proposals, gt_instances) | |
| else: | |
| proposals, detector_losses = self.roi_heads( | |
| images, features, proposals, gt_instances, | |
| ann_type=ann_type, classifier_info=classifier_info) | |
| if self.vis_period > 0: | |
| storage = get_event_storage() | |
| if storage.iter % self.vis_period == 0: | |
| self.visualize_training(batched_inputs, proposals) | |
| losses = {} | |
| losses.update(detector_losses) | |
| if self.with_image_labels: | |
| if ann_type in ['box', 'prop', 'proptag']: | |
| losses.update(proposal_losses) | |
| else: # ignore proposal loss for non-bbox data | |
| losses.update({k: v * 0 for k, v in proposal_losses.items()}) | |
| else: | |
| losses.update(proposal_losses) | |
| if len(self.dataset_loss_weight) > 0: | |
| dataset_sources = [x['dataset_source'] for x in batched_inputs] | |
| assert len(set(dataset_sources)) == 1 | |
| dataset_source = dataset_sources[0] | |
| for k in losses: | |
| losses[k] *= self.dataset_loss_weight[dataset_source] | |
| if self.return_proposal: | |
| return proposals, losses | |
| else: | |
| return losses | |
| def _sync_caption_features(self, caption_features, ann_type, BS): | |
| has_caption_feature = (caption_features is not None) | |
| BS = (BS * self.cap_batch_ratio) if (ann_type == 'box') else BS | |
| rank = torch.full( | |
| (BS, 1), comm.get_rank(), dtype=torch.float32, | |
| device=self.device) | |
| if not has_caption_feature: | |
| caption_features = rank.new_zeros((BS, 512)) | |
| caption_features = torch.cat([caption_features, rank], dim=1) | |
| global_caption_features = comm.all_gather(caption_features) | |
| caption_features = torch.cat( | |
| [x.to(self.device) for x in global_caption_features], dim=0) \ | |
| if has_caption_feature else None # (NB) x (D + 1) | |
| return caption_features | |
| def _sample_cls_inds(self, gt_instances, ann_type='box'): | |
| if ann_type == 'box': | |
| gt_classes = torch.cat( | |
| [x.gt_classes for x in gt_instances]) | |
| C = len(self.freq_weight) | |
| freq_weight = self.freq_weight | |
| else: | |
| gt_classes = torch.cat( | |
| [torch.tensor( | |
| x._pos_category_ids, | |
| dtype=torch.long, device=x.gt_classes.device) \ | |
| for x in gt_instances]) | |
| C = self.num_classes | |
| freq_weight = None | |
| assert gt_classes.max() < C, '{} {}'.format(gt_classes.max(), C) | |
| inds = get_fed_loss_inds( | |
| gt_classes, self.num_sample_cats, C, | |
| weight=freq_weight) | |
| cls_id_map = gt_classes.new_full( | |
| (self.num_classes + 1,), len(inds)) | |
| cls_id_map[inds] = torch.arange(len(inds), device=cls_id_map.device) | |
| return inds, cls_id_map |