JonasGeiping commited on
Commit
adef816
·
verified ·
1 Parent(s): def35f6

Upstream improvements from seal_rg repo

Browse files
Files changed (1) hide show
  1. raven_modeling_minimal.py +198 -88
raven_modeling_minimal.py CHANGED
@@ -25,6 +25,7 @@ class RavenPreTrainedModel(PreTrainedModel):
25
  supports_gradient_checkpointing = True
26
  _no_split_modules = ["SandwichBlock"]
27
  _skip_keys_device_placement = ["past_key_values"]
 
28
  _supports_flash_attn_2 = True
29
  _supports_sdpa = True
30
  _supports_cache_class = True
@@ -63,7 +64,7 @@ class RMSNorm(torch.nn.Module):
63
  return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
64
 
65
  def forward(self, x):
66
- with torch.autocast(enabled=False, device_type=x.device.type):
67
  return self._norm(x.float()).type_as(x) * self.weight
68
 
69
  def reset_parameters(self) -> None:
@@ -342,10 +343,16 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
342
  # Head
343
  self.lm_head = torch.nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)
344
  if self.config.tie_embeddings:
345
- self.lm_head.weight = self.transformer.wte.weight
346
  # rope
347
  self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True)
348
 
 
 
 
 
 
 
349
  def _precompute_freqs_cis(self):
350
  # can actually be a buffer now, and remains in fp32! (at least in the settings I tested)
351
  freqs_cis = precompute_freqs_cis(
@@ -461,7 +468,7 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
461
  elif hasattr(num_steps, "__len__") and len(num_steps) > 1:
462
  num_steps_no_grad, num_steps_with_grad = num_steps
463
  else:
464
- num_steps_no_grad, num_steps_with_grad = num_steps, torch.tensor(0)
465
 
466
  with torch.no_grad():
467
  # ultra annoying in ddp due to
@@ -594,6 +601,10 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
594
  """Outputs are long tensors so that they can be passed through compiled functions"""
595
  t = max(self.config.mean_recurrence - self.config.mean_backprop_depth, 0)
596
  s = self.config.mean_backprop_depth
 
 
 
 
597
  if self.training:
598
  sigma = 0.5
599
  mu = math.log(t + s) - (sigma**2 / 2)
@@ -649,11 +660,10 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
649
 
650
  @torch.no_grad()
651
  def generate(self, *args, **kwargs):
652
- """Dispatcher - use HF generate in all normal cases."""
653
- if any(
654
- k in kwargs
655
- for k in ("continuous_compute", "latent_dampening", "criterion", "exit_threshold", "cache_kwargs")
656
- ):
657
  print("Dispatching to custom generate function call")
658
  return self.generate_with_adaptive_compute(*args, **kwargs)
659
  else:
@@ -757,7 +767,10 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
757
  cache_kwargs: dict = {},
758
  **model_kwargs,
759
  ) -> Union[torch.Tensor, GenerateDecoderOnlyOutput]:
760
- """Minimal single-sequence generation. Template for more complicated generate tasks"""
 
 
 
761
  # Setup
762
  if generation_config is None:
763
  generation_config: GenerationConfig = self.generation_config # type: ignore
@@ -765,10 +778,16 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
765
  model_kwargs["use_cache"] = True
766
  model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
767
  stop_tokens = self._get_stops(generation_config, tokenizer).to(input_ids.device)
 
 
 
 
768
  if continuous_compute:
769
  embedded_inputs, _, _ = self.embed_inputs(input_ids)
770
- current_last_latent = self.initialize_state(embedded_inputs)
771
- compute_steps = []
 
 
772
 
773
  # Generate tokens
774
  for step in range(generation_config.max_length - input_ids.shape[1]):
@@ -781,130 +800,179 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
781
  if not continuous_compute:
782
  current_latents = self.initialize_state(embedded_inputs, deterministic=False)
783
  else:
784
- current_latents = current_last_latent
 
 
 
 
 
785
 
786
- # Prep criterions:
787
  if criterion == "entropy-diff":
788
- entropy = torch.tensor(100.0, device=input_ids.device)
789
  exit_threshold = 1e-3 if exit_threshold == "auto" else float(exit_threshold)
790
  elif criterion in ["latent-diff", "none"]:
791
  exit_threshold = 0.03 if exit_threshold == "auto" else float(exit_threshold)
