File size: 13,416 Bytes
3d91cf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
import torch
from diffusers import DPMSolverMultistepScheduler, UniPCMultistepScheduler
from typing import List

def AdamBmixer(order, ets, b=1):
 
    cur_order = min(order, len(ets))
    if cur_order == 1:
        prime = b * ets[-1]  
    elif cur_order == 2:
        prime = ((2+b) * ets[-1] - (2-b)*ets[-2]) / 2
    elif cur_order == 3:
        prime = ((18+5*b) * ets[-1] - (24-8*b) * ets[-2] + (6-1*b) * ets[-3]) / 12
    elif cur_order == 4:
        prime = ((46+9*b) * ets[-1] - (78-19*b) * ets[-2] + (42-5*b) * ets[-3] - (10-b) * ets[-4]) / 24
    elif cur_order == 5:
        prime = ((1650+251*b) * ets[-1] - (3420-646*b) * ets[-2]
                     + (2880-264*b) * ets[-3] - (1380-106*b) * ets[-4]
                     + (270-19*b)* ets[-5]) / 720
    else:
        raise NotImplementedError
    
    prime = prime/b
    return prime

class PLMSWithHBScheduler():
    """
    PLMS with Polyak's Heavy Ball Momentum (HB) for diffusion ODEs.
    We implement it as a wrapper for schedulers in diffusers (https://github.com/huggingface/diffusers)
    
    When order is an integer, this method is equivalent to PLMS without momentum.
    """
    def __init__(self, scheduler, order):
        self.scheduler = scheduler
        self.ets = []
        self.update_order(order)
        self.mixer = AdamBmixer
        
    def update_order(self, order):
        self.order = order // 1  + 1 if order%1 > 0 else order // 1 
        self.beta = order % 1 if order%1 > 0 else 1
        self.vel = None
 
    def clear(self):
        self.ets = []
        self.vel = None

    def update_ets(self, val):
        self.ets.append(val)
        if len(self.ets) > self.order:
            self.ets.pop(0)

    def _step_with_momentum(self, grads):
        self.update_ets(grads)
        prime = self.mixer(self.order, self.ets, 1.0)
        self.vel = (1 - self.beta) * self.vel + self.beta * prime
        return self.vel

    def step(
        self,
        grads: torch.FloatTensor,
        timestep: int,
        latents: torch.FloatTensor,
        output_mode: str = "scale",
    ):
        if self.vel is None: self.vel = grads
 
        if hasattr(self.scheduler, 'sigmas'):
            step_index = (self.scheduler.timesteps == timestep).nonzero().item()
            sigma = self.scheduler.sigmas[step_index]
            sigma_next = self.scheduler.sigmas[step_index + 1]
            del_g = sigma_next - sigma
 
            update_val = self._step_with_momentum(grads)
            return latents + del_g * update_val

        elif isinstance(self.scheduler, DPMSolverMultistepScheduler):
            step_index = (self.scheduler.timesteps == timestep).nonzero().item()
            current_timestep = self.scheduler.timesteps[step_index]
            prev_timestep = 0 if step_index == len(self.scheduler.timesteps) - 1 else self.scheduler.timesteps[step_index + 1]

            alpha_prod_t = self.scheduler.alphas_cumprod[current_timestep]
            alpha_bar_prev = self.scheduler.alphas_cumprod[prev_timestep]

            s0 = torch.sqrt(alpha_prod_t)
            s_1 = torch.sqrt(alpha_bar_prev)
            g0 = torch.sqrt(1-alpha_prod_t)/s0
            g_1 = torch.sqrt(1-alpha_bar_prev)/s_1
            del_g = g_1 - g0
 
            update_val = self._step_with_momentum(grads)
            if output_mode in ["scale"]:
                return (latents/s0  + del_g * update_val) * s_1
            elif output_mode in ["back"]:
                return latents + del_g * update_val * s_1
            elif output_mode in ["front"]:
                return latents + del_g * update_val * s0
            else:
                return latents + del_g * update_val
        else:
            raise NotImplementedError

class GHVBScheduler(PLMSWithHBScheduler):
    """
    Generalizing Polyak's Heavy Bal (GHVB) for diffusion ODEs.
    We implement it as a wrapper for schedulers in diffusers (https://github.com/huggingface/diffusers)
    
    When order is an integer, this method is equivalent to PLMS without momentum.
    """
    def _step_with_momentum(self, grads):
        self.vel = (1 - self.beta) * self.vel + self.beta * grads
        self.update_ets(self.vel)
        prime = self.mixer(self.order, self.ets, self.beta)
        return prime

