Return BaseModelOutputWithPoolingAndCrossAttentions
Browse files- modeling_flexbert.py +2 -2
modeling_flexbert.py
CHANGED
@@ -66,7 +66,7 @@ from transformers.modeling_outputs import (
|
|
66 |
SequenceClassifierOutput,
|
67 |
)
|
68 |
from transformers.models.bert.modeling_bert import BertPreTrainedModel
|
69 |
-
from transformers import
|
70 |
from .bert_padding import index_put_first_axis
|
71 |
|
72 |
from .activation import get_act_fn
|
@@ -968,7 +968,7 @@ class FlexBertModel(FlexBertPreTrainedModel):
|
|
968 |
if self.final_norm is not None:
|
969 |
encoder_outputs = self.final_norm(encoder_outputs)
|
970 |
|
971 |
-
return
|
972 |
|
973 |
def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
|
974 |
assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
|
|
|
66 |
SequenceClassifierOutput,
|
67 |
)
|
68 |
from transformers.models.bert.modeling_bert import BertPreTrainedModel
|
69 |
+
from transformers import BaseModelOutputWithPoolingAndCrossAttentions
|
70 |
from .bert_padding import index_put_first_axis
|
71 |
|
72 |
from .activation import get_act_fn
|
|
|
968 |
if self.final_norm is not None:
|
969 |
encoder_outputs = self.final_norm(encoder_outputs)
|
970 |
|
971 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=encoder_outputs)
|
972 |
|
973 |
def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
|
974 |
assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
|