Spaces:
Build error
Build error
import argparse | |
import clip | |
import torch | |
import torch.nn as nn | |
from torch.nn import Linear, LayerNorm, LeakyReLU, Sequential | |
from torchvision import transforms as T | |
from models.Net import FeatureEncoderMult, IBasicBlock, conv1x1 | |
from models.stylegan2.model import PixelNorm | |
class ModulationModule(nn.Module): | |
def __init__(self, layernum, last=False, inp=512, middle=512): | |
super().__init__() | |
self.layernum = layernum | |
self.last = last | |
self.fc = Linear(512, 512) | |
self.norm = LayerNorm([self.layernum, 512], elementwise_affine=False) | |
self.gamma_function = Sequential(Linear(inp, middle), LayerNorm([middle]), LeakyReLU(), Linear(middle, 512)) | |
self.beta_function = Sequential(Linear(inp, middle), LayerNorm([middle]), LeakyReLU(), Linear(middle, 512)) | |
self.leakyrelu = LeakyReLU() | |
def forward(self, x, embedding): | |
x = self.fc(x) | |
x = self.norm(x) | |
gamma = self.gamma_function(embedding) | |
beta = self.beta_function(embedding) | |
out = x * (1 + gamma) + beta | |
if not self.last: | |
out = self.leakyrelu(out) | |
return out | |
class FeatureiResnet(nn.Module): | |
def __init__(self, blocks, inplanes=1024): | |
super().__init__() | |
self.res_blocks = {} | |
for n, block in enumerate(blocks, start=1): | |
planes, num_blocks = block | |
for k in range(1, num_blocks + 1): | |
downsample = None | |
if inplanes != planes: | |
downsample = nn.Sequential(conv1x1(inplanes, planes, 1), nn.BatchNorm2d(planes, eps=1e-05, ), ) | |
self.res_blocks[f'res_block_{n}_{k}'] = IBasicBlock(inplanes, planes, 1, downsample, 1, 64, 1) | |
inplanes = planes | |
self.res_blocks = nn.ModuleDict(self.res_blocks) | |
def forward(self, x): | |
for module in self.res_blocks.values(): | |
x = module(x) | |
return x | |
class RotateModel(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.pixelnorm = PixelNorm() | |
self.modulation_module_list = nn.ModuleList([ModulationModule(6, i == 4) for i in range(5)]) | |
def forward(self, latent_from, latent_to): | |
dt_latent = self.pixelnorm(latent_from) | |
for modulation_module in self.modulation_module_list: | |
dt_latent = modulation_module(dt_latent, latent_to) | |
output = latent_from + 0.1 * dt_latent | |
return output | |
class ClipBlendingModel(nn.Module): | |
def __init__(self, clip_model="ViT-B/32"): | |
super().__init__() | |
self.pixelnorm = PixelNorm() | |
self.clip_model, _ = clip.load(clip_model, device="cuda") | |
self.transform = T.Compose( | |
[T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))]) | |
self.face_pool = torch.nn.AdaptiveAvgPool2d((224, 224)) | |
self.modulation_module_list = nn.ModuleList( | |
[ModulationModule(12, i == 4, inp=512 * 3, middle=1024) for i in range(5)] | |
) | |
for param in self.clip_model.parameters(): | |
param.requires_grad = False | |
def get_image_embed(self, image_tensor): | |
resized_tensor = self.face_pool(image_tensor) | |
renormed_tensor = self.transform(resized_tensor * 0.5 + 0.5) | |
return self.clip_model.encode_image(renormed_tensor) | |
def forward(self, latent_face, latent_color, target_face, hair_color): | |
embed_face = self.get_image_embed(target_face).unsqueeze(1).expand(-1, 12, -1) | |
embed_color = self.get_image_embed(hair_color).unsqueeze(1).expand(-1, 12, -1) | |
latent_in = torch.cat((latent_color, embed_face, embed_color), dim=-1) | |
dt_latent = self.pixelnorm(latent_face) | |
for modulation_module in self.modulation_module_list: | |
dt_latent = modulation_module(dt_latent, latent_in) | |
output = latent_face + 0.1 * dt_latent | |
return output | |
class PostProcessModel(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.encoder_face = FeatureEncoderMult(fs_layers=[9], opts=argparse.Namespace( | |
**{'arcface_model_path': "pretrained_models/ArcFace/backbone_ir50.pth"})) | |
self.latent_avg = torch.load('pretrained_models/PostProcess/latent_avg.pt', map_location=torch.device('cuda')) | |
self.to_feature = FeatureiResnet([[1024, 2], [768, 2], [512, 2]]) | |
self.to_latent_1 = nn.ModuleList([ModulationModule(18, i == 4) for i in range(5)]) | |
self.to_latent_2 = nn.ModuleList([ModulationModule(18, i == 4) for i in range(5)]) | |
self.pixelnorm = PixelNorm() | |
def forward(self, source, target): | |
s_face, [f_face] = self.encoder_face(source) | |
s_hair, [f_hair] = self.encoder_face(target) | |
dt_latent_face = self.pixelnorm(s_face) | |
dt_latent_hair = self.pixelnorm(s_hair) | |
for mod_module in self.to_latent_1: | |
dt_latent_face = mod_module(dt_latent_face, s_hair) | |
for mod_module in self.to_latent_2: | |
dt_latent_hair = mod_module(dt_latent_hair, s_face) | |
finall_s = self.latent_avg + 0.1 * (dt_latent_face + dt_latent_hair) | |
cat_f = torch.cat((f_face, f_hair), dim=1) | |
finall_f = self.to_feature(cat_f) | |
return finall_s, finall_f | |
class ClipModel(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.clip_model, _ = clip.load("ViT-B/32", device="cuda") | |
self.transform = T.Compose( | |
[T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))] | |
) | |
self.face_pool = torch.nn.AdaptiveAvgPool2d((224, 224)) | |
for param in self.clip_model.parameters(): | |
param.requires_grad = False | |
def forward(self, image_tensor): | |
if not image_tensor.is_cuda: | |
image_tensor = image_tensor.to("cuda") | |
if image_tensor.dtype == torch.uint8: | |
image_tensor = image_tensor / 255 | |
resized_tensor = self.face_pool(image_tensor) | |
renormed_tensor = self.transform(resized_tensor) | |
return self.clip_model.encode_image(renormed_tensor) | |