File size: 10,208 Bytes
9aa8ed3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 |
import torch
import triton
import triton.language as tl
# https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/rope.py
# BSD 2-CLAUSE LICENSE
# Copyright 2024 LinkedIn Corporation
# All Rights Reserved.
# Redistribution and use in source and binary forms, with or
# without modification, are permitted provided that the following
# conditions are met:
# 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided
# with the distribution.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
@triton.jit
def _triton_rope(
q_ptr,
q_row_stride,
k_ptr,
k_row_stride,
cos,
cos_row_stride,
sin,
sin_row_stride,
sl,
bs: tl.constexpr,
cos_bs: tl.constexpr,
n_qh: tl.constexpr,
n_kh: tl.constexpr,
hd: tl.constexpr,
pad_n_qh: tl.constexpr,
pad_n_kh: tl.constexpr,
pad_hd: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
BACKWARD_PASS: tl.constexpr = False,
):
# q size: (bsz, seq_len, num_q_heads, head_dim)
# q stride: (seq_len * num_q_heads * head_dim, num_q_heads * head_dim, head_dim, 1)
# k size: (bsz, seq_len, num_kv_heads, head_dim)
# k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1)
# cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
# stride: (seq_len * head_dim, head_dim, 1)
pid = tl.program_id(0)
# locate start address
q_ptr = q_ptr + pid * q_row_stride
k_ptr = k_ptr + pid * k_row_stride
# ####################################################################
# get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
# m of this program instance
# ####################################################################
# 1. program instances are laid out in a 1D vector of size bsz * seq_len, which
# effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension
# being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index
# and pid % sl to get the sequence index.
# 2. We only need the left half of cos and sin matrix because the right half is just
# a clone of the left half.
batch_idx = pid // sl
cos_row_idx = pid % sl
cos = cos + tl.where(
cos_bs == 1,
cos_row_idx * cos_row_stride,
batch_idx * (sl * cos_row_stride) + cos_row_idx * cos_row_stride,
)
sin = sin + tl.where(
cos_bs == 1,
cos_row_idx * sin_row_stride,
batch_idx * (sl * sin_row_stride) + cos_row_idx * sin_row_stride,
)
cos_offsets = tl.arange(0, pad_hd // 2)
cos_mask = cos_offsets < hd // 2
cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0)
sin_row = tl.load(sin + cos_offsets, mask=cos_mask, other=0)
# ####################################################################
# Load the left and right half of q and k for the current
# program instance (i.e. for the current token) separately
# ####################################################################
# left half of the head
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
# right half of the head
second_half_q_offsets = first_half_q_offsets + (hd // 2)
second_half_k_offsets = first_half_k_offsets + (hd // 2)
second_q_mask = first_q_mask
second_k_mask = first_k_mask
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
if not BACKWARD_PASS:
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
else:
# with some math, we can get:
# dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin]
new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
def rope_forward(q, k, cos, sin):
# transpose it back to the physical shape because Triton looks at the physical storage
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
batch_size, seq_len, n_q_head, head_dim = q.shape
n_kv_head = k.shape[2]
pad_hd = triton.next_power_of_2(head_dim)
pad_n_q_head = triton.next_power_of_2(n_q_head)
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
n_row = batch_size * seq_len
# ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
q = q.contiguous()
k = k.contiguous()
cos = cos.contiguous()
sin = sin.contiguous()
cos_batch_size = cos.shape[0]
_triton_rope[(n_row,)](
q,
q.stride(1),
k,
k.stride(1),
cos,
cos.stride(-2),
sin,
sin.stride(-2),
seq_len,
batch_size,
cos_batch_size,
n_q_head,
n_kv_head,
head_dim,
pad_n_q_head,
pad_n_kv_head,
pad_hd,
BLOCK_SIZE=BLOCK_SIZE,
BACKWARD_PASS=False,
)
return q, k, cos, sin
def rope_backward(dq, dk, cos, sin):
batch_size, seq_len, n_q_head, head_dim = dq.shape
cos_batch_size = cos.shape[0]
n_kv_head = dk.shape[2]
pad_hd = triton.next_power_of_2(head_dim)
pad_n_q_head = triton.next_power_of_2(n_q_head)
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
n_row = batch_size * seq_len
# ensure dq and dk are contiguous
dq = dq.contiguous()
dk = dk.contiguous()
# backward is similar to forward except swapping few ops
_triton_rope[(n_row,)](
dq,
dq.stride(1),
dk,
dk.stride(1),
cos,
cos.stride(-2),
sin,
sin.stride(-2),
seq_len,
batch_size,
cos_batch_size,
n_q_head,
n_kv_head,
head_dim,
pad_n_q_head,
pad_n_kv_head,
pad_hd,
BLOCK_SIZE=BLOCK_SIZE,
BACKWARD_PASS=True,
)
return dq, dk
class LigerRopeFunction(torch.autograd.Function):
"""
Triton implementation of the Rotary Positional Embedding (RoPE) operation. Please note that
this implements the HuggingFace Llama & Mistral version, whose rotation matrix is slightly different
than the original RoPE paper.
Please find the corresponding HuggingFace implementation here:
https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/llama/modeling_llama.py#L184
For more details about the rotation matrix used here, please refer to:
https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509/2
"""
@staticmethod
def forward(ctx, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""
q size: (bsz, n_q_head, seq_len, head_dim)
k size: (bsz, n_kv_head, seq_len, head_dim)
cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
"""
q, k, cos, sin = rope_forward(q, k, cos, sin)
ctx.save_for_backward(cos, sin)
return q, k
def backward(ctx, dq, dk):
"""
dq size: (bsz, n_q_head, seq_len, head_dim)
dk size: (bsz, n_kv_head, seq_len, head_dim)
cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
"""
cos, sin = ctx.saved_tensors
dq, dk = rope_backward(dq, dk, cos, sin)
return dq, dk, None, None, None, None |