Crystalcareai
commited on
Update generate.py
Browse files- generate.py +9 -9
generate.py
CHANGED
@@ -80,7 +80,7 @@ def custom_generate(
|
|
80 |
if last_token_idx + 1 >= len(base_answer_ids):
|
81 |
# Add padding everywhere
|
82 |
new_padding = torch.full((batch_size, 1), self.tokenizer.pad_token_id, dtype=torch.long,
|
83 |
-
|
84 |
input_ids = torch.cat([input_ids, new_padding], dim=-1)
|
85 |
if attention_mask is not None:
|
86 |
attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
|
@@ -103,8 +103,10 @@ def custom_generate(
|
|
103 |
if streamer is not None:
|
104 |
streamer.put(new_ids_sampled)
|
105 |
|
106 |
-
|
|
|
107 |
|
|
|
108 |
|
109 |
def generate(
|
110 |
self,
|
@@ -158,8 +160,8 @@ def generate(
|
|
158 |
):
|
159 |
|
160 |
if max_new_tokens is None:
|
161 |
-
max_new_tokens = 128
|
162 |
-
|
163 |
# Set model attributes
|
164 |
self.max_thoughts = n_ahead + n_ahead_talk + 1
|
165 |
self.merged_talk_heads = merged_talk_heads
|
@@ -191,9 +193,9 @@ def generate(
|
|
191 |
if attention_mask is not None:
|
192 |
attention_mask = attention_mask.to(self.device)
|
193 |
|
194 |
-
generated_token_ids = custom_generate(
|
195 |
self,
|
196 |
-
input_ids=input_ids,
|
197 |
attention_mask=attention_mask,
|
198 |
max_new_tokens=max_new_tokens,
|
199 |
min_length=min_length,
|
@@ -228,6 +230,4 @@ def generate(
|
|
228 |
**model_kwargs,
|
229 |
)
|
230 |
|
231 |
-
|
232 |
-
generated_text = self.tokenizer.decode(generated_token_ids[0], skip_special_tokens=False)
|
233 |
-
return generated_token_ids, generated_text
|
|
|
80 |
if last_token_idx + 1 >= len(base_answer_ids):
|
81 |
# Add padding everywhere
|
82 |
new_padding = torch.full((batch_size, 1), self.tokenizer.pad_token_id, dtype=torch.long,
|
83 |
+
device=device)
|
84 |
input_ids = torch.cat([input_ids, new_padding], dim=-1)
|
85 |
if attention_mask is not None:
|
86 |
attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
|
|
|
103 |
if streamer is not None:
|
104 |
streamer.put(new_ids_sampled)
|
105 |
|
106 |
+
# Convert generated token IDs to text
|
107 |
+
generated_text = self.tokenizer.decode(generated_token_ids[0], skip_special_tokens=False)
|
108 |
|
109 |
+
return generated_token_ids, generated_text
|
110 |
|
111 |
def generate(
|
112 |
self,
|
|
|
160 |
):
|
161 |
|
162 |
if max_new_tokens is None:
|
163 |
+
max_new_tokens = 128
|
164 |
+
|
165 |
# Set model attributes
|
166 |
self.max_thoughts = n_ahead + n_ahead_talk + 1
|
167 |
self.merged_talk_heads = merged_talk_heads
|
|
|
193 |
if attention_mask is not None:
|
194 |
attention_mask = attention_mask.to(self.device)
|
195 |
|
196 |
+
generated_token_ids, generated_text = custom_generate(
|
197 |
self,
|
198 |
+
input_ids=input_ids,
|
199 |
attention_mask=attention_mask,
|
200 |
max_new_tokens=max_new_tokens,
|
201 |
min_length=min_length,
|
|
|
230 |
**model_kwargs,
|
231 |
)
|
232 |
|
233 |
+
return generated_token_ids, generated_text
|
|
|
|