Spaces:
Runtime error
Runtime error
File size: 4,334 Bytes
81ccbca fc73e59 81ccbca 0002379 81ccbca fc73e59 81ccbca fc73e59 81ccbca fc73e59 81ccbca fc73e59 81ccbca fc73e59 81ccbca fc73e59 7021212 fc73e59 81ccbca fc73e59 0002379 fc73e59 81ccbca fc73e59 81ccbca fc73e59 81ccbca 7c89716 fd9afda 81ccbca 0002379 81ccbca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
from StableDiffuser import StableDiffuser
from finetuning import FineTunedModel
import torch
from tqdm import tqdm
from memory_efficiency import MemoryEfficiencyWrapper
def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path,
use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False):
nsteps = 50
diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path).to('cuda')
memory_efficiency_wrapper = MemoryEfficiencyWrapper(diffuser=diffuser, use_amp=use_amp, use_xformers=use_xformers,
use_gradient_checkpointing=use_gradient_checkpointing )
with memory_efficiency_wrapper:
diffuser.train()
finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules)
if use_adamw8bit:
import bitsandbytes as bnb
optimizer = bnb.optim.AdamW8bit(finetuner.parameters(),
lr=lr,
betas=(0.9, 0.999),
weight_decay=0.010,
eps=1e-8
)
else:
optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr)
criteria = torch.nn.MSELoss()
pbar = tqdm(range(iterations))
with torch.no_grad():
neutral_text_embeddings = diffuser.get_text_embeddings([''],n_imgs=1)
positive_text_embeddings = diffuser.get_text_embeddings([prompt],n_imgs=1)
del diffuser.vae
del diffuser.text_encoder
del diffuser.tokenizer
torch.cuda.empty_cache()
print(f"using img_size of {img_size}")
for i in pbar:
with torch.no_grad():
diffuser.set_scheduler_timesteps(nsteps)
optimizer.zero_grad()
iteration = torch.randint(1, nsteps - 1, (1,)).item()
latents = diffuser.get_initial_latents(1, width=img_size, height=img_size, n_prompts=1)
with finetuner:
latents_steps, _ = diffuser.diffusion(
latents,
positive_text_embeddings,
start_iteration=0,
end_iteration=iteration,
guidance_scale=3,
show_progress=False
)
diffuser.set_scheduler_timesteps(1000)
iteration = int(iteration / nsteps * 1000)
positive_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
neutral_latents = diffuser.predict_noise(iteration, latents_steps[0], neutral_text_embeddings, guidance_scale=1)
with finetuner:
negative_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
positive_latents.requires_grad = False
neutral_latents.requires_grad = False
loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents))) #loss = criteria(e_n, e_0) works the best try 5000 epochs
memory_efficiency_wrapper.step(optimizer, loss)
optimizer.step()
torch.save(finetuner.state_dict(), save_path)
del diffuser, loss, optimizer, finetuner, negative_latents, neutral_latents, positive_latents, latents_steps, latents
torch.cuda.empty_cache()
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--repo_id_or_path", required=True)
parser.add_argument("--img_size", type=int, required=False, default=512)
parser.add_argument('--prompt', required=True)
parser.add_argument('--modules', required=True)
parser.add_argument('--freeze_modules', nargs='+', required=True)
parser.add_argument('--save_path', required=True)
parser.add_argument('--iterations', type=int, required=True)
parser.add_argument('--lr', type=float, required=True)
parser.add_argument('--negative_guidance', type=float, required=True)
train(**vars(parser.parse_args())) |