792
  elif "kl" in criterion:
793
  V = self.config.padded_vocab_size
794
- log_probs = (1 / V * torch.ones(V, device=input_ids.device)).log()
795
  if criterion == "minp-kl":
796
  exit_threshold = 1e-6 if exit_threshold == "auto" else float(exit_threshold)
797
  else:
798
  exit_threshold = 5e-4 if exit_threshold == "auto" else float(exit_threshold)
799
  elif criterion == "argmax-stability":
800
- stable_for_n_steps = 0
801
- current_argmax = torch.tensor(-1, dtype=torch.long, device=input_ids.device)
802
  exit_threshold = 5 if exit_threshold == "auto" else int(exit_threshold)
803
  else:
804
  raise ValueError("Invalid adaptive compute strategy.")
805
 
806
  all_latents = []
807
- exit_values = []
 
 
808
  for compute_step in range(model_inputs["num_steps"]):
809
  prev_latents = current_latents.clone()
810
  current_latents, block_idx, _ = self.iterate_one_step(
811
  embedded_inputs, current_latents, block_idx=block_idx, **aux_inputs
812
  )
813
- all_latents.append(current_latents if latent_dampening else None)
 
 
 
814
  if step > 0: # do not exit in prefill:
 
815
  if criterion == "entropy-diff":
816
- prev_entropy = entropy.clone()
817
  outputs = self.predict_from_latents(current_latents, **aux_inputs)
818
- probs = F.softmax(outputs.logits[:, -1, :], dim=-1) # type: ignore
819
- entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1).mean()
820
- entropy_diff = (entropy - prev_entropy).abs()
821
- exit_values.append(entropy_diff.item())
822
- if entropy_diff < exit_threshold:
823
- break
824
  elif criterion == "latent-diff":
825
- norm_diff = (prev_latents - current_latents).norm() / current_latents.norm()
826
- exit_values.append(norm_diff.item())
827
- if norm_diff < exit_threshold:
828
- break
829
- elif criterion == "kl":
830
- prev_log_probs = log_probs.clone()
831
- outputs = self.predict_from_latents(current_latents, **aux_inputs)
832
- log_probs = F.log_softmax(outputs.logits[:, -1, :], dim=-1) # type: ignore
833
- kl = F.kl_div(log_probs, prev_log_probs, reduction="none", log_target=True).sum(dim=-1)
834
- exit_values.append(kl.item())
835
- if kl < exit_threshold:
836
- break
837
- elif criterion == "minp-kl":
838
- prev_log_probs = log_probs.clone()
839
  outputs = self.predict_from_latents(current_latents, **aux_inputs)
840
- probs = F.softmax(outputs.logits[:, -1, :], dim=-1) # type: ignore
841
- probs[probs < 0.1 * probs.max()] = 1 / V
842
- probs = probs / probs.sum()
843
- log_probs = probs.log()
844
- kl = F.kl_div(log_probs, prev_log_probs, reduction="none", log_target=True).sum(dim=-1)
845
- exit_values.append(kl.item())
846
- if kl < exit_threshold:
847
- break
 
 
 
 
 
 
848
  elif criterion == "argmax-stability":
849
- prev_argmax = current_argmax.clone()
850
  outputs = self.predict_from_latents(current_latents, **aux_inputs)
851
- current_argmax = outputs.logits[0, -1, :].argmax(dim=-1) # type: ignore
852
- if current_argmax == prev_argmax:
853
- stable_for_n_steps += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
854
  else:
855
- stable_for_n_steps = 0
856
- exit_values.append(stable_for_n_steps)
857
- if stable_for_n_steps >= exit_threshold:
858
- break
859
- elif criterion == "none":
860
- pass
861
-
 
 
 
 
862
  else:
863
  if not latent_dampening:
864
  outputs = self.predict_from_latents(current_latents, **aux_inputs)
865
  else:
866
  dampened_latents = torch.sum(torch.cat(all_latents, dim=0), dim=0, keepdim=True)
867
  outputs = self.predict_from_latents(dampened_latents, **aux_inputs)
868
- compute_steps.append([compute_step + 1, exit_values])
869
 
870
- next_token_logits = outputs.logits[0, -1, :] # type: ignore
871
- if continuous_compute: # Save last latent
872
- current_last_latent = current_latents[:, -1:, :]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
873
 
