| from transformers import PretrainedConfig | |
| from typing import List, Tuple | |
| class MammoConfig(PretrainedConfig): | |
| model_type = "mammo" | |
| def __init__( | |
| self, | |
| backbone: str = "tf_efficientnetv2_s", | |
| feature_dim: int = 1280, | |
| dropout: float = 0.1, | |
| num_classes: int = 5, | |
| in_chans: int = 1, | |
| num_models: int = 3, | |
| image_sizes: List[Tuple[int, int]] = [(2048, 1024), (1920, 1280), (1536, 1536)], | |
| pad_to_aspect_ratio: List[bool] = [True, True, False], | |
| **kwargs, | |
| ): | |
| self.backbone = backbone | |
| self.feature_dim = feature_dim | |
| self.dropout = dropout | |
| self.num_classes = num_classes | |
| self.in_chans = in_chans | |
| self.num_models = num_models | |
| assert len(image_sizes) == len(pad_to_aspect_ratio) == num_models, ( | |
| f"length of `image_sizes` [{len(image_sizes)}] and `pad_to_aspect_ratio` " | |
| f"[{len(pad_to_aspect_ratio)}] must be equal to `num_models` [{num_models}]." | |
| ) | |
| self.image_sizes = image_sizes | |
| self.pad_to_aspect_ratio = pad_to_aspect_ratio | |
| super().__init__(**kwargs) | |