import os import torch import torchvision import torchvision.transforms as transforms from torch.utils.data import Dataset, DataLoader import gradio as gr import sys import tqdm sys.path.append(os.path.abspath(os.path.join("", ".."))) import gc import warnings warnings.filterwarnings("ignore") from PIL import Image import numpy as np from utils import load_models from editing import get_direction, debias from sampling import sample_weights from lora_w2w import LoRAw2w from huggingface_hub import snapshot_download import spaces models_path = snapshot_download(repo_id="Snapchat/w2w") class main(): def __init__(self): super(main, self).__init__() device = "cuda" mean = torch.load(f"{models_path}/files/mean.pt", map_location=torch.device('cpu')).bfloat16().to(device) std = torch.load(f"{models_path}/files/std.pt", map_location=torch.device('cpu')).bfloat16().to(device) v = torch.load(f"{models_path}/files/V.pt", map_location=torch.device('cpu')).bfloat16().to(device) proj = torch.load(f"{models_path}/files/proj_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device) df = torch.load(f"{models_path}/files/identity_df.pt") weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt") pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device) self.device = device self.mean = mean self.std = std self.v = v self.proj = proj self.df = df self.weight_dimensions = weight_dimensions self.pinverse = pinverse self.unet, self.vae, self.text_encoder, self.tokenizer, self.noise_scheduler = load_models(device) self.network = None young = get_direction(df, "Young", pinverse, 1000, device) young = debias(young, "Male", df, pinverse, device) young = debias(young, "Pointy_Nose", df, pinverse, device) young = debias(young, "Wavy_Hair", df, pinverse, device) young = debias(young, "Chubby", df, pinverse, device) young = debias(young, "No_Beard", df, pinverse, device) young = debias(young, "Mustache", df, pinverse, device) self.young = young pointy = get_direction(df, "Pointy_Nose", pinverse, 1000, device) pointy = debias(pointy, "Young", df, pinverse, device) pointy = debias(pointy, "Male", df, pinverse, device) pointy = debias(pointy, "Wavy_Hair", df, pinverse, device) pointy = debias(pointy, "Chubby", df, pinverse, device) pointy = debias(pointy, "Heavy_Makeup", df, pinverse, device) self.pointy = pointy wavy = get_direction(df, "Wavy_Hair", pinverse, 1000, device) wavy = debias(wavy, "Young", df, pinverse, device) wavy = debias(wavy, "Male", df, pinverse, device) wavy = debias(wavy, "Pointy_Nose", df, pinverse, device) wavy = debias(wavy, "Chubby", df, pinverse, device) wavy = debias(wavy, "Heavy_Makeup", df, pinverse, device) self.wavy = wavy thick = get_direction(df, "Bushy_Eyebrows", pinverse, 1000, device) thick = debias(thick, "Male", df, pinverse, device) thick = debias(thick, "Young", df, pinverse, device) thick = debias(thick, "Pointy_Nose", df, pinverse, device) thick = debias(thick, "Wavy_Hair", df, pinverse, device) thick = debias(thick, "Mustache", df, pinverse, device) thick = debias(thick, "No_Beard", df, pinverse, device) thick = debias(thick, "Sideburns", df, pinverse, device) thick = debias(thick, "Big_Nose", df, pinverse, device) thick = debias(thick, "Big_Lips", df, pinverse, device) thick = debias(thick, "Black_Hair", df, pinverse, device) thick = debias(thick, "Brown_Hair", df, pinverse, device) thick = debias(thick, "Pale_Skin", df, pinverse, device) thick = debias(thick, "Heavy_Makeup", df, pinverse, device) self.thick = thick def sample_model(self): self.unet, _, _, _, _ = load_models(self.device) self.network = sample_weights(self.unet, self.proj, self.mean, self.std, self.v[:, :1000], self.device, factor = 1.00) @torch.no_grad() @spaces.GPU def inference(self, prompt, negative_prompt, guidance_scale, ddim_steps, seed): generator = torch.Generator(device=device).manual_seed(seed) latents = torch.randn( (1, self.unet.in_channels, 512 // 8, 512 // 8), generator = generator, device = self.device ).bfloat16() text_input = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt") text_embeddings = self.text_encoder(text_input.input_ids.to(device))[0] max_length = text_input.input_ids.shape[-1] uncond_input = self.tokenizer( [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt" ) uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0] text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) self.noise_scheduler.set_timesteps(ddim_steps) latents = latents * self.noise_scheduler.init_noise_sigma for i,t in enumerate(tqdm.tqdm(self.noise_scheduler.timesteps)): latent_model_input = torch.cat([latents] * 2) latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep=t) with self.network: noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample #guidance noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) latents = noise_scheduler.step(noise_pred, t, latents).prev_sample latents = 1 / 0.18215 * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0] image = Image.fromarray((image * 255).round().astype("uint8")) return image @torch.no_grad() @spaces.GPU def edit_inference(self, prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4): original_weights = self,network.proj.clone() #pad to same number of PCs pcs_original = original_weights.shape[1] pcs_edits = self.young.shape[1] padding = torch.zeros((1,pcs_original-pcs_edits)).to(device) young_pad = torch.cat((self.young, padding), 1) pointy_pad = torch.cat((self.pointy, padding), 1) wavy_pad = torch.cat((self.wavy, padding), 1) thick_pad = torch.cat((self.thick, padding), 1) edited_weights = original_weights+a1*1e6*young_pad+a2*1e6*pointy_pad+a3*1e6*wavy_pad+a4*2e6*thick_pad generator = torch.Generator(device=device).manual_seed(seed) latents = torch.randn( (1, self.unet.in_channels, 512 // 8, 512 // 8), generator = generator, device = self.device ).bfloat16() text_input = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt") text_embeddings = text_encoder(text_input.input_ids.to(device))[0] max_length = text_input.input_ids.shape[-1] uncond_input = tokenizer( [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt" ) uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) noise_scheduler.set_timesteps(ddim_steps) latents = latents * noise_scheduler.init_noise_sigma for i,t in enumerate(tqdm.tqdm(self.noise_scheduler.timesteps)): latent_model_input = torch.cat([latents] * 2) latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep=t) if t>start_noise: pass elif t<=start_noise: self.network.proj = torch.nn.Parameter(edited_weights) self.network.reset() with self.network: noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample #guidance noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) latents = noise_scheduler.step(noise_pred, t, latents).prev_sample latents = 1 / 0.18215 * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0] image = Image.fromarray((image * 255).round().astype("uint8")) #reset weights back to original self.network.proj = torch.nn.Parameter(original_weights) self.network.reset() return image @spaces.GPU def sample_then_run(self): sample_model() prompt = "sks person" negative_prompt = "low quality, blurry, unfinished, nudity, weapon" seed = 5 cfg = 3.0 steps = 25 image = inference( prompt, negative_prompt, cfg, steps, seed) torch.save(self.network.proj, "model.pt" ) return image, "model.pt" class CustomImageDataset(Dataset): def __init__(self, images, transform=None): self.images = images self.transform = transform def __len__(self): return len(self.images) def __getitem__(self, idx): image = self.images[idx] if self.transform: image = self.transform(image) return image @spaces.GPU def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1): del unet del network unet, _, _, _, _ = load_models(device) proj = torch.zeros(1,pcs).bfloat16().to(device) network = LoRAw2w( proj, mean, std, v[:, :pcs], unet, rank=1, multiplier=1.0, alpha=27.0, train_method="xattn-strict" ).to(device, torch.bfloat16) ### load mask mask = transforms.Resize((64,64), interpolation=transforms.InterpolationMode.BILINEAR)(mask) mask = torchvision.transforms.functional.pil_to_tensor(mask).unsqueeze(0).to(device).bfloat16()[:,0,:,:].unsqueeze(1) ### check if an actual mask was draw, otherwise mask is just all ones if torch.sum(mask) == 0: mask = torch.ones((1,1,64,64)).to(device).bfloat16() ### single image dataset image_transforms = transforms.Compose([transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR), transforms.RandomCrop(512), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) train_dataset = CustomImageDataset(image, transform=image_transforms) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True) ### optimizer optim = torch.optim.Adam(network.parameters(), lr=lr, weight_decay=weight_decay) ### training loop unet.train() for epoch in tqdm.tqdm(range(epochs)): for batch in train_dataloader: ### prepare inputs batch = batch.to(device).bfloat16() latents = vae.encode(batch).latent_dist.sample() latents = latents*0.18215 noise = torch.randn_like(latents) bsz = latents.shape[0] timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) text_input = tokenizer("sks person", padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") text_embeddings = text_encoder(text_input.input_ids.to(device))[0] ### loss + sgd step with network: model_pred = unet(noisy_latents, timesteps, text_embeddings).sample loss = torch.nn.functional.mse_loss(mask*model_pred.float(), mask*noise.float(), reduction="mean") optim.zero_grad() loss.backward() optim.step() ### return optimized network return network @spaces.GPU def run_inversion(dict, pcs, epochs, weight_decay,lr): init_image = dict["image"].convert("RGB").resize((512, 512)) mask = dict["mask"].convert("RGB").resize((512, 512)) network = invert([init_image], mask, pcs, epochs, weight_decay,lr) #sample an image prompt = "sks person" negative_prompt = "low quality, blurry, unfinished, nudity" seed = 5 cfg = 3.0 steps = 25 image = inference( prompt, negative_prompt, cfg, steps, seed) torch.save(network.proj, "model.pt" ) return image, "model.pt" @spaces.GPU def file_upload(file): del unet del network proj = torch.load(file.name).to(device) #pad to 10000 Principal components to keep everything consistent pcs = proj.shape[1] padding = torch.zeros((1,10000-pcs)).to(device) proj = torch.cat((proj, padding), 1) unet, _, _, _, _ = load_models(device) network = LoRAw2w( proj, mean, std, v[:, :10000], unet, rank=1, multiplier=1.0, alpha=27.0, train_method="xattn-strict" ).to(device, torch.bfloat16) prompt = "sks person" negative_prompt = "low quality, blurry, unfinished, nudity" seed = 5 cfg = 3.0 steps = 25 image = inference( prompt, negative_prompt, cfg, steps, seed) return image intro = """
Project Page | Paper
| Code |