fasterrcnn-project-demo / rpn /region_proposal_network.py
sadimanna's picture
Upload 20 files
d6def08
raw
history blame
9.34 kB
from typing import Tuple, List, Optional, Union
import numpy as np
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from bbox import BBox
from extension.functional import beta_smooth_l1_loss
from torchvision.ops import nms
class RegionProposalNetwork(nn.Module):
def __init__(self, num_features_out: int, anchor_ratios: List[Tuple[int, int]], anchor_sizes: List[int],
pre_nms_top_n: int, post_nms_top_n: int, anchor_smooth_l1_loss_beta: float):
super().__init__()
self._features = nn.Sequential(
nn.Conv2d(in_channels=num_features_out, out_channels=512, kernel_size=3, padding=1),
nn.ReLU()
)
self._anchor_ratios = anchor_ratios
self._anchor_sizes = anchor_sizes
num_anchor_ratios = len(self._anchor_ratios)
num_anchor_sizes = len(self._anchor_sizes)
num_anchors = num_anchor_ratios * num_anchor_sizes
self._pre_nms_top_n = pre_nms_top_n
self._post_nms_top_n = post_nms_top_n
self._anchor_smooth_l1_loss_beta = anchor_smooth_l1_loss_beta
self._anchor_objectness = nn.Conv2d(in_channels=512, out_channels=num_anchors * 2, kernel_size=1)
self._anchor_transformer = nn.Conv2d(in_channels=512, out_channels=num_anchors * 4, kernel_size=1)
def forward(self, features: Tensor,
anchor_bboxes: Optional[Tensor] = None, gt_bboxes_batch: Optional[Tensor] = None,
image_width: Optional[int]=None, image_height: Optional[int]=None) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor, Tensor]]:
batch_size = features.shape[0]
features = self._features(features)
anchor_objectnesses = self._anchor_objectness(features)
anchor_transformers = self._anchor_transformer(features)
anchor_objectnesses = anchor_objectnesses.permute(0, 2, 3, 1).contiguous().view(batch_size, -1, 2)
anchor_transformers = anchor_transformers.permute(0, 2, 3, 1).contiguous().view(batch_size, -1, 4)
if not self.training:
return anchor_objectnesses, anchor_transformers
else:
# remove cross-boundary
# NOTE: The length of `inside_indices` is guaranteed to be a multiple of `anchor_bboxes.shape[0]` as each batch in `anchor_bboxes` is the same
inside_indices = BBox.inside(anchor_bboxes, left=0, top=0, right=image_width, bottom=image_height).nonzero().unbind(dim=1)
inside_anchor_bboxes = anchor_bboxes[inside_indices].view(batch_size, -1, anchor_bboxes.shape[2])
inside_anchor_objectnesses = anchor_objectnesses[inside_indices].view(batch_size, -1, anchor_objectnesses.shape[2])
inside_anchor_transformers = anchor_transformers[inside_indices].view(batch_size, -1, anchor_transformers.shape[2])
# find labels for each `anchor_bboxes`
labels = torch.full((batch_size, inside_anchor_bboxes.shape[1]), -1, dtype=torch.long, device=inside_anchor_bboxes.device)
ious = BBox.iou(inside_anchor_bboxes, gt_bboxes_batch)
anchor_max_ious, anchor_assignments = ious.max(dim=2)
gt_max_ious, gt_assignments = ious.max(dim=1)
anchor_additions = ((ious > 0) & (ious == gt_max_ious.unsqueeze(dim=1))).nonzero()[:, :2].unbind(dim=1)
labels[anchor_max_ious < 0.3] = 0
labels[anchor_additions] = 1
labels[anchor_max_ious >= 0.7] = 1
# select 256 x `batch_size` samples
fg_indices = (labels == 1).nonzero()
bg_indices = (labels == 0).nonzero()
fg_indices = fg_indices[torch.randperm(len(fg_indices))[:min(len(fg_indices), 256 * batch_size)]]
bg_indices = bg_indices[torch.randperm(len(bg_indices))[:256 * batch_size - len(fg_indices)]]
selected_indices = torch.cat([fg_indices, bg_indices], dim=0)
selected_indices = selected_indices[torch.randperm(len(selected_indices))].unbind(dim=1)
inside_anchor_bboxes = inside_anchor_bboxes[selected_indices]
gt_bboxes = gt_bboxes_batch[selected_indices[0], anchor_assignments[selected_indices]]
gt_anchor_objectnesses = labels[selected_indices]
gt_anchor_transformers = BBox.calc_transformer(inside_anchor_bboxes, gt_bboxes)
batch_indices = selected_indices[0]
anchor_objectness_losses, anchor_transformer_losses = self.loss(inside_anchor_objectnesses[selected_indices],
inside_anchor_transformers[selected_indices],
gt_anchor_objectnesses,
gt_anchor_transformers,
batch_size, batch_indices)
return anchor_objectnesses, anchor_transformers, anchor_objectness_losses, anchor_transformer_losses
def loss(self, anchor_objectnesses: Tensor, anchor_transformers: Tensor,
gt_anchor_objectnesses: Tensor, gt_anchor_transformers: Tensor,
batch_size: int, batch_indices: Tensor) -> Tuple[Tensor, Tensor]:
cross_entropies = torch.empty(batch_size, dtype=torch.float, device=anchor_objectnesses.device)
smooth_l1_losses = torch.empty(batch_size, dtype=torch.float, device=anchor_transformers.device)
for batch_index in range(batch_size):
selected_indices = (batch_indices == batch_index).nonzero().view(-1)
cross_entropy = F.cross_entropy(input=anchor_objectnesses[selected_indices],
target=gt_anchor_objectnesses[selected_indices])
fg_indices = gt_anchor_objectnesses[selected_indices].nonzero().view(-1)
smooth_l1_loss = beta_smooth_l1_loss(input=anchor_transformers[selected_indices][fg_indices],
target=gt_anchor_transformers[selected_indices][fg_indices],
beta=self._anchor_smooth_l1_loss_beta)
cross_entropies[batch_index] = cross_entropy
smooth_l1_losses[batch_index] = smooth_l1_loss
return cross_entropies, smooth_l1_losses
def generate_anchors(self, image_width: int, image_height: int, num_x_anchors: int, num_y_anchors: int) -> Tensor:
center_ys = np.linspace(start=0, stop=image_height, num=num_y_anchors + 2)[1:-1]
center_xs = np.linspace(start=0, stop=image_width, num=num_x_anchors + 2)[1:-1]
ratios = np.array(self._anchor_ratios)
ratios = ratios[:, 0] / ratios[:, 1]
sizes = np.array(self._anchor_sizes)
# NOTE: it's important to let `center_ys` be the major index (i.e., move horizontally and then vertically) for consistency with 2D convolution
# giving the string 'ij' returns a meshgrid with matrix indexing, i.e., with shape (#center_ys, #center_xs, #ratios)
center_ys, center_xs, ratios, sizes = np.meshgrid(center_ys, center_xs, ratios, sizes, indexing='ij')
center_ys = center_ys.reshape(-1)
center_xs = center_xs.reshape(-1)
ratios = ratios.reshape(-1)
sizes = sizes.reshape(-1)
widths = sizes * np.sqrt(1 / ratios)
heights = sizes * np.sqrt(ratios)
center_based_anchor_bboxes = np.stack((center_xs, center_ys, widths, heights), axis=1)
center_based_anchor_bboxes = torch.from_numpy(center_based_anchor_bboxes).float()
anchor_bboxes = BBox.from_center_base(center_based_anchor_bboxes)
return anchor_bboxes
def generate_proposals(self, anchor_bboxes: Tensor, objectnesses: Tensor, transformers: Tensor, image_width: int, image_height: int) -> Tensor:
batch_size = anchor_bboxes.shape[0]
proposal_bboxes = BBox.apply_transformer(anchor_bboxes, transformers)
proposal_bboxes = BBox.clip(proposal_bboxes, left=0, top=0, right=image_width, bottom=image_height)
proposal_probs = F.softmax(objectnesses[:, :, 1], dim=-1)
_, sorted_indices = torch.sort(proposal_probs, dim=-1, descending=True)
nms_proposal_bboxes_batch = []
for batch_index in range(batch_size):
sorted_bboxes = proposal_bboxes[batch_index][sorted_indices[batch_index]][:self._pre_nms_top_n]
sorted_probs = proposal_probs[batch_index][sorted_indices[batch_index]][:self._pre_nms_top_n]
threshold = 0.7
kept_indices = nms(sorted_bboxes, sorted_probs, threshold)
nms_bboxes = sorted_bboxes[kept_indices][:self._post_nms_top_n]
nms_proposal_bboxes_batch.append(nms_bboxes)
max_nms_proposal_bboxes_length = max([len(it) for it in nms_proposal_bboxes_batch])
padded_proposal_bboxes = []
for nms_proposal_bboxes in nms_proposal_bboxes_batch:
padded_proposal_bboxes.append(
torch.cat([
nms_proposal_bboxes,
torch.zeros(max_nms_proposal_bboxes_length - len(nms_proposal_bboxes), 4).to(nms_proposal_bboxes)
])
)
padded_proposal_bboxes = torch.stack(padded_proposal_bboxes, dim=0)
return padded_proposal_bboxes