Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2023, Tri Dao, Albert Gu. | |
import torch | |
import torch.nn.functional as F | |
from torch.cuda.amp import custom_bwd, custom_fwd | |
from einops import rearrange, repeat | |
from causal_conv1d import causal_conv1d_fn | |
import causal_conv1d_cuda | |
import selective_scan_cuda | |
class SelectiveScanFn(torch.autograd.Function): | |
def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, | |
return_last_state=False): | |
if u.stride(-1) != 1: | |
u = u.contiguous() | |
if delta.stride(-1) != 1: | |
delta = delta.contiguous() | |
if D is not None: | |
D = D.contiguous() | |
if B.stride(-1) != 1: | |
B = B.contiguous() | |
if C.stride(-1) != 1: | |
C = C.contiguous() | |
if z is not None and z.stride(-1) != 1: | |
z = z.contiguous() | |
if B.dim() == 3: | |
B = rearrange(B, "b dstate l -> b 1 dstate l") | |
ctx.squeeze_B = True | |
if C.dim() == 3: | |
C = rearrange(C, "b dstate l -> b 1 dstate l") | |
ctx.squeeze_C = True | |
out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) | |
ctx.delta_softplus = delta_softplus | |
ctx.has_z = z is not None | |
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) | |
if not ctx.has_z: | |
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) | |
return out if not return_last_state else (out, last_state) | |
else: | |
ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) | |
out_z = rest[0] | |
return out_z if not return_last_state else (out_z, last_state) | |
def backward(ctx, dout, *args): | |
if not ctx.has_z: | |
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors | |
z = None | |
out = None | |
else: | |
u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors | |
if dout.stride(-1) != 1: | |
dout = dout.contiguous() | |
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the | |
# backward of selective_scan_cuda with the backward of chunk). | |
# Here we just pass in None and dz will be allocated in the C++ code. | |
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( | |
u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, | |
False # option to recompute out_z, not used here | |
) | |
dz = rest[0] if ctx.has_z else None | |
dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB | |
dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC | |
return (du, ddelta, dA, dB, dC, | |
dD if D is not None else None, | |
dz, | |
ddelta_bias if delta_bias is not None else None, | |
None, | |
None) | |
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, | |
return_last_state=False): | |
"""if return_last_state is True, returns (out, last_state) | |
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is | |
not considered in the backward pass. | |
""" | |
return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) | |
def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, | |
return_last_state=False): | |
""" | |
u: r(B D L) | |
delta: r(B D L) | |
A: c(D N) or r(D N) | |
B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) | |
C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) | |
D: r(D) | |
z: r(B D L) | |
delta_bias: r(D), fp32 | |
out: r(B D L) | |
last_state (optional): r(B D dstate) or c(B D dstate) | |
""" | |
dtype_in = u.dtype | |
u = u.float() | |
delta = delta.float() | |
if delta_bias is not None: | |
delta = delta + delta_bias[..., None].float() | |
if delta_softplus: | |
delta = F.softplus(delta) | |
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] | |
is_variable_B = B.dim() >= 3 | |
is_variable_C = C.dim() >= 3 | |
if A.is_complex(): | |
if is_variable_B: | |
B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) | |
if is_variable_C: | |
C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) | |
else: | |
B = B.float() | |
C = C.float() | |
x = A.new_zeros((batch, dim, dstate)) | |
ys = [] | |
deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) | |
if not is_variable_B: | |
deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) | |
else: | |
if B.dim() == 3: | |
deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) | |
else: | |
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) | |
deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) | |
if is_variable_C and C.dim() == 4: | |
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) | |
last_state = None | |
for i in range(u.shape[2]): | |
x = deltaA[:, :, i] * x + deltaB_u[:, :, i] | |
if not is_variable_C: | |
y = torch.einsum('bdn,dn->bd', x, C) | |
else: | |
if C.dim() == 3: | |
y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) | |
else: | |
y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) | |
if i == u.shape[2] - 1: | |
last_state = x | |
if y.is_complex(): | |
y = y.real * 2 | |
ys.append(y) | |
y = torch.stack(ys, dim=2) # (batch dim L) | |
out = y if D is None else y + u * rearrange(D, "d -> d 1") | |
if z is not None: | |
out = out * F.silu(z) | |
out = out.to(dtype=dtype_in) | |
return out if not return_last_state else (out, last_state) | |
class MambaInnerFnNoOutProj(torch.autograd.Function): | |
def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, | |
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, | |
C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1): | |
""" | |
xz: (batch, dim, seqlen) | |
""" | |
assert checkpoint_lvl in [0, 1] | |
L = xz.shape[-1] | |
delta_rank = delta_proj_weight.shape[1] | |
d_state = A.shape[-1] * (1 if not A.is_complex() else 2) | |
if torch.is_autocast_enabled(): | |
x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) | |
delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) | |
if xz.stride(-1) != 1: | |
xz = xz.contiguous() | |
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") | |
x, z = xz.chunk(2, dim=1) | |
conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None | |
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True) | |
# We're being very careful here about the layout, to avoid extra transposes. | |
# We want delta to have d as the slowest moving dimension | |
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. | |
x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d) | |
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) | |
ctx.is_variable_B = B is None | |
ctx.is_variable_C = C is None | |
ctx.B_proj_bias_is_None = B_proj_bias is None | |
ctx.C_proj_bias_is_None = C_proj_bias is None | |
if B is None: # variable B | |
B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate) | |
if B_proj_bias is not None: | |
B = B + B_proj_bias.to(dtype=B.dtype) | |
if not A.is_complex(): | |
# B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() | |
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() | |
else: | |
B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() | |
else: | |
if B.stride(-1) != 1: | |
B = B.contiguous() | |
if C is None: # variable C | |
C = x_dbl[:, -d_state:] # (bl dstate) | |
if C_proj_bias is not None: | |
C = C + C_proj_bias.to(dtype=C.dtype) | |
if not A.is_complex(): | |
# C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() | |
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() | |
else: | |
C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() | |
else: | |
if C.stride(-1) != 1: | |
C = C.contiguous() | |
if D is not None: | |
D = D.contiguous() | |
out, scan_intermediates, out_z = selective_scan_cuda.fwd( | |
conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus | |
) | |
ctx.delta_softplus = delta_softplus | |
ctx.checkpoint_lvl = checkpoint_lvl | |
if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass | |
conv1d_out, delta = None, None | |
ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, | |
delta_proj_weight, conv1d_out, delta, | |
A, B, C, D, delta_bias, scan_intermediates, out) | |
# return rearrange(out_z, "b d l -> b l d") | |
return out_z | |
def backward(ctx, dout): | |
# dout: (batch, seqlen, dim) | |
(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, | |
conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors | |
L = xz.shape[-1] | |
delta_rank = delta_proj_weight.shape[1] | |
d_state = A.shape[-1] * (1 if not A.is_complex() else 2) | |
x, z = xz.chunk(2, dim=1) | |
if dout.stride(-1) != 1: | |
dout = dout.contiguous() | |
if ctx.checkpoint_lvl == 1: | |
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True) | |
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), | |
"d (b l) -> b d l", l = L) | |
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the | |
# backward of selective_scan_cuda with the backward of chunk). | |
dxz = torch.empty_like(xz) # (batch, dim, seqlen) | |
dx, dz = dxz.chunk(2, dim=1) | |
# dout_y = rearrange(dout, "b l d -> b d l") # because no arrange at end of forward, so dout shape is b d l | |
dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( | |
conv1d_out, delta, A, B, C, D, z, delta_bias, dout, scan_intermediates, out, dz, | |
ctx.delta_softplus, | |
True # option to recompute out_z | |
) | |
dD = dD if D is not None else None | |
dx_dbl = torch.empty_like(x_dbl) | |
dB_proj_bias = None | |
if ctx.is_variable_B: | |
if not A.is_complex(): | |
dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous() | |
else: | |
dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() | |
dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None | |
dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d) | |
dB = None | |
dC_proj_bias = None | |
if ctx.is_variable_C: | |
if not A.is_complex(): | |
dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous() | |
else: | |
dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() | |
dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None | |
dx_dbl[:, -d_state:] = dC # (bl d) | |
dC = None | |
ddelta = rearrange(ddelta, "b d l -> d (b l)") | |
ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank]) | |
dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight) | |
dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)") | |
dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")) | |
dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out) | |
dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]) | |
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the | |
# backward of conv1d with the backward of chunk). | |
dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd( | |
x, conv1d_weight, conv1d_bias, dconv1d_out, dx, True | |
) | |
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None | |
dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w") | |
return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight, | |
dA, dB, dC, dD, | |
ddelta_bias if delta_bias is not None else None, | |
dB_proj_bias, dC_proj_bias, None) | |
class MambaInnerFn(torch.autograd.Function): | |
def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, | |
out_proj_weight, out_proj_bias, | |
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, | |
C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1): | |
""" | |
xz: (batch, dim, seqlen) | |
""" | |
assert checkpoint_lvl in [0, 1] | |
L = xz.shape[-1] | |
delta_rank = delta_proj_weight.shape[1] | |
d_state = A.shape[-1] * (1 if not A.is_complex() else 2) | |
if torch.is_autocast_enabled(): | |
x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) | |
delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) | |
out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) | |
out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype()) | |
if out_proj_bias is not None else None) | |
if xz.stride(-1) != 1: | |
xz = xz.contiguous() | |
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") | |
x, z = xz.chunk(2, dim=1) | |
conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None | |
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True) | |
# We're being very careful here about the layout, to avoid extra transposes. | |
# We want delta to have d as the slowest moving dimension | |
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. | |
x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d) | |
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) | |
ctx.is_variable_B = B is None | |
ctx.is_variable_C = C is None | |
ctx.B_proj_bias_is_None = B_proj_bias is None | |
ctx.C_proj_bias_is_None = C_proj_bias is None | |
if B is None: # variable B | |
B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate) | |
if B_proj_bias is not None: | |
B = B + B_proj_bias.to(dtype=B.dtype) | |
if not A.is_complex(): | |
# B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() | |
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() | |
else: | |
B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() | |
else: | |
if B.stride(-1) != 1: | |
B = B.contiguous() | |
if C is None: # variable C | |
C = x_dbl[:, -d_state:] # (bl dstate) | |
if C_proj_bias is not None: | |
C = C + C_proj_bias.to(dtype=C.dtype) | |
if not A.is_complex(): | |
# C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() | |
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() | |
else: | |
C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() | |
else: | |
if C.stride(-1) != 1: | |
C = C.contiguous() | |
if D is not None: | |
D = D.contiguous() | |
out, scan_intermediates, out_z = selective_scan_cuda.fwd( | |
conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus | |
) | |
ctx.delta_softplus = delta_softplus | |
ctx.out_proj_bias_is_None = out_proj_bias is None | |
ctx.checkpoint_lvl = checkpoint_lvl | |
if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass | |
conv1d_out, delta = None, None | |
ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, | |
delta_proj_weight, out_proj_weight, conv1d_out, delta, | |
A, B, C, D, delta_bias, scan_intermediates, out) | |
return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) | |
def backward(ctx, dout): | |
# dout: (batch, seqlen, dim) | |
(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, | |
conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors | |
L = xz.shape[-1] | |
delta_rank = delta_proj_weight.shape[1] | |
d_state = A.shape[-1] * (1 if not A.is_complex() else 2) | |
x, z = xz.chunk(2, dim=1) | |
if dout.stride(-1) != 1: | |
dout = dout.contiguous() | |
if ctx.checkpoint_lvl == 1: | |
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True) | |
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), | |
"d (b l) -> b d l", l = L) | |
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the | |
# backward of selective_scan_cuda with the backward of chunk). | |
dxz = torch.empty_like(xz) # (batch, dim, seqlen) | |
dx, dz = dxz.chunk(2, dim=1) | |
dout = rearrange(dout, "b l e -> e (b l)") | |
dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L) | |
dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( | |
conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz, | |
ctx.delta_softplus, | |
True # option to recompute out_z | |
) | |
dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")) | |
dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None | |
dD = dD if D is not None else None | |
dx_dbl = torch.empty_like(x_dbl) | |
dB_proj_bias = None | |
if ctx.is_variable_B: | |
if not A.is_complex(): | |
dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous() | |
else: | |
dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() | |
dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None | |
dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d) | |
dB = None | |
dC_proj_bias = None | |
if ctx.is_variable_C: | |
if not A.is_complex(): | |
dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous() | |
else: | |
dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() | |
dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None | |
dx_dbl[:, -d_state:] = dC # (bl d) | |
dC = None | |
ddelta = rearrange(ddelta, "b d l -> d (b l)") | |
ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank]) | |
dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight) | |
dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)") | |
dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")) | |
dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out) | |
dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]) | |
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the | |
# backward of conv1d with the backward of chunk). | |
dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd( | |
x, conv1d_weight, conv1d_bias, dconv1d_out, dx, True | |
) | |
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None | |
dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w") | |
return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight, | |
dout_proj_weight, dout_proj_bias, | |
dA, dB, dC, dD, | |
ddelta_bias if delta_bias is not None else None, | |
dB_proj_bias, dC_proj_bias, None) | |
class BiMambaInnerFn(torch.autograd.Function): | |
def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, | |
out_proj_weight, out_proj_bias, | |
A, A_b, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, | |
C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1): | |
""" | |
xz: (batch, dim, seqlen) | |
""" | |
assert checkpoint_lvl in [0, 1] | |
L = xz.shape[-1] | |
delta_rank = delta_proj_weight.shape[1] | |
d_state = A.shape[-1] * (1 if not A.is_complex() else 2) | |
if torch.is_autocast_enabled(): | |
x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) | |
delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) | |
out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) | |
out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype()) | |
if out_proj_bias is not None else None) | |
if xz.stride(-1) != 1: | |
xz = xz.contiguous() | |
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") | |
x, z = xz.chunk(2, dim=1) | |
conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None | |
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True) | |
# We're being very careful here about the layout, to avoid extra transposes. | |
# We want delta to have d as the slowest moving dimension | |
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. | |
x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d) | |
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) | |
ctx.is_variable_B = B is None | |
ctx.is_variable_C = C is None | |
ctx.B_proj_bias_is_None = B_proj_bias is None | |
ctx.C_proj_bias_is_None = C_proj_bias is None | |
if B is None: # variable B | |
B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate) | |
if B_proj_bias is not None: | |
B = B + B_proj_bias.to(dtype=B.dtype) | |
if not A.is_complex(): | |
# B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() | |
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() | |
else: | |
B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() | |
else: | |
if B.stride(-1) != 1: | |
B = B.contiguous() | |
if C is None: # variable C | |
C = x_dbl[:, -d_state:] # (bl dstate) | |
if C_proj_bias is not None: | |
C = C + C_proj_bias.to(dtype=C.dtype) | |
if not A.is_complex(): | |
# C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() | |
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() | |
else: | |
C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() | |
else: | |
if C.stride(-1) != 1: | |
C = C.contiguous() | |
if D is not None: | |
D = D.contiguous() | |
out_f, scan_intermediates_f, out_z_f = selective_scan_cuda.fwd( | |
conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus | |
) | |
assert not A_b.is_complex(), "A should not be complex!!" | |
out_b, scan_intermediates_b, out_z_b = selective_scan_cuda.fwd( | |
conv1d_out.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, delta_softplus, | |
) | |
out_z = out_z_f + out_z_b.flip([-1]) | |
ctx.delta_softplus = delta_softplus | |
ctx.out_proj_bias_is_None = out_proj_bias is None | |
ctx.checkpoint_lvl = checkpoint_lvl | |
if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass | |
conv1d_out, delta = None, None | |
ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, | |
delta_proj_weight, out_proj_weight, conv1d_out, delta, | |
A, A_b, B, C, D, delta_bias, scan_intermediates_f, scan_intermediates_b, out_f, out_b) | |
return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) | |
def backward(ctx, dout): | |
# dout: (batch, seqlen, dim) | |
(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, | |
conv1d_out, delta, A, A_b, B, C, D, delta_bias, scan_intermediates_f, scan_intermediates_b, out_f, out_b) = ctx.saved_tensors | |
L = xz.shape[-1] | |
delta_rank = delta_proj_weight.shape[1] | |
d_state = A.shape[-1] * (1 if not A.is_complex() else 2) | |
x, z = xz.chunk(2, dim=1) | |
if dout.stride(-1) != 1: | |
dout = dout.contiguous() | |
if ctx.checkpoint_lvl == 1: | |
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True) | |
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), | |
"d (b l) -> b d l", l = L) | |
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the | |
# backward of selective_scan_cuda with the backward of chunk). | |
dxz = torch.empty_like(xz) # (batch, dim, seqlen) | |
dx, dz = dxz.chunk(2, dim=1) | |
dout = rearrange(dout, "b l e -> e (b l)") | |
dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L) | |
dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z_f = selective_scan_cuda.bwd( | |
conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates_f, out_f, dz, | |
ctx.delta_softplus, | |
True # option to recompute out_z | |
) | |
# flip one | |
dz_b = torch.empty_like(dz) | |
dconv1d_out_f_b, ddelta_f_b, dA_b, dB_f_b, dC_f_b, dD_b, ddelta_bias_b, dz_b, out_z_b = selective_scan_cuda.bwd( | |
conv1d_out.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, dout_y.flip([-1]), scan_intermediates_b, out_b, dz_b, | |
ctx.delta_softplus, | |
True # option to recompute out_z | |
) | |
dconv1d_out = dconv1d_out + dconv1d_out_f_b.flip([-1]) | |
ddelta = ddelta + ddelta_f_b.flip([-1]) | |
dB = dB + dB_f_b.flip([-1]) | |
dC = dC + dC_f_b.flip([-1]) | |
dD = dD + dD_b | |
ddelta_bias = ddelta_bias + ddelta_bias_b | |
dz = dz + dz_b.flip([-1]) | |
out_z = out_z_f + out_z_b.flip([-1]) | |
dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")) | |
dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None | |
dD = dD if D is not None else None | |
dx_dbl = torch.empty_like(x_dbl) | |
dB_proj_bias = None | |
if ctx.is_variable_B: | |
if not A.is_complex(): | |
dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous() | |
else: | |
dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() | |
dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None | |
dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d) | |
dB = None | |
dC_proj_bias = None | |
if ctx.is_variable_C: | |
if not A.is_complex(): | |
dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous() | |
else: | |
dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() | |
dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None | |
dx_dbl[:, -d_state:] = dC # (bl d) | |
dC = None | |
ddelta = rearrange(ddelta, "b d l -> d (b l)") | |
ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank]) | |
dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight) | |
dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)") | |
dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")) | |
dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out) | |
dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]) | |
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the | |
# backward of conv1d with the backward of chunk). | |
dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd( | |
x, conv1d_weight, conv1d_bias, dconv1d_out, dx, True | |
) | |
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None | |
dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w") | |
return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight, | |
dout_proj_weight, dout_proj_bias, | |
dA, dA_b, dB, dC, dD, | |
ddelta_bias if delta_bias is not None else None, | |
dB_proj_bias, dC_proj_bias, None) | |
def mamba_inner_fn( | |
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, | |
out_proj_weight, out_proj_bias, | |
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, | |
C_proj_bias=None, delta_softplus=True | |
): | |
return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, | |
out_proj_weight, out_proj_bias, | |
A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus) | |
def bimamba_inner_fn( | |
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, | |
out_proj_weight, out_proj_bias, | |
A, A_b, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, | |
C_proj_bias=None, delta_softplus=True | |
): | |
return BiMambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, | |
out_proj_weight, out_proj_bias, | |
A, A_b, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus) | |
def mamba_inner_fn_no_out_proj( | |
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, | |
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, | |
C_proj_bias=None, delta_softplus=True | |
): | |
return MambaInnerFnNoOutProj.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, | |
A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus) | |
def mamba_inner_ref( | |
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, | |
out_proj_weight, out_proj_bias, | |
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, | |
C_proj_bias=None, delta_softplus=True | |
): | |
L = xz.shape[-1] | |
delta_rank = delta_proj_weight.shape[1] | |
d_state = A.shape[-1] * (1 if not A.is_complex() else 2) | |
x, z = xz.chunk(2, dim=1) | |
x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu") | |
# We're being very careful here about the layout, to avoid extra transposes. | |
# We want delta to have d as the slowest moving dimension | |
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. | |
x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d) | |
delta = delta_proj_weight @ x_dbl[:, :delta_rank].t() | |
delta = rearrange(delta, "d (b l) -> b d l", l=L) | |
if B is None: # variable B | |
B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d) | |
if B_proj_bias is not None: | |
B = B + B_proj_bias.to(dtype=B.dtype) | |
if not A.is_complex(): | |
B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() | |
else: | |
B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() | |
if C is None: # variable B | |
C = x_dbl[:, -d_state:] # (bl d) | |
if C_proj_bias is not None: | |
C = C + C_proj_bias.to(dtype=C.dtype) | |
if not A.is_complex(): | |
C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() | |
else: | |
C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() | |
y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True) | |
return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias) | |
def bimamba_inner_ref( | |
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, | |
out_proj_weight, out_proj_bias, | |
A, A_b, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, | |
C_proj_bias=None, delta_softplus=True | |
): | |
L = xz.shape[-1] | |
delta_rank = delta_proj_weight.shape[1] | |
d_state = A.shape[-1] * (1 if not A.is_complex() else 2) | |
x, z = xz.chunk(2, dim=1) | |
x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu") | |
# We're being very careful here about the layout, to avoid extra transposes. | |
# We want delta to have d as the slowest moving dimension | |
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. | |
x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d) | |
delta = delta_proj_weight @ x_dbl[:, :delta_rank].t() | |
delta = rearrange(delta, "d (b l) -> b d l", l=L) | |
if B is None: # variable B | |
B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d) | |
if B_proj_bias is not None: | |
B = B + B_proj_bias.to(dtype=B.dtype) | |
if not A.is_complex(): | |
B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() | |
else: | |
B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() | |
if C is None: # variable B | |
C = x_dbl[:, -d_state:] # (bl d) | |
if C_proj_bias is not None: | |
C = C + C_proj_bias.to(dtype=C.dtype) | |
if not A.is_complex(): | |
C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() | |
else: | |
C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() | |
y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True) | |
y_b = selective_scan_fn(x.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, delta_softplus=True) | |
y = y + y_b.flip([-1]) | |
return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias) | |