|
import torch |
|
from typing import Any, Dict, List, Optional, Tuple |
|
|
|
from transformers import Cache, GenerationConfig |
|
|
|
|
|
UNSUPPORTED_GENERATION_ARGS = [ |
|
"cache_implementation", |
|
"cache_config", |
|
"return_legacy_cache", |
|
"num_beams", |
|
"compile_config", |
|
"assistant_model", |
|
] |
|
|
|
class SinkCache(Cache): |
|
""" |
|
A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to |
|
generate beyond the length of its context window, without losing fluency in the conversation. As it discards past |
|
tokens, the model will lose the ability to generate tokens that depend on the context that was discarded. |
|
|
|
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is |
|
`[batch_size, num_heads, seq_len, head_dim]`. |
|
|
|
This class was copied from transformers 4.52.0, with minor modifications. |
|
|
|
Parameters: |
|
window_length (`int`): |
|
The length of the context window. |
|
num_sink_tokens (`int`): |
|
The number of sink tokens. See the original paper for more information. |
|
""" |
|
|
|
def __init__(self, window_length: int, num_sink_tokens: int) -> None: |
|
super().__init__() |
|
self.key_cache: List[torch.Tensor] = [] |
|
self.value_cache: List[torch.Tensor] = [] |
|
self.window_length = window_length |
|
self.num_sink_tokens = num_sink_tokens |
|
self.cos_sin_rerotation_cache = {} |
|
self._cos_cache = None |
|
self._sin_cache = None |
|
|
|
@staticmethod |
|
def _rotate_half(x): |
|
x1 = x[..., : x.shape[-1] // 2] |
|
x2 = x[..., x.shape[-1] // 2 :] |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
def _apply_key_rotary_pos_emb( |
|
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor |
|
) -> torch.Tensor: |
|
rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin) |
|
return rotated_key_states |
|
|
|
def _get_rerotation_cos_sin( |
|
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
if key_states.shape[-2] not in self.cos_sin_rerotation_cache: |
|
|
|
cos = cos.to(torch.float32) |
|
sin = sin.to(torch.float32) |
|
|
|
|
|
original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :] |
|
shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]] |
|
original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :] |
|
shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]] |
|
rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin |
|
rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin |
|
|
|
self.cos_sin_rerotation_cache[key_states.shape[-2]] = ( |
|
rerotation_cos.to(key_states.dtype).unsqueeze(0), |
|
rerotation_sin.to(key_states.dtype).unsqueeze(0), |
|
) |
|
return self.cos_sin_rerotation_cache[key_states.shape[-2]] |
|
|
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: |
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed.""" |
|
if len(self.key_cache) <= layer_idx: |
|
return 0 |
|
return self.key_cache[layer_idx].shape[-2] |
|
|
|
def get_max_cache_shape(self) -> Optional[int]: |
|
"""Returns the maximum sequence length of the cache object, in case of SinkCache it is the window length.""" |
|
return self.window_length |
|
|
|
def update( |
|
self, |
|
key_states: torch.Tensor, |
|
value_states: torch.Tensor, |
|
layer_idx: int, |
|
cache_kwargs: Optional[Dict[str, Any]] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. |
|
|
|
Parameters: |
|
key_states (`torch.Tensor`): |
|
The new key states to cache. |
|
value_states (`torch.Tensor`): |
|
The new value states to cache. |
|
layer_idx (`int`): |
|
The index of the layer to cache the states for. |
|
cache_kwargs (`Dict[str, Any]`, `optional`): |
|
Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`, |
|
`cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the |
|
rotation as the tokens are shifted. |
|
|
|
Return: |
|
A tuple containing the updated key and value states. |
|
""" |
|
|
|
|
|
if cache_kwargs is None: |
|
cache_kwargs = {} |
|
sin = cache_kwargs.get("sin") |
|
cos = cache_kwargs.get("cos") |
|
partial_rotation_size = cache_kwargs.get("partial_rotation_size") |
|
using_rope = cos is not None and sin is not None |
|
|
|
|
|
if using_rope and layer_idx == 0: |
|
|
|
|
|
if cos.dim() == 2: |
|
self._cos_cache = cos |
|
self._sin_cache = sin |
|
else: |
|
if self._cos_cache is None: |
|
self._cos_cache = cos[0, ...] |
|
self._sin_cache = sin[0, ...] |
|
elif self._cos_cache.shape[0] < self.window_length: |
|
self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0) |
|
self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0) |
|
|
|
|
|
if len(self.key_cache) <= layer_idx: |
|
|
|
self.key_cache.append(key_states) |
|
self.value_cache.append(value_states) |
|
|
|
elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length: |
|
|
|
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) |
|
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) |
|
|
|
else: |
|
|
|
keys_to_keep = self.key_cache[layer_idx][ |
|
:, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] : |
|
] |
|
|
|
|
|
if using_rope: |
|
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin( |
|
key_states, self._cos_cache[: self.window_length], self._sin_cache[: self.window_length] |
|
) |
|
if partial_rotation_size is not None: |
|
keys_to_keep, keys_pass = ( |
|
keys_to_keep[..., :partial_rotation_size], |
|
keys_to_keep[..., partial_rotation_size:], |
|
) |
|
keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin) |
|
if partial_rotation_size is not None: |
|
keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1) |
|
|
|
|
|
sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens] |
|
self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2) |
|
|
|
sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens] |
|
values_to_keep = self.value_cache[layer_idx][ |
|
:, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] : |
|
] |
|
self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2) |
|
|
|
return self.key_cache[layer_idx], self.value_cache[layer_idx] |
|
|
|
|
|
def generate(model, window_length=256, num_sink_tokens=4, **kwargs): |
|
"""Custom generate function for SinkCache. |
|
|
|
Args: |
|
model (`PreTrainedModel`): |
|
The model to generate from. |
|
window_length (`int`, *optional*, defaults to 256): |
|
The length of the context window. |
|
num_sink_tokens (`int`, *optional*, defaults to 4): |
|
The number of sink tokens. See the original paper for more information. |
|
""" |
|
|
|
|
|
generation_config = kwargs.get("generation_config") |
|
default_global_generation_config = GenerationConfig() |
|
default_model_generation_config = model.generation_config |
|
for arg in UNSUPPORTED_GENERATION_ARGS: |
|
has_custom_gen_config_arg = ( |
|
generation_config is not None |
|
|
|
and not ( |
|
getattr(default_model_generation_config, arg) == getattr(generation_config, arg) |
|
or getattr(default_global_generation_config, arg) == getattr(generation_config, arg) |
|
) |
|
) |
|
kwargs_has_arg = arg in kwargs and kwargs[arg] is not None |
|
if kwargs_has_arg or has_custom_gen_config_arg: |
|
raise ValueError( |
|
f"`{arg}` is set, but it's not supported in this custom generate function. List of " |
|
f"unsupported arguments: {UNSUPPORTED_GENERATION_ARGS}" |
|
) |
|
|
|
|
|
if model.config.is_encoder_decoder: |
|
raise ValueError("This custom generate function only works with decoder-only models") |
|
|
|
|
|
|
|
kwargs.pop("custom_generate", None) |
|
|
|
|
|
|
|
past_key_values = kwargs.pop("past_key_values", None) |
|
if past_key_values is None: |
|
past_key_values = SinkCache(window_length=window_length, num_sink_tokens=num_sink_tokens) |
|
elif not isinstance(past_key_values, SinkCache): |
|
raise ValueError(f"`past_key_values` must be a `SinkCache` instance, got a {type(past_key_values)} instance") |
|
|
|
|
|
generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True) |
|
return generation_outputs |
|
|