Crystalcareai commited on
Commit
c0dd54c
·
verified ·
1 Parent(s): bca4d85

Create generate.py

Browse files
Files changed (1) hide show
  1. generate.py +88 -0
generate.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers.utils import logging
3
+
4
+ logger = logging.get_logger(__name__)
5
+
6
+ @torch.no_grad()
7
+ def generate(
8
+ self,
9
+ input_ids,
10
+ attention_mask=None,
11
+ max_length=None,
12
+ temperature=1.0,
13
+ n_ahead=4,
14
+ n_ahead_talk=4,
15
+ merged_talk_heads=True,
16
+ merged_lm_and_talk_heads=False,
17
+ merged_lm_and_think_heads=True,
18
+ use_concat_talk_head=True,
19
+ use_shallow_think=True,
20
+ use_shallow_talk=False,
21
+ use_complex_think_head=False,
22
+ use_complex_talk_head=True,
23
+ use_weighted_talk_head=True,
24
+ trust_remote_code=True,
25
+ torch_dtype=torch.bfloat16,
26
+ **kwargs
27
+ ):
28
+ batch_size, seq_length = input_ids.shape
29
+
30
+ # Set model attributes
31
+ self.max_thoughts = n_ahead + n_ahead_talk + 1
32
+ self.merged_talk_heads = merged_talk_heads
33
+ self.merged_lm_and_talk_heads = merged_lm_and_talk_heads
34
+ self.merged_lm_and_think_heads = merged_lm_and_think_heads
35
+ self.use_concat_talk_head = use_concat_talk_head
36
+ self.use_shallow_think = use_shallow_think
37
+ self.use_shallow_talk = use_shallow_talk
38
+ self.use_complex_think_head = use_complex_think_head
39
+ self.use_complex_talk_head = use_complex_talk_head
40
+ self.use_weighted_talk_head = use_weighted_talk_head
41
+
42
+ # Set model properties
43
+ self.use_end_thought_token = True
44
+ self.use_start_thought_token = True
45
+ self.wandb_enabled = True
46
+ self.n_ahead = n_ahead
47
+ self.n_passes = 1
48
+ self.eval_mode = True
49
+ self.first_run = False
50
+ self.kill_after = 100
51
+ self.rm_initialized = True
52
+ self.original_mode = False
53
+
54
+ # Keep track of which sequences have finished generating
55
+ finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
56
+
57
+ while input_ids.shape[-1] < max_length:
58
+ # Get the model outputs
59
+ model_outputs = self(
60
+ input_ids[~finished_generating],
61
+ attention_mask=attention_mask[~finished_generating] if attention_mask is not None else None,
62
+ **kwargs
63
+ )
64
+ logits = model_outputs.logits[:, -1, :]
65
+
66
+ # Apply temperature scaling
67
+ logits = logits / temperature
68
+
69
+ # Sample the next token
70
+ next_token_logits = logits
71
+ next_token_id = torch.multinomial(torch.softmax(next_token_logits, dim=-1), num_samples=1).squeeze(-1)
72
+
73
+ # Assign the sampled token to the sequences that are still generating
74
+ input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1)], dim=-1)
75
+
76
+ # Update the attention mask if provided
77
+ if attention_mask is not None:
78
+ attention_mask = torch.cat([attention_mask, torch.ones_like(next_token_id.unsqueeze(-1))], dim=-1)
79
+
80
+ # Mark sequences as finished if the end token is generated
81
+ finished_generating = finished_generating | (next_token_id == self.tokenizer.eos_token_id)
82
+
83
+ # Stop generation if all sequences are finished
84
+ if finished_generating.all():
85
+ break
86
+
87
+ # Return the generated token IDs and attention mask
88
+ return input_ids, attention_mask