Update modeling_esm_plusplus.py
Browse files- modeling_esm_plusplus.py +4 -2
modeling_esm_plusplus.py
CHANGED
@@ -49,6 +49,7 @@ class ESMplusplusConfig(PretrainedConfig):
|
|
49 |
num_labels: int = 2,
|
50 |
problem_type: str | None = None,
|
51 |
dropout: float = 0.0,
|
|
|
52 |
**kwargs,
|
53 |
):
|
54 |
super().__init__(**kwargs)
|
@@ -59,6 +60,7 @@ class ESMplusplusConfig(PretrainedConfig):
|
|
59 |
self.num_labels = num_labels
|
60 |
self.problem_type = problem_type
|
61 |
self.dropout = dropout
|
|
|
62 |
|
63 |
|
64 |
### Rotary Embeddings
|
@@ -792,7 +794,7 @@ class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel):
|
|
792 |
)
|
793 |
|
794 |
|
795 |
-
class ESMplusplusForSequenceClassification(
|
796 |
"""
|
797 |
ESM++ model for sequence classification.
|
798 |
Extends the base ESM++ model with a classification head.
|
@@ -873,7 +875,7 @@ class ESMplusplusForSequenceClassification(PreTrainedESMplusplusModel):
|
|
873 |
)
|
874 |
|
875 |
|
876 |
-
class ESMplusplusForTokenClassification(
|
877 |
"""
|
878 |
ESM++ model for token classification.
|
879 |
Extends the base ESM++ model with a token classification head.
|
|
|
49 |
num_labels: int = 2,
|
50 |
problem_type: str | None = None,
|
51 |
dropout: float = 0.0,
|
52 |
+
initializer_range: float = 0.02,
|
53 |
**kwargs,
|
54 |
):
|
55 |
super().__init__(**kwargs)
|
|
|
60 |
self.num_labels = num_labels
|
61 |
self.problem_type = problem_type
|
62 |
self.dropout = dropout
|
63 |
+
self.initializer_range = initializer_range
|
64 |
|
65 |
|
66 |
### Rotary Embeddings
|
|
|
794 |
)
|
795 |
|
796 |
|
797 |
+
class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
|
798 |
"""
|
799 |
ESM++ model for sequence classification.
|
800 |
Extends the base ESM++ model with a classification head.
|
|
|
875 |
)
|
876 |
|
877 |
|
878 |
+
class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM):
|
879 |
"""
|
880 |
ESM++ model for token classification.
|
881 |
Extends the base ESM++ model with a token classification head.
|