Update modeling_esm_plusplus.py
Browse files- modeling_esm_plusplus.py +5 -4
modeling_esm_plusplus.py
CHANGED
|
@@ -467,8 +467,8 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
|
|
| 467 |
Implements the base ESM++ architecture with a masked language modeling head.
|
| 468 |
"""
|
| 469 |
config_class = ESMplusplusConfig
|
| 470 |
-
def __init__(self, config: ESMplusplusConfig):
|
| 471 |
-
super().__init__(config)
|
| 472 |
self.config = config
|
| 473 |
self.vocab_size = config.vocab_size
|
| 474 |
self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
|
|
@@ -642,9 +642,10 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
|
|
| 642 |
|
| 643 |
Extends the base ESM++ model with a classification head.
|
| 644 |
"""
|
| 645 |
-
def __init__(self, config: ESMplusplusConfig):
|
| 646 |
-
super().__init__(config)
|
| 647 |
self.config = config
|
|
|
|
| 648 |
self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4)
|
| 649 |
# Large intermediate projections help with sequence classification tasks (*4)
|
| 650 |
self.mse = nn.MSELoss()
|
|
|
|
| 467 |
Implements the base ESM++ architecture with a masked language modeling head.
|
| 468 |
"""
|
| 469 |
config_class = ESMplusplusConfig
|
| 470 |
+
def __init__(self, config: ESMplusplusConfig, **kwargs):
|
| 471 |
+
super().__init__(config, **kwargs)
|
| 472 |
self.config = config
|
| 473 |
self.vocab_size = config.vocab_size
|
| 474 |
self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
|
|
|
|
| 642 |
|
| 643 |
Extends the base ESM++ model with a classification head.
|
| 644 |
"""
|
| 645 |
+
def __init__(self, config: ESMplusplusConfig, **kwargs):
|
| 646 |
+
super().__init__(config, **kwargs)
|
| 647 |
self.config = config
|
| 648 |
+
self.num_labels = config.num_labels
|
| 649 |
self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4)
|
| 650 |
# Large intermediate projections help with sequence classification tasks (*4)
|
| 651 |
self.mse = nn.MSELoss()
|