import torch import torch.nn as nn import torch.nn.functional as F import timm from typing import Optional class DecoderBlock(nn.Module): def __init__(self, in_channels: int, skip_channels: int, out_channels: int): super(DecoderBlock, self).__init__() self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2) self.conv1 = nn.Conv2d(out_channels + skip_channels, out_channels, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(out_channels) self.relu1 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(out_channels) self.relu2 = nn.ReLU(inplace=True) def forward(self, x: torch.Tensor, skip: Optional[torch.Tensor]) -> torch.Tensor: x = self.up(x) if skip is not None: if x.size() != skip.size(): x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=True) x = torch.cat([x, skip], dim=1) x = self.conv1(x) x = self.bn1(x) x = self.relu1(x) x = self.conv2(x) x = self.bn2(x) x = self.relu2(x) return x class UNetTimmWithClassification(nn.Module): def __init__(self, encoder_name: str = 'resnet50', encoder_weights: Optional[str] = 'imagenet', num_classes_seg: int = 1, num_classes_cls: int = 9 ): super(UNetTimmWithClassification, self).__init__() self.encoder = timm.create_model( encoder_name, pretrained=(encoder_weights == 'imagenet'), features_only=True, in_chans=3 ) encoder_channels = self.encoder.feature_info.channels() decoder_channels = [2048, 1024, 512, 256] # decoder_channels = [512, 256, 128, 64] self.decoder4 = DecoderBlock( in_channels=decoder_channels[0], skip_channels=encoder_channels[3], out_channels=decoder_channels[1] ) self.decoder3 = DecoderBlock( in_channels=decoder_channels[1], skip_channels=encoder_channels[2], out_channels=decoder_channels[2] ) self.decoder2 = DecoderBlock( in_channels=decoder_channels[2], skip_channels=encoder_channels[1], out_channels=decoder_channels[3] ) self.decoder1 = DecoderBlock( in_channels=decoder_channels[3], skip_channels=encoder_channels[0], out_channels=decoder_channels[3] ) self.final_up = nn.ConvTranspose2d( in_channels=decoder_channels[-1], out_channels=32, kernel_size=2, stride=2 ) self.final_conv_seg = nn.Conv2d( in_channels=32, out_channels=num_classes_seg, kernel_size=1 ) #Cls head self.classification_head = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(2048, 512), # nn.Linear(encoder_channels[-1], 512), nn.Dropout(0.2), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, num_classes_cls) ) if num_classes_cls > 1: self.classification_activation = nn.Softmax(dim=1) elif num_classes_cls == 1: self.classification_activation = nn.Sigmoid() else: self.classification_activation = None if self.classification_activation is not None: self.classification_head.add_module("activation", self.classification_activation) #Xavier weight initialize if encoder_weights == 'xavier': self.apply(self.xavier_init_weights) def xavier_init_weights(self, m): if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.BatchNorm2d): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) def forward(self, x: torch.Tensor) -> tuple: features = self.encoder(x) C0, C1, C2, C3, C4 = features cls = self.classification_head(C4) D4 = self.decoder4(C4, C3) D3 = self.decoder3(D4, C2) D2 = self.decoder2(D3, C1) D1 = self.decoder1(D2, C0) x = self.final_up(D1) seg = self.final_conv_seg(x) return seg, cls