874
- # Sample or select next token
875
- if generation_config.do_sample:
876
- if generation_config.temperature:
877
- next_token_logits = next_token_logits / generation_config.temperature
878
 
879
- probs = F.softmax(next_token_logits, dim=-1)
880
- # Apply top_k
881
- if generation_config.top_k:
882
- top_k_probs, _ = torch.topk(probs, generation_config.top_k)
883
- probs[probs < top_k_probs[-1]] = 0
884
- # Apply top_p
885
- if generation_config.top_p:
886
- sorted_probs = torch.sort(probs, descending=True)[0]
887
- cumsum = torch.cumsum(sorted_probs, dim=-1)
888
- probs[cumsum > generation_config.top_p] = 0
889
- # Apply min_p
890
- if generation_config.min_p:
891
- probs[probs < generation_config.min_p * probs.max()] = 0
892
-
893
- probs = probs / probs.sum()
894
- next_token = torch.multinomial(probs, num_samples=1)
895
  else:
896
- next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
897
 
898
- input_ids = torch.cat([input_ids, next_token[None, :]], dim=-1) # type: ignore
899
 
900
  if streamer:
901
  streamer.put(next_token.cpu())
902
 
903
  # Update model kwargs
904
  model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
 
 
905
 
906
- # Check if we hit a stop token
907
- if stop_tokens is not None and next_token in stop_tokens:
 
 
 
 
 
908
  break
909
 
910
  if streamer:
@@ -931,6 +999,48 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
931
  stop_tokens.add(token_id)
932
  return torch.tensor(list(stop_tokens))
933
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
934
  def get_stats(self, logits, x, latent_states, xk, input_embeds, num_steps_no_grad, num_steps_with_grad):
935
  probs = torch.softmax(logits.float(), dim=-1)
936
  prob_entropy = torch.where(probs > 0, -probs * probs.log(), 0).sum(dim=-1)
@@ -985,4 +1095,4 @@ RavenForCausalLM.register_for_auto_class("AutoModelForCausalLM")
985
  # Old?
986
  AutoConfig.register("huginn_raven", RavenConfig)
987
  AutoModel.register(RavenConfig, RavenForCausalLM)
988
- AutoModelForCausalLM.register(RavenConfig, RavenForCausalLM)
 
25
  supports_gradient_checkpointing = True
26
  _no_split_modules = ["SandwichBlock"]
27
  _skip_keys_device_placement = ["past_key_values"]
28
+ _tied_weights_keys = ["lm_head.weight"]
29
  _supports_flash_attn_2 = True
30
  _supports_sdpa = True
31
  _supports_cache_class = True
 
64
  return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
65
 
66
  def forward(self, x):
67
+ with torch.autocast(enabled=False, device_type=x.device.type if x.device.type != "meta" else "cuda"):
68
  return self._norm(x.float()).type_as(x) * self.weight
69
 
70
  def reset_parameters(self) -> None:
 
343
  # Head
344
  self.lm_head = torch.nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)
345
  if self.config.tie_embeddings:
346
+ self.tie_weights()
347
  # rope
348
  self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True)
349
 
350
+ def get_input_embeddings(self):
351
+ return self.transformer.wte
352
+
353
+ def get_output_embeddings(self):
354
+ return self.lm_head
355
+
356
  def _precompute_freqs_cis(self):
357
  # can actually be a buffer now, and remains in fp32! (at least in the settings I tested)
358
  freqs_cis = precompute_freqs_cis(
 
468
  elif hasattr(num_steps, "__len__") and len(num_steps) > 1:
469
  num_steps_no_grad, num_steps_with_grad = num_steps
470
  else:
471
+ num_steps_no_grad, num_steps_with_grad = num_steps, torch.tensor(0) if not x.is_meta else 0
472
 
473
  with torch.no_grad():
474
  # ultra annoying in ddp due to
 
601
  """Outputs are long tensors so that they can be passed through compiled functions"""
602
  t = max(self.config.mean_recurrence - self.config.mean_backprop_depth, 0)
603
  s = self.config.mean_backprop_depth
604
+ if torch.rand((1,)).is_meta: # annoying clause to make meta-tensor-based flop counting work
605
+ # these values are only the mean TFLOPs of the randomized sampler
606
+ # Note that this clause also breaks the contract, and returns ints in meta tensor mode
607
+ return t, s # type: ignore
608
  if self.training:
609
  sigma = 0.5
610
  mu = math.log(t + s) - (sigma**2 / 2)
 
660
 
661
  @torch.no_grad()
662
  def generate(self, *args, **kwargs):
663
+ """Dispatcher - use HF generate in all normal cases.
664
+ If BOTH `criterion` AND `exit_threshold` are provided as not None, we use adaptive compute.
665
+ """
666
+ if kwargs.get("criterion", None) is not None and kwargs.get("exit_threshold", None) is not None:
 
667
  print("Dispatching to custom generate function call")
668
  return self.generate_with_adaptive_compute(*args, **kwargs)
669
  else:
 
767
  cache_kwargs: dict = {},
768
  **model_kwargs,
769
  ) -> Union[torch.Tensor, GenerateDecoderOnlyOutput]:
