import os import numpy as np import torch from torch import nn from torchvision import transforms from models.CtrlHair.external_code.face_parsing.my_parsing_util import FaceParsing_tensor from models.stylegan2.model import Generator from utils.drive import download_weight transform_to_256 = transforms.Compose([ transforms.Resize((256, 256)), ]) __all__ = ['Net', 'iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200', 'FeatureEncoderMult', 'IBasicBlock', 'conv1x1', 'get_segmentation'] class Net(nn.Module): def __init__(self, opts): super(Net, self).__init__() self.opts = opts self.generator = Generator(opts.size, opts.latent, opts.n_mlp, channel_multiplier=opts.channel_multiplier) self.cal_layer_num() self.load_weights() self.load_PCA_model() FaceParsing_tensor.parsing_img() def load_weights(self): if not os.path.exists(self.opts.ckpt): print('Downloading StyleGAN2 checkpoint: {}'.format(self.opts.ckpt)) download_weight(self.opts.ckpt) print('Loading StyleGAN2 from checkpoint: {}'.format(self.opts.ckpt)) checkpoint = torch.load(self.opts.ckpt) device = self.opts.device self.generator.load_state_dict(checkpoint['g_ema']) self.latent_avg = checkpoint['latent_avg'] self.generator.to(device) self.latent_avg = self.latent_avg.to(device) for param in self.generator.parameters(): param.requires_grad = False self.generator.eval() def build_PCA_model(self, PCA_path): with torch.no_grad(): latent = torch.randn((1000000, 512), dtype=torch.float32) # latent = torch.randn((10000, 512), dtype=torch.float32) self.generator.style.cpu() pulse_space = torch.nn.LeakyReLU(5)(self.generator.style(latent)).numpy() self.generator.style.to(self.opts.device) from utils.PCA_utils import IPCAEstimator transformer = IPCAEstimator(512) X_mean = pulse_space.mean(0) transformer.fit(pulse_space - X_mean) X_comp, X_stdev, X_var_ratio = transformer.get_components() np.savez(PCA_path, X_mean=X_mean, X_comp=X_comp, X_stdev=X_stdev, X_var_ratio=X_var_ratio) def load_PCA_model(self): device = self.opts.device PCA_path = self.opts.ckpt[:-3] + '_PCA.npz' if not os.path.isfile(PCA_path): self.build_PCA_model(PCA_path) PCA_model = np.load(PCA_path) self.X_mean = torch.from_numpy(PCA_model['X_mean']).float().to(device) self.X_comp = torch.from_numpy(PCA_model['X_comp']).float().to(device) self.X_stdev = torch.from_numpy(PCA_model['X_stdev']).float().to(device) # def make_noise(self): # noises_single = self.generator.make_noise() # noises = [] # for noise in noises_single: # noises.append(noise.repeat(1, 1, 1, 1).normal_()) # # return noises def cal_layer_num(self): if self.opts.size == 1024: self.layer_num = 18 elif self.opts.size == 512: self.layer_num = 16 elif self.opts.size == 256: self.layer_num = 14 self.S_index = self.layer_num - 11 return def cal_p_norm_loss(self, latent_in): latent_p_norm = (torch.nn.LeakyReLU(negative_slope=5)(latent_in) - self.X_mean).bmm( self.X_comp.T.unsqueeze(0)) / self.X_stdev p_norm_loss = self.opts.p_norm_lambda * (latent_p_norm.pow(2).mean()) return p_norm_loss def cal_l_F(self, latent_F, F_init): return self.opts.l_F_lambda * (latent_F - F_init).pow(2).mean() def get_segmentation(img_rgb, resize=True): parsing, _ = FaceParsing_tensor.parsing_img(img_rgb) parsing = FaceParsing_tensor.swap_parsing_label_to_celeba_mask(parsing) mask_img = parsing.long()[None, None, ...] if resize: mask_img = transforms.functional.resize(mask_img, (256, 256), interpolation=transforms.InterpolationMode.NEAREST) return mask_img fs_kernals = { 0: (12, 12), 1: (12, 12), 2: (6, 6), 3: (6, 6), 4: (3, 3), 5: (3, 3), 6: (3, 3), 7: (3, 3), } fs_strides = { 0: (7, 7), 1: (7, 7), 2: (4, 4), 3: (4, 4), 4: (2, 2), 5: (2, 2), 6: (1, 1), 7: (1, 1), } def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation) def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class IBasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1): super(IBasicBlock, self).__init__() if groups != 1 or base_width != 64: raise ValueError('BasicBlock only supports groups=1 and base_width=64') if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, ) self.conv1 = conv3x3(inplanes, planes) self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, ) self.prelu = nn.PReLU(planes) self.conv2 = conv3x3(planes, planes, stride) self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, ) self.downsample = downsample self.stride = stride def forward(self, x): identity = x out = self.bn1(x) out = self.conv1(out) out = self.bn2(out) out = self.prelu(out) out = self.conv2(out) out = self.bn3(out) if self.downsample is not None: identity = self.downsample(x) out += identity return out class IResNet(nn.Module): fc_scale = 7 * 7 def __init__(self, block, layers, dropout=0, num_features=512, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False): super(IResNet, self).__init__() self.fp16 = fp16 self.inplanes = 64 self.dilation = 1 if replace_stride_with_dilation is None: replace_stride_with_dilation = [False, False, False] if len(replace_stride_with_dilation) != 3: raise ValueError("replace_stride_with_dilation should be None " "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) self.groups = groups self.base_width = width_per_group self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) self.prelu = nn.PReLU(self.inplanes) self.layer1 = self._make_layer(block, 64, layers[0], stride=2) self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, ) self.dropout = nn.Dropout(p=dropout, inplace=True) self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) self.features = nn.BatchNorm1d(num_features, eps=1e-05) nn.init.constant_(self.features.weight, 1.0) self.features.weight.requires_grad = False for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.normal_(m.weight, 0, 0.1) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) if zero_init_residual: for m in self.modules(): if isinstance(m, IBasicBlock): nn.init.constant_(m.bn2.weight, 0) def _make_layer(self, block, planes, blocks, stride=1, dilate=False): downsample = None previous_dilation = self.dilation if dilate: self.dilation *= stride stride = 1 if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( conv1x1(self.inplanes, planes * block.expansion, stride), nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), ) layers = [] layers.append( block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation)) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append( block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation)) return nn.Sequential(*layers) def forward(self, x, return_features=False): out = [] with torch.cuda.amp.autocast(self.fp16): x = self.conv1(x) x = self.bn1(x) x = self.prelu(x) x = self.layer1(x) out.append(x) x = self.layer2(x) out.append(x) x = self.layer3(x) out.append(x) x = self.layer4(x) out.append(x) x = self.bn2(x) x = torch.flatten(x, 1) x = self.dropout(x) x = self.fc(x.float() if self.fp16 else x) x = self.features(x) if return_features: out.append(x) return out return x def _iresnet(arch, block, layers, pretrained, progress, **kwargs): model = IResNet(block, layers, **kwargs) if pretrained: raise ValueError() return model def iresnet18(pretrained=False, progress=True, **kwargs): return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) def iresnet34(pretrained=False, progress=True, **kwargs): return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs) def iresnet50(pretrained=False, progress=True, **kwargs): return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained, progress, **kwargs) def iresnet100(pretrained=False, progress=True, **kwargs): return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained, progress, **kwargs) def iresnet200(pretrained=False, progress=True, **kwargs): return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained, progress, **kwargs) class FeatureEncoder(nn.Module): def __init__(self, n_styles=18, opts=None, residual=False, use_coeff=False, resnet_layer=None, video_input=False, f_maps=512, stride=(1, 1)): super(FeatureEncoder, self).__init__() resnet50 = iresnet50() resnet50.load_state_dict(torch.load(opts.arcface_model_path)) # input conv layer if video_input: self.conv = nn.Sequential( nn.Conv2d(6, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), *list(resnet50.children())[1:3] ) else: self.conv = nn.Sequential(*list(resnet50.children())[:3]) # define layers self.block_1 = list(resnet50.children())[3] # 15-18 self.block_2 = list(resnet50.children())[4] # 10-14 self.block_3 = list(resnet50.children())[5] # 5-9 self.block_4 = list(resnet50.children())[6] # 1-4 self.content_layer = nn.Sequential( nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), nn.Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), nn.PReLU(num_parameters=512), nn.Conv2d(512, 512, kernel_size=(3, 3), stride=stride, padding=(1, 1), bias=False), nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) self.avg_pool = nn.AdaptiveAvgPool2d((3, 3)) self.styles = nn.ModuleList() for i in range(n_styles): self.styles.append(nn.Linear(960 * 9, 512)) def apply_head(self, x): latents = [] for i in range(len(self.styles)): latents.append(self.styles[i](x)) out = torch.stack(latents, dim=1) return out def forward(self, x): latents = [] features = [] x = self.conv(x) x = self.block_1(x) features.append(self.avg_pool(x)) x = self.block_2(x) features.append(self.avg_pool(x)) x = self.block_3(x) content = self.content_layer(x) features.append(self.avg_pool(x)) x = self.block_4(x) features.append(self.avg_pool(x)) x = torch.cat(features, dim=1) x = x.view(x.size(0), -1) return self.apply_head(x), content class FeatureEncoderMult(FeatureEncoder): def __init__(self, fs_layers=(5,), ranks=None, **kwargs): super().__init__(**kwargs) self.fs_layers = fs_layers self.content_layer = nn.ModuleList() self.ranks = ranks shift = 0 if max(fs_layers) <= 7 else 2 scale = 1 if max(fs_layers) <= 7 else 2 for i in range(len(fs_layers)): if ranks is not None: stride, kern = ranks_data[ranks[i] - shift] layer1 = nn.Sequential( nn.BatchNorm2d(256 // scale, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), nn.Conv2d(256 // scale, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), nn.PReLU(num_parameters=512), nn.Conv2d(512, 512, kernel_size=(fs_kernals[fs_layers[i] - shift][0], kern), stride=(fs_strides[fs_layers[i] - shift][0], stride), padding=(1, 1), bias=False), nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) layer2 = nn.Sequential( nn.BatchNorm2d(256 // scale, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), nn.Conv2d(256 // scale, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), nn.PReLU(num_parameters=512), nn.Conv2d(512, 512, kernel_size=(kern, fs_kernals[fs_layers[i] - shift][1]), stride=(stride, fs_strides[fs_layers[i] - shift][1]), padding=(1, 1), bias=False), nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) layer = nn.ModuleList([layer1, layer2]) else: layer = nn.Sequential( nn.BatchNorm2d(256 // scale, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), nn.Conv2d(256 // scale, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), nn.PReLU(num_parameters=512), nn.Conv2d(512, 512, kernel_size=fs_kernals[fs_layers[i] - shift], stride=fs_strides[fs_layers[i] - shift], padding=(1, 1), bias=False), nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) self.content_layer.append(layer) def forward(self, x): x = transform_to_256(x) features = [] content = [] x = self.conv(x) x = self.block_1(x) features.append(self.avg_pool(x)) x = self.block_2(x) if max(self.fs_layers) > 7: for layer in self.content_layer: if self.ranks is not None: mat1 = layer[0](x) mat2 = layer[1](x) content.append(torch.matmul(mat1, mat2)) else: content.append(layer(x)) features.append(self.avg_pool(x)) x = self.block_3(x) if len(content) == 0: for layer in self.content_layer: if self.ranks is not None: mat1 = layer[0](x) mat2 = layer[1](x) content.append(torch.matmul(mat1, mat2)) else: content.append(layer(x)) features.append(self.avg_pool(x)) x = self.block_4(x) features.append(self.avg_pool(x)) x = torch.cat(features, dim=1) x = x.view(x.size(0), -1) return self.apply_head(x), content def get_keys(d, name, key="state_dict"): if key in d: d = d[key] d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[: len(name) + 1] == name + '.'} return d_filt