Crystalcareai commited on
Commit
f49a8fd
·
verified ·
1 Parent(s): 2c88f19

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +6 -2
modeling_quiet.py CHANGED
@@ -986,10 +986,14 @@ class QuietModel(QuietPreTrainedModel):
986
  elif input_ids is not None:
987
  batch_size, seq_length = input_ids.shape
988
  elif inputs_embeds is not None:
989
- batch_size, seq_length, _ = inputs_embeds.shape
 
 
 
990
  else:
991
  raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
992
 
 
993
  if self.gradient_checkpointing and self.training:
994
  if use_cache:
995
  logger.warning_once(
@@ -1209,7 +1213,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
1209
  attention_mask=attention_mask,
1210
  position_ids=position_ids,
1211
  past_key_values=past_key_values,
1212
- # inputs_embeds=inputs_embeds,
1213
  use_cache=use_cache,
1214
  output_attentions=output_attentions,
1215
  output_hidden_states=output_hidden_states,
 
986
  elif input_ids is not None:
987
  batch_size, seq_length = input_ids.shape
988
  elif inputs_embeds is not None:
989
+ if isinstance(inputs_embeds, list):
990
+ batch_size, seq_length, _ = inputs_embeds[0].shape
991
+ else:
992
+ batch_size, seq_length, _ = inputs_embeds.shape
993
  else:
994
  raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
995
 
996
+
997
  if self.gradient_checkpointing and self.training:
998
  if use_cache:
999
  logger.warning_once(
 
1213
  attention_mask=attention_mask,
1214
  position_ids=position_ids,
1215
  past_key_values=past_key_values,
1216
+ inputs_embeds=inputs_embeds,
1217
  use_cache=use_cache,
1218
  output_attentions=output_attentions,
1219
  output_hidden_states=output_hidden_states,