sadimanna's picture
Upload 20 files
d6def08
raw
history blame
1.91 kB
from enum import Enum
import torch
from torch import Tensor
from torch.nn import functional as F
# from support.layer.roi_align import ROIAlign
from torchvision.ops import RoIAlign as ROIAlign
class Pooler(object):
class Mode(Enum):
POOLING = 'pooling'
ALIGN = 'align'
OPTIONS = ['pooling', 'align']
@staticmethod
def apply(features: Tensor, proposal_bboxes: Tensor, proposal_batch_indices: Tensor, mode: Mode) -> Tensor:
_, _, feature_map_height, feature_map_width = features.shape
scale = 1 / 16
output_size = (7 * 2, 7 * 2)
if mode == Pooler.Mode.POOLING:
pool = []
for (proposal_bbox, proposal_batch_index) in zip(proposal_bboxes, proposal_batch_indices):
start_x = max(min(round(proposal_bbox[0].item() * scale), feature_map_width - 1), 0) # [0, feature_map_width)
start_y = max(min(round(proposal_bbox[1].item() * scale), feature_map_height - 1), 0) # (0, feature_map_height]
end_x = max(min(round(proposal_bbox[2].item() * scale) + 1, feature_map_width), 1) # [0, feature_map_width)
end_y = max(min(round(proposal_bbox[3].item() * scale) + 1, feature_map_height), 1) # (0, feature_map_height]
roi_feature_map = features[proposal_batch_index, :, start_y:end_y, start_x:end_x]
pool.append(F.adaptive_max_pool2d(input=roi_feature_map, output_size=output_size))
pool = torch.stack(pool, dim=0)
elif mode == Pooler.Mode.ALIGN:
pool = ROIAlign(output_size, spatial_scale=scale, sampling_ratio=0)(
features,
torch.cat([proposal_batch_indices.view(-1, 1).float(), proposal_bboxes], dim=1)
)
else:
raise ValueError
pool = F.max_pool2d(input=pool, kernel_size=2, stride=2)
return pool