Update modeling_esm_plusplus.py
Browse files- modeling_esm_plusplus.py +12 -2
modeling_esm_plusplus.py
CHANGED
@@ -669,7 +669,12 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
|
|
669 |
Returns:
|
670 |
ESMplusplusOutput containing loss, logits, and hidden states
|
671 |
"""
|
672 |
-
output = super().forward(
|
|
|
|
|
|
|
|
|
|
|
673 |
x = output.last_hidden_state
|
674 |
cls_features = x[:, 0, :]
|
675 |
mean_features = self.mean_pooling(x, attention_mask)
|
@@ -735,7 +740,12 @@ class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM):
|
|
735 |
Returns:
|
736 |
ESMplusplusOutput containing loss, logits, and hidden states
|
737 |
"""
|
738 |
-
output = super().forward(
|
|
|
|
|
|
|
|
|
|
|
739 |
x = output.last_hidden_state
|
740 |
logits = self.classifier(x)
|
741 |
loss = None
|
|
|
669 |
Returns:
|
670 |
ESMplusplusOutput containing loss, logits, and hidden states
|
671 |
"""
|
672 |
+
output = super().forward(
|
673 |
+
input_ids=input_ids,
|
674 |
+
attention_mask=attention_mask,
|
675 |
+
labels=None,
|
676 |
+
output_hidden_states=output_hidden_states
|
677 |
+
)
|
678 |
x = output.last_hidden_state
|
679 |
cls_features = x[:, 0, :]
|
680 |
mean_features = self.mean_pooling(x, attention_mask)
|
|
|
740 |
Returns:
|
741 |
ESMplusplusOutput containing loss, logits, and hidden states
|
742 |
"""
|
743 |
+
output = super().forward(
|
744 |
+
input_ids=input_ids,
|
745 |
+
attention_mask=attention_mask,
|
746 |
+
labels=None,
|
747 |
+
output_hidden_states=output_hidden_states
|
748 |
+
)
|
749 |
x = output.last_hidden_state
|
750 |
logits = self.classifier(x)
|
751 |
loss = None
|