|
from transformers.configuration_utils import PretrainedConfig |
|
|
|
|
|
class FeConfig: |
|
def __init__( |
|
self, |
|
patch_size: int = 128, |
|
variant: str = "p128", |
|
num_classes: int = 0, |
|
num_filters: int = 6, |
|
is_constrained: bool = False, |
|
): |
|
self.patch_size = patch_size |
|
self.variant = variant |
|
self.num_classes = num_classes |
|
self.num_filters = num_filters |
|
self.is_constrained = is_constrained |
|
|
|
def to_dict(self): |
|
return { |
|
"patch_size": self.patch_size, |
|
"variant": self.variant, |
|
"num_classes": self.num_classes, |
|
"num_filters": self.num_filters, |
|
"is_constrained": self.is_constrained, |
|
} |
|
|
|
|
|
class CompareNetConfig: |
|
def __init__( |
|
self, |
|
hidden_dim: int = 2048, |
|
output_dim: int = 64, |
|
): |
|
self.hidden_dim = hidden_dim |
|
self.output_dim = output_dim |
|
|
|
def to_dict(self): |
|
return { |
|
"hidden_dim": self.hidden_dim, |
|
"output_dim": self.output_dim, |
|
} |
|
|
|
|
|
class FsgConfig(PretrainedConfig): |
|
model_type = "fsg" |
|
|
|
def __init__( |
|
self, |
|
fe_config=None, |
|
comparenet_config=None, |
|
fast_sim_mode: bool = True, |
|
loc_threshold: float = 0.3, |
|
stride_ratio: float = 0.5, |
|
need_input_255: bool = True, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.fe_config = FeConfig() if fe_config is None else FeConfig(**fe_config) |
|
self.comparenet_config = CompareNetConfig() if comparenet_config is None else CompareNetConfig(**comparenet_config) |
|
self.fast_sim_mode = fast_sim_mode |
|
self.loc_threshold = loc_threshold |
|
self.stride_ratio = stride_ratio |
|
self.need_input_255 = need_input_255 |
|
|
|
def to_dict(self): |
|
return { |
|
"fe_config": self.fe_config.to_dict(), |
|
"comparenet_config": self.comparenet_config.to_dict(), |
|
"fast_sim_mode": self.fast_sim_mode, |
|
"loc_threshold": self.loc_threshold, |
|
"stride_ratio": self.stride_ratio, |
|
"need_input_255": self.need_input_255, |
|
} |