Spaces:
Runtime error
Runtime error
Update modeling_llama.py
Browse files- modeling_llama.py +7 -18
modeling_llama.py
CHANGED
|
@@ -166,17 +166,14 @@ def rotate_half(x):
|
|
| 166 |
return torch.cat((-x2, x1), dim=-1)
|
| 167 |
|
| 168 |
|
| 169 |
-
def apply_rotary_pos_emb(q, k, cos, sin, position_ids
|
| 170 |
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
| 171 |
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
| 172 |
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
sin_k = sin[key_position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
| 178 |
-
q_embed = (q * cos_q) + (rotate_half(q) * sin_q)
|
| 179 |
-
k_embed = (k * cos_k) + (rotate_half(k) * sin_k)
|
| 180 |
return q_embed, k_embed
|
| 181 |
|
| 182 |
|
|
@@ -285,30 +282,22 @@ class LlamaAttention(nn.Module):
|
|
| 285 |
|
| 286 |
if past_key_value is not None:
|
| 287 |
kv_seq_len += past_key_value[0].shape[-2]
|
| 288 |
-
cache_key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
| 289 |
-
else:
|
| 290 |
-
cache_key_states = key_states
|
| 291 |
|
| 292 |
if pack_cos_sin is not None:
|
| 293 |
cos, sin = pack_cos_sin.to(query_states.device)
|
| 294 |
else:
|
| 295 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 296 |
-
|
| 297 |
-
query_states, key_states = apply_rotary_pos_emb(query_states, cache_key_states, cos, sin, position_ids, key_position_ids)
|
| 298 |
|
| 299 |
if past_key_value is not None:
|
| 300 |
# reuse k, v, self_attention
|
| 301 |
-
|
| 302 |
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
| 303 |
|
| 304 |
past_key_value = (cache_key_states, value_states) if use_cache else None
|
| 305 |
|
| 306 |
use_flashattn = self.config.use_flashattn and is_flash_attn_available()
|
| 307 |
|
| 308 |
-
if self.log_scale:
|
| 309 |
-
log_n = torch.log(torch.tensor(kv_seq_len*1.0)).to(query_states.device, dtype=query_states.dtype) / \
|
| 310 |
-
torch.log(torch.tensor(self.config.max_position_embeddings)).to(query_states.device, dtype=query_states.dtype)
|
| 311 |
-
query_states = query_states * log_n
|
| 312 |
|
| 313 |
|
| 314 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
|
|
|
| 166 |
return torch.cat((-x2, x1), dim=-1)
|
| 167 |
|
| 168 |
|
| 169 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
| 170 |
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
| 171 |
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
| 172 |
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
| 173 |
+
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
| 174 |
+
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
| 175 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 176 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
|
|
|
|
|
|
|
|
| 177 |
return q_embed, k_embed
|
| 178 |
|
| 179 |
|
|
|
|
| 282 |
|
| 283 |
if past_key_value is not None:
|
| 284 |
kv_seq_len += past_key_value[0].shape[-2]
|
|
|
|
|
|
|
|
|
|
| 285 |
|
| 286 |
if pack_cos_sin is not None:
|
| 287 |
cos, sin = pack_cos_sin.to(query_states.device)
|
| 288 |
else:
|
| 289 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 290 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
|
|
| 291 |
|
| 292 |
if past_key_value is not None:
|
| 293 |
# reuse k, v, self_attention
|
| 294 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
| 295 |
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
| 296 |
|
| 297 |
past_key_value = (cache_key_states, value_states) if use_cache else None
|
| 298 |
|
| 299 |
use_flashattn = self.config.use_flashattn and is_flash_attn_available()
|
| 300 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
|
| 302 |
|
| 303 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|