Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Update modeling_llama.py
Browse files- modeling_llama.py +7 -18
    	
        modeling_llama.py
    CHANGED
    
    | @@ -166,17 +166,14 @@ def rotate_half(x): | |
| 166 | 
             
                return torch.cat((-x2, x1), dim=-1)
         | 
| 167 |  | 
| 168 |  | 
| 169 | 
            -
            def apply_rotary_pos_emb(q, k, cos, sin,  position_ids | 
| 170 | 
             
                # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
         | 
| 171 | 
             
                cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
         | 
| 172 | 
             
                sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
         | 
| 173 | 
            -
                 | 
| 174 | 
            -
                 | 
| 175 | 
            -
             | 
| 176 | 
            -
                 | 
| 177 | 
            -
                sin_k = sin[key_position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
         | 
| 178 | 
            -
                q_embed = (q * cos_q) + (rotate_half(q) * sin_q)
         | 
| 179 | 
            -
                k_embed = (k * cos_k) + (rotate_half(k) * sin_k)
         | 
| 180 | 
             
                return q_embed, k_embed
         | 
| 181 |  | 
| 182 |  | 
| @@ -285,30 +282,22 @@ class LlamaAttention(nn.Module): | |
| 285 |  | 
| 286 | 
             
                    if past_key_value is not None:
         | 
| 287 | 
             
                        kv_seq_len += past_key_value[0].shape[-2]
         | 
| 288 | 
            -
                        cache_key_states = torch.cat([past_key_value[0], key_states], dim=2)
         | 
| 289 | 
            -
                    else:
         | 
| 290 | 
            -
                        cache_key_states = key_states
         | 
| 291 |  | 
| 292 | 
             
                    if pack_cos_sin is not None:
         | 
| 293 | 
             
                        cos, sin = pack_cos_sin.to(query_states.device)
         | 
| 294 | 
             
                    else:
         | 
| 295 | 
             
                        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
         | 
| 296 | 
            -
                     | 
| 297 | 
            -
                    query_states, key_states = apply_rotary_pos_emb(query_states, cache_key_states, cos, sin, position_ids, key_position_ids)
         | 
| 298 |  | 
| 299 | 
             
                    if past_key_value is not None:
         | 
| 300 | 
             
                        # reuse k, v, self_attention
         | 
| 301 | 
            -
                         | 
| 302 | 
             
                        value_states = torch.cat([past_key_value[1], value_states], dim=2)
         | 
| 303 |  | 
| 304 | 
             
                    past_key_value = (cache_key_states, value_states) if use_cache else None
         | 
| 305 |  | 
| 306 | 
             
                    use_flashattn =  self.config.use_flashattn and is_flash_attn_available()
         | 
| 307 |  | 
| 308 | 
            -
                    if self.log_scale:
         | 
| 309 | 
            -
                        log_n = torch.log(torch.tensor(kv_seq_len*1.0)).to(query_states.device, dtype=query_states.dtype) / \
         | 
| 310 | 
            -
                                 torch.log(torch.tensor(self.config.max_position_embeddings)).to(query_states.device, dtype=query_states.dtype)
         | 
| 311 | 
            -
                        query_states = query_states * log_n
         | 
| 312 |  | 
| 313 |  | 
| 314 | 
             
                    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
         | 
|  | |
| 166 | 
             
                return torch.cat((-x2, x1), dim=-1)
         | 
| 167 |  | 
| 168 |  | 
| 169 | 
            +
            def apply_rotary_pos_emb(q, k, cos, sin,  position_ids):
         | 
| 170 | 
             
                # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
         | 
| 171 | 
             
                cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
         | 
| 172 | 
             
                sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
         | 
| 173 | 
            +
                cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
         | 
| 174 | 
            +
                sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
         | 
| 175 | 
            +
                q_embed = (q * cos) + (rotate_half(q) * sin)
         | 
| 176 | 
            +
                k_embed = (k * cos) + (rotate_half(k) * sin)
         | 
|  | |
|  | |
|  | |
| 177 | 
             
                return q_embed, k_embed
         | 
| 178 |  | 
| 179 |  | 
|  | |
| 282 |  | 
| 283 | 
             
                    if past_key_value is not None:
         | 
| 284 | 
             
                        kv_seq_len += past_key_value[0].shape[-2]
         | 
|  | |
|  | |
|  | |
| 285 |  | 
| 286 | 
             
                    if pack_cos_sin is not None:
         | 
| 287 | 
             
                        cos, sin = pack_cos_sin.to(query_states.device)
         | 
| 288 | 
             
                    else:
         | 
| 289 | 
             
                        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
         | 
| 290 | 
            +
                    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
         | 
|  | |
| 291 |  | 
| 292 | 
             
                    if past_key_value is not None:
         | 
| 293 | 
             
                        # reuse k, v, self_attention
         | 
| 294 | 
            +
                        key_states = torch.cat([past_key_value[0], key_states], dim=2)
         | 
| 295 | 
             
                        value_states = torch.cat([past_key_value[1], value_states], dim=2)
         | 
| 296 |  | 
| 297 | 
             
                    past_key_value = (cache_key_states, value_states) if use_cache else None
         | 
| 298 |  | 
| 299 | 
             
                    use_flashattn =  self.config.use_flashattn and is_flash_attn_available()
         | 
| 300 |  | 
|  | |
|  | |
|  | |
|  | |
| 301 |  | 
| 302 |  | 
| 303 | 
             
                    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
         | 
