JonasGeiping commited on
Commit
a213de4
·
verified ·
1 Parent(s): 05323d0

Update raven_modeling_minimal.py

Browse files
Files changed (1) hide show
  1. raven_modeling_minimal.py +40 -28
raven_modeling_minimal.py CHANGED
@@ -11,7 +11,7 @@ from .raven_config_minimal import RavenConfig
11
  from transformers.cache_utils import Cache, DynamicCache
12
 
13
  ###################### Huggingface Glue code I ##################################################################
14
- from transformers import PreTrainedModel
15
  from transformers.utils import ModelOutput
16
  from transformers.generation.utils import GenerateDecoderOnlyOutput
17
 
@@ -32,7 +32,8 @@ class RavenPreTrainedModel(PreTrainedModel):
32
  _supports_static_cache = False
33
 
34
  def _init_weights(self, module):
35
- print("Random Initialization not implemented.")
 
36
 
37
 
38
  @dataclass
@@ -309,7 +310,7 @@ class SandwichBlock(torch.nn.Module):
309
  return x, attn_map
310
 
311
 
312
- class RavenForCausalLM(RavenPreTrainedModel):
313
  def __init__(
314
  self,
315
  config: RavenConfig,
@@ -367,7 +368,7 @@ class RavenForCausalLM(RavenPreTrainedModel):
367
  "return_latents": True,
368
  "return_attention": False,
369
  "return_head": False,
370
- "return_stats": True,
371
  },
372
  use_cache: bool = False,
373
  cache_position: Optional[torch.Tensor] = None,
@@ -395,7 +396,7 @@ class RavenForCausalLM(RavenPreTrainedModel):
395
  # Non-recurrent prelude
396
  for block_idx, block in enumerate(self.transformer.prelude):
397
  input_embeds, attn_map = block(
398
- input_embeds, freqs_cis, block_idx, attention_mask, past_key_values, return_attn
399
  )
400
  attn_maps[block_idx] = attn_map
401
 
@@ -409,12 +410,13 @@ class RavenForCausalLM(RavenPreTrainedModel):
409
  past_key_values,
410
  num_steps,
411
  attn_maps,
 
412
  )
413
  latent_states = x.clone().detach()
414
 
415
  # Coda layers
416
  for block_idx, block in enumerate(self.transformer.coda, start=1):
417
- x, attn_map = block(x, freqs_cis, -block_idx, attention_mask, past_key_values, return_attn)
418
  attn_maps[-block_idx] = attn_map
419
  x = self.transformer.ln_f(x)
420
 
@@ -451,6 +453,7 @@ class RavenForCausalLM(RavenPreTrainedModel):
451
  past_key_values: Optional[Cache] = None,
452
  num_steps: Optional[torch.Tensor] = None,
453
  attn_maps: dict = {},
 
454
  ):
455
  x = xk = self.initialize_state(input_embeds) if input_states is None else input_states.clone()
456
  if num_steps is None:
@@ -468,13 +471,13 @@ class RavenForCausalLM(RavenPreTrainedModel):
468
  for step in range(num_steps_no_grad):
469
  xk = x
470
  x, block_idx, attn_maps = self.core_block_forward(
471
- xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps
472
  )
473
 
474
  for step in range(num_steps_with_grad):
475
  xk = x
476
  x, block_idx, attn_maps = self.core_block_forward(
477
- xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps
478
  )
479
  return self.transformer.ln_f(x), num_steps_no_grad, num_steps_with_grad, xk.detach(), block_idx, attn_maps
480
 
@@ -487,10 +490,11 @@ class RavenForCausalLM(RavenPreTrainedModel):
487
  past_key_values,
488
  block_idx: Union[torch.Tensor, int],
489
  attn_maps: dict = {},
 
490
  ):
491
  x = self.transformer.adapter(torch.cat([x, input_embeds], dim=-1))
492
  for idx, block in enumerate(self.transformer.core_block, start=1):
493
- x, attn_map = block(x, freqs_cis, block_idx + idx, mask, past_key_values, return_attn=len(attn_maps) > 0)
494
  attn_maps[block_idx + idx] = attn_map
