Spaces:
Runtime error
Runtime error
| import importlib | |
| from models.utils import calculate_metrics | |
| from abc import ABC, abstractmethod | |
| import pytorch_lightning as pl | |
| import torch | |
| import torch.nn as nn | |
| class TrainingEnvironment(pl.LightningModule): | |
| def __init__( | |
| self, | |
| model: nn.Module, | |
| criterion: nn.Module, | |
| config: dict, | |
| learning_rate=1e-4, | |
| *args, | |
| **kwargs, | |
| ): | |
| super().__init__(*args, **kwargs) | |
| self.model = model | |
| self.criterion = criterion | |
| self.learning_rate = learning_rate | |
| self.experiment_loggers = load_loggers( | |
| config["training_environment"].get("loggers", {}) | |
| ) | |
| self.config = config | |
| self.has_multi_label_predictions = ( | |
| not type(criterion).__name__ == "CrossEntropyLoss" | |
| ) | |
| self.save_hyperparameters( | |
| { | |
| "model": type(model).__name__, | |
| "loss": type(criterion).__name__, | |
| "config": config, | |
| **kwargs, | |
| } | |
| ) | |
| def training_step( | |
| self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int | |
| ) -> torch.Tensor: | |
| features, labels = batch | |
| outputs = self.model(features) | |
| loss = self.criterion(outputs, labels) | |
| metrics = calculate_metrics( | |
| outputs, | |
| labels, | |
| prefix="train/", | |
| multi_label=self.has_multi_label_predictions, | |
| ) | |
| self.log_dict(metrics, prog_bar=True) | |
| experiment = self.logger.experiment | |
| for logger in self.experiment_loggers: | |
| logger.step(experiment, batch_index, features, labels) | |
| return loss | |
| def validation_step( | |
| self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int | |
| ): | |
| x, y = batch | |
| preds = self.model(x) | |
| metrics = calculate_metrics( | |
| preds, y, prefix="val/", multi_label=self.has_multi_label_predictions | |
| ) | |
| metrics["val/loss"] = self.criterion(preds, y) | |
| self.log_dict(metrics, prog_bar=True) | |
| def test_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int): | |
| x, y = batch | |
| preds = self.model(x) | |
| self.log_dict( | |
| calculate_metrics( | |
| preds, y, prefix="test/", multi_label=self.has_multi_label_predictions | |
| ), | |
| prog_bar=True, | |
| ) | |
| def configure_optimizers(self): | |
| optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) | |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min") | |
| return { | |
| "optimizer": optimizer, | |
| "lr_scheduler": scheduler, | |
| "monitor": "val/loss", | |
| } | |
| class ExperimentLogger(ABC): | |
| def step(self, experiment, data): | |
| pass | |
| class SpectrogramLogger(ExperimentLogger): | |
| def __init__(self, frequency=100) -> None: | |
| self.frequency = frequency | |
| self.counter = 0 | |
| def step(self, experiment, batch_index, x, label): | |
| if self.counter == self.frequency: | |
| self.counter = 0 | |
| img_index = torch.randint(0, len(x), (1,)).item() | |
| img = x[img_index][0] | |
| img = (img - img.min()) / (img.max() - img.min()) | |
| experiment.add_image( | |
| f"batch: {batch_index}, element: {img_index}", img, 0, dataformats="HW" | |
| ) | |
| self.counter += 1 | |
| def load_loggers(logger_config: dict) -> list[ExperimentLogger]: | |
| loggers = [] | |
| for logger_path, kwargs in logger_config.items(): | |
| module_name, class_name = logger_path.rsplit(".", 1) | |
| module = importlib.import_module(module_name) | |
| Logger = getattr(module, class_name) | |
| loggers.append(Logger(**kwargs)) | |
| return loggers | |