thon | |
from transformers import PretrainedConfig | |
from typing import List | |
class ResnetConfig(PretrainedConfig): | |
model_type = "resnet" | |
def __init__( | |
self, | |
block_type="bottleneck", | |
layers: List[int] = [3, 4, 6, 3], | |
num_classes: int = 1000, | |
input_channels: int = 3, | |
cardinality: int = 1, | |
base_width: int = 64, | |
stem_width: int = 64, | |
stem_type: str = "", | |
avg_down: bool = False, | |
**kwargs, | |
): | |
if block_type not in ["basic", "bottleneck"]: | |
raise ValueError(f"`block_type` must be 'basic' or bottleneck', got {block_type}.") |