lhallee commited on
Commit
013d99e
·
verified ·
1 Parent(s): e42ebfb

Update modeling_esm_plusplus.py

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +2 -1
modeling_esm_plusplus.py CHANGED
@@ -411,6 +411,7 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
411
  if attention_mask is None:
412
  return x.mean(dim=1)
413
  else:
 
414
  return (x * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
415
 
416
  def forward(
@@ -440,7 +441,7 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
440
 
441
  if self.config.problem_type == "regression":
442
  if self.num_labels == 1:
443
- loss = self.mse(logits.squeeze(), labels.squeeze())
444
  else:
445
  loss = self.mse(logits, labels)
446
  elif self.config.problem_type == "single_label_classification":
 
411
  if attention_mask is None:
412
  return x.mean(dim=1)
413
  else:
414
+ attention_mask = attention_mask.unsqueeze(-1)
415
  return (x * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
416
 
417
  def forward(
 
441
 
442
  if self.config.problem_type == "regression":
443
  if self.num_labels == 1:
444
+ loss = self.mse(logits.flatten(), labels.flatten())
445
  else:
446
  loss = self.mse(logits, labels)
447
  elif self.config.problem_type == "single_label_classification":