Update modeling_bailing_moe.py
Browse files- modeling_bailing_moe.py +127 -22
 
    	
        modeling_bailing_moe.py
    CHANGED
    
    | 
         @@ -207,6 +207,90 @@ class BailingMoeDynamicNTKScalingRotaryEmbedding(BailingMoeRotaryEmbedding): 
     | 
|
| 207 | 
         
             
                    self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
         
     | 
| 208 | 
         | 
| 209 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 210 | 
         
             
            # Copied from transformers.models.llama.modeling_llama.rotate_half
         
     | 
| 211 | 
         
             
            def rotate_half(x):
         
     | 
| 212 | 
         
             
                """Rotates half the hidden dims of the input."""
         
     | 
| 
         @@ -278,7 +362,7 @@ class BailingMoeGate(nn.Module): 
     | 
|
| 278 | 
         | 
| 279 | 
         
             
                    init.kaiming_uniform_(self.weight, a=math.sqrt(5))
         
     | 
| 280 | 
         | 
| 281 | 
         
            -
                def forward(self, hidden_states):
         
     | 
| 282 | 
         
             
                    bsz, seq_len, h = hidden_states.shape
         
     | 
| 283 | 
         
             
                    # compute gating score
         
     | 
| 284 | 
         
             
                    hidden_states = hidden_states.view(-1, h)
         
     | 
| 
         @@ -286,7 +370,7 @@ class BailingMoeGate(nn.Module): 
     | 
|
| 286 | 
         
             
                    scores = logits.softmax(dim=-1, dtype=torch.float32)
         
     | 
| 287 | 
         | 
| 288 | 
         
             
                    # select top-k experts
         
     | 
| 289 | 
         
            -
                    topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted= 
     | 
| 290 | 
         | 
| 291 | 
         
             
                    # norm gate to sum 1
         
     | 
| 292 | 
         
             
                    if self.top_k > 1 and self.norm_topk_prob:
         
     | 
| 
         @@ -305,7 +389,7 @@ class BailingMoeSparseMoeBlock(nn.Module): 
     | 
|
| 305 | 
         
             
                    super().__init__()
         
     | 
| 306 | 
         
             
                    self.config = config
         
     | 
| 307 | 
         
             
                    self.num_experts_per_tok = config.num_experts_per_tok
         
     | 
| 308 | 
         
            -
                    self. 
     | 
| 309 | 
         
             
                    self.gate = BailingMoeGate(config)
         
     | 
| 310 | 
         
             
                    if config.num_shared_experts is not None:
         
     | 
| 311 | 
         
             
                        self.shared_experts = BailingMoeMLP(
         
     | 
| 
         @@ -313,7 +397,7 @@ class BailingMoeSparseMoeBlock(nn.Module): 
     | 
|
| 313 | 
         
             
                        )
         
     | 
| 314 | 
         | 
| 315 | 
         
             
                def _setup_experts(self):
         
     | 
| 316 | 
         
            -
                     
     | 
| 317 | 
         
             
                        [
         
     | 
| 318 | 
         
             
                            BailingMoeMLP(config=self.config, intermediate_size=self.config.moe_intermediate_size)
         
     | 
| 319 | 
         
             
                            for _ in range(self.config.num_experts)
         
     | 
| 
         @@ -443,6 +527,25 @@ class BailingMoeAttention(nn.Module): 
     | 
|
| 443 | 
         
             
                                scaling_factor=scaling_factor,
         
     | 
| 444 | 
         
             
                                base=self.rope_theta,
         
     | 
| 445 | 
         
             
                            )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 446 | 
         
             
                        else:
         
     | 
| 447 | 
         
             
                            raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
         
     | 
| 448 | 
         | 
| 
         @@ -1258,6 +1361,24 @@ class BailingMoeForCausalLM(BailingMoePreTrainedModel): 
     | 
|
| 1258 | 
         
             
                def get_decoder(self):
         
     | 
| 1259 | 
         
             
                    return self.model
         
     | 
