calebrob6 commited on
Commit
c180e59
·
1 Parent(s): 2865d68
Files changed (7) hide show
  1. .gitignore +1 -0
  2. app.py +118 -0
  3. crop1.png +0 -0
  4. crop2.png +0 -0
  5. crop3.png +0 -0
  6. requirements.txt +6 -0
  7. trainer.py +418 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ import numpy as np
6
+ from trainer import CustomSemanticSegmentationTask
7
+
8
+ # Load a pre-trained semantic segmentation model
9
+ task = CustomSemanticSegmentationTask.load_from_checkpoint("maui_demo_model.ckpt", map_location="cpu")
10
+ task.freeze()
11
+ model = task.model
12
+ model = model.eval()
13
+
14
+
15
+ # Define the image transformations
16
+ preprocess = transforms.Compose([
17
+ transforms.ToTensor(),
18
+ ])
19
+
20
+ # Function to perform semantic segmentation
21
+ def segment_image(image):
22
+ input_tensor = preprocess(image).unsqueeze(0)
23
+ with torch.inference_mode():
24
+ output = model(input_tensor)
25
+ output_predictions = output.argmax(1).squeeze().numpy()
26
+ return output_predictions
27
+
28
+
29
+ # Preexisting images
30
+ preexisting_images = ["crop1.png", "crop2.png", "crop3.png"]
31
+
32
+ # Function to handle user input and run the model
33
+ def handle_image(image):
34
+ image = Image.open(image)
35
+ mask = segment_image(image)
36
+
37
+ # Decode the segmentation output
38
+ colormap = np.array([
39
+ [0, 0, 0], # nodata
40
+ [0, 0, 0], # background
41
+ [0, 255, 0], # building
42
+ [255, 0, 0], # damage
43
+ ])
44
+ output = colormap[mask].astype('uint8')
45
+
46
+ segmented_image = np.array(image)
47
+ segmented_image[mask > 1] = (0.5 * output[mask > 1]) + (0.5 * segmented_image[mask > 1])
48
+ segmented_image = Image.fromarray(segmented_image)
49
+ return segmented_image
50
+
51
+ # Create the Gradio interface
52
+ image_input = gr.Image(type="filepath", label="Upload an Image", sources=["upload"])
53
+ image_output = gr.Image(type="pil", label="Output")
54
+
55
+ css_content = """
56
+ .legend {
57
+ list-style: none;
58
+ padding: 0;
59
+ }
60
+ .legend li {
61
+ line-height: 20px; /* Match the height of the color-box */
62
+ }
63
+ .legend .color-box {
64
+ display: inline-block;
65
+ width: 20px;
66
+ height: 20px;
67
+ margin-right: 5px;
68
+ border: 1px solid #000; /* Optional: adds a border around the color box */
69
+ vertical-align: middle; /* Centers the box vertically relative to the text */
70
+ }
71
+ .background { background-color: #FFFFFF; } /* White */
72
+ .building { background-color: #00FF00; } /* Green */
73
+ .damage { background-color: #FF0000; } /* Red */
74
+ """
75
+
76
+ html_content = """
77
+ <div style="font-size:large;">
78
+ <p>
79
+ This application demonstrates the input and output of the building damage assessment model trained through the tutorial of the
80
+ <a href="https://github.com/microsoft/building-damage/assessment/" target="_blank">Microsoft AI for Good Building Damage
81
+ Assessment Toolkit</a>. This particular model was trained on
82
+ <a href="https://radiantearth.github.io/stac-browser/#/external/maxar-opendata.s3.amazonaws.com/events/Maui-Hawaii-fires-Aug-23/collection.json?.language=en">Maxar Open Data imagery</a>
83
+ captured over Lahaina during the Maui Wildfires in August, 2023 and 106 polygon annotations
84
+ created over the same imagery. The "Building Damage Assessment Toolkit" details a workflow for quickly
85
+ modeling <i>any</i> new post disaster imagery by:
86
+ <ul style='padding-left:20px'>
87
+ <li>Setting up a web labeling tool instance with the post disaster imagery to facilitate rapid annotations of the imagery</li>
88
+ <li>Fine-tuning a building damage assessment model using the imagery and annotations</li>
89
+ <li>Running inference with the model over potentially large scenes</li>
90
+ <li>Merging and summarizing the output of the model using different building footprints layers</li>
91
+ </ul>
92
+ This workflow allows a user to consistently and rapidly create a good model for a particular event,
93
+ but that is overfit to that event (i.e. it will not generalize to other events).
94
+ </p>
95
+ <p>
96
+ The model outputs per-pixel predictions of building damage, with the following classes:
97
+ <ul class="legend">
98
+ <li><span class="color-box background"></span>Background (transparent)</li>
99
+ <li><span class="color-box building"></span>Building (green)</li>
100
+ <li><span class="color-box damage"></span>Damage (red)</li>
101
+ </ul>
102
+ </p>
103
+ </div>
104
+ """
105
+
106
+ iface = gr.Interface(
107
+ fn=handle_image,
108
+ inputs=image_input,
109
+ outputs=image_output,
110
+ title="Building damage assessment model demo -- Maui Wildfires 2023",
111
+ examples=preexisting_images,
112
+ css=css_content,
113
+ description=html_content,
114
+ allow_flagging="never"
115
+ )
116
+
117
+ # Launch the app
118
+ iface.launch(share=True)
crop1.png ADDED
crop2.png ADDED
crop3.png ADDED
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ lightning
2
+ numpy
3
+ segmentation-models-pytorch
4
+ torch
5
+ torchmetrics
6
+ torchvision
trainer.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Trainers for semantic segmentation."""
5
+
6
+ import os
7
+ import warnings
8
+ from abc import ABC, abstractmethod
9
+ from collections import OrderedDict
10
+ from collections.abc import Sequence
11
+ from typing import Any, Optional, Union
12
+
13
+ import lightning
14
+ import segmentation_models_pytorch as smp
15
+ import torch
16
+ import torch.nn as nn
17
+ from lightning.pytorch import LightningModule
18
+ from lightning.pytorch.callbacks import Callback
19
+ from torch import Tensor
20
+ from torch.optim import AdamW
21
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
22
+ from torchmetrics import MetricCollection
23
+ from torchmetrics.classification import MulticlassAccuracy, MulticlassJaccardIndex
24
+ from torchvision.models._api import WeightsEnum
25
+
26
+
27
+ def get_weight(name: str) -> WeightsEnum:
28
+ """Get the weights enum value by its full name.
29
+
30
+ .. versionadded:: 0.4
31
+
32
+ Args:
33
+ name: Name of the weight enum entry.
34
+
35
+ Returns:
36
+ The requested weight enum.
37
+ """
38
+ return eval(name)
39
+
40
+
41
+ def extract_backbone(path: str) -> tuple[str, "OrderedDict[str, Tensor]"]:
42
+ """Extracts a backbone from a lightning checkpoint file.
43
+
44
+ Args:
45
+ path: path to checkpoint file (.ckpt)
46
+
47
+ Returns:
48
+ tuple containing model name and state dict
49
+
50
+ Raises:
51
+ ValueError: if 'model' or 'backbone' not in
52
+ checkpoint['hyper_parameters']
53
+
54
+ .. versionchanged:: 0.4
55
+ Renamed from *extract_encoder* to *extract_backbone*
56
+ """
57
+ checkpoint = torch.load(path, map_location=torch.device("cpu"))
58
+ if "model" in checkpoint["hyper_parameters"]:
59
+ name = checkpoint["hyper_parameters"]["model"]
60
+ state_dict = checkpoint["state_dict"]
61
+ state_dict = OrderedDict({k: v for k, v in state_dict.items() if "model." in k})
62
+ state_dict = OrderedDict(
63
+ {k.replace("model.", ""): v for k, v in state_dict.items()}
64
+ )
65
+ elif "backbone" in checkpoint["hyper_parameters"]:
66
+ name = checkpoint["hyper_parameters"]["backbone"]
67
+ state_dict = checkpoint["state_dict"]
68
+ state_dict = OrderedDict(
69
+ {k: v for k, v in state_dict.items() if "model.backbone.model" in k}
70
+ )
71
+ state_dict = OrderedDict(
72
+ {k.replace("model.backbone.model.", ""): v for k, v in state_dict.items()}
73
+ )
74
+ else:
75
+ raise ValueError(
76
+ "Unknown checkpoint task. Only backbone or model extraction is supported"
77
+ )
78
+
79
+ return name, state_dict
80
+
81
+
82
+ class BaseTask(LightningModule, ABC):
83
+ """Abstract base class for all TorchGeo trainers.
84
+
85
+ .. versionadded:: 0.5
86
+ """
87
+
88
+ #: Model to train.
89
+ model: Any
90
+
91
+ #: Performance metric to monitor in learning rate scheduler and callbacks.
92
+ monitor = "val_loss"
93
+
94
+ #: Whether the goal is to minimize or maximize the performance metric to monitor.
95
+ mode = "min"
96
+
97
+ def __init__(self, ignore: Optional[Union[Sequence[str], str]] = None) -> None:
98
+ """Initialize a new BaseTask instance.
99
+
100
+ Args:
101
+ ignore: Arguments to skip when saving hyperparameters.
102
+ """
103
+ super().__init__()
104
+ self.save_hyperparameters(ignore=ignore)
105
+ self.configure_losses()
106
+ self.configure_metrics()
107
+ self.configure_models()
108
+
109
+ def configure_losses(self) -> None:
110
+ """Initialize the loss criterion."""
111
+
112
+ def configure_metrics(self) -> None:
113
+ """Initialize the performance metrics."""
114
+
115
+ @abstractmethod
116
+ def configure_models(self) -> None:
117
+ """Initialize the model."""
118
+
119
+ def configure_optimizers(
120
+ self,
121
+ ) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig":
122
+ """Initialize the optimizer and learning rate scheduler.
123
+
124
+ Returns:
125
+ Optimizer and learning rate scheduler.
126
+ """
127
+ optimizer = AdamW(self.parameters(), lr=self.hparams["lr"])
128
+ scheduler = ReduceLROnPlateau(optimizer, patience=self.hparams["patience"])
129
+ return {
130
+ "optimizer": optimizer,
131
+ "lr_scheduler": {"scheduler": scheduler, "monitor": self.monitor},
132
+ }
133
+
134
+ def forward(self, *args: Any, **kwargs: Any) -> Any:
135
+ """Forward pass of the model.
136
+
137
+ Args:
138
+ args: Arguments to pass to model.
139
+ kwargs: Keyword arguments to pass to model.
140
+
141
+ Returns:
142
+ Output of the model.
143
+ """
144
+ return self.model(*args, **kwargs)
145
+
146
+
147
+ class SemanticSegmentationTask(BaseTask):
148
+ """Semantic Segmentation."""
149
+
150
+ def __init__(
151
+ self,
152
+ model: str = "unet",
153
+ backbone: str = "resnet50",
154
+ weights: Optional[Union[WeightsEnum, str, bool]] = None,
155
+ in_channels: int = 3,
156
+ num_classes: int = 1000,
157
+ num_filters: int = 3,
158
+ loss: str = "ce",
159
+ class_weights: Optional[Tensor] = None,
160
+ ignore_index: Optional[int] = None,
161
+ lr: float = 1e-3,
162
+ patience: int = 10,
163
+ freeze_backbone: bool = False,
164
+ freeze_decoder: bool = False,
165
+ ) -> None:
166
+ """Initialize a new SemanticSegmentationTask instance.
167
+
168
+ Args:
169
+ model: Name of the
170
+ `smp <https://smp.readthedocs.io/en/latest/models.html>`__ model to use.
171
+ backbone: Name of the `timm
172
+ <https://smp.readthedocs.io/en/latest/encoders_timm.html>`__ or `smp
173
+ <https://smp.readthedocs.io/en/latest/encoders.html>`__ backbone to use.
174
+ weights: Initial model weights. Either a weight enum, the string
175
+ representation of a weight enum, True for ImageNet weights, False or
176
+ None for random weights, or the path to a saved model state dict. FCN
177
+ model does not support pretrained weights. Pretrained ViT weight enums
178
+ are not supported yet.
179
+ in_channels: Number of input channels to model.
180
+ num_classes: Number of prediction classes.
181
+ num_filters: Number of filters. Only applicable when model='fcn'.
182
+ loss: Name of the loss function, currently supports
183
+ 'ce', 'jaccard' or 'focal' loss.
184
+ class_weights: Optional rescaling weight given to each
185
+ class and used with 'ce' loss.
186
+ ignore_index: Optional integer class index to ignore in the loss and
187
+ metrics.
188
+ lr: Learning rate for optimizer.
189
+ patience: Patience for learning rate scheduler.
190
+ freeze_backbone: Freeze the backbone network to fine-tune the
191
+ decoder and segmentation head.
192
+ freeze_decoder: Freeze the decoder network to linear probe
193
+ the segmentation head.
194
+
195
+ Warns:
196
+ UserWarning: When loss='jaccard' and ignore_index is specified.
197
+
198
+ .. versionchanged:: 0.3
199
+ *ignore_zeros* was renamed to *ignore_index*.
200
+
201
+ .. versionchanged:: 0.4
202
+ *segmentation_model*, *encoder_name*, and *encoder_weights*
203
+ were renamed to *model*, *backbone*, and *weights*.
204
+
205
+ .. versionadded: 0.5
206
+ The *class_weights*, *freeze_backbone*, and *freeze_decoder* parameters.
207
+
208
+ .. versionchanged:: 0.5
209
+ The *weights* parameter now supports WeightEnums and checkpoint paths.
210
+ *learning_rate* and *learning_rate_schedule_patience* were renamed to
211
+ *lr* and *patience*.
212
+ """
213
+ if ignore_index is not None and loss == "jaccard":
214
+ warnings.warn(
215
+ "ignore_index has no effect on training when loss='jaccard'",
216
+ UserWarning,
217
+ )
218
+
219
+ self.weights = weights
220
+ super().__init__(ignore="weights")
221
+
222
+ def configure_losses(self) -> None:
223
+ """Initialize the loss criterion.
224
+
225
+ Raises:
226
+ ValueError: If *loss* is invalid.
227
+ """
228
+ loss: str = self.hparams["loss"]
229
+ ignore_index = self.hparams["ignore_index"]
230
+ if loss == "ce":
231
+ ignore_value = -1000 if ignore_index is None else ignore_index
232
+ self.criterion = nn.CrossEntropyLoss(
233
+ ignore_index=ignore_value, weight=self.hparams["class_weights"]
234
+ )
235
+ elif loss == "jaccard":
236
+ self.criterion = smp.losses.JaccardLoss(
237
+ mode="multiclass", classes=self.hparams["num_classes"]
238
+ )
239
+ elif loss == "focal":
240
+ self.criterion = smp.losses.FocalLoss(
241
+ "multiclass", ignore_index=ignore_index, normalized=True
242
+ )
243
+ else:
244
+ raise ValueError(
245
+ f"Loss type '{loss}' is not valid. "
246
+ "Currently, supports 'ce', 'jaccard' or 'focal' loss."
247
+ )
248
+
249
+ def configure_metrics(self) -> None:
250
+ """Initialize the performance metrics.
251
+
252
+ * :class:`~torchmetrics.classification.MulticlassAccuracy`: Overall accuracy
253
+ (OA) using 'micro' averaging. The number of true positives divided by the
254
+ dataset size. Higher values are better.
255
+ * :class:`~torchmetrics.classification.MulticlassJaccardIndex`: Intersection
256
+ over union (IoU). Uses 'micro' averaging. Higher valuers are better.
257
+
258
+ .. note::
259
+ * 'Micro' averaging suits overall performance evaluation but may not reflect
260
+ minority class accuracy.
261
+ * 'Macro' averaging, not used here, gives equal weight to each class, useful
262
+ for balanced performance assessment across imbalanced classes.
263
+ """
264
+ num_classes: int = self.hparams["num_classes"]
265
+ ignore_index: Optional[int] = self.hparams["ignore_index"]
266
+ metrics = MetricCollection(
267
+ [
268
+ MulticlassAccuracy(
269
+ num_classes=num_classes,
270
+ ignore_index=ignore_index,
271
+ multidim_average="global",
272
+ average="micro",
273
+ ),
274
+ MulticlassJaccardIndex(
275
+ num_classes=num_classes, ignore_index=ignore_index, average="micro"
276
+ ),
277
+ ]
278
+ )
279
+ self.train_metrics = metrics.clone(prefix="train_")
280
+ self.val_metrics = metrics.clone(prefix="val_")
281
+ self.test_metrics = metrics.clone(prefix="test_")
282
+
283
+ def configure_models(self) -> None:
284
+ """Initialize the model.
285
+
286
+ Raises:
287
+ ValueError: If *model* is invalid.
288
+ """
289
+ model: str = self.hparams["model"]
290
+ backbone: str = self.hparams["backbone"]
291
+ weights = self.weights
292
+ in_channels: int = self.hparams["in_channels"]
293
+ num_classes: int = self.hparams["num_classes"]
294
+ num_filters: int = self.hparams["num_filters"]
295
+
296
+ if model == "unet":
297
+ self.model = smp.Unet(
298
+ encoder_name=backbone,
299
+ encoder_weights="imagenet" if weights is True else None,
300
+ in_channels=in_channels,
301
+ classes=num_classes,
302
+ )
303
+ elif model == "deeplabv3+":
304
+ self.model = smp.DeepLabV3Plus(
305
+ encoder_name=backbone,
306
+ encoder_weights="imagenet" if weights is True else None,
307
+ in_channels=in_channels,
308
+ classes=num_classes,
309
+ )
310
+ else:
311
+ raise ValueError(
312
+ f"Model type '{model}' is not valid. "
313
+ "Currently, only supports 'unet', 'deeplabv3+' and 'fcn'."
314
+ )
315
+
316
+ if weights and weights is not True:
317
+ if isinstance(weights, WeightsEnum):
318
+ state_dict = weights.get_state_dict(progress=True)
319
+ elif os.path.exists(weights):
320
+ _, state_dict = extract_backbone(weights)
321
+ else:
322
+ state_dict = get_weight(weights).get_state_dict(progress=True)
323
+ self.model.encoder.load_state_dict(state_dict)
324
+
325
+ # Freeze backbone
326
+ if self.hparams["freeze_backbone"] and model in ["unet", "deeplabv3+"]:
327
+ for param in self.model.encoder.parameters():
328
+ param.requires_grad = False
329
+
330
+ # Freeze decoder
331
+ if self.hparams["freeze_decoder"] and model in ["unet", "deeplabv3+"]:
332
+ for param in self.model.decoder.parameters():
333
+ param.requires_grad = False
334
+
335
+ def training_step(
336
+ self, batch: Any, batch_idx: int, dataloader_idx: int = 0
337
+ ) -> Tensor:
338
+ """Compute the training loss and additional metrics.
339
+
340
+ Args:
341
+ batch: The output of your DataLoader.
342
+ batch_idx: Integer displaying index of this batch.
343
+ dataloader_idx: Index of the current dataloader.
344
+
345
+ Returns:
346
+ The loss tensor.
347
+ """
348
+ x = batch["image"]
349
+ y = batch["mask"]
350
+ y_hat = self(x)
351
+ loss: Tensor = self.criterion(y_hat, y)
352
+ self.log("train_loss", loss)
353
+ self.train_metrics(y_hat, y)
354
+ self.log_dict(self.train_metrics)
355
+ return loss
356
+
357
+ def validation_step(
358
+ self, batch: Any, batch_idx: int, dataloader_idx: int = 0
359
+ ) -> None:
360
+ """Compute the validation loss and additional metrics.
361
+
362
+ Args:
363
+ batch: The output of your DataLoader.
364
+ batch_idx: Integer displaying index of this batch.
365
+ dataloader_idx: Index of the current dataloader.
366
+ """
367
+ x = batch["image"]
368
+ y = batch["mask"]
369
+ y_hat = self(x)
370
+ loss = self.criterion(y_hat, y)
371
+ self.log("val_loss", loss)
372
+ self.val_metrics(y_hat, y)
373
+ self.log_dict(self.val_metrics)
374
+
375
+ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
376
+ """Compute the test loss and additional metrics.
377
+
378
+ Args:
379
+ batch: The output of your DataLoader.
380
+ batch_idx: Integer displaying index of this batch.
381
+ dataloader_idx: Index of the current dataloader.
382
+ """
383
+ x = batch["image"]
384
+ y = batch["mask"]
385
+ y_hat = self(x)
386
+ loss = self.criterion(y_hat, y)
387
+ self.log("test_loss", loss)
388
+ self.test_metrics(y_hat, y)
389
+ self.log_dict(self.test_metrics)
390
+
391
+ def predict_step(
392
+ self, batch: Any, batch_idx: int, dataloader_idx: int = 0
393
+ ) -> Tensor:
394
+ """Compute the predicted class probabilities.
395
+
396
+ Args:
397
+ batch: The output of your DataLoader.
398
+ batch_idx: Integer displaying index of this batch.
399
+ dataloader_idx: Index of the current dataloader.
400
+
401
+ Returns:
402
+ Output predicted probabilities.
403
+ """
404
+ x = batch["image"]
405
+ y_hat: Tensor = self(x).softmax(dim=1)
406
+ return y_hat
407
+
408
+
409
+ class CustomSemanticSegmentationTask(SemanticSegmentationTask):
410
+ """A custom trainer for semantic segmentation tasks."""
411
+
412
+ def configure_callbacks(self) -> list[Callback]:
413
+ """Configures the callbacks for the trainer.
414
+
415
+ Returns:
416
+ an empty list to override the default callbacks, we set these in the Trainer
417
+ """
418
+ return []