Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from base.base_modules import * | |
from timm.models import create_model | |
from functools import partial | |
class Backbone(nn.Module): | |
""" | |
Model backbone to extract features | |
""" | |
def __init__(self, | |
input_channels: int = 3, | |
channels: tuple = (32, 64, 128, 256, 512), | |
strides: tuple = (2, 2, 2, 2), | |
use_dropout: bool = False, | |
norm: str = 'BATCH', | |
leaky: bool = True): | |
""" | |
Args: | |
input_channels: the number of input channels | |
channels: length-5 tuple, define the number of channels in each stage | |
strides: tuple, define the stride in each stage | |
use_dropout: bool, whether to use dropout | |
norm: str, normalization type | |
leaky: bool, whether to use leaky relu | |
""" | |
super().__init__() | |
self.nb_filter = channels | |
self.strides = strides + (5 - len(strides)) * (1,) | |
res_unit = ResBlock if channels[-1] <= 320 else ResBottleneck | |
self.conv0_0 = nn.Sequential( | |
nn.Conv2d(input_channels, channels[0], kernel_size=7, stride=self.strides[0], padding=3), | |
nn.GroupNorm(1, channels[0]) if norm == 'GROUP' else nn.BatchNorm2d(channels[0]) if norm == 'BATCH' else nn.InstanceNorm2d(channels[0]), | |
nn.LeakyReLU() if leaky else nn.ReLU(), | |
) | |
self.conv1_0 = res_unit(self.nb_filter[0], self.nb_filter[1], self.strides[1], use_dropout=use_dropout, norm=norm, leaky=leaky) | |
self.conv2_0 = res_unit(self.nb_filter[1], self.nb_filter[2], self.strides[2], use_dropout=use_dropout, norm=norm, leaky=leaky) | |
self.conv3_0 = res_unit(self.nb_filter[2], self.nb_filter[3], self.strides[3], use_dropout=use_dropout, norm=norm, leaky=leaky) | |
self.conv4_0 = res_unit(self.nb_filter[3], self.nb_filter[4], self.strides[4], use_dropout=use_dropout, norm=norm, leaky=leaky) | |
def forward(self, x): | |
x0_0 = self.conv0_0(x) | |
x1_0 = self.conv1_0(x0_0) | |
x2_0 = self.conv2_0(x1_0) | |
x3_0 = self.conv3_0(x2_0) | |
x4_0 = self.conv4_0(x3_0) | |
return x0_0, x1_0, x2_0, x3_0, x4_0 | |
class TimmBackbone(nn.Module): | |
""" | |
Timm backbone to extract features, utilizing pretrained weights | |
""" | |
def __init__(self, model_name) -> None: | |
super().__init__() | |
self.backbone = create_model(model_name, pretrained=True, features_only=True) | |
self.determine_nb_filters() | |
def determine_nb_filters(self): | |
dummy = torch.randn(1, 3, 256, 256) | |
out = self.backbone(dummy) | |
nb_filters = [] | |
for o in out: | |
nb_filters.append(o.size(1)) | |
self.nb_filter = nb_filters | |
def forward(self, inputs): | |
return self.backbone(inputs) | |
class UNet(nn.Module): | |
def __init__(self, | |
model_name: str = None, | |
in_channels: int = 1, | |
out_channels: int = None, | |
channels: tuple = (64, 128, 256, 320, 512), | |
strides: tuple = (2, 2, 2, 2, 2), | |
use_dropout: bool = False, | |
norm: str = 'INSTANCE', | |
leaky: bool = True, | |
use_dilated_bottleneck: bool = False): | |
""" | |
Args: | |
model_name: timm model name | |
input_channels: the number of input channels | |
in_channels: the number of output channels | |
channels: length-5 tuple, define the number of channels in each stage | |
strides: tuple, define the stride in each stage | |
use_dropout: bool, whether to use dropout | |
norm: str, normalization type | |
leaky: bool, whether to use leaky relu | |
""" | |
super().__init__() | |
if model_name not in [None, 'none', 'None']: | |
# use Timm backbone and overrides any other input arguments | |
self.backbone = TimmBackbone(model_name) | |
else: | |
self.backbone = Backbone(input_channels=in_channels, channels=channels, strides=strides, | |
use_dropout=use_dropout, norm=norm, leaky=leaky) | |
nb_filter = self.backbone.nb_filter | |
res_unit = ResBlock if nb_filter[-1] <= 512 else ResBottleneck | |
# decoder | |
self.conv3_1 = res_unit(nb_filter[3] + nb_filter[4], nb_filter[3], use_dropout=use_dropout, norm=norm, leaky=leaky) | |
self.conv2_2 = res_unit(nb_filter[2] + nb_filter[3], nb_filter[2], use_dropout=use_dropout, norm=norm, leaky=leaky) | |
self.conv1_3 = res_unit(nb_filter[1] + nb_filter[2], nb_filter[1], use_dropout=use_dropout, norm=norm, leaky=leaky) | |
self.conv0_4 = res_unit(nb_filter[0] + nb_filter[1], nb_filter[0], use_dropout=use_dropout, norm=norm, leaky=leaky) | |
# dilated bottleneck: optional | |
if use_dilated_bottleneck: | |
self.dilation = nn.Sequential( | |
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=1), | |
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]), | |
nn.LeakyReLU() if leaky else nn.ReLU(), | |
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=2, dilation=2), | |
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]), | |
nn.LeakyReLU() if leaky else nn.ReLU(), | |
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=5, dilation=5), | |
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]), | |
nn.LeakyReLU() if leaky else nn.ReLU(), | |
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=1), | |
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]), | |
nn.LeakyReLU() if leaky else nn.ReLU(), | |
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=2, dilation=2), | |
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]), | |
nn.LeakyReLU() if leaky else nn.ReLU(), | |
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=5, dilation=5), | |
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]), | |
nn.LeakyReLU() if leaky else nn.ReLU(), | |
) | |
else: | |
self.dilation = nn.Identity() | |
if out_channels is not None: | |
self.convds0 = nn.Conv2d(nb_filter[0], out_channels, kernel_size=1, bias=False) | |
else: | |
self.convds0 = None | |
def upsample(self, inputs, target): | |
return F.interpolate(inputs, size=target.shape[2:], mode='bilinear', align_corners=False) | |
def extract_features(self, x): | |
x0, x1, x2, x3, x4 = self.backbone(x) | |
x4 = self.dilation(x4) | |
x3_1 = self.conv3_1(torch.cat([x3, self.upsample(x4, x3)], dim=1)) | |
x2_2 = self.conv2_2(torch.cat([x2, self.upsample(x3_1, x2)], dim=1)) | |
x1_3 = self.conv1_3(torch.cat([x1, self.upsample(x2_2, x1)], dim=1)) | |
x0_4 = self.conv0_4(torch.cat([x0, self.upsample(x1_3, x0)], dim=1)) | |
return x4, x0_4 | |
def forward(self, x): | |
size = x.shape[2:] | |
x0, x1, x2, x3, x4 = self.backbone(x) | |
x4 = self.dilation(x4) | |
x3_1 = self.conv3_1(torch.cat([x3, self.upsample(x4, x3)], dim=1)) | |
x2_2 = self.conv2_2(torch.cat([x2, self.upsample(x3_1, x2)], dim=1)) | |
x1_3 = self.conv1_3(torch.cat([x1, self.upsample(x2_2, x1)], dim=1)) | |
x0_4 = self.conv0_4(torch.cat([x0, self.upsample(x1_3, x0)], dim=1)) | |
if self.convds0 is not None: | |
x_out = self.convds0(x0_4) | |
out = F.interpolate(x_out, size=size, mode='bilinear', align_corners=False) | |
else: | |
out = x0_4 | |
return out | |
def freeze(self): | |
# freeze the network | |
for p in self.parameters(): | |
p.requires_grad = False | |
def unfreeze(self): | |
# unfreeze the network to allow parameter update | |
for p in self.parameters(): | |
p.requires_grad = True | |
class PromptAttentionUNet(nn.Module): | |
def __init__(self, | |
model_name: str = None, | |
in_channels: int = 1, | |
out_channels: int = None, | |
channels: tuple = (64, 128, 256, 320, 512), | |
strides: tuple = (2, 2, 2, 2, 2), | |
use_dropout: bool = False, | |
norm: str = 'INSTANCE', | |
leaky: bool = True, | |
use_dilated_bottleneck: bool = False): | |
""" | |
Args: | |
model_name: timm model name | |
input_channels: the number of input channels | |
in_channels: the number of output channels | |
channels: length-5 tuple, define the number of channels in each stage | |
strides: tuple, define the stride in each stage | |
use_dropout: bool, whether to use dropout | |
norm: str, normalization type | |
leaky: bool, whether to use leaky relu | |
""" | |
super().__init__() | |
if model_name not in [None, 'none', 'None']: | |
# use Timm backbone and overrides any other input arguments | |
self.backbone = TimmBackbone(model_name) | |
else: | |
self.backbone = Backbone(input_channels=in_channels, channels=channels, strides=strides, | |
use_dropout=use_dropout, norm=norm, leaky=leaky) | |
nb_filter = self.backbone.nb_filter | |
res_unit = PromptResBlock if nb_filter[-1] <= 512 else PromptResBottleneck | |
# decoder | |
self.conv3_1 = res_unit(nb_filter[3] + nb_filter[4], nb_filter[3], use_dropout=use_dropout, norm=norm, leaky=leaky) | |
self.conv2_2 = res_unit(nb_filter[2] + nb_filter[3], nb_filter[2], use_dropout=use_dropout, norm=norm, leaky=leaky) | |
self.conv1_3 = res_unit(nb_filter[1] + nb_filter[2], nb_filter[1], use_dropout=use_dropout, norm=norm, leaky=leaky) | |
self.conv0_4 = res_unit(nb_filter[0] + nb_filter[1], nb_filter[0], use_dropout=use_dropout, norm=norm, leaky=leaky) | |
# dilated bottleneck: optional | |
if use_dilated_bottleneck: | |
self.dilation = nn.Sequential( | |
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=1), | |
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]), | |
nn.LeakyReLU() if leaky else nn.ReLU(), | |
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=2), | |
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]), | |
nn.LeakyReLU() if leaky else nn.ReLU(), | |
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=5), | |
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]), | |
nn.LeakyReLU() if leaky else nn.ReLU(), | |
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=1), | |
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]), | |
nn.LeakyReLU() if leaky else nn.ReLU(), | |
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=2), | |
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]), | |
nn.LeakyReLU() if leaky else nn.ReLU(), | |
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=5), | |
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]), | |
nn.LeakyReLU() if leaky else nn.ReLU(), | |
) | |
else: | |
self.dilation = nn.Identity() | |
if out_channels is not None: | |
self.convds0 = nn.Conv2d(nb_filter[0], out_channels, kernel_size=1, bias=False) | |
def upsample(self, inputs, target): | |
return F.interpolate(inputs, size=target.shape[2:], mode='bilinear', align_corners=False) | |
def extract_features(self, x): | |
x0, x1, x2, x3, x4 = self.backbone(x) | |
x4 = self.dilation(x4) | |
x3_1 = self.conv3_1(torch.cat([x3, self.upsample(x4, x3)], dim=1)) | |
x2_2 = self.conv2_2(torch.cat([x2, self.upsample(x3_1, x2)], dim=1)) | |
x1_3 = self.conv1_3(torch.cat([x1, self.upsample(x2_2, x1)], dim=1)) | |
x0_4 = self.conv0_4(torch.cat([x0, self.upsample(x1_3, x0)], dim=1)) | |
return x4, x0_4 | |
def forward(self, x, prompt_in): | |
size = x.shape[2:] | |
x0, x1, x2, x3, x4 = self.backbone(x) | |
x4 = self.dilation(x4) | |
x3_1 = self.conv3_1(torch.cat([x3, self.upsample(x4, x3)], dim=1), prompt_in) | |
x2_2 = self.conv2_2(torch.cat([x2, self.upsample(x3_1, x2)], dim=1), prompt_in) | |
x1_3 = self.conv1_3(torch.cat([x1, self.upsample(x2_2, x1)], dim=1), prompt_in) | |
x0_4 = self.conv0_4(torch.cat([x0, self.upsample(x1_3, x0)], dim=1), prompt_in) | |
x_out = self.convds0(x0_4) | |
out = F.interpolate(x_out, size=size, mode='bilinear', align_corners=False) | |
return out | |
def freeze(self): | |
# freeze the network | |
for p in self.parameters(): | |
p.requires_grad = False | |
def unfreeze(self): | |
# unfreeze the network to allow parameter update | |
for p in self.parameters(): | |
p.requires_grad = True | |
class CLIPDrivenUNet(nn.Module): | |
def __init__(self, encoding: str, model_name: str = None, in_channels: int = 1, out_channels: int = 1, channels: tuple = (32, 64, 128, 256, 512), | |
strides: tuple = (2, 2, 2, 2, 2), norm: str = 'INSTANCE', leaky: bool = True) -> None: | |
super().__init__() | |
self.encoding = encoding | |
self.num_classes = out_channels | |
self.backbone = UNet(model_name=model_name, in_channels=in_channels, out_channels=None, channels=channels, | |
strides=strides, use_dropout=False, norm=norm, leaky=leaky) | |
self.gap = nn.AdaptiveAvgPool2d(1) | |
self.precls_conv = nn.Sequential( | |
nn.InstanceNorm2d(32), | |
nn.LeakyReLU(), | |
nn.Conv2d(32, 8, kernel_size=1) | |
) | |
self.weight_nums = [8*8, 8*8, 8*1] | |
self.bias_nums = [8, 8, 1] | |
self.controller = nn.Conv2d(256 + channels[-1], sum(self.weight_nums + self.bias_nums), kernel_size=1, stride=1, padding=0) | |
if encoding == 'CLIP': | |
self.register_buffer('protein_embedding', torch.randn(self.num_classes, 512)) | |
self.text_to_vision = nn.Linear(512, 256) | |
elif encoding == 'RAND': | |
self.register_buffer('protein_embedding', torch.randn(self.num_classes, 256)) | |
def parse_dynamic_params(self, params, channels, weight_nums, bias_nums): | |
assert params.dim() == 2 | |
assert len(weight_nums) == len(bias_nums) | |
assert params.size(1) == sum(weight_nums) + sum(bias_nums) | |
num_insts = params.size(0) | |
num_layers = len(weight_nums) | |
params_splits = list(torch.split_with_sizes( | |
params, weight_nums + bias_nums, dim=1 | |
)) | |
weight_splits = params_splits[:num_layers] | |
bias_splits = params_splits[num_layers:] | |
for l in range(num_layers): | |
if l < num_layers - 1: | |
weight_splits[l] = weight_splits[l].reshape(num_insts * channels, -1, 1, 1) | |
bias_splits[l] = bias_splits[l].reshape(num_insts * channels) | |
else: | |
weight_splits[l] = weight_splits[l].reshape(num_insts * 1, -1, 1, 1) | |
bias_splits[l] = bias_splits[l].reshape(num_insts * 1) | |
# print(weight_splits[l].shape, bias_splits[l].shape) | |
return weight_splits, bias_splits | |
def heads_forward(self, features, weights, biases, num_insts): | |
n_layers = len(weights) | |
x = features | |
for i, (w, b) in enumerate(zip(weights, biases)): | |
x = F.conv2d( | |
x, w, bias=b, | |
stride=1, padding=0, | |
groups=num_insts | |
) | |
if i < n_layers - 1: | |
x = F.leaky_relu(x) | |
return x | |
def forward(self, x_in): | |
out_shape = x_in.shape[2:] | |
dec4, out = self.backbone.extract_features(x_in) # dec4: (B, channels[-1], H, W), out: (B, channels[0], H, W) | |
if self.encoding == 'RAND': | |
task_encoding = self.protein_embedding[..., None, None] # (num_classes, 256, 1, 1) | |
elif self.encoding == 'CLIP': | |
task_encoding = F.leaky_relu(self.text_to_vision(self.protein_embedding))[..., None, None] # (num_classes, 256, 1, 1) | |
else: | |
raise NotImplementedError | |
x_feat = self.gap(dec4) | |
b = x_feat.shape[0] | |
logits_array = [] | |
for i in range(b): | |
x_cond = torch.cat([x_feat[i].unsqueeze(0).repeat(self.num_classes, 1, 1, 1), task_encoding], 1) | |
params = self.controller(x_cond) # (num_classes, num_params, 1, 1) | |
params.squeeze_(-1).squeeze_(-1) # (num_classes, num_params) | |
head_inputs = self.precls_conv(out[i].unsqueeze(0)) | |
head_inputs = head_inputs.repeat(self.num_classes, 1, 1, 1) # (num_classes, 8, H, W) | |
N, _, H, W = head_inputs.size() | |
head_inputs = head_inputs.reshape(1, -1, H, W) | |
# print(head_inputs.shape, params.shape) | |
weights, biases = self.parse_dynamic_params(params, 8, self.weight_nums, self.bias_nums) | |
logits = self.heads_forward(head_inputs, weights, biases, N) | |
logits_array.append(logits.reshape(1, -1, H, W)) | |
out = torch.cat(logits_array, dim=0) | |
out = F.interpolate(out, size=out_shape, mode='bilinear', align_corners=False) | |
# print(out.shape) | |
return out | |
class NLayerDiscriminator(nn.Module): | |
"""Defines a PatchGAN discriminator""" | |
def __init__(self, input_nc, norm='INSTANCE', ndf=64, n_layers=3): | |
"""Construct a PatchGAN discriminator | |
Parameters: | |
input_nc (int) -- the number of channels in input images | |
ndf (int) -- the number of filters in the last conv layer | |
n_layers (int) -- the number of conv layers in the discriminator | |
norm_layer -- normalization layer | |
""" | |
super(NLayerDiscriminator, self).__init__() | |
norm_layer = norm_dict[norm] | |
use_bias = norm_layer == nn.InstanceNorm2d | |
kw = 4 | |
padw = 1 | |
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] | |
nf_mult = 1 | |
nf_mult_prev = 1 | |
for n in range(1, n_layers): # gradually increase the number of filters | |
nf_mult_prev = nf_mult | |
nf_mult = min(2 ** n, 8) | |
sequence += [ | |
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), | |
norm_layer(ndf * nf_mult), | |
nn.LeakyReLU(0.2, True) | |
] | |
nf_mult_prev = nf_mult | |
nf_mult = min(2 ** n_layers, 8) | |
sequence += [ | |
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), | |
norm_layer(ndf * nf_mult), | |
nn.LeakyReLU(0.2, True) | |
] | |
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map | |
self.model = nn.Sequential(*sequence) | |
def forward(self, input): | |
"""Standard forward.""" | |
return self.model(input) | |
class PatchDiscriminator(nn.Module): | |
def __init__(self, in_channels, norm_type='INSTANCE'): | |
super().__init__() | |
nb_filters = [32, 64, 128, 256, 512] | |
strides = [2, 2, 2, 2, 2] | |
self.layer1 = ConvNorm(in_channels=in_channels, out_channels=nb_filters[0], kernel_size=5, stride=strides[0], norm='NONE', leaky=True) | |
self.layer2 = ConvNorm(in_channels=nb_filters[0], out_channels=nb_filters[1], kernel_size=5, stride=strides[1], norm=norm_type, leaky=True) | |
self.layer3 = ConvNorm(in_channels=nb_filters[1], out_channels=nb_filters[2], kernel_size=5, stride=strides[2], norm=norm_type, leaky=True) | |
self.layer4 = ConvNorm(in_channels=nb_filters[2], out_channels=nb_filters[3], kernel_size=5, stride=strides[3], norm=norm_type, leaky=True) | |
self.layer5 = ConvNorm(in_channels=nb_filters[3], out_channels=nb_filters[4], kernel_size=5, stride=strides[4], norm=norm_type, leaky=True) | |
self.dense_pred = ConvNorm(in_channels=nb_filters[4], out_channels=1, kernel_size=3, stride=1, norm='NONE', activation=False) | |
def forward(self, inputs): | |
x1 = self.layer1(inputs) | |
x2 = self.layer2(x1) | |
x3 = self.layer3(x2) | |
x4 = self.layer4(x3) | |
x5 = self.layer5(x4) | |
output = self.dense_pred(x5) | |
output_list = [x1, x2, x3, x4, x5, output] | |
return output_list | |
class PromptPatchDiscriminator(nn.Module): | |
def __init__(self, in_channels, norm_type='INSTANCE'): | |
super().__init__() | |
nb_filters = [32, 64, 128, 256, 512] | |
strides = [2, 2, 2, 2, 2] | |
self.layer1 = ConvNorm(in_channels=in_channels, out_channels=nb_filters[0], kernel_size=5, stride=strides[0], norm='NONE', leaky=True) | |
self.layer2 = ConvNorm(in_channels=nb_filters[0], out_channels=nb_filters[1], kernel_size=5, stride=strides[1], norm=norm_type, leaky=True) | |
self.layer3 = ConvNorm(in_channels=nb_filters[1], out_channels=nb_filters[2], kernel_size=5, stride=strides[2], norm=norm_type, leaky=True) | |
self.layer4 = ConvNorm(in_channels=nb_filters[2], out_channels=nb_filters[3], kernel_size=5, stride=strides[3], norm=norm_type, leaky=True) | |
self.layer5 = ConvNorm(in_channels=nb_filters[3], out_channels=nb_filters[4], kernel_size=5, stride=strides[4], norm=norm_type, leaky=True) | |
self.attn4 = PromptAttentionModule(in_channels=nb_filters[3], prompt_channels=512, mid_channels=nb_filters[3] // 4) | |
self.attn5 = PromptAttentionModule(in_channels=nb_filters[4], prompt_channels=512, mid_channels=nb_filters[4] // 4) | |
self.dense_pred = ConvNorm(in_channels=nb_filters[4], out_channels=1, kernel_size=3, stride=1, norm='NONE', activation=False) | |
def forward(self, inputs, prompt_in): | |
x1 = self.layer1(inputs) | |
x2 = self.layer2(x1) | |
x3 = self.layer3(x2) | |
x4 = self.layer4(x3) | |
x4 = self.attn4(x4, prompt_in) | |
x5 = self.layer5(x4) | |
x5 = self.attn5(x5, prompt_in) | |
output = self.dense_pred(x5) | |
output_list = [x1, x2, x3, x4, x5, output] | |
return output_list | |
class MultiScaleDiscriminator(nn.Module): | |
def __init__(self, in_channels, norm='INSTANCE', num_D=3): | |
super(MultiScaleDiscriminator, self).__init__() | |
self.num_D = num_D | |
module = PatchDiscriminator | |
for i in range(num_D): | |
netD = module(in_channels, norm) | |
setattr(self, 'layer' + str(i), netD) | |
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) | |
def singleD_forward(self, model, input): | |
return model(input) | |
def forward(self, input): | |
num_D = self.num_D | |
result = [] | |
input_downsampled = input | |
for i in range(num_D): | |
model = getattr(self, 'layer' + str(num_D - 1 - i)) | |
result.append(self.singleD_forward(model, input_downsampled)) | |
if i != (num_D - 1): | |
input_downsampled = self.downsample(input_downsampled) | |
return result | |
class PromptMultiScaleDiscriminator(nn.Module): | |
def __init__(self, in_channels, norm='INSTANCE', num_D=3): | |
super(PromptMultiScaleDiscriminator, self).__init__() | |
self.num_D = num_D | |
module = PromptPatchDiscriminator | |
for i in range(num_D): | |
netD = module(in_channels, norm) | |
setattr(self, 'layer' + str(i), netD) | |
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) | |
def singleD_forward(self, model, input, prompt_in): | |
return model(input, prompt_in) | |
def forward(self, input, prompt_in): | |
num_D = self.num_D | |
result = [] | |
input_downsampled = input | |
for i in range(num_D): | |
model = getattr(self, 'layer' + str(num_D - 1 - i)) | |
result.append(self.singleD_forward(model, input_downsampled, prompt_in)) | |
if i != (num_D - 1): | |
input_downsampled = self.downsample(input_downsampled) | |
return result | |
class HighResEnhancer(nn.Module): | |
""" | |
Design a global-local network for high res generation and enhance boundary information. | |
""" | |
def __init__(self, | |
model_name: str = None, | |
in_channels: int = 1, | |
out_channels: int = None, | |
coarse_channels: tuple = (16, 32, 64, 128, 256), | |
channels: tuple = (32, 64, 128, 256, 512), | |
use_dropout: bool = False, | |
norm: str = 'INSTANCE', | |
leaky: bool = True, | |
use_dilated_bottleneck: bool = False): | |
super().__init__() | |
# define basic blocks | |
self.norm = norm | |
self.leaky = leaky | |
norm_layer = self.get_norm_layer() | |
act_layer = self.get_act_layer() | |
res_unit = ResBlock if channels[-1] <= 512 else ResBottleneck | |
# check input channels | |
assert channels[1] == coarse_channels[2], 'The number of channel-2 for coarse and number of channel-1 for fine branch should be the same.' | |
# downsample and edge information extraction: | |
# the downsample operation provides the input for coarse branch | |
self.downsample = nn.AvgPool2d(3, stride=2, padding=1) | |
# the sobel filter is operated on the downsampled image to provide edge information | |
self.sobel = SobelEdge(input_dim=2, channels=in_channels) | |
self.sobel_conv = nn.Sequential( | |
nn.Conv2d(in_channels, channels[0], kernel_size=3, stride=2, padding=1), | |
norm_layer(channels[0]), | |
act_layer() | |
) | |
# coarse generator: in_channels -> coarse_channels[2] | |
# input stride: 0 | |
# output stride: 4 (as input is already 2x downsampled) | |
self.coarse = nn.Sequential( | |
nn.Conv2d(in_channels, coarse_channels[0], kernel_size=3, stride=2, padding=1), | |
norm_layer(coarse_channels[0]), | |
act_layer(), | |
res_unit(coarse_channels[0], coarse_channels[1], stride=2), | |
res_unit(coarse_channels[1], coarse_channels[2], stride=2), | |
res_unit(coarse_channels[2], coarse_channels[3], stride=2), | |
res_unit(coarse_channels[3], coarse_channels[4], stride=1), | |
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), | |
res_unit(coarse_channels[4], coarse_channels[3], stride=1), | |
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), | |
res_unit(coarse_channels[3], coarse_channels[2], stride=1), | |
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), | |
res_unit(coarse_channels[2], coarse_channels[2], stride=1), | |
) | |
# fine generator: used to enhance the generation for better details | |
# 1. simple encoder: channels[0] -> channels[1] | |
# input stride: 0 | |
# output stride: 4 | |
self.fine_encoder = nn.Sequential( | |
nn.Conv2d(in_channels, channels[0], kernel_size=3, stride=2, padding=1), | |
norm_layer(channels[0]), | |
act_layer(), | |
nn.Conv2d(channels[0], channels[1], kernel_size=3, stride=2, padding=1), | |
norm_layer(channels[1]), | |
act_layer() | |
) | |
# 2. bottleneck: channels[1] -> channels[4] | |
# input stride: 4 | |
# output stride: 16 | |
self.bottleneck = nn.Sequential( | |
res_unit(channels[1], channels[2], stride=2), | |
res_unit(channels[2], channels[3], stride=2), | |
res_unit(channels[3], channels[4], stride=1), | |
res_unit(channels[4], channels[4], stride=1), | |
) | |
if use_dilated_bottleneck: | |
self.bottleneck.add_module('dilated_block_1', | |
nn.Sequential( | |
nn.Conv2d(channels[4], channels[4], kernel_size=3, stride=1, padding=1, dilation=1), | |
norm_layer(channels[4]), | |
act_layer() | |
)) | |
self.bottleneck.add_module('dilated_block_2', | |
nn.Sequential( | |
nn.Conv2d(channels[4], channels[4], kernel_size=3, stride=1, padding=2, dilation=2), | |
norm_layer(channels[4]), | |
act_layer() | |
)) | |
self.bottleneck.add_module('dilated_block_3', | |
nn.Sequential( | |
nn.Conv2d(channels[4], channels[4], kernel_size=3, stride=1, padding=5, dilation=5), | |
norm_layer(channels[4]), | |
act_layer() | |
)) | |
self.bottleneck.add_module('dilated_block_4', | |
nn.Sequential( | |
nn.Conv2d(channels[4], channels[4], kernel_size=3, stride=1, padding=1, dilation=1), | |
norm_layer(channels[4]), | |
act_layer() | |
)) | |
self.bottleneck.add_module('dilated_block_5', | |
nn.Sequential( | |
nn.Conv2d(channels[4], channels[4], kernel_size=3, stride=1, padding=2, dilation=2), | |
norm_layer(channels[4]), | |
act_layer() | |
)) | |
self.bottleneck.add_module('dilated_block_6', | |
nn.Sequential( | |
nn.Conv2d(channels[4], channels[4], kernel_size=3, stride=1, padding=5, dilation=5), | |
norm_layer(channels[4]), | |
act_layer() | |
)) | |
# 3. simple decoder: channels[4] -> channels[0] | |
# input stride: 16 | |
# output stride: 2 | |
self.decoder = nn.Sequential( | |
res_unit(channels[4], channels[3], stride=1), | |
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), | |
res_unit(channels[3], channels[2], stride=1), | |
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), | |
res_unit(channels[2], channels[1], stride=1), | |
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), | |
res_unit(channels[1], channels[0], stride=1), | |
) | |
# output operation that combines both feature branch and edge branch | |
# input stride: 2 | |
# output stride: 0 | |
self.output = nn.Sequential( | |
nn.Conv2d(2 * channels[0], channels[0], kernel_size=3, stride=1, padding=1), | |
norm_layer(channels[0]), | |
act_layer(), | |
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), | |
nn.Conv2d(channels[0], out_channels, kernel_size=1, stride=1, bias=False) | |
) | |
def get_norm_layer(self): | |
if self.norm == 'INSTANCE': | |
return partial(nn.InstanceNorm2d, affine=False) | |
elif self.norm == 'BATCH': | |
return partial(nn.BatchNorm2d, affine=True, track_running_stats=True) | |
elif self.norm == 'GROUP': | |
return partial(nn.GroupNorm, num_groups=8) | |
else: | |
raise NotImplementedError(f'Normalization layer {self.norm} is not implemented.') | |
def get_act_layer(self): | |
if self.leaky: | |
return partial(nn.LeakyReLU, inplace=False) | |
else: | |
return partial(nn.ReLU, inplace=False) | |
def forward(self, inputs): | |
""" | |
Args: | |
inputs: (B, C, H, W), input IMC image | |
""" | |
# downsample and edge information extraction | |
downsampled = self.downsample(inputs) # 0 -> 2x stride | |
edge = self.sobel(inputs) | |
edge = self.sobel_conv(edge) | |
# coarse generator | |
coarse = self.coarse(downsampled) # 2x stride -> 4x stride | |
# fine generator | |
fine = self.fine_encoder(inputs) # 0x stride -> 4x stride | |
# add coarse and fine information together | |
fine = self.bottleneck(fine + coarse) # 4x stride -> 16x stride | |
fine = self.decoder(fine) # 16x stride -> 2x stride | |
# output operation | |
output = self.output(torch.cat([edge, fine], dim=1)) | |
return output |