dineshsai07's picture
Add files using upload-large-folder tool
46a8d8a verified
import os
import os.path as osp
import PIL
from PIL import Image
from pathlib import Path
import numpy as np
import numpy.random as npr
import torch
import torchvision.transforms as tvtrans
from lib.cfg_helper import model_cfg_bank
from lib.model_zoo import get_model
from lib.model_zoo.ddim_dualcontext import DDIMSampler_DualContext
from lib.experiments.sd_default import color_adjust, auto_merge_imlist
import argparse
n_sample_image_default = 2
n_sample_text_default = 4
def highlight_print(info):
print('')
print(''.join(['#']*(len(info)+4)))
print('# '+info+' #')
print(''.join(['#']*(len(info)+4)))
print('')
class vd_inference(object):
def __init__(self, pth='pretrained/vd1.0-four-flow.pth', fp16=False, device=0):
cfgm_name = 'vd_noema'
cfgm = model_cfg_bank()('vd_noema')
device_str = device if isinstance(device, str) else 'cuda:{}'.format(device)
cfgm.args.autokl_cfg.map_location = device_str
cfgm.args.optimus_cfg.map_location = device_str
net = get_model()(cfgm)
if fp16:
highlight_print('Running in FP16')
net.clip.fp16 = True
net = net.half()
sd = torch.load(pth, map_location=device_str)
net.load_state_dict(sd, strict=False)
print('Load pretrained weight from {}'.format(pth))
net.to(device)
self.device = device
self.model_name = cfgm_name
self.net = net
self.fp16 = fp16
from lib.model_zoo.ddim_vd import DDIMSampler_VD
self.sampler = DDIMSampler_VD(net)
def regularize_image(self, x):
BICUBIC = PIL.Image.Resampling.BICUBIC
if isinstance(x, str):
x = Image.open(x).resize([512, 512], resample=BICUBIC)
x = tvtrans.ToTensor()(x)
elif isinstance(x, PIL.Image.Image):
x = x.resize([512, 512], resample=BICUBIC)
x = tvtrans.ToTensor()(x)
elif isinstance(x, np.ndarray):
x = PIL.Image.fromarray(x).resize([512, 512], resample=BICUBIC)
x = tvtrans.ToTensor()(x)
elif isinstance(x, torch.Tensor):
pass
else:
assert False, 'Unknown image type'
assert (x.shape[1]==512) & (x.shape[2]==512), \
'Wrong image size'
x = x.to(self.device)
if self.fp16:
x = x.half()
return x
def decode(self, z, xtype, ctype, color_adj='None', color_adj_to=None):
net = self.net
if xtype == 'image':
x = net.autokl_decode(z)
color_adj_flag = (color_adj!='none') and (color_adj!='None') and (color_adj is not None)
color_adj_simple = (color_adj=='Simple') or color_adj=='simple'
color_adj_keep_ratio = 0.5
if color_adj_flag and (ctype=='vision'):
x_adj = []
for xi in x:
color_adj_f = color_adjust(ref_from=(xi+1)/2, ref_to=color_adj_to)
xi_adj = color_adj_f((xi+1)/2, keep=color_adj_keep_ratio, simple=color_adj_simple)
x_adj.append(xi_adj)
x = x_adj
else:
x = torch.clamp((x+1.0)/2.0, min=0.0, max=1.0)
x = [tvtrans.ToPILImage()(xi) for xi in x]
return x
elif xtype == 'text':
prompt_temperature = 1.0
prompt_merge_same_adj_word = True
x = net.optimus_decode(z, temperature=prompt_temperature)
if prompt_merge_same_adj_word:
xnew = []
for xi in x:
xi_split = xi.split()
xinew = []
for idxi, wi in enumerate(xi_split):
if idxi!=0 and wi==xi_split[idxi-1]:
continue
xinew.append(wi)
xnew.append(' '.join(xinew))
x = xnew
return x
def inference(self, xtype, cin, ctype, scale=7.5, n_samples=None, color_adj=None,):
net = self.net
sampler = self.sampler
ddim_steps = 50
ddim_eta = 0.0
if xtype == 'image':
n_samples = n_sample_image_default if n_samples is None else n_samples
elif xtype == 'text':
n_samples = n_sample_text_default if n_samples is None else n_samples
if ctype in ['prompt', 'text']:
c = net.clip_encode_text(n_samples * [cin])
u = None
if scale != 1.0:
u = net.clip_encode_text(n_samples * [""])
elif ctype in ['vision', 'image']:
cin = self.regularize_image(cin)
ctemp = cin*2 - 1
ctemp = ctemp[None].repeat(n_samples, 1, 1, 1)
c = net.clip_encode_vision(ctemp)
u = None
if scale != 1.0:
dummy = torch.zeros_like(ctemp)
u = net.clip_encode_vision(dummy)
u, c = [u.half(), c.half()] if self.fp16 else [u, c]
if xtype == 'image':
h, w = [512, 512]
shape = [n_samples, 4, h//8, w//8]
z, _ = sampler.sample(
steps=ddim_steps,
shape=shape,
conditioning=c,
unconditional_guidance_scale=scale,
unconditional_conditioning=u,
xtype=xtype, ctype=ctype,
eta=ddim_eta,
verbose=False,)
x = self.decode(z, xtype, ctype, color_adj=color_adj, color_adj_to=cin)
return x
elif xtype == 'text':
n = 768
shape = [n_samples, n]
z, _ = sampler.sample(
steps=ddim_steps,
shape=shape,
conditioning=c,
unconditional_guidance_scale=scale,
unconditional_conditioning=u,
xtype=xtype, ctype=ctype,
eta=ddim_eta,
verbose=False,)
x = self.decode(z, xtype, ctype)
return x
def application_disensemble(self, cin, n_samples=None, level=0, color_adj=None,):
net = self.net
scale = 7.5
sampler = self.sampler
ddim_steps = 50
ddim_eta = 0.0
n_samples = n_sample_image_default if n_samples is None else n_samples
cin = self.regularize_image(cin)
ctemp = cin*2 - 1
ctemp = ctemp[None].repeat(n_samples, 1, 1, 1)
c = net.clip_encode_vision(ctemp)
u = None
if scale != 1.0:
dummy = torch.zeros_like(ctemp)
u = net.clip_encode_vision(dummy)
u, c = [u.half(), c.half()] if self.fp16 else [u, c]
if level == 0:
pass
else:
c_glb = c[:, 0:1]
c_loc = c[:, 1: ]
u_glb = u[:, 0:1]
u_loc = u[:, 1: ]
if level == -1:
c_loc = self.remove_low_rank(c_loc, demean=True, q=50, q_remove=1)
u_loc = self.remove_low_rank(u_loc, demean=True, q=50, q_remove=1)
if level == -2:
c_loc = self.remove_low_rank(c_loc, demean=True, q=50, q_remove=2)
u_loc = self.remove_low_rank(u_loc, demean=True, q=50, q_remove=2)
if level == 1:
c_loc = self.find_low_rank(c_loc, demean=True, q=10)
u_loc = self.find_low_rank(u_loc, demean=True, q=10)
if level == 2:
c_loc = self.find_low_rank(c_loc, demean=True, q=2)
u_loc = self.find_low_rank(u_loc, demean=True, q=2)
c = torch.cat([c_glb, c_loc], dim=1)
u = torch.cat([u_glb, u_loc], dim=1)
h, w = [512, 512]
shape = [n_samples, 4, h//8, w//8]
z, _ = sampler.sample(
steps=ddim_steps,
shape=shape,
conditioning=c,
unconditional_guidance_scale=scale,
unconditional_conditioning=u,
xtype='image', ctype='vision',
eta=ddim_eta,
verbose=False,)
x = self.decode(z, 'image', 'vision', color_adj=color_adj, color_adj_to=cin)
return x
def find_low_rank(self, x, demean=True, q=20, niter=10):
if demean:
x_mean = x.mean(-1, keepdim=True)
x_input = x - x_mean
else:
x_input = x
if x_input.dtype == torch.float16:
fp16 = True
x_input = x_input.float()
else:
fp16 = False
u, s, v = torch.pca_lowrank(x_input, q=q, center=False, niter=niter)
ss = torch.stack([torch.diag(si) for si in s])
x_lowrank = torch.bmm(torch.bmm(u, ss), torch.permute(v, [0, 2, 1]))
if fp16:
x_lowrank = x_lowrank.half()
if demean:
x_lowrank += x_mean
return x_lowrank
def remove_low_rank(self, x, demean=True, q=20, niter=10, q_remove=10):
if demean:
x_mean = x.mean(-1, keepdim=True)
x_input = x - x_mean
else:
x_input = x
if x_input.dtype == torch.float16:
fp16 = True
x_input = x_input.float()
else:
fp16 = False
u, s, v = torch.pca_lowrank(x_input, q=q, center=False, niter=niter)
s[:, 0:q_remove] = 0
ss = torch.stack([torch.diag(si) for si in s])
x_lowrank = torch.bmm(torch.bmm(u, ss), torch.permute(v, [0, 2, 1]))
if fp16:
x_lowrank = x_lowrank.half()
if demean:
x_lowrank += x_mean
return x_lowrank
def application_dualguided(self, cim, ctx, n_samples=None, mixing=0.5, color_adj=None, ):
net = self.net
scale = 7.5
sampler = self.sampler
ddim_steps = 50
ddim_eta = 0.0
n_samples = n_sample_image_default if n_samples is None else n_samples
ctemp0 = self.regularize_image(cim)
ctemp1 = ctemp0*2 - 1
ctemp1 = ctemp1[None].repeat(n_samples, 1, 1, 1)
cim = net.clip_encode_vision(ctemp1)
uim = None
if scale != 1.0:
dummy = torch.zeros_like(ctemp1)
uim = net.clip_encode_vision(dummy)
ctx = net.clip_encode_text(n_samples * [ctx])
utx = None
if scale != 1.0:
utx = net.clip_encode_text(n_samples * [""])
uim, cim = [uim.half(), cim.half()] if self.fp16 else [uim, cim]
utx, ctx = [utx.half(), ctx.half()] if self.fp16 else [utx, ctx]
h, w = [512, 512]
shape = [n_samples, 4, h//8, w//8]
z, _ = sampler.sample_dc(
steps=ddim_steps,
shape=shape,
first_conditioning=[uim, cim],
second_conditioning=[utx, ctx],
unconditional_guidance_scale=scale,
xtype='image',
first_ctype='vision',
second_ctype='prompt',
eta=ddim_eta,
verbose=False,
mixed_ratio=(1-mixing), )
x = self.decode(z, 'image', 'vision', color_adj=color_adj, color_adj_to=ctemp0)
return x
def application_i2t2i(self, cim, ctx_n, ctx_p, n_samples=None, color_adj=None,):
net = self.net
scale = 7.5
sampler = self.sampler
ddim_steps = 50
ddim_eta = 0.0
prompt_temperature = 1.0
n_samples = n_sample_image_default if n_samples is None else n_samples
ctemp0 = self.regularize_image(cim)
ctemp1 = ctemp0*2 - 1
ctemp1 = ctemp1[None].repeat(n_samples, 1, 1, 1)
cim = net.clip_encode_vision(ctemp1)
uim = None
if scale != 1.0:
dummy = torch.zeros_like(ctemp1)
uim = net.clip_encode_vision(dummy)
uim, cim = [uim.half(), cim.half()] if self.fp16 else [uim, cim]
n = 768
shape = [n_samples, n]
zt, _ = sampler.sample(
steps=ddim_steps,
shape=shape,
conditioning=cim,
unconditional_guidance_scale=scale,
unconditional_conditioning=uim,
xtype='text', ctype='vision',
eta=ddim_eta,
verbose=False,)
ztn = net.optimus_encode([ctx_n])
ztp = net.optimus_encode([ctx_p])
ztn_norm = ztn / ztn.norm(dim=1)
zt_proj_mag = torch.matmul(zt, ztn_norm[0])
zt_perp = zt - zt_proj_mag[:, None] * ztn_norm
zt_newd = zt_perp + ztp
ctx_new = net.optimus_decode(zt_newd, temperature=prompt_temperature)
ctx_new = net.clip_encode_text(ctx_new)
ctx_p = net.clip_encode_text([ctx_p])
ctx_new = torch.cat([ctx_new, ctx_p.repeat(n_samples, 1, 1)], dim=1)
utx_new = net.clip_encode_text(n_samples * [""])
utx_new = torch.cat([utx_new, utx_new], dim=1)
cim_loc = cim[:, 1: ]
cim_loc_new = self.find_low_rank(cim_loc, demean=True, q=10)
cim_new = cim_loc_new
uim_new = uim[:, 1:]
h, w = [512, 512]
shape = [n_samples, 4, h//8, w//8]
z, _ = sampler.sample_dc(
steps=ddim_steps,
shape=shape,
first_conditioning=[uim_new, cim_new],
second_conditioning=[utx_new, ctx_new],
unconditional_guidance_scale=scale,
xtype='image',
first_ctype='vision',
second_ctype='prompt',
eta=ddim_eta,
verbose=False,
mixed_ratio=0.33, )
x = self.decode(z, 'image', 'vision', color_adj=color_adj, color_adj_to=ctemp0)
return x
def main(netwrapper,
app,
image=None,
prompt=None,
nprompt=None,
pprompt=None,
color_adj=None,
disentanglement_level=None,
dual_guided_mixing=None,
n_samples=4,
seed=0,):
if seed is not None:
seed = 0 if seed<0 else seed
np.random.seed(seed)
torch.manual_seed(seed+100)
if app == 'text-to-image':
print('Running [{}] with prompt [{}], n_samples [{}], seed [{}].'.format(
app, prompt, n_samples, seed))
if (prompt is None) or (prompt == ""):
return None, None
with torch.no_grad():
rv = netwrapper.inference(
xtype = 'image',
cin = prompt,
ctype = 'prompt',
n_samples = n_samples, )
return rv, None
elif app == 'image-variation':
print('Running [{}] with image [{}], color_adj [{}], n_samples [{}], seed [{}].'.format(
app, image, color_adj, n_samples, seed))
if image is None:
return None, None
with torch.no_grad():
rv = netwrapper.inference(
xtype = 'image',
cin = image,
ctype = 'vision',
color_adj = color_adj,
n_samples = n_samples, )
return rv, None
elif app == 'image-to-text':
print('Running [{}] with iamge [{}], n_samples [{}], seed [{}].'.format(
app, image, n_samples, seed))
if image is None:
return None, None
with torch.no_grad():
rv = netwrapper.inference(
xtype = 'text',
cin = image,
ctype = 'vision',
n_samples = n_samples, )
return None, '\n'.join(rv)
elif app == 'text-variation':
print('Running [{}] with prompt [{}], n_samples [{}], seed [{}].'.format(
app, prompt, n_samples, seed))
if prompt is None:
return None, None
with torch.no_grad():
rv = netwrapper.inference(
xtype = 'text',
cin = prompt,
ctype = 'prompt',
n_samples = n_samples, )
return None, '\n'.join(rv)
elif app == 'disentanglement':
print('Running [{}] with image [{}], color_adj [{}], disentanglement_level [{}], n_samples [{}], seed [{}].'.format(
app, image, color_adj, disentanglement_level, n_samples, seed))
if image is None:
return None, None
with torch.no_grad():
rv = netwrapper.application_disensemble(
cin = image,
level = disentanglement_level,
color_adj = color_adj,
n_samples = n_samples, )
return rv, None
elif app == 'dual-guided':
print('Running [{}] with image [{}], prompt [{}], color_adj [{}], dual_guided_mixing [{}], n_samples [{}], seed [{}].'.format(
app, image, prompt, color_adj, dual_guided_mixing, n_samples, seed))
if (image is None) or (prompt is None) or (prompt==""):
return None, None
with torch.no_grad():
rv = netwrapper.application_dualguided(
cim = image,
ctx = prompt,
mixing = dual_guided_mixing,
color_adj = color_adj,
n_samples = n_samples, )
return rv, None
elif app == 'i2t2i':
print('Running [{}] with image [{}], nprompt [{}], pprompt [{}], color_adj [{}], n_samples [{}], seed [{}].'.format(
app, image, nprompt, pprompt, color_adj, n_samples, seed))
if (image is None) or (nprompt is None) or (nprompt=="") \
or (pprompt is None) or (pprompt==""):
return None, None
with torch.no_grad():
rv = netwrapper.application_i2t2i(
cim = image,
ctx_n = nprompt,
ctx_p = pprompt,
color_adj = color_adj,
n_samples = n_samples, )
return rv, None
else:
assert False, "No such mode!"
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"--app", type=str, default="text-to-image",
help="Choose the application from ["\
"text-to-image, image-variation, "\
"image-to-text, text-variation, "\
"disentanglement, dual-guided, i2t2i]")
parser.add_argument(
"--model", type=str, default="official",
help="Choose the model type from ["\
"dc, official]")
parser.add_argument(
"--prompt", type=str,
default="a dream of a village in china, by Caspar "\
"David Friedrich, matte painting trending on artstation HQ")
parser.add_argument("--image", type=str)
parser.add_argument("--nprompt", type=str)
parser.add_argument("--pprompt", type=str)
parser.add_argument("--coloradj", type=str, default='simple')
parser.add_argument("--dislevel", type=int, default=0)
parser.add_argument("--dgmixing", type=float, default=0.7)
parser.add_argument("--nsample", type=int, default=4)
parser.add_argument("--seed", type=int)
parser.add_argument("--save", type=str, default='log',
help="The path or file the result will save into")
parser.add_argument("--gpu", type=int, default=0)
parser.add_argument("--fp16", action="store_true")
# parser.add_argument("--pth", type=str, default='pretrained/vd-four-flow-v1-0.pth')
args = parser.parse_args()
assert args.app in [
"text-to-image", "image-variation",
"image-to-text", "text-variation",
"disentanglement", "dual-guided", "i2t2i"], \
"Unknown app! Select from [text-to-image, image-variation, "\
"image-to-text, text-variation, "\
"disentanglement, dual-guided, i2t2i]"
device=args.gpu if torch.cuda.is_available() else 'cpu'
if args.model in ['4-flow', 'official']:
if args.fp16:
pth='pretrained/vd-four-flow-v1-0-fp16.pth'
else:
pth='pretrained/vd-four-flow-v1-0.pth'
vd_wrapper = vd_inference(pth=pth, fp16=args.fp16, device=device)
elif args.model in ['2-flow', 'dc']:
raise NotImplementedError
# vd_wrapper = vd_dc_inference(args.model, pth=args.pth, device=device)
elif args.model in ['1-flow', 'basic']:
raise NotImplementedError
# vd_wrapper = vd_basic_inference(args.model, pth=args.pth, device=device)
else:
assert False, "No such model! Select model from [4-flow(official), 2-flow(dc), 1-flow(basic)]"
imout, txtout = main(
netwrapper=vd_wrapper,
app=args.app,
image=args.image,
prompt=args.prompt,
nprompt=args.nprompt,
pprompt=args.pprompt,
color_adj=args.coloradj,
disentanglement_level=args.dislevel,
dual_guided_mixing=args.dgmixing,
n_samples=args.nsample,
seed=args.seed,)
if imout is not None:
imout = auto_merge_imlist([np.array(i) for i in imout])
imout = PIL.Image.fromarray(imout)
if osp.isdir(args.save):
imout.save(osp.join(args.save, 'imout.png'))
print('Output image saved to {}.'.format(osp.join(args.save, 'imout.png')))
else:
imout.save(osp.join(args.save))
print('Output image saved to {}.'.format(args.save))
if txtout is not None:
print(txtout)