App
Browse files- .gitignore +1 -0
- app.py +118 -0
- crop1.png +0 -0
- crop2.png +0 -0
- crop3.png +0 -0
- requirements.txt +6 -0
- trainer.py +418 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__
|
app.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from torchvision import transforms
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
from trainer import CustomSemanticSegmentationTask
|
7 |
+
|
8 |
+
# Load a pre-trained semantic segmentation model
|
9 |
+
task = CustomSemanticSegmentationTask.load_from_checkpoint("maui_demo_model.ckpt", map_location="cpu")
|
10 |
+
task.freeze()
|
11 |
+
model = task.model
|
12 |
+
model = model.eval()
|
13 |
+
|
14 |
+
|
15 |
+
# Define the image transformations
|
16 |
+
preprocess = transforms.Compose([
|
17 |
+
transforms.ToTensor(),
|
18 |
+
])
|
19 |
+
|
20 |
+
# Function to perform semantic segmentation
|
21 |
+
def segment_image(image):
|
22 |
+
input_tensor = preprocess(image).unsqueeze(0)
|
23 |
+
with torch.inference_mode():
|
24 |
+
output = model(input_tensor)
|
25 |
+
output_predictions = output.argmax(1).squeeze().numpy()
|
26 |
+
return output_predictions
|
27 |
+
|
28 |
+
|
29 |
+
# Preexisting images
|
30 |
+
preexisting_images = ["crop1.png", "crop2.png", "crop3.png"]
|
31 |
+
|
32 |
+
# Function to handle user input and run the model
|
33 |
+
def handle_image(image):
|
34 |
+
image = Image.open(image)
|
35 |
+
mask = segment_image(image)
|
36 |
+
|
37 |
+
# Decode the segmentation output
|
38 |
+
colormap = np.array([
|
39 |
+
[0, 0, 0], # nodata
|
40 |
+
[0, 0, 0], # background
|
41 |
+
[0, 255, 0], # building
|
42 |
+
[255, 0, 0], # damage
|
43 |
+
])
|
44 |
+
output = colormap[mask].astype('uint8')
|
45 |
+
|
46 |
+
segmented_image = np.array(image)
|
47 |
+
segmented_image[mask > 1] = (0.5 * output[mask > 1]) + (0.5 * segmented_image[mask > 1])
|
48 |
+
segmented_image = Image.fromarray(segmented_image)
|
49 |
+
return segmented_image
|
50 |
+
|
51 |
+
# Create the Gradio interface
|
52 |
+
image_input = gr.Image(type="filepath", label="Upload an Image", sources=["upload"])
|
53 |
+
image_output = gr.Image(type="pil", label="Output")
|
54 |
+
|
55 |
+
css_content = """
|
56 |
+
.legend {
|
57 |
+
list-style: none;
|
58 |
+
padding: 0;
|
59 |
+
}
|
60 |
+
.legend li {
|
61 |
+
line-height: 20px; /* Match the height of the color-box */
|
62 |
+
}
|
63 |
+
.legend .color-box {
|
64 |
+
display: inline-block;
|
65 |
+
width: 20px;
|
66 |
+
height: 20px;
|
67 |
+
margin-right: 5px;
|
68 |
+
border: 1px solid #000; /* Optional: adds a border around the color box */
|
69 |
+
vertical-align: middle; /* Centers the box vertically relative to the text */
|
70 |
+
}
|
71 |
+
.background { background-color: #FFFFFF; } /* White */
|
72 |
+
.building { background-color: #00FF00; } /* Green */
|
73 |
+
.damage { background-color: #FF0000; } /* Red */
|
74 |
+
"""
|
75 |
+
|
76 |
+
html_content = """
|
77 |
+
<div style="font-size:large;">
|
78 |
+
<p>
|
79 |
+
This application demonstrates the input and output of the building damage assessment model trained through the tutorial of the
|
80 |
+
<a href="https://github.com/microsoft/building-damage/assessment/" target="_blank">Microsoft AI for Good Building Damage
|
81 |
+
Assessment Toolkit</a>. This particular model was trained on
|
82 |
+
<a href="https://radiantearth.github.io/stac-browser/#/external/maxar-opendata.s3.amazonaws.com/events/Maui-Hawaii-fires-Aug-23/collection.json?.language=en">Maxar Open Data imagery</a>
|
83 |
+
captured over Lahaina during the Maui Wildfires in August, 2023 and 106 polygon annotations
|
84 |
+
created over the same imagery. The "Building Damage Assessment Toolkit" details a workflow for quickly
|
85 |
+
modeling <i>any</i> new post disaster imagery by:
|
86 |
+
<ul style='padding-left:20px'>
|
87 |
+
<li>Setting up a web labeling tool instance with the post disaster imagery to facilitate rapid annotations of the imagery</li>
|
88 |
+
<li>Fine-tuning a building damage assessment model using the imagery and annotations</li>
|
89 |
+
<li>Running inference with the model over potentially large scenes</li>
|
90 |
+
<li>Merging and summarizing the output of the model using different building footprints layers</li>
|
91 |
+
</ul>
|
92 |
+
This workflow allows a user to consistently and rapidly create a good model for a particular event,
|
93 |
+
but that is overfit to that event (i.e. it will not generalize to other events).
|
94 |
+
</p>
|
95 |
+
<p>
|
96 |
+
The model outputs per-pixel predictions of building damage, with the following classes:
|
97 |
+
<ul class="legend">
|
98 |
+
<li><span class="color-box background"></span>Background (transparent)</li>
|
99 |
+
<li><span class="color-box building"></span>Building (green)</li>
|
100 |
+
<li><span class="color-box damage"></span>Damage (red)</li>
|
101 |
+
</ul>
|
102 |
+
</p>
|
103 |
+
</div>
|
104 |
+
"""
|
105 |
+
|
106 |
+
iface = gr.Interface(
|
107 |
+
fn=handle_image,
|
108 |
+
inputs=image_input,
|
109 |
+
outputs=image_output,
|
110 |
+
title="Building damage assessment model demo -- Maui Wildfires 2023",
|
111 |
+
examples=preexisting_images,
|
112 |
+
css=css_content,
|
113 |
+
description=html_content,
|
114 |
+
allow_flagging="never"
|
115 |
+
)
|
116 |
+
|
117 |
+
# Launch the app
|
118 |
+
iface.launch(share=True)
|
crop1.png
ADDED
![]() |
crop2.png
ADDED
![]() |
crop3.png
ADDED
![]() |
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
lightning
|
2 |
+
numpy
|
3 |
+
segmentation-models-pytorch
|
4 |
+
torch
|
5 |
+
torchmetrics
|
6 |
+
torchvision
|
trainer.py
ADDED
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
"""Trainers for semantic segmentation."""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import warnings
|
8 |
+
from abc import ABC, abstractmethod
|
9 |
+
from collections import OrderedDict
|
10 |
+
from collections.abc import Sequence
|
11 |
+
from typing import Any, Optional, Union
|
12 |
+
|
13 |
+
import lightning
|
14 |
+
import segmentation_models_pytorch as smp
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
from lightning.pytorch import LightningModule
|
18 |
+
from lightning.pytorch.callbacks import Callback
|
19 |
+
from torch import Tensor
|
20 |
+
from torch.optim import AdamW
|
21 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
22 |
+
from torchmetrics import MetricCollection
|
23 |
+
from torchmetrics.classification import MulticlassAccuracy, MulticlassJaccardIndex
|
24 |
+
from torchvision.models._api import WeightsEnum
|
25 |
+
|
26 |
+
|
27 |
+
def get_weight(name: str) -> WeightsEnum:
|
28 |
+
"""Get the weights enum value by its full name.
|
29 |
+
|
30 |
+
.. versionadded:: 0.4
|
31 |
+
|
32 |
+
Args:
|
33 |
+
name: Name of the weight enum entry.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
The requested weight enum.
|
37 |
+
"""
|
38 |
+
return eval(name)
|
39 |
+
|
40 |
+
|
41 |
+
def extract_backbone(path: str) -> tuple[str, "OrderedDict[str, Tensor]"]:
|
42 |
+
"""Extracts a backbone from a lightning checkpoint file.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
path: path to checkpoint file (.ckpt)
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
tuple containing model name and state dict
|
49 |
+
|
50 |
+
Raises:
|
51 |
+
ValueError: if 'model' or 'backbone' not in
|
52 |
+
checkpoint['hyper_parameters']
|
53 |
+
|
54 |
+
.. versionchanged:: 0.4
|
55 |
+
Renamed from *extract_encoder* to *extract_backbone*
|
56 |
+
"""
|
57 |
+
checkpoint = torch.load(path, map_location=torch.device("cpu"))
|
58 |
+
if "model" in checkpoint["hyper_parameters"]:
|
59 |
+
name = checkpoint["hyper_parameters"]["model"]
|
60 |
+
state_dict = checkpoint["state_dict"]
|
61 |
+
state_dict = OrderedDict({k: v for k, v in state_dict.items() if "model." in k})
|
62 |
+
state_dict = OrderedDict(
|
63 |
+
{k.replace("model.", ""): v for k, v in state_dict.items()}
|
64 |
+
)
|
65 |
+
elif "backbone" in checkpoint["hyper_parameters"]:
|
66 |
+
name = checkpoint["hyper_parameters"]["backbone"]
|
67 |
+
state_dict = checkpoint["state_dict"]
|
68 |
+
state_dict = OrderedDict(
|
69 |
+
{k: v for k, v in state_dict.items() if "model.backbone.model" in k}
|
70 |
+
)
|
71 |
+
state_dict = OrderedDict(
|
72 |
+
{k.replace("model.backbone.model.", ""): v for k, v in state_dict.items()}
|
73 |
+
)
|
74 |
+
else:
|
75 |
+
raise ValueError(
|
76 |
+
"Unknown checkpoint task. Only backbone or model extraction is supported"
|
77 |
+
)
|
78 |
+
|
79 |
+
return name, state_dict
|
80 |
+
|
81 |
+
|
82 |
+
class BaseTask(LightningModule, ABC):
|
83 |
+
"""Abstract base class for all TorchGeo trainers.
|
84 |
+
|
85 |
+
.. versionadded:: 0.5
|
86 |
+
"""
|
87 |
+
|
88 |
+
#: Model to train.
|
89 |
+
model: Any
|
90 |
+
|
91 |
+
#: Performance metric to monitor in learning rate scheduler and callbacks.
|
92 |
+
monitor = "val_loss"
|
93 |
+
|
94 |
+
#: Whether the goal is to minimize or maximize the performance metric to monitor.
|
95 |
+
mode = "min"
|
96 |
+
|
97 |
+
def __init__(self, ignore: Optional[Union[Sequence[str], str]] = None) -> None:
|
98 |
+
"""Initialize a new BaseTask instance.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
ignore: Arguments to skip when saving hyperparameters.
|
102 |
+
"""
|
103 |
+
super().__init__()
|
104 |
+
self.save_hyperparameters(ignore=ignore)
|
105 |
+
self.configure_losses()
|
106 |
+
self.configure_metrics()
|
107 |
+
self.configure_models()
|
108 |
+
|
109 |
+
def configure_losses(self) -> None:
|
110 |
+
"""Initialize the loss criterion."""
|
111 |
+
|
112 |
+
def configure_metrics(self) -> None:
|
113 |
+
"""Initialize the performance metrics."""
|
114 |
+
|
115 |
+
@abstractmethod
|
116 |
+
def configure_models(self) -> None:
|
117 |
+
"""Initialize the model."""
|
118 |
+
|
119 |
+
def configure_optimizers(
|
120 |
+
self,
|
121 |
+
) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig":
|
122 |
+
"""Initialize the optimizer and learning rate scheduler.
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
Optimizer and learning rate scheduler.
|
126 |
+
"""
|
127 |
+
optimizer = AdamW(self.parameters(), lr=self.hparams["lr"])
|
128 |
+
scheduler = ReduceLROnPlateau(optimizer, patience=self.hparams["patience"])
|
129 |
+
return {
|
130 |
+
"optimizer": optimizer,
|
131 |
+
"lr_scheduler": {"scheduler": scheduler, "monitor": self.monitor},
|
132 |
+
}
|
133 |
+
|
134 |
+
def forward(self, *args: Any, **kwargs: Any) -> Any:
|
135 |
+
"""Forward pass of the model.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
args: Arguments to pass to model.
|
139 |
+
kwargs: Keyword arguments to pass to model.
|
140 |
+
|
141 |
+
Returns:
|
142 |
+
Output of the model.
|
143 |
+
"""
|
144 |
+
return self.model(*args, **kwargs)
|
145 |
+
|
146 |
+
|
147 |
+
class SemanticSegmentationTask(BaseTask):
|
148 |
+
"""Semantic Segmentation."""
|
149 |
+
|
150 |
+
def __init__(
|
151 |
+
self,
|
152 |
+
model: str = "unet",
|
153 |
+
backbone: str = "resnet50",
|
154 |
+
weights: Optional[Union[WeightsEnum, str, bool]] = None,
|
155 |
+
in_channels: int = 3,
|
156 |
+
num_classes: int = 1000,
|
157 |
+
num_filters: int = 3,
|
158 |
+
loss: str = "ce",
|
159 |
+
class_weights: Optional[Tensor] = None,
|
160 |
+
ignore_index: Optional[int] = None,
|
161 |
+
lr: float = 1e-3,
|
162 |
+
patience: int = 10,
|
163 |
+
freeze_backbone: bool = False,
|
164 |
+
freeze_decoder: bool = False,
|
165 |
+
) -> None:
|
166 |
+
"""Initialize a new SemanticSegmentationTask instance.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
model: Name of the
|
170 |
+
`smp <https://smp.readthedocs.io/en/latest/models.html>`__ model to use.
|
171 |
+
backbone: Name of the `timm
|
172 |
+
<https://smp.readthedocs.io/en/latest/encoders_timm.html>`__ or `smp
|
173 |
+
<https://smp.readthedocs.io/en/latest/encoders.html>`__ backbone to use.
|
174 |
+
weights: Initial model weights. Either a weight enum, the string
|
175 |
+
representation of a weight enum, True for ImageNet weights, False or
|
176 |
+
None for random weights, or the path to a saved model state dict. FCN
|
177 |
+
model does not support pretrained weights. Pretrained ViT weight enums
|
178 |
+
are not supported yet.
|
179 |
+
in_channels: Number of input channels to model.
|
180 |
+
num_classes: Number of prediction classes.
|
181 |
+
num_filters: Number of filters. Only applicable when model='fcn'.
|
182 |
+
loss: Name of the loss function, currently supports
|
183 |
+
'ce', 'jaccard' or 'focal' loss.
|
184 |
+
class_weights: Optional rescaling weight given to each
|
185 |
+
class and used with 'ce' loss.
|
186 |
+
ignore_index: Optional integer class index to ignore in the loss and
|
187 |
+
metrics.
|
188 |
+
lr: Learning rate for optimizer.
|
189 |
+
patience: Patience for learning rate scheduler.
|
190 |
+
freeze_backbone: Freeze the backbone network to fine-tune the
|
191 |
+
decoder and segmentation head.
|
192 |
+
freeze_decoder: Freeze the decoder network to linear probe
|
193 |
+
the segmentation head.
|
194 |
+
|
195 |
+
Warns:
|
196 |
+
UserWarning: When loss='jaccard' and ignore_index is specified.
|
197 |
+
|
198 |
+
.. versionchanged:: 0.3
|
199 |
+
*ignore_zeros* was renamed to *ignore_index*.
|
200 |
+
|
201 |
+
.. versionchanged:: 0.4
|
202 |
+
*segmentation_model*, *encoder_name*, and *encoder_weights*
|
203 |
+
were renamed to *model*, *backbone*, and *weights*.
|
204 |
+
|
205 |
+
.. versionadded: 0.5
|
206 |
+
The *class_weights*, *freeze_backbone*, and *freeze_decoder* parameters.
|
207 |
+
|
208 |
+
.. versionchanged:: 0.5
|
209 |
+
The *weights* parameter now supports WeightEnums and checkpoint paths.
|
210 |
+
*learning_rate* and *learning_rate_schedule_patience* were renamed to
|
211 |
+
*lr* and *patience*.
|
212 |
+
"""
|
213 |
+
if ignore_index is not None and loss == "jaccard":
|
214 |
+
warnings.warn(
|
215 |
+
"ignore_index has no effect on training when loss='jaccard'",
|
216 |
+
UserWarning,
|
217 |
+
)
|
218 |
+
|
219 |
+
self.weights = weights
|
220 |
+
super().__init__(ignore="weights")
|
221 |
+
|
222 |
+
def configure_losses(self) -> None:
|
223 |
+
"""Initialize the loss criterion.
|
224 |
+
|
225 |
+
Raises:
|
226 |
+
ValueError: If *loss* is invalid.
|
227 |
+
"""
|
228 |
+
loss: str = self.hparams["loss"]
|
229 |
+
ignore_index = self.hparams["ignore_index"]
|
230 |
+
if loss == "ce":
|
231 |
+
ignore_value = -1000 if ignore_index is None else ignore_index
|
232 |
+
self.criterion = nn.CrossEntropyLoss(
|
233 |
+
ignore_index=ignore_value, weight=self.hparams["class_weights"]
|
234 |
+
)
|
235 |
+
elif loss == "jaccard":
|
236 |
+
self.criterion = smp.losses.JaccardLoss(
|
237 |
+
mode="multiclass", classes=self.hparams["num_classes"]
|
238 |
+
)
|
239 |
+
elif loss == "focal":
|
240 |
+
self.criterion = smp.losses.FocalLoss(
|
241 |
+
"multiclass", ignore_index=ignore_index, normalized=True
|
242 |
+
)
|
243 |
+
else:
|
244 |
+
raise ValueError(
|
245 |
+
f"Loss type '{loss}' is not valid. "
|
246 |
+
"Currently, supports 'ce', 'jaccard' or 'focal' loss."
|
247 |
+
)
|
248 |
+
|
249 |
+
def configure_metrics(self) -> None:
|
250 |
+
"""Initialize the performance metrics.
|
251 |
+
|
252 |
+
* :class:`~torchmetrics.classification.MulticlassAccuracy`: Overall accuracy
|
253 |
+
(OA) using 'micro' averaging. The number of true positives divided by the
|
254 |
+
dataset size. Higher values are better.
|
255 |
+
* :class:`~torchmetrics.classification.MulticlassJaccardIndex`: Intersection
|
256 |
+
over union (IoU). Uses 'micro' averaging. Higher valuers are better.
|
257 |
+
|
258 |
+
.. note::
|
259 |
+
* 'Micro' averaging suits overall performance evaluation but may not reflect
|
260 |
+
minority class accuracy.
|
261 |
+
* 'Macro' averaging, not used here, gives equal weight to each class, useful
|
262 |
+
for balanced performance assessment across imbalanced classes.
|
263 |
+
"""
|
264 |
+
num_classes: int = self.hparams["num_classes"]
|
265 |
+
ignore_index: Optional[int] = self.hparams["ignore_index"]
|
266 |
+
metrics = MetricCollection(
|
267 |
+
[
|
268 |
+
MulticlassAccuracy(
|
269 |
+
num_classes=num_classes,
|
270 |
+
ignore_index=ignore_index,
|
271 |
+
multidim_average="global",
|
272 |
+
average="micro",
|
273 |
+
),
|
274 |
+
MulticlassJaccardIndex(
|
275 |
+
num_classes=num_classes, ignore_index=ignore_index, average="micro"
|
276 |
+
),
|
277 |
+
]
|
278 |
+
)
|
279 |
+
self.train_metrics = metrics.clone(prefix="train_")
|
280 |
+
self.val_metrics = metrics.clone(prefix="val_")
|
281 |
+
self.test_metrics = metrics.clone(prefix="test_")
|
282 |
+
|
283 |
+
def configure_models(self) -> None:
|
284 |
+
"""Initialize the model.
|
285 |
+
|
286 |
+
Raises:
|
287 |
+
ValueError: If *model* is invalid.
|
288 |
+
"""
|
289 |
+
model: str = self.hparams["model"]
|
290 |
+
backbone: str = self.hparams["backbone"]
|
291 |
+
weights = self.weights
|
292 |
+
in_channels: int = self.hparams["in_channels"]
|
293 |
+
num_classes: int = self.hparams["num_classes"]
|
294 |
+
num_filters: int = self.hparams["num_filters"]
|
295 |
+
|
296 |
+
if model == "unet":
|
297 |
+
self.model = smp.Unet(
|
298 |
+
encoder_name=backbone,
|
299 |
+
encoder_weights="imagenet" if weights is True else None,
|
300 |
+
in_channels=in_channels,
|
301 |
+
classes=num_classes,
|
302 |
+
)
|
303 |
+
elif model == "deeplabv3+":
|
304 |
+
self.model = smp.DeepLabV3Plus(
|
305 |
+
encoder_name=backbone,
|
306 |
+
encoder_weights="imagenet" if weights is True else None,
|
307 |
+
in_channels=in_channels,
|
308 |
+
classes=num_classes,
|
309 |
+
)
|
310 |
+
else:
|
311 |
+
raise ValueError(
|
312 |
+
f"Model type '{model}' is not valid. "
|
313 |
+
"Currently, only supports 'unet', 'deeplabv3+' and 'fcn'."
|
314 |
+
)
|
315 |
+
|
316 |
+
if weights and weights is not True:
|
317 |
+
if isinstance(weights, WeightsEnum):
|
318 |
+
state_dict = weights.get_state_dict(progress=True)
|
319 |
+
elif os.path.exists(weights):
|
320 |
+
_, state_dict = extract_backbone(weights)
|
321 |
+
else:
|
322 |
+
state_dict = get_weight(weights).get_state_dict(progress=True)
|
323 |
+
self.model.encoder.load_state_dict(state_dict)
|
324 |
+
|
325 |
+
# Freeze backbone
|
326 |
+
if self.hparams["freeze_backbone"] and model in ["unet", "deeplabv3+"]:
|
327 |
+
for param in self.model.encoder.parameters():
|
328 |
+
param.requires_grad = False
|
329 |
+
|
330 |
+
# Freeze decoder
|
331 |
+
if self.hparams["freeze_decoder"] and model in ["unet", "deeplabv3+"]:
|
332 |
+
for param in self.model.decoder.parameters():
|
333 |
+
param.requires_grad = False
|
334 |
+
|
335 |
+
def training_step(
|
336 |
+
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
|
337 |
+
) -> Tensor:
|
338 |
+
"""Compute the training loss and additional metrics.
|
339 |
+
|
340 |
+
Args:
|
341 |
+
batch: The output of your DataLoader.
|
342 |
+
batch_idx: Integer displaying index of this batch.
|
343 |
+
dataloader_idx: Index of the current dataloader.
|
344 |
+
|
345 |
+
Returns:
|
346 |
+
The loss tensor.
|
347 |
+
"""
|
348 |
+
x = batch["image"]
|
349 |
+
y = batch["mask"]
|
350 |
+
y_hat = self(x)
|
351 |
+
loss: Tensor = self.criterion(y_hat, y)
|
352 |
+
self.log("train_loss", loss)
|
353 |
+
self.train_metrics(y_hat, y)
|
354 |
+
self.log_dict(self.train_metrics)
|
355 |
+
return loss
|
356 |
+
|
357 |
+
def validation_step(
|
358 |
+
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
|
359 |
+
) -> None:
|
360 |
+
"""Compute the validation loss and additional metrics.
|
361 |
+
|
362 |
+
Args:
|
363 |
+
batch: The output of your DataLoader.
|
364 |
+
batch_idx: Integer displaying index of this batch.
|
365 |
+
dataloader_idx: Index of the current dataloader.
|
366 |
+
"""
|
367 |
+
x = batch["image"]
|
368 |
+
y = batch["mask"]
|
369 |
+
y_hat = self(x)
|
370 |
+
loss = self.criterion(y_hat, y)
|
371 |
+
self.log("val_loss", loss)
|
372 |
+
self.val_metrics(y_hat, y)
|
373 |
+
self.log_dict(self.val_metrics)
|
374 |
+
|
375 |
+
def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
|
376 |
+
"""Compute the test loss and additional metrics.
|
377 |
+
|
378 |
+
Args:
|
379 |
+
batch: The output of your DataLoader.
|
380 |
+
batch_idx: Integer displaying index of this batch.
|
381 |
+
dataloader_idx: Index of the current dataloader.
|
382 |
+
"""
|
383 |
+
x = batch["image"]
|
384 |
+
y = batch["mask"]
|
385 |
+
y_hat = self(x)
|
386 |
+
loss = self.criterion(y_hat, y)
|
387 |
+
self.log("test_loss", loss)
|
388 |
+
self.test_metrics(y_hat, y)
|
389 |
+
self.log_dict(self.test_metrics)
|
390 |
+
|
391 |
+
def predict_step(
|
392 |
+
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
|
393 |
+
) -> Tensor:
|
394 |
+
"""Compute the predicted class probabilities.
|
395 |
+
|
396 |
+
Args:
|
397 |
+
batch: The output of your DataLoader.
|
398 |
+
batch_idx: Integer displaying index of this batch.
|
399 |
+
dataloader_idx: Index of the current dataloader.
|
400 |
+
|
401 |
+
Returns:
|
402 |
+
Output predicted probabilities.
|
403 |
+
"""
|
404 |
+
x = batch["image"]
|
405 |
+
y_hat: Tensor = self(x).softmax(dim=1)
|
406 |
+
return y_hat
|
407 |
+
|
408 |
+
|
409 |
+
class CustomSemanticSegmentationTask(SemanticSegmentationTask):
|
410 |
+
"""A custom trainer for semantic segmentation tasks."""
|
411 |
+
|
412 |
+
def configure_callbacks(self) -> list[Callback]:
|
413 |
+
"""Configures the callbacks for the trainer.
|
414 |
+
|
415 |
+
Returns:
|
416 |
+
an empty list to override the default callbacks, we set these in the Trainer
|
417 |
+
"""
|
418 |
+
return []
|