|
import torch |
|
from einops import rearrange, repeat |
|
from torch import Tensor |
|
from torch.amp import autocast |
|
|
|
|
|
def rotate_half(x): |
|
"""Also known as "interleaved" style or GPT-J style.""" |
|
x = rearrange(x, "... (d r) -> ... d r", r=2) |
|
x1, x2 = x.unbind(dim=-1) |
|
x = torch.stack((-x2, x1), dim=-1) |
|
return rearrange(x, "... d r -> ... (d r)") |
|
|
|
|
|
@autocast("cuda", enabled=False) |
|
def apply_rotary_emb( |
|
freqs: Tensor, t: Tensor, start_index: int = 0, scale: float = 1.0, seq_dim=-2 |
|
): |
|
dtype = t.dtype |
|
|
|
if t.ndim == 3: |
|
seq_len = t.shape[seq_dim] |
|
freqs = freqs[-seq_len:] |
|
|
|
rot_dim = freqs.shape[-1] |
|
end_index = start_index + rot_dim |
|
|
|
assert ( |
|
rot_dim <= t.shape[-1] |
|
), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" |
|
|
|
t_left, t, t_right = ( |
|
t[..., :start_index], |
|
t[..., start_index:end_index], |
|
t[..., end_index:], |
|
) |
|
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) |
|
out = torch.cat((t_left, t, t_right), dim=-1) |
|
return out.to(dtype) |
|
|
|
|
|
def precompute_freqs_cis( |
|
dim: int, |
|
max_seqlen: int, |
|
theta: float = 10_000.0, |
|
theta_rescale_factor: float = 1.0, |
|
dtype: torch.dtype = torch.float32, |
|
): |
|
theta *= theta_rescale_factor ** (dim / (dim - 2)) |
|
pos = torch.arange(max_seqlen, dtype=dtype) |
|
inv_freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=dtype) / dim)) |
|
freqs = torch.einsum("..., f -> ... f", pos.to(inv_freqs.dtype), inv_freqs) |
|
freqs = repeat(freqs, "... n -> ... (n r)", r=2) |
|
return freqs |
|
|