Spaces:
Build error
Build error
File size: 2,535 Bytes
6d314be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
from torchvision import models, utils
from arcface.iresnet import *
class fs_encoder_v2(nn.Module):
def __init__(self, n_styles=18, opts=None, residual=False, use_coeff=False, resnet_layer=None, video_input=False, f_maps=512, stride=(1, 1)):
super(fs_encoder_v2, self).__init__()
resnet50 = iresnet50()
resnet50.load_state_dict(torch.load(opts.arcface_model_path))
# input conv layer
if video_input:
self.conv = nn.Sequential(
nn.Conv2d(6, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
*list(resnet50.children())[1:3]
)
else:
self.conv = nn.Sequential(*list(resnet50.children())[:3])
# define layers
self.block_1 = list(resnet50.children())[3] # 15-18
self.block_2 = list(resnet50.children())[4] # 10-14
self.block_3 = list(resnet50.children())[5] # 5-9
self.block_4 = list(resnet50.children())[6] # 1-4
self.content_layer = nn.Sequential(
nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
nn.Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
nn.PReLU(num_parameters=512),
nn.Conv2d(512, 512, kernel_size=(3, 3), stride=stride, padding=(1, 1), bias=False),
nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
self.avg_pool = nn.AdaptiveAvgPool2d((3,3))
self.styles = nn.ModuleList()
for i in range(n_styles):
self.styles.append(nn.Linear(960 * 9, 512))
def forward(self, x):
latents = []
features = []
x = self.conv(x)
x = self.block_1(x)
features.append(self.avg_pool(x))
x = self.block_2(x)
features.append(self.avg_pool(x))
x = self.block_3(x)
content = self.content_layer(x)
features.append(self.avg_pool(x))
x = self.block_4(x)
features.append(self.avg_pool(x))
x = torch.cat(features, dim=1)
x = x.view(x.size(0), -1)
for i in range(len(self.styles)):
latents.append(self.styles[i](x))
out = torch.stack(latents, dim=1)
return out, content
|