Spaces:
Build error
Build error
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn.utils import spectral_norm | |
from torchvision import models, utils | |
from arcface.iresnet import * | |
class fs_encoder_v2(nn.Module): | |
def __init__(self, n_styles=18, opts=None, residual=False, use_coeff=False, resnet_layer=None, video_input=False, f_maps=512, stride=(1, 1)): | |
super(fs_encoder_v2, self).__init__() | |
resnet50 = iresnet50() | |
resnet50.load_state_dict(torch.load(opts.arcface_model_path)) | |
# input conv layer | |
if video_input: | |
self.conv = nn.Sequential( | |
nn.Conv2d(6, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), | |
*list(resnet50.children())[1:3] | |
) | |
else: | |
self.conv = nn.Sequential(*list(resnet50.children())[:3]) | |
# define layers | |
self.block_1 = list(resnet50.children())[3] # 15-18 | |
self.block_2 = list(resnet50.children())[4] # 10-14 | |
self.block_3 = list(resnet50.children())[5] # 5-9 | |
self.block_4 = list(resnet50.children())[6] # 1-4 | |
self.content_layer = nn.Sequential( | |
nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), | |
nn.Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), | |
nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), | |
nn.PReLU(num_parameters=512), | |
nn.Conv2d(512, 512, kernel_size=(3, 3), stride=stride, padding=(1, 1), bias=False), | |
nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) | |
) | |
self.avg_pool = nn.AdaptiveAvgPool2d((3,3)) | |
self.styles = nn.ModuleList() | |
for i in range(n_styles): | |
self.styles.append(nn.Linear(960 * 9, 512)) | |
def forward(self, x): | |
latents = [] | |
features = [] | |
x = self.conv(x) | |
x = self.block_1(x) | |
features.append(self.avg_pool(x)) | |
x = self.block_2(x) | |
features.append(self.avg_pool(x)) | |
x = self.block_3(x) | |
content = self.content_layer(x) | |
features.append(self.avg_pool(x)) | |
x = self.block_4(x) | |
features.append(self.avg_pool(x)) | |
x = torch.cat(features, dim=1) | |
x = x.view(x.size(0), -1) | |
for i in range(len(self.styles)): | |
latents.append(self.styles[i](x)) | |
out = torch.stack(latents, dim=1) | |
return out, content | |