danieldk's picture
danieldk HF Staff
Update build
9c1f92e
import torch
from torch import nn
from .layer_norm import rms_norm_fn
class LlamaRMSNorm(nn.Module):
"""
RMS Layer Norm for Llama models.
Triton-optimized RMS layer norm. The interface is compatible with `LLamaRMSNorm` in
`transformers`.
Attributes:
weight (`torch.Tensor`): The learnable scaling parameter.
variance_epsilon (`float`): The epsilon value for numerical stability.
"""
weight: torch.Tensor
variance_epsilon: float
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Apply RMS normalization to the input hidden states.
Args:
hidden_states (`torch.Tensor`):
Input tensor of shape `(batch_size, sequence_length, hidden_size)` or any shape
where the last dimension is the feature dimension to be normalized.
Returns:
`torch.Tensor`:
The normalized tensor with the same shape as the input `hidden_states`.
"""
return rms_norm_fn(
hidden_states,
self.weight,
bias=None,
residual=None,
eps=self.variance_epsilon,
dropout_p=0.0,
prenorm=False,
residual_in_fp32=False,
)
__all__ = ["LlamaRMSNorm"]