495
  return x, block_idx + idx, attn_maps
496
 
@@ -623,7 +627,7 @@ class RavenForCausalLM(RavenPreTrainedModel):
623
  model_inputs["cache_position"] = cache_position
624
  current_input_length = input_ids.shape[1]
625
  if past_key_values is not None:
626
- if type(past_key_values) == DynamicCache:
627
  # Need to use custom cache, detect and replace HF dynamic cache if generate injects it
628
  assert past_key_values.get_seq_length() == 0
629
  past_key_values = HuginnDynamicCache()
@@ -643,6 +647,18 @@ class RavenForCausalLM(RavenPreTrainedModel):
643
  model_inputs[key] = value
644
  return model_inputs
645
 
 
 
 
 
 
 
 
 
 
 
 
 
646
  @torch.no_grad()
647
  def generate_minimal(
648
  self,
@@ -788,37 +804,35 @@ class RavenForCausalLM(RavenPreTrainedModel):
788
  raise ValueError("Invalid adaptive compute strategy.")
789
 
790
  all_latents = []
791
- exit_value = float("NaN")
792
- for compute_step in range(1, model_inputs["num_steps"]):
793
  prev_latents = current_latents.clone()
794
  current_latents, block_idx, _ = self.iterate_one_step(
795
  embedded_inputs, current_latents, block_idx=block_idx, **aux_inputs
796
  )
797
  all_latents.append(current_latents if latent_dampening else None)
798
- if compute_step > 1 and step > 0: # do not exit in prefill:
799
  if criterion == "entropy-diff":
800
  prev_entropy = entropy.clone()
801
  outputs = self.predict_from_latents(current_latents, **aux_inputs)
802
  probs = F.softmax(outputs.logits[:, -1, :], dim=-1) # type: ignore
803
  entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1).mean()
804
- entropy_diff = exit_value = (entropy - prev_entropy).abs()
 
805
  if entropy_diff < exit_threshold:
806
- compute_steps.append([compute_step, entropy_diff.item()])
807
  break
808
  elif criterion == "latent-diff":
809
- norm_diff = exit_value = (prev_latents - current_latents).norm() / current_latents.norm()
 
810
  if norm_diff < exit_threshold:
811
- compute_steps.append([compute_step, norm_diff.item()])
812
  break
813
  elif criterion == "kl":
814
  prev_log_probs = log_probs.clone()
815
  outputs = self.predict_from_latents(current_latents, **aux_inputs)
816
  log_probs = F.log_softmax(outputs.logits[:, -1, :], dim=-1) # type: ignore
817
- kl = exit_value = F.kl_div(log_probs, prev_log_probs, reduction="none", log_target=True).sum(
818
- dim=-1
819
- )
820
  if kl < exit_threshold:
821
- compute_steps.append([compute_step, kl.item()])
822
  break
823
  elif criterion == "minp-kl":
824
  prev_log_probs = log_probs.clone()
@@ -827,33 +841,31 @@ class RavenForCausalLM(RavenPreTrainedModel):
827
  probs[probs < 0.1 * probs.max()] = 1 / V
828
  probs = probs / probs.sum()
829
  log_probs = probs.log()
830
- kl = exit_value = F.kl_div(log_probs, prev_log_probs, reduction="none", log_target=True).sum(
831
- dim=-1
832
- )
833
  if kl < exit_threshold:
834
- compute_steps.append([compute_step, kl.item()])
835
  break
836
  elif criterion == "argmax-stability":
837
  prev_argmax = current_argmax.clone()
838
  outputs = self.predict_from_latents(current_latents, **aux_inputs)
839
- current_argmax = exit_value = outputs.logits[0, -1, :].argmax(dim=-1) # type: ignore
840
  if current_argmax == prev_argmax:
841
  stable_for_n_steps += 1
842
  else:
843
  stable_for_n_steps = 0
 
844
  if stable_for_n_steps >= exit_threshold:
845
- compute_steps.append([compute_step, stable_for_n_steps])
846
  break
847
  elif criterion == "none":
848
  pass
849
 
850
  else:
