Guanzheng commited on
Commit
39684af
·
1 Parent(s): a572663

Update modeling_llama.py

Browse files
Files changed (1) hide show
  1. modeling_llama.py +77 -52
modeling_llama.py CHANGED
@@ -30,8 +30,8 @@ from transformers.activations import ACT2FN
30
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
31
  from transformers.modeling_utils import PreTrainedModel
32
  from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
33
- from configuration_clex import CLEXLlamaConfig
34
- from clex_layer import LlamaCLEXScalingRotaryEmbedding
35
  from einops import rearrange
36
  import importlib.metadata
37
  import importlib.util
@@ -60,14 +60,10 @@ def is_flash_attn_available():
60
  return False
61
 
62
  # Let's add an extra check to see if cuda is available
63
- import torch
64
 
65
  return _is_package_available("flash_attn") and torch.cuda.is_available()
66
 
67
- if is_flash_attn_available():
68
- from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func, flash_attn_qkvpacked_func, flash_attn_with_kvcache
69
- # from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
70
- from flash_attn.bert_padding import unpad_input, pad_input
71
 
72
 
73
 
@@ -170,14 +166,17 @@ def rotate_half(x):
170
  return torch.cat((-x2, x1), dim=-1)
171
 
172
 
173
- def apply_rotary_pos_emb(q, k, cos, sin, q_len, position_ids):
174
  # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
175
  cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
176
  sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
177
- cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
178
- sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
179
- q_embed = (q * cos[:, :, -q_len:, :]) + (rotate_half(q) * sin[:, :, -q_len:, :])
180
- k_embed = (k * cos) + (rotate_half(k) * sin)
 
 
 
181
  return q_embed, k_embed
182
 
183
 
@@ -232,7 +231,10 @@ class LlamaAttention(nn.Module):
232
 
233
  attention_mask: [bsz, q_len]
234
  """
235
-
 
 
 
236
  bsz, q_len, *_ = qkv.size()
237
 
238
  if key_padding_mask is None:
@@ -283,63 +285,86 @@ class LlamaAttention(nn.Module):
283
 
284
  if past_key_value is not None:
285
  kv_seq_len += past_key_value[0].shape[-2]
286
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
 
 
287
 
288
  if pack_cos_sin is not None:
289
  cos, sin = pack_cos_sin.to(query_states.device)
290
  else:
291
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
292
  key_position_ids = torch.arange(kv_seq_len, dtype=torch.long, device=position_ids.device).unsqueeze(0).view(-1, kv_seq_len)
293
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, q_len, key_position_ids)
294
 
295
  if past_key_value is not None:
296
  # reuse k, v, self_attention
 
297
  value_states = torch.cat([past_key_value[1], value_states], dim=2)
298
 
299
- past_key_value = (key_states, value_states) if use_cache else None
300
 
301
- use_flashatn = self.config.use_flashattn and is_flash_attn_available()
302
 
303
  if self.log_scale:
304
  log_n = torch.log(torch.tensor(kv_seq_len*1.0)).to(query_states.device, dtype=query_states.dtype) / \
305
  torch.log(torch.tensor(self.config.max_position_embeddings)).to(query_states.device, dtype=query_states.dtype)
306
  query_states = query_states * log_n
307
 
308
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
309
 
310
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
311
- raise ValueError(
312
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
313
- f" {attn_weights.size()}"
314
- )
315
 
316
- if attention_mask is not None:
317
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
318
  raise ValueError(
319
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
 
320
  )
321
- attn_weights = attn_weights + attention_mask
322
- attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
323
-
324
- # upcast attention to fp32
325
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
326
- attn_output = torch.matmul(attn_weights, value_states)
327
 
328
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
329
- raise ValueError(
330
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
331
- f" {attn_output.size()}"
332
- )
 
 
333
 
334
- attn_output = attn_output.transpose(1, 2)
335
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 
336
 
337
- attn_output = self.o_proj(attn_output)
338
-
339
- if not output_attentions:
340
- attn_weights = None
 
341
 
342
- return attn_output, attn_weights, past_key_value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
 
344
 
345
  class LlamaDecoderLayer(nn.Module):
@@ -629,14 +654,14 @@ class LlamaModel(LlamaPreTrainedModel):
629
  if inputs_embeds is None:
630
  inputs_embeds = self.embed_tokens(input_ids)
631
  # embed positions
632
- # if attention_mask is None:
633
- # attention_mask = torch.ones(
634
- # (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
635
- # )
636
- # attention_mask = self._prepare_decoder_attention_mask(
637
- # attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
638
- # )
639
- attention_mask = None
640
 
641
 
642
  hidden_states = inputs_embeds
 
30
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
31
  from transformers.modeling_utils import PreTrainedModel
32
  from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
33
+ from .configuration_clex import CLEXLlamaConfig
34
+ from .clex_layer import LlamaCLEXScalingRotaryEmbedding
35
  from einops import rearrange
36
  import importlib.metadata
37
  import importlib.util
 
60
  return False
61
 
62
  # Let's add an extra check to see if cuda is available
 
63
 
64
  return _is_package_available("flash_attn") and torch.cuda.is_available()
65
 
66
+
 
 
 
67
 
68
 
69
 
 
166
  return torch.cat((-x2, x1), dim=-1)
167
 
168
 
169
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, key_position_ids):
170
  # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
171
  cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
172
  sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
173
+ cos_q = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
174
+ sin_q = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
175
+
176
+ cos_k = cos[key_position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
177
+ sin_k = sin[key_position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
178
+ q_embed = (q * cos_q) + (rotate_half(q) * sin_q)
179
+ k_embed = (k * cos_k) + (rotate_half(k) * sin_k)
180
  return q_embed, k_embed
181
 
182
 
 
231
 
232
  attention_mask: [bsz, q_len]
233
  """
