Update modeling_esm_plusplus.py
Browse files- modeling_esm_plusplus.py +5 -4
modeling_esm_plusplus.py
CHANGED
@@ -467,8 +467,8 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
|
|
467 |
Implements the base ESM++ architecture with a masked language modeling head.
|
468 |
"""
|
469 |
config_class = ESMplusplusConfig
|
470 |
-
def __init__(self, config: ESMplusplusConfig):
|
471 |
-
super().__init__(config)
|
472 |
self.config = config
|
473 |
self.vocab_size = config.vocab_size
|
474 |
self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
|
@@ -642,9 +642,10 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
|
|
642 |
|
643 |
Extends the base ESM++ model with a classification head.
|
644 |
"""
|
645 |
-
def __init__(self, config: ESMplusplusConfig):
|
646 |
-
super().__init__(config)
|
647 |
self.config = config
|
|
|
648 |
self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4)
|
649 |
# Large intermediate projections help with sequence classification tasks (*4)
|
650 |
self.mse = nn.MSELoss()
|
|
|
467 |
Implements the base ESM++ architecture with a masked language modeling head.
|
468 |
"""
|
469 |
config_class = ESMplusplusConfig
|
470 |
+
def __init__(self, config: ESMplusplusConfig, **kwargs):
|
471 |
+
super().__init__(config, **kwargs)
|
472 |
self.config = config
|
473 |
self.vocab_size = config.vocab_size
|
474 |
self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
|
|
|
642 |
|
643 |
Extends the base ESM++ model with a classification head.
|
644 |
"""
|
645 |
+
def __init__(self, config: ESMplusplusConfig, **kwargs):
|
646 |
+
super().__init__(config, **kwargs)
|
647 |
self.config = config
|
648 |
+
self.num_labels = config.num_labels
|
649 |
self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4)
|
650 |
# Large intermediate projections help with sequence classification tasks (*4)
|
651 |
self.mse = nn.MSELoss()
|