object-detection / evaluator.py
mingyang91's picture
Update UI
095f7cc verified
raw
history blame
2.33 kB
from typing import Callable
from PIL.Image import Image
from coco_eval import CocoEvaluator
from pycocotools.coco import COCO
from tqdm import tqdm
from yolo_dataset import YoloDataset
from yolo_model import YoloModel
image_loader = Callable[[str], Image]
def evaluate(model: YoloModel, coco_gt: COCO, loader: image_loader, confidence_threshold=0.6):
# initialize evaluator with ground truth (gt)
evaluator = CocoEvaluator(coco_gt=coco_gt, iou_types=["bbox"])
print("Running evaluation...")
for image_id, annotations in tqdm(coco_gt.imgToAnns.items()):
# get the inputs
image = coco_gt.imgs[image_id]
results = model.model(source=loader(image["file_name"]))
for result in results:
coco_anns = yolo_boxes_to_coco_annotations(image_id, result.boxes,
confidence_threshold=confidence_threshold)
if len(coco_anns) == 0:
continue
evaluator.update(coco_anns)
if len(evaluator.eval_imgs["bbox"]) == 0:
print("No detections!")
return
evaluator.synchronize_between_processes()
evaluator.accumulate()
evaluator.summarize()
def yolo_boxes_to_coco_annotations(image_id: int, yolo_boxes, confidence_threshold=0.6):
return [
{
"image_id": image_id,
"category_id": box.cls.tolist()[0],
"area": box.xywh.tolist()[0][2] * box.xywh.tolist()[0][3],
"bbox": box.xywh.tolist()[0],
"score": box.conf.tolist()[0],
}
for box in yolo_boxes if box.conf.tolist()[0] > confidence_threshold
]
def main():
yolo_dataset = YoloDataset.from_path('tests/coco8.zip')
coco_gt = yolo_dataset.to_coco()
model = YoloModel("ultralyticsplus/yolov8s", "yolov8s.pt")
# model = YoloModel("SHOU-ISD/fire-and-smoke", "yolov8n.pt")
evaluate(model=model, coco_gt=coco_gt, loader=yolo_dataset.load_image, confidence_threshold=0.1)
# Validate the model
metrics = model.model.val() # no arguments needed, dataset and settings remembered
print(metrics.box.map) # map50-95
print(metrics.box.map50) # map50
print(metrics.box.map75) # map75
print(metrics.box.maps) # a list contains map50-95 of each category
if __name__ == '__main__':
main()