|
import torch |
|
from ldm.models.diffusion.ddim import DDIMSampler |
|
from ldm.models.diffusion.plms import PLMSSampler |
|
from ldm.util import instantiate_from_config |
|
import numpy as np |
|
import random |
|
import time |
|
from dataset.concat_dataset import ConCatDataset |
|
from torch.utils.data.distributed import DistributedSampler |
|
from torch.utils.data import DataLoader |
|
from torch.utils.tensorboard import SummaryWriter |
|
import os |
|
import shutil |
|
import torchvision |
|
import math |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
from tqdm import tqdm |
|
from distributed import get_rank, synchronize, get_world_size |
|
from transformers import get_cosine_schedule_with_warmup, get_constant_schedule_with_warmup |
|
from copy import deepcopy |
|
try: |
|
from apex import amp |
|
except: |
|
pass |
|
|
|
|
|
class ImageCaptionSaver: |
|
def __init__(self, base_path, nrow=8, normalize=True, scale_each=True, range=(-1,1) ): |
|
self.base_path = base_path |
|
self.nrow = nrow |
|
self.normalize = normalize |
|
self.scale_each = scale_each |
|
self.range = range |
|
|
|
def __call__(self, images, real, captions, seen): |
|
|
|
save_path = os.path.join(self.base_path, str(seen).zfill(8)+'.png') |
|
torchvision.utils.save_image( images, save_path, nrow=self.nrow, normalize=self.normalize, scale_each=self.scale_each, range=self.range ) |
|
|
|
save_path = os.path.join(self.base_path, str(seen).zfill(8)+'_real.png') |
|
torchvision.utils.save_image( real, save_path, nrow=self.nrow) |
|
|
|
assert images.shape[0] == len(captions) |
|
|
|
save_path = os.path.join(self.base_path, 'captions.txt') |
|
with open(save_path, "a") as f: |
|
f.write( str(seen).zfill(8) + ':\n' ) |
|
for cap in captions: |
|
f.write( cap + '\n' ) |
|
f.write( '\n' ) |
|
|
|
|
|
|
|
def read_official_ckpt(ckpt_path): |
|
"Read offical pretrained ckpt and convert into my style" |
|
state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] |
|
out = {} |
|
out["model"] = {} |
|
out["text_encoder"] = {} |
|
out["autoencoder"] = {} |
|
out["unexpected"] = {} |
|
out["diffusion"] = {} |
|
|
|
for k,v in state_dict.items(): |
|
if k.startswith('model.diffusion_model'): |
|
out["model"][k.replace("model.diffusion_model.", "")] = v |
|
elif k.startswith('cond_stage_model'): |
|
out["text_encoder"][k.replace("cond_stage_model.", "")] = v |
|
elif k.startswith('first_stage_model'): |
|
out["autoencoder"][k.replace("first_stage_model.", "")] = v |
|
elif k in ["model_ema.decay", "model_ema.num_updates"]: |
|
out["unexpected"][k] = v |
|
else: |
|
out["diffusion"][k] = v |
|
return out |
|
|
|
|
|
def batch_to_device(batch, device): |
|
for k in batch: |
|
if isinstance(batch[k], torch.Tensor): |
|
batch[k] = batch[k].to(device) |
|
return batch |
|
|
|
|
|
def sub_batch(batch, num=1): |
|
|
|
num = num if num > 1 else 1 |
|
for k in batch: |
|
batch[k] = batch[k][0:num] |
|
return batch |
|
|
|
|
|
def wrap_loader(loader): |
|
while True: |
|
for batch in loader: |
|
yield batch |
|
|
|
|
|
def disable_grads(model): |
|
for p in model.parameters(): |
|
p.requires_grad = False |
|
|
|
|
|
def count_params(params): |
|
total_trainable_params_count = 0 |
|
for p in params: |
|
total_trainable_params_count += p.numel() |
|
print("total_trainable_params_count is: ", total_trainable_params_count) |
|
|
|
|
|
def update_ema(target_params, source_params, rate=0.99): |
|
for targ, src in zip(target_params, source_params): |
|
targ.detach().mul_(rate).add_(src, alpha=1 - rate) |
|
|
|
|
|
def create_expt_folder_with_auto_resuming(OUTPUT_ROOT, name): |
|
|
|
name = os.path.join( OUTPUT_ROOT, name ) |
|
writer = None |
|
checkpoint = None |
|
|
|
if os.path.exists(name): |
|
all_tags = os.listdir(name) |
|
all_existing_tags = [ tag for tag in all_tags if tag.startswith('tag') ] |
|
all_existing_tags.sort() |
|
all_existing_tags = all_existing_tags[::-1] |
|
for previous_tag in all_existing_tags: |
|
potential_ckpt = os.path.join( name, previous_tag, 'checkpoint_latest.pth' ) |
|
if os.path.exists(potential_ckpt): |
|
checkpoint = potential_ckpt |
|
if get_rank() == 0: |
|
print('ckpt found '+ potential_ckpt) |
|
break |
|
curr_tag = 'tag'+str(len(all_existing_tags)).zfill(2) |
|
name = os.path.join( name, curr_tag ) |
|
else: |
|
name = os.path.join( name, 'tag00' ) |
|
|
|
if get_rank() == 0: |
|
os.makedirs(name) |
|
os.makedirs( os.path.join(name,'Log') ) |
|
writer = SummaryWriter( os.path.join(name,'Log') ) |
|
|
|
return name, writer, checkpoint |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Trainer: |
|
def __init__(self, config): |
|
|
|
self.config = config |
|
self.device = torch.device("cuda") |
|
|
|
self.l_simple_weight = 1 |
|
self.name, self.writer, checkpoint = create_expt_folder_with_auto_resuming(config.OUTPUT_ROOT, config.name) |
|
if get_rank() == 0: |
|
shutil.copyfile(config.yaml_file, os.path.join(self.name, "train_config_file.yaml") ) |
|
torch.save( vars(config), os.path.join(self.name, "config_dict.pth") ) |
|
|
|
|
|
self.model = instantiate_from_config(config.model).to(self.device) |
|
self.autoencoder = instantiate_from_config(config.autoencoder).to(self.device) |
|
self.text_encoder = instantiate_from_config(config.text_encoder).to(self.device) |
|
self.diffusion = instantiate_from_config(config.diffusion).to(self.device) |
|
|
|
|
|
state_dict = read_official_ckpt( os.path.join(config.DATA_ROOT, config.official_ckpt_name) ) |
|
missing_keys, unexpected_keys = self.model.load_state_dict( state_dict["model"], strict=False ) |
|
assert unexpected_keys == [] |
|
original_params_names = list( state_dict["model"].keys() ) |
|
self.autoencoder.load_state_dict( state_dict["autoencoder"] ) |
|
self.text_encoder.load_state_dict( state_dict["text_encoder"] ) |
|
self.diffusion.load_state_dict( state_dict["diffusion"] ) |
|
|
|
self.autoencoder.eval() |
|
self.text_encoder.eval() |
|
disable_grads(self.autoencoder) |
|
disable_grads(self.text_encoder) |
|
|
|
|
|
|
|
|
|
if self.config.ckpt is not None: |
|
first_stage_ckpt = torch.load(self.config.ckpt, map_location="cpu") |
|
self.model.load_state_dict(first_stage_ckpt["model"]) |
|
|
|
|
|
|
|
|
|
|
|
print(" ") |
|
print("IMPORTANT: following code decides which params trainable!") |
|
print(" ") |
|
|
|
if self.config.whole: |
|
print("Entire model is trainable") |
|
params = list(self.model.parameters()) |
|
else: |
|
print("Only new added components will be updated") |
|
params = [] |
|
trainable_names = [] |
|
for name, p in self.model.named_parameters(): |
|
if ("transformer_blocks" in name) and ("fuser" in name): |
|
params.append(p) |
|
trainable_names.append(name) |
|
elif "position_net" in name: |
|
params.append(p) |
|
trainable_names.append(name) |
|
else: |
|
|
|
|
|
assert name in original_params_names, name |
|
|
|
all_params_name = list( self.model.state_dict().keys() ) |
|
assert set(all_params_name) == set(trainable_names + original_params_names) |
|
|
|
self.opt = torch.optim.AdamW(params, lr=config.base_learning_rate, weight_decay=config.weight_decay) |
|
count_params(params) |
|
|
|
self.master_params = list(self.model.parameters()) |
|
|
|
if config.enable_ema: |
|
self.ema = deepcopy(self.model) |
|
self.ema_params = list(self.ema.parameters()) |
|
self.ema.eval() |
|
|
|
|
|
if config.scheduler_type == "cosine": |
|
self.scheduler = get_cosine_schedule_with_warmup(self.opt, num_warmup_steps=config.warmup_steps, num_training_steps=config.total_iters) |
|
elif config.scheduler_type == "constant": |
|
self.scheduler = get_constant_schedule_with_warmup(self.opt, num_warmup_steps=config.warmup_steps) |
|
else: |
|
assert False |
|
|
|
|
|
|
|
|
|
train_dataset_repeats = config.train_dataset_repeats if 'train_dataset_repeats' in config else None |
|
dataset_train = ConCatDataset(config.train_dataset_names, config.DATA_ROOT, config.which_embedder, train=True, repeats=train_dataset_repeats) |
|
sampler = DistributedSampler(dataset_train) if config.distributed else None |
|
loader_train = DataLoader( dataset_train, batch_size=config.batch_size, |
|
shuffle=(sampler is None), |
|
num_workers=config.workers, |
|
pin_memory=True, |
|
sampler=sampler) |
|
self.dataset_train = dataset_train |
|
self.loader_train = wrap_loader(loader_train) |
|
|
|
if get_rank() == 0: |
|
total_image = dataset_train.total_images() |
|
print("Total training images: ", total_image) |
|
|
|
|
|
|
|
self.starting_iter = 0 |
|
if checkpoint is not None: |
|
checkpoint = torch.load(checkpoint, map_location="cpu") |
|
self.model.load_state_dict(checkpoint["model"]) |
|
if config.enable_ema: |
|
self.ema.load_state_dict(checkpoint["ema"]) |
|
self.opt.load_state_dict(checkpoint["opt"]) |
|
self.scheduler.load_state_dict(checkpoint["scheduler"]) |
|
self.starting_iter = checkpoint["iters"] |
|
if self.starting_iter >= config.total_iters: |
|
synchronize() |
|
print("Training finished. Start exiting") |
|
exit() |
|
|
|
|
|
|
|
if get_rank() == 0: |
|
print("Actual total need see images is: ", config.total_iters*config.total_batch_size) |
|
print("Equivalent training epoch is: ", (config.total_iters*config.total_batch_size) / len(dataset_train) ) |
|
self.image_caption_saver = ImageCaptionSaver(self.name) |
|
|
|
|
|
if config.use_o2: |
|
self.model, self.opt = amp.initialize(self.model, self.opt, opt_level="O2") |
|
self.model.use_o2 = True |
|
|
|
|
|
|
|
if config.distributed: |
|
self.model = DDP( self.model, device_ids=[config.local_rank], output_device=config.local_rank, broadcast_buffers=False ) |
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
def get_input(self, batch): |
|
|
|
z = self.autoencoder.encode( batch["image"] ) |
|
|
|
context = self.text_encoder.encode( batch["caption"] ) |
|
|
|
_t = torch.rand(z.shape[0]).to(z.device) |
|
t = (torch.pow(_t, self.config.resample_step_gamma) * 1000).long() |
|
t = torch.where(t!=1000, t, 999) |
|
|
|
return z, t, context |
|
|
|
|
|
def run_one_step(self, batch): |
|
x_start, t, context = self.get_input(batch) |
|
noise = torch.randn_like(x_start) |
|
x_noisy = self.diffusion.q_sample(x_start=x_start, t=t, noise=noise) |
|
|
|
input = dict(x = x_noisy, |
|
timesteps = t, |
|
context = context, |
|
boxes = batch['boxes'], |
|
masks = batch['masks'], |
|
text_masks = batch['text_masks'], |
|
image_masks = batch['image_masks'], |
|
text_embeddings = batch["text_embeddings"], |
|
image_embeddings = batch["image_embeddings"] ) |
|
model_output = self.model(input) |
|
|
|
loss = torch.nn.functional.mse_loss(model_output, noise) * self.l_simple_weight |
|
|
|
self.loss_dict = {"loss": loss.item()} |
|
|
|
return loss |
|
|
|
|
|
|
|
def start_training(self): |
|
|
|
if not self.config.use_o2: |
|
|
|
scaler = torch.cuda.amp.GradScaler() |
|
|
|
|
|
iterator = tqdm(range(self.starting_iter, self.config.total_iters), desc='Training progress', disable=get_rank() != 0 ) |
|
self.model.train() |
|
for iter_idx in iterator: |
|
self.iter_idx = iter_idx |
|
|
|
self.opt.zero_grad() |
|
batch = next(self.loader_train) |
|
batch_to_device(batch, self.device) |
|
|
|
if self.config.use_o2: |
|
loss = self.run_one_step(batch) |
|
with amp.scale_loss(loss, self.opt) as scaled_loss: |
|
scaled_loss.backward() |
|
self.opt.step() |
|
else: |
|
enabled = True if self.config.use_mixed else False |
|
with torch.cuda.amp.autocast(enabled=enabled): |
|
loss = self.run_one_step(batch) |
|
scaler.scale(loss).backward() |
|
scaler.step(self.opt) |
|
scaler.update() |
|
|
|
|
|
self.scheduler.step() |
|
|
|
if self.config.enable_ema: |
|
update_ema(self.ema_params, self.master_params, self.config.ema_rate) |
|
|
|
|
|
if (get_rank() == 0): |
|
if (iter_idx % 10 == 0): |
|
self.log_loss() |
|
if (iter_idx == 0) or ( iter_idx % self.config.save_every_iters == 0 ) or (iter_idx == self.config.total_iters-1): |
|
self.save_ckpt_and_result() |
|
synchronize() |
|
|
|
|
|
synchronize() |
|
print("Training finished. Start exiting") |
|
exit() |
|
|
|
|
|
def log_loss(self): |
|
for k, v in self.loss_dict.items(): |
|
self.writer.add_scalar( k, v, self.iter_idx+1 ) |
|
|
|
|
|
@torch.no_grad() |
|
def save_ckpt_and_result(self): |
|
|
|
model_wo_wrapper = self.model.module if self.config.distributed else self.model |
|
|
|
iter_name = self.iter_idx + 1 |
|
|
|
if not self.config.disable_inference_in_training: |
|
|
|
batch_here = self.config.batch_size |
|
batch = sub_batch( next(self.loader_train), batch_here) |
|
batch_to_device(batch, self.device) |
|
|
|
|
|
real_images_with_box_drawing = [] |
|
for i in range(batch_here): |
|
temp_data = {"image": batch["image"][i], "boxes":batch["boxes"][i]} |
|
im = self.dataset_train.datasets[0].vis_getitem_data(out=temp_data, return_tensor=True, print_caption=False) |
|
real_images_with_box_drawing.append(im) |
|
real_images_with_box_drawing = torch.stack(real_images_with_box_drawing) |
|
|
|
|
|
uc = self.text_encoder.encode( batch_here*[""] ) |
|
context = self.text_encoder.encode( batch["caption"] ) |
|
|
|
ddim_sampler = PLMSSampler(self.diffusion, model_wo_wrapper) |
|
shape = (batch_here, model_wo_wrapper.in_channels, model_wo_wrapper.image_size, model_wo_wrapper.image_size) |
|
input = dict( x = None, |
|
timesteps = None, |
|
context = context, |
|
boxes = batch['boxes'], |
|
masks = batch['masks'], |
|
text_masks = batch['text_masks'], |
|
image_masks = batch['image_masks'], |
|
text_embeddings = batch["text_embeddings"], |
|
image_embeddings = batch["image_embeddings"] ) |
|
samples = ddim_sampler.sample(S=50, shape=shape, input=input, uc=uc, guidance_scale=5) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
autoencoder_wo_wrapper = self.autoencoder |
|
samples = autoencoder_wo_wrapper.decode(samples).cpu() |
|
|
|
self.image_caption_saver(samples, real_images_with_box_drawing, batch["caption"], iter_name) |
|
|
|
ckpt = dict(model = model_wo_wrapper.state_dict(), |
|
opt = self.opt.state_dict(), |
|
scheduler= self.scheduler.state_dict(), |
|
iters = self.iter_idx+1 ) |
|
if self.config.enable_ema: |
|
ckpt["ema"] = self.ema.state_dict() |
|
torch.save( ckpt, os.path.join(self.name, "checkpoint_"+str(iter_name).zfill(8)+".pth") ) |
|
torch.save( ckpt, os.path.join(self.name, "checkpoint_latest.pth") ) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|