Update generate.py
Browse files- generate.py +3 -8
generate.py
CHANGED
@@ -123,7 +123,8 @@ def generate(
|
|
123 |
synced_gpus=None,
|
124 |
**model_kwargs,
|
125 |
):
|
126 |
-
|
|
|
127 |
input_ids=input_ids,
|
128 |
attention_mask=attention_mask,
|
129 |
max_length=max_length,
|
@@ -147,7 +148,6 @@ def generate(
|
|
147 |
num_beam_groups=num_beam_groups,
|
148 |
diversity_penalty=diversity_penalty,
|
149 |
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
150 |
-
stopping_criteria=stopping_criteria,
|
151 |
output_attentions=output_attentions,
|
152 |
output_hidden_states=output_hidden_states,
|
153 |
output_scores=output_scores,
|
@@ -157,9 +157,4 @@ def generate(
|
|
157 |
remove_invalid_values=remove_invalid_values,
|
158 |
synced_gpus=synced_gpus,
|
159 |
**model_kwargs,
|
160 |
-
)
|
161 |
-
|
162 |
-
generation_output = GenerationMixin.generate(self, **generation_kwargs)
|
163 |
-
|
164 |
-
# Return the generated token IDs as a list
|
165 |
-
return generation_output.tolist()
|
|
|
123 |
synced_gpus=None,
|
124 |
**model_kwargs,
|
125 |
):
|
126 |
+
return custom_generate(
|
127 |
+
self,
|
128 |
input_ids=input_ids,
|
129 |
attention_mask=attention_mask,
|
130 |
max_length=max_length,
|
|
|
148 |
num_beam_groups=num_beam_groups,
|
149 |
diversity_penalty=diversity_penalty,
|
150 |
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
|
|
151 |
output_attentions=output_attentions,
|
152 |
output_hidden_states=output_hidden_states,
|
153 |
output_scores=output_scores,
|
|
|
157 |
remove_invalid_values=remove_invalid_values,
|
158 |
synced_gpus=synced_gpus,
|
159 |
**model_kwargs,
|
160 |
+
)
|
|
|
|
|
|
|
|
|
|