Spaces:
Running
Running
File size: 5,169 Bytes
306b4ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
# Copyright (c) 2024, Tri Dao, Albert Gu.
import torch
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config({'BLOCK_N': 32}),
triton.Config({'BLOCK_N': 64}),
triton.Config({'BLOCK_N': 128}),
triton.Config({'BLOCK_N': 256}),
triton.Config({'BLOCK_N': 512}),
triton.Config({'BLOCK_N': 1024}),
],
key=['ncols'],
)
@triton.jit
def _swiglu_fwd_kernel(
X,
Y,
OUT,
stride_x_row, # how much to increase the pointer when moving by 1 row
stride_y_row,
stride_out_row,
ncols,
BLOCK_N: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
start_col = tl.program_id(1) * BLOCK_N
X += row * stride_x_row
Y += row * stride_y_row
OUT += row * stride_out_row
cols = start_col + tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)
y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)
out = x * tl.sigmoid(x) * y
tl.store(OUT + cols, out, mask=cols < ncols)
def _swiglu_fwd(xy, out=None):
if xy.stride(-1) != 1:
xy = xy.contiguous()
batch_shape = xy.shape[:-1]
xy = xy.reshape(-1, xy.shape[-1])
x, y = xy.chunk(2, dim=-1)
if out is None:
out = torch.empty_like(x)
else:
out = out.reshape(-1, out.shape[-1])
assert out.shape == x.shape
assert out.stride(-1) == 1
M, N = x.shape
grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))
with torch.cuda.device(x.device.index):
_swiglu_fwd_kernel[grid](x, y, out, x.stride(0), y.stride(0), out.stride(0), N)
return out.reshape(*batch_shape, out.shape[-1])
@triton.autotune(
configs=[
triton.Config({'BLOCK_N': 32}),
triton.Config({'BLOCK_N': 64}),
triton.Config({'BLOCK_N': 128}),
triton.Config({'BLOCK_N': 256}),
triton.Config({'BLOCK_N': 512}),
triton.Config({'BLOCK_N': 1024}),
],
key=['ncols'],
)
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["OUT"] is not None})
@triton.jit
def _swiglu_bwd_kernel(
X,
Y,
DOUT,
OUT,
DX,
DY,
stride_x_row, # how much to increase the pointer when moving by 1 row
stride_y_row,
stride_dout_row,
stride_out_row,
stride_dx_row,
stride_dy_row,
ncols,
BLOCK_N: tl.constexpr,
RECOMPUTE_OUTPUT: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
start_col = tl.program_id(1) * BLOCK_N
X += row * stride_x_row
Y += row * stride_y_row
DOUT += row * stride_dout_row
if RECOMPUTE_OUTPUT:
OUT += row * stride_out_row
DX += row * stride_dx_row
DY += row * stride_dy_row
cols = start_col + tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)
y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)
dout = tl.load(DOUT + cols, mask=cols < ncols, other=0.).to(tl.float32)
x_sigmoid = tl.sigmoid(x)
dx = x_sigmoid * (1 + x * (1 - x_sigmoid)) * y * dout
dy = x * x_sigmoid * dout
tl.store(DX + cols, dx, mask=cols < ncols)
tl.store(DY + cols, dy, mask=cols < ncols)
if RECOMPUTE_OUTPUT:
out = x * x_sigmoid * y
tl.store(OUT + cols, out, mask=cols < ncols)
def _swiglu_bwd(xy, dout, dxy=None, recompute_output=False, out=None):
if xy.stride(-1) != 1:
xy = xy.contiguous()
if dout.stride(-1) != 1:
dout = dout.contiguous()
batch_shape = xy.shape[:-1]
xy = xy.reshape(-1, xy.shape[-1])
x, y = xy.chunk(2, dim=-1)
dout = dout.reshape(-1, dout.shape[-1])
assert dout.shape == x.shape
if dxy is None:
dxy = torch.empty_like(xy)
else:
dxy = dxy.reshape(-1, dxy.shape[-1])
assert dxy.shape == xy.shape
dx, dy = dxy.chunk(2, dim=-1)
assert dx.stride(-1) == 1
assert dy.stride(-1) == 1
if recompute_output:
if out is None:
out = torch.empty_like(x)
else:
out = out.reshape(-1, out.shape[-1])
assert out.shape == x.shape
assert out.stride(-1) == 1
M, N = x.shape
grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))
with torch.cuda.device(x.device.index):
_swiglu_bwd_kernel[grid](x, y, dout, out if recompute_output else None, dx, dy,
x.stride(0), y.stride(0), dout.stride(0),
out.stride(0) if recompute_output else 0,
dx.stride(0), dy.stride(0),
N)
if not recompute_output:
return dxy.reshape(*batch_shape, dxy.shape[-1])
else:
return dxy.reshape(*batch_shape, dxy.shape[-1]), out.reshape(*batch_shape, out.shape[-1])
class SwiGLU(torch.autograd.Function):
@staticmethod
def forward(ctx, xy):
ctx.save_for_backward(xy)
return _swiglu_fwd(xy)
@staticmethod
def backward(ctx, dout):
xy, = ctx.saved_tensors
return _swiglu_bwd(xy, dout)
swiglu = SwiGLU.apply
|