|
import os |
|
import os.path as osp |
|
import PIL |
|
from PIL import Image |
|
from pathlib import Path |
|
import numpy as np |
|
import numpy.random as npr |
|
|
|
import torch |
|
import torchvision.transforms as tvtrans |
|
from lib.cfg_helper import model_cfg_bank |
|
from lib.model_zoo import get_model |
|
from lib.model_zoo.ddim_dualcontext import DDIMSampler_DualContext |
|
from lib.experiments.sd_default import color_adjust, auto_merge_imlist |
|
|
|
import argparse |
|
|
|
n_sample_image_default = 2 |
|
n_sample_text_default = 4 |
|
|
|
def highlight_print(info): |
|
print('') |
|
print(''.join(['#']*(len(info)+4))) |
|
print('# '+info+' #') |
|
print(''.join(['#']*(len(info)+4))) |
|
print('') |
|
|
|
class vd_inference(object): |
|
def __init__(self, pth='pretrained/vd1.0-four-flow.pth', fp16=False, device=0): |
|
cfgm_name = 'vd_noema' |
|
cfgm = model_cfg_bank()('vd_noema') |
|
device_str = device if isinstance(device, str) else 'cuda:{}'.format(device) |
|
cfgm.args.autokl_cfg.map_location = device_str |
|
cfgm.args.optimus_cfg.map_location = device_str |
|
net = get_model()(cfgm) |
|
if fp16: |
|
highlight_print('Running in FP16') |
|
net.clip.fp16 = True |
|
net = net.half() |
|
sd = torch.load(pth, map_location=device_str) |
|
net.load_state_dict(sd, strict=False) |
|
print('Load pretrained weight from {}'.format(pth)) |
|
net.to(device) |
|
|
|
self.device = device |
|
self.model_name = cfgm_name |
|
self.net = net |
|
self.fp16 = fp16 |
|
from lib.model_zoo.ddim_vd import DDIMSampler_VD |
|
self.sampler = DDIMSampler_VD(net) |
|
|
|
def regularize_image(self, x): |
|
BICUBIC = PIL.Image.Resampling.BICUBIC |
|
if isinstance(x, str): |
|
x = Image.open(x).resize([512, 512], resample=BICUBIC) |
|
x = tvtrans.ToTensor()(x) |
|
elif isinstance(x, PIL.Image.Image): |
|
x = x.resize([512, 512], resample=BICUBIC) |
|
x = tvtrans.ToTensor()(x) |
|
elif isinstance(x, np.ndarray): |
|
x = PIL.Image.fromarray(x).resize([512, 512], resample=BICUBIC) |
|
x = tvtrans.ToTensor()(x) |
|
elif isinstance(x, torch.Tensor): |
|
pass |
|
else: |
|
assert False, 'Unknown image type' |
|
|
|
assert (x.shape[1]==512) & (x.shape[2]==512), \ |
|
'Wrong image size' |
|
x = x.to(self.device) |
|
if self.fp16: |
|
x = x.half() |
|
return x |
|
|
|
def decode(self, z, xtype, ctype, color_adj='None', color_adj_to=None): |
|
net = self.net |
|
if xtype == 'image': |
|
x = net.autokl_decode(z) |
|
|
|
color_adj_flag = (color_adj!='none') and (color_adj!='None') and (color_adj is not None) |
|
color_adj_simple = (color_adj=='Simple') or color_adj=='simple' |
|
color_adj_keep_ratio = 0.5 |
|
|
|
if color_adj_flag and (ctype=='vision'): |
|
x_adj = [] |
|
for xi in x: |
|
color_adj_f = color_adjust(ref_from=(xi+1)/2, ref_to=color_adj_to) |
|
xi_adj = color_adj_f((xi+1)/2, keep=color_adj_keep_ratio, simple=color_adj_simple) |
|
x_adj.append(xi_adj) |
|
x = x_adj |
|
else: |
|
x = torch.clamp((x+1.0)/2.0, min=0.0, max=1.0) |
|
x = [tvtrans.ToPILImage()(xi) for xi in x] |
|
return x |
|
|
|
elif xtype == 'text': |
|
prompt_temperature = 1.0 |
|
prompt_merge_same_adj_word = True |
|
x = net.optimus_decode(z, temperature=prompt_temperature) |
|
if prompt_merge_same_adj_word: |
|
xnew = [] |
|
for xi in x: |
|
xi_split = xi.split() |
|
xinew = [] |
|
for idxi, wi in enumerate(xi_split): |
|
if idxi!=0 and wi==xi_split[idxi-1]: |
|
continue |
|
xinew.append(wi) |
|
xnew.append(' '.join(xinew)) |
|
x = xnew |
|
return x |
|
|
|
def inference(self, xtype, cin, ctype, scale=7.5, n_samples=None, color_adj=None,): |
|
net = self.net |
|
sampler = self.sampler |
|
ddim_steps = 50 |
|
ddim_eta = 0.0 |
|
|
|
if xtype == 'image': |
|
n_samples = n_sample_image_default if n_samples is None else n_samples |
|
elif xtype == 'text': |
|
n_samples = n_sample_text_default if n_samples is None else n_samples |
|
|
|
if ctype in ['prompt', 'text']: |
|
c = net.clip_encode_text(n_samples * [cin]) |
|
u = None |
|
if scale != 1.0: |
|
u = net.clip_encode_text(n_samples * [""]) |
|
|
|
elif ctype in ['vision', 'image']: |
|
cin = self.regularize_image(cin) |
|
ctemp = cin*2 - 1 |
|
ctemp = ctemp[None].repeat(n_samples, 1, 1, 1) |
|
c = net.clip_encode_vision(ctemp) |
|
u = None |
|
if scale != 1.0: |
|
dummy = torch.zeros_like(ctemp) |
|
u = net.clip_encode_vision(dummy) |
|
|
|
u, c = [u.half(), c.half()] if self.fp16 else [u, c] |
|
|
|
if xtype == 'image': |
|
h, w = [512, 512] |
|
shape = [n_samples, 4, h//8, w//8] |
|
z, _ = sampler.sample( |
|
steps=ddim_steps, |
|
shape=shape, |
|
conditioning=c, |
|
unconditional_guidance_scale=scale, |
|
unconditional_conditioning=u, |
|
xtype=xtype, ctype=ctype, |
|
eta=ddim_eta, |
|
verbose=False,) |
|
x = self.decode(z, xtype, ctype, color_adj=color_adj, color_adj_to=cin) |
|
return x |
|
|
|
elif xtype == 'text': |
|
n = 768 |
|
shape = [n_samples, n] |
|
z, _ = sampler.sample( |
|
steps=ddim_steps, |
|
shape=shape, |
|
conditioning=c, |
|
unconditional_guidance_scale=scale, |
|
unconditional_conditioning=u, |
|
xtype=xtype, ctype=ctype, |
|
eta=ddim_eta, |
|
verbose=False,) |
|
x = self.decode(z, xtype, ctype) |
|
return x |
|
|
|
def application_disensemble(self, cin, n_samples=None, level=0, color_adj=None,): |
|
net = self.net |
|
scale = 7.5 |
|
sampler = self.sampler |
|
ddim_steps = 50 |
|
ddim_eta = 0.0 |
|
n_samples = n_sample_image_default if n_samples is None else n_samples |
|
|
|
cin = self.regularize_image(cin) |
|
ctemp = cin*2 - 1 |
|
ctemp = ctemp[None].repeat(n_samples, 1, 1, 1) |
|
c = net.clip_encode_vision(ctemp) |
|
u = None |
|
if scale != 1.0: |
|
dummy = torch.zeros_like(ctemp) |
|
u = net.clip_encode_vision(dummy) |
|
u, c = [u.half(), c.half()] if self.fp16 else [u, c] |
|
|
|
if level == 0: |
|
pass |
|
else: |
|
c_glb = c[:, 0:1] |
|
c_loc = c[:, 1: ] |
|
u_glb = u[:, 0:1] |
|
u_loc = u[:, 1: ] |
|
|
|
if level == -1: |
|
c_loc = self.remove_low_rank(c_loc, demean=True, q=50, q_remove=1) |
|
u_loc = self.remove_low_rank(u_loc, demean=True, q=50, q_remove=1) |
|
if level == -2: |
|
c_loc = self.remove_low_rank(c_loc, demean=True, q=50, q_remove=2) |
|
u_loc = self.remove_low_rank(u_loc, demean=True, q=50, q_remove=2) |
|
if level == 1: |
|
c_loc = self.find_low_rank(c_loc, demean=True, q=10) |
|
u_loc = self.find_low_rank(u_loc, demean=True, q=10) |
|
if level == 2: |
|
c_loc = self.find_low_rank(c_loc, demean=True, q=2) |
|
u_loc = self.find_low_rank(u_loc, demean=True, q=2) |
|
|
|
c = torch.cat([c_glb, c_loc], dim=1) |
|
u = torch.cat([u_glb, u_loc], dim=1) |
|
|
|
h, w = [512, 512] |
|
shape = [n_samples, 4, h//8, w//8] |
|
z, _ = sampler.sample( |
|
steps=ddim_steps, |
|
shape=shape, |
|
conditioning=c, |
|
unconditional_guidance_scale=scale, |
|
unconditional_conditioning=u, |
|
xtype='image', ctype='vision', |
|
eta=ddim_eta, |
|
verbose=False,) |
|
x = self.decode(z, 'image', 'vision', color_adj=color_adj, color_adj_to=cin) |
|
return x |
|
|
|
def find_low_rank(self, x, demean=True, q=20, niter=10): |
|
if demean: |
|
x_mean = x.mean(-1, keepdim=True) |
|
x_input = x - x_mean |
|
else: |
|
x_input = x |
|
|
|
if x_input.dtype == torch.float16: |
|
fp16 = True |
|
x_input = x_input.float() |
|
else: |
|
fp16 = False |
|
|
|
u, s, v = torch.pca_lowrank(x_input, q=q, center=False, niter=niter) |
|
ss = torch.stack([torch.diag(si) for si in s]) |
|
x_lowrank = torch.bmm(torch.bmm(u, ss), torch.permute(v, [0, 2, 1])) |
|
|
|
if fp16: |
|
x_lowrank = x_lowrank.half() |
|
|
|
if demean: |
|
x_lowrank += x_mean |
|
return x_lowrank |
|
|
|
def remove_low_rank(self, x, demean=True, q=20, niter=10, q_remove=10): |
|
if demean: |
|
x_mean = x.mean(-1, keepdim=True) |
|
x_input = x - x_mean |
|
else: |
|
x_input = x |
|
|
|
if x_input.dtype == torch.float16: |
|
fp16 = True |
|
x_input = x_input.float() |
|
else: |
|
fp16 = False |
|
|
|
u, s, v = torch.pca_lowrank(x_input, q=q, center=False, niter=niter) |
|
s[:, 0:q_remove] = 0 |
|
ss = torch.stack([torch.diag(si) for si in s]) |
|
x_lowrank = torch.bmm(torch.bmm(u, ss), torch.permute(v, [0, 2, 1])) |
|
|
|
if fp16: |
|
x_lowrank = x_lowrank.half() |
|
|
|
if demean: |
|
x_lowrank += x_mean |
|
return x_lowrank |
|
|
|
def application_dualguided(self, cim, ctx, n_samples=None, mixing=0.5, color_adj=None, ): |
|
net = self.net |
|
scale = 7.5 |
|
sampler = self.sampler |
|
ddim_steps = 50 |
|
ddim_eta = 0.0 |
|
n_samples = n_sample_image_default if n_samples is None else n_samples |
|
|
|
ctemp0 = self.regularize_image(cim) |
|
ctemp1 = ctemp0*2 - 1 |
|
ctemp1 = ctemp1[None].repeat(n_samples, 1, 1, 1) |
|
cim = net.clip_encode_vision(ctemp1) |
|
uim = None |
|
if scale != 1.0: |
|
dummy = torch.zeros_like(ctemp1) |
|
uim = net.clip_encode_vision(dummy) |
|
|
|
ctx = net.clip_encode_text(n_samples * [ctx]) |
|
utx = None |
|
if scale != 1.0: |
|
utx = net.clip_encode_text(n_samples * [""]) |
|
|
|
uim, cim = [uim.half(), cim.half()] if self.fp16 else [uim, cim] |
|
utx, ctx = [utx.half(), ctx.half()] if self.fp16 else [utx, ctx] |
|
|
|
h, w = [512, 512] |
|
shape = [n_samples, 4, h//8, w//8] |
|
|
|
z, _ = sampler.sample_dc( |
|
steps=ddim_steps, |
|
shape=shape, |
|
first_conditioning=[uim, cim], |
|
second_conditioning=[utx, ctx], |
|
unconditional_guidance_scale=scale, |
|
xtype='image', |
|
first_ctype='vision', |
|
second_ctype='prompt', |
|
eta=ddim_eta, |
|
verbose=False, |
|
mixed_ratio=(1-mixing), ) |
|
x = self.decode(z, 'image', 'vision', color_adj=color_adj, color_adj_to=ctemp0) |
|
return x |
|
|
|
def application_i2t2i(self, cim, ctx_n, ctx_p, n_samples=None, color_adj=None,): |
|
net = self.net |
|
scale = 7.5 |
|
sampler = self.sampler |
|
ddim_steps = 50 |
|
ddim_eta = 0.0 |
|
prompt_temperature = 1.0 |
|
n_samples = n_sample_image_default if n_samples is None else n_samples |
|
|
|
ctemp0 = self.regularize_image(cim) |
|
ctemp1 = ctemp0*2 - 1 |
|
ctemp1 = ctemp1[None].repeat(n_samples, 1, 1, 1) |
|
cim = net.clip_encode_vision(ctemp1) |
|
uim = None |
|
if scale != 1.0: |
|
dummy = torch.zeros_like(ctemp1) |
|
uim = net.clip_encode_vision(dummy) |
|
|
|
uim, cim = [uim.half(), cim.half()] if self.fp16 else [uim, cim] |
|
|
|
n = 768 |
|
shape = [n_samples, n] |
|
zt, _ = sampler.sample( |
|
steps=ddim_steps, |
|
shape=shape, |
|
conditioning=cim, |
|
unconditional_guidance_scale=scale, |
|
unconditional_conditioning=uim, |
|
xtype='text', ctype='vision', |
|
eta=ddim_eta, |
|
verbose=False,) |
|
ztn = net.optimus_encode([ctx_n]) |
|
ztp = net.optimus_encode([ctx_p]) |
|
|
|
ztn_norm = ztn / ztn.norm(dim=1) |
|
zt_proj_mag = torch.matmul(zt, ztn_norm[0]) |
|
zt_perp = zt - zt_proj_mag[:, None] * ztn_norm |
|
zt_newd = zt_perp + ztp |
|
ctx_new = net.optimus_decode(zt_newd, temperature=prompt_temperature) |
|
|
|
ctx_new = net.clip_encode_text(ctx_new) |
|
ctx_p = net.clip_encode_text([ctx_p]) |
|
ctx_new = torch.cat([ctx_new, ctx_p.repeat(n_samples, 1, 1)], dim=1) |
|
utx_new = net.clip_encode_text(n_samples * [""]) |
|
utx_new = torch.cat([utx_new, utx_new], dim=1) |
|
|
|
cim_loc = cim[:, 1: ] |
|
cim_loc_new = self.find_low_rank(cim_loc, demean=True, q=10) |
|
cim_new = cim_loc_new |
|
uim_new = uim[:, 1:] |
|
|
|
h, w = [512, 512] |
|
shape = [n_samples, 4, h//8, w//8] |
|
z, _ = sampler.sample_dc( |
|
steps=ddim_steps, |
|
shape=shape, |
|
first_conditioning=[uim_new, cim_new], |
|
second_conditioning=[utx_new, ctx_new], |
|
unconditional_guidance_scale=scale, |
|
xtype='image', |
|
first_ctype='vision', |
|
second_ctype='prompt', |
|
eta=ddim_eta, |
|
verbose=False, |
|
mixed_ratio=0.33, ) |
|
|
|
x = self.decode(z, 'image', 'vision', color_adj=color_adj, color_adj_to=ctemp0) |
|
return x |
|
|
|
def main(netwrapper, |
|
app, |
|
image=None, |
|
prompt=None, |
|
nprompt=None, |
|
pprompt=None, |
|
color_adj=None, |
|
disentanglement_level=None, |
|
dual_guided_mixing=None, |
|
n_samples=4, |
|
seed=0,): |
|
|
|
if seed is not None: |
|
seed = 0 if seed<0 else seed |
|
np.random.seed(seed) |
|
torch.manual_seed(seed+100) |
|
|
|
if app == 'text-to-image': |
|
print('Running [{}] with prompt [{}], n_samples [{}], seed [{}].'.format( |
|
app, prompt, n_samples, seed)) |
|
if (prompt is None) or (prompt == ""): |
|
return None, None |
|
with torch.no_grad(): |
|
rv = netwrapper.inference( |
|
xtype = 'image', |
|
cin = prompt, |
|
ctype = 'prompt', |
|
n_samples = n_samples, ) |
|
return rv, None |
|
|
|
elif app == 'image-variation': |
|
print('Running [{}] with image [{}], color_adj [{}], n_samples [{}], seed [{}].'.format( |
|
app, image, color_adj, n_samples, seed)) |
|
if image is None: |
|
return None, None |
|
with torch.no_grad(): |
|
rv = netwrapper.inference( |
|
xtype = 'image', |
|
cin = image, |
|
ctype = 'vision', |
|
color_adj = color_adj, |
|
n_samples = n_samples, ) |
|
return rv, None |
|
|
|
elif app == 'image-to-text': |
|
print('Running [{}] with iamge [{}], n_samples [{}], seed [{}].'.format( |
|
app, image, n_samples, seed)) |
|
if image is None: |
|
return None, None |
|
with torch.no_grad(): |
|
rv = netwrapper.inference( |
|
xtype = 'text', |
|
cin = image, |
|
ctype = 'vision', |
|
n_samples = n_samples, ) |
|
return None, '\n'.join(rv) |
|
|
|
elif app == 'text-variation': |
|
print('Running [{}] with prompt [{}], n_samples [{}], seed [{}].'.format( |
|
app, prompt, n_samples, seed)) |
|
if prompt is None: |
|
return None, None |
|
with torch.no_grad(): |
|
rv = netwrapper.inference( |
|
xtype = 'text', |
|
cin = prompt, |
|
ctype = 'prompt', |
|
n_samples = n_samples, ) |
|
return None, '\n'.join(rv) |
|
|
|
elif app == 'disentanglement': |
|
print('Running [{}] with image [{}], color_adj [{}], disentanglement_level [{}], n_samples [{}], seed [{}].'.format( |
|
app, image, color_adj, disentanglement_level, n_samples, seed)) |
|
if image is None: |
|
return None, None |
|
with torch.no_grad(): |
|
rv = netwrapper.application_disensemble( |
|
cin = image, |
|
level = disentanglement_level, |
|
color_adj = color_adj, |
|
n_samples = n_samples, ) |
|
return rv, None |
|
|
|
elif app == 'dual-guided': |
|
print('Running [{}] with image [{}], prompt [{}], color_adj [{}], dual_guided_mixing [{}], n_samples [{}], seed [{}].'.format( |
|
app, image, prompt, color_adj, dual_guided_mixing, n_samples, seed)) |
|
if (image is None) or (prompt is None) or (prompt==""): |
|
return None, None |
|
with torch.no_grad(): |
|
rv = netwrapper.application_dualguided( |
|
cim = image, |
|
ctx = prompt, |
|
mixing = dual_guided_mixing, |
|
color_adj = color_adj, |
|
n_samples = n_samples, ) |
|
return rv, None |
|
|
|
elif app == 'i2t2i': |
|
print('Running [{}] with image [{}], nprompt [{}], pprompt [{}], color_adj [{}], n_samples [{}], seed [{}].'.format( |
|
app, image, nprompt, pprompt, color_adj, n_samples, seed)) |
|
if (image is None) or (nprompt is None) or (nprompt=="") \ |
|
or (pprompt is None) or (pprompt==""): |
|
return None, None |
|
with torch.no_grad(): |
|
rv = netwrapper.application_i2t2i( |
|
cim = image, |
|
ctx_n = nprompt, |
|
ctx_p = pprompt, |
|
color_adj = color_adj, |
|
n_samples = n_samples, ) |
|
return rv, None |
|
|
|
else: |
|
assert False, "No such mode!" |
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--app", type=str, default="text-to-image", |
|
help="Choose the application from ["\ |
|
"text-to-image, image-variation, "\ |
|
"image-to-text, text-variation, "\ |
|
"disentanglement, dual-guided, i2t2i]") |
|
|
|
parser.add_argument( |
|
"--model", type=str, default="official", |
|
help="Choose the model type from ["\ |
|
"dc, official]") |
|
|
|
parser.add_argument( |
|
"--prompt", type=str, |
|
default="a dream of a village in china, by Caspar "\ |
|
"David Friedrich, matte painting trending on artstation HQ") |
|
|
|
parser.add_argument("--image", type=str) |
|
|
|
parser.add_argument("--nprompt", type=str) |
|
|
|
parser.add_argument("--pprompt", type=str) |
|
|
|
parser.add_argument("--coloradj", type=str, default='simple') |
|
|
|
parser.add_argument("--dislevel", type=int, default=0) |
|
|
|
parser.add_argument("--dgmixing", type=float, default=0.7) |
|
|
|
parser.add_argument("--nsample", type=int, default=4) |
|
|
|
parser.add_argument("--seed", type=int) |
|
|
|
parser.add_argument("--save", type=str, default='log', |
|
help="The path or file the result will save into") |
|
|
|
parser.add_argument("--gpu", type=int, default=0) |
|
|
|
parser.add_argument("--fp16", action="store_true") |
|
|
|
|
|
|
|
args = parser.parse_args() |
|
|
|
assert args.app in [ |
|
"text-to-image", "image-variation", |
|
"image-to-text", "text-variation", |
|
"disentanglement", "dual-guided", "i2t2i"], \ |
|
"Unknown app! Select from [text-to-image, image-variation, "\ |
|
"image-to-text, text-variation, "\ |
|
"disentanglement, dual-guided, i2t2i]" |
|
|
|
device=args.gpu if torch.cuda.is_available() else 'cpu' |
|
|
|
if args.model in ['4-flow', 'official']: |
|
if args.fp16: |
|
pth='pretrained/vd-four-flow-v1-0-fp16.pth' |
|
else: |
|
pth='pretrained/vd-four-flow-v1-0.pth' |
|
vd_wrapper = vd_inference(pth=pth, fp16=args.fp16, device=device) |
|
elif args.model in ['2-flow', 'dc']: |
|
raise NotImplementedError |
|
|
|
elif args.model in ['1-flow', 'basic']: |
|
raise NotImplementedError |
|
|
|
else: |
|
assert False, "No such model! Select model from [4-flow(official), 2-flow(dc), 1-flow(basic)]" |
|
|
|
imout, txtout = main( |
|
netwrapper=vd_wrapper, |
|
app=args.app, |
|
image=args.image, |
|
prompt=args.prompt, |
|
nprompt=args.nprompt, |
|
pprompt=args.pprompt, |
|
color_adj=args.coloradj, |
|
disentanglement_level=args.dislevel, |
|
dual_guided_mixing=args.dgmixing, |
|
n_samples=args.nsample, |
|
seed=args.seed,) |
|
|
|
if imout is not None: |
|
imout = auto_merge_imlist([np.array(i) for i in imout]) |
|
imout = PIL.Image.fromarray(imout) |
|
if osp.isdir(args.save): |
|
imout.save(osp.join(args.save, 'imout.png')) |
|
print('Output image saved to {}.'.format(osp.join(args.save, 'imout.png'))) |
|
else: |
|
imout.save(osp.join(args.save)) |
|
print('Output image saved to {}.'.format(args.save)) |
|
|
|
if txtout is not None: |
|
print(txtout) |
|
|