Spaces:
Build error
Build error
File size: 5,499 Bytes
6d314be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
from collections import defaultdict
import torch
import torch.nn.functional as F
import torchvision.transforms as T
from torch import nn
from torch.utils.data import DataLoader
from datasets.image_dataset import ImagesDataset, image_collate
from models.FeatureStyleEncoder import FSencoder
from models.Net import Net, get_segmentation
from models.encoder4editing.utils.model_utils import setup_model, get_latents
from utils.bicubic import BicubicDownSample
from utils.save_utils import save_gen_image, save_latents
class Embedding(nn.Module):
"""
Module for image embedding
"""
def __init__(self, opts, net=None):
super().__init__()
self.opts = opts
if net is None:
self.net = Net(self.opts)
else:
self.net = net
self.encoder = FSencoder.get_trainer(self.opts.device)
self.e4e, _ = setup_model('pretrained_models/encoder4editing/e4e_ffhq_encode.pt', self.opts.device)
self.normalize = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
self.to_bisenet = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
self.downsample_512 = BicubicDownSample(factor=2)
self.downsample_256 = BicubicDownSample(factor=4)
def setup_dataloader(self, images: dict[torch.Tensor, list[str]] | list[torch.Tensor], batch_size=None):
self.dataset = ImagesDataset(images)
self.dataloader = DataLoader(self.dataset, collate_fn=image_collate, shuffle=False,
batch_size=batch_size or self.opts.batch_size)
@torch.inference_mode()
def get_e4e_embed(self, images: list[torch.Tensor]) -> dict[str, torch.Tensor]:
device = self.opts.device
self.setup_dataloader(images, batch_size=len(images))
for image, _ in self.dataloader:
image = image.to(device)
latent_W = get_latents(self.e4e, image)
latent_F, _ = self.net.generator([latent_W], input_is_latent=True, return_latents=False,
start_layer=0, end_layer=3)
return {"F": latent_F, "W": latent_W}
@torch.inference_mode()
def embedding_images(self, images_to_name: dict[torch.Tensor, list[str]], **kwargs) -> dict[
str, dict[str, torch.Tensor]]:
device = self.opts.device
self.setup_dataloader(images_to_name)
name_to_embed = defaultdict(dict)
for image, names in self.dataloader:
image = image.to(device)
im_512 = self.downsample_512(image)
im_256 = self.downsample_256(image)
im_256_norm = self.normalize(im_256)
# E4E
latent_W = get_latents(self.e4e, im_256_norm)
# FS encoder
output = self.encoder.test(img=self.normalize(image), return_latent=True)
latent = output.pop() # [bs, 512, 16, 16]
latent_S = output.pop() # [bs, 18, 512]
latent_F, _ = self.net.generator([latent_S], input_is_latent=True, return_latents=False,
start_layer=3, end_layer=3, layer_in=latent) # [bs, 512, 32, 32]
# BiSeNet
masks = torch.cat([get_segmentation(image.unsqueeze(0)) for image in self.to_bisenet(im_512)])
# Mixing if we change the color or shape
if len(images_to_name) > 1:
hair_mask = torch.where(masks == 13, torch.ones_like(masks, device=device),
torch.zeros_like(masks, device=device))
hair_mask = F.interpolate(hair_mask.float(), size=(32, 32), mode='bicubic')
latent_F_from_W = self.net.generator([latent_W], input_is_latent=True, return_latents=False,
start_layer=0, end_layer=3)[0]
latent_F = latent_F + self.opts.mixing * hair_mask * (latent_F_from_W - latent_F)
for k, names in enumerate(names):
for name in names:
name_to_embed[name]['W'] = latent_W[k].unsqueeze(0)
name_to_embed[name]['F'] = latent_F[k].unsqueeze(0)
name_to_embed[name]['S'] = latent_S[k].unsqueeze(0)
name_to_embed[name]['mask'] = masks[k].unsqueeze(0)
name_to_embed[name]['image_256'] = im_256[k].unsqueeze(0)
name_to_embed[name]['image_norm_256'] = im_256_norm[k].unsqueeze(0)
if self.opts.save_all:
gen_W_im, _ = self.net.generator([latent_W], input_is_latent=True, return_latents=False)
gen_FS_im, _ = self.net.generator([latent_S], input_is_latent=True, return_latents=False,
start_layer=4, end_layer=8, layer_in=latent_F)
exp_name = exp_name if (exp_name := kwargs.get('exp_name')) is not None else ""
output_dir = self.opts.save_all_dir / exp_name
for name, im_W, lat_W in zip(names, gen_W_im, latent_W):
save_gen_image(output_dir, 'W+', f'{name}.png', im_W)
save_latents(output_dir, 'W+', f'{name}.npz', latent_W=lat_W)
for name, im_F, lat_S, lat_F in zip(names, gen_FS_im, latent_S, latent_F):
save_gen_image(output_dir, 'FS', f'{name}.png', im_F)
save_latents(output_dir, 'FS', f'{name}.npz', latent_S=lat_S, latent_F=lat_F)
return name_to_embed
|