Spaces:
Running
Running
File size: 1,555 Bytes
20239f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
# This file contains the function to generate the center coordinates as tensor for the current net.
import torch
def landmark_coordinates(maps, grid_x=None, grid_y=None):
"""
Generate the center coordinates as tensor for the current net.
Modified from: https://github.com/robertdvdk/part_detection/blob/eec53f2f40602113f74c6c1f60a2034823b0fcaf/lib.py#L19
Parameters
----------
maps: torch.Tensor
Attention map with shape (batch_size, channels, height, width) where channels is the landmark probability
grid_x: torch.Tensor
The grid x coordinates
grid_y: torch.Tensor
The grid y coordinates
Returns
----------
loc_x: Tensor
The centroid x coordinates
loc_y: Tensor
The centroid y coordinates
grid_x: Tensor
grid_y: Tensor
"""
return_grid = False
if grid_x is None or grid_y is None:
return_grid = True
grid_x, grid_y = torch.meshgrid(torch.arange(maps.shape[2]),
torch.arange(maps.shape[3]), indexing='ij')
grid_x = grid_x.unsqueeze(0).unsqueeze(0).contiguous().to(maps.device, non_blocking=True)
grid_y = grid_y.unsqueeze(0).unsqueeze(0).contiguous().to(maps.device, non_blocking=True)
map_sums = maps.sum(3).sum(2).detach()
maps_x = grid_x * maps
maps_y = grid_y * maps
loc_x = maps_x.sum(3).sum(2) / map_sums
loc_y = maps_y.sum(3).sum(2) / map_sums
if return_grid:
return loc_x, loc_y, grid_x, grid_y
else:
return loc_x, loc_y
|