NohTow commited on
Commit
98cac79
·
1 Parent(s): f1c47e8

Return BaseModelOutputWithPoolingAndCrossAttentions

Browse files
Files changed (1) hide show
  1. modeling_flexbert.py +2 -2
modeling_flexbert.py CHANGED
@@ -66,7 +66,7 @@ from transformers.modeling_outputs import (
66
  SequenceClassifierOutput,
67
  )
68
  from transformers.models.bert.modeling_bert import BertPreTrainedModel
69
- from transformers import BaseModelOutput
70
  from .bert_padding import index_put_first_axis
71
 
72
  from .activation import get_act_fn
@@ -968,7 +968,7 @@ class FlexBertModel(FlexBertPreTrainedModel):
968
  if self.final_norm is not None:
969
  encoder_outputs = self.final_norm(encoder_outputs)
970
 
971
- return BaseModelOutput(last_hidden_state=encoder_outputs)
972
 
973
  def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
974
  assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
 
66
  SequenceClassifierOutput,
67
  )
68
  from transformers.models.bert.modeling_bert import BertPreTrainedModel
69
+ from transformers import BaseModelOutputWithPoolingAndCrossAttentions
70
  from .bert_padding import index_put_first_axis
71
 
72
  from .activation import get_act_fn
 
968
  if self.final_norm is not None:
969
  encoder_outputs = self.final_norm(encoder_outputs)
970
 
971
+ return BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=encoder_outputs)
972
 
973
  def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
974
  assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"