newTryOn / models /Net.py
amanSethSmava
new commit
6d314be
raw
history blame
18.6 kB
import os
import numpy as np
import torch
from torch import nn
from torchvision import transforms
from models.CtrlHair.external_code.face_parsing.my_parsing_util import FaceParsing_tensor
from models.stylegan2.model import Generator
from utils.drive import download_weight
transform_to_256 = transforms.Compose([
transforms.Resize((256, 256)),
])
__all__ = ['Net', 'iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200', 'FeatureEncoderMult',
'IBasicBlock', 'conv1x1', 'get_segmentation']
class Net(nn.Module):
def __init__(self, opts):
super(Net, self).__init__()
self.opts = opts
self.generator = Generator(opts.size, opts.latent, opts.n_mlp, channel_multiplier=opts.channel_multiplier)
self.cal_layer_num()
self.load_weights()
self.load_PCA_model()
FaceParsing_tensor.parsing_img()
def load_weights(self):
if not os.path.exists(self.opts.ckpt):
print('Downloading StyleGAN2 checkpoint: {}'.format(self.opts.ckpt))
download_weight(self.opts.ckpt)
print('Loading StyleGAN2 from checkpoint: {}'.format(self.opts.ckpt))
checkpoint = torch.load(self.opts.ckpt)
device = self.opts.device
self.generator.load_state_dict(checkpoint['g_ema'])
self.latent_avg = checkpoint['latent_avg']
self.generator.to(device)
self.latent_avg = self.latent_avg.to(device)
for param in self.generator.parameters():
param.requires_grad = False
self.generator.eval()
def build_PCA_model(self, PCA_path):
with torch.no_grad():
latent = torch.randn((1000000, 512), dtype=torch.float32)
# latent = torch.randn((10000, 512), dtype=torch.float32)
self.generator.style.cpu()
pulse_space = torch.nn.LeakyReLU(5)(self.generator.style(latent)).numpy()
self.generator.style.to(self.opts.device)
from utils.PCA_utils import IPCAEstimator
transformer = IPCAEstimator(512)
X_mean = pulse_space.mean(0)
transformer.fit(pulse_space - X_mean)
X_comp, X_stdev, X_var_ratio = transformer.get_components()
np.savez(PCA_path, X_mean=X_mean, X_comp=X_comp, X_stdev=X_stdev, X_var_ratio=X_var_ratio)
def load_PCA_model(self):
device = self.opts.device
PCA_path = self.opts.ckpt[:-3] + '_PCA.npz'
if not os.path.isfile(PCA_path):
self.build_PCA_model(PCA_path)
PCA_model = np.load(PCA_path)
self.X_mean = torch.from_numpy(PCA_model['X_mean']).float().to(device)
self.X_comp = torch.from_numpy(PCA_model['X_comp']).float().to(device)
self.X_stdev = torch.from_numpy(PCA_model['X_stdev']).float().to(device)
# def make_noise(self):
# noises_single = self.generator.make_noise()
# noises = []
# for noise in noises_single:
# noises.append(noise.repeat(1, 1, 1, 1).normal_())
#
# return noises
def cal_layer_num(self):
if self.opts.size == 1024:
self.layer_num = 18
elif self.opts.size == 512:
self.layer_num = 16
elif self.opts.size == 256:
self.layer_num = 14
self.S_index = self.layer_num - 11
return
def cal_p_norm_loss(self, latent_in):
latent_p_norm = (torch.nn.LeakyReLU(negative_slope=5)(latent_in) - self.X_mean).bmm(
self.X_comp.T.unsqueeze(0)) / self.X_stdev
p_norm_loss = self.opts.p_norm_lambda * (latent_p_norm.pow(2).mean())
return p_norm_loss
def cal_l_F(self, latent_F, F_init):
return self.opts.l_F_lambda * (latent_F - F_init).pow(2).mean()
def get_segmentation(img_rgb, resize=True):
parsing, _ = FaceParsing_tensor.parsing_img(img_rgb)
parsing = FaceParsing_tensor.swap_parsing_label_to_celeba_mask(parsing)
mask_img = parsing.long()[None, None, ...]
if resize:
mask_img = transforms.functional.resize(mask_img, (256, 256),
interpolation=transforms.InterpolationMode.NEAREST)
return mask_img
fs_kernals = {
0: (12, 12),
1: (12, 12),
2: (6, 6),
3: (6, 6),
4: (3, 3),
5: (3, 3),
6: (3, 3),
7: (3, 3),
}
fs_strides = {
0: (7, 7),
1: (7, 7),
2: (4, 4),
3: (4, 4),
4: (2, 2),
5: (2, 2),
6: (1, 1),
7: (1, 1),
}
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=False,
dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes,
out_planes,
kernel_size=1,
stride=stride,
bias=False)
class IBasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None,
groups=1, base_width=64, dilation=1):
super(IBasicBlock, self).__init__()
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, )
self.conv1 = conv3x3(inplanes, planes)
self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, )
self.prelu = nn.PReLU(planes)
self.conv2 = conv3x3(planes, planes, stride)
self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, )
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.bn1(x)
out = self.conv1(out)
out = self.bn2(out)
out = self.prelu(out)
out = self.conv2(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
class IResNet(nn.Module):
fc_scale = 7 * 7
def __init__(self,
block, layers, dropout=0, num_features=512, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
super(IResNet, self).__init__()
self.fp16 = fp16
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
self.prelu = nn.PReLU(self.inplanes)
self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
self.layer2 = self._make_layer(block,
128,
layers[1],
stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block,
256,
layers[2],
stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block,
512,
layers[3],
stride=2,
dilate=replace_stride_with_dilation[2])
self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, )
self.dropout = nn.Dropout(p=dropout, inplace=True)
self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
self.features = nn.BatchNorm1d(num_features, eps=1e-05)
nn.init.constant_(self.features.weight, 1.0)
self.features.weight.requires_grad = False
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, 0, 0.1)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
if zero_init_residual:
for m in self.modules():
if isinstance(m, IBasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
)
layers = []
layers.append(
block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
block(self.inplanes,
planes,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation))
return nn.Sequential(*layers)
def forward(self, x, return_features=False):
out = []
with torch.cuda.amp.autocast(self.fp16):
x = self.conv1(x)
x = self.bn1(x)
x = self.prelu(x)
x = self.layer1(x)
out.append(x)
x = self.layer2(x)
out.append(x)
x = self.layer3(x)
out.append(x)
x = self.layer4(x)
out.append(x)
x = self.bn2(x)
x = torch.flatten(x, 1)
x = self.dropout(x)
x = self.fc(x.float() if self.fp16 else x)
x = self.features(x)
if return_features:
out.append(x)
return out
return x
def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
model = IResNet(block, layers, **kwargs)
if pretrained:
raise ValueError()
return model
def iresnet18(pretrained=False, progress=True, **kwargs):
return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
progress, **kwargs)
def iresnet34(pretrained=False, progress=True, **kwargs):
return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
progress, **kwargs)
def iresnet50(pretrained=False, progress=True, **kwargs):
return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
progress, **kwargs)
def iresnet100(pretrained=False, progress=True, **kwargs):
return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
progress, **kwargs)
def iresnet200(pretrained=False, progress=True, **kwargs):
return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,
progress, **kwargs)
class FeatureEncoder(nn.Module):
def __init__(self, n_styles=18, opts=None, residual=False,
use_coeff=False, resnet_layer=None,
video_input=False, f_maps=512, stride=(1, 1)):
super(FeatureEncoder, self).__init__()
resnet50 = iresnet50()
resnet50.load_state_dict(torch.load(opts.arcface_model_path))
# input conv layer
if video_input:
self.conv = nn.Sequential(
nn.Conv2d(6, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
*list(resnet50.children())[1:3]
)
else:
self.conv = nn.Sequential(*list(resnet50.children())[:3])
# define layers
self.block_1 = list(resnet50.children())[3] # 15-18
self.block_2 = list(resnet50.children())[4] # 10-14
self.block_3 = list(resnet50.children())[5] # 5-9
self.block_4 = list(resnet50.children())[6] # 1-4
self.content_layer = nn.Sequential(
nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
nn.Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
nn.PReLU(num_parameters=512),
nn.Conv2d(512, 512, kernel_size=(3, 3), stride=stride, padding=(1, 1), bias=False),
nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
self.avg_pool = nn.AdaptiveAvgPool2d((3, 3))
self.styles = nn.ModuleList()
for i in range(n_styles):
self.styles.append(nn.Linear(960 * 9, 512))
def apply_head(self, x):
latents = []
for i in range(len(self.styles)):
latents.append(self.styles[i](x))
out = torch.stack(latents, dim=1)
return out
def forward(self, x):
latents = []
features = []
x = self.conv(x)
x = self.block_1(x)
features.append(self.avg_pool(x))
x = self.block_2(x)
features.append(self.avg_pool(x))
x = self.block_3(x)
content = self.content_layer(x)
features.append(self.avg_pool(x))
x = self.block_4(x)
features.append(self.avg_pool(x))
x = torch.cat(features, dim=1)
x = x.view(x.size(0), -1)
return self.apply_head(x), content
class FeatureEncoderMult(FeatureEncoder):
def __init__(self, fs_layers=(5,), ranks=None, **kwargs):
super().__init__(**kwargs)
self.fs_layers = fs_layers
self.content_layer = nn.ModuleList()
self.ranks = ranks
shift = 0 if max(fs_layers) <= 7 else 2
scale = 1 if max(fs_layers) <= 7 else 2
for i in range(len(fs_layers)):
if ranks is not None:
stride, kern = ranks_data[ranks[i] - shift]
layer1 = nn.Sequential(
nn.BatchNorm2d(256 // scale, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
nn.Conv2d(256 // scale, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
nn.PReLU(num_parameters=512),
nn.Conv2d(512, 512, kernel_size=(fs_kernals[fs_layers[i] - shift][0], kern),
stride=(fs_strides[fs_layers[i] - shift][0], stride),
padding=(1, 1), bias=False),
nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
layer2 = nn.Sequential(
nn.BatchNorm2d(256 // scale, eps=1e-05, momentum=0.1, affine=True,
track_running_stats=True),
nn.Conv2d(256 // scale, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1),
bias=False),
nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True,
track_running_stats=True),
nn.PReLU(num_parameters=512),
nn.Conv2d(512, 512, kernel_size=(kern, fs_kernals[fs_layers[i] - shift][1]),
stride=(stride, fs_strides[fs_layers[i] - shift][1]),
padding=(1, 1), bias=False),
nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True,
track_running_stats=True)
)
layer = nn.ModuleList([layer1, layer2])
else:
layer = nn.Sequential(
nn.BatchNorm2d(256 // scale, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
nn.Conv2d(256 // scale, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
nn.PReLU(num_parameters=512),
nn.Conv2d(512, 512, kernel_size=fs_kernals[fs_layers[i] - shift],
stride=fs_strides[fs_layers[i] - shift],
padding=(1, 1), bias=False),
nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
self.content_layer.append(layer)
def forward(self, x):
x = transform_to_256(x)
features = []
content = []
x = self.conv(x)
x = self.block_1(x)
features.append(self.avg_pool(x))
x = self.block_2(x)
if max(self.fs_layers) > 7:
for layer in self.content_layer:
if self.ranks is not None:
mat1 = layer[0](x)
mat2 = layer[1](x)
content.append(torch.matmul(mat1, mat2))
else:
content.append(layer(x))
features.append(self.avg_pool(x))
x = self.block_3(x)
if len(content) == 0:
for layer in self.content_layer:
if self.ranks is not None:
mat1 = layer[0](x)
mat2 = layer[1](x)
content.append(torch.matmul(mat1, mat2))
else:
content.append(layer(x))
features.append(self.avg_pool(x))
x = self.block_4(x)
features.append(self.avg_pool(x))
x = torch.cat(features, dim=1)
x = x.view(x.size(0), -1)
return self.apply_head(x), content
def get_keys(d, name, key="state_dict"):
if key in d:
d = d[key]
d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[: len(name) + 1] == name + '.'}
return d_filt