File size: 1,305 Bytes
4a07eb0 9fc83e6 4a07eb0 9fc83e6 4a07eb0 9c1f92e 4a07eb0 9c1f92e 4a07eb0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import torch
from torch import nn
from .layer_norm import rms_norm_fn
class LlamaRMSNorm(nn.Module):
"""
RMS Layer Norm for Llama models.
Triton-optimized RMS layer norm. The interface is compatible with `LLamaRMSNorm` in
`transformers`.
Attributes:
weight (`torch.Tensor`): The learnable scaling parameter.
variance_epsilon (`float`): The epsilon value for numerical stability.
"""
weight: torch.Tensor
variance_epsilon: float
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Apply RMS normalization to the input hidden states.
Args:
hidden_states (`torch.Tensor`):
Input tensor of shape `(batch_size, sequence_length, hidden_size)` or any shape
where the last dimension is the feature dimension to be normalized.
Returns:
`torch.Tensor`:
The normalized tensor with the same shape as the input `hidden_states`.
"""
return rms_norm_fn(
hidden_states,
self.weight,
bias=None,
residual=None,
eps=self.variance_epsilon,
dropout_p=0.0,
prenorm=False,
residual_in_fp32=False,
)
__all__ = ["LlamaRMSNorm"]
|