tezuesh's picture
Upload folder using huggingface_hub
22d5f88 verified
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from torch import nn
import math
import torch
from ..utils.compile import torch_compile_lazy
@torch_compile_lazy
def apply_rope(
q: torch.Tensor,
k: torch.Tensor,
offset: torch.Tensor,
max_period: float = 10_000,
time_before_heads: bool = False,
):
"""
Args:
q (torch.Tensor): queries, shape `[B, T, H, D]`.
k (torch.Tensor): keys, shape `[B, T, H, D]`.
offset (int): current offset, e.g. when streaming.
max_period (float): maximum period for the cos and sin.
time_before_heads (bool): if True, expected [B, T, H, D], else [B, H, T ,D]
"""
if time_before_heads:
B, T, H, D = q.shape
else:
B, H, T, D = q.shape
assert k.shape == q.shape
assert D > 0
assert D % 2 == 0
assert max_period > 0
ds = torch.arange(D // 2, device=q.device, dtype=torch.float32)
freqs = torch.exp(ds * (-math.log(max_period) * 2 / D))
ts = offset.float() + torch.arange(T, device=q.device, dtype=torch.float32)
if time_before_heads:
ts = ts.view(-1, 1, 1)
else:
ts = ts.view(1, -1, 1)
dims = q.shape[:-1]
q = q.view(*dims, D // 2, 2)
k = k.view(*dims, D // 2, 2)
# convention is `r` suffix is real part, `i` is imaginary.
qr = q[..., 0].float()
qi = q[..., 1].float()
kr = k[..., 0].float()
ki = k[..., 1].float()
rotr = torch.cos(freqs * ts)
roti = torch.sin(freqs * ts)
qor = qr * rotr - qi * roti
qoi = qr * roti + qi * rotr
kor = kr * rotr - ki * roti
koi = kr * roti + ki * rotr
dtype = q.dtype
qo = torch.stack([qor.to(dtype), qoi.to(dtype)], dim=-1)
ko = torch.stack([kor.to(dtype), koi.to(dtype)], dim=-1)
return qo.view(*dims, D), ko.view(*dims, D)
class RotaryEmbedding(nn.Module):
"""Rotary positional embedding (RoPE) from [Su et al 2022](https://arxiv.org/abs/2104.09864).
Args:
max_period (float): Maximum period of the rotation frequencies.
"""
def __init__(self, max_period: float = 10000.0):
super().__init__()
self.max_period = max_period
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
offset: torch.Tensor,
time_before_heads: bool = False,
):
"""Apply rope rotation to query or key tensor."""
return apply_rope(q, k, offset, self.max_period, time_before_heads)