Freiburg-AI-Research commited on
Commit
62e125d
Β·
1 Parent(s): 1b65a4a

Upload 10 files

Browse files
glide_text2im/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ A codebase for performing model inference with a text-conditional diffusion model.
3
+ """
glide_text2im/download.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from functools import lru_cache
3
+ from typing import Dict, Optional
4
+
5
+ import requests
6
+ import torch as th
7
+ from filelock import FileLock
8
+ from tqdm.auto import tqdm
9
+
10
+ MODEL_PATHS = {
11
+ "base": "https://huggingface.co/datasets/asifhugs/weights/blob/main/base.pt",
12
+ "upsample": "https://huggingface.co/datasets/asifhugs/weights/blob/main/upsample.pt",
13
+ "base-inpaint": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base_inpaint.pt",
14
+ "upsample-inpaint": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample_inpaint.pt",
15
+ "clip/image-enc": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/clip_image_enc.pt",
16
+ "clip/text-enc": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/clip_text_enc.pt",
17
+ }
18
+
19
+
20
+ @lru_cache()
21
+ def default_cache_dir() -> str:
22
+ return os.path.join(os.path.abspath(os.getcwd()), "glide_model_cache")
23
+
24
+
25
+ def fetch_file_cached(
26
+ url: str, progress: bool = True, cache_dir: Optional[str] = None, chunk_size: int = 4096
27
+ ) -> str:
28
+ """
29
+ Download the file at the given URL into a local file and return the path.
30
+
31
+ If cache_dir is specified, it will be used to download the files.
32
+ Otherwise, default_cache_dir() is used.
33
+ """
34
+ if cache_dir is None:
35
+ cache_dir = default_cache_dir()
36
+ os.makedirs(cache_dir, exist_ok=True)
37
+ response = requests.get(url, stream=True)
38
+ size = int(response.headers.get("content-length", "0"))
39
+ local_path = os.path.join(cache_dir, url.split("/")[-1])
40
+ with FileLock(local_path + ".lock"):
41
+ if os.path.exists(local_path):
42
+ return local_path
43
+ if progress:
44
+ pbar = tqdm(total=size, unit="iB", unit_scale=True)
45
+ tmp_path = local_path + ".tmp"
46
+ with open(tmp_path, "wb") as f:
47
+ for chunk in response.iter_content(chunk_size):
48
+ if progress:
49
+ pbar.update(len(chunk))
50
+ f.write(chunk)
51
+ os.rename(tmp_path, local_path)
52
+ if progress:
53
+ pbar.close()
54
+ return local_path
55
+
56
+
57
+ def load_checkpoint(
58
+ checkpoint_name: str,
59
+ device: th.device,
60
+ progress: bool = True,
61
+ cache_dir: Optional[str] = None,
62
+ chunk_size: int = 4096,
63
+ ) -> Dict[str, th.Tensor]:
64
+ if checkpoint_name not in MODEL_PATHS:
65
+ raise ValueError(
66
+ f"Unknown checkpoint name {checkpoint_name}. Known names are: {MODEL_PATHS.keys()}."
67
+ )
68
+ path = fetch_file_cached(
69
+ MODEL_PATHS[checkpoint_name], progress=progress, cache_dir=cache_dir, chunk_size=chunk_size
70
+ )
71
+ return th.load(path, map_location=device)
glide_text2im/fp16_util.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpers to inference with 16-bit precision.
3
+ """
4
+
5
+ import torch.nn as nn
6
+
7
+
8
+ def convert_module_to_f16(l):
9
+ """
10
+ Convert primitive modules to float16.
11
+ """
12
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
13
+ l.weight.data = l.weight.data.half()
14
+ if l.bias is not None:
15
+ l.bias.data = l.bias.data.half()
16
+
17
+
18
+ def convert_module_to_f32(l):
19
+ """
20
+ Convert primitive modules to float32, undoing convert_module_to_f16().
21
+ """
22
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
23
+ l.weight.data = l.weight.data.float()
24
+ if l.bias is not None:
25
+ l.bias.data = l.bias.data.float()
glide_text2im/gaussian_diffusion.py ADDED
@@ -0,0 +1,639 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simplified from https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/gaussian_diffusion.py.
3
+ """
4
+
5
+ import math
6
+
7
+ import numpy as np
8
+ import torch as th
9
+
10
+
11
+ def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
12
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
13
+ warmup_time = int(num_diffusion_timesteps * warmup_frac)
14
+ betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
15
+ return betas
16
+
17
+
18
+ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
19
+ """
20
+ This is the deprecated API for creating beta schedules.
21
+
22
+ See get_named_beta_schedule() for the new library of schedules.
23
+ """
24
+ if beta_schedule == "quad":
25
+ betas = (
26
+ np.linspace(
27
+ beta_start ** 0.5,
28
+ beta_end ** 0.5,
29
+ num_diffusion_timesteps,
30
+ dtype=np.float64,
31
+ )
32
+ ** 2
33
+ )
34
+ elif beta_schedule == "linear":
35
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
36
+ elif beta_schedule == "warmup10":
37
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
38
+ elif beta_schedule == "warmup50":
39
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
40
+ elif beta_schedule == "const":
41
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
42
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
43
+ betas = 1.0 / np.linspace(
44
+ num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
45
+ )
46
+ else:
47
+ raise NotImplementedError(beta_schedule)
48
+ assert betas.shape == (num_diffusion_timesteps,)
49
+ return betas
50
+
51
+
52
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
53
+ """
54
+ Get a pre-defined beta schedule for the given name.
55
+
56
+ The beta schedule library consists of beta schedules which remain similar
57
+ in the limit of num_diffusion_timesteps.
58
+ Beta schedules may be added, but should not be removed or changed once
59
+ they are committed to maintain backwards compatibility.
60
+ """
61
+ if schedule_name == "linear":
62
+ # Linear schedule from Ho et al, extended to work for any number of
63
+ # diffusion steps.
64
+ scale = 1000 / num_diffusion_timesteps
65
+ return get_beta_schedule(
66
+ "linear",
67
+ beta_start=scale * 0.0001,
68
+ beta_end=scale * 0.02,
69
+ num_diffusion_timesteps=num_diffusion_timesteps,
70
+ )
71
+ elif schedule_name == "squaredcos_cap_v2":
72
+ return betas_for_alpha_bar(
73
+ num_diffusion_timesteps,
74
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
75
+ )
76
+ else:
77
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
78
+
79
+
80
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
81
+ """
82
+ Create a beta schedule that discretizes the given alpha_t_bar function,
83
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
84
+
85
+ :param num_diffusion_timesteps: the number of betas to produce.
86
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
87
+ produces the cumulative product of (1-beta) up to that
88
+ part of the diffusion process.
89
+ :param max_beta: the maximum beta to use; use values lower than 1 to
90
+ prevent singularities.
91
+ """
92
+ betas = []
93
+ for i in range(num_diffusion_timesteps):
94
+ t1 = i / num_diffusion_timesteps
95
+ t2 = (i + 1) / num_diffusion_timesteps
96
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
97
+ return np.array(betas)
98
+
99
+
100
+ class GaussianDiffusion:
101
+ """
102
+ Utilities for training and sampling diffusion models.
103
+
104
+ Original ported from this codebase:
105
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
106
+
107
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
108
+ starting at T and going to 1.
109
+ """
110
+
111
+ def __init__(
112
+ self,
113
+ *,
114
+ betas,
115
+ ):
116
+ # Use float64 for accuracy.
117
+ betas = np.array(betas, dtype=np.float64)
118
+ self.betas = betas
119
+ assert len(betas.shape) == 1, "betas must be 1-D"
120
+ assert (betas > 0).all() and (betas <= 1).all()
121
+
122
+ self.num_timesteps = int(betas.shape[0])
123
+
124
+ alphas = 1.0 - betas
125
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
126
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
127
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
128
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
129
+
130
+ # calculations for diffusion q(x_t | x_{t-1}) and others
131
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
132
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
133
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
134
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
135
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
136
+
137
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
138
+ self.posterior_variance = (
139
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
140
+ )
141
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
142
+ self.posterior_log_variance_clipped = np.log(
143
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
144
+ )
145
+ self.posterior_mean_coef1 = (
146
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
147
+ )
148
+ self.posterior_mean_coef2 = (
149
+ (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
150
+ )
151
+
152
+ def q_mean_variance(self, x_start, t):
153
+ """
154
+ Get the distribution q(x_t | x_0).
155
+
156
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
157
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
158
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
159
+ """
160
+ mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
161
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
162
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
163
+ return mean, variance, log_variance
164
+
165
+ def q_sample(self, x_start, t, noise=None):
166
+ """
167
+ Diffuse the data for a given number of diffusion steps.
168
+
169
+ In other words, sample from q(x_t | x_0).
170
+
171
+ :param x_start: the initial data batch.
172
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
173
+ :param noise: if specified, the split-out normal noise.
174
+ :return: A noisy version of x_start.
175
+ """
176
+ if noise is None:
177
+ noise = th.randn_like(x_start)
178
+ assert noise.shape == x_start.shape
179
+ return (
180
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
181
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
182
+ )
183
+
184
+ def q_posterior_mean_variance(self, x_start, x_t, t):
185
+ """
186
+ Compute the mean and variance of the diffusion posterior:
187
+
188
+ q(x_{t-1} | x_t, x_0)
189
+
190
+ """
191
+ assert x_start.shape == x_t.shape
192
+ posterior_mean = (
193
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
194
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
195
+ )
196
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
197
+ posterior_log_variance_clipped = _extract_into_tensor(
198
+ self.posterior_log_variance_clipped, t, x_t.shape
199
+ )
200
+ assert (
201
+ posterior_mean.shape[0]
202
+ == posterior_variance.shape[0]
203
+ == posterior_log_variance_clipped.shape[0]
204
+ == x_start.shape[0]
205
+ )
206
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
207
+
208
+ def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
209
+ """
210
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
211
+ the initial x, x_0.
212
+
213
+ :param model: the model, which takes a signal and a batch of timesteps
214
+ as input.
215
+ :param x: the [N x C x ...] tensor at time t.
216
+ :param t: a 1-D Tensor of timesteps.
217
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
218
+ :param denoised_fn: if not None, a function which applies to the
219
+ x_start prediction before it is used to sample. Applies before
220
+ clip_denoised.
221
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
222
+ pass to the model. This can be used for conditioning.
223
+ :return: a dict with the following keys:
224
+ - 'mean': the model mean output.
225
+ - 'variance': the model variance output.
226
+ - 'log_variance': the log of 'variance'.
227
+ - 'pred_xstart': the prediction for x_0.
228
+ """
229
+ if model_kwargs is None:
230
+ model_kwargs = {}
231
+
232
+ B, C = x.shape[:2]
233
+ assert t.shape == (B,)
234
+ model_output = model(x, t, **model_kwargs)
235
+ if isinstance(model_output, tuple):
236
+ model_output, extra = model_output
237
+ else:
238
+ extra = None
239
+
240
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
241
+ model_output, model_var_values = th.split(model_output, C, dim=1)
242
+ min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
243
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
244
+ # The model_var_values is [-1, 1] for [min_var, max_var].
245
+ frac = (model_var_values + 1) / 2
246
+ model_log_variance = frac * max_log + (1 - frac) * min_log
247
+ model_variance = th.exp(model_log_variance)
248
+
249
+ def process_xstart(x):
250
+ if denoised_fn is not None:
251
+ x = denoised_fn(x)
252
+ if clip_denoised:
253
+ return x.clamp(-1, 1)
254
+ return x
255
+
256
+ pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output))
257
+ model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
258
+
259
+ assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
260
+ return {
261
+ "mean": model_mean,
262
+ "variance": model_variance,
263
+ "log_variance": model_log_variance,
264
+ "pred_xstart": pred_xstart,
265
+ "extra": extra,
266
+ }
267
+
268
+ def _predict_xstart_from_eps(self, x_t, t, eps):
269
+ assert x_t.shape == eps.shape
270
+ return (
271
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
272
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
273
+ )
274
+
275
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
276
+ return (
277
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
278
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
279
+
280
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
281
+ """
282
+ Compute the mean for the previous step, given a function cond_fn that
283
+ computes the gradient of a conditional log probability with respect to
284
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
285
+ condition on y.
286
+
287
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
288
+ """
289
+ gradient = cond_fn(x, t, **model_kwargs)
290
+ new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
291
+ return new_mean
292
+
293
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
294
+ """
295
+ Compute what the p_mean_variance output would have been, should the
296
+ model's score function be conditioned by cond_fn.
297
+
298
+ See condition_mean() for details on cond_fn.
299
+
300
+ Unlike condition_mean(), this instead uses the conditioning strategy
301
+ from Song et al (2020).
302
+ """
303
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
304
+
305
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
306
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
307
+
308
+ out = p_mean_var.copy()
309
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
310
+ out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
311
+ return out
312
+
313
+ def p_sample(
314
+ self,
315
+ model,
316
+ x,
317
+ t,
318
+ clip_denoised=True,
319
+ denoised_fn=None,
320
+ cond_fn=None,
321
+ model_kwargs=None,
322
+ ):
323
+ """
324
+ Sample x_{t-1} from the model at the given timestep.
325
+
326
+ :param model: the model to sample from.
327
+ :param x: the current tensor at x_{t-1}.
328
+ :param t: the value of t, starting at 0 for the first diffusion step.
329
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
330
+ :param denoised_fn: if not None, a function which applies to the
331
+ x_start prediction before it is used to sample.
332
+ :param cond_fn: if not None, this is a gradient function that acts
333
+ similarly to the model.
334
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
335
+ pass to the model. This can be used for conditioning.
336
+ :return: a dict containing the following keys:
337
+ - 'sample': a random sample from the model.
338
+ - 'pred_xstart': a prediction of x_0.
339
+ """
340
+ out = self.p_mean_variance(
341
+ model,
342
+ x,
343
+ t,
344
+ clip_denoised=clip_denoised,
345
+ denoised_fn=denoised_fn,
346
+ model_kwargs=model_kwargs,
347
+ )
348
+ noise = th.randn_like(x)
349
+ nonzero_mask = (
350
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
351
+ ) # no noise when t == 0
352
+ if cond_fn is not None:
353
+ out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
354
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
355
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
356
+
357
+ def p_sample_loop(
358
+ self,
359
+ model,
360
+ shape,
361
+ noise=None,
362
+ clip_denoised=True,
363
+ denoised_fn=None,
364
+ cond_fn=None,
365
+ model_kwargs=None,
366
+ device=None,
367
+ progress=False,
368
+ ):
369
+ """
370
+ Generate samples from the model.
371
+
372
+ :param model: the model module.
373
+ :param shape: the shape of the samples, (N, C, H, W).
374
+ :param noise: if specified, the noise from the encoder to sample.
375
+ Should be of the same shape as `shape`.
376
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
377
+ :param denoised_fn: if not None, a function which applies to the
378
+ x_start prediction before it is used to sample.
379
+ :param cond_fn: if not None, this is a gradient function that acts
380
+ similarly to the model.
381
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
382
+ pass to the model. This can be used for conditioning.
383
+ :param device: if specified, the device to create the samples on.
384
+ If not specified, use a model parameter's device.
385
+ :param progress: if True, show a tqdm progress bar.
386
+ :return: a non-differentiable batch of samples.
387
+ """
388
+ final = None
389
+ for sample in self.p_sample_loop_progressive(
390
+ model,
391
+ shape,
392
+ noise=noise,
393
+ clip_denoised=clip_denoised,
394
+ denoised_fn=denoised_fn,
395
+ cond_fn=cond_fn,
396
+ model_kwargs=model_kwargs,
397
+ device=device,
398
+ progress=progress,
399
+ ):
400
+ final = sample
401
+ return final["sample"]
402
+
403
+ def p_sample_loop_progressive(
404
+ self,
405
+ model,
406
+ shape,
407
+ noise=None,
408
+ clip_denoised=True,
409
+ denoised_fn=None,
410
+ cond_fn=None,
411
+ model_kwargs=None,
412
+ device=None,
413
+ progress=False,
414
+ ):
415
+ """
416
+ Generate samples from the model and yield intermediate samples from
417
+ each timestep of diffusion.
418
+
419
+ Arguments are the same as p_sample_loop().
420
+ Returns a generator over dicts, where each dict is the return value of
421
+ p_sample().
422
+ """
423
+ if device is None:
424
+ device = next(model.parameters()).device
425
+ assert isinstance(shape, (tuple, list))
426
+ if noise is not None:
427
+ img = noise
428
+ else:
429
+ img = th.randn(*shape, device=device)
430
+ indices = list(range(self.num_timesteps))[::-1]
431
+
432
+ if progress:
433
+ # Lazy import so that we don't depend on tqdm.
434
+ from tqdm.auto import tqdm
435
+
436
+ indices = tqdm(indices)
437
+
438
+ for i in indices:
439
+ t = th.tensor([i] * shape[0], device=device)
440
+ with th.no_grad():
441
+ out = self.p_sample(
442
+ model,
443
+ img,
444
+ t,
445
+ clip_denoised=clip_denoised,
446
+ denoised_fn=denoised_fn,
447
+ cond_fn=cond_fn,
448
+ model_kwargs=model_kwargs,
449
+ )
450
+ yield out
451
+ img = out["sample"]
452
+
453
+ def ddim_sample(
454
+ self,
455
+ model,
456
+ x,
457
+ t,
458
+ clip_denoised=True,
459
+ denoised_fn=None,
460
+ cond_fn=None,
461
+ model_kwargs=None,
462
+ eta=0.0,
463
+ ):
464
+ """
465
+ Sample x_{t-1} from the model using DDIM.
466
+
467
+ Same usage as p_sample().
468
+ """
469
+ out = self.p_mean_variance(
470
+ model,
471
+ x,
472
+ t,
473
+ clip_denoised=clip_denoised,
474
+ denoised_fn=denoised_fn,
475
+ model_kwargs=model_kwargs,
476
+ )
477
+ if cond_fn is not None:
478
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
479
+
480
+ # Usually our model outputs epsilon, but we re-derive it
481
+ # in case we used x_start or x_prev prediction.
482
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
483
+
484
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
485
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
486
+ sigma = (
487
+ eta
488
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
489
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
490
+ )
491
+ # Equation 12.
492
+ noise = th.randn_like(x)
493
+ mean_pred = (
494
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
495
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
496
+ )
497
+ nonzero_mask = (
498
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
499
+ ) # no noise when t == 0
500
+ sample = mean_pred + nonzero_mask * sigma * noise
501
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
502
+
503
+ def ddim_reverse_sample(
504
+ self,
505
+ model,
506
+ x,
507
+ t,
508
+ clip_denoised=True,
509
+ denoised_fn=None,
510
+ cond_fn=None,
511
+ model_kwargs=None,
512
+ eta=0.0,
513
+ ):
514
+ """
515
+ Sample x_{t+1} from the model using DDIM reverse ODE.
516
+ """
517
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
518
+ out = self.p_mean_variance(
519
+ model,
520
+ x,
521
+ t,
522
+ clip_denoised=clip_denoised,
523
+ denoised_fn=denoised_fn,
524
+ model_kwargs=model_kwargs,
525
+ )
526
+ if cond_fn is not None:
527
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
528
+ # Usually our model outputs epsilon, but we re-derive it
529
+ # in case we used x_start or x_prev prediction.
530
+ eps = (
531
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
532
+ - out["pred_xstart"]
533
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
534
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
535
+
536
+ # Equation 12. reversed
537
+ mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
538
+
539
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
540
+
541
+ def ddim_sample_loop(
542
+ self,
543
+ model,
544
+ shape,
545
+ noise=None,
546
+ clip_denoised=True,
547
+ denoised_fn=None,
548
+ cond_fn=None,
549
+ model_kwargs=None,
550
+ device=None,
551
+ progress=False,
552
+ eta=0.0,
553
+ ):
554
+ """
555
+ Generate samples from the model using DDIM.
556
+
557
+ Same usage as p_sample_loop().
558
+ """
559
+ final = None
560
+ for sample in self.ddim_sample_loop_progressive(
561
+ model,
562
+ shape,
563
+ noise=noise,
564
+ clip_denoised=clip_denoised,
565
+ denoised_fn=denoised_fn,
566
+ cond_fn=cond_fn,
567
+ model_kwargs=model_kwargs,
568
+ device=device,
569
+ progress=progress,
570
+ eta=eta,
571
+ ):
572
+ final = sample
573
+ return final["sample"]
574
+
575
+ def ddim_sample_loop_progressive(
576
+ self,
577
+ model,
578
+ shape,
579
+ noise=None,
580
+ clip_denoised=True,
581
+ denoised_fn=None,
582
+ cond_fn=None,
583
+ model_kwargs=None,
584
+ device=None,
585
+ progress=False,
586
+ eta=0.0,
587
+ ):
588
+ """
589
+ Use DDIM to sample from the model and yield intermediate samples from
590
+ each timestep of DDIM.
591
+
592
+ Same usage as p_sample_loop_progressive().
593
+ """
594
+ if device is None:
595
+ device = next(model.parameters()).device
596
+ assert isinstance(shape, (tuple, list))
597
+ if noise is not None:
598
+ img = noise
599
+ else:
600
+ img = th.randn(*shape, device=device)
601
+ indices = list(range(self.num_timesteps))[::-1]
602
+
603
+ if progress:
604
+ # Lazy import so that we don't depend on tqdm.
605
+ from tqdm.auto import tqdm
606
+
607
+ indices = tqdm(indices)
608
+
609
+ for i in indices:
610
+ t = th.tensor([i] * shape[0], device=device)
611
+ with th.no_grad():
612
+ out = self.ddim_sample(
613
+ model,
614
+ img,
615
+ t,
616
+ clip_denoised=clip_denoised,
617
+ denoised_fn=denoised_fn,
618
+ cond_fn=cond_fn,
619
+ model_kwargs=model_kwargs,
620
+ eta=eta,
621
+ )
622
+ yield out
623
+ img = out["sample"]
624
+
625
+
626
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
627
+ """
628
+ Extract values from a 1-D numpy array for a batch of indices.
629
+
630
+ :param arr: the 1-D numpy array.
631
+ :param timesteps: a tensor of indices into the array to extract.
632
+ :param broadcast_shape: a larger shape of K dimensions with the batch
633
+ dimension equal to the length of timesteps.
634
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
635
+ """
636
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
637
+ while len(res.shape) < len(broadcast_shape):
638
+ res = res[..., None]
639
+ return res + th.zeros(broadcast_shape, device=timesteps.device)
glide_text2im/model_creation.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from glide_text2im.gaussian_diffusion import get_named_beta_schedule
2
+ from glide_text2im.respace import SpacedDiffusion, space_timesteps
3
+ from glide_text2im.text2im_model import (
4
+ InpaintText2ImUNet,
5
+ SuperResInpaintText2ImUnet,
6
+ SuperResText2ImUNet,
7
+ Text2ImUNet,
8
+ )
9
+ from glide_text2im.tokenizer.bpe import get_encoder
10
+
11
+
12
+ def model_and_diffusion_defaults():
13
+ return dict(
14
+ image_size=64,
15
+ num_channels=192,
16
+ num_res_blocks=3,
17
+ channel_mult="",
18
+ num_heads=1,
19
+ num_head_channels=64,
20
+ num_heads_upsample=-1,
21
+ attention_resolutions="32,16,8",
22
+ dropout=0.1,
23
+ text_ctx=128,
24
+ xf_width=512,
25
+ xf_layers=16,
26
+ xf_heads=8,
27
+ xf_final_ln=True,
28
+ xf_padding=True,
29
+ diffusion_steps=1000,
30
+ noise_schedule="squaredcos_cap_v2",
31
+ timestep_respacing="",
32
+ use_scale_shift_norm=True,
33
+ resblock_updown=True,
34
+ use_fp16=True,
35
+ cache_text_emb=False,
36
+ inpaint=False,
37
+ super_res=False,
38
+ )
39
+
40
+
41
+ def model_and_diffusion_defaults_upsampler():
42
+ result = model_and_diffusion_defaults()
43
+ result.update(
44
+ dict(
45
+ image_size=256,
46
+ num_res_blocks=2,
47
+ noise_schedule="linear",
48
+ super_res=True,
49
+ )
50
+ )
51
+ return result
52
+
53
+
54
+ def create_model_and_diffusion(
55
+ image_size,
56
+ num_channels,
57
+ num_res_blocks,
58
+ channel_mult,
59
+ num_heads,
60
+ num_head_channels,
61
+ num_heads_upsample,
62
+ attention_resolutions,
63
+ dropout,
64
+ text_ctx,
65
+ xf_width,
66
+ xf_layers,
67
+ xf_heads,
68
+ xf_final_ln,
69
+ xf_padding,
70
+ diffusion_steps,
71
+ noise_schedule,
72
+ timestep_respacing,
73
+ use_scale_shift_norm,
74
+ resblock_updown,
75
+ use_fp16,
76
+ cache_text_emb,
77
+ inpaint,
78
+ super_res,
79
+ ):
80
+ model = create_model(
81
+ image_size,
82
+ num_channels,
83
+ num_res_blocks,
84
+ channel_mult=channel_mult,
85
+ attention_resolutions=attention_resolutions,
86
+ num_heads=num_heads,
87
+ num_head_channels=num_head_channels,
88
+ num_heads_upsample=num_heads_upsample,
89
+ use_scale_shift_norm=use_scale_shift_norm,
90
+ dropout=dropout,
91
+ text_ctx=text_ctx,
92
+ xf_width=xf_width,
93
+ xf_layers=xf_layers,
94
+ xf_heads=xf_heads,
95
+ xf_final_ln=xf_final_ln,
96
+ xf_padding=xf_padding,
97
+ resblock_updown=resblock_updown,
98
+ use_fp16=use_fp16,
99
+ cache_text_emb=cache_text_emb,
100
+ inpaint=inpaint,
101
+ super_res=super_res,
102
+ )
103
+ diffusion = create_gaussian_diffusion(
104
+ steps=diffusion_steps,
105
+ noise_schedule=noise_schedule,
106
+ timestep_respacing=timestep_respacing,
107
+ )
108
+ return model, diffusion
109
+
110
+
111
+ def create_model(
112
+ image_size,
113
+ num_channels,
114
+ num_res_blocks,
115
+ channel_mult,
116
+ attention_resolutions,
117
+ num_heads,
118
+ num_head_channels,
119
+ num_heads_upsample,
120
+ use_scale_shift_norm,
121
+ dropout,
122
+ text_ctx,
123
+ xf_width,
124
+ xf_layers,
125
+ xf_heads,
126
+ xf_final_ln,
127
+ xf_padding,
128
+ resblock_updown,
129
+ use_fp16,
130
+ cache_text_emb,
131
+ inpaint,
132
+ super_res,
133
+ ):
134
+ if channel_mult == "":
135
+ if image_size == 256:
136
+ channel_mult = (1, 1, 2, 2, 4, 4)
137
+ elif image_size == 128:
138
+ channel_mult = (1, 1, 2, 3, 4)
139
+ elif image_size == 64:
140
+ channel_mult = (1, 2, 3, 4)
141
+ else:
142
+ raise ValueError(f"unsupported image size: {image_size}")
143
+ else:
144
+ channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(","))
145
+ assert 2 ** (len(channel_mult) + 2) == image_size
146
+
147
+ attention_ds = []
148
+ for res in attention_resolutions.split(","):
149
+ attention_ds.append(image_size // int(res))
150
+
151
+ if inpaint and super_res:
152
+ model_cls = SuperResInpaintText2ImUnet
153
+ elif inpaint:
154
+ model_cls = InpaintText2ImUNet
155
+ elif super_res:
156
+ model_cls = SuperResText2ImUNet
157
+ else:
158
+ model_cls = Text2ImUNet
159
+ return model_cls(
160
+ text_ctx=text_ctx,
161
+ xf_width=xf_width,
162
+ xf_layers=xf_layers,
163
+ xf_heads=xf_heads,
164
+ xf_final_ln=xf_final_ln,
165
+ tokenizer=get_encoder(),
166
+ xf_padding=xf_padding,
167
+ in_channels=3,
168
+ model_channels=num_channels,
169
+ out_channels=6,
170
+ num_res_blocks=num_res_blocks,
171
+ attention_resolutions=tuple(attention_ds),
172
+ dropout=dropout,
173
+ channel_mult=channel_mult,
174
+ use_fp16=use_fp16,
175
+ num_heads=num_heads,
176
+ num_head_channels=num_head_channels,
177
+ num_heads_upsample=num_heads_upsample,
178
+ use_scale_shift_norm=use_scale_shift_norm,
179
+ resblock_updown=resblock_updown,
180
+ cache_text_emb=cache_text_emb,
181
+ )
182
+
183
+
184
+ def create_gaussian_diffusion(
185
+ steps,
186
+ noise_schedule,
187
+ timestep_respacing,
188
+ ):
189
+ betas = get_named_beta_schedule(noise_schedule, steps)
190
+ if not timestep_respacing:
191
+ timestep_respacing = [steps]
192
+ return SpacedDiffusion(
193
+ use_timesteps=space_timesteps(steps, timestep_respacing),
194
+ betas=betas,
195
+ )
glide_text2im/nn.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Various utilities for neural networks.
3
+ """
4
+
5
+ import math
6
+
7
+ import torch as th
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ class GroupNorm32(nn.GroupNorm):
13
+ def __init__(self, num_groups, num_channels, swish, eps=1e-5):
14
+ super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps)
15
+ self.swish = swish
16
+
17
+ def forward(self, x):
18
+ y = super().forward(x.float()).to(x.dtype)
19
+ if self.swish == 1.0:
20
+ y = F.silu(y)
21
+ elif self.swish:
22
+ y = y * F.sigmoid(y * float(self.swish))
23
+ return y
24
+
25
+
26
+ def conv_nd(dims, *args, **kwargs):
27
+ """
28
+ Create a 1D, 2D, or 3D convolution module.
29
+ """
30
+ if dims == 1:
31
+ return nn.Conv1d(*args, **kwargs)
32
+ elif dims == 2:
33
+ return nn.Conv2d(*args, **kwargs)
34
+ elif dims == 3:
35
+ return nn.Conv3d(*args, **kwargs)
36
+ raise ValueError(f"unsupported dimensions: {dims}")
37
+
38
+
39
+ def linear(*args, **kwargs):
40
+ """
41
+ Create a linear module.
42
+ """
43
+ return nn.Linear(*args, **kwargs)
44
+
45
+
46
+ def avg_pool_nd(dims, *args, **kwargs):
47
+ """
48
+ Create a 1D, 2D, or 3D average pooling module.
49
+ """
50
+ if dims == 1:
51
+ return nn.AvgPool1d(*args, **kwargs)
52
+ elif dims == 2:
53
+ return nn.AvgPool2d(*args, **kwargs)
54
+ elif dims == 3:
55
+ return nn.AvgPool3d(*args, **kwargs)
56
+ raise ValueError(f"unsupported dimensions: {dims}")
57
+
58
+
59
+ def zero_module(module):
60
+ """
61
+ Zero out the parameters of a module and return it.
62
+ """
63
+ for p in module.parameters():
64
+ p.detach().zero_()
65
+ return module
66
+
67
+
68
+ def scale_module(module, scale):
69
+ """
70
+ Scale the parameters of a module and return it.
71
+ """
72
+ for p in module.parameters():
73
+ p.detach().mul_(scale)
74
+ return module
75
+
76
+
77
+ def normalization(channels, swish=0.0):
78
+ """
79
+ Make a standard normalization layer, with an optional swish activation.
80
+
81
+ :param channels: number of input channels.
82
+ :return: an nn.Module for normalization.
83
+ """
84
+ return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
85
+
86
+
87
+ def timestep_embedding(timesteps, dim, max_period=10000):
88
+ """
89
+ Create sinusoidal timestep embeddings.
90
+
91
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
92
+ These may be fractional.
93
+ :param dim: the dimension of the output.
94
+ :param max_period: controls the minimum frequency of the embeddings.
95
+ :return: an [N x dim] Tensor of positional embeddings.
96
+ """
97
+ half = dim // 2
98
+ freqs = th.exp(
99
+ -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
100
+ ).to(device=timesteps.device)
101
+ args = timesteps[:, None].float() * freqs[None]
102
+ embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
103
+ if dim % 2:
104
+ embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
105
+ return embedding
glide_text2im/respace.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for changing sampling schedules of a trained model.
3
+
4
+ Simplified from: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py
5
+ """
6
+
7
+ import numpy as np
8
+ import torch as th
9
+
10
+ from .gaussian_diffusion import GaussianDiffusion
11
+
12
+
13
+ def space_timesteps(num_timesteps, section_counts):
14
+ """
15
+ Create a list of timesteps to use from an original diffusion process,
16
+ given the number of timesteps we want to take from equally-sized portions
17
+ of the original process.
18
+
19
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
20
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
21
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
22
+
23
+ :param num_timesteps: the number of diffusion steps in the original
24
+ process to divide up.
25
+ :param section_counts: either a list of numbers, or a string containing
26
+ comma-separated numbers, indicating the step count
27
+ per section. As a special case, use "ddimN" where N
28
+ is a number of steps to use the striding from the
29
+ DDIM paper.
30
+ :return: a set of diffusion steps from the original process to use.
31
+ """
32
+ if isinstance(section_counts, str):
33
+ if section_counts.startswith("ddim"):
34
+ desired_count = int(section_counts[len("ddim") :])
35
+ for i in range(1, num_timesteps):
36
+ if len(range(0, num_timesteps, i)) == desired_count:
37
+ return set(range(0, num_timesteps, i))
38
+ raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride")
39
+ elif section_counts == "fast27":
40
+ steps = space_timesteps(num_timesteps, "10,10,3,2,2")
41
+ # Help reduce DDIM artifacts from noisiest timesteps.
42
+ steps.remove(num_timesteps - 1)
43
+ steps.add(num_timesteps - 3)
44
+ return steps
45
+ section_counts = [int(x) for x in section_counts.split(",")]
46
+ size_per = num_timesteps // len(section_counts)
47
+ extra = num_timesteps % len(section_counts)
48
+ start_idx = 0
49
+ all_steps = []
50
+ for i, section_count in enumerate(section_counts):
51
+ size = size_per + (1 if i < extra else 0)
52
+ if size < section_count:
53
+ raise ValueError(f"cannot divide section of {size} steps into {section_count}")
54
+ if section_count <= 1:
55
+ frac_stride = 1
56
+ else:
57
+ frac_stride = (size - 1) / (section_count - 1)
58
+ cur_idx = 0.0
59
+ taken_steps = []
60
+ for _ in range(section_count):
61
+ taken_steps.append(start_idx + round(cur_idx))
62
+ cur_idx += frac_stride
63
+ all_steps += taken_steps
64
+ start_idx += size
65
+ return set(all_steps)
66
+
67
+
68
+ class SpacedDiffusion(GaussianDiffusion):
69
+ """
70
+ A diffusion process which can skip steps in a base diffusion process.
71
+
72
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
73
+ original diffusion process to retain.
74
+ :param kwargs: the kwargs to create the base diffusion process.
75
+ """
76
+
77
+ def __init__(self, use_timesteps, **kwargs):
78
+ self.use_timesteps = set(use_timesteps)
79
+ self.timestep_map = []
80
+ self.original_num_steps = len(kwargs["betas"])
81
+
82
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
83
+ last_alpha_cumprod = 1.0
84
+ new_betas = []
85
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
86
+ if i in self.use_timesteps:
87
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
88
+ last_alpha_cumprod = alpha_cumprod
89
+ self.timestep_map.append(i)
90
+ kwargs["betas"] = np.array(new_betas)
91
+ super().__init__(**kwargs)
92
+
93
+ def p_mean_variance(self, model, *args, **kwargs):
94
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
95
+
96
+ def condition_mean(self, cond_fn, *args, **kwargs):
97
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
98
+
99
+ def condition_score(self, cond_fn, *args, **kwargs):
100
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
101
+
102
+ def _wrap_model(self, model):
103
+ if isinstance(model, _WrappedModel):
104
+ return model
105
+ return _WrappedModel(model, self.timestep_map, self.original_num_steps)
106
+
107
+
108
+ class _WrappedModel:
109
+ def __init__(self, model, timestep_map, original_num_steps):
110
+ self.model = model
111
+ self.timestep_map = timestep_map
112
+ self.original_num_steps = original_num_steps
113
+
114
+ def __call__(self, x, ts, **kwargs):
115
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
116
+ new_ts = map_tensor[ts]
117
+ return self.model(x, new_ts, **kwargs)
glide_text2im/text2im_model.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as th
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .nn import timestep_embedding
6
+ from .unet import UNetModel
7
+ from .xf import LayerNorm, Transformer, convert_module_to_f16
8
+
9
+
10
+ class Text2ImUNet(UNetModel):
11
+ """
12
+ A UNetModel that conditions on text with an encoding transformer.
13
+
14
+ Expects an extra kwarg `tokens` of text.
15
+
16
+ :param text_ctx: number of text tokens to expect.
17
+ :param xf_width: width of the transformer.
18
+ :param xf_layers: depth of the transformer.
19
+ :param xf_heads: heads in the transformer.
20
+ :param xf_final_ln: use a LayerNorm after the output layer.
21
+ :param tokenizer: the text tokenizer for sampling/vocab size.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ text_ctx,
27
+ xf_width,
28
+ xf_layers,
29
+ xf_heads,
30
+ xf_final_ln,
31
+ tokenizer,
32
+ *args,
33
+ cache_text_emb=False,
34
+ xf_ar=0.0,
35
+ xf_padding=False,
36
+ share_unemb=False,
37
+ **kwargs,
38
+ ):
39
+ self.text_ctx = text_ctx
40
+ self.xf_width = xf_width
41
+ self.xf_ar = xf_ar
42
+ self.xf_padding = xf_padding
43
+ self.tokenizer = tokenizer
44
+
45
+ if not xf_width:
46
+ super().__init__(*args, **kwargs, encoder_channels=None)
47
+ else:
48
+ super().__init__(*args, **kwargs, encoder_channels=xf_width)
49
+ if self.xf_width:
50
+ self.transformer = Transformer(
51
+ text_ctx,
52
+ xf_width,
53
+ xf_layers,
54
+ xf_heads,
55
+ )
56
+ if xf_final_ln:
57
+ self.final_ln = LayerNorm(xf_width)
58
+ else:
59
+ self.final_ln = None
60
+
61
+ self.token_embedding = nn.Embedding(self.tokenizer.n_vocab, xf_width)
62
+ self.positional_embedding = nn.Parameter(th.empty(text_ctx, xf_width, dtype=th.float32))
63
+ self.transformer_proj = nn.Linear(xf_width, self.model_channels * 4)
64
+
65
+ if self.xf_padding:
66
+ self.padding_embedding = nn.Parameter(
67
+ th.empty(text_ctx, xf_width, dtype=th.float32)
68
+ )
69
+ if self.xf_ar:
70
+ self.unemb = nn.Linear(xf_width, self.tokenizer.n_vocab)
71
+ if share_unemb:
72
+ self.unemb.weight = self.token_embedding.weight
73
+
74
+ self.cache_text_emb = cache_text_emb
75
+ self.cache = None
76
+
77
+ def convert_to_fp16(self):
78
+ super().convert_to_fp16()
79
+ if self.xf_width:
80
+ self.transformer.apply(convert_module_to_f16)
81
+ self.transformer_proj.to(th.float16)
82
+ self.token_embedding.to(th.float16)
83
+ self.positional_embedding.to(th.float16)
84
+ if self.xf_padding:
85
+ self.padding_embedding.to(th.float16)
86
+ if self.xf_ar:
87
+ self.unemb.to(th.float16)
88
+
89
+ def get_text_emb(self, tokens, mask):
90
+ assert tokens is not None
91
+
92
+ if self.cache_text_emb and self.cache is not None:
93
+ assert (
94
+ tokens == self.cache["tokens"]
95
+ ).all(), f"Tokens {tokens.cpu().numpy().tolist()} do not match cache {self.cache['tokens'].cpu().numpy().tolist()}"
96
+ return self.cache
97
+
98
+ xf_in = self.token_embedding(tokens.long())
99
+ xf_in = xf_in + self.positional_embedding[None]
100
+ if self.xf_padding:
101
+ assert mask is not None
102
+ xf_in = th.where(mask[..., None], xf_in, self.padding_embedding[None])
103
+ xf_out = self.transformer(xf_in.to(self.dtype))
104
+ if self.final_ln is not None:
105
+ xf_out = self.final_ln(xf_out)
106
+ xf_proj = self.transformer_proj(xf_out[:, -1])
107
+ xf_out = xf_out.permute(0, 2, 1) # NLC -> NCL
108
+
109
+ outputs = dict(xf_proj=xf_proj, xf_out=xf_out)
110
+
111
+ if self.cache_text_emb:
112
+ self.cache = dict(
113
+ tokens=tokens,
114
+ xf_proj=xf_proj.detach(),
115
+ xf_out=xf_out.detach() if xf_out is not None else None,
116
+ )
117
+
118
+ return outputs
119
+
120
+ def del_cache(self):
121
+ self.cache = None
122
+
123
+ def forward(self, x, timesteps, tokens=None, mask=None):
124
+ hs = []
125
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
126
+ if self.xf_width:
127
+ text_outputs = self.get_text_emb(tokens, mask)
128
+ xf_proj, xf_out = text_outputs["xf_proj"], text_outputs["xf_out"]
129
+ emb = emb + xf_proj.to(emb)
130
+ else:
131
+ xf_out = None
132
+ h = x.type(self.dtype)
133
+ for module in self.input_blocks:
134
+ h = module(h, emb, xf_out)
135
+ hs.append(h)
136
+ h = self.middle_block(h, emb, xf_out)
137
+ for module in self.output_blocks:
138
+ h = th.cat([h, hs.pop()], dim=1)
139
+ h = module(h, emb, xf_out)
140
+ h = h.type(x.dtype)
141
+ h = self.out(h)
142
+ return h
143
+
144
+
145
+ class SuperResText2ImUNet(Text2ImUNet):
146
+ """
147
+ A text2im model that performs super-resolution.
148
+ Expects an extra kwarg `low_res` to condition on a low-resolution image.
149
+ """
150
+
151
+ def __init__(self, *args, **kwargs):
152
+ if "in_channels" in kwargs:
153
+ kwargs = dict(kwargs)
154
+ kwargs["in_channels"] = kwargs["in_channels"] * 2
155
+ else:
156
+ # Curse you, Python. Or really, just curse positional arguments :|.
157
+ args = list(args)
158
+ args[1] = args[1] * 2
159
+ super().__init__(*args, **kwargs)
160
+
161
+ def forward(self, x, timesteps, low_res=None, **kwargs):
162
+ _, _, new_height, new_width = x.shape
163
+ upsampled = F.interpolate(
164
+ low_res, (new_height, new_width), mode="bilinear", align_corners=False
165
+ )
166
+ x = th.cat([x, upsampled], dim=1)
167
+ return super().forward(x, timesteps, **kwargs)
168
+
169
+
170
+ class InpaintText2ImUNet(Text2ImUNet):
171
+ """
172
+ A text2im model which can perform inpainting.
173
+ """
174
+
175
+ def __init__(self, *args, **kwargs):
176
+ if "in_channels" in kwargs:
177
+ kwargs = dict(kwargs)
178
+ kwargs["in_channels"] = kwargs["in_channels"] * 2 + 1
179
+ else:
180
+ # Curse you, Python. Or really, just curse positional arguments :|.
181
+ args = list(args)
182
+ args[1] = args[1] * 2 + 1
183
+ super().__init__(*args, **kwargs)
184
+
185
+ def forward(self, x, timesteps, inpaint_image=None, inpaint_mask=None, **kwargs):
186
+ if inpaint_image is None:
187
+ inpaint_image = th.zeros_like(x)
188
+ if inpaint_mask is None:
189
+ inpaint_mask = th.zeros_like(x[:, :1])
190
+ return super().forward(
191
+ th.cat([x, inpaint_image * inpaint_mask, inpaint_mask], dim=1),
192
+ timesteps,
193
+ **kwargs,
194
+ )
195
+
196
+
197
+ class SuperResInpaintText2ImUnet(Text2ImUNet):
198
+ """
199
+ A text2im model which can perform both upsampling and inpainting.
200
+ """
201
+
202
+ def __init__(self, *args, **kwargs):
203
+ if "in_channels" in kwargs:
204
+ kwargs = dict(kwargs)
205
+ kwargs["in_channels"] = kwargs["in_channels"] * 3 + 1
206
+ else:
207
+ # Curse you, Python. Or really, just curse positional arguments :|.
208
+ args = list(args)
209
+ args[1] = args[1] * 3 + 1
210
+ super().__init__(*args, **kwargs)
211
+
212
+ def forward(
213
+ self,
214
+ x,
215
+ timesteps,
216
+ inpaint_image=None,
217
+ inpaint_mask=None,
218
+ low_res=None,
219
+ **kwargs,
220
+ ):
221
+ if inpaint_image is None:
222
+ inpaint_image = th.zeros_like(x)
223
+ if inpaint_mask is None:
224
+ inpaint_mask = th.zeros_like(x[:, :1])
225
+ _, _, new_height, new_width = x.shape
226
+ upsampled = F.interpolate(
227
+ low_res, (new_height, new_width), mode="bilinear", align_corners=False
228
+ )
229
+ return super().forward(
230
+ th.cat([x, inpaint_image * inpaint_mask, inpaint_mask, upsampled], dim=1),
231
+ timesteps,
232
+ **kwargs,
233
+ )
glide_text2im/unet.py ADDED
@@ -0,0 +1,635 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from abc import abstractmethod
3
+
4
+ import torch as th
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from .fp16_util import convert_module_to_f16, convert_module_to_f32
9
+ from .nn import avg_pool_nd, conv_nd, linear, normalization, timestep_embedding, zero_module
10
+
11
+
12
+ class TimestepBlock(nn.Module):
13
+ """
14
+ Any module where forward() takes timestep embeddings as a second argument.
15
+ """
16
+
17
+ @abstractmethod
18
+ def forward(self, x, emb):
19
+ """
20
+ Apply the module to `x` given `emb` timestep embeddings.
21
+ """
22
+
23
+
24
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
25
+ """
26
+ A sequential module that passes timestep embeddings to the children that
27
+ support it as an extra input.
28
+ """
29
+
30
+ def forward(self, x, emb, encoder_out=None):
31
+ for layer in self:
32
+ if isinstance(layer, TimestepBlock):
33
+ x = layer(x, emb)
34
+ elif isinstance(layer, AttentionBlock):
35
+ x = layer(x, encoder_out)
36
+ else:
37
+ x = layer(x)
38
+ return x
39
+
40
+
41
+ class Upsample(nn.Module):
42
+ """
43
+ An upsampling layer with an optional convolution.
44
+
45
+ :param channels: channels in the inputs and outputs.
46
+ :param use_conv: a bool determining if a convolution is applied.
47
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
48
+ upsampling occurs in the inner-two dimensions.
49
+ """
50
+
51
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
52
+ super().__init__()
53
+ self.channels = channels
54
+ self.out_channels = out_channels or channels
55
+ self.use_conv = use_conv
56
+ self.dims = dims
57
+ if use_conv:
58
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
59
+
60
+ def forward(self, x):
61
+ assert x.shape[1] == self.channels
62
+ if self.dims == 3:
63
+ x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
64
+ else:
65
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
66
+ if self.use_conv:
67
+ x = self.conv(x)
68
+ return x
69
+
70
+
71
+ class Downsample(nn.Module):
72
+ """
73
+ A downsampling layer with an optional convolution.
74
+
75
+ :param channels: channels in the inputs and outputs.
76
+ :param use_conv: a bool determining if a convolution is applied.
77
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
78
+ downsampling occurs in the inner-two dimensions.
79
+ """
80
+
81
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
82
+ super().__init__()
83
+ self.channels = channels
84
+ self.out_channels = out_channels or channels
85
+ self.use_conv = use_conv
86
+ self.dims = dims
87
+ stride = 2 if dims != 3 else (1, 2, 2)
88
+ if use_conv:
89
+ self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=1)
90
+ else:
91
+ assert self.channels == self.out_channels
92
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
93
+
94
+ def forward(self, x):
95
+ assert x.shape[1] == self.channels
96
+ return self.op(x)
97
+
98
+
99
+ class ResBlock(TimestepBlock):
100
+ """
101
+ A residual block that can optionally change the number of channels.
102
+
103
+ :param channels: the number of input channels.
104
+ :param emb_channels: the number of timestep embedding channels.
105
+ :param dropout: the rate of dropout.
106
+ :param out_channels: if specified, the number of out channels.
107
+ :param use_conv: if True and out_channels is specified, use a spatial
108
+ convolution instead of a smaller 1x1 convolution to change the
109
+ channels in the skip connection.
110
+ :param dims: determines if the signal is 1D, 2D, or 3D.
111
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
112
+ :param up: if True, use this block for upsampling.
113
+ :param down: if True, use this block for downsampling.
114
+ """
115
+
116
+ def __init__(
117
+ self,
118
+ channels,
119
+ emb_channels,
120
+ dropout,
121
+ out_channels=None,
122
+ use_conv=False,
123
+ use_scale_shift_norm=False,
124
+ dims=2,
125
+ use_checkpoint=False,
126
+ up=False,
127
+ down=False,
128
+ ):
129
+ super().__init__()
130
+ self.channels = channels
131
+ self.emb_channels = emb_channels
132
+ self.dropout = dropout
133
+ self.out_channels = out_channels or channels
134
+ self.use_conv = use_conv
135
+ self.use_checkpoint = use_checkpoint
136
+ self.use_scale_shift_norm = use_scale_shift_norm
137
+
138
+ self.in_layers = nn.Sequential(
139
+ normalization(channels, swish=1.0),
140
+ nn.Identity(),
141
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
142
+ )
143
+
144
+ self.updown = up or down
145
+
146
+ if up:
147
+ self.h_upd = Upsample(channels, False, dims)
148
+ self.x_upd = Upsample(channels, False, dims)
149
+ elif down:
150
+ self.h_upd = Downsample(channels, False, dims)
151
+ self.x_upd = Downsample(channels, False, dims)
152
+ else:
153
+ self.h_upd = self.x_upd = nn.Identity()
154
+
155
+ self.emb_layers = nn.Sequential(
156
+ nn.SiLU(),
157
+ linear(
158
+ emb_channels,
159
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
160
+ ),
161
+ )
162
+ self.out_layers = nn.Sequential(
163
+ normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
164
+ nn.SiLU() if use_scale_shift_norm else nn.Identity(),
165
+ nn.Dropout(p=dropout),
166
+ zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
167
+ )
168
+
169
+ if self.out_channels == channels:
170
+ self.skip_connection = nn.Identity()
171
+ elif use_conv:
172
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
173
+ else:
174
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
175
+
176
+ def forward(self, x, emb):
177
+ """
178
+ Apply the block to a Tensor, conditioned on a timestep embedding.
179
+
180
+ :param x: an [N x C x ...] Tensor of features.
181
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
182
+ :return: an [N x C x ...] Tensor of outputs.
183
+ """
184
+ if self.updown:
185
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
186
+ h = in_rest(x)
187
+ h = self.h_upd(h)
188
+ x = self.x_upd(x)
189
+ h = in_conv(h)
190
+ else:
191
+ h = self.in_layers(x)
192
+ emb_out = self.emb_layers(emb).type(h.dtype)
193
+ while len(emb_out.shape) < len(h.shape):
194
+ emb_out = emb_out[..., None]
195
+ if self.use_scale_shift_norm:
196
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
197
+ scale, shift = th.chunk(emb_out, 2, dim=1)
198
+ h = out_norm(h) * (1 + scale) + shift
199
+ h = out_rest(h)
200
+ else:
201
+ h = h + emb_out
202
+ h = self.out_layers(h)
203
+ return self.skip_connection(x) + h
204
+
205
+
206
+ class AttentionBlock(nn.Module):
207
+ """
208
+ An attention block that allows spatial positions to attend to each other.
209
+
210
+ Originally ported from here, but adapted to the N-d case.
211
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
212
+ """
213
+
214
+ def __init__(
215
+ self,
216
+ channels,
217
+ num_heads=1,
218
+ num_head_channels=-1,
219
+ use_checkpoint=False,
220
+ encoder_channels=None,
221
+ ):
222
+ super().__init__()
223
+ self.channels = channels
224
+ if num_head_channels == -1:
225
+ self.num_heads = num_heads
226
+ else:
227
+ assert (
228
+ channels % num_head_channels == 0
229
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
230
+ self.num_heads = channels // num_head_channels
231
+ self.use_checkpoint = use_checkpoint
232
+ self.norm = normalization(channels, swish=0.0)
233
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
234
+ self.attention = QKVAttention(self.num_heads)
235
+
236
+ if encoder_channels is not None:
237
+ self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1)
238
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
239
+
240
+ def forward(self, x, encoder_out=None):
241
+ b, c, *spatial = x.shape
242
+ qkv = self.qkv(self.norm(x).view(b, c, -1))
243
+ if encoder_out is not None:
244
+ encoder_out = self.encoder_kv(encoder_out)
245
+ h = self.attention(qkv, encoder_out)
246
+ else:
247
+ h = self.attention(qkv)
248
+ h = self.proj_out(h)
249
+ return x + h.reshape(b, c, *spatial)
250
+
251
+
252
+ class QKVAttention(nn.Module):
253
+ """
254
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
255
+ """
256
+
257
+ def __init__(self, n_heads):
258
+ super().__init__()
259
+ self.n_heads = n_heads
260
+
261
+ def forward(self, qkv, encoder_kv=None):
262
+ """
263
+ Apply QKV attention.
264
+
265
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
266
+ :return: an [N x (H * C) x T] tensor after attention.
267
+ """
268
+ bs, width, length = qkv.shape
269
+ assert width % (3 * self.n_heads) == 0
270
+ ch = width // (3 * self.n_heads)
271
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
272
+ if encoder_kv is not None:
273
+ assert encoder_kv.shape[1] == self.n_heads * ch * 2
274
+ ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
275
+ k = th.cat([ek, k], dim=-1)
276
+ v = th.cat([ev, v], dim=-1)
277
+ scale = 1 / math.sqrt(math.sqrt(ch))
278
+ weight = th.einsum(
279
+ "bct,bcs->bts", q * scale, k * scale
280
+ ) # More stable with f16 than dividing afterwards
281
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
282
+ a = th.einsum("bts,bcs->bct", weight, v)
283
+ return a.reshape(bs, -1, length)
284
+
285
+
286
+ class UNetModel(nn.Module):
287
+ """
288
+ The full UNet model with attention and timestep embedding.
289
+
290
+ :param in_channels: channels in the input Tensor.
291
+ :param model_channels: base channel count for the model.
292
+ :param out_channels: channels in the output Tensor.
293
+ :param num_res_blocks: number of residual blocks per downsample.
294
+ :param attention_resolutions: a collection of downsample rates at which
295
+ attention will take place. May be a set, list, or tuple.
296
+ For example, if this contains 4, then at 4x downsampling, attention
297
+ will be used.
298
+ :param dropout: the dropout probability.
299
+ :param channel_mult: channel multiplier for each level of the UNet.
300
+ :param conv_resample: if True, use learned convolutions for upsampling and
301
+ downsampling.
302
+ :param dims: determines if the signal is 1D, 2D, or 3D.
303
+ :param num_classes: if specified (as an int), then this model will be
304
+ class-conditional with `num_classes` classes.
305
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
306
+ :param num_heads: the number of attention heads in each attention layer.
307
+ :param num_heads_channels: if specified, ignore num_heads and instead use
308
+ a fixed channel width per attention head.
309
+ :param num_heads_upsample: works with num_heads to set a different number
310
+ of heads for upsampling. Deprecated.
311
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
312
+ :param resblock_updown: use residual blocks for up/downsampling.
313
+ """
314
+
315
+ def __init__(
316
+ self,
317
+ in_channels,
318
+ model_channels,
319
+ out_channels,
320
+ num_res_blocks,
321
+ attention_resolutions,
322
+ dropout=0,
323
+ channel_mult=(1, 2, 4, 8),
324
+ conv_resample=True,
325
+ dims=2,
326
+ num_classes=None,
327
+ use_checkpoint=False,
328
+ use_fp16=False,
329
+ num_heads=1,
330
+ num_head_channels=-1,
331
+ num_heads_upsample=-1,
332
+ use_scale_shift_norm=False,
333
+ resblock_updown=False,
334
+ encoder_channels=None,
335
+ ):
336
+ super().__init__()
337
+
338
+ if num_heads_upsample == -1:
339
+ num_heads_upsample = num_heads
340
+
341
+ self.in_channels = in_channels
342
+ self.model_channels = model_channels
343
+ self.out_channels = out_channels
344
+ self.num_res_blocks = num_res_blocks
345
+ self.attention_resolutions = attention_resolutions
346
+ self.dropout = dropout
347
+ self.channel_mult = channel_mult
348
+ self.conv_resample = conv_resample
349
+ self.num_classes = num_classes
350
+ self.use_checkpoint = use_checkpoint
351
+ self.dtype = th.float16 if use_fp16 else th.float32
352
+ self.num_heads = num_heads
353
+ self.num_head_channels = num_head_channels
354
+ self.num_heads_upsample = num_heads_upsample
355
+
356
+ time_embed_dim = model_channels * 4
357
+ self.time_embed = nn.Sequential(
358
+ linear(model_channels, time_embed_dim),
359
+ nn.SiLU(),
360
+ linear(time_embed_dim, time_embed_dim),
361
+ )
362
+
363
+ if self.num_classes is not None:
364
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
365
+
366
+ ch = input_ch = int(channel_mult[0] * model_channels)
367
+ self.input_blocks = nn.ModuleList(
368
+ [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
369
+ )
370
+ self._feature_size = ch
371
+ input_block_chans = [ch]
372
+ ds = 1
373
+ for level, mult in enumerate(channel_mult):
374
+ for _ in range(num_res_blocks):
375
+ layers = [
376
+ ResBlock(
377
+ ch,
378
+ time_embed_dim,
379
+ dropout,
380
+ out_channels=int(mult * model_channels),
381
+ dims=dims,
382
+ use_checkpoint=use_checkpoint,
383
+ use_scale_shift_norm=use_scale_shift_norm,
384
+ )
385
+ ]
386
+ ch = int(mult * model_channels)
387
+ if ds in attention_resolutions:
388
+ layers.append(
389
+ AttentionBlock(
390
+ ch,
391
+ use_checkpoint=use_checkpoint,
392
+ num_heads=num_heads,
393
+ num_head_channels=num_head_channels,
394
+ encoder_channels=encoder_channels,
395
+ )
396
+ )
397
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
398
+ self._feature_size += ch
399
+ input_block_chans.append(ch)
400
+ if level != len(channel_mult) - 1:
401
+ out_ch = ch
402
+ self.input_blocks.append(
403
+ TimestepEmbedSequential(
404
+ ResBlock(
405
+ ch,
406
+ time_embed_dim,
407
+ dropout,
408
+ out_channels=out_ch,
409
+ dims=dims,
410
+ use_checkpoint=use_checkpoint,
411
+ use_scale_shift_norm=use_scale_shift_norm,
412
+ down=True,
413
+ )
414
+ if resblock_updown
415
+ else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
416
+ )
417
+ )
418
+ ch = out_ch
419
+ input_block_chans.append(ch)
420
+ ds *= 2
421
+ self._feature_size += ch
422
+
423
+ self.middle_block = TimestepEmbedSequential(
424
+ ResBlock(
425
+ ch,
426
+ time_embed_dim,
427
+ dropout,
428
+ dims=dims,
429
+ use_checkpoint=use_checkpoint,
430
+ use_scale_shift_norm=use_scale_shift_norm,
431
+ ),
432
+ AttentionBlock(
433
+ ch,
434
+ use_checkpoint=use_checkpoint,
435
+ num_heads=num_heads,
436
+ num_head_channels=num_head_channels,
437
+ encoder_channels=encoder_channels,
438
+ ),
439
+ ResBlock(
440
+ ch,
441
+ time_embed_dim,
442
+ dropout,
443
+ dims=dims,
444
+ use_checkpoint=use_checkpoint,
445
+ use_scale_shift_norm=use_scale_shift_norm,
446
+ ),
447
+ )
448
+ self._feature_size += ch
449
+
450
+ self.output_blocks = nn.ModuleList([])
451
+ for level, mult in list(enumerate(channel_mult))[::-1]:
452
+ for i in range(num_res_blocks + 1):
453
+ ich = input_block_chans.pop()
454
+ layers = [
455
+ ResBlock(
456
+ ch + ich,
457
+ time_embed_dim,
458
+ dropout,
459
+ out_channels=int(model_channels * mult),
460
+ dims=dims,
461
+ use_checkpoint=use_checkpoint,
462
+ use_scale_shift_norm=use_scale_shift_norm,
463
+ )
464
+ ]
465
+ ch = int(model_channels * mult)
466
+ if ds in attention_resolutions:
467
+ layers.append(
468
+ AttentionBlock(
469
+ ch,
470
+ use_checkpoint=use_checkpoint,
471
+ num_heads=num_heads_upsample,
472
+ num_head_channels=num_head_channels,
473
+ encoder_channels=encoder_channels,
474
+ )
475
+ )
476
+ if level and i == num_res_blocks:
477
+ out_ch = ch
478
+ layers.append(
479
+ ResBlock(
480
+ ch,
481
+ time_embed_dim,
482
+ dropout,
483
+ out_channels=out_ch,
484
+ dims=dims,
485
+ use_checkpoint=use_checkpoint,
486
+ use_scale_shift_norm=use_scale_shift_norm,
487
+ up=True,
488
+ )
489
+ if resblock_updown
490
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
491
+ )
492
+ ds //= 2
493
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
494
+ self._feature_size += ch
495
+
496
+ self.out = nn.Sequential(
497
+ normalization(ch, swish=1.0),
498
+ nn.Identity(),
499
+ zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
500
+ )
501
+ self.use_fp16 = use_fp16
502
+
503
+ def convert_to_fp16(self):
504
+ """
505
+ Convert the torso of the model to float16.
506
+ """
507
+ self.input_blocks.apply(convert_module_to_f16)
508
+ self.middle_block.apply(convert_module_to_f16)
509
+ self.output_blocks.apply(convert_module_to_f16)
510
+
511
+ def convert_to_fp32(self):
512
+ """
513
+ Convert the torso of the model to float32.
514
+ """
515
+ self.input_blocks.apply(convert_module_to_f32)
516
+ self.middle_block.apply(convert_module_to_f32)
517
+ self.output_blocks.apply(convert_module_to_f32)
518
+
519
+ def forward(self, x, timesteps, y=None):
520
+ """
521
+ Apply the model to an input batch.
522
+
523
+ :param x: an [N x C x ...] Tensor of inputs.
524
+ :param timesteps: a 1-D batch of timesteps.
525
+ :param y: an [N] Tensor of labels, if class-conditional.
526
+ :return: an [N x C x ...] Tensor of outputs.
527
+ """
528
+ assert (y is not None) == (
529
+ self.num_classes is not None
530
+ ), "must specify y if and only if the model is class-conditional"
531
+
532
+ hs = []
533
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
534
+
535
+ if self.num_classes is not None:
536
+ assert y.shape == (x.shape[0],)
537
+ emb = emb + self.label_emb(y)
538
+
539
+ h = x.type(self.dtype)
540
+ for module in self.input_blocks:
541
+ h = module(h, emb)
542
+ hs.append(h)
543
+ h = self.middle_block(h, emb)
544
+ for module in self.output_blocks:
545
+ h = th.cat([h, hs.pop()], dim=1)
546
+ h = module(h, emb)
547
+ h = h.type(x.dtype)
548
+ return self.out(h)
549
+
550
+ class SuperResUNetModel(UNetModel):
551
+ """
552
+ A UNetModel that performs super-resolution.
553
+
554
+ Expects an extra kwarg `low_res` to condition on a low-resolution image.
555
+ """
556
+
557
+ def __init__(self, *args, **kwargs):
558
+ if "in_channels" in kwargs:
559
+ kwargs = dict(kwargs)
560
+ kwargs["in_channels"] = kwargs["in_channels"] * 2
561
+ else:
562
+ # Curse you, Python. Or really, just curse positional arguments :|.
563
+ args = list(args)
564
+ args[1] = args[1] * 2
565
+ super().__init__(*args, **kwargs)
566
+
567
+ def forward(self, x, timesteps, low_res=None, **kwargs):
568
+ _, _, new_height, new_width = x.shape
569
+ upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
570
+ x = th.cat([x, upsampled], dim=1)
571
+ return super().forward(x, timesteps, **kwargs)
572
+
573
+
574
+ class InpaintUNetModel(UNetModel):
575
+ """
576
+ A UNetModel which can perform inpainting.
577
+ """
578
+
579
+ def __init__(self, *args, **kwargs):
580
+ if "in_channels" in kwargs:
581
+ kwargs = dict(kwargs)
582
+ kwargs["in_channels"] = kwargs["in_channels"] * 2 + 1
583
+ else:
584
+ # Curse you, Python. Or really, just curse positional arguments :|.
585
+ args = list(args)
586
+ args[1] = args[1] * 2 + 1
587
+ super().__init__(*args, **kwargs)
588
+
589
+ def forward(self, x, timesteps, inpaint_image=None, inpaint_mask=None, **kwargs):
590
+ if inpaint_image is None:
591
+ inpaint_image = th.zeros_like(x)
592
+ if inpaint_mask is None:
593
+ inpaint_mask = th.zeros_like(x[:, :1])
594
+ return super().forward(
595
+ th.cat([x, inpaint_image * inpaint_mask, inpaint_mask], dim=1),
596
+ timesteps,
597
+ **kwargs,
598
+ )
599
+
600
+
601
+ class SuperResInpaintUNetModel(UNetModel):
602
+ """
603
+ A UNetModel which can perform both upsampling and inpainting.
604
+ """
605
+
606
+ def __init__(self, *args, **kwargs):
607
+ if "in_channels" in kwargs:
608
+ kwargs = dict(kwargs)
609
+ kwargs["in_channels"] = kwargs["in_channels"] * 3 + 1
610
+ else:
611
+ # Curse you, Python. Or really, just curse positional arguments :|.
612
+ args = list(args)
613
+ args[1] = args[1] * 3 + 1
614
+ super().__init__(*args, **kwargs)
615
+
616
+ def forward(
617
+ self,
618
+ x,
619
+ timesteps,
620
+ inpaint_image=None,
621
+ inpaint_mask=None,
622
+ low_res=None,
623
+ **kwargs,
624
+ ):
625
+ if inpaint_image is None:
626
+ inpaint_image = th.zeros_like(x)
627
+ if inpaint_mask is None:
628
+ inpaint_mask = th.zeros_like(x[:, :1])
629
+ _, _, new_height, new_width = x.shape
630
+ upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
631
+ return super().forward(
632
+ th.cat([x, inpaint_image * inpaint_mask, inpaint_mask, upsampled], dim=1),
633
+ timesteps,
634
+ **kwargs,
635
+ )
glide_text2im/xf.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Transformer implementation adapted from CLIP ViT:
3
+ https://github.com/openai/CLIP/blob/4c0275784d6d9da97ca1f47eaaee31de1867da91/clip/model.py
4
+ """
5
+
6
+ import math
7
+
8
+ import torch as th
9
+ import torch.nn as nn
10
+
11
+
12
+ def convert_module_to_f16(l):
13
+ """
14
+ Convert primitive modules to float16.
15
+ """
16
+ if isinstance(l, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
17
+ l.weight.data = l.weight.data.half()
18
+ if l.bias is not None:
19
+ l.bias.data = l.bias.data.half()
20
+
21
+
22
+ class LayerNorm(nn.LayerNorm):
23
+ """
24
+ Implementation that supports fp16 inputs but fp32 gains/biases.
25
+ """
26
+
27
+ def forward(self, x: th.Tensor):
28
+ return super().forward(x.float()).to(x.dtype)
29
+
30
+
31
+ class MultiheadAttention(nn.Module):
32
+ def __init__(self, n_ctx, width, heads):
33
+ super().__init__()
34
+ self.n_ctx = n_ctx
35
+ self.width = width
36
+ self.heads = heads
37
+ self.c_qkv = nn.Linear(width, width * 3)
38
+ self.c_proj = nn.Linear(width, width)
39
+ self.attention = QKVMultiheadAttention(heads, n_ctx)
40
+
41
+ def forward(self, x):
42
+ x = self.c_qkv(x)
43
+ x = self.attention(x)
44
+ x = self.c_proj(x)
45
+ return x
46
+
47
+
48
+ class MLP(nn.Module):
49
+ def __init__(self, width):
50
+ super().__init__()
51
+ self.width = width
52
+ self.c_fc = nn.Linear(width, width * 4)
53
+ self.c_proj = nn.Linear(width * 4, width)
54
+ self.gelu = nn.GELU()
55
+
56
+ def forward(self, x):
57
+ return self.c_proj(self.gelu(self.c_fc(x)))
58
+
59
+
60
+ class QKVMultiheadAttention(nn.Module):
61
+ def __init__(self, n_heads: int, n_ctx: int):
62
+ super().__init__()
63
+ self.n_heads = n_heads
64
+ self.n_ctx = n_ctx
65
+
66
+ def forward(self, qkv):
67
+ bs, n_ctx, width = qkv.shape
68
+ attn_ch = width // self.n_heads // 3
69
+ scale = 1 / math.sqrt(math.sqrt(attn_ch))
70
+ qkv = qkv.view(bs, n_ctx, self.n_heads, -1)
71
+ q, k, v = th.split(qkv, attn_ch, dim=-1)
72
+ weight = th.einsum(
73
+ "bthc,bshc->bhts", q * scale, k * scale
74
+ ) # More stable with f16 than dividing afterwards
75
+ wdtype = weight.dtype
76
+ weight = th.softmax(weight.float(), dim=-1).type(wdtype)
77
+ return th.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
78
+
79
+
80
+ class ResidualAttentionBlock(nn.Module):
81
+ def __init__(
82
+ self,
83
+ n_ctx: int,
84
+ width: int,
85
+ heads: int,
86
+ ):
87
+ super().__init__()
88
+
89
+ self.attn = MultiheadAttention(
90
+ n_ctx,
91
+ width,
92
+ heads,
93
+ )
94
+ self.ln_1 = LayerNorm(width)
95
+ self.mlp = MLP(width)
96
+ self.ln_2 = LayerNorm(width)
97
+
98
+ def forward(self, x: th.Tensor):
99
+ x = x + self.attn(self.ln_1(x))
100
+ x = x + self.mlp(self.ln_2(x))
101
+ return x
102
+
103
+
104
+ class Transformer(nn.Module):
105
+ def __init__(
106
+ self,
107
+ n_ctx: int,
108
+ width: int,
109
+ layers: int,
110
+ heads: int,
111
+ ):
112
+ super().__init__()
113
+ self.n_ctx = n_ctx
114
+ self.width = width
115
+ self.layers = layers
116
+ self.resblocks = nn.ModuleList(
117
+ [
118
+ ResidualAttentionBlock(
119
+ n_ctx,
120
+ width,
121
+ heads,
122
+ )
123
+ for _ in range(layers)
124
+ ]
125
+ )
126
+
127
+ def forward(self, x: th.Tensor):
128
+ for block in self.resblocks:
129
+ x = block(x)
130
+ return x