File size: 9,341 Bytes
d6def08 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
from typing import Tuple, List, Optional, Union
import numpy as np
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from bbox import BBox
from extension.functional import beta_smooth_l1_loss
from torchvision.ops import nms
class RegionProposalNetwork(nn.Module):
def __init__(self, num_features_out: int, anchor_ratios: List[Tuple[int, int]], anchor_sizes: List[int],
pre_nms_top_n: int, post_nms_top_n: int, anchor_smooth_l1_loss_beta: float):
super().__init__()
self._features = nn.Sequential(
nn.Conv2d(in_channels=num_features_out, out_channels=512, kernel_size=3, padding=1),
nn.ReLU()
)
self._anchor_ratios = anchor_ratios
self._anchor_sizes = anchor_sizes
num_anchor_ratios = len(self._anchor_ratios)
num_anchor_sizes = len(self._anchor_sizes)
num_anchors = num_anchor_ratios * num_anchor_sizes
self._pre_nms_top_n = pre_nms_top_n
self._post_nms_top_n = post_nms_top_n
self._anchor_smooth_l1_loss_beta = anchor_smooth_l1_loss_beta
self._anchor_objectness = nn.Conv2d(in_channels=512, out_channels=num_anchors * 2, kernel_size=1)
self._anchor_transformer = nn.Conv2d(in_channels=512, out_channels=num_anchors * 4, kernel_size=1)
def forward(self, features: Tensor,
anchor_bboxes: Optional[Tensor] = None, gt_bboxes_batch: Optional[Tensor] = None,
image_width: Optional[int]=None, image_height: Optional[int]=None) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor, Tensor]]:
batch_size = features.shape[0]
features = self._features(features)
anchor_objectnesses = self._anchor_objectness(features)
anchor_transformers = self._anchor_transformer(features)
anchor_objectnesses = anchor_objectnesses.permute(0, 2, 3, 1).contiguous().view(batch_size, -1, 2)
anchor_transformers = anchor_transformers.permute(0, 2, 3, 1).contiguous().view(batch_size, -1, 4)
if not self.training:
return anchor_objectnesses, anchor_transformers
else:
# remove cross-boundary
# NOTE: The length of `inside_indices` is guaranteed to be a multiple of `anchor_bboxes.shape[0]` as each batch in `anchor_bboxes` is the same
inside_indices = BBox.inside(anchor_bboxes, left=0, top=0, right=image_width, bottom=image_height).nonzero().unbind(dim=1)
inside_anchor_bboxes = anchor_bboxes[inside_indices].view(batch_size, -1, anchor_bboxes.shape[2])
inside_anchor_objectnesses = anchor_objectnesses[inside_indices].view(batch_size, -1, anchor_objectnesses.shape[2])
inside_anchor_transformers = anchor_transformers[inside_indices].view(batch_size, -1, anchor_transformers.shape[2])
# find labels for each `anchor_bboxes`
labels = torch.full((batch_size, inside_anchor_bboxes.shape[1]), -1, dtype=torch.long, device=inside_anchor_bboxes.device)
ious = BBox.iou(inside_anchor_bboxes, gt_bboxes_batch)
anchor_max_ious, anchor_assignments = ious.max(dim=2)
gt_max_ious, gt_assignments = ious.max(dim=1)
anchor_additions = ((ious > 0) & (ious == gt_max_ious.unsqueeze(dim=1))).nonzero()[:, :2].unbind(dim=1)
labels[anchor_max_ious < 0.3] = 0
labels[anchor_additions] = 1
labels[anchor_max_ious >= 0.7] = 1
# select 256 x `batch_size` samples
fg_indices = (labels == 1).nonzero()
bg_indices = (labels == 0).nonzero()
fg_indices = fg_indices[torch.randperm(len(fg_indices))[:min(len(fg_indices), 256 * batch_size)]]
bg_indices = bg_indices[torch.randperm(len(bg_indices))[:256 * batch_size - len(fg_indices)]]
selected_indices = torch.cat([fg_indices, bg_indices], dim=0)
selected_indices = selected_indices[torch.randperm(len(selected_indices))].unbind(dim=1)
inside_anchor_bboxes = inside_anchor_bboxes[selected_indices]
gt_bboxes = gt_bboxes_batch[selected_indices[0], anchor_assignments[selected_indices]]
gt_anchor_objectnesses = labels[selected_indices]
gt_anchor_transformers = BBox.calc_transformer(inside_anchor_bboxes, gt_bboxes)
batch_indices = selected_indices[0]
anchor_objectness_losses, anchor_transformer_losses = self.loss(inside_anchor_objectnesses[selected_indices],
inside_anchor_transformers[selected_indices],
gt_anchor_objectnesses,
gt_anchor_transformers,
batch_size, batch_indices)
return anchor_objectnesses, anchor_transformers, anchor_objectness_losses, anchor_transformer_losses
def loss(self, anchor_objectnesses: Tensor, anchor_transformers: Tensor,
gt_anchor_objectnesses: Tensor, gt_anchor_transformers: Tensor,
batch_size: int, batch_indices: Tensor) -> Tuple[Tensor, Tensor]:
cross_entropies = torch.empty(batch_size, dtype=torch.float, device=anchor_objectnesses.device)
smooth_l1_losses = torch.empty(batch_size, dtype=torch.float, device=anchor_transformers.device)
for batch_index in range(batch_size):
selected_indices = (batch_indices == batch_index).nonzero().view(-1)
cross_entropy = F.cross_entropy(input=anchor_objectnesses[selected_indices],
target=gt_anchor_objectnesses[selected_indices])
fg_indices = gt_anchor_objectnesses[selected_indices].nonzero().view(-1)
smooth_l1_loss = beta_smooth_l1_loss(input=anchor_transformers[selected_indices][fg_indices],
target=gt_anchor_transformers[selected_indices][fg_indices],
beta=self._anchor_smooth_l1_loss_beta)
cross_entropies[batch_index] = cross_entropy
smooth_l1_losses[batch_index] = smooth_l1_loss
return cross_entropies, smooth_l1_losses
def generate_anchors(self, image_width: int, image_height: int, num_x_anchors: int, num_y_anchors: int) -> Tensor:
center_ys = np.linspace(start=0, stop=image_height, num=num_y_anchors + 2)[1:-1]
center_xs = np.linspace(start=0, stop=image_width, num=num_x_anchors + 2)[1:-1]
ratios = np.array(self._anchor_ratios)
ratios = ratios[:, 0] / ratios[:, 1]
sizes = np.array(self._anchor_sizes)
# NOTE: it's important to let `center_ys` be the major index (i.e., move horizontally and then vertically) for consistency with 2D convolution
# giving the string 'ij' returns a meshgrid with matrix indexing, i.e., with shape (#center_ys, #center_xs, #ratios)
center_ys, center_xs, ratios, sizes = np.meshgrid(center_ys, center_xs, ratios, sizes, indexing='ij')
center_ys = center_ys.reshape(-1)
center_xs = center_xs.reshape(-1)
ratios = ratios.reshape(-1)
sizes = sizes.reshape(-1)
widths = sizes * np.sqrt(1 / ratios)
heights = sizes * np.sqrt(ratios)
center_based_anchor_bboxes = np.stack((center_xs, center_ys, widths, heights), axis=1)
center_based_anchor_bboxes = torch.from_numpy(center_based_anchor_bboxes).float()
anchor_bboxes = BBox.from_center_base(center_based_anchor_bboxes)
return anchor_bboxes
def generate_proposals(self, anchor_bboxes: Tensor, objectnesses: Tensor, transformers: Tensor, image_width: int, image_height: int) -> Tensor:
batch_size = anchor_bboxes.shape[0]
proposal_bboxes = BBox.apply_transformer(anchor_bboxes, transformers)
proposal_bboxes = BBox.clip(proposal_bboxes, left=0, top=0, right=image_width, bottom=image_height)
proposal_probs = F.softmax(objectnesses[:, :, 1], dim=-1)
_, sorted_indices = torch.sort(proposal_probs, dim=-1, descending=True)
nms_proposal_bboxes_batch = []
for batch_index in range(batch_size):
sorted_bboxes = proposal_bboxes[batch_index][sorted_indices[batch_index]][:self._pre_nms_top_n]
sorted_probs = proposal_probs[batch_index][sorted_indices[batch_index]][:self._pre_nms_top_n]
threshold = 0.7
kept_indices = nms(sorted_bboxes, sorted_probs, threshold)
nms_bboxes = sorted_bboxes[kept_indices][:self._post_nms_top_n]
nms_proposal_bboxes_batch.append(nms_bboxes)
max_nms_proposal_bboxes_length = max([len(it) for it in nms_proposal_bboxes_batch])
padded_proposal_bboxes = []
for nms_proposal_bboxes in nms_proposal_bboxes_batch:
padded_proposal_bboxes.append(
torch.cat([
nms_proposal_bboxes,
torch.zeros(max_nms_proposal_bboxes_length - len(nms_proposal_bboxes), 4).to(nms_proposal_bboxes)
])
)
padded_proposal_bboxes = torch.stack(padded_proposal_bboxes, dim=0)
return padded_proposal_bboxes
|