Return BaseModelOutputWithPoolingAndCrossAttentions
Browse files- modeling_flexbert.py +2 -2
modeling_flexbert.py
CHANGED
|
@@ -66,7 +66,7 @@ from transformers.modeling_outputs import (
|
|
| 66 |
SequenceClassifierOutput,
|
| 67 |
)
|
| 68 |
from transformers.models.bert.modeling_bert import BertPreTrainedModel
|
| 69 |
-
from transformers import
|
| 70 |
from .bert_padding import index_put_first_axis
|
| 71 |
|
| 72 |
from .activation import get_act_fn
|
|
@@ -968,7 +968,7 @@ class FlexBertModel(FlexBertPreTrainedModel):
|
|
| 968 |
if self.final_norm is not None:
|
| 969 |
encoder_outputs = self.final_norm(encoder_outputs)
|
| 970 |
|
| 971 |
-
return
|
| 972 |
|
| 973 |
def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
|
| 974 |
assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
|
|
|
|
| 66 |
SequenceClassifierOutput,
|
| 67 |
)
|
| 68 |
from transformers.models.bert.modeling_bert import BertPreTrainedModel
|
| 69 |
+
from transformers import BaseModelOutputWithPoolingAndCrossAttentions
|
| 70 |
from .bert_padding import index_put_first_axis
|
| 71 |
|
| 72 |
from .activation import get_act_fn
|
|
|
|
| 968 |
if self.final_norm is not None:
|
| 969 |
encoder_outputs = self.final_norm(encoder_outputs)
|
| 970 |
|
| 971 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=encoder_outputs)
|
| 972 |
|
| 973 |
def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
|
| 974 |
assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
|