770
+ """
771
+ Generate tokens with adaptive compute. This is NOT the most efficient implementation.
772
+ For batches, on each token, we iterate until the entire batch finishes.
773
+ """
774
  # Setup
775
  if generation_config is None:
776
  generation_config: GenerationConfig = self.generation_config # type: ignore
 
778
  model_kwargs["use_cache"] = True
779
  model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
780
  stop_tokens = self._get_stops(generation_config, tokenizer).to(input_ids.device)
781
+ batch_size = input_ids.shape[0]
782
+ compute_steps = []
783
+
784
+ # Set up continuous compute if enabled
785
  if continuous_compute:
786
  embedded_inputs, _, _ = self.embed_inputs(input_ids)
787
+ current_last_latents = self.initialize_state(embedded_inputs)
788
+
789
+ # Track which sequences have finished
790
+ finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
791
 
792
  # Generate tokens
793
  for step in range(generation_config.max_length - input_ids.shape[1]):
 
800
  if not continuous_compute:
801
  current_latents = self.initialize_state(embedded_inputs, deterministic=False)
802
  else:
803
+ current_latents = current_last_latents
804
+
805
+ # Initialize criterion tracking for each sequence in batch
806
+ exit_values_per_seq = [[] for _ in range(batch_size)]
807
+ compute_steps_per_seq = [0] * batch_size
808
+ exit_reached = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
809
 
810
+ # Set up criterions based on selected strategy
811
  if criterion == "entropy-diff":
812
+ entropy = torch.ones(batch_size, device=input_ids.device) * 100.0
813
  exit_threshold = 1e-3 if exit_threshold == "auto" else float(exit_threshold)
814
  elif criterion in ["latent-diff", "none"]:
815
  exit_threshold = 0.03 if exit_threshold == "auto" else float(exit_threshold)
816
  elif "kl" in criterion:
817
  V = self.config.padded_vocab_size
818
+ log_probs = ((1 / V) * torch.ones(batch_size, V, device=input_ids.device)).log()
819
  if criterion == "minp-kl":
820
  exit_threshold = 1e-6 if exit_threshold == "auto" else float(exit_threshold)
821
  else:
822
  exit_threshold = 5e-4 if exit_threshold == "auto" else float(exit_threshold)
823
  elif criterion == "argmax-stability":
824
+ stable_for_n_steps = torch.zeros(batch_size, dtype=torch.long, device=input_ids.device)
825
+ current_argmax = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) * -1
826
  exit_threshold = 5 if exit_threshold == "auto" else int(exit_threshold)
827
  else:
828
  raise ValueError("Invalid adaptive compute strategy.")
829
 
830
  all_latents = []
831
+ next_token_logits = None
832
+
833
+ # Iterate through compute steps
834
  for compute_step in range(model_inputs["num_steps"]):
835
  prev_latents = current_latents.clone()
836
  current_latents, block_idx, _ = self.iterate_one_step(
837
  embedded_inputs, current_latents, block_idx=block_idx, **aux_inputs
838
  )
839
+
840
+ if latent_dampening:
841
+ all_latents.append(current_latents)
842
+
843
  if step > 0: # do not exit in prefill:
844
+ # Check exit condition for each sequence in batch
845
  if criterion == "entropy-diff":
846
+ prev_entropy = entropy
847
  outputs = self.predict_from_latents(current_latents, **aux_inputs)
