# Copyright (c) 2023, Tri Dao, Albert Gu. import torch import torch.nn.functional as F from torch.cuda.amp import custom_bwd, custom_fwd from einops import rearrange, repeat from causal_conv1d import causal_conv1d_fn import causal_conv1d_cuda import selective_scan_cuda class SelectiveScanFn(torch.autograd.Function): @staticmethod def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False): if u.stride(-1) != 1: u = u.contiguous() if delta.stride(-1) != 1: delta = delta.contiguous() if D is not None: D = D.contiguous() if B.stride(-1) != 1: B = B.contiguous() if C.stride(-1) != 1: C = C.contiguous() if z is not None and z.stride(-1) != 1: z = z.contiguous() if B.dim() == 3: B = rearrange(B, "b dstate l -> b 1 dstate l") ctx.squeeze_B = True if C.dim() == 3: C = rearrange(C, "b dstate l -> b 1 dstate l") ctx.squeeze_C = True out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) ctx.delta_softplus = delta_softplus ctx.has_z = z is not None last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) if not ctx.has_z: ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) return out if not return_last_state else (out, last_state) else: ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) out_z = rest[0] return out_z if not return_last_state else (out_z, last_state) @staticmethod def backward(ctx, dout, *args): if not ctx.has_z: u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors z = None out = None else: u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors if dout.stride(-1) != 1: dout = dout.contiguous() # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the # backward of selective_scan_cuda with the backward of chunk). # Here we just pass in None and dz will be allocated in the C++ code. du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, False # option to recompute out_z, not used here ) dz = rest[0] if ctx.has_z else None dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC return (du, ddelta, dA, dB, dC, dD if D is not None else None, dz, ddelta_bias if delta_bias is not None else None, None, None) def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False): """if return_last_state is True, returns (out, last_state) last_state has shape (batch, dim, dstate). Note that the gradient of the last state is not considered in the backward pass. """ return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False): """ u: r(B D L) delta: r(B D L) A: c(D N) or r(D N) B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) D: r(D) z: r(B D L) delta_bias: r(D), fp32 out: r(B D L) last_state (optional): r(B D dstate) or c(B D dstate) """ dtype_in = u.dtype u = u.float() delta = delta.float() if delta_bias is not None: delta = delta + delta_bias[..., None].float() if delta_softplus: delta = F.softplus(delta) batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] is_variable_B = B.dim() >= 3 is_variable_C = C.dim() >= 3 if A.is_complex(): if is_variable_B: B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) if is_variable_C: C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) else: B = B.float() C = C.float() x = A.new_zeros((batch, dim, dstate)) ys = [] deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) if not is_variable_B: deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) else: if B.dim() == 3: deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) else: B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) if is_variable_C and C.dim() == 4: C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) last_state = None for i in range(u.shape[2]): x = deltaA[:, :, i] * x + deltaB_u[:, :, i] if not is_variable_C: y = torch.einsum('bdn,dn->bd', x, C) else: if C.dim() == 3: y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) else: y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) if i == u.shape[2] - 1: last_state = x if y.is_complex(): y = y.real * 2 ys.append(y) y = torch.stack(ys, dim=2) # (batch dim L) out = y if D is None else y + u * rearrange(D, "d -> d 1") if z is not None: out = out * F.silu(z) out = out.to(dtype=dtype_in) return out if not return_last_state else (out, last_state) class MambaInnerFnNoOutProj(torch.autograd.Function): @staticmethod @custom_fwd def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1): """ xz: (batch, dim, seqlen) """ assert checkpoint_lvl in [0, 1] L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) if torch.is_autocast_enabled(): x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) if xz.stride(-1) != 1: xz = xz.contiguous() conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") x, z = xz.chunk(2, dim=1) conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True) # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d) delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) ctx.is_variable_B = B is None ctx.is_variable_C = C is None ctx.B_proj_bias_is_None = B_proj_bias is None ctx.C_proj_bias_is_None = C_proj_bias is None if B is None: # variable B B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate) if B_proj_bias is not None: B = B + B_proj_bias.to(dtype=B.dtype) if not A.is_complex(): # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() else: B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() else: if B.stride(-1) != 1: B = B.contiguous() if C is None: # variable C C = x_dbl[:, -d_state:] # (bl dstate) if C_proj_bias is not None: C = C + C_proj_bias.to(dtype=C.dtype) if not A.is_complex(): # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() else: C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() else: if C.stride(-1) != 1: C = C.contiguous() if D is not None: D = D.contiguous() out, scan_intermediates, out_z = selective_scan_cuda.fwd( conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus ) ctx.delta_softplus = delta_softplus ctx.checkpoint_lvl = checkpoint_lvl if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass conv1d_out, delta = None, None ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) # return rearrange(out_z, "b d l -> b l d") return out_z @staticmethod @custom_bwd def backward(ctx, dout): # dout: (batch, seqlen, dim) (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) x, z = xz.chunk(2, dim=1) if dout.stride(-1) != 1: dout = dout.contiguous() if ctx.checkpoint_lvl == 1: conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True) delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the # backward of selective_scan_cuda with the backward of chunk). dxz = torch.empty_like(xz) # (batch, dim, seqlen) dx, dz = dxz.chunk(2, dim=1) # dout_y = rearrange(dout, "b l d -> b d l") # because no arrange at end of forward, so dout shape is b d l dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( conv1d_out, delta, A, B, C, D, z, delta_bias, dout, scan_intermediates, out, dz, ctx.delta_softplus, True # option to recompute out_z ) dD = dD if D is not None else None dx_dbl = torch.empty_like(x_dbl) dB_proj_bias = None if ctx.is_variable_B: if not A.is_complex(): dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous() else: dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d) dB = None dC_proj_bias = None if ctx.is_variable_C: if not A.is_complex(): dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous() else: dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None dx_dbl[:, -d_state:] = dC # (bl d) dC = None ddelta = rearrange(ddelta, "b d l -> d (b l)") ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank]) dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight) dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)") dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")) dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out) dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]) # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the # backward of conv1d with the backward of chunk). dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd( x, conv1d_weight, conv1d_bias, dconv1d_out, dx, True ) dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w") return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight, dA, dB, dC, dD, ddelta_bias if delta_bias is not None else None, dB_proj_bias, dC_proj_bias, None) class MambaInnerFn(torch.autograd.Function): @staticmethod @custom_fwd def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1): """ xz: (batch, dim, seqlen) """ assert checkpoint_lvl in [0, 1] L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) if torch.is_autocast_enabled(): x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype()) if out_proj_bias is not None else None) if xz.stride(-1) != 1: xz = xz.contiguous() conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") x, z = xz.chunk(2, dim=1) conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True) # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d) delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) ctx.is_variable_B = B is None ctx.is_variable_C = C is None ctx.B_proj_bias_is_None = B_proj_bias is None ctx.C_proj_bias_is_None = C_proj_bias is None if B is None: # variable B B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate) if B_proj_bias is not None: B = B + B_proj_bias.to(dtype=B.dtype) if not A.is_complex(): # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() else: B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() else: if B.stride(-1) != 1: B = B.contiguous() if C is None: # variable C C = x_dbl[:, -d_state:] # (bl dstate) if C_proj_bias is not None: C = C + C_proj_bias.to(dtype=C.dtype) if not A.is_complex(): # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() else: C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() else: if C.stride(-1) != 1: C = C.contiguous() if D is not None: D = D.contiguous() out, scan_intermediates, out_z = selective_scan_cuda.fwd( conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus ) ctx.delta_softplus = delta_softplus ctx.out_proj_bias_is_None = out_proj_bias is None ctx.checkpoint_lvl = checkpoint_lvl if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass conv1d_out, delta = None, None ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) @staticmethod @custom_bwd def backward(ctx, dout): # dout: (batch, seqlen, dim) (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) x, z = xz.chunk(2, dim=1) if dout.stride(-1) != 1: dout = dout.contiguous() if ctx.checkpoint_lvl == 1: conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True) delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the # backward of selective_scan_cuda with the backward of chunk). dxz = torch.empty_like(xz) # (batch, dim, seqlen) dx, dz = dxz.chunk(2, dim=1) dout = rearrange(dout, "b l e -> e (b l)") dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L) dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz, ctx.delta_softplus, True # option to recompute out_z ) dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")) dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None dD = dD if D is not None else None dx_dbl = torch.empty_like(x_dbl) dB_proj_bias = None if ctx.is_variable_B: if not A.is_complex(): dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous() else: dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d) dB = None dC_proj_bias = None if ctx.is_variable_C: if not A.is_complex(): dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous() else: dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None dx_dbl[:, -d_state:] = dC # (bl d) dC = None ddelta = rearrange(ddelta, "b d l -> d (b l)") ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank]) dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight) dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)") dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")) dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out) dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]) # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the # backward of conv1d with the backward of chunk). dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd( x, conv1d_weight, conv1d_bias, dconv1d_out, dx, True ) dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w") return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight, dout_proj_weight, dout_proj_bias, dA, dB, dC, dD, ddelta_bias if delta_bias is not None else None, dB_proj_bias, dC_proj_bias, None) class BiMambaInnerFn(torch.autograd.Function): @staticmethod @custom_fwd def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, A_b, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1): """ xz: (batch, dim, seqlen) """ assert checkpoint_lvl in [0, 1] L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) if torch.is_autocast_enabled(): x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype()) if out_proj_bias is not None else None) if xz.stride(-1) != 1: xz = xz.contiguous() conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") x, z = xz.chunk(2, dim=1) conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True) # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d) delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) ctx.is_variable_B = B is None ctx.is_variable_C = C is None ctx.B_proj_bias_is_None = B_proj_bias is None ctx.C_proj_bias_is_None = C_proj_bias is None if B is None: # variable B B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate) if B_proj_bias is not None: B = B + B_proj_bias.to(dtype=B.dtype) if not A.is_complex(): # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() else: B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() else: if B.stride(-1) != 1: B = B.contiguous() if C is None: # variable C C = x_dbl[:, -d_state:] # (bl dstate) if C_proj_bias is not None: C = C + C_proj_bias.to(dtype=C.dtype) if not A.is_complex(): # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() else: C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() else: if C.stride(-1) != 1: C = C.contiguous() if D is not None: D = D.contiguous() out_f, scan_intermediates_f, out_z_f = selective_scan_cuda.fwd( conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus ) assert not A_b.is_complex(), "A should not be complex!!" out_b, scan_intermediates_b, out_z_b = selective_scan_cuda.fwd( conv1d_out.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, delta_softplus, ) out_z = out_z_f + out_z_b.flip([-1]) ctx.delta_softplus = delta_softplus ctx.out_proj_bias_is_None = out_proj_bias is None ctx.checkpoint_lvl = checkpoint_lvl if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass conv1d_out, delta = None, None ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, conv1d_out, delta, A, A_b, B, C, D, delta_bias, scan_intermediates_f, scan_intermediates_b, out_f, out_b) return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) @staticmethod @custom_bwd def backward(ctx, dout): # dout: (batch, seqlen, dim) (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, conv1d_out, delta, A, A_b, B, C, D, delta_bias, scan_intermediates_f, scan_intermediates_b, out_f, out_b) = ctx.saved_tensors L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) x, z = xz.chunk(2, dim=1) if dout.stride(-1) != 1: dout = dout.contiguous() if ctx.checkpoint_lvl == 1: conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True) delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the # backward of selective_scan_cuda with the backward of chunk). dxz = torch.empty_like(xz) # (batch, dim, seqlen) dx, dz = dxz.chunk(2, dim=1) dout = rearrange(dout, "b l e -> e (b l)") dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L) dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z_f = selective_scan_cuda.bwd( conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates_f, out_f, dz, ctx.delta_softplus, True # option to recompute out_z ) # flip one dz_b = torch.empty_like(dz) dconv1d_out_f_b, ddelta_f_b, dA_b, dB_f_b, dC_f_b, dD_b, ddelta_bias_b, dz_b, out_z_b = selective_scan_cuda.bwd( conv1d_out.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, dout_y.flip([-1]), scan_intermediates_b, out_b, dz_b, ctx.delta_softplus, True # option to recompute out_z ) dconv1d_out = dconv1d_out + dconv1d_out_f_b.flip([-1]) ddelta = ddelta + ddelta_f_b.flip([-1]) dB = dB + dB_f_b.flip([-1]) dC = dC + dC_f_b.flip([-1]) dD = dD + dD_b ddelta_bias = ddelta_bias + ddelta_bias_b dz = dz + dz_b.flip([-1]) out_z = out_z_f + out_z_b.flip([-1]) dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")) dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None dD = dD if D is not None else None dx_dbl = torch.empty_like(x_dbl) dB_proj_bias = None if ctx.is_variable_B: if not A.is_complex(): dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous() else: dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d) dB = None dC_proj_bias = None if ctx.is_variable_C: if not A.is_complex(): dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous() else: dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None dx_dbl[:, -d_state:] = dC # (bl d) dC = None ddelta = rearrange(ddelta, "b d l -> d (b l)") ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank]) dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight) dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)") dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")) dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out) dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]) # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the # backward of conv1d with the backward of chunk). dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd( x, conv1d_weight, conv1d_bias, dconv1d_out, dx, True ) dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w") return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight, dout_proj_weight, dout_proj_bias, dA, dA_b, dB, dC, dD, ddelta_bias if delta_bias is not None else None, dB_proj_bias, dC_proj_bias, None) def mamba_inner_fn( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True ): return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus) def bimamba_inner_fn( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, A_b, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True ): return BiMambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, A_b, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus) def mamba_inner_fn_no_out_proj( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True ): return MambaInnerFnNoOutProj.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus) def mamba_inner_ref( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True ): L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) x, z = xz.chunk(2, dim=1) x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu") # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d) delta = delta_proj_weight @ x_dbl[:, :delta_rank].t() delta = rearrange(delta, "d (b l) -> b d l", l=L) if B is None: # variable B B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d) if B_proj_bias is not None: B = B + B_proj_bias.to(dtype=B.dtype) if not A.is_complex(): B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() else: B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() if C is None: # variable B C = x_dbl[:, -d_state:] # (bl d) if C_proj_bias is not None: C = C + C_proj_bias.to(dtype=C.dtype) if not A.is_complex(): C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() else: C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True) return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias) def bimamba_inner_ref( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, A_b, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True ): L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) x, z = xz.chunk(2, dim=1) x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu") # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d) delta = delta_proj_weight @ x_dbl[:, :delta_rank].t() delta = rearrange(delta, "d (b l) -> b d l", l=L) if B is None: # variable B B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d) if B_proj_bias is not None: B = B + B_proj_bias.to(dtype=B.dtype) if not A.is_complex(): B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() else: B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() if C is None: # variable B C = x_dbl[:, -d_state:] # (bl d) if C_proj_bias is not None: C = C + C_proj_bias.to(dtype=C.dtype) if not A.is_complex(): C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() else: C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True) y_b = selective_scan_fn(x.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, delta_softplus=True) y = y + y_b.flip([-1]) return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)