Update modeling_quiet.py
Browse files- modeling_quiet.py +3 -1
modeling_quiet.py
CHANGED
@@ -987,13 +987,15 @@ class QuietModel(QuietPreTrainedModel):
|
|
987 |
batch_size, seq_length = input_ids.shape
|
988 |
elif inputs_embeds is not None:
|
989 |
if isinstance(inputs_embeds, list):
|
990 |
-
batch_size, seq_length
|
|
|
991 |
else:
|
992 |
batch_size, seq_length, _ = inputs_embeds.shape
|
993 |
else:
|
994 |
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
995 |
|
996 |
|
|
|
997 |
if self.gradient_checkpointing and self.training:
|
998 |
if use_cache:
|
999 |
logger.warning_once(
|
|
|
987 |
batch_size, seq_length = input_ids.shape
|
988 |
elif inputs_embeds is not None:
|
989 |
if isinstance(inputs_embeds, list):
|
990 |
+
batch_size, seq_length = inputs_embeds[0].shape
|
991 |
+
inputs_embeds = torch.stack(inputs_embeds, dim=0)
|
992 |
else:
|
993 |
batch_size, seq_length, _ = inputs_embeds.shape
|
994 |
else:
|
995 |
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
996 |
|
997 |
|
998 |
+
|
999 |
if self.gradient_checkpointing and self.training:
|
1000 |
if use_cache:
|
1001 |
logger.warning_once(
|