newTryOn / models /FeatureStyleEncoder /nets /feature_style_encoder.py
amanSethSmava
new commit
6d314be
raw
history blame
2.54 kB
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
from torchvision import models, utils
from arcface.iresnet import *
class fs_encoder_v2(nn.Module):
def __init__(self, n_styles=18, opts=None, residual=False, use_coeff=False, resnet_layer=None, video_input=False, f_maps=512, stride=(1, 1)):
super(fs_encoder_v2, self).__init__()
resnet50 = iresnet50()
resnet50.load_state_dict(torch.load(opts.arcface_model_path))
# input conv layer
if video_input:
self.conv = nn.Sequential(
nn.Conv2d(6, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
*list(resnet50.children())[1:3]
)
else:
self.conv = nn.Sequential(*list(resnet50.children())[:3])
# define layers
self.block_1 = list(resnet50.children())[3] # 15-18
self.block_2 = list(resnet50.children())[4] # 10-14
self.block_3 = list(resnet50.children())[5] # 5-9
self.block_4 = list(resnet50.children())[6] # 1-4
self.content_layer = nn.Sequential(
nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
nn.Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
nn.PReLU(num_parameters=512),
nn.Conv2d(512, 512, kernel_size=(3, 3), stride=stride, padding=(1, 1), bias=False),
nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
self.avg_pool = nn.AdaptiveAvgPool2d((3,3))
self.styles = nn.ModuleList()
for i in range(n_styles):
self.styles.append(nn.Linear(960 * 9, 512))
def forward(self, x):
latents = []
features = []
x = self.conv(x)
x = self.block_1(x)
features.append(self.avg_pool(x))
x = self.block_2(x)
features.append(self.avg_pool(x))
x = self.block_3(x)
content = self.content_layer(x)
features.append(self.avg_pool(x))
x = self.block_4(x)
features.append(self.avg_pool(x))
x = torch.cat(features, dim=1)
x = x.view(x.size(0), -1)
for i in range(len(self.styles)):
latents.append(self.styles[i](x))
out = torch.stack(latents, dim=1)
return out, content