|
"""Minimal modeling.py file for HF compatibility and funny zero-shot experiments. Use only for inference.""" |
|
|
|
import torch |
|
import math |
|
|
|
from torch import Tensor |
|
from dataclasses import dataclass |
|
from typing import Optional, Union, Any |
|
|
|
from .raven_config_minimal import RavenConfig |
|
from transformers.cache_utils import Cache, DynamicCache |
|
|
|
|
|
from transformers import PreTrainedModel, GenerationMixin |
|
from transformers.utils import ModelOutput |
|
from transformers.generation.utils import GenerateDecoderOnlyOutput |
|
|
|
import torch.nn.functional as F |
|
from transformers import GenerationConfig |
|
|
|
|
|
class RavenPreTrainedModel(PreTrainedModel): |
|
config_class = RavenConfig |
|
base_model_prefix = "model" |
|
supports_gradient_checkpointing = True |
|
_no_split_modules = ["SandwichBlock"] |
|
_skip_keys_device_placement = ["past_key_values"] |
|
_supports_flash_attn_2 = True |
|
_supports_sdpa = True |
|
_supports_cache_class = True |
|
_supports_quantized_cache = False |
|
_supports_static_cache = False |
|
|
|
def _init_weights(self, module): |
|
if not torch.rand((1,)).is_meta: |
|
print("Random Initialization not implemented.") |
|
|
|
|
|
@dataclass |
|
class CausalLMOutputRecurrentLatents(ModelOutput): |
|
loss: Optional[torch.Tensor] = None |
|
log_ppl: Optional[torch.Tensor] = None |
|
logits: Optional[torch.Tensor] = None |
|
past_key_values: Optional[Cache] = None |
|
latent_states: Optional[torch.Tensor] = None |
|
hidden_states: Optional[torch.Tensor] = None |
|
attention_maps: Optional[dict[int, torch.Tensor]] = None |
|
stats: Optional[dict] = None |
|
|
|
|
|
|
|
|
|
|
|
class RMSNorm(torch.nn.Module): |
|
"""Saner dtype handling and slightly better for fusion""" |
|
|
|
def __init__(self, dim: int, eps: float = 1e-6): |
|
super().__init__() |
|
self.eps = eps |
|
self.weight = torch.nn.Parameter(torch.ones(dim)) |
|
|
|
def _norm(self, x): |
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
|
def forward(self, x): |
|
with torch.autocast(enabled=False, device_type=x.device.type): |
|
return self._norm(x.float()).type_as(x) * self.weight |
|
|
|
def reset_parameters(self) -> None: |
|
torch.nn.init.ones_(self.weight) |
|
|
|
|
|
class HuginnDynamicCache(DynamicCache): |
|
def __init__(self, lookup_strategy: str = "full") -> None: |
|
super().__init__() |
|
self._seen_tokens = 0 |
|
self.key_cache: dict[int, dict[int, torch.Tensor]] = {} |
|
self.value_cache: dict[int, dict[int, torch.Tensor]] = {} |
|
|
|
|
|
|
|
|
|
self.lookup_strategy = lookup_strategy |
|
|
|
def update( |
|
self, |
|
key_states: torch.Tensor, |
|
value_states: torch.Tensor, |
|
step_idx: int, |
|
lookup_strategy: Optional[str] = None, |
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
lookup_strategy = self.lookup_strategy if lookup_strategy is None else lookup_strategy |
|
if "compress-" in self.lookup_strategy and step_idx > 1: |
|
compression_stage = int(self.lookup_strategy.split("compress-")[1][1:]) |
|
if "compress-s" in self.lookup_strategy: |
|
new_step_idx = (step_idx - 2) % compression_stage + 2 |
|
else: |
|
new_step_idx = (step_idx - 2) // compression_stage + 2 |
|
|
|
step_idx = new_step_idx |
|
|
|
if step_idx not in self.key_cache: |
|
self.key_cache[step_idx] = {} |
|
self.value_cache[step_idx] = {} |
|
|
|
if step_idx == 0: |
|
self._seen_tokens += key_states.shape[-2] |
|
|
|
for idx, entry in enumerate(key_states.unbind(dim=-2)): |
|
if "compress-" not in self.lookup_strategy: |
|
assert step_idx < 0 or self._seen_tokens - key_states.shape[-2] + idx not in self.key_cache[step_idx] |
|
|
|
self.key_cache[step_idx][self._seen_tokens - key_states.shape[-2] + idx] = entry |
|
for idx, entry in enumerate(value_states.unbind(dim=-2)): |
|
self.value_cache[step_idx][self._seen_tokens - value_states.shape[-2] + idx] = entry |
|
|
|
|
|
if len(self.key_cache[step_idx]) == self._seen_tokens or self.lookup_strategy == "full": |
|
|
|
return ( |
|
torch.stack(list(self.key_cache[step_idx].values()), dim=-2), |
|
torch.stack(list(self.value_cache[step_idx].values()), dim=-2), |
|
) |
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if lookup_strategy.startswith("latest-m4"): |
|
latest_keys = [] |
|
latest_values = [] |
|
for token_pos in range(self._seen_tokens): |
|
|
|
if step_idx >= 2: |
|
|
|
valid_steps = [s for s in range(step_idx + 1) if token_pos in self.key_cache[s]] |
|
max_step = max([s for s in valid_steps if s >= 2 and s % 4 == step_idx % 4]) |
|
else: |
|
max_step = step_idx if token_pos in self.key_cache[step_idx] else 0 |
|
if max_step is None: |
|
raise ValueError(f"No cache entry found for token position {token_pos}") |
|
latest_keys.append(self.key_cache[max_step][token_pos]) |
|
latest_values.append(self.value_cache[max_step][token_pos]) |
|
return torch.stack(latest_keys, dim=-2), torch.stack(latest_values, dim=-2) |
|
elif lookup_strategy.startswith("skip"): |
|
existing_keys = [] |
|
existing_values = [] |
|
for token_pos in range(self._seen_tokens): |
|
if token_pos in self.key_cache[step_idx]: |
|
existing_keys.append(self.key_cache[step_idx][token_pos]) |
|
existing_values.append(self.value_cache[step_idx][token_pos]) |
|
return torch.stack(existing_keys, dim=-2), torch.stack(existing_values, dim=-2) |
|
elif lookup_strategy.startswith("randomized"): |
|
rand_keys = [] |
|
rand_values = [] |
|
for token_pos in range(self._seen_tokens): |
|
if step_idx < 2: |
|
max_step = step_idx if token_pos in self.key_cache[step_idx] else 0 |
|
else: |
|
curr_modulo = (step_idx - 2) % 4 + 2 |
|
valid_steps = [ |
|
s |
|
for s in range(2, step_idx + 1) |
|
if (s - 2) % 4 + 2 == curr_modulo and token_pos in self.key_cache[s] |
|
] |
|
max_step = valid_steps[torch.randint(len(valid_steps), (1,))] |
|
rand_keys.append(self.key_cache[max_step][token_pos]) |
|
rand_values.append(self.value_cache[max_step][token_pos]) |
|
return torch.stack(rand_keys, dim=-2), torch.stack(rand_values, dim=-2) |
|
else: |
|
raise ValueError(f"Unknown lookup strategy: {lookup_strategy}") |
|
|
|
def reset(self) -> None: |
|
"""Reset the cache state.""" |
|
self._seen_tokens = 0 |
|
self.key_cache.clear() |
|
self.value_cache.clear() |
|
|
|
def get_seq_length(self, step_idx: int = 0) -> int: |
|
return self._seen_tokens |
|
|
|
def get_memory_usage(self) -> float: |
|
total_bytes = 0 |
|
|
|
for step_idx in self.key_cache: |
|
|
|
key_seq_cache = self.key_cache[step_idx] |
|
for seq_idx in key_seq_cache: |
|
key_tensor = key_seq_cache[seq_idx] |
|
|
|
total_bytes += key_tensor.nelement() * key_tensor.element_size() |
|
return total_bytes * 2 / (1024 * 1024) |
|
|
|
|
|
class CausalSelfAttention(torch.nn.Module): |
|
def __init__(self, config: RavenConfig) -> None: |
|
super().__init__() |
|
self.config = config |
|
self.n_head = config.num_attention_heads |
|
self.n_kv_heads = config.num_key_value_heads |
|
self.head_dim = config.n_embd // self.n_head |
|
|
|
shape = (self.n_head + 2 * self.n_kv_heads) * self.head_dim |
|
self.chunks = [config.n_embd, self.n_kv_heads * self.head_dim, self.n_kv_heads * self.head_dim] |
|
self.Wqkv = torch.nn.Linear(config.n_embd, shape, bias=False) |
|
if config.qk_bias: |
|
self.qk_bias = torch.nn.Parameter(torch.zeros(2, 1, self.n_head, self.head_dim)) |
|
self.proj = torch.nn.Linear(config.n_embd, config.n_embd, bias=False) |
|
|
|
def forward( |
|
self, |
|
x: Tensor, |
|
freqs_cis: Tensor, |
|
step_idx: int, |
|
mask: Optional[Tensor] = None, |
|
past_key_values: Optional[Cache] = None, |
|
return_attn: bool = False, |
|
) -> tuple[Tensor, Optional[Tensor]]: |
|
B, S, E = x.shape |
|
q, k, v = self.Wqkv(x).split(self.chunks, dim=2) |
|
q = q.view(B, S, self.n_head, self.head_dim) |
|
k = k.view(B, S, self.n_kv_heads, self.head_dim) |
|
v = v.view(B, S, self.n_kv_heads, self.head_dim) |
|
|
|
if self.config.qk_bias: |
|
q_bias, k_bias = self.qk_bias.split(1, dim=0) |
|
q, k = (q + q_bias).to(q.dtype), (k + k_bias).to(q.dtype) |
|
|
|
q, k = apply_rotary_emb_complex_like(q, k, freqs_cis=freqs_cis) |
|
|
|
q = q.transpose(1, 2) |
|
k = k.transpose(1, 2) |
|
v = v.transpose(1, 2) |
|
|
|
if past_key_values is not None: |
|
k, v = past_key_values.update(k, v, step_idx) |
|
|
|
if return_attn: |
|
y, attention_map = self.compute_eager_sdpa(q, k, v, attn_mask=mask) |
|
else: |
|
y = torch.nn.functional.scaled_dot_product_attention( |
|
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=q.shape[2] > 1 |
|
) |
|
y = y.transpose(1, 2).reshape(B, S, E).contiguous() |
|
return self.proj(y), attention_map if return_attn else None |
|
|
|
def compute_eager_sdpa(self, q, k, v, attn_mask): |
|
scale = 1.0 / math.sqrt(self.head_dim) |
|
scores = torch.matmul(q, k.transpose(-2, -1)) * scale |
|
|
|
if attn_mask is not None: |
|
scores = scores + attn_mask |
|
if q.shape[2] > 1: |
|
causal_mask = torch.triu(torch.ones(q.shape[2], q.shape[2]), diagonal=1).bool() |
|
scores.masked_fill_(causal_mask.to(scores.device), float("-inf")) |
|
|
|
attention_weights = torch.nn.functional.softmax(scores, dim=-1) |
|
y = torch.matmul(attention_weights, v) |
|
return y, attention_weights.max(dim=1)[0] |
|
|
|
|
|
class GatedMLP(torch.nn.Module): |
|
def __init__(self, config: RavenConfig, in_features: int = 0) -> None: |
|
super().__init__() |
|
in_features = config.n_embd if in_features == 0 else in_features |
|
self.fc = torch.nn.Linear(in_features, config.intermediate_size * 2, bias=False) |
|
|
|
self.proj = torch.nn.Linear(config.intermediate_size, config.n_embd, bias=False) |
|
self.nonlin = torch.nn.SiLU() |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
|
x_fc_1, x_fc_2 = self.fc(x).chunk(2, dim=-1) |
|
x = self.nonlin(x_fc_1) * x_fc_2 |
|
return self.proj(x) |
|
|
|
|
|
class SandwichBlock(torch.nn.Module): |
|
expanded = False |
|
|
|
def __init__(self, config: RavenConfig, layer_id: int) -> None: |
|
super().__init__() |
|
self.norm_1 = RMSNorm(config.n_embd, eps=config.norm_eps) |
|
self.attn = CausalSelfAttention(config) |
|
self.norm_2 = RMSNorm(config.n_embd, eps=config.norm_eps) |
|
self.mlp = GatedMLP(config) |
|
self.norm_3 = RMSNorm(config.n_embd, eps=config.norm_eps) |
|
self.norm_4 = RMSNorm(config.n_embd, eps=config.norm_eps) |
|
self.layer_id = layer_id |
|
|
|
def forward( |
|
self, |
|
x: Tensor, |
|
freqs_cis: Tensor, |
|
step_idx: int, |
|
mask: Optional[Tensor] = None, |
|
past_key_values: Optional[Cache] = None, |
|
return_attn: bool = False, |
|
) -> tuple[Tensor, Optional[Tensor]]: |
|
attn_out, attn_map = self.attn(self.norm_1(x), freqs_cis, step_idx, mask, past_key_values, return_attn) |
|
x = self.norm_2(attn_out + x) |
|
x = self.norm_4(self.mlp(self.norm_3(x)) + x) |
|
return x, attn_map |
|
|
|
|
|
class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin): |
|
def __init__( |
|
self, |
|
config: RavenConfig, |
|
) -> None: |
|
super().__init__(config) |
|
self.config = config |
|
|
|
|
|
prelude = torch.nn.ModuleList(SandwichBlock(config, layer_id=i) for i in range(config.n_layers_in_prelude)) |
|
adapter = torch.nn.Linear(config.n_embd * 2, config.n_embd, bias=config.bias) |
|
core_block = torch.nn.ModuleList( |
|
SandwichBlock(config, layer_id=i + config.n_layers_in_prelude) |
|
for i in range(config.n_layers_in_recurrent_block) |
|
) |
|
o = config.n_layers_in_prelude + config.n_layers_in_recurrent_block * config.mean_recurrence |
|
coda = torch.nn.ModuleList(SandwichBlock(config, layer_id=i + o) for i in range(config.n_layers_in_coda)) |
|
|
|
self.transformer = torch.nn.ModuleDict( |
|
dict( |
|
wte=torch.nn.Embedding(config.padded_vocab_size, config.n_embd), |
|
prelude=prelude, |
|
adapter=adapter, |
|
core_block=core_block, |
|
coda=coda, |
|
ln_f=RMSNorm(config.n_embd, eps=config.norm_eps), |
|
) |
|
) |
|
self.emb_scale = config.init_values["embed_scale"] |
|
|
|
self.lm_head = torch.nn.Linear(config.n_embd, config.padded_vocab_size, bias=False) |
|
if self.config.tie_embeddings: |
|
self.lm_head.weight = self.transformer.wte.weight |
|
|
|
self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True) |
|
|
|
def _precompute_freqs_cis(self): |
|
|
|
freqs_cis = precompute_freqs_cis( |
|
self.config.n_embd // self.config.num_attention_heads, self.config.block_size, self.config.rope_base, 1 |
|
) |
|
return freqs_cis |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
input_embeds: Optional[torch.Tensor] = None, |
|
input_states: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
num_steps: Optional[torch.Tensor] = None, |
|
past_key_values: Optional[Cache] = None, |
|
output_details: dict = { |
|
"return_logits": True, |
|
"return_latents": True, |
|
"return_attention": False, |
|
"return_head": False, |
|
"return_stats": False, |
|
}, |
|
use_cache: bool = False, |
|
cache_position: Optional[torch.Tensor] = None, |
|
**kwargs, |
|
) -> CausalLMOutputRecurrentLatents: |
|
|
|
if position_ids is None and cache_position is None: |
|
freqs_cis = self.freqs_cis[:, : input_ids.shape[1]] |
|
elif position_ids is not None: |
|
freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze()) |
|
elif cache_position is not None: |
|
freqs_cis = self.freqs_cis[:, cache_position] |
|
|
|
if input_embeds is None: |
|
input_embeds = self.transformer.wte(input_ids) |
|
|
|
if self.emb_scale != 1: |
|
input_embeds = input_embeds * self.emb_scale |
|
|
|
if use_cache and past_key_values is None: |
|
past_key_values = HuginnDynamicCache() |
|
attn_maps = {} |
|
return_attn = output_details["return_attention"] |
|
|
|
|
|
for block_idx, block in enumerate(self.transformer.prelude): |
|
input_embeds, attn_map = block( |
|
input_embeds, freqs_cis, block_idx, attention_mask, past_key_values, return_attn=return_attn |
|
) |
|
attn_maps[block_idx] = attn_map |
|
|
|
|
|
x, num_steps_no_grad, num_steps_with_grad, xk, block_idx, attn_maps = self.iterate_forward( |
|
input_embeds, |
|
input_states, |
|
freqs_cis, |
|
block_idx, |
|
attention_mask, |
|
past_key_values, |
|
num_steps, |
|
attn_maps, |
|
return_attn=return_attn, |
|
) |
|
latent_states = x.clone().detach() |
|
|
|
|
|
for block_idx, block in enumerate(self.transformer.coda, start=1): |
|
x, attn_map = block(x, freqs_cis, -block_idx, attention_mask, past_key_values, return_attn=return_attn) |
|
attn_maps[-block_idx] = attn_map |
|
x = self.transformer.ln_f(x) |
|
|
|
|
|
if labels is not None: |
|
logits = self.lm_head(x).float() |
|
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.shape[-1]), labels.view(-1)) |
|
log_ppl = loss.clone().detach() |
|
else: |
|
logits = self.lm_head(x).float() |
|
loss, log_ppl = torch.as_tensor(0.0), torch.as_tensor(0.0) |
|
|
|
return CausalLMOutputRecurrentLatents( |
|
loss=loss, |
|
log_ppl=log_ppl, |
|
logits=logits if output_details["return_logits"] else None, |
|
past_key_values=past_key_values, |
|
hidden_states=x if output_details["return_head"] else None, |
|
latent_states=latent_states if output_details["return_latents"] else None, |
|
attention_maps=attn_maps if output_details["return_attention"] else None, |
|
stats=self.get_stats(logits, x, latent_states, xk, input_embeds, num_steps_no_grad, num_steps_with_grad) |
|
if output_details["return_stats"] |
|
else None, |
|
) |
|
|
|
@torch._dynamo.disable(recursive=False) |
|
def iterate_forward( |
|
self, |
|
input_embeds, |
|
input_states, |
|
freqs_cis, |
|
block_idx, |
|
mask, |
|
past_key_values: Optional[Cache] = None, |
|
num_steps: Optional[torch.Tensor] = None, |
|
attn_maps: dict = {}, |
|
return_attn: bool = False, |
|
): |
|
x = xk = self.initialize_state(input_embeds) if input_states is None else input_states.clone() |
|
if num_steps is None: |
|
num_steps_no_grad, num_steps_with_grad = self.randomized_iteration_sampler() |
|
elif hasattr(num_steps, "__len__") and len(num_steps) > 1: |
|
num_steps_no_grad, num_steps_with_grad = num_steps |
|
else: |
|
num_steps_no_grad, num_steps_with_grad = num_steps, torch.tensor(0) |
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
|
|
for step in range(num_steps_no_grad): |
|
xk = x |
|
x, block_idx, attn_maps = self.core_block_forward( |
|
xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps, return_attn |
|
) |
|
|
|
for step in range(num_steps_with_grad): |
|
xk = x |
|
x, block_idx, attn_maps = self.core_block_forward( |
|
xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps, return_attn |
|
) |
|
return self.transformer.ln_f(x), num_steps_no_grad, num_steps_with_grad, xk.detach(), block_idx, attn_maps |
|
|
|
def core_block_forward( |
|
self, |
|
x, |
|
input_embeds, |
|
freqs_cis, |
|
mask, |
|
past_key_values, |
|
block_idx: Union[torch.Tensor, int], |
|
attn_maps: dict = {}, |
|
return_attn: bool = False, |
|
): |
|
x = self.transformer.adapter(torch.cat([x, input_embeds], dim=-1)) |
|
for idx, block in enumerate(self.transformer.core_block, start=1): |
|
x, attn_map = block(x, freqs_cis, block_idx + idx, mask, past_key_values, return_attn=return_attn) |
|
attn_maps[block_idx + idx] = attn_map |
|
return x, block_idx + idx, attn_maps |
|
|
|
@torch.no_grad() |
|
def iterate_one_step( |
|
self, |
|
input_embeds, |
|
input_states, |
|
position_ids: Optional[torch.Tensor] = None, |
|
cache_position: Optional[torch.Tensor] = None, |
|
block_idx: Union[torch.Tensor, int] = 0, |
|
attention_mask: Optional[Tensor] = None, |
|
past_key_values: Optional[Cache] = None, |
|
attn_maps: dict = {}, |
|
): |
|
if position_ids is None and cache_position is None: |
|
freqs_cis = self.freqs_cis[:, : input_embeds.shape[1]] |
|
elif position_ids is not None: |
|
freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze()) |
|
elif cache_position is not None: |
|
freqs_cis = self.freqs_cis[:, cache_position] |
|
x, block_idx, attn_maps = self.core_block_forward( |
|
input_states, input_embeds, freqs_cis, attention_mask, past_key_values, block_idx, attn_maps |
|
) |
|
return x, block_idx, attn_maps |
|
|
|
def predict_from_latents( |
|
self, |
|
latents, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
cache_position: Optional[torch.Tensor] = None, |
|
past_key_values: Optional[Cache] = None, |
|
return_attn: bool = False, |
|
attn_maps: dict = {}, |
|
): |
|
if position_ids is None and cache_position is None: |
|
freqs_cis = self.freqs_cis[:, : latents.shape[1]] |
|
elif position_ids is not None: |
|
freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze()) |
|
elif cache_position is not None: |
|
freqs_cis = self.freqs_cis[:, cache_position] |
|
x = self.transformer.ln_f(latents) |
|
|
|
for block_idx, block in enumerate(self.transformer.coda, start=1): |
|
x, attn_map = block(x, freqs_cis, -block_idx, attention_mask, past_key_values) |
|
attn_maps[block_idx] = attn_map |
|
x = self.transformer.ln_f(x) |
|
|
|
logits = self.lm_head(x).float() |
|
|
|
return CausalLMOutputRecurrentLatents( |
|
loss=torch.as_tensor(0.0), |
|
log_ppl=torch.as_tensor(0.0), |
|
logits=logits, |
|
past_key_values=past_key_values, |
|
attention_maps=attn_maps if len(attn_maps) > 0 else None, |
|
) |
|
|
|
def embed_inputs( |
|
self, |
|
input_ids: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
past_key_values: Optional[Cache] = None, |
|
use_cache: bool = False, |
|
cache_position: Optional[torch.Tensor] = None, |
|
return_attn: bool = False, |
|
**kwargs, |
|
) -> tuple[torch.Tensor, int, dict[int, Tensor]]: |
|
|
|
if position_ids is None and cache_position is None: |
|
freqs_cis = self.freqs_cis[:, : input_ids.shape[1]] |
|
elif position_ids is not None: |
|
freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze()) |
|
elif cache_position is not None: |
|
freqs_cis = self.freqs_cis[:, cache_position] |
|
|
|
input_embeds = self.transformer.wte(input_ids) |
|
|
|
if self.emb_scale != 1: |
|
input_embeds = input_embeds * self.emb_scale |
|
|
|
if use_cache and past_key_values is None: |
|
past_key_values = HuginnDynamicCache() |
|
|
|
|
|
attn_maps = {} |
|
for block_idx, block in enumerate(self.transformer.prelude): |
|
input_embeds, attn_maps = block( |
|
input_embeds, freqs_cis, block_idx, attention_mask, past_key_values, return_attn |
|
) |
|
return input_embeds, block_idx, attn_maps |
|
|
|
@torch._dynamo.disable(recursive=False) |
|
def randomized_iteration_sampler(self) -> tuple[torch.Tensor, torch.Tensor]: |
|
"""Outputs are long tensors so that they can be passed through compiled functions""" |
|
t = max(self.config.mean_recurrence - self.config.mean_backprop_depth, 0) |
|
s = self.config.mean_backprop_depth |
|
if self.training: |
|
sigma = 0.5 |
|
mu = math.log(t + s) - (sigma**2 / 2) |
|
rate = torch.zeros((1,)).log_normal_(mean=mu, std=sigma) |
|
p = torch.poisson(torch.tensor([rate], dtype=torch.float)) + 1 |
|
n = torch.clamp(p - s, min=0) |
|
k = torch.as_tensor(torch.minimum(torch.as_tensor(s), p)) |
|
else: |
|
n, k = torch.as_tensor(self.config.mean_recurrence), torch.as_tensor(0) |
|
|
|
return n.to(dtype=torch.long), k.to(dtype=torch.long) |
|
|
|
def initialize_state(self, input_embeds, deterministic: bool = False): |
|
x = torch.randn_like(input_embeds) |
|
std = self.config.init_values["std"] |
|
torch.nn.init.trunc_normal_(x, mean=0.0, std=std, a=-3 * std, b=3 * std) |
|
if self.emb_scale != 1: |
|
x = x * self.emb_scale |
|
return x if not deterministic else x.zero_() |
|
|
|
def prepare_inputs_for_generation( |
|
self, |
|
input_ids: torch.LongTensor, |
|
past_key_values: Optional[Cache] = None, |
|
attention_mask: Optional[torch.LongTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
**kwargs, |
|
): |
|
model_inputs = {} |
|
model_inputs["cache_position"] = cache_position |
|
current_input_length = input_ids.shape[1] |
|
if past_key_values is not None: |
|
if type(past_key_values) != HuginnDynamicCache: |
|
|
|
assert past_key_values.get_seq_length() == 0 |
|
past_key_values = HuginnDynamicCache() |
|
model_inputs["past_key_values"] = past_key_values if kwargs["use_cache"] else None |
|
input_ids = input_ids[:, cache_position] |
|
model_inputs["input_ids"] = input_ids.clone(memory_format=torch.contiguous_format) |
|
|
|
if cache_position is None: |
|
position_ids = torch.arange(current_input_length)[None, :].to(input_ids.device) |
|
model_inputs["position_ids"] = position_ids[:, -current_input_length:].clone( |
|
memory_format=torch.contiguous_format |
|
) |
|
|
|
|
|
for key, value in kwargs.items(): |
|
if key not in model_inputs: |
|
model_inputs[key] = value |
|
return model_inputs |
|
|
|
@torch.no_grad() |
|
def generate(self, *args, **kwargs): |
|
"""Dispatcher - use HF generate in all normal cases.""" |
|
if any( |
|
k in kwargs |
|
for k in ("continuous_compute", "latent_dampening", "criterion", "exit_threshold", "cache_kwargs") |
|
): |
|
print("Dispatching to custom generate function call") |
|
return self.generate_with_adaptive_compute(*args, **kwargs) |
|
else: |
|
return super().generate(*args, **kwargs) |
|
|
|
@torch.no_grad() |
|
def generate_minimal( |
|
self, |
|
input_ids: torch.LongTensor, |
|
generation_config: Optional[GenerationConfig] = None, |
|
tokenizer=None, |
|
streamer=None, |
|
continuous_compute=False, |
|
cache_kwargs: dict = {}, |
|
**model_kwargs, |
|
) -> Union[torch.Tensor, dict[str, Any]]: |
|
"""Minimal single-sequence generation. Template for more complicated generate tasks""" |
|
|
|
if generation_config is None: |
|
generation_config: GenerationConfig = self.generation_config |
|
model_kwargs["past_key_values"] = HuginnDynamicCache(**cache_kwargs) |
|
model_kwargs["use_cache"] = True |
|
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) |
|
stop_tokens = self._get_stops(generation_config, tokenizer).to(input_ids.device) |
|
if continuous_compute: |
|
embedded_inputs, _, _ = self.embed_inputs(input_ids) |
|
model_kwargs["input_states"] = self.initialize_state(embedded_inputs) |
|
|
|
for _ in range(generation_config.max_length - input_ids.shape[1]): |
|
|
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) |
|
outputs = self(**model_inputs) |
|
next_token_logits = outputs.logits[0, -1, :] |
|
if continuous_compute: |
|
current_last_latent = outputs.latent_states[:, -1:, :] |
|
|
|
|
|
if generation_config.do_sample: |
|
if generation_config.temperature: |
|
next_token_logits = next_token_logits / generation_config.temperature |
|
|
|
probs = F.softmax(next_token_logits, dim=-1) |
|
|
|
|
|
if generation_config.top_k: |
|
top_k_probs, _ = torch.topk(probs, generation_config.top_k) |
|
probs[probs < top_k_probs[-1]] = 0 |
|
|
|
if generation_config.top_p: |
|
sorted_probs = torch.sort(probs, descending=True)[0] |
|
cumsum = torch.cumsum(sorted_probs, dim=-1) |
|
probs[cumsum > generation_config.top_p] = 0 |
|
|
|
if generation_config.min_p: |
|
probs[probs < generation_config.min_p * probs.max()] = 0 |
|
|
|
probs = probs / probs.sum() |
|
next_token = torch.multinomial(probs, num_samples=1) |
|
else: |
|
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) |
|
|
|
input_ids = torch.cat([input_ids, next_token[None, :]], dim=-1) |
|
|
|
if streamer: |
|
streamer.put(next_token.cpu()) |
|
|
|
|
|
model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs) |
|
if continuous_compute: |
|
model_kwargs["input_states"] = current_last_latent |
|
|
|
|
|
if stop_tokens is not None and next_token in stop_tokens: |
|
break |
|
|
|
if streamer: |
|
streamer.end() |
|
|
|
if generation_config.return_dict_in_generate: |
|
return GenerateDecoderOnlyOutput( |
|
sequences=input_ids, |
|
scores=None, |
|
logits=None, |
|
attentions=None, |
|
hidden_states=None, |
|
past_key_values=model_kwargs.get("past_key_values"), |
|
) |
|
return input_ids |
|
|
|
@torch.no_grad() |
|
def generate_with_adaptive_compute( |
|
self, |
|
input_ids: torch.LongTensor, |
|
generation_config: Optional[GenerationConfig] = None, |
|
tokenizer=None, |
|
streamer=None, |
|
continuous_compute=False, |
|
latent_dampening=False, |
|
criterion="entropy-diff", |
|
exit_threshold: Union[str, float, int] = "auto", |
|
cache_kwargs: dict = {}, |
|
**model_kwargs, |
|
) -> Union[torch.Tensor, GenerateDecoderOnlyOutput]: |
|
"""Minimal single-sequence generation. Template for more complicated generate tasks""" |
|
|
|
if generation_config is None: |
|
generation_config: GenerationConfig = self.generation_config |
|
model_kwargs["past_key_values"] = HuginnDynamicCache(**cache_kwargs) |
|
model_kwargs["use_cache"] = True |
|
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) |
|
stop_tokens = self._get_stops(generation_config, tokenizer).to(input_ids.device) |
|
if continuous_compute: |
|
embedded_inputs, _, _ = self.embed_inputs(input_ids) |
|
current_last_latent = self.initialize_state(embedded_inputs) |
|
compute_steps = [] |
|
|
|
|
|
for step in range(generation_config.max_length - input_ids.shape[1]): |
|
|
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) |
|
aux_inputs = { |
|
k: model_inputs[k] for k in ["cache_position", "past_key_values", "attention_mask"] if k in model_inputs |
|
} |
|
embedded_inputs, block_idx, _ = self.embed_inputs(model_inputs["input_ids"], **aux_inputs) |
|
if not continuous_compute: |
|
current_latents = self.initialize_state(embedded_inputs, deterministic=False) |
|
else: |
|
current_latents = current_last_latent |
|
|
|
|
|
if criterion == "entropy-diff": |
|
entropy = torch.tensor(100.0, device=input_ids.device) |
|
exit_threshold = 1e-3 if exit_threshold == "auto" else float(exit_threshold) |
|
elif criterion in ["latent-diff", "none"]: |
|
exit_threshold = 0.03 if exit_threshold == "auto" else float(exit_threshold) |
|
elif "kl" in criterion: |
|
V = self.config.padded_vocab_size |
|
log_probs = (1 / V * torch.ones(V, device=input_ids.device)).log() |
|
if criterion == "minp-kl": |
|
exit_threshold = 1e-6 if exit_threshold == "auto" else float(exit_threshold) |
|
else: |
|
exit_threshold = 5e-4 if exit_threshold == "auto" else float(exit_threshold) |
|
elif criterion == "argmax-stability": |
|
stable_for_n_steps = 0 |
|
current_argmax = torch.tensor(-1, dtype=torch.long, device=input_ids.device) |
|
exit_threshold = 5 if exit_threshold == "auto" else int(exit_threshold) |
|
else: |
|
raise ValueError("Invalid adaptive compute strategy.") |
|
|
|
all_latents = [] |
|
exit_values = [] |
|
for compute_step in range(model_inputs["num_steps"]): |
|
prev_latents = current_latents.clone() |
|
current_latents, block_idx, _ = self.iterate_one_step( |
|
embedded_inputs, current_latents, block_idx=block_idx, **aux_inputs |
|
) |
|
all_latents.append(current_latents if latent_dampening else None) |
|
if step > 0: |
|
if criterion == "entropy-diff": |
|
prev_entropy = entropy.clone() |
|
outputs = self.predict_from_latents(current_latents, **aux_inputs) |
|
probs = F.softmax(outputs.logits[:, -1, :], dim=-1) |
|
entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1).mean() |
|
entropy_diff = (entropy - prev_entropy).abs() |
|
exit_values.append(entropy_diff.item()) |
|
if entropy_diff < exit_threshold: |
|
break |
|
elif criterion == "latent-diff": |
|
norm_diff = (prev_latents - current_latents).norm() / current_latents.norm() |
|
exit_values.append(norm_diff.item()) |
|
if norm_diff < exit_threshold: |
|
break |
|
elif criterion == "kl": |
|
prev_log_probs = log_probs.clone() |
|
outputs = self.predict_from_latents(current_latents, **aux_inputs) |
|
log_probs = F.log_softmax(outputs.logits[:, -1, :], dim=-1) |
|
kl = F.kl_div(log_probs, prev_log_probs, reduction="none", log_target=True).sum(dim=-1) |
|
exit_values.append(kl.item()) |
|
if kl < exit_threshold: |
|
break |
|
elif criterion == "minp-kl": |
|
prev_log_probs = log_probs.clone() |
|
outputs = self.predict_from_latents(current_latents, **aux_inputs) |
|
probs = F.softmax(outputs.logits[:, -1, :], dim=-1) |
|
probs[probs < 0.1 * probs.max()] = 1 / V |
|
probs = probs / probs.sum() |
|
log_probs = probs.log() |
|
kl = F.kl_div(log_probs, prev_log_probs, reduction="none", log_target=True).sum(dim=-1) |
|
exit_values.append(kl.item()) |
|
if kl < exit_threshold: |
|
break |
|
elif criterion == "argmax-stability": |
|
prev_argmax = current_argmax.clone() |
|
outputs = self.predict_from_latents(current_latents, **aux_inputs) |
|
current_argmax = outputs.logits[0, -1, :].argmax(dim=-1) |
|
if current_argmax == prev_argmax: |
|
stable_for_n_steps += 1 |
|
else: |
|
stable_for_n_steps = 0 |
|
exit_values.append(stable_for_n_steps) |
|
if stable_for_n_steps >= exit_threshold: |
|
break |
|
elif criterion == "none": |
|
pass |
|
|
|
else: |
|
if not latent_dampening: |
|
outputs = self.predict_from_latents(current_latents, **aux_inputs) |
|
else: |
|
dampened_latents = torch.sum(torch.cat(all_latents, dim=0), dim=0, keepdim=True) |
|
outputs = self.predict_from_latents(dampened_latents, **aux_inputs) |
|
compute_steps.append([compute_step + 1, exit_values]) |
|
|
|
next_token_logits = outputs.logits[0, -1, :] |
|
if continuous_compute: |
|
current_last_latent = current_latents[:, -1:, :] |
|
|
|
|
|
if generation_config.do_sample: |
|
if generation_config.temperature: |
|
next_token_logits = next_token_logits / generation_config.temperature |
|
|
|
probs = F.softmax(next_token_logits, dim=-1) |
|
|
|
if generation_config.top_k: |
|
top_k_probs, _ = torch.topk(probs, generation_config.top_k) |
|
probs[probs < top_k_probs[-1]] = 0 |
|
|
|
if generation_config.top_p: |
|
sorted_probs = torch.sort(probs, descending=True)[0] |
|
cumsum = torch.cumsum(sorted_probs, dim=-1) |
|
probs[cumsum > generation_config.top_p] = 0 |
|
|
|
if generation_config.min_p: |
|
probs[probs < generation_config.min_p * probs.max()] = 0 |
|
|
|
probs = probs / probs.sum() |
|
next_token = torch.multinomial(probs, num_samples=1) |
|
else: |
|
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) |
|
|
|
input_ids = torch.cat([input_ids, next_token[None, :]], dim=-1) |
|
|
|
if streamer: |
|
streamer.put(next_token.cpu()) |
|
|
|
|
|
model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs) |
|
|
|
|
|
if stop_tokens is not None and next_token in stop_tokens: |
|
break |
|
|
|
if streamer: |
|
streamer.end() |
|
|
|
if generation_config.return_dict_in_generate: |
|
return GenerateDecoderOnlyOutput( |
|
sequences=input_ids, |
|
scores=compute_steps, |
|
logits=None, |
|
attentions=None, |
|
hidden_states=None, |
|
past_key_values=model_kwargs.get("past_key_values"), |
|
) |
|
return input_ids |
|
|
|
def _get_stops(self, generation_config, tokenizer): |
|
stop_tokens = set() |
|
if generation_config.eos_token_id is not None: |
|
stop_tokens.add(generation_config.eos_token_id) |
|
if hasattr(generation_config, "stop_strings") and tokenizer and generation_config.stop_strings: |
|
for s in generation_config.stop_strings: |
|
token_id = tokenizer(s, add_special_tokens=False)["input_ids"][0] |
|
stop_tokens.add(token_id) |
|
return torch.tensor(list(stop_tokens)) |
|
|
|
def get_stats(self, logits, x, latent_states, xk, input_embeds, num_steps_no_grad, num_steps_with_grad): |
|
probs = torch.softmax(logits.float(), dim=-1) |
|
prob_entropy = torch.where(probs > 0, -probs * probs.log(), 0).sum(dim=-1) |
|
residual_diff = (x - latent_states).norm(dim=-1) |
|
rel_residual = residual_diff / latent_states.norm(dim=-1) |
|
stats = { |
|
"entropy": prob_entropy, |
|
"residual_diff": residual_diff, |
|
"rel_residual": rel_residual, |
|
"num_steps_no_grad": num_steps_no_grad, |
|
"num_steps_with_grad": num_steps_with_grad, |
|
} |
|
return stats |
|
|
|
|
|
|
|
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, condense_ratio: int = 1): |
|
with torch.autocast("cuda", enabled=False): |
|
inv_freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) |
|
t = torch.arange(end, dtype=torch.float32, device=inv_freqs.device) / condense_ratio |
|
freqs = torch.outer(t, inv_freqs).float() |
|
return torch.stack([torch.cos(freqs)[None, :, None, :], torch.sin(freqs)[None, :, None, :]], dim=4) |
|
|
|
|
|
|
|
|
|
|
|
def apply_rotary_emb_complex_like(q: Tensor, k: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: |
|
with torch.autocast("cuda", enabled=False): |
|
qk_r2 = torch.cat([q, k], dim=2).unflatten(dim=-1, sizes=(-1, 2)).float() |
|
rotated_qk_r2 = torch.stack( |
|
[ |
|
qk_r2[..., 0] * freqs_cis[..., 0] - qk_r2[..., 1] * freqs_cis[..., 1], |
|
qk_r2[..., 1] * freqs_cis[..., 0] + qk_r2[..., 0] * freqs_cis[..., 1], |
|
], |
|
-1, |
|
).flatten(3) |
|
rotated_qk = rotated_qk_r2 |
|
return torch.split(rotated_qk.type_as(q), q.shape[2], dim=2) |
|
|
|
|
|
|
|
|
|
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM |
|
|
|
|
|
RavenConfig.register_for_auto_class() |
|
|
|
RavenForCausalLM.register_for_auto_class("AutoModel") |
|
RavenForCausalLM.register_for_auto_class("AutoModelForCausalLM") |
|
|
|
|
|
AutoConfig.register("huginn_raven", RavenConfig) |
|
AutoModel.register(RavenConfig, RavenForCausalLM) |
|
AutoModelForCausalLM.register(RavenConfig, RavenForCausalLM) |
|
|