|
|
|
|
|
|
|
|
|
from torch import nn |
|
import math |
|
import torch |
|
from ..utils.compile import torch_compile_lazy |
|
|
|
|
|
@torch_compile_lazy |
|
def apply_rope( |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
offset: torch.Tensor, |
|
max_period: float = 10_000, |
|
time_before_heads: bool = False, |
|
): |
|
""" |
|
Args: |
|
q (torch.Tensor): queries, shape `[B, T, H, D]`. |
|
k (torch.Tensor): keys, shape `[B, T, H, D]`. |
|
offset (int): current offset, e.g. when streaming. |
|
max_period (float): maximum period for the cos and sin. |
|
time_before_heads (bool): if True, expected [B, T, H, D], else [B, H, T ,D] |
|
""" |
|
|
|
if time_before_heads: |
|
B, T, H, D = q.shape |
|
else: |
|
B, H, T, D = q.shape |
|
assert k.shape == q.shape |
|
assert D > 0 |
|
assert D % 2 == 0 |
|
assert max_period > 0 |
|
|
|
ds = torch.arange(D // 2, device=q.device, dtype=torch.float32) |
|
freqs = torch.exp(ds * (-math.log(max_period) * 2 / D)) |
|
ts = offset.float() + torch.arange(T, device=q.device, dtype=torch.float32) |
|
if time_before_heads: |
|
ts = ts.view(-1, 1, 1) |
|
else: |
|
ts = ts.view(1, -1, 1) |
|
|
|
dims = q.shape[:-1] |
|
q = q.view(*dims, D // 2, 2) |
|
k = k.view(*dims, D // 2, 2) |
|
|
|
|
|
qr = q[..., 0].float() |
|
qi = q[..., 1].float() |
|
|
|
kr = k[..., 0].float() |
|
ki = k[..., 1].float() |
|
|
|
rotr = torch.cos(freqs * ts) |
|
roti = torch.sin(freqs * ts) |
|
qor = qr * rotr - qi * roti |
|
qoi = qr * roti + qi * rotr |
|
|
|
kor = kr * rotr - ki * roti |
|
koi = kr * roti + ki * rotr |
|
|
|
dtype = q.dtype |
|
qo = torch.stack([qor.to(dtype), qoi.to(dtype)], dim=-1) |
|
ko = torch.stack([kor.to(dtype), koi.to(dtype)], dim=-1) |
|
|
|
return qo.view(*dims, D), ko.view(*dims, D) |
|
|
|
|
|
class RotaryEmbedding(nn.Module): |
|
"""Rotary positional embedding (RoPE) from [Su et al 2022](https://arxiv.org/abs/2104.09864). |
|
|
|
Args: |
|
max_period (float): Maximum period of the rotation frequencies. |
|
""" |
|
|
|
def __init__(self, max_period: float = 10000.0): |
|
super().__init__() |
|
self.max_period = max_period |
|
|
|
def forward( |
|
self, |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
offset: torch.Tensor, |
|
time_before_heads: bool = False, |
|
): |
|
"""Apply rope rotation to query or key tensor.""" |
|
return apply_rope(q, k, offset, self.max_period, time_before_heads) |
|
|