|
import argparse |
|
from PIL import Image, ImageDraw |
|
from evaluator import Evaluator |
|
from omegaconf import OmegaConf |
|
from ldm.models.diffusion.ddim import DDIMSampler |
|
from ldm.models.diffusion.plms import PLMSSampler |
|
import os |
|
from transformers import CLIPProcessor, CLIPModel |
|
from copy import deepcopy |
|
import torch |
|
from ldm.util import instantiate_from_config |
|
from trainer import read_official_ckpt, batch_to_device |
|
from evaluator import set_alpha_scale, save_images, draw_masks_from_boxes |
|
import numpy as np |
|
import clip |
|
from functools import partial |
|
import torchvision.transforms.functional as F |
|
import random |
|
|
|
|
|
device = "cuda" |
|
|
|
|
|
def alpha_generator(length, type=[1,0,0]): |
|
""" |
|
length is total timestpes needed for sampling. |
|
type should be a list containing three values which sum should be 1 |
|
|
|
It means the percentage of three stages: |
|
alpha=1 stage |
|
linear deacy stage |
|
alpha=0 stage. |
|
|
|
For example if length=100, type=[0.8,0.1,0.1] |
|
then the first 800 stpes, alpha will be 1, and then linearly decay to 0 in the next 100 steps, |
|
and the last 100 stpes are 0. |
|
""" |
|
|
|
assert len(type)==3 |
|
assert type[0] + type[1] + type[2] == 1 |
|
|
|
stage0_length = int(type[0]*length) |
|
stage1_length = int(type[1]*length) |
|
stage2_length = length - stage0_length - stage1_length |
|
|
|
if stage1_length != 0: |
|
decay_alphas = np.arange(start=0, stop=1, step=1/stage1_length)[::-1] |
|
decay_alphas = list(decay_alphas) |
|
else: |
|
decay_alphas = [] |
|
|
|
|
|
alphas = [1]*stage0_length + decay_alphas + [0]*stage2_length |
|
|
|
assert len(alphas) == length |
|
|
|
return alphas |
|
|
|
|
|
def draw_box(img, locations): |
|
colors = ["red", "green", "blue", "olive", "orange", "brown", "cyan", "purple"] |
|
draw = ImageDraw.Draw(img) |
|
WW,HH = img.size |
|
for bid, box in enumerate(locations): |
|
draw.rectangle([box[0]*WW, box[1]*HH, box[2]*WW, box[3]*HH], outline =colors[bid % len(colors)], width=5) |
|
return img |
|
|
|
def load_common_ckpt(config, common_ckpt): |
|
autoencoder = instantiate_from_config(config.autoencoder).to(device).eval() |
|
text_encoder = instantiate_from_config(config.text_encoder).to(device).eval() |
|
diffusion = instantiate_from_config(config.diffusion).to(device) |
|
|
|
autoencoder.load_state_dict( common_ckpt["autoencoder"] ) |
|
text_encoder.load_state_dict( common_ckpt["text_encoder"] ) |
|
diffusion.load_state_dict( common_ckpt["diffusion"] ) |
|
|
|
return [autoencoder, text_encoder, diffusion] |
|
|
|
def load_ckpt(config, state_dict, common_instances): |
|
model = instantiate_from_config(config.model).to(device).eval() |
|
|
|
model.load_state_dict(state_dict['model']) |
|
set_alpha_scale(model, config.alpha_scale) |
|
print("ckpt is loaded") |
|
|
|
return [model] + common_instances |
|
|
|
|
|
|
|
|
|
def project(x, projection_matrix): |
|
""" |
|
x (Batch*768) should be the penultimate feature of CLIP (before projection) |
|
projection_matrix (768*768) is the CLIP projection matrix, which should be weight.data of Linear layer |
|
defined in CLIP (out_dim, in_dim), thus we need to apply transpose below. |
|
this function will return the CLIP feature (without normalziation) |
|
""" |
|
return x@torch.transpose(projection_matrix, 0, 1) |
|
|
|
|
|
def get_clip_feature(model, processor, input, is_image=False): |
|
feature_type = ['before','after_reproject'] |
|
|
|
if is_image: |
|
image = input |
|
inputs = processor(images=[image], return_tensors="pt", padding=True) |
|
inputs['pixel_values'] = inputs['pixel_values'].cuda() |
|
inputs['input_ids'] = torch.tensor([[0,1,2,3]]).cuda() |
|
outputs = model(**inputs) |
|
feature = outputs.image_embeds |
|
if feature_type[1] == 'after_renorm': |
|
feature = feature*28.7 |
|
if feature_type[1] == 'after_reproject': |
|
feature = project( feature, torch.load('gligen/projection_matrix.pth').cuda().T ).squeeze(0) |
|
feature = ( feature / feature.norm() ) * 28.7 |
|
feature = feature.unsqueeze(0) |
|
else: |
|
inputs = processor(text=input, return_tensors="pt", padding=True) |
|
inputs['input_ids'] = inputs['input_ids'].cuda() |
|
inputs['pixel_values'] = torch.ones(1,3,224,224).cuda() |
|
inputs['attention_mask'] = inputs['attention_mask'].cuda() |
|
outputs = model(**inputs) |
|
feature = outputs.text_embeds if feature_type[0] == 'after' else outputs.text_model_output.pooler_output |
|
return feature |
|
|
|
|
|
|
|
def complete_mask(has_mask, max_objs): |
|
mask = torch.ones(1,max_objs) |
|
if type(has_mask) == int or type(has_mask) == float: |
|
return mask * has_mask |
|
else: |
|
for idx, value in enumerate(has_mask): |
|
mask[0,idx] = value |
|
return mask |
|
|
|
|
|
|
|
@torch.no_grad() |
|
def fire_clip(text_encoder, meta, batch=1, max_objs=30, clip_model=None): |
|
phrases = meta["phrases"] |
|
images = meta["images"] |
|
|
|
if clip_model is None: |
|
version = "openai/clip-vit-large-patch14" |
|
model = CLIPModel.from_pretrained(version).cuda() |
|
processor = CLIPProcessor.from_pretrained(version) |
|
else: |
|
version = "openai/clip-vit-large-patch14" |
|
assert clip_model['version'] == version |
|
model = clip_model['model'] |
|
processor = clip_model['processor'] |
|
|
|
boxes = torch.zeros(max_objs, 4) |
|
masks = torch.zeros(max_objs) |
|
text_embeddings = torch.zeros(max_objs, 768) |
|
image_embeddings = torch.zeros(max_objs, 768) |
|
|
|
|
|
text_features = [] |
|
image_features = [] |
|
for phrase, image in zip(phrases,images): |
|
text_features.append( get_clip_feature(model, processor, phrase, is_image=False) ) |
|
image_features.append( get_clip_feature(model, processor, image, is_image=True) ) |
|
|
|
if len(text_features) > 0: |
|
text_features = torch.cat(text_features, dim=0) |
|
image_features = torch.cat(image_features, dim=0) |
|
|
|
for idx, (box, text_feature, image_feature) in enumerate(zip( meta['locations'], text_features, image_features)): |
|
boxes[idx] = torch.tensor(box) |
|
masks[idx] = 1 |
|
text_embeddings[idx] = text_feature |
|
image_embeddings[idx] = image_feature |
|
|
|
|
|
out = { |
|
"boxes" : boxes.unsqueeze(0).repeat(batch,1,1), |
|
"masks" : masks.unsqueeze(0).repeat(batch,1), |
|
"text_masks" : masks.unsqueeze(0).repeat(batch,1)*complete_mask( meta["has_text_mask"], max_objs ), |
|
"image_masks" : masks.unsqueeze(0).repeat(batch,1)*complete_mask( meta["has_image_mask"], max_objs ), |
|
"text_embeddings" : text_embeddings.unsqueeze(0).repeat(batch,1,1), |
|
"image_embeddings" : image_embeddings.unsqueeze(0).repeat(batch,1,1) |
|
} |
|
return batch_to_device(out, device) |
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
def grounded_generation_box(loaded_model_list, instruction, *args, **kwargs): |
|
|
|
|
|
model, autoencoder, text_encoder, diffusion = loaded_model_list |
|
batch_size = instruction["batch_size"] |
|
is_inpaint = True if "input_image" in instruction else False |
|
save_folder = os.path.join("create_samples", instruction["save_folder_name"]) |
|
|
|
|
|
|
|
if instruction.get('fix_seed', False): |
|
random_seed = instruction['rand_seed'] |
|
random.seed(random_seed) |
|
np.random.seed(random_seed) |
|
torch.manual_seed(random_seed) |
|
|
|
|
|
batch = fire_clip(text_encoder, instruction, batch_size, clip_model=kwargs.get('clip_model', None)) |
|
context = text_encoder.encode( [instruction["prompt"]]*batch_size ) |
|
uc = text_encoder.encode( batch_size*[""] ) |
|
|
|
input = dict(x = None, |
|
timesteps = None, |
|
context = context, |
|
boxes = batch['boxes'], |
|
masks = batch['masks'], |
|
text_masks = batch['text_masks'], |
|
image_masks = batch['image_masks'], |
|
text_embeddings = batch["text_embeddings"], |
|
image_embeddings = batch["image_embeddings"] ) |
|
|
|
inpainting_mask = x0 = None |
|
if is_inpaint: |
|
input_image = F.pil_to_tensor( instruction["input_image"] ) |
|
input_image = ( input_image.float().unsqueeze(0).cuda() / 255 - 0.5 ) / 0.5 |
|
x0 = autoencoder.encode( input_image ) |
|
if instruction["actual_mask"] is not None: |
|
inpainting_mask = instruction["actual_mask"][None, None].expand(batch['boxes'].shape[0], -1, -1, -1).cuda() |
|
else: |
|
|
|
actual_boxes = [instruction['inpainting_boxes_nodrop'] for _ in range(batch['boxes'].shape[0])] |
|
inpainting_mask = draw_masks_from_boxes(actual_boxes, (x0.shape[-2], x0.shape[-1]) ).cuda() |
|
|
|
masked_x0 = x0*inpainting_mask |
|
inpainting_extra_input = torch.cat([masked_x0,inpainting_mask], dim=1) |
|
input["inpainting_extra_input"] = inpainting_extra_input |
|
|
|
|
|
|
|
alpha_generator_func = partial(alpha_generator, type=instruction["alpha_type"]) |
|
if False: |
|
sampler = DDIMSampler(diffusion, model, alpha_generator_func=alpha_generator_func, set_alpha_scale=set_alpha_scale) |
|
steps = 250 |
|
else: |
|
sampler = PLMSSampler(diffusion, model, alpha_generator_func=alpha_generator_func, set_alpha_scale=set_alpha_scale) |
|
steps = 50 |
|
|
|
|
|
shape = (batch_size, model.in_channels, model.image_size, model.image_size) |
|
samples_fake = sampler.sample(S=steps, shape=shape, input=input, uc=uc, guidance_scale=instruction['guidance_scale'], mask=inpainting_mask, x0=x0) |
|
samples_fake = autoencoder.decode(samples_fake) |
|
|
|
|
|
|
|
|
|
sample_list = [] |
|
for sample in samples_fake: |
|
sample = torch.clamp(sample, min=-1, max=1) * 0.5 + 0.5 |
|
sample = sample.cpu().numpy().transpose(1,2,0) * 255 |
|
sample = Image.fromarray(sample.astype(np.uint8)) |
|
sample_list.append(sample) |
|
|
|
return sample_list, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|