Update modeling_quiet.py
Browse files- modeling_quiet.py +6 -2
modeling_quiet.py
CHANGED
@@ -986,10 +986,14 @@ class QuietModel(QuietPreTrainedModel):
|
|
986 |
elif input_ids is not None:
|
987 |
batch_size, seq_length = input_ids.shape
|
988 |
elif inputs_embeds is not None:
|
989 |
-
|
|
|
|
|
|
|
990 |
else:
|
991 |
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
992 |
|
|
|
993 |
if self.gradient_checkpointing and self.training:
|
994 |
if use_cache:
|
995 |
logger.warning_once(
|
@@ -1209,7 +1213,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1209 |
attention_mask=attention_mask,
|
1210 |
position_ids=position_ids,
|
1211 |
past_key_values=past_key_values,
|
1212 |
-
|
1213 |
use_cache=use_cache,
|
1214 |
output_attentions=output_attentions,
|
1215 |
output_hidden_states=output_hidden_states,
|
|
|
986 |
elif input_ids is not None:
|
987 |
batch_size, seq_length = input_ids.shape
|
988 |
elif inputs_embeds is not None:
|
989 |
+
if isinstance(inputs_embeds, list):
|
990 |
+
batch_size, seq_length, _ = inputs_embeds[0].shape
|
991 |
+
else:
|
992 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
993 |
else:
|
994 |
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
995 |
|
996 |
+
|
997 |
if self.gradient_checkpointing and self.training:
|
998 |
if use_cache:
|
999 |
logger.warning_once(
|
|
|
1213 |
attention_mask=attention_mask,
|
1214 |
position_ids=position_ids,
|
1215 |
past_key_values=past_key_values,
|
1216 |
+
inputs_embeds=inputs_embeds,
|
1217 |
use_cache=use_cache,
|
1218 |
output_attentions=output_attentions,
|
1219 |
output_hidden_states=output_hidden_states,
|