|
|
|
|
|
|
|
import dropout_layer_norm |
|
import torch |
|
from torch.nn import init |
|
|
|
|
|
def maybe_align(x, alignment_in_bytes=16): |
|
"""Assume that x already has last dim divisible by alignment_in_bytes""" |
|
|
|
|
|
return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone() |
|
|
|
|
|
def _dropout_add_layer_norm_forward( |
|
x0, |
|
residual, |
|
gamma, |
|
beta, |
|
rowscale, |
|
colscale, |
|
dropout_p, |
|
epsilon, |
|
residual_in_fp32=False, |
|
is_rms_norm=False, |
|
): |
|
"""Assume that arguments are contiguous and aligned to 16 bytes""" |
|
hidden_size = gamma.numel() |
|
x0mat = x0.view((-1, hidden_size)) |
|
residualmat = residual.view((-1, hidden_size)) if residual is not None else None |
|
rowscale = rowscale.view(-1) if rowscale is not None else None |
|
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( |
|
x0mat, |
|
residualmat, |
|
gamma, |
|
beta, |
|
rowscale, |
|
colscale, |
|
None, |
|
None, |
|
dropout_p, |
|
epsilon, |
|
1.0, |
|
0, |
|
None, |
|
residual_in_fp32, |
|
is_rms_norm, |
|
) |
|
|
|
|
|
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma |
|
|
|
|
|
def _dropout_add_layer_norm_backward( |
|
dz, |
|
dx, |
|
x, |
|
x0, |
|
dmask, |
|
mu, |
|
rsigma, |
|
gamma, |
|
rowscale, |
|
colscale, |
|
dropout_p, |
|
has_residual, |
|
is_rms_norm=False, |
|
): |
|
"""Assume that arguments are contiguous and aligned to 16 bytes |
|
dx == None means that it was a post-norm architecture |
|
(x = drop(x0) + residual was not returned in the fwd). |
|
x0 must not be None if we have colscale. |
|
""" |
|
hidden_size = gamma.numel() |
|
xmat = x.view((-1, hidden_size)) |
|
dzmat = dz.view(xmat.shape) |
|
dxmat = dx.view(xmat.shape) if dx is not None else None |
|
x0mat = x0.view((-1, hidden_size)) if x0 is not None else None |
|
rowscale = rowscale.view(-1) if rowscale is not None else None |
|
if colscale is not None: |
|
assert x0 is not None, "x0 is required to compute the gradient of colscale" |
|
dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( |
|
dzmat, |
|
dxmat, |
|
xmat, |
|
x0mat, |
|
dmask, |
|
mu, |
|
rsigma, |
|
gamma, |
|
rowscale, |
|
colscale, |
|
None, |
|
None, |
|
dropout_p, |
|
1.0, |
|
0, |
|
has_residual, |
|
is_rms_norm, |
|
) |
|
|
|
if colscale is None: |
|
return dx0mat, dresidualmat, dgamma, dbeta |
|
else: |
|
dcolscale = rest[0] |
|
return dx0mat, dresidualmat, dgamma, dbeta, dcolscale |
|
|
|
|
|
def _dropout_add_layer_norm_subset_forward( |
|
x0, |
|
residual, |
|
gamma, |
|
beta, |
|
colscale, |
|
x0_subset, |
|
out_subset, |
|
dropout_p, |
|
epsilon, |
|
rowscale_const, |
|
out_numrows, |
|
residual_in_fp32=False, |
|
is_rms_norm=False, |
|
): |
|
"""Assume that arguments are contiguous and aligned to 16 bytes""" |
|
hidden_size = gamma.numel() |
|
x0mat = x0.view((-1, hidden_size)) |
|
residualmat = residual.view((-1, hidden_size)) if residual is not None else None |
|
x0_subset = x0_subset.view(-1) if x0_subset is not None else None |
|
out_subset = out_subset.view(-1) if out_subset is not None else None |
|
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( |
|
x0mat, |
|
residualmat, |
|
gamma, |
|
beta, |
|
None, |
|
colscale, |
|
x0_subset, |
|
out_subset, |
|
dropout_p, |
|
epsilon, |
|
rowscale_const, |
|
out_numrows, |
|
None, |
|
residual_in_fp32, |
|
is_rms_norm, |
|
) |
|
|
|
|
|
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma |
|
|
|
|
|
def _dropout_add_layer_norm_subset_backward( |
|
dz, |
|
dx, |
|
x, |
|
x0, |
|
dmask, |
|
mu, |
|
rsigma, |
|
gamma, |
|
colscale, |
|
x0_subset, |
|
out_subset, |
|
dropout_p, |
|
rowscale_const, |
|
x0_numrows, |
|
has_residual, |
|
is_rms_norm=False, |
|
): |
|
"""Assume that arguments are contiguous and aligned to 16 bytes |
|
dx == None means that it was a post-norm architecture |
|
(x = drop(x0) + residual was not returned in the fwd). |
|
x0 must not be None if we have colscale. |
|
""" |
|
hidden_size = gamma.numel() |
|
xmat = x.view((-1, hidden_size)) |
|
dzmat = dz.view(-1, hidden_size) |
|
dxmat = dx.view(xmat.shape) if dx is not None else None |
|
x0mat = x0.view((-1, hidden_size)) if x0 is not None else None |
|
x0_subset = x0_subset.view(-1) if x0_subset is not None else None |
|
out_subset = out_subset.view(-1) if out_subset is not None else None |
|
if colscale is not None: |
|
assert x0 is not None, "x0 is required to compute the gradient of colscale" |
|
dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( |
|
dzmat, |
|
dxmat, |
|
xmat, |
|
x0mat, |
|
dmask, |
|
mu, |
|
rsigma, |
|
gamma, |
|
None, |
|
colscale, |
|
x0_subset, |
|
out_subset, |
|
dropout_p, |
|
rowscale_const, |
|
x0_numrows, |
|
has_residual, |
|
is_rms_norm, |
|
) |
|
|
|
if colscale is None: |
|
return dx0mat, dresidualmat, dgamma, dbeta |
|
else: |
|
dcolscale = rest[0] |
|
return dx0mat, dresidualmat, dgamma, dbeta, dcolscale |
|
|
|
|
|
def _dropout_add_layer_norm_parallel_residual_forward( |
|
x0, |
|
x1, |
|
residual, |
|
gamma0, |
|
beta0, |
|
gamma1, |
|
beta1, |
|
dropout_p, |
|
epsilon, |
|
residual_in_fp32=False, |
|
is_rms_norm=False, |
|
): |
|
"""Assume that arguments are contiguous and aligned to 16 bytes""" |
|
hidden_size = gamma0.numel() |
|
x0mat = x0.view((-1, hidden_size)) |
|
x1mat = x1.view((-1, hidden_size)) if x1 is not None else None |
|
residualmat = residual.view((-1, hidden_size)) if residual is not None else None |
|
( |
|
z0mat, |
|
z1mat, |
|
xmat, |
|
dmask0, |
|
dmask1, |
|
mu, |
|
rsigma, |
|
) = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd( |
|
x0mat, |
|
x1mat, |
|
residualmat, |
|
gamma0, |
|
beta0, |
|
gamma1, |
|
beta1, |
|
dropout_p, |
|
epsilon, |
|
None, |
|
residual_in_fp32, |
|
is_rms_norm, |
|
) |
|
|
|
|
|
return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma |
|
|
|
|
|
def _dropout_add_layer_norm_parallel_residual_backward( |
|
dz0, |
|
dz1, |
|
dx, |
|
x, |
|
dmask0, |
|
dmask1, |
|
mu, |
|
rsigma, |
|
gamma0, |
|
gamma1, |
|
dropout_p, |
|
has_x1, |
|
has_residual, |
|
is_rms_norm=False, |
|
): |
|
"""Assume that arguments are contiguous and aligned to 16 bytes |
|
dx == None means that it was a post-norm architecture |
|
(x = drop(x0) + residual was not returned in the fwd). |
|
""" |
|
hidden_size = gamma0.numel() |
|
xmat = x.view((-1, hidden_size)) |
|
dz0mat = dz0.view(xmat.shape) |
|
dz1mat = dz1.view(xmat.shape) if dz1 is not None else None |
|
dxmat = dx.view(xmat.shape) if dx is not None else None |
|
( |
|
dx0mat, |
|
dx1mat, |
|
dresidualmat, |
|
dgamma0, |
|
dbeta0, |
|
dgamma1, |
|
dbeta1, |
|
*rest, |
|
) = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd( |
|
dz0mat, |
|
dz1mat, |
|
dxmat, |
|
xmat, |
|
dmask0, |
|
dmask1, |
|
mu, |
|
rsigma, |
|
gamma0, |
|
gamma1, |
|
dropout_p, |
|
has_x1, |
|
has_residual, |
|
is_rms_norm, |
|
) |
|
|
|
return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1 |
|
|
|
|
|
class DropoutAddLayerNormFn(torch.autograd.Function): |
|
@staticmethod |
|
def forward( |
|
ctx, |
|
x0, |
|
residual, |
|
gamma, |
|
beta, |
|
rowscale, |
|
colscale, |
|
dropout_p, |
|
epsilon, |
|
residual_in_fp32=False, |
|
prenorm=False, |
|
is_rms_norm=False, |
|
return_dmask=False, |
|
): |
|
x0 = maybe_align(x0.contiguous(), 16) |
|
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None |
|
gamma = maybe_align(gamma.contiguous(), 16) |
|
beta = maybe_align(beta.contiguous(), 16) if beta is not None else None |
|
rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None |
|
colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None |
|
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward( |
|
x0, |
|
residual, |
|
gamma, |
|
beta, |
|
rowscale, |
|
colscale, |
|
dropout_p, |
|
epsilon, |
|
residual_in_fp32, |
|
is_rms_norm, |
|
) |
|
|
|
x0_saved = x0 if colscale is not None else None |
|
ctx.save_for_backward( |
|
xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale |
|
) |
|
ctx.prenorm = prenorm |
|
ctx.dropout_p = dropout_p |
|
ctx.has_residual = residual is not None |
|
ctx.is_rms_norm = is_rms_norm |
|
ctx.has_beta = beta is not None |
|
if not return_dmask: |
|
return ( |
|
zmat.view(x0.shape) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape)) |
|
) |
|
else: |
|
dmask = ( |
|
dmask.view(x0.shape) |
|
if dropout_p > 0.0 |
|
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) |
|
) |
|
ctx.mark_non_differentiable(dmask) |
|
return ( |
|
(zmat.view(x0.shape), dmask) |
|
if not prenorm |
|
else (zmat.view(x0.shape), xmat.view(x0.shape), dmask) |
|
) |
|
|
|
@staticmethod |
|
def backward(ctx, dz, *args): |
|
|
|
dz = maybe_align(dz.contiguous(), 16) |
|
dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None |
|
x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors |
|
|
|
dropout_p = ctx.dropout_p |
|
has_residual = ctx.has_residual |
|
dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward( |
|
dz, |
|
dx, |
|
x, |
|
x0, |
|
dmask, |
|
mu, |
|
rsigma, |
|
gamma, |
|
rowscale, |
|
colscale, |
|
dropout_p, |
|
has_residual, |
|
ctx.is_rms_norm, |
|
) |
|
dx0 = dx0mat.view(x.shape) |
|
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None |
|
dcolscale = rest[0] if colscale is not None else None |
|
return ( |
|
dx0, |
|
dresidual, |
|
dgamma, |
|
dbeta if ctx.has_beta else None, |
|
None, |
|
dcolscale, |
|
None, |
|
None, |
|
None, |
|
None, |
|
None, |
|
None, |
|
) |
|
|
|
|
|
class DropoutAddLayerNormSubsetFn(torch.autograd.Function): |
|
@staticmethod |
|
def forward( |
|
ctx, |
|
x0, |
|
residual, |
|
gamma, |
|
beta, |
|
colscale, |
|
x0_subset, |
|
out_subset, |
|
dropout_p, |
|
epsilon, |
|
rowscale_const, |
|
out_numrows, |
|
residual_in_fp32=False, |
|
prenorm=False, |
|
is_rms_norm=False, |
|
return_dmask=False, |
|
): |
|
x0 = maybe_align(x0.contiguous(), 16) |
|
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None |
|
gamma = maybe_align(gamma.contiguous(), 16) |
|
beta = maybe_align(beta.contiguous(), 16) if beta is not None else None |
|
colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None |
|
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward( |
|
x0, |
|
residual, |
|
gamma, |
|
beta, |
|
colscale, |
|
x0_subset, |
|
out_subset, |
|
dropout_p, |
|
epsilon, |
|
rowscale_const, |
|
out_numrows, |
|
residual_in_fp32, |
|
is_rms_norm, |
|
) |
|
|
|
x0_saved = x0 if colscale is not None else None |
|
x_shape = (-1, *x0.shape[1:]) |
|
ctx.save_for_backward( |
|
xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset |
|
) |
|
ctx.prenorm = prenorm |
|
ctx.dropout_p = dropout_p |
|
ctx.rowscale_const = rowscale_const |
|
ctx.x0_numrows = x0.shape[:-1].numel() |
|
ctx.has_residual = residual is not None |
|
ctx.is_rms_norm = is_rms_norm |
|
ctx.has_beta = beta is not None |
|
z_shape = (-1, *x0.shape[1:]) |
|
if not return_dmask: |
|
return zmat.view(z_shape) if not prenorm else (zmat.view(z_shape), xmat.view(x0.shape)) |
|
else: |
|
z = zmat.view(z_shape) |
|
dmask = ( |
|
dmask.view(x0.shape) |
|
if dropout_p > 0.0 |
|
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) |
|
) |
|
ctx.mark_non_differentiable(dmask) |
|
return (z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask) |
|
|
|
@staticmethod |
|
def backward(ctx, dz, *args): |
|
|
|
dz = maybe_align(dz.contiguous(), 16) |
|
dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None |
|
x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors |
|
|
|
dropout_p = ctx.dropout_p |
|
has_residual = ctx.has_residual |
|
dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward( |
|
dz, |
|
dx, |
|
x, |
|
x0, |
|
dmask, |
|
mu, |
|
rsigma, |
|
gamma, |
|
colscale, |
|
x0_subset, |
|
out_subset, |
|
dropout_p, |
|
ctx.rowscale_const, |
|
ctx.x0_numrows, |
|
has_residual, |
|
ctx.is_rms_norm, |
|
) |
|
dx0 = dx0mat.view(-1, *x.shape[1:]) |
|
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None |
|
dcolscale = rest[0] if colscale is not None else None |
|
return ( |
|
dx0, |
|
dresidual, |
|
dgamma, |
|
dbeta if ctx.has_beta else None, |
|
dcolscale, |
|
None, |
|
None, |
|
None, |
|
None, |
|
None, |
|
None, |
|
None, |
|
None, |
|
None, |
|
None, |
|
) |
|
|
|
|
|
class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function): |
|
@staticmethod |
|
def forward( |
|
ctx, |
|
x0, |
|
x1, |
|
residual, |
|
gamma0, |
|
beta0, |
|
gamma1, |
|
beta1, |
|
dropout_p, |
|
epsilon, |
|
residual_in_fp32=False, |
|
prenorm=False, |
|
is_rms_norm=False, |
|
return_dmask=False, |
|
): |
|
x0 = maybe_align(x0.contiguous(), 16) |
|
x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None |
|
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None |
|
gamma0 = maybe_align(gamma0.contiguous(), 16) |
|
beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None |
|
gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None |
|
beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None |
|
( |
|
z0mat, |
|
z1mat, |
|
xmat, |
|
dmask0, |
|
dmask1, |
|
mu, |
|
rsigma, |
|
) = _dropout_add_layer_norm_parallel_residual_forward( |
|
x0, |
|
x1, |
|
residual, |
|
gamma0, |
|
beta0, |
|
gamma1, |
|
beta1, |
|
dropout_p, |
|
epsilon, |
|
residual_in_fp32, |
|
is_rms_norm, |
|
) |
|
ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma) |
|
ctx.prenorm = prenorm |
|
ctx.dropout_p = dropout_p |
|
ctx.has_x1 = x1 is not None |
|
ctx.has_residual = residual is not None |
|
ctx.is_rms_norm = is_rms_norm |
|
ctx.has_beta = beta0 is not None |
|
z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None) |
|
if not return_dmask: |
|
return z if not prenorm else (*z, xmat.view(x0.shape)) |
|
else: |
|
dmask0 = ( |
|
dmask0.view(x0.shape) |
|
if dropout_p > 0.0 |
|
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) |
|
) |
|
dmask1 = ( |
|
dmask1.view(x0.shape) |
|
if dropout_p > 0.0 and x1 is not None |
|
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) |
|
) |
|
ctx.mark_non_differentiable(dmask0) |
|
ctx.mark_non_differentiable(dmask1) |
|
return ( |
|
(*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1) |
|
) |
|
|
|
@staticmethod |
|
def backward(ctx, dz0, dz1, *args): |
|
dz0 = maybe_align(dz0.contiguous(), 16) |
|
dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None |
|
dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None |
|
x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors |
|
dropout_p = ctx.dropout_p |
|
has_x1 = ctx.has_x1 |
|
has_residual = ctx.has_residual |
|
( |
|
dx0mat, |
|
dx1mat, |
|
dresidualmat, |
|
dgamma0, |
|
dbeta0, |
|
dgamma1, |
|
dbeta1, |
|
) = _dropout_add_layer_norm_parallel_residual_backward( |
|
dz0, |
|
dz1, |
|
dx, |
|
x, |
|
dmask0, |
|
dmask1, |
|
mu, |
|
rsigma, |
|
gamma0, |
|
gamma1, |
|
dropout_p, |
|
has_x1, |
|
has_residual, |
|
ctx.is_rms_norm, |
|
) |
|
dx0 = dx0mat.view(x.shape) |
|
dx1 = dx1mat.view(x.shape) if dx1mat is not None else None |
|
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None |
|
return ( |
|
dx0, |
|
dx1, |
|
dresidual, |
|
dgamma0, |
|
dbeta0 if ctx.has_beta else None, |
|
dgamma1, |
|
dbeta1 if ctx.has_beta else None, |
|
None, |
|
None, |
|
None, |
|
None, |
|
None, |
|
None, |
|
) |
|
|
|
|
|
def layer_norm(x, weight, bias, epsilon): |
|
return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False) |
|
|
|
|
|
def dropout_add_layer_norm( |
|
x0, |
|
residual, |
|
weight, |
|
bias, |
|
dropout_p, |
|
epsilon, |
|
rowscale=None, |
|
layerscale=None, |
|
prenorm=False, |
|
residual_in_fp32=False, |
|
return_dropout_mask=False, |
|
): |
|
"""residual_in_fp32 only has an effect if residual is None. |
|
Otherwise residual dtype is residual.dtype. |
|
""" |
|
return DropoutAddLayerNormFn.apply( |
|
x0, |
|
residual, |
|
weight, |
|
bias, |
|
rowscale, |
|
layerscale, |
|
dropout_p, |
|
epsilon, |
|
residual_in_fp32, |
|
prenorm, |
|
False, |
|
return_dropout_mask, |
|
) |
|
|
|
|
|
def dropout_add_layer_norm_subset( |
|
x0, |
|
residual, |
|
weight, |
|
bias, |
|
dropout_p, |
|
epsilon, |
|
layerscale=None, |
|
x0_subset=None, |
|
out_subset=None, |
|
rowscale_const=1.0, |
|
out_numrows=0, |
|
prenorm=False, |
|
residual_in_fp32=False, |
|
return_dropout_mask=False, |
|
): |
|
"""residual_in_fp32 only has an effect if residual is None. |
|
Otherwise residual dtype is residual.dtype. |
|
""" |
|
return DropoutAddLayerNormSubsetFn.apply( |
|
x0, |
|
residual, |
|
weight, |
|
bias, |
|
layerscale, |
|
x0_subset, |
|
out_subset, |
|
dropout_p, |
|
epsilon, |
|
rowscale_const, |
|
out_numrows, |
|
residual_in_fp32, |
|
prenorm, |
|
False, |
|
return_dropout_mask, |
|
) |
|
|
|
|
|
def dropout_add_layer_norm_parallel_residual( |
|
x0, |
|
x1, |
|
residual, |
|
weight0, |
|
bias0, |
|
weight1, |
|
bias1, |
|
dropout_p, |
|
epsilon, |
|
prenorm=False, |
|
residual_in_fp32=False, |
|
return_dropout_mask=False, |
|
): |
|
"""residual_in_fp32 only has an effect if residual is None. |
|
Otherwise residual dtype is residual.dtype. |
|
""" |
|
return DropoutAddLayerNormParallelResidualFn.apply( |
|
x0, |
|
x1, |
|
residual, |
|
weight0, |
|
bias0, |
|
weight1, |
|
bias1, |
|
dropout_p, |
|
epsilon, |
|
residual_in_fp32, |
|
prenorm, |
|
False, |
|
return_dropout_mask, |
|
) |
|
|
|
|
|
class DropoutAddLayerNorm(torch.nn.Module): |
|
def __init__( |
|
self, |
|
hidden_size, |
|
prenorm=False, |
|
p=0.0, |
|
eps=1e-5, |
|
residual_in_fp32=False, |
|
device=None, |
|
dtype=None, |
|
): |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
super().__init__() |
|
self.prenorm = prenorm |
|
self.p = p |
|
self.eps = eps |
|
self.residual_in_fp32 = residual_in_fp32 |
|
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) |
|
self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
init.ones_(self.weight) |
|
init.zeros_(self.bias) |
|
|
|
def forward(self, x0, residual=None): |
|
return dropout_add_layer_norm( |
|
x0, |
|
residual, |
|
self.weight, |
|
self.bias, |
|
self.p if self.training else 0.0, |
|
self.eps, |
|
prenorm=self.prenorm, |
|
residual_in_fp32=self.residual_in_fp32, |
|
) |
|
|