Spaces:
Runtime error
Runtime error
Update modeling_llama.py
Browse files- modeling_llama.py +77 -52
modeling_llama.py
CHANGED
@@ -30,8 +30,8 @@ from transformers.activations import ACT2FN
|
|
30 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
31 |
from transformers.modeling_utils import PreTrainedModel
|
32 |
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
33 |
-
from configuration_clex import CLEXLlamaConfig
|
34 |
-
from clex_layer import LlamaCLEXScalingRotaryEmbedding
|
35 |
from einops import rearrange
|
36 |
import importlib.metadata
|
37 |
import importlib.util
|
@@ -60,14 +60,10 @@ def is_flash_attn_available():
|
|
60 |
return False
|
61 |
|
62 |
# Let's add an extra check to see if cuda is available
|
63 |
-
import torch
|
64 |
|
65 |
return _is_package_available("flash_attn") and torch.cuda.is_available()
|
66 |
|
67 |
-
|
68 |
-
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func, flash_attn_qkvpacked_func, flash_attn_with_kvcache
|
69 |
-
# from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
70 |
-
from flash_attn.bert_padding import unpad_input, pad_input
|
71 |
|
72 |
|
73 |
|
@@ -170,14 +166,17 @@ def rotate_half(x):
|
|
170 |
return torch.cat((-x2, x1), dim=-1)
|
171 |
|
172 |
|
173 |
-
def apply_rotary_pos_emb(q, k, cos, sin,
|
174 |
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
175 |
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
176 |
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
|
|
|
|
|
|
181 |
return q_embed, k_embed
|
182 |
|
183 |
|
@@ -232,7 +231,10 @@ class LlamaAttention(nn.Module):
|
|
232 |
|
233 |
attention_mask: [bsz, q_len]
|
234 |
"""
|
235 |
-
|
|
|
|
|
|
|
236 |
bsz, q_len, *_ = qkv.size()
|
237 |
|
238 |
if key_padding_mask is None:
|
@@ -283,63 +285,86 @@ class LlamaAttention(nn.Module):
|
|
283 |
|
284 |
if past_key_value is not None:
|
285 |
kv_seq_len += past_key_value[0].shape[-2]
|
286 |
-
|
|
|
|
|
287 |
|
288 |
if pack_cos_sin is not None:
|
289 |
cos, sin = pack_cos_sin.to(query_states.device)
|
290 |
else:
|
291 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
292 |
key_position_ids = torch.arange(kv_seq_len, dtype=torch.long, device=position_ids.device).unsqueeze(0).view(-1, kv_seq_len)
|
293 |
-
query_states, key_states = apply_rotary_pos_emb(query_states,
|
294 |
|
295 |
if past_key_value is not None:
|
296 |
# reuse k, v, self_attention
|
|
|
297 |
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
298 |
|
299 |
-
past_key_value = (
|
300 |
|
301 |
-
|
302 |
|
303 |
if self.log_scale:
|
304 |
log_n = torch.log(torch.tensor(kv_seq_len*1.0)).to(query_states.device, dtype=query_states.dtype) / \
|
305 |
torch.log(torch.tensor(self.config.max_position_embeddings)).to(query_states.device, dtype=query_states.dtype)
|
306 |
query_states = query_states * log_n
|
307 |
|
308 |
-
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
309 |
|
310 |
-
if
|
311 |
-
|
312 |
-
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
313 |
-
f" {attn_weights.size()}"
|
314 |
-
)
|
315 |
|
316 |
-
|
317 |
-
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
318 |
raise ValueError(
|
319 |
-
f"Attention
|
|
|
320 |
)
|
321 |
-
attn_weights = attn_weights + attention_mask
|
322 |
-
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
323 |
-
|
324 |
-
# upcast attention to fp32
|
325 |
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
326 |
-
attn_output = torch.matmul(attn_weights, value_states)
|
327 |
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
|
|
|
|
333 |
|
334 |
-
|
335 |
-
|
|
|
336 |
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
|
|
341 |
|
342 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
343 |
|
344 |
|
345 |
class LlamaDecoderLayer(nn.Module):
|
@@ -629,14 +654,14 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
629 |
if inputs_embeds is None:
|
630 |
inputs_embeds = self.embed_tokens(input_ids)
|
631 |
# embed positions
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
-
|
636 |
-
|
637 |
-
|
638 |
-
|
639 |
-
attention_mask = None
|
640 |
|
641 |
|
642 |
hidden_states = inputs_embeds
|
|
|
30 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
31 |
from transformers.modeling_utils import PreTrainedModel
|
32 |
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
33 |
+
from .configuration_clex import CLEXLlamaConfig
|
34 |
+
from .clex_layer import LlamaCLEXScalingRotaryEmbedding
|
35 |
from einops import rearrange
|
36 |
import importlib.metadata
|
37 |
import importlib.util
|
|
|
60 |
return False
|
61 |
|
62 |
# Let's add an extra check to see if cuda is available
|
|
|
63 |
|
64 |
return _is_package_available("flash_attn") and torch.cuda.is_available()
|
65 |
|
66 |
+
|
|
|
|
|
|
|
67 |
|
68 |
|
69 |
|
|
|
166 |
return torch.cat((-x2, x1), dim=-1)
|
167 |
|
168 |
|
169 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, key_position_ids):
|
170 |
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
171 |
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
172 |
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
173 |
+
cos_q = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
174 |
+
sin_q = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
175 |
+
|
176 |
+
cos_k = cos[key_position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
177 |
+
sin_k = sin[key_position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
178 |
+
q_embed = (q * cos_q) + (rotate_half(q) * sin_q)
|
179 |
+
k_embed = (k * cos_k) + (rotate_half(k) * sin_k)
|
180 |
return q_embed, k_embed
|
181 |
|
182 |
|
|
|
231 |
|
232 |
attention_mask: [bsz, q_len]
|
233 |
"""
|
234 |
+
if is_flash_attn_available():
|
235 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func, flash_attn_qkvpacked_func, flash_attn_with_kvcache
|
236 |
+
# from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
237 |
+
from flash_attn.bert_padding import unpad_input, pad_input
|
238 |
bsz, q_len, *_ = qkv.size()
|
239 |
|
240 |
if key_padding_mask is None:
|
|
|
285 |
|
286 |
if past_key_value is not None:
|
287 |
kv_seq_len += past_key_value[0].shape[-2]
|
288 |
+
cache_key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
289 |
+
else:
|
290 |
+
cache_key_states = key_states
|
291 |
|
292 |
if pack_cos_sin is not None:
|
293 |
cos, sin = pack_cos_sin.to(query_states.device)
|
294 |
else:
|
295 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
296 |
key_position_ids = torch.arange(kv_seq_len, dtype=torch.long, device=position_ids.device).unsqueeze(0).view(-1, kv_seq_len)
|
297 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, cache_key_states, cos, sin, position_ids, key_position_ids)
|
298 |
|
299 |
if past_key_value is not None:
|
300 |
# reuse k, v, self_attention
|
301 |
+
# key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
302 |
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
303 |
|
304 |
+
past_key_value = (cache_key_states, value_states) if use_cache else None
|
305 |
|
306 |
+
use_flashattn = self.config.use_flashattn and is_flash_attn_available()
|
307 |
|
308 |
if self.log_scale:
|
309 |
log_n = torch.log(torch.tensor(kv_seq_len*1.0)).to(query_states.device, dtype=query_states.dtype) / \
|
310 |
torch.log(torch.tensor(self.config.max_position_embeddings)).to(query_states.device, dtype=query_states.dtype)
|
311 |
query_states = query_states * log_n
|
312 |
|
|
|
313 |
|
314 |
+
if query_states.shape[-2] == 1 or query_states.shape[-2] != key_states.shape[-2] and not use_flashattn:
|
315 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
|
|
|
|
|
|
316 |
|
317 |
+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
|
|
318 |
raise ValueError(
|
319 |
+
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
320 |
+
f" {attn_weights.size()}"
|
321 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
322 |
|
323 |
+
if attention_mask is not None:
|
324 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
325 |
+
raise ValueError(
|
326 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
327 |
+
)
|
328 |
+
attn_weights = attn_weights + attention_mask
|
329 |
+
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
330 |
|
331 |
+
# upcast attention to fp32
|
332 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
333 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
334 |
|
335 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
336 |
+
raise ValueError(
|
337 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
338 |
+
f" {attn_output.size()}"
|
339 |
+
)
|
340 |
|
341 |
+
attn_output = attn_output.transpose(1, 2)
|
342 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
343 |
+
|
344 |
+
attn_output = self.o_proj(attn_output)
|
345 |
+
|
346 |
+
if not output_attentions:
|
347 |
+
attn_weights = None
|
348 |
+
|
349 |
+
return attn_output, attn_weights, past_key_value
|
350 |
+
# use flash attention
|
351 |
+
elif past_key_value is not None:
|
352 |
+
from flash_attn.flash_attn_interface import flash_attn_with_kvcache
|
353 |
+
output = flash_attn_with_kvcache(
|
354 |
+
query_states.transpose(1, 2),
|
355 |
+
key_states.transpose(1, 2),
|
356 |
+
value_states.transpose(1, 2),
|
357 |
+
cache_seqlens=kv_seq_len,
|
358 |
+
causal=True,
|
359 |
+
)
|
360 |
+
attn_output = self.o_proj(rearrange(output, "b s h d -> b s (h d)"))
|
361 |
+
else:
|
362 |
+
qkv = torch.stack(
|
363 |
+
[query_states, key_states, value_states], dim=2
|
364 |
+
) # [bsz, nh, 3, q_len, hd]
|
365 |
+
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
366 |
+
attn_output = self.flash_attn_forward(qkv)
|
367 |
+
return attn_output, None, past_key_value
|
368 |
|
369 |
|
370 |
class LlamaDecoderLayer(nn.Module):
|
|
|
654 |
if inputs_embeds is None:
|
655 |
inputs_embeds = self.embed_tokens(input_ids)
|
656 |
# embed positions
|
657 |
+
if attention_mask is None:
|
658 |
+
attention_mask = torch.ones(
|
659 |
+
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
660 |
+
)
|
661 |
+
attention_mask = self._prepare_decoder_attention_mask(
|
662 |
+
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
663 |
+
)
|
664 |
+
# attention_mask = None
|
665 |
|
666 |
|
667 |
hidden_states = inputs_embeds
|