import torch import torch.nn as nn import torch.nn.functional as F from base.base_modules import * from timm.models import create_model from functools import partial class Backbone(nn.Module): """ Model backbone to extract features """ def __init__(self, input_channels: int = 3, channels: tuple = (32, 64, 128, 256, 512), strides: tuple = (2, 2, 2, 2), use_dropout: bool = False, norm: str = 'BATCH', leaky: bool = True): """ Args: input_channels: the number of input channels channels: length-5 tuple, define the number of channels in each stage strides: tuple, define the stride in each stage use_dropout: bool, whether to use dropout norm: str, normalization type leaky: bool, whether to use leaky relu """ super().__init__() self.nb_filter = channels self.strides = strides + (5 - len(strides)) * (1,) res_unit = ResBlock if channels[-1] <= 320 else ResBottleneck self.conv0_0 = nn.Sequential( nn.Conv2d(input_channels, channels[0], kernel_size=7, stride=self.strides[0], padding=3), nn.GroupNorm(1, channels[0]) if norm == 'GROUP' else nn.BatchNorm2d(channels[0]) if norm == 'BATCH' else nn.InstanceNorm2d(channels[0]), nn.LeakyReLU() if leaky else nn.ReLU(), ) self.conv1_0 = res_unit(self.nb_filter[0], self.nb_filter[1], self.strides[1], use_dropout=use_dropout, norm=norm, leaky=leaky) self.conv2_0 = res_unit(self.nb_filter[1], self.nb_filter[2], self.strides[2], use_dropout=use_dropout, norm=norm, leaky=leaky) self.conv3_0 = res_unit(self.nb_filter[2], self.nb_filter[3], self.strides[3], use_dropout=use_dropout, norm=norm, leaky=leaky) self.conv4_0 = res_unit(self.nb_filter[3], self.nb_filter[4], self.strides[4], use_dropout=use_dropout, norm=norm, leaky=leaky) def forward(self, x): x0_0 = self.conv0_0(x) x1_0 = self.conv1_0(x0_0) x2_0 = self.conv2_0(x1_0) x3_0 = self.conv3_0(x2_0) x4_0 = self.conv4_0(x3_0) return x0_0, x1_0, x2_0, x3_0, x4_0 class TimmBackbone(nn.Module): """ Timm backbone to extract features, utilizing pretrained weights """ def __init__(self, model_name) -> None: super().__init__() self.backbone = create_model(model_name, pretrained=True, features_only=True) self.determine_nb_filters() def determine_nb_filters(self): dummy = torch.randn(1, 3, 256, 256) out = self.backbone(dummy) nb_filters = [] for o in out: nb_filters.append(o.size(1)) self.nb_filter = nb_filters def forward(self, inputs): return self.backbone(inputs) class UNet(nn.Module): def __init__(self, model_name: str = None, in_channels: int = 1, out_channels: int = None, channels: tuple = (64, 128, 256, 320, 512), strides: tuple = (2, 2, 2, 2, 2), use_dropout: bool = False, norm: str = 'INSTANCE', leaky: bool = True, use_dilated_bottleneck: bool = False): """ Args: model_name: timm model name input_channels: the number of input channels in_channels: the number of output channels channels: length-5 tuple, define the number of channels in each stage strides: tuple, define the stride in each stage use_dropout: bool, whether to use dropout norm: str, normalization type leaky: bool, whether to use leaky relu """ super().__init__() if model_name not in [None, 'none', 'None']: # use Timm backbone and overrides any other input arguments self.backbone = TimmBackbone(model_name) else: self.backbone = Backbone(input_channels=in_channels, channels=channels, strides=strides, use_dropout=use_dropout, norm=norm, leaky=leaky) nb_filter = self.backbone.nb_filter res_unit = ResBlock if nb_filter[-1] <= 512 else ResBottleneck # decoder self.conv3_1 = res_unit(nb_filter[3] + nb_filter[4], nb_filter[3], use_dropout=use_dropout, norm=norm, leaky=leaky) self.conv2_2 = res_unit(nb_filter[2] + nb_filter[3], nb_filter[2], use_dropout=use_dropout, norm=norm, leaky=leaky) self.conv1_3 = res_unit(nb_filter[1] + nb_filter[2], nb_filter[1], use_dropout=use_dropout, norm=norm, leaky=leaky) self.conv0_4 = res_unit(nb_filter[0] + nb_filter[1], nb_filter[0], use_dropout=use_dropout, norm=norm, leaky=leaky) # dilated bottleneck: optional if use_dilated_bottleneck: self.dilation = nn.Sequential( nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=1), nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]), nn.LeakyReLU() if leaky else nn.ReLU(), nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=2, dilation=2), nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]), nn.LeakyReLU() if leaky else nn.ReLU(), nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=5, dilation=5), nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]), nn.LeakyReLU() if leaky else nn.ReLU(), nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=1), nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]), nn.LeakyReLU() if leaky else nn.ReLU(), nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=2, dilation=2), nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]), nn.LeakyReLU() if leaky else nn.ReLU(), nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=5, dilation=5), nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]), nn.LeakyReLU() if leaky else nn.ReLU(), ) else: self.dilation = nn.Identity() if out_channels is not None: self.convds0 = nn.Conv2d(nb_filter[0], out_channels, kernel_size=1, bias=False) else: self.convds0 = None def upsample(self, inputs, target): return F.interpolate(inputs, size=target.shape[2:], mode='bilinear', align_corners=False) def extract_features(self, x): x0, x1, x2, x3, x4 = self.backbone(x) x4 = self.dilation(x4) x3_1 = self.conv3_1(torch.cat([x3, self.upsample(x4, x3)], dim=1)) x2_2 = self.conv2_2(torch.cat([x2, self.upsample(x3_1, x2)], dim=1)) x1_3 = self.conv1_3(torch.cat([x1, self.upsample(x2_2, x1)], dim=1)) x0_4 = self.conv0_4(torch.cat([x0, self.upsample(x1_3, x0)], dim=1)) return x4, x0_4 def forward(self, x): size = x.shape[2:] x0, x1, x2, x3, x4 = self.backbone(x) x4 = self.dilation(x4) x3_1 = self.conv3_1(torch.cat([x3, self.upsample(x4, x3)], dim=1)) x2_2 = self.conv2_2(torch.cat([x2, self.upsample(x3_1, x2)], dim=1)) x1_3 = self.conv1_3(torch.cat([x1, self.upsample(x2_2, x1)], dim=1)) x0_4 = self.conv0_4(torch.cat([x0, self.upsample(x1_3, x0)], dim=1)) if self.convds0 is not None: x_out = self.convds0(x0_4) out = F.interpolate(x_out, size=size, mode='bilinear', align_corners=False) else: out = x0_4 return out def freeze(self): # freeze the network for p in self.parameters(): p.requires_grad = False def unfreeze(self): # unfreeze the network to allow parameter update for p in self.parameters(): p.requires_grad = True class PromptAttentionUNet(nn.Module): def __init__(self, model_name: str = None, in_channels: int = 1, out_channels: int = None, channels: tuple = (64, 128, 256, 320, 512), strides: tuple = (2, 2, 2, 2, 2), use_dropout: bool = False, norm: str = 'INSTANCE', leaky: bool = True, use_dilated_bottleneck: bool = False): """ Args: model_name: timm model name input_channels: the number of input channels in_channels: the number of output channels channels: length-5 tuple, define the number of channels in each stage strides: tuple, define the stride in each stage use_dropout: bool, whether to use dropout norm: str, normalization type leaky: bool, whether to use leaky relu """ super().__init__() if model_name not in [None, 'none', 'None']: # use Timm backbone and overrides any other input arguments self.backbone = TimmBackbone(model_name) else: self.backbone = Backbone(input_channels=in_channels, channels=channels, strides=strides, use_dropout=use_dropout, norm=norm, leaky=leaky) nb_filter = self.backbone.nb_filter res_unit = PromptResBlock if nb_filter[-1] <= 512 else PromptResBottleneck # decoder self.conv3_1 = res_unit(nb_filter[3] + nb_filter[4], nb_filter[3], use_dropout=use_dropout, norm=norm, leaky=leaky) self.conv2_2 = res_unit(nb_filter[2] + nb_filter[3], nb_filter[2], use_dropout=use_dropout, norm=norm, leaky=leaky) self.conv1_3 = res_unit(nb_filter[1] + nb_filter[2], nb_filter[1], use_dropout=use_dropout, norm=norm, leaky=leaky) self.conv0_4 = res_unit(nb_filter[0] + nb_filter[1], nb_filter[0], use_dropout=use_dropout, norm=norm, leaky=leaky) # dilated bottleneck: optional if use_dilated_bottleneck: self.dilation = nn.Sequential( nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=1), nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]), nn.LeakyReLU() if leaky else nn.ReLU(), nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=2), nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]), nn.LeakyReLU() if leaky else nn.ReLU(), nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=5), nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]), nn.LeakyReLU() if leaky else nn.ReLU(), nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=1), nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]), nn.LeakyReLU() if leaky else nn.ReLU(), nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=2), nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]), nn.LeakyReLU() if leaky else nn.ReLU(), nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=5), nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]), nn.LeakyReLU() if leaky else nn.ReLU(), ) else: self.dilation = nn.Identity() if out_channels is not None: self.convds0 = nn.Conv2d(nb_filter[0], out_channels, kernel_size=1, bias=False) def upsample(self, inputs, target): return F.interpolate(inputs, size=target.shape[2:], mode='bilinear', align_corners=False) def extract_features(self, x): x0, x1, x2, x3, x4 = self.backbone(x) x4 = self.dilation(x4) x3_1 = self.conv3_1(torch.cat([x3, self.upsample(x4, x3)], dim=1)) x2_2 = self.conv2_2(torch.cat([x2, self.upsample(x3_1, x2)], dim=1)) x1_3 = self.conv1_3(torch.cat([x1, self.upsample(x2_2, x1)], dim=1)) x0_4 = self.conv0_4(torch.cat([x0, self.upsample(x1_3, x0)], dim=1)) return x4, x0_4 def forward(self, x, prompt_in): size = x.shape[2:] x0, x1, x2, x3, x4 = self.backbone(x) x4 = self.dilation(x4) x3_1 = self.conv3_1(torch.cat([x3, self.upsample(x4, x3)], dim=1), prompt_in) x2_2 = self.conv2_2(torch.cat([x2, self.upsample(x3_1, x2)], dim=1), prompt_in) x1_3 = self.conv1_3(torch.cat([x1, self.upsample(x2_2, x1)], dim=1), prompt_in) x0_4 = self.conv0_4(torch.cat([x0, self.upsample(x1_3, x0)], dim=1), prompt_in) x_out = self.convds0(x0_4) out = F.interpolate(x_out, size=size, mode='bilinear', align_corners=False) return out def freeze(self): # freeze the network for p in self.parameters(): p.requires_grad = False def unfreeze(self): # unfreeze the network to allow parameter update for p in self.parameters(): p.requires_grad = True class CLIPDrivenUNet(nn.Module): def __init__(self, encoding: str, model_name: str = None, in_channels: int = 1, out_channels: int = 1, channels: tuple = (32, 64, 128, 256, 512), strides: tuple = (2, 2, 2, 2, 2), norm: str = 'INSTANCE', leaky: bool = True) -> None: super().__init__() self.encoding = encoding self.num_classes = out_channels self.backbone = UNet(model_name=model_name, in_channels=in_channels, out_channels=None, channels=channels, strides=strides, use_dropout=False, norm=norm, leaky=leaky) self.gap = nn.AdaptiveAvgPool2d(1) self.precls_conv = nn.Sequential( nn.InstanceNorm2d(32), nn.LeakyReLU(), nn.Conv2d(32, 8, kernel_size=1) ) self.weight_nums = [8*8, 8*8, 8*1] self.bias_nums = [8, 8, 1] self.controller = nn.Conv2d(256 + channels[-1], sum(self.weight_nums + self.bias_nums), kernel_size=1, stride=1, padding=0) if encoding == 'CLIP': self.register_buffer('protein_embedding', torch.randn(self.num_classes, 512)) self.text_to_vision = nn.Linear(512, 256) elif encoding == 'RAND': self.register_buffer('protein_embedding', torch.randn(self.num_classes, 256)) def parse_dynamic_params(self, params, channels, weight_nums, bias_nums): assert params.dim() == 2 assert len(weight_nums) == len(bias_nums) assert params.size(1) == sum(weight_nums) + sum(bias_nums) num_insts = params.size(0) num_layers = len(weight_nums) params_splits = list(torch.split_with_sizes( params, weight_nums + bias_nums, dim=1 )) weight_splits = params_splits[:num_layers] bias_splits = params_splits[num_layers:] for l in range(num_layers): if l < num_layers - 1: weight_splits[l] = weight_splits[l].reshape(num_insts * channels, -1, 1, 1) bias_splits[l] = bias_splits[l].reshape(num_insts * channels) else: weight_splits[l] = weight_splits[l].reshape(num_insts * 1, -1, 1, 1) bias_splits[l] = bias_splits[l].reshape(num_insts * 1) # print(weight_splits[l].shape, bias_splits[l].shape) return weight_splits, bias_splits def heads_forward(self, features, weights, biases, num_insts): n_layers = len(weights) x = features for i, (w, b) in enumerate(zip(weights, biases)): x = F.conv2d( x, w, bias=b, stride=1, padding=0, groups=num_insts ) if i < n_layers - 1: x = F.leaky_relu(x) return x def forward(self, x_in): out_shape = x_in.shape[2:] dec4, out = self.backbone.extract_features(x_in) # dec4: (B, channels[-1], H, W), out: (B, channels[0], H, W) if self.encoding == 'RAND': task_encoding = self.protein_embedding[..., None, None] # (num_classes, 256, 1, 1) elif self.encoding == 'CLIP': task_encoding = F.leaky_relu(self.text_to_vision(self.protein_embedding))[..., None, None] # (num_classes, 256, 1, 1) else: raise NotImplementedError x_feat = self.gap(dec4) b = x_feat.shape[0] logits_array = [] for i in range(b): x_cond = torch.cat([x_feat[i].unsqueeze(0).repeat(self.num_classes, 1, 1, 1), task_encoding], 1) params = self.controller(x_cond) # (num_classes, num_params, 1, 1) params.squeeze_(-1).squeeze_(-1) # (num_classes, num_params) head_inputs = self.precls_conv(out[i].unsqueeze(0)) head_inputs = head_inputs.repeat(self.num_classes, 1, 1, 1) # (num_classes, 8, H, W) N, _, H, W = head_inputs.size() head_inputs = head_inputs.reshape(1, -1, H, W) # print(head_inputs.shape, params.shape) weights, biases = self.parse_dynamic_params(params, 8, self.weight_nums, self.bias_nums) logits = self.heads_forward(head_inputs, weights, biases, N) logits_array.append(logits.reshape(1, -1, H, W)) out = torch.cat(logits_array, dim=0) out = F.interpolate(out, size=out_shape, mode='bilinear', align_corners=False) # print(out.shape) return out class NLayerDiscriminator(nn.Module): """Defines a PatchGAN discriminator""" def __init__(self, input_nc, norm='INSTANCE', ndf=64, n_layers=3): """Construct a PatchGAN discriminator Parameters: input_nc (int) -- the number of channels in input images ndf (int) -- the number of filters in the last conv layer n_layers (int) -- the number of conv layers in the discriminator norm_layer -- normalization layer """ super(NLayerDiscriminator, self).__init__() norm_layer = norm_dict[norm] use_bias = norm_layer == nn.InstanceNorm2d kw = 4 padw = 1 sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] nf_mult = 1 nf_mult_prev = 1 for n in range(1, n_layers): # gradually increase the number of filters nf_mult_prev = nf_mult nf_mult = min(2 ** n, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] nf_mult_prev = nf_mult nf_mult = min(2 ** n_layers, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map self.model = nn.Sequential(*sequence) def forward(self, input): """Standard forward.""" return self.model(input) class PatchDiscriminator(nn.Module): def __init__(self, in_channels, norm_type='INSTANCE'): super().__init__() nb_filters = [32, 64, 128, 256, 512] strides = [2, 2, 2, 2, 2] self.layer1 = ConvNorm(in_channels=in_channels, out_channels=nb_filters[0], kernel_size=5, stride=strides[0], norm='NONE', leaky=True) self.layer2 = ConvNorm(in_channels=nb_filters[0], out_channels=nb_filters[1], kernel_size=5, stride=strides[1], norm=norm_type, leaky=True) self.layer3 = ConvNorm(in_channels=nb_filters[1], out_channels=nb_filters[2], kernel_size=5, stride=strides[2], norm=norm_type, leaky=True) self.layer4 = ConvNorm(in_channels=nb_filters[2], out_channels=nb_filters[3], kernel_size=5, stride=strides[3], norm=norm_type, leaky=True) self.layer5 = ConvNorm(in_channels=nb_filters[3], out_channels=nb_filters[4], kernel_size=5, stride=strides[4], norm=norm_type, leaky=True) self.dense_pred = ConvNorm(in_channels=nb_filters[4], out_channels=1, kernel_size=3, stride=1, norm='NONE', activation=False) def forward(self, inputs): x1 = self.layer1(inputs) x2 = self.layer2(x1) x3 = self.layer3(x2) x4 = self.layer4(x3) x5 = self.layer5(x4) output = self.dense_pred(x5) output_list = [x1, x2, x3, x4, x5, output] return output_list class PromptPatchDiscriminator(nn.Module): def __init__(self, in_channels, norm_type='INSTANCE'): super().__init__() nb_filters = [32, 64, 128, 256, 512] strides = [2, 2, 2, 2, 2] self.layer1 = ConvNorm(in_channels=in_channels, out_channels=nb_filters[0], kernel_size=5, stride=strides[0], norm='NONE', leaky=True) self.layer2 = ConvNorm(in_channels=nb_filters[0], out_channels=nb_filters[1], kernel_size=5, stride=strides[1], norm=norm_type, leaky=True) self.layer3 = ConvNorm(in_channels=nb_filters[1], out_channels=nb_filters[2], kernel_size=5, stride=strides[2], norm=norm_type, leaky=True) self.layer4 = ConvNorm(in_channels=nb_filters[2], out_channels=nb_filters[3], kernel_size=5, stride=strides[3], norm=norm_type, leaky=True) self.layer5 = ConvNorm(in_channels=nb_filters[3], out_channels=nb_filters[4], kernel_size=5, stride=strides[4], norm=norm_type, leaky=True) self.attn4 = PromptAttentionModule(in_channels=nb_filters[3], prompt_channels=512, mid_channels=nb_filters[3] // 4) self.attn5 = PromptAttentionModule(in_channels=nb_filters[4], prompt_channels=512, mid_channels=nb_filters[4] // 4) self.dense_pred = ConvNorm(in_channels=nb_filters[4], out_channels=1, kernel_size=3, stride=1, norm='NONE', activation=False) def forward(self, inputs, prompt_in): x1 = self.layer1(inputs) x2 = self.layer2(x1) x3 = self.layer3(x2) x4 = self.layer4(x3) x4 = self.attn4(x4, prompt_in) x5 = self.layer5(x4) x5 = self.attn5(x5, prompt_in) output = self.dense_pred(x5) output_list = [x1, x2, x3, x4, x5, output] return output_list class MultiScaleDiscriminator(nn.Module): def __init__(self, in_channels, norm='INSTANCE', num_D=3): super(MultiScaleDiscriminator, self).__init__() self.num_D = num_D module = PatchDiscriminator for i in range(num_D): netD = module(in_channels, norm) setattr(self, 'layer' + str(i), netD) self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) def singleD_forward(self, model, input): return model(input) def forward(self, input): num_D = self.num_D result = [] input_downsampled = input for i in range(num_D): model = getattr(self, 'layer' + str(num_D - 1 - i)) result.append(self.singleD_forward(model, input_downsampled)) if i != (num_D - 1): input_downsampled = self.downsample(input_downsampled) return result class PromptMultiScaleDiscriminator(nn.Module): def __init__(self, in_channels, norm='INSTANCE', num_D=3): super(PromptMultiScaleDiscriminator, self).__init__() self.num_D = num_D module = PromptPatchDiscriminator for i in range(num_D): netD = module(in_channels, norm) setattr(self, 'layer' + str(i), netD) self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) def singleD_forward(self, model, input, prompt_in): return model(input, prompt_in) def forward(self, input, prompt_in): num_D = self.num_D result = [] input_downsampled = input for i in range(num_D): model = getattr(self, 'layer' + str(num_D - 1 - i)) result.append(self.singleD_forward(model, input_downsampled, prompt_in)) if i != (num_D - 1): input_downsampled = self.downsample(input_downsampled) return result class HighResEnhancer(nn.Module): """ Design a global-local network for high res generation and enhance boundary information. """ def __init__(self, model_name: str = None, in_channels: int = 1, out_channels: int = None, coarse_channels: tuple = (16, 32, 64, 128, 256), channels: tuple = (32, 64, 128, 256, 512), use_dropout: bool = False, norm: str = 'INSTANCE', leaky: bool = True, use_dilated_bottleneck: bool = False): super().__init__() # define basic blocks self.norm = norm self.leaky = leaky norm_layer = self.get_norm_layer() act_layer = self.get_act_layer() res_unit = ResBlock if channels[-1] <= 512 else ResBottleneck # check input channels assert channels[1] == coarse_channels[2], 'The number of channel-2 for coarse and number of channel-1 for fine branch should be the same.' # downsample and edge information extraction: # the downsample operation provides the input for coarse branch self.downsample = nn.AvgPool2d(3, stride=2, padding=1) # the sobel filter is operated on the downsampled image to provide edge information self.sobel = SobelEdge(input_dim=2, channels=in_channels) self.sobel_conv = nn.Sequential( nn.Conv2d(in_channels, channels[0], kernel_size=3, stride=2, padding=1), norm_layer(channels[0]), act_layer() ) # coarse generator: in_channels -> coarse_channels[2] # input stride: 0 # output stride: 4 (as input is already 2x downsampled) self.coarse = nn.Sequential( nn.Conv2d(in_channels, coarse_channels[0], kernel_size=3, stride=2, padding=1), norm_layer(coarse_channels[0]), act_layer(), res_unit(coarse_channels[0], coarse_channels[1], stride=2), res_unit(coarse_channels[1], coarse_channels[2], stride=2), res_unit(coarse_channels[2], coarse_channels[3], stride=2), res_unit(coarse_channels[3], coarse_channels[4], stride=1), nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), res_unit(coarse_channels[4], coarse_channels[3], stride=1), nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), res_unit(coarse_channels[3], coarse_channels[2], stride=1), nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), res_unit(coarse_channels[2], coarse_channels[2], stride=1), ) # fine generator: used to enhance the generation for better details # 1. simple encoder: channels[0] -> channels[1] # input stride: 0 # output stride: 4 self.fine_encoder = nn.Sequential( nn.Conv2d(in_channels, channels[0], kernel_size=3, stride=2, padding=1), norm_layer(channels[0]), act_layer(), nn.Conv2d(channels[0], channels[1], kernel_size=3, stride=2, padding=1), norm_layer(channels[1]), act_layer() ) # 2. bottleneck: channels[1] -> channels[4] # input stride: 4 # output stride: 16 self.bottleneck = nn.Sequential( res_unit(channels[1], channels[2], stride=2), res_unit(channels[2], channels[3], stride=2), res_unit(channels[3], channels[4], stride=1), res_unit(channels[4], channels[4], stride=1), ) if use_dilated_bottleneck: self.bottleneck.add_module('dilated_block_1', nn.Sequential( nn.Conv2d(channels[4], channels[4], kernel_size=3, stride=1, padding=1, dilation=1), norm_layer(channels[4]), act_layer() )) self.bottleneck.add_module('dilated_block_2', nn.Sequential( nn.Conv2d(channels[4], channels[4], kernel_size=3, stride=1, padding=2, dilation=2), norm_layer(channels[4]), act_layer() )) self.bottleneck.add_module('dilated_block_3', nn.Sequential( nn.Conv2d(channels[4], channels[4], kernel_size=3, stride=1, padding=5, dilation=5), norm_layer(channels[4]), act_layer() )) self.bottleneck.add_module('dilated_block_4', nn.Sequential( nn.Conv2d(channels[4], channels[4], kernel_size=3, stride=1, padding=1, dilation=1), norm_layer(channels[4]), act_layer() )) self.bottleneck.add_module('dilated_block_5', nn.Sequential( nn.Conv2d(channels[4], channels[4], kernel_size=3, stride=1, padding=2, dilation=2), norm_layer(channels[4]), act_layer() )) self.bottleneck.add_module('dilated_block_6', nn.Sequential( nn.Conv2d(channels[4], channels[4], kernel_size=3, stride=1, padding=5, dilation=5), norm_layer(channels[4]), act_layer() )) # 3. simple decoder: channels[4] -> channels[0] # input stride: 16 # output stride: 2 self.decoder = nn.Sequential( res_unit(channels[4], channels[3], stride=1), nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), res_unit(channels[3], channels[2], stride=1), nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), res_unit(channels[2], channels[1], stride=1), nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), res_unit(channels[1], channels[0], stride=1), ) # output operation that combines both feature branch and edge branch # input stride: 2 # output stride: 0 self.output = nn.Sequential( nn.Conv2d(2 * channels[0], channels[0], kernel_size=3, stride=1, padding=1), norm_layer(channels[0]), act_layer(), nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), nn.Conv2d(channels[0], out_channels, kernel_size=1, stride=1, bias=False) ) def get_norm_layer(self): if self.norm == 'INSTANCE': return partial(nn.InstanceNorm2d, affine=False) elif self.norm == 'BATCH': return partial(nn.BatchNorm2d, affine=True, track_running_stats=True) elif self.norm == 'GROUP': return partial(nn.GroupNorm, num_groups=8) else: raise NotImplementedError(f'Normalization layer {self.norm} is not implemented.') def get_act_layer(self): if self.leaky: return partial(nn.LeakyReLU, inplace=False) else: return partial(nn.ReLU, inplace=False) def forward(self, inputs): """ Args: inputs: (B, C, H, W), input IMC image """ # downsample and edge information extraction downsampled = self.downsample(inputs) # 0 -> 2x stride edge = self.sobel(inputs) edge = self.sobel_conv(edge) # coarse generator coarse = self.coarse(downsampled) # 2x stride -> 4x stride # fine generator fine = self.fine_encoder(inputs) # 0x stride -> 4x stride # add coarse and fine information together fine = self.bottleneck(fine + coarse) # 4x stride -> 16x stride fine = self.decoder(fine) # 16x stride -> 2x stride # output operation output = self.output(torch.cat([edge, fine], dim=1)) return output