mac9087 commited on
Commit
527aa85
·
verified ·
1 Parent(s): b910820

Create transformer/attention.py

Browse files
Files changed (1) hide show
  1. tsr/models/transformer/attention.py +653 -0
tsr/models/transformer/attention.py ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # --------
16
+ #
17
+ # Modified 2024 by the Tripo AI and Stability AI Team.
18
+ #
19
+ # Copyright (c) 2024 Tripo AI & Stability AI
20
+ #
21
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
22
+ # of this software and associated documentation files (the "Software"), to deal
23
+ # in the Software without restriction, including without limitation the rights
24
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25
+ # copies of the Software, and to permit persons to whom the Software is
26
+ # furnished to do so, subject to the following conditions:
27
+ #
28
+ # The above copyright notice and this permission notice shall be included in all
29
+ # copies or substantial portions of the Software.
30
+ #
31
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
34
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
36
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
37
+ # SOFTWARE.
38
+
39
+ from typing import Optional
40
+
41
+ import torch
42
+ import torch.nn.functional as F
43
+ from torch import nn
44
+
45
+
46
+ class Attention(nn.Module):
47
+ r"""
48
+ A cross attention layer.
49
+
50
+ Parameters:
51
+ query_dim (`int`):
52
+ The number of channels in the query.
53
+ cross_attention_dim (`int`, *optional*):
54
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
55
+ heads (`int`, *optional*, defaults to 8):
56
+ The number of heads to use for multi-head attention.
57
+ dim_head (`int`, *optional*, defaults to 64):
58
+ The number of channels in each head.
59
+ dropout (`float`, *optional*, defaults to 0.0):
60
+ The dropout probability to use.
61
+ bias (`bool`, *optional*, defaults to False):
62
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
63
+ upcast_attention (`bool`, *optional*, defaults to False):
64
+ Set to `True` to upcast the attention computation to `float32`.
65
+ upcast_softmax (`bool`, *optional*, defaults to False):
66
+ Set to `True` to upcast the softmax computation to `float32`.
67
+ cross_attention_norm (`str`, *optional*, defaults to `None`):
68
+ The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
69
+ cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
70
+ The number of groups to use for the group norm in the cross attention.
71
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
72
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
73
+ norm_num_groups (`int`, *optional*, defaults to `None`):
74
+ The number of groups to use for the group norm in the attention.
75
+ spatial_norm_dim (`int`, *optional*, defaults to `None`):
76
+ The number of channels to use for the spatial normalization.
77
+ out_bias (`bool`, *optional*, defaults to `True`):
78
+ Set to `True` to use a bias in the output linear layer.
79
+ scale_qk (`bool`, *optional*, defaults to `True`):
80
+ Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
81
+ only_cross_attention (`bool`, *optional*, defaults to `False`):
82
+ Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
83
+ `added_kv_proj_dim` is not `None`.
84
+ eps (`float`, *optional*, defaults to 1e-5):
85
+ An additional value added to the denominator in group normalization that is used for numerical stability.
86
+ rescale_output_factor (`float`, *optional*, defaults to 1.0):
87
+ A factor to rescale the output by dividing it with this value.
88
+ residual_connection (`bool`, *optional*, defaults to `False`):
89
+ Set to `True` to add the residual connection to the output.
90
+ _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
91
+ Set to `True` if the attention block is loaded from a deprecated state dict.
92
+ processor (`AttnProcessor`, *optional*, defaults to `None`):
93
+ The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
94
+ `AttnProcessor` otherwise.
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ query_dim: int,
100
+ cross_attention_dim: Optional[int] = None,
101
+ heads: int = 8,
102
+ dim_head: int = 64,
103
+ dropout: float = 0.0,
104
+ bias: bool = False,
105
+ upcast_attention: bool = False,
106
+ upcast_softmax: bool = False,
107
+ cross_attention_norm: Optional[str] = None,
108
+ cross_attention_norm_num_groups: int = 32,
109
+ added_kv_proj_dim: Optional[int] = None,
110
+ norm_num_groups: Optional[int] = None,
111
+ out_bias: bool = True,
112
+ scale_qk: bool = True,
113
+ only_cross_attention: bool = False,
114
+ eps: float = 1e-5,
115
+ rescale_output_factor: float = 1.0,
116
+ residual_connection: bool = False,
117
+ _from_deprecated_attn_block: bool = False,
118
+ processor: Optional["AttnProcessor"] = None,
119
+ out_dim: int = None,
120
+ ):
121
+ super().__init__()
122
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
123
+ self.query_dim = query_dim
124
+ self.cross_attention_dim = (
125
+ cross_attention_dim if cross_attention_dim is not None else query_dim
126
+ )
127
+ self.upcast_attention = upcast_attention
128
+ self.upcast_softmax = upcast_softmax
129
+ self.rescale_output_factor = rescale_output_factor
130
+ self.residual_connection = residual_connection
131
+ self.dropout = dropout
132
+ self.fused_projections = False
133
+ self.out_dim = out_dim if out_dim is not None else query_dim
134
+
135
+ # we make use of this private variable to know whether this class is loaded
136
+ # with an deprecated state dict so that we can convert it on the fly
137
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
138
+
139
+ self.scale_qk = scale_qk
140
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
141
+
142
+ self.heads = out_dim // dim_head if out_dim is not None else heads
143
+ # for slice_size > 0 the attention score computation
144
+ # is split across the batch axis to save memory
145
+ # You can set slice_size with `set_attention_slice`
146
+ self.sliceable_head_dim = heads
147
+
148
+ self.added_kv_proj_dim = added_kv_proj_dim
149
+ self.only_cross_attention = only_cross_attention
150
+
151
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
152
+ raise ValueError(
153
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
154
+ )
155
+
156
+ if norm_num_groups is not None:
157
+ self.group_norm = nn.GroupNorm(
158
+ num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True
159
+ )
160
+ else:
161
+ self.group_norm = None
162
+
163
+ self.spatial_norm = None
164
+
165
+ if cross_attention_norm is None:
166
+ self.norm_cross = None
167
+ elif cross_attention_norm == "layer_norm":
168
+ self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
169
+ elif cross_attention_norm == "group_norm":
170
+ if self.added_kv_proj_dim is not None:
171
+ # The given `encoder_hidden_states` are initially of shape
172
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
173
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
174
+ # before the projection, so we need to use `added_kv_proj_dim` as
175
+ # the number of channels for the group norm.
176
+ norm_cross_num_channels = added_kv_proj_dim
177
+ else:
178
+ norm_cross_num_channels = self.cross_attention_dim
179
+
180
+ self.norm_cross = nn.GroupNorm(
181
+ num_channels=norm_cross_num_channels,
182
+ num_groups=cross_attention_norm_num_groups,
183
+ eps=1e-5,
184
+ affine=True,
185
+ )
186
+ else:
187
+ raise ValueError(
188
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
189
+ )
190
+
191
+ linear_cls = nn.Linear
192
+
193
+ self.linear_cls = linear_cls
194
+ self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
195
+
196
+ if not self.only_cross_attention:
197
+ # only relevant for the `AddedKVProcessor` classes
198
+ self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
199
+ self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
200
+ else:
201
+ self.to_k = None
202
+ self.to_v = None
203
+
204
+ if self.added_kv_proj_dim is not None:
205
+ self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
206
+ self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
207
+
208
+ self.to_out = nn.ModuleList([])
209
+ self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
210
+ self.to_out.append(nn.Dropout(dropout))
211
+
212
+ # set attention processor
213
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
214
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
215
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
216
+ if processor is None:
217
+ processor = (
218
+ AttnProcessor2_0()
219
+ if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
220
+ else AttnProcessor()
221
+ )
222
+ self.set_processor(processor)
223
+
224
+ def set_processor(self, processor: "AttnProcessor") -> None:
225
+ self.processor = processor
226
+
227
+ def forward(
228
+ self,
229
+ hidden_states: torch.FloatTensor,
230
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
231
+ attention_mask: Optional[torch.FloatTensor] = None,
232
+ **cross_attention_kwargs,
233
+ ) -> torch.Tensor:
234
+ r"""
235
+ The forward method of the `Attention` class.
236
+
237
+ Args:
238
+ hidden_states (`torch.Tensor`):
239
+ The hidden states of the query.
240
+ encoder_hidden_states (`torch.Tensor`, *optional*):
241
+ The hidden states of the encoder.
242
+ attention_mask (`torch.Tensor`, *optional*):
243
+ The attention mask to use. If `None`, no mask is applied.
244
+ **cross_attention_kwargs:
245
+ Additional keyword arguments to pass along to the cross attention.
246
+
247
+ Returns:
248
+ `torch.Tensor`: The output of the attention layer.
249
+ """
250
+ # The `Attention` class can call different attention processors / attention functions
251
+ # here we simply pass along all tensors to the selected processor class
252
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
253
+ return self.processor(
254
+ self,
255
+ hidden_states,
256
+ encoder_hidden_states=encoder_hidden_states,
257
+ attention_mask=attention_mask,
258
+ **cross_attention_kwargs,
259
+ )
260
+
261
+ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
262
+ r"""
263
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
264
+ is the number of heads initialized while constructing the `Attention` class.
265
+
266
+ Args:
267
+ tensor (`torch.Tensor`): The tensor to reshape.
268
+
269
+ Returns:
270
+ `torch.Tensor`: The reshaped tensor.
271
+ """
272
+ head_size = self.heads
273
+ batch_size, seq_len, dim = tensor.shape
274
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
275
+ tensor = tensor.permute(0, 2, 1, 3).reshape(
276
+ batch_size // head_size, seq_len, dim * head_size
277
+ )
278
+ return tensor
279
+
280
+ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
281
+ r"""
282
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
283
+ the number of heads initialized while constructing the `Attention` class.
284
+
285
+ Args:
286
+ tensor (`torch.Tensor`): The tensor to reshape.
287
+ out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
288
+ reshaped to `[batch_size * heads, seq_len, dim // heads]`.
289
+
290
+ Returns:
291
+ `torch.Tensor`: The reshaped tensor.
292
+ """
293
+ head_size = self.heads
294
+ batch_size, seq_len, dim = tensor.shape
295
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
296
+ tensor = tensor.permute(0, 2, 1, 3)
297
+
298
+ if out_dim == 3:
299
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
300
+
301
+ return tensor
302
+
303
+ def get_attention_scores(
304
+ self,
305
+ query: torch.Tensor,
306
+ key: torch.Tensor,
307
+ attention_mask: torch.Tensor = None,
308
+ ) -> torch.Tensor:
309
+ r"""
310
+ Compute the attention scores.
311
+
312
+ Args:
313
+ query (`torch.Tensor`): The query tensor.
314
+ key (`torch.Tensor`): The key tensor.
315
+ attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
316
+
317
+ Returns:
318
+ `torch.Tensor`: The attention probabilities/scores.
319
+ """
320
+ dtype = query.dtype
321
+ if self.upcast_attention:
322
+ query = query.float()
323
+ key = key.float()
324
+
325
+ if attention_mask is None:
326
+ baddbmm_input = torch.empty(
327
+ query.shape[0],
328
+ query.shape[1],
329
+ key.shape[1],
330
+ dtype=query.dtype,
331
+ device=query.device,
332
+ )
333
+ beta = 0
334
+ else:
335
+ baddbmm_input = attention_mask
336
+ beta = 1
337
+
338
+ attention_scores = torch.baddbmm(
339
+ baddbmm_input,
340
+ query,
341
+ key.transpose(-1, -2),
342
+ beta=beta,
343
+ alpha=self.scale,
344
+ )
345
+ del baddbmm_input
346
+
347
+ if self.upcast_softmax:
348
+ attention_scores = attention_scores.float()
349
+
350
+ attention_probs = attention_scores.softmax(dim=-1)
351
+ del attention_scores
352
+
353
+ attention_probs = attention_probs.to(dtype)
354
+
355
+ return attention_probs
356
+
357
+ def prepare_attention_mask(
358
+ self,
359
+ attention_mask: torch.Tensor,
360
+ target_length: int,
361
+ batch_size: int,
362
+ out_dim: int = 3,
363
+ ) -> torch.Tensor:
364
+ r"""
365
+ Prepare the attention mask for the attention computation.
366
+
367
+ Args:
368
+ attention_mask (`torch.Tensor`):
369
+ The attention mask to prepare.
370
+ target_length (`int`):
371
+ The target length of the attention mask. This is the length of the attention mask after padding.
372
+ batch_size (`int`):
373
+ The batch size, which is used to repeat the attention mask.
374
+ out_dim (`int`, *optional*, defaults to `3`):
375
+ The output dimension of the attention mask. Can be either `3` or `4`.
376
+
377
+ Returns:
378
+ `torch.Tensor`: The prepared attention mask.
379
+ """
380
+ head_size = self.heads
381
+ if attention_mask is None:
382
+ return attention_mask
383
+
384
+ current_length: int = attention_mask.shape[-1]
385
+ if current_length != target_length:
386
+ if attention_mask.device.type == "mps":
387
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
388
+ # Instead, we can manually construct the padding tensor.
389
+ padding_shape = (
390
+ attention_mask.shape[0],
391
+ attention_mask.shape[1],
392
+ target_length,
393
+ )
394
+ padding = torch.zeros(
395
+ padding_shape,
396
+ dtype=attention_mask.dtype,
397
+ device=attention_mask.device,
398
+ )
399
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
400
+ else:
401
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
402
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
403
+ # remaining_length: int = target_length - current_length
404
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
405
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
406
+
407
+ if out_dim == 3:
408
+ if attention_mask.shape[0] < batch_size * head_size:
409
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
410
+ elif out_dim == 4:
411
+ attention_mask = attention_mask.unsqueeze(1)
412
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
413
+
414
+ return attention_mask
415
+
416
+ def norm_encoder_hidden_states(
417
+ self, encoder_hidden_states: torch.Tensor
418
+ ) -> torch.Tensor:
419
+ r"""
420
+ Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
421
+ `Attention` class.
422
+
423
+ Args:
424
+ encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
425
+
426
+ Returns:
427
+ `torch.Tensor`: The normalized encoder hidden states.
428
+ """
429
+ assert (
430
+ self.norm_cross is not None
431
+ ), "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
432
+
433
+ if isinstance(self.norm_cross, nn.LayerNorm):
434
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
435
+ elif isinstance(self.norm_cross, nn.GroupNorm):
436
+ # Group norm norms along the channels dimension and expects
437
+ # input to be in the shape of (N, C, *). In this case, we want
438
+ # to norm along the hidden dimension, so we need to move
439
+ # (batch_size, sequence_length, hidden_size) ->
440
+ # (batch_size, hidden_size, sequence_length)
441
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
442
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
443
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
444
+ else:
445
+ assert False
446
+
447
+ return encoder_hidden_states
448
+
449
+ @torch.no_grad()
450
+ def fuse_projections(self, fuse=True):
451
+ is_cross_attention = self.cross_attention_dim != self.query_dim
452
+ device = self.to_q.weight.data.device
453
+ dtype = self.to_q.weight.data.dtype
454
+
455
+ if not is_cross_attention:
456
+ # fetch weight matrices.
457
+ concatenated_weights = torch.cat(
458
+ [self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]
459
+ )
460
+ in_features = concatenated_weights.shape[1]
461
+ out_features = concatenated_weights.shape[0]
462
+
463
+ # create a new single projection layer and copy over the weights.
464
+ self.to_qkv = self.linear_cls(
465
+ in_features, out_features, bias=False, device=device, dtype=dtype
466
+ )
467
+ self.to_qkv.weight.copy_(concatenated_weights)
468
+
469
+ else:
470
+ concatenated_weights = torch.cat(
471
+ [self.to_k.weight.data, self.to_v.weight.data]
472
+ )
473
+ in_features = concatenated_weights.shape[1]
474
+ out_features = concatenated_weights.shape[0]
475
+
476
+ self.to_kv = self.linear_cls(
477
+ in_features, out_features, bias=False, device=device, dtype=dtype
478
+ )
479
+ self.to_kv.weight.copy_(concatenated_weights)
480
+
481
+ self.fused_projections = fuse
482
+
483
+
484
+ class AttnProcessor:
485
+ r"""
486
+ Default processor for performing attention-related computations.
487
+ """
488
+
489
+ def __call__(
490
+ self,
491
+ attn: Attention,
492
+ hidden_states: torch.FloatTensor,
493
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
494
+ attention_mask: Optional[torch.FloatTensor] = None,
495
+ ) -> torch.Tensor:
496
+ residual = hidden_states
497
+
498
+ input_ndim = hidden_states.ndim
499
+
500
+ if input_ndim == 4:
501
+ batch_size, channel, height, width = hidden_states.shape
502
+ hidden_states = hidden_states.view(
503
+ batch_size, channel, height * width
504
+ ).transpose(1, 2)
505
+
506
+ batch_size, sequence_length, _ = (
507
+ hidden_states.shape
508
+ if encoder_hidden_states is None
509
+ else encoder_hidden_states.shape
510
+ )
511
+ attention_mask = attn.prepare_attention_mask(
512
+ attention_mask, sequence_length, batch_size
513
+ )
514
+
515
+ if attn.group_norm is not None:
516
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
517
+ 1, 2
518
+ )
519
+
520
+ query = attn.to_q(hidden_states)
521
+
522
+ if encoder_hidden_states is None:
523
+ encoder_hidden_states = hidden_states
524
+ elif attn.norm_cross:
525
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
526
+ encoder_hidden_states
527
+ )
528
+
529
+ key = attn.to_k(encoder_hidden_states)
530
+ value = attn.to_v(encoder_hidden_states)
531
+
532
+ query = attn.head_to_batch_dim(query)
533
+ key = attn.head_to_batch_dim(key)
534
+ value = attn.head_to_batch_dim(value)
535
+
536
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
537
+ hidden_states = torch.bmm(attention_probs, value)
538
+ hidden_states = attn.batch_to_head_dim(hidden_states)
539
+
540
+ # linear proj
541
+ hidden_states = attn.to_out[0](hidden_states)
542
+ # dropout
543
+ hidden_states = attn.to_out[1](hidden_states)
544
+
545
+ if input_ndim == 4:
546
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
547
+ batch_size, channel, height, width
548
+ )
549
+
550
+ if attn.residual_connection:
551
+ hidden_states = hidden_states + residual
552
+
553
+ hidden_states = hidden_states / attn.rescale_output_factor
554
+
555
+ return hidden_states
556
+
557
+
558
+ class AttnProcessor2_0:
559
+ r"""
560
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
561
+ """
562
+
563
+ def __init__(self):
564
+ if not hasattr(F, "scaled_dot_product_attention"):
565
+ raise ImportError(
566
+ "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
567
+ )
568
+
569
+ def __call__(
570
+ self,
571
+ attn: Attention,
572
+ hidden_states: torch.FloatTensor,
573
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
574
+ attention_mask: Optional[torch.FloatTensor] = None,
575
+ ) -> torch.FloatTensor:
576
+ residual = hidden_states
577
+
578
+ input_ndim = hidden_states.ndim
579
+
580
+ if input_ndim == 4:
581
+ batch_size, channel, height, width = hidden_states.shape
582
+ hidden_states = hidden_states.view(
583
+ batch_size, channel, height * width
584
+ ).transpose(1, 2)
585
+
586
+ batch_size, sequence_length, _ = (
587
+ hidden_states.shape
588
+ if encoder_hidden_states is None
589
+ else encoder_hidden_states.shape
590
+ )
591
+
592
+ if attention_mask is not None:
593
+ attention_mask = attn.prepare_attention_mask(
594
+ attention_mask, sequence_length, batch_size
595
+ )
596
+ # scaled_dot_product_attention expects attention_mask shape to be
597
+ # (batch, heads, source_length, target_length)
598
+ attention_mask = attention_mask.view(
599
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
600
+ )
601
+
602
+ if attn.group_norm is not None:
603
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
604
+ 1, 2
605
+ )
606
+
607
+ query = attn.to_q(hidden_states)
608
+
609
+ if encoder_hidden_states is None:
610
+ encoder_hidden_states = hidden_states
611
+ elif attn.norm_cross:
612
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
613
+ encoder_hidden_states
614
+ )
615
+
616
+ key = attn.to_k(encoder_hidden_states)
617
+ value = attn.to_v(encoder_hidden_states)
618
+
619
+ inner_dim = key.shape[-1]
620
+ head_dim = inner_dim // attn.heads
621
+
622
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
623
+
624
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
625
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
626
+
627
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
628
+ # TODO: add support for attn.scale when we move to Torch 2.1
629
+ hidden_states = F.scaled_dot_product_attention(
630
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
631
+ )
632
+
633
+ hidden_states = hidden_states.transpose(1, 2).reshape(
634
+ batch_size, -1, attn.heads * head_dim
635
+ )
636
+ hidden_states = hidden_states.to(query.dtype)
637
+
638
+ # linear proj
639
+ hidden_states = attn.to_out[0](hidden_states)
640
+ # dropout
641
+ hidden_states = attn.to_out[1](hidden_states)
642
+
643
+ if input_ndim == 4:
644
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
645
+ batch_size, channel, height, width
646
+ )
647
+
648
+ if attn.residual_connection:
649
+ hidden_states = hidden_states + residual
650
+
651
+ hidden_states = hidden_states / attn.rescale_output_factor
652
+
653
+ return hidden_states