import argparse import clip import torch import torch.nn as nn from torch.nn import Linear, LayerNorm, LeakyReLU, Sequential from torchvision import transforms as T from models.Net import FeatureEncoderMult, IBasicBlock, conv1x1 from models.stylegan2.model import PixelNorm class ModulationModule(nn.Module): def __init__(self, layernum, last=False, inp=512, middle=512): super().__init__() self.layernum = layernum self.last = last self.fc = Linear(512, 512) self.norm = LayerNorm([self.layernum, 512], elementwise_affine=False) self.gamma_function = Sequential(Linear(inp, middle), LayerNorm([middle]), LeakyReLU(), Linear(middle, 512)) self.beta_function = Sequential(Linear(inp, middle), LayerNorm([middle]), LeakyReLU(), Linear(middle, 512)) self.leakyrelu = LeakyReLU() def forward(self, x, embedding): x = self.fc(x) x = self.norm(x) gamma = self.gamma_function(embedding) beta = self.beta_function(embedding) out = x * (1 + gamma) + beta if not self.last: out = self.leakyrelu(out) return out class FeatureiResnet(nn.Module): def __init__(self, blocks, inplanes=1024): super().__init__() self.res_blocks = {} for n, block in enumerate(blocks, start=1): planes, num_blocks = block for k in range(1, num_blocks + 1): downsample = None if inplanes != planes: downsample = nn.Sequential(conv1x1(inplanes, planes, 1), nn.BatchNorm2d(planes, eps=1e-05, ), ) self.res_blocks[f'res_block_{n}_{k}'] = IBasicBlock(inplanes, planes, 1, downsample, 1, 64, 1) inplanes = planes self.res_blocks = nn.ModuleDict(self.res_blocks) def forward(self, x): for module in self.res_blocks.values(): x = module(x) return x class RotateModel(nn.Module): def __init__(self): super().__init__() self.pixelnorm = PixelNorm() self.modulation_module_list = nn.ModuleList([ModulationModule(6, i == 4) for i in range(5)]) def forward(self, latent_from, latent_to): dt_latent = self.pixelnorm(latent_from) for modulation_module in self.modulation_module_list: dt_latent = modulation_module(dt_latent, latent_to) output = latent_from + 0.1 * dt_latent return output class ClipBlendingModel(nn.Module): def __init__(self, clip_model="ViT-B/32"): super().__init__() self.pixelnorm = PixelNorm() self.clip_model, _ = clip.load(clip_model, device="cuda") self.transform = T.Compose( [T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))]) self.face_pool = torch.nn.AdaptiveAvgPool2d((224, 224)) self.modulation_module_list = nn.ModuleList( [ModulationModule(12, i == 4, inp=512 * 3, middle=1024) for i in range(5)] ) for param in self.clip_model.parameters(): param.requires_grad = False def get_image_embed(self, image_tensor): resized_tensor = self.face_pool(image_tensor) renormed_tensor = self.transform(resized_tensor * 0.5 + 0.5) return self.clip_model.encode_image(renormed_tensor) def forward(self, latent_face, latent_color, target_face, hair_color): embed_face = self.get_image_embed(target_face).unsqueeze(1).expand(-1, 12, -1) embed_color = self.get_image_embed(hair_color).unsqueeze(1).expand(-1, 12, -1) latent_in = torch.cat((latent_color, embed_face, embed_color), dim=-1) dt_latent = self.pixelnorm(latent_face) for modulation_module in self.modulation_module_list: dt_latent = modulation_module(dt_latent, latent_in) output = latent_face + 0.1 * dt_latent return output class PostProcessModel(nn.Module): def __init__(self): super().__init__() self.encoder_face = FeatureEncoderMult(fs_layers=[9], opts=argparse.Namespace( **{'arcface_model_path': "pretrained_models/ArcFace/backbone_ir50.pth"})) self.latent_avg = torch.load('pretrained_models/PostProcess/latent_avg.pt', map_location=torch.device('cuda')) self.to_feature = FeatureiResnet([[1024, 2], [768, 2], [512, 2]]) self.to_latent_1 = nn.ModuleList([ModulationModule(18, i == 4) for i in range(5)]) self.to_latent_2 = nn.ModuleList([ModulationModule(18, i == 4) for i in range(5)]) self.pixelnorm = PixelNorm() def forward(self, source, target): s_face, [f_face] = self.encoder_face(source) s_hair, [f_hair] = self.encoder_face(target) dt_latent_face = self.pixelnorm(s_face) dt_latent_hair = self.pixelnorm(s_hair) for mod_module in self.to_latent_1: dt_latent_face = mod_module(dt_latent_face, s_hair) for mod_module in self.to_latent_2: dt_latent_hair = mod_module(dt_latent_hair, s_face) finall_s = self.latent_avg + 0.1 * (dt_latent_face + dt_latent_hair) cat_f = torch.cat((f_face, f_hair), dim=1) finall_f = self.to_feature(cat_f) return finall_s, finall_f class ClipModel(nn.Module): def __init__(self): super().__init__() self.clip_model, _ = clip.load("ViT-B/32", device="cuda") self.transform = T.Compose( [T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))] ) self.face_pool = torch.nn.AdaptiveAvgPool2d((224, 224)) for param in self.clip_model.parameters(): param.requires_grad = False def forward(self, image_tensor): if not image_tensor.is_cuda: image_tensor = image_tensor.to("cuda") if image_tensor.dtype == torch.uint8: image_tensor = image_tensor / 255 resized_tensor = self.face_pool(image_tensor) renormed_tensor = self.transform(resized_tensor) return self.clip_model.encode_image(renormed_tensor)