|
|
|
|
|
|
|
import torch |
|
from torch.nn import init |
|
|
|
from flash_attn.ops.layer_norm import ( |
|
DropoutAddLayerNormFn, |
|
DropoutAddLayerNormParallelResidualFn, |
|
DropoutAddLayerNormSubsetFn, |
|
) |
|
|
|
|
|
def rms_norm(x, weight, epsilon): |
|
return DropoutAddLayerNormFn.apply( |
|
x, None, weight, None, None, None, 0.0, epsilon, False, False, True |
|
) |
|
|
|
|
|
def dropout_add_rms_norm( |
|
x0, |
|
residual, |
|
weight, |
|
bias, |
|
dropout_p, |
|
epsilon, |
|
rowscale=None, |
|
layerscale=None, |
|
prenorm=False, |
|
residual_in_fp32=False, |
|
return_dropout_mask=False, |
|
): |
|
"""residual_in_fp32 only has an effect if residual is None. |
|
Otherwise residual dtype is residual.dtype. |
|
""" |
|
return DropoutAddLayerNormFn.apply( |
|
x0, |
|
residual, |
|
weight, |
|
bias, |
|
rowscale, |
|
layerscale, |
|
dropout_p, |
|
epsilon, |
|
residual_in_fp32, |
|
prenorm, |
|
True, |
|
return_dropout_mask, |
|
) |
|
|
|
|
|
def dropout_add_rms_norm_subset( |
|
x0, |
|
residual, |
|
weight, |
|
bias, |
|
dropout_p, |
|
epsilon, |
|
layerscale=None, |
|
x0_subset=None, |
|
out_subset=None, |
|
rowscale_const=1.0, |
|
out_numrows=0, |
|
prenorm=False, |
|
residual_in_fp32=False, |
|
return_dropout_mask=False, |
|
): |
|
"""residual_in_fp32 only has an effect if residual is None. |
|
Otherwise residual dtype is residual.dtype. |
|
""" |
|
return DropoutAddLayerNormSubsetFn.apply( |
|
x0, |
|
residual, |
|
weight, |
|
bias, |
|
layerscale, |
|
x0_subset, |
|
out_subset, |
|
dropout_p, |
|
epsilon, |
|
rowscale_const, |
|
out_numrows, |
|
residual_in_fp32, |
|
prenorm, |
|
True, |
|
return_dropout_mask, |
|
) |
|
|
|
|
|
def dropout_add_rms_norm_parallel_residual( |
|
x0, |
|
x1, |
|
residual, |
|
weight0, |
|
bias0, |
|
weight1, |
|
bias1, |
|
dropout_p, |
|
epsilon, |
|
prenorm=False, |
|
residual_in_fp32=False, |
|
return_dropout_mask=False, |
|
): |
|
"""residual_in_fp32 only has an effect if residual is None. |
|
Otherwise residual dtype is residual.dtype. |
|
""" |
|
return DropoutAddLayerNormParallelResidualFn.apply( |
|
x0, |
|
x1, |
|
residual, |
|
weight0, |
|
bias0, |
|
weight1, |
|
bias1, |
|
dropout_p, |
|
epsilon, |
|
residual_in_fp32, |
|
prenorm, |
|
True, |
|
return_dropout_mask, |
|
) |
|
|
|
|
|
class RMSNorm(torch.nn.Module): |
|
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
super().__init__() |
|
self.eps = eps |
|
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) |
|
self.register_parameter("bias", None) |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
init.ones_(self.weight) |
|
|
|
def forward(self, x): |
|
return rms_norm(x, self.weight, self.eps) |
|
|
|
|
|
class DropoutAddRMSNorm(torch.nn.Module): |
|
def __init__( |
|
self, |
|
hidden_size, |
|
prenorm=False, |
|
p=0.0, |
|
eps=1e-5, |
|
residual_in_fp32=False, |
|
device=None, |
|
dtype=None, |
|
): |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
super().__init__() |
|
self.prenorm = prenorm |
|
self.p = p |
|
self.eps = eps |
|
self.residual_in_fp32 = residual_in_fp32 |
|
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) |
|
self.register_parameter("bias", None) |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
init.ones_(self.weight) |
|
|
|
def forward(self, x0, residual=None): |
|
return dropout_add_rms_norm( |
|
x0, |
|
residual, |
|
self.weight, |
|
None, |
|
self.p if self.training else 0.0, |
|
self.eps, |
|
prenorm=self.prenorm, |
|
residual_in_fp32=self.residual_in_fp32, |
|
) |
|
|