Spaces:
Runtime error
Runtime error
| import collections | |
| import os | |
| from os.path import join | |
| import io | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch.multiprocessing | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import wget | |
| from PIL import Image | |
| from scipy.optimize import linear_sum_assignment | |
| from torch._six import string_classes | |
| from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format | |
| from torchmetrics import Metric | |
| from torchvision import models | |
| from torchvision import transforms as T | |
| from torch.utils.tensorboard.summary import hparams | |
| import matplotlib as mpl | |
| torch.multiprocessing.set_sharing_strategy("file_system") | |
| colors = ("red", "palegreen", "green", "steelblue", "blue", "yellow", "lightgrey") | |
| class_names = ( | |
| "Buildings", | |
| "Cultivation", | |
| "Natural green", | |
| "Wetland", | |
| "Water", | |
| "Infrastructure", | |
| "Background", | |
| ) | |
| bounds = list(np.arange(len(class_names) + 1) + 1) | |
| cmap = mpl.colors.ListedColormap(colors) | |
| norm = mpl.colors.BoundaryNorm(bounds, cmap.N) | |
| def compute_biodiv_score(image): | |
| """Compute the biodiversity score of an image | |
| Args: | |
| image (_type_): _description_ | |
| Returns: | |
| biodiversity_score: the biodiversity score associated to the landscape of the image | |
| """ | |
| pix = np.array(image.getdata()) | |
| return np.mean(pix) | |
| import cv2 | |
| def create_video(array_images, output_path="output.mp4"): | |
| height, width, layers = array_images[0].shape | |
| size = (width,height) | |
| fourcc = cv2.VideoWriter_fourcc(*'VP90') | |
| out = cv2.VideoWriter('output.mp4', fourcc, 2, size) | |
| for i in range(len(array_images)): | |
| out.write(array_images[i]) | |
| out.release() | |
| return out | |
| def transform_to_pil(outputs, alpha=0.3): | |
| """Turn an ouput into a PIL | |
| Args: | |
| outputs (_type_): _description_ | |
| alpha (float, optional): _description_. Defaults to 0.3. | |
| Returns: | |
| _type_: _description_ | |
| """ | |
| # Transform img with torch | |
| img = torch.moveaxis(prep_for_plot(outputs["img"][0]), -1, 0) | |
| img = T.ToPILImage()(img) | |
| # Transform label by saving it then open it | |
| label = outputs["linear_preds"][0].numpy() | |
| # image_label = Image.fromarray(label, mode="P") | |
| plt.imsave("label.png", label, cmap=cmap) | |
| image_label = Image.open("label.png") | |
| # Overlay labels with img wit alpha | |
| background = img.convert("RGBA") | |
| overlay = image_label.convert("RGBA") | |
| labeled_img = Image.blend(background, overlay, alpha) | |
| labeled_img = labeled_img.convert("RGB") | |
| return img, image_label, labeled_img | |
| def prep_for_plot(img, rescale=True, resize=None): | |
| if resize is not None: | |
| img = F.interpolate(img.unsqueeze(0), resize, mode="bilinear") | |
| else: | |
| img = img.unsqueeze(0) | |
| plot_img = unnorm(img).squeeze(0).cpu().permute(1, 2, 0) | |
| if rescale: | |
| plot_img = (plot_img - plot_img.min()) / (plot_img.max() - plot_img.min()) | |
| return plot_img | |
| def add_plot(writer, name, step): | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='jpeg', dpi=100) | |
| buf.seek(0) | |
| image = Image.open(buf) | |
| image = T.ToTensor()(image) | |
| writer.add_image(name, image, step) | |
| plt.clf() | |
| plt.close() | |
| def shuffle(x): | |
| return x[torch.randperm(x.shape[0])] | |
| def add_hparams_fixed(writer, hparam_dict, metric_dict, global_step): | |
| exp, ssi, sei = hparams(hparam_dict, metric_dict) | |
| writer.file_writer.add_summary(exp) | |
| writer.file_writer.add_summary(ssi) | |
| writer.file_writer.add_summary(sei) | |
| for k, v in metric_dict.items(): | |
| writer.add_scalar(k, v, global_step) | |
| def resize(classes: torch.Tensor, size: int): | |
| return F.interpolate(classes, (size, size), mode="bilinear", align_corners=False) | |
| def one_hot_feats(labels, n_classes): | |
| return F.one_hot(labels, n_classes).permute(0, 3, 1, 2).to(torch.float32) | |
| def load_model(model_type, data_dir): | |
| if model_type == "robust_resnet50": | |
| model = models.resnet50(pretrained=False) | |
| model_file = join(data_dir, 'imagenet_l2_3_0.pt') | |
| if not os.path.exists(model_file): | |
| wget.download("http://6.869.csail.mit.edu/fa19/psets19/pset6/imagenet_l2_3_0.pt", | |
| model_file) | |
| model_weights = torch.load(model_file) | |
| model_weights_modified = {name.split('model.')[1]: value for name, value in model_weights['model'].items() if | |
| 'model' in name} | |
| model.load_state_dict(model_weights_modified) | |
| model = nn.Sequential(*list(model.children())[:-1]) | |
| elif model_type == "densecl": | |
| model = models.resnet50(pretrained=False) | |
| model_file = join(data_dir, 'densecl_r50_coco_1600ep.pth') | |
| if not os.path.exists(model_file): | |
| wget.download("https://cloudstor.aarnet.edu.au/plus/s/3GapXiWuVAzdKwJ/download", | |
| model_file) | |
| model_weights = torch.load(model_file) | |
| # model_weights_modified = {name.split('model.')[1]: value for name, value in model_weights['model'].items() if | |
| # 'model' in name} | |
| model.load_state_dict(model_weights['state_dict'], strict=False) | |
| model = nn.Sequential(*list(model.children())[:-1]) | |
| elif model_type == "resnet50": | |
| model = models.resnet50(pretrained=True) | |
| model = nn.Sequential(*list(model.children())[:-1]) | |
| elif model_type == "mocov2": | |
| model = models.resnet50(pretrained=False) | |
| model_file = join(data_dir, 'moco_v2_800ep_pretrain.pth.tar') | |
| if not os.path.exists(model_file): | |
| wget.download("https://dl.fbaipublicfiles.com/moco/moco_checkpoints/" | |
| "moco_v2_800ep/moco_v2_800ep_pretrain.pth.tar", model_file) | |
| checkpoint = torch.load(model_file) | |
| # rename moco pre-trained keys | |
| state_dict = checkpoint['state_dict'] | |
| for k in list(state_dict.keys()): | |
| # retain only encoder_q up to before the embedding layer | |
| if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): | |
| # remove prefix | |
| state_dict[k[len("module.encoder_q."):]] = state_dict[k] | |
| # delete renamed or unused k | |
| del state_dict[k] | |
| msg = model.load_state_dict(state_dict, strict=False) | |
| assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} | |
| model = nn.Sequential(*list(model.children())[:-1]) | |
| elif model_type == "densenet121": | |
| model = models.densenet121(pretrained=True) | |
| model = nn.Sequential(*list(model.children())[:-1] + [nn.AdaptiveAvgPool2d((1, 1))]) | |
| elif model_type == "vgg11": | |
| model = models.vgg11(pretrained=True) | |
| model = nn.Sequential(*list(model.children())[:-1] + [nn.AdaptiveAvgPool2d((1, 1))]) | |
| else: | |
| raise ValueError("No model: {} found".format(model_type)) | |
| model.eval() | |
| model.cuda() | |
| return model | |
| class UnNormalize(object): | |
| def __init__(self, mean, std): | |
| self.mean = mean | |
| self.std = std | |
| def __call__(self, image): | |
| image2 = torch.clone(image) | |
| for t, m, s in zip(image2, self.mean, self.std): | |
| t.mul_(s).add_(m) | |
| return image2 | |
| normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| unnorm = UnNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| class ToTargetTensor(object): | |
| def __call__(self, target): | |
| return torch.as_tensor(np.array(target), dtype=torch.int64).unsqueeze(0) | |
| def prep_args(): | |
| import sys | |
| old_args = sys.argv | |
| new_args = [old_args.pop(0)] | |
| while len(old_args) > 0: | |
| arg = old_args.pop(0) | |
| if len(arg.split("=")) == 2: | |
| new_args.append(arg) | |
| elif arg.startswith("--"): | |
| new_args.append(arg[2:] + "=" + old_args.pop(0)) | |
| else: | |
| raise ValueError("Unexpected arg style {}".format(arg)) | |
| sys.argv = new_args | |
| def get_transform(res, is_label, crop_type): | |
| if crop_type == "center": | |
| cropper = T.CenterCrop(res) | |
| elif crop_type == "random": | |
| cropper = T.RandomCrop(res) | |
| elif crop_type is None: | |
| cropper = T.Lambda(lambda x: x) | |
| res = (res, res) | |
| else: | |
| raise ValueError("Unknown Cropper {}".format(crop_type)) | |
| if is_label: | |
| return T.Compose([T.Resize(res, Image.NEAREST), | |
| cropper, | |
| ToTargetTensor()]) | |
| else: | |
| return T.Compose([T.Resize(res, Image.NEAREST), | |
| cropper, | |
| T.ToTensor(), | |
| normalize]) | |
| def _remove_axes(ax): | |
| ax.xaxis.set_major_formatter(plt.NullFormatter()) | |
| ax.yaxis.set_major_formatter(plt.NullFormatter()) | |
| ax.set_xticks([]) | |
| ax.set_yticks([]) | |
| def remove_axes(axes): | |
| if len(axes.shape) == 2: | |
| for ax1 in axes: | |
| for ax in ax1: | |
| _remove_axes(ax) | |
| else: | |
| for ax in axes: | |
| _remove_axes(ax) | |
| class UnsupervisedMetrics(Metric): | |
| def __init__(self, prefix: str, n_classes: int, extra_clusters: int, compute_hungarian: bool, | |
| dist_sync_on_step=True): | |
| # call `self.add_state`for every internal state that is needed for the metrics computations | |
| # dist_reduce_fx indicates the function that should be used to reduce | |
| # state from multiple processes | |
| super().__init__(dist_sync_on_step=dist_sync_on_step) | |
| self.n_classes = n_classes | |
| self.extra_clusters = extra_clusters | |
| self.compute_hungarian = compute_hungarian | |
| self.prefix = prefix | |
| self.add_state("stats", | |
| default=torch.zeros(n_classes + self.extra_clusters, n_classes, dtype=torch.int64), | |
| dist_reduce_fx="sum") | |
| def update(self, preds: torch.Tensor, target: torch.Tensor): | |
| with torch.no_grad(): | |
| actual = target.reshape(-1) | |
| preds = preds.reshape(-1) | |
| mask = (actual >= 0) & (actual < self.n_classes) & (preds >= 0) & (preds < self.n_classes) | |
| actual = actual[mask] | |
| preds = preds[mask] | |
| self.stats += torch.bincount( | |
| (self.n_classes + self.extra_clusters) * actual + preds, | |
| minlength=self.n_classes * (self.n_classes + self.extra_clusters)) \ | |
| .reshape(self.n_classes, self.n_classes + self.extra_clusters).t().to(self.stats.device) | |
| def map_clusters(self, clusters): | |
| if self.extra_clusters == 0: | |
| return torch.tensor(self.assignments[1])[clusters] | |
| else: | |
| missing = sorted(list(set(range(self.n_classes + self.extra_clusters)) - set(self.assignments[0]))) | |
| cluster_to_class = self.assignments[1] | |
| for missing_entry in missing: | |
| if missing_entry == cluster_to_class.shape[0]: | |
| cluster_to_class = np.append(cluster_to_class, -1) | |
| else: | |
| cluster_to_class = np.insert(cluster_to_class, missing_entry + 1, -1) | |
| cluster_to_class = torch.tensor(cluster_to_class) | |
| return cluster_to_class[clusters] | |
| def compute(self): | |
| if self.compute_hungarian: | |
| self.assignments = linear_sum_assignment(self.stats.detach().cpu(), maximize=True) | |
| # print(self.assignments) | |
| if self.extra_clusters == 0: | |
| self.histogram = self.stats[np.argsort(self.assignments[1]), :] | |
| if self.extra_clusters > 0: | |
| self.assignments_t = linear_sum_assignment(self.stats.detach().cpu().t(), maximize=True) | |
| histogram = self.stats[self.assignments_t[1], :] | |
| missing = list(set(range(self.n_classes + self.extra_clusters)) - set(self.assignments[0])) | |
| new_row = self.stats[missing, :].sum(0, keepdim=True) | |
| histogram = torch.cat([histogram, new_row], axis=0) | |
| new_col = torch.zeros(self.n_classes + 1, 1, device=histogram.device) | |
| self.histogram = torch.cat([histogram, new_col], axis=1) | |
| else: | |
| self.assignments = (torch.arange(self.n_classes).unsqueeze(1), | |
| torch.arange(self.n_classes).unsqueeze(1)) | |
| self.histogram = self.stats | |
| tp = torch.diag(self.histogram) | |
| fp = torch.sum(self.histogram, dim=0) - tp | |
| fn = torch.sum(self.histogram, dim=1) - tp | |
| iou = tp / (tp + fp + fn) | |
| prc = tp / (tp + fn) | |
| opc = torch.sum(tp) / torch.sum(self.histogram) | |
| metric_dict = {self.prefix + "mIoU": iou[~torch.isnan(iou)].mean().item(), | |
| self.prefix + "Accuracy": opc.item()} | |
| return {k: 100 * v for k, v in metric_dict.items()} | |
| def flexible_collate(batch): | |
| r"""Puts each data field into a tensor with outer dimension batch size""" | |
| elem = batch[0] | |
| elem_type = type(elem) | |
| if isinstance(elem, torch.Tensor): | |
| out = None | |
| if torch.utils.data.get_worker_info() is not None: | |
| # If we're in a background process, concatenate directly into a | |
| # shared memory tensor to avoid an extra copy | |
| numel = sum([x.numel() for x in batch]) | |
| storage = elem.storage()._new_shared(numel) | |
| out = elem.new(storage) | |
| try: | |
| return torch.stack(batch, 0, out=out) | |
| except RuntimeError: | |
| return batch | |
| elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ | |
| and elem_type.__name__ != 'string_': | |
| if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': | |
| # array of string classes and object | |
| if np_str_obj_array_pattern.search(elem.dtype.str) is not None: | |
| raise TypeError(default_collate_err_msg_format.format(elem.dtype)) | |
| return flexible_collate([torch.as_tensor(b) for b in batch]) | |
| elif elem.shape == (): # scalars | |
| return torch.as_tensor(batch) | |
| elif isinstance(elem, float): | |
| return torch.tensor(batch, dtype=torch.float64) | |
| elif isinstance(elem, int): | |
| return torch.tensor(batch) | |
| elif isinstance(elem, string_classes): | |
| return batch | |
| elif isinstance(elem, collections.abc.Mapping): | |
| return {key: flexible_collate([d[key] for d in batch]) for key in elem} | |
| elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple | |
| return elem_type(*(flexible_collate(samples) for samples in zip(*batch))) | |
| elif isinstance(elem, collections.abc.Sequence): | |
| # check to make sure that the elements in batch have consistent size | |
| it = iter(batch) | |
| elem_size = len(next(it)) | |
| if not all(len(elem) == elem_size for elem in it): | |
| raise RuntimeError('each element in list of batch should be of equal size') | |
| transposed = zip(*batch) | |
| return [flexible_collate(samples) for samples in transposed] | |
| raise TypeError(default_collate_err_msg_format.format(elem_type)) | |