Spaces:
Sleeping
Sleeping
| # This file contains the function to generate the center coordinates as tensor for the current net. | |
| import torch | |
| def landmark_coordinates(maps, grid_x=None, grid_y=None): | |
| """ | |
| Generate the center coordinates as tensor for the current net. | |
| Modified from: https://github.com/robertdvdk/part_detection/blob/eec53f2f40602113f74c6c1f60a2034823b0fcaf/lib.py#L19 | |
| Parameters | |
| ---------- | |
| maps: torch.Tensor | |
| Attention map with shape (batch_size, channels, height, width) where channels is the landmark probability | |
| grid_x: torch.Tensor | |
| The grid x coordinates | |
| grid_y: torch.Tensor | |
| The grid y coordinates | |
| Returns | |
| ---------- | |
| loc_x: Tensor | |
| The centroid x coordinates | |
| loc_y: Tensor | |
| The centroid y coordinates | |
| grid_x: Tensor | |
| grid_y: Tensor | |
| """ | |
| return_grid = False | |
| if grid_x is None or grid_y is None: | |
| return_grid = True | |
| grid_x, grid_y = torch.meshgrid(torch.arange(maps.shape[2]), | |
| torch.arange(maps.shape[3]), indexing='ij') | |
| grid_x = grid_x.unsqueeze(0).unsqueeze(0).contiguous().to(maps.device, non_blocking=True) | |
| grid_y = grid_y.unsqueeze(0).unsqueeze(0).contiguous().to(maps.device, non_blocking=True) | |
| map_sums = maps.sum(3).sum(2).detach() | |
| maps_x = grid_x * maps | |
| maps_y = grid_y * maps | |
| loc_x = maps_x.sum(3).sum(2) / map_sums | |
| loc_y = maps_y.sum(3).sum(2) / map_sums | |
| if return_grid: | |
| return loc_x, loc_y, grid_x, grid_y | |
| else: | |
| return loc_x, loc_y | |