class PLMSWithNTScheduler(PLMSWithHBScheduler):
    """
    PLMS with Nesterov Momentum (NT) for diffusion ODEs.
    We implement it as a wrapper for schedulers in diffusers (https://github.com/huggingface/diffusers)
    
    When order is an integer, this method is equivalent to PLMS without momentum.
    """
    def _step_with_momentum(self, grads):
        self.update_ets(grads)
        prime = self.mixer(self.order, self.ets, 1.0) # update v^{(2)}
        self.vel = (1 - self.beta) * self.vel + self.beta * prime # update v^{(1)}
        update_val = (1 - self.beta) * self.vel + self.beta * prime # update x
        return update_val

class MomentumDPMSolverMultistepScheduler(DPMSolverMultistepScheduler):
    """
    DPM-Solver++2M with HB momentum.
    Currently support only algorithm_type = "dpmsolver++" and solver_type = "midpoint"

    When beta = 1.0, this method is equivalent to DPM-Solver++2M without momentum.
    """
    def initialize_momentum(self, beta):
        self.vel = None
        self.beta = beta

    def multistep_dpm_solver_second_order_update(
        self,
        model_output_list: List[torch.FloatTensor],
        timestep_list: List[int],
        prev_timestep: int,
        sample: torch.FloatTensor,
    ) -> torch.FloatTensor:
        
        t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
        m0, m1 = model_output_list[-1], model_output_list[-2]
        lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
        alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
        sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
        h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
        r0 = h_0 / h
        D0, D1 = m0, (1.0 / r0) * (m0 - m1)
        if self.config.algorithm_type == "dpmsolver++":
            # See https://arxiv.org/abs/2211.01095 for detailed derivations
            if self.config.solver_type == "midpoint":
                diff = (D0 + 0.5 * D1)

                if self.vel is None:
                    self.vel = diff
                else:
                    self.vel = (1-self.beta)*self.vel + self.beta * diff
                
                x_t = (
                    (sigma_t / sigma_s0) * sample
                    - (alpha_t * (torch.exp(-h) - 1.0)) * self.vel
                )
            elif self.config.solver_type == "heun":
                raise NotImplementedError(
                    "{self.config.algorithm_type} with {self.config.solver_type} is currently not supported."
                )
        elif self.config.algorithm_type == "dpmsolver":
            # See https://arxiv.org/abs/2206.00927 for detailed derivations
            if self.config.solver_type == "midpoint":
                raise NotImplementedError(
                    "{self.config.algorithm_type} with {self.config.solver_type} is currently not supported."
                )
            elif self.config.solver_type == "heun":
                raise NotImplementedError(
                    "{self.config.algorithm_type} with {self.config.solver_type} is currently not supported."
                )
        return x_t

