Spaces:
Running
Running
| # Copyright 2020 Toyota Research Institute. All rights reserved. | |
| import numpy as np | |
| import torch | |
| import torchvision.transforms as transforms | |
| from tqdm import tqdm | |
| from evaluation.descriptor_evaluation import (compute_homography, | |
| compute_matching_score) | |
| from evaluation.detector_evaluation import compute_repeatability | |
| def evaluate_keypoint_net(data_loader, keypoint_net, output_shape=(320, 240), top_k=300): | |
| """Keypoint net evaluation script. | |
| Parameters | |
| ---------- | |
| data_loader: torch.utils.data.DataLoader | |
| Dataset loader. | |
| keypoint_net: torch.nn.module | |
| Keypoint network. | |
| output_shape: tuple | |
| Original image shape. | |
| top_k: int | |
| Number of keypoints to use to compute metrics, selected based on probability. | |
| use_color: bool | |
| Use color or grayscale images. | |
| """ | |
| keypoint_net.eval() | |
| keypoint_net.training = False | |
| conf_threshold = 0.0 | |
| localization_err, repeatability = [], [] | |
| correctness1, correctness3, correctness5, MScore = [], [], [], [] | |
| with torch.no_grad(): | |
| for i, sample in tqdm(enumerate(data_loader), desc="Evaluate point model"): | |
| image = sample['image'].cuda() | |
| warped_image = sample['warped_image'].cuda() | |
| score_1, coord_1, desc1 = keypoint_net(image) | |
| score_2, coord_2, desc2 = keypoint_net(warped_image) | |
| B, _, Hc, Wc = desc1.shape | |
| # Scores & Descriptors | |
| score_1 = torch.cat([coord_1, score_1], dim=1).view(3, -1).t().cpu().numpy() | |
| score_2 = torch.cat([coord_2, score_2], dim=1).view(3, -1).t().cpu().numpy() | |
| desc1 = desc1.view(256, Hc, Wc).view(256, -1).t().cpu().numpy() | |
| desc2 = desc2.view(256, Hc, Wc).view(256, -1).t().cpu().numpy() | |
| # Filter based on confidence threshold | |
| desc1 = desc1[score_1[:, 2] > conf_threshold, :] | |
| desc2 = desc2[score_2[:, 2] > conf_threshold, :] | |
| score_1 = score_1[score_1[:, 2] > conf_threshold, :] | |
| score_2 = score_2[score_2[:, 2] > conf_threshold, :] | |
| # Prepare data for eval | |
| data = {'image': sample['image'].numpy().squeeze(), | |
| 'image_shape' : output_shape[::-1], | |
| 'warped_image': sample['warped_image'].numpy().squeeze(), | |
| 'homography': sample['homography'].squeeze().numpy(), | |
| 'prob': score_1, | |
| 'warped_prob': score_2, | |
| 'desc': desc1, | |
| 'warped_desc': desc2} | |
| # Compute repeatabilty and localization error | |
| _, _, rep, loc_err = compute_repeatability(data, keep_k_points=top_k, distance_thresh=3) | |
| repeatability.append(rep) | |
| localization_err.append(loc_err) | |
| # Compute correctness | |
| c1, c2, c3 = compute_homography(data, keep_k_points=top_k) | |
| correctness1.append(c1) | |
| correctness3.append(c2) | |
| correctness5.append(c3) | |
| # Compute matching score | |
| mscore = compute_matching_score(data, keep_k_points=top_k) | |
| MScore.append(mscore) | |
| return np.mean(repeatability), np.mean(localization_err), \ | |
| np.mean(correctness1), np.mean(correctness3), np.mean(correctness5), np.mean(MScore) | |