Update modeling_quiet.py
Browse files- modeling_quiet.py +5 -3
modeling_quiet.py
CHANGED
@@ -2293,13 +2293,15 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
2293 |
return input_ids
|
2294 |
|
2295 |
def prepare_thought_embeds(self, hidden_states, temperature=1.0):
|
|
|
|
|
2296 |
if self.use_start_thought_token:
|
2297 |
-
start_embed = self.start_embedding[0].unsqueeze(0) * temperature
|
2298 |
else:
|
2299 |
-
start_embed = hidden_states[:,
|
2300 |
|
2301 |
if self.use_end_thought_token:
|
2302 |
-
end_embed = self.end_embedding[0].unsqueeze(0) * temperature
|
2303 |
thought_embeds = torch.cat([start_embed, hidden_states[:, 1:-1, :], end_embed], dim=1)
|
2304 |
else:
|
2305 |
thought_embeds = torch.cat([start_embed, hidden_states[:, 1:, :]], dim=1)
|
|
|
2293 |
return input_ids
|
2294 |
|
2295 |
def prepare_thought_embeds(self, hidden_states, temperature=1.0):
|
2296 |
+
batch_size, seq_len, hidden_size = hidden_states.shape
|
2297 |
+
|
2298 |
if self.use_start_thought_token:
|
2299 |
+
start_embed = self.start_embedding[0].unsqueeze(0).unsqueeze(0).repeat(batch_size, 1, 1) * temperature
|
2300 |
else:
|
2301 |
+
start_embed = hidden_states[:, :1, :]
|
2302 |
|
2303 |
if self.use_end_thought_token:
|
2304 |
+
end_embed = self.end_embedding[0].unsqueeze(0).unsqueeze(0).repeat(batch_size, 1, 1) * temperature
|
2305 |
thought_embeds = torch.cat([start_embed, hidden_states[:, 1:-1, :], end_embed], dim=1)
|
2306 |
else:
|
2307 |
thought_embeds = torch.cat([start_embed, hidden_states[:, 1:, :]], dim=1)
|