import os.path import random from accelerate.utils import set_seed from diffusers import StableDiffusionPipeline from torch.cuda.amp import autocast from torchvision import transforms from StableDiffuser import StableDiffuser from finetuning import FineTunedModel import torch from tqdm import tqdm from isolate_rng import isolate_rng from memory_efficiency import MemoryEfficiencyWrapper from torch.utils.tensorboard import SummaryWriter training_should_cancel = False def validate(diffuser: StableDiffuser, finetuner: FineTunedModel, validation_embeddings: torch.FloatTensor, neutral_embeddings: torch.FloatTensor, sample_embeddings: torch.FloatTensor, logger: SummaryWriter, use_amp: bool, global_step: int, validation_seed: int = 555, ): print("validating...") with isolate_rng(include_cuda=True), torch.no_grad(): set_seed(validation_seed) criteria = torch.nn.MSELoss() negative_guidance = 1 val_count = 5 nsteps=50 num_validation_prompts = validation_embeddings.shape[0] // 2 for i in range(0, num_validation_prompts): accumulated_loss = None this_validation_embeddings = validation_embeddings[i*2:i*2+2] for j in range(val_count): iteration = random.randint(1, nsteps) diffused_latents = get_diffused_latents(diffuser, nsteps, this_validation_embeddings, iteration, use_amp) with autocast(enabled=use_amp): positive_latents = diffuser.predict_noise(iteration, diffused_latents, this_validation_embeddings, guidance_scale=1) neutral_latents = diffuser.predict_noise(iteration, diffused_latents, neutral_embeddings, guidance_scale=1) with finetuner, autocast(enabled=use_amp): negative_latents = diffuser.predict_noise(iteration, diffused_latents, this_validation_embeddings, guidance_scale=1) loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents))) accumulated_loss = (accumulated_loss or 0) + loss.item() logger.add_scalar(f"loss/val_{i}", accumulated_loss/val_count, global_step=global_step) num_samples = sample_embeddings.shape[0] // 2 for i in range(0, num_samples): print(f'making sample {i}...') with finetuner: pipeline = StableDiffusionPipeline(vae=diffuser.vae, text_encoder=diffuser.text_encoder, tokenizer=diffuser.tokenizer, unet=diffuser.unet, scheduler=diffuser.scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False) images = pipeline(prompt_embeds=sample_embeddings[i*2+1:i*2+2], negative_prompt_embeds=sample_embeddings[i*2:i*2+1], num_inference_steps=50) image_tensor = transforms.ToTensor()(images.images[0]) logger.add_image(f"samples/{i}", img_tensor=image_tensor, global_step=global_step) """ with finetuner, torch.cuda.amp.autocast(enabled=use_amp): images = diffuser( combined_embeddings=sample_embeddings[i*2:i*2+2], n_steps=50 ) logger.add_images(f"samples/{i}", images) """ torch.cuda.empty_cache() def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path, use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False, seed=-1, save_every_n_steps=-1, validate_every_n_steps=-1, validation_prompts=[], sample_positive_prompts=[], sample_negative_prompts=[]): diffuser = None loss = None optimizer = None finetuner = None negative_latents = None neutral_latents = None positive_latents = None nsteps = 50 print(f"using img_size of {img_size}") diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path, native_img_size=img_size).to('cuda') logger = SummaryWriter(log_dir=f"logs/{os.path.splitext(os.path.basename(save_path))[0]}") memory_efficiency_wrapper = MemoryEfficiencyWrapper(diffuser=diffuser, use_amp=use_amp, use_xformers=use_xformers, use_gradient_checkpointing=use_gradient_checkpointing ) with memory_efficiency_wrapper: diffuser.train() finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules) if use_adamw8bit: print("using AdamW 8Bit optimizer") import bitsandbytes as bnb optimizer = bnb.optim.AdamW8bit(finetuner.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=0.010, eps=1e-8 ) else: print("using Adam optimizer") optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr) criteria = torch.nn.MSELoss() pbar = tqdm(range(iterations)) with torch.no_grad(): neutral_text_embeddings = diffuser.get_cond_and_uncond_embeddings([''], n_imgs=1) positive_text_embeddings = diffuser.get_cond_and_uncond_embeddings([prompt], n_imgs=1) validation_embeddings = diffuser.get_cond_and_uncond_embeddings(validation_prompts, n_imgs=1) sample_embeddings = diffuser.get_cond_and_uncond_embeddings(sample_positive_prompts, sample_negative_prompts, n_imgs=1) #if use_amp: # diffuser.vae = diffuser.vae.to(diffuser.vae.device, dtype=torch.float16) #del diffuser.text_encoder #del diffuser.tokenizer torch.cuda.empty_cache() if seed == -1: seed = random.randint(0, 2 ** 30) set_seed(int(seed)) prev_losses = [] start_loss = None max_prev_loss_count = 10 try: for i in pbar: if training_should_cancel: print("received cancellation request") return None with torch.no_grad(): optimizer.zero_grad() iteration = torch.randint(1, nsteps - 1, (1,)).item() with finetuner: diffused_latents = get_diffused_latents(diffuser, nsteps, positive_text_embeddings, iteration, use_amp) iteration = int(iteration / nsteps * 1000) with autocast(enabled=use_amp): positive_latents = diffuser.predict_noise(iteration, diffused_latents, positive_text_embeddings, guidance_scale=1) neutral_latents = diffuser.predict_noise(iteration, diffused_latents, neutral_text_embeddings, guidance_scale=1) with finetuner: with autocast(enabled=use_amp): negative_latents = diffuser.predict_noise(iteration, diffused_latents, positive_text_embeddings, guidance_scale=1) positive_latents.requires_grad = False neutral_latents.requires_grad = False # loss = criteria(e_n, e_0) works the best try 5000 epochs loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents))) memory_efficiency_wrapper.step(optimizer, loss) optimizer.zero_grad() logger.add_scalar("loss", loss.item(), global_step=i) # print moving average loss prev_losses.append(loss.detach().clone()) if len(prev_losses) > max_prev_loss_count: prev_losses.pop(0) if start_loss is None: start_loss = prev_losses[-1] if len(prev_losses) >= max_prev_loss_count: moving_average_loss = sum(prev_losses) / len(prev_losses) print( f"step {i}: loss={loss.item()} (avg={moving_average_loss.item()}, start ∆={(moving_average_loss - start_loss).item()}") else: print(f"step {i}: loss={loss.item()}") if save_every_n_steps > 0 and ((i+1) % save_every_n_steps) == 0: torch.save(finetuner.state_dict(), save_path + f"__step_{i+1}.pt") if validate_every_n_steps > 0 and ((i+1) % validate_every_n_steps) == 0: validate(diffuser, finetuner, validation_embeddings=validation_embeddings, sample_embeddings=sample_embeddings, neutral_embeddings=neutral_text_embeddings, logger=logger, use_amp=False, global_step=i) torch.save(finetuner.state_dict(), save_path) return save_path finally: del diffuser, loss, optimizer, finetuner, negative_latents, neutral_latents, positive_latents torch.cuda.empty_cache() def get_diffused_latents(diffuser, nsteps, text_embeddings, end_iteration, use_amp): diffuser.set_scheduler_timesteps(nsteps) latents = diffuser.get_initial_latents(1, n_prompts=1) latents_steps, _ = diffuser.diffusion( latents, text_embeddings, start_iteration=0, end_iteration=end_iteration, guidance_scale=3, show_progress=False, use_amp=use_amp ) # because return_latents is not passed to diffuser.diffusion(), latents_steps should have only 1 entry # but we take the "last" (-1) entry because paranoia diffused_latents = latents_steps[-1] diffuser.set_scheduler_timesteps(1000) del latents_steps, latents return diffused_latents if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument("--repo_id_or_path", required=True) parser.add_argument("--img_size", type=int, required=False, default=512) parser.add_argument('--prompt', required=True) parser.add_argument('--modules', required=True) parser.add_argument('--freeze_modules', nargs='+', required=True) parser.add_argument('--save_path', required=True) parser.add_argument('--iterations', type=int, required=True) parser.add_argument('--lr', type=float, required=True) parser.add_argument('--negative_guidance', type=float, required=True) parser.add_argument('--seed', type=int, required=False, default=-1, help='Training seed for reproducible results, or -1 to pick a random seed') parser.add_argument('--use_adamw8bit', action='store_true') parser.add_argument('--use_xformers', action='store_true') parser.add_argument('--use_amp', action='store_true') parser.add_argument('--use_gradient_checkpointing', action='store_true') train(**vars(parser.parse_args()))