# generate.py import sys sys.path.append('versatile_diffusion') import os import os.path as osp import PIL from PIL import Image from pathlib import Path import numpy as np import numpy.random as npr import torch import torchvision.transforms as tvtrans from lib.cfg_helper import model_cfg_bank from lib.model_zoo import get_model from lib.model_zoo.ddim_vd import DDIMSampler_VD from lib.experiments.sd_default import color_adjust, auto_merge_imlist from torch.utils.data import DataLoader, Dataset from lib.model_zoo.vd import VD from lib.cfg_holder import cfg_unique_holder as cfguh from lib.cfg_helper import get_command_line_args, cfg_initiates, load_cfg_yaml import matplotlib.pyplot as plt from skimage.transform import resize, downscale_local_mean def regularize_image(x): BICUBIC = PIL.Image.Resampling.BICUBIC if isinstance(x, str): x = Image.open(x).resize([512, 512], resample=BICUBIC) x = tvtrans.ToTensor()(x) elif isinstance(x, PIL.Image.Image): x = x.resize([512, 512], resample=BICUBIC) x = tvtrans.ToTensor()(x) elif isinstance(x, np.ndarray): x = PIL.Image.fromarray(x).resize([512, 512], resample=BICUBIC) x = tvtrans.ToTensor()(x) elif isinstance(x, torch.Tensor): pass else: assert False, 'Unknown image type' assert (x.shape[1]==512) & (x.shape[2]==512), \ 'Wrong image size' return x # Load model once globally cfgm_name = 'vd_noema' sampler = DDIMSampler_VD pth = 'versatile_diffusion/pretrained/vd-four-flow-v1-0-fp16-deprecated.pth' cfgm = model_cfg_bank()(cfgm_name) net = get_model()(cfgm) sd = torch.load(pth, map_location='cpu') net.load_state_dict(sd, strict=False) # Ensuring proper GPU device assignment, using cuda:0 for all tensor assignments device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # Move models and data to GPU (cuda:0) net.clip.cuda(0) net.autokl.cuda(0) sampler = sampler(net) sampler.model.model.diffusion_model.device = device sampler.model.model.diffusion_model.half().to(device) batch_size = 1 # Load predicted features and move them to GPU # pred_text = np.load('data/predicted_features/subj{:02d}/nsd_cliptext_predtest_nsdgeneral.npy'.format(sub)) # pred_text = torch.tensor(pred_text).half().to(device) # pred_vision = np.load('data/predicted_features/subj{:02d}/nsd_clipvision_predtest_nsdgeneral.npy'.format(sub)) # pred_vision = torch.tensor(pred_vision).half().to(device) n_samples = 1 ddim_steps = 50 ddim_eta = 0 scale = 7.5 xtype = 'image' ctype = 'prompt' net.autokl.half() torch.manual_seed(0) net.clip = net.clip.to(device) def generate_image(sub, image_id, annot, strength=0.75, mixing=0.4): im_id = image_id pred_text = np.load(f'data/predicted_features/subj{sub:02d}/nsd_cliptext_predtest_nsdgeneral.npy') pred_vision = np.load(f'data/predicted_features/subj{sub:02d}/nsd_clipvision_predtest_nsdgeneral.npy') pred_text = torch.tensor(pred_text).half().to(device) pred_vision = torch.tensor(pred_vision).half().to(device) zim = Image.open(f'results/vdvae/subj{sub:02d}/{image_id}.png') test_img = Image.open(f'data/nsddata_stimuli/test_images/{image_id}.png') test_img_path = f'scripts/images/original_image.png' test_img.save(test_img_path) zim = regularize_image(zim) zin = zim * 2 - 1 zin = zin.unsqueeze(0).to(device).half() init_latent = net.autokl_encode(zin) sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False) t_enc = int(strength * ddim_steps) z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]).to(device)) dummy = '' utx = net.clip_encode_text(dummy).to(device).half() dummy = torch.zeros((1, 3, 224, 224)).to(device) uim = net.clip_encode_vision(dummy).to(device).half() z_enc = z_enc.to(device) # Sample configuration for diffusion h, w = 512,512 shape = [n_samples, 4, h//8, w//8] pred_text = np.load(f'data/predicted_features/subj{sub:02d}/nsd_cliptext_predtest_nsdgeneral.npy') with torch.no_grad(): pred_text[image_id] = net.clip_encode_text([annot]).to('cpu').numpy().mean(0) pred_text = torch.tensor(pred_text).half().to(device) ctx = pred_text[image_id].unsqueeze(0).to(device) cim = pred_vision[image_id].unsqueeze(0).to(device) z = sampler.decode_dc( x_latent=z_enc, first_conditioning=[uim, cim], second_conditioning=[utx, ctx], t_start=t_enc, unconditional_guidance_scale=7.5, xtype='image', first_ctype='vision', second_ctype='prompt', mixed_ratio=(1 - mixing), ) z = z.to(device).half() x = net.autokl_decode(z) # Adjust color if needed color_adj='None' color_adj_flag = (color_adj != 'none') and (color_adj != 'None') and (color_adj is not None) color_adj_simple = (color_adj == 'Simple') or color_adj == 'simple' color_adj_keep_ratio = 0.5 if color_adj_flag and (ctype == 'vision'): x_adj = [] for xi in x: color_adj_f = color_adjust(ref_from=(xi+1)/2, ref_to=color_adj_to) xi_adj = color_adj_f((xi+1)/2, keep=color_adj_keep_ratio, simple=color_adj_simple) x_adj.append(xi_adj) x = x_adj else: x = torch.clamp((x+1.0)/2.0, min=0.0, max=1.0) x = [tvtrans.ToPILImage()(xi) for xi in x] # Save output image x[0].save('scripts/images/reconstructed.png'.format(sub, im_id)) # x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0) # x = [tvtrans.ToPILImage()(xi) for xi in x] output_path = f'scripts/images/reconstructed.png' # x[0].save(output_path) return test_img_path, output_path