mac9087 commited on
Commit
1d6325d
·
verified ·
1 Parent(s): 527aa85

Create basic_transformer_block.py

Browse files
tsr/models/transformer/basic_transformer_block.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # --------
16
+ #
17
+ # Modified 2024 by the Tripo AI and Stability AI Team.
18
+ #
19
+ # Copyright (c) 2024 Tripo AI & Stability AI
20
+ #
21
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
22
+ # of this software and associated documentation files (the "Software"), to deal
23
+ # in the Software without restriction, including without limitation the rights
24
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25
+ # copies of the Software, and to permit persons to whom the Software is
26
+ # furnished to do so, subject to the following conditions:
27
+ #
28
+ # The above copyright notice and this permission notice shall be included in all
29
+ # copies or substantial portions of the Software.
30
+ #
31
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
34
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
36
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
37
+ # SOFTWARE.
38
+
39
+ from typing import Optional
40
+
41
+ import torch
42
+ import torch.nn.functional as F
43
+ from torch import nn
44
+
45
+ from .attention import Attention
46
+
47
+
48
+ class BasicTransformerBlock(nn.Module):
49
+ r"""
50
+ A basic Transformer block.
51
+
52
+ Parameters:
53
+ dim (`int`): The number of channels in the input and output.
54
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
55
+ attention_head_dim (`int`): The number of channels in each head.
56
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
57
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
58
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
59
+ attention_bias (:
60
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
61
+ only_cross_attention (`bool`, *optional*):
62
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
63
+ double_self_attention (`bool`, *optional*):
64
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
65
+ upcast_attention (`bool`, *optional*):
66
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
67
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
68
+ Whether to use learnable elementwise affine parameters for normalization.
69
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
70
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
71
+ final_dropout (`bool` *optional*, defaults to False):
72
+ Whether to apply a final dropout after the last feed-forward layer.
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ dim: int,
78
+ num_attention_heads: int,
79
+ attention_head_dim: int,
80
+ dropout=0.0,
81
+ cross_attention_dim: Optional[int] = None,
82
+ activation_fn: str = "geglu",
83
+ attention_bias: bool = False,
84
+ only_cross_attention: bool = False,
85
+ double_self_attention: bool = False,
86
+ upcast_attention: bool = False,
87
+ norm_elementwise_affine: bool = True,
88
+ norm_type: str = "layer_norm",
89
+ final_dropout: bool = False,
90
+ ):
91
+ super().__init__()
92
+ self.only_cross_attention = only_cross_attention
93
+
94
+ assert norm_type == "layer_norm"
95
+
96
+ # Define 3 blocks. Each block has its own normalization layer.
97
+ # 1. Self-Attn
98
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
99
+ self.attn1 = Attention(
100
+ query_dim=dim,
101
+ heads=num_attention_heads,
102
+ dim_head=attention_head_dim,
103
+ dropout=dropout,
104
+ bias=attention_bias,
105
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
106
+ upcast_attention=upcast_attention,
107
+ )
108
+
109
+ # 2. Cross-Attn
110
+ if cross_attention_dim is not None or double_self_attention:
111
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
112
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
113
+ # the second cross attention block.
114
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
115
+
116
+ self.attn2 = Attention(
117
+ query_dim=dim,
118
+ cross_attention_dim=(
119
+ cross_attention_dim if not double_self_attention else None
120
+ ),
121
+ heads=num_attention_heads,
122
+ dim_head=attention_head_dim,
123
+ dropout=dropout,
124
+ bias=attention_bias,
125
+ upcast_attention=upcast_attention,
126
+ ) # is self-attn if encoder_hidden_states is none
127
+ else:
128
+ self.norm2 = None
129
+ self.attn2 = None
130
+
131
+ # 3. Feed-forward
132
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
133
+ self.ff = FeedForward(
134
+ dim,
135
+ dropout=dropout,
136
+ activation_fn=activation_fn,
137
+ final_dropout=final_dropout,
138
+ )
139
+
140
+ # let chunk size default to None
141
+ self._chunk_size = None
142
+ self._chunk_dim = 0
143
+
144
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
145
+ # Sets chunk feed-forward
146
+ self._chunk_size = chunk_size
147
+ self._chunk_dim = dim
148
+
149
+ def forward(
150
+ self,
151
+ hidden_states: torch.FloatTensor,
152
+ attention_mask: Optional[torch.FloatTensor] = None,
153
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
154
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
155
+ ) -> torch.FloatTensor:
156
+ # Notice that normalization is always applied before the real computation in the following blocks.
157
+ # 0. Self-Attention
158
+ norm_hidden_states = self.norm1(hidden_states)
159
+
160
+ attn_output = self.attn1(
161
+ norm_hidden_states,
162
+ encoder_hidden_states=(
163
+ encoder_hidden_states if self.only_cross_attention else None
164
+ ),
165
+ attention_mask=attention_mask,
166
+ )
167
+
168
+ hidden_states = attn_output + hidden_states
169
+
170
+ # 3. Cross-Attention
171
+ if self.attn2 is not None:
172
+ norm_hidden_states = self.norm2(hidden_states)
173
+
174
+ attn_output = self.attn2(
175
+ norm_hidden_states,
176
+ encoder_hidden_states=encoder_hidden_states,
177
+ attention_mask=encoder_attention_mask,
178
+ )
179
+ hidden_states = attn_output + hidden_states
180
+
181
+ # 4. Feed-forward
182
+ norm_hidden_states = self.norm3(hidden_states)
183
+
184
+ if self._chunk_size is not None:
185
+ # "feed_forward_chunk_size" can be used to save memory
186
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
187
+ raise ValueError(
188
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
189
+ )
190
+
191
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
192
+ ff_output = torch.cat(
193
+ [
194
+ self.ff(hid_slice)
195
+ for hid_slice in norm_hidden_states.chunk(
196
+ num_chunks, dim=self._chunk_dim
197
+ )
198
+ ],
199
+ dim=self._chunk_dim,
200
+ )
201
+ else:
202
+ ff_output = self.ff(norm_hidden_states)
203
+
204
+ hidden_states = ff_output + hidden_states
205
+
206
+ return hidden_states
207
+
208
+
209
+ class FeedForward(nn.Module):
210
+ r"""
211
+ A feed-forward layer.
212
+
213
+ Parameters:
214
+ dim (`int`): The number of channels in the input.
215
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
216
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
217
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
218
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
219
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
220
+ """
221
+
222
+ def __init__(
223
+ self,
224
+ dim: int,
225
+ dim_out: Optional[int] = None,
226
+ mult: int = 4,
227
+ dropout: float = 0.0,
228
+ activation_fn: str = "geglu",
229
+ final_dropout: bool = False,
230
+ ):
231
+ super().__init__()
232
+ inner_dim = int(dim * mult)
233
+ dim_out = dim_out if dim_out is not None else dim
234
+ linear_cls = nn.Linear
235
+
236
+ if activation_fn == "gelu":
237
+ act_fn = GELU(dim, inner_dim)
238
+ if activation_fn == "gelu-approximate":
239
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
240
+ elif activation_fn == "geglu":
241
+ act_fn = GEGLU(dim, inner_dim)
242
+ elif activation_fn == "geglu-approximate":
243
+ act_fn = ApproximateGELU(dim, inner_dim)
244
+
245
+ self.net = nn.ModuleList([])
246
+ # project in
247
+ self.net.append(act_fn)
248
+ # project dropout
249
+ self.net.append(nn.Dropout(dropout))
250
+ # project out
251
+ self.net.append(linear_cls(inner_dim, dim_out))
252
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
253
+ if final_dropout:
254
+ self.net.append(nn.Dropout(dropout))
255
+
256
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
257
+ for module in self.net:
258
+ hidden_states = module(hidden_states)
259
+ return hidden_states
260
+
261
+
262
+ class GELU(nn.Module):
263
+ r"""
264
+ GELU activation function with tanh approximation support with `approximate="tanh"`.
265
+
266
+ Parameters:
267
+ dim_in (`int`): The number of channels in the input.
268
+ dim_out (`int`): The number of channels in the output.
269
+ approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
270
+ """
271
+
272
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
273
+ super().__init__()
274
+ self.proj = nn.Linear(dim_in, dim_out)
275
+ self.approximate = approximate
276
+
277
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
278
+ if gate.device.type != "mps":
279
+ return F.gelu(gate, approximate=self.approximate)
280
+ # mps: gelu is not implemented for float16
281
+ return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(
282
+ dtype=gate.dtype
283
+ )
284
+
285
+ def forward(self, hidden_states):
286
+ hidden_states = self.proj(hidden_states)
287
+ hidden_states = self.gelu(hidden_states)
288
+ return hidden_states
289
+
290
+
291
+ class GEGLU(nn.Module):
292
+ r"""
293
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
294
+
295
+ Parameters:
296
+ dim_in (`int`): The number of channels in the input.
297
+ dim_out (`int`): The number of channels in the output.
298
+ """
299
+
300
+ def __init__(self, dim_in: int, dim_out: int):
301
+ super().__init__()
302
+ linear_cls = nn.Linear
303
+
304
+ self.proj = linear_cls(dim_in, dim_out * 2)
305
+
306
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
307
+ if gate.device.type != "mps":
308
+ return F.gelu(gate)
309
+ # mps: gelu is not implemented for float16
310
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
311
+
312
+ def forward(self, hidden_states, scale: float = 1.0):
313
+ args = ()
314
+ hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
315
+ return hidden_states * self.gelu(gate)
316
+
317
+
318
+ class ApproximateGELU(nn.Module):
319
+ r"""
320
+ The approximate form of Gaussian Error Linear Unit (GELU). For more details, see section 2:
321
+ https://arxiv.org/abs/1606.08415.
322
+
323
+ Parameters:
324
+ dim_in (`int`): The number of channels in the input.
325
+ dim_out (`int`): The number of channels in the output.
326
+ """
327
+
328
+ def __init__(self, dim_in: int, dim_out: int):
329
+ super().__init__()
330
+ self.proj = nn.Linear(dim_in, dim_out)
331
+
332
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
333
+ x = self.proj(x)
334
+ return x * torch.sigmoid(1.702 * x)