File size: 5,079 Bytes
c0dd54c dc34aea c0dd54c dc34aea c0dd54c dc34aea c0dd54c dc34aea c0dd54c dc34aea c0dd54c dc34aea c0dd54c dc34aea c0dd54c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import torch
from transformers.utils import logging
from transformers.generation.utils import (
GenerationMixin,
validate_stopping_criteria,
StoppingCriteriaList,
)
logger = logging.get_logger(__name__)
@torch.no_grad()
def custom_generate(model, input_ids, attention_mask, max_new_tokens, streamer, **kwargs):
finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=input_ids.device)
for cur_token_idx in range(max_new_tokens):
# Sample the next token
new_ids = model(
input_ids[~finished_generating],
attention_mask=attention_mask[~finished_generating]
)['logits']
# Mask out the start and end thought tokens so we don't accidentally sample them
new_ids[:, :, model.tokenizer.vocab_size:] = -float("inf")
for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]):
# Find the index of the last token that is not padding
base_answer_ids = input_ids[answer_idx]
new_answer_ids = new_ids[list_idx]
last_token_idx = (base_answer_ids != model.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
new_ids_sampled = torch.multinomial(
torch.nn.functional.softmax(new_answer_ids[last_token_idx] / kwargs.get("temperature", 1.0), dim=-1), 1)
# Assign the new id to the last token
if last_token_idx + 1 >= len(base_answer_ids):
# Add padding everywhere
new_padding = torch.full((len(input_ids), 1), model.tokenizer.pad_token_id, dtype=torch.long,
device=input_ids.device)
input_ids = torch.cat([input_ids, new_padding], dim=-1)
attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
attention_mask[answer_idx, last_token_idx + 1] = 1
input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
if new_ids_sampled == model.tokenizer.eos_token_id or new_ids_sampled == model.tokenizer.bos_token_id or new_ids_sampled == model.tokenizer.pad_token_id:
finished_generating[answer_idx] = 1
# Check if the end token is generated
if new_ids_sampled == model.tokenizer.convert_tokens_to_ids("<|/assistant|>"):
finished_generating[answer_idx] = 1
if finished_generating.all():
break
streamer.put(new_ids_sampled)
return input_ids, attention_mask
def generate(
self,
input_ids,
attention_mask=None,
max_length=None,
min_length=None,
do_sample=None,
early_stopping=None,
num_beams=None,
temperature=1.0,
top_k=None,
top_p=None,
repetition_penalty=None,
bad_words_ids=None,
bos_token_id=None,
pad_token_id=None,
eos_token_id=None,
length_penalty=None,
no_repeat_ngram_size=None,
num_return_sequences=None,
decoder_start_token_id=None,
use_cache=None,
num_beam_groups=None,
diversity_penalty=None,
prefix_allowed_tokens_fn=None,
output_attentions=None,
output_hidden_states=None,
output_scores=None,
return_dict_in_generate=None,
forced_bos_token_id=None,
forced_eos_token_id=None,
remove_invalid_values=None,
synced_gpus=None,
n_ahead=4,
n_ahead_talk=4,
merged_talk_heads=True,
merged_lm_and_talk_heads=False,
merged_lm_and_think_heads=True,
use_concat_talk_head=True,
use_shallow_think=True,
use_shallow_talk=False,
use_complex_think_head=False,
use_complex_talk_head=True,
use_weighted_talk_head=True,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
**model_kwargs,
):
# Set model attributes
self.max_thoughts = n_ahead + n_ahead_talk + 1
self.merged_talk_heads = merged_talk_heads
self.merged_lm_and_talk_heads = merged_lm_and_talk_heads
self.merged_lm_and_think_heads = merged_lm_and_think_heads
self.use_concat_talk_head = use_concat_talk_head
self.use_shallow_think = use_shallow_think
self.use_shallow_talk = use_shallow_talk
self.use_complex_think_head = use_complex_think_head
self.use_complex_talk_head = use_complex_talk_head
self.use_weighted_talk_head = use_weighted_talk_head
# Set model properties
self.use_end_thought_token = True
self.use_start_thought_token = True
self.wandb_enabled = True
self.n_ahead = n_ahead
self.n_passes = 1
self.eval_mode = True
self.first_run = False
self.kill_after = 100
self.rm_initialized = True
self.original_mode = False
# Initialize a TextStreamer for streaming the generated text
streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
# Generate using the custom generate function
input_ids, attention_mask = custom_generate(
self,
input_ids,
attention_mask,
max_length,
streamer,
temperature=temperature,
**model_kwargs,
)
return input_ids, attention_mask |