Spaces:
Runtime error
Runtime error
Update modeling_llama.py
Browse files- modeling_llama.py +29 -46
modeling_llama.py
CHANGED
@@ -304,59 +304,42 @@ class LlamaAttention(nn.Module):
|
|
304 |
log_n = torch.log(torch.tensor(kv_seq_len*1.0)).to(query_states.device, dtype=query_states.dtype) / \
|
305 |
torch.log(torch.tensor(self.config.max_position_embeddings)).to(query_states.device, dtype=query_states.dtype)
|
306 |
query_states = query_states * log_n
|
307 |
-
|
308 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
309 |
|
310 |
-
|
|
|
311 |
raise ValueError(
|
312 |
-
f"Attention
|
313 |
-
f" {attn_weights.size()}"
|
314 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
315 |
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
attn_weights = attn_weights + attention_mask
|
322 |
-
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
323 |
|
324 |
-
|
325 |
-
|
326 |
-
attn_output = torch.matmul(attn_weights, value_states)
|
327 |
|
328 |
-
|
329 |
-
raise ValueError(
|
330 |
-
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
331 |
-
f" {attn_output.size()}"
|
332 |
-
)
|
333 |
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
if not output_attentions:
|
340 |
-
attn_weights = None
|
341 |
-
|
342 |
-
return attn_output, attn_weights, past_key_value
|
343 |
-
# use flash attention
|
344 |
-
elif past_key_value is not None:
|
345 |
-
output = flash_attn_with_kvcache(
|
346 |
-
query_states.transpose(1, 2),
|
347 |
-
key_states.transpose(1, 2),
|
348 |
-
value_states.transpose(1, 2),
|
349 |
-
cache_seqlens=kv_seq_len,
|
350 |
-
causal=True,
|
351 |
-
)
|
352 |
-
attn_output = self.o_proj(rearrange(output, "b s h d -> b s (h d)"))
|
353 |
-
else:
|
354 |
-
qkv = torch.stack(
|
355 |
-
[query_states, key_states, value_states], dim=2
|
356 |
-
) # [bsz, nh, 3, q_len, hd]
|
357 |
-
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
358 |
-
attn_output = self.flash_attn_forward(qkv)
|
359 |
-
return attn_output, None, past_key_value
|
360 |
|
361 |
|
362 |
class LlamaDecoderLayer(nn.Module):
|
|
|
304 |
log_n = torch.log(torch.tensor(kv_seq_len*1.0)).to(query_states.device, dtype=query_states.dtype) / \
|
305 |
torch.log(torch.tensor(self.config.max_position_embeddings)).to(query_states.device, dtype=query_states.dtype)
|
306 |
query_states = query_states * log_n
|
307 |
+
|
308 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
309 |
+
|
310 |
+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
311 |
+
raise ValueError(
|
312 |
+
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
313 |
+
f" {attn_weights.size()}"
|
314 |
+
)
|
315 |
|
316 |
+
if attention_mask is not None:
|
317 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
318 |
raise ValueError(
|
319 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
|
|
320 |
)
|
321 |
+
attn_weights = attn_weights + attention_mask
|
322 |
+
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
323 |
+
|
324 |
+
# upcast attention to fp32
|
325 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
326 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
327 |
|
328 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
329 |
+
raise ValueError(
|
330 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
331 |
+
f" {attn_output.size()}"
|
332 |
+
)
|
|
|
|
|
333 |
|
334 |
+
attn_output = attn_output.transpose(1, 2)
|
335 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
|
336 |
|
337 |
+
attn_output = self.o_proj(attn_output)
|
|
|
|
|
|
|
|
|
338 |
|
339 |
+
if not output_attentions:
|
340 |
+
attn_weights = None
|
341 |
+
|
342 |
+
return attn_output, attn_weights, past_key_value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
343 |
|
344 |
|
345 |
class LlamaDecoderLayer(nn.Module):
|