File size: 1,555 Bytes
20239f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# This file contains the function to generate the center coordinates as tensor for the current net.
import torch


def landmark_coordinates(maps, grid_x=None, grid_y=None):
    """
    Generate the center coordinates as tensor for the current net.
    Modified from: https://github.com/robertdvdk/part_detection/blob/eec53f2f40602113f74c6c1f60a2034823b0fcaf/lib.py#L19
    Parameters
    ----------
    maps: torch.Tensor
        Attention map with shape (batch_size, channels, height, width) where channels is the landmark probability
    grid_x: torch.Tensor
        The grid x coordinates
    grid_y: torch.Tensor
        The grid y coordinates
    Returns
    ----------
    loc_x: Tensor
        The centroid x coordinates
    loc_y: Tensor
        The centroid y coordinates
    grid_x: Tensor
    grid_y: Tensor
    """
    return_grid = False
    if grid_x is None or grid_y is None:
        return_grid = True
        grid_x, grid_y = torch.meshgrid(torch.arange(maps.shape[2]),
                                        torch.arange(maps.shape[3]), indexing='ij')
        grid_x = grid_x.unsqueeze(0).unsqueeze(0).contiguous().to(maps.device, non_blocking=True)
        grid_y = grid_y.unsqueeze(0).unsqueeze(0).contiguous().to(maps.device, non_blocking=True)
    map_sums = maps.sum(3).sum(2).detach()
    maps_x = grid_x * maps
    maps_y = grid_y * maps
    loc_x = maps_x.sum(3).sum(2) / map_sums
    loc_y = maps_y.sum(3).sum(2) / map_sums
    if return_grid:
        return loc_x, loc_y, grid_x, grid_y
    else:
        return loc_x, loc_y