RaushanTurganbay HF Staff commited on
Commit
b991a3b
·
verified ·
1 Parent(s): c1cd11a

Update custom_generate/generate.py

Browse files

Makes sure that the core `generate()` gets called which prepares generation config, model kwargs, logits processors and so on. Currently the custom generation fails with some models due to model kwargs being incomplete, e.g. gemma3

Files changed (1) hide show
  1. custom_generate/generate.py +23 -18
custom_generate/generate.py CHANGED
@@ -17,12 +17,12 @@ from transformers.generation.utils import (
17
  )
18
 
19
 
20
- def generate(
21
  model: Any,
22
  input_ids: torch.LongTensor,
23
- logits_processor: Optional[LogitsProcessorList] = None,
24
- stopping_criteria: Optional[StoppingCriteriaList] = None,
25
- generation_config: Optional[GenerationConfig] = None,
26
  synced_gpus: bool = False,
27
  streamer: Optional[Any] = None,
28
  **model_kwargs,
@@ -44,12 +44,6 @@ def generate(
44
  depending on `return_dict_in_generate` and model type.
45
  """
46
 
47
- # Ensure processors/criteria are defined
48
- if logits_processor is None:
49
- logits_processor = LogitsProcessorList()
50
- if stopping_criteria is None:
51
- stopping_criteria = StoppingCriteriaList()
52
-
53
  # Get DeepCONF parameters from generation_config or set defaults
54
  enable_conf = getattr(generation_config, "enable_conf", False)
55
  enable_early_stopping = getattr(generation_config, "enable_early_stopping", True) # NEW: Allow disabling early stopping
@@ -75,14 +69,7 @@ def generate(
75
 
76
  # Initialize values
77
  # Handle pad token properly (following HF best practices)
78
- pad_token_id = generation_config.pad_token_id
79
- if pad_token_id is None and hasattr(generation_config, "_pad_token_tensor"):
80
- pad_token_id = generation_config._pad_token_tensor
81
- if pad_token_id is None and hasattr(model.config, "pad_token_id"):
82
- pad_token_id = model.config.pad_token_id
83
- if pad_token_id is None and generation_config.eos_token_id is not None:
84
- # Use eos token as pad token if not set
85
- pad_token_id = generation_config.eos_token_id
86
 
87
  output_attentions = generation_config.output_attentions
88
  output_hidden_states = generation_config.output_hidden_states
@@ -383,3 +370,21 @@ def generate(
383
  return output
384
  else:
385
  return input_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  )
18
 
19
 
20
+ def _deepconf_generate(
21
  model: Any,
22
  input_ids: torch.LongTensor,
23
+ logits_processor: Optional[LogitsProcessorList],
24
+ stopping_criteria: Optional[StoppingCriteriaList],
25
+ generation_config: Optional[GenerationConfig],
26
  synced_gpus: bool = False,
27
  streamer: Optional[Any] = None,
28
  **model_kwargs,
 
44
  depending on `return_dict_in_generate` and model type.
45
  """
46
 
 
 
 
 
 
 
47
  # Get DeepCONF parameters from generation_config or set defaults
48
  enable_conf = getattr(generation_config, "enable_conf", False)
49
  enable_early_stopping = getattr(generation_config, "enable_early_stopping", True) # NEW: Allow disabling early stopping
 
69
 
70
  # Initialize values
71
  # Handle pad token properly (following HF best practices)
72
+ pad_token_id = generation_config._pad_token_tensor
 
 
 
 
 
 
 
73
 
74
  output_attentions = generation_config.output_attentions
75
  output_hidden_states = generation_config.output_hidden_states
 
370
  return output
371
  else:
372
  return input_ids
373
+
374
+
375
+ def generate(model, *args, **kwargs):
376
+ """Custom generate function for group beam search decoding.
377
+ Args:
378
+ model (`PreTrainedModel`):
379
+ The model to generate from.
380
+ num_beams (`int`): The number of beams to use for beam search.
381
+ num_beam_groups (`int`): The number of beam groups to use for beam search.
382
+ length_penalty (`float`): The length penalty to use for beam search.
383
+ early_stopping (`bool`): Whether to stop beam search when sufficient beams have finished.
384
+ num_return_sequences (`int`): The number of sequences to return.
385
+ max_length (`int`): The maximum length of the generated sequence.
386
+ """
387
+ generation_outputs = GenerationMixin.generate(
388
+ model, *args, custom_generate=_deepconf_generate, **kwargs
389
+ )
390
+ return generation_outputs