LKCell / cell_segmentation /experiments /experiment_cpp_net_pannuke.py
qingke1's picture
initial commit
aea73e2
raw
history blame
12.6 kB
# -*- coding: utf-8 -*-
# CPP-Net Experiment Class
#
# @ Fabian Hörst, [email protected]
# Institute for Artifical Intelligence in Medicine,
# University Medicine Essen
import inspect
import os
import sys
from base_ml.base_trainer import BaseTrainer
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir)
from pathlib import Path
from typing import Union
import torch
import torch.nn as nn
from torchinfo import summary
from base_ml.base_loss import retrieve_loss_fn
from cell_segmentation.experiments.experiment_stardist_pannuke import (
ExperimentCellViTStarDist,
)
from cell_segmentation.trainer.trainer_cpp_net import CellViTCPPTrainer
from models.segmentation.cell_segmentation.cellvit_cpp_net import (
CellViT256CPP,
CellViTCPP,
CellViTSAMCPP,
)
class ExperimentCellViTCPP(ExperimentCellViTStarDist):
def get_loss_fn(self, loss_fn_settings: dict) -> dict:
"""Create a dictionary with loss functions for all branches
Branches: "dist_map", "stardist_map", "stardist_map_refined", "nuclei_type_map", "tissue_types"
Args:
loss_fn_settings (dict): Dictionary with the loss function settings. Structure
branch_name(str):
loss_name(str):
loss_fn(str): String matching to the loss functions defined in the LOSS_DICT (base_ml.base_loss)
weight(float): Weighting factor as float value
(optional) args: Optional parameters for initializing the loss function
arg_name: value
If a branch is not provided, the defaults settings (described below) are used.
For further information, please have a look at the file configs/examples/cell_segmentation/train_cellvit.yaml
under the section "loss"
Example:
nuclei_type_map:
bce:
loss_fn: xentropy_loss
weight: 1
dice:
loss_fn: dice_loss
weight: 1
Returns:
dict: Dictionary with loss functions for each branch. Structure:
branch_name(str):
loss_name(str):
"loss_fn": Callable loss function
"weight": weight of the loss since in the end all losses of all branches are added together for backward pass
loss_name(str):
"loss_fn": Callable loss function
"weight": weight of the loss since in the end all losses of all branches are added together for backward pass
branch_name(str)
...
Default loss dictionary:
dist_map:
bceweighted:
loss_fn: BCEWithLogitsLoss
weight: 1
stardist_map:
L1LossWeighted:
loss_fn: L1LossWeighted
weight: 1
nuclei_type_map
bce:
loss_fn: xentropy_loss
weight: 1
dice:
loss_fn: dice_loss
weight: 1
tissue_type has no default loss and might be skipped
"""
loss_fn_dict = {}
if "dist_map" in loss_fn_settings.keys():
loss_fn_dict["dist_map"] = {}
for loss_name, loss_sett in loss_fn_settings["dist_map"].items():
parameters = loss_sett.get("args", {})
loss_fn_dict["dist_map"][loss_name] = {
"loss_fn": retrieve_loss_fn(loss_sett["loss_fn"], **parameters),
"weight": loss_sett["weight"],
}
else:
loss_fn_dict["dist_map"] = {
"bceweighted": {
"loss_fn": retrieve_loss_fn("BCEWithLogitsLoss"),
"weight": 1,
},
}
if "stardist_map" in loss_fn_settings.keys():
loss_fn_dict["stardist_map"] = {}
for loss_name, loss_sett in loss_fn_settings["stardist_map"].items():
parameters = loss_sett.get("args", {})
loss_fn_dict["stardist_map"][loss_name] = {
"loss_fn": retrieve_loss_fn(loss_sett["loss_fn"], **parameters),
"weight": loss_sett["weight"],
}
else:
loss_fn_dict["stardist_map"] = {
"L1LossWeighted": {
"loss_fn": retrieve_loss_fn("L1LossWeighted"),
"weight": 1,
},
}
if "stardist_map_refined" in loss_fn_settings.keys():
loss_fn_dict["stardist_map_refined"] = {}
for loss_name, loss_sett in loss_fn_settings[
"stardist_map_refined"
].items():
parameters = loss_sett.get("args", {})
loss_fn_dict["stardist_map_refined"][loss_name] = {
"loss_fn": retrieve_loss_fn(loss_sett["loss_fn"], **parameters),
"weight": loss_sett["weight"],
}
else:
loss_fn_dict["stardist_map_refined"] = {
"L1LossWeighted": {
"loss_fn": retrieve_loss_fn("L1LossWeighted"),
"weight": 1,
},
}
if "nuclei_type_map" in loss_fn_settings.keys():
loss_fn_dict["nuclei_type_map"] = {}
for loss_name, loss_sett in loss_fn_settings["nuclei_type_map"].items():
parameters = loss_sett.get("args", {})
loss_fn_dict["nuclei_type_map"][loss_name] = {
"loss_fn": retrieve_loss_fn(loss_sett["loss_fn"], **parameters),
"weight": loss_sett["weight"],
}
else:
loss_fn_dict["nuclei_type_map"] = {
"bce": {"loss_fn": retrieve_loss_fn("xentropy_loss"), "weight": 1},
"dice": {"loss_fn": retrieve_loss_fn("dice_loss"), "weight": 1},
}
if "tissue_types" in loss_fn_settings.keys():
loss_fn_dict["tissue_types"] = {}
for loss_name, loss_sett in loss_fn_settings["tissue_types"].items():
parameters = loss_sett.get("args", {})
loss_fn_dict["tissue_types"][loss_name] = {
"loss_fn": retrieve_loss_fn(loss_sett["loss_fn"], **parameters),
"weight": loss_sett["weight"],
}
# skip default tissue loss!
return loss_fn_dict
def get_train_model(
self,
pretrained_encoder: Union[Path, str] = None,
pretrained_model: Union[Path, str] = None,
backbone_type: str = "default",
shared_decoders: bool = False,
**kwargs,
) -> nn.Module:
"""Return the CellViTStarDist training model
Args:
pretrained_encoder (Union[Path, str]): Path to a pretrained encoder. Defaults to None.
pretrained_model (Union[Path, str], optional): Path to a pretrained model. Defaults to None.
backbone_type (str, optional): Backbone Type. Currently supported are default (None, ViT256, SAM-B, SAM-L, SAM-H). Defaults to None
shared_decoders (bool, optional): If shared skip decoders should be used. Defaults to False.
Returns:
nn.Module: StarDist training model with given setup
"""
# reseed needed, due to subprocess seeding compatibility
self.seed_run(self.default_conf["random_seed"])
# check for backbones
implemented_backbones = [
"default",
"vit256",
"sam-b",
"sam-l",
"sam-h",
]
if backbone_type.lower() not in implemented_backbones:
raise NotImplementedError(
f"Unknown Backbone Type - Currently supported are: {implemented_backbones}"
)
if backbone_type.lower() == "default":
if shared_decoders:
raise NotImplementedError(
"Shared decoders are not implemented for StarDist"
)
else:
model_class = CellViTCPP
model = model_class(
num_nuclei_classes=self.run_conf["data"]["num_nuclei_classes"],
num_tissue_classes=self.run_conf["data"]["num_tissue_classes"],
embed_dim=self.run_conf["model"]["embed_dim"],
input_channels=self.run_conf["model"].get("input_channels", 3),
depth=self.run_conf["model"]["depth"],
num_heads=self.run_conf["model"]["num_heads"],
extract_layers=self.run_conf["model"]["extract_layers"],
drop_rate=self.run_conf["training"].get("drop_rate", 0),
attn_drop_rate=self.run_conf["training"].get("attn_drop_rate", 0),
drop_path_rate=self.run_conf["training"].get("drop_path_rate", 0),
nrays=self.run_conf["model"].get("nrays", 32),
)
if pretrained_model is not None:
self.logger.info(
f"Loading pretrained CellViT model from path: {pretrained_model}"
)
cellvit_pretrained = torch.load(pretrained_model)
self.logger.info(model.load_state_dict(cellvit_pretrained, strict=True))
self.logger.info("Loaded CellViT model")
if backbone_type.lower() == "vit256":
if shared_decoders:
raise NotImplementedError(
"Shared decoders are not implemented for StarDist"
)
else:
model_class = CellViT256CPP
model = model_class(
model256_path=pretrained_encoder,
num_nuclei_classes=self.run_conf["data"]["num_nuclei_classes"],
num_tissue_classes=self.run_conf["data"]["num_tissue_classes"],
drop_rate=self.run_conf["training"].get("drop_rate", 0),
attn_drop_rate=self.run_conf["training"].get("attn_drop_rate", 0),
drop_path_rate=self.run_conf["training"].get("drop_path_rate", 0),
nrays=self.run_conf["model"].get("nrays", 32),
)
model.load_pretrained_encoder(model.model256_path)
if pretrained_model is not None:
self.logger.info(
f"Loading pretrained CellViT model from path: {pretrained_model}"
)
cellvit_pretrained = torch.load(pretrained_model, map_location="cpu")
self.logger.info(model.load_state_dict(cellvit_pretrained, strict=True))
model.freeze_encoder()
self.logger.info("Loaded CellVit256 model")
if backbone_type.lower() in ["sam-b", "sam-l", "sam-h"]:
if shared_decoders:
raise NotImplementedError(
"Shared decoders are not implemented for StarDist"
)
else:
model_class = CellViTSAMCPP
model = model_class(
model_path=pretrained_encoder,
num_nuclei_classes=self.run_conf["data"]["num_nuclei_classes"],
num_tissue_classes=self.run_conf["data"]["num_tissue_classes"],
vit_structure=backbone_type,
drop_rate=self.run_conf["training"].get("drop_rate", 0),
nrays=self.run_conf["model"].get("nrays", 32),
)
model.load_pretrained_encoder(model.model_path)
if pretrained_model is not None:
self.logger.info(
f"Loading pretrained CellViT model from path: {pretrained_model}"
)
cellvit_pretrained = torch.load(pretrained_model, map_location="cpu")
self.logger.info(model.load_state_dict(cellvit_pretrained, strict=True))
model.freeze_encoder()
self.logger.info(f"Loaded CellViT-SAM model with backbone: {backbone_type}")
self.logger.info(f"\nModel: {model}")
model = model.to("cpu")
self.logger.info(
f"\n{summary(model, input_size=(1, 3, 256, 256), device='cpu')}"
)
return model
def get_trainer(self) -> BaseTrainer:
"""Return Trainer matching to this network
Returns:
BaseTrainer: Trainer
"""
return CellViTCPPTrainer