lhallee commited on
Commit
ab7f9c5
·
verified ·
1 Parent(s): 8c45c04

Update modeling_esm_plusplus.py

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +5 -4
modeling_esm_plusplus.py CHANGED
@@ -467,8 +467,8 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
467
  Implements the base ESM++ architecture with a masked language modeling head.
468
  """
469
  config_class = ESMplusplusConfig
470
- def __init__(self, config: ESMplusplusConfig):
471
- super().__init__(config)
472
  self.config = config
473
  self.vocab_size = config.vocab_size
474
  self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
@@ -642,9 +642,10 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
642
 
643
  Extends the base ESM++ model with a classification head.
644
  """
645
- def __init__(self, config: ESMplusplusConfig):
646
- super().__init__(config)
647
  self.config = config
 
648
  self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4)
649
  # Large intermediate projections help with sequence classification tasks (*4)
650
  self.mse = nn.MSELoss()
 
467
  Implements the base ESM++ architecture with a masked language modeling head.
468
  """
469
  config_class = ESMplusplusConfig
470
+ def __init__(self, config: ESMplusplusConfig, **kwargs):
471
+ super().__init__(config, **kwargs)
472
  self.config = config
473
  self.vocab_size = config.vocab_size
474
  self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
 
642
 
643
  Extends the base ESM++ model with a classification head.
644
  """
645
+ def __init__(self, config: ESMplusplusConfig, **kwargs):
646
+ super().__init__(config, **kwargs)
647
  self.config = config
648
+ self.num_labels = config.num_labels
649
  self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4)
650
  # Large intermediate projections help with sequence classification tasks (*4)
651
  self.mse = nn.MSELoss()