Guanzheng commited on
Commit
bf0e3a8
·
1 Parent(s): f6c39bc

Update modeling_llama.py

Browse files
Files changed (1) hide show
  1. modeling_llama.py +29 -46
modeling_llama.py CHANGED
@@ -304,59 +304,42 @@ class LlamaAttention(nn.Module):
304
  log_n = torch.log(torch.tensor(kv_seq_len*1.0)).to(query_states.device, dtype=query_states.dtype) / \
305
  torch.log(torch.tensor(self.config.max_position_embeddings)).to(query_states.device, dtype=query_states.dtype)
306
  query_states = query_states * log_n
307
- if query_states.shape[-2] == 1 or query_states.shape[-2] != key_states.shape[-2] or use_flashatn:
308
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 
 
 
 
 
 
309
 
310
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
 
311
  raise ValueError(
312
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
313
- f" {attn_weights.size()}"
314
  )
 
 
 
 
 
 
315
 
316
- if attention_mask is not None:
317
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
318
- raise ValueError(
319
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
320
- )
321
- attn_weights = attn_weights + attention_mask
322
- attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
323
 
324
- # upcast attention to fp32
325
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
326
- attn_output = torch.matmul(attn_weights, value_states)
327
 
328
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
329
- raise ValueError(
330
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
331
- f" {attn_output.size()}"
332
- )
333
 
334
- attn_output = attn_output.transpose(1, 2)
335
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
336
-
337
- attn_output = self.o_proj(attn_output)
338
-
339
- if not output_attentions:
340
- attn_weights = None
341
-
342
- return attn_output, attn_weights, past_key_value
343
- # use flash attention
344
- elif past_key_value is not None:
345
- output = flash_attn_with_kvcache(
346
- query_states.transpose(1, 2),
347
- key_states.transpose(1, 2),
348
- value_states.transpose(1, 2),
349
- cache_seqlens=kv_seq_len,
350
- causal=True,
351
- )
352
- attn_output = self.o_proj(rearrange(output, "b s h d -> b s (h d)"))
353
- else:
354
- qkv = torch.stack(
355
- [query_states, key_states, value_states], dim=2
356
- ) # [bsz, nh, 3, q_len, hd]
357
- qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
358
- attn_output = self.flash_attn_forward(qkv)
359
- return attn_output, None, past_key_value
360
 
361
 
362
  class LlamaDecoderLayer(nn.Module):
 
304
  log_n = torch.log(torch.tensor(kv_seq_len*1.0)).to(query_states.device, dtype=query_states.dtype) / \
305
  torch.log(torch.tensor(self.config.max_position_embeddings)).to(query_states.device, dtype=query_states.dtype)
306
  query_states = query_states * log_n
307
+
308
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
309
+
310
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
311
+ raise ValueError(
312
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
313
+ f" {attn_weights.size()}"
314
+ )
315
 
316
+ if attention_mask is not None:
317
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
318
  raise ValueError(
319
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
 
320
  )
321
+ attn_weights = attn_weights + attention_mask
322
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
323
+
324
+ # upcast attention to fp32
325
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
326
+ attn_output = torch.matmul(attn_weights, value_states)
327
 
328
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
329
+ raise ValueError(
330
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
331
+ f" {attn_output.size()}"
332
+ )
 
 
333
 
334
+ attn_output = attn_output.transpose(1, 2)
335
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 
336
 
337
+ attn_output = self.o_proj(attn_output)
 
 
 
 
338
 
339
+ if not output_attentions:
340
+ attn_weights = None
341
+
342
+ return attn_output, attn_weights, past_key_value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
 
344
 
345
  class LlamaDecoderLayer(nn.Module):