Spaces:
Runtime error
Runtime error
Last commit not found
#!/usr/bin/env python3 | |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
import torch | |
from torch import nn | |
from maskrcnn_benchmark.modeling.matcher import Matcher | |
from maskrcnn_benchmark.modeling.utils import cat | |
from maskrcnn_benchmark.structures.bounding_box import BoxList | |
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou | |
from .inference import make_roi_mask_post_processor | |
from .loss import make_roi_mask_loss_evaluator | |
from .roi_mask_feature_extractors import make_roi_mask_feature_extractor | |
from .roi_mask_predictors import make_roi_mask_predictor | |
from maskrcnn_benchmark.layers import Conv2d | |
import math | |
def conv3x3(in_planes, out_planes, stride=1, has_bias=False): | |
"3x3 convolution with padding" | |
return nn.Conv2d( | |
in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=has_bias | |
) | |
def conv3x3_bn_relu(in_planes, out_planes, stride=1, has_bias=False): | |
return nn.Sequential( | |
conv3x3(in_planes, out_planes, stride), | |
nn.BatchNorm2d(out_planes), | |
nn.ReLU(inplace=True), | |
) | |
def keep_only_positive_boxes(boxes, batch_size_per_im): | |
""" | |
Given a set of BoxList containing the `labels` field, | |
return a set of BoxList for which `labels > 0`. | |
Arguments: | |
boxes (list of BoxList) | |
""" | |
assert isinstance(boxes, (list, tuple)) | |
assert isinstance(boxes[0], BoxList) | |
assert boxes[0].has_field("labels") | |
positive_boxes = [] | |
positive_inds = [] | |
for boxes_per_image in boxes: | |
labels = boxes_per_image.get_field("labels") | |
inds_mask = labels > 0 | |
inds = inds_mask.nonzero().squeeze(1) | |
if len(inds) > batch_size_per_im: | |
new_inds = inds[:batch_size_per_im] | |
inds_mask[inds[batch_size_per_im:]] = 0 | |
else: | |
new_inds = inds | |
positive_boxes.append(boxes_per_image[new_inds]) | |
positive_inds.append(inds_mask) | |
return positive_boxes, positive_inds | |
# TODO | |
def project_char_masks_on_boxes( | |
segmentation_masks, segmentation_char_masks, proposals, discretization_size | |
): | |
""" | |
Given segmentation masks and the bounding boxes corresponding | |
to the location of the masks in the image, this function | |
crops and resizes the masks in the position defined by the | |
boxes. This prepares the masks for them to be fed to the | |
loss computation as the targets. | |
Arguments: | |
segmentation_masks: an instance of SegmentationMask | |
proposals: an instance of BoxList | |
""" | |
masks = [] | |
char_masks = [] | |
char_mask_weights = [] | |
decoder_targets = [] | |
word_targets = [] | |
M_H, M_W = discretization_size[0], discretization_size[1] | |
device = proposals.bbox.device | |
proposals = proposals.convert("xyxy") | |
assert segmentation_masks.size == proposals.size, "{}, {}".format( | |
segmentation_masks, proposals | |
) | |
assert segmentation_char_masks.size == proposals.size, "{}, {}".format( | |
segmentation_char_masks, proposals | |
) | |
# TODO put the proposals on the CPU, as the representation for the | |
# masks is not efficient GPU-wise (possibly several small tensors for | |
# representing a single instance mask) | |
proposals = proposals.bbox.to(torch.device("cpu")) | |
for segmentation_mask, segmentation_char_mask, proposal in zip( | |
segmentation_masks, segmentation_char_masks, proposals | |
): | |
# crop the masks, resize them to the desired resolution and | |
# then convert them to the tensor representation, | |
# instead of the list representation that was used | |
cropped_mask = segmentation_mask.crop(proposal) | |
scaled_mask = cropped_mask.resize((M_W, M_H)) | |
mask = scaled_mask.convert(mode="mask") | |
masks.append(mask) | |
cropped_char_mask = segmentation_char_mask.crop(proposal) | |
scaled_char_mask = cropped_char_mask.resize((M_W, M_H)) | |
char_mask, char_mask_weight, decoder_target, word_target = scaled_char_mask.convert( | |
mode="seq_char_mask" | |
) | |
char_masks.append(char_mask) | |
char_mask_weights.append(char_mask_weight) | |
decoder_targets.append(decoder_target) | |
word_targets.append(word_target) | |
if len(masks) == 0: | |
return ( | |
torch.empty(0, dtype=torch.float32, device=device), | |
torch.empty(0, dtype=torch.long, device=device), | |
torch.empty(0, dtype=torch.float32, device=device), | |
torch.empty(0, dtype=torch.long, device=device), | |
) | |
return ( | |
torch.stack(masks, dim=0).to(device, dtype=torch.float32), | |
torch.stack(char_masks, dim=0).to(device, dtype=torch.long), | |
torch.stack(char_mask_weights, dim=0).to(device, dtype=torch.float32), | |
torch.stack(decoder_targets, dim=0).to(device, dtype=torch.long), | |
torch.stack(word_targets, dim=0).to(device, dtype=torch.long), | |
) | |
class ROIMaskHead(torch.nn.Module): | |
def __init__(self, cfg, proposal_matcher, discretization_size): | |
super(ROIMaskHead, self).__init__() | |
self.proposal_matcher = proposal_matcher | |
self.discretization_size = discretization_size | |
self.cfg = cfg.clone() | |
self.feature_extractor = make_roi_mask_feature_extractor(cfg) | |
self.predictor = make_roi_mask_predictor(cfg) | |
self.post_processor = make_roi_mask_post_processor(cfg) | |
self.loss_evaluator = make_roi_mask_loss_evaluator(cfg) | |
if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'ATTENTION': | |
self.mask_attention = nn.Sequential( | |
conv3x3_bn_relu(cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1] + 1, 32), | |
conv3x3(32, 1), | |
# Conv2d(cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1] + 1, 1, 1, 1, 0), | |
nn.Sigmoid() | |
) | |
self.mask_attention.apply(self.weights_init) | |
# for name, param in self.named_parameters(): | |
# if "bias" in name: | |
# nn.init.constant_(param, 0) | |
# elif "weight" in name: | |
# # Caffe2 implementation uses MSRAFill, which in fact | |
# # corresponds to kaiming_normal_ in PyTorch | |
# nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu") | |
if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'ATTENTION_DOWN': | |
self.mask_attention = nn.Sequential( | |
conv3x3_bn_relu(cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1] + 1, 32, stride=2), | |
conv3x3(32, 1, stride=2), | |
nn.Upsample(scale_factor=4, mode='nearest'), | |
nn.Sigmoid() | |
) | |
self.mask_attention.apply(self.weights_init) | |
if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'ATTENTION_CHANNEL': | |
num_channel = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1] * 2 | |
self.channel_attention = nn.Sequential( | |
nn.MaxPool2d(2), | |
conv3x3_bn_relu(num_channel, num_channel, stride=2), | |
conv3x3(num_channel, num_channel, stride=2), | |
nn.AdaptiveAvgPool2d((1,1)), | |
nn.Sigmoid() | |
) | |
self.channel_attention.apply(self.weights_init) | |
if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'ATTENTION_CHANNEL_SPLIT' or self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'ATTENTION_CHANNEL_SPLIT_BINARY': | |
num_channel = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1] * 2 | |
self.channel_attention = nn.Sequential( | |
nn.MaxPool2d(2), | |
conv3x3_bn_relu(num_channel, int(num_channel / 4), stride=2), | |
conv3x3(int(num_channel / 4), 2, stride=2), | |
nn.AdaptiveAvgPool2d((1,1)), | |
# nn.Sigmoid() | |
nn.Softmax(dim=1) | |
) | |
self.channel_attention.apply(self.weights_init) | |
if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'ATTENTION_CHANNEL_2': | |
num_channel = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1] * 2 | |
self.channel_attention_2 = nn.Sequential( | |
nn.AdaptiveAvgPool2d((1,1)), | |
nn.Conv2d( | |
num_channel, num_channel, kernel_size=1, stride=1, padding=0 | |
), | |
nn.Conv2d( | |
num_channel, num_channel, kernel_size=1, stride=1, padding=0 | |
), | |
nn.Softmax(dim=1) | |
) | |
self.channel_attention_2.apply(self.weights_init) | |
if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'ATTENTION_CHANNEL_TANH': | |
feature_dim = 128 | |
num_channel = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1] * 2 | |
self.mask_pooler = nn.Sequential( | |
nn.MaxPool2d(2), | |
conv3x3_bn_relu(num_channel, num_channel, stride=2), | |
) | |
self.attn = nn.Linear(feature_dim, feature_dim) | |
self.v = nn.Parameter(torch.rand(feature_dim)) | |
stdv = 1.0 / math.sqrt(self.v.size(0)) | |
self.v.data.normal_(mean=0, std=stdv) | |
self.mask_pooler.apply(self.weights_init) | |
if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'NEW_CAT': | |
num_channel = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1] | |
self.enlarge_recepitve_field = nn.Sequential( | |
nn.Conv2d( | |
2 * num_channel, num_channel, kernel_size=3, stride=1, padding=2, dilation=2 | |
), | |
nn.Conv2d( | |
num_channel, num_channel, kernel_size=3, stride=1, padding=2, dilation=2 | |
), | |
) | |
self.enlarge_recepitve_field.apply(self.weights_init) | |
if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'NEW_MASK': | |
num_channel = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1] | |
self.new_mask = nn.Sequential( | |
nn.Conv2d( | |
2 * num_channel, num_channel, kernel_size=3, stride=1, padding=2, dilation=2 | |
), | |
nn.Conv2d( | |
num_channel, 32, kernel_size=3, stride=1, padding=2, dilation=2 | |
), | |
nn.Conv2d( | |
32, 1, kernel_size=3, stride=1, padding=2, dilation=2 | |
), | |
nn.Sigmoid() | |
) | |
self.new_mask.apply(self.weights_init) | |
def weights_init(self, m): | |
classname = m.__class__.__name__ | |
if classname.find("Conv") != -1: | |
nn.init.kaiming_normal_(m.weight.data) | |
elif classname.find("BatchNorm") != -1: | |
m.weight.data.fill_(1.0) | |
m.bias.data.fill_(1e-4) | |
def step_function(self, x): | |
return torch.reciprocal(1 + torch.exp(-50 * (x - 0.5))) | |
def channel_attention_tanh(self, feature, mask): | |
""" | |
:param hidden: | |
previous hidden state of the decoder, in shape (B, hidden_size) | |
:param encoder_outputs: | |
encoder outputs from Encoder, in shape (H*W, B, hidden_size) | |
:return | |
attention energies in shape (B, H*W) | |
""" | |
feature = feature.reshape((feature.shape[0], feature.shape[1], -1)) # (B, C, H*W) | |
masks = mask.reshape((mask.shape[0], mask.shape[1], -1)).repeat(1, feature.shape[1], 1) # (B, C, H*W) | |
fuse_feature = torch.cat([feature, masks], 2) | |
energy = torch.tanh(self.attn(fuse_feature)) # (B, C, 2*H*W)->(B, C, 2*H*W) | |
energy = energy.transpose(2, 1) # (B, 2*H*W, C) | |
v = self.v.repeat(feature.shape[0], 1).unsqueeze( | |
1 | |
) # (B, 1, 2*H*W) | |
energy = torch.bmm(v, energy) # (B, 1, C) | |
energy = energy.squeeze(1) # (B, C) | |
return nn.functional.softmax(energy, dim=1).unsqueeze(2).unsqueeze(3) # normalize with softmax (B, C) | |
def match_targets_to_proposals(self, proposal, target): | |
match_quality_matrix = boxlist_iou(target, proposal) | |
# match_quality_matrix = boxlist_polygon_iou(target, proposal) | |
matched_idxs = self.proposal_matcher(match_quality_matrix) | |
# Mask RCNN needs "labels" and "masks "fields for creating the targets | |
target = target.copy_with_fields(["labels", "masks", "char_masks"]) | |
# get the targets corresponding GT for each proposal | |
# NB: need to clamp the indices because we can have a single | |
# GT in the image, and matched_idxs can be -2, which goes | |
# out of bounds | |
matched_targets = target[matched_idxs.clamp(min=0)] | |
matched_targets.add_field("matched_idxs", matched_idxs) | |
return matched_targets | |
def prepare_targets(self, proposals, targets): | |
masks = [] | |
char_masks = [] | |
char_mask_weights = [] | |
decoder_targets = [] | |
word_targets = [] | |
for proposals_per_image, targets_per_image in zip(proposals, targets): | |
matched_targets = self.match_targets_to_proposals( | |
proposals_per_image, targets_per_image | |
) | |
matched_idxs = matched_targets.get_field("matched_idxs") | |
labels_per_image = matched_targets.get_field("labels") | |
labels_per_image = labels_per_image.to(dtype=torch.int64) | |
# this can probably be removed, but is left here for clarity | |
# and completeness | |
neg_inds = matched_idxs == Matcher.BELOW_LOW_THRESHOLD | |
labels_per_image[neg_inds] = 0 | |
# mask scores are only computed on positive samples | |
positive_inds = torch.nonzero(labels_per_image > 0).squeeze(1) | |
segmentation_masks = matched_targets.get_field("masks") | |
segmentation_masks = segmentation_masks[positive_inds] | |
char_segmentation_masks = matched_targets.get_field("char_masks") | |
char_segmentation_masks = char_segmentation_masks[positive_inds] | |
positive_proposals = proposals_per_image[positive_inds] | |
masks_per_image, char_masks_per_image, char_masks_weight_per_image, decoder_targets_per_image, word_targets_per_image = project_char_masks_on_boxes( | |
segmentation_masks, | |
char_segmentation_masks, | |
positive_proposals, | |
self.discretization_size, | |
) | |
masks.append(masks_per_image) | |
char_masks.append(char_masks_per_image) | |
char_mask_weights.append(char_masks_weight_per_image) | |
decoder_targets.append(decoder_targets_per_image) | |
word_targets.append(word_targets_per_image) | |
return masks, char_masks, char_mask_weights, decoder_targets, word_targets | |
def feature_mask(self, x, proposals): | |
masks = [] | |
for proposal in proposals: | |
segmentation_masks = proposal.get_field("masks") | |
boxes = proposal.bbox.to(torch.device("cpu")) | |
for segmentation_mask, box in zip(segmentation_masks, boxes): | |
cropped_mask = segmentation_mask.crop(box) | |
scaled_mask = cropped_mask.resize((self.cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION_W, self.cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION_H)) | |
mask = scaled_mask.convert(mode="mask") | |
masks.append(mask) | |
if len(masks) == 0: | |
if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'CAT': | |
x = cat([x, torch.ones((x.shape[0], 1, x.shape[2], x.shape[3]), device=x.device)], dim=1) | |
if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'MIX' or 'ATTENTION_CHANNEL' in self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION: | |
x = cat([x, x], dim=1) | |
return x | |
masks = torch.stack(masks, dim=0).to(x.device, dtype=torch.float32) | |
if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'CAT': | |
x = cat([x, masks.unsqueeze(1)], dim=1) | |
return x | |
if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'NEW_CAT': | |
cat_x = cat([x, x * masks.unsqueeze(1)], dim=1) | |
out_x = self.enlarge_recepitve_field(cat_x) | |
return out_x | |
if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'NEW_MASK': | |
cat_x = cat([x, x * masks.unsqueeze(1)], dim=1) | |
new_mask = self.new_mask(cat_x) | |
out_x = x * new_mask | |
return out_x | |
if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'ATTENTION' or self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'ATTENTION_DOWN': | |
x_cat = cat([x, masks.unsqueeze(1)], dim=1) | |
attention = self.mask_attention(x_cat) | |
x = x * attention | |
return x | |
if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'MIX': | |
mask_x = x * masks.unsqueeze(1) | |
cat_x = cat([x, mask_x], dim=1) | |
return cat_x | |
if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'ATTENTION_CHANNEL': | |
mask_x = x * masks.unsqueeze(1) | |
cat_x = cat([x, mask_x], dim=1) | |
channel_attention = self.channel_attention(cat_x) | |
attentioned_x = cat_x * channel_attention | |
return attentioned_x | |
if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'ATTENTION_CHANNEL_2': | |
mask_x = x * masks.unsqueeze(1) | |
cat_x = cat([x, mask_x], dim=1) | |
channel_attention = self.channel_attention_2(cat_x) | |
# print(channel_attention[0, :, 0, 0]) | |
attentioned_x = cat_x * channel_attention | |
return attentioned_x | |
if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'ATTENTION_CHANNEL_SPLIT': | |
mask_x = x * masks.unsqueeze(1) | |
cat_x = cat([x, mask_x], dim=1) | |
channel_attention = self.channel_attention(cat_x) | |
print(channel_attention[0, :, 0, 0]) | |
attentioned_x = cat([x * channel_attention[:, 0:1, :, :], mask_x * channel_attention[:, 1:, :, :]], dim=1) | |
return attentioned_x | |
if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'ATTENTION_CHANNEL_SPLIT_BINARY': | |
mask_x = x * masks.unsqueeze(1) | |
cat_x = cat([x, mask_x], dim=1) | |
channel_attention = self.step_function(self.channel_attention(cat_x)) | |
# print(channel_attention[:, :, 0, 0]) | |
attentioned_x = cat([x * channel_attention[:, 0:1, :, :], mask_x * channel_attention[:, 1:, :, :]], dim=1) | |
# attentioned_x = cat([x * channel_attention[:, 1:, :, :], mask_x * channel_attention[:, 0:1, :, :]], dim=1) | |
return attentioned_x | |
if self.cfg.MODEL.ROI_MASK_HEAD.MIX_OPTION == 'ATTENTION_CHANNEL_TANH': | |
mask_x = x * masks.unsqueeze(1) | |
cat_x = cat([x, mask_x], dim=1) | |
pooler_x = self.mask_pooler(cat_x) | |
pooler_mask = nn.functional.interpolate(masks.unsqueeze(1), scale_factor=0.25, mode='bilinear') | |
channel_attention = self.channel_attention_tanh(pooler_x, pooler_mask) | |
attentioned_x = cat_x * channel_attention | |
return attentioned_x | |
soft_ratio = self.cfg.MODEL.ROI_MASK_HEAD.SOFT_MASKED_FEATURE_RATIO | |
if soft_ratio > 0: | |
if soft_ratio < 1.0: | |
x = x * (soft_ratio + (1 - soft_ratio) * masks.unsqueeze(1)) | |
else: | |
x = x * (1.0 + soft_ratio * masks.unsqueeze(1)) | |
else: | |
x = x * masks.unsqueeze(1) | |
return x | |
def forward(self, features, proposals, targets=None): | |
""" | |
Arguments: | |
features (list[Tensor]): feature-maps from possibly several levels | |
proposals (list[BoxList]): proposal boxes | |
targets (list[BoxList], optional): the ground-truth targets. | |
Returns: | |
x (Tensor): the result of the feature extractor | |
proposals (list[BoxList]): during training, the original proposals | |
are returned. During testing, the predicted boxlists are returned | |
with the `mask` field set | |
losses (dict[Tensor]): During training, returns the losses for the | |
head. During testing, returns an empty dict. | |
""" | |
if self.training: | |
# during training, only focus on positive boxes | |
all_proposals = proposals | |
proposals, positive_inds = keep_only_positive_boxes( | |
proposals, self.cfg.MODEL.ROI_MASK_HEAD.MASK_BATCH_SIZE_PER_IM | |
) | |
if all(len(proposal) == 0 for proposal in proposals): | |
return None, None, None | |
if self.training and self.cfg.MODEL.ROI_MASK_HEAD.SHARE_BOX_FEATURE_EXTRACTOR: | |
x = features | |
x = x[torch.cat(positive_inds, dim=0)] | |
else: | |
x = self.feature_extractor(features, proposals) | |
if self.cfg.MODEL.ROI_MASK_HEAD.USE_MASKED_FEATURE: | |
x = self.feature_mask(x, proposals) | |
if self.training: | |
mask_targets, char_mask_targets, char_mask_weights, \ | |
decoder_targets, word_targets = self.prepare_targets( | |
proposals, targets | |
) | |
decoder_targets = cat(decoder_targets, dim=0) | |
word_targets = cat(word_targets, dim=0) | |
# proposals_not_empty, targets_not = [], [] | |
# for proposal, target, mask_target, char_mask_target, char_mask_weight in zip(proposals, targets, mask_targets, char_mask_targets, char_mask_weights): | |
# if len(proposal_target[0]) > 0: | |
# proposals_not_empty.append(proposal) | |
# targets_not.append(proposal_target[1]) | |
# proposals = proposals_not_empty | |
# targets = targets_not | |
if self.cfg.MODEL.CHAR_MASK_ON: | |
if self.cfg.SEQUENCE.SEQ_ON: | |
if not self.training: | |
if x.numel() > 0: | |
mask_logits, char_mask_logits, seq_outputs, seq_scores, \ | |
detailed_seq_scores = self.predictor(x) | |
result = self.post_processor( | |
mask_logits, | |
char_mask_logits, | |
proposals, | |
seq_outputs=seq_outputs, | |
seq_scores=seq_scores, | |
detailed_seq_scores=detailed_seq_scores, | |
) | |
return x, result, {} | |
else: | |
return None, None, {} | |
mask_logits, char_mask_logits, seq_outputs = self.predictor( | |
x, decoder_targets=decoder_targets, word_targets=word_targets | |
) | |
loss_mask, loss_char_mask = self.loss_evaluator( | |
proposals, | |
mask_logits, | |
char_mask_logits, | |
mask_targets, | |
char_mask_targets, | |
char_mask_weights, | |
) | |
return ( | |
x, | |
all_proposals, | |
dict( | |
loss_mask=loss_mask, | |
loss_char_mask=loss_char_mask, | |
loss_seq=seq_outputs, | |
), | |
) | |
else: | |
mask_logits, char_mask_logits = self.predictor(x) | |
if not self.training: | |
result = self.post_processor( | |
mask_logits, char_mask_logits, proposals | |
) | |
return x, result, {} | |
loss_mask, loss_char_mask = self.loss_evaluator( | |
proposals, | |
mask_logits, | |
char_mask_logits, | |
mask_targets, | |
char_mask_targets, | |
char_mask_weights, | |
) | |
return ( | |
x, | |
all_proposals, | |
dict(loss_mask=loss_mask, loss_char_mask=loss_char_mask), | |
) | |
else: | |
if self.cfg.SEQUENCE.SEQ_ON: | |
if self.cfg.MODEL.MASK_ON: | |
if not self.training: | |
if x.numel() > 0: | |
mask_logits, seq_outputs, seq_scores, \ | |
detailed_seq_scores = self.predictor(x) | |
result = self.post_processor( | |
mask_logits, | |
None, | |
proposals, | |
seq_outputs=seq_outputs, | |
seq_scores=seq_scores, | |
detailed_seq_scores=detailed_seq_scores, | |
) | |
return x, result, {} | |
else: | |
return None, None, {} | |
mask_logits, seq_outputs = self.predictor( | |
x, decoder_targets=decoder_targets, word_targets=word_targets | |
) | |
loss_mask = self.loss_evaluator( | |
proposals, | |
mask_logits, | |
mask_targets, | |
) | |
return ( | |
x, | |
all_proposals, | |
dict( | |
loss_mask=loss_mask, | |
loss_seq=seq_outputs, | |
), | |
) | |
else: | |
if not self.training: | |
if x.numel() > 0: | |
_, seq_outputs, seq_scores, \ | |
detailed_seq_scores = self.predictor(x) | |
result = self.post_processor( | |
None, | |
None, | |
proposals, | |
seq_outputs=seq_outputs, | |
seq_scores=seq_scores, | |
detailed_seq_scores=detailed_seq_scores, | |
) | |
return x, result, {} | |
else: | |
return None, None, {} | |
_, seq_outputs = self.predictor( | |
x, decoder_targets=decoder_targets, word_targets=word_targets | |
) | |
return ( | |
x, | |
all_proposals, | |
dict( | |
loss_seq=seq_outputs, | |
), | |
) | |
else: | |
mask_logits = self.predictor(x) | |
if not self.training: | |
result = self.post_processor(mask_logits, proposals) | |
return x, result, {} | |
loss_mask = self.loss_evaluator(proposals, mask_logits, targets) | |
return x, all_proposals, dict(loss_mask=loss_mask) | |
def build_roi_mask_head(cfg): | |
matcher = Matcher( | |
cfg.MODEL.ROI_HEADS.FG_IOU_THRESHOLD, | |
cfg.MODEL.ROI_HEADS.BG_IOU_THRESHOLD, | |
allow_low_quality_matches=False, | |
) | |
return ROIMaskHead( | |
cfg, | |
matcher, | |
(cfg.MODEL.ROI_MASK_HEAD.RESOLUTION_H, cfg.MODEL.ROI_MASK_HEAD.RESOLUTION_W), | |
) | |