File size: 630 Bytes
823e567 1150876 823e567 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
import segmentation_models_pytorch as smp
from .hf_config import UnetConfig
from transformers import PreTrainedModel
class HFUnetPlusPlus(PreTrainedModel):
config_class = UnetConfig
def __init__(self, config):
super().__init__(config)
self.model = smp.UnetPlusPlus(
encoder_name=config.encoder_name,
encoder_weights="imagenet",
decoder_channels=config.decoder_channels,
in_channels=config.input_channels,
classes=config.num_classes,
decoder_attention_type="scse")
def forward(self, tensor):
return self.model(tensor)
|