Quiet-Star-Custom / generate.py
Crystalcareai's picture
Create generate.py
c0dd54c verified
raw
history blame
3.02 kB
import torch
from transformers.utils import logging
logger = logging.get_logger(__name__)
@torch.no_grad()
def generate(
self,
input_ids,
attention_mask=None,
max_length=None,
temperature=1.0,
n_ahead=4,
n_ahead_talk=4,
merged_talk_heads=True,
merged_lm_and_talk_heads=False,
merged_lm_and_think_heads=True,
use_concat_talk_head=True,
use_shallow_think=True,
use_shallow_talk=False,
use_complex_think_head=False,
use_complex_talk_head=True,
use_weighted_talk_head=True,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
**kwargs
):
batch_size, seq_length = input_ids.shape
# Set model attributes
self.max_thoughts = n_ahead + n_ahead_talk + 1
self.merged_talk_heads = merged_talk_heads
self.merged_lm_and_talk_heads = merged_lm_and_talk_heads
self.merged_lm_and_think_heads = merged_lm_and_think_heads
self.use_concat_talk_head = use_concat_talk_head
self.use_shallow_think = use_shallow_think
self.use_shallow_talk = use_shallow_talk
self.use_complex_think_head = use_complex_think_head
self.use_complex_talk_head = use_complex_talk_head
self.use_weighted_talk_head = use_weighted_talk_head
# Set model properties
self.use_end_thought_token = True
self.use_start_thought_token = True
self.wandb_enabled = True
self.n_ahead = n_ahead
self.n_passes = 1
self.eval_mode = True
self.first_run = False
self.kill_after = 100
self.rm_initialized = True
self.original_mode = False
# Keep track of which sequences have finished generating
finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
while input_ids.shape[-1] < max_length:
# Get the model outputs
model_outputs = self(
input_ids[~finished_generating],
attention_mask=attention_mask[~finished_generating] if attention_mask is not None else None,
**kwargs
)
logits = model_outputs.logits[:, -1, :]
# Apply temperature scaling
logits = logits / temperature
# Sample the next token
next_token_logits = logits
next_token_id = torch.multinomial(torch.softmax(next_token_logits, dim=-1), num_samples=1).squeeze(-1)
# Assign the sampled token to the sequences that are still generating
input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1)], dim=-1)
# Update the attention mask if provided
if attention_mask is not None:
attention_mask = torch.cat([attention_mask, torch.ones_like(next_token_id.unsqueeze(-1))], dim=-1)
# Mark sequences as finished if the end token is generated
finished_generating = finished_generating | (next_token_id == self.tokenizer.eos_token_id)
# Stop generation if all sequences are finished
if finished_generating.all():
break
# Return the generated token IDs and attention mask
return input_ids, attention_mask