|
|
|
|
|
from pathlib import Path |
|
|
|
from ultralytics.engine.model import Model |
|
from ultralytics.models import yolo |
|
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, OBBModel, PoseModel, SegmentationModel, WorldModel |
|
from ultralytics.utils import yaml_load, ROOT |
|
|
|
|
|
class YOLO(Model): |
|
"""YOLO (You Only Look Once) object detection model.""" |
|
|
|
def __init__(self, model="yolov8n.pt", task=None, verbose=False): |
|
"""Initialize YOLO model, switching to YOLOWorld if model filename contains '-world'.""" |
|
path = Path(model) |
|
if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: |
|
new_instance = YOLOWorld(path) |
|
self.__class__ = type(new_instance) |
|
self.__dict__ = new_instance.__dict__ |
|
elif "yolov10" in path.stem: |
|
from ultralytics import YOLOv10 |
|
new_instance = YOLOv10(path) |
|
self.__class__ = type(new_instance) |
|
self.__dict__ = new_instance.__dict__ |
|
else: |
|
|
|
super().__init__(model=model, task=task, verbose=verbose) |
|
|
|
@property |
|
def task_map(self): |
|
"""Map head to model, trainer, validator, and predictor classes.""" |
|
return { |
|
"classify": { |
|
"model": ClassificationModel, |
|
"trainer": yolo.classify.ClassificationTrainer, |
|
"validator": yolo.classify.ClassificationValidator, |
|
"predictor": yolo.classify.ClassificationPredictor, |
|
}, |
|
"detect": { |
|
"model": DetectionModel, |
|
"trainer": yolo.detect.DetectionTrainer, |
|
"validator": yolo.detect.DetectionValidator, |
|
"predictor": yolo.detect.DetectionPredictor, |
|
}, |
|
"segment": { |
|
"model": SegmentationModel, |
|
"trainer": yolo.segment.SegmentationTrainer, |
|
"validator": yolo.segment.SegmentationValidator, |
|
"predictor": yolo.segment.SegmentationPredictor, |
|
}, |
|
"pose": { |
|
"model": PoseModel, |
|
"trainer": yolo.pose.PoseTrainer, |
|
"validator": yolo.pose.PoseValidator, |
|
"predictor": yolo.pose.PosePredictor, |
|
}, |
|
"obb": { |
|
"model": OBBModel, |
|
"trainer": yolo.obb.OBBTrainer, |
|
"validator": yolo.obb.OBBValidator, |
|
"predictor": yolo.obb.OBBPredictor, |
|
}, |
|
} |
|
|
|
|
|
class YOLOWorld(Model): |
|
"""YOLO-World object detection model.""" |
|
|
|
def __init__(self, model="yolov8s-world.pt") -> None: |
|
""" |
|
Initializes the YOLOv8-World model with the given pre-trained model file. Supports *.pt and *.yaml formats. |
|
|
|
Args: |
|
model (str | Path): Path to the pre-trained model. Defaults to 'yolov8s-world.pt'. |
|
""" |
|
super().__init__(model=model, task="detect") |
|
|
|
|
|
if not hasattr(self.model, "names"): |
|
self.model.names = yaml_load(ROOT / "cfg/datasets/coco8.yaml").get("names") |
|
|
|
@property |
|
def task_map(self): |
|
"""Map head to model, validator, and predictor classes.""" |
|
return { |
|
"detect": { |
|
"model": WorldModel, |
|
"validator": yolo.detect.DetectionValidator, |
|
"predictor": yolo.detect.DetectionPredictor, |
|
} |
|
} |
|
|
|
def set_classes(self, classes): |
|
""" |
|
Set classes. |
|
|
|
Args: |
|
classes (List(str)): A list of categories i.e ["person"]. |
|
""" |
|
self.model.set_classes(classes) |
|
|
|
background = " " |
|
if background in classes: |
|
classes.remove(background) |
|
self.model.names = classes |
|
|
|
|
|
|
|
if self.predictor: |
|
self.predictor.model.names = classes |
|
|