Update modeling_esm_plusplus.py
Browse files- modeling_esm_plusplus.py +2 -1
modeling_esm_plusplus.py
CHANGED
@@ -411,6 +411,7 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
|
|
411 |
if attention_mask is None:
|
412 |
return x.mean(dim=1)
|
413 |
else:
|
|
|
414 |
return (x * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
|
415 |
|
416 |
def forward(
|
@@ -440,7 +441,7 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
|
|
440 |
|
441 |
if self.config.problem_type == "regression":
|
442 |
if self.num_labels == 1:
|
443 |
-
loss = self.mse(logits.
|
444 |
else:
|
445 |
loss = self.mse(logits, labels)
|
446 |
elif self.config.problem_type == "single_label_classification":
|
|
|
411 |
if attention_mask is None:
|
412 |
return x.mean(dim=1)
|
413 |
else:
|
414 |
+
attention_mask = attention_mask.unsqueeze(-1)
|
415 |
return (x * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
|
416 |
|
417 |
def forward(
|
|
|
441 |
|
442 |
if self.config.problem_type == "regression":
|
443 |
if self.num_labels == 1:
|
444 |
+
loss = self.mse(logits.flatten(), labels.flatten())
|
445 |
else:
|
446 |
loss = self.mse(logits, labels)
|
447 |
elif self.config.problem_type == "single_label_classification":
|