851
- compute_steps.append([compute_step, exit_value])
852
  if not latent_dampening:
853
  outputs = self.predict_from_latents(current_latents, **aux_inputs)
854
  else:
855
  dampened_latents = torch.sum(torch.cat(all_latents, dim=0), dim=0, keepdim=True)
856
  outputs = self.predict_from_latents(dampened_latents, **aux_inputs)
 
857
 
858
  next_token_logits = outputs.logits[0, -1, :] # type: ignore
859
  if continuous_compute: # Save last latent
 
11
  from transformers.cache_utils import Cache, DynamicCache
12
 
13
  ###################### Huggingface Glue code I ##################################################################
14
+ from transformers import PreTrainedModel, GenerationMixin
15
  from transformers.utils import ModelOutput
16
  from transformers.generation.utils import GenerateDecoderOnlyOutput
17
 
 
32
  _supports_static_cache = False
33
 
34
  def _init_weights(self, module):
35
+ if not torch.rand((1,)).is_meta:
36
+ print("Random Initialization not implemented.")
37
 
38
 
39
  @dataclass
 
310
  return x, attn_map
311
 
312
 
313
+ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
314
  def __init__(
315
  self,
316
  config: RavenConfig,
 
368
  "return_latents": True,
369
  "return_attention": False,
370
  "return_head": False,
371
+ "return_stats": False,
372
  },
373
  use_cache: bool = False,
374
  cache_position: Optional[torch.Tensor] = None,
 
396
  # Non-recurrent prelude
397
  for block_idx, block in enumerate(self.transformer.prelude):
398
  input_embeds, attn_map = block(
399
+ input_embeds, freqs_cis, block_idx, attention_mask, past_key_values, return_attn=return_attn
400
  )
401
  attn_maps[block_idx] = attn_map
402
 
 
410
  past_key_values,
411
  num_steps,
412
  attn_maps,
413
+ return_attn=return_attn,
414
  )
415
  latent_states = x.clone().detach()
416
 
417
  # Coda layers
418
  for block_idx, block in enumerate(self.transformer.coda, start=1):
419
+ x, attn_map = block(x, freqs_cis, -block_idx, attention_mask, past_key_values, return_attn=return_attn)
420
  attn_maps[-block_idx] = attn_map
421
  x = self.transformer.ln_f(x)
422
 
 
453
  past_key_values: Optional[Cache] = None,
454
  num_steps: Optional[torch.Tensor] = None,
455
  attn_maps: dict = {},
456
+ return_attn: bool = False,
457
  ):
458
  x = xk = self.initialize_state(input_embeds) if input_states is None else input_states.clone()
459
  if num_steps is None:
 
471
  for step in range(num_steps_no_grad):
472
  xk = x
473
  x, block_idx, attn_maps = self.core_block_forward(
474
+ xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps, return_attn
475
  )
476
 
477
  for step in range(num_steps_with_grad):
478
  xk = x
479
  x, block_idx, attn_maps = self.core_block_forward(
480
+ xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps, return_attn
481
  )
482
  return self.transformer.ln_f(x), num_steps_no_grad, num_steps_with_grad, xk.detach(), block_idx, attn_maps
483
 
 
490
  past_key_values,
491
  block_idx: Union[torch.Tensor, int],
492
  attn_maps: dict = {},
493
+ return_attn: bool = False,
494
  ):
495
  x = self.transformer.adapter(torch.cat([x, input_embeds], dim=-1))
496
  for idx, block in enumerate(self.transformer.core_block, start=1):
497
+ x, attn_map = block(x, freqs_cis, block_idx + idx, mask, past_key_values, return_attn=return_attn)
498
  attn_maps[block_idx + idx] = attn_map
499
  return x, block_idx + idx, attn_maps
500
 
 
627
  model_inputs["cache_position"] = cache_position
628
  current_input_length = input_ids.shape[1]
629
  if past_key_values is not None:
630
+ if type(past_key_values) != HuginnDynamicCache:
631
  # Need to use custom cache, detect and replace HF dynamic cache if generate injects it
632
  assert past_key_values.get_seq_length() == 0
