fasterrcnn-project-demo / rpn /region_proposal_network.py
sadimanna's picture
Upload 20 files
d6def08
from typing import Tuple, List, Optional, Union
import numpy as np
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from bbox import BBox
from extension.functional import beta_smooth_l1_loss
from torchvision.ops import nms
class RegionProposalNetwork(nn.Module):
def __init__(self, num_features_out: int, anchor_ratios: List[Tuple[int, int]], anchor_sizes: List[int],
pre_nms_top_n: int, post_nms_top_n: int, anchor_smooth_l1_loss_beta: float):
super().__init__()
self._features = nn.Sequential(
nn.Conv2d(in_channels=num_features_out, out_channels=512, kernel_size=3, padding=1),
nn.ReLU()
)
self._anchor_ratios = anchor_ratios
self._anchor_sizes = anchor_sizes
num_anchor_ratios = len(self._anchor_ratios)
num_anchor_sizes = len(self._anchor_sizes)
num_anchors = num_anchor_ratios * num_anchor_sizes
self._pre_nms_top_n = pre_nms_top_n
self._post_nms_top_n = post_nms_top_n
self._anchor_smooth_l1_loss_beta = anchor_smooth_l1_loss_beta
self._anchor_objectness = nn.Conv2d(in_channels=512, out_channels=num_anchors * 2, kernel_size=1)
self._anchor_transformer = nn.Conv2d(in_channels=512, out_channels=num_anchors * 4, kernel_size=1)
def forward(self, features: Tensor,
anchor_bboxes: Optional[Tensor] = None, gt_bboxes_batch: Optional[Tensor] = None,
image_width: Optional[int]=None, image_height: Optional[int]=None) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor, Tensor]]:
batch_size = features.shape[0]
features = self._features(features)
anchor_objectnesses = self._anchor_objectness(features)
anchor_transformers = self._anchor_transformer(features)
anchor_objectnesses = anchor_objectnesses.permute(0, 2, 3, 1).contiguous().view(batch_size, -1, 2)
anchor_transformers = anchor_transformers.permute(0, 2, 3, 1).contiguous().view(batch_size, -1, 4)
if not self.training:
return anchor_objectnesses, anchor_transformers
else:
# remove cross-boundary
# NOTE: The length of `inside_indices` is guaranteed to be a multiple of `anchor_bboxes.shape[0]` as each batch in `anchor_bboxes` is the same
inside_indices = BBox.inside(anchor_bboxes, left=0, top=0, right=image_width, bottom=image_height).nonzero().unbind(dim=1)
inside_anchor_bboxes = anchor_bboxes[inside_indices].view(batch_size, -1, anchor_bboxes.shape[2])
inside_anchor_objectnesses = anchor_objectnesses[inside_indices].view(batch_size, -1, anchor_objectnesses.shape[2])
inside_anchor_transformers = anchor_transformers[inside_indices].view(batch_size, -1, anchor_transformers.shape[2])
# find labels for each `anchor_bboxes`
labels = torch.full((batch_size, inside_anchor_bboxes.shape[1]), -1, dtype=torch.long, device=inside_anchor_bboxes.device)
ious = BBox.iou(inside_anchor_bboxes, gt_bboxes_batch)
anchor_max_ious, anchor_assignments = ious.max(dim=2)
gt_max_ious, gt_assignments = ious.max(dim=1)
anchor_additions = ((ious > 0) & (ious == gt_max_ious.unsqueeze(dim=1))).nonzero()[:, :2].unbind(dim=1)
labels[anchor_max_ious < 0.3] = 0
labels[anchor_additions] = 1
labels[anchor_max_ious >= 0.7] = 1
# select 256 x `batch_size` samples
fg_indices = (labels == 1).nonzero()
bg_indices = (labels == 0).nonzero()
fg_indices = fg_indices[torch.randperm(len(fg_indices))[:min(len(fg_indices), 256 * batch_size)]]
bg_indices = bg_indices[torch.randperm(len(bg_indices))[:256 * batch_size - len(fg_indices)]]
selected_indices = torch.cat([fg_indices, bg_indices], dim=0)
selected_indices = selected_indices[torch.randperm(len(selected_indices))].unbind(dim=1)
inside_anchor_bboxes = inside_anchor_bboxes[selected_indices]
gt_bboxes = gt_bboxes_batch[selected_indices[0], anchor_assignments[selected_indices]]
gt_anchor_objectnesses = labels[selected_indices]
gt_anchor_transformers = BBox.calc_transformer(inside_anchor_bboxes, gt_bboxes)
batch_indices = selected_indices[0]
anchor_objectness_losses, anchor_transformer_losses = self.loss(inside_anchor_objectnesses[selected_indices],
inside_anchor_transformers[selected_indices],
gt_anchor_objectnesses,
gt_anchor_transformers,
batch_size, batch_indices)
return anchor_objectnesses, anchor_transformers, anchor_objectness_losses, anchor_transformer_losses
def loss(self, anchor_objectnesses: Tensor, anchor_transformers: Tensor,
gt_anchor_objectnesses: Tensor, gt_anchor_transformers: Tensor,
batch_size: int, batch_indices: Tensor) -> Tuple[Tensor, Tensor]:
cross_entropies = torch.empty(batch_size, dtype=torch.float, device=anchor_objectnesses.device)
smooth_l1_losses = torch.empty(batch_size, dtype=torch.float, device=anchor_transformers.device)
for batch_index in range(batch_size):
selected_indices = (batch_indices == batch_index).nonzero().view(-1)
cross_entropy = F.cross_entropy(input=anchor_objectnesses[selected_indices],
target=gt_anchor_objectnesses[selected_indices])
fg_indices = gt_anchor_objectnesses[selected_indices].nonzero().view(-1)
smooth_l1_loss = beta_smooth_l1_loss(input=anchor_transformers[selected_indices][fg_indices],
target=gt_anchor_transformers[selected_indices][fg_indices],
beta=self._anchor_smooth_l1_loss_beta)
cross_entropies[batch_index] = cross_entropy
smooth_l1_losses[batch_index] = smooth_l1_loss
return cross_entropies, smooth_l1_losses
def generate_anchors(self, image_width: int, image_height: int, num_x_anchors: int, num_y_anchors: int) -> Tensor:
center_ys = np.linspace(start=0, stop=image_height, num=num_y_anchors + 2)[1:-1]
center_xs = np.linspace(start=0, stop=image_width, num=num_x_anchors + 2)[1:-1]
ratios = np.array(self._anchor_ratios)
ratios = ratios[:, 0] / ratios[:, 1]
sizes = np.array(self._anchor_sizes)
# NOTE: it's important to let `center_ys` be the major index (i.e., move horizontally and then vertically) for consistency with 2D convolution
# giving the string 'ij' returns a meshgrid with matrix indexing, i.e., with shape (#center_ys, #center_xs, #ratios)
center_ys, center_xs, ratios, sizes = np.meshgrid(center_ys, center_xs, ratios, sizes, indexing='ij')
center_ys = center_ys.reshape(-1)
center_xs = center_xs.reshape(-1)
ratios = ratios.reshape(-1)
sizes = sizes.reshape(-1)
widths = sizes * np.sqrt(1 / ratios)
heights = sizes * np.sqrt(ratios)
center_based_anchor_bboxes = np.stack((center_xs, center_ys, widths, heights), axis=1)
center_based_anchor_bboxes = torch.from_numpy(center_based_anchor_bboxes).float()
anchor_bboxes = BBox.from_center_base(center_based_anchor_bboxes)
return anchor_bboxes
def generate_proposals(self, anchor_bboxes: Tensor, objectnesses: Tensor, transformers: Tensor, image_width: int, image_height: int) -> Tensor:
batch_size = anchor_bboxes.shape[0]
proposal_bboxes = BBox.apply_transformer(anchor_bboxes, transformers)
proposal_bboxes = BBox.clip(proposal_bboxes, left=0, top=0, right=image_width, bottom=image_height)
proposal_probs = F.softmax(objectnesses[:, :, 1], dim=-1)
_, sorted_indices = torch.sort(proposal_probs, dim=-1, descending=True)
nms_proposal_bboxes_batch = []
for batch_index in range(batch_size):
sorted_bboxes = proposal_bboxes[batch_index][sorted_indices[batch_index]][:self._pre_nms_top_n]
sorted_probs = proposal_probs[batch_index][sorted_indices[batch_index]][:self._pre_nms_top_n]
threshold = 0.7
kept_indices = nms(sorted_bboxes, sorted_probs, threshold)
nms_bboxes = sorted_bboxes[kept_indices][:self._post_nms_top_n]
nms_proposal_bboxes_batch.append(nms_bboxes)
max_nms_proposal_bboxes_length = max([len(it) for it in nms_proposal_bboxes_batch])
padded_proposal_bboxes = []
for nms_proposal_bboxes in nms_proposal_bboxes_batch:
padded_proposal_bboxes.append(
torch.cat([
nms_proposal_bboxes,
torch.zeros(max_nms_proposal_bboxes_length - len(nms_proposal_bboxes), 4).to(nms_proposal_bboxes)
])
)
padded_proposal_bboxes = torch.stack(padded_proposal_bboxes, dim=0)
return padded_proposal_bboxes