lhallee commited on
Commit
8c45c04
·
verified ·
1 Parent(s): 3f46b49

Update modeling_esm_plusplus.py

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +12 -2
modeling_esm_plusplus.py CHANGED
@@ -669,7 +669,12 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
669
  Returns:
670
  ESMplusplusOutput containing loss, logits, and hidden states
671
  """
672
- output = super().forward(input_ids, attention_mask, labels, output_hidden_states)
 
 
 
 
 
673
  x = output.last_hidden_state
674
  cls_features = x[:, 0, :]
675
  mean_features = self.mean_pooling(x, attention_mask)
@@ -735,7 +740,12 @@ class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM):
735
  Returns:
736
  ESMplusplusOutput containing loss, logits, and hidden states
737
  """
738
- output = super().forward(input_ids, attention_mask, labels, output_hidden_states)
 
 
 
 
 
739
  x = output.last_hidden_state
740
  logits = self.classifier(x)
741
  loss = None
 
669
  Returns:
670
  ESMplusplusOutput containing loss, logits, and hidden states
671
  """
672
+ output = super().forward(
673
+ input_ids=input_ids,
674
+ attention_mask=attention_mask,
675
+ labels=None,
676
+ output_hidden_states=output_hidden_states
677
+ )
678
  x = output.last_hidden_state
679
  cls_features = x[:, 0, :]
680
  mean_features = self.mean_pooling(x, attention_mask)
 
740
  Returns:
741
  ESMplusplusOutput containing loss, logits, and hidden states
742
  """
743
+ output = super().forward(
744
+ input_ids=input_ids,
745
+ attention_mask=attention_mask,
746
+ labels=None,
747
+ output_hidden_states=output_hidden_states
748
+ )
749
  x = output.last_hidden_state
750
  logits = self.classifier(x)
751
  loss = None