File size: 6,136 Bytes
6067469
94be4c7
 
ab11bdd
 
81ccbca
 
 
 
 
6067469
fc73e59
 
 
 
6067469
81ccbca
ab11bdd
0002379
81ccbca
fc73e59
 
 
 
 
 
b58675c
fc73e59
 
 
 
 
 
 
 
ab11bdd
fc73e59
 
81ccbca
fc73e59
81ccbca
fc73e59
 
 
81ccbca
fc73e59
 
 
0002379
fc73e59
 
 
 
94be4c7
 
c8aa68b
94be4c7
6067469
 
 
fc73e59
 
 
 
81ccbca
fc73e59
 
 
94be4c7
fc73e59
 
 
 
 
 
94be4c7
 
fc73e59
 
 
 
 
ab11bdd
 
 
81ccbca
 
ab11bdd
 
fc73e59
 
 
 
ab11bdd
 
fc73e59
94be4c7
81ccbca
6067469
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81ccbca
7c89716
 
 
fd9afda
81ccbca
 
 
 
 
 
0002379
 
81ccbca
 
 
 
 
 
 
94be4c7
 
 
 
 
 
81ccbca
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import random

from accelerate.utils import set_seed
from torch.cuda.amp import autocast

from StableDiffuser import StableDiffuser
from finetuning import FineTunedModel
import torch
from tqdm import tqdm

from isolate_rng import isolate_rng
from memory_efficiency import MemoryEfficiencyWrapper


def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path,
          use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False, seed=-1, save_every=-1):

    nsteps = 50
    diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path).to('cuda')

    memory_efficiency_wrapper = MemoryEfficiencyWrapper(diffuser=diffuser, use_amp=use_amp, use_xformers=use_xformers,
                                                        use_gradient_checkpointing=use_gradient_checkpointing )
    with memory_efficiency_wrapper:
        diffuser.train()
        finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules)
        if use_adamw8bit:
            print("using AdamW 8Bit optimizer")
            import bitsandbytes as bnb
            optimizer = bnb.optim.AdamW8bit(finetuner.parameters(),
                                            lr=lr,
                                            betas=(0.9, 0.999),
                                            weight_decay=0.010,
                                            eps=1e-8
                                            )
        else:
            print("using Adam optimizer")
            optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr)
        criteria = torch.nn.MSELoss()

        pbar = tqdm(range(iterations))

        with torch.no_grad():
            neutral_text_embeddings = diffuser.get_text_embeddings([''],n_imgs=1)
            positive_text_embeddings = diffuser.get_text_embeddings([prompt],n_imgs=1)

        del diffuser.vae
        del diffuser.text_encoder
        del diffuser.tokenizer

        torch.cuda.empty_cache()

        print(f"using img_size of {img_size}")

        if seed == -1:
            seed = random.randint(0, 2 ** 30)
        set_seed(int(seed))

        prev_losses = []
        start_loss = None
        max_prev_loss_count = 10
        for i in pbar:
            with torch.no_grad():
                diffuser.set_scheduler_timesteps(nsteps)
                optimizer.zero_grad()

                iteration = torch.randint(1, nsteps - 1, (1,)).item()
                latents = diffuser.get_initial_latents(1, width=img_size, height=img_size, n_prompts=1)

                with finetuner:
                    latents_steps, _ = diffuser.diffusion(
                        latents,
                        positive_text_embeddings,
                        start_iteration=0,
                        end_iteration=iteration,
                        guidance_scale=3,
                        show_progress=False,
                        use_amp=use_amp
                    )

                diffuser.set_scheduler_timesteps(1000)
                iteration = int(iteration / nsteps * 1000)

                with autocast(enabled=use_amp):
                    positive_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
                    neutral_latents = diffuser.predict_noise(iteration, latents_steps[0], neutral_text_embeddings, guidance_scale=1)

            with finetuner:
                with autocast(enabled=use_amp):
                    negative_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)

            positive_latents.requires_grad = False
            neutral_latents.requires_grad = False

            # loss = criteria(e_n, e_0) works the best try 5000 epochs
            loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents)))
            memory_efficiency_wrapper.step(optimizer, loss)
            optimizer.zero_grad()

            # print moving average loss
            prev_losses.append(loss.detach().clone())
            if len(prev_losses) > max_prev_loss_count:
                prev_losses.pop(0)
            if start_loss is None:
                start_loss = prev_losses[-1]
            if len(prev_losses) >= max_prev_loss_count:
                moving_average_loss = sum(prev_losses) / len(prev_losses)
                print(
                    f"step {i}: loss={loss.item()} (avg={moving_average_loss.item()}, start ∆={(moving_average_loss - start_loss).item()}")
            else:
                print(f"step {i}: loss={loss.item()}")

            if save_every > 0 and ((i % save_every) == (save_every-1)):
                torch.save(finetuner.state_dict(), save_path + f"__step_{i}.pt")

    torch.save(finetuner.state_dict(), save_path)

    del diffuser, loss, optimizer, finetuner, negative_latents, neutral_latents, positive_latents, latents_steps, latents

    torch.cuda.empty_cache()
if __name__ == '__main__':

    import argparse

    parser = argparse.ArgumentParser()

    parser.add_argument("--repo_id_or_path", required=True)
    parser.add_argument("--img_size", type=int, required=False, default=512)
    parser.add_argument('--prompt', required=True)
    parser.add_argument('--modules', required=True)
    parser.add_argument('--freeze_modules', nargs='+', required=True)
    parser.add_argument('--save_path', required=True)
    parser.add_argument('--iterations', type=int, required=True)
    parser.add_argument('--lr', type=float, required=True)
    parser.add_argument('--negative_guidance', type=float, required=True)
    parser.add_argument('--seed', type=int, required=False, default=-1,
                        help='Training seed for reproducible results, or -1 to pick a random seed')
    parser.add_argument('--use_adamw8bit', action='store_true')
    parser.add_argument('--use_xformers', action='store_true')
    parser.add_argument('--use_amp', action='store_true')
    parser.add_argument('--use_gradient_checkpointing', action='store_true')

    train(**vars(parser.parse_args()))