File size: 1,616 Bytes
88afac1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import torch
from einops import rearrange, repeat
from torch import Tensor
from torch.amp import autocast
def rotate_half(x):
"""Also known as "interleaved" style or GPT-J style."""
x = rearrange(x, "... (d r) -> ... d r", r=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, "... d r -> ... (d r)")
@autocast("cuda", enabled=False)
def apply_rotary_emb(
freqs: Tensor, t: Tensor, start_index: int = 0, scale: float = 1.0, seq_dim=-2
):
dtype = t.dtype
if t.ndim == 3:
seq_len = t.shape[seq_dim]
freqs = freqs[-seq_len:]
rot_dim = freqs.shape[-1]
end_index = start_index + rot_dim
assert (
rot_dim <= t.shape[-1]
), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
t_left, t, t_right = (
t[..., :start_index],
t[..., start_index:end_index],
t[..., end_index:],
)
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
out = torch.cat((t_left, t, t_right), dim=-1)
return out.to(dtype)
def precompute_freqs_cis(
dim: int,
max_seqlen: int,
theta: float = 10_000.0,
theta_rescale_factor: float = 1.0,
dtype: torch.dtype = torch.float32,
):
theta *= theta_rescale_factor ** (dim / (dim - 2))
pos = torch.arange(max_seqlen, dtype=dtype)
inv_freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=dtype) / dim))
freqs = torch.einsum("..., f -> ... f", pos.to(inv_freqs.dtype), inv_freqs)
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
return freqs
|