Spaces:
Runtime error
Runtime error
| from typing import Callable | |
| from PIL.Image import Image | |
| from coco_eval import CocoEvaluator | |
| from pycocotools.coco import COCO | |
| from tqdm import tqdm | |
| from yolo_dataset import YoloDataset | |
| from yolo_model import YoloModel | |
| image_loader = Callable[[str], Image] | |
| def evaluate(model: YoloModel, coco_gt: COCO, loader: image_loader, confidence_threshold=0.6): | |
| # initialize evaluator with ground truth (gt) | |
| evaluator = CocoEvaluator(coco_gt=coco_gt, iou_types=["bbox"]) | |
| print("Running evaluation...") | |
| for image_id, annotations in tqdm(coco_gt.imgToAnns.items()): | |
| # get the inputs | |
| image = coco_gt.imgs[image_id] | |
| results = model.model(source=loader(image["file_name"])) | |
| for result in results: | |
| coco_anns = yolo_boxes_to_coco_annotations(image_id, result.boxes, | |
| confidence_threshold=confidence_threshold) | |
| if len(coco_anns) == 0: | |
| continue | |
| evaluator.update(coco_anns) | |
| if len(evaluator.eval_imgs["bbox"]) == 0: | |
| print("No detections!") | |
| return | |
| evaluator.synchronize_between_processes() | |
| evaluator.accumulate() | |
| evaluator.summarize() | |
| def yolo_boxes_to_coco_annotations(image_id: int, yolo_boxes, confidence_threshold=0.6): | |
| return [ | |
| { | |
| "image_id": image_id, | |
| "category_id": box.cls.tolist()[0], | |
| "area": box.xywh.tolist()[0][2] * box.xywh.tolist()[0][3], | |
| "bbox": box.xywh.tolist()[0], | |
| "score": box.conf.tolist()[0], | |
| } | |
| for box in yolo_boxes if box.conf.tolist()[0] > confidence_threshold | |
| ] | |
| def main(): | |
| yolo_dataset = YoloDataset.from_path('tests/coco8.zip') | |
| coco_gt = yolo_dataset.to_coco() | |
| model = YoloModel("ultralyticsplus/yolov8s", "yolov8s.pt") | |
| # model = YoloModel("SHOU-ISD/fire-and-smoke", "yolov8n.pt") | |
| evaluate(model=model, coco_gt=coco_gt, loader=yolo_dataset.load_image, confidence_threshold=0.1) | |
| # Validate the model | |
| metrics = model.model.val() # no arguments needed, dataset and settings remembered | |
| print(metrics.box.map) # map50-95 | |
| print(metrics.box.map50) # map50 | |
| print(metrics.box.map75) # map75 | |
| print(metrics.box.maps) # a list contains map50-95 of each category | |
| if __name__ == '__main__': | |
| main() | |