Upstream improvements from seal_rg repo
Browse files- raven_modeling_minimal.py +198 -88
raven_modeling_minimal.py
CHANGED
@@ -25,6 +25,7 @@ class RavenPreTrainedModel(PreTrainedModel):
|
|
25 |
supports_gradient_checkpointing = True
|
26 |
_no_split_modules = ["SandwichBlock"]
|
27 |
_skip_keys_device_placement = ["past_key_values"]
|
|
|
28 |
_supports_flash_attn_2 = True
|
29 |
_supports_sdpa = True
|
30 |
_supports_cache_class = True
|
@@ -63,7 +64,7 @@ class RMSNorm(torch.nn.Module):
|
|
63 |
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
64 |
|
65 |
def forward(self, x):
|
66 |
-
with torch.autocast(enabled=False, device_type=x.device.type):
|
67 |
return self._norm(x.float()).type_as(x) * self.weight
|
68 |
|
69 |
def reset_parameters(self) -> None:
|
@@ -342,10 +343,16 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
|
|
342 |
# Head
|
343 |
self.lm_head = torch.nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)
|
344 |
if self.config.tie_embeddings:
|
345 |
-
self.
|
346 |
# rope
|
347 |
self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True)
|
348 |
|
|
|
|
|
|
|
|
|
|
|
|
|
349 |
def _precompute_freqs_cis(self):
|
350 |
# can actually be a buffer now, and remains in fp32! (at least in the settings I tested)
|
351 |
freqs_cis = precompute_freqs_cis(
|
@@ -461,7 +468,7 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
|
|
461 |
elif hasattr(num_steps, "__len__") and len(num_steps) > 1:
|
462 |
num_steps_no_grad, num_steps_with_grad = num_steps
|
463 |
else:
|
464 |
-
num_steps_no_grad, num_steps_with_grad = num_steps, torch.tensor(0)
|
465 |
|
466 |
with torch.no_grad():
|
467 |
# ultra annoying in ddp due to
|
@@ -594,6 +601,10 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
|
|
594 |
"""Outputs are long tensors so that they can be passed through compiled functions"""
|
595 |
t = max(self.config.mean_recurrence - self.config.mean_backprop_depth, 0)
|
596 |
s = self.config.mean_backprop_depth
|
|
|
|
|
|
|
|
|
597 |
if self.training:
|
598 |
sigma = 0.5
|
599 |
mu = math.log(t + s) - (sigma**2 / 2)
|
@@ -649,11 +660,10 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
|
|
649 |
|
650 |
@torch.no_grad()
|
651 |
def generate(self, *args, **kwargs):
|
652 |
-
"""Dispatcher - use HF generate in all normal cases.
|
653 |
-
|
654 |
-
|
655 |
-
|
656 |
-
):
|
657 |
print("Dispatching to custom generate function call")
|
658 |
return self.generate_with_adaptive_compute(*args, **kwargs)
|
659 |
else:
|
@@ -757,7 +767,10 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
|
|
757 |
cache_kwargs: dict = {},
|
758 |
**model_kwargs,
|
759 |
) -> Union[torch.Tensor, GenerateDecoderOnlyOutput]:
|
760 |
-
"""
|
|
|
|
|
|
|
761 |
# Setup
|
762 |
if generation_config is None:
|
763 |
generation_config: GenerationConfig = self.generation_config # type: ignore
|
@@ -765,10 +778,16 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
|
|
765 |
model_kwargs["use_cache"] = True
|
766 |
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
767 |
stop_tokens = self._get_stops(generation_config, tokenizer).to(input_ids.device)
|
|
|
|
|
|
|
|
|
768 |
if continuous_compute:
|
769 |
embedded_inputs, _, _ = self.embed_inputs(input_ids)
|
770 |
-
|
771 |
-
|
|
|
|
|
772 |
|
773 |
# Generate tokens
|
774 |
for step in range(generation_config.max_length - input_ids.shape[1]):
|
@@ -781,130 +800,179 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
|
|
781 |
if not continuous_compute:
|
782 |
current_latents = self.initialize_state(embedded_inputs, deterministic=False)
|
783 |
else:
|
784 |
-
current_latents =
|
|
|
|
|
|
|
|
|
|
|
785 |
|
786 |
-
#
|
787 |
if criterion == "entropy-diff":
|
788 |
-
entropy = torch.
|
789 |
exit_threshold = 1e-3 if exit_threshold == "auto" else float(exit_threshold)
|
790 |
elif criterion in ["latent-diff", "none"]:
|
791 |
exit_threshold = 0.03 if exit_threshold == "auto" else float(exit_threshold)
|
792 |
elif "kl" in criterion:
|
793 |
V = self.config.padded_vocab_size
|
794 |
-
log_probs = (1 / V * torch.ones(V, device=input_ids.device)).log()
|
795 |
if criterion == "minp-kl":
|
796 |
exit_threshold = 1e-6 if exit_threshold == "auto" else float(exit_threshold)
|
797 |
else:
|
798 |
exit_threshold = 5e-4 if exit_threshold == "auto" else float(exit_threshold)
|
799 |
elif criterion == "argmax-stability":
|
800 |
-
stable_for_n_steps =
|
801 |
-
current_argmax = torch.
|
802 |
exit_threshold = 5 if exit_threshold == "auto" else int(exit_threshold)
|
803 |
else:
|
804 |
raise ValueError("Invalid adaptive compute strategy.")
|
805 |
|
806 |
all_latents = []
|
807 |
-
|
|
|
|
|
808 |
for compute_step in range(model_inputs["num_steps"]):
|
809 |
prev_latents = current_latents.clone()
|
810 |
current_latents, block_idx, _ = self.iterate_one_step(
|
811 |
embedded_inputs, current_latents, block_idx=block_idx, **aux_inputs
|
812 |
)
|
813 |
-
|
|
|
|
|
|
|
814 |
if step > 0: # do not exit in prefill:
|
|
|
815 |
if criterion == "entropy-diff":
|
816 |
-
prev_entropy = entropy
|
817 |
outputs = self.predict_from_latents(current_latents, **aux_inputs)
|
818 |
-
|
819 |
-
|
820 |
-
|
821 |
-
exit_values
|
822 |
-
|
823 |
-
break
|
824 |
elif criterion == "latent-diff":
|
825 |
-
norm_diff = (prev_latents - current_latents).norm() / current_latents.norm()
|
826 |
-
exit_values
|
827 |
-
|
828 |
-
|
829 |
-
elif criterion == "kl":
|
830 |
-
prev_log_probs = log_probs.clone()
|
831 |
-
outputs = self.predict_from_latents(current_latents, **aux_inputs)
|
832 |
-
log_probs = F.log_softmax(outputs.logits[:, -1, :], dim=-1) # type: ignore
|
833 |
-
kl = F.kl_div(log_probs, prev_log_probs, reduction="none", log_target=True).sum(dim=-1)
|
834 |
-
exit_values.append(kl.item())
|
835 |
-
if kl < exit_threshold:
|
836 |
-
break
|
837 |
-
elif criterion == "minp-kl":
|
838 |
-
prev_log_probs = log_probs.clone()
|
839 |
outputs = self.predict_from_latents(current_latents, **aux_inputs)
|
840 |
-
|
841 |
-
|
842 |
-
|
843 |
-
|
844 |
-
|
845 |
-
|
846 |
-
|
847 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
848 |
elif criterion == "argmax-stability":
|
849 |
-
prev_argmax = current_argmax
|
850 |
outputs = self.predict_from_latents(current_latents, **aux_inputs)
|
851 |
-
|
852 |
-
|
853 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
854 |
else:
|
855 |
-
|
856 |
-
|
857 |
-
|
858 |
-
|
859 |
-
|
860 |
-
|
861 |
-
|
|
|
|
|
|
|
|
|
862 |
else:
|
863 |
if not latent_dampening:
|
864 |
outputs = self.predict_from_latents(current_latents, **aux_inputs)
|
865 |
else:
|
866 |
dampened_latents = torch.sum(torch.cat(all_latents, dim=0), dim=0, keepdim=True)
|
867 |
outputs = self.predict_from_latents(dampened_latents, **aux_inputs)
|
868 |
-
compute_steps.append([compute_step + 1, exit_values])
|
869 |
|
870 |
-
|
871 |
-
|
872 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
873 |
|
874 |
-
#
|
875 |
-
|
876 |
-
if generation_config.temperature:
|
877 |
-
next_token_logits = next_token_logits / generation_config.temperature
|
878 |
|
879 |
-
|
880 |
-
|
881 |
-
|
882 |
-
|
883 |
-
|
884 |
-
|
885 |
-
|
886 |
-
|
887 |
-
|
888 |
-
probs[cumsum > generation_config.top_p] = 0
|
889 |
-
# Apply min_p
|
890 |
-
if generation_config.min_p:
|
891 |
-
probs[probs < generation_config.min_p * probs.max()] = 0
|
892 |
-
|
893 |
-
probs = probs / probs.sum()
|
894 |
-
next_token = torch.multinomial(probs, num_samples=1)
|
895 |
else:
|
896 |
-
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
|
897 |
|
898 |
-
input_ids = torch.cat([input_ids, next_token
|
899 |
|
900 |
if streamer:
|
901 |
streamer.put(next_token.cpu())
|
902 |
|
903 |
# Update model kwargs
|
904 |
model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
|
|
|
|
|
905 |
|
906 |
-
# Check
|
907 |
-
|
|
|
|
|
|
|
|
|
|
|
908 |
break
|
909 |
|
910 |
if streamer:
|
@@ -931,6 +999,48 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
|
|
931 |
stop_tokens.add(token_id)
|
932 |
return torch.tensor(list(stop_tokens))
|
933 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
934 |
def get_stats(self, logits, x, latent_states, xk, input_embeds, num_steps_no_grad, num_steps_with_grad):
|
935 |
probs = torch.softmax(logits.float(), dim=-1)
|
936 |
prob_entropy = torch.where(probs > 0, -probs * probs.log(), 0).sum(dim=-1)
|
@@ -985,4 +1095,4 @@ RavenForCausalLM.register_for_auto_class("AutoModelForCausalLM")
|
|
985 |
# Old?
|
986 |
AutoConfig.register("huginn_raven", RavenConfig)
|
987 |
AutoModel.register(RavenConfig, RavenForCausalLM)
|
988 |
-
AutoModelForCausalLM.register(RavenConfig, RavenForCausalLM)
|
|
|
25 |
supports_gradient_checkpointing = True
|
26 |
_no_split_modules = ["SandwichBlock"]
|
27 |
_skip_keys_device_placement = ["past_key_values"]
|
28 |
+
_tied_weights_keys = ["lm_head.weight"]
|
29 |
_supports_flash_attn_2 = True
|
30 |
_supports_sdpa = True
|
31 |
_supports_cache_class = True
|
|
|
64 |
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
65 |
|
66 |
def forward(self, x):
|
67 |
+
with torch.autocast(enabled=False, device_type=x.device.type if x.device.type != "meta" else "cuda"):
|
68 |
return self._norm(x.float()).type_as(x) * self.weight
|
69 |
|
70 |
def reset_parameters(self) -> None:
|
|
|
343 |
# Head
|
344 |
self.lm_head = torch.nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)
|
345 |
if self.config.tie_embeddings:
|
346 |
+
self.tie_weights()
|
347 |
# rope
|
348 |
self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True)
|
349 |
|
350 |
+
def get_input_embeddings(self):
|
351 |
+
return self.transformer.wte
|
352 |
+
|
353 |
+
def get_output_embeddings(self):
|
354 |
+
return self.lm_head
|
355 |
+
|
356 |
def _precompute_freqs_cis(self):
|
357 |
# can actually be a buffer now, and remains in fp32! (at least in the settings I tested)
|
358 |
freqs_cis = precompute_freqs_cis(
|
|
|
468 |
elif hasattr(num_steps, "__len__") and len(num_steps) > 1:
|
469 |
num_steps_no_grad, num_steps_with_grad = num_steps
|
470 |
else:
|
471 |
+
num_steps_no_grad, num_steps_with_grad = num_steps, torch.tensor(0) if not x.is_meta else 0
|
472 |
|
473 |
with torch.no_grad():
|
474 |
# ultra annoying in ddp due to
|
|
|
601 |
"""Outputs are long tensors so that they can be passed through compiled functions"""
|
602 |
t = max(self.config.mean_recurrence - self.config.mean_backprop_depth, 0)
|
603 |
s = self.config.mean_backprop_depth
|
604 |
+
if torch.rand((1,)).is_meta: # annoying clause to make meta-tensor-based flop counting work
|
605 |
+
# these values are only the mean TFLOPs of the randomized sampler
|
606 |
+
# Note that this clause also breaks the contract, and returns ints in meta tensor mode
|
607 |
+
return t, s # type: ignore
|
608 |
if self.training:
|
609 |
sigma = 0.5
|
610 |
mu = math.log(t + s) - (sigma**2 / 2)
|
|
|
660 |
|
661 |
@torch.no_grad()
|
662 |
def generate(self, *args, **kwargs):
|
663 |
+
"""Dispatcher - use HF generate in all normal cases.
|
664 |
+
If BOTH `criterion` AND `exit_threshold` are provided as not None, we use adaptive compute.
|
665 |
+
"""
|
666 |
+
if kwargs.get("criterion", None) is not None and kwargs.get("exit_threshold", None) is not None:
|
|
|
667 |
print("Dispatching to custom generate function call")
|
668 |
return self.generate_with_adaptive_compute(*args, **kwargs)
|
669 |
else:
|
|
|
767 |
cache_kwargs: dict = {},
|
768 |
**model_kwargs,
|
769 |
) -> Union[torch.Tensor, GenerateDecoderOnlyOutput]:
|
770 |
+
"""
|
771 |
+
Generate tokens with adaptive compute. This is NOT the most efficient implementation.
|
772 |
+
For batches, on each token, we iterate until the entire batch finishes.
|
773 |
+
"""
|
774 |
# Setup
|
775 |
if generation_config is None:
|
776 |
generation_config: GenerationConfig = self.generation_config # type: ignore
|
|
|
778 |
model_kwargs["use_cache"] = True
|
779 |
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
780 |
stop_tokens = self._get_stops(generation_config, tokenizer).to(input_ids.device)
|
781 |
+
batch_size = input_ids.shape[0]
|
782 |
+
compute_steps = []
|
783 |
+
|
784 |
+
# Set up continuous compute if enabled
|
785 |
if continuous_compute:
|
786 |
embedded_inputs, _, _ = self.embed_inputs(input_ids)
|
787 |
+
current_last_latents = self.initialize_state(embedded_inputs)
|
788 |
+
|
789 |
+
# Track which sequences have finished
|
790 |
+
finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
|
791 |
|
792 |
# Generate tokens
|
793 |
for step in range(generation_config.max_length - input_ids.shape[1]):
|
|
|
800 |
if not continuous_compute:
|
801 |
current_latents = self.initialize_state(embedded_inputs, deterministic=False)
|
802 |
else:
|
803 |
+
current_latents = current_last_latents
|
804 |
+
|
805 |
+
# Initialize criterion tracking for each sequence in batch
|
806 |
+
exit_values_per_seq = [[] for _ in range(batch_size)]
|
807 |
+
compute_steps_per_seq = [0] * batch_size
|
808 |
+
exit_reached = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
|
809 |
|
810 |
+
# Set up criterions based on selected strategy
|
811 |
if criterion == "entropy-diff":
|
812 |
+
entropy = torch.ones(batch_size, device=input_ids.device) * 100.0
|
813 |
exit_threshold = 1e-3 if exit_threshold == "auto" else float(exit_threshold)
|
814 |
elif criterion in ["latent-diff", "none"]:
|
815 |
exit_threshold = 0.03 if exit_threshold == "auto" else float(exit_threshold)
|
816 |
elif "kl" in criterion:
|
817 |
V = self.config.padded_vocab_size
|
818 |
+
log_probs = ((1 / V) * torch.ones(batch_size, V, device=input_ids.device)).log()
|
819 |
if criterion == "minp-kl":
|
820 |
exit_threshold = 1e-6 if exit_threshold == "auto" else float(exit_threshold)
|
821 |
else:
|
822 |
exit_threshold = 5e-4 if exit_threshold == "auto" else float(exit_threshold)
|
823 |
elif criterion == "argmax-stability":
|
824 |
+
stable_for_n_steps = torch.zeros(batch_size, dtype=torch.long, device=input_ids.device)
|
825 |
+
current_argmax = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) * -1
|
826 |
exit_threshold = 5 if exit_threshold == "auto" else int(exit_threshold)
|
827 |
else:
|
828 |
raise ValueError("Invalid adaptive compute strategy.")
|
829 |
|
830 |
all_latents = []
|
831 |
+
next_token_logits = None
|
832 |
+
|
833 |
+
# Iterate through compute steps
|
834 |
for compute_step in range(model_inputs["num_steps"]):
|
835 |
prev_latents = current_latents.clone()
|
836 |
current_latents, block_idx, _ = self.iterate_one_step(
|
837 |
embedded_inputs, current_latents, block_idx=block_idx, **aux_inputs
|
838 |
)
|
839 |
+
|
840 |
+
if latent_dampening:
|
841 |
+
all_latents.append(current_latents)
|
842 |
+
|
843 |
if step > 0: # do not exit in prefill:
|
844 |
+
# Check exit condition for each sequence in batch
|
845 |
if criterion == "entropy-diff":
|
846 |
+
prev_entropy = entropy
|
847 |
outputs = self.predict_from_latents(current_latents, **aux_inputs)
|
848 |
+
logits: torch.Tensor = outputs.logits # type: ignore
|
849 |
+
probs = F.softmax(logits[:, -1, :], dim=-1)
|
850 |
+
entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1)
|
851 |
+
exit_values = (entropy - prev_entropy).abs()
|
852 |
+
|
|
|
853 |
elif criterion == "latent-diff":
|
854 |
+
norm_diff = (prev_latents - current_latents).norm(dim=-1) / current_latents.norm(dim=-1)
|
855 |
+
exit_values = norm_diff.mean(dim=-1)
|
856 |
+
|
857 |
+
elif "kl" in criterion:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
858 |
outputs = self.predict_from_latents(current_latents, **aux_inputs)
|
859 |
+
logits: torch.Tensor = outputs.logits # type: ignore
|
860 |
+
prev_log_probs = log_probs
|
861 |
+
if criterion == "minp-kl":
|
862 |
+
probs = F.softmax(logits[:, -1, :], dim=-1)
|
863 |
+
max_probs = probs.max(dim=-1, keepdim=True)[0]
|
864 |
+
probs_mask = probs < (0.1 * max_probs)
|
865 |
+
masked_probs = probs
|
866 |
+
masked_probs[probs_mask] = 1 / V
|
867 |
+
probs = masked_probs / masked_probs.sum(dim=-1, keepdim=True)
|
868 |
+
log_probs = probs.log()
|
869 |
+
else:
|
870 |
+
log_probs = F.log_softmax(logits[:, -1, :], dim=-1)
|
871 |
+
exit_values = F.kl_div(log_probs, prev_log_probs, reduction="none", log_target=True).sum(dim=-1)
|
872 |
+
|
873 |
elif criterion == "argmax-stability":
|
874 |
+
prev_argmax = current_argmax
|
875 |
outputs = self.predict_from_latents(current_latents, **aux_inputs)
|
876 |
+
logits: torch.Tensor = outputs.logits # type: ignore
|
877 |
+
current_argmax = logits[:, -1, :].argmax(dim=-1)
|
878 |
+
stable_for_n_steps = torch.where(
|
879 |
+
current_argmax == prev_argmax, stable_for_n_steps + 1, torch.zeros_like(stable_for_n_steps)
|
880 |
+
)
|
881 |
+
exit_values = stable_for_n_steps
|
882 |
+
|
883 |
+
# Record values and check exits for each sequence
|
884 |
+
for i in range(batch_size):
|
885 |
+
if not exit_reached[i] and not finished_sequences[i]:
|
886 |
+
exit_values_per_seq[i].append(exit_values[i].item())
|
887 |
+
|
888 |
+
new_exits = (
|
889 |
+
exit_values < exit_threshold
|
890 |
+
if criterion != "argmax-stability"
|
891 |
+
else exit_values >= exit_threshold
|
892 |
+
)
|
893 |
+
new_exits = new_exits & ~exit_reached & ~finished_sequences
|
894 |
+
|
895 |
+
if new_exits.any():
|
896 |
+
exit_reached = exit_reached | new_exits
|
897 |
+
if criterion == "latent-diff":
|
898 |
+
# Normally we don't compute the output for latent-diff, but when there is an exit,
|
899 |
+
# we need to compute and save the output
|
900 |
+
outputs = self.predict_from_latents(current_latents, **aux_inputs)
|
901 |
+
logits: torch.Tensor = outputs.logits # type: ignore
|
902 |
+
if next_token_logits is None:
|
903 |
+
next_token_logits = logits[:, -1, :].clone()
|
904 |
else:
|
905 |
+
next_token_logits = torch.where(
|
906 |
+
new_exits.unsqueeze(1).expand_as(logits[:, -1, :]), logits[:, -1, :], next_token_logits
|
907 |
+
)
|
908 |
+
for i in range(batch_size):
|
909 |
+
if new_exits[i]:
|
910 |
+
compute_steps_per_seq[i] = compute_step + 1
|
911 |
+
|
912 |
+
# If all sequences have exited, break early
|
913 |
+
if (exit_reached | finished_sequences).all():
|
914 |
+
break
|
915 |
+
# This else is if the for loop finished without breaking
|
916 |
else:
|
917 |
if not latent_dampening:
|
918 |
outputs = self.predict_from_latents(current_latents, **aux_inputs)
|
919 |
else:
|
920 |
dampened_latents = torch.sum(torch.cat(all_latents, dim=0), dim=0, keepdim=True)
|
921 |
outputs = self.predict_from_latents(dampened_latents, **aux_inputs)
|
|
|
922 |
|
923 |
+
# For sequences that didn't exit early, use the final logits
|
924 |
+
if next_token_logits is None:
|
925 |
+
next_token_logits = outputs.logits[:, -1, :] # type: ignore
|
926 |
+
else:
|
927 |
+
# Only update logits for sequences that didn't exit early
|
928 |
+
non_exit_mask = ~exit_reached & ~finished_sequences
|
929 |
+
next_token_logits = torch.where(
|
930 |
+
non_exit_mask.unsqueeze(1).expand_as(next_token_logits),
|
931 |
+
outputs.logits[:, -1, :], # type: ignore
|
932 |
+
next_token_logits,
|
933 |
+
)
|
934 |
+
|
935 |
+
# Record compute steps for non-exited sequences
|
936 |
+
for i in range(batch_size):
|
937 |
+
if non_exit_mask[i]:
|
938 |
+
compute_steps_per_seq[i] = model_inputs["num_steps"]
|
939 |
+
|
940 |
+
# Save latent states for continuous compute if enabled
|
941 |
+
if continuous_compute:
|
942 |
+
current_last_latents = current_latents[:, -1:, :]
|
943 |
|
944 |
+
# Record compute steps for this token generation
|
945 |
+
compute_steps.append([compute_steps_per_seq, exit_values_per_seq])
|
|
|
|
|
946 |
|
947 |
+
# Sample or select next token based on generation config
|
948 |
+
if generation_config.do_sample:
|
949 |
+
next_token = self._sample_next_token(
|
950 |
+
next_token_logits,
|
951 |
+
generation_config.temperature,
|
952 |
+
generation_config.top_k,
|
953 |
+
generation_config.top_p,
|
954 |
+
generation_config.min_p,
|
955 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
956 |
else:
|
957 |
+
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) # type: ignore
|
958 |
|
959 |
+
input_ids = torch.cat([input_ids, next_token], dim=-1) # type: ignore
|
960 |
|
961 |
if streamer:
|
962 |
streamer.put(next_token.cpu())
|
963 |
|
964 |
# Update model kwargs
|
965 |
model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
|
966 |
+
if continuous_compute:
|
967 |
+
model_kwargs["input_states"] = current_last_latents
|
968 |
|
969 |
+
# Check for finished sequences
|
970 |
+
for i in range(batch_size):
|
971 |
+
if not finished_sequences[i] and stop_tokens is not None and next_token[i, 0] in stop_tokens:
|
972 |
+
finished_sequences[i] = True
|
973 |
+
|
974 |
+
# Break if all sequences are finished
|
975 |
+
if finished_sequences.all():
|
976 |
break
|
977 |
|
978 |
if streamer:
|
|
|
999 |
stop_tokens.add(token_id)
|
1000 |
return torch.tensor(list(stop_tokens))
|
1001 |
|
1002 |
+
def _sample_next_token(self, next_token_logits, temperature=None, top_k=None, top_p=None, min_p=None):
|
1003 |
+
"""Helper function to sample the next token."""
|
1004 |
+
if temperature:
|
1005 |
+
next_token_logits = next_token_logits / temperature
|
1006 |
+
|
1007 |
+
probs = F.softmax(next_token_logits, dim=-1)
|
1008 |
+
|
1009 |
+
# Apply top_k
|
1010 |
+
if top_k:
|
1011 |
+
top_k_values, _ = torch.topk(probs, top_k, dim=-1)
|
1012 |
+
min_values = top_k_values[:, -1].unsqueeze(-1).expand_as(probs)
|
1013 |
+
probs = torch.where(probs < min_values, torch.zeros_like(probs), probs)
|
1014 |
+
|
1015 |
+
# Apply top_p (nucleus sampling)
|
1016 |
+
if top_p:
|
1017 |
+
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
|
1018 |
+
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
1019 |
+
|
1020 |
+
# Create mask for probs to keep
|
1021 |
+
remove_indices = cumulative_probs > top_p
|
1022 |
+
remove_indices[:, 0] = False # Keep at least the top probability
|
1023 |
+
|
1024 |
+
# Convert sorted indices mask back to original indices mask
|
1025 |
+
mask = torch.zeros_like(probs, dtype=torch.bool)
|
1026 |
+
for i in range(probs.shape[0]):
|
1027 |
+
mask[i, sorted_indices[i, remove_indices[i]]] = True
|
1028 |
+
|
1029 |
+
probs = torch.where(mask, torch.zeros_like(probs), probs)
|
1030 |
+
|
1031 |
+
# Apply min_p
|
1032 |
+
if min_p:
|
1033 |
+
max_probs = probs.max(dim=-1, keepdim=True)[0]
|
1034 |
+
min_p_threshold = min_p * max_probs
|
1035 |
+
probs = torch.where(probs < min_p_threshold, torch.zeros_like(probs), probs)
|
1036 |
+
|
1037 |
+
# Renormalize probabilities
|
1038 |
+
probs = probs / probs.sum(dim=-1, keepdim=True).clamp(min=1e-10)
|
1039 |
+
|
1040 |
+
# Sample from the distribution
|
1041 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
1042 |
+
return next_token
|
1043 |
+
|
1044 |
def get_stats(self, logits, x, latent_states, xk, input_embeds, num_steps_no_grad, num_steps_with_grad):
|
1045 |
probs = torch.softmax(logits.float(), dim=-1)
|
1046 |
prob_entropy = torch.where(probs > 0, -probs * probs.log(), 0).sum(dim=-1)
|
|
|
1095 |
# Old?
|
1096 |
AutoConfig.register("huginn_raven", RavenConfig)
|
1097 |
AutoModel.register(RavenConfig, RavenForCausalLM)
|
1098 |
+
AutoModelForCausalLM.register(RavenConfig, RavenForCausalLM)
|