JonasGeiping commited on
Commit
972cea6
·
verified ·
1 Parent(s): fdc5fee

Update raven_modeling_minimal.py

Browse files
Files changed (1) hide show
  1. raven_modeling_minimal.py +1592 -88
raven_modeling_minimal.py CHANGED
@@ -1,99 +1,1603 @@
1
- """A HuggingFace-style model configuration."""
2
 
3
- from transformers import PretrainedConfig
4
- from math import sqrt
5
 
 
 
 
 
 
6
 
7
- class RavenConfig(PretrainedConfig):
8
- model_type = "huginn_raven"
9
- keys_to_ignore_at_inference = [""]
10
- attribute_map = {"num_attention_heads": "n_heads", "hidden_size": "n_embd", "num_hidden_layers": "n_layers"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def __init__(
13
  self,
14
- n_embd: int = 5280,
15
- n_heads: int = 55,
16
- n_layers: int = 8, # total of prelude + recurrent + coda
17
- block_size: int = 4096,
18
- vocab_size: int = 65536,
19
- padding_multiple: int = 4096,
20
- tie_embeddings: bool = True,
21
- intermediate_size: int = 17920,
22
- bias: bool = False,
23
- architecture_class_name: str = "RecurrentGPT",
24
- block_class_name: str = "SandwichBlock",
25
- norm_class_name: str = "RMSNorm_llama",
26
- norm_eps: float = 0.000001,
27
- mlp_class_name: str = "GatedMLP",
28
- nonlin_name: str = "SiLU",
29
- init_strategy: str = "takase",
30
- init_orthogonal: bool = False,
31
- state_init: str = "like-init",
32
- injection_type: str = "linear",
33
- n_layers_in_recurrent_block: int = 4,
34
- mean_recurrence: int = 32,
35
- sampling_scheme: str = "poisson-lognormal-filling",
36
- mean_backprop_depth: int = 8,
37
- n_layers_in_prelude: int = 2,
38
- n_layers_in_coda: int = 2,
39
- qk_bias: bool = True,
40
- activation_checkpoint_impl: str = "per-iteration",
41
- rope_base: float = 50_000,
42
- torch_dtype: str = "bfloat16",
43
- transformers_version: str = "4.47.1",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  **kwargs,
45
  ):
46
- self.n_embd = n_embd
47
- self.n_heads = n_heads
48
- self.n_layers = n_layers
49
- self.block_size = block_size
50
- self.vocab_size = self.padded_vocab_size = vocab_size
51
- self.padding_multiple = padding_multiple
52
- self.tie_embeddings = tie_embeddings
53
- self.intermediate_size = intermediate_size
54
- self.bias = bias
55
- self.architecture_class_name = architecture_class_name
56
- self.block_class_name = block_class_name
57
- self.norm_class_name = norm_class_name
58
- self.norm_eps = norm_eps
59
- self.mlp_class_name = mlp_class_name
60
- self.nonlin_name = nonlin_name
61
- self.init_strategy = init_strategy
62
- self.init_orthogonal = init_orthogonal
63
- self.state_init = state_init
64
- self.injection_type = injection_type
65
- self.n_layers_in_recurrent_block = n_layers_in_recurrent_block
66
- self.mean_recurrence = mean_recurrence
67
- self.sampling_scheme = sampling_scheme
68
- self.mean_backprop_depth = mean_backprop_depth
69
- self.n_layers_in_prelude = n_layers_in_prelude
70
- self.n_layers_in_coda = n_layers_in_coda
71
- self.qk_bias = qk_bias
72
- self.activation_checkpoint_impl = activation_checkpoint_impl
73
- self.rope_base = rope_base
74
- self.torch_dtype = torch_dtype # Added from JSON
75
- self.transformers_version = transformers_version # Added from JSON
76
- # inference
77
- self.test_time_noise = 0
78
- self.test_time_noise_type = "fixed"
79
- # Derived
80
- self.num_key_value_heads = n_heads
81
- self.num_attention_heads = n_heads
82
- self.head_dim = n_embd // n_heads
83
- self.effective_expected_depth = (
84
- self.n_layers_in_prelude + self.n_layers_in_coda + self.n_layers_in_recurrent_block * self.mean_recurrence
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  )
86
- self.init_values = {
87
- "std": sqrt(2 / (5 * self.n_embd)),
88
- "out_proj": sqrt(2 / (5 * self.n_embd)) / sqrt(2 * self.effective_expected_depth),
89
- "embedding": sqrt(2 / (5 * self.n_embd)),
90
- "embed_scale": sqrt(self.n_embd),
91
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
- super().__init__(
94
- # pad_token_id=65509,
95
- # bos_token_id=65504,
96
- # eos_token_id=65505,
97
- tie_word_embeddings=tie_embeddings,
98
- **kwargs,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modeling file for HF compatibility and zero-shot experiments."""
2
 
3
+ import torch
4
+ import math
5
 
6
+ from torch import Tensor
7
+ from torch.nn.attention.flex_attention import create_block_mask, BlockMask, flex_attention
8
+ from torch.nn.attention import bias as attn_bias
9
+ from dataclasses import dataclass
10
+ from typing import Union, Optional, Any
11
 
12
+
13
+ from .raven_config_minimal import RavenConfig
14
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
15
+
16
+ ###################### Huggingface Glue code I ##################################################################
17
+ from transformers import PreTrainedModel, GenerationMixin
18
+ from transformers.utils import ModelOutput
19
+ from transformers.generation.utils import GenerateDecoderOnlyOutput
20
+
21
+ import torch.nn.functional as F
22
+ from transformers import GenerationConfig
23
+
24
+ torch.backends.cuda.enable_math_sdp(False)
25
+
26
+
27
+ class RavenPreTrainedModel(PreTrainedModel):
28
+ config_class = RavenConfig
29
+ base_model_prefix = "model"
30
+ supports_gradient_checkpointing = True
31
+ _no_split_modules = ["SandwichBlock"]
32
+ _skip_keys_device_placement = ["past_key_values"]
33
+ _tied_weights_keys = ["lm_head.weight"]
34
+ _supports_flash_attn_2 = True
35
+ _supports_sdpa = True
36
+ _supports_cache_class = True
37
+ _supports_quantized_cache = False
38
+ _supports_static_cache = True
39
+ _tp_plan = {}
40
+
41
+ def _init_weights(self, module):
42
+ if not torch.rand((1,)).is_meta:
43
+ print("Random Initialization not implemented.")
44
+
45
+
46
+ @dataclass
47
+ class CausalLMOutputRecurrentLatents(ModelOutput):
48
+ loss: Optional[torch.Tensor] = None
49
+ log_ppl: Optional[torch.Tensor] = None
50
+ logits: Optional[torch.Tensor] = None
51
+ past_key_values: Optional[Cache] = None
52
+ latent_states: Optional[torch.Tensor] = None
53
+ hidden_states: Optional[torch.Tensor] = None
54
+ attention_maps: Optional[dict[int, torch.Tensor]] = None
55
+ stats: Optional[dict] = None
56
+
57
+
58
+ ###################### Minimal implementation from here ############################################################
59
+
60
+
61
+ class RMSNorm(torch.nn.Module):
62
+ """Saner dtype handling and slightly better for fusion"""
63
+
64
+ def __init__(self, dim: int, eps: float = 1e-6):
65
+ super().__init__()
66
+ self.eps = eps
67
+ self.weight = torch.nn.Parameter(torch.ones(dim))
68
+
69
+ def _norm(self, x):
70
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
71
+
72
+ def forward(self, x):
73
+ with torch.autocast(enabled=False, device_type=x.device.type if x.device.type != "meta" else "cuda"):
74
+ return self._norm(x.float()).type_as(x) * self.weight
75
+
76
+ def reset_parameters(self) -> None:
77
+ torch.nn.init.ones_(self.weight)
78
+
79
+
80
+ class HuginnDynamicCache(DynamicCache):
81
+ def __init__(self, lookup_strategy: str = "full") -> None:
82
+ super().__init__()
83
+ self._seen_tokens = 0
84
+ self.key_cache: dict[int, dict[int, torch.Tensor]] = {}
85
+ self.value_cache: dict[int, dict[int, torch.Tensor]] = {}
86
+ # structure: cache[index_of_layer_or_recurrent_step][index_in_sequence]
87
+ # the cache is held uncoalesced because certain recurrent steps may be missing for some sequence ids if using
88
+ # per-token adaptive compute. In those cases, the "lookup_strategy" determines how to proceed
89
+ # Also, It is critical that the head indices do not overlap with the recurrent iteration indices
90
+ self.lookup_strategy = lookup_strategy
91
+
92
+ def update(
93
+ self,
94
+ key_states: torch.Tensor,
95
+ value_states: torch.Tensor,
96
+ step_idx_tensor: torch.Tensor,
97
+ lookup_strategy: Optional[str] = None,
98
+ ) -> tuple[torch.Tensor, torch.Tensor]:
99
+ step_idx: int = int(step_idx_tensor) # todo: fix dicts with tensor step_idx, currently the memberships fail
100
+ lookup_strategy = self.lookup_strategy if lookup_strategy is None else lookup_strategy
101
+ if "compress-" in self.lookup_strategy and step_idx > 1: # hardcode for current model!
102
+ if "compress-s" in self.lookup_strategy:
103
+ compression_stage = int(self.lookup_strategy.split("compress-")[1][1:])
104
+ new_step_idx = (step_idx - 2) % compression_stage + 2
105
+ elif "compress-anchor" in self.lookup_strategy:
106
+ if step_idx - 2 < 4 * 8: # anchor onto first 8 recurrence steps # noqa: SIM108
107
+ new_step_idx = step_idx
108
+ else: # then re-use the next 4 KV states = one recurrence for all future recurrence
109
+ new_step_idx = 34 + (step_idx - 34) % 4
110
+ # print(step_idx, new_step_idx)
111
+ else: # compress-r
112
+ compression_stage = int(self.lookup_strategy.split("compress-")[1][1:])
113
+ new_step_idx = (step_idx - 2) // compression_stage + 2
114
+ step_idx = new_step_idx
115
+ # Init
116
+ if step_idx not in self.key_cache:
117
+ self.key_cache[step_idx] = {}
118
+ self.value_cache[step_idx] = {}
119
+ # Update the number of seen tokens, we assume that step_idx=0 (first prelude) is always hit
120
+ if step_idx == 0:
121
+ self._seen_tokens += key_states.shape[-2]
122
+ # Add entries to cache
123
+ for idx, entry in enumerate(key_states.unbind(dim=-2)):
124
+ if "compress-" not in self.lookup_strategy:
125
+ assert step_idx < 0 or self._seen_tokens - key_states.shape[-2] + idx not in self.key_cache[step_idx]
126
+ self.key_cache[step_idx][self._seen_tokens - key_states.shape[-2] + idx] = entry
127
+ for idx, entry in enumerate(value_states.unbind(dim=-2)):
128
+ self.value_cache[step_idx][self._seen_tokens - value_states.shape[-2] + idx] = entry
129
+
130
+ # Materialize past state based on lookup strategy:
131
+ if len(self.key_cache[step_idx]) == self._seen_tokens or self.lookup_strategy == "full":
132
+ # All entries are present, materialize cache as normal
133
+ return (
134
+ torch.stack(list(self.key_cache[step_idx].values()), dim=-2),
135
+ torch.stack(list(self.value_cache[step_idx].values()), dim=-2),
136
+ )
137
+ else: # some entries were not previously computed
138
+ if lookup_strategy.startswith("latest-m4"):
139
+ latest_keys = []
140
+ latest_values = []
141
+ for token_pos in range(self._seen_tokens):
142
+ # For steps >= 2, use modulo 4, this hard-codes the huginn block structure for now
143
+ if step_idx >= 2:
144
+ # Find valid steps for this token position
145
+ valid_steps = [s for s in range(step_idx + 1) if token_pos in self.key_cache[s]]
146
+ max_step = max([s for s in valid_steps if s >= 2 and s % 4 == step_idx % 4])
147
+ else:
148
+ max_step = step_idx if token_pos in self.key_cache[step_idx] else 0
149
+ latest_keys.append(self.key_cache[max_step][token_pos])
150
+ latest_values.append(self.value_cache[max_step][token_pos])
151
+ return torch.stack(latest_keys, dim=-2), torch.stack(latest_values, dim=-2)
152
+ elif lookup_strategy.startswith("available-m4"):
153
+ latest_keys = []
154
+ latest_values = []
155
+ for token_pos in range(self._seen_tokens):
156
+ if token_pos in self.key_cache[step_idx]:
157
+ step = step_idx
158
+ else:
159
+ # Find valid steps for this token position
160
+ valid_steps = [s for s in range(step_idx + 1) if token_pos in self.key_cache[s]]
161
+ step = max([s for s in valid_steps if s >= 2 and s % 4 == step_idx % 4])
162
+ latest_keys.append(self.key_cache[step][token_pos])
163
+ latest_values.append(self.value_cache[step][token_pos])
164
+ return torch.stack(latest_keys, dim=-2), torch.stack(latest_values, dim=-2)
165
+ elif lookup_strategy.startswith("always-last-m4"):
166
+ latest_keys = []
167
+ latest_values = []
168
+ for token_pos in range(self._seen_tokens):
169
+ # For steps >= 2, use modulo 4, this hard-codes the huginn block structure for now
170
+ if step_idx >= 2:
171
+ # Find valid steps for this token position
172
+ valid_steps = [key_step for key_step in self.key_cache if token_pos in self.key_cache[key_step]]
173
+ max_step = max([s for s in valid_steps if s >= 2 and s % 4 == step_idx % 4])
174
+ else:
175
+ max_step = step_idx if token_pos in self.key_cache[step_idx] else 0
176
+ latest_keys.append(self.key_cache[max_step][token_pos])
177
+ latest_values.append(self.value_cache[max_step][token_pos])
178
+ return torch.stack(latest_keys, dim=-2), torch.stack(latest_values, dim=-2)
179
+ elif lookup_strategy.startswith("skip"):
180
+ existing_keys = []
181
+ existing_values = []
182
+ for token_pos in range(self._seen_tokens):
183
+ if token_pos in self.key_cache[step_idx]:
184
+ existing_keys.append(self.key_cache[step_idx][token_pos])
185
+ existing_values.append(self.value_cache[step_idx][token_pos])
186
+ return torch.stack(existing_keys, dim=-2), torch.stack(existing_values, dim=-2)
187
+ elif lookup_strategy.startswith("randomized"): # sanity check
188
+ rand_keys = []
189
+ rand_values = []
190
+ for token_pos in range(self._seen_tokens):
191
+ if step_idx < 2: # For prelude steps
192
+ max_step = step_idx if token_pos in self.key_cache[step_idx] else 0
193
+ else: # Get all steps from same block position
194
+ curr_modulo = (step_idx - 2) % 4 + 2
195
+ valid_steps = [
196
+ s
197
+ for s in range(2, step_idx + 1)
198
+ if (s - 2) % 4 + 2 == curr_modulo and token_pos in self.key_cache[s]
199
+ ]
200
+ max_step = valid_steps[torch.randint(len(valid_steps), (1,))]
201
+ rand_keys.append(self.key_cache[max_step][token_pos])
202
+ rand_values.append(self.value_cache[max_step][token_pos])
203
+ return torch.stack(rand_keys, dim=-2), torch.stack(rand_values, dim=-2)
204
+ else:
205
+ raise ValueError(f"Unknown lookup strategy: {lookup_strategy}")
206
+
207
+ def reset(self) -> None:
208
+ """Reset the cache state."""
209
+ self._seen_tokens = 0
210
+ self.key_cache.clear()
211
+ self.value_cache.clear()
212
+
213
+ def clear_last_k_entries(self, k: int = 0):
214
+ """Partially clear cache."""
215
+ assert self._seen_tokens >= k
216
+ self._seen_tokens = self._seen_tokens - k
217
+ # self.key_cache[step_idx][self._seen_tokens - key_states.shape[-2] + idx] = entry
218
+ self.key_cache = {
219
+ step: {seq: seq_cache for seq, seq_cache in cache.items() if seq < self._seen_tokens}
220
+ for step, cache in self.key_cache.items()
221
+ }
222
+ self.value_cache = {
223
+ step: {seq: seq_cache for seq, seq_cache in cache.items() if seq < self._seen_tokens}
224
+ for step, cache in self.value_cache.items()
225
+ }
226
+
227
+ def get_seq_length(self, step_idx: int = 0) -> int:
228
+ return self._seen_tokens
229
+
230
+ def get_memory_usage(self) -> float:
231
+ total_bytes = 0
232
+ # For each recurrent step/layer index
233
+ for step_idx in self.key_cache:
234
+ # Get the sequence cache for this step
235
+ key_seq_cache = self.key_cache[step_idx]
236
+ for seq_idx in key_seq_cache:
237
+ key_tensor = key_seq_cache[seq_idx]
238
+ # Add memory for of key tensors, assuming value is the same
239
+ total_bytes += key_tensor.nelement() * key_tensor.element_size()
240
+ return total_bytes * 2 / (1024 * 1024)
241
+
242
+
243
+ class HuginnStaticCache(Cache):
244
+ """Static Cache for the recurrent model"""
245
+
246
+ is_compileable = False # this is todo
247
+
248
+ def __init__(
249
+ self,
250
+ max_length: int,
251
+ max_num_steps: int,
252
+ num_heads: int,
253
+ hidden_dim: int,
254
+ batch_size: int = 1,
255
+ lookup_strategy: str = "full",
256
+ device: Optional[Union[torch.device, str]] = None,
257
+ dtype: torch.dtype = torch.float32,
258
+ ) -> None:
259
+ super().__init__()
260
+ self._seen_tokens = 0
261
+ self.max_length = max_length
262
+ self.lookup_strategy = lookup_strategy
263
+
264
+ # Adjust max_num_steps based on compression strategy
265
+ if "compress-" in lookup_strategy:
266
+ compression_stage = int(lookup_strategy.split("compress-")[1][1:])
267
+ if "compress-s" in lookup_strategy:
268
+ # For modulo compression (s), we need steps for 0,1 + compressed steps
269
+ self.max_num_steps = 4 + compression_stage
270
+ else:
271
+ # For relative compression, we need steps for 0,1 + compressed steps
272
+ self.max_num_steps = 4 + (max_num_steps - 4 + compression_stage - 1) // compression_stage
273
+ else:
274
+ self.max_num_steps = max_num_steps
275
+
276
+ # Pre-allocate cache tensors [steps, batch, heads, seq_len, head_dim]
277
+ device = torch.device(device) if device is not None else None
278
+ cache_shape = (self.max_num_steps, batch_size, num_heads, max_length, hidden_dim)
279
+
280
+ self.key_cache = torch.zeros(cache_shape, dtype=dtype, device=device)
281
+ self.value_cache = torch.zeros(cache_shape, dtype=dtype, device=device)
282
+ self.valid_mask = torch.zeros((self.max_num_steps, max_length), dtype=torch.bool, device=device)
283
+ # Mark tensors as static for compile
284
+ torch._dynamo.mark_static_address(self.key_cache)
285
+ torch._dynamo.mark_static_address(self.value_cache)
286
+ torch._dynamo.mark_static_address(self.valid_mask)
287
+
288
+ def update(
289
+ self,
290
+ key_states: torch.Tensor,
291
+ value_states: torch.Tensor,
292
+ step_idx: torch.Tensor,
293
+ lookup_strategy: Optional[str] = None,
294
+ ) -> tuple[torch.Tensor, torch.Tensor]:
295
+ if step_idx == 0:
296
+ self._seen_tokens += key_states.shape[-2]
297
+
298
+ # Adjust step_idx for compression
299
+ lookup_strategy = self.lookup_strategy if lookup_strategy is None else lookup_strategy
300
+ if "compress-" in lookup_strategy and step_idx > 1:
301
+ compression_stage = int(lookup_strategy.split("compress-")[1][1:])
302
+ if "compress-s" in lookup_strategy:
303
+ step_idx = (step_idx - 2) % compression_stage + 2
304
+ else:
305
+ step_idx = (step_idx - 2) // compression_stage + 2
306
+
307
+ start_idx = self._seen_tokens - key_states.shape[-2]
308
+
309
+ indices = torch.arange(start_idx, start_idx + key_states.shape[-2], device=key_states.device)
310
+ self.key_cache[step_idx].index_copy_(2, indices, key_states)
311
+ self.value_cache[step_idx].index_copy_(2, indices, value_states)
312
+ self.valid_mask[step_idx, start_idx : start_idx + key_states.shape[-2]] = True
313
+
314
+ # Return based on lookup strategy
315
+ if lookup_strategy == "full":
316
+ return (
317
+ self.key_cache[step_idx, :, :, : self._seen_tokens],
318
+ self.value_cache[step_idx, :, :, : self._seen_tokens],
319
+ )
320
+ elif lookup_strategy.startswith("latest-m4"):
321
+ if step_idx >= 2:
322
+ pattern_steps = torch.arange(2, step_idx.item() + 1, 4, device=self.valid_mask.device)
323
+ pattern_valid = self.valid_mask[pattern_steps]
324
+ max_valid_step = pattern_steps[pattern_valid.to(torch.long).argmax(dim=0)]
325
+ return (
326
+ self.key_cache[max_valid_step, torch.arange(self._seen_tokens)],
327
+ self.value_cache[max_valid_step, torch.arange(self._seen_tokens)],
328
+ )
329
+ return self.key_cache[step_idx, :, :, : self._seen_tokens], self.value_cache[
330
+ step_idx, :, :, : self._seen_tokens
331
+ ]
332
+ elif lookup_strategy == "skip":
333
+ valid_mask = self.valid_mask[step_idx, : self._seen_tokens]
334
+ return (
335
+ self.key_cache[step_idx, :, :, : self._seen_tokens][valid_mask],
336
+ self.value_cache[step_idx, :, :, : self._seen_tokens][valid_mask],
337
+ )
338
+ elif lookup_strategy.startswith("randomized"):
339
+ if step_idx < 2:
340
+ max_step = step_idx
341
+ else:
342
+ curr_modulo = (step_idx - 2) % 4 + 2
343
+ valid_steps = (
344
+ torch.where(
345
+ (torch.arange(2, step_idx.item() + 1, device=self.valid_mask.device) - 2) % 4 + 2 == curr_modulo
346
+ )[0]
347
+ + 2
348
+ )
349
+ rand_idx = torch.randint(len(valid_steps), (1,), device=valid_steps.device)
350
+ max_step = valid_steps[rand_idx]
351
+ return self.key_cache[max_step, : self._seen_tokens], self.value_cache[max_step, : self._seen_tokens]
352
+ else:
353
+ raise ValueError(f"Unknown lookup strategy: {lookup_strategy}")
354
+
355
+ def reset(self) -> None:
356
+ self._seen_tokens = 0
357
+ self.key_cache.zero_()
358
+ self.value_cache.zero_()
359
+ self.valid_mask.zero_()
360
+
361
+ def get_seq_length(self, step_idx: int = 0) -> int:
362
+ return self._seen_tokens
363
+
364
+ def get_memory_usage(self) -> float:
365
+ return (self.key_cache.nelement() + self.value_cache.nelement()) * self.key_cache.element_size() / (1024 * 1024)
366
+
367
+
368
+ ValidCache = HuginnDynamicCache | HuginnStaticCache
369
+
370
+
371
+ class CausalSelfAttention(torch.nn.Module):
372
+ def __init__(self, config: RavenConfig) -> None:
373
+ super().__init__()
374
+ self.config = config
375
+ self.n_head = config.num_attention_heads
376
+ self.n_kv_heads = config.num_key_value_heads
377
+ self.head_dim = config.n_embd // self.n_head
378
+
379
+ shape = (self.n_head + 2 * self.n_kv_heads) * self.head_dim
380
+ self.chunks = [config.n_embd, self.n_kv_heads * self.head_dim, self.n_kv_heads * self.head_dim]
381
+ self.Wqkv = torch.nn.Linear(config.n_embd, shape, bias=False)
382
+ if config.qk_bias:
383
+ self.qk_bias = torch.nn.Parameter(torch.zeros(2, 1, self.n_head, self.head_dim))
384
+ self.proj = torch.nn.Linear(config.n_embd, config.n_embd, bias=False)
385
+
386
+ def forward(
387
+ self,
388
+ x: Tensor,
389
+ freqs_cis: Tensor,
390
+ block_idx: torch.Tensor,
391
+ mask: Optional[BlockMask] = None,
392
+ past_key_values: Optional[ValidCache] = None,
393
+ ) -> Tensor:
394
+ B, S, E = x.shape # batch size, sequence length, embedding dimensionality (n_embd)
395
+ q, k, v = self.Wqkv(x).split(self.chunks, dim=2)
396
+ q = q.view(B, S, self.n_head, self.head_dim)
397
+ k = k.view(B, S, self.n_kv_heads, self.head_dim)
398
+ v = v.view(B, S, self.n_kv_heads, self.head_dim)
399
+ # bias?
400
+ if self.config.qk_bias:
401
+ q_bias, k_bias = self.qk_bias.split(1, dim=0)
402
+ q, k = (q + q_bias).to(q.dtype), (k + k_bias).to(q.dtype)
403
+ # apply rotary
404
+ q, k = apply_rotary_emb_complex_like(q, k, freqs_cis=freqs_cis)
405
+
406
+ q = q.transpose(1, 2) # (B, nh, S, hs)
407
+ k = k.transpose(1, 2)
408
+ v = v.transpose(1, 2)
409
+
410
+ if past_key_values is not None:
411
+ k, v = past_key_values.update(k, v, block_idx)
412
+
413
+ if mask is not None:
414
+ y: torch.Tensor = flex_attention(q, k, v, block_mask=mask) # type: ignore
415
+ else:
416
+ if q.shape[2] < k.shape[2]:
417
+ if q.shape[2] > 1:
418
+ bias = attn_bias.causal_lower_right(q.shape[2], k.shape[2])
419
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, bias, dropout_p=0.0)
420
+ else:
421
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
422
+ else:
423
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True)
424
+ y = y.transpose(1, 2).reshape(B, S, E).contiguous() # reshape is a view if possible (it mostly is)
425
+ return self.proj(y)
426
+
427
+
428
+ class GatedMLP(torch.nn.Module):
429
+ def __init__(self, config: RavenConfig, in_features: int = 0) -> None:
430
+ super().__init__()
431
+ in_features = config.n_embd if in_features == 0 else in_features
432
+ self.fc = torch.nn.Linear(in_features, config.intermediate_size * 2, bias=False)
433
+
434
+ self.proj = torch.nn.Linear(config.intermediate_size, config.n_embd, bias=False)
435
+ self.nonlin = torch.nn.SiLU()
436
+
437
+ def forward(self, x: Tensor) -> Tensor:
438
+ # modified to single FC layer to improve parallelism
439
+ x_fc_1, x_fc_2 = self.fc(x).chunk(2, dim=-1)
440
+ x = self.nonlin(x_fc_1) * x_fc_2
441
+ return self.proj(x)
442
+
443
+
444
+ class SandwichBlock(torch.nn.Module):
445
+ expanded = False
446
+
447
+ def __init__(self, config: RavenConfig, layer_id: int) -> None:
448
+ super().__init__()
449
+ self.norm_1 = RMSNorm(config.n_embd, eps=config.norm_eps)
450
+ self.attn = CausalSelfAttention(config)
451
+ self.norm_2 = RMSNorm(config.n_embd, eps=config.norm_eps)
452
+ self.mlp = GatedMLP(config)
453
+ self.norm_3 = RMSNorm(config.n_embd, eps=config.norm_eps)
454
+ self.norm_4 = RMSNorm(config.n_embd, eps=config.norm_eps)
455
+ self.layer_id = layer_id
456
+
457
+ def forward(
458
+ self,
459
+ x: Tensor,
460
+ freqs_cis: Tensor,
461
+ step_idx: int,
462
+ mask: Optional[BlockMask] = None,
463
+ past_key_values: Optional[ValidCache] = None,
464
+ ) -> Tensor:
465
+ attn_out = self.attn(self.norm_1(x), freqs_cis, step_idx, mask, past_key_values)
466
+ x = self.norm_2(attn_out + x)
467
+ x = self.norm_4(self.mlp(self.norm_3(x)) + x)
468
+ return x
469
+
470
+
471
+ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
472
+ freqs_cis: torch.Tensor
473
 
474
  def __init__(
475
  self,
476
+ config: RavenConfig,
477
+ ) -> None:
478
+ super().__init__(config)
479
+ self.config = config
480
+
481
+ # Transformer layers
482
+ prelude = torch.nn.ModuleList(SandwichBlock(config, layer_id=i) for i in range(config.n_layers_in_prelude))
483
+ adapter = torch.nn.Linear(config.n_embd * 2, config.n_embd, bias=config.bias)
484
+ core_block = torch.nn.ModuleList(
485
+ SandwichBlock(config, layer_id=i + config.n_layers_in_prelude)
486
+ for i in range(config.n_layers_in_recurrent_block)
487
+ )
488
+ o = config.n_layers_in_prelude + config.n_layers_in_recurrent_block * config.mean_recurrence
489
+ coda = torch.nn.ModuleList(SandwichBlock(config, layer_id=i + o) for i in range(config.n_layers_in_coda))
490
+
491
+ self.transformer = torch.nn.ModuleDict(
492
+ dict(
493
+ wte=torch.nn.Embedding(config.padded_vocab_size, config.n_embd),
494
+ prelude=prelude,
495
+ adapter=adapter,
496
+ core_block=core_block,
497
+ coda=coda,
498
+ ln_f=RMSNorm(config.n_embd, eps=config.norm_eps), # used twice :>
499
+ )
500
+ )
501
+ self.emb_scale = config.init_values["embed_scale"]
502
+ # Head
503
+ self.lm_head = torch.nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)
504
+ if self.config.tie_embeddings:
505
+ self.tie_weights()
506
+ # rope
507
+ self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True)
508
+
509
+ def get_input_embeddings(self):
510
+ return self.transformer.wte
511
+
512
+ def get_output_embeddings(self):
513
+ return self.lm_head
514
+
515
+ def _precompute_freqs_cis(self):
516
+ # can actually be a buffer now, and remains in fp32! (at least in the settings I tested)
517
+ freqs_cis = precompute_freqs_cis(
518
+ self.config.n_embd // self.config.num_attention_heads, self.config.block_size, self.config.rope_base, 1
519
+ )
520
+ return freqs_cis
521
+
522
+ def compile_mask(
523
+ self,
524
+ input_ids: torch.Tensor,
525
+ attention_mask: Optional[torch.Tensor] = None,
526
+ past_key_values: Optional[ValidCache] = None,
527
+ pad_token_id=65509,
528
+ ) -> Optional[BlockMask]:
529
+ batch_size, seq_len = input_ids.shape[0], input_ids.shape[1]
530
+
531
+ # If no padding and no attention mask, no need for a mask
532
+ if attention_mask is None and (input_ids == pad_token_id).sum() == 0:
533
+ return None
534
+
535
+ if past_key_values is not None and seq_len == 1:
536
+ return None
537
+
538
+ # Get total sequence length including cache
539
+ cache_len = past_key_values.get_seq_length() if past_key_values is not None else 0
540
+ kv_length = cache_len + seq_len
541
+
542
+ if attention_mask is None:
543
+
544
+ def mask_mod(b, h, q_idx, kv_idx):
545
+ return q_idx >= kv_idx & (input_ids[b, kv_idx] != pad_token_id)
546
+ else:
547
+
548
+ def mask_mod(b, h, q_idx, kv_idx):
549
+ return (q_idx >= kv_idx) & (input_ids[b, kv_idx] != pad_token_id) & attention_mask[b, q_idx, kv_idx]
550
+
551
+ kv_length = past_key_values.get_seq_length() if past_key_values is not None else seq_len
552
+ if kv_length == 0:
553
+ kv_length = seq_len # prefill
554
+ block_mask = create_block_mask(
555
+ mask_mod,
556
+ B=batch_size,
557
+ H=None,
558
+ Q_LEN=seq_len,
559
+ KV_LEN=kv_length,
560
+ device=input_ids.device,
561
+ )
562
+
563
+ # # Define mask_mod function
564
+ # def mask_mod(b, h, q_idx, kv_idx):
565
+ # # Always apply causal constraint
566
+ # is_causal = q_idx >= kv_idx
567
+
568
+ # # Handle cache vs current tokens
569
+ # is_cache = kv_idx < cache_len
570
+ # current_idx = kv_idx - cache_len
571
+
572
+ # # For cache: always valid; For current: check padding
573
+ # not_pad = input_ids[b, current_idx] != pad_token_id
574
+ # valid = is_cache | not_pad
575
+
576
+ # # Apply attention mask if provided
577
+ # if attention_mask is not None:
578
+ # q_idx_curr = q_idx - cache_len
579
+ # attn_valid = attention_mask[b, q_idx_curr, current_idx]
580
+ # valid = valid & (is_cache | attn_valid)
581
+
582
+ # return is_causal & valid
583
+
584
+ # def mask_mod(b, h, q_idx, kv_idx):
585
+ # is_causal = q_idx >= kv_idx
586
+ # is_current = (kv_idx >= cache_len) & (kv_idx < kv_length)
587
+ # current_idx = kv_idx - cache_len
588
+
589
+ # is_valid = (~is_current) | (
590
+ # (current_idx >= 0) & (current_idx < seq_len) & (input_ids != pad_token_id)[b, current_idx % seq_len]
591
+ # )
592
+
593
+ # return is_causal & is_valid
594
+
595
+ # # Define mask_mod function
596
+ # def mask_mod(b, h, q_idx, kv_idx):
597
+ # # Always apply causal constraint
598
+ # is_causal = q_idx >= kv_idx
599
+
600
+ # # Handle cache vs current tokens
601
+ # is_cache = kv_idx < cache_len
602
+ # current_idx = kv_idx - cache_len
603
+ # in_bounds = (current_idx >= 0) & (current_idx < seq_len)
604
+
605
+ # # For cache: always valid; For current: check padding
606
+ # not_pad = (input_ids[b, current_idx % seq_len] != pad_token_id) | ~in_bounds
607
+ # valid = is_cache | (not_pad & in_bounds)
608
+
609
+ # # Apply attention mask if provided
610
+ # if attention_mask is not None:
611
+ # q_idx_curr = q_idx - cache_len
612
+ # q_in_bounds = (q_idx_curr >= 0) & (q_idx_curr < seq_len)
613
+ # attn_valid = attention_mask[b, q_idx_curr % seq_len, current_idx % seq_len] | ~(in_bounds & q_in_bounds)
614
+ # valid = valid & (is_cache | attn_valid)
615
+
616
+ # return is_causal & valid
617
+
618
+ # Create block mask
619
+ block_mask = create_block_mask(
620
+ mask_mod,
621
+ B=batch_size,
622
+ H=None,
623
+ Q_LEN=seq_len,
624
+ KV_LEN=kv_length,
625
+ device=input_ids.device,
626
+ )
627
+
628
+ return block_mask
629
+
630
+ def forward(
631
+ self,
632
+ input_ids: torch.Tensor,
633
+ input_embeds: Optional[torch.Tensor] = None,
634
+ input_states: Optional[torch.Tensor] = None,
635
+ attention_mask: Optional[torch.Tensor] = None, # binary mask of shape q x kv, True=valid position
636
+ position_ids: Optional[torch.Tensor] = None,
637
+ labels: Optional[torch.Tensor] = None,
638
+ num_steps: Optional[torch.Tensor] = None,
639
+ past_key_values: Optional[ValidCache] = None,
640
+ output_details: dict = {
641
+ "return_logits": True,
642
+ "return_latents": True,
643
+ "return_head": False,
644
+ "return_stats": False,
645
+ },
646
+ use_cache: bool = False,
647
+ cache_position: Optional[torch.Tensor] = None,
648
+ init_scale: float = 1.0,
649
+ **kwargs,
650
+ ) -> CausalLMOutputRecurrentLatents:
651
+ # Support multiple position formats:
652
+ if position_ids is None and cache_position is None:
653
+ freqs_cis = self.freqs_cis[:, : input_ids.shape[1]]
654
+ elif position_ids is not None:
655
+ freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze())
656
+ elif cache_position is not None:
657
+ freqs_cis = self.freqs_cis[:, cache_position]
658
+
659
+ if input_embeds is None:
660
+ input_embeds = self.transformer.wte(input_ids) # type: ignore # types broken in 2.6+
661
+
662
+ if self.emb_scale != 1:
663
+ input_embeds = input_embeds * self.emb_scale # type: ignore
664
+
665
+ if use_cache and past_key_values is None:
666
+ past_key_values = HuginnDynamicCache()
667
+
668
+ prepared_attn_mask = None # self.compile_mask(input_ids, attention_mask, past_key_values)
669
+ block_idx = torch.tensor(-1, device=torch.device("cpu"), dtype=torch.long) # count in tensors for compile
670
+ # Non-recurrent prelude
671
+ for block in self.transformer.prelude: # type: ignore # types broken in 2.6+
672
+ block_idx += 1
673
+ input_embeds = block(input_embeds, freqs_cis, block_idx, prepared_attn_mask, past_key_values)
674
+
675
+ # Main recurrence
676
+ x, num_steps_no_grad, num_steps_with_grad, xk, block_idx = self.iterate_forward(
677
+ input_embeds, # type: ignore # mystery typing error
678
+ input_states,
679
+ freqs_cis,
680
+ block_idx,
681
+ prepared_attn_mask,
682
+ past_key_values,
683
+ num_steps,
684
+ init_scale,
685
+ )
686
+ latent_states = x.clone().detach()
687
+
688
+ # Coda layers
689
+ block_idx = torch.tensor(0, device=torch.device("cpu"), dtype=torch.long) # use negative indices for head
690
+ for block in self.transformer.coda: # type: ignore # types broken in 2.6+
691
+ block_idx -= 1
692
+ x = block(x, freqs_cis, block_idx, prepared_attn_mask, past_key_values)
693
+ x = self.transformer.ln_f(x) # type: ignore # types broken in 2.6+
694
+
695
+ # Prediction head, assuming labels really are labels and not equal to input_ids
696
+ if labels is not None:
697
+ logits = self.lm_head(x).float()
698
+ loss = torch.nn.functional.cross_entropy(
699
+ logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=-100
700
+ )
701
+ log_ppl = loss.clone().detach().exp()
702
+ else:
703
+ logits = self.lm_head(x).float()
704
+ loss, log_ppl = torch.as_tensor(0.0), torch.as_tensor(0.0)
705
+
706
+ return CausalLMOutputRecurrentLatents(
707
+ loss=loss,
708
+ log_ppl=log_ppl,
709
+ logits=logits if output_details["return_logits"] else None,
710
+ past_key_values=past_key_values,
711
+ hidden_states=x if output_details["return_head"] else None,
712
+ latent_states=latent_states if output_details["return_latents"] else None,
713
+ stats=self.get_stats(logits, x, latent_states, xk, input_embeds, num_steps_no_grad, num_steps_with_grad)
714
+ if output_details["return_stats"]
715
+ else None,
716
+ )
717
+
718
+ @torch._dynamo.disable(recursive=False) # type: ignore
719
+ def iterate_forward(
720
+ self,
721
+ input_embeds: torch.Tensor,
722
+ input_states: torch.Tensor,
723
+ freqs_cis,
724
+ block_idx: torch.Tensor,
725
+ mask: Optional[BlockMask],
726
+ past_key_values: Optional[ValidCache] = None,
727
+ num_steps: Optional[torch.Tensor] = None,
728
+ init_scale: float = 1.0,
729
+ ):
730
+ x = xk = self.initialize_state(input_embeds, scale=init_scale) if input_states is None else input_states.clone()
731
+ if num_steps is None:
732
+ num_steps_no_grad, num_steps_with_grad = self.randomized_iteration_sampler() # type: ignore
733
+ elif hasattr(num_steps, "__len__") and len(num_steps) > 1:
734
+ num_steps_no_grad, num_steps_with_grad = num_steps
735
+ else:
736
+ num_steps_no_grad, num_steps_with_grad = num_steps, torch.tensor(0) if not x.is_meta else 0
737
+
738
+ with torch.no_grad():
739
+ # ultra annoying in ddp due to
740
+ # https://discuss.pytorch.org/t/does-distributeddataparallel-work-with-torch-no-grad-and-find-unused-parameters-false/122594
741
+ # for now running with find_unused_params=True enabled even though the graph structure is (technically) clear
742
+ # and all parameters are always used
743
+ for no_grad_step in range(num_steps_no_grad):
744
+ xk = x
745
+ x, block_idx = self.core_block_forward(
746
+ xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, no_grad_step
747
+ )
748
+
749
+ for grad_step in range(num_steps_with_grad):
750
+ xk = x
751
+ x, block_idx = self.core_block_forward(
752
+ xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, num_steps_no_grad + grad_step
753
+ )
754
+ return self.transformer.ln_f(x), num_steps_no_grad, num_steps_with_grad, xk.detach(), block_idx # type: ignore # types broken in 2.6+
755
+
756
+ def core_block_forward(
757
+ self,
758
+ x,
759
+ input_embeds,
760
+ freqs_cis,
761
+ mask: Optional[BlockMask],
762
+ past_key_values,
763
+ block_idx: torch.Tensor,
764
+ current_step: int | Tensor,
765
+ ):
766
+ x = self._maybe_inject_noise(x, current_step)
767
+ x = self.transformer.adapter(torch.cat([x, input_embeds.to(x.device)], dim=-1)) # type: ignore # types broken in 2.6+
768
+ for block in self.transformer.core_block: # type: ignore # types broken in 2.6+
769
+ block_idx += 1
770
+ x = block(x, freqs_cis, block_idx, mask, past_key_values)
771
+ return x, block_idx
772
+
773
+ @torch.no_grad()
774
+ def iterate_one_step(
775
+ self,
776
+ input_embeds,
777
+ input_states,
778
+ position_ids: Optional[torch.Tensor] = None,
779
+ cache_position: Optional[torch.Tensor] = None,
780
+ block_idx: torch.Tensor = torch.tensor(0, dtype=torch.long),
781
+ attention_mask: Optional[BlockMask] = None,
782
+ past_key_values: Optional[ValidCache] = None,
783
+ current_step: int = 0,
784
+ ):
785
+ if position_ids is None and cache_position is None:
786
+ freqs_cis = self.freqs_cis[:, : input_embeds.shape[1]]
787
+ elif position_ids is not None:
788
+ freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze())
789
+ elif cache_position is not None:
790
+ freqs_cis = self.freqs_cis[:, cache_position]
791
+ x, block_idx = self.core_block_forward(
792
+ input_states,
793
+ input_embeds,
794
+ freqs_cis,
795
+ attention_mask,
796
+ past_key_values,
797
+ block_idx,
798
+ current_step=current_step,
799
+ )
800
+ return x, block_idx, current_step + 1
801
+
802
+ def predict_from_latents(
803
+ self,
804
+ latents,
805
+ attention_mask: Optional[BlockMask] = None,
806
+ position_ids: Optional[torch.Tensor] = None,
807
+ cache_position: Optional[torch.Tensor] = None,
808
+ past_key_values: Optional[ValidCache] = None,
809
+ ):
810
+ if position_ids is None and cache_position is None:
811
+ freqs_cis = self.freqs_cis[:, : latents.shape[1]]
812
+ elif position_ids is not None:
813
+ freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze())
814
+ elif cache_position is not None:
815
+ freqs_cis = self.freqs_cis[:, cache_position]
816
+ x = self.transformer.ln_f(latents) # type: ignore # types broken in 2.6+
817
+ # Coda layers
818
+ block_idx = torch.tensor(0, device=torch.device("cpu"), dtype=torch.long) # use negative indices for head
819
+ for block in self.transformer.coda: # type: ignore # types broken in 2.6+
820
+ block_idx -= 1
821
+ x = block(x, freqs_cis, block_idx, attention_mask, past_key_values)
822
+ x = self.transformer.ln_f(x) # type: ignore # types broken in 2.6+
823
+
824
+ logits = self.lm_head(x).float()
825
+
826
+ return CausalLMOutputRecurrentLatents(
827
+ loss=torch.as_tensor(0.0),
828
+ log_ppl=torch.as_tensor(0.0),
829
+ logits=logits,
830
+ past_key_values=past_key_values,
831
+ latent_states=x,
832
+ )
833
+
834
+ def embed_inputs(
835
+ self,
836
+ input_ids: torch.Tensor,
837
+ attention_mask: Optional[torch.Tensor] = None,
838
+ position_ids: Optional[torch.Tensor] = None,
839
+ past_key_values: Optional[ValidCache] = None,
840
+ use_cache: bool = False,
841
+ cache_position: Optional[torch.Tensor] = None,
842
+ **kwargs,
843
+ ) -> tuple[torch.Tensor, torch.Tensor]:
844
+ # Support multiple position formats:
845
+ if position_ids is None and cache_position is None:
846
+ freqs_cis = self.freqs_cis[:, : input_ids.shape[1]]
847
+ elif position_ids is not None:
848
+ freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze())
849
+ elif cache_position is not None:
850
+ freqs_cis = self.freqs_cis[:, cache_position]
851
+
852
+ input_embeds = self.transformer.wte(input_ids) # type: ignore # types broken in 2.6+
853
+ prepared_attn_mask = self.compile_mask(input_ids, attention_mask)
854
+
855
+ if self.emb_scale != 1:
856
+ input_embeds = input_embeds * self.emb_scale # type: ignore
857
+
858
+ if use_cache and past_key_values is None:
859
+ past_key_values = HuginnDynamicCache()
860
+
861
+ block_idx = torch.tensor(-1, device=torch.device("cpu"), dtype=torch.long) # count in tensors for compile
862
+ # Non-recurrent prelude
863
+ for block in self.transformer.prelude: # type: ignore # types broken in 2.6+
864
+ block_idx += 1
865
+ input_embeds = block(input_embeds, freqs_cis, block_idx, prepared_attn_mask, past_key_values)
866
+ return input_embeds, block_idx
867
+
868
+ @torch._dynamo.disable(recursive=False) # type: ignore
869
+ def randomized_iteration_sampler(self) -> tuple[torch.Tensor, torch.Tensor]:
870
+ """Outputs are long tensors so that they can be passed through compiled functions"""
871
+ t = max(self.config.mean_recurrence - self.config.mean_backprop_depth, 0)
872
+ s = self.config.mean_backprop_depth
873
+ if torch.rand((1,)).is_meta: # annoying clause to make meta-tensor-based flop counting work
874
+ # these values are only the mean TFLOPs of the randomized sampler
875
+ # Note that this clause also breaks the contract, and returns ints in meta tensor mode
876
+ return t, s # type: ignore
877
+ if self.training:
878
+ sigma = 0.5
879
+ mu = math.log(t + s) - (sigma**2 / 2)
880
+ rate = torch.zeros((1,)).log_normal_(mean=mu, std=sigma)
881
+ p = torch.poisson(torch.tensor([rate], dtype=torch.float)) + 1
882
+ n = torch.clamp(p - s, min=0)
883
+ k = torch.as_tensor(torch.minimum(torch.as_tensor(s), p))
884
+ else:
885
+ n, k = torch.as_tensor(self.config.mean_recurrence), torch.as_tensor(0)
886
+
887
+ return n.to(dtype=torch.long), k.to(dtype=torch.long)
888
+
889
+ def initialize_state(self, input_embeds, scale: float = 1.0):
890
+ x = torch.randn_like(input_embeds)
891
+ std = self.config.init_values["std"] * scale
892
+ if std > 0:
893
+ torch.nn.init.trunc_normal_(x, mean=0.0, std=std, a=-3 * std, b=3 * std)
894
+ if self.emb_scale != 1:
895
+ x = x * self.emb_scale
896
+ else:
897
+ x.zero_()
898
+ return x
899
+
900
+ def _maybe_inject_noise(self, x, current_step, renorm=False):
901
+ if self.config.test_time_noise > 0:
902
+ n = self.config.test_time_noise * self.config.init_values["std"] * self.emb_scale
903
+ if self.config.test_time_noise_type == "geom":
904
+ step1 = torch.as_tensor(current_step + 1, device=x.device) # need to cast for compile
905
+ x = x * (1 - n / step1) + torch.randn_like(x) * n / step1
906
+ elif self.config.test_time_noise_type == "sqrt":
907
+ step1sqrt = torch.as_tensor(current_step + 1, device=x.device).sqrt() # need to cast for compile
908
+ x = x * (1 - n / step1sqrt) + torch.randn_like(x) * n / step1sqrt
909
+ elif self.config.test_time_noise_type == "line":
910
+ noise = max(n, (self.config.mean_recurrence - current_step) / self.config.mean_recurrence) # type: ignore
911
+ x = x * (1 - noise) + torch.randn_like(x) * noise
912
+ elif self.config.test_time_noise_type == "chi":
913
+ noise = 2 * torch.rand(1, device=x.device, dtype=x.dtype) * n
914
+ x = x * (1 - noise) + torch.randn_like(x) * noise
915
+ elif self.config.test_time_noise_type == "fixed":
916
+ x = x * (1 - n) + torch.randn_like(x) * n
917
+ else:
918
+ raise ValueError()
919
+
920
+ if renorm:
921
+ x = self.transformer.core_block[-1].norm_4(x) # type: ignore moduledict types still broken in pytorch
922
+ return x
923
+
924
+ def prepare_inputs_for_generation(
925
+ self,
926
+ input_ids: torch.Tensor,
927
+ past_key_values: Optional[Cache] = None,
928
+ attention_mask: Optional[torch.Tensor] = None,
929
+ inputs_embeds: Optional[torch.FloatTensor] = None,
930
+ cache_position: Optional[torch.Tensor] = None,
931
+ cache_lookup_strategy: str = "full",
932
  **kwargs,
933
  ):
934
+ model_inputs = {}
935
+ model_inputs["cache_position"] = cache_position
936
+ current_input_length = input_ids.shape[1]
937
+
938
+ if past_key_values is not None:
939
+ if not isinstance(past_key_values, (HuginnDynamicCache, HuginnStaticCache)):
940
+ assert past_key_values.get_seq_length() == 0 # only replace empty caches
941
+ # Need to use custom cache, detect and replace HF cache if generate injects it
942
+ if isinstance(past_key_values, StaticCache):
943
+ past_key_values = HuginnStaticCache(
944
+ max_length=getattr(self.generation_config, "max_length", self.config.block_size),
945
+ max_num_steps=4 + kwargs.get("num_steps", self.config.mean_recurrence) * 4,
946
+ num_heads=self.config.num_key_value_heads,
947
+ hidden_dim=self.config.n_embd // self.config.num_attention_heads,
948
+ dtype=torch.bfloat16,
949
+ device=input_ids.device,
950
+ lookup_strategy=cache_lookup_strategy,
951
+ )
952
+ else:
953
+ past_key_values = HuginnDynamicCache(lookup_strategy=cache_lookup_strategy)
954
+ model_inputs["past_key_values"] = past_key_values if kwargs["use_cache"] else None
955
+ input_ids = input_ids[:, cache_position] # type: ignore
956
+
957
+ model_inputs["input_ids"] = input_ids.clone(memory_format=torch.contiguous_format)
958
+ if cache_position is None:
959
+ position_ids = torch.arange(current_input_length)[None, :].to(input_ids.device)
960
+ model_inputs["position_ids"] = position_ids[:, -current_input_length:].clone(
961
+ memory_format=torch.contiguous_format
962
+ ) # some form of position_ids is a critical argument for the model to correctly apply rope!
963
+
964
+ # forward all other entries
965
+ for key, value in kwargs.items():
966
+ if key not in model_inputs:
967
+ model_inputs[key] = value
968
+ return model_inputs
969
+
970
+ @torch.no_grad()
971
+ def generate(self, *args, **kwargs):
972
+ """Dispatcher - use HF generate in all normal cases."""
973
+ self.generation_config = args[1] if len(args) > 1 else self.generation_config
974
+ if any(k in kwargs for k in ("criterion", "exit_threshold")):
975
+ # print("Dispatching to custom generate_adaptive function call")
976
+ return self.generate_with_adaptive_compute(*args, **kwargs)
977
+ elif "continuous_compute" in kwargs:
978
+ # print("Dispatching to custom generate_minimal function call")
979
+ return self.generate_minimal(*args, **kwargs)
980
+ else:
981
+ return super().generate(*args, **kwargs)
982
+
983
+ @torch.no_grad()
984
+ def _prep_generate_args(
985
+ self,
986
+ input_ids: torch.Tensor,
987
+ generation_config: Optional[GenerationConfig] = None, # type: ignore
988
+ cache_lookup_strategy: str = "full",
989
+ model_kwargs: dict = {},
990
+ ):
991
+ # Setup
992
+ if generation_config is None:
993
+ generation_config: GenerationConfig = self.generation_config # type: ignore
994
+ if "max_new_tokens" in model_kwargs:
995
+ max_new_tokens = model_kwargs["max_new_tokens"]
996
+ if "max_length" in model_kwargs:
997
+ max_new_tokens = min(max_new_tokens, model_kwargs["max_length"] - input_ids.shape[1])
998
+ else:
999
+ max_length = model_kwargs.get("max_length", generation_config.max_length)
1000
+ max_new_tokens = max_length - input_ids.shape[1]
1001
+
1002
+ if "cache_implementation" not in model_kwargs or model_kwargs["cache_implementation"] == "dynamic":
1003
+ model_kwargs["past_key_values"] = HuginnDynamicCache(lookup_strategy=cache_lookup_strategy)
1004
+ else:
1005
+ model_kwargs["past_key_values"] = HuginnStaticCache(
1006
+ max_length=max_length,
1007
+ max_num_steps=4 + model_kwargs.get("num_steps", self.config.mean_recurrence) * 4,
1008
+ num_heads=self.config.num_key_value_heads,
1009
+ hidden_dim=self.config.n_embd // self.config.num_attention_heads,
1010
+ batch_size=input_ids.shape[0],
1011
+ dtype=torch.bfloat16,
1012
+ device=input_ids.device,
1013
+ lookup_strategy=cache_lookup_strategy,
1014
+ )
1015
+ model_kwargs["use_cache"] = True
1016
+ model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
1017
+ return model_kwargs, generation_config, max_new_tokens
1018
+
1019
+ @torch.no_grad()
1020
+ def generate_minimal(
1021
+ self,
1022
+ input_ids: torch.Tensor,
1023
+ generation_config: Optional[GenerationConfig] = None, # type: ignore
1024
+ tokenizer=None,
1025
+ streamer=None,
1026
+ continuous_compute=False, # warm-start state / continuous CoT
1027
+ init_scale: float = 1.0,
1028
+ cache_lookup_strategy: str = "full",
1029
+ **model_kwargs,
1030
+ ) -> Union[torch.Tensor, dict[str, Any]]:
1031
+ """Minimal single-sequence generation. Template for more complicated generate tasks"""
1032
+ model_kwargs, generation_config, max_new_tokens = self._prep_generate_args(
1033
+ input_ids, generation_config, cache_lookup_strategy
1034
  )
1035
+ stop_tokens = self._get_stops(generation_config, tokenizer, model_kwargs).to(input_ids.device)
1036
+ unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
1037
+
1038
+ # Set up continuous compute if enabled
1039
+ if continuous_compute:
1040
+ embedded_inputs, _ = self.embed_inputs(input_ids)
1041
+ model_kwargs["input_states"] = self.initialize_state(embedded_inputs, scale=init_scale)
1042
+
1043
+ # Generate tokens
1044
+ batch_size = input_ids.shape[0]
1045
+ for _ in range(max_new_tokens):
1046
+ # Forward pass
1047
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1048
+ outputs = self(**model_inputs, init_scale=init_scale)
1049
+
1050
+ # Get next token
1051
+ next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
1052
+ next_token = self._sample_next_token(next_token_logits, generation_config)
1053
+
1054
+ # Append token to sequence
1055
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
1056
+
1057
+ if streamer:
1058
+ streamer.put(next_token.cpu())
1059
+
1060
+ # Update model kwargs
1061
+ model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
1062
+ if continuous_compute:
1063
+ model_kwargs["input_states"] = outputs.latent_states[:, -1:, :]
1064
+
1065
+ if stop_tokens is not None:
1066
+ for i in range(batch_size):
1067
+ if unfinished_sequences[i] and next_token[i, 0].item() in stop_tokens:
1068
+ unfinished_sequences[i] = 0
1069
+ if "stopping_criteria" in model_kwargs:
1070
+ unfinished_sequences = unfinished_sequences & ~model_kwargs["stopping_criteria"](input_ids, None)
1071
+ if unfinished_sequences.max() == 0:
1072
+ break
1073
+
1074
+ if streamer:
1075
+ streamer.end()
1076
+
1077
+ if generation_config.return_dict_in_generate:
1078
+ return GenerateDecoderOnlyOutput(
1079
+ sequences=input_ids, # type: ignore
1080
+ scores=None,
1081
+ logits=None,
1082
+ attentions=None,
1083
+ hidden_states=None,
1084
+ past_key_values=model_kwargs.get("past_key_values"),
1085
+ )
1086
+ return input_ids
1087
+
1088
+ @torch.no_grad()
1089
+ def generate_with_adaptive_compute(
1090
+ self,
1091
+ input_ids: torch.Tensor,
1092
+ generation_config: Optional[GenerationConfig] = None, # type: ignore
1093
+ tokenizer=None,
1094
+ streamer=None,
1095
+ continuous_compute=False, # warm-start state / continuous CoT
1096
+ criterion="none", # off by default, turn on by choosing an exit criterion
1097
+ exit_threshold: Union[str, float, int] = "auto",
1098
+ init_scale: float = 1.0,
1099
+ cache_lookup_strategy: str = "full",
1100
+ **model_kwargs,
1101
+ ) -> Union[torch.Tensor, GenerateDecoderOnlyOutput]:
1102
+ """
1103
+ Generate tokens with adaptive compute. This is NOT the most efficient implementation.
1104
+ For batches, on each token, we iterate until the entire batch finishes.
1105
+ """
1106
+ model_kwargs, generation_config, max_new_tokens = self._prep_generate_args(
1107
+ input_ids, generation_config, cache_lookup_strategy, model_kwargs
1108
+ )
1109
+ max_steps = model_kwargs.get("num_steps", self.config.mean_recurrence)
1110
+ stop_tokens = self._get_stops(generation_config, tokenizer, model_kwargs).to(input_ids.device)
1111
+ logit_type = dict(copy=True, dtype=torch.float32, device=input_ids.device)
1112
+ batch_size = input_ids.shape[0]
1113
+ compute_steps = []
1114
+
1115
+ # Set up continuous compute if enabled
1116
+ if continuous_compute:
1117
+ embedded_inputs, _ = self.embed_inputs(input_ids)
1118
+ model_kwargs["input_states"] = self.initialize_state(embedded_inputs, scale=init_scale)
1119
+
1120
+ # Track which sequences have finished (using unfinished_sequences to match generate_minimal)
1121
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
1122
+
1123
+ # Generate tokens
1124
+ for _ in range(max_new_tokens):
1125
+ # Adaptive compute forward
1126
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1127
+ aux_inputs = {
1128
+ k: model_inputs[k] for k in ["cache_position", "past_key_values", "attention_mask"] if k in model_inputs
1129
+ }
1130
+ embedded_inputs, block_idx = self.embed_inputs(model_inputs["input_ids"], **aux_inputs)
1131
+ current_latents = (
1132
+ self.initialize_state(embedded_inputs, scale=init_scale)
1133
+ if not continuous_compute
1134
+ else model_kwargs["input_states"]
1135
+ )
1136
+
1137
+ # Initialize criterion tracking for each sequence in batch
1138
+ exit_values_per_seq = [[] for _ in range(batch_size)]
1139
+ compute_steps_per_seq = [0] * batch_size
1140
+ exit_reached = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
1141
+
1142
+ # Set up criterions based on selected strategy
1143
+ if criterion == "entropy-diff":
1144
+ entropy = torch.ones(batch_size, device=input_ids.device) * 100.0
1145
+ exit_threshold = 1e-3 if exit_threshold == "auto" else float(exit_threshold)
1146
+ elif criterion == "latent-diff":
1147
+ exit_threshold = 0.03 if exit_threshold == "auto" else float(exit_threshold)
1148
+ elif "kl" in criterion:
1149
+ V = self.config.padded_vocab_size
1150
+ log_probs = ((1 / V) * torch.ones(batch_size, V, dtype=torch.float, device=input_ids.device)).log()
1151
+ if criterion == "minp-kl":
1152
+ exit_threshold = 1e-6 if exit_threshold == "auto" else float(exit_threshold)
1153
+ else:
1154
+ exit_threshold = 5e-4 if exit_threshold == "auto" else float(exit_threshold)
1155
+ elif criterion == "argmax-stability":
1156
+ stable_for_n_steps = torch.zeros(batch_size, dtype=torch.long, device=input_ids.device)
1157
+ current_argmax = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) * -1
1158
+ exit_threshold = 5 if exit_threshold == "auto" else int(exit_threshold)
1159
+ elif criterion == "none":
1160
+ exit_threshold = 1.0 if exit_threshold == "auto" else float(exit_threshold)
1161
+ else:
1162
+ raise ValueError("Invalid adaptive compute strategy.")
1163
+
1164
+ next_token_logits = None
1165
+
1166
+ # Iterate through compute steps
1167
+ for compute_step in range(max_steps):
1168
+ prev_latents = current_latents.clone()
1169
+ current_latents, block_idx, _ = self.iterate_one_step(
1170
+ embedded_inputs,
1171
+ current_latents,
1172
+ block_idx=block_idx,
1173
+ **aux_inputs,
1174
+ current_step=compute_step,
1175
+ )
1176
+
1177
+ if _ > 0: # do not exit in prefill
1178
+ # Check exit condition for each sequence in batch
1179
+ if criterion == "entropy-diff":
1180
+ prev_entropy = entropy
1181
+ outputs = self.predict_from_latents(current_latents, **aux_inputs)
1182
+ logits: torch.Tensor = outputs.logits # type: ignore
1183
+ probs = F.softmax(logits[:, -1, :], dim=-1)
1184
+ entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1)
1185
+ exit_values = (entropy - prev_entropy).abs()
1186
+ elif criterion == "latent-diff":
1187
+ norm_diff = (prev_latents - current_latents).norm(dim=-1) / current_latents.norm(dim=-1)
1188
+ exit_values = norm_diff.mean(dim=-1)
1189
+ elif "kl" in criterion:
1190
+ outputs = self.predict_from_latents(current_latents, **aux_inputs)
1191
+ logits: torch.Tensor = outputs.logits # type: ignore
1192
+ prev_log_probs = log_probs
1193
+ if criterion == "minp-kl":
1194
+ probs = F.softmax(logits[:, -1, :].float(), dim=-1)
1195
+ max_probs = probs.max(dim=-1, keepdim=True)[0]
1196
+ probs_mask = probs < (0.1 * max_probs)
1197
+ masked_probs = probs.clone()
1198
+ masked_probs[probs_mask] = 1 / V
1199
+ probs = masked_probs / masked_probs.sum(dim=-1, keepdim=True)
1200
+ log_probs = probs.log()
1201
+ else:
1202
+ log_probs = F.log_softmax(logits[:, -1, :].float(), dim=-1)
1203
+ exit_values = F.kl_div(log_probs, prev_log_probs, reduction="none", log_target=True).sum(dim=-1)
1204
+ elif criterion == "argmax-stability":
1205
+ prev_argmax = current_argmax
1206
+ outputs = self.predict_from_latents(current_latents, **aux_inputs)
1207
+ logits: torch.Tensor = outputs.logits # type: ignore
1208
+ current_argmax = logits[:, -1, :].argmax(dim=-1)
1209
+ stable_for_n_steps = torch.where(
1210
+ current_argmax == prev_argmax, stable_for_n_steps + 1, torch.zeros_like(stable_for_n_steps)
1211
+ )
1212
+ exit_values = stable_for_n_steps
1213
+ elif criterion == "none":
1214
+ exit_values = torch.ones(batch_size, device=input_ids.device) * 2.0 * exit_threshold
1215
 
1216
+ # Record values and check exits for each sequence
1217
+ for i in range(batch_size):
1218
+ if not exit_reached[i] and unfinished_sequences[i].bool():
1219
+ exit_values_per_seq[i].append(exit_values[i].item())
1220
+
1221
+ # Check for new exits, respecting unfinished_sequences
1222
+ new_exits = (
1223
+ exit_values < exit_threshold
1224
+ if criterion != "argmax-stability"
1225
+ else exit_values >= exit_threshold
1226
+ )
1227
+ new_exits = new_exits & ~exit_reached & unfinished_sequences.bool()
1228
+
1229
+ if new_exits.any():
1230
+ exit_reached = exit_reached | new_exits
1231
+ if criterion == "latent-diff":
1232
+ # Normally we don't compute the output for latent-diff, but when there is an exit,
1233
+ # we need to compute and save the output
1234
+ outputs = self.predict_from_latents(current_latents, **aux_inputs)
1235
+ logits: torch.Tensor = outputs.logits # type: ignore
1236
+ if next_token_logits is None:
1237
+ next_token_logits = logits[:, -1, :].to(**logit_type) # type: ignore
1238
+ else:
1239
+ for i in range(batch_size):
1240
+ if new_exits[i]:
1241
+ next_token_logits[i] = logits[i, -1, :].to(**logit_type) # type: ignore
1242
+ for i in range(batch_size):
1243
+ if new_exits[i]:
1244
+ compute_steps_per_seq[i] = compute_step + 1
1245
+
1246
+ # If all sequences have exited or finished, break early
1247
+ if (exit_reached | ~unfinished_sequences.bool()).all():
1248
+ break
1249
+ # This else is if the for loop finished without breaking
1250
+ else:
1251
+ outputs = self.predict_from_latents(current_latents, **aux_inputs)
1252
+
1253
+ # For sequences that didn't exit early, use the final logits
1254
+ if next_token_logits is None:
1255
+ next_token_logits = outputs.logits[:, -1, :].to(**logit_type) # type: ignore
1256
+ else:
1257
+ for i in range(batch_size):
1258
+ if not exit_reached[i] and unfinished_sequences[i].bool():
1259
+ next_token_logits[i] = outputs.logits[i, -1, :].to(**logit_type) # type: ignore
1260
+ compute_steps_per_seq[i] = max_steps
1261
+
1262
+ # Save latent states for continuous compute if enabled
1263
+ if continuous_compute:
1264
+ model_kwargs["input_states"] = current_latents[:, -1:, :]
1265
+
1266
+ # Record compute steps for this token generation
1267
+ compute_steps.append([compute_steps_per_seq, exit_values_per_seq])
1268
+
1269
+ # Sample or select next token based on generation config
1270
+ next_token = self._sample_next_token(next_token_logits, generation_config)
1271
+
1272
+ # Append token to sequence
1273
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
1274
+
1275
+ if streamer:
1276
+ streamer.put(next_token.cpu())
1277
+
1278
+ # Update model kwargs for next iteration
1279
+ model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
1280
+
1281
+ # Check for stop tokens and update unfinished sequences
1282
+ for i in range(batch_size):
1283
+ if (
1284
+ unfinished_sequences[i].bool()
1285
+ and stop_tokens is not None
1286
+ and next_token[i, 0].item() in stop_tokens
1287
+ ):
1288
+ unfinished_sequences[i] = 0
1289
+
1290
+ # Apply any custom stopping criteria
1291
+ if "stopping_criteria" in model_kwargs:
1292
+ unfinished_sequences = unfinished_sequences & ~model_kwargs["stopping_criteria"](input_ids, None)
1293
+
1294
+ # Break if all sequences are finished
1295
+ if unfinished_sequences.max() == 0:
1296
+ break
1297
+
1298
+ if streamer:
1299
+ streamer.end()
1300
+
1301
+ if generation_config.return_dict_in_generate:
1302
+ return GenerateDecoderOnlyOutput(
1303
+ sequences=input_ids, # type: ignore
1304
+ scores=compute_steps, # type: ignore
1305
+ logits=None,
1306
+ attentions=None,
1307
+ hidden_states=None,
1308
+ past_key_values=model_kwargs.get("past_key_values"),
1309
+ )
1310
+ return input_ids
1311
+
1312
+ def _get_stops(self, generation_config, tokenizer, model_kwargs):
1313
+ stop_tokens = {65504, 65505, 65508} # begin_text, end_text, end_turn
1314
+ if generation_config.eos_token_id is not None:
1315
+ stop_tokens.add(generation_config.eos_token_id)
1316
+ if "stopping_criteria" in model_kwargs and tokenizer is None:
1317
+ tokenizer = model_kwargs["stopping_criteria"][0].tokenizer
1318
+ if hasattr(generation_config, "stop_strings") and tokenizer and generation_config.stop_strings:
1319
+ for s in generation_config.stop_strings:
1320
+ token_id = tokenizer(s, add_special_tokens=False)["input_ids"][0]
1321
+ stop_tokens.add(token_id)
1322
+ return torch.tensor(list(stop_tokens))
1323
+
1324
+ def _sample_next_token(self, next_token_logits, generation_config):
1325
+ """Helper function to sample the next token."""
1326
+ if generation_config.do_sample:
1327
+ if generation_config.temperature:
1328
+ next_token_logits = next_token_logits.float() / generation_config.temperature
1329
+
1330
+ probs = F.softmax(next_token_logits, dim=-1)
1331
+
1332
+ # Apply top_k
1333
+ if generation_config.top_k:
1334
+ top_k_values, _ = torch.topk(probs, generation_config.top_k, dim=-1)
1335
+ min_values = top_k_values[:, -1].unsqueeze(-1).expand_as(probs)
1336
+ probs = torch.where(probs < min_values, torch.zeros_like(probs), probs)
1337
+
1338
+ # Apply top_p (nucleus sampling)
1339
+ if generation_config.top_p:
1340
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
1341
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
1342
+
1343
+ # Create mask for probs to keep
1344
+ remove_indices = cumulative_probs > generation_config.top_p
1345
+ remove_indices[:, 0] = False # Keep at least the top probability
1346
+
1347
+ # Convert sorted indices mask back to original indices mask
1348
+ mask = torch.zeros_like(probs, dtype=torch.bool)
1349
+ for i in range(probs.shape[0]):
1350
+ mask[i, sorted_indices[i, remove_indices[i]]] = True
1351
+
1352
+ probs = torch.where(mask, torch.zeros_like(probs), probs)
1353
+
1354
+ # Apply min_p
1355
+ if generation_config.min_p:
1356
+ max_probs = probs.max(dim=-1, keepdim=True)[0]
1357
+ min_p_threshold = generation_config.min_p * max_probs
1358
+ probs = torch.where(probs < min_p_threshold, torch.zeros_like(probs), probs)
1359
+
1360
+ # Renormalize probabilities
1361
+ probs = probs / probs.sum(dim=-1, keepdim=True).clamp(min=1e-10)
1362
+
1363
+ # Sample from the distribution
1364
+ return torch.multinomial(probs, num_samples=1)
1365
+ else:
1366
+ return torch.argmax(next_token_logits, dim=-1, keepdim=True)
1367
+
1368
+ @torch.no_grad()
1369
+ def generate_speculative(
1370
+ self,
1371
+ input_ids: torch.Tensor,
1372
+ generation_config: Optional[GenerationConfig] = None, # type: ignore
1373
+ tokenizer=None,
1374
+ streamer=None,
1375
+ continuous_compute=False, # warm-start state / continuous CoT
1376
+ init_scale: float = 1.0,
1377
+ cache_lookup_strategy: str = "full",
1378
+ draft_steps=32,
1379
+ lookahead_for_draft=8,
1380
+ verification_threshold=1,
1381
+ num_steps: int = 32, # intercept deliberately
1382
+ **model_kwargs,
1383
+ ) -> Union[torch.Tensor, dict[str, Any]]:
1384
+ """Batched speculative decoding with per-sequence acceptance."""
1385
+ assert lookahead_for_draft > 0
1386
+ pad_id = 65509
1387
+ model_kwargs, generation_config, max_new_tokens = self._prep_generate_args(
1388
+ input_ids, generation_config, cache_lookup_strategy, model_kwargs
1389
  )
1390
+ stop_tokens = self._get_stops(generation_config, tokenizer, model_kwargs).to(input_ids.device)
1391
+ unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
1392
+
1393
+ # Set up continuous compute if enabled
1394
+ if continuous_compute:
1395
+ embedded_inputs, _ = self.embed_inputs(input_ids)
1396
+ model_kwargs["input_states"] = self.initialize_state(embedded_inputs, scale=init_scale)
1397
+
1398
+ tokens_generated = 0
1399
+ # Prefill cache with full num_steps
1400
+ if model_kwargs["past_key_values"].get_seq_length() == 0:
1401
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1402
+ outputs = self(**model_inputs, num_steps=num_steps, init_scale=init_scale)
1403
+ next_token = self._sample_next_token(
1404
+ outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32), generation_config
1405
+ )
1406
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
1407
+ tokens_generated += 1
1408
+ if streamer:
1409
+ streamer.put(next_token.cpu())
1410
+ model_kwargs["cache_position"] = torch.as_tensor(
1411
+ [model_inputs["past_key_values"].get_seq_length()], device=input_ids.device
1412
+ )
1413
+ if continuous_compute:
1414
+ model_kwargs["input_states"] = outputs.latent_states[:, -1:, :]
1415
+
1416
+ # Generate tokens
1417
+ batch_size, prefix_seq_len = input_ids.shape[0], input_ids.shape[1]
1418
+ accepted_tokens = []
1419
+
1420
+ while tokens_generated < max_new_tokens:
1421
+ ### Run the next draft ####
1422
+ drafted_inputs = input_ids.clone()
1423
+ current_len = input_ids.shape[1]
1424
+
1425
+ for _ in range(lookahead_for_draft):
1426
+ model_inputs = self.prepare_inputs_for_generation(drafted_inputs, **model_kwargs)
1427
+ outputs = self(**model_inputs, num_steps=draft_steps, init_scale=init_scale)
1428
+ next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32)
1429
+ next_token = self._sample_next_token(next_token_logits, generation_config)
1430
+ drafted_inputs = torch.cat([drafted_inputs, next_token], dim=-1)
1431
+ model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
1432
+ if continuous_compute:
1433
+ model_kwargs["input_states"] = outputs.latent_states[:, -1:, :]
1434
+
1435
+ model_kwargs["past_key_values"].clear_last_k_entries(lookahead_for_draft)
1436
+
1437
+ ## Verify drafted tokens ###
1438
+ model_kwargs["cache_position"] = torch.arange(
1439
+ current_len - 1, current_len + lookahead_for_draft - 1, device=input_ids.device
1440
+ )
1441
+ model_inputs = self.prepare_inputs_for_generation(drafted_inputs, **model_kwargs)
1442
+ outputs = self(**model_inputs, num_steps=num_steps, init_scale=init_scale)
1443
+ verified_next_token_preds = outputs.logits.argmax(dim=-1)
1444
+
1445
+ if verification_threshold >= 1:
1446
+ mismatched_tokens = (
1447
+ verified_next_token_preds[:, -lookahead_for_draft:] != drafted_inputs[:, current_len:]
1448
+ )
1449
+ not_all_matched, first_mismatch = torch.max(mismatched_tokens, dim=1)
1450
+ else:
1451
+ verified_logits = outputs.logits[:, -lookahead_for_draft:, :]
1452
+ verified_probs = F.softmax(verified_logits, dim=-1)
1453
+ drafted_token_probs = torch.gather(
1454
+ verified_probs, -1, drafted_inputs[:, current_len:].unsqueeze(-1)
1455
+ ).squeeze(-1)
1456
+ max_probs = verified_probs.max(dim=-1)[0]
1457
+ verification_passed = drafted_token_probs >= verification_threshold * max_probs
1458
+ not_all_matched, first_mismatch = torch.max(~verification_passed, dim=1)
1459
+
1460
+ # Per-sequence acceptance handling
1461
+ acceptance_lengths = torch.where(not_all_matched, first_mismatch, lookahead_for_draft)
1462
+
1463
+ # Build next_tokens for each sequence
1464
+ next_tokens_batch = []
1465
+ for i in range(batch_size):
1466
+ seq_acceptance = acceptance_lengths[i].item()
1467
+ if not_all_matched[i] and seq_acceptance < lookahead_for_draft:
1468
+ # Accept up to mismatch + sample final token
1469
+ accepted_part = drafted_inputs[i : i + 1, current_len : current_len + seq_acceptance]
1470
+ final_token_logits = outputs.logits[i : i + 1, seq_acceptance, :].to(copy=True, dtype=torch.float32)
1471
+ final_token = self._sample_next_token(final_token_logits, generation_config)
1472
+ seq_tokens = torch.cat([accepted_part, final_token], dim=-1) if seq_acceptance > 0 else final_token
1473
+ else:
1474
+ # Accept all drafted tokens
1475
+ seq_tokens = drafted_inputs[i : i + 1, current_len : current_len + seq_acceptance]
1476
+ next_tokens_batch.append(seq_tokens)
1477
+
1478
+ # Clean up KV cache - only if any sequence had mismatches
1479
+ if not_all_matched.any():
1480
+ min_first_mismatch = first_mismatch.min().item()
1481
+ model_inputs["past_key_values"].clear_last_k_entries(lookahead_for_draft - min_first_mismatch - 1)
1482
+
1483
+ # Concatenate accepted tokens to input_ids
1484
+ batch_accepted_counts = [tokens.shape[1] for tokens in next_tokens_batch]
1485
+ max_len = max(batch_accepted_counts)
1486
+ padded_tokens = [
1487
+ torch.cat(
1488
+ [
1489
+ tokens,
1490
+ pad_id * torch.ones((1, max_len - tokens.shape[1]), dtype=tokens.dtype, device=tokens.device),
1491
+ ],
1492
+ dim=-1,
1493
+ )
1494
+ if tokens.shape[1] < max_len
1495
+ else tokens
1496
+ for tokens in next_tokens_batch
1497
+ ]
1498
+ next_tokens = torch.cat(padded_tokens, dim=0)
1499
+ input_ids = torch.cat([input_ids, next_tokens], dim=-1)
1500
+
1501
+ accepted_tokens.append(batch_accepted_counts)
1502
+ tokens_generated += max(batch_accepted_counts)
1503
+
1504
+ if streamer:
1505
+ streamer.put(next_tokens_batch[0].cpu())
1506
+
1507
+ model_kwargs["cache_position"] = torch.as_tensor(
1508
+ [model_inputs["past_key_values"].get_seq_length()], device=input_ids.device
1509
+ )
1510
+ if continuous_compute:
1511
+ model_kwargs["input_states"] = outputs.latent_states[:, -1:, :]
1512
+
1513
+ # Check stopping conditions
1514
+ if stop_tokens is not None:
1515
+ for i in range(batch_size):
1516
+ if unfinished_sequences[i] and torch.isin(next_tokens_batch[i], stop_tokens).any():
1517
+ unfinished_sequences[i] = 0
1518
+ if "stopping_criteria" in model_kwargs:
1519
+ unfinished_sequences = unfinished_sequences & ~model_kwargs["stopping_criteria"](input_ids, None)
1520
+ if unfinished_sequences.max() == 0:
1521
+ break
1522
+
1523
+ if streamer:
1524
+ streamer.end()
1525
+
1526
+ # Cut off extraneous parts of the sequence per batch element
1527
+ if stop_tokens is not None:
1528
+ for i in range(batch_size):
1529
+ stop_positions = torch.isin(input_ids[i, prefix_seq_len:], stop_tokens).nonzero()
1530
+ if len(stop_positions) > 0:
1531
+ input_ids[i, prefix_seq_len + stop_positions[0].item() + 1 :] = pad_id
1532
+ # Trim tensor to remove columns that are pad_id across all sequences
1533
+ non_pad_mask = input_ids != pad_id
1534
+ last_real_token = non_pad_mask.any(dim=0).nonzero()
1535
+ if len(last_real_token) > 0:
1536
+ input_ids = input_ids[:, : last_real_token[-1].item() + 1]
1537
+
1538
+ if generation_config.return_dict_in_generate:
1539
+ return GenerateDecoderOnlyOutput(
1540
+ sequences=input_ids, # type: ignore
1541
+ scores=accepted_tokens, # type: ignore
1542
+ logits=None,
1543
+ attentions=None,
1544
+ hidden_states=None,
1545
+ past_key_values=model_kwargs.get("past_key_values"),
1546
+ )
1547
+ return input_ids
1548
+
1549
+ def get_stats(self, logits, x, latent_states, xk, input_embeds, num_steps_no_grad, num_steps_with_grad):
1550
+ probs = torch.softmax(logits.float(), dim=-1)
1551
+ prob_entropy = torch.where(probs > 0, -probs * probs.log(), 0).sum(dim=-1)
1552
+ residual_diff = (x - latent_states).norm(dim=-1)
1553
+ rel_residual = residual_diff / latent_states.norm(dim=-1)
1554
+ stats = {
1555
+ "entropy": prob_entropy,
1556
+ "residual_diff": residual_diff,
1557
+ "rel_residual": rel_residual,
1558
+ "num_steps_no_grad": num_steps_no_grad,
1559
+ "num_steps_with_grad": num_steps_with_grad,
1560
+ }
1561
+ return stats
1562
+
1563
+
1564
+ #################################### Utils #######################################################################
1565
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, condense_ratio: int = 1):
1566
+ with torch.autocast("cuda", enabled=False):
1567
+ inv_freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
1568
+ t = torch.arange(end, dtype=torch.float32, device=inv_freqs.device) / condense_ratio
1569
+ freqs = torch.outer(t, inv_freqs).float()
1570
+ return torch.stack([torch.cos(freqs)[None, :, None, :], torch.sin(freqs)[None, :, None, :]], dim=4)
1571
+ # equivalent to
1572
+ # freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
1573
+ # cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
1574
+
1575
+
1576
+ def apply_rotary_emb_complex_like(q: Tensor, k: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
1577
+ with torch.autocast("cuda", enabled=False):
1578
+ qk_r2 = torch.cat([q, k], dim=2).unflatten(dim=-1, sizes=(-1, 2)).float() # cast to float32 for smooth skin
1579
+ rotated_qk_r2 = torch.stack(
1580
+ [
1581
+ qk_r2[..., 0] * freqs_cis[..., 0] - qk_r2[..., 1] * freqs_cis[..., 1],
1582
+ qk_r2[..., 1] * freqs_cis[..., 0] + qk_r2[..., 0] * freqs_cis[..., 1],
1583
+ ],
1584
+ -1,
1585
+ ).flatten(3)
1586
+ rotated_qk = rotated_qk_r2
1587
+ return torch.split(rotated_qk.type_as(q), q.shape[2], dim=2) # type: ignore
1588
+
1589
+
1590
+ #################################### HF registration ############################################################
1591
+
1592
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
1593
+
1594
+ # New
1595
+ RavenConfig.register_for_auto_class()
1596
+
1597
+ RavenForCausalLM.register_for_auto_class("AutoModel")
1598
+ RavenForCausalLM.register_for_auto_class("AutoModelForCausalLM")
1599
+
1600
+ # Old?
1601
+ AutoConfig.register("huginn_raven", RavenConfig)
1602
+ AutoModel.register(RavenConfig, RavenForCausalLM)
1603
+ AutoModelForCausalLM.register(RavenConfig, RavenForCausalLM)