Spaces:
Running
Running
File size: 4,943 Bytes
894bc0c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import torch
from metrics.utils import haversine, reverse
from torchmetrics import Metric
class HaversineMetrics(Metric):
"""
Computes the average haversine distance between the predicted and ground truth points.
Compute the accuracy given some radiuses.
Compute the Geoguessr score given some radiuses.
Args:
acc_radiuses (list): list of radiuses to compute the accuracy from
acc_area (list): list of areas to compute the accuracy from.
acc_data (list): list of auxilliary data to compute the accuracy from.
"""
def __init__(
self,
acc_radiuses=[],
acc_area=["country", "region", "sub-region", "city"],
aux_data=[],
):
super().__init__()
self.add_state("haversine_sum", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("geoguessr_sum", default=torch.tensor(0.0), dist_reduce_fx="sum")
for acc in acc_radiuses:
self.add_state(
f"close_enough_points_{acc}",
default=torch.tensor(0.0),
dist_reduce_fx="sum",
)
for acc in acc_area:
self.add_state(
f"close_enough_points_{acc}",
default=torch.tensor(0.0),
dist_reduce_fx="sum",
)
self.add_state(
f"count_{acc}", default=torch.tensor(0), dist_reduce_fx="sum"
)
self.acc_radius = acc_radiuses
self.acc_area = acc_area
self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
self.aux = len(aux_data) > 0
self.aux_list = aux_data
if self.aux:
self.aux_count = {}
for col in self.aux_list:
self.add_state(
f"aux_{col}",
default=torch.tensor(0.0),
dist_reduce_fx="sum",
)
def update(self, pred, gt):
haversine_distance = haversine(pred["gps"], gt["gps"])
for acc in self.acc_radius:
self.__dict__[f"close_enough_points_{acc}"] += (
haversine_distance < acc
).sum()
if len(self.acc_area) > 0:
area_pred, area_gt = reverse(pred["gps"], gt, self.acc_area)
for acc in self.acc_area:
self.__dict__[f"close_enough_points_{acc}"] += (
area_pred[acc] == area_gt["_".join(["unique", acc])]
).sum()
self.__dict__[f"count_{acc}"] += len(area_gt["_".join(["unique", acc])])
self.haversine_sum += haversine_distance.sum()
self.geoguessr_sum += 5000 * torch.exp(-haversine_distance / 1492.7).sum()
if self.aux:
if "land_cover" in self.aux_list:
col = "land_cover"
self.__dict__[f"aux_{col}"] += (
pred[col].argmax(dim=1) == gt[col].argmax(dim=1)
).sum()
if "road_index" in self.aux_list:
col = "road_index"
self.__dict__[f"aux_{col}"] += (
pred[col].argmax(dim=1) == gt[col].argmax(dim=1)
).sum()
if "drive_side" in self.aux_list:
col = "drive_side"
self.__dict__[f"aux_{col}"] += (
(pred[col] > 0.5).float() == gt[col]
).sum()
if "climate" in self.aux_list:
col = "climate"
self.__dict__[f"aux_{col}"] += (
pred[col].argmax(dim=1) == gt[col].argmax(dim=1)
).sum()
if "soil" in self.aux_list:
col = "soil"
self.__dict__[f"aux_{col}"] += (
pred[col].argmax(dim=1) == gt[col].argmax(dim=1)
).sum()
if "dist_sea" in self.aux_list:
col = "dist_sea"
self.__dict__[f"aux_{col}"] += (
(pred[col] - gt[col]).pow(2).sum(dim=1).sum()
)
self.count += pred["gps"].shape[0]
def compute(self):
output = {
"Haversine": self.haversine_sum / self.count,
"Geoguessr": self.geoguessr_sum / self.count,
}
for acc in self.acc_radius:
output[f"Accuracy_{acc}_km_radius"] = (
self.__dict__[f"close_enough_points_{acc}"] / self.count
)
for acc in self.acc_area:
output[f"Accuracy_{acc}"] = (
self.__dict__[f"close_enough_points_{acc}"]
/ self.__dict__[f"count_{acc}"]
)
if self.aux:
for col in self.aux_list:
output["_".join(["Accuracy", col])] = (
self.__dict__[f"aux_{col}"] / self.count
)
return output
|