Biomap / biomap /utils copy.py
jeremyLE-Ekimetrics's picture
streamlit
9fcd62f
import collections
import os
from os.path import join
import io
import matplotlib.pyplot as plt
import numpy as np
import torch.multiprocessing
import torch.nn as nn
import torch.nn.functional as F
import wget
import datetime
from dateutil.relativedelta import relativedelta
from PIL import Image
from scipy.optimize import linear_sum_assignment
from torch._six import string_classes
from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format
from torchmetrics import Metric
from torchvision import models
from torchvision import transforms as T
from torch.utils.tensorboard.summary import hparams
import matplotlib as mpl
from PIL import Image
import matplotlib as mpl
import torch.multiprocessing
import torchvision.transforms as T
import plotly.graph_objects as go
import plotly.express as px
import numpy as np
from plotly.subplots import make_subplots
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
colors = ("red", "palegreen", "green", "steelblue", "blue", "yellow", "lightgrey")
class_names = ('Buildings', 'Cultivation', 'Natural green', 'Wetland', 'Water', 'Infrastructure', 'Background')
mapping_class = {
"Buildings": 1,
"Cultivation": 2,
"Natural green": 3,
"Wetland": 4,
"Water": 5,
"Infrastructure": 6,
"Background": 0,
}
score_attribution = {
"Buildings" : 0.,
"Cultivation": 0.3,
"Natural green": 1.,
"Wetland": 0.9,
"Water": 0.9,
"Infrastructure": 0.,
"Background": 0.
}
bounds = list(np.arange(len(mapping_class.keys()) + 1) + 1)
cmap = mpl.colors.ListedColormap(colors)
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
def compute_biodiv_score(class_image):
"""Compute the biodiversity score of an image
Args:
image (_type_): _description_
Returns:
biodiversity_score: the biodiversity score associated to the landscape of the image
"""
score_matrice = class_image.copy().astype(int)
for key in mapping_class.keys():
score_matrice = np.where(score_matrice==mapping_class[key], score_attribution[key], score_matrice)
number_of_pixel = np.prod(list(score_matrice.shape))
score = np.sum(score_matrice)/number_of_pixel
score_details = {
key: np.sum(np.where(class_image == mapping_class[key], 1, 0))
for key in mapping_class.keys()
if key not in ["background"]
}
return score, score_details
def plot_imgs_labels(months, imgs, imgs_label, nb_values, scores) :
scores = [0.89, 0.70, 0.3, 0.2]
# fig2 = px.imshow(np.array(imgs), animation_frame=0, binary_string=True)
# fig3 = px.imshow(np.array(imgs_label), animation_frame=0, binary_string=True)
# # Scores
# scatters = [go.Scatter(
# x=months[:i+1],
# y=scores[:i+1],
# mode="lines+markers+text",
# marker_color="black",
# text = [f"{score:.4f}" for score in scores[:i+1]],
# textposition="top center",
# ) for i in range(len(scores))]
# fig.add_trace(fig2["frames"][0]["data"][0], row=1, col=1)
# fig.add_trace(fig3["frames"][0]["data"][0], row=1, col=2)
# fig.add_trace(go.Pie(labels = class_names,
# values = [nb_values[0][key] for key in mapping_class.keys()],
# marker_colors = colors,
# name="Segment repartition",
# textposition='inside',
# texttemplate = "%{percent:.0%}",
# textfont_size=14
# ),
# row=1, col=3)
# fig.add_trace(scatters[0], row=1, col=4)
# # fig.update_traces(selector=dict(type='scatter'))
# number_frames = len(imgs)
# frames = [dict(
# name = k,
# data = [ fig2["frames"][k]["data"][0],
# fig3["frames"][k]["data"][0],
# go.Pie(labels = class_names,
# values = [nb_values[k][key] for key in mapping_class.keys()],
# marker_colors = colors,
# name="Segment repartition",
# textposition='inside',
# texttemplate = "%{percent:.0%}",
# textfont_size=14
# ),
# scatters[k]
# ],
# traces=[0, 1, 2, 3]
# ) for k in range(number_frames)]
# updatemenus = [dict(type='buttons',
# buttons=[dict(
# label='Play',
# method='animate',
# args=[
# [f'{k}' for k in range(number_frames)],
# dict(
# frame=dict(duration=500, redraw=False),
# transition=dict(duration=0),
# # easing='linear',
# # fromcurrent=True,
# # mode='immediate'
# )
# ])
# ],
# direction= 'left',
# pad=dict(r= 10, t=85),
# showactive=True, x= 0.1, y= 0.1, xanchor= 'right', yanchor= 'bottom')
# ]
# sliders = [{'yanchor': 'top',
# 'xanchor': 'left',
# 'currentvalue': {'font': {'size': 16}, 'prefix': 'Frame: ', 'visible': False, 'xanchor': 'right'},
# 'transition': {'duration': 500.0, 'easing': 'linear'},
# 'pad': {'b': 10, 't': 50},
# 'len': 0.9, 'x': 0.1, 'y': 0,
# 'steps': [{'args': [[k], {'frame': {'duration': 500.0, 'easing': 'linear', 'redraw': False},
# 'transition': {'duration': 0, 'easing': 'linear'}}],
# 'label': months[k], 'method': 'animate'} for k in range(number_frames)
# ]}]
# fig.update(frames=frames,
# layout={
# "xaxis1": {
# "autorange":True,
# 'showgrid': False,
# 'zeroline': False, # thick line at x=0
# 'visible': False, # numbers below
# },
# "yaxis1": {
# "autorange":True,
# 'showgrid': False,
# 'zeroline': False,
# 'visible': False,},
# "xaxis2": {
# "autorange":True,
# 'showgrid': False,
# 'zeroline': False,
# 'visible': False,
# },
# "yaxis2": {
# "autorange":True,
# 'showgrid': False,
# 'zeroline': False,
# 'visible': False,},
# "xaxis4": {
# "ticktext": months,
# "tickvals": months,
# "tickangle": 90,
# },
# "yaxis4": {
# 'range': [min(scores)*0.9, max(scores)* 1.1],
# 'showgrid': False,
# 'zeroline': False,
# 'visible': True
# },
# })
# fig.update_layout(
# updatemenus=updatemenus,
# sliders=sliders,
# # legend=dict(
# # yanchor= 'bottom',
# # xanchor= 'center',
# # orientation="h"),
# )
# Scores
fig = make_subplots(
rows=1, cols=4,
specs=[[{"type": "image"},{"type": "image"}, {"type": "pie"}, {"type": "scatter"}]],
subplot_titles=("Localisation visualization", "Labeled visualisation", "Segments repartition", "Biodiversity scores")
)
fig2 = px.imshow(np.array(imgs), animation_frame=0, binary_string=True)
fig3 = px.imshow(np.array(imgs_label), animation_frame=0, binary_string=True)
pie_charts = [go.Pie(labels = class_names,
values = [nb_values[k][key] for key in mapping_class.keys()],
marker_colors = colors,
name="Segment repartition",
textposition='inside',
texttemplate = "%{percent:.0%}",
textfont_size=14,
)
for k in range(len(scores))]
scatters = [go.Scatter(
x=months[:i+1],
y=scores[:i+1],
mode="lines+markers+text",
marker_color="black",
text = [f"{score:.4f}" for score in scores[:i+1]],
textposition="top center",
) for i in range(len(scores))]
fig.add_trace(fig2["frames"][0]["data"][0], row=1, col=1)
fig.add_trace(fig3["frames"][0]["data"][0], row=1, col=2)
fig.add_trace(pie_charts[0], row=1, col=3)
fig.add_trace(scatters[0], row=1, col=4)
start_date = datetime.datetime.strptime(months[0], "%Y-%m-%d") - relativedelta(months=1)
end_date = datetime.datetime.strptime(months[-1], "%Y-%m-%d") + relativedelta(months=1)
interval = [start_date.strftime("%Y-%m-%d"),end_date.strftime("%Y-%m-%d")]
fig.update_layout({
"xaxis": {
"autorange":True,
'showgrid': False,
'zeroline': False, # thick line at x=0
'visible': False, # numbers below
},
"yaxis": {
"autorange":True,
'showgrid': False,
'zeroline': False,
'visible': False,},
"xaxis1": {
"range":[0,imgs[0].shape[1]],
'showgrid': False,
'zeroline': False,
'visible': False,
},
"yaxis1": {
"range":[imgs[0].shape[0],0],
'showgrid': False,
'zeroline': False,
'visible': False,},
"xaxis3": {
"dtick":"M3",
"range":interval
},
"yaxis3": {
'range': [min(scores)*0.9, max(scores)* 1.1],
'showgrid': False,
'zeroline': False,
'visible': True
}}
)
frames = [dict(
name = k,
data = [ fig2["frames"][k]["data"][0],
fig3["frames"][k]["data"][0],
pie_charts[k],
scatters[k]
],
traces=[0,1,2,3]
) for k in range(len(scores))]
updatemenus = [dict(type='buttons',
buttons=[dict(label='Play',
method='animate',
args=[
[f'{k}' for k in range(len(scores))],
dict(
frame=dict(duration=500, redraw=False),
transition=dict(duration=0),
# easing='linear',
# fromcurrent=True,
# mode='immediate'
)
]
)],
direction= 'left',
pad=dict(r= 10, t=85),
showactive =True, x= 0.1, y= 0, xanchor= 'right', yanchor= 'top')
]
sliders = [{'yanchor': 'top',
'xanchor': 'left',
'currentvalue': {
'font': {'size': 16},
'visible': True,
'xanchor': 'right'},
'transition': {
'duration': 500.0,
'easing': 'linear'},
'pad': {'b': 10, 't': 50},
'len': 0.9, 'x': 0.1, 'y': 0,
'steps': [{'args': [None, {'frame': {'duration': 500.0,'redraw': False},
'transition': {'duration': 0}}],
'label': k, 'method': 'animate'} for k in range(len(scores))
]
}]
fig.update_layout(updatemenus=updatemenus,
sliders=sliders,
)
fig.update(frames=frames)
return fig
def transform_to_pil(output, alpha=0.3):
# Transform img with torch
img = torch.moveaxis(prep_for_plot(output['img']),-1,0)
img=T.ToPILImage()(img)
cmaplist = np.array([np.array(cmap(i)) for i in range(cmap.N)])
labels = np.array(output['linear_preds'])-1
label = T.ToPILImage()((cmaplist[labels]*255).astype(np.uint8))
# Overlay labels with img wit alpha
background = img.convert("RGBA")
overlay = label.convert("RGBA")
labeled_img = Image.blend(background, overlay, alpha)
return img, label, labeled_img
def prep_for_plot(img, rescale=True, resize=None):
if resize is not None:
img = F.interpolate(img.unsqueeze(0), resize, mode="bilinear")
else:
img = img.unsqueeze(0)
plot_img = unnorm(img).squeeze(0).cpu().permute(1, 2, 0)
if rescale:
plot_img = (plot_img - plot_img.min()) / (plot_img.max() - plot_img.min())
return plot_img
def add_plot(writer, name, step):
buf = io.BytesIO()
plt.savefig(buf, format='jpeg', dpi=100)
buf.seek(0)
image = Image.open(buf)
image = T.ToTensor()(image)
writer.add_image(name, image, step)
plt.clf()
plt.close()
@torch.jit.script
def shuffle(x):
return x[torch.randperm(x.shape[0])]
def add_hparams_fixed(writer, hparam_dict, metric_dict, global_step):
exp, ssi, sei = hparams(hparam_dict, metric_dict)
writer.file_writer.add_summary(exp)
writer.file_writer.add_summary(ssi)
writer.file_writer.add_summary(sei)
for k, v in metric_dict.items():
writer.add_scalar(k, v, global_step)
@torch.jit.script
def resize(classes: torch.Tensor, size: int):
return F.interpolate(classes, (size, size), mode="bilinear", align_corners=False)
def one_hot_feats(labels, n_classes):
return F.one_hot(labels, n_classes).permute(0, 3, 1, 2).to(torch.float32)
def load_model(model_type, data_dir):
if model_type == "robust_resnet50":
model = models.resnet50(pretrained=False)
model_file = join(data_dir, 'imagenet_l2_3_0.pt')
if not os.path.exists(model_file):
wget.download("http://6.869.csail.mit.edu/fa19/psets19/pset6/imagenet_l2_3_0.pt",
model_file)
model_weights = torch.load(model_file)
model_weights_modified = {name.split('model.')[1]: value for name, value in model_weights['model'].items() if
'model' in name}
model.load_state_dict(model_weights_modified)
model = nn.Sequential(*list(model.children())[:-1])
elif model_type == "densecl":
model = models.resnet50(pretrained=False)
model_file = join(data_dir, 'densecl_r50_coco_1600ep.pth')
if not os.path.exists(model_file):
wget.download("https://cloudstor.aarnet.edu.au/plus/s/3GapXiWuVAzdKwJ/download",
model_file)
model_weights = torch.load(model_file)
# model_weights_modified = {name.split('model.')[1]: value for name, value in model_weights['model'].items() if
# 'model' in name}
model.load_state_dict(model_weights['state_dict'], strict=False)
model = nn.Sequential(*list(model.children())[:-1])
elif model_type == "resnet50":
model = models.resnet50(pretrained=True)
model = nn.Sequential(*list(model.children())[:-1])
elif model_type == "mocov2":
model = models.resnet50(pretrained=False)
model_file = join(data_dir, 'moco_v2_800ep_pretrain.pth.tar')
if not os.path.exists(model_file):
wget.download("https://dl.fbaipublicfiles.com/moco/moco_checkpoints/"
"moco_v2_800ep/moco_v2_800ep_pretrain.pth.tar", model_file)
checkpoint = torch.load(model_file)
# rename moco pre-trained keys
state_dict = checkpoint['state_dict']
for k in list(state_dict.keys()):
# retain only encoder_q up to before the embedding layer
if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
# remove prefix
state_dict[k[len("module.encoder_q."):]] = state_dict[k]
# delete renamed or unused k
del state_dict[k]
msg = model.load_state_dict(state_dict, strict=False)
assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
model = nn.Sequential(*list(model.children())[:-1])
elif model_type == "densenet121":
model = models.densenet121(pretrained=True)
model = nn.Sequential(*list(model.children())[:-1] + [nn.AdaptiveAvgPool2d((1, 1))])
elif model_type == "vgg11":
model = models.vgg11(pretrained=True)
model = nn.Sequential(*list(model.children())[:-1] + [nn.AdaptiveAvgPool2d((1, 1))])
else:
raise ValueError("No model: {} found".format(model_type))
model.eval()
model.cuda()
return model
class UnNormalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, image):
image2 = torch.clone(image)
for t, m, s in zip(image2, self.mean, self.std):
t.mul_(s).add_(m)
return image2
normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
unnorm = UnNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
class ToTargetTensor(object):
def __call__(self, target):
return torch.as_tensor(np.array(target), dtype=torch.int64).unsqueeze(0)
def prep_args():
import sys
old_args = sys.argv
new_args = [old_args.pop(0)]
while len(old_args) > 0:
arg = old_args.pop(0)
if len(arg.split("=")) == 2:
new_args.append(arg)
elif arg.startswith("--"):
new_args.append(arg[2:] + "=" + old_args.pop(0))
else:
raise ValueError("Unexpected arg style {}".format(arg))
sys.argv = new_args
def get_transform(res, is_label, crop_type):
if crop_type == "center":
cropper = T.CenterCrop(res)
elif crop_type == "random":
cropper = T.RandomCrop(res)
elif crop_type is None:
cropper = T.Lambda(lambda x: x)
res = (res, res)
else:
raise ValueError("Unknown Cropper {}".format(crop_type))
if is_label:
return T.Compose([T.Resize(res, Image.NEAREST),
cropper,
ToTargetTensor()])
else:
return T.Compose([T.Resize(res, Image.NEAREST),
cropper,
T.ToTensor(),
normalize])
def _remove_axes(ax):
ax.xaxis.set_major_formatter(plt.NullFormatter())
ax.yaxis.set_major_formatter(plt.NullFormatter())
ax.set_xticks([])
ax.set_yticks([])
def remove_axes(axes):
if len(axes.shape) == 2:
for ax1 in axes:
for ax in ax1:
_remove_axes(ax)
else:
for ax in axes:
_remove_axes(ax)
class UnsupervisedMetrics(Metric):
def __init__(self, prefix: str, n_classes: int, extra_clusters: int, compute_hungarian: bool,
dist_sync_on_step=True):
# call `self.add_state`for every internal state that is needed for the metrics computations
# dist_reduce_fx indicates the function that should be used to reduce
# state from multiple processes
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.n_classes = n_classes
self.extra_clusters = extra_clusters
self.compute_hungarian = compute_hungarian
self.prefix = prefix
self.add_state("stats",
default=torch.zeros(n_classes + self.extra_clusters, n_classes, dtype=torch.int64),
dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, target: torch.Tensor):
with torch.no_grad():
actual = target.reshape(-1)
preds = preds.reshape(-1)
mask = (actual >= 0) & (actual < self.n_classes) & (preds >= 0) & (preds < self.n_classes)
actual = actual[mask]
preds = preds[mask]
self.stats += torch.bincount(
(self.n_classes + self.extra_clusters) * actual + preds,
minlength=self.n_classes * (self.n_classes + self.extra_clusters)) \
.reshape(self.n_classes, self.n_classes + self.extra_clusters).t().to(self.stats.device)
def map_clusters(self, clusters):
if self.extra_clusters == 0:
return torch.tensor(self.assignments[1])[clusters]
else:
missing = sorted(list(set(range(self.n_classes + self.extra_clusters)) - set(self.assignments[0])))
cluster_to_class = self.assignments[1]
for missing_entry in missing:
if missing_entry == cluster_to_class.shape[0]:
cluster_to_class = np.append(cluster_to_class, -1)
else:
cluster_to_class = np.insert(cluster_to_class, missing_entry + 1, -1)
cluster_to_class = torch.tensor(cluster_to_class)
return cluster_to_class[clusters]
def compute(self):
if self.compute_hungarian:
self.assignments = linear_sum_assignment(self.stats.detach().cpu(), maximize=True)
# print(self.assignments)
if self.extra_clusters == 0:
self.histogram = self.stats[np.argsort(self.assignments[1]), :]
if self.extra_clusters > 0:
self.assignments_t = linear_sum_assignment(self.stats.detach().cpu().t(), maximize=True)
histogram = self.stats[self.assignments_t[1], :]
missing = list(set(range(self.n_classes + self.extra_clusters)) - set(self.assignments[0]))
new_row = self.stats[missing, :].sum(0, keepdim=True)
histogram = torch.cat([histogram, new_row], axis=0)
new_col = torch.zeros(self.n_classes + 1, 1, device=histogram.device)
self.histogram = torch.cat([histogram, new_col], axis=1)
else:
self.assignments = (torch.arange(self.n_classes).unsqueeze(1),
torch.arange(self.n_classes).unsqueeze(1))
self.histogram = self.stats
tp = torch.diag(self.histogram)
fp = torch.sum(self.histogram, dim=0) - tp
fn = torch.sum(self.histogram, dim=1) - tp
iou = tp / (tp + fp + fn)
prc = tp / (tp + fn)
opc = torch.sum(tp) / torch.sum(self.histogram)
metric_dict = {self.prefix + "mIoU": iou[~torch.isnan(iou)].mean().item(),
self.prefix + "Accuracy": opc.item()}
return {k: 100 * v for k, v in metric_dict.items()}
def flexible_collate(batch):
r"""Puts each data field into a tensor with outer dimension batch size"""
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
try:
return torch.stack(batch, 0, out=out)
except RuntimeError:
return batch
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
return flexible_collate([torch.as_tensor(b) for b in batch])
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, collections.abc.Mapping):
return {key: flexible_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(flexible_collate(samples) for samples in zip(*batch)))
elif isinstance(elem, collections.abc.Sequence):
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError('each element in list of batch should be of equal size')
transposed = zip(*batch)
return [flexible_collate(samples) for samples in transposed]
raise TypeError(default_collate_err_msg_format.format(elem_type))
if __name__ == "__main__":
fig = plot_imgs_labels(months, imgs, imgs_label, nb_values, scores)