|
from transformers import PretrainedConfig |
|
from typing import List, Tuple |
|
|
|
|
|
class MammoConfig(PretrainedConfig): |
|
model_type = "mammo" |
|
|
|
def __init__( |
|
self, |
|
backbone: str = "tf_efficientnetv2_s", |
|
feature_dim: int = 1280, |
|
dropout: float = 0.1, |
|
num_classes: int = 5, |
|
in_chans: int = 1, |
|
num_models: int = 3, |
|
image_sizes: List[Tuple[int, int]] = [(2048, 1024), (1920, 1280), (1536, 1536)], |
|
pad_to_aspect_ratio: List[bool] = [True, True, False], |
|
**kwargs, |
|
): |
|
self.backbone = backbone |
|
self.feature_dim = feature_dim |
|
self.dropout = dropout |
|
self.num_classes = num_classes |
|
self.in_chans = in_chans |
|
self.num_models = num_models |
|
assert len(image_sizes) == len(pad_to_aspect_ratio) == num_models, ( |
|
f"length of `image_sizes` [{len(image_sizes)}] and `pad_to_aspect_ratio` " |
|
f"[{len(pad_to_aspect_ratio)}] must be equal to `num_models` [{num_models}]." |
|
) |
|
self.image_sizes = image_sizes |
|
self.pad_to_aspect_ratio = pad_to_aspect_ratio |
|
super().__init__(**kwargs) |
|
|