File size: 4,334 Bytes
81ccbca
 
 
 
 
fc73e59
 
 
 
 
81ccbca
 
 
0002379
81ccbca
fc73e59
 
 
81ccbca
fc73e59
81ccbca
fc73e59
81ccbca
fc73e59
 
 
 
 
 
 
 
 
 
 
81ccbca
fc73e59
81ccbca
fc73e59
7021212
fc73e59
 
81ccbca
fc73e59
 
 
0002379
fc73e59
 
 
 
 
 
 
 
81ccbca
fc73e59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81ccbca
 
fc73e59
 
 
 
 
 
 
 
81ccbca
 
7c89716
 
 
fd9afda
81ccbca
 
 
 
 
 
0002379
 
81ccbca
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from StableDiffuser import StableDiffuser
from finetuning import FineTunedModel
import torch
from tqdm import tqdm

from memory_efficiency import MemoryEfficiencyWrapper


def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path,
          use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False):
  
    nsteps = 50

    diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path).to('cuda')

    memory_efficiency_wrapper = MemoryEfficiencyWrapper(diffuser=diffuser, use_amp=use_amp, use_xformers=use_xformers,
                                                        use_gradient_checkpointing=use_gradient_checkpointing )
    with memory_efficiency_wrapper:

        diffuser.train()

        finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules)

        if use_adamw8bit:
            import bitsandbytes as bnb
            optimizer = bnb.optim.AdamW8bit(finetuner.parameters(),
                                            lr=lr,
                                            betas=(0.9, 0.999),
                                            weight_decay=0.010,
                                            eps=1e-8
                                            )
        else:
            optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr)
        criteria = torch.nn.MSELoss()

        pbar = tqdm(range(iterations))

        with torch.no_grad():

            neutral_text_embeddings = diffuser.get_text_embeddings([''],n_imgs=1)
            positive_text_embeddings = diffuser.get_text_embeddings([prompt],n_imgs=1)

        del diffuser.vae
        del diffuser.text_encoder
        del diffuser.tokenizer

        torch.cuda.empty_cache()

        print(f"using img_size of {img_size}")

        for i in pbar:
            with torch.no_grad():
                diffuser.set_scheduler_timesteps(nsteps)
                optimizer.zero_grad()

                iteration = torch.randint(1, nsteps - 1, (1,)).item()
                latents = diffuser.get_initial_latents(1, width=img_size, height=img_size, n_prompts=1)

                with finetuner:
                    latents_steps, _ = diffuser.diffusion(
                        latents,
                        positive_text_embeddings,
                        start_iteration=0,
                        end_iteration=iteration,
                        guidance_scale=3,
                        show_progress=False
                    )

                diffuser.set_scheduler_timesteps(1000)

                iteration = int(iteration / nsteps * 1000)

                positive_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
                neutral_latents = diffuser.predict_noise(iteration, latents_steps[0], neutral_text_embeddings, guidance_scale=1)

            with finetuner:
                negative_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)

            positive_latents.requires_grad = False
            neutral_latents.requires_grad = False

            loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents))) #loss = criteria(e_n, e_0) works the best try 5000 epochs
            memory_efficiency_wrapper.step(optimizer, loss)
            optimizer.step()

    torch.save(finetuner.state_dict(), save_path)

    del diffuser, loss, optimizer, finetuner, negative_latents, neutral_latents, positive_latents, latents_steps, latents

    torch.cuda.empty_cache()
if __name__ == '__main__':

    import argparse

    parser = argparse.ArgumentParser()

    parser.add_argument("--repo_id_or_path", required=True)
    parser.add_argument("--img_size", type=int, required=False, default=512)
    parser.add_argument('--prompt', required=True)
    parser.add_argument('--modules', required=True)
    parser.add_argument('--freeze_modules', nargs='+', required=True)
    parser.add_argument('--save_path', required=True)
    parser.add_argument('--iterations', type=int, required=True)
    parser.add_argument('--lr', type=float, required=True)
    parser.add_argument('--negative_guidance', type=float, required=True)

    train(**vars(parser.parse_args()))