234
+ if is_flash_attn_available():
235
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func, flash_attn_qkvpacked_func, flash_attn_with_kvcache
236
+ # from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
237
+ from flash_attn.bert_padding import unpad_input, pad_input
238
  bsz, q_len, *_ = qkv.size()
239
 
240
  if key_padding_mask is None:
 
285
 
286
  if past_key_value is not None:
287
  kv_seq_len += past_key_value[0].shape[-2]
288
+ cache_key_states = torch.cat([past_key_value[0], key_states], dim=2)
289
+ else:
290
+ cache_key_states = key_states
291
 
292
  if pack_cos_sin is not None:
293
  cos, sin = pack_cos_sin.to(query_states.device)
294
  else:
295
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
296
  key_position_ids = torch.arange(kv_seq_len, dtype=torch.long, device=position_ids.device).unsqueeze(0).view(-1, kv_seq_len)
297
+ query_states, key_states = apply_rotary_pos_emb(query_states, cache_key_states, cos, sin, position_ids, key_position_ids)
298
 
299
  if past_key_value is not None:
300
  # reuse k, v, self_attention
301
+ # key_states = torch.cat([past_key_value[0], key_states], dim=2)
302
  value_states = torch.cat([past_key_value[1], value_states], dim=2)
303
 
304
+ past_key_value = (cache_key_states, value_states) if use_cache else None
305
 
306
+ use_flashattn = self.config.use_flashattn and is_flash_attn_available()
307
 
308
  if self.log_scale:
309
  log_n = torch.log(torch.tensor(kv_seq_len*1.0)).to(query_states.device, dtype=query_states.dtype) / \
310
  torch.log(torch.tensor(self.config.max_position_embeddings)).to(query_states.device, dtype=query_states.dtype)
311
  query_states = query_states * log_n
312
 
 
313
 
314
+ if query_states.shape[-2] == 1 or query_states.shape[-2] != key_states.shape[-2] and not use_flashattn:
315
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 
 
 
316
 
317
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
 
318
  raise ValueError(
319
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
320
+ f" {attn_weights.size()}"
321
  )
 
 
 
 
 
 
322
 
323
+ if attention_mask is not None:
324
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
325
+ raise ValueError(
326
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
327
+ )
328
+ attn_weights = attn_weights + attention_mask
329
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
330
 
331
+ # upcast attention to fp32
332
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
333
+ attn_output = torch.matmul(attn_weights, value_states)
334
 
335
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
336
+ raise ValueError(
337
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
338
+ f" {attn_output.size()}"
339
+ )
340
 
341
+ attn_output = attn_output.transpose(1, 2)
342
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
343
+
344
+ attn_output = self.o_proj(attn_output)
345
+
346
+ if not output_attentions:
347
+ attn_weights = None
348
+
349
+ return attn_output, attn_weights, past_key_value
350
+ # use flash attention
351
+ elif past_key_value is not None:
352
+ from flash_attn.flash_attn_interface import flash_attn_with_kvcache
353
+ output = flash_attn_with_kvcache(
354
+ query_states.transpose(1, 2),
355
+ key_states.transpose(1, 2),
356
+ value_states.transpose(1, 2),
357
+ cache_seqlens=kv_seq_len,
358
+ causal=True,
359
+ )
360
+ attn_output = self.o_proj(rearrange(output, "b s h d -> b s (h d)"))
361
+ else:
362
+ qkv = torch.stack(
363
+ [query_states, key_states, value_states], dim=2
364
+ ) # [bsz, nh, 3, q_len, hd]
365
+ qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
366
+ attn_output = self.flash_attn_forward(qkv)
367
+ return attn_output, None, past_key_value
368
 
369
 
370
  class LlamaDecoderLayer(nn.Module):
 
654
  if inputs_embeds is None:
655
  inputs_embeds = self.embed_tokens(input_ids)
656
  # embed positions
657
+ if attention_mask is None:
658
+ attention_mask = torch.ones(
659
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
660
+ )
661
+ attention_mask = self._prepare_decoder_attention_mask(
662
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
663
+ )
664
+ # attention_mask = None
665
 
666
 
667
  hidden_states = inputs_embeds