File size: 510 Bytes
823e567
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from transformers import PretrainedConfig

class UnetConfig(PretrainedConfig):

    def __init__(
        self,
        encoder_name: str = "resnet18",
        num_classes: int = 16,
        input_channels: int = 1,
        decoder_channels: tuple = (1024, 512, 256, 128, 64),
        **kwargs
    ):
        self.encoder_name = encoder_name
        self.num_classes = num_classes
        self.input_channels = input_channels
        self.decoder_channels = decoder_channels
        super().__init__(**kwargs)