Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from segmentation_models_pytorch import MAnet | |
from segmentation_models_pytorch.base.modules import Activation | |
__all__ = ["MEDIARFormer"] | |
class MEDIARFormer(MAnet): | |
"""MEDIAR-Former Model""" | |
def __init__( | |
self, | |
encoder_name="mit_b5", # Default encoder | |
encoder_weights="imagenet", # Pre-trained weights | |
decoder_channels=(1024, 512, 256, 128, 64), # Decoder configuration | |
decoder_pab_channels=256, # Decoder Pyramid Attention Block channels | |
in_channels=3, # Number of input channels | |
classes=3, # Number of output classes | |
): | |
# Initialize the MAnet model with provided parameters | |
super().__init__( | |
encoder_name=encoder_name, | |
encoder_weights=encoder_weights, | |
decoder_channels=decoder_channels, | |
decoder_pab_channels=decoder_pab_channels, | |
in_channels=in_channels, | |
classes=classes, | |
) | |
# Remove the default segmentation head as it's not used in this architecture | |
self.segmentation_head = None | |
# Modify all activation functions in the encoder and decoder from ReLU to Mish | |
_convert_activations(self.encoder, nn.ReLU, nn.Mish(inplace=True)) | |
_convert_activations(self.decoder, nn.ReLU, nn.Mish(inplace=True)) | |
# Add custom segmentation heads for different segmentation tasks | |
self.cellprob_head = DeepSegmentationHead( | |
in_channels=decoder_channels[-1], out_channels=1 | |
) | |
self.gradflow_head = DeepSegmentationHead( | |
in_channels=decoder_channels[-1], out_channels=2 | |
) | |
def forward(self, x): | |
"""Forward pass through the network""" | |
# Ensure the input shape is correct | |
self.check_input_shape(x) | |
# Encode the input and then decode it | |
features = self.encoder(x) | |
decoder_output = self.decoder(*features) | |
# Generate masks for cell probability and gradient flows | |
cellprob_mask = self.cellprob_head(decoder_output) | |
gradflow_mask = self.gradflow_head(decoder_output) | |
# Concatenate the masks for output | |
masks = torch.cat([gradflow_mask, cellprob_mask], dim=1) | |
return masks | |
class DeepSegmentationHead(nn.Sequential): | |
"""Custom segmentation head for generating specific masks""" | |
def __init__( | |
self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1 | |
): | |
# Define a sequence of layers for the segmentation head | |
layers = [ | |
nn.Conv2d( | |
in_channels, | |
in_channels // 2, | |
kernel_size=kernel_size, | |
padding=kernel_size // 2, | |
), | |
nn.Mish(inplace=True), | |
nn.BatchNorm2d(in_channels // 2), | |
nn.Conv2d( | |
in_channels // 2, | |
out_channels, | |
kernel_size=kernel_size, | |
padding=kernel_size // 2, | |
), | |
nn.UpsamplingBilinear2d(scale_factor=upsampling) | |
if upsampling > 1 | |
else nn.Identity(), | |
Activation(activation) if activation else nn.Identity(), | |
] | |
super().__init__(*layers) | |
def _convert_activations(module, from_activation, to_activation): | |
"""Recursively convert activation functions in a module""" | |
for name, child in module.named_children(): | |
if isinstance(child, from_activation): | |
setattr(module, name, to_activation) | |
else: | |
_convert_activations(child, from_activation, to_activation) | |