Crystalcareai commited on
Commit
f21015a
·
verified ·
1 Parent(s): 9564a62

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +3 -8
generate.py CHANGED
@@ -123,7 +123,8 @@ def generate(
123
  synced_gpus=None,
124
  **model_kwargs,
125
  ):
126
- generation_kwargs = dict(
 
127
  input_ids=input_ids,
128
  attention_mask=attention_mask,
129
  max_length=max_length,
@@ -147,7 +148,6 @@ def generate(
147
  num_beam_groups=num_beam_groups,
148
  diversity_penalty=diversity_penalty,
149
  prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
150
- stopping_criteria=stopping_criteria,
151
  output_attentions=output_attentions,
152
  output_hidden_states=output_hidden_states,
153
  output_scores=output_scores,
@@ -157,9 +157,4 @@ def generate(
157
  remove_invalid_values=remove_invalid_values,
158
  synced_gpus=synced_gpus,
159
  **model_kwargs,
160
- )
161
-
162
- generation_output = GenerationMixin.generate(self, **generation_kwargs)
163
-
164
- # Return the generated token IDs as a list
165
- return generation_output.tolist()
 
123
  synced_gpus=None,
124
  **model_kwargs,
125
  ):
126
+ return custom_generate(
127
+ self,
128
  input_ids=input_ids,
129
  attention_mask=attention_mask,
130
  max_length=max_length,
 
148
  num_beam_groups=num_beam_groups,
149
  diversity_penalty=diversity_penalty,
150
  prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
 
151
  output_attentions=output_attentions,
152
  output_hidden_states=output_hidden_states,
153
  output_scores=output_scores,
 
157
  remove_invalid_values=remove_invalid_values,
158
  synced_gpus=synced_gpus,
159
  **model_kwargs,
160
+ )