| 1260 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 1261 | 
         
             
                @add_start_docstrings_to_model_forward(BAILINGMOE_INPUTS_DOCSTRING)
         
     | 
| 1262 | 
         
             
                @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
         
     | 
| 1263 | 
         
             
                def forward(
         
     | 
| 
         @@ -1325,22 +1446,7 @@ class BailingMoeForCausalLM(BailingMoePreTrainedModel): 
     | 
|
| 1325 | 
         | 
| 1326 | 
         
             
                    hidden_states = outputs[0]
         
     | 
| 1327 | 
         | 
| 1328 | 
         
            -
                     
     | 
| 1329 | 
         
            -
                        if self.training:
         
     | 
| 1330 | 
         
            -
                            norm_weight = (
         
     | 
| 1331 | 
         
            -
                                self.lm_head.weight / (torch.norm(self.lm_head.weight, p=2, dim=0, keepdim=True) + 1e-7).detach()
         
     | 
| 1332 | 
         
            -
                            )
         
     | 
| 1333 | 
         
            -
                            logits = F.linear(hidden_states, norm_weight, None)
         
     | 
| 1334 | 
         
            -
                        else:
         
     | 
| 1335 | 
         
            -
                            self.lm_head.weight.data = (
         
     | 
| 1336 | 
         
            -
                                self.lm_head.weight.data.float()
         
     | 
| 1337 | 
         
            -
                                / (torch.norm(self.lm_head.weight.data.float(), p=2, dim=0, keepdim=True) + 1e-7)
         
     | 
| 1338 | 
         
            -
                            ).to(hidden_states.dtype)
         
     | 
| 1339 | 
         
            -
                            logits = F.linear(hidden_states, self.lm_head.weight.data, None)
         
     | 
| 1340 | 
         
            -
                            self.norm_head = False
         
     | 
| 1341 | 
         
            -
                    else:
         
     | 
| 1342 | 
         
            -
                        logits = self.lm_head(hidden_states)
         
     | 
| 1343 | 
         
            -
             
     | 
| 1344 | 
         
             
                    logits = logits.float()
         
     | 
| 1345 | 
         | 
| 1346 | 
         
             
                    loss = None
         
     | 
| 
         @@ -1392,8 +1498,7 @@ class BailingMoeForCausalLM(BailingMoePreTrainedModel): 
     | 
|
| 1392 | 
         | 
| 1393 | 
         
             
                        # Keep only the unprocessed tokens:
         
     | 
| 1394 | 
         
             
                        # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
         
     | 
| 1395 | 
         
            -
                        # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
         
     | 
| 1396 | 
         
            -
                        # input)
         
     | 
| 1397 | 
         
             
                        if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
         
     | 
| 1398 | 
         
             
                            input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
         
     | 
| 1399 | 
         
             
                        # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
         
     | 
| 
         | 
|
| 207 | 
         
             
                    self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
         
     | 
| 208 | 
         | 
| 209 | 
         | 
| 210 | 
         
            +
            # Inverse dim formula to find dim based on number of rotations
         
     | 
| 211 | 
         
            +
            def yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
         
     | 
| 212 | 
         
            +
                return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
         
     | 
| 213 | 
         
            +
             
     | 
| 214 | 
         
            +
             
     | 
| 215 | 
         
            +
            # Find dim range bounds based on rotations
         
     | 
| 216 | 
         
            +
            def yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
         
     | 
| 217 | 
         
            +
                low = math.floor(yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
         
     | 
| 218 | 
         
            +
                high = math.ceil(yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
         
     | 
| 219 | 
         
            +
                return max(low, 0), min(high, dim - 1)  # Clamp values just in case
         
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
            def yarn_get_mscale(scale=1, mscale=1):
         
     | 
| 223 | 
         
            +
                if scale <= 1:
         
     | 
| 224 | 
         
            +
                    return 1.0
         
     | 
| 225 | 
         
            +
                return 0.1 * mscale * math.log(scale) + 1.0
         
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
            def yarn_linear_ramp_mask(min, max, dim):
         
     | 
| 229 | 
         
            +
                if min == max:
         
     | 
| 230 | 
         
            +
                    max += 0.001  # Prevent singularity
         
     | 
| 231 | 
         
            +
             
     | 
| 232 | 
         
            +
                linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
         
     | 
| 233 | 
         
            +
                ramp_func = torch.clamp(linear_func, 0, 1)
         
     | 
| 234 | 
         
            +
                return ramp_func
         
     | 
| 235 | 
         
            +
             
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
            class BailingMoeYarnRotaryEmbedding(BailingMoeRotaryEmbedding):
         
     | 
| 238 | 
         
            +
             
     | 
| 239 | 
         
            +
                def __init__(
         
     | 
| 240 | 
         
            +
                    self,
         
     | 
| 241 | 
         
            +
                    dim,
         
     | 
| 242 | 
         
            +
                    max_position_embeddings=2048,
         
     | 
| 243 | 
         
            +
                    base=10000,
         
     | 
| 244 | 
         
            +
                    device=None,
         
     | 
| 245 | 
         
            +
                    scaling_factor=1.0,
         
     | 
| 246 | 
         
            +
                    original_max_position_embeddings=4096,
         
     | 
| 247 | 
         
            +
                    beta_fast=32,
         
     | 
| 248 | 
         
            +
                    beta_slow=1,
         
     | 
| 249 | 
         
            +
                    mscale=1,
         
     | 
| 250 | 
         
            +
                    mscale_all_dim=0,
         
     | 
| 251 | 
         
            +
                ):
         
     | 
| 252 | 
         
            +
                    self.scaling_factor = scaling_factor
         
     | 
| 253 | 
         
            +
                    self.original_max_position_embeddings = original_max_position_embeddings
         
     | 
| 254 | 
         
            +
                    self.beta_fast = beta_fast
         
     | 
| 255 | 
         
            +
                    self.beta_slow = beta_slow
         
     | 
| 256 | 
         
            +
                    self.mscale = mscale
         
     | 
| 257 | 
         
            +
                    self.mscale_all_dim = mscale_all_dim
         
     | 
| 258 | 
         
            +
                    super().__init__(dim, max_position_embeddings, base, device)
         
     | 
| 259 | 
         
            +
             
     | 
| 260 | 
         
            +
                def _set_cos_sin_cache(self, seq_len, device, dtype):
         
     | 
| 261 | 
         
            +
                    self.max_seq_len_cached = seq_len
         
     | 
| 262 | 
         
            +
                    dim = self.dim
         
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
            +
                    freq_extra = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
         
     | 
| 265 | 
         
            +
                    freq_inter = 1.0 / (
         
     | 
| 266 | 
         
            +
                        self.scaling_factor * self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
         
     | 
| 267 | 
         
            +
                    )
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
                    low, high = yarn_find_correction_range(
         
     | 
| 270 | 
         
            +
                        self.beta_fast,
         
     | 
| 271 | 
         
            +
                        self.beta_slow,
         
     | 
| 272 | 
         
            +
                        dim,
         
     | 
| 273 | 
         
            +
                        self.base,
         
     | 
| 274 | 
         
            +
                        self.original_max_position_embeddings,
         
     | 
| 275 | 
         
            +
                    )
         
     | 
| 276 | 
         
            +
                    inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(device=device, dtype=torch.float32)
         
     | 
| 277 | 
         
            +
                    inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
         
     | 
| 278 | 
         
            +
                    self.register_buffer("inv_freq", inv_freq, persistent=False)
         
     | 
| 279 | 
         
            +
             
     | 
| 280 | 
         
            +
                    t = torch.arange(seq_len, device=device, dtype=torch.float32)
         
     | 
| 281 | 
         
            +
             
     | 
| 282 | 
         
            +
                    freqs = torch.outer(t, inv_freq)
         
     | 
| 283 | 
         
            +
             
     | 
| 284 | 
         
            +
                    _mscale = float(
         
     | 
| 285 | 
         
            +
                        yarn_get_mscale(self.scaling_factor, self.mscale)
         
     | 
| 286 | 
         
            +
                        / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
         
     | 
| 287 | 
         
            +
                    )
         
     | 
| 288 | 
         
            +
             
     | 
| 289 | 
         
            +
                    emb = torch.cat((freqs, freqs), dim=-1)
         
     | 
| 290 | 
         
            +
                    self.register_buffer("cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False)
         
     | 
| 291 | 
         
            +
                    self.register_buffer("sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False)
         
     | 
| 292 | 
         
            +
             
     | 
| 293 | 
         
            +
             
     | 
| 294 | 
         
             
            # Copied from transformers.models.llama.modeling_llama.rotate_half
         
     | 
| 295 | 
         
             
            def rotate_half(x):
         
     | 
| 296 | 
         
             
                """Rotates half the hidden dims of the input."""
         
     | 
| 
         | 
|
| 362 | 
         | 
| 363 | 
         
             
                    init.kaiming_uniform_(self.weight, a=math.sqrt(5))
         
     | 
| 364 | 
         | 
| 365 | 
         
            +
                def forward(self, hidden_states, sort=False):
         
     | 
| 366 | 
         
             
                    bsz, seq_len, h = hidden_states.shape
         
     | 
| 367 | 
         
             
                    # compute gating score
         
     | 
| 368 | 
         
             
                    hidden_states = hidden_states.view(-1, h)
         
     | 
| 
         | 
|
| 370 | 
         
             
                    scores = logits.softmax(dim=-1, dtype=torch.float32)
         
     | 
| 371 | 
         | 
| 372 | 
         
             
                    # select top-k experts
         
     | 
| 373 | 
         
            +
                    topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=sort)
         
     | 
| 374 | 
         | 
| 375 | 
         
             
                    # norm gate to sum 1
         
     | 
| 376 | 
         
             
                    if self.top_k > 1 and self.norm_topk_prob:
         
     | 
| 
         | 
|
| 389 | 
         
             
                    super().__init__()
         
     | 
| 390 | 
         
             
                    self.config = config
         
     | 
| 391 | 
         
             
                    self.num_experts_per_tok = config.num_experts_per_tok
         
     | 
| 392 | 
         
            +
                    self._setup_experts()
         
     | 
| 393 | 
         
             
                    self.gate = BailingMoeGate(config)
         
     | 
| 394 | 
         
             
                    if config.num_shared_experts is not None:
         
     | 
| 395 | 
         
             
                        self.shared_experts = BailingMoeMLP(
         
     | 
| 
         | 
|
| 397 | 
         
             
                        )
         
     | 
| 398 | 
         | 
| 399 | 
         
             
                def _setup_experts(self):
         
     | 
| 400 | 
         
            +
                    self.experts = nn.ModuleList(
         
     | 
| 401 | 
         
             
                        [
         
     | 
| 402 | 
         
             
                            BailingMoeMLP(config=self.config, intermediate_size=self.config.moe_intermediate_size)
         
     | 
| 403 | 
         
             
                            for _ in range(self.config.num_experts)
         
     | 
| 
         | 
|
| 527 | 
         
             
                                scaling_factor=scaling_factor,
         
     | 
| 528 | 
         
             
                                base=self.rope_theta,
         
     | 
| 529 | 
         
             
                            )
         
     | 
| 530 | 
         
            +
                        elif scaling_type == "yarn":
         
     | 
| 531 | 
         
            +
                            kwargs = {
         
     | 
| 532 | 
         
            +
                                key: self.config.rope_scaling[key]
         
     | 
| 533 | 
         
            +
                                for key in [
         
     | 
| 534 | 
         
            +
                                    "original_max_position_embeddings",
         
     | 
| 535 | 
         
            +
                                    "beta_fast",
         
     | 
| 536 | 
         
            +
                                    "beta_slow",
         
     | 
| 537 | 
         
            +
                                    "mscale",
         
     | 
| 538 | 
         
            +
                                    "mscale_all_dim",
         
     | 
| 539 | 
         
            +
                                ]
         
     | 
| 540 | 
         
            +
                                if key in self.config.rope_scaling
         
     | 
| 541 | 
         
            +
                            }
         
     | 
| 542 | 
         
            +
                            self.rotary_emb = BailingMoeYarnRotaryEmbedding(
         
     | 
| 543 | 
         
            +
                                self.head_dim,
         
     | 
| 544 | 
         
            +
                                max_position_embeddings=self.max_position_embeddings,
         
     | 
| 545 | 
         
            +
                                scaling_factor=scaling_factor,
         
     | 
| 546 | 
         
            +
                                base=self.rope_theta,
         
     | 
| 547 | 
         
            +
                                **kwargs,
         
     | 
| 548 | 
         
            +
                            )
         
     | 
| 549 | 
         
             
                        else:
         
     | 
| 550 | 
         
             
                            raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
         
     | 
| 551 | 
         | 
| 
         | 
|
| 1361 | 
         
             
                def get_decoder(self):
         
     | 
| 1362 | 
         
             
                    return self.model
         
     | 
| 1363 | 
         | 
| 1364 | 
         
            +
                def compute_logit(self, hidden_states):
         
     | 
| 1365 | 
         
            +
                    if self.norm_head:
         
     | 
| 1366 | 
         
            +
                        if self.training:
         
     | 
| 1367 | 
         
            +
                            norm_weight = (
         
     | 
| 1368 | 
         
            +
                                self.lm_head.weight / (torch.norm(self.lm_head.weight, p=2, dim=0, keepdim=True) + 1e-7).detach()
         
     | 
| 1369 | 
         
            +
                            )
         
     | 
| 1370 | 
         
            +
                            logits = F.linear(hidden_states, norm_weight, None)
         
     | 
| 1371 | 
         
            +
                        else:
         
     | 
| 1372 | 
         
            +
                            self.lm_head.weight.data = (
         
     | 
| 1373 | 
         
            +
                                self.lm_head.weight.data.float()
         
     | 
| 1374 | 
         
            +
                                / (torch.norm(self.lm_head.weight.data.float(), p=2, dim=0, keepdim=True) + 1e-7)
         
     | 
| 1375 | 
         
            +
                            ).to(hidden_states.dtype)
         
     | 
| 1376 | 
         
            +
                            logits = F.linear(hidden_states, self.lm_head.weight.data, None)
         
     | 
| 1377 | 
         
            +
                            self.norm_head = False
         
     | 
| 1378 | 
         
            +
                    else:
         
     | 
| 1379 | 
         
            +
                        logits = self.lm_head(hidden_states)
         
     | 
| 1380 | 
         
            +
                    return logits
         
     | 
| 1381 | 
         
            +
             
     | 
| 1382 | 
         
             
                @add_start_docstrings_to_model_forward(BAILINGMOE_INPUTS_DOCSTRING)
         
     | 
| 1383 | 
         
             
                @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
         
     | 
| 1384 | 
         
             
                def forward(
         
     | 
| 
         | 
|
| 1446 | 
         | 
| 1447 | 
         
             
                    hidden_states = outputs[0]
         
     | 
| 1448 | 
         | 
| 1449 | 
         
            +
                    logits = self.compute_logit(hidden_states=hidden_states)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 1450 | 
         
             
                    logits = logits.float()
         
     | 
| 1451 | 
         | 
| 1452 | 
         
             
                    loss = None
         
     | 
| 
         | 
|
| 1498 | 
         | 
| 1499 | 
         
             
                        # Keep only the unprocessed tokens:
         
     | 
| 1500 | 
         
             
                        # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
         
     | 
| 1501 | 
         
            +
                        # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as input)
         
     | 
| 
         | 
|
| 1502 | 
         
             
                        if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
         
     | 
| 1503 | 
         
             
                            input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
         
     | 
| 1504 | 
         
             
                        # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
         
     |