forensic-similarity-graph / configuration.py
ductai199x's picture
Upload model
45f4abb verified
from transformers.configuration_utils import PretrainedConfig
class FeConfig:
def __init__(
self,
patch_size: int = 128,
variant: str = "p128",
num_classes: int = 0,
num_filters: int = 6,
is_constrained: bool = False,
):
self.patch_size = patch_size
self.variant = variant
self.num_classes = num_classes
self.num_filters = num_filters
self.is_constrained = is_constrained
def to_dict(self):
return {
"patch_size": self.patch_size,
"variant": self.variant,
"num_classes": self.num_classes,
"num_filters": self.num_filters,
"is_constrained": self.is_constrained,
}
class CompareNetConfig:
def __init__(
self,
hidden_dim: int = 2048,
output_dim: int = 64,
):
self.hidden_dim = hidden_dim
self.output_dim = output_dim
def to_dict(self):
return {
"hidden_dim": self.hidden_dim,
"output_dim": self.output_dim,
}
class FsgConfig(PretrainedConfig):
model_type = "fsg"
def __init__(
self,
fe_config=None,
comparenet_config=None,
fast_sim_mode: bool = True,
loc_threshold: float = 0.3,
stride_ratio: float = 0.5,
need_input_255: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.fe_config = FeConfig() if fe_config is None else FeConfig(**fe_config)
self.comparenet_config = CompareNetConfig() if comparenet_config is None else CompareNetConfig(**comparenet_config)
self.fast_sim_mode = fast_sim_mode
self.loc_threshold = loc_threshold
self.stride_ratio = stride_ratio
self.need_input_255 = need_input_255
def to_dict(self):
return {
"fe_config": self.fe_config.to_dict(),
"comparenet_config": self.comparenet_config.to_dict(),
"fast_sim_mode": self.fast_sim_mode,
"loc_threshold": self.loc_threshold,
"stride_ratio": self.stride_ratio,
"need_input_255": self.need_input_255,
}