Crystalcareai commited on
Commit
952f2ad
·
verified ·
1 Parent(s): 846f7b2

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +12 -132
generate.py CHANGED
@@ -1,11 +1,3 @@
1
- import torch
2
- from transformers.generation.utils import (
3
- GenerationMixin,
4
- validate_stopping_criteria,
5
- StoppingCriteriaList,
6
- )
7
- from transformers import TextStreamer
8
-
9
  def custom_generate(
10
  self,
11
  input_ids,
@@ -42,6 +34,11 @@ def custom_generate(
42
  synced_gpus=None,
43
  **kwargs,
44
  ):
 
 
 
 
 
45
  if input_ids is None or input_ids.nelement() == 0:
46
  # If input_ids is None or an empty tensor, create a default input tensor
47
  input_ids = torch.LongTensor([[self.tokenizer.bos_token_id]]).to(self.device)
@@ -61,6 +58,8 @@ def custom_generate(
61
  **kwargs
62
  )['logits']
63
 
 
 
64
  # Mask out the start and end thought tokens so we don't accidentally sample them
65
  new_ids[:, :, self.tokenizer.vocab_size:] = -float("inf")
66
 
@@ -73,6 +72,8 @@ def custom_generate(
73
  new_ids_sampled = torch.multinomial(
74
  torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temperature, dim=-1), 1)
75
 
 
 
76
  # Assign the new id to the last token
77
  if last_token_idx + 1 >= len(base_answer_ids):
78
  # Add padding everywhere
@@ -100,128 +101,7 @@ def custom_generate(
100
  if streamer is not None:
101
  streamer.put(new_ids_sampled)
102
 
103
- return generated_token_ids
104
-
105
-
106
- def generate(
107
- self,
108
- input_ids,
109
- attention_mask=None,
110
- max_new_tokens=None,
111
- min_length=None,
112
- do_sample=None,
113
- early_stopping=None,
114
- num_beams=None,
115
- temperature=1.1,
116
- streamer=None,
117
- top_k=None,
118
- top_p=None,
119
- repetition_penalty=None,
120
- bad_words_ids=None,
121
- bos_token_id=None,
122
- pad_token_id=None,
123
- eos_token_id=None,
124
- length_penalty=None,
125
- no_repeat_ngram_size=None,
126
- num_return_sequences=None,
127
- decoder_start_token_id=None,
128
- use_cache=None,
129
- num_beam_groups=None,
130
- diversity_penalty=None,
131
- prefix_allowed_tokens_fn=None,
132
- output_attentions=None,
133
- output_hidden_states=None,
134
- output_scores=None,
135
- return_dict_in_generate=None,
136
- forced_bos_token_id=None,
137
- forced_eos_token_id=None,
138
- remove_invalid_values=None,
139
- synced_gpus=None,
140
- n_ahead=4,
141
- n_ahead_talk=4,
142
- merged_talk_heads=True,
143
- merged_lm_and_talk_heads=False,
144
- merged_lm_and_think_heads=True,
145
- use_concat_talk_head=True,
146
- use_shallow_think=True,
147
- use_shallow_talk=False,
148
- use_complex_think_head=False,
149
- use_complex_talk_head=True,
150
- use_weighted_talk_head=True,
151
- trust_remote_code=True,
152
- torch_dtype=torch.bfloat16,
153
- **model_kwargs,
154
- ):
155
-
156
- if max_new_tokens is None:
157
- max_new_tokens = 128
158
-
159
- # Set model attributes
160
- self.max_thoughts = n_ahead + n_ahead_talk + 1
161
- self.merged_talk_heads = merged_talk_heads
162
- self.merged_lm_and_talk_heads = merged_lm_and_talk_heads
163
- self.merged_lm_and_think_heads = merged_lm_and_think_heads
164
- self.use_concat_talk_head = use_concat_talk_head
165
- self.use_shallow_think = use_shallow_think
166
- self.use_shallow_talk = use_shallow_talk
167
- self.use_complex_think_head = use_complex_think_head
168
- self.use_complex_talk_head = use_complex_talk_head
169
- self.use_weighted_talk_head = use_weighted_talk_head
170
-
171
- # Set model properties
172
- self.use_end_thought_token = True
173
- self.use_start_thought_token = True
174
- self.n_ahead = n_ahead
175
- self.n_passes = 1
176
- self.eval_mode = True
177
- self.first_run = False
178
- self.rm_initialized = True
179
- self.original_mode = False
180
-
181
- # Check if the input is a string (for compatibility with text-generation-webui)
182
- if isinstance(input_ids, str):
183
- input_ids = self.tokenizer.encode(input_ids, return_tensors='pt')
184
-
185
- # Move input_ids and attention_mask to the same device as the model
186
- input_ids = input_ids.to(self.device)
187
- if attention_mask is not None:
188
- attention_mask = attention_mask.to(self.device)
189
-
190
- generated_token_ids = custom_generate(
191
- self,
192
- input_ids=input_ids,
193
- attention_mask=attention_mask,
194
- max_new_tokens=max_new_tokens,
195
- min_length=min_length,
196
- do_sample=do_sample,
197
- early_stopping=early_stopping,
198
- num_beams=num_beams,
199
- temperature=temperature,
200
- top_k=top_k,
201
- top_p=top_p,
202
- repetition_penalty=repetition_penalty,
203
- bad_words_ids=bad_words_ids,
204
- bos_token_id=bos_token_id,
205
- pad_token_id=pad_token_id,
206
- eos_token_id=eos_token_id,
207
- length_penalty=length_penalty,
208
- no_repeat_ngram_size=no_repeat_ngram_size,
209
- num_return_sequences=num_return_sequences,
210
- decoder_start_token_id=decoder_start_token_id,
211
- use_cache=use_cache,
212
- num_beam_groups=num_beam_groups,
213
- diversity_penalty=diversity_penalty,
214
- prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
215
- output_attentions=output_attentions,
216
- output_hidden_states=output_hidden_states,
217
- output_scores=output_scores,
218
- return_dict_in_generate=return_dict_in_generate,
219
- forced_bos_token_id=forced_bos_token_id,
220
- forced_eos_token_id=forced_eos_token_id,
221
- remove_invalid_values=remove_invalid_values,
222
- synced_gpus=synced_gpus,
223
- streamer=streamer,
224
- **model_kwargs,
225
- )
226
 
227
- return generated_token_ids
 
 
 
 
 
 
 
 
 
1
  def custom_generate(
2
  self,
3
  input_ids,
 
34
  synced_gpus=None,
35
  **kwargs,
36
  ):
37
+ print("Input IDs shape:", input_ids.shape)
38
+ print("Input IDs:", input_ids)
39
+ print("Attention Mask shape:", attention_mask.shape if attention_mask is not None else None)
40
+ print("Attention Mask:", attention_mask)
41
+
42
  if input_ids is None or input_ids.nelement() == 0:
43
  # If input_ids is None or an empty tensor, create a default input tensor
44
  input_ids = torch.LongTensor([[self.tokenizer.bos_token_id]]).to(self.device)
 
58
  **kwargs
59
  )['logits']
60
 
61
+ print(f"Step {cur_token_idx + 1}: New IDs shape: {new_ids.shape}")
62
+
63
  # Mask out the start and end thought tokens so we don't accidentally sample them
64
  new_ids[:, :, self.tokenizer.vocab_size:] = -float("inf")
65
 
 
72
  new_ids_sampled = torch.multinomial(
73
  torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temperature, dim=-1), 1)
74
 
75
+ print(f"Step {cur_token_idx + 1}: New IDs sampled: {new_ids_sampled}")
76
+
77
  # Assign the new id to the last token
78
  if last_token_idx + 1 >= len(base_answer_ids):
79
  # Add padding everywhere
 
101
  if streamer is not None:
102
  streamer.put(new_ids_sampled)
103
 
104
+ print("Generated Token IDs shape:", generated_token_ids.shape)
105
+ print("Generated Token IDs:", generated_token_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
+ return generated_token_ids