Adding model type to the config
Browse files- configuration_bert.py +2 -0
configuration_bert.py
CHANGED
@@ -97,6 +97,7 @@ class FlexBertConfig(TransformersBertConfig):
|
|
97 |
pad_logits: bool = False,
|
98 |
compile_model: bool = False,
|
99 |
masked_prediction: bool = False,
|
|
|
100 |
**kwargs,
|
101 |
):
|
102 |
"""
|
@@ -213,6 +214,7 @@ class FlexBertConfig(TransformersBertConfig):
|
|
213 |
self.pad_logits = pad_logits
|
214 |
self.compile_model = compile_model
|
215 |
self.masked_prediction = masked_prediction
|
|
|
216 |
|
217 |
if loss_kwargs.get("return_z_loss", False):
|
218 |
if loss_function != "fa_cross_entropy":
|
|
|
97 |
pad_logits: bool = False,
|
98 |
compile_model: bool = False,
|
99 |
masked_prediction: bool = False,
|
100 |
+
model_type: str = "flex_bert"
|
101 |
**kwargs,
|
102 |
):
|
103 |
"""
|
|
|
214 |
self.pad_logits = pad_logits
|
215 |
self.compile_model = compile_model
|
216 |
self.masked_prediction = masked_prediction
|
217 |
+
self.model_type = model_type
|
218 |
|
219 |
if loss_kwargs.get("return_z_loss", False):
|
220 |
if loss_function != "fa_cross_entropy":
|