Guanzheng commited on
Commit
5d8ca76
·
1 Parent(s): 75ec811

Update modeling_llama.py

Browse files
Files changed (1) hide show
  1. modeling_llama.py +28 -47
modeling_llama.py CHANGED
@@ -311,60 +311,41 @@ class LlamaAttention(nn.Module):
311
  query_states = query_states * log_n
312
 
313
 
314
- if query_states.shape[-2] == 1 or query_states.shape[-2] != key_states.shape[-2] and not use_flashattn:
315
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
316
 
317
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
 
 
 
 
 
 
 
318
  raise ValueError(
319
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
320
- f" {attn_weights.size()}"
321
  )
 
 
322
 
323
- if attention_mask is not None:
324
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
325
- raise ValueError(
326
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
327
- )
328
- attn_weights = attn_weights + attention_mask
329
- attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
330
 
331
- # upcast attention to fp32
332
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
333
- attn_output = torch.matmul(attn_weights, value_states)
 
 
334
 
335
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
336
- raise ValueError(
337
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
338
- f" {attn_output.size()}"
339
- )
340
 
341
- attn_output = attn_output.transpose(1, 2)
342
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
343
-
344
- attn_output = self.o_proj(attn_output)
345
-
346
- if not output_attentions:
347
- attn_weights = None
348
-
349
- return attn_output, attn_weights, past_key_value
350
- # use flash attention
351
- elif past_key_value is not None:
352
- from flash_attn.flash_attn_interface import flash_attn_with_kvcache
353
- output = flash_attn_with_kvcache(
354
- query_states.transpose(1, 2),
355
- key_states.transpose(1, 2),
356
- value_states.transpose(1, 2),
357
- cache_seqlens=kv_seq_len,
358
- causal=True,
359
- )
360
- attn_output = self.o_proj(rearrange(output, "b s h d -> b s (h d)"))
361
- else:
362
- qkv = torch.stack(
363
- [query_states, key_states, value_states], dim=2
364
- ) # [bsz, nh, 3, q_len, hd]
365
- qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
366
- attn_output = self.flash_attn_forward(qkv)
367
- return attn_output, None, past_key_value
368
 
369
 
370
  class LlamaDecoderLayer(nn.Module):
 
311
  query_states = query_states * log_n
312
 
313
 
314
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 
315
 
316
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
317
+ raise ValueError(
318
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
319
+ f" {attn_weights.size()}"
320
+ )
321
+
322
+ if attention_mask is not None:
323
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
324
  raise ValueError(
325
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
 
326
  )
327
+ attn_weights = attn_weights + attention_mask
328
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
329
 
330
+ # upcast attention to fp32
331
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
332
+ attn_output = torch.matmul(attn_weights, value_states)
 
 
 
 
333
 
334
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
335
+ raise ValueError(
336
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
337
+ f" {attn_output.size()}"
338
+ )
339
 
340
+ attn_output = attn_output.transpose(1, 2)
341
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 
 
 
342
 
343
+ attn_output = self.o_proj(attn_output)
344
+
345
+ if not output_attentions:
346
+ attn_weights = None
347
+
348
+ return attn_output, attn_weights, past_key_value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
 
350
 
351
  class LlamaDecoderLayer(nn.Module):