|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, Union, Tuple |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
|
|
from diffusers.utils import logging |
|
from diffusers.models.attention_processor import Attention |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class CustomLiteLAProcessor2_0: |
|
"""Attention processor used typically in processing the SD3-like self-attention projections. add rms norm for query and key and apply RoPE""" |
|
|
|
def __init__(self): |
|
self.kernel_func = nn.ReLU(inplace=False) |
|
self.eps = 1e-15 |
|
self.pad_val = 1.0 |
|
|
|
def apply_rotary_emb( |
|
self, |
|
x: torch.Tensor, |
|
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings |
|
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are |
|
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting |
|
tensors contain rotary embeddings and are returned as real tensors. |
|
|
|
Args: |
|
x (`torch.Tensor`): |
|
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply |
|
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) |
|
|
|
Returns: |
|
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. |
|
""" |
|
cos, sin = freqs_cis |
|
cos = cos[None, None] |
|
sin = sin[None, None] |
|
cos, sin = cos.to(x.device), sin.to(x.device) |
|
|
|
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) |
|
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) |
|
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) |
|
|
|
return out |
|
|
|
def __call__( |
|
self, |
|
attn: Attention, |
|
hidden_states: torch.FloatTensor, |
|
encoder_hidden_states: torch.FloatTensor = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
encoder_attention_mask: Optional[torch.FloatTensor] = None, |
|
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None, |
|
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None, |
|
*args, |
|
**kwargs, |
|
) -> torch.FloatTensor: |
|
hidden_states_len = hidden_states.shape[1] |
|
|
|
input_ndim = hidden_states.ndim |
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
if encoder_hidden_states is not None: |
|
context_input_ndim = encoder_hidden_states.ndim |
|
if context_input_ndim == 4: |
|
batch_size, channel, height, width = encoder_hidden_states.shape |
|
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size = hidden_states.shape[0] |
|
|
|
|
|
dtype = hidden_states.dtype |
|
query = attn.to_q(hidden_states) |
|
key = attn.to_k(hidden_states) |
|
value = attn.to_v(hidden_states) |
|
|
|
|
|
has_encoder_hidden_state_proj = hasattr(attn, "add_q_proj") and hasattr(attn, "add_k_proj") and hasattr(attn, "add_v_proj") |
|
if encoder_hidden_states is not None and has_encoder_hidden_state_proj: |
|
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) |
|
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) |
|
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) |
|
|
|
|
|
if not attn.is_cross_attention: |
|
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1) |
|
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1) |
|
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1) |
|
else: |
|
query = hidden_states |
|
key = encoder_hidden_states |
|
value = encoder_hidden_states |
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1) |
|
key = key.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1).transpose(-1, -2) |
|
value = value.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1) |
|
|
|
|
|
|
|
query = query.permute(0, 1, 3, 2) |
|
|
|
|
|
if attn.norm_q is not None: |
|
query = attn.norm_q(query) |
|
if attn.norm_k is not None: |
|
key = attn.norm_k(key) |
|
|
|
|
|
if rotary_freqs_cis is not None: |
|
query = self.apply_rotary_emb(query, rotary_freqs_cis) |
|
if not attn.is_cross_attention: |
|
key = self.apply_rotary_emb(key, rotary_freqs_cis) |
|
elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj: |
|
key = self.apply_rotary_emb(key, rotary_freqs_cis_cross) |
|
|
|
|
|
query = query.permute(0, 1, 3, 2) |
|
|
|
if attention_mask is not None: |
|
|
|
attention_mask = attention_mask[:, None, :, None].to(key.dtype) |
|
query = query * attention_mask.permute(0, 1, 3, 2) |
|
if not attn.is_cross_attention: |
|
key = key * attention_mask |
|
value = value * attention_mask.permute(0, 1, 3, 2) |
|
|
|
if attn.is_cross_attention and encoder_attention_mask is not None and has_encoder_hidden_state_proj: |
|
encoder_attention_mask = encoder_attention_mask[:, None, :, None].to(key.dtype) |
|
|
|
key = key * encoder_attention_mask |
|
value = value * encoder_attention_mask.permute(0, 1, 3, 2) |
|
|
|
query = self.kernel_func(query) |
|
key = self.kernel_func(key) |
|
|
|
query, key, value = query.float(), key.float(), value.float() |
|
|
|
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=self.pad_val) |
|
|
|
vk = torch.matmul(value, key) |
|
|
|
hidden_states = torch.matmul(vk, query) |
|
|
|
if hidden_states.dtype in [torch.float16, torch.bfloat16]: |
|
hidden_states = hidden_states.float() |
|
|
|
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps) |
|
|
|
hidden_states = hidden_states.view(batch_size, attn.heads * head_dim, -1).permute(0, 2, 1) |
|
|
|
hidden_states = hidden_states.to(dtype) |
|
if encoder_hidden_states is not None: |
|
encoder_hidden_states = encoder_hidden_states.to(dtype) |
|
|
|
|
|
if encoder_hidden_states is not None and not attn.is_cross_attention and has_encoder_hidden_state_proj: |
|
hidden_states, encoder_hidden_states = ( |
|
hidden_states[:, : hidden_states_len], |
|
hidden_states[:, hidden_states_len:], |
|
) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
if encoder_hidden_states is not None and not attn.context_pre_only and not attn.is_cross_attention and hasattr(attn, "to_add_out"): |
|
encoder_hidden_states = attn.to_add_out(encoder_hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
if encoder_hidden_states is not None and context_input_ndim == 4: |
|
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if torch.get_autocast_gpu_dtype() == torch.float16: |
|
hidden_states = hidden_states.clip(-65504, 65504) |
|
if encoder_hidden_states is not None: |
|
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) |
|
|
|
return hidden_states, encoder_hidden_states |
|
|
|
|
|
class CustomerAttnProcessor2_0: |
|
r""" |
|
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). |
|
""" |
|
|
|
def __init__(self): |
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
|
|
|
def apply_rotary_emb( |
|
self, |
|
x: torch.Tensor, |
|
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings |
|
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are |
|
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting |
|
tensors contain rotary embeddings and are returned as real tensors. |
|
|
|
Args: |
|
x (`torch.Tensor`): |
|
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply |
|
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) |
|
|
|
Returns: |
|
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. |
|
""" |
|
cos, sin = freqs_cis |
|
cos = cos[None, None] |
|
sin = sin[None, None] |
|
cos, sin = cos.to(x.device), sin.to(x.device) |
|
|
|
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) |
|
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) |
|
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) |
|
|
|
return out |
|
|
|
def __call__( |
|
self, |
|
attn: Attention, |
|
hidden_states: torch.FloatTensor, |
|
encoder_hidden_states: torch.FloatTensor = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
encoder_attention_mask: Optional[torch.FloatTensor] = None, |
|
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None, |
|
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None, |
|
*args, |
|
**kwargs, |
|
) -> torch.Tensor: |
|
|
|
residual = hidden_states |
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
|
|
has_encoder_hidden_state_proj = hasattr(attn, "add_q_proj") and hasattr(attn, "add_k_proj") and hasattr(attn, "add_v_proj") |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
query = attn.to_q(hidden_states) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
elif attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
if attn.norm_q is not None: |
|
query = attn.norm_q(query) |
|
if attn.norm_k is not None: |
|
key = attn.norm_k(key) |
|
|
|
|
|
if rotary_freqs_cis is not None: |
|
query = self.apply_rotary_emb(query, rotary_freqs_cis) |
|
if not attn.is_cross_attention: |
|
key = self.apply_rotary_emb(key, rotary_freqs_cis) |
|
elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj: |
|
key = self.apply_rotary_emb(key, rotary_freqs_cis_cross) |
|
|
|
if attn.is_cross_attention and encoder_attention_mask is not None and has_encoder_hidden_state_proj: |
|
|
|
|
|
|
|
combined_mask = attention_mask[:, :, None] * encoder_attention_mask[:, None, :] |
|
attention_mask = torch.where(combined_mask == 1, 0.0, -torch.inf) |
|
attention_mask = attention_mask[:, None, :, :].expand(-1, attn.heads, -1, -1).to(query.dtype) |
|
|
|
elif not attn.is_cross_attention and attention_mask is not None: |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
|
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
|
|
|
|
|
|
|
hidden_states = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
hidden_states = hidden_states.to(query.dtype) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |
|
|