Update generate.py
Browse files- generate.py +12 -132
generate.py
CHANGED
@@ -1,11 +1,3 @@
|
|
1 |
-
import torch
|
2 |
-
from transformers.generation.utils import (
|
3 |
-
GenerationMixin,
|
4 |
-
validate_stopping_criteria,
|
5 |
-
StoppingCriteriaList,
|
6 |
-
)
|
7 |
-
from transformers import TextStreamer
|
8 |
-
|
9 |
def custom_generate(
|
10 |
self,
|
11 |
input_ids,
|
@@ -42,6 +34,11 @@ def custom_generate(
|
|
42 |
synced_gpus=None,
|
43 |
**kwargs,
|
44 |
):
|
|
|
|
|
|
|
|
|
|
|
45 |
if input_ids is None or input_ids.nelement() == 0:
|
46 |
# If input_ids is None or an empty tensor, create a default input tensor
|
47 |
input_ids = torch.LongTensor([[self.tokenizer.bos_token_id]]).to(self.device)
|
@@ -61,6 +58,8 @@ def custom_generate(
|
|
61 |
**kwargs
|
62 |
)['logits']
|
63 |
|
|
|
|
|
64 |
# Mask out the start and end thought tokens so we don't accidentally sample them
|
65 |
new_ids[:, :, self.tokenizer.vocab_size:] = -float("inf")
|
66 |
|
@@ -73,6 +72,8 @@ def custom_generate(
|
|
73 |
new_ids_sampled = torch.multinomial(
|
74 |
torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temperature, dim=-1), 1)
|
75 |
|
|
|
|
|
76 |
# Assign the new id to the last token
|
77 |
if last_token_idx + 1 >= len(base_answer_ids):
|
78 |
# Add padding everywhere
|
@@ -100,128 +101,7 @@ def custom_generate(
|
|
100 |
if streamer is not None:
|
101 |
streamer.put(new_ids_sampled)
|
102 |
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
def generate(
|
107 |
-
self,
|
108 |
-
input_ids,
|
109 |
-
attention_mask=None,
|
110 |
-
max_new_tokens=None,
|
111 |
-
min_length=None,
|
112 |
-
do_sample=None,
|
113 |
-
early_stopping=None,
|
114 |
-
num_beams=None,
|
115 |
-
temperature=1.1,
|
116 |
-
streamer=None,
|
117 |
-
top_k=None,
|
118 |
-
top_p=None,
|
119 |
-
repetition_penalty=None,
|
120 |
-
bad_words_ids=None,
|
121 |
-
bos_token_id=None,
|
122 |
-
pad_token_id=None,
|
123 |
-
eos_token_id=None,
|
124 |
-
length_penalty=None,
|
125 |
-
no_repeat_ngram_size=None,
|
126 |
-
num_return_sequences=None,
|
127 |
-
decoder_start_token_id=None,
|
128 |
-
use_cache=None,
|
129 |
-
num_beam_groups=None,
|
130 |
-
diversity_penalty=None,
|
131 |
-
prefix_allowed_tokens_fn=None,
|
132 |
-
output_attentions=None,
|
133 |
-
output_hidden_states=None,
|
134 |
-
output_scores=None,
|
135 |
-
return_dict_in_generate=None,
|
136 |
-
forced_bos_token_id=None,
|
137 |
-
forced_eos_token_id=None,
|
138 |
-
remove_invalid_values=None,
|
139 |
-
synced_gpus=None,
|
140 |
-
n_ahead=4,
|
141 |
-
n_ahead_talk=4,
|
142 |
-
merged_talk_heads=True,
|
143 |
-
merged_lm_and_talk_heads=False,
|
144 |
-
merged_lm_and_think_heads=True,
|
145 |
-
use_concat_talk_head=True,
|
146 |
-
use_shallow_think=True,
|
147 |
-
use_shallow_talk=False,
|
148 |
-
use_complex_think_head=False,
|
149 |
-
use_complex_talk_head=True,
|
150 |
-
use_weighted_talk_head=True,
|
151 |
-
trust_remote_code=True,
|
152 |
-
torch_dtype=torch.bfloat16,
|
153 |
-
**model_kwargs,
|
154 |
-
):
|
155 |
-
|
156 |
-
if max_new_tokens is None:
|
157 |
-
max_new_tokens = 128
|
158 |
-
|
159 |
-
# Set model attributes
|
160 |
-
self.max_thoughts = n_ahead + n_ahead_talk + 1
|
161 |
-
self.merged_talk_heads = merged_talk_heads
|
162 |
-
self.merged_lm_and_talk_heads = merged_lm_and_talk_heads
|
163 |
-
self.merged_lm_and_think_heads = merged_lm_and_think_heads
|
164 |
-
self.use_concat_talk_head = use_concat_talk_head
|
165 |
-
self.use_shallow_think = use_shallow_think
|
166 |
-
self.use_shallow_talk = use_shallow_talk
|
167 |
-
self.use_complex_think_head = use_complex_think_head
|
168 |
-
self.use_complex_talk_head = use_complex_talk_head
|
169 |
-
self.use_weighted_talk_head = use_weighted_talk_head
|
170 |
-
|
171 |
-
# Set model properties
|
172 |
-
self.use_end_thought_token = True
|
173 |
-
self.use_start_thought_token = True
|
174 |
-
self.n_ahead = n_ahead
|
175 |
-
self.n_passes = 1
|
176 |
-
self.eval_mode = True
|
177 |
-
self.first_run = False
|
178 |
-
self.rm_initialized = True
|
179 |
-
self.original_mode = False
|
180 |
-
|
181 |
-
# Check if the input is a string (for compatibility with text-generation-webui)
|
182 |
-
if isinstance(input_ids, str):
|
183 |
-
input_ids = self.tokenizer.encode(input_ids, return_tensors='pt')
|
184 |
-
|
185 |
-
# Move input_ids and attention_mask to the same device as the model
|
186 |
-
input_ids = input_ids.to(self.device)
|
187 |
-
if attention_mask is not None:
|
188 |
-
attention_mask = attention_mask.to(self.device)
|
189 |
-
|
190 |
-
generated_token_ids = custom_generate(
|
191 |
-
self,
|
192 |
-
input_ids=input_ids,
|
193 |
-
attention_mask=attention_mask,
|
194 |
-
max_new_tokens=max_new_tokens,
|
195 |
-
min_length=min_length,
|
196 |
-
do_sample=do_sample,
|
197 |
-
early_stopping=early_stopping,
|
198 |
-
num_beams=num_beams,
|
199 |
-
temperature=temperature,
|
200 |
-
top_k=top_k,
|
201 |
-
top_p=top_p,
|
202 |
-
repetition_penalty=repetition_penalty,
|
203 |
-
bad_words_ids=bad_words_ids,
|
204 |
-
bos_token_id=bos_token_id,
|
205 |
-
pad_token_id=pad_token_id,
|
206 |
-
eos_token_id=eos_token_id,
|
207 |
-
length_penalty=length_penalty,
|
208 |
-
no_repeat_ngram_size=no_repeat_ngram_size,
|
209 |
-
num_return_sequences=num_return_sequences,
|
210 |
-
decoder_start_token_id=decoder_start_token_id,
|
211 |
-
use_cache=use_cache,
|
212 |
-
num_beam_groups=num_beam_groups,
|
213 |
-
diversity_penalty=diversity_penalty,
|
214 |
-
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
215 |
-
output_attentions=output_attentions,
|
216 |
-
output_hidden_states=output_hidden_states,
|
217 |
-
output_scores=output_scores,
|
218 |
-
return_dict_in_generate=return_dict_in_generate,
|
219 |
-
forced_bos_token_id=forced_bos_token_id,
|
220 |
-
forced_eos_token_id=forced_eos_token_id,
|
221 |
-
remove_invalid_values=remove_invalid_values,
|
222 |
-
synced_gpus=synced_gpus,
|
223 |
-
streamer=streamer,
|
224 |
-
**model_kwargs,
|
225 |
-
)
|
226 |
|
227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
def custom_generate(
|
2 |
self,
|
3 |
input_ids,
|
|
|
34 |
synced_gpus=None,
|
35 |
**kwargs,
|
36 |
):
|
37 |
+
print("Input IDs shape:", input_ids.shape)
|
38 |
+
print("Input IDs:", input_ids)
|
39 |
+
print("Attention Mask shape:", attention_mask.shape if attention_mask is not None else None)
|
40 |
+
print("Attention Mask:", attention_mask)
|
41 |
+
|
42 |
if input_ids is None or input_ids.nelement() == 0:
|
43 |
# If input_ids is None or an empty tensor, create a default input tensor
|
44 |
input_ids = torch.LongTensor([[self.tokenizer.bos_token_id]]).to(self.device)
|
|
|
58 |
**kwargs
|
59 |
)['logits']
|
60 |
|
61 |
+
print(f"Step {cur_token_idx + 1}: New IDs shape: {new_ids.shape}")
|
62 |
+
|
63 |
# Mask out the start and end thought tokens so we don't accidentally sample them
|
64 |
new_ids[:, :, self.tokenizer.vocab_size:] = -float("inf")
|
65 |
|
|
|
72 |
new_ids_sampled = torch.multinomial(
|
73 |
torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temperature, dim=-1), 1)
|
74 |
|
75 |
+
print(f"Step {cur_token_idx + 1}: New IDs sampled: {new_ids_sampled}")
|
76 |
+
|
77 |
# Assign the new id to the last token
|
78 |
if last_token_idx + 1 >= len(base_answer_ids):
|
79 |
# Add padding everywhere
|
|
|
101 |
if streamer is not None:
|
102 |
streamer.put(new_ids_sampled)
|
103 |
|
104 |
+
print("Generated Token IDs shape:", generated_token_ids.shape)
|
105 |
+
print("Generated Token IDs:", generated_token_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
+
return generated_token_ids
|