class MomentumUniPCMultistepScheduler(UniPCMultistepScheduler):
    """
    UniPC with HB momentum.
    Currently support only self.predict_x0 = True

    When beta = 1.0, this method is equivalent to UniPC without momentum.
    """
    def initialize_momentum(self, beta):
        self.vel_p = None
        self.vel_c = None
        self.beta = beta
 
    def multistep_uni_p_bh_update(
        self,
        model_output: torch.FloatTensor,
        prev_timestep: int,
        sample: torch.FloatTensor,
        order: int,
    ) -> torch.FloatTensor:
 
        timestep_list = self.timestep_list
        model_output_list = self.model_outputs
 
        s0, t = self.timestep_list[-1], prev_timestep
        m0 = model_output_list[-1]
        x = sample
 
        if self.solver_p:
            x_t = self.solver_p.step(model_output, s0, x).prev_sample
            return x_t
 
        lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
        alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
        sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
 
        h = lambda_t - lambda_s0
        device = sample.device
 
        rks = []
        D1s = []
        for i in range(1, order):
            si = timestep_list[-(i + 1)]
            mi = model_output_list[-(i + 1)]
            lambda_si = self.lambda_t[si]
            rk = (lambda_si - lambda_s0) / h
            rks.append(rk)
            D1s.append((mi - m0) / rk)
 
        rks.append(1.0)
        rks = torch.tensor(rks, device=device)
 
        R = []
        b = []
 
        hh = -h if self.predict_x0 else h
        h_phi_1 = torch.expm1(hh)  # h\phi_1(h) = e^h - 1
        h_phi_k = h_phi_1 / hh - 1
 
        factorial_i = 1
 
        if self.config.solver_type == "bh1":
            B_h = hh
        elif self.config.solver_type == "bh2":
            B_h = torch.expm1(hh)
        else:
            raise NotImplementedError()
 
        for i in range(1, order + 1):
            R.append(torch.pow(rks, i - 1))
            b.append(h_phi_k * factorial_i / B_h)
            factorial_i *= i + 1
            h_phi_k = h_phi_k / hh - 1 / factorial_i
 
        R = torch.stack(R)
        b = torch.tensor(b, device=device)
 
        if len(D1s) > 0:
            D1s = torch.stack(D1s, dim=1)  # (B, K)
            # for order 2, we use a simplified version
            if order == 2:
                rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
            else:
                rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
        else:
            D1s = None
 
        if self.predict_x0:
            if D1s is not None:
                pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s)
            else:
                pred_res = 0
 
            val = ( h_phi_1 * m0 + B_h * pred_res ) /sigma_t /h_phi_1
            if self.vel_p is None:
                self.vel_p = val
            else:
                self.vel_p = (1-self.beta)*self.vel_p + self.beta * val
            self.vel_p = val
 
            x_t = sigma_t  * (x/ sigma_s0 - alpha_t * self.vel_p * h_phi_1) 
        else:
            raise NotImplementedError
 
        x_t = x_t.to(x.dtype)
        return x_t
 
    def multistep_uni_c_bh_update(
        self,
        this_model_output: torch.FloatTensor,
        this_timestep: int,
        last_sample: torch.FloatTensor,
        this_sample: torch.FloatTensor,
        order: int,
    ) -> torch.FloatTensor:
 
        timestep_list = self.timestep_list
        model_output_list = self.model_outputs
 
        s0, t = timestep_list[-1], this_timestep
        m0 = model_output_list[-1]
        x = last_sample
        x_t = this_sample
        model_t = this_model_output
 
        lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
        alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
        sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
 
        h = lambda_t - lambda_s0
        device = this_sample.device
 
        rks = []
        D1s = []
        for i in range(1, order):
            si = timestep_list[-(i + 1)]
            mi = model_output_list[-(i + 1)]
            lambda_si = self.lambda_t[si]
            rk = (lambda_si - lambda_s0) / h
            rks.append(rk)
            D1s.append((mi - m0) / rk)
 
        rks.append(1.0)
        rks = torch.tensor(rks, device=device)
 
        R = []
        b = []
 
        hh = -h if self.predict_x0 else h
        h_phi_1 = torch.expm1(hh)  # h\phi_1(h) = e^h - 1
        h_phi_k = h_phi_1 / hh - 1
 
        factorial_i = 1
 
        if self.config.solver_type == "bh1":
            B_h = hh
        elif self.config.solver_type == "bh2":
            B_h = torch.expm1(hh)
        else:
            raise NotImplementedError()
 
        for i in range(1, order + 1):
            R.append(torch.pow(rks, i - 1))
            b.append(h_phi_k * factorial_i / B_h)
            factorial_i *= i + 1
            h_phi_k = h_phi_k / hh - 1 / factorial_i
 
        R = torch.stack(R)
        b = torch.tensor(b, device=device)
 
        if len(D1s) > 0:
            D1s = torch.stack(D1s, dim=1)
        else:
            D1s = None
 
        # for order 1, we use a simplified version
        if order == 1:
            rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
        else:
            rhos_c = torch.linalg.solve(R, b)
 
        if self.predict_x0:
            if D1s is not None:
                corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s)
            else:
                corr_res = 0
            D1_t = model_t - m0
 
            val = (h_phi_1 * m0 + B_h * (corr_res + rhos_c[-1] * D1_t))/sigma_t/h_phi_1
            if self.vel_c is None:
                self.vel_c = val
            else:
                self.vel_c = (1-self.beta)*self.vel_c + self.beta * val

            x_t = sigma_t  * (x/ sigma_s0 - alpha_t * self.vel_c * h_phi_1) 
        else:
            raise NotImplementedError
        
        x_t = x_t.to(x.dtype)
        return x_t