Spaces:
Build error
Build error
Commit
·
de1f95e
1
Parent(s):
0bb094e
fix bugs and add improvements
Browse files- PanopticQuality.py +28 -8
PanopticQuality.py
CHANGED
|
@@ -71,7 +71,6 @@ Examples:
|
|
| 71 |
Added data ...
|
| 72 |
Start computing ...
|
| 73 |
Finished!
|
| 74 |
-
tensor(0.2082, dtype=torch.float64)
|
| 75 |
"""
|
| 76 |
|
| 77 |
|
|
@@ -81,6 +80,8 @@ class PQMetric(evaluate.Metric):
|
|
| 81 |
self,
|
| 82 |
label2id: Dict[str, int] = None,
|
| 83 |
stuff: List[str] = None,
|
|
|
|
|
|
|
| 84 |
**kwargs
|
| 85 |
):
|
| 86 |
super().__init__(**kwargs)
|
|
@@ -109,9 +110,13 @@ class PQMetric(evaluate.Metric):
|
|
| 109 |
|
| 110 |
self.label2id = label2id if label2id is not None else DEFAULT_LABEL2ID
|
| 111 |
self.stuff = stuff if stuff is not None else DEFAULT_STUFF
|
|
|
|
|
|
|
| 112 |
self.pq_metric = PanopticQuality(
|
| 113 |
things=set([self.label2id[label] for label in self.label2id.keys() if label not in self.stuff]),
|
| 114 |
-
stuffs=set([self.label2id[label] for label in self.label2id.keys() if label in self.stuff])
|
|
|
|
|
|
|
| 115 |
)
|
| 116 |
|
| 117 |
def _info(self):
|
|
@@ -151,9 +156,6 @@ class PQMetric(evaluate.Metric):
|
|
| 151 |
# in case the inputs are lists, convert them to numpy arrays
|
| 152 |
|
| 153 |
self.pq_metric.update(prediction, reference)
|
| 154 |
-
print("TP:", self.pq_metric.metric.true_positives)
|
| 155 |
-
print("FP:", self.pq_metric.metric.false_positives)
|
| 156 |
-
print("FN:", self.pq_metric.metric.false_negatives)
|
| 157 |
|
| 158 |
# does not impact the metric, but is required for the interface x_x
|
| 159 |
super(evaluate.Metric, self).add(
|
|
@@ -164,12 +166,30 @@ class PQMetric(evaluate.Metric):
|
|
| 164 |
|
| 165 |
def _compute(self, *, predictions, references, **kwargs):
|
| 166 |
"""Called within the evaluate.Metric.compute() method"""
|
| 167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
id2label = {id: label for label, id in self.label2id.items()}
|
| 169 |
-
|
| 170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
}
|
| 172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
def add_payload(self, payload: Payload, model_name: str = None):
|
| 174 |
"""Converts the payload to the format expected by the metric"""
|
| 175 |
# import only if needed since fiftyone is not a direct dependency
|
|
|
|
| 71 |
Added data ...
|
| 72 |
Start computing ...
|
| 73 |
Finished!
|
|
|
|
| 74 |
"""
|
| 75 |
|
| 76 |
|
|
|
|
| 80 |
self,
|
| 81 |
label2id: Dict[str, int] = None,
|
| 82 |
stuff: List[str] = None,
|
| 83 |
+
per_class: bool = True,
|
| 84 |
+
split_sq_rq: bool = True,
|
| 85 |
**kwargs
|
| 86 |
):
|
| 87 |
super().__init__(**kwargs)
|
|
|
|
| 110 |
|
| 111 |
self.label2id = label2id if label2id is not None else DEFAULT_LABEL2ID
|
| 112 |
self.stuff = stuff if stuff is not None else DEFAULT_STUFF
|
| 113 |
+
self.per_class = per_class
|
| 114 |
+
self.split_sq_rq = split_sq_rq
|
| 115 |
self.pq_metric = PanopticQuality(
|
| 116 |
things=set([self.label2id[label] for label in self.label2id.keys() if label not in self.stuff]),
|
| 117 |
+
stuffs=set([self.label2id[label] for label in self.label2id.keys() if label in self.stuff]),
|
| 118 |
+
return_per_class=per_class,
|
| 119 |
+
return_sq_and_rq=split_sq_rq
|
| 120 |
)
|
| 121 |
|
| 122 |
def _info(self):
|
|
|
|
| 156 |
# in case the inputs are lists, convert them to numpy arrays
|
| 157 |
|
| 158 |
self.pq_metric.update(prediction, reference)
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
# does not impact the metric, but is required for the interface x_x
|
| 161 |
super(evaluate.Metric, self).add(
|
|
|
|
| 166 |
|
| 167 |
def _compute(self, *, predictions, references, **kwargs):
|
| 168 |
"""Called within the evaluate.Metric.compute() method"""
|
| 169 |
+
tp = self.pq_metric.metric.true_positives.clone()
|
| 170 |
+
fp = self.pq_metric.metric.false_positives.clone()
|
| 171 |
+
fn = self.pq_metric.metric.false_negatives.clone()
|
| 172 |
+
iou = self.pq_metric.metric.iou_sum.clone()
|
| 173 |
+
|
| 174 |
id2label = {id: label for label, id in self.label2id.items()}
|
| 175 |
+
things_stuffs = sorted(self.pq_metric.things) + sorted(self.pq_metric.stuffs)
|
| 176 |
+
|
| 177 |
+
# compute scores
|
| 178 |
+
result = self.pq_metric.compute() # shape : (n_classes (sorted things + sorted stuffs), scores (pq, sq, rq))
|
| 179 |
+
|
| 180 |
+
result_dict = {
|
| 181 |
+
"numbers": {id2label[numeric_label]: [tp[i].item(), fp[i].item(), fn[i].item(), iou[i].item()] \
|
| 182 |
+
for i, numeric_label in enumerate(things_stuffs)},
|
| 183 |
+
"scores": None
|
| 184 |
}
|
| 185 |
|
| 186 |
+
if self.per_class:
|
| 187 |
+
result_dict["scores"] = {id2label[numeric_label]: result[i].tolist() for i, numeric_label in enumerate(things_stuffs)}
|
| 188 |
+
else:
|
| 189 |
+
result_dict["scores"] = result.tolist()
|
| 190 |
+
|
| 191 |
+
return result_dict
|
| 192 |
+
|
| 193 |
def add_payload(self, payload: Payload, model_name: str = None):
|
| 194 |
"""Converts the payload to the format expected by the metric"""
|
| 195 |
# import only if needed since fiftyone is not a direct dependency
|