Crystalcareai commited on
Commit
f874538
·
verified ·
1 Parent(s): 9a4caac

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +5 -3
modeling_quiet.py CHANGED
@@ -2293,13 +2293,15 @@ class QuietForCausalLM(QuietPreTrainedModel):
2293
  return input_ids
2294
 
2295
  def prepare_thought_embeds(self, hidden_states, temperature=1.0):
 
 
2296
  if self.use_start_thought_token:
2297
- start_embed = self.start_embedding[0].unsqueeze(0) * temperature
2298
  else:
2299
- start_embed = hidden_states[:, 0, :]
2300
 
2301
  if self.use_end_thought_token:
2302
- end_embed = self.end_embedding[0].unsqueeze(0) * temperature
2303
  thought_embeds = torch.cat([start_embed, hidden_states[:, 1:-1, :], end_embed], dim=1)
2304
  else:
2305
  thought_embeds = torch.cat([start_embed, hidden_states[:, 1:, :]], dim=1)
 
2293
  return input_ids
2294
 
2295
  def prepare_thought_embeds(self, hidden_states, temperature=1.0):
2296
+ batch_size, seq_len, hidden_size = hidden_states.shape
2297
+
2298
  if self.use_start_thought_token:
2299
+ start_embed = self.start_embedding[0].unsqueeze(0).unsqueeze(0).repeat(batch_size, 1, 1) * temperature
2300
  else:
2301
+ start_embed = hidden_states[:, :1, :]
2302
 
2303
  if self.use_end_thought_token:
2304
+ end_embed = self.end_embedding[0].unsqueeze(0).unsqueeze(0).repeat(batch_size, 1, 1) * temperature
2305
  thought_embeds = torch.cat([start_embed, hidden_states[:, 1:-1, :], end_embed], dim=1)
2306
  else:
2307
  thought_embeds = torch.cat([start_embed, hidden_states[:, 1:, :]], dim=1)