File size: 510 Bytes
823e567 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
from transformers import PretrainedConfig
class UnetConfig(PretrainedConfig):
def __init__(
self,
encoder_name: str = "resnet18",
num_classes: int = 16,
input_channels: int = 1,
decoder_channels: tuple = (1024, 512, 256, 128, 64),
**kwargs
):
self.encoder_name = encoder_name
self.num_classes = num_classes
self.input_channels = input_channels
self.decoder_channels = decoder_channels
super().__init__(**kwargs)
|