|
import torch |
|
from transformers.utils import logging |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
@torch.no_grad() |
|
def generate( |
|
self, |
|
input_ids, |
|
attention_mask=None, |
|
max_length=None, |
|
temperature=1.0, |
|
n_ahead=4, |
|
n_ahead_talk=4, |
|
merged_talk_heads=True, |
|
merged_lm_and_talk_heads=False, |
|
merged_lm_and_think_heads=True, |
|
use_concat_talk_head=True, |
|
use_shallow_think=True, |
|
use_shallow_talk=False, |
|
use_complex_think_head=False, |
|
use_complex_talk_head=True, |
|
use_weighted_talk_head=True, |
|
trust_remote_code=True, |
|
torch_dtype=torch.bfloat16, |
|
**kwargs |
|
): |
|
batch_size, seq_length = input_ids.shape |
|
|
|
|
|
self.max_thoughts = n_ahead + n_ahead_talk + 1 |
|
self.merged_talk_heads = merged_talk_heads |
|
self.merged_lm_and_talk_heads = merged_lm_and_talk_heads |
|
self.merged_lm_and_think_heads = merged_lm_and_think_heads |
|
self.use_concat_talk_head = use_concat_talk_head |
|
self.use_shallow_think = use_shallow_think |
|
self.use_shallow_talk = use_shallow_talk |
|
self.use_complex_think_head = use_complex_think_head |
|
self.use_complex_talk_head = use_complex_talk_head |
|
self.use_weighted_talk_head = use_weighted_talk_head |
|
|
|
|
|
self.use_end_thought_token = True |
|
self.use_start_thought_token = True |
|
self.wandb_enabled = True |
|
self.n_ahead = n_ahead |
|
self.n_passes = 1 |
|
self.eval_mode = True |
|
self.first_run = False |
|
self.kill_after = 100 |
|
self.rm_initialized = True |
|
self.original_mode = False |
|
|
|
|
|
finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device) |
|
|
|
while input_ids.shape[-1] < max_length: |
|
|
|
model_outputs = self( |
|
input_ids[~finished_generating], |
|
attention_mask=attention_mask[~finished_generating] if attention_mask is not None else None, |
|
**kwargs |
|
) |
|
logits = model_outputs.logits[:, -1, :] |
|
|
|
|
|
logits = logits / temperature |
|
|
|
|
|
next_token_logits = logits |
|
next_token_id = torch.multinomial(torch.softmax(next_token_logits, dim=-1), num_samples=1).squeeze(-1) |
|
|
|
|
|
input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1)], dim=-1) |
|
|
|
|
|
if attention_mask is not None: |
|
attention_mask = torch.cat([attention_mask, torch.ones_like(next_token_id.unsqueeze(-1))], dim=-1) |
|
|
|
|
|
finished_generating = finished_generating | (next_token_id == self.tokenizer.eos_token_id) |
|
|
|
|
|
if finished_generating.all(): |
|
break |
|
|
|
|
|
return input_ids, attention_mask |