Spaces:
Runtime error
Runtime error
Update modeling_llama.py
Browse files- modeling_llama.py +77 -52
modeling_llama.py
CHANGED
|
@@ -30,8 +30,8 @@ from transformers.activations import ACT2FN
|
|
| 30 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
| 31 |
from transformers.modeling_utils import PreTrainedModel
|
| 32 |
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
| 33 |
-
from configuration_clex import CLEXLlamaConfig
|
| 34 |
-
from clex_layer import LlamaCLEXScalingRotaryEmbedding
|
| 35 |
from einops import rearrange
|
| 36 |
import importlib.metadata
|
| 37 |
import importlib.util
|
|
@@ -60,14 +60,10 @@ def is_flash_attn_available():
|
|
| 60 |
return False
|
| 61 |
|
| 62 |
# Let's add an extra check to see if cuda is available
|
| 63 |
-
import torch
|
| 64 |
|
| 65 |
return _is_package_available("flash_attn") and torch.cuda.is_available()
|
| 66 |
|
| 67 |
-
|
| 68 |
-
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func, flash_attn_qkvpacked_func, flash_attn_with_kvcache
|
| 69 |
-
# from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
| 70 |
-
from flash_attn.bert_padding import unpad_input, pad_input
|
| 71 |
|
| 72 |
|
| 73 |
|
|
@@ -170,14 +166,17 @@ def rotate_half(x):
|
|
| 170 |
return torch.cat((-x2, x1), dim=-1)
|
| 171 |
|
| 172 |
|
| 173 |
-
def apply_rotary_pos_emb(q, k, cos, sin,
|
| 174 |
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
| 175 |
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
| 176 |
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
|
|
|
|
|
|
|
|
|
| 181 |
return q_embed, k_embed
|
| 182 |
|
| 183 |
|
|
@@ -232,7 +231,10 @@ class LlamaAttention(nn.Module):
|
|
| 232 |
|
| 233 |
attention_mask: [bsz, q_len]
|
| 234 |
"""
|
| 235 |
-
|
|
|
|
|
|
|
|
|
|
| 236 |
bsz, q_len, *_ = qkv.size()
|
| 237 |
|
| 238 |
if key_padding_mask is None:
|
|
@@ -283,63 +285,86 @@ class LlamaAttention(nn.Module):
|
|
| 283 |
|
| 284 |
if past_key_value is not None:
|
| 285 |
kv_seq_len += past_key_value[0].shape[-2]
|
| 286 |
-
|
|
|
|
|
|
|
| 287 |
|
| 288 |
if pack_cos_sin is not None:
|
| 289 |
cos, sin = pack_cos_sin.to(query_states.device)
|
| 290 |
else:
|
| 291 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 292 |
key_position_ids = torch.arange(kv_seq_len, dtype=torch.long, device=position_ids.device).unsqueeze(0).view(-1, kv_seq_len)
|
| 293 |
-
query_states, key_states = apply_rotary_pos_emb(query_states,
|
| 294 |
|
| 295 |
if past_key_value is not None:
|
| 296 |
# reuse k, v, self_attention
|
|
|
|
| 297 |
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
| 298 |
|
| 299 |
-
past_key_value = (
|
| 300 |
|
| 301 |
-
|
| 302 |
|
| 303 |
if self.log_scale:
|
| 304 |
log_n = torch.log(torch.tensor(kv_seq_len*1.0)).to(query_states.device, dtype=query_states.dtype) / \
|
| 305 |
torch.log(torch.tensor(self.config.max_position_embeddings)).to(query_states.device, dtype=query_states.dtype)
|
| 306 |
query_states = query_states * log_n
|
| 307 |
|
| 308 |
-
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
| 309 |
|
| 310 |
-
if
|
| 311 |
-
|
| 312 |
-
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
| 313 |
-
f" {attn_weights.size()}"
|
| 314 |
-
)
|
| 315 |
|
| 316 |
-
|
| 317 |
-
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
| 318 |
raise ValueError(
|
| 319 |
-
f"Attention
|
|
|
|
| 320 |
)
|
| 321 |
-
attn_weights = attn_weights + attention_mask
|
| 322 |
-
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
| 323 |
-
|
| 324 |
-
# upcast attention to fp32
|
| 325 |
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| 326 |
-
attn_output = torch.matmul(attn_weights, value_states)
|
| 327 |
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
|
|
|
|
|
|
| 333 |
|
| 334 |
-
|
| 335 |
-
|
|
|
|
| 336 |
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
|
|
|
| 341 |
|
| 342 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
|
| 344 |
|
| 345 |
class LlamaDecoderLayer(nn.Module):
|
|
@@ -629,14 +654,14 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
| 629 |
if inputs_embeds is None:
|
| 630 |
inputs_embeds = self.embed_tokens(input_ids)
|
| 631 |
# embed positions
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
attention_mask = None
|
| 640 |
|
| 641 |
|
| 642 |
hidden_states = inputs_embeds
|
|
|
|
| 30 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
| 31 |
from transformers.modeling_utils import PreTrainedModel
|
| 32 |
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
| 33 |
+
from .configuration_clex import CLEXLlamaConfig
|
| 34 |
+
from .clex_layer import LlamaCLEXScalingRotaryEmbedding
|
| 35 |
from einops import rearrange
|
| 36 |
import importlib.metadata
|
| 37 |
import importlib.util
|
|
|
|
| 60 |
return False
|
| 61 |
|
| 62 |
# Let's add an extra check to see if cuda is available
|
|
|
|
| 63 |
|
| 64 |
return _is_package_available("flash_attn") and torch.cuda.is_available()
|
| 65 |
|
| 66 |
+
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
|
| 69 |
|
|
|
|
| 166 |
return torch.cat((-x2, x1), dim=-1)
|
| 167 |
|
| 168 |
|
| 169 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, key_position_ids):
|
| 170 |
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
| 171 |
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
| 172 |
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
| 173 |
+
cos_q = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
| 174 |
+
sin_q = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
| 175 |
+
|
| 176 |
+
cos_k = cos[key_position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
| 177 |
+
sin_k = sin[key_position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
| 178 |
+
q_embed = (q * cos_q) + (rotate_half(q) * sin_q)
|
| 179 |
+
k_embed = (k * cos_k) + (rotate_half(k) * sin_k)
|
| 180 |
return q_embed, k_embed
|
| 181 |
|
| 182 |
|
|
|
|
| 231 |
|
| 232 |
attention_mask: [bsz, q_len]
|
| 233 |
"""
|
| 234 |
+
if is_flash_attn_available():
|
| 235 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func, flash_attn_qkvpacked_func, flash_attn_with_kvcache
|
| 236 |
+
# from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
| 237 |
+
from flash_attn.bert_padding import unpad_input, pad_input
|
| 238 |
bsz, q_len, *_ = qkv.size()
|
| 239 |
|
| 240 |
if key_padding_mask is None:
|
|
|
|
| 285 |
|
| 286 |
if past_key_value is not None:
|
| 287 |
kv_seq_len += past_key_value[0].shape[-2]
|
| 288 |
+
cache_key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
| 289 |
+
else:
|
| 290 |
+
cache_key_states = key_states
|
| 291 |
|
| 292 |
if pack_cos_sin is not None:
|
| 293 |
cos, sin = pack_cos_sin.to(query_states.device)
|
| 294 |
else:
|
| 295 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 296 |
key_position_ids = torch.arange(kv_seq_len, dtype=torch.long, device=position_ids.device).unsqueeze(0).view(-1, kv_seq_len)
|
| 297 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, cache_key_states, cos, sin, position_ids, key_position_ids)
|
| 298 |
|
| 299 |
if past_key_value is not None:
|
| 300 |
# reuse k, v, self_attention
|
| 301 |
+
# key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
| 302 |
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
| 303 |
|
| 304 |
+
past_key_value = (cache_key_states, value_states) if use_cache else None
|
| 305 |
|
| 306 |
+
use_flashattn = self.config.use_flashattn and is_flash_attn_available()
|
| 307 |
|
| 308 |
if self.log_scale:
|
| 309 |
log_n = torch.log(torch.tensor(kv_seq_len*1.0)).to(query_states.device, dtype=query_states.dtype) / \
|
| 310 |
torch.log(torch.tensor(self.config.max_position_embeddings)).to(query_states.device, dtype=query_states.dtype)
|
| 311 |
query_states = query_states * log_n
|
| 312 |
|
|
|
|
| 313 |
|
| 314 |
+
if query_states.shape[-2] == 1 or query_states.shape[-2] != key_states.shape[-2] and not use_flashattn:
|
| 315 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
|
|
|
|
|
|
|
|
|
| 316 |
|
| 317 |
+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
|
|
|
| 318 |
raise ValueError(
|
| 319 |
+
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
| 320 |
+
f" {attn_weights.size()}"
|
| 321 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
|
| 323 |
+
if attention_mask is not None:
|
| 324 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
| 325 |
+
raise ValueError(
|
| 326 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
| 327 |
+
)
|
| 328 |
+
attn_weights = attn_weights + attention_mask
|
| 329 |
+
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
| 330 |
|
| 331 |
+
# upcast attention to fp32
|
| 332 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| 333 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 334 |
|
| 335 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
| 336 |
+
raise ValueError(
|
| 337 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
| 338 |
+
f" {attn_output.size()}"
|
| 339 |
+
)
|
| 340 |
|
| 341 |
+
attn_output = attn_output.transpose(1, 2)
|
| 342 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 343 |
+
|
| 344 |
+
attn_output = self.o_proj(attn_output)
|
| 345 |
+
|
| 346 |
+
if not output_attentions:
|
| 347 |
+
attn_weights = None
|
| 348 |
+
|
| 349 |
+
return attn_output, attn_weights, past_key_value
|
| 350 |
+
# use flash attention
|
| 351 |
+
elif past_key_value is not None:
|
| 352 |
+
from flash_attn.flash_attn_interface import flash_attn_with_kvcache
|
| 353 |
+
output = flash_attn_with_kvcache(
|
| 354 |
+
query_states.transpose(1, 2),
|
| 355 |
+
key_states.transpose(1, 2),
|
| 356 |
+
value_states.transpose(1, 2),
|
| 357 |
+
cache_seqlens=kv_seq_len,
|
| 358 |
+
causal=True,
|
| 359 |
+
)
|
| 360 |
+
attn_output = self.o_proj(rearrange(output, "b s h d -> b s (h d)"))
|
| 361 |
+
else:
|
| 362 |
+
qkv = torch.stack(
|
| 363 |
+
[query_states, key_states, value_states], dim=2
|
| 364 |
+
) # [bsz, nh, 3, q_len, hd]
|
| 365 |
+
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
| 366 |
+
attn_output = self.flash_attn_forward(qkv)
|
| 367 |
+
return attn_output, None, past_key_value
|
| 368 |
|
| 369 |
|
| 370 |
class LlamaDecoderLayer(nn.Module):
|
|
|
|
| 654 |
if inputs_embeds is None:
|
| 655 |
inputs_embeds = self.embed_tokens(input_ids)
|
| 656 |
# embed positions
|
| 657 |
+
if attention_mask is None:
|
| 658 |
+
attention_mask = torch.ones(
|
| 659 |
+
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
| 660 |
+
)
|
| 661 |
+
attention_mask = self._prepare_decoder_attention_mask(
|
| 662 |
+
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
| 663 |
+
)
|
| 664 |
+
# attention_mask = None
|
| 665 |
|
| 666 |
|
| 667 |
hidden_states = inputs_embeds
|