|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | from __future__ import annotations | 
					
						
						|  |  | 
					
						
						|  | import copy | 
					
						
						|  | import inspect | 
					
						
						|  | from dataclasses import dataclass, field | 
					
						
						|  | from typing import Any, Dict, Optional, Tuple | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  | from einops import rearrange | 
					
						
						|  | from flash_attn.flash_attn_interface import ( | 
					
						
						|  | flash_attn_kvpacked_func, | 
					
						
						|  | flash_attn_qkvpacked_func, | 
					
						
						|  | flash_attn_varlen_qkvpacked_func, | 
					
						
						|  | ) | 
					
						
						|  | from transformers import PretrainedConfig, PreTrainedModel | 
					
						
						|  | from transformers.activations import ACT2FN | 
					
						
						|  | from transformers.modeling_outputs import CausalLMOutputWithPast | 
					
						
						|  |  | 
					
						
						|  | from ...monkeypatch.utils import get_cu_seqlens_from_pos_ids | 
					
						
						|  | from .configuration_mixformer_sequential import MixFormerSequentialConfig | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @dataclass | 
					
						
						|  | class InferenceParams: | 
					
						
						|  | """Inference parameters that are passed to the main model in order | 
					
						
						|  | to efficienly calculate and store the context during inference. | 
					
						
						|  | Adapted from https://github.com/Dao-AILab/flash-attention.""" | 
					
						
						|  |  | 
					
						
						|  | max_sequence_len: int | 
					
						
						|  | max_batch_size: int | 
					
						
						|  | sequence_len_offset: int = 0 | 
					
						
						|  | batch_size_offset: int = 0 | 
					
						
						|  | key_value_memory_dict: dict = field(default_factory=dict) | 
					
						
						|  | fused_ft_kernel: bool = False | 
					
						
						|  | lengths_per_sample: Optional[torch.Tensor] = None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class Embedding(nn.Module): | 
					
						
						|  | """Token embedding with dropout.""" | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, config: PretrainedConfig) -> None: | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | self.wte = nn.Embedding(config.vocab_size, config.n_embd) | 
					
						
						|  | self.drop = nn.Dropout(config.embd_pdrop) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor: | 
					
						
						|  | input_shape = input_ids.size() | 
					
						
						|  | input_ids = input_ids.view(-1, input_shape[-1]) | 
					
						
						|  |  | 
					
						
						|  | hidden_states = self.wte(input_ids) | 
					
						
						|  | hidden_states = self.drop(hidden_states) | 
					
						
						|  |  | 
					
						
						|  | return hidden_states | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class RotaryEmbedding(nn.Module): | 
					
						
						|  | """PyTorch implementation of `flash-attn` RotaryEmbedding layer. | 
					
						
						|  | Adapted from https://github.com/Dao-AILab/flash-attention.""" | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | dim: int, | 
					
						
						|  | base: Optional[int] = 10000, | 
					
						
						|  | scale_base: Optional[float] = None, | 
					
						
						|  | device: Optional[str] = None, | 
					
						
						|  | **kwargs, | 
					
						
						|  | ) -> None: | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | if scale_base is not None: | 
					
						
						|  | raise NotImplementedError | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.dim = dim | 
					
						
						|  | self.base = base | 
					
						
						|  | self.scale_base = scale_base | 
					
						
						|  | self.device = device | 
					
						
						|  |  | 
					
						
						|  | inv_freq = 1.0 / ( | 
					
						
						|  | base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) | 
					
						
						|  | ) | 
					
						
						|  | self.register_buffer("inv_freq", inv_freq) | 
					
						
						|  |  | 
					
						
						|  | scale = ( | 
					
						
						|  | (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) | 
					
						
						|  | / (1.4 * dim) | 
					
						
						|  | if scale_base is not None | 
					
						
						|  | else None | 
					
						
						|  | ) | 
					
						
						|  | self.register_buffer("scale", scale) | 
					
						
						|  |  | 
					
						
						|  | self._seq_len_cached = 0 | 
					
						
						|  | self._cos_cached = None | 
					
						
						|  | self._sin_cached = None | 
					
						
						|  | self._cos_k_cached = None | 
					
						
						|  | self._sin_k_cached = None | 
					
						
						|  |  | 
					
						
						|  | def _update_cos_sin_cache( | 
					
						
						|  | self, x: torch.FloatTensor, seqlen_offset: Optional[int] = 0 | 
					
						
						|  | ) -> None: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | seqlen = x.shape[1] + seqlen_offset | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.inv_freq.dtype != "torch.float32": | 
					
						
						|  | self.inv_freq = 1.0 / ( | 
					
						
						|  | self.base | 
					
						
						|  | ** ( | 
					
						
						|  | torch.arange( | 
					
						
						|  | 0, self.dim, 2, device=self.device, dtype=torch.float32 | 
					
						
						|  | ) | 
					
						
						|  | / self.dim | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if ( | 
					
						
						|  | seqlen > self._seq_len_cached | 
					
						
						|  | or self._cos_cached.device != x.device | 
					
						
						|  | or self._cos_cached.dtype != x.dtype | 
					
						
						|  | ): | 
					
						
						|  | self._seq_len_cached = seqlen | 
					
						
						|  | t = torch.arange(seqlen, device=x.device, dtype=torch.float32) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | freqs = torch.outer( | 
					
						
						|  | t, self.inv_freq.to(device=t.device, dtype=torch.float32) | 
					
						
						|  | ) | 
					
						
						|  | if self.scale is None: | 
					
						
						|  | self._cos_cached = torch.cos(freqs).to(x.dtype) | 
					
						
						|  | self._sin_cached = torch.sin(freqs).to(x.dtype) | 
					
						
						|  | else: | 
					
						
						|  | power = ( | 
					
						
						|  | torch.arange( | 
					
						
						|  | seqlen, dtype=self.scale.dtype, device=self.scale.device | 
					
						
						|  | ) | 
					
						
						|  | - seqlen // 2 | 
					
						
						|  | ) / self.scale_base | 
					
						
						|  | scale = self.scale.to(device=power.device) ** rearrange( | 
					
						
						|  | power, "s -> s 1" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype) | 
					
						
						|  | self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype) | 
					
						
						|  | self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype) | 
					
						
						|  | self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype) | 
					
						
						|  |  | 
					
						
						|  | def apply_rotary_emb_qkv( | 
					
						
						|  | self, | 
					
						
						|  | qkv: torch.FloatTensor, | 
					
						
						|  | sin: torch.FloatTensor, | 
					
						
						|  | cos: torch.FloatTensor, | 
					
						
						|  | sin_k: Optional[torch.FloatTensor] = None, | 
					
						
						|  | cos_k: Optional[torch.FloatTensor] = None, | 
					
						
						|  | ) -> torch.FloatTensor: | 
					
						
						|  | _, seqlen, three, _, headdim = qkv.shape | 
					
						
						|  | assert three == 3 | 
					
						
						|  |  | 
					
						
						|  | rotary_seqlen, rotary_dim = cos.shape | 
					
						
						|  | rotary_dim *= 2 | 
					
						
						|  | assert rotary_dim <= headdim | 
					
						
						|  | assert seqlen <= rotary_seqlen | 
					
						
						|  |  | 
					
						
						|  | cos_k = cos if cos_k is None else cos_k | 
					
						
						|  | sin_k = sin if sin_k is None else sin_k | 
					
						
						|  | assert ( | 
					
						
						|  | sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | q_rot = qkv[:, :, 0, :, :rotary_dim] | 
					
						
						|  | q_pass = qkv[:, :, 0, :, rotary_dim:] | 
					
						
						|  |  | 
					
						
						|  | k_rot = qkv[:, :, 1, :, :rotary_dim] | 
					
						
						|  | k_pass = qkv[:, :, 1, :, rotary_dim:] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | q1, q2 = q_rot.chunk(2, dim=-1) | 
					
						
						|  | k1, k2 = k_rot.chunk(2, dim=-1) | 
					
						
						|  | c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange( | 
					
						
						|  | sin[:seqlen], "s d -> s 1 d" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | q1, q2, k1, k2, c, s = [ | 
					
						
						|  | t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s] | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype) | 
					
						
						|  |  | 
					
						
						|  | k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype) | 
					
						
						|  |  | 
					
						
						|  | return torch.cat( | 
					
						
						|  | [ | 
					
						
						|  | torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2), | 
					
						
						|  | torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2), | 
					
						
						|  | qkv[:, :, 2:3, :, :], | 
					
						
						|  | ], | 
					
						
						|  | axis=2, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def forward( | 
					
						
						|  | self, qkv: torch.Tensor, seqlen_offset: int = 0 | 
					
						
						|  | ) -> Tuple[torch.Tensor, torch.Tensor]: | 
					
						
						|  | """Perform the forward pass. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | qkv: Query, key and value tensors of shape (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim). | 
					
						
						|  | seqlen_offset: Used in generation where the passed `qkv` is only the last token in the batch. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | New `qkv` and the cached sinusoids. | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | self._update_cos_sin_cache(qkv, seqlen_offset) | 
					
						
						|  |  | 
					
						
						|  | return self.apply_rotary_emb_qkv( | 
					
						
						|  | qkv, self._sin_cached[seqlen_offset:], self._cos_cached[seqlen_offset:] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _update_kv_cache(kv, inference_params, layer_idx): | 
					
						
						|  | """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim) | 
					
						
						|  | Adapted from https://github.com/Dao-AILab/flash-attention.""" | 
					
						
						|  |  | 
					
						
						|  | num_heads, head_dim = kv.shape[-2:] | 
					
						
						|  | if layer_idx not in inference_params.key_value_memory_dict: | 
					
						
						|  | kv_cache = torch.empty( | 
					
						
						|  | inference_params.max_batch_size, | 
					
						
						|  | inference_params.max_sequence_len, | 
					
						
						|  | 2, | 
					
						
						|  | num_heads, | 
					
						
						|  | head_dim, | 
					
						
						|  | dtype=kv.dtype, | 
					
						
						|  | device=kv.device, | 
					
						
						|  | ) | 
					
						
						|  | inference_params.key_value_memory_dict[layer_idx] = kv_cache | 
					
						
						|  | else: | 
					
						
						|  | kv_cache = inference_params.key_value_memory_dict[layer_idx] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | batch_start = inference_params.batch_size_offset | 
					
						
						|  | batch_end = batch_start + kv.shape[0] | 
					
						
						|  | sequence_start = inference_params.sequence_len_offset | 
					
						
						|  | sequence_end = sequence_start + kv.shape[1] | 
					
						
						|  | assert batch_end <= ( | 
					
						
						|  | kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0] | 
					
						
						|  | ) | 
					
						
						|  | assert sequence_end <= ( | 
					
						
						|  | kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | assert kv_cache is not None | 
					
						
						|  | kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv | 
					
						
						|  | kv = kv_cache[batch_start:batch_end, :sequence_end, ...] | 
					
						
						|  | return kv | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class MLP(nn.Module): | 
					
						
						|  | """Multi-Layer Perceptron. | 
					
						
						|  |  | 
					
						
						|  | Reference: | 
					
						
						|  | Attention Is All You Need. | 
					
						
						|  | https://arxiv.org/pdf/1706.03762.pdf. | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | config: PretrainedConfig, | 
					
						
						|  | n_inner: Optional[int] = None, | 
					
						
						|  | act_fn: Optional[str] = None, | 
					
						
						|  | ) -> None: | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | act_fn = config.activation_function if act_fn is None else act_fn | 
					
						
						|  | assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}." | 
					
						
						|  |  | 
					
						
						|  | n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner | 
					
						
						|  | n_inner = n_inner if n_inner is not None else 4 * config.n_embd | 
					
						
						|  |  | 
					
						
						|  | self.fc1 = nn.Linear(config.n_embd, n_inner) | 
					
						
						|  | self.fc2 = nn.Linear(n_inner, config.n_embd) | 
					
						
						|  | self.act = ACT2FN[act_fn] | 
					
						
						|  |  | 
					
						
						|  | def _load_from_state_dict( | 
					
						
						|  | self, | 
					
						
						|  | state_dict, | 
					
						
						|  | prefix, | 
					
						
						|  | local_metadata, | 
					
						
						|  | strict, | 
					
						
						|  | missing_keys, | 
					
						
						|  | unexpected_keys, | 
					
						
						|  | error_msgs, | 
					
						
						|  | ): | 
					
						
						|  | old_keys = [ | 
					
						
						|  | prefix + "fc_in.weight", | 
					
						
						|  | prefix + "fc_out.weight", | 
					
						
						|  | prefix + "fc_in.bias", | 
					
						
						|  | prefix + "fc_out.bias", | 
					
						
						|  | ] | 
					
						
						|  | new_keys = [ | 
					
						
						|  | prefix + "fc1.weight", | 
					
						
						|  | prefix + "fc2.weight", | 
					
						
						|  | prefix + "fc1.bias", | 
					
						
						|  | prefix + "fc2.bias", | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  | if all(k in state_dict for k in old_keys) and not all( | 
					
						
						|  | k in state_dict for k in new_keys | 
					
						
						|  | ): | 
					
						
						|  |  | 
					
						
						|  | for old_key, new_key in zip(old_keys, new_keys): | 
					
						
						|  | state_dict[new_key] = state_dict.pop(old_key) | 
					
						
						|  |  | 
					
						
						|  | return super()._load_from_state_dict( | 
					
						
						|  | state_dict, | 
					
						
						|  | prefix, | 
					
						
						|  | local_metadata, | 
					
						
						|  | strict, | 
					
						
						|  | missing_keys, | 
					
						
						|  | unexpected_keys, | 
					
						
						|  | error_msgs, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: | 
					
						
						|  | hidden_states = self.fc1(hidden_states) | 
					
						
						|  | hidden_states = self.act(hidden_states) | 
					
						
						|  | hidden_states = self.fc2(hidden_states) | 
					
						
						|  |  | 
					
						
						|  | return hidden_states | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class FusedMLP(nn.Module): | 
					
						
						|  | """Fused Multi-Layer Perceptron from `flash-attn`. | 
					
						
						|  |  | 
					
						
						|  | Reference: | 
					
						
						|  | https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/ops/fused_dense.py. | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | config: PretrainedConfig, | 
					
						
						|  | n_inner: Optional[int] = None, | 
					
						
						|  | act_fn: Optional[str] = None, | 
					
						
						|  | raise_on_missing: bool = False, | 
					
						
						|  | ) -> None: | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | act_fn = config.activation_function if act_fn is None else act_fn | 
					
						
						|  | assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}." | 
					
						
						|  |  | 
					
						
						|  | n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner | 
					
						
						|  | n_inner = n_inner if n_inner is not None else 4 * config.n_embd | 
					
						
						|  |  | 
					
						
						|  | gelu_activations = ["gelu_new", "gelu_fast", "gelu_approx"] | 
					
						
						|  | activation = "gelu_approx" if act_fn in gelu_activations else "relu" | 
					
						
						|  |  | 
					
						
						|  | self.mlp = MLP(config, n_inner=n_inner, act_fn=act_fn) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: | 
					
						
						|  | return self.mlp(hidden_states) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class SelfAttention(nn.Module): | 
					
						
						|  | """Implement the scaled dot product attention with softmax. | 
					
						
						|  | Adapted from https://github.com/Dao-AILab/flash-attention. | 
					
						
						|  | Arguments | 
					
						
						|  | --------- | 
					
						
						|  | softmax_scale: The temperature to use for the softmax attention. | 
					
						
						|  | (default: 1/sqrt(d_keys) where d_keys is computed at | 
					
						
						|  | runtime) | 
					
						
						|  | attention_dropout: The dropout rate to apply to the attention | 
					
						
						|  | (default: 0.0) | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.causal = causal | 
					
						
						|  | self.softmax_scale = softmax_scale | 
					
						
						|  | self.drop = nn.Dropout(attention_dropout) | 
					
						
						|  |  | 
					
						
						|  | def forward( | 
					
						
						|  | self, qkv, causal=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None | 
					
						
						|  | ): | 
					
						
						|  | """Implements the multihead softmax attention. | 
					
						
						|  | Arguments | 
					
						
						|  | --------- | 
					
						
						|  | qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) | 
					
						
						|  | causal: if passed, will override self.causal | 
					
						
						|  | key_padding_mask: boolean mask to apply to the attention weights. True means to keep, | 
					
						
						|  | False means to mask out. (B, S) | 
					
						
						|  | """ | 
					
						
						|  | causal = self.causal if causal is None else causal | 
					
						
						|  | if cu_seqlens is not None: | 
					
						
						|  | return flash_attn_varlen_qkvpacked_func( | 
					
						
						|  | qkv.squeeze(0), | 
					
						
						|  | cu_seqlens, | 
					
						
						|  | max_seqlen, | 
					
						
						|  | dropout_p=self.drop.p, | 
					
						
						|  | softmax_scale=self.softmax_scale, | 
					
						
						|  | causal=causal, | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | return flash_attn_qkvpacked_func( | 
					
						
						|  | qkv, | 
					
						
						|  | dropout_p=self.drop.p, | 
					
						
						|  | softmax_scale=self.softmax_scale, | 
					
						
						|  | causal=causal, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class CrossAttention(nn.Module): | 
					
						
						|  | """Implement the scaled dot product attention with softmax. | 
					
						
						|  | Adapted from https://github.com/Dao-AILab/flash-attention. | 
					
						
						|  | Arguments | 
					
						
						|  | --------- | 
					
						
						|  | softmax_scale: The temperature to use for the softmax attention. | 
					
						
						|  | (default: 1/sqrt(d_keys) where d_keys is computed at | 
					
						
						|  | runtime) | 
					
						
						|  | attention_dropout: The dropout rate to apply to the attention | 
					
						
						|  | (default: 0.0) | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.causal = causal | 
					
						
						|  | self.softmax_scale = softmax_scale | 
					
						
						|  | self.drop = nn.Dropout(attention_dropout) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, q, kv, causal=None, key_padding_mask=None): | 
					
						
						|  | """Implements the multihead softmax attention. | 
					
						
						|  | Arguments | 
					
						
						|  | --------- | 
					
						
						|  | q: The tensor containing the query. (B, Sq, H, D) | 
					
						
						|  | kv: The tensor containing the key and value. (B, Sk, 2, H, D) | 
					
						
						|  | causal: if passed, will override self.causal | 
					
						
						|  | key_padding_mask: boolean mask to apply to the attention weights. True means to keep, | 
					
						
						|  | False means to mask out. (B, Sk) | 
					
						
						|  | """ | 
					
						
						|  | causal = self.causal if causal is None else causal | 
					
						
						|  | return flash_attn_kvpacked_func( | 
					
						
						|  | q, | 
					
						
						|  | kv, | 
					
						
						|  | dropout_p=self.drop.p, | 
					
						
						|  | softmax_scale=self.softmax_scale, | 
					
						
						|  | causal=causal, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def find_mha_dims( | 
					
						
						|  | config: PretrainedConfig, | 
					
						
						|  | n_head: Optional[int] = None, | 
					
						
						|  | head_dim: Optional[int] = None, | 
					
						
						|  | ) -> Tuple[int, int]: | 
					
						
						|  | """Validate and return the number of heads and head dimension for multi-head attention. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | config: Model configuration. | 
					
						
						|  | n_head: Number of heads. | 
					
						
						|  | head_dim: Head dimension. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | Number of heads and head dimension. | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | assert all( | 
					
						
						|  | hasattr(config, attr) for attr in ["n_embd", "n_head"] | 
					
						
						|  | ), "`config` must have `n_embd` and `n_head` attributes." | 
					
						
						|  |  | 
					
						
						|  | if head_dim is None: | 
					
						
						|  | assert ( | 
					
						
						|  | config.n_embd % config.n_head == 0 | 
					
						
						|  | ), f"Hidden size ({config.n_embd}) must be divisible by the number of heads ({config.n_head})." | 
					
						
						|  |  | 
					
						
						|  | if n_head is None and head_dim is None: | 
					
						
						|  | head_dim = config.n_embd // config.n_head | 
					
						
						|  | n_head = config.n_head | 
					
						
						|  | elif n_head is None or head_dim is None: | 
					
						
						|  | raise ValueError("`n_head` and `head_dim` must be both specified or `None`.") | 
					
						
						|  |  | 
					
						
						|  | return n_head, head_dim | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class MHA(nn.Module): | 
					
						
						|  | """Multi-head attention layer. | 
					
						
						|  | Adapted from https://github.com/Dao-AILab/flash-attention.""" | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | config: PretrainedConfig, | 
					
						
						|  | rotary_dim: Optional[int] = None, | 
					
						
						|  | n_head: Optional[int] = None, | 
					
						
						|  | head_dim: Optional[int] = None, | 
					
						
						|  | bias: Optional[bool] = True, | 
					
						
						|  | dropout: Optional[float] = 0.0, | 
					
						
						|  | softmax_scale: Optional[float] = None, | 
					
						
						|  | causal: Optional[bool] = True, | 
					
						
						|  | layer_idx: Optional[int] = None, | 
					
						
						|  | rotary_emb_scale_base: Optional[float] = None, | 
					
						
						|  | return_residual: Optional[bool] = False, | 
					
						
						|  | checkpointing: Optional[bool] = False, | 
					
						
						|  | device: Optional[str] = None, | 
					
						
						|  | dtype: Optional[torch.dtype] = None, | 
					
						
						|  | fused_dense: Optional[bool] = True, | 
					
						
						|  | flash_attn: Optional[bool] = True, | 
					
						
						|  | cutlass_attn: Optional[bool] = False, | 
					
						
						|  | flash_rotary: Optional[bool] = True, | 
					
						
						|  | raise_on_missing: Optional[bool] = False, | 
					
						
						|  | ) -> None: | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | factory_kwargs = {"device": device, "dtype": dtype} | 
					
						
						|  | n_head, head_dim = find_mha_dims(config, n_head, head_dim) | 
					
						
						|  |  | 
					
						
						|  | self.hidden_size = config.n_embd | 
					
						
						|  | self.n_head = n_head | 
					
						
						|  | self.head_dim = head_dim | 
					
						
						|  | self.op_size = n_head * head_dim | 
					
						
						|  |  | 
					
						
						|  | self.causal = causal | 
					
						
						|  | self.layer_idx = layer_idx | 
					
						
						|  | self.rotary_emb_dim = ( | 
					
						
						|  | rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0) | 
					
						
						|  | ) | 
					
						
						|  | self.fused_dense = fused_dense | 
					
						
						|  | self.flash_attn = flash_attn | 
					
						
						|  | self.cutlass_attn = cutlass_attn | 
					
						
						|  | self.flash_rotary = flash_rotary | 
					
						
						|  | self.return_residual = return_residual | 
					
						
						|  | self.checkpointing = checkpointing | 
					
						
						|  |  | 
					
						
						|  | if self.rotary_emb_dim > 0: | 
					
						
						|  | rotary_kwargs = {"device": device} | 
					
						
						|  | if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0: | 
					
						
						|  | rotary_kwargs["scale_base"] = rotary_emb_scale_base | 
					
						
						|  |  | 
					
						
						|  | self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, **rotary_kwargs) | 
					
						
						|  | else: | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  | self.Wqkv = nn.Linear( | 
					
						
						|  | self.hidden_size, 3 * self.op_size, bias=bias, **factory_kwargs | 
					
						
						|  | ) | 
					
						
						|  | self.out_proj = nn.Linear( | 
					
						
						|  | self.op_size, self.hidden_size, bias=bias, **factory_kwargs | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.inner_attn = SelfAttention( | 
					
						
						|  | causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout | 
					
						
						|  | ) | 
					
						
						|  | self.inner_cross_attn = CrossAttention( | 
					
						
						|  | causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def _update_kv_cache( | 
					
						
						|  | self, kv: torch.FloatTensor, inference_params: InferenceParams | 
					
						
						|  | ) -> None: | 
					
						
						|  | """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim) | 
					
						
						|  | Adapted from https://github.com/Dao-AILab/flash-attention.""" | 
					
						
						|  |  | 
					
						
						|  | assert ( | 
					
						
						|  | self.layer_idx is not None | 
					
						
						|  | ), "Generation requires layer_idx in the constructor" | 
					
						
						|  |  | 
					
						
						|  | return _update_kv_cache(kv, inference_params, self.layer_idx) | 
					
						
						|  |  | 
					
						
						|  | def forward( | 
					
						
						|  | self, | 
					
						
						|  | x: torch.FloatTensor, | 
					
						
						|  | x_kv: Optional[torch.FloatTensor] = None, | 
					
						
						|  | key_padding_mask: Optional[torch.BoolTensor] = None, | 
					
						
						|  | cu_seqlens: Optional[torch.LongTensor] = None, | 
					
						
						|  | max_seqlen: Optional[int] = None, | 
					
						
						|  | mixer_subset: Optional[torch.LongTensor] = None, | 
					
						
						|  | past_cache: Optional[InferenceParams] = None, | 
					
						
						|  | **kwargs, | 
					
						
						|  | ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: | 
					
						
						|  | """Perform the forward pass. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if | 
					
						
						|  | cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total | 
					
						
						|  | is the is the sum of the sequence lengths in the batch. | 
					
						
						|  | x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x. | 
					
						
						|  | key_padding_mask: boolean mask, True means to keep, False means to mask out. | 
					
						
						|  | (batch, seqlen). Only applicable when not using FlashAttention. | 
					
						
						|  | cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths | 
					
						
						|  | of the sequences in the batch, used to index into x. Only applicable when using | 
					
						
						|  | FlashAttention. | 
					
						
						|  | max_seqlen: int. Maximum sequence length in the batch. | 
					
						
						|  | mixer_subset: for cross-attention only. If not None, will take a subset of x | 
					
						
						|  | before applying the query projection. Useful for e.g., ViT where we only care | 
					
						
						|  | about the CLS token in the last layer. | 
					
						
						|  | past_cache: For generation only. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | (batch, seqlen, hidden_dim) if cu_seqlens is None and max_seqlen is None, | 
					
						
						|  | else (total, hidden_dim) where total is the is the sum of the sequence lengths | 
					
						
						|  | in the batch. | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | if cu_seqlens is not None: | 
					
						
						|  | assert max_seqlen is not None | 
					
						
						|  | assert key_padding_mask is None | 
					
						
						|  | assert self.flash_attn | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if key_padding_mask is not None: | 
					
						
						|  | assert cu_seqlens is None | 
					
						
						|  | assert max_seqlen is None | 
					
						
						|  | assert not self.flash_attn | 
					
						
						|  |  | 
					
						
						|  | if past_cache is not None: | 
					
						
						|  | assert key_padding_mask is None | 
					
						
						|  | assert cu_seqlens is None and max_seqlen is None | 
					
						
						|  |  | 
					
						
						|  | attn_kwargs = {"key_padding_mask": key_padding_mask} | 
					
						
						|  |  | 
					
						
						|  | assert x_kv is None and mixer_subset is None | 
					
						
						|  |  | 
					
						
						|  | qkv = self.Wqkv(x) | 
					
						
						|  | qkv = rearrange( | 
					
						
						|  | qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if past_cache is None: | 
					
						
						|  | if self.rotary_emb_dim > 0: | 
					
						
						|  | qkv = self.rotary_emb(qkv) | 
					
						
						|  | context = self.inner_attn( | 
					
						
						|  | qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, **attn_kwargs | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | else: | 
					
						
						|  | if self.rotary_emb_dim > 0: | 
					
						
						|  | qkv = self.rotary_emb(qkv, seqlen_offset=past_cache.sequence_len_offset) | 
					
						
						|  | q = qkv[:, :, 0] | 
					
						
						|  | kv = self._update_kv_cache(qkv[:, :, 1:], past_cache) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | causal = None if past_cache.sequence_len_offset == 0 else False | 
					
						
						|  | context = self.inner_cross_attn(q, kv, causal=causal) | 
					
						
						|  |  | 
					
						
						|  | out = rearrange(context, "... h d -> ... (h d)") | 
					
						
						|  | out = self.out_proj(out) | 
					
						
						|  |  | 
					
						
						|  | return out if not self.return_residual else (out, x) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ParallelBlock(nn.Module): | 
					
						
						|  | """Parallel block. | 
					
						
						|  |  | 
					
						
						|  | This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen). | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | config: PretrainedConfig, | 
					
						
						|  | mixer: Optional[Dict[str, Any]] = None, | 
					
						
						|  | mlp: Optional[Dict[str, Any]] = None, | 
					
						
						|  | block_idx: Optional[int] = None, | 
					
						
						|  | ) -> None: | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) | 
					
						
						|  | self.resid_dropout = nn.Dropout(config.resid_pdrop) | 
					
						
						|  | self.block_idx = block_idx | 
					
						
						|  |  | 
					
						
						|  | self.mixer = MHA(config, layer_idx=block_idx) | 
					
						
						|  | self.mlp = MLP(config) | 
					
						
						|  |  | 
					
						
						|  | def forward( | 
					
						
						|  | self, | 
					
						
						|  | hidden_states: torch.FloatTensor, | 
					
						
						|  | past_cache: Optional[torch.FloatTensor] = None, | 
					
						
						|  | cu_seqlens: Optional[torch.LongTensor] = None, | 
					
						
						|  | max_seqlen: Optional[int] = None, | 
					
						
						|  | ) -> torch.FloatTensor: | 
					
						
						|  | residual = hidden_states | 
					
						
						|  | hidden_states = self.ln(hidden_states) | 
					
						
						|  |  | 
					
						
						|  | attn_outputs = self.mixer( | 
					
						
						|  | hidden_states, | 
					
						
						|  | past_cache=past_cache, | 
					
						
						|  | cu_seqlens=cu_seqlens, | 
					
						
						|  | max_seqlen=max_seqlen, | 
					
						
						|  | ) | 
					
						
						|  | if isinstance(attn_outputs, tuple): | 
					
						
						|  | attn_outputs = attn_outputs[0] | 
					
						
						|  |  | 
					
						
						|  | attn_outputs = self.resid_dropout(attn_outputs) | 
					
						
						|  | feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states)) | 
					
						
						|  |  | 
					
						
						|  | hidden_states = attn_outputs + feed_forward_hidden_states + residual | 
					
						
						|  |  | 
					
						
						|  | return hidden_states | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class CausalLMHead(nn.Module): | 
					
						
						|  | """Causal Language Modeling head. | 
					
						
						|  |  | 
					
						
						|  | Reference: | 
					
						
						|  | Improving Language Understanding by Generative Pre-Training. | 
					
						
						|  | https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf. | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, config: PretrainedConfig) -> None: | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) | 
					
						
						|  | self.linear = nn.Linear(config.n_embd, config.vocab_size) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: | 
					
						
						|  | hidden_states = self.ln(hidden_states) | 
					
						
						|  | logits = self.linear(hidden_states).to(torch.float32) | 
					
						
						|  |  | 
					
						
						|  | return logits | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class CausalLMLoss(nn.Module): | 
					
						
						|  | """Causal Language Modeling loss. | 
					
						
						|  |  | 
					
						
						|  | Reference: | 
					
						
						|  | Improving Language Understanding by Generative Pre-Training. | 
					
						
						|  | https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf. | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, shift_labels: Optional[bool] = True) -> None: | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | self.shift_labels = shift_labels | 
					
						
						|  | self.loss_fct = nn.CrossEntropyLoss() | 
					
						
						|  |  | 
					
						
						|  | def forward( | 
					
						
						|  | self, logits: torch.FloatTensor, labels: torch.LongTensor | 
					
						
						|  | ) -> torch.FloatTensor: | 
					
						
						|  | if self.shift_labels: | 
					
						
						|  | logits = logits[..., :-1, :].contiguous() | 
					
						
						|  | labels = labels[..., 1:].contiguous() | 
					
						
						|  |  | 
					
						
						|  | loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) | 
					
						
						|  |  | 
					
						
						|  | return loss | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class MixFormerSequentialPreTrainedModel(PreTrainedModel): | 
					
						
						|  | """MixFormer (sequential for DeepSpeed) pre-trained model.""" | 
					
						
						|  |  | 
					
						
						|  | config_class = MixFormerSequentialConfig | 
					
						
						|  | base_model_prefix = "transformer" | 
					
						
						|  | supports_gradient_checkpointing = True | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, *inputs, **kwargs) -> None: | 
					
						
						|  | super().__init__(*inputs, **kwargs) | 
					
						
						|  |  | 
					
						
						|  | def prepare_inputs_for_generation( | 
					
						
						|  | self, input_ids, past_key_values=None, **kwargs | 
					
						
						|  | ) -> Dict[str, Any]: | 
					
						
						|  | if "use_cache" in kwargs and not kwargs["use_cache"]: | 
					
						
						|  | return {"input_ids": input_ids} | 
					
						
						|  |  | 
					
						
						|  | if past_key_values is None or not ( | 
					
						
						|  | isinstance(past_key_values, InferenceParams) | 
					
						
						|  | ): | 
					
						
						|  | past_key_values = InferenceParams( | 
					
						
						|  | max_batch_size=input_ids.shape[0], | 
					
						
						|  | max_sequence_len=self.config.n_positions, | 
					
						
						|  | sequence_len_offset=0, | 
					
						
						|  | batch_size_offset=0, | 
					
						
						|  | fused_ft_kernel=False, | 
					
						
						|  | key_value_memory_dict={}, | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | past_key_values.sequence_len_offset = len(input_ids[0]) - 1 | 
					
						
						|  | input_ids = input_ids[:, -1].unsqueeze(-1) | 
					
						
						|  |  | 
					
						
						|  | return {"input_ids": input_ids, "past_key_values": past_key_values, **kwargs} | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class PackedSequential(nn.Sequential): | 
					
						
						|  | def forward( | 
					
						
						|  | self, | 
					
						
						|  | input, | 
					
						
						|  | cu_seqlens: Optional[torch.LongTensor] = None, | 
					
						
						|  | max_seqlen: Optional[int] = None, | 
					
						
						|  | ): | 
					
						
						|  | for module in self: | 
					
						
						|  | sig = inspect.signature(module.forward) | 
					
						
						|  | if "cu_seqlens" in sig.parameters: | 
					
						
						|  | input = module(input, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) | 
					
						
						|  | else: | 
					
						
						|  | input = module(input) | 
					
						
						|  | return input | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel): | 
					
						
						|  | """MixFormer (sequential for DeepSpeed) for Causal Language Modeling.""" | 
					
						
						|  |  | 
					
						
						|  | _keys_to_ignore_on_load_missing = [""] | 
					
						
						|  | _keys_to_ignore_on_load_unexpected = [ | 
					
						
						|  | r"layers\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)" | 
					
						
						|  | ] | 
					
						
						|  | _no_split_modules = ["ParallelBlock"] | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, config: MixFormerSequentialConfig) -> None: | 
					
						
						|  | super().__init__(config) | 
					
						
						|  |  | 
					
						
						|  | modules = [Embedding(config)] | 
					
						
						|  | block_config = config.architecture | 
					
						
						|  |  | 
					
						
						|  | if not isinstance(block_config, list): | 
					
						
						|  | block_config = [block_config for _ in range(config.n_layer)] | 
					
						
						|  |  | 
					
						
						|  | if config.n_layer != len(block_config): | 
					
						
						|  | config.n_layer = len(block_config) | 
					
						
						|  |  | 
					
						
						|  | for block_idx, block in enumerate(block_config): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | block = copy.deepcopy(block) or {"block_cls": "parallel"} | 
					
						
						|  | block.pop("path", None) or block.pop("block_cls", None) | 
					
						
						|  |  | 
					
						
						|  | block["block_idx"] = block_idx | 
					
						
						|  | modules.append(ParallelBlock(config, **block)) | 
					
						
						|  |  | 
					
						
						|  | modules.append(CausalLMHead(config)) | 
					
						
						|  |  | 
					
						
						|  | self.layers = PackedSequential(*modules) | 
					
						
						|  | self.loss = CausalLMLoss() | 
					
						
						|  |  | 
					
						
						|  | self.post_init() | 
					
						
						|  |  | 
					
						
						|  | def get_input_embeddings(self) -> nn.Embedding: | 
					
						
						|  | return self.layers[0].wte | 
					
						
						|  |  | 
					
						
						|  | def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None: | 
					
						
						|  | self.layers[0].wte = new_embeddings | 
					
						
						|  |  | 
					
						
						|  | def get_output_embeddings(self) -> nn.Linear: | 
					
						
						|  | return self.layers[-1].linear | 
					
						
						|  |  | 
					
						
						|  | def set_output_embeddings(self, new_embeddings: nn.Linear) -> None: | 
					
						
						|  | self.layers[-1].linear = new_embeddings | 
					
						
						|  |  | 
					
						
						|  | def forward( | 
					
						
						|  | self, | 
					
						
						|  | input_ids: torch.LongTensor, | 
					
						
						|  | labels: Optional[torch.LongTensor] = None, | 
					
						
						|  | past_key_values: Optional[torch.FloatTensor] = None, | 
					
						
						|  | position_ids: Optional[torch.LongTensor] = None, | 
					
						
						|  | **kwargs, | 
					
						
						|  | ) -> CausalLMOutputWithPast: | 
					
						
						|  | cu_seqlens: Optional[torch.LongTensor] = None | 
					
						
						|  | max_seqlen: Optional[int] = None | 
					
						
						|  | if position_ids is not None: | 
					
						
						|  | batch_size, seq_length = input_ids.shape | 
					
						
						|  | position_ids = position_ids.view(-1, seq_length).long() | 
					
						
						|  | cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids) | 
					
						
						|  | cu_seqlens = cu_seqlens.squeeze() | 
					
						
						|  |  | 
					
						
						|  | if not past_key_values: | 
					
						
						|  | lm_logits = self.layers( | 
					
						
						|  | input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | hidden_layer = self.layers[0](input_ids) | 
					
						
						|  | for module in self.layers[1:-1]: | 
					
						
						|  | hidden_layer = module( | 
					
						
						|  | hidden_layer, | 
					
						
						|  | past_cache=past_key_values, | 
					
						
						|  | cu_seqlens=cu_seqlens, | 
					
						
						|  | max_seqlen=max_seqlen, | 
					
						
						|  | ) | 
					
						
						|  | lm_logits = self.layers[-1](hidden_layer) | 
					
						
						|  |  | 
					
						
						|  | loss = None | 
					
						
						|  | if labels is not None: | 
					
						
						|  | loss = self.loss(lm_logits, labels) | 
					
						
						|  |  | 
					
						
						|  | return CausalLMOutputWithPast( | 
					
						
						|  | loss=loss, logits=lm_logits, past_key_values=past_key_values | 
					
						
						|  | ) | 
					
						
						|  |  |