Spaces:
Build error
Build error
| import logging | |
| import math | |
| from typing import Dict | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import tqdm | |
| from torch.utils.data import DataLoader | |
| from saicinpainting.evaluation.utils import move_to_device | |
| LOGGER = logging.getLogger(__name__) | |
| class InpaintingEvaluator(): | |
| def __init__(self, dataset, scores, area_grouping=True, bins=10, batch_size=32, device='cuda', | |
| integral_func=None, integral_title=None, clamp_image_range=None): | |
| """ | |
| :param dataset: torch.utils.data.Dataset which contains images and masks | |
| :param scores: dict {score_name: EvaluatorScore object} | |
| :param area_grouping: in addition to the overall scores, allows to compute score for the groups of samples | |
| which are defined by share of area occluded by mask | |
| :param bins: number of groups, partition is generated by np.linspace(0., 1., bins + 1) | |
| :param batch_size: batch_size for the dataloader | |
| :param device: device to use | |
| """ | |
| self.scores = scores | |
| self.dataset = dataset | |
| self.area_grouping = area_grouping | |
| self.bins = bins | |
| self.device = torch.device(device) | |
| self.dataloader = DataLoader(self.dataset, shuffle=False, batch_size=batch_size) | |
| self.integral_func = integral_func | |
| self.integral_title = integral_title | |
| self.clamp_image_range = clamp_image_range | |
| def _get_bin_edges(self): | |
| bin_edges = np.linspace(0, 1, self.bins + 1) | |
| num_digits = max(0, math.ceil(math.log10(self.bins)) - 1) | |
| interval_names = [] | |
| for idx_bin in range(self.bins): | |
| start_percent, end_percent = round(100 * bin_edges[idx_bin], num_digits), \ | |
| round(100 * bin_edges[idx_bin + 1], num_digits) | |
| start_percent = '{:.{n}f}'.format(start_percent, n=num_digits) | |
| end_percent = '{:.{n}f}'.format(end_percent, n=num_digits) | |
| interval_names.append("{0}-{1}%".format(start_percent, end_percent)) | |
| groups = [] | |
| for batch in self.dataloader: | |
| mask = batch['mask'] | |
| batch_size = mask.shape[0] | |
| area = mask.to(self.device).reshape(batch_size, -1).mean(dim=-1) | |
| bin_indices = np.searchsorted(bin_edges, area.detach().cpu().numpy(), side='right') - 1 | |
| # corner case: when area is equal to 1, bin_indices should return bins - 1, not bins for that element | |
| bin_indices[bin_indices == self.bins] = self.bins - 1 | |
| groups.append(bin_indices) | |
| groups = np.hstack(groups) | |
| return groups, interval_names | |
| def evaluate(self, model=None): | |
| """ | |
| :param model: callable with signature (image_batch, mask_batch); should return inpainted_batch | |
| :return: dict with (score_name, group_type) as keys, where group_type can be either 'overall' or | |
| name of the particular group arranged by area of mask (e.g. '10-20%') | |
| and score statistics for the group as values. | |
| """ | |
| results = dict() | |
| if self.area_grouping: | |
| groups, interval_names = self._get_bin_edges() | |
| else: | |
| groups = None | |
| for score_name, score in tqdm.auto.tqdm(self.scores.items(), desc='scores'): | |
| score.to(self.device) | |
| with torch.no_grad(): | |
| score.reset() | |
| for batch in tqdm.auto.tqdm(self.dataloader, desc=score_name, leave=False): | |
| batch = move_to_device(batch, self.device) | |
| image_batch, mask_batch = batch['image'], batch['mask'] | |
| if self.clamp_image_range is not None: | |
| image_batch = torch.clamp(image_batch, | |
| min=self.clamp_image_range[0], | |
| max=self.clamp_image_range[1]) | |
| if model is None: | |
| assert 'inpainted' in batch, \ | |
| 'Model is None, so we expected precomputed inpainting results at key "inpainted"' | |
| inpainted_batch = batch['inpainted'] | |
| else: | |
| inpainted_batch = model(image_batch, mask_batch) | |
| score(inpainted_batch, image_batch, mask_batch) | |
| total_results, group_results = score.get_value(groups=groups) | |
| results[(score_name, 'total')] = total_results | |
| if groups is not None: | |
| for group_index, group_values in group_results.items(): | |
| group_name = interval_names[group_index] | |
| results[(score_name, group_name)] = group_values | |
| if self.integral_func is not None: | |
| results[(self.integral_title, 'total')] = dict(mean=self.integral_func(results)) | |
| return results | |
| def ssim_fid100_f1(metrics, fid_scale=100): | |
| ssim = metrics[('ssim', 'total')]['mean'] | |
| fid = metrics[('fid', 'total')]['mean'] | |
| fid_rel = max(0, fid_scale - fid) / fid_scale | |
| f1 = 2 * ssim * fid_rel / (ssim + fid_rel + 1e-3) | |
| return f1 | |
| def lpips_fid100_f1(metrics, fid_scale=100): | |
| neg_lpips = 1 - metrics[('lpips', 'total')]['mean'] # invert, so bigger is better | |
| fid = metrics[('fid', 'total')]['mean'] | |
| fid_rel = max(0, fid_scale - fid) / fid_scale | |
| f1 = 2 * neg_lpips * fid_rel / (neg_lpips + fid_rel + 1e-3) | |
| return f1 | |
| class InpaintingEvaluatorOnline(nn.Module): | |
| def __init__(self, scores, bins=10, image_key='image', inpainted_key='inpainted', | |
| integral_func=None, integral_title=None, clamp_image_range=None): | |
| """ | |
| :param scores: dict {score_name: EvaluatorScore object} | |
| :param bins: number of groups, partition is generated by np.linspace(0., 1., bins + 1) | |
| :param device: device to use | |
| """ | |
| super().__init__() | |
| LOGGER.info(f'{type(self)} init called') | |
| self.scores = nn.ModuleDict(scores) | |
| self.image_key = image_key | |
| self.inpainted_key = inpainted_key | |
| self.bins_num = bins | |
| self.bin_edges = np.linspace(0, 1, self.bins_num + 1) | |
| num_digits = max(0, math.ceil(math.log10(self.bins_num)) - 1) | |
| self.interval_names = [] | |
| for idx_bin in range(self.bins_num): | |
| start_percent, end_percent = round(100 * self.bin_edges[idx_bin], num_digits), \ | |
| round(100 * self.bin_edges[idx_bin + 1], num_digits) | |
| start_percent = '{:.{n}f}'.format(start_percent, n=num_digits) | |
| end_percent = '{:.{n}f}'.format(end_percent, n=num_digits) | |
| self.interval_names.append("{0}-{1}%".format(start_percent, end_percent)) | |
| self.groups = [] | |
| self.integral_func = integral_func | |
| self.integral_title = integral_title | |
| self.clamp_image_range = clamp_image_range | |
| LOGGER.info(f'{type(self)} init done') | |
| def _get_bins(self, mask_batch): | |
| batch_size = mask_batch.shape[0] | |
| area = mask_batch.view(batch_size, -1).mean(dim=-1).detach().cpu().numpy() | |
| bin_indices = np.clip(np.searchsorted(self.bin_edges, area) - 1, 0, self.bins_num - 1) | |
| return bin_indices | |
| def forward(self, batch: Dict[str, torch.Tensor]): | |
| """ | |
| Calculate and accumulate metrics for batch. To finalize evaluation and obtain final metrics, call evaluation_end | |
| :param batch: batch dict with mandatory fields mask, image, inpainted (can be overriden by self.inpainted_key) | |
| """ | |
| result = {} | |
| with torch.no_grad(): | |
| image_batch, mask_batch, inpainted_batch = batch[self.image_key], batch['mask'], batch[self.inpainted_key] | |
| if self.clamp_image_range is not None: | |
| image_batch = torch.clamp(image_batch, | |
| min=self.clamp_image_range[0], | |
| max=self.clamp_image_range[1]) | |
| self.groups.extend(self._get_bins(mask_batch)) | |
| for score_name, score in self.scores.items(): | |
| result[score_name] = score(inpainted_batch, image_batch, mask_batch) | |
| return result | |
| def process_batch(self, batch: Dict[str, torch.Tensor]): | |
| return self(batch) | |
| def evaluation_end(self, states=None): | |
| """:return: dict with (score_name, group_type) as keys, where group_type can be either 'overall' or | |
| name of the particular group arranged by area of mask (e.g. '10-20%') | |
| and score statistics for the group as values. | |
| """ | |
| LOGGER.info(f'{type(self)}: evaluation_end called') | |
| self.groups = np.array(self.groups) | |
| results = {} | |
| for score_name, score in self.scores.items(): | |
| LOGGER.info(f'Getting value of {score_name}') | |
| cur_states = [s[score_name] for s in states] if states is not None else None | |
| total_results, group_results = score.get_value(groups=self.groups, states=cur_states) | |
| LOGGER.info(f'Getting value of {score_name} done') | |
| results[(score_name, 'total')] = total_results | |
| for group_index, group_values in group_results.items(): | |
| group_name = self.interval_names[group_index] | |
| results[(score_name, group_name)] = group_values | |
| if self.integral_func is not None: | |
| results[(self.integral_title, 'total')] = dict(mean=self.integral_func(results)) | |
| LOGGER.info(f'{type(self)}: reset scores') | |
| self.groups = [] | |
| for sc in self.scores.values(): | |
| sc.reset() | |
| LOGGER.info(f'{type(self)}: reset scores done') | |
| LOGGER.info(f'{type(self)}: evaluation_end done') | |
| return results | |