Update custom_generate/generate.py
Browse filesMakes sure that the core `generate()` gets called which prepares generation config, model kwargs, logits processors and so on. Currently the custom generation fails with some models due to model kwargs being incomplete, e.g. gemma3
- custom_generate/generate.py +23 -18
custom_generate/generate.py
CHANGED
|
@@ -17,12 +17,12 @@ from transformers.generation.utils import (
|
|
| 17 |
)
|
| 18 |
|
| 19 |
|
| 20 |
-
def
|
| 21 |
model: Any,
|
| 22 |
input_ids: torch.LongTensor,
|
| 23 |
-
logits_processor: Optional[LogitsProcessorList]
|
| 24 |
-
stopping_criteria: Optional[StoppingCriteriaList]
|
| 25 |
-
generation_config: Optional[GenerationConfig]
|
| 26 |
synced_gpus: bool = False,
|
| 27 |
streamer: Optional[Any] = None,
|
| 28 |
**model_kwargs,
|
|
@@ -44,12 +44,6 @@ def generate(
|
|
| 44 |
depending on `return_dict_in_generate` and model type.
|
| 45 |
"""
|
| 46 |
|
| 47 |
-
# Ensure processors/criteria are defined
|
| 48 |
-
if logits_processor is None:
|
| 49 |
-
logits_processor = LogitsProcessorList()
|
| 50 |
-
if stopping_criteria is None:
|
| 51 |
-
stopping_criteria = StoppingCriteriaList()
|
| 52 |
-
|
| 53 |
# Get DeepCONF parameters from generation_config or set defaults
|
| 54 |
enable_conf = getattr(generation_config, "enable_conf", False)
|
| 55 |
enable_early_stopping = getattr(generation_config, "enable_early_stopping", True) # NEW: Allow disabling early stopping
|
|
@@ -75,14 +69,7 @@ def generate(
|
|
| 75 |
|
| 76 |
# Initialize values
|
| 77 |
# Handle pad token properly (following HF best practices)
|
| 78 |
-
pad_token_id = generation_config.
|
| 79 |
-
if pad_token_id is None and hasattr(generation_config, "_pad_token_tensor"):
|
| 80 |
-
pad_token_id = generation_config._pad_token_tensor
|
| 81 |
-
if pad_token_id is None and hasattr(model.config, "pad_token_id"):
|
| 82 |
-
pad_token_id = model.config.pad_token_id
|
| 83 |
-
if pad_token_id is None and generation_config.eos_token_id is not None:
|
| 84 |
-
# Use eos token as pad token if not set
|
| 85 |
-
pad_token_id = generation_config.eos_token_id
|
| 86 |
|
| 87 |
output_attentions = generation_config.output_attentions
|
| 88 |
output_hidden_states = generation_config.output_hidden_states
|
|
@@ -383,3 +370,21 @@ def generate(
|
|
| 383 |
return output
|
| 384 |
else:
|
| 385 |
return input_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
)
|
| 18 |
|
| 19 |
|
| 20 |
+
def _deepconf_generate(
|
| 21 |
model: Any,
|
| 22 |
input_ids: torch.LongTensor,
|
| 23 |
+
logits_processor: Optional[LogitsProcessorList],
|
| 24 |
+
stopping_criteria: Optional[StoppingCriteriaList],
|
| 25 |
+
generation_config: Optional[GenerationConfig],
|
| 26 |
synced_gpus: bool = False,
|
| 27 |
streamer: Optional[Any] = None,
|
| 28 |
**model_kwargs,
|
|
|
|
| 44 |
depending on `return_dict_in_generate` and model type.
|
| 45 |
"""
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
# Get DeepCONF parameters from generation_config or set defaults
|
| 48 |
enable_conf = getattr(generation_config, "enable_conf", False)
|
| 49 |
enable_early_stopping = getattr(generation_config, "enable_early_stopping", True) # NEW: Allow disabling early stopping
|
|
|
|
| 69 |
|
| 70 |
# Initialize values
|
| 71 |
# Handle pad token properly (following HF best practices)
|
| 72 |
+
pad_token_id = generation_config._pad_token_tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
output_attentions = generation_config.output_attentions
|
| 75 |
output_hidden_states = generation_config.output_hidden_states
|
|
|
|
| 370 |
return output
|
| 371 |
else:
|
| 372 |
return input_ids
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def generate(model, *args, **kwargs):
|
| 376 |
+
"""Custom generate function for group beam search decoding.
|
| 377 |
+
Args:
|
| 378 |
+
model (`PreTrainedModel`):
|
| 379 |
+
The model to generate from.
|
| 380 |
+
num_beams (`int`): The number of beams to use for beam search.
|
| 381 |
+
num_beam_groups (`int`): The number of beam groups to use for beam search.
|
| 382 |
+
length_penalty (`float`): The length penalty to use for beam search.
|
| 383 |
+
early_stopping (`bool`): Whether to stop beam search when sufficient beams have finished.
|
| 384 |
+
num_return_sequences (`int`): The number of sequences to return.
|
| 385 |
+
max_length (`int`): The maximum length of the generated sequence.
|
| 386 |
+
"""
|
| 387 |
+
generation_outputs = GenerationMixin.generate(
|
| 388 |
+
model, *args, custom_generate=_deepconf_generate, **kwargs
|
| 389 |
+
)
|
| 390 |
+
return generation_outputs
|