|
|
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from ..utils.torch import custom_fwd, custom_bwd |
|
|
|
from einops import rearrange, repeat |
|
|
|
try: |
|
from causal_conv1d import causal_conv1d_fn |
|
import causal_conv1d_cuda |
|
except ImportError: |
|
causal_conv1d_fn = None |
|
causal_conv1d_cuda = None |
|
|
|
from .triton.layer_norm import _layer_norm_fwd |
|
|
|
from .._ops import ops |
|
|
|
|
|
class SelectiveScanFn(torch.autograd.Function): |
|
|
|
@staticmethod |
|
def forward( |
|
ctx, |
|
u, |
|
delta, |
|
A, |
|
B, |
|
C, |
|
D=None, |
|
z=None, |
|
delta_bias=None, |
|
delta_softplus=False, |
|
return_last_state=False, |
|
): |
|
if u.stride(-1) != 1: |
|
u = u.contiguous() |
|
if delta.stride(-1) != 1: |
|
delta = delta.contiguous() |
|
if D is not None: |
|
D = D.contiguous() |
|
if B.stride(-1) != 1: |
|
B = B.contiguous() |
|
if C.stride(-1) != 1: |
|
C = C.contiguous() |
|
if z is not None and z.stride(-1) != 1: |
|
z = z.contiguous() |
|
if B.dim() == 3: |
|
B = rearrange(B, "b dstate l -> b 1 dstate l") |
|
ctx.squeeze_B = True |
|
if C.dim() == 3: |
|
C = rearrange(C, "b dstate l -> b 1 dstate l") |
|
ctx.squeeze_C = True |
|
out, x, *rest = ops.selective_scan_fwd( |
|
u, delta, A, B, C, D, z, delta_bias, delta_softplus |
|
) |
|
ctx.delta_softplus = delta_softplus |
|
ctx.has_z = z is not None |
|
last_state = x[:, :, -1, 1::2] |
|
if not ctx.has_z: |
|
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) |
|
return out if not return_last_state else (out, last_state) |
|
else: |
|
ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) |
|
out_z = rest[0] |
|
return out_z if not return_last_state else (out_z, last_state) |
|
|
|
@staticmethod |
|
def backward(ctx, dout, *args): |
|
if not ctx.has_z: |
|
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors |
|
z = None |
|
out = None |
|
else: |
|
u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors |
|
if dout.stride(-1) != 1: |
|
dout = dout.contiguous() |
|
|
|
|
|
|
|
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = ops.selective_scan_bwd( |
|
u, |
|
delta, |
|
A, |
|
B, |
|
C, |
|
D, |
|
z, |
|
delta_bias, |
|
dout, |
|
x, |
|
out, |
|
None, |
|
ctx.delta_softplus, |
|
False, |
|
) |
|
dz = rest[0] if ctx.has_z else None |
|
dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB |
|
dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC |
|
return ( |
|
du, |
|
ddelta, |
|
dA, |
|
dB, |
|
dC, |
|
dD if D is not None else None, |
|
dz, |
|
ddelta_bias if delta_bias is not None else None, |
|
None, |
|
None, |
|
) |
|
|
|
|
|
def rms_norm_forward( |
|
x, |
|
weight, |
|
bias, |
|
eps=1e-6, |
|
is_rms_norm=True, |
|
): |
|
|
|
if x.stride(-1) != 1: |
|
x = x.contiguous() |
|
weight = weight.contiguous() |
|
if bias is not None: |
|
bias = bias.contiguous() |
|
y = _layer_norm_fwd( |
|
x, weight, bias, eps, None, residual_dtype=None, is_rms_norm=is_rms_norm |
|
)[0] |
|
|
|
return y |
|
|
|
|
|
def selective_scan_fn( |
|
u, |
|
delta, |
|
A, |
|
B, |
|
C, |
|
D=None, |
|
z=None, |
|
delta_bias=None, |
|
delta_softplus=False, |
|
return_last_state=False, |
|
): |
|
"""if return_last_state is True, returns (out, last_state) |
|
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is |
|
not considered in the backward pass. |
|
""" |
|
return SelectiveScanFn.apply( |
|
u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state |
|
) |
|
|
|
|
|
def selective_scan_ref( |
|
u, |
|
delta, |
|
A, |
|
B, |
|
C, |
|
D=None, |
|
z=None, |
|
delta_bias=None, |
|
delta_softplus=False, |
|
return_last_state=False, |
|
): |
|
""" |
|
u: r(B D L) |
|
delta: r(B D L) |
|
A: c(D N) or r(D N) |
|
B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) |
|
C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) |
|
D: r(D) |
|
z: r(B D L) |
|
delta_bias: r(D), fp32 |
|
|
|
out: r(B D L) |
|
last_state (optional): r(B D dstate) or c(B D dstate) |
|
""" |
|
dtype_in = u.dtype |
|
u = u.float() |
|
delta = delta.float() |
|
if delta_bias is not None: |
|
delta = delta + delta_bias[..., None].float() |
|
if delta_softplus: |
|
delta = F.softplus(delta) |
|
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] |
|
is_variable_B = B.dim() >= 3 |
|
is_variable_C = C.dim() >= 3 |
|
if A.is_complex(): |
|
if is_variable_B: |
|
B = torch.view_as_complex( |
|
rearrange(B.float(), "... (L two) -> ... L two", two=2) |
|
) |
|
if is_variable_C: |
|
C = torch.view_as_complex( |
|
rearrange(C.float(), "... (L two) -> ... L two", two=2) |
|
) |
|
else: |
|
B = B.float() |
|
C = C.float() |
|
x = A.new_zeros((batch, dim, dstate)) |
|
ys = [] |
|
deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A)) |
|
if not is_variable_B: |
|
deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u) |
|
else: |
|
if B.dim() == 3: |
|
deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u) |
|
else: |
|
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) |
|
deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u) |
|
if is_variable_C and C.dim() == 4: |
|
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) |
|
last_state = None |
|
for i in range(u.shape[2]): |
|
x = deltaA[:, :, i] * x + deltaB_u[:, :, i] |
|
if not is_variable_C: |
|
y = torch.einsum("bdn,dn->bd", x, C) |
|
else: |
|
if C.dim() == 3: |
|
y = torch.einsum("bdn,bn->bd", x, C[:, :, i]) |
|
else: |
|
y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i]) |
|
if i == u.shape[2] - 1: |
|
last_state = x |
|
if y.is_complex(): |
|
y = y.real * 2 |
|
ys.append(y) |
|
y = torch.stack(ys, dim=2) |
|
out = y if D is None else y + u * rearrange(D, "d -> d 1") |
|
if z is not None: |
|
out = out * F.silu(z) |
|
out = out.to(dtype=dtype_in) |
|
return out if not return_last_state else (out, last_state) |
|
|
|
|
|
class MambaInnerFn(torch.autograd.Function): |
|
|
|
@staticmethod |
|
@custom_fwd |
|
def forward( |
|
ctx, |
|
xz, |
|
conv1d_weight, |
|
conv1d_bias, |
|
x_proj_weight, |
|
delta_proj_weight, |
|
out_proj_weight, |
|
out_proj_bias, |
|
A, |
|
B=None, |
|
C=None, |
|
D=None, |
|
delta_bias=None, |
|
B_proj_bias=None, |
|
C_proj_bias=None, |
|
delta_softplus=True, |
|
checkpoint_lvl=1, |
|
b_rms_weight=None, |
|
c_rms_weight=None, |
|
dt_rms_weight=None, |
|
b_c_dt_rms_eps=1e-6, |
|
): |
|
""" |
|
xz: (batch, dim, seqlen) |
|
""" |
|
assert ( |
|
causal_conv1d_cuda is not None |
|
), "causal_conv1d_cuda is not available. Please install causal-conv1d." |
|
assert checkpoint_lvl in [0, 1] |
|
L = xz.shape[-1] |
|
delta_rank = delta_proj_weight.shape[1] |
|
d_state = A.shape[-1] * (1 if not A.is_complex() else 2) |
|
if torch.is_autocast_enabled(): |
|
x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) |
|
delta_proj_weight = delta_proj_weight.to( |
|
dtype=torch.get_autocast_gpu_dtype() |
|
) |
|
out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) |
|
out_proj_bias = ( |
|
out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype()) |
|
if out_proj_bias is not None |
|
else None |
|
) |
|
if xz.stride(-1) != 1: |
|
xz = xz.contiguous() |
|
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") |
|
x, z = xz.chunk(2, dim=1) |
|
conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None |
|
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd( |
|
x, conv1d_weight, conv1d_bias, None, None, None, True |
|
) |
|
|
|
|
|
|
|
x_dbl = F.linear( |
|
rearrange(conv1d_out, "b d l -> (b l) d"), x_proj_weight |
|
) |
|
delta = rearrange( |
|
delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L |
|
) |
|
ctx.is_variable_B = B is None |
|
ctx.is_variable_C = C is None |
|
ctx.B_proj_bias_is_None = B_proj_bias is None |
|
ctx.C_proj_bias_is_None = C_proj_bias is None |
|
if B is None: |
|
B = x_dbl[:, delta_rank : delta_rank + d_state] |
|
if B_proj_bias is not None: |
|
B = B + B_proj_bias.to(dtype=B.dtype) |
|
if not A.is_complex(): |
|
|
|
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() |
|
else: |
|
B = rearrange( |
|
B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2 |
|
).contiguous() |
|
else: |
|
if B.stride(-1) != 1: |
|
B = B.contiguous() |
|
if C is None: |
|
C = x_dbl[:, -d_state:] |
|
if C_proj_bias is not None: |
|
C = C + C_proj_bias.to(dtype=C.dtype) |
|
if not A.is_complex(): |
|
|
|
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() |
|
else: |
|
C = rearrange( |
|
C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2 |
|
).contiguous() |
|
else: |
|
if C.stride(-1) != 1: |
|
C = C.contiguous() |
|
if D is not None: |
|
D = D.contiguous() |
|
|
|
if b_rms_weight is not None: |
|
B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous() |
|
B = rms_norm_forward(B, b_rms_weight, bias=None, eps=b_c_dt_rms_eps) |
|
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() |
|
if c_rms_weight is not None: |
|
C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous() |
|
C = rms_norm_forward(C, c_rms_weight, bias=None, eps=b_c_dt_rms_eps) |
|
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() |
|
if dt_rms_weight is not None: |
|
delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous() |
|
delta = rms_norm_forward( |
|
delta, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps |
|
) |
|
delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous() |
|
|
|
out, scan_intermediates, out_z = ops.selective_scan_fwd( |
|
conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus |
|
) |
|
ctx.delta_softplus = delta_softplus |
|
ctx.out_proj_bias_is_None = out_proj_bias is None |
|
ctx.checkpoint_lvl = checkpoint_lvl |
|
ctx.b_rms_weight = b_rms_weight |
|
ctx.c_rms_weight = c_rms_weight |
|
ctx.dt_rms_weight = dt_rms_weight |
|
ctx.b_c_dt_rms_eps = b_c_dt_rms_eps |
|
if ( |
|
checkpoint_lvl >= 1 |
|
): |
|
conv1d_out, delta = None, None |
|
ctx.save_for_backward( |
|
xz, |
|
conv1d_weight, |
|
conv1d_bias, |
|
x_dbl, |
|
x_proj_weight, |
|
delta_proj_weight, |
|
out_proj_weight, |
|
conv1d_out, |
|
delta, |
|
A, |
|
B, |
|
C, |
|
D, |
|
delta_bias, |
|
scan_intermediates, |
|
b_rms_weight, |
|
c_rms_weight, |
|
dt_rms_weight, |
|
out, |
|
) |
|
return F.linear( |
|
rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias |
|
) |
|
|
|
@staticmethod |
|
@custom_bwd |
|
def backward(ctx, dout): |
|
|
|
assert ( |
|
causal_conv1d_cuda is not None |
|
), "causal_conv1d_cuda is not available. Please install causal-conv1d." |
|
( |
|
xz, |
|
conv1d_weight, |
|
conv1d_bias, |
|
x_dbl, |
|
x_proj_weight, |
|
delta_proj_weight, |
|
out_proj_weight, |
|
conv1d_out, |
|
delta, |
|
A, |
|
B, |
|
C, |
|
D, |
|
delta_bias, |
|
scan_intermediates, |
|
b_rms_weight, |
|
c_rms_weight, |
|
dt_rms_weight, |
|
out, |
|
) = ctx.saved_tensors |
|
L = xz.shape[-1] |
|
delta_rank = delta_proj_weight.shape[1] |
|
d_state = A.shape[-1] * (1 if not A.is_complex() else 2) |
|
x, z = xz.chunk(2, dim=1) |
|
if dout.stride(-1) != 1: |
|
dout = dout.contiguous() |
|
if ctx.checkpoint_lvl == 1: |
|
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd( |
|
x, conv1d_weight, conv1d_bias, None, None, None, True |
|
) |
|
delta = rearrange( |
|
delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L |
|
) |
|
if dt_rms_weight is not None: |
|
delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous() |
|
delta = rms_norm_forward( |
|
delta, ctx.dt_rms_weight, None, ctx.b_c_dt_rms_eps |
|
) |
|
delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous() |
|
if b_rms_weight is not None: |
|
|
|
B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous() |
|
B = rms_norm_forward(B, ctx.b_rms_weight, None, ctx.b_c_dt_rms_eps) |
|
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() |
|
if c_rms_weight is not None: |
|
|
|
C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous() |
|
C = rms_norm_forward(C, ctx.c_rms_weight, None, ctx.b_c_dt_rms_eps) |
|
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() |
|
|
|
|
|
|
|
dxz = torch.empty_like(xz) |
|
dx, dz = dxz.chunk(2, dim=1) |
|
dout = rearrange(dout, "b l e -> e (b l)") |
|
dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L) |
|
dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = ( |
|
ops.selective_scan_bwd( |
|
conv1d_out, |
|
delta, |
|
A, |
|
B, |
|
C, |
|
D, |
|
z, |
|
delta_bias, |
|
dout_y, |
|
scan_intermediates, |
|
out, |
|
dz, |
|
ctx.delta_softplus, |
|
True, |
|
) |
|
) |
|
dout_proj_weight = torch.einsum( |
|
"eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)") |
|
) |
|
dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None |
|
dD = dD if D is not None else None |
|
dx_dbl = torch.empty_like(x_dbl) |
|
dB_proj_bias = None |
|
if ctx.is_variable_B: |
|
if not A.is_complex(): |
|
dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous() |
|
else: |
|
dB = rearrange( |
|
dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2 |
|
).contiguous() |
|
dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None |
|
dx_dbl[:, delta_rank : delta_rank + d_state] = dB |
|
dB = None |
|
dC_proj_bias = None |
|
if ctx.is_variable_C: |
|
if not A.is_complex(): |
|
dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous() |
|
else: |
|
dC = rearrange( |
|
dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2 |
|
).contiguous() |
|
dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None |
|
dx_dbl[:, -d_state:] = dC |
|
dC = None |
|
ddelta = rearrange(ddelta, "b d l -> d (b l)") |
|
ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank]) |
|
dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight) |
|
dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)") |
|
dx_proj_weight = torch.einsum( |
|
"Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d") |
|
) |
|
dconv1d_out = torch.addmm( |
|
dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out |
|
) |
|
dconv1d_out = rearrange( |
|
dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1] |
|
) |
|
|
|
|
|
dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd( |
|
x, |
|
conv1d_weight, |
|
conv1d_bias, |
|
dconv1d_out, |
|
None, |
|
None, |
|
None, |
|
dx, |
|
False, |
|
True, |
|
) |
|
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None |
|
dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w") |
|
return ( |
|
dxz, |
|
dconv1d_weight, |
|
dconv1d_bias, |
|
dx_proj_weight, |
|
ddelta_proj_weight, |
|
dout_proj_weight, |
|
dout_proj_bias, |
|
dA, |
|
dB, |
|
dC, |
|
dD, |
|
ddelta_bias if delta_bias is not None else None, |
|
|
|
dB_proj_bias, |
|
dC_proj_bias, |
|
None, |
|
None, |
|
None, |
|
None, |
|
None, |
|
None, |
|
) |
|
|
|
|
|
def mamba_inner_fn( |
|
xz, |
|
conv1d_weight, |
|
conv1d_bias, |
|
x_proj_weight, |
|
delta_proj_weight, |
|
out_proj_weight, |
|
out_proj_bias, |
|
A, |
|
B=None, |
|
C=None, |
|
D=None, |
|
delta_bias=None, |
|
B_proj_bias=None, |
|
C_proj_bias=None, |
|
delta_softplus=True, |
|
checkpoint_lvl=1, |
|
b_rms_weight=None, |
|
c_rms_weight=None, |
|
dt_rms_weight=None, |
|
b_c_dt_rms_eps=1e-6, |
|
): |
|
return MambaInnerFn.apply( |
|
xz, |
|
conv1d_weight, |
|
conv1d_bias, |
|
x_proj_weight, |
|
delta_proj_weight, |
|
out_proj_weight, |
|
out_proj_bias, |
|
A, |
|
B, |
|
C, |
|
D, |
|
delta_bias, |
|
B_proj_bias, |
|
C_proj_bias, |
|
delta_softplus, |
|
checkpoint_lvl, |
|
b_rms_weight, |
|
c_rms_weight, |
|
dt_rms_weight, |
|
b_c_dt_rms_eps, |
|
) |
|
|
|
|
|
def mamba_inner_ref( |
|
xz, |
|
conv1d_weight, |
|
conv1d_bias, |
|
x_proj_weight, |
|
delta_proj_weight, |
|
out_proj_weight, |
|
out_proj_bias, |
|
A, |
|
B=None, |
|
C=None, |
|
D=None, |
|
delta_bias=None, |
|
B_proj_bias=None, |
|
C_proj_bias=None, |
|
delta_softplus=True, |
|
): |
|
assert ( |
|
causal_conv1d_fn is not None |
|
), "causal_conv1d_fn is not available. Please install causal-conv1d." |
|
L = xz.shape[-1] |
|
delta_rank = delta_proj_weight.shape[1] |
|
d_state = A.shape[-1] * (1 if not A.is_complex() else 2) |
|
x, z = xz.chunk(2, dim=1) |
|
x = causal_conv1d_fn( |
|
x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu" |
|
) |
|
|
|
|
|
|
|
x_dbl = F.linear(rearrange(x, "b d l -> (b l) d"), x_proj_weight) |
|
delta = delta_proj_weight @ x_dbl[:, :delta_rank].t() |
|
delta = rearrange(delta, "d (b l) -> b d l", l=L) |
|
if B is None: |
|
B = x_dbl[:, delta_rank : delta_rank + d_state] |
|
if B_proj_bias is not None: |
|
B = B + B_proj_bias.to(dtype=B.dtype) |
|
if not A.is_complex(): |
|
B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() |
|
else: |
|
B = rearrange( |
|
B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2 |
|
).contiguous() |
|
if C is None: |
|
C = x_dbl[:, -d_state:] |
|
if C_proj_bias is not None: |
|
C = C + C_proj_bias.to(dtype=C.dtype) |
|
if not A.is_complex(): |
|
C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() |
|
else: |
|
C = rearrange( |
|
C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2 |
|
).contiguous() |
|
y = selective_scan_fn( |
|
x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True |
|
) |
|
return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias) |
|
|