|
from typing import Optional, Tuple |
|
from torch import nn |
|
from diffusers.models.resnet import Downsample2D |
|
import torch |
|
from diffusers.utils import is_torch_version |
|
from .ResnetBlock2D import ResnetBlock2D |
|
|
|
class DownBlock2D(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
temb_channels: int, |
|
dropout: float = 0.0, |
|
num_layers: int = 1, |
|
resnet_eps: float = 1e-6, |
|
resnet_time_scale_shift: str = "default", |
|
resnet_act_fn: str = "swish", |
|
resnet_groups: int = 32, |
|
resnet_pre_norm: bool = True, |
|
output_scale_factor: float = 1.0, |
|
add_downsample: bool = True, |
|
downsample_padding: int = 1, |
|
normalization_type = None, |
|
SPADE_chs = (320, 640, 1280, 1280), |
|
is_crossAttn = False, |
|
): |
|
super().__init__() |
|
resnets = [] |
|
|
|
for i in range(num_layers): |
|
in_channels = in_channels if i == 0 else out_channels |
|
resnets.append( |
|
ResnetBlock2D( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
temb_channels=temb_channels, |
|
eps=resnet_eps, |
|
groups=resnet_groups, |
|
dropout=dropout, |
|
time_embedding_norm=resnet_time_scale_shift, |
|
non_linearity=resnet_act_fn, |
|
output_scale_factor=output_scale_factor, |
|
pre_norm=resnet_pre_norm, |
|
SPADE_chs=SPADE_chs, |
|
normalization_type=normalization_type, |
|
is_crossAttn = is_crossAttn |
|
) |
|
) |
|
|
|
self.resnets = nn.ModuleList(resnets) |
|
|
|
if add_downsample: |
|
self.downsamplers = nn.ModuleList( |
|
[ |
|
Downsample2D( |
|
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" |
|
) |
|
] |
|
) |
|
else: |
|
self.downsamplers = None |
|
|
|
self.gradient_checkpointing = False |
|
|
|
def forward( |
|
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, |
|
scale: float = 1.0, segmap=None, |
|
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: |
|
output_states = () |
|
|
|
for resnet in self.resnets: |
|
if self.training and self.gradient_checkpointing: |
|
|
|
def create_custom_forward(module): |
|
def custom_forward(*inputs): |
|
return module(*inputs) |
|
|
|
return custom_forward |
|
|
|
if is_torch_version(">=", "1.11.0"): |
|
hidden_states = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False |
|
) |
|
else: |
|
hidden_states = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(resnet), hidden_states, temb |
|
) |
|
else: |
|
hidden_states = resnet(hidden_states, temb, scale=scale, segmaps=segmap) |
|
|
|
output_states = output_states + (hidden_states,) |
|
|
|
if self.downsamplers is not None: |
|
for downsampler in self.downsamplers: |
|
hidden_states = downsampler(hidden_states, scale=scale) |
|
|
|
output_states = output_states + (hidden_states,) |
|
|
|
return hidden_states, output_states |