File size: 1,289 Bytes
067283f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#original code from https://github.com/genmoai/models under apache 2.0 license

# Based on Llama3 Implementation.
import torch


def apply_rotary_emb_qk_real(

    xqk: torch.Tensor,

    freqs_cos: torch.Tensor,

    freqs_sin: torch.Tensor,

) -> torch.Tensor:
    """

    Apply rotary embeddings to input tensors using the given frequency tensor without complex numbers.



    Args:

        xqk (torch.Tensor): Query and/or Key tensors to apply rotary embeddings. Shape: (B, S, *, num_heads, D)

                            Can be either just query or just key, or both stacked along some batch or * dim.

        freqs_cos (torch.Tensor): Precomputed cosine frequency tensor.

        freqs_sin (torch.Tensor): Precomputed sine frequency tensor.



    Returns:

        torch.Tensor: The input tensor with rotary embeddings applied.

    """
    # Split the last dimension into even and odd parts
    xqk_even = xqk[..., 0::2]
    xqk_odd = xqk[..., 1::2]

    # Apply rotation
    cos_part = (xqk_even * freqs_cos - xqk_odd * freqs_sin).type_as(xqk)
    sin_part = (xqk_even * freqs_sin + xqk_odd * freqs_cos).type_as(xqk)

    # Interleave the results back into the original shape
    out = torch.stack([cos_part, sin_part], dim=-1).flatten(-2)
    return out