Mbi2Spi / models /modules /networks.py
hsiangyualex's picture
Upload 64 files
f97a499 verified
raw
history blame
34.3 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from base.base_modules import *
from timm.models import create_model
from functools import partial
class Backbone(nn.Module):
"""
Model backbone to extract features
"""
def __init__(self,
input_channels: int = 3,
channels: tuple = (32, 64, 128, 256, 512),
strides: tuple = (2, 2, 2, 2),
use_dropout: bool = False,
norm: str = 'BATCH',
leaky: bool = True):
"""
Args:
input_channels: the number of input channels
channels: length-5 tuple, define the number of channels in each stage
strides: tuple, define the stride in each stage
use_dropout: bool, whether to use dropout
norm: str, normalization type
leaky: bool, whether to use leaky relu
"""
super().__init__()
self.nb_filter = channels
self.strides = strides + (5 - len(strides)) * (1,)
res_unit = ResBlock if channels[-1] <= 320 else ResBottleneck
self.conv0_0 = nn.Sequential(
nn.Conv2d(input_channels, channels[0], kernel_size=7, stride=self.strides[0], padding=3),
nn.GroupNorm(1, channels[0]) if norm == 'GROUP' else nn.BatchNorm2d(channels[0]) if norm == 'BATCH' else nn.InstanceNorm2d(channels[0]),
nn.LeakyReLU() if leaky else nn.ReLU(),
)
self.conv1_0 = res_unit(self.nb_filter[0], self.nb_filter[1], self.strides[1], use_dropout=use_dropout, norm=norm, leaky=leaky)
self.conv2_0 = res_unit(self.nb_filter[1], self.nb_filter[2], self.strides[2], use_dropout=use_dropout, norm=norm, leaky=leaky)
self.conv3_0 = res_unit(self.nb_filter[2], self.nb_filter[3], self.strides[3], use_dropout=use_dropout, norm=norm, leaky=leaky)
self.conv4_0 = res_unit(self.nb_filter[3], self.nb_filter[4], self.strides[4], use_dropout=use_dropout, norm=norm, leaky=leaky)
def forward(self, x):
x0_0 = self.conv0_0(x)
x1_0 = self.conv1_0(x0_0)
x2_0 = self.conv2_0(x1_0)
x3_0 = self.conv3_0(x2_0)
x4_0 = self.conv4_0(x3_0)
return x0_0, x1_0, x2_0, x3_0, x4_0
class TimmBackbone(nn.Module):
"""
Timm backbone to extract features, utilizing pretrained weights
"""
def __init__(self, model_name) -> None:
super().__init__()
self.backbone = create_model(model_name, pretrained=True, features_only=True)
self.determine_nb_filters()
def determine_nb_filters(self):
dummy = torch.randn(1, 3, 256, 256)
out = self.backbone(dummy)
nb_filters = []
for o in out:
nb_filters.append(o.size(1))
self.nb_filter = nb_filters
def forward(self, inputs):
return self.backbone(inputs)
class UNet(nn.Module):
def __init__(self,
model_name: str = None,
in_channels: int = 1,
out_channels: int = None,
channels: tuple = (64, 128, 256, 320, 512),
strides: tuple = (2, 2, 2, 2, 2),
use_dropout: bool = False,
norm: str = 'INSTANCE',
leaky: bool = True,
use_dilated_bottleneck: bool = False):
"""
Args:
model_name: timm model name
input_channels: the number of input channels
in_channels: the number of output channels
channels: length-5 tuple, define the number of channels in each stage
strides: tuple, define the stride in each stage
use_dropout: bool, whether to use dropout
norm: str, normalization type
leaky: bool, whether to use leaky relu
"""
super().__init__()
if model_name not in [None, 'none', 'None']:
# use Timm backbone and overrides any other input arguments
self.backbone = TimmBackbone(model_name)
else:
self.backbone = Backbone(input_channels=in_channels, channels=channels, strides=strides,
use_dropout=use_dropout, norm=norm, leaky=leaky)
nb_filter = self.backbone.nb_filter
res_unit = ResBlock if nb_filter[-1] <= 512 else ResBottleneck
# decoder
self.conv3_1 = res_unit(nb_filter[3] + nb_filter[4], nb_filter[3], use_dropout=use_dropout, norm=norm, leaky=leaky)
self.conv2_2 = res_unit(nb_filter[2] + nb_filter[3], nb_filter[2], use_dropout=use_dropout, norm=norm, leaky=leaky)
self.conv1_3 = res_unit(nb_filter[1] + nb_filter[2], nb_filter[1], use_dropout=use_dropout, norm=norm, leaky=leaky)
self.conv0_4 = res_unit(nb_filter[0] + nb_filter[1], nb_filter[0], use_dropout=use_dropout, norm=norm, leaky=leaky)
# dilated bottleneck: optional
if use_dilated_bottleneck:
self.dilation = nn.Sequential(
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=1),
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
nn.LeakyReLU() if leaky else nn.ReLU(),
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=2, dilation=2),
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
nn.LeakyReLU() if leaky else nn.ReLU(),
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=5, dilation=5),
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
nn.LeakyReLU() if leaky else nn.ReLU(),
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=1),
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
nn.LeakyReLU() if leaky else nn.ReLU(),
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=2, dilation=2),
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
nn.LeakyReLU() if leaky else nn.ReLU(),
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=5, dilation=5),
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
nn.LeakyReLU() if leaky else nn.ReLU(),
)
else:
self.dilation = nn.Identity()
if out_channels is not None:
self.convds0 = nn.Conv2d(nb_filter[0], out_channels, kernel_size=1, bias=False)
else:
self.convds0 = None
def upsample(self, inputs, target):
return F.interpolate(inputs, size=target.shape[2:], mode='bilinear', align_corners=False)
def extract_features(self, x):
x0, x1, x2, x3, x4 = self.backbone(x)
x4 = self.dilation(x4)
x3_1 = self.conv3_1(torch.cat([x3, self.upsample(x4, x3)], dim=1))
x2_2 = self.conv2_2(torch.cat([x2, self.upsample(x3_1, x2)], dim=1))
x1_3 = self.conv1_3(torch.cat([x1, self.upsample(x2_2, x1)], dim=1))
x0_4 = self.conv0_4(torch.cat([x0, self.upsample(x1_3, x0)], dim=1))
return x4, x0_4
def forward(self, x):
size = x.shape[2:]
x0, x1, x2, x3, x4 = self.backbone(x)
x4 = self.dilation(x4)
x3_1 = self.conv3_1(torch.cat([x3, self.upsample(x4, x3)], dim=1))
x2_2 = self.conv2_2(torch.cat([x2, self.upsample(x3_1, x2)], dim=1))
x1_3 = self.conv1_3(torch.cat([x1, self.upsample(x2_2, x1)], dim=1))
x0_4 = self.conv0_4(torch.cat([x0, self.upsample(x1_3, x0)], dim=1))
if self.convds0 is not None:
x_out = self.convds0(x0_4)
out = F.interpolate(x_out, size=size, mode='bilinear', align_corners=False)
else:
out = x0_4
return out
def freeze(self):
# freeze the network
for p in self.parameters():
p.requires_grad = False
def unfreeze(self):
# unfreeze the network to allow parameter update
for p in self.parameters():
p.requires_grad = True
class PromptAttentionUNet(nn.Module):
def __init__(self,
model_name: str = None,
in_channels: int = 1,
out_channels: int = None,
channels: tuple = (64, 128, 256, 320, 512),
strides: tuple = (2, 2, 2, 2, 2),
use_dropout: bool = False,
norm: str = 'INSTANCE',
leaky: bool = True,
use_dilated_bottleneck: bool = False):
"""
Args:
model_name: timm model name
input_channels: the number of input channels
in_channels: the number of output channels
channels: length-5 tuple, define the number of channels in each stage
strides: tuple, define the stride in each stage
use_dropout: bool, whether to use dropout
norm: str, normalization type
leaky: bool, whether to use leaky relu
"""
super().__init__()
if model_name not in [None, 'none', 'None']:
# use Timm backbone and overrides any other input arguments
self.backbone = TimmBackbone(model_name)
else:
self.backbone = Backbone(input_channels=in_channels, channels=channels, strides=strides,
use_dropout=use_dropout, norm=norm, leaky=leaky)
nb_filter = self.backbone.nb_filter
res_unit = PromptResBlock if nb_filter[-1] <= 512 else PromptResBottleneck
# decoder
self.conv3_1 = res_unit(nb_filter[3] + nb_filter[4], nb_filter[3], use_dropout=use_dropout, norm=norm, leaky=leaky)
self.conv2_2 = res_unit(nb_filter[2] + nb_filter[3], nb_filter[2], use_dropout=use_dropout, norm=norm, leaky=leaky)
self.conv1_3 = res_unit(nb_filter[1] + nb_filter[2], nb_filter[1], use_dropout=use_dropout, norm=norm, leaky=leaky)
self.conv0_4 = res_unit(nb_filter[0] + nb_filter[1], nb_filter[0], use_dropout=use_dropout, norm=norm, leaky=leaky)
# dilated bottleneck: optional
if use_dilated_bottleneck:
self.dilation = nn.Sequential(
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=1),
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
nn.LeakyReLU() if leaky else nn.ReLU(),
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=2),
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
nn.LeakyReLU() if leaky else nn.ReLU(),
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=5),
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
nn.LeakyReLU() if leaky else nn.ReLU(),
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=1),
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
nn.LeakyReLU() if leaky else nn.ReLU(),
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=2),
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
nn.LeakyReLU() if leaky else nn.ReLU(),
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=5),
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
nn.LeakyReLU() if leaky else nn.ReLU(),
)
else:
self.dilation = nn.Identity()
if out_channels is not None:
self.convds0 = nn.Conv2d(nb_filter[0], out_channels, kernel_size=1, bias=False)
def upsample(self, inputs, target):
return F.interpolate(inputs, size=target.shape[2:], mode='bilinear', align_corners=False)
def extract_features(self, x):
x0, x1, x2, x3, x4 = self.backbone(x)
x4 = self.dilation(x4)
x3_1 = self.conv3_1(torch.cat([x3, self.upsample(x4, x3)], dim=1))
x2_2 = self.conv2_2(torch.cat([x2, self.upsample(x3_1, x2)], dim=1))
x1_3 = self.conv1_3(torch.cat([x1, self.upsample(x2_2, x1)], dim=1))
x0_4 = self.conv0_4(torch.cat([x0, self.upsample(x1_3, x0)], dim=1))
return x4, x0_4
def forward(self, x, prompt_in):
size = x.shape[2:]
x0, x1, x2, x3, x4 = self.backbone(x)
x4 = self.dilation(x4)
x3_1 = self.conv3_1(torch.cat([x3, self.upsample(x4, x3)], dim=1), prompt_in)
x2_2 = self.conv2_2(torch.cat([x2, self.upsample(x3_1, x2)], dim=1), prompt_in)
x1_3 = self.conv1_3(torch.cat([x1, self.upsample(x2_2, x1)], dim=1), prompt_in)
x0_4 = self.conv0_4(torch.cat([x0, self.upsample(x1_3, x0)], dim=1), prompt_in)
x_out = self.convds0(x0_4)
out = F.interpolate(x_out, size=size, mode='bilinear', align_corners=False)
return out
def freeze(self):
# freeze the network
for p in self.parameters():
p.requires_grad = False
def unfreeze(self):
# unfreeze the network to allow parameter update
for p in self.parameters():
p.requires_grad = True
class CLIPDrivenUNet(nn.Module):
def __init__(self, encoding: str, model_name: str = None, in_channels: int = 1, out_channels: int = 1, channels: tuple = (32, 64, 128, 256, 512),
strides: tuple = (2, 2, 2, 2, 2), norm: str = 'INSTANCE', leaky: bool = True) -> None:
super().__init__()
self.encoding = encoding
self.num_classes = out_channels
self.backbone = UNet(model_name=model_name, in_channels=in_channels, out_channels=None, channels=channels,
strides=strides, use_dropout=False, norm=norm, leaky=leaky)
self.gap = nn.AdaptiveAvgPool2d(1)
self.precls_conv = nn.Sequential(
nn.InstanceNorm2d(32),
nn.LeakyReLU(),
nn.Conv2d(32, 8, kernel_size=1)
)
self.weight_nums = [8*8, 8*8, 8*1]
self.bias_nums = [8, 8, 1]
self.controller = nn.Conv2d(256 + channels[-1], sum(self.weight_nums + self.bias_nums), kernel_size=1, stride=1, padding=0)
if encoding == 'CLIP':
self.register_buffer('protein_embedding', torch.randn(self.num_classes, 512))
self.text_to_vision = nn.Linear(512, 256)
elif encoding == 'RAND':
self.register_buffer('protein_embedding', torch.randn(self.num_classes, 256))
def parse_dynamic_params(self, params, channels, weight_nums, bias_nums):
assert params.dim() == 2
assert len(weight_nums) == len(bias_nums)
assert params.size(1) == sum(weight_nums) + sum(bias_nums)
num_insts = params.size(0)
num_layers = len(weight_nums)
params_splits = list(torch.split_with_sizes(
params, weight_nums + bias_nums, dim=1
))
weight_splits = params_splits[:num_layers]
bias_splits = params_splits[num_layers:]
for l in range(num_layers):
if l < num_layers - 1:
weight_splits[l] = weight_splits[l].reshape(num_insts * channels, -1, 1, 1)
bias_splits[l] = bias_splits[l].reshape(num_insts * channels)
else:
weight_splits[l] = weight_splits[l].reshape(num_insts * 1, -1, 1, 1)
bias_splits[l] = bias_splits[l].reshape(num_insts * 1)
# print(weight_splits[l].shape, bias_splits[l].shape)
return weight_splits, bias_splits
def heads_forward(self, features, weights, biases, num_insts):
n_layers = len(weights)
x = features
for i, (w, b) in enumerate(zip(weights, biases)):
x = F.conv2d(
x, w, bias=b,
stride=1, padding=0,
groups=num_insts
)
if i < n_layers - 1:
x = F.leaky_relu(x)
return x
def forward(self, x_in):
out_shape = x_in.shape[2:]
dec4, out = self.backbone.extract_features(x_in) # dec4: (B, channels[-1], H, W), out: (B, channels[0], H, W)
if self.encoding == 'RAND':
task_encoding = self.protein_embedding[..., None, None] # (num_classes, 256, 1, 1)
elif self.encoding == 'CLIP':
task_encoding = F.leaky_relu(self.text_to_vision(self.protein_embedding))[..., None, None] # (num_classes, 256, 1, 1)
else:
raise NotImplementedError
x_feat = self.gap(dec4)
b = x_feat.shape[0]
logits_array = []
for i in range(b):
x_cond = torch.cat([x_feat[i].unsqueeze(0).repeat(self.num_classes, 1, 1, 1), task_encoding], 1)
params = self.controller(x_cond) # (num_classes, num_params, 1, 1)
params.squeeze_(-1).squeeze_(-1) # (num_classes, num_params)
head_inputs = self.precls_conv(out[i].unsqueeze(0))
head_inputs = head_inputs.repeat(self.num_classes, 1, 1, 1) # (num_classes, 8, H, W)
N, _, H, W = head_inputs.size()
head_inputs = head_inputs.reshape(1, -1, H, W)
# print(head_inputs.shape, params.shape)
weights, biases = self.parse_dynamic_params(params, 8, self.weight_nums, self.bias_nums)
logits = self.heads_forward(head_inputs, weights, biases, N)
logits_array.append(logits.reshape(1, -1, H, W))
out = torch.cat(logits_array, dim=0)
out = F.interpolate(out, size=out_shape, mode='bilinear', align_corners=False)
# print(out.shape)
return out
class NLayerDiscriminator(nn.Module):
"""Defines a PatchGAN discriminator"""
def __init__(self, input_nc, norm='INSTANCE', ndf=64, n_layers=3):
"""Construct a PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
norm_layer -- normalization layer
"""
super(NLayerDiscriminator, self).__init__()
norm_layer = norm_dict[norm]
use_bias = norm_layer == nn.InstanceNorm2d
kw = 4
padw = 1
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
self.model = nn.Sequential(*sequence)
def forward(self, input):
"""Standard forward."""
return self.model(input)
class PatchDiscriminator(nn.Module):
def __init__(self, in_channels, norm_type='INSTANCE'):
super().__init__()
nb_filters = [32, 64, 128, 256, 512]
strides = [2, 2, 2, 2, 2]
self.layer1 = ConvNorm(in_channels=in_channels, out_channels=nb_filters[0], kernel_size=5, stride=strides[0], norm='NONE', leaky=True)
self.layer2 = ConvNorm(in_channels=nb_filters[0], out_channels=nb_filters[1], kernel_size=5, stride=strides[1], norm=norm_type, leaky=True)
self.layer3 = ConvNorm(in_channels=nb_filters[1], out_channels=nb_filters[2], kernel_size=5, stride=strides[2], norm=norm_type, leaky=True)
self.layer4 = ConvNorm(in_channels=nb_filters[2], out_channels=nb_filters[3], kernel_size=5, stride=strides[3], norm=norm_type, leaky=True)
self.layer5 = ConvNorm(in_channels=nb_filters[3], out_channels=nb_filters[4], kernel_size=5, stride=strides[4], norm=norm_type, leaky=True)
self.dense_pred = ConvNorm(in_channels=nb_filters[4], out_channels=1, kernel_size=3, stride=1, norm='NONE', activation=False)
def forward(self, inputs):
x1 = self.layer1(inputs)
x2 = self.layer2(x1)
x3 = self.layer3(x2)
x4 = self.layer4(x3)
x5 = self.layer5(x4)
output = self.dense_pred(x5)
output_list = [x1, x2, x3, x4, x5, output]
return output_list
class PromptPatchDiscriminator(nn.Module):
def __init__(self, in_channels, norm_type='INSTANCE'):
super().__init__()
nb_filters = [32, 64, 128, 256, 512]
strides = [2, 2, 2, 2, 2]
self.layer1 = ConvNorm(in_channels=in_channels, out_channels=nb_filters[0], kernel_size=5, stride=strides[0], norm='NONE', leaky=True)
self.layer2 = ConvNorm(in_channels=nb_filters[0], out_channels=nb_filters[1], kernel_size=5, stride=strides[1], norm=norm_type, leaky=True)
self.layer3 = ConvNorm(in_channels=nb_filters[1], out_channels=nb_filters[2], kernel_size=5, stride=strides[2], norm=norm_type, leaky=True)
self.layer4 = ConvNorm(in_channels=nb_filters[2], out_channels=nb_filters[3], kernel_size=5, stride=strides[3], norm=norm_type, leaky=True)
self.layer5 = ConvNorm(in_channels=nb_filters[3], out_channels=nb_filters[4], kernel_size=5, stride=strides[4], norm=norm_type, leaky=True)
self.attn4 = PromptAttentionModule(in_channels=nb_filters[3], prompt_channels=512, mid_channels=nb_filters[3] // 4)
self.attn5 = PromptAttentionModule(in_channels=nb_filters[4], prompt_channels=512, mid_channels=nb_filters[4] // 4)
self.dense_pred = ConvNorm(in_channels=nb_filters[4], out_channels=1, kernel_size=3, stride=1, norm='NONE', activation=False)
def forward(self, inputs, prompt_in):
x1 = self.layer1(inputs)
x2 = self.layer2(x1)
x3 = self.layer3(x2)
x4 = self.layer4(x3)
x4 = self.attn4(x4, prompt_in)
x5 = self.layer5(x4)
x5 = self.attn5(x5, prompt_in)
output = self.dense_pred(x5)
output_list = [x1, x2, x3, x4, x5, output]
return output_list
class MultiScaleDiscriminator(nn.Module):
def __init__(self, in_channels, norm='INSTANCE', num_D=3):
super(MultiScaleDiscriminator, self).__init__()
self.num_D = num_D
module = PatchDiscriminator
for i in range(num_D):
netD = module(in_channels, norm)
setattr(self, 'layer' + str(i), netD)
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
def singleD_forward(self, model, input):
return model(input)
def forward(self, input):
num_D = self.num_D
result = []
input_downsampled = input
for i in range(num_D):
model = getattr(self, 'layer' + str(num_D - 1 - i))
result.append(self.singleD_forward(model, input_downsampled))
if i != (num_D - 1):
input_downsampled = self.downsample(input_downsampled)
return result
class PromptMultiScaleDiscriminator(nn.Module):
def __init__(self, in_channels, norm='INSTANCE', num_D=3):
super(PromptMultiScaleDiscriminator, self).__init__()
self.num_D = num_D
module = PromptPatchDiscriminator
for i in range(num_D):
netD = module(in_channels, norm)
setattr(self, 'layer' + str(i), netD)
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
def singleD_forward(self, model, input, prompt_in):
return model(input, prompt_in)
def forward(self, input, prompt_in):
num_D = self.num_D
result = []
input_downsampled = input
for i in range(num_D):
model = getattr(self, 'layer' + str(num_D - 1 - i))
result.append(self.singleD_forward(model, input_downsampled, prompt_in))
if i != (num_D - 1):
input_downsampled = self.downsample(input_downsampled)
return result
class HighResEnhancer(nn.Module):
"""
Design a global-local network for high res generation and enhance boundary information.
"""
def __init__(self,
model_name: str = None,
in_channels: int = 1,
out_channels: int = None,
coarse_channels: tuple = (16, 32, 64, 128, 256),
channels: tuple = (32, 64, 128, 256, 512),
use_dropout: bool = False,
norm: str = 'INSTANCE',
leaky: bool = True,
use_dilated_bottleneck: bool = False):
super().__init__()
# define basic blocks
self.norm = norm
self.leaky = leaky
norm_layer = self.get_norm_layer()
act_layer = self.get_act_layer()
res_unit = ResBlock if channels[-1] <= 512 else ResBottleneck
# check input channels
assert channels[1] == coarse_channels[2], 'The number of channel-2 for coarse and number of channel-1 for fine branch should be the same.'
# downsample and edge information extraction:
# the downsample operation provides the input for coarse branch
self.downsample = nn.AvgPool2d(3, stride=2, padding=1)
# the sobel filter is operated on the downsampled image to provide edge information
self.sobel = SobelEdge(input_dim=2, channels=in_channels)
self.sobel_conv = nn.Sequential(
nn.Conv2d(in_channels, channels[0], kernel_size=3, stride=2, padding=1),
norm_layer(channels[0]),
act_layer()
)
# coarse generator: in_channels -> coarse_channels[2]
# input stride: 0
# output stride: 4 (as input is already 2x downsampled)
self.coarse = nn.Sequential(
nn.Conv2d(in_channels, coarse_channels[0], kernel_size=3, stride=2, padding=1),
norm_layer(coarse_channels[0]),
act_layer(),
res_unit(coarse_channels[0], coarse_channels[1], stride=2),
res_unit(coarse_channels[1], coarse_channels[2], stride=2),
res_unit(coarse_channels[2], coarse_channels[3], stride=2),
res_unit(coarse_channels[3], coarse_channels[4], stride=1),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
res_unit(coarse_channels[4], coarse_channels[3], stride=1),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
res_unit(coarse_channels[3], coarse_channels[2], stride=1),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
res_unit(coarse_channels[2], coarse_channels[2], stride=1),
)
# fine generator: used to enhance the generation for better details
# 1. simple encoder: channels[0] -> channels[1]
# input stride: 0
# output stride: 4
self.fine_encoder = nn.Sequential(
nn.Conv2d(in_channels, channels[0], kernel_size=3, stride=2, padding=1),
norm_layer(channels[0]),
act_layer(),
nn.Conv2d(channels[0], channels[1], kernel_size=3, stride=2, padding=1),
norm_layer(channels[1]),
act_layer()
)
# 2. bottleneck: channels[1] -> channels[4]
# input stride: 4
# output stride: 16
self.bottleneck = nn.Sequential(
res_unit(channels[1], channels[2], stride=2),
res_unit(channels[2], channels[3], stride=2),
res_unit(channels[3], channels[4], stride=1),
res_unit(channels[4], channels[4], stride=1),
)
if use_dilated_bottleneck:
self.bottleneck.add_module('dilated_block_1',
nn.Sequential(
nn.Conv2d(channels[4], channels[4], kernel_size=3, stride=1, padding=1, dilation=1),
norm_layer(channels[4]),
act_layer()
))
self.bottleneck.add_module('dilated_block_2',
nn.Sequential(
nn.Conv2d(channels[4], channels[4], kernel_size=3, stride=1, padding=2, dilation=2),
norm_layer(channels[4]),
act_layer()
))
self.bottleneck.add_module('dilated_block_3',
nn.Sequential(
nn.Conv2d(channels[4], channels[4], kernel_size=3, stride=1, padding=5, dilation=5),
norm_layer(channels[4]),
act_layer()
))
self.bottleneck.add_module('dilated_block_4',
nn.Sequential(
nn.Conv2d(channels[4], channels[4], kernel_size=3, stride=1, padding=1, dilation=1),
norm_layer(channels[4]),
act_layer()
))
self.bottleneck.add_module('dilated_block_5',
nn.Sequential(
nn.Conv2d(channels[4], channels[4], kernel_size=3, stride=1, padding=2, dilation=2),
norm_layer(channels[4]),
act_layer()
))
self.bottleneck.add_module('dilated_block_6',
nn.Sequential(
nn.Conv2d(channels[4], channels[4], kernel_size=3, stride=1, padding=5, dilation=5),
norm_layer(channels[4]),
act_layer()
))
# 3. simple decoder: channels[4] -> channels[0]
# input stride: 16
# output stride: 2
self.decoder = nn.Sequential(
res_unit(channels[4], channels[3], stride=1),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
res_unit(channels[3], channels[2], stride=1),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
res_unit(channels[2], channels[1], stride=1),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
res_unit(channels[1], channels[0], stride=1),
)
# output operation that combines both feature branch and edge branch
# input stride: 2
# output stride: 0
self.output = nn.Sequential(
nn.Conv2d(2 * channels[0], channels[0], kernel_size=3, stride=1, padding=1),
norm_layer(channels[0]),
act_layer(),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(channels[0], out_channels, kernel_size=1, stride=1, bias=False)
)
def get_norm_layer(self):
if self.norm == 'INSTANCE':
return partial(nn.InstanceNorm2d, affine=False)
elif self.norm == 'BATCH':
return partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
elif self.norm == 'GROUP':
return partial(nn.GroupNorm, num_groups=8)
else:
raise NotImplementedError(f'Normalization layer {self.norm} is not implemented.')
def get_act_layer(self):
if self.leaky:
return partial(nn.LeakyReLU, inplace=False)
else:
return partial(nn.ReLU, inplace=False)
def forward(self, inputs):
"""
Args:
inputs: (B, C, H, W), input IMC image
"""
# downsample and edge information extraction
downsampled = self.downsample(inputs) # 0 -> 2x stride
edge = self.sobel(inputs)
edge = self.sobel_conv(edge)
# coarse generator
coarse = self.coarse(downsampled) # 2x stride -> 4x stride
# fine generator
fine = self.fine_encoder(inputs) # 0x stride -> 4x stride
# add coarse and fine information together
fine = self.bottleneck(fine + coarse) # 4x stride -> 16x stride
fine = self.decoder(fine) # 16x stride -> 2x stride
# output operation
output = self.output(torch.cat([edge, fine], dim=1))
return output