import torch from torch import nn from .layer_norm import rms_norm_fn class LlamaRMSNorm(nn.Module): """ RMS Layer Norm for Llama models. Triton-optimized RMS layer norm. The interface is compatible with `LLamaRMSNorm` in `transformers`. Attributes: weight (`torch.Tensor`): The learnable scaling parameter. variance_epsilon (`float`): The epsilon value for numerical stability. """ weight: torch.Tensor variance_epsilon: float def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ Apply RMS normalization to the input hidden states. Args: hidden_states (`torch.Tensor`): Input tensor of shape `(batch_size, sequence_length, hidden_size)` or any shape where the last dimension is the feature dimension to be normalized. Returns: `torch.Tensor`: The normalized tensor with the same shape as the input `hidden_states`. """ return rms_norm_fn( hidden_states, self.weight, bias=None, residual=None, eps=self.variance_epsilon, dropout_p=0.0, prenorm=False, residual_in_fp32=False, ) __all__ = ["LlamaRMSNorm"]