Spaces:
Build error
Build error
from collections import defaultdict | |
import torch | |
import torch.nn.functional as F | |
import torchvision.transforms as T | |
from torch import nn | |
from torch.utils.data import DataLoader | |
from datasets.image_dataset import ImagesDataset, image_collate | |
from models.FeatureStyleEncoder import FSencoder | |
from models.Net import Net, get_segmentation | |
from models.encoder4editing.utils.model_utils import setup_model, get_latents | |
from utils.bicubic import BicubicDownSample | |
from utils.save_utils import save_gen_image, save_latents | |
class Embedding(nn.Module): | |
""" | |
Module for image embedding | |
""" | |
def __init__(self, opts, net=None): | |
super().__init__() | |
self.opts = opts | |
if net is None: | |
self.net = Net(self.opts) | |
else: | |
self.net = net | |
self.encoder = FSencoder.get_trainer(self.opts.device) | |
self.e4e, _ = setup_model('pretrained_models/encoder4editing/e4e_ffhq_encode.pt', self.opts.device) | |
self.normalize = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
self.to_bisenet = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) | |
self.downsample_512 = BicubicDownSample(factor=2) | |
self.downsample_256 = BicubicDownSample(factor=4) | |
def setup_dataloader(self, images: dict[torch.Tensor, list[str]] | list[torch.Tensor], batch_size=None): | |
self.dataset = ImagesDataset(images) | |
self.dataloader = DataLoader(self.dataset, collate_fn=image_collate, shuffle=False, | |
batch_size=batch_size or self.opts.batch_size) | |
def get_e4e_embed(self, images: list[torch.Tensor]) -> dict[str, torch.Tensor]: | |
device = self.opts.device | |
self.setup_dataloader(images, batch_size=len(images)) | |
for image, _ in self.dataloader: | |
image = image.to(device) | |
latent_W = get_latents(self.e4e, image) | |
latent_F, _ = self.net.generator([latent_W], input_is_latent=True, return_latents=False, | |
start_layer=0, end_layer=3) | |
return {"F": latent_F, "W": latent_W} | |
def embedding_images(self, images_to_name: dict[torch.Tensor, list[str]], **kwargs) -> dict[ | |
str, dict[str, torch.Tensor]]: | |
device = self.opts.device | |
self.setup_dataloader(images_to_name) | |
name_to_embed = defaultdict(dict) | |
for image, names in self.dataloader: | |
image = image.to(device) | |
im_512 = self.downsample_512(image) | |
im_256 = self.downsample_256(image) | |
im_256_norm = self.normalize(im_256) | |
# E4E | |
latent_W = get_latents(self.e4e, im_256_norm) | |
# FS encoder | |
output = self.encoder.test(img=self.normalize(image), return_latent=True) | |
latent = output.pop() # [bs, 512, 16, 16] | |
latent_S = output.pop() # [bs, 18, 512] | |
latent_F, _ = self.net.generator([latent_S], input_is_latent=True, return_latents=False, | |
start_layer=3, end_layer=3, layer_in=latent) # [bs, 512, 32, 32] | |
# BiSeNet | |
masks = torch.cat([get_segmentation(image.unsqueeze(0)) for image in self.to_bisenet(im_512)]) | |
# Mixing if we change the color or shape | |
if len(images_to_name) > 1: | |
hair_mask = torch.where(masks == 13, torch.ones_like(masks, device=device), | |
torch.zeros_like(masks, device=device)) | |
hair_mask = F.interpolate(hair_mask.float(), size=(32, 32), mode='bicubic') | |
latent_F_from_W = self.net.generator([latent_W], input_is_latent=True, return_latents=False, | |
start_layer=0, end_layer=3)[0] | |
latent_F = latent_F + self.opts.mixing * hair_mask * (latent_F_from_W - latent_F) | |
for k, names in enumerate(names): | |
for name in names: | |
name_to_embed[name]['W'] = latent_W[k].unsqueeze(0) | |
name_to_embed[name]['F'] = latent_F[k].unsqueeze(0) | |
name_to_embed[name]['S'] = latent_S[k].unsqueeze(0) | |
name_to_embed[name]['mask'] = masks[k].unsqueeze(0) | |
name_to_embed[name]['image_256'] = im_256[k].unsqueeze(0) | |
name_to_embed[name]['image_norm_256'] = im_256_norm[k].unsqueeze(0) | |
if self.opts.save_all: | |
gen_W_im, _ = self.net.generator([latent_W], input_is_latent=True, return_latents=False) | |
gen_FS_im, _ = self.net.generator([latent_S], input_is_latent=True, return_latents=False, | |
start_layer=4, end_layer=8, layer_in=latent_F) | |
exp_name = exp_name if (exp_name := kwargs.get('exp_name')) is not None else "" | |
output_dir = self.opts.save_all_dir / exp_name | |
for name, im_W, lat_W in zip(names, gen_W_im, latent_W): | |
save_gen_image(output_dir, 'W+', f'{name}.png', im_W) | |
save_latents(output_dir, 'W+', f'{name}.npz', latent_W=lat_W) | |
for name, im_F, lat_S, lat_F in zip(names, gen_FS_im, latent_S, latent_F): | |
save_gen_image(output_dir, 'FS', f'{name}.png', im_F) | |
save_latents(output_dir, 'FS', f'{name}.npz', latent_S=lat_S, latent_F=lat_F) | |
return name_to_embed | |