Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| # CPP-Net Experiment Class | |
| # | |
| # @ Fabian Hörst, [email protected] | |
| # Institute for Artifical Intelligence in Medicine, | |
| # University Medicine Essen | |
| import inspect | |
| import os | |
| import sys | |
| from base_ml.base_trainer import BaseTrainer | |
| currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) | |
| parentdir = os.path.dirname(currentdir) | |
| sys.path.insert(0, parentdir) | |
| from pathlib import Path | |
| from typing import Union | |
| import torch | |
| import torch.nn as nn | |
| from torchinfo import summary | |
| from base_ml.base_loss import retrieve_loss_fn | |
| from cell_segmentation.experiments.experiment_stardist_pannuke import ( | |
| ExperimentCellViTStarDist, | |
| ) | |
| from cell_segmentation.trainer.trainer_cpp_net import CellViTCPPTrainer | |
| from models.segmentation.cell_segmentation.cellvit_cpp_net import ( | |
| CellViT256CPP, | |
| CellViTCPP, | |
| CellViTSAMCPP, | |
| ) | |
| class ExperimentCellViTCPP(ExperimentCellViTStarDist): | |
| def get_loss_fn(self, loss_fn_settings: dict) -> dict: | |
| """Create a dictionary with loss functions for all branches | |
| Branches: "dist_map", "stardist_map", "stardist_map_refined", "nuclei_type_map", "tissue_types" | |
| Args: | |
| loss_fn_settings (dict): Dictionary with the loss function settings. Structure | |
| branch_name(str): | |
| loss_name(str): | |
| loss_fn(str): String matching to the loss functions defined in the LOSS_DICT (base_ml.base_loss) | |
| weight(float): Weighting factor as float value | |
| (optional) args: Optional parameters for initializing the loss function | |
| arg_name: value | |
| If a branch is not provided, the defaults settings (described below) are used. | |
| For further information, please have a look at the file configs/examples/cell_segmentation/train_cellvit.yaml | |
| under the section "loss" | |
| Example: | |
| nuclei_type_map: | |
| bce: | |
| loss_fn: xentropy_loss | |
| weight: 1 | |
| dice: | |
| loss_fn: dice_loss | |
| weight: 1 | |
| Returns: | |
| dict: Dictionary with loss functions for each branch. Structure: | |
| branch_name(str): | |
| loss_name(str): | |
| "loss_fn": Callable loss function | |
| "weight": weight of the loss since in the end all losses of all branches are added together for backward pass | |
| loss_name(str): | |
| "loss_fn": Callable loss function | |
| "weight": weight of the loss since in the end all losses of all branches are added together for backward pass | |
| branch_name(str) | |
| ... | |
| Default loss dictionary: | |
| dist_map: | |
| bceweighted: | |
| loss_fn: BCEWithLogitsLoss | |
| weight: 1 | |
| stardist_map: | |
| L1LossWeighted: | |
| loss_fn: L1LossWeighted | |
| weight: 1 | |
| nuclei_type_map | |
| bce: | |
| loss_fn: xentropy_loss | |
| weight: 1 | |
| dice: | |
| loss_fn: dice_loss | |
| weight: 1 | |
| tissue_type has no default loss and might be skipped | |
| """ | |
| loss_fn_dict = {} | |
| if "dist_map" in loss_fn_settings.keys(): | |
| loss_fn_dict["dist_map"] = {} | |
| for loss_name, loss_sett in loss_fn_settings["dist_map"].items(): | |
| parameters = loss_sett.get("args", {}) | |
| loss_fn_dict["dist_map"][loss_name] = { | |
| "loss_fn": retrieve_loss_fn(loss_sett["loss_fn"], **parameters), | |
| "weight": loss_sett["weight"], | |
| } | |
| else: | |
| loss_fn_dict["dist_map"] = { | |
| "bceweighted": { | |
| "loss_fn": retrieve_loss_fn("BCEWithLogitsLoss"), | |
| "weight": 1, | |
| }, | |
| } | |
| if "stardist_map" in loss_fn_settings.keys(): | |
| loss_fn_dict["stardist_map"] = {} | |
| for loss_name, loss_sett in loss_fn_settings["stardist_map"].items(): | |
| parameters = loss_sett.get("args", {}) | |
| loss_fn_dict["stardist_map"][loss_name] = { | |
| "loss_fn": retrieve_loss_fn(loss_sett["loss_fn"], **parameters), | |
| "weight": loss_sett["weight"], | |
| } | |
| else: | |
| loss_fn_dict["stardist_map"] = { | |
| "L1LossWeighted": { | |
| "loss_fn": retrieve_loss_fn("L1LossWeighted"), | |
| "weight": 1, | |
| }, | |
| } | |
| if "stardist_map_refined" in loss_fn_settings.keys(): | |
| loss_fn_dict["stardist_map_refined"] = {} | |
| for loss_name, loss_sett in loss_fn_settings[ | |
| "stardist_map_refined" | |
| ].items(): | |
| parameters = loss_sett.get("args", {}) | |
| loss_fn_dict["stardist_map_refined"][loss_name] = { | |
| "loss_fn": retrieve_loss_fn(loss_sett["loss_fn"], **parameters), | |
| "weight": loss_sett["weight"], | |
| } | |
| else: | |
| loss_fn_dict["stardist_map_refined"] = { | |
| "L1LossWeighted": { | |
| "loss_fn": retrieve_loss_fn("L1LossWeighted"), | |
| "weight": 1, | |
| }, | |
| } | |
| if "nuclei_type_map" in loss_fn_settings.keys(): | |
| loss_fn_dict["nuclei_type_map"] = {} | |
| for loss_name, loss_sett in loss_fn_settings["nuclei_type_map"].items(): | |
| parameters = loss_sett.get("args", {}) | |
| loss_fn_dict["nuclei_type_map"][loss_name] = { | |
| "loss_fn": retrieve_loss_fn(loss_sett["loss_fn"], **parameters), | |
| "weight": loss_sett["weight"], | |
| } | |
| else: | |
| loss_fn_dict["nuclei_type_map"] = { | |
| "bce": {"loss_fn": retrieve_loss_fn("xentropy_loss"), "weight": 1}, | |
| "dice": {"loss_fn": retrieve_loss_fn("dice_loss"), "weight": 1}, | |
| } | |
| if "tissue_types" in loss_fn_settings.keys(): | |
| loss_fn_dict["tissue_types"] = {} | |
| for loss_name, loss_sett in loss_fn_settings["tissue_types"].items(): | |
| parameters = loss_sett.get("args", {}) | |
| loss_fn_dict["tissue_types"][loss_name] = { | |
| "loss_fn": retrieve_loss_fn(loss_sett["loss_fn"], **parameters), | |
| "weight": loss_sett["weight"], | |
| } | |
| # skip default tissue loss! | |
| return loss_fn_dict | |
| def get_train_model( | |
| self, | |
| pretrained_encoder: Union[Path, str] = None, | |
| pretrained_model: Union[Path, str] = None, | |
| backbone_type: str = "default", | |
| shared_decoders: bool = False, | |
| **kwargs, | |
| ) -> nn.Module: | |
| """Return the CellViTStarDist training model | |
| Args: | |
| pretrained_encoder (Union[Path, str]): Path to a pretrained encoder. Defaults to None. | |
| pretrained_model (Union[Path, str], optional): Path to a pretrained model. Defaults to None. | |
| backbone_type (str, optional): Backbone Type. Currently supported are default (None, ViT256, SAM-B, SAM-L, SAM-H). Defaults to None | |
| shared_decoders (bool, optional): If shared skip decoders should be used. Defaults to False. | |
| Returns: | |
| nn.Module: StarDist training model with given setup | |
| """ | |
| # reseed needed, due to subprocess seeding compatibility | |
| self.seed_run(self.default_conf["random_seed"]) | |
| # check for backbones | |
| implemented_backbones = [ | |
| "default", | |
| "vit256", | |
| "sam-b", | |
| "sam-l", | |
| "sam-h", | |
| ] | |
| if backbone_type.lower() not in implemented_backbones: | |
| raise NotImplementedError( | |
| f"Unknown Backbone Type - Currently supported are: {implemented_backbones}" | |
| ) | |
| if backbone_type.lower() == "default": | |
| if shared_decoders: | |
| raise NotImplementedError( | |
| "Shared decoders are not implemented for StarDist" | |
| ) | |
| else: | |
| model_class = CellViTCPP | |
| model = model_class( | |
| num_nuclei_classes=self.run_conf["data"]["num_nuclei_classes"], | |
| num_tissue_classes=self.run_conf["data"]["num_tissue_classes"], | |
| embed_dim=self.run_conf["model"]["embed_dim"], | |
| input_channels=self.run_conf["model"].get("input_channels", 3), | |
| depth=self.run_conf["model"]["depth"], | |
| num_heads=self.run_conf["model"]["num_heads"], | |
| extract_layers=self.run_conf["model"]["extract_layers"], | |
| drop_rate=self.run_conf["training"].get("drop_rate", 0), | |
| attn_drop_rate=self.run_conf["training"].get("attn_drop_rate", 0), | |
| drop_path_rate=self.run_conf["training"].get("drop_path_rate", 0), | |
| nrays=self.run_conf["model"].get("nrays", 32), | |
| ) | |
| if pretrained_model is not None: | |
| self.logger.info( | |
| f"Loading pretrained CellViT model from path: {pretrained_model}" | |
| ) | |
| cellvit_pretrained = torch.load(pretrained_model) | |
| self.logger.info(model.load_state_dict(cellvit_pretrained, strict=True)) | |
| self.logger.info("Loaded CellViT model") | |
| if backbone_type.lower() == "vit256": | |
| if shared_decoders: | |
| raise NotImplementedError( | |
| "Shared decoders are not implemented for StarDist" | |
| ) | |
| else: | |
| model_class = CellViT256CPP | |
| model = model_class( | |
| model256_path=pretrained_encoder, | |
| num_nuclei_classes=self.run_conf["data"]["num_nuclei_classes"], | |
| num_tissue_classes=self.run_conf["data"]["num_tissue_classes"], | |
| drop_rate=self.run_conf["training"].get("drop_rate", 0), | |
| attn_drop_rate=self.run_conf["training"].get("attn_drop_rate", 0), | |
| drop_path_rate=self.run_conf["training"].get("drop_path_rate", 0), | |
| nrays=self.run_conf["model"].get("nrays", 32), | |
| ) | |
| model.load_pretrained_encoder(model.model256_path) | |
| if pretrained_model is not None: | |
| self.logger.info( | |
| f"Loading pretrained CellViT model from path: {pretrained_model}" | |
| ) | |
| cellvit_pretrained = torch.load(pretrained_model, map_location="cpu") | |
| self.logger.info(model.load_state_dict(cellvit_pretrained, strict=True)) | |
| model.freeze_encoder() | |
| self.logger.info("Loaded CellVit256 model") | |
| if backbone_type.lower() in ["sam-b", "sam-l", "sam-h"]: | |
| if shared_decoders: | |
| raise NotImplementedError( | |
| "Shared decoders are not implemented for StarDist" | |
| ) | |
| else: | |
| model_class = CellViTSAMCPP | |
| model = model_class( | |
| model_path=pretrained_encoder, | |
| num_nuclei_classes=self.run_conf["data"]["num_nuclei_classes"], | |
| num_tissue_classes=self.run_conf["data"]["num_tissue_classes"], | |
| vit_structure=backbone_type, | |
| drop_rate=self.run_conf["training"].get("drop_rate", 0), | |
| nrays=self.run_conf["model"].get("nrays", 32), | |
| ) | |
| model.load_pretrained_encoder(model.model_path) | |
| if pretrained_model is not None: | |
| self.logger.info( | |
| f"Loading pretrained CellViT model from path: {pretrained_model}" | |
| ) | |
| cellvit_pretrained = torch.load(pretrained_model, map_location="cpu") | |
| self.logger.info(model.load_state_dict(cellvit_pretrained, strict=True)) | |
| model.freeze_encoder() | |
| self.logger.info(f"Loaded CellViT-SAM model with backbone: {backbone_type}") | |
| self.logger.info(f"\nModel: {model}") | |
| model = model.to("cpu") | |
| self.logger.info( | |
| f"\n{summary(model, input_size=(1, 3, 256, 256), device='cpu')}" | |
| ) | |
| return model | |
| def get_trainer(self) -> BaseTrainer: | |
| """Return Trainer matching to this network | |
| Returns: | |
| BaseTrainer: Trainer | |
| """ | |
| return CellViTCPPTrainer | |