Update modeling_quiet.py
Browse files- modeling_quiet.py +3 -29
modeling_quiet.py
CHANGED
@@ -57,31 +57,6 @@ logger = logging.get_logger(__name__)
|
|
57 |
|
58 |
_CONFIG_FOR_DOC = "QuietConfig"
|
59 |
|
60 |
-
@dataclass
|
61 |
-
class ModelOutput:
|
62 |
-
"""
|
63 |
-
Base class for model's outputs, with potential hidden states and attentions.
|
64 |
-
"""
|
65 |
-
|
66 |
-
def to_tuple(self):
|
67 |
-
"""
|
68 |
-
Convert the output to a tuple.
|
69 |
-
"""
|
70 |
-
return tuple(self[k] for k in self.keys())
|
71 |
-
|
72 |
-
@dataclass
|
73 |
-
class BaseModelOutput(ModelOutput):
|
74 |
-
last_hidden_state: torch.FloatTensor = None
|
75 |
-
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
76 |
-
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
77 |
-
|
78 |
-
@dataclass
|
79 |
-
class QuietModelOutputWithPast(BaseModelOutput):
|
80 |
-
last_hidden_state: torch.FloatTensor = None
|
81 |
-
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
82 |
-
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
83 |
-
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
84 |
-
logits: torch.FloatTensor = None
|
85 |
|
86 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
87 |
def _get_unpad_data(attention_mask):
|
@@ -1123,12 +1098,11 @@ class QuietModel(QuietPreTrainedModel):
|
|
1123 |
|
1124 |
if not return_dict:
|
1125 |
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
1126 |
-
return
|
1127 |
last_hidden_state=hidden_states,
|
1128 |
past_key_values=next_cache,
|
1129 |
hidden_states=all_hidden_states,
|
1130 |
attentions=all_self_attns,
|
1131 |
-
logits=self.lm_head(hidden_states),
|
1132 |
)
|
1133 |
|
1134 |
|
@@ -1274,8 +1248,8 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1274 |
return ((loss,) + output) if loss is not None else output
|
1275 |
|
1276 |
return CausalLMOutputWithPast(
|
1277 |
-
loss=loss,
|
1278 |
-
logits=
|
1279 |
past_key_values=outputs.past_key_values,
|
1280 |
hidden_states=outputs.hidden_states,
|
1281 |
attentions=outputs.attentions,
|
|
|
57 |
|
58 |
_CONFIG_FOR_DOC = "QuietConfig"
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
62 |
def _get_unpad_data(attention_mask):
|
|
|
1098 |
|
1099 |
if not return_dict:
|
1100 |
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
1101 |
+
return BaseModelOutputWithPast(
|
1102 |
last_hidden_state=hidden_states,
|
1103 |
past_key_values=next_cache,
|
1104 |
hidden_states=all_hidden_states,
|
1105 |
attentions=all_self_attns,
|
|
|
1106 |
)
|
1107 |
|
1108 |
|
|
|
1248 |
return ((loss,) + output) if loss is not None else output
|
1249 |
|
1250 |
return CausalLMOutputWithPast(
|
1251 |
+
loss=loss if loss is not None else None,
|
1252 |
+
logits=(rm_logits if self.n_ahead > 1 else logits) if not self.output_logits_at_the_end else logits,
|
1253 |
past_key_values=outputs.past_key_values,
|
1254 |
hidden_states=outputs.hidden_states,
|
1255 |
attentions=outputs.attentions,
|