Spaces:
Runtime error
Runtime error
Update modeling_llama.py
Browse files- modeling_llama.py +28 -47
modeling_llama.py
CHANGED
@@ -311,60 +311,41 @@ class LlamaAttention(nn.Module):
|
|
311 |
query_states = query_states * log_n
|
312 |
|
313 |
|
314 |
-
|
315 |
-
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
316 |
|
317 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
318 |
raise ValueError(
|
319 |
-
f"Attention
|
320 |
-
f" {attn_weights.size()}"
|
321 |
)
|
|
|
|
|
322 |
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
327 |
-
)
|
328 |
-
attn_weights = attn_weights + attention_mask
|
329 |
-
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
330 |
|
331 |
-
|
332 |
-
|
333 |
-
|
|
|
|
|
334 |
|
335 |
-
|
336 |
-
|
337 |
-
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
338 |
-
f" {attn_output.size()}"
|
339 |
-
)
|
340 |
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
attn_weights = None
|
348 |
-
|
349 |
-
return attn_output, attn_weights, past_key_value
|
350 |
-
# use flash attention
|
351 |
-
elif past_key_value is not None:
|
352 |
-
from flash_attn.flash_attn_interface import flash_attn_with_kvcache
|
353 |
-
output = flash_attn_with_kvcache(
|
354 |
-
query_states.transpose(1, 2),
|
355 |
-
key_states.transpose(1, 2),
|
356 |
-
value_states.transpose(1, 2),
|
357 |
-
cache_seqlens=kv_seq_len,
|
358 |
-
causal=True,
|
359 |
-
)
|
360 |
-
attn_output = self.o_proj(rearrange(output, "b s h d -> b s (h d)"))
|
361 |
-
else:
|
362 |
-
qkv = torch.stack(
|
363 |
-
[query_states, key_states, value_states], dim=2
|
364 |
-
) # [bsz, nh, 3, q_len, hd]
|
365 |
-
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
366 |
-
attn_output = self.flash_attn_forward(qkv)
|
367 |
-
return attn_output, None, past_key_value
|
368 |
|
369 |
|
370 |
class LlamaDecoderLayer(nn.Module):
|
|
|
311 |
query_states = query_states * log_n
|
312 |
|
313 |
|
314 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
|
|
315 |
|
316 |
+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
317 |
+
raise ValueError(
|
318 |
+
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
319 |
+
f" {attn_weights.size()}"
|
320 |
+
)
|
321 |
+
|
322 |
+
if attention_mask is not None:
|
323 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
324 |
raise ValueError(
|
325 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
|
|
326 |
)
|
327 |
+
attn_weights = attn_weights + attention_mask
|
328 |
+
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
329 |
|
330 |
+
# upcast attention to fp32
|
331 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
332 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
|
|
|
|
|
|
|
|
333 |
|
334 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
335 |
+
raise ValueError(
|
336 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
337 |
+
f" {attn_output.size()}"
|
338 |
+
)
|
339 |
|
340 |
+
attn_output = attn_output.transpose(1, 2)
|
341 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
|
|
|
|
|
342 |
|
343 |
+
attn_output = self.o_proj(attn_output)
|
344 |
+
|
345 |
+
if not output_attentions:
|
346 |
+
attn_weights = None
|
347 |
+
|
348 |
+
return attn_output, attn_weights, past_key_value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
349 |
|
350 |
|
351 |
class LlamaDecoderLayer(nn.Module):
|