Spaces:
Running
Running
File size: 3,615 Bytes
29ac506 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import torch
import torch.nn as nn
from segmentation_models_pytorch import MAnet
from segmentation_models_pytorch.base.modules import Activation
__all__ = ["MEDIARFormer"]
class MEDIARFormer(MAnet):
"""MEDIAR-Former Model"""
def __init__(
self,
encoder_name="mit_b5", # Default encoder
encoder_weights="imagenet", # Pre-trained weights
decoder_channels=(1024, 512, 256, 128, 64), # Decoder configuration
decoder_pab_channels=256, # Decoder Pyramid Attention Block channels
in_channels=3, # Number of input channels
classes=3, # Number of output classes
):
# Initialize the MAnet model with provided parameters
super().__init__(
encoder_name=encoder_name,
encoder_weights=encoder_weights,
decoder_channels=decoder_channels,
decoder_pab_channels=decoder_pab_channels,
in_channels=in_channels,
classes=classes,
)
# Remove the default segmentation head as it's not used in this architecture
self.segmentation_head = None
# Modify all activation functions in the encoder and decoder from ReLU to Mish
_convert_activations(self.encoder, nn.ReLU, nn.Mish(inplace=True))
_convert_activations(self.decoder, nn.ReLU, nn.Mish(inplace=True))
# Add custom segmentation heads for different segmentation tasks
self.cellprob_head = DeepSegmentationHead(
in_channels=decoder_channels[-1], out_channels=1
)
self.gradflow_head = DeepSegmentationHead(
in_channels=decoder_channels[-1], out_channels=2
)
def forward(self, x):
"""Forward pass through the network"""
# Ensure the input shape is correct
self.check_input_shape(x)
# Encode the input and then decode it
features = self.encoder(x)
decoder_output = self.decoder(*features)
# Generate masks for cell probability and gradient flows
cellprob_mask = self.cellprob_head(decoder_output)
gradflow_mask = self.gradflow_head(decoder_output)
# Concatenate the masks for output
masks = torch.cat([gradflow_mask, cellprob_mask], dim=1)
return masks
class DeepSegmentationHead(nn.Sequential):
"""Custom segmentation head for generating specific masks"""
def __init__(
self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1
):
# Define a sequence of layers for the segmentation head
layers = [
nn.Conv2d(
in_channels,
in_channels // 2,
kernel_size=kernel_size,
padding=kernel_size // 2,
),
nn.Mish(inplace=True),
nn.BatchNorm2d(in_channels // 2),
nn.Conv2d(
in_channels // 2,
out_channels,
kernel_size=kernel_size,
padding=kernel_size // 2,
),
nn.UpsamplingBilinear2d(scale_factor=upsampling)
if upsampling > 1
else nn.Identity(),
Activation(activation) if activation else nn.Identity(),
]
super().__init__(*layers)
def _convert_activations(module, from_activation, to_activation):
"""Recursively convert activation functions in a module"""
for name, child in module.named_children():
if isinstance(child, from_activation):
setattr(module, name, to_activation)
else:
_convert_activations(child, from_activation, to_activation)
|