Update modeling_esm_plusplus.py
Browse files- modeling_esm_plusplus.py +4 -2
modeling_esm_plusplus.py
CHANGED
|
@@ -49,6 +49,7 @@ class ESMplusplusConfig(PretrainedConfig):
|
|
| 49 |
num_labels: int = 2,
|
| 50 |
problem_type: str | None = None,
|
| 51 |
dropout: float = 0.0,
|
|
|
|
| 52 |
**kwargs,
|
| 53 |
):
|
| 54 |
super().__init__(**kwargs)
|
|
@@ -59,6 +60,7 @@ class ESMplusplusConfig(PretrainedConfig):
|
|
| 59 |
self.num_labels = num_labels
|
| 60 |
self.problem_type = problem_type
|
| 61 |
self.dropout = dropout
|
|
|
|
| 62 |
|
| 63 |
|
| 64 |
### Rotary Embeddings
|
|
@@ -792,7 +794,7 @@ class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel):
|
|
| 792 |
)
|
| 793 |
|
| 794 |
|
| 795 |
-
class ESMplusplusForSequenceClassification(
|
| 796 |
"""
|
| 797 |
ESM++ model for sequence classification.
|
| 798 |
Extends the base ESM++ model with a classification head.
|
|
@@ -873,7 +875,7 @@ class ESMplusplusForSequenceClassification(PreTrainedESMplusplusModel):
|
|
| 873 |
)
|
| 874 |
|
| 875 |
|
| 876 |
-
class ESMplusplusForTokenClassification(
|
| 877 |
"""
|
| 878 |
ESM++ model for token classification.
|
| 879 |
Extends the base ESM++ model with a token classification head.
|
|
|
|
| 49 |
num_labels: int = 2,
|
| 50 |
problem_type: str | None = None,
|
| 51 |
dropout: float = 0.0,
|
| 52 |
+
initializer_range: float = 0.02,
|
| 53 |
**kwargs,
|
| 54 |
):
|
| 55 |
super().__init__(**kwargs)
|
|
|
|
| 60 |
self.num_labels = num_labels
|
| 61 |
self.problem_type = problem_type
|
| 62 |
self.dropout = dropout
|
| 63 |
+
self.initializer_range = initializer_range
|
| 64 |
|
| 65 |
|
| 66 |
### Rotary Embeddings
|
|
|
|
| 794 |
)
|
| 795 |
|
| 796 |
|
| 797 |
+
class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
|
| 798 |
"""
|
| 799 |
ESM++ model for sequence classification.
|
| 800 |
Extends the base ESM++ model with a classification head.
|
|
|
|
| 875 |
)
|
| 876 |
|
| 877 |
|
| 878 |
+
class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM):
|
| 879 |
"""
|
| 880 |
ESM++ model for token classification.
|
| 881 |
Extends the base ESM++ model with a token classification head.
|