File size: 5,079 Bytes
c0dd54c
 
dc34aea
 
 
 
 
c0dd54c
 
 
 
dc34aea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0dd54c
 
 
 
 
dc34aea
 
 
 
c0dd54c
dc34aea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0dd54c
 
 
 
 
 
 
 
 
 
 
 
 
dc34aea
c0dd54c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc34aea
 
c0dd54c
dc34aea
 
 
 
 
 
 
 
 
 
c0dd54c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import torch
from transformers.utils import logging
from transformers.generation.utils import (
    GenerationMixin,
    validate_stopping_criteria,
    StoppingCriteriaList,
)

logger = logging.get_logger(__name__)

@torch.no_grad()
def custom_generate(model, input_ids, attention_mask, max_new_tokens, streamer, **kwargs):
    finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=input_ids.device)
    for cur_token_idx in range(max_new_tokens):
        # Sample the next token
        new_ids = model(
            input_ids[~finished_generating],
            attention_mask=attention_mask[~finished_generating]
        )['logits']
        # Mask out the start and end thought tokens so we don't accidentally sample them
        new_ids[:, :, model.tokenizer.vocab_size:] = -float("inf")
        for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]):
            # Find the index of the last token that is not padding
            base_answer_ids = input_ids[answer_idx]
            new_answer_ids = new_ids[list_idx]
            last_token_idx = (base_answer_ids != model.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
            new_ids_sampled = torch.multinomial(
                torch.nn.functional.softmax(new_answer_ids[last_token_idx] / kwargs.get("temperature", 1.0), dim=-1), 1)
            # Assign the new id to the last token
            if last_token_idx + 1 >= len(base_answer_ids):
                # Add padding everywhere
                new_padding = torch.full((len(input_ids), 1), model.tokenizer.pad_token_id, dtype=torch.long,
                                        device=input_ids.device)
                input_ids = torch.cat([input_ids, new_padding], dim=-1)
                attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
            attention_mask[answer_idx, last_token_idx + 1] = 1
            input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
            if new_ids_sampled == model.tokenizer.eos_token_id or new_ids_sampled == model.tokenizer.bos_token_id or new_ids_sampled == model.tokenizer.pad_token_id:
                finished_generating[answer_idx] = 1
            # Check if the end token is generated
            if new_ids_sampled == model.tokenizer.convert_tokens_to_ids("<|/assistant|>"):
                finished_generating[answer_idx] = 1
        if finished_generating.all():
            break
        streamer.put(new_ids_sampled)
    return input_ids, attention_mask

def generate(
    self,
    input_ids,
    attention_mask=None,
    max_length=None,
    min_length=None,
    do_sample=None,
    early_stopping=None,
    num_beams=None,
    temperature=1.0,
    top_k=None,
    top_p=None,
    repetition_penalty=None,
    bad_words_ids=None,
    bos_token_id=None,
    pad_token_id=None,
    eos_token_id=None,
    length_penalty=None,
    no_repeat_ngram_size=None,
    num_return_sequences=None,
    decoder_start_token_id=None,
    use_cache=None,
    num_beam_groups=None,
    diversity_penalty=None,
    prefix_allowed_tokens_fn=None,
    output_attentions=None,
    output_hidden_states=None,
    output_scores=None,
    return_dict_in_generate=None,
    forced_bos_token_id=None,
    forced_eos_token_id=None,
    remove_invalid_values=None,
    synced_gpus=None,
    n_ahead=4,
    n_ahead_talk=4,
    merged_talk_heads=True,
    merged_lm_and_talk_heads=False,
    merged_lm_and_think_heads=True,
    use_concat_talk_head=True,
    use_shallow_think=True,
    use_shallow_talk=False,
    use_complex_think_head=False,
    use_complex_talk_head=True,
    use_weighted_talk_head=True,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    **model_kwargs,
):
    # Set model attributes
    self.max_thoughts = n_ahead + n_ahead_talk + 1
    self.merged_talk_heads = merged_talk_heads
    self.merged_lm_and_talk_heads = merged_lm_and_talk_heads
    self.merged_lm_and_think_heads = merged_lm_and_think_heads
    self.use_concat_talk_head = use_concat_talk_head
    self.use_shallow_think = use_shallow_think
    self.use_shallow_talk = use_shallow_talk
    self.use_complex_think_head = use_complex_think_head
    self.use_complex_talk_head = use_complex_talk_head
    self.use_weighted_talk_head = use_weighted_talk_head

    # Set model properties
    self.use_end_thought_token = True
    self.use_start_thought_token = True
    self.wandb_enabled = True
    self.n_ahead = n_ahead
    self.n_passes = 1
    self.eval_mode = True
    self.first_run = False
    self.kill_after = 100
    self.rm_initialized = True
    self.original_mode = False

    # Initialize a TextStreamer for streaming the generated text
    streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)

    # Generate using the custom generate function
    input_ids, attention_mask = custom_generate(
        self,
        input_ids,
        attention_mask,
        max_length,
        streamer,
        temperature=temperature,
        **model_kwargs,
    )

    return input_ids, attention_mask