Crystalcareai commited on
Commit
44640e0
·
verified ·
1 Parent(s): 55e9861

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +3 -29
modeling_quiet.py CHANGED
@@ -57,31 +57,6 @@ logger = logging.get_logger(__name__)
57
 
58
  _CONFIG_FOR_DOC = "QuietConfig"
59
 
60
- @dataclass
61
- class ModelOutput:
62
- """
63
- Base class for model's outputs, with potential hidden states and attentions.
64
- """
65
-
66
- def to_tuple(self):
67
- """
68
- Convert the output to a tuple.
69
- """
70
- return tuple(self[k] for k in self.keys())
71
-
72
- @dataclass
73
- class BaseModelOutput(ModelOutput):
74
- last_hidden_state: torch.FloatTensor = None
75
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
76
- attentions: Optional[Tuple[torch.FloatTensor]] = None
77
-
78
- @dataclass
79
- class QuietModelOutputWithPast(BaseModelOutput):
80
- last_hidden_state: torch.FloatTensor = None
81
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
82
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
83
- attentions: Optional[Tuple[torch.FloatTensor]] = None
84
- logits: torch.FloatTensor = None
85
 
86
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
87
  def _get_unpad_data(attention_mask):
@@ -1123,12 +1098,11 @@ class QuietModel(QuietPreTrainedModel):
1123
 
1124
  if not return_dict:
1125
  return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1126
- return QuietModelOutputWithPast(
1127
  last_hidden_state=hidden_states,
1128
  past_key_values=next_cache,
1129
  hidden_states=all_hidden_states,
1130
  attentions=all_self_attns,
1131
- logits=self.lm_head(hidden_states),
1132
  )
1133
 
1134
 
@@ -1274,8 +1248,8 @@ class QuietForCausalLM(QuietPreTrainedModel):
1274
  return ((loss,) + output) if loss is not None else output
1275
 
1276
  return CausalLMOutputWithPast(
1277
- loss=loss,
1278
- logits=mixed_logits,
1279
  past_key_values=outputs.past_key_values,
1280
  hidden_states=outputs.hidden_states,
1281
  attentions=outputs.attentions,
 
57
 
58
  _CONFIG_FOR_DOC = "QuietConfig"
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
62
  def _get_unpad_data(attention_mask):
 
1098
 
1099
  if not return_dict:
1100
  return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1101
+ return BaseModelOutputWithPast(
1102
  last_hidden_state=hidden_states,
1103
  past_key_values=next_cache,
1104
  hidden_states=all_hidden_states,
1105
  attentions=all_self_attns,
 
1106
  )
1107
 
1108
 
 
1248
  return ((loss,) + output) if loss is not None else output
1249
 
1250
  return CausalLMOutputWithPast(
1251
+ loss=loss if loss is not None else None,
1252
+ logits=(rm_logits if self.n_ahead > 1 else logits) if not self.output_logits_at_the_end else logits,
1253
  past_key_values=outputs.past_key_values,
1254
  hidden_states=outputs.hidden_states,
1255
  attentions=outputs.attentions,