Spaces:
Running
Running
File size: 6,301 Bytes
1999a98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
# Ultralytics π AGPL-3.0 License - https://ultralytics.com/license
from copy import copy
import torch
from ultralytics.data import ClassificationDataset, build_dataloader
from ultralytics.engine.trainer import BaseTrainer
from ultralytics.models import yolo
from ultralytics.nn.tasks import ClassificationModel
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
from ultralytics.utils.plotting import plot_images, plot_results
from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first
class ClassificationTrainer(BaseTrainer):
"""
A class extending the BaseTrainer class for training based on a classification model.
Notes:
- Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
Example:
```python
from ultralytics.models.yolo.classify import ClassificationTrainer
args = dict(model="yolov8n-cls.pt", data="imagenet10", epochs=3)
trainer = ClassificationTrainer(overrides=args)
trainer.train()
```
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initialize a ClassificationTrainer object with optional configuration overrides and callbacks."""
if overrides is None:
overrides = {}
overrides["task"] = "classify"
if overrides.get("imgsz") is None:
overrides["imgsz"] = 224
super().__init__(cfg, overrides, _callbacks)
def set_model_attributes(self):
"""Set the YOLO model's class names from the loaded dataset."""
self.model.names = self.data["names"]
def get_model(self, cfg=None, weights=None, verbose=True):
"""Returns a modified PyTorch model configured for training YOLO."""
model = ClassificationModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
for m in model.modules():
if not self.args.pretrained and hasattr(m, "reset_parameters"):
m.reset_parameters()
if isinstance(m, torch.nn.Dropout) and self.args.dropout:
m.p = self.args.dropout # set dropout
for p in model.parameters():
p.requires_grad = True # for training
return model
def setup_model(self):
"""Load, create or download model for any task."""
import torchvision # scope for faster 'import ultralytics'
if str(self.model) in torchvision.models.__dict__:
self.model = torchvision.models.__dict__[self.model](
weights="IMAGENET1K_V1" if self.args.pretrained else None
)
ckpt = None
else:
ckpt = super().setup_model()
ClassificationModel.reshape_outputs(self.model, self.data["nc"])
return ckpt
def build_dataset(self, img_path, mode="train", batch=None):
"""Creates a ClassificationDataset instance given an image path, and mode (train/test etc.)."""
return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
"""Returns PyTorch DataLoader with transforms to preprocess images for inference."""
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
dataset = self.build_dataset(dataset_path, mode)
loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank)
# Attach inference transforms
if mode != "train":
if is_parallel(self.model):
self.model.module.transforms = loader.dataset.torch_transforms
else:
self.model.transforms = loader.dataset.torch_transforms
return loader
def preprocess_batch(self, batch):
"""Preprocesses a batch of images and classes."""
batch["img"] = batch["img"].to(self.device)
batch["cls"] = batch["cls"].to(self.device)
return batch
def progress_string(self):
"""Returns a formatted string showing training progress."""
return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
"Epoch",
"GPU_mem",
*self.loss_names,
"Instances",
"Size",
)
def get_validator(self):
"""Returns an instance of ClassificationValidator for validation."""
self.loss_names = ["loss"]
return yolo.classify.ClassificationValidator(
self.test_loader, self.save_dir, args=copy(self.args), _callbacks=self.callbacks
)
def label_loss_items(self, loss_items=None, prefix="train"):
"""
Returns a loss dict with labelled training loss items tensor.
Not needed for classification but necessary for segmentation & detection
"""
keys = [f"{prefix}/{x}" for x in self.loss_names]
if loss_items is None:
return keys
loss_items = [round(float(loss_items), 5)]
return dict(zip(keys, loss_items))
def plot_metrics(self):
"""Plots metrics from a CSV file."""
plot_results(file=self.csv, classify=True, on_plot=self.on_plot) # save results.png
def final_eval(self):
"""Evaluate trained model and save validation results."""
for f in self.last, self.best:
if f.exists():
strip_optimizer(f) # strip optimizers
if f is self.best:
LOGGER.info(f"\nValidating {f}...")
self.validator.args.data = self.args.data
self.validator.args.plots = self.args.plots
self.metrics = self.validator(model=f)
self.metrics.pop("fitness", None)
self.run_callbacks("on_fit_epoch_end")
def plot_training_samples(self, batch, ni):
"""Plots training samples with their annotations."""
plot_images(
images=batch["img"],
batch_idx=torch.arange(len(batch["img"])),
cls=batch["cls"].view(-1), # warning: use .view(), not .squeeze() for Classify models
fname=self.save_dir / f"train_batch{ni}.jpg",
on_plot=self.on_plot,
)
|