Crystalcareai commited on
Commit
b087ddf
·
verified ·
1 Parent(s): 90a26fc

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +26 -33
modeling_quiet.py CHANGED
@@ -929,40 +929,29 @@ class QuietModel(QuietPreTrainedModel):
929
  self.embed_tokens = value
930
 
931
  def _generate_thoughts(self, hidden_states, max_length):
932
- thought_ids = []
 
933
  thought_embeddings = []
934
-
935
- for _ in range(self.config.max_thoughts):
936
- thought_id = torch.LongTensor([[self.config.start_token_id]]).to(hidden_states.device)
937
- thought_embedding = self.embed_tokens(thought_id)
938
-
939
- for _ in range(max_length):
940
- outputs = self.forward(
941
- inputs_embeds=thought_embedding,
942
- attention_mask=None,
943
- use_cache=True,
944
- return_dict=True, # Set return_dict=True
945
- )
946
- logits = self.lm_head(outputs.last_hidden_state) # Use outputs.last_hidden_state instead of outputs.logits
947
- next_token_id = torch.argmax(logits[:, -1, :], dim=-1)
948
-
949
- if next_token_id == self.config.end_token_id:
950
- break
951
-
952
- thought_id = torch.cat([thought_id, next_token_id.unsqueeze(0)], dim=-1)
953
- thought_embedding = torch.cat([thought_embedding, self.embed_tokens(next_token_id.unsqueeze(0))], dim=1)
954
-
955
- thought_ids.append(thought_id.squeeze(0))
956
- thought_embeddings.append(thought_embedding.squeeze(0))
957
- seq_length = hidden_states.size(1)
958
- thought_embeddings = [
959
- torch.nn.functional.pad(emb, (0, 0, 0, seq_length - emb.size(0)), mode='constant', value=0)[:seq_length]
960
- for emb in thought_embeddings
961
- ]
962
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
963
  return thought_ids, thought_embeddings
964
 
965
 
 
966
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
967
  def forward(
968
  self,
@@ -1229,13 +1218,17 @@ class QuietForCausalLM(QuietPreTrainedModel):
1229
  hidden_states = outputs.last_hidden_state
1230
  logits = self.lm_head(hidden_states)
1231
 
1232
- thought_ids, thought_embeddings = self.model._generate_thoughts(hidden_states, max_length=self.thought_length)
 
1233
  thought_hidden_states = self.model(inputs_embeds=thought_embeddings).last_hidden_state
 
 
1234
  thought_logits = self.lm_head(thought_hidden_states)
1235
 
1236
- mixing_input = torch.cat([hidden_states, thought_hidden_states], dim=-1)
1237
- mixing_weights = self.mixing_head(mixing_input).squeeze(-1) # (batch_size, seq_length)
1238
- mixed_logits = base_logits * (1 - mixing_weights.unsqueeze(-1)) + thought_logits * mixing_weights.unsqueeze(-1)
 
1239
  loss = None
1240
  if labels is not None:
1241
  # Shift so that tokens < n predict n
 
929
  self.embed_tokens = value
930
 
931
  def _generate_thoughts(self, hidden_states, max_length):
932
+ batch_size = hidden_states.size(0)
933
+ thought_ids = torch.zeros((batch_size, self.config.num_thoughts, max_length), dtype=torch.long, device=hidden_states.device)
934
  thought_embeddings = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
935
 
936
+ for i in range(self.config.num_thoughts):
937
+ thought_input_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=hidden_states.device)
938
+ thought_outputs = self.model.generate(
939
+ input_ids=thought_input_ids,
940
+ max_length=max_length,
941
+ do_sample=True,
942
+ top_k=50,
943
+ top_p=0.95,
944
+ pad_token_id=self.config.pad_token_id,
945
+ eos_token_id=self.config.eos_token_id,
946
+ )
947
+ thought_ids[:, i, :] = thought_outputs
948
+ thought_embeddings.append(self.model.get_input_embeddings()(thought_outputs))
949
+
950
+ thought_embeddings = torch.stack(thought_embeddings, dim=1)
951
  return thought_ids, thought_embeddings
952
 
953
 
954
+
955
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
956
  def forward(
957
  self,
 
1218
  hidden_states = outputs.last_hidden_state
1219
  logits = self.lm_head(hidden_states)
1220
 
1221
+
1222
+ thought_ids, thought_embeddings = self._generate_thoughts(hidden_states, max_length=self.config.max_thought_length)
1223
  thought_hidden_states = self.model(inputs_embeds=thought_embeddings).last_hidden_state
1224
+
1225
+ # Compute thought logits
1226
  thought_logits = self.lm_head(thought_hidden_states)
1227
 
1228
+ # Mix base and thought logits
1229
+ mixed_logits = logits.unsqueeze(1) + self.mixing_head(thought_logits)
1230
+ mixed_logits = mixed_logits.view(-1, mixed_logits.size(-1))
1231
+
1232
  loss = None
1233
  if labels is not None:
1234
  # Shift so that tokens < n predict n