Update raven_modeling_minimal.py
Browse files- raven_modeling_minimal.py +22 -8
raven_modeling_minimal.py
CHANGED
@@ -242,7 +242,6 @@ class CausalSelfAttention(torch.nn.Module):
|
|
242 |
if past_key_values is not None:
|
243 |
k, v = past_key_values.update(k, v, step_idx)
|
244 |
|
245 |
-
return_attn = False # hardcode for now
|
246 |
if return_attn:
|
247 |
y, attention_map = self.compute_eager_sdpa(q, k, v, attn_mask=mask)
|
248 |
else:
|
@@ -369,7 +368,7 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
|
|
369 |
"return_latents": True,
|
370 |
"return_attention": False,
|
371 |
"return_head": False,
|
372 |
-
"return_stats":
|
373 |
},
|
374 |
use_cache: bool = False,
|
375 |
cache_position: Optional[torch.Tensor] = None,
|
@@ -397,7 +396,7 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
|
|
397 |
# Non-recurrent prelude
|
398 |
for block_idx, block in enumerate(self.transformer.prelude):
|
399 |
input_embeds, attn_map = block(
|
400 |
-
input_embeds, freqs_cis, block_idx, attention_mask, past_key_values, return_attn
|
401 |
)
|
402 |
attn_maps[block_idx] = attn_map
|
403 |
|
@@ -411,12 +410,13 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
|
|
411 |
past_key_values,
|
412 |
num_steps,
|
413 |
attn_maps,
|
|
|
414 |
)
|
415 |
latent_states = x.clone().detach()
|
416 |
|
417 |
# Coda layers
|
418 |
for block_idx, block in enumerate(self.transformer.coda, start=1):
|
419 |
-
x, attn_map = block(x, freqs_cis, -block_idx, attention_mask, past_key_values, return_attn)
|
420 |
attn_maps[-block_idx] = attn_map
|
421 |
x = self.transformer.ln_f(x)
|
422 |
|
@@ -453,6 +453,7 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
|
|
453 |
past_key_values: Optional[Cache] = None,
|
454 |
num_steps: Optional[torch.Tensor] = None,
|
455 |
attn_maps: dict = {},
|
|
|
456 |
):
|
457 |
x = xk = self.initialize_state(input_embeds) if input_states is None else input_states.clone()
|
458 |
if num_steps is None:
|
@@ -470,13 +471,13 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
|
|
470 |
for step in range(num_steps_no_grad):
|
471 |
xk = x
|
472 |
x, block_idx, attn_maps = self.core_block_forward(
|
473 |
-
xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps
|
474 |
)
|
475 |
|
476 |
for step in range(num_steps_with_grad):
|
477 |
xk = x
|
478 |
x, block_idx, attn_maps = self.core_block_forward(
|
479 |
-
xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps
|
480 |
)
|
481 |
return self.transformer.ln_f(x), num_steps_no_grad, num_steps_with_grad, xk.detach(), block_idx, attn_maps
|
482 |
|
@@ -489,10 +490,11 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
|
|
489 |
past_key_values,
|
490 |
block_idx: Union[torch.Tensor, int],
|
491 |
attn_maps: dict = {},
|
|
|
492 |
):
|
493 |
x = self.transformer.adapter(torch.cat([x, input_embeds], dim=-1))
|
494 |
for idx, block in enumerate(self.transformer.core_block, start=1):
|
495 |
-
x, attn_map = block(x, freqs_cis, block_idx + idx, mask, past_key_values, return_attn=
|
496 |
attn_maps[block_idx + idx] = attn_map
|
497 |
return x, block_idx + idx, attn_maps
|
498 |
|
@@ -625,7 +627,7 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
|
|
625 |
model_inputs["cache_position"] = cache_position
|
626 |
current_input_length = input_ids.shape[1]
|
627 |
if past_key_values is not None:
|
628 |
-
if type(past_key_values)
|
629 |
# Need to use custom cache, detect and replace HF dynamic cache if generate injects it
|
630 |
assert past_key_values.get_seq_length() == 0
|
631 |
past_key_values = HuginnDynamicCache()
|
@@ -645,6 +647,18 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
|
|
645 |
model_inputs[key] = value
|
646 |
return model_inputs
|
647 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
648 |
@torch.no_grad()
|
649 |
def generate_minimal(
|
650 |
self,
|
|
|
242 |
if past_key_values is not None:
|
243 |
k, v = past_key_values.update(k, v, step_idx)
|
244 |
|
|
|
245 |
if return_attn:
|
246 |
y, attention_map = self.compute_eager_sdpa(q, k, v, attn_mask=mask)
|
247 |
else:
|
|
|
368 |
"return_latents": True,
|
369 |
"return_attention": False,
|
370 |
"return_head": False,
|
371 |
+
"return_stats": False,
|
372 |
},
|
373 |
use_cache: bool = False,
|
374 |
cache_position: Optional[torch.Tensor] = None,
|
|
|
396 |
# Non-recurrent prelude
|
397 |
for block_idx, block in enumerate(self.transformer.prelude):
|
398 |
input_embeds, attn_map = block(
|
399 |
+
input_embeds, freqs_cis, block_idx, attention_mask, past_key_values, return_attn=return_attn
|
400 |
)
|
401 |
attn_maps[block_idx] = attn_map
|
402 |
|
|
|
410 |
past_key_values,
|
411 |
num_steps,
|
412 |
attn_maps,
|
413 |
+
return_attn=return_attn,
|
414 |
)
|
415 |
latent_states = x.clone().detach()
|
416 |
|
417 |
# Coda layers
|
418 |
for block_idx, block in enumerate(self.transformer.coda, start=1):
|
419 |
+
x, attn_map = block(x, freqs_cis, -block_idx, attention_mask, past_key_values, return_attn=return_attn)
|
420 |
attn_maps[-block_idx] = attn_map
|
421 |
x = self.transformer.ln_f(x)
|
422 |
|
|
|
453 |
past_key_values: Optional[Cache] = None,
|
454 |
num_steps: Optional[torch.Tensor] = None,
|
455 |
attn_maps: dict = {},
|
456 |
+
return_attn: bool = False,
|
457 |
):
|
458 |
x = xk = self.initialize_state(input_embeds) if input_states is None else input_states.clone()
|
459 |
if num_steps is None:
|
|
|
471 |
for step in range(num_steps_no_grad):
|
472 |
xk = x
|
473 |
x, block_idx, attn_maps = self.core_block_forward(
|
474 |
+
xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps, return_attn
|
475 |
)
|
476 |
|
477 |
for step in range(num_steps_with_grad):
|
478 |
xk = x
|
479 |
x, block_idx, attn_maps = self.core_block_forward(
|
480 |
+
xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps, return_attn
|
481 |
)
|
482 |
return self.transformer.ln_f(x), num_steps_no_grad, num_steps_with_grad, xk.detach(), block_idx, attn_maps
|
483 |
|
|
|
490 |
past_key_values,
|
491 |
block_idx: Union[torch.Tensor, int],
|
492 |
attn_maps: dict = {},
|
493 |
+
return_attn: bool = False,
|
494 |
):
|
495 |
x = self.transformer.adapter(torch.cat([x, input_embeds], dim=-1))
|
496 |
for idx, block in enumerate(self.transformer.core_block, start=1):
|
497 |
+
x, attn_map = block(x, freqs_cis, block_idx + idx, mask, past_key_values, return_attn=return_attn)
|
498 |
attn_maps[block_idx + idx] = attn_map
|
499 |
return x, block_idx + idx, attn_maps
|
500 |
|
|
|
627 |
model_inputs["cache_position"] = cache_position
|
628 |
current_input_length = input_ids.shape[1]
|
629 |
if past_key_values is not None:
|
630 |
+
if type(past_key_values) != HuginnDynamicCache:
|
631 |
# Need to use custom cache, detect and replace HF dynamic cache if generate injects it
|
632 |
assert past_key_values.get_seq_length() == 0
|
633 |
past_key_values = HuginnDynamicCache()
|
|
|
647 |
model_inputs[key] = value
|
648 |
return model_inputs
|
649 |
|
650 |
+
@torch.no_grad()
|
651 |
+
def generate(self, *args, **kwargs):
|
652 |
+
"""Dispatcher - use HF generate in all normal cases."""
|
653 |
+
if any(
|
654 |
+
k in kwargs
|
655 |
+
for k in ("continuous_compute", "latent_dampening", "criterion", "exit_threshold", "cache_kwargs")
|
656 |
+
):
|
657 |
+
print("Dispatching to custom generate function call")
|
658 |
+
return self.generate_with_adaptive_compute(*args, **kwargs)
|
659 |
+
else:
|
660 |
+
return super().generate(*args, **kwargs)
|
661 |
+
|
662 |
@torch.no_grad()
|
663 |
def generate_minimal(
|
664 |
self,
|