Spaces:
Runtime error
Runtime error
File size: 6,136 Bytes
6067469 94be4c7 ab11bdd 81ccbca 6067469 fc73e59 6067469 81ccbca ab11bdd 0002379 81ccbca fc73e59 b58675c fc73e59 ab11bdd fc73e59 81ccbca fc73e59 81ccbca fc73e59 81ccbca fc73e59 0002379 fc73e59 94be4c7 c8aa68b 94be4c7 6067469 fc73e59 81ccbca fc73e59 94be4c7 fc73e59 94be4c7 fc73e59 ab11bdd 81ccbca ab11bdd fc73e59 ab11bdd fc73e59 94be4c7 81ccbca 6067469 81ccbca 7c89716 fd9afda 81ccbca 0002379 81ccbca 94be4c7 81ccbca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import random
from accelerate.utils import set_seed
from torch.cuda.amp import autocast
from StableDiffuser import StableDiffuser
from finetuning import FineTunedModel
import torch
from tqdm import tqdm
from isolate_rng import isolate_rng
from memory_efficiency import MemoryEfficiencyWrapper
def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path,
use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False, seed=-1, save_every=-1):
nsteps = 50
diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path).to('cuda')
memory_efficiency_wrapper = MemoryEfficiencyWrapper(diffuser=diffuser, use_amp=use_amp, use_xformers=use_xformers,
use_gradient_checkpointing=use_gradient_checkpointing )
with memory_efficiency_wrapper:
diffuser.train()
finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules)
if use_adamw8bit:
print("using AdamW 8Bit optimizer")
import bitsandbytes as bnb
optimizer = bnb.optim.AdamW8bit(finetuner.parameters(),
lr=lr,
betas=(0.9, 0.999),
weight_decay=0.010,
eps=1e-8
)
else:
print("using Adam optimizer")
optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr)
criteria = torch.nn.MSELoss()
pbar = tqdm(range(iterations))
with torch.no_grad():
neutral_text_embeddings = diffuser.get_text_embeddings([''],n_imgs=1)
positive_text_embeddings = diffuser.get_text_embeddings([prompt],n_imgs=1)
del diffuser.vae
del diffuser.text_encoder
del diffuser.tokenizer
torch.cuda.empty_cache()
print(f"using img_size of {img_size}")
if seed == -1:
seed = random.randint(0, 2 ** 30)
set_seed(int(seed))
prev_losses = []
start_loss = None
max_prev_loss_count = 10
for i in pbar:
with torch.no_grad():
diffuser.set_scheduler_timesteps(nsteps)
optimizer.zero_grad()
iteration = torch.randint(1, nsteps - 1, (1,)).item()
latents = diffuser.get_initial_latents(1, width=img_size, height=img_size, n_prompts=1)
with finetuner:
latents_steps, _ = diffuser.diffusion(
latents,
positive_text_embeddings,
start_iteration=0,
end_iteration=iteration,
guidance_scale=3,
show_progress=False,
use_amp=use_amp
)
diffuser.set_scheduler_timesteps(1000)
iteration = int(iteration / nsteps * 1000)
with autocast(enabled=use_amp):
positive_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
neutral_latents = diffuser.predict_noise(iteration, latents_steps[0], neutral_text_embeddings, guidance_scale=1)
with finetuner:
with autocast(enabled=use_amp):
negative_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
positive_latents.requires_grad = False
neutral_latents.requires_grad = False
# loss = criteria(e_n, e_0) works the best try 5000 epochs
loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents)))
memory_efficiency_wrapper.step(optimizer, loss)
optimizer.zero_grad()
# print moving average loss
prev_losses.append(loss.detach().clone())
if len(prev_losses) > max_prev_loss_count:
prev_losses.pop(0)
if start_loss is None:
start_loss = prev_losses[-1]
if len(prev_losses) >= max_prev_loss_count:
moving_average_loss = sum(prev_losses) / len(prev_losses)
print(
f"step {i}: loss={loss.item()} (avg={moving_average_loss.item()}, start ∆={(moving_average_loss - start_loss).item()}")
else:
print(f"step {i}: loss={loss.item()}")
if save_every > 0 and ((i % save_every) == (save_every-1)):
torch.save(finetuner.state_dict(), save_path + f"__step_{i}.pt")
torch.save(finetuner.state_dict(), save_path)
del diffuser, loss, optimizer, finetuner, negative_latents, neutral_latents, positive_latents, latents_steps, latents
torch.cuda.empty_cache()
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--repo_id_or_path", required=True)
parser.add_argument("--img_size", type=int, required=False, default=512)
parser.add_argument('--prompt', required=True)
parser.add_argument('--modules', required=True)
parser.add_argument('--freeze_modules', nargs='+', required=True)
parser.add_argument('--save_path', required=True)
parser.add_argument('--iterations', type=int, required=True)
parser.add_argument('--lr', type=float, required=True)
parser.add_argument('--negative_guidance', type=float, required=True)
parser.add_argument('--seed', type=int, required=False, default=-1,
help='Training seed for reproducible results, or -1 to pick a random seed')
parser.add_argument('--use_adamw8bit', action='store_true')
parser.add_argument('--use_xformers', action='store_true')
parser.add_argument('--use_amp', action='store_true')
parser.add_argument('--use_gradient_checkpointing', action='store_true')
train(**vars(parser.parse_args())) |