rahul7star's picture
Upload 303 files
e0336bc verified
# Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
from typing import Any, cast, Dict, List, Union
import torch
from torch import nn, Tensor
from torch.nn import functional as F_torch
from torchvision import models, transforms
from torchvision.models.feature_extraction import create_feature_extractor
__all__ = [
"DiscriminatorForVGG", "RRDBNet", "ContentLoss",
"discriminator_for_vgg", "rrdbnet_x2", "rrdbnet_x4", "rrdbnet_x8"
]
feature_extractor_net_cfgs: Dict[str, List[Union[str, int]]] = {
"vgg11": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
"vgg13": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
"vgg16": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],
"vgg19": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"],
}
def _make_layers(net_cfg_name: str, batch_norm: bool = False) -> nn.Sequential:
net_cfg = feature_extractor_net_cfgs[net_cfg_name]
layers: nn.Sequential[nn.Module] = nn.Sequential()
in_channels = 3
for v in net_cfg:
if v == "M":
layers.append(nn.MaxPool2d((2, 2), (2, 2)))
else:
v = cast(int, v)
conv2d = nn.Conv2d(in_channels, v, (3, 3), (1, 1), (1, 1))
if batch_norm:
layers.append(conv2d)
layers.append(nn.BatchNorm2d(v))
layers.append(nn.ReLU(True))
else:
layers.append(conv2d)
layers.append(nn.ReLU(True))
in_channels = v
return layers
class _FeatureExtractor(nn.Module):
def __init__(
self,
net_cfg_name: str = "vgg19",
batch_norm: bool = False,
num_classes: int = 1000) -> None:
super(_FeatureExtractor, self).__init__()
self.features = _make_layers(net_cfg_name, batch_norm)
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(0.5),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(0.5),
nn.Linear(4096, num_classes),
)
# Initialize neural network weights
for module in self.modules():
if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
if module.bias is not None:
nn.init.constant_(module.bias, 0)
elif isinstance(module, nn.BatchNorm2d):
nn.init.constant_(module.weight, 1)
nn.init.constant_(module.bias, 0)
elif isinstance(module, nn.Linear):
nn.init.normal_(module.weight, 0, 0.01)
nn.init.constant_(module.bias, 0)
def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)
# Support torch.script function
def _forward_impl(self, x: Tensor) -> Tensor:
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
class RRDBNet(nn.Module):
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
channels: int = 64,
growth_channels: int = 32,
num_rrdb: int = 23,
upscale: int = 4,
) -> None:
super(RRDBNet, self).__init__()
self.upscale = upscale
# The first layer of convolutional layer.
self.conv1 = nn.Conv2d(in_channels, channels, (3, 3), (1, 1), (1, 1))
# Feature extraction backbone network.
trunk = []
for _ in range(num_rrdb):
trunk.append(_ResidualResidualDenseBlock(channels, growth_channels))
self.trunk = nn.Sequential(*trunk)
# After the feature extraction network, reconnect a layer of convolutional blocks.
self.conv2 = nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1))
# Upsampling convolutional layer.
if upscale == 2:
self.upsampling1 = nn.Sequential(
nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)),
nn.LeakyReLU(0.2, True)
)
if upscale == 4:
self.upsampling1 = nn.Sequential(
nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)),
nn.LeakyReLU(0.2, True)
)
self.upsampling2 = nn.Sequential(
nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)),
nn.LeakyReLU(0.2, True)
)
if upscale == 8:
self.upsampling1 = nn.Sequential(
nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)),
nn.LeakyReLU(0.2, True)
)
self.upsampling2 = nn.Sequential(
nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)),
nn.LeakyReLU(0.2, True)
)
self.upsampling3 = nn.Sequential(
nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)),
nn.LeakyReLU(0.2, True)
)
# Reconnect a layer of convolution block after upsampling.
self.conv3 = nn.Sequential(
nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)),
nn.LeakyReLU(0.2, True)
)
# Output layer.
self.conv4 = nn.Conv2d(channels, out_channels, (3, 3), (1, 1), (1, 1))
# Initialize all layer
for module in self.modules():
if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight)
module.weight.data *= 0.2
if module.bias is not None:
nn.init.constant_(module.bias, 0)
# The model should be defined in the Torch.script method.
def _forward_impl(self, x: Tensor) -> Tensor:
conv1 = self.conv1(x)
x = self.trunk(conv1)
x = self.conv2(x)
x = torch.add(x, conv1)
if self.upscale == 2:
x = self.upsampling1(F_torch.interpolate(x, scale_factor=2, mode="nearest"))
if self.upscale == 4:
x = self.upsampling1(F_torch.interpolate(x, scale_factor=2, mode="nearest"))
x = self.upsampling2(F_torch.interpolate(x, scale_factor=2, mode="nearest"))
if self.upscale == 8:
x = self.upsampling1(F_torch.interpolate(x, scale_factor=2, mode="nearest"))
x = self.upsampling2(F_torch.interpolate(x, scale_factor=2, mode="nearest"))
x = self.upsampling3(F_torch.interpolate(x, scale_factor=2, mode="nearest"))
x = self.conv3(x)
x = self.conv4(x)
return x
def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)
class _ResidualDenseBlock(nn.Module):
"""Achieves densely connected convolutional layers.
`Densely Connected Convolutional Networks <https://arxiv.org/pdf/1608.06993v5.pdf>` paper.
Args:
channels (int): The number of channels in the input image.
growth_channels (int): The number of channels that increase in each layer of convolution.
"""
def __init__(self, channels: int, growth_channels: int) -> None:
super(_ResidualDenseBlock, self).__init__()
self.conv1 = nn.Conv2d(channels + growth_channels * 0, growth_channels, (3, 3), (1, 1), (1, 1))
self.conv2 = nn.Conv2d(channels + growth_channels * 1, growth_channels, (3, 3), (1, 1), (1, 1))
self.conv3 = nn.Conv2d(channels + growth_channels * 2, growth_channels, (3, 3), (1, 1), (1, 1))
self.conv4 = nn.Conv2d(channels + growth_channels * 3, growth_channels, (3, 3), (1, 1), (1, 1))
self.conv5 = nn.Conv2d(channels + growth_channels * 4, channels, (3, 3), (1, 1), (1, 1))
self.leaky_relu = nn.LeakyReLU(0.2, True)
self.identity = nn.Identity()
def forward(self, x: Tensor) -> Tensor:
identity = x
out1 = self.leaky_relu(self.conv1(x))
out2 = self.leaky_relu(self.conv2(torch.cat([x, out1], 1)))
out3 = self.leaky_relu(self.conv3(torch.cat([x, out1, out2], 1)))
out4 = self.leaky_relu(self.conv4(torch.cat([x, out1, out2, out3], 1)))
out5 = self.identity(self.conv5(torch.cat([x, out1, out2, out3, out4], 1)))
x = torch.mul(out5, 0.2)
x = torch.add(x, identity)
return x
class _ResidualResidualDenseBlock(nn.Module):
"""Multi-layer residual dense convolution block.
Args:
channels (int): The number of channels in the input image.
growth_channels (int): The number of channels that increase in each layer of convolution.
"""
def __init__(self, channels: int, growth_channels: int) -> None:
super(_ResidualResidualDenseBlock, self).__init__()
self.rdb1 = _ResidualDenseBlock(channels, growth_channels)
self.rdb2 = _ResidualDenseBlock(channels, growth_channels)
self.rdb3 = _ResidualDenseBlock(channels, growth_channels)
def forward(self, x: Tensor) -> Tensor:
identity = x
x = self.rdb1(x)
x = self.rdb2(x)
x = self.rdb3(x)
x = torch.mul(x, 0.2)
x = torch.add(x, identity)
return x
class DiscriminatorForVGG(nn.Module):
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
channels: int = 64,
) -> None:
super(DiscriminatorForVGG, self).__init__()
self.features = nn.Sequential(
# input size. (3) x 128 x 128
nn.Conv2d(in_channels, channels, (3, 3), (1, 1), (1, 1), bias=True),
nn.LeakyReLU(0.2, True),
# state size. (64) x 64 x 64
nn.Conv2d(channels, channels, (4, 4), (2, 2), (1, 1), bias=False),
nn.BatchNorm2d(channels),
nn.LeakyReLU(0.2, True),
nn.Conv2d(channels, int(2 * channels), (3, 3), (1, 1), (1, 1), bias=False),
nn.BatchNorm2d(int(2 * channels)),
nn.LeakyReLU(0.2, True),
# state size. (128) x 32 x 32
nn.Conv2d(int(2 * channels), int(2 * channels), (4, 4), (2, 2), (1, 1), bias=False),
nn.BatchNorm2d(int(2 * channels)),
nn.LeakyReLU(0.2, True),
nn.Conv2d(int(2 * channels), int(4 * channels), (3, 3), (1, 1), (1, 1), bias=False),
nn.BatchNorm2d(int(4 * channels)),
nn.LeakyReLU(0.2, True),
# state size. (256) x 16 x 16
nn.Conv2d(int(4 * channels), int(4 * channels), (4, 4), (2, 2), (1, 1), bias=False),
nn.BatchNorm2d(int(4 * channels)),
nn.LeakyReLU(0.2, True),
nn.Conv2d(int(4 * channels), int(8 * channels), (3, 3), (1, 1), (1, 1), bias=False),
nn.BatchNorm2d(int(8 * channels)),
nn.LeakyReLU(0.2, True),
# state size. (512) x 8 x 8
nn.Conv2d(int(8 * channels), int(8 * channels), (4, 4), (2, 2), (1, 1), bias=False),
nn.BatchNorm2d(int(8 * channels)),
nn.LeakyReLU(0.2, True),
nn.Conv2d(int(8 * channels), int(8 * channels), (3, 3), (1, 1), (1, 1), bias=False),
nn.BatchNorm2d(int(8 * channels)),
nn.LeakyReLU(0.2, True),
# state size. (512) x 4 x 4
nn.Conv2d(int(8 * channels), int(8 * channels), (4, 4), (2, 2), (1, 1), bias=False),
nn.BatchNorm2d(int(8 * channels)),
nn.LeakyReLU(0.2, True)
)
self.classifier = nn.Sequential(
nn.Linear(int(8 * channels) * 4 * 4, 100),
nn.LeakyReLU(0.2, True),
nn.Linear(100, out_channels)
)
def forward(self, x: Tensor) -> Tensor:
out = self.features(x)
out = torch.flatten(out, 1)
out = self.classifier(out)
return out
class ContentLoss(nn.Module):
"""Constructs a content loss function based on the VGG19 network.
Using high-level feature mapping layers from the latter layers will focus more on the texture content of the image.
Paper reference list:
-`Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network <https://arxiv.org/pdf/1609.04802.pdf>` paper.
-`ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks <https://arxiv.org/pdf/1809.00219.pdf>` paper.
-`Perceptual Extreme Super Resolution Network with Receptive Field Block <https://arxiv.org/pdf/2005.12597.pdf>` paper.
"""
def __init__(
self,
net_cfg_name: str = "vgg19",
batch_norm: bool = False,
num_classes: int = 1000,
model_weights_path: str = "",
feature_nodes: list = None,
feature_normalize_mean: list = None,
feature_normalize_std: list = None,
) -> None:
super(ContentLoss, self).__init__()
# Define the feature extraction model
model = _FeatureExtractor(net_cfg_name, batch_norm, num_classes)
# Load the pre-trained model
if model_weights_path == "":
model = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1)
elif model_weights_path is not None and os.path.exists(model_weights_path):
checkpoint = torch.load(model_weights_path, map_location=lambda storage, loc: storage)
if "state_dict" in checkpoint.keys():
model.load_state_dict(checkpoint["state_dict"])
else:
model.load_state_dict(checkpoint)
else:
raise FileNotFoundError("Model weight file not found")
# Extract the output of the feature extraction layer
self.feature_extractor = create_feature_extractor(model, feature_nodes)
# Select the specified layers as the feature extraction layer
self.feature_extractor_nodes = feature_nodes
# input normalization
self.normalize = transforms.Normalize(feature_normalize_mean, feature_normalize_std)
# Freeze model parameters without derivatives
for model_parameters in self.feature_extractor.parameters():
model_parameters.requires_grad = False
self.feature_extractor.eval()
def forward(self, sr_tensor: Tensor, gt_tensor: Tensor) -> [Tensor]:
assert sr_tensor.size() == gt_tensor.size(), "Two tensor must have the same size"
device = sr_tensor.device
losses = []
# input normalization
sr_tensor = self.normalize(sr_tensor)
gt_tensor = self.normalize(gt_tensor)
# Get the output of the feature extraction layer
sr_feature = self.feature_extractor(sr_tensor)
gt_feature = self.feature_extractor(gt_tensor)
# Compute feature loss
for i in range(len(self.feature_extractor_nodes)):
losses.append(F_torch.l1_loss(sr_feature[self.feature_extractor_nodes[i]],
gt_feature[self.feature_extractor_nodes[i]]))
losses = torch.Tensor([losses]).to(device)
return losses
def rrdbnet_x2(**kwargs: Any) -> RRDBNet:
model = RRDBNet(upscale=2, **kwargs)
return model
def rrdbnet_x4(**kwargs: Any) -> RRDBNet:
model = RRDBNet(upscale=4, **kwargs)
return model
def rrdbnet_x8(**kwargs: Any) -> RRDBNet:
model = RRDBNet(upscale=8, **kwargs)
return model
def discriminator_for_vgg(**kwargs) -> DiscriminatorForVGG:
model = DiscriminatorForVGG(**kwargs)
return model