# Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import os from typing import Any, cast, Dict, List, Union import torch from torch import nn, Tensor from torch.nn import functional as F_torch from torchvision import models, transforms from torchvision.models.feature_extraction import create_feature_extractor __all__ = [ "DiscriminatorForVGG", "RRDBNet", "ContentLoss", "discriminator_for_vgg", "rrdbnet_x2", "rrdbnet_x4", "rrdbnet_x8" ] feature_extractor_net_cfgs: Dict[str, List[Union[str, int]]] = { "vgg11": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], "vgg13": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], "vgg16": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"], "vgg19": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"], } def _make_layers(net_cfg_name: str, batch_norm: bool = False) -> nn.Sequential: net_cfg = feature_extractor_net_cfgs[net_cfg_name] layers: nn.Sequential[nn.Module] = nn.Sequential() in_channels = 3 for v in net_cfg: if v == "M": layers.append(nn.MaxPool2d((2, 2), (2, 2))) else: v = cast(int, v) conv2d = nn.Conv2d(in_channels, v, (3, 3), (1, 1), (1, 1)) if batch_norm: layers.append(conv2d) layers.append(nn.BatchNorm2d(v)) layers.append(nn.ReLU(True)) else: layers.append(conv2d) layers.append(nn.ReLU(True)) in_channels = v return layers class _FeatureExtractor(nn.Module): def __init__( self, net_cfg_name: str = "vgg19", batch_norm: bool = False, num_classes: int = 1000) -> None: super(_FeatureExtractor, self).__init__() self.features = _make_layers(net_cfg_name, batch_norm) self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) self.classifier = nn.Sequential( nn.Linear(512 * 7 * 7, 4096), nn.ReLU(True), nn.Dropout(0.5), nn.Linear(4096, 4096), nn.ReLU(True), nn.Dropout(0.5), nn.Linear(4096, num_classes), ) # Initialize neural network weights for module in self.modules(): if isinstance(module, nn.Conv2d): nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") if module.bias is not None: nn.init.constant_(module.bias, 0) elif isinstance(module, nn.BatchNorm2d): nn.init.constant_(module.weight, 1) nn.init.constant_(module.bias, 0) elif isinstance(module, nn.Linear): nn.init.normal_(module.weight, 0, 0.01) nn.init.constant_(module.bias, 0) def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) # Support torch.script function def _forward_impl(self, x: Tensor) -> Tensor: x = self.features(x) x = self.avgpool(x) x = torch.flatten(x, 1) x = self.classifier(x) return x class RRDBNet(nn.Module): def __init__( self, in_channels: int = 3, out_channels: int = 3, channels: int = 64, growth_channels: int = 32, num_rrdb: int = 23, upscale: int = 4, ) -> None: super(RRDBNet, self).__init__() self.upscale = upscale # The first layer of convolutional layer. self.conv1 = nn.Conv2d(in_channels, channels, (3, 3), (1, 1), (1, 1)) # Feature extraction backbone network. trunk = [] for _ in range(num_rrdb): trunk.append(_ResidualResidualDenseBlock(channels, growth_channels)) self.trunk = nn.Sequential(*trunk) # After the feature extraction network, reconnect a layer of convolutional blocks. self.conv2 = nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)) # Upsampling convolutional layer. if upscale == 2: self.upsampling1 = nn.Sequential( nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)), nn.LeakyReLU(0.2, True) ) if upscale == 4: self.upsampling1 = nn.Sequential( nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)), nn.LeakyReLU(0.2, True) ) self.upsampling2 = nn.Sequential( nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)), nn.LeakyReLU(0.2, True) ) if upscale == 8: self.upsampling1 = nn.Sequential( nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)), nn.LeakyReLU(0.2, True) ) self.upsampling2 = nn.Sequential( nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)), nn.LeakyReLU(0.2, True) ) self.upsampling3 = nn.Sequential( nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)), nn.LeakyReLU(0.2, True) ) # Reconnect a layer of convolution block after upsampling. self.conv3 = nn.Sequential( nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)), nn.LeakyReLU(0.2, True) ) # Output layer. self.conv4 = nn.Conv2d(channels, out_channels, (3, 3), (1, 1), (1, 1)) # Initialize all layer for module in self.modules(): if isinstance(module, nn.Conv2d): nn.init.kaiming_normal_(module.weight) module.weight.data *= 0.2 if module.bias is not None: nn.init.constant_(module.bias, 0) # The model should be defined in the Torch.script method. def _forward_impl(self, x: Tensor) -> Tensor: conv1 = self.conv1(x) x = self.trunk(conv1) x = self.conv2(x) x = torch.add(x, conv1) if self.upscale == 2: x = self.upsampling1(F_torch.interpolate(x, scale_factor=2, mode="nearest")) if self.upscale == 4: x = self.upsampling1(F_torch.interpolate(x, scale_factor=2, mode="nearest")) x = self.upsampling2(F_torch.interpolate(x, scale_factor=2, mode="nearest")) if self.upscale == 8: x = self.upsampling1(F_torch.interpolate(x, scale_factor=2, mode="nearest")) x = self.upsampling2(F_torch.interpolate(x, scale_factor=2, mode="nearest")) x = self.upsampling3(F_torch.interpolate(x, scale_factor=2, mode="nearest")) x = self.conv3(x) x = self.conv4(x) return x def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) class _ResidualDenseBlock(nn.Module): """Achieves densely connected convolutional layers. `Densely Connected Convolutional Networks ` paper. Args: channels (int): The number of channels in the input image. growth_channels (int): The number of channels that increase in each layer of convolution. """ def __init__(self, channels: int, growth_channels: int) -> None: super(_ResidualDenseBlock, self).__init__() self.conv1 = nn.Conv2d(channels + growth_channels * 0, growth_channels, (3, 3), (1, 1), (1, 1)) self.conv2 = nn.Conv2d(channels + growth_channels * 1, growth_channels, (3, 3), (1, 1), (1, 1)) self.conv3 = nn.Conv2d(channels + growth_channels * 2, growth_channels, (3, 3), (1, 1), (1, 1)) self.conv4 = nn.Conv2d(channels + growth_channels * 3, growth_channels, (3, 3), (1, 1), (1, 1)) self.conv5 = nn.Conv2d(channels + growth_channels * 4, channels, (3, 3), (1, 1), (1, 1)) self.leaky_relu = nn.LeakyReLU(0.2, True) self.identity = nn.Identity() def forward(self, x: Tensor) -> Tensor: identity = x out1 = self.leaky_relu(self.conv1(x)) out2 = self.leaky_relu(self.conv2(torch.cat([x, out1], 1))) out3 = self.leaky_relu(self.conv3(torch.cat([x, out1, out2], 1))) out4 = self.leaky_relu(self.conv4(torch.cat([x, out1, out2, out3], 1))) out5 = self.identity(self.conv5(torch.cat([x, out1, out2, out3, out4], 1))) x = torch.mul(out5, 0.2) x = torch.add(x, identity) return x class _ResidualResidualDenseBlock(nn.Module): """Multi-layer residual dense convolution block. Args: channels (int): The number of channels in the input image. growth_channels (int): The number of channels that increase in each layer of convolution. """ def __init__(self, channels: int, growth_channels: int) -> None: super(_ResidualResidualDenseBlock, self).__init__() self.rdb1 = _ResidualDenseBlock(channels, growth_channels) self.rdb2 = _ResidualDenseBlock(channels, growth_channels) self.rdb3 = _ResidualDenseBlock(channels, growth_channels) def forward(self, x: Tensor) -> Tensor: identity = x x = self.rdb1(x) x = self.rdb2(x) x = self.rdb3(x) x = torch.mul(x, 0.2) x = torch.add(x, identity) return x class DiscriminatorForVGG(nn.Module): def __init__( self, in_channels: int = 3, out_channels: int = 3, channels: int = 64, ) -> None: super(DiscriminatorForVGG, self).__init__() self.features = nn.Sequential( # input size. (3) x 128 x 128 nn.Conv2d(in_channels, channels, (3, 3), (1, 1), (1, 1), bias=True), nn.LeakyReLU(0.2, True), # state size. (64) x 64 x 64 nn.Conv2d(channels, channels, (4, 4), (2, 2), (1, 1), bias=False), nn.BatchNorm2d(channels), nn.LeakyReLU(0.2, True), nn.Conv2d(channels, int(2 * channels), (3, 3), (1, 1), (1, 1), bias=False), nn.BatchNorm2d(int(2 * channels)), nn.LeakyReLU(0.2, True), # state size. (128) x 32 x 32 nn.Conv2d(int(2 * channels), int(2 * channels), (4, 4), (2, 2), (1, 1), bias=False), nn.BatchNorm2d(int(2 * channels)), nn.LeakyReLU(0.2, True), nn.Conv2d(int(2 * channels), int(4 * channels), (3, 3), (1, 1), (1, 1), bias=False), nn.BatchNorm2d(int(4 * channels)), nn.LeakyReLU(0.2, True), # state size. (256) x 16 x 16 nn.Conv2d(int(4 * channels), int(4 * channels), (4, 4), (2, 2), (1, 1), bias=False), nn.BatchNorm2d(int(4 * channels)), nn.LeakyReLU(0.2, True), nn.Conv2d(int(4 * channels), int(8 * channels), (3, 3), (1, 1), (1, 1), bias=False), nn.BatchNorm2d(int(8 * channels)), nn.LeakyReLU(0.2, True), # state size. (512) x 8 x 8 nn.Conv2d(int(8 * channels), int(8 * channels), (4, 4), (2, 2), (1, 1), bias=False), nn.BatchNorm2d(int(8 * channels)), nn.LeakyReLU(0.2, True), nn.Conv2d(int(8 * channels), int(8 * channels), (3, 3), (1, 1), (1, 1), bias=False), nn.BatchNorm2d(int(8 * channels)), nn.LeakyReLU(0.2, True), # state size. (512) x 4 x 4 nn.Conv2d(int(8 * channels), int(8 * channels), (4, 4), (2, 2), (1, 1), bias=False), nn.BatchNorm2d(int(8 * channels)), nn.LeakyReLU(0.2, True) ) self.classifier = nn.Sequential( nn.Linear(int(8 * channels) * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, out_channels) ) def forward(self, x: Tensor) -> Tensor: out = self.features(x) out = torch.flatten(out, 1) out = self.classifier(out) return out class ContentLoss(nn.Module): """Constructs a content loss function based on the VGG19 network. Using high-level feature mapping layers from the latter layers will focus more on the texture content of the image. Paper reference list: -`Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network ` paper. -`ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks ` paper. -`Perceptual Extreme Super Resolution Network with Receptive Field Block ` paper. """ def __init__( self, net_cfg_name: str = "vgg19", batch_norm: bool = False, num_classes: int = 1000, model_weights_path: str = "", feature_nodes: list = None, feature_normalize_mean: list = None, feature_normalize_std: list = None, ) -> None: super(ContentLoss, self).__init__() # Define the feature extraction model model = _FeatureExtractor(net_cfg_name, batch_norm, num_classes) # Load the pre-trained model if model_weights_path == "": model = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1) elif model_weights_path is not None and os.path.exists(model_weights_path): checkpoint = torch.load(model_weights_path, map_location=lambda storage, loc: storage) if "state_dict" in checkpoint.keys(): model.load_state_dict(checkpoint["state_dict"]) else: model.load_state_dict(checkpoint) else: raise FileNotFoundError("Model weight file not found") # Extract the output of the feature extraction layer self.feature_extractor = create_feature_extractor(model, feature_nodes) # Select the specified layers as the feature extraction layer self.feature_extractor_nodes = feature_nodes # input normalization self.normalize = transforms.Normalize(feature_normalize_mean, feature_normalize_std) # Freeze model parameters without derivatives for model_parameters in self.feature_extractor.parameters(): model_parameters.requires_grad = False self.feature_extractor.eval() def forward(self, sr_tensor: Tensor, gt_tensor: Tensor) -> [Tensor]: assert sr_tensor.size() == gt_tensor.size(), "Two tensor must have the same size" device = sr_tensor.device losses = [] # input normalization sr_tensor = self.normalize(sr_tensor) gt_tensor = self.normalize(gt_tensor) # Get the output of the feature extraction layer sr_feature = self.feature_extractor(sr_tensor) gt_feature = self.feature_extractor(gt_tensor) # Compute feature loss for i in range(len(self.feature_extractor_nodes)): losses.append(F_torch.l1_loss(sr_feature[self.feature_extractor_nodes[i]], gt_feature[self.feature_extractor_nodes[i]])) losses = torch.Tensor([losses]).to(device) return losses def rrdbnet_x2(**kwargs: Any) -> RRDBNet: model = RRDBNet(upscale=2, **kwargs) return model def rrdbnet_x4(**kwargs: Any) -> RRDBNet: model = RRDBNet(upscale=4, **kwargs) return model def rrdbnet_x8(**kwargs: Any) -> RRDBNet: model = RRDBNet(upscale=8, **kwargs) return model def discriminator_for_vgg(**kwargs) -> DiscriminatorForVGG: model = DiscriminatorForVGG(**kwargs) return model