Crystalcareai commited on
Commit
29d3cfe
·
verified ·
1 Parent(s): 47f9089

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +9 -9
generate.py CHANGED
@@ -80,7 +80,7 @@ def custom_generate(
80
  if last_token_idx + 1 >= len(base_answer_ids):
81
  # Add padding everywhere
82
  new_padding = torch.full((batch_size, 1), self.tokenizer.pad_token_id, dtype=torch.long,
83
- device=device)
84
  input_ids = torch.cat([input_ids, new_padding], dim=-1)
85
  if attention_mask is not None:
86
  attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
@@ -103,8 +103,10 @@ def custom_generate(
103
  if streamer is not None:
104
  streamer.put(new_ids_sampled)
105
 
106
- return generated_token_ids
 
107
 
 
108
 
109
  def generate(
110
  self,
@@ -158,8 +160,8 @@ def generate(
158
  ):
159
 
160
  if max_new_tokens is None:
161
- max_new_tokens = 128
162
-
163
  # Set model attributes
164
  self.max_thoughts = n_ahead + n_ahead_talk + 1
165
  self.merged_talk_heads = merged_talk_heads
@@ -191,9 +193,9 @@ def generate(
191
  if attention_mask is not None:
192
  attention_mask = attention_mask.to(self.device)
193
 
194
- generated_token_ids = custom_generate(
195
  self,
196
- input_ids=input_ids,
197
  attention_mask=attention_mask,
198
  max_new_tokens=max_new_tokens,
199
  min_length=min_length,
@@ -228,6 +230,4 @@ def generate(
228
  **model_kwargs,
229
  )
230
 
231
-
232
- generated_text = self.tokenizer.decode(generated_token_ids[0], skip_special_tokens=False)
233
- return generated_token_ids, generated_text
 
80
  if last_token_idx + 1 >= len(base_answer_ids):
81
  # Add padding everywhere
82
  new_padding = torch.full((batch_size, 1), self.tokenizer.pad_token_id, dtype=torch.long,
83
+ device=device)
84
  input_ids = torch.cat([input_ids, new_padding], dim=-1)
85
  if attention_mask is not None:
86
  attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
 
103
  if streamer is not None:
104
  streamer.put(new_ids_sampled)
105
 
106
+ # Convert generated token IDs to text
107
+ generated_text = self.tokenizer.decode(generated_token_ids[0], skip_special_tokens=False)
108
 
109
+ return generated_token_ids, generated_text
110
 
111
  def generate(
112
  self,
 
160
  ):
161
 
162
  if max_new_tokens is None:
163
+ max_new_tokens = 128
164
+
165
  # Set model attributes
166
  self.max_thoughts = n_ahead + n_ahead_talk + 1
167
  self.merged_talk_heads = merged_talk_heads
 
193
  if attention_mask is not None:
194
  attention_mask = attention_mask.to(self.device)
195
 
196
+ generated_token_ids, generated_text = custom_generate(
197
  self,
198
+ input_ids=input_ids,
199
  attention_mask=attention_mask,
200
  max_new_tokens=max_new_tokens,
201
  min_length=min_length,
 
230
  **model_kwargs,
231
  )
232
 
233
+ return generated_token_ids, generated_text