633
  past_key_values = HuginnDynamicCache()
 
647
  model_inputs[key] = value
648
  return model_inputs
649
 
650
+ @torch.no_grad()
651
+ def generate(self, *args, **kwargs):
652
+ """Dispatcher - use HF generate in all normal cases."""
653
+ if any(
654
+ k in kwargs
655
+ for k in ("continuous_compute", "latent_dampening", "criterion", "exit_threshold", "cache_kwargs")
656
+ ):
657
+ print("Dispatching to custom generate function call")
658
+ return self.generate_with_adaptive_compute(*args, **kwargs)
659
+ else:
660
+ return super().generate(*args, **kwargs)
661
+
662
  @torch.no_grad()
663
  def generate_minimal(
664
  self,
 
804
  raise ValueError("Invalid adaptive compute strategy.")
805
 
806
  all_latents = []
807
+ exit_values = []
808
+ for compute_step in range(model_inputs["num_steps"]):
809
  prev_latents = current_latents.clone()
810
  current_latents, block_idx, _ = self.iterate_one_step(
811
  embedded_inputs, current_latents, block_idx=block_idx, **aux_inputs
812
  )
813
  all_latents.append(current_latents if latent_dampening else None)
814
+ if step > 0: # do not exit in prefill:
815
  if criterion == "entropy-diff":
816
  prev_entropy = entropy.clone()
817
  outputs = self.predict_from_latents(current_latents, **aux_inputs)
818
  probs = F.softmax(outputs.logits[:, -1, :], dim=-1) # type: ignore
819
  entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1).mean()
820
+ entropy_diff = (entropy - prev_entropy).abs()
821
+ exit_values.append(entropy_diff.item())
822
  if entropy_diff < exit_threshold:
 
823
  break
824
  elif criterion == "latent-diff":
825
+ norm_diff = (prev_latents - current_latents).norm() / current_latents.norm()
826
+ exit_values.append(norm_diff.item())
827
  if norm_diff < exit_threshold:
 
828
  break
829
  elif criterion == "kl":
830
  prev_log_probs = log_probs.clone()
831
  outputs = self.predict_from_latents(current_latents, **aux_inputs)
832
  log_probs = F.log_softmax(outputs.logits[:, -1, :], dim=-1) # type: ignore
833
+ kl = F.kl_div(log_probs, prev_log_probs, reduction="none", log_target=True).sum(dim=-1)
834
+ exit_values.append(kl.item())
 
835
  if kl < exit_threshold:
 
836
  break
837
  elif criterion == "minp-kl":
838
  prev_log_probs = log_probs.clone()
 
841
  probs[probs < 0.1 * probs.max()] = 1 / V
842
  probs = probs / probs.sum()
843
  log_probs = probs.log()
844
+ kl = F.kl_div(log_probs, prev_log_probs, reduction="none", log_target=True).sum(dim=-1)
845
+ exit_values.append(kl.item())
 
846
  if kl < exit_threshold:
 
847
  break
848
  elif criterion == "argmax-stability":
849
  prev_argmax = current_argmax.clone()
850
  outputs = self.predict_from_latents(current_latents, **aux_inputs)
851
+ current_argmax = outputs.logits[0, -1, :].argmax(dim=-1) # type: ignore
852
  if current_argmax == prev_argmax:
853
  stable_for_n_steps += 1
854
  else:
855
  stable_for_n_steps = 0
856
+ exit_values.append(stable_for_n_steps)
857
  if stable_for_n_steps >= exit_threshold:
 
858
  break
859
  elif criterion == "none":
860
  pass
861
 
862
  else:
 
863
  if not latent_dampening:
864
  outputs = self.predict_from_latents(current_latents, **aux_inputs)
865
  else:
866
  dampened_latents = torch.sum(torch.cat(all_latents, dim=0), dim=0, keepdim=True)
867
  outputs = self.predict_from_latents(dampened_latents, **aux_inputs)
868
+ compute_steps.append([compute_step + 1, exit_values])
869
 
870
  next_token_logits = outputs.logits[0, -1, :] # type: ignore
871
  if continuous_compute: # Save last latent