848
+ logits: torch.Tensor = outputs.logits # type: ignore
849
+ probs = F.softmax(logits[:, -1, :], dim=-1)
850
+ entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1)
851
+ exit_values = (entropy - prev_entropy).abs()
852
+
 
853
  elif criterion == "latent-diff":
854
+ norm_diff = (prev_latents - current_latents).norm(dim=-1) / current_latents.norm(dim=-1)
855
+ exit_values = norm_diff.mean(dim=-1)
856
+
857
+ elif "kl" in criterion:
 
 
 
 
 
 
 
 
 
 
858
  outputs = self.predict_from_latents(current_latents, **aux_inputs)
859
+ logits: torch.Tensor = outputs.logits # type: ignore
860
+ prev_log_probs = log_probs
861
+ if criterion == "minp-kl":
862
+ probs = F.softmax(logits[:, -1, :], dim=-1)
863
+ max_probs = probs.max(dim=-1, keepdim=True)[0]
864
+ probs_mask = probs < (0.1 * max_probs)
865
+ masked_probs = probs
866
+ masked_probs[probs_mask] = 1 / V
867
+ probs = masked_probs / masked_probs.sum(dim=-1, keepdim=True)
868
+ log_probs = probs.log()
869
+ else:
870
+ log_probs = F.log_softmax(logits[:, -1, :], dim=-1)
871
+ exit_values = F.kl_div(log_probs, prev_log_probs, reduction="none", log_target=True).sum(dim=-1)
872
+
873
  elif criterion == "argmax-stability":
874
+ prev_argmax = current_argmax
875
  outputs = self.predict_from_latents(current_latents, **aux_inputs)
876
+ logits: torch.Tensor = outputs.logits # type: ignore
877
+ current_argmax = logits[:, -1, :].argmax(dim=-1)
878
+ stable_for_n_steps = torch.where(
879
+ current_argmax == prev_argmax, stable_for_n_steps + 1, torch.zeros_like(stable_for_n_steps)
880
+ )
881
+ exit_values = stable_for_n_steps
882
+
883
+ # Record values and check exits for each sequence
884
+ for i in range(batch_size):
885
+ if not exit_reached[i] and not finished_sequences[i]:
886
+ exit_values_per_seq[i].append(exit_values[i].item())
887
+
888
+ new_exits = (
889
+ exit_values < exit_threshold
890
+ if criterion != "argmax-stability"
891
+ else exit_values >= exit_threshold
892
+ )
893
+ new_exits = new_exits & ~exit_reached & ~finished_sequences
894
+
895
+ if new_exits.any():
896
+ exit_reached = exit_reached | new_exits
897
+ if criterion == "latent-diff":
898
+ # Normally we don't compute the output for latent-diff, but when there is an exit,
899
+ # we need to compute and save the output
900
+ outputs = self.predict_from_latents(current_latents, **aux_inputs)
901
+ logits: torch.Tensor = outputs.logits # type: ignore
902
+ if next_token_logits is None:
903
+ next_token_logits = logits[:, -1, :].clone()
904
  else:
905
+ next_token_logits = torch.where(
906
+ new_exits.unsqueeze(1).expand_as(logits[:, -1, :]), logits[:, -1, :], next_token_logits
907
+ )
908
+ for i in range(batch_size):
909
+ if new_exits[i]:
910
+ compute_steps_per_seq[i] = compute_step + 1
911
+
912
+ # If all sequences have exited, break early
913
+ if (exit_reached | finished_sequences).all():
914
+ break
915
+ # This else is if the for loop finished without breaking
916
  else:
917
  if not latent_dampening:
918
  outputs = self.predict_from_latents(current_latents, **aux_inputs)
919
  else:
920
  dampened_latents = torch.sum(torch.cat(all_latents, dim=0), dim=0, keepdim=True)
921
  outputs = self.predict_from_latents(dampened_latents, **aux_inputs)
 
922
 
923
+ # For sequences that didn't exit early, use the final logits
924
+ if next_token_logits is None:
925
+ next_token_logits = outputs.logits[:, -1, :] # type: ignore
926
+ else:
927
+ # Only update logits for sequences that didn't exit early
928
+ non_exit_mask = ~exit_reached & ~finished_sequences
929
+ next_token_logits = torch.where(
930
+ non_exit_mask.unsqueeze(1).expand_as(next_token_logits),
931
+ outputs.logits[:, -1, :], # type: ignore
932
+ next_token_logits,
933
+ )
934
+
935
+ # Record compute steps for non-exited sequences
936
+ for i in range(batch_size):
937
+ if non_exit_mask[i]:
938
+ compute_steps_per_seq[i] = model_inputs["num_steps"]
939
+
940
+ # Save latent states for continuous compute if enabled
941
+ if continuous_compute:
942
+ current_last_latents = current_latents[:, -1:, :]
943
 
