import sys sys.path.append('versatile_diffusion') import os import os.path as osp import PIL from PIL import Image from pathlib import Path import numpy as np import numpy.random as npr import torch import torchvision.transforms as tvtrans from lib.cfg_helper import model_cfg_bank from lib.model_zoo import get_model from lib.model_zoo.ddim_vd import DDIMSampler_VD from lib.experiments.sd_default import color_adjust, auto_merge_imlist from torch.utils.data import DataLoader, Dataset from lib.model_zoo.vd import VD from lib.cfg_holder import cfg_unique_holder as cfguh from lib.cfg_helper import get_command_line_args, cfg_initiates, load_cfg_yaml import matplotlib.pyplot as plt from skimage.transform import resize, downscale_local_mean import argparse parser = argparse.ArgumentParser(description='Argument Parser') parser.add_argument("-sub", "--sub",help="Subject Number",default=1) parser.add_argument("-diff_str", "--diff_str",help="Diffusion Strength",default=0.75) parser.add_argument("-mix_str", "--mix_str",help="Mixing Strength",default=0.4) args = parser.parse_args() sub=int(args.sub) assert sub in [1,2,5,7] strength = float(args.diff_str) mixing = float(args.mix_str) def regularize_image(x): BICUBIC = PIL.Image.Resampling.BICUBIC if isinstance(x, str): x = Image.open(x).resize([512, 512], resample=BICUBIC) x = tvtrans.ToTensor()(x) elif isinstance(x, PIL.Image.Image): x = x.resize([512, 512], resample=BICUBIC) x = tvtrans.ToTensor()(x) elif isinstance(x, np.ndarray): x = PIL.Image.fromarray(x).resize([512, 512], resample=BICUBIC) x = tvtrans.ToTensor()(x) elif isinstance(x, torch.Tensor): pass else: assert False, 'Unknown image type' assert (x.shape[1]==512) & (x.shape[2]==512), \ 'Wrong image size' return x cfgm_name = 'vd_noema' sampler = DDIMSampler_VD pth = 'versatile_diffusion/pretrained/vd-four-flow-v1-0-fp16-deprecated.pth' cfgm = model_cfg_bank()(cfgm_name) net = get_model()(cfgm) sd = torch.load(pth, map_location='cpu') net.load_state_dict(sd, strict=False) # Ensuring proper GPU device assignment, using cuda:0 for all tensor assignments device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # Move models and data to GPU (cuda:0) net.clip.cuda(0) net.autokl.cuda(0) sampler = sampler(net) sampler.model.model.diffusion_model.device = device sampler.model.model.diffusion_model.half().to(device) batch_size = 1 # Load predicted features and move them to GPU pred_text = np.load('data/predicted_features/subj{:02d}/nsd_cliptext_predtest_nsdgeneral.npy'.format(sub)) pred_text = torch.tensor(pred_text).half().to(device) pred_vision = np.load('data/predicted_features/subj{:02d}/nsd_clipvision_predtest_nsdgeneral.npy'.format(sub)) pred_vision = torch.tensor(pred_vision).half().to(device) n_samples = 1 ddim_steps = 50 ddim_eta = 0 scale = 7.5 xtype = 'image' ctype = 'prompt' net.autokl.half() torch.manual_seed(0) ## for clip text ######## # cfgm_name = 'vd_noema' # pth = 'versatile_diffusion/pretrained/vd-four-flow-v1-0-fp16-deprecated.pth' # cfgm = model_cfg_bank()(cfgm_name) # net = get_model()(cfgm) # sd = torch.load(pth, map_location='cpu') # net.load_state_dict(sd, strict=False) # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') net.clip = net.clip.to(device) # with torch.no_grad(): while True: image_id = int(input("Imagine a random object or scenery ")) for im_id in range(image_id,image_id+1): # print("surya theja -------------------------------") zim = Image.open('results/vdvae/subj{:02d}/{}.png'.format(sub,im_id)) test_img = Image.open('data/nsddata_stimuli/test_images/{}.png'.format(im_id)) # Display the test image test_img.save('scripts/images/imagined_image.png'.format(sub, im_id)) annot = input("Describe what you have imagined in few words : ") zim = regularize_image(zim) zin = zim*2 - 1 zin = zin.unsqueeze(0).to(device).half() init_latent = net.autokl_encode(zin) sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False) t_enc = int(strength * ddim_steps) # Encode the image using the sampler z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]).to(device)) # Encoding text and vision dummy = '' utx = net.clip_encode_text(dummy) utx = utx.to(device).half() dummy = torch.zeros((1,3,224,224)).to(device) uim = net.clip_encode_vision(dummy) uim = uim.to(device).half() z_enc = z_enc.to(device) # Sample configuration for diffusion h, w = 512,512 shape = [n_samples, 4, h//8, w//8] cim = pred_vision[im_id].unsqueeze(0).to(device) # ctx = pred_text[im_id].unsqueeze(0).to(device) pred_text = np.load('data/predicted_features/subj{:02d}/nsd_cliptext_predtest_nsdgeneral.npy'.format(sub)) with torch.no_grad(): pred_text[im_id] = net.clip_encode_text([annot]).to('cpu').numpy().mean(0) pred_text = torch.tensor(pred_text).half().to(device) ctx = pred_text[im_id].unsqueeze(0).to(device) # Decode using sampler z = sampler.decode_dc( x_latent=z_enc, first_conditioning=[uim, cim], second_conditioning=[utx, ctx], t_start=t_enc, unconditional_guidance_scale=scale, xtype='image', first_ctype='vision', second_ctype='prompt', mixed_ratio=(1-mixing), ) z = z.to(device).half() x = net.autokl_decode(z) # Adjust color if needed color_adj='None' color_adj_flag = (color_adj != 'none') and (color_adj != 'None') and (color_adj is not None) color_adj_simple = (color_adj == 'Simple') or color_adj == 'simple' color_adj_keep_ratio = 0.5 if color_adj_flag and (ctype == 'vision'): x_adj = [] for xi in x: color_adj_f = color_adjust(ref_from=(xi+1)/2, ref_to=color_adj_to) xi_adj = color_adj_f((xi+1)/2, keep=color_adj_keep_ratio, simple=color_adj_simple) x_adj.append(xi_adj) x = x_adj else: x = torch.clamp((x+1.0)/2.0, min=0.0, max=1.0) x = [tvtrans.ToPILImage()(xi) for xi in x] # Save output image x[0].save('scripts/images/reconstructed.png'.format(sub, im_id))