File size: 6,337 Bytes
894bc0c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import logging
from collections import OrderedDict
from pathlib import Path
from typing import Union, List
import torch
import torchvision
def check_is_valid_torchvision_architecture(architecture: str):
"""Raises an ValueError if architecture is not part of available torchvision models
available = sorted(
for name in torchvision.models.__dict__
if name.islower()
and not name.startswith("__")
and callable(torchvision.models.__dict__[name])
if architecture not in available:
raise ValueError(f"{architecture} not in {available}")
def build_base_model(arch: str):
model = torchvision.models.__dict__[arch](pretrained=True)
# get input dimension before classification layer
if arch in ["mobilenet_v2"]:
nfeatures = model.classifier[-1].in_features
model = torch.nn.Sequential(*list(model.children())[:-1])
elif arch in ["densenet121", "densenet161", "densenet169"]:
nfeatures = model.classifier.in_features
model = torch.nn.Sequential(*list(model.children())[:-1])
elif "resne" in arch:
# usually all ResNet variants
nfeatures = model.fc.in_features
model = torch.nn.Sequential(*list(model.children())[:-2])
raise NotImplementedError
model.avgpool = torch.nn.AdaptiveAvgPool2d(1)
model.flatten = torch.nn.Flatten(start_dim=1)
return model, nfeatures
def load_weights_if_available(
model: torch.nn.Module, classifier: torch.nn.Module, weights_path: Union[str, Path]
checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage)
state_dict_features = OrderedDict()
state_dict_classifier = OrderedDict()
for k, w in checkpoint["state_dict"].items():
if k.startswith("model"):
state_dict_features[k.replace("model.", "")] = w
elif k.startswith("classifier"):
state_dict_classifier[k.replace("classifier.", "")] = w
logging.warning(f"Unexpected prefix in state_dict: {k}")
model.load_state_dict(state_dict_features, strict=True)
return model, classifier
def vectorized_gc_distance(latitudes, longitudes, latitudes_gt, longitudes_gt):
R = 6371
factor_rad = 0.01745329252
longitudes = factor_rad * longitudes
longitudes_gt = factor_rad * longitudes_gt
latitudes = factor_rad * latitudes
latitudes_gt = factor_rad * latitudes_gt
delta_long = longitudes_gt - longitudes
delta_lat = latitudes_gt - latitudes
subterm0 = torch.sin(delta_lat / 2) ** 2
subterm1 = torch.cos(latitudes) * torch.cos(latitudes_gt)
subterm2 = torch.sin(delta_long / 2) ** 2
subterm1 = subterm1 * subterm2
a = subterm0 + subterm1
c = 2 * torch.asin(torch.sqrt(a))
gcd = R * c
return gcd
def gcd_threshold_eval(gc_dists, thresholds=[1, 25, 200, 750, 2500]):
# calculate accuracy for given gcd thresolds
results = {}
for thres in thresholds:
results[thres] = torch.true_divide(
torch.sum(gc_dists <= thres), len(gc_dists)
return results
def accuracy(output, target, partitioning_shortnames: list, topk=(1, 5, 10)):
def _accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = {}
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res[k] = correct_k / batch_size
return res
with torch.no_grad():
out_dict = {}
for i, pname in enumerate(partitioning_shortnames):
res_dict = _accuracy(output[i], target[i], topk=topk)
for k, v in res_dict.items():
out_dict[f"acc{k}_val/{pname}"] = v
return out_dict
def summarize_gcd_stats(pnames: List[str], outputs, hierarchy=None):
gcd_dict = {}
metric_names = [f"gcd_{p}_val" for p in pnames]
if hierarchy is not None:
for metric_name in metric_names:
distances_flat = [output[metric_name] for output in outputs]
distances_flat =, dim=0)
gcd_results = gcd_threshold_eval(distances_flat)
for gcd_thres, acc in gcd_results.items():
gcd_dict[f"{metric_name}/{gcd_thres}"] = acc
return gcd_dict
def summarize_test_gcd(pnames, outputs, hierarchy=None):
def _eval(output):
# calculate acc@km for a list of given thresholds
accuracy_outputs = {}
if hierarchy is not None:
for pname in pnames:
# concat batches of distances
distances_flat =[x[pname] for x in output], dim=0)
# acc for all distances
acc_dict = gcd_threshold_eval(distances_flat)
accuracy_outputs[f"acc_test/{pname}"] = acc_dict
return accuracy_outputs
result = {}
if isinstance(outputs[0], dict): # only one testset
result = _eval(outputs)
elif isinstance(outputs[0], list): # multiple testsets
for testset_index, output in enumerate(outputs):
result[testset_index] = _eval(output)
raise TypeError
return result
def summarize_loss_acc_stats(pnames: List[str], outputs, topk=[1, 5, 10]):
loss_acc_dict = {}
metric_names = []
for k in topk:
accuracy_names = [f"acc{k}_val/{p}" for p in pnames]
metric_names.extend([f"loss_val/{p}" for p in pnames])
for metric_name in ["loss_val/total", *metric_names]:
metric_total = 0
for output in outputs:
metric_value = output[metric_name]
metric_total += metric_value
loss_acc_dict[metric_name] = metric_total / len(outputs)
return loss_acc_dict |