944
+ # Record compute steps for this token generation
945
+ compute_steps.append([compute_steps_per_seq, exit_values_per_seq])
 
 
946
 
947
+ # Sample or select next token based on generation config
948
+ if generation_config.do_sample:
949
+ next_token = self._sample_next_token(
950
+ next_token_logits,
951
+ generation_config.temperature,
952
+ generation_config.top_k,
953
+ generation_config.top_p,
954
+ generation_config.min_p,
955
+ )
 
 
 
 
 
 
 
956
  else:
957
+ next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) # type: ignore
958
 
959
+ input_ids = torch.cat([input_ids, next_token], dim=-1) # type: ignore
960
 
961
  if streamer:
962
  streamer.put(next_token.cpu())
963
 
964
  # Update model kwargs
965
  model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
966
+ if continuous_compute:
967
+ model_kwargs["input_states"] = current_last_latents
968
 
969
+ # Check for finished sequences
970
+ for i in range(batch_size):
971
+ if not finished_sequences[i] and stop_tokens is not None and next_token[i, 0] in stop_tokens:
972
+ finished_sequences[i] = True
973
+
974
+ # Break if all sequences are finished
975
+ if finished_sequences.all():
976
  break
977
 
978
  if streamer:
 
999
  stop_tokens.add(token_id)
1000
  return torch.tensor(list(stop_tokens))
1001
 
1002
+ def _sample_next_token(self, next_token_logits, temperature=None, top_k=None, top_p=None, min_p=None):
1003
+ """Helper function to sample the next token."""
1004
+ if temperature:
1005
+ next_token_logits = next_token_logits / temperature
1006
+
1007
+ probs = F.softmax(next_token_logits, dim=-1)
1008
+
1009
+ # Apply top_k
1010
+ if top_k:
1011
+ top_k_values, _ = torch.topk(probs, top_k, dim=-1)
1012
+ min_values = top_k_values[:, -1].unsqueeze(-1).expand_as(probs)
1013
+ probs = torch.where(probs < min_values, torch.zeros_like(probs), probs)
1014
+
1015
+ # Apply top_p (nucleus sampling)
1016
+ if top_p:
1017
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
1018
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
1019
+
1020
+ # Create mask for probs to keep
1021
+ remove_indices = cumulative_probs > top_p
1022
+ remove_indices[:, 0] = False # Keep at least the top probability
1023
+
1024
+ # Convert sorted indices mask back to original indices mask
1025
+ mask = torch.zeros_like(probs, dtype=torch.bool)
1026
+ for i in range(probs.shape[0]):
1027
+ mask[i, sorted_indices[i, remove_indices[i]]] = True
1028
+
1029
+ probs = torch.where(mask, torch.zeros_like(probs), probs)
1030
+
1031
+ # Apply min_p
1032
+ if min_p:
1033
+ max_probs = probs.max(dim=-1, keepdim=True)[0]
1034
+ min_p_threshold = min_p * max_probs
1035
+ probs = torch.where(probs < min_p_threshold, torch.zeros_like(probs), probs)
1036
+
1037
+ # Renormalize probabilities
1038
+ probs = probs / probs.sum(dim=-1, keepdim=True).clamp(min=1e-10)
1039
+
1040
+ # Sample from the distribution
1041
+ next_token = torch.multinomial(probs, num_samples=1)
1042
+ return next_token
1043
+
1044
  def get_stats(self, logits, x, latent_states, xk, input_embeds, num_steps_no_grad, num_steps_with_grad):
1045
  probs = torch.softmax(logits.float(), dim=-1)
1046
  prob_entropy = torch.where(probs > 0, -probs * probs.log(), 0).sum(dim=-1)
 
1095
  # Old?
1096
  AutoConfig.register("huginn_raven", RavenConfig)
1097
  AutoModel.register(RavenConfig, RavenForCausalLM)
1098
+ AutoModelForCausalLM.register(RavenConfig, RavenForCausalLM)