joaogante's picture
joaogante HF Staff
add sanity checks
835a8a6
import torch
from typing import Any, Dict, List, Optional, Tuple
from transformers import Cache, GenerationConfig
UNSUPPORTED_GENERATION_ARGS = [
"cache_implementation", # cache-related arguments, here we always use SinkCache
"cache_config",
"return_legacy_cache",
"num_beams", # beam search (and cousin techniques) are not supported
"compile_config", # SinkCache doesn't support torch.compile
"assistant_model", # it also doesn't support speculative decoding
]
class SinkCache(Cache):
"""
A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
`[batch_size, num_heads, seq_len, head_dim]`.
This class was copied from transformers 4.52.0, with minor modifications.
Parameters:
window_length (`int`):
The length of the context window.
num_sink_tokens (`int`):
The number of sink tokens. See the original paper for more information.
"""
def __init__(self, window_length: int, num_sink_tokens: int) -> None:
super().__init__()
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
self.window_length = window_length
self.num_sink_tokens = num_sink_tokens
self.cos_sin_rerotation_cache = {}
self._cos_cache = None
self._sin_cache = None
@staticmethod
def _rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def _apply_key_rotary_pos_emb(
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> torch.Tensor:
rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
return rotated_key_states
def _get_rerotation_cos_sin(
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
if key_states.shape[-2] not in self.cos_sin_rerotation_cache:
# Upcast to float32 temporarily for better accuracy
cos = cos.to(torch.float32)
sin = sin.to(torch.float32)
# Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
self.cos_sin_rerotation_cache[key_states.shape[-2]] = (
rerotation_cos.to(key_states.dtype).unsqueeze(0),
rerotation_sin.to(key_states.dtype).unsqueeze(0),
)
return self.cos_sin_rerotation_cache[key_states.shape[-2]]
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
if len(self.key_cache) <= layer_idx:
return 0
return self.key_cache[layer_idx].shape[-2]
def get_max_cache_shape(self) -> Optional[int]:
"""Returns the maximum sequence length of the cache object, in case of SinkCache it is the window length."""
return self.window_length
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
`cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
rotation as the tokens are shifted.
Return:
A tuple containing the updated key and value states.
"""
# Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
# with partially rotated position embeddings, like Phi or Persimmon.
if cache_kwargs is None:
cache_kwargs = {}
sin = cache_kwargs.get("sin")
cos = cache_kwargs.get("cos")
partial_rotation_size = cache_kwargs.get("partial_rotation_size")
using_rope = cos is not None and sin is not None
# Update the sin/cos cache, which holds sin/cos values for all possible positions
if using_rope and layer_idx == 0:
# BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove
# after all RoPE models have a llama-like cache utilization.
if cos.dim() == 2:
self._cos_cache = cos
self._sin_cache = sin
else:
if self._cos_cache is None:
self._cos_cache = cos[0, ...]
self._sin_cache = sin[0, ...]
elif self._cos_cache.shape[0] < self.window_length:
self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0)
self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0)
# [bsz, num_heads, seq_len, head_dim]
if len(self.key_cache) <= layer_idx:
# Empty cache
self.key_cache.append(key_states)
self.value_cache.append(value_states)
elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
# Growing cache
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
else:
# Shifting cache
keys_to_keep = self.key_cache[layer_idx][
:, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
]
# On RoPE models, we need to recompute the Key rotation as the tokens are shifted
if using_rope:
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
key_states, self._cos_cache[: self.window_length], self._sin_cache[: self.window_length]
)
if partial_rotation_size is not None:
keys_to_keep, keys_pass = (
keys_to_keep[..., :partial_rotation_size],
keys_to_keep[..., partial_rotation_size:],
)
keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
if partial_rotation_size is not None:
keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
# Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
values_to_keep = self.value_cache[layer_idx][
:, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
]
self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
def generate(model, window_length=256, num_sink_tokens=4, **kwargs):
"""Custom generate function for SinkCache.
Args:
model (`PreTrainedModel`):
The model to generate from.
window_length (`int`, *optional*, defaults to 256):
The length of the context window.
num_sink_tokens (`int`, *optional*, defaults to 4):
The number of sink tokens. See the original paper for more information.
"""
# 1. General sanity checks
# 1.a. A few arguments are not allowed, especially arguments that control caches.
generation_config = kwargs.get("generation_config")
default_global_generation_config = GenerationConfig()
default_model_generation_config = model.generation_config
for arg in UNSUPPORTED_GENERATION_ARGS:
has_custom_gen_config_arg = (
generation_config is not None
# = and not (match global default or match model-specific default)
and not (
getattr(default_model_generation_config, arg) == getattr(generation_config, arg)
or getattr(default_global_generation_config, arg) == getattr(generation_config, arg)
)
)
kwargs_has_arg = arg in kwargs and kwargs[arg] is not None
if kwargs_has_arg or has_custom_gen_config_arg:
raise ValueError(
f"`{arg}` is set, but it's not supported in this custom generate function. List of "
f"unsupported arguments: {UNSUPPORTED_GENERATION_ARGS}"
)
# 1.b. The model must be decoder-only
if model.config.is_encoder_decoder:
raise ValueError("This custom generate function only works with decoder-only models")
# 1.c. compatibility with transformers 4.52: we must pop `custom_generate` from kwargs, otherwise it will result
# in an infinite loop when we call `model.generate`. This is solved in transformers 4.53.
kwargs.pop("custom_generate", None)
# 2. Generate with SinkCache
# 2.a. prepare the cache, if it was not passed.
past_key_values = kwargs.pop("past_key_values", None)
if past_key_values is None:
past_key_values = SinkCache(window_length=window_length, num_sink_tokens=num_sink_tokens)
elif not isinstance(past_key_values, SinkCache):
raise ValueError(f"`past_key_values` must be a `SinkCache` instance, got a {type(past_key_values)} instance")
# 2.b. generate with the cache
generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True)
return generation_outputs