Spaces:
Running
Running
File size: 9,365 Bytes
1999a98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
# Ultralytics π AGPL-3.0 License - https://ultralytics.com/license
from pathlib import Path
import torch
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import LOGGER, ops
from ultralytics.utils.metrics import OBBMetrics, batch_probiou
from ultralytics.utils.plotting import output_to_rotated_target, plot_images
class OBBValidator(DetectionValidator):
"""
A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.
Example:
```python
from ultralytics.models.yolo.obb import OBBValidator
args = dict(model="yolov8n-obb.pt", data="dota8.yaml")
validator = OBBValidator(args=args)
validator(model=args["model"])
```
"""
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.args.task = "obb"
self.metrics = OBBMetrics(save_dir=self.save_dir, plot=True, on_plot=self.on_plot)
def init_metrics(self, model):
"""Initialize evaluation metrics for YOLO."""
super().init_metrics(model)
val = self.data.get(self.args.split, "") # validation path
self.is_dota = isinstance(val, str) and "DOTA" in val # is COCO
def postprocess(self, preds):
"""Apply Non-maximum suppression to prediction outputs."""
return ops.non_max_suppression(
preds,
self.args.conf,
self.args.iou,
labels=self.lb,
nc=self.nc,
multi_label=True,
agnostic=self.args.single_cls or self.args.agnostic_nms,
max_det=self.args.max_det,
rotated=True,
)
def _process_batch(self, detections, gt_bboxes, gt_cls):
"""
Perform computation of the correct prediction matrix for a batch of detections and ground truth bounding boxes.
Args:
detections (torch.Tensor): A tensor of shape (N, 7) representing the detected bounding boxes and associated
data. Each detection is represented as (x1, y1, x2, y2, conf, class, angle).
gt_bboxes (torch.Tensor): A tensor of shape (M, 5) representing the ground truth bounding boxes. Each box is
represented as (x1, y1, x2, y2, angle).
gt_cls (torch.Tensor): A tensor of shape (M,) representing class labels for the ground truth bounding boxes.
Returns:
(torch.Tensor): The correct prediction matrix with shape (N, 10), which includes 10 IoU (Intersection over
Union) levels for each detection, indicating the accuracy of predictions compared to the ground truth.
Example:
```python
detections = torch.rand(100, 7) # 100 sample detections
gt_bboxes = torch.rand(50, 5) # 50 sample ground truth boxes
gt_cls = torch.randint(0, 5, (50,)) # 50 ground truth class labels
correct_matrix = OBBValidator._process_batch(detections, gt_bboxes, gt_cls)
```
Note:
This method relies on `batch_probiou` to calculate IoU between detections and ground truth bounding boxes.
"""
iou = batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1))
return self.match_predictions(detections[:, 5], gt_cls, iou)
def _prepare_batch(self, si, batch):
"""Prepares and returns a batch for OBB validation."""
idx = batch["batch_idx"] == si
cls = batch["cls"][idx].squeeze(-1)
bbox = batch["bboxes"][idx]
ori_shape = batch["ori_shape"][si]
imgsz = batch["img"].shape[2:]
ratio_pad = batch["ratio_pad"][si]
if len(cls):
bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]) # target boxes
ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad, xywh=True) # native-space labels
return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
def _prepare_pred(self, pred, pbatch):
"""Prepares and returns a batch for OBB validation with scaled and padded bounding boxes."""
predn = pred.clone()
ops.scale_boxes(
pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
) # native-space pred
return predn
def plot_predictions(self, batch, preds, ni):
"""Plots predicted bounding boxes on input images and saves the result."""
plot_images(
batch["img"],
*output_to_rotated_target(preds, max_det=self.args.max_det),
paths=batch["im_file"],
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
names=self.names,
on_plot=self.on_plot,
) # pred
def pred_to_json(self, predn, filename):
"""Serialize YOLO predictions to COCO json format."""
stem = Path(filename).stem
image_id = int(stem) if stem.isnumeric() else stem
rbox = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8)
for i, (r, b) in enumerate(zip(rbox.tolist(), poly.tolist())):
self.jdict.append(
{
"image_id": image_id,
"category_id": self.class_map[int(predn[i, 5].item())],
"score": round(predn[i, 4].item(), 5),
"rbox": [round(x, 3) for x in r],
"poly": [round(x, 3) for x in b],
}
)
def save_one_txt(self, predn, save_conf, shape, file):
"""Save YOLO detections to a txt file in normalized coordinates in a specific format."""
import numpy as np
from ultralytics.engine.results import Results
rboxes = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
# xywh, r, conf, cls
obb = torch.cat([rboxes, predn[:, 4:6]], dim=-1)
Results(
np.zeros((shape[0], shape[1]), dtype=np.uint8),
path=None,
names=self.names,
obb=obb,
).save_txt(file, save_conf=save_conf)
def eval_json(self, stats):
"""Evaluates YOLO output in JSON format and returns performance statistics."""
if self.args.save_json and self.is_dota and len(self.jdict):
import json
import re
from collections import defaultdict
pred_json = self.save_dir / "predictions.json" # predictions
pred_txt = self.save_dir / "predictions_txt" # predictions
pred_txt.mkdir(parents=True, exist_ok=True)
data = json.load(open(pred_json))
# Save split results
LOGGER.info(f"Saving predictions with DOTA format to {pred_txt}...")
for d in data:
image_id = d["image_id"]
score = d["score"]
classname = self.names[d["category_id"] - 1].replace(" ", "-")
p = d["poly"]
with open(f"{pred_txt / f'Task1_{classname}'}.txt", "a") as f:
f.writelines(f"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\n")
# Save merged results, this could result slightly lower map than using official merging script,
# because of the probiou calculation.
pred_merged_txt = self.save_dir / "predictions_merged_txt" # predictions
pred_merged_txt.mkdir(parents=True, exist_ok=True)
merged_results = defaultdict(list)
LOGGER.info(f"Saving merged predictions with DOTA format to {pred_merged_txt}...")
for d in data:
image_id = d["image_id"].split("__")[0]
pattern = re.compile(r"\d+___\d+")
x, y = (int(c) for c in re.findall(pattern, d["image_id"])[0].split("___"))
bbox, score, cls = d["rbox"], d["score"], d["category_id"] - 1
bbox[0] += x
bbox[1] += y
bbox.extend([score, cls])
merged_results[image_id].append(bbox)
for image_id, bbox in merged_results.items():
bbox = torch.tensor(bbox)
max_wh = torch.max(bbox[:, :2]).item() * 2
c = bbox[:, 6:7] * max_wh # classes
scores = bbox[:, 5] # scores
b = bbox[:, :5].clone()
b[:, :2] += c
# 0.3 could get results close to the ones from official merging script, even slightly better.
i = ops.nms_rotated(b, scores, 0.3)
bbox = bbox[i]
b = ops.xywhr2xyxyxyxy(bbox[:, :5]).view(-1, 8)
for x in torch.cat([b, bbox[:, 5:7]], dim=-1).tolist():
classname = self.names[int(x[-1])].replace(" ", "-")
p = [round(i, 3) for i in x[:-2]] # poly
score = round(x[-2], 3)
with open(f"{pred_merged_txt / f'Task1_{classname}'}.txt", "a") as f:
f.writelines(f"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\n")
return stats
|