import torch from transformers.utils import logging logger = logging.get_logger(__name__) @torch.no_grad() def generate( self, input_ids, attention_mask=None, max_length=None, temperature=1.0, n_ahead=4, n_ahead_talk=4, merged_talk_heads=True, merged_lm_and_talk_heads=False, merged_lm_and_think_heads=True, use_concat_talk_head=True, use_shallow_think=True, use_shallow_talk=False, use_complex_think_head=False, use_complex_talk_head=True, use_weighted_talk_head=True, trust_remote_code=True, torch_dtype=torch.bfloat16, **kwargs ): batch_size, seq_length = input_ids.shape # Set model attributes self.max_thoughts = n_ahead + n_ahead_talk + 1 self.merged_talk_heads = merged_talk_heads self.merged_lm_and_talk_heads = merged_lm_and_talk_heads self.merged_lm_and_think_heads = merged_lm_and_think_heads self.use_concat_talk_head = use_concat_talk_head self.use_shallow_think = use_shallow_think self.use_shallow_talk = use_shallow_talk self.use_complex_think_head = use_complex_think_head self.use_complex_talk_head = use_complex_talk_head self.use_weighted_talk_head = use_weighted_talk_head # Set model properties self.use_end_thought_token = True self.use_start_thought_token = True self.wandb_enabled = True self.n_ahead = n_ahead self.n_passes = 1 self.eval_mode = True self.first_run = False self.kill_after = 100 self.rm_initialized = True self.original_mode = False # Keep track of which sequences have finished generating finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device) while input_ids.shape[-1] < max_length: # Get the model outputs model_outputs = self( input_ids[~finished_generating], attention_mask=attention_mask[~finished_generating] if attention_mask is not None else None, **kwargs ) logits = model_outputs.logits[:, -1, :] # Apply temperature scaling logits = logits / temperature # Sample the next token next_token_logits = logits next_token_id = torch.multinomial(torch.softmax(next_token_logits, dim=-1), num_samples=1).squeeze(-1) # Assign the sampled token to the sequences that are still generating input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1)], dim=-1) # Update the attention mask if provided if attention_mask is not None: attention_mask = torch.cat([attention_mask, torch.ones_like(next_token_id.unsqueeze(-1))], dim=-1) # Mark sequences as finished if the end token is generated finished_generating = finished_generating | (next_token_id == self.tokenizer.eos_token_id) # Stop generation if all sequences are finished if finished_generating.all(): break # Return the generated token IDs and attention mask return input_ids, attention_mask