|
from typing import Tuple, List, Optional, Union |
|
|
|
import numpy as np |
|
import torch |
|
from torch import nn, Tensor |
|
from torch.nn import functional as F |
|
|
|
from bbox import BBox |
|
from extension.functional import beta_smooth_l1_loss |
|
from torchvision.ops import nms |
|
|
|
|
|
class RegionProposalNetwork(nn.Module): |
|
|
|
def __init__(self, num_features_out: int, anchor_ratios: List[Tuple[int, int]], anchor_sizes: List[int], |
|
pre_nms_top_n: int, post_nms_top_n: int, anchor_smooth_l1_loss_beta: float): |
|
super().__init__() |
|
|
|
self._features = nn.Sequential( |
|
nn.Conv2d(in_channels=num_features_out, out_channels=512, kernel_size=3, padding=1), |
|
nn.ReLU() |
|
) |
|
|
|
self._anchor_ratios = anchor_ratios |
|
self._anchor_sizes = anchor_sizes |
|
|
|
num_anchor_ratios = len(self._anchor_ratios) |
|
num_anchor_sizes = len(self._anchor_sizes) |
|
num_anchors = num_anchor_ratios * num_anchor_sizes |
|
|
|
self._pre_nms_top_n = pre_nms_top_n |
|
self._post_nms_top_n = post_nms_top_n |
|
self._anchor_smooth_l1_loss_beta = anchor_smooth_l1_loss_beta |
|
|
|
self._anchor_objectness = nn.Conv2d(in_channels=512, out_channels=num_anchors * 2, kernel_size=1) |
|
self._anchor_transformer = nn.Conv2d(in_channels=512, out_channels=num_anchors * 4, kernel_size=1) |
|
|
|
def forward(self, features: Tensor, |
|
anchor_bboxes: Optional[Tensor] = None, gt_bboxes_batch: Optional[Tensor] = None, |
|
image_width: Optional[int]=None, image_height: Optional[int]=None) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor, Tensor]]: |
|
batch_size = features.shape[0] |
|
|
|
features = self._features(features) |
|
anchor_objectnesses = self._anchor_objectness(features) |
|
anchor_transformers = self._anchor_transformer(features) |
|
|
|
anchor_objectnesses = anchor_objectnesses.permute(0, 2, 3, 1).contiguous().view(batch_size, -1, 2) |
|
anchor_transformers = anchor_transformers.permute(0, 2, 3, 1).contiguous().view(batch_size, -1, 4) |
|
|
|
if not self.training: |
|
return anchor_objectnesses, anchor_transformers |
|
else: |
|
|
|
|
|
inside_indices = BBox.inside(anchor_bboxes, left=0, top=0, right=image_width, bottom=image_height).nonzero().unbind(dim=1) |
|
inside_anchor_bboxes = anchor_bboxes[inside_indices].view(batch_size, -1, anchor_bboxes.shape[2]) |
|
inside_anchor_objectnesses = anchor_objectnesses[inside_indices].view(batch_size, -1, anchor_objectnesses.shape[2]) |
|
inside_anchor_transformers = anchor_transformers[inside_indices].view(batch_size, -1, anchor_transformers.shape[2]) |
|
|
|
|
|
labels = torch.full((batch_size, inside_anchor_bboxes.shape[1]), -1, dtype=torch.long, device=inside_anchor_bboxes.device) |
|
ious = BBox.iou(inside_anchor_bboxes, gt_bboxes_batch) |
|
anchor_max_ious, anchor_assignments = ious.max(dim=2) |
|
gt_max_ious, gt_assignments = ious.max(dim=1) |
|
anchor_additions = ((ious > 0) & (ious == gt_max_ious.unsqueeze(dim=1))).nonzero()[:, :2].unbind(dim=1) |
|
labels[anchor_max_ious < 0.3] = 0 |
|
labels[anchor_additions] = 1 |
|
labels[anchor_max_ious >= 0.7] = 1 |
|
|
|
|
|
fg_indices = (labels == 1).nonzero() |
|
bg_indices = (labels == 0).nonzero() |
|
fg_indices = fg_indices[torch.randperm(len(fg_indices))[:min(len(fg_indices), 256 * batch_size)]] |
|
bg_indices = bg_indices[torch.randperm(len(bg_indices))[:256 * batch_size - len(fg_indices)]] |
|
selected_indices = torch.cat([fg_indices, bg_indices], dim=0) |
|
selected_indices = selected_indices[torch.randperm(len(selected_indices))].unbind(dim=1) |
|
|
|
inside_anchor_bboxes = inside_anchor_bboxes[selected_indices] |
|
gt_bboxes = gt_bboxes_batch[selected_indices[0], anchor_assignments[selected_indices]] |
|
gt_anchor_objectnesses = labels[selected_indices] |
|
gt_anchor_transformers = BBox.calc_transformer(inside_anchor_bboxes, gt_bboxes) |
|
batch_indices = selected_indices[0] |
|
|
|
anchor_objectness_losses, anchor_transformer_losses = self.loss(inside_anchor_objectnesses[selected_indices], |
|
inside_anchor_transformers[selected_indices], |
|
gt_anchor_objectnesses, |
|
gt_anchor_transformers, |
|
batch_size, batch_indices) |
|
|
|
return anchor_objectnesses, anchor_transformers, anchor_objectness_losses, anchor_transformer_losses |
|
|
|
def loss(self, anchor_objectnesses: Tensor, anchor_transformers: Tensor, |
|
gt_anchor_objectnesses: Tensor, gt_anchor_transformers: Tensor, |
|
batch_size: int, batch_indices: Tensor) -> Tuple[Tensor, Tensor]: |
|
cross_entropies = torch.empty(batch_size, dtype=torch.float, device=anchor_objectnesses.device) |
|
smooth_l1_losses = torch.empty(batch_size, dtype=torch.float, device=anchor_transformers.device) |
|
|
|
for batch_index in range(batch_size): |
|
selected_indices = (batch_indices == batch_index).nonzero().view(-1) |
|
|
|
cross_entropy = F.cross_entropy(input=anchor_objectnesses[selected_indices], |
|
target=gt_anchor_objectnesses[selected_indices]) |
|
|
|
fg_indices = gt_anchor_objectnesses[selected_indices].nonzero().view(-1) |
|
smooth_l1_loss = beta_smooth_l1_loss(input=anchor_transformers[selected_indices][fg_indices], |
|
target=gt_anchor_transformers[selected_indices][fg_indices], |
|
beta=self._anchor_smooth_l1_loss_beta) |
|
|
|
cross_entropies[batch_index] = cross_entropy |
|
smooth_l1_losses[batch_index] = smooth_l1_loss |
|
|
|
return cross_entropies, smooth_l1_losses |
|
|
|
def generate_anchors(self, image_width: int, image_height: int, num_x_anchors: int, num_y_anchors: int) -> Tensor: |
|
center_ys = np.linspace(start=0, stop=image_height, num=num_y_anchors + 2)[1:-1] |
|
center_xs = np.linspace(start=0, stop=image_width, num=num_x_anchors + 2)[1:-1] |
|
ratios = np.array(self._anchor_ratios) |
|
ratios = ratios[:, 0] / ratios[:, 1] |
|
sizes = np.array(self._anchor_sizes) |
|
|
|
|
|
|
|
center_ys, center_xs, ratios, sizes = np.meshgrid(center_ys, center_xs, ratios, sizes, indexing='ij') |
|
|
|
center_ys = center_ys.reshape(-1) |
|
center_xs = center_xs.reshape(-1) |
|
ratios = ratios.reshape(-1) |
|
sizes = sizes.reshape(-1) |
|
|
|
widths = sizes * np.sqrt(1 / ratios) |
|
heights = sizes * np.sqrt(ratios) |
|
|
|
center_based_anchor_bboxes = np.stack((center_xs, center_ys, widths, heights), axis=1) |
|
center_based_anchor_bboxes = torch.from_numpy(center_based_anchor_bboxes).float() |
|
anchor_bboxes = BBox.from_center_base(center_based_anchor_bboxes) |
|
|
|
return anchor_bboxes |
|
|
|
def generate_proposals(self, anchor_bboxes: Tensor, objectnesses: Tensor, transformers: Tensor, image_width: int, image_height: int) -> Tensor: |
|
batch_size = anchor_bboxes.shape[0] |
|
|
|
proposal_bboxes = BBox.apply_transformer(anchor_bboxes, transformers) |
|
proposal_bboxes = BBox.clip(proposal_bboxes, left=0, top=0, right=image_width, bottom=image_height) |
|
proposal_probs = F.softmax(objectnesses[:, :, 1], dim=-1) |
|
|
|
_, sorted_indices = torch.sort(proposal_probs, dim=-1, descending=True) |
|
nms_proposal_bboxes_batch = [] |
|
|
|
for batch_index in range(batch_size): |
|
sorted_bboxes = proposal_bboxes[batch_index][sorted_indices[batch_index]][:self._pre_nms_top_n] |
|
sorted_probs = proposal_probs[batch_index][sorted_indices[batch_index]][:self._pre_nms_top_n] |
|
threshold = 0.7 |
|
kept_indices = nms(sorted_bboxes, sorted_probs, threshold) |
|
nms_bboxes = sorted_bboxes[kept_indices][:self._post_nms_top_n] |
|
nms_proposal_bboxes_batch.append(nms_bboxes) |
|
|
|
max_nms_proposal_bboxes_length = max([len(it) for it in nms_proposal_bboxes_batch]) |
|
padded_proposal_bboxes = [] |
|
|
|
for nms_proposal_bboxes in nms_proposal_bboxes_batch: |
|
padded_proposal_bboxes.append( |
|
torch.cat([ |
|
nms_proposal_bboxes, |
|
torch.zeros(max_nms_proposal_bboxes_length - len(nms_proposal_bboxes), 4).to(nms_proposal_bboxes) |
|
]) |
|
) |
|
|
|
padded_proposal_bboxes = torch.stack(padded_proposal_bboxes, dim=0) |
|
return padded_proposal_bboxes |
|
|