Update modeling_quiet.py
Browse files- modeling_quiet.py +26 -33
modeling_quiet.py
CHANGED
@@ -929,40 +929,29 @@ class QuietModel(QuietPreTrainedModel):
|
|
929 |
self.embed_tokens = value
|
930 |
|
931 |
def _generate_thoughts(self, hidden_states, max_length):
|
932 |
-
|
|
|
933 |
thought_embeddings = []
|
934 |
-
|
935 |
-
for _ in range(self.config.max_thoughts):
|
936 |
-
thought_id = torch.LongTensor([[self.config.start_token_id]]).to(hidden_states.device)
|
937 |
-
thought_embedding = self.embed_tokens(thought_id)
|
938 |
-
|
939 |
-
for _ in range(max_length):
|
940 |
-
outputs = self.forward(
|
941 |
-
inputs_embeds=thought_embedding,
|
942 |
-
attention_mask=None,
|
943 |
-
use_cache=True,
|
944 |
-
return_dict=True, # Set return_dict=True
|
945 |
-
)
|
946 |
-
logits = self.lm_head(outputs.last_hidden_state) # Use outputs.last_hidden_state instead of outputs.logits
|
947 |
-
next_token_id = torch.argmax(logits[:, -1, :], dim=-1)
|
948 |
-
|
949 |
-
if next_token_id == self.config.end_token_id:
|
950 |
-
break
|
951 |
-
|
952 |
-
thought_id = torch.cat([thought_id, next_token_id.unsqueeze(0)], dim=-1)
|
953 |
-
thought_embedding = torch.cat([thought_embedding, self.embed_tokens(next_token_id.unsqueeze(0))], dim=1)
|
954 |
-
|
955 |
-
thought_ids.append(thought_id.squeeze(0))
|
956 |
-
thought_embeddings.append(thought_embedding.squeeze(0))
|
957 |
-
seq_length = hidden_states.size(1)
|
958 |
-
thought_embeddings = [
|
959 |
-
torch.nn.functional.pad(emb, (0, 0, 0, seq_length - emb.size(0)), mode='constant', value=0)[:seq_length]
|
960 |
-
for emb in thought_embeddings
|
961 |
-
]
|
962 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
963 |
return thought_ids, thought_embeddings
|
964 |
|
965 |
|
|
|
966 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
967 |
def forward(
|
968 |
self,
|
@@ -1229,13 +1218,17 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1229 |
hidden_states = outputs.last_hidden_state
|
1230 |
logits = self.lm_head(hidden_states)
|
1231 |
|
1232 |
-
|
|
|
1233 |
thought_hidden_states = self.model(inputs_embeds=thought_embeddings).last_hidden_state
|
|
|
|
|
1234 |
thought_logits = self.lm_head(thought_hidden_states)
|
1235 |
|
1236 |
-
|
1237 |
-
|
1238 |
-
mixed_logits =
|
|
|
1239 |
loss = None
|
1240 |
if labels is not None:
|
1241 |
# Shift so that tokens < n predict n
|
|
|
929 |
self.embed_tokens = value
|
930 |
|
931 |
def _generate_thoughts(self, hidden_states, max_length):
|
932 |
+
batch_size = hidden_states.size(0)
|
933 |
+
thought_ids = torch.zeros((batch_size, self.config.num_thoughts, max_length), dtype=torch.long, device=hidden_states.device)
|
934 |
thought_embeddings = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
935 |
|
936 |
+
for i in range(self.config.num_thoughts):
|
937 |
+
thought_input_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=hidden_states.device)
|
938 |
+
thought_outputs = self.model.generate(
|
939 |
+
input_ids=thought_input_ids,
|
940 |
+
max_length=max_length,
|
941 |
+
do_sample=True,
|
942 |
+
top_k=50,
|
943 |
+
top_p=0.95,
|
944 |
+
pad_token_id=self.config.pad_token_id,
|
945 |
+
eos_token_id=self.config.eos_token_id,
|
946 |
+
)
|
947 |
+
thought_ids[:, i, :] = thought_outputs
|
948 |
+
thought_embeddings.append(self.model.get_input_embeddings()(thought_outputs))
|
949 |
+
|
950 |
+
thought_embeddings = torch.stack(thought_embeddings, dim=1)
|
951 |
return thought_ids, thought_embeddings
|
952 |
|
953 |
|
954 |
+
|
955 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
956 |
def forward(
|
957 |
self,
|
|
|
1218 |
hidden_states = outputs.last_hidden_state
|
1219 |
logits = self.lm_head(hidden_states)
|
1220 |
|
1221 |
+
|
1222 |
+
thought_ids, thought_embeddings = self._generate_thoughts(hidden_states, max_length=self.config.max_thought_length)
|
1223 |
thought_hidden_states = self.model(inputs_embeds=thought_embeddings).last_hidden_state
|
1224 |
+
|
1225 |
+
# Compute thought logits
|
1226 |
thought_logits = self.lm_head(thought_hidden_states)
|
1227 |
|
1228 |
+
# Mix base and thought logits
|
1229 |
+
mixed_logits = logits.unsqueeze(1) + self.mixing_head(thought_logits)
|
1230 |
+
mixed_logits = mixed_logits.view(-1, mixed_logits.size(-1))
|
1231 |
+
|
1232 |
loss = None
|
1233 |
if labels is not None:
|
1234 |
# Shift so that tokens < n predict n
|