JonasGeiping commited on
Commit
b93dc0f
·
verified ·
1 Parent(s): f6c3335

Update raven_modeling_minimal.py

Browse files
Files changed (1) hide show
  1. raven_modeling_minimal.py +22 -8
raven_modeling_minimal.py CHANGED
@@ -242,7 +242,6 @@ class CausalSelfAttention(torch.nn.Module):
242
  if past_key_values is not None:
243
  k, v = past_key_values.update(k, v, step_idx)
244
 
245
- return_attn = False # hardcode for now
246
  if return_attn:
247
  y, attention_map = self.compute_eager_sdpa(q, k, v, attn_mask=mask)
248
  else:
@@ -369,7 +368,7 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
369
  "return_latents": True,
370
  "return_attention": False,
371
  "return_head": False,
372
- "return_stats": True,
373
  },
374
  use_cache: bool = False,
375
  cache_position: Optional[torch.Tensor] = None,
@@ -397,7 +396,7 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
397
  # Non-recurrent prelude
398
  for block_idx, block in enumerate(self.transformer.prelude):
399
  input_embeds, attn_map = block(
400
- input_embeds, freqs_cis, block_idx, attention_mask, past_key_values, return_attn
401
  )
402
  attn_maps[block_idx] = attn_map
403
 
@@ -411,12 +410,13 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
411
  past_key_values,
412
  num_steps,
413
  attn_maps,
 
414
  )
415
  latent_states = x.clone().detach()
416
 
417
  # Coda layers
418
  for block_idx, block in enumerate(self.transformer.coda, start=1):
419
- x, attn_map = block(x, freqs_cis, -block_idx, attention_mask, past_key_values, return_attn)
420
  attn_maps[-block_idx] = attn_map
421
  x = self.transformer.ln_f(x)
422
 
@@ -453,6 +453,7 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
453
  past_key_values: Optional[Cache] = None,
454
  num_steps: Optional[torch.Tensor] = None,
455
  attn_maps: dict = {},
 
456
  ):
457
  x = xk = self.initialize_state(input_embeds) if input_states is None else input_states.clone()
458
  if num_steps is None:
@@ -470,13 +471,13 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
470
  for step in range(num_steps_no_grad):
471
  xk = x
472
  x, block_idx, attn_maps = self.core_block_forward(
473
- xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps
474
  )
475
 
476
  for step in range(num_steps_with_grad):
477
  xk = x
478
  x, block_idx, attn_maps = self.core_block_forward(
479
- xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps
480
  )
481
  return self.transformer.ln_f(x), num_steps_no_grad, num_steps_with_grad, xk.detach(), block_idx, attn_maps
482
 
@@ -489,10 +490,11 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
489
  past_key_values,
490
  block_idx: Union[torch.Tensor, int],
491
  attn_maps: dict = {},
 
492
  ):
493
  x = self.transformer.adapter(torch.cat([x, input_embeds], dim=-1))
494
  for idx, block in enumerate(self.transformer.core_block, start=1):
495
- x, attn_map = block(x, freqs_cis, block_idx + idx, mask, past_key_values, return_attn=len(attn_maps) > 0)
496
  attn_maps[block_idx + idx] = attn_map
497
  return x, block_idx + idx, attn_maps
498
 
@@ -625,7 +627,7 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
625
  model_inputs["cache_position"] = cache_position
626
  current_input_length = input_ids.shape[1]
627
  if past_key_values is not None:
628
- if type(past_key_values) == DynamicCache:
629
  # Need to use custom cache, detect and replace HF dynamic cache if generate injects it
630
  assert past_key_values.get_seq_length() == 0
631
  past_key_values = HuginnDynamicCache()
@@ -645,6 +647,18 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
645
  model_inputs[key] = value
646
  return model_inputs
647
 
 
 
 
 
 
 
 
 
 
 
 
 
648
  @torch.no_grad()
649
  def generate_minimal(
650
  self,
 
242
  if past_key_values is not None:
243
  k, v = past_key_values.update(k, v, step_idx)
244
 
 
245
  if return_attn:
246
  y, attention_map = self.compute_eager_sdpa(q, k, v, attn_mask=mask)
247
  else:
 
368
  "return_latents": True,
369
  "return_attention": False,
370
  "return_head": False,
371
+ "return_stats": False,
372
  },
373
  use_cache: bool = False,
374
  cache_position: Optional[torch.Tensor] = None,
 
396
  # Non-recurrent prelude
397
  for block_idx, block in enumerate(self.transformer.prelude):
398
  input_embeds, attn_map = block(
399
+ input_embeds, freqs_cis, block_idx, attention_mask, past_key_values, return_attn=return_attn
400
  )
401
  attn_maps[block_idx] = attn_map
402
 
 
410
  past_key_values,
411
  num_steps,
412
  attn_maps,
413
+ return_attn=return_attn,
414
  )
415
  latent_states = x.clone().detach()
416
 
417
  # Coda layers
418
  for block_idx, block in enumerate(self.transformer.coda, start=1):
419
+ x, attn_map = block(x, freqs_cis, -block_idx, attention_mask, past_key_values, return_attn=return_attn)
420
  attn_maps[-block_idx] = attn_map
421
  x = self.transformer.ln_f(x)
422
 
 
453
  past_key_values: Optional[Cache] = None,
454
  num_steps: Optional[torch.Tensor] = None,
455
  attn_maps: dict = {},
456
+ return_attn: bool = False,
457
  ):
458
  x = xk = self.initialize_state(input_embeds) if input_states is None else input_states.clone()
459
  if num_steps is None:
 
471
  for step in range(num_steps_no_grad):
472
  xk = x
473
  x, block_idx, attn_maps = self.core_block_forward(
474
+ xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps, return_attn
475
  )
476
 
477
  for step in range(num_steps_with_grad):
478
  xk = x
479
  x, block_idx, attn_maps = self.core_block_forward(
480
+ xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps, return_attn
481
  )
482
  return self.transformer.ln_f(x), num_steps_no_grad, num_steps_with_grad, xk.detach(), block_idx, attn_maps
483
 
 
490
  past_key_values,
491
  block_idx: Union[torch.Tensor, int],
492
  attn_maps: dict = {},
493
+ return_attn: bool = False,
494
  ):
495
  x = self.transformer.adapter(torch.cat([x, input_embeds], dim=-1))
496
  for idx, block in enumerate(self.transformer.core_block, start=1):
497
+ x, attn_map = block(x, freqs_cis, block_idx + idx, mask, past_key_values, return_attn=return_attn)
498
  attn_maps[block_idx + idx] = attn_map
499
  return x, block_idx + idx, attn_maps
500
 
 
627
  model_inputs["cache_position"] = cache_position
628
  current_input_length = input_ids.shape[1]
629
  if past_key_values is not None:
630
+ if type(past_key_values) != HuginnDynamicCache:
631
  # Need to use custom cache, detect and replace HF dynamic cache if generate injects it
632
  assert past_key_values.get_seq_length() == 0
633
  past_key_values = HuginnDynamicCache()
 
647
  model_inputs[key] = value
648
  return model_inputs
649
 
650
+ @torch.no_grad()
651
+ def generate(self, *args, **kwargs):
652
+ """Dispatcher - use HF generate in all normal cases."""
653
+ if any(
654
+ k in kwargs
655
+ for k in ("continuous_compute", "latent_dampening", "criterion", "exit_threshold", "cache_kwargs")
656
+ ):
657
+ print("Dispatching to custom generate function call")
658
+ return self.generate_with_adaptive_compute(*args, **kwargs)
659
+ else:
660
+ return super().generate(*args, **kwargs)
661
+
662
  @torch.no_grad()
663
  def generate_minimal(
664
  self,