Crystalcareai commited on
Commit
dc34aea
·
verified ·
1 Parent(s): cfd3f6f

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +81 -35
generate.py CHANGED
@@ -1,15 +1,83 @@
1
  import torch
2
  from transformers.utils import logging
 
 
 
 
 
3
 
4
  logger = logging.get_logger(__name__)
5
 
6
  @torch.no_grad()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  def generate(
8
  self,
9
  input_ids,
10
  attention_mask=None,
11
  max_length=None,
 
 
 
 
12
  temperature=1.0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  n_ahead=4,
14
  n_ahead_talk=4,
15
  merged_talk_heads=True,
@@ -23,10 +91,8 @@ def generate(
23
  use_weighted_talk_head=True,
24
  trust_remote_code=True,
25
  torch_dtype=torch.bfloat16,
26
- **kwargs
27
  ):
28
- batch_size, seq_length = input_ids.shape
29
-
30
  # Set model attributes
31
  self.max_thoughts = n_ahead + n_ahead_talk + 1
32
  self.merged_talk_heads = merged_talk_heads
@@ -51,38 +117,18 @@ def generate(
51
  self.rm_initialized = True
52
  self.original_mode = False
53
 
54
- # Keep track of which sequences have finished generating
55
- finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
56
-
57
- while input_ids.shape[-1] < max_length:
58
- # Get the model outputs
59
- model_outputs = self(
60
- input_ids[~finished_generating],
61
- attention_mask=attention_mask[~finished_generating] if attention_mask is not None else None,
62
- **kwargs
63
- )
64
- logits = model_outputs.logits[:, -1, :]
65
-
66
- # Apply temperature scaling
67
- logits = logits / temperature
68
-
69
- # Sample the next token
70
- next_token_logits = logits
71
- next_token_id = torch.multinomial(torch.softmax(next_token_logits, dim=-1), num_samples=1).squeeze(-1)
72
-
73
- # Assign the sampled token to the sequences that are still generating
74
- input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1)], dim=-1)
75
 
76
- # Update the attention mask if provided
77
- if attention_mask is not None:
78
- attention_mask = torch.cat([attention_mask, torch.ones_like(next_token_id.unsqueeze(-1))], dim=-1)
79
-
80
- # Mark sequences as finished if the end token is generated
81
- finished_generating = finished_generating | (next_token_id == self.tokenizer.eos_token_id)
82
-
83
- # Stop generation if all sequences are finished
84
- if finished_generating.all():
85
- break
86
 
87
- # Return the generated token IDs and attention mask
88
  return input_ids, attention_mask
 
1
  import torch
2
  from transformers.utils import logging
3
+ from transformers.generation.utils import (
4
+ GenerationMixin,
5
+ validate_stopping_criteria,
6
+ StoppingCriteriaList,
7
+ )
8
 
9
  logger = logging.get_logger(__name__)
10
 
11
  @torch.no_grad()
12
+ def custom_generate(model, input_ids, attention_mask, max_new_tokens, streamer, **kwargs):
13
+ finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=input_ids.device)
14
+ for cur_token_idx in range(max_new_tokens):
15
+ # Sample the next token
16
+ new_ids = model(
17
+ input_ids[~finished_generating],
18
+ attention_mask=attention_mask[~finished_generating]
19
+ )['logits']
20
+ # Mask out the start and end thought tokens so we don't accidentally sample them
21
+ new_ids[:, :, model.tokenizer.vocab_size:] = -float("inf")
22
+ for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]):
23
+ # Find the index of the last token that is not padding
24
+ base_answer_ids = input_ids[answer_idx]
25
+ new_answer_ids = new_ids[list_idx]
26
+ last_token_idx = (base_answer_ids != model.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
27
+ new_ids_sampled = torch.multinomial(
28
+ torch.nn.functional.softmax(new_answer_ids[last_token_idx] / kwargs.get("temperature", 1.0), dim=-1), 1)
29
+ # Assign the new id to the last token
30
+ if last_token_idx + 1 >= len(base_answer_ids):
31
+ # Add padding everywhere
32
+ new_padding = torch.full((len(input_ids), 1), model.tokenizer.pad_token_id, dtype=torch.long,
33
+ device=input_ids.device)
34
+ input_ids = torch.cat([input_ids, new_padding], dim=-1)
35
+ attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
36
+ attention_mask[answer_idx, last_token_idx + 1] = 1
37
+ input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
38
+ if new_ids_sampled == model.tokenizer.eos_token_id or new_ids_sampled == model.tokenizer.bos_token_id or new_ids_sampled == model.tokenizer.pad_token_id:
39
+ finished_generating[answer_idx] = 1
40
+ # Check if the end token is generated
41
+ if new_ids_sampled == model.tokenizer.convert_tokens_to_ids("<|/assistant|>"):
42
+ finished_generating[answer_idx] = 1
43
+ if finished_generating.all():
44
+ break
45
+ streamer.put(new_ids_sampled)
46
+ return input_ids, attention_mask
47
+
48
  def generate(
49
  self,
50
  input_ids,
51
  attention_mask=None,
52
  max_length=None,
53
+ min_length=None,
54
+ do_sample=None,
55
+ early_stopping=None,
56
+ num_beams=None,
57
  temperature=1.0,
58
+ top_k=None,
59
+ top_p=None,
60
+ repetition_penalty=None,
61
+ bad_words_ids=None,
62
+ bos_token_id=None,
63
+ pad_token_id=None,
64
+ eos_token_id=None,
65
+ length_penalty=None,
66
+ no_repeat_ngram_size=None,
67
+ num_return_sequences=None,
68
+ decoder_start_token_id=None,
69
+ use_cache=None,
70
+ num_beam_groups=None,
71
+ diversity_penalty=None,
72
+ prefix_allowed_tokens_fn=None,
73
+ output_attentions=None,
74
+ output_hidden_states=None,
75
+ output_scores=None,
76
+ return_dict_in_generate=None,
77
+ forced_bos_token_id=None,
78
+ forced_eos_token_id=None,
79
+ remove_invalid_values=None,
80
+ synced_gpus=None,
81
  n_ahead=4,
82
  n_ahead_talk=4,
83
  merged_talk_heads=True,
 
91
  use_weighted_talk_head=True,
92
  trust_remote_code=True,
93
  torch_dtype=torch.bfloat16,
94
+ **model_kwargs,
95
  ):
 
 
96
  # Set model attributes
97
  self.max_thoughts = n_ahead + n_ahead_talk + 1
98
  self.merged_talk_heads = merged_talk_heads
 
117
  self.rm_initialized = True
118
  self.original_mode = False
119
 
120
+ # Initialize a TextStreamer for streaming the generated text
121
+ streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
+ # Generate using the custom generate function
124
+ input_ids, attention_mask = custom_generate(
125
+ self,
126
+ input_ids,
127
+ attention_mask,
128
+ max_length,
129
+ streamer,
130
+ temperature=temperature,
131
+ **model_kwargs,
132
+ )
133
 
 
134
  return input_ids, attention_mask