ElixirRod JamesXu commited on
Commit
3ac711b
·
0 Parent(s):

Duplicate from shi-labs/Versatile-Diffusion

Browse files

Co-authored-by: Xingqian Xu <JamesXu@users.noreply.huggingface.co>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +35 -0
  2. .gitignore +7 -0
  3. README.md +15 -0
  4. app.py +845 -0
  5. assets/benz.jpg +3 -0
  6. assets/boy_and_girl.jpg +3 -0
  7. assets/church.jpg +3 -0
  8. assets/figures/share_instruction.png +0 -0
  9. assets/firework.jpg +3 -0
  10. assets/ghibli.jpg +3 -0
  11. assets/horse.png +3 -0
  12. assets/house_by_lake.jpg +3 -0
  13. assets/matisse.jpg +3 -0
  14. assets/night_light.jpg +3 -0
  15. assets/penguin.png +3 -0
  16. assets/san_diego.jpg +3 -0
  17. assets/scream.jpg +3 -0
  18. assets/space.jpg +3 -0
  19. assets/tiger.jpg +3 -0
  20. assets/train.jpg +3 -0
  21. assets/vermeer.jpg +3 -0
  22. configs/model/clip.yaml +50 -0
  23. configs/model/openai_unet.yaml +72 -0
  24. configs/model/optimus.yaml +103 -0
  25. configs/model/sd.yaml +69 -0
  26. configs/model/vd.yaml +61 -0
  27. lib/__init__.py +0 -0
  28. lib/cfg_helper.py +664 -0
  29. lib/cfg_holder.py +28 -0
  30. lib/log_service.py +166 -0
  31. lib/model_zoo/__init__.py +4 -0
  32. lib/model_zoo/attention.py +435 -0
  33. lib/model_zoo/autoencoder.py +428 -0
  34. lib/model_zoo/bert.py +142 -0
  35. lib/model_zoo/clip.py +178 -0
  36. lib/model_zoo/clip_justin/__init__.py +1 -0
  37. lib/model_zoo/clip_justin/clip.py +237 -0
  38. lib/model_zoo/clip_justin/model.py +436 -0
  39. lib/model_zoo/clip_justin/simple_tokenizer.py +132 -0
  40. lib/model_zoo/common/get_model.py +128 -0
  41. lib/model_zoo/common/get_optimizer.py +47 -0
  42. lib/model_zoo/common/get_scheduler.py +262 -0
  43. lib/model_zoo/common/utils.py +292 -0
  44. lib/model_zoo/ddim.py +216 -0
  45. lib/model_zoo/ddim_dualcontext.py +144 -0
  46. lib/model_zoo/ddim_dualmodel.py +244 -0
  47. lib/model_zoo/ddim_vd.py +290 -0
  48. lib/model_zoo/diffusion_modules.py +835 -0
  49. lib/model_zoo/diffusion_utils.py +250 -0
  50. lib/model_zoo/distributions.py +92 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.pth filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ .vscode/
3
+ src/
4
+ data/
5
+ data
6
+ log/
7
+ log
README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Versatile Diffusion
3
+ emoji: null
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 3.9.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ python_version: 3.8.5
12
+ duplicated_from: shi-labs/Versatile-Diffusion
13
+ ---
14
+
15
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,845 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import PIL
4
+ from PIL import Image
5
+ from pathlib import Path
6
+ import numpy as np
7
+ import numpy.random as npr
8
+ from contextlib import nullcontext
9
+
10
+ import torch
11
+ import torchvision.transforms as tvtrans
12
+ from lib.cfg_helper import model_cfg_bank
13
+ from lib.model_zoo import get_model
14
+
15
+ n_sample_image_default = 2
16
+ n_sample_text_default = 4
17
+ cache_examples = True
18
+ hfm_repo_id = 'shi-labs/versatile-diffusion-model'
19
+ hfm_filename = 'pretrained_pth/vd-four-flow-v1-0-fp16.pth'
20
+
21
+ def highlight_print(info):
22
+ print('')
23
+ print(''.join(['#']*(len(info)+4)))
24
+ print('# '+info+' #')
25
+ print(''.join(['#']*(len(info)+4)))
26
+ print('')
27
+
28
+ class color_adjust(object):
29
+ def __init__(self, ref_from, ref_to):
30
+ x0, m0, std0 = self.get_data_and_stat(ref_from)
31
+ x1, m1, std1 = self.get_data_and_stat(ref_to)
32
+ self.ref_from_stat = (m0, std0)
33
+ self.ref_to_stat = (m1, std1)
34
+ self.ref_from = self.preprocess(x0).reshape(-1, 3)
35
+ self.ref_to = x1.reshape(-1, 3)
36
+
37
+ def get_data_and_stat(self, x):
38
+ if isinstance(x, str):
39
+ x = np.array(PIL.Image.open(x))
40
+ elif isinstance(x, PIL.Image.Image):
41
+ x = np.array(x)
42
+ elif isinstance(x, torch.Tensor):
43
+ x = torch.clamp(x, min=0.0, max=1.0)
44
+ x = np.array(tvtrans.ToPILImage()(x))
45
+ elif isinstance(x, np.ndarray):
46
+ pass
47
+ else:
48
+ raise ValueError
49
+ x = x.astype(float)
50
+ m = np.reshape(x, (-1, 3)).mean(0)
51
+ s = np.reshape(x, (-1, 3)).std(0)
52
+ return x, m, s
53
+
54
+ def preprocess(self, x):
55
+ m0, s0 = self.ref_from_stat
56
+ m1, s1 = self.ref_to_stat
57
+ y = ((x-m0)/s0)*s1 + m1
58
+ return y
59
+
60
+ def __call__(self, xin, keep=0, simple=False):
61
+ xin, _, _ = self.get_data_and_stat(xin)
62
+ x = self.preprocess(xin)
63
+ if simple:
64
+ y = (x*(1-keep) + xin*keep)
65
+ y = np.clip(y, 0, 255).astype(np.uint8)
66
+ return y
67
+
68
+ h, w = x.shape[:2]
69
+ x = x.reshape(-1, 3)
70
+ y = []
71
+ for chi in range(3):
72
+ yi = self.pdf_transfer_1d(self.ref_from[:, chi], self.ref_to[:, chi], x[:, chi])
73
+ y.append(yi)
74
+
75
+ y = np.stack(y, axis=1)
76
+ y = y.reshape(h, w, 3)
77
+ y = (y.astype(float)*(1-keep) + xin.astype(float)*keep)
78
+ y = np.clip(y, 0, 255).astype(np.uint8)
79
+ return y
80
+
81
+ def pdf_transfer_1d(self, arr_fo, arr_to, arr_in, n=600):
82
+ arr = np.concatenate((arr_fo, arr_to))
83
+ min_v = arr.min() - 1e-6
84
+ max_v = arr.max() + 1e-6
85
+ min_vto = arr_to.min() - 1e-6
86
+ max_vto = arr_to.max() + 1e-6
87
+ xs = np.array(
88
+ [min_v + (max_v - min_v) * i / n for i in range(n + 1)])
89
+ hist_fo, _ = np.histogram(arr_fo, xs)
90
+ hist_to, _ = np.histogram(arr_to, xs)
91
+ xs = xs[:-1]
92
+ # compute probability distribution
93
+ cum_fo = np.cumsum(hist_fo)
94
+ cum_to = np.cumsum(hist_to)
95
+ d_fo = cum_fo / cum_fo[-1]
96
+ d_to = cum_to / cum_to[-1]
97
+ # transfer
98
+ t_d = np.interp(d_fo, d_to, xs)
99
+ t_d[d_fo <= d_to[ 0]] = min_vto
100
+ t_d[d_fo >= d_to[-1]] = max_vto
101
+ arr_out = np.interp(arr_in, xs, t_d)
102
+ return arr_out
103
+
104
+ class vd_inference(object):
105
+ def __init__(self, pth=None, hfm_repo=None, fp16=False, device=0):
106
+ cfgm_name = 'vd_noema'
107
+ cfgm = model_cfg_bank()('vd_noema')
108
+ net = get_model()(cfgm)
109
+ if fp16:
110
+ highlight_print('Running in FP16')
111
+ net.clip.fp16 = True
112
+ net = net.half()
113
+ if pth is not None:
114
+ sd = torch.load(pth, map_location='cpu')
115
+ print('Load pretrained weight from {}'.format(pth))
116
+ else:
117
+ from huggingface_hub import hf_hub_download
118
+ temppath = hf_hub_download(hfm_repo[0], hfm_repo[1])
119
+ sd = torch.load(temppath, map_location='cpu')
120
+ print('Load pretrained weight from {}/{}'.format(*hfm_repo))
121
+
122
+ net.load_state_dict(sd, strict=False)
123
+ net.to(device)
124
+
125
+ self.device = device
126
+ self.model_name = cfgm_name
127
+ self.net = net
128
+ self.fp16 = fp16
129
+ from lib.model_zoo.ddim_vd import DDIMSampler_VD
130
+ self.sampler = DDIMSampler_VD(net)
131
+
132
+ def regularize_image(self, x):
133
+ BICUBIC = PIL.Image.Resampling.BICUBIC
134
+ if isinstance(x, str):
135
+ x = Image.open(x).resize([512, 512], resample=BICUBIC)
136
+ x = tvtrans.ToTensor()(x)
137
+ elif isinstance(x, PIL.Image.Image):
138
+ x = x.resize([512, 512], resample=BICUBIC)
139
+ x = tvtrans.ToTensor()(x)
140
+ elif isinstance(x, np.ndarray):
141
+ x = PIL.Image.fromarray(x).resize([512, 512], resample=BICUBIC)
142
+ x = tvtrans.ToTensor()(x)
143
+ elif isinstance(x, torch.Tensor):
144
+ pass
145
+ else:
146
+ assert False, 'Unknown image type'
147
+
148
+ assert (x.shape[1]==512) & (x.shape[2]==512), \
149
+ 'Wrong image size'
150
+ x = x.to(self.device)
151
+ if self.fp16:
152
+ x = x.half()
153
+ return x
154
+
155
+ def decode(self, z, xtype, ctype, color_adj='None', color_adj_to=None):
156
+ net = self.net
157
+ if xtype == 'image':
158
+ x = net.autokl_decode(z)
159
+
160
+ color_adj_flag = (color_adj!='none') and (color_adj!='None') and (color_adj is not None)
161
+ color_adj_simple = (color_adj=='Simple') or color_adj=='simple'
162
+ color_adj_keep_ratio = 0.5
163
+
164
+ if color_adj_flag and (ctype=='vision'):
165
+ x_adj = []
166
+ for xi in x:
167
+ color_adj_f = color_adjust(ref_from=(xi+1)/2, ref_to=color_adj_to)
168
+ xi_adj = color_adj_f((xi+1)/2, keep=color_adj_keep_ratio, simple=color_adj_simple)
169
+ x_adj.append(xi_adj)
170
+ x = x_adj
171
+ else:
172
+ x = torch.clamp((x+1.0)/2.0, min=0.0, max=1.0)
173
+ x = [tvtrans.ToPILImage()(xi) for xi in x]
174
+ return x
175
+
176
+ elif xtype == 'text':
177
+ prompt_temperature = 1.0
178
+ prompt_merge_same_adj_word = True
179
+ x = net.optimus_decode(z, temperature=prompt_temperature)
180
+ if prompt_merge_same_adj_word:
181
+ xnew = []
182
+ for xi in x:
183
+ xi_split = xi.split()
184
+ xinew = []
185
+ for idxi, wi in enumerate(xi_split):
186
+ if idxi!=0 and wi==xi_split[idxi-1]:
187
+ continue
188
+ xinew.append(wi)
189
+ xnew.append(' '.join(xinew))
190
+ x = xnew
191
+ return x
192
+
193
+ def inference(self, xtype, cin, ctype, scale=7.5, n_samples=None, color_adj=None,):
194
+ net = self.net
195
+ sampler = self.sampler
196
+ ddim_steps = 50
197
+ ddim_eta = 0.0
198
+
199
+ if xtype == 'image':
200
+ n_samples = n_sample_image_default if n_samples is None else n_samples
201
+ elif xtype == 'text':
202
+ n_samples = n_sample_text_default if n_samples is None else n_samples
203
+
204
+ if ctype in ['prompt', 'text']:
205
+ c = net.clip_encode_text(n_samples * [cin])
206
+ u = None
207
+ if scale != 1.0:
208
+ u = net.clip_encode_text(n_samples * [""])
209
+
210
+ elif ctype in ['vision', 'image']:
211
+ cin = self.regularize_image(cin)
212
+ ctemp = cin*2 - 1
213
+ ctemp = ctemp[None].repeat(n_samples, 1, 1, 1)
214
+ c = net.clip_encode_vision(ctemp)
215
+ u = None
216
+ if scale != 1.0:
217
+ dummy = torch.zeros_like(ctemp)
218
+ u = net.clip_encode_vision(dummy)
219
+
220
+ u, c = [u.half(), c.half()] if self.fp16 else [u, c]
221
+
222
+ if xtype == 'image':
223
+ h, w = [512, 512]
224
+ shape = [n_samples, 4, h//8, w//8]
225
+ z, _ = sampler.sample(
226
+ steps=ddim_steps,
227
+ shape=shape,
228
+ conditioning=c,
229
+ unconditional_guidance_scale=scale,
230
+ unconditional_conditioning=u,
231
+ xtype=xtype, ctype=ctype,
232
+ eta=ddim_eta,
233
+ verbose=False,)
234
+ x = self.decode(z, xtype, ctype, color_adj=color_adj, color_adj_to=cin)
235
+ return x
236
+
237
+ elif xtype == 'text':
238
+ n = 768
239
+ shape = [n_samples, n]
240
+ z, _ = sampler.sample(
241
+ steps=ddim_steps,
242
+ shape=shape,
243
+ conditioning=c,
244
+ unconditional_guidance_scale=scale,
245
+ unconditional_conditioning=u,
246
+ xtype=xtype, ctype=ctype,
247
+ eta=ddim_eta,
248
+ verbose=False,)
249
+ x = self.decode(z, xtype, ctype)
250
+ return x
251
+
252
+ def application_disensemble(self, cin, n_samples=None, level=0, color_adj=None,):
253
+ net = self.net
254
+ scale = 7.5
255
+ sampler = self.sampler
256
+ ddim_steps = 50
257
+ ddim_eta = 0.0
258
+ n_samples = n_sample_image_default if n_samples is None else n_samples
259
+
260
+ cin = self.regularize_image(cin)
261
+ ctemp = cin*2 - 1
262
+ ctemp = ctemp[None].repeat(n_samples, 1, 1, 1)
263
+ c = net.clip_encode_vision(ctemp)
264
+ u = None
265
+ if scale != 1.0:
266
+ dummy = torch.zeros_like(ctemp)
267
+ u = net.clip_encode_vision(dummy)
268
+ u, c = [u.half(), c.half()] if self.fp16 else [u, c]
269
+
270
+ if level == 0:
271
+ pass
272
+ else:
273
+ c_glb = c[:, 0:1]
274
+ c_loc = c[:, 1: ]
275
+ u_glb = u[:, 0:1]
276
+ u_loc = u[:, 1: ]
277
+
278
+ if level == -1:
279
+ c_loc = self.remove_low_rank(c_loc, demean=True, q=50, q_remove=1)
280
+ u_loc = self.remove_low_rank(u_loc, demean=True, q=50, q_remove=1)
281
+ if level == -2:
282
+ c_loc = self.remove_low_rank(c_loc, demean=True, q=50, q_remove=2)
283
+ u_loc = self.remove_low_rank(u_loc, demean=True, q=50, q_remove=2)
284
+ if level == 1:
285
+ c_loc = self.find_low_rank(c_loc, demean=True, q=10)
286
+ u_loc = self.find_low_rank(u_loc, demean=True, q=10)
287
+ if level == 2:
288
+ c_loc = self.find_low_rank(c_loc, demean=True, q=2)
289
+ u_loc = self.find_low_rank(u_loc, demean=True, q=2)
290
+
291
+ c = torch.cat([c_glb, c_loc], dim=1)
292
+ u = torch.cat([u_glb, u_loc], dim=1)
293
+
294
+ h, w = [512, 512]
295
+ shape = [n_samples, 4, h//8, w//8]
296
+ z, _ = sampler.sample(
297
+ steps=ddim_steps,
298
+ shape=shape,
299
+ conditioning=c,
300
+ unconditional_guidance_scale=scale,
301
+ unconditional_conditioning=u,
302
+ xtype='image', ctype='vision',
303
+ eta=ddim_eta,
304
+ verbose=False,)
305
+ x = self.decode(z, 'image', 'vision', color_adj=color_adj, color_adj_to=cin)
306
+ return x
307
+
308
+ def find_low_rank(self, x, demean=True, q=20, niter=10):
309
+ if demean:
310
+ x_mean = x.mean(-1, keepdim=True)
311
+ x_input = x - x_mean
312
+ else:
313
+ x_input = x
314
+
315
+ if x_input.dtype == torch.float16:
316
+ fp16 = True
317
+ x_input = x_input.float()
318
+ else:
319
+ fp16 = False
320
+
321
+ u, s, v = torch.pca_lowrank(x_input, q=q, center=False, niter=niter)
322
+ ss = torch.stack([torch.diag(si) for si in s])
323
+ x_lowrank = torch.bmm(torch.bmm(u, ss), torch.permute(v, [0, 2, 1]))
324
+
325
+ if fp16:
326
+ x_lowrank = x_lowrank.half()
327
+
328
+ if demean:
329
+ x_lowrank += x_mean
330
+ return x_lowrank
331
+
332
+ def remove_low_rank(self, x, demean=True, q=20, niter=10, q_remove=10):
333
+ if demean:
334
+ x_mean = x.mean(-1, keepdim=True)
335
+ x_input = x - x_mean
336
+ else:
337
+ x_input = x
338
+
339
+ if x_input.dtype == torch.float16:
340
+ fp16 = True
341
+ x_input = x_input.float()
342
+ else:
343
+ fp16 = False
344
+
345
+ u, s, v = torch.pca_lowrank(x_input, q=q, center=False, niter=niter)
346
+ s[:, 0:q_remove] = 0
347
+ ss = torch.stack([torch.diag(si) for si in s])
348
+ x_lowrank = torch.bmm(torch.bmm(u, ss), torch.permute(v, [0, 2, 1]))
349
+
350
+ if fp16:
351
+ x_lowrank = x_lowrank.half()
352
+
353
+ if demean:
354
+ x_lowrank += x_mean
355
+ return x_lowrank
356
+
357
+ def application_dualguided(self, cim, ctx, n_samples=None, mixing=0.5, color_adj=None, ):
358
+ net = self.net
359
+ scale = 7.5
360
+ sampler = self.sampler
361
+ ddim_steps = 50
362
+ ddim_eta = 0.0
363
+ n_samples = n_sample_image_default if n_samples is None else n_samples
364
+
365
+ ctemp0 = self.regularize_image(cim)
366
+ ctemp1 = ctemp0*2 - 1
367
+ ctemp1 = ctemp1[None].repeat(n_samples, 1, 1, 1)
368
+ cim = net.clip_encode_vision(ctemp1)
369
+ uim = None
370
+ if scale != 1.0:
371
+ dummy = torch.zeros_like(ctemp1)
372
+ uim = net.clip_encode_vision(dummy)
373
+
374
+ ctx = net.clip_encode_text(n_samples * [ctx])
375
+ utx = None
376
+ if scale != 1.0:
377
+ utx = net.clip_encode_text(n_samples * [""])
378
+
379
+ uim, cim = [uim.half(), cim.half()] if self.fp16 else [uim, cim]
380
+ utx, ctx = [utx.half(), ctx.half()] if self.fp16 else [utx, ctx]
381
+
382
+ h, w = [512, 512]
383
+ shape = [n_samples, 4, h//8, w//8]
384
+
385
+ z, _ = sampler.sample_dc(
386
+ steps=ddim_steps,
387
+ shape=shape,
388
+ first_conditioning=[uim, cim],
389
+ second_conditioning=[utx, ctx],
390
+ unconditional_guidance_scale=scale,
391
+ xtype='image',
392
+ first_ctype='vision',
393
+ second_ctype='prompt',
394
+ eta=ddim_eta,
395
+ verbose=False,
396
+ mixed_ratio=(1-mixing), )
397
+ x = self.decode(z, 'image', 'vision', color_adj=color_adj, color_adj_to=ctemp0)
398
+ return x
399
+
400
+ def application_i2t2i(self, cim, ctx_n, ctx_p, n_samples=None, color_adj=None,):
401
+ net = self.net
402
+ scale = 7.5
403
+ sampler = self.sampler
404
+ ddim_steps = 50
405
+ ddim_eta = 0.0
406
+ prompt_temperature = 1.0
407
+ n_samples = n_sample_image_default if n_samples is None else n_samples
408
+
409
+ ctemp0 = self.regularize_image(cim)
410
+ ctemp1 = ctemp0*2 - 1
411
+ ctemp1 = ctemp1[None].repeat(n_samples, 1, 1, 1)
412
+ cim = net.clip_encode_vision(ctemp1)
413
+ uim = None
414
+ if scale != 1.0:
415
+ dummy = torch.zeros_like(ctemp1)
416
+ uim = net.clip_encode_vision(dummy)
417
+
418
+ uim, cim = [uim.half(), cim.half()] if self.fp16 else [uim, cim]
419
+
420
+ n = 768
421
+ shape = [n_samples, n]
422
+ zt, _ = sampler.sample(
423
+ steps=ddim_steps,
424
+ shape=shape,
425
+ conditioning=cim,
426
+ unconditional_guidance_scale=scale,
427
+ unconditional_conditioning=uim,
428
+ xtype='text', ctype='vision',
429
+ eta=ddim_eta,
430
+ verbose=False,)
431
+ ztn = net.optimus_encode([ctx_n])
432
+ ztp = net.optimus_encode([ctx_p])
433
+
434
+ ztn_norm = ztn / ztn.norm(dim=1)
435
+ zt_proj_mag = torch.matmul(zt, ztn_norm[0])
436
+ zt_perp = zt - zt_proj_mag[:, None] * ztn_norm
437
+ zt_newd = zt_perp + ztp
438
+ ctx_new = net.optimus_decode(zt_newd, temperature=prompt_temperature)
439
+
440
+ ctx_new = net.clip_encode_text(ctx_new)
441
+ ctx_p = net.clip_encode_text([ctx_p])
442
+ ctx_new = torch.cat([ctx_new, ctx_p.repeat(n_samples, 1, 1)], dim=1)
443
+ utx_new = net.clip_encode_text(n_samples * [""])
444
+ utx_new = torch.cat([utx_new, utx_new], dim=1)
445
+
446
+ cim_loc = cim[:, 1: ]
447
+ cim_loc_new = self.find_low_rank(cim_loc, demean=True, q=10)
448
+ cim_new = cim_loc_new
449
+ uim_new = uim[:, 1:]
450
+
451
+ h, w = [512, 512]
452
+ shape = [n_samples, 4, h//8, w//8]
453
+ z, _ = sampler.sample_dc(
454
+ steps=ddim_steps,
455
+ shape=shape,
456
+ first_conditioning=[uim_new, cim_new],
457
+ second_conditioning=[utx_new, ctx_new],
458
+ unconditional_guidance_scale=scale,
459
+ xtype='image',
460
+ first_ctype='vision',
461
+ second_ctype='prompt',
462
+ eta=ddim_eta,
463
+ verbose=False,
464
+ mixed_ratio=0.33, )
465
+
466
+ x = self.decode(z, 'image', 'vision', color_adj=color_adj, color_adj_to=ctemp0)
467
+ return x
468
+
469
+ vd_inference = vd_inference(hfm_repo=[hfm_repo_id, hfm_filename], fp16=True, device='cuda')
470
+
471
+ def main(mode,
472
+ image=None,
473
+ prompt=None,
474
+ nprompt=None,
475
+ pprompt=None,
476
+ color_adj=None,
477
+ disentanglement_level=None,
478
+ dual_guided_mixing=None,
479
+ seed=0,):
480
+
481
+ if seed<0:
482
+ seed = 0
483
+ np.random.seed(seed)
484
+ torch.manual_seed(seed+100)
485
+
486
+ if mode == 'Text-to-Image':
487
+ if (prompt is None) or (prompt == ""):
488
+ return None, None
489
+ with torch.no_grad():
490
+ rv = vd_inference.inference(
491
+ xtype = 'image',
492
+ cin = prompt,
493
+ ctype = 'prompt', )
494
+ return rv, None
495
+ elif mode == 'Image-Variation':
496
+ if image is None:
497
+ return None, None
498
+ with torch.no_grad():
499
+ rv = vd_inference.inference(
500
+ xtype = 'image',
501
+ cin = image,
502
+ ctype = 'vision',
503
+ color_adj = color_adj,)
504
+ return rv, None
505
+ elif mode == 'Image-to-Text':
506
+ if image is None:
507
+ return None, None
508
+ with torch.no_grad():
509
+ rv = vd_inference.inference(
510
+ xtype = 'text',
511
+ cin = image,
512
+ ctype = 'vision',)
513
+ return None, '\n'.join(rv)
514
+ elif mode == 'Text-Variation':
515
+ if prompt is None:
516
+ return None, None
517
+ with torch.no_grad():
518
+ rv = vd_inference.inference(
519
+ xtype = 'text',
520
+ cin = prompt,
521
+ ctype = 'prompt',)
522
+ return None, '\n'.join(rv)
523
+ elif mode == 'Disentanglement':
524
+ if image is None:
525
+ return None, None
526
+ with torch.no_grad():
527
+ rv = vd_inference.application_disensemble(
528
+ cin = image,
529
+ level = disentanglement_level,
530
+ color_adj = color_adj,)
531
+ return rv, None
532
+ elif mode == 'Dual-Guided':
533
+ if (image is None) or (prompt is None) or (prompt==""):
534
+ return None, None
535
+ with torch.no_grad():
536
+ rv = vd_inference.application_dualguided(
537
+ cim = image,
538
+ ctx = prompt,
539
+ mixing = dual_guided_mixing,
540
+ color_adj = color_adj,)
541
+ return rv, None
542
+ elif mode == 'Latent-I2T2I':
543
+ if (image is None) or (nprompt is None) or (nprompt=="") \
544
+ or (pprompt is None) or (pprompt==""):
545
+ return None, None
546
+ with torch.no_grad():
547
+ rv = vd_inference.application_i2t2i(
548
+ cim = image,
549
+ ctx_n = nprompt,
550
+ ctx_p = pprompt,
551
+ color_adj = color_adj,)
552
+ return rv, None
553
+ else:
554
+ assert False, "No such mode!"
555
+
556
+ def get_instruction(mode):
557
+ t2i_instruction = ["Generate image from text prompt."]
558
+ i2i_instruction = [
559
+ "Generate image conditioned on reference image.",
560
+ "Color Calibration provide an opinion to adjust image color according to reference image.", ]
561
+ i2t_instruction = ["Generate text from reference image."]
562
+ t2t_instruction = ["Generate text from reference text prompt. (Model insufficiently trained, thus results are still experimental)"]
563
+ dis_instruction = [
564
+ "Generate a variation of reference image that disentangled for semantic or style.",
565
+ "Color Calibration provide an opinion to adjust image color according to reference image.",
566
+ "Disentanglement level controls the level of focus towards semantic (-2, -1) or style (1 2). Level 0 serves as Image-Variation.", ]
567
+ dug_instruction = [
568
+ "Generate image from dual guidance of reference image and text prompt.",
569
+ "Color Calibration provide an opinion to adjust image color according to reference image.",
570
+ "Guidance Mixing provides linear balances between image and text context. (0 towards image, 1 towards text)", ]
571
+ iti_instruction = [
572
+ "Generate image variations via image-to-text, text-latent-editing, and then text-to-image. (Still under exploration)",
573
+ "Color Calibration provide an opinion to adjust image color according to reference image.",
574
+ "Input prompt that will be substract from text/text latent code.",
575
+ "Input prompt that will be added to text/text latent code.", ]
576
+
577
+ if mode == "Text-to-Image":
578
+ return '\n'.join(t2i_instruction)
579
+ elif mode == "Image-Variation":
580
+ return '\n'.join(i2i_instruction)
581
+ elif mode == "Image-to-Text":
582
+ return '\n'.join(i2t_instruction)
583
+ elif mode == "Text-Variation":
584
+ return '\n'.join(t2t_instruction)
585
+ elif mode == "Disentanglement":
586
+ return '\n'.join(dis_instruction)
587
+ elif mode == "Dual-Guided":
588
+ return '\n'.join(dug_instruction)
589
+ elif mode == "Latent-I2T2I":
590
+ return '\n'.join(iti_instruction)
591
+
592
+ #############
593
+ # Interface #
594
+ #############
595
+
596
+ if True:
597
+ img_output = gr.Gallery(label="Image Result").style(grid=n_sample_image_default)
598
+ txt_output = gr.Textbox(lines=4, label='Text Result', visible=False)
599
+
600
+ with gr.Blocks() as demo:
601
+ gr.HTML(
602
+ """
603
+ <div style="position: relative; float: left; text-align: center; width: 60%; min-width:600px; height: 160px; margin: 20px 0 20px 20%;">
604
+ <h1 style="font-weight: 900; font-size: 3rem;">
605
+ Versatile Diffusion
606
+ </h1>
607
+ <br>
608
+ <h2 style="font-weight: 450; font-size: 1rem;">
609
+ We built <b>Versatile Diffusion (VD), the first unified multi-flow multimodal diffusion framework</b>, as a step towards <b>Universal Generative AI</b>.
610
+ VD can natively support image-to-text, image-variation, text-to-image, and text-variation,
611
+ and can be further extended to other applications such as
612
+ semantic-style disentanglement, image-text dual-guided generation, latent image-to-text-to-image editing, and more.
613
+ Future versions will support more modalities such as speech, music, video and 3D.
614
+ </h2>
615
+ <br>
616
+ <h3>Xingqian Xu, Atlas Wang, Eric Zhang, Kai Wang,
617
+ and <a href="https://www.humphreyshi.com/home">Humphrey Shi</a>
618
+ [<a href="https://arxiv.org/abs/2211.08332" style="color:blue;">arXiv</a>]
619
+ [<a href="https://github.com/SHI-Labs/Versatile-Diffusion" style="color:blue;">GitHub</a>]
620
+ </h3>
621
+ </div>
622
+ <div style="position: relative; float: right; width: 19.9%; min-width:200px; margin: 20px auto;">
623
+ <img src="https://huggingface.co/spaces/shi-labs/Versatile-Diffusion/resolve/main/assets/figures/share_instruction.png">
624
+ </div>
625
+ """)
626
+ mode_input = gr.Radio([
627
+ "Text-to-Image", "Image-Variation", "Image-to-Text", "Text-Variation",
628
+ "Disentanglement", "Dual-Guided", "Latent-I2T2I"], value='Text-to-Image', label="VD Flows and Applications")
629
+
630
+ instruction = gr.Textbox(get_instruction("Text-to-Image"), label='Info')
631
+
632
+ with gr.Row():
633
+ with gr.Column():
634
+ img_input = gr.Image(label='Image Input', visible=False)
635
+ txt_input = gr.Textbox(lines=4, placeholder="Input prompt...", label='Text Input')
636
+ ntxt_input = gr.Textbox(label='Remove Prompt', visible=False)
637
+ ptxt_input = gr.Textbox(label='Add Prompt', visible=False)
638
+ coladj_input = gr.Radio(["None", "Simple"], value='Simple', label="Color Calibration", visible=False)
639
+ dislvl_input = gr.Slider(-2, 2, value=0, step=1, label="Disentanglement level", visible=False)
640
+ dguide_input = gr.Slider(0, 1, value=0.5, step=0.01, label="Guidance Mixing", visible=False)
641
+ seed_input = gr.Number(100, label="Seed", precision=0)
642
+
643
+ btn = gr.Button("Run")
644
+ btn.click(
645
+ main,
646
+ inputs=[
647
+ mode_input,
648
+ img_input,
649
+ txt_input,
650
+ ntxt_input,
651
+ ptxt_input,
652
+ coladj_input,
653
+ dislvl_input,
654
+ dguide_input,
655
+ seed_input, ],
656
+ outputs=[img_output, txt_output])
657
+
658
+ with gr.Column():
659
+ img_output.render()
660
+ txt_output.render()
661
+
662
+ example_mode = [
663
+ "Text-to-Image",
664
+ "Image-Variation",
665
+ "Image-to-Text",
666
+ "Text-Variation",
667
+ "Disentanglement",
668
+ "Dual-Guided",
669
+ "Latent-I2T2I"]
670
+
671
+ def get_example(mode):
672
+ if mode == 'Text-to-Image':
673
+ case = [
674
+ ['a dream of a village in china, by Caspar David Friedrich, matte painting trending on artstation HQ', 23],
675
+ ['a beautiful grand nebula in the universe', 24],
676
+ ['heavy arms gundam penguin mech', 25],
677
+ ]
678
+ elif mode == "Image-Variation":
679
+ case = [
680
+ ['assets/space.jpg', 'None', 26],
681
+ ['assets/train.jpg', 'Simple', 27],
682
+ ]
683
+ elif mode == "Image-to-Text":
684
+ case = [
685
+ ['assets/boy_and_girl.jpg' , 28],
686
+ ['assets/house_by_lake.jpg', 29],
687
+ ]
688
+ elif mode == "Text-Variation":
689
+ case = [
690
+ ['a dream of a village in china, by Caspar David Friedrich, matte painting trending on artstation HQ' , 32],
691
+ ['a beautiful grand nebula in the universe' , 33],
692
+ ['heavy arms gundam penguin mech', 34],
693
+ ]
694
+ elif mode == "Disentanglement":
695
+ case = [
696
+ ['assets/vermeer.jpg', 'Simple', -2, 30],
697
+ ['assets/matisse.jpg', 'Simple', 2, 31],
698
+ ]
699
+ elif mode == "Dual-Guided":
700
+ case = [
701
+ ['assets/benz.jpg', 'cyberpunk 2077', 'Simple', 0.75, 22],
702
+ ['assets/vermeer.jpg', 'a girl with a diamond necklace', 'Simple', 0.66, 21],
703
+ ]
704
+ elif mode == "Latent-I2T2I":
705
+ case = [
706
+ ['assets/ghibli.jpg', 'white house', 'tall castle', 'Simple', 20],
707
+ ['assets/matisse.jpg', 'fruits and bottles on the table', 'flowers on the table', 'Simple', 21],
708
+ ]
709
+ else:
710
+ raise ValueError
711
+ case = [[mode] + casei for casei in case]
712
+ return case
713
+
714
+ def get_example_iof(mode):
715
+ if mode == 'Text-to-Image':
716
+ inps = [txt_input, seed_input]
717
+ oups = [img_output]
718
+ fn = lambda m, x, y: \
719
+ main(mode=m, prompt=x, seed=y)[0]
720
+ elif mode == "Image-Variation":
721
+ inps = [img_input, coladj_input, seed_input]
722
+ oups = [img_output]
723
+ fn = lambda m, x, y, z: \
724
+ main(mode=m, image=x, color_adj=y, seed=z)[0]
725
+ elif mode == "Image-to-Text":
726
+ inps = [img_input, seed_input]
727
+ oups = [txt_output]
728
+ fn = lambda m, x, y: \
729
+ main(mode=m, image=x, seed=y)[1]
730
+ elif mode == "Text-Variation":
731
+ inps = [txt_input, seed_input]
732
+ oups = [txt_output]
733
+ fn = lambda m, x, y: \
734
+ main(mode=m, prompt=x, seed=y)[1]
735
+ elif mode == "Disentanglement":
736
+ inps = [img_input, coladj_input, dislvl_input, seed_input]
737
+ oups = [img_output]
738
+ fn = lambda m, x, y, z, w: \
739
+ main(mode=m, image=x, color_adj=y, disentanglement_level=z, seed=w)[0]
740
+ elif mode == "Dual-Guided":
741
+ inps = [img_input, txt_input, coladj_input, dguide_input, seed_input]
742
+ oups = [img_output]
743
+ fn = lambda m, x, y, z, w, u: \
744
+ main(mode=m, image=x, prompt=y, color_adj=z, dual_guided_mixing=w, seed=u)[0]
745
+ elif mode == "Latent-I2T2I":
746
+ inps = [img_input, ntxt_input, ptxt_input, coladj_input, seed_input]
747
+ oups = [img_output]
748
+ fn = lambda m, x, y, z, w, u: \
749
+ main(mode=m, image=x, nprompt=y, pprompt=z, color_adj=w, seed=u)[0]
750
+ else:
751
+ raise ValueError
752
+ return [mode_input]+inps, oups, fn
753
+
754
+ with gr.Row():
755
+ for emode in example_mode[0:4]:
756
+ with gr.Column():
757
+ gr.Examples(
758
+ label=emode+' Examples',
759
+ examples=get_example(emode),
760
+ inputs=get_example_iof(emode)[0],
761
+ outputs=get_example_iof(emode)[1],
762
+ fn = get_example_iof(emode)[2],
763
+ cache_examples=cache_examples),
764
+ with gr.Row():
765
+ for emode in example_mode[4:7]:
766
+ with gr.Column():
767
+ gr.Examples(
768
+ label=emode+' Examples',
769
+ examples=get_example(emode),
770
+ inputs=get_example_iof(emode)[0],
771
+ outputs=get_example_iof(emode)[1],
772
+ fn = get_example_iof(emode)[2],
773
+ cache_examples=cache_examples),
774
+
775
+ mode_input.change(
776
+ fn=lambda x: gr.update(value=get_instruction(x)),
777
+ inputs=mode_input,
778
+ outputs=instruction,)
779
+
780
+ mode_input.change(
781
+ fn=lambda x: gr.update(visible=(x not in ['Text-to-Image', 'Text-Variation'])),
782
+ inputs=mode_input,
783
+ outputs=img_input,)
784
+
785
+ mode_input.change(
786
+ fn=lambda x: gr.update(visible=(x in ['Text-to-Image', 'Text-Variation', 'Dual-Guided'])),
787
+ inputs=mode_input,
788
+ outputs=txt_input,)
789
+
790
+ mode_input.change(
791
+ fn=lambda x: gr.update(visible=(x in ['Latent-I2T2I'])),
792
+ inputs=mode_input,
793
+ outputs=ntxt_input,)
794
+ mode_input.change(
795
+ fn=lambda x: gr.update(visible=(x in ['Latent-I2T2I'])),
796
+ inputs=mode_input,
797
+ outputs=ptxt_input,)
798
+
799
+ mode_input.change(
800
+ fn=lambda x: gr.update(visible=(x not in ['Text-to-Image', 'Image-to-Text', 'Text-Variation'])),
801
+ inputs=mode_input,
802
+ outputs=coladj_input,)
803
+
804
+ mode_input.change(
805
+ fn=lambda x: gr.update(visible=(x=='Disentanglement')),
806
+ inputs=mode_input,
807
+ outputs=dislvl_input,)
808
+
809
+ mode_input.change(
810
+ fn=lambda x: gr.update(visible=(x=='Dual-Guided')),
811
+ inputs=mode_input,
812
+ outputs=dguide_input,)
813
+
814
+ mode_input.change(
815
+ fn=lambda x: gr.update(visible=(x not in ['Image-to-Text', 'Text-Variation'])),
816
+ inputs=mode_input,
817
+ outputs=img_output,)
818
+ mode_input.change(
819
+ fn=lambda x: gr.update(visible=(x in ['Image-to-Text', 'Text-Variation'])),
820
+ inputs=mode_input,
821
+ outputs=txt_output,)
822
+
823
+ gr.HTML(
824
+ """
825
+ <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
826
+ <h3>
827
+ <b>Caution</b>:
828
+ We would like the raise the awareness of users of this demo of its potential issues and concerns.
829
+ Like previous large foundation models, Versatile Diffusion could be problematic in some cases, partially due to the imperfect training data and pretrained network (VAEs / context encoders) with limited scope.
830
+ In its future research phase, VD may do better on tasks such as text-to-image, image-to-text, etc., with the help of more powerful VAEs, more sophisticated network designs, and more cleaned data.
831
+ So far, we keep all features available for research testing both to show the great potential of the VD framework and to collect important feedback to improve the model in the future.
832
+ We welcome researchers and users to report issues with the HuggingFace community discussion feature or email the authors.
833
+ </h3>
834
+ <br>
835
+ <h3>
836
+ <b>Biases and content acknowledgement</b>:
837
+ Beware that VD may output content that reinforces or exacerbates societal biases, as well as realistic faces, pornography, and violence.
838
+ VD was trained on the LAION-2B dataset, which scraped non-curated online images and text, and may contained unintended exceptions as we removed illegal content.
839
+ VD in this demo is meant only for research purposes.
840
+ </h3>
841
+ </div>
842
+ """)
843
+
844
+ # demo.launch(share=True)
845
+ demo.launch(debug=True)
assets/benz.jpg ADDED

Git LFS Details

  • SHA256: bdfdfb603af2179878013b08500fdc78c5f20d70efd581f2ebfed1b65321f9a2
  • Pointer size: 131 Bytes
  • Size of remote file: 204 kB
assets/boy_and_girl.jpg ADDED

Git LFS Details

  • SHA256: aba3f4834a4f82fb65ff8e6c5e5a1b60d248d2e83d97321b98a0d24ba999390c
  • Pointer size: 131 Bytes
  • Size of remote file: 139 kB
assets/church.jpg ADDED

Git LFS Details

  • SHA256: ec3be4a83b1ceb43cfee1c5bd125f564e0b42a71c440e731fa9cecc2b761263d
  • Pointer size: 131 Bytes
  • Size of remote file: 338 kB
assets/figures/share_instruction.png ADDED
assets/firework.jpg ADDED

Git LFS Details

  • SHA256: 6040aeca347b2896de63b3bf9145e307ad06fa4ab0435609e1d7df5587c29bd6
  • Pointer size: 131 Bytes
  • Size of remote file: 279 kB
assets/ghibli.jpg ADDED

Git LFS Details

  • SHA256: 153e34326ce625f2a6c41d6922549ad690b63d8e18de43532e7fd9808cb9de8b
  • Pointer size: 131 Bytes
  • Size of remote file: 145 kB
assets/horse.png ADDED

Git LFS Details

  • SHA256: 27c5ba007e2984f2e8128df6418c780fefcbd940025ccece14a5a13894065457
  • Pointer size: 131 Bytes
  • Size of remote file: 395 kB
assets/house_by_lake.jpg ADDED

Git LFS Details

  • SHA256: 3d3dcc9f8d8eb90b69fb0a17967440b960bb1d545bcf85de37c4d08f9e5d4606
  • Pointer size: 131 Bytes
  • Size of remote file: 189 kB
assets/matisse.jpg ADDED

Git LFS Details

  • SHA256: ea0428092cbca5224b72a3665c140e96142b5df9c78b36b66f910f42093ecd4f
  • Pointer size: 131 Bytes
  • Size of remote file: 271 kB
assets/night_light.jpg ADDED

Git LFS Details

  • SHA256: 5103ce525e00f0f8ff3c83bcbf954ebda5deb869377a8ccd2b5b6362f7b0aa4a
  • Pointer size: 131 Bytes
  • Size of remote file: 213 kB
assets/penguin.png ADDED

Git LFS Details

  • SHA256: e22e87eec01455b342a849d868ea6cf893b8ae7d81da54eba2f620dbccf972ac
  • Pointer size: 131 Bytes
  • Size of remote file: 147 kB
assets/san_diego.jpg ADDED

Git LFS Details

  • SHA256: 491e93d1b0e99ae2223d85beb4cc98aa790c377b51a938721c3d0c62645a81e4
  • Pointer size: 131 Bytes
  • Size of remote file: 235 kB
assets/scream.jpg ADDED

Git LFS Details

  • SHA256: e154dbe35cb10c4f65c022e97ba9d6ccd7f6ebf3285ff56ebcd9c43b0246e309
  • Pointer size: 131 Bytes
  • Size of remote file: 246 kB
assets/space.jpg ADDED

Git LFS Details

  • SHA256: 7cb01b250297f088ecb1310d746a18b141ab26cd6f3f46e41c52285bb4c7e3d4
  • Pointer size: 131 Bytes
  • Size of remote file: 236 kB
assets/tiger.jpg ADDED

Git LFS Details

  • SHA256: a4b58a11be073fad21a218bcf6478a4da61532b8520176a6454542eb9368081d
  • Pointer size: 131 Bytes
  • Size of remote file: 272 kB
assets/train.jpg ADDED

Git LFS Details

  • SHA256: 50b45524dc627d0042ed789e4495250e71aae7f2935b9a0879d09cf73d8aff37
  • Pointer size: 131 Bytes
  • Size of remote file: 310 kB
assets/vermeer.jpg ADDED

Git LFS Details

  • SHA256: d884ddfe302572f6c4eb8607942cc59807b3dc8321f27ff9358b8e5a3657d015
  • Pointer size: 130 Bytes
  • Size of remote file: 65.5 kB
configs/model/clip.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ clip:
3
+ symbol: clip
4
+ args: {}
5
+
6
+ clip_frozen:
7
+ super_cfg: clip
8
+ type: clip_frozen
9
+ args: {}
10
+
11
+ clip_text_frozen:
12
+ super_cfg: clip
13
+ type: clip_text_frozen
14
+ args: {}
15
+
16
+ clip_vision_frozen:
17
+ super_cfg: clip
18
+ type: clip_vision_frozen
19
+ args: {}
20
+
21
+ ############################
22
+ # clip with focused encode #
23
+ ############################
24
+
25
+ clip_frozen_encode_text:
26
+ super_cfg: clip
27
+ type: clip_frozen
28
+ args:
29
+ encode_type : encode_text
30
+
31
+ clip_frozen_encode_vision:
32
+ super_cfg: clip
33
+ type: clip_frozen
34
+ args:
35
+ encode_type : encode_vision
36
+
37
+ clip_frozen_encode_text_noproj:
38
+ super_cfg: clip
39
+ type: clip_frozen
40
+ args:
41
+ encode_type : encode_text_noproj
42
+
43
+ #####################################
44
+ # clip vision forzen justin version #
45
+ #####################################
46
+
47
+ clip_vision_frozen_justin:
48
+ super_cfg: clip
49
+ type: clip_vision_frozen_justin
50
+ args: {}
configs/model/openai_unet.yaml ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ openai_unet_sd:
2
+ type: openai_unet
3
+ args:
4
+ image_size: null # no use
5
+ in_channels: 4
6
+ out_channels: 4
7
+ model_channels: 320
8
+ attention_resolutions: [ 4, 2, 1 ]
9
+ num_res_blocks: [ 2, 2, 2, 2 ]
10
+ channel_mult: [ 1, 2, 4, 4 ]
11
+ # disable_self_attentions: [ False, False, False, False ] # converts the self-attention to a cross-attention layer if true
12
+ num_heads: 8
13
+ use_spatial_transformer: True
14
+ transformer_depth: 1
15
+ context_dim: 768
16
+ use_checkpoint: True
17
+ legacy: False
18
+
19
+ openai_unet_dual_context:
20
+ super_cfg: openai_unet_sd
21
+ type: openai_unet_dual_context
22
+
23
+ ########################
24
+ # Code cleaned version #
25
+ ########################
26
+
27
+ openai_unet_2d:
28
+ type: openai_unet_2d
29
+ args:
30
+ input_channels: 4
31
+ model_channels: 320
32
+ output_channels: 4
33
+ num_noattn_blocks: [ 2, 2, 2, 2 ]
34
+ channel_mult: [ 1, 2, 4, 4 ]
35
+ with_attn: [true, true, true, false]
36
+ num_heads: 8
37
+ context_dim: 768
38
+ use_checkpoint: True
39
+
40
+ openai_unet_0d:
41
+ type: openai_unet_0d
42
+ args:
43
+ input_channels: 768
44
+ model_channels: 320
45
+ output_channels: 768
46
+ num_noattn_blocks: [ 2, 2, 2, 2 ]
47
+ channel_mult: [ 1, 2, 4, 4 ]
48
+ with_attn: [true, true, true, false]
49
+ num_heads: 8
50
+ context_dim: 768
51
+ use_checkpoint: True
52
+
53
+ openai_unet_0dmd:
54
+ type: openai_unet_0dmd
55
+ args:
56
+ input_channels: 768
57
+ model_channels: 320
58
+ output_channels: 768
59
+ num_noattn_blocks: [ 2, 2, 2, 2 ]
60
+ channel_mult: [ 1, 2, 4, 4 ]
61
+ second_dim: [ 4, 4, 4, 4 ]
62
+ with_attn: [true, true, true, false]
63
+ num_heads: 8
64
+ context_dim: 768
65
+ use_checkpoint: True
66
+
67
+ openai_unet_vd:
68
+ type: openai_unet_vd
69
+ args:
70
+ unet_image_cfg: MODEL(openai_unet_2d)
71
+ unet_text_cfg: MODEL(openai_unet_0dmd)
72
+
configs/model/optimus.yaml ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ optimus:
3
+ symbol: optimus
4
+ find_unused_parameters: false
5
+ args: {}
6
+
7
+ optimus_bert_encoder:
8
+ super_cfg: optimus
9
+ type: optimus_bert_connector
10
+ # pth: pretrained/optimus_bert_encoder.pth
11
+ args:
12
+ config:
13
+ architectures:
14
+ - BertForMaskedLM
15
+ attention_probs_dropout_prob: 0.1
16
+ finetuning_task: null
17
+ hidden_act: gelu
18
+ hidden_dropout_prob: 0.1
19
+ hidden_size: 768
20
+ initializer_range: 0.02
21
+ intermediate_size: 3072
22
+ layer_norm_eps: 1.e-12
23
+ max_position_embeddings: 512
24
+ num_attention_heads: 12
25
+ num_hidden_layers: 12
26
+ num_labels: 2
27
+ output_attentions: false
28
+ output_hidden_states: false
29
+ pruned_heads: {}
30
+ torchscript: false
31
+ type_vocab_size: 2
32
+ vocab_size: 28996
33
+ latent_size: 768
34
+
35
+ optimus_bert_tokenizer:
36
+ super_cfg: optimus
37
+ type: optimus_bert_tokenizer
38
+ args:
39
+ do_lower_case: false
40
+ max_len: 512
41
+ vocab_file: lib/model_zoo/optimus_models/vocab/bert-base-cased-vocab.txt
42
+
43
+ optimus_gpt2_decoder:
44
+ super_cfg: optimus
45
+ type: optimus_gpt2_connector
46
+ # pth: pretrained/optimus_gpt2_decoder.pth
47
+ args:
48
+ config:
49
+ architectures:
50
+ - GPT2LMHeadModel
51
+ attn_pdrop: 0.1
52
+ embd_pdrop: 0.1
53
+ finetuning_task: null
54
+ hidden_size: 768
55
+ initializer_range: 0.02
56
+ latent_size: 768
57
+ layer_norm_epsilon: 1.e-05
58
+ max_position_embeddings: 1024
59
+ n_ctx: 1024
60
+ n_embd: 768
61
+ n_head: 12
62
+ n_layer: 12
63
+ n_positions: 1024
64
+ num_attention_heads: 12
65
+ num_hidden_layers: 12
66
+ num_labels: 1
67
+ output_attentions: false
68
+ output_hidden_states: false
69
+ pretrained_config_archive_map:
70
+ gpt2 : https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json
71
+ gpt2-medium : https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json
72
+ gpt2-large : https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json
73
+ pruned_heads: {}
74
+ resid_pdrop: 0.1
75
+ summary_activation: null
76
+ summary_first_dropout: 0.1
77
+ summary_proj_to_labels: true
78
+ summary_type: cls_index
79
+ summary_use_proj: true
80
+ torchscript: false
81
+ vocab_size: 50260
82
+
83
+ optimus_gpt2_tokenizer:
84
+ super_cfg: optimus
85
+ type: optimus_gpt2_tokenizer
86
+ args:
87
+ do_lower_case: false
88
+ max_len: 1024
89
+ vocab_file: lib/model_zoo/optimus_models/vocab/gpt2-vocab.json
90
+ merges_file: lib/model_zoo/optimus_models/vocab/gpt2-merges.txt
91
+
92
+ optimus_vae:
93
+ super_cfg: optimus
94
+ type: optimus_vae
95
+ args:
96
+ encoder: MODEL(optimus_bert_encoder)
97
+ decoder: MODEL(optimus_gpt2_decoder)
98
+ tokenizer_encoder: MODEL(optimus_bert_tokenizer)
99
+ tokenizer_decoder: MODEL(optimus_gpt2_tokenizer)
100
+ args:
101
+ latent_size: 768
102
+ hfm: ['shi-labs/versatile-diffusion-model', 'pretrained_pth/optimus-vae.pth']
103
+ # pth: pretrained/optimus-vae.pth
configs/model/sd.yaml ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ sd_base:
2
+ symbol: sd
3
+ find_unused_parameters: true
4
+
5
+ sd_autoencoder:
6
+ type: autoencoderkl
7
+ args:
8
+ embed_dim: 4
9
+ monitor: val/rec_loss
10
+ ddconfig:
11
+ double_z: true
12
+ z_channels: 4
13
+ resolution: 256
14
+ in_channels: 3
15
+ out_ch: 3
16
+ ch: 128
17
+ ch_mult: [1, 2, 4, 4]
18
+ num_res_blocks: 2
19
+ attn_resolutions: []
20
+ dropout: 0.0
21
+ lossconfig:
22
+ target: torch.nn.Identity
23
+ # pth: pretrained/kl-f8.pth
24
+ hfm: ['shi-labs/versatile-diffusion-model', 'pretrained_pth/kl-f8.pth']
25
+
26
+ sd_t2i:
27
+ super_cfg: sd_base
28
+ type: sd_t2i
29
+ args:
30
+ first_stage_config: MODEL(sd_autoencoder)
31
+ cond_stage_config: MODEL(clip_text_frozen)
32
+ unet_config: MODEL(openai_unet_sd)
33
+ beta_linear_start: 0.00085
34
+ beta_linear_end: 0.012
35
+ num_timesteps_cond: 1
36
+ timesteps: 1000
37
+ scale_factor: 0.18215
38
+ use_ema: true
39
+
40
+ sd_t2i_noema:
41
+ super_cfg: sd
42
+ args:
43
+ use_ema: false
44
+
45
+ #####################
46
+ # sd with full clip #
47
+ #####################
48
+
49
+ sd_t2i_fullclip_backward_compatible:
50
+ super_cfg: sd_t2i
51
+ args:
52
+ cond_stage_config: MODEL(clip_frozen_encode_text_noproj)
53
+
54
+ sd_t2i_fullclip_backward_compatible_noema:
55
+ super_cfg: sd_t2i_noema
56
+ args:
57
+ cond_stage_config: MODEL(clip_frozen_encode_text_noproj)
58
+
59
+ sd_t2i_fullclip:
60
+ super_cfg: sd_t2i
61
+ args:
62
+ cond_stage_config: MODEL(clip_frozen_encode_text)
63
+
64
+ sd_variation:
65
+ super_cfg: sd_t2i
66
+ type: sd_variation
67
+ args:
68
+ cond_stage_config: MODEL(clip_vision_frozen_justin)
69
+
configs/model/vd.yaml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # vd_base:
2
+ # symbol: vd
3
+ # find_unused_parameters: true
4
+
5
+ ############
6
+ # vd basic #
7
+ ############
8
+
9
+ vd_basic:
10
+ super_cfg: sd_t2i
11
+ type: vd_basic
12
+ symbol: vd
13
+ find_unused_parameters: true
14
+ args:
15
+ cond_stage_config: MODEL(clip_frozen_encode_vision)
16
+
17
+ vd_basic_noema:
18
+ super_cfg: vd_basic
19
+ args:
20
+ use_ema: false
21
+
22
+ ###################
23
+ # vd dual-context #
24
+ ###################
25
+
26
+ vd_dc:
27
+ super_cfg: sd_t2i_fullclip
28
+ type: vd_dc
29
+ symbol: vd
30
+ find_unused_parameters: true
31
+ args:
32
+ unet_config: MODEL(openai_unet_dual_context)
33
+
34
+ vd_dc_noema:
35
+ super_cfg: vd_dc
36
+ args:
37
+ use_ema: false
38
+
39
+ ######
40
+ # vd #
41
+ ######
42
+
43
+ vd:
44
+ type: vd
45
+ symbol: vd
46
+ find_unused_parameters: true
47
+ args:
48
+ autokl_cfg: MODEL(sd_autoencoder)
49
+ optimus_cfg: MODEL(optimus_vae)
50
+ clip_cfg: MODEL(clip_frozen)
51
+ unet_config: MODEL(openai_unet_vd)
52
+ beta_linear_start: 0.00085
53
+ beta_linear_end: 0.012
54
+ timesteps: 1000
55
+ scale_factor: 0.18215
56
+ use_ema: true
57
+
58
+ vd_noema:
59
+ super_cfg: vd
60
+ args:
61
+ use_ema: false
lib/__init__.py ADDED
File without changes
lib/cfg_helper.py ADDED
@@ -0,0 +1,664 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import shutil
4
+ import copy
5
+ import time
6
+ import pprint
7
+ import numpy as np
8
+ import torch
9
+ import matplotlib
10
+ import argparse
11
+ import json
12
+ import yaml
13
+ from easydict import EasyDict as edict
14
+
15
+ from .model_zoo import get_model
16
+
17
+ ############
18
+ # cfg_bank #
19
+ ############
20
+
21
+ def cfg_solvef(cmd, root):
22
+ if not isinstance(cmd, str):
23
+ return cmd
24
+
25
+ if cmd.find('SAME')==0:
26
+ zoom = root
27
+ p = cmd[len('SAME'):].strip('()').split('.')
28
+ p = [pi.strip() for pi in p]
29
+ for pi in p:
30
+ try:
31
+ pi = int(pi)
32
+ except:
33
+ pass
34
+
35
+ try:
36
+ zoom = zoom[pi]
37
+ except:
38
+ return cmd
39
+ return cfg_solvef(zoom, root)
40
+
41
+ if cmd.find('SEARCH')==0:
42
+ zoom = root
43
+ p = cmd[len('SEARCH'):].strip('()').split('.')
44
+ p = [pi.strip() for pi in p]
45
+ find = True
46
+ # Depth first search
47
+ for pi in p:
48
+ try:
49
+ pi = int(pi)
50
+ except:
51
+ pass
52
+
53
+ try:
54
+ zoom = zoom[pi]
55
+ except:
56
+ find = False
57
+ break
58
+
59
+ if find:
60
+ return cfg_solvef(zoom, root)
61
+ else:
62
+ if isinstance(root, dict):
63
+ for ri in root:
64
+ rv = cfg_solvef(cmd, root[ri])
65
+ if rv != cmd:
66
+ return rv
67
+ if isinstance(root, list):
68
+ for ri in root:
69
+ rv = cfg_solvef(cmd, ri)
70
+ if rv != cmd:
71
+ return rv
72
+ return cmd
73
+
74
+ if cmd.find('MODEL')==0:
75
+ goto = cmd[len('MODEL'):].strip('()')
76
+ return model_cfg_bank()(goto)
77
+
78
+ if cmd.find('DATASET')==0:
79
+ goto = cmd[len('DATASET'):].strip('()')
80
+ return dataset_cfg_bank()(goto)
81
+
82
+ return cmd
83
+
84
+ def cfg_solve(cfg, cfg_root):
85
+ # The function solve cfg element such that
86
+ # all sorrogate input are settled.
87
+ # (i.e. SAME(***) )
88
+ if isinstance(cfg, list):
89
+ for i in range(len(cfg)):
90
+ if isinstance(cfg[i], (list, dict)):
91
+ cfg[i] = cfg_solve(cfg[i], cfg_root)
92
+ else:
93
+ cfg[i] = cfg_solvef(cfg[i], cfg_root)
94
+ if isinstance(cfg, dict):
95
+ for k in cfg:
96
+ if isinstance(cfg[k], (list, dict)):
97
+ cfg[k] = cfg_solve(cfg[k], cfg_root)
98
+ else:
99
+ cfg[k] = cfg_solvef(cfg[k], cfg_root)
100
+ return cfg
101
+
102
+ class model_cfg_bank(object):
103
+ def __init__(self):
104
+ self.cfg_dir = osp.join('configs', 'model')
105
+ self.cfg_bank = edict()
106
+
107
+ def __call__(self, name):
108
+ if name not in self.cfg_bank:
109
+ cfg_path = self.get_yaml_path(name)
110
+ with open(cfg_path, 'r') as f:
111
+ cfg_new = yaml.load(
112
+ f, Loader=yaml.FullLoader)
113
+ cfg_new = edict(cfg_new)
114
+ self.cfg_bank.update(cfg_new)
115
+
116
+ cfg = self.cfg_bank[name]
117
+ cfg.name = name
118
+ if 'super_cfg' not in cfg:
119
+ cfg = cfg_solve(cfg, cfg)
120
+ self.cfg_bank[name] = cfg
121
+ return copy.deepcopy(cfg)
122
+
123
+ super_cfg = self.__call__(cfg.super_cfg)
124
+ # unlike other field,
125
+ # args will not be replaced but update.
126
+ if 'args' in cfg:
127
+ if 'args' in super_cfg:
128
+ super_cfg.args.update(cfg.args)
129
+ else:
130
+ super_cfg.args = cfg.args
131
+ cfg.pop('args')
132
+
133
+ super_cfg.update(cfg)
134
+ super_cfg.pop('super_cfg')
135
+ cfg = super_cfg
136
+ try:
137
+ delete_args = cfg.pop('delete_args')
138
+ except:
139
+ delete_args = []
140
+
141
+ for dargs in delete_args:
142
+ cfg.args.pop(dargs)
143
+
144
+ cfg = cfg_solve(cfg, cfg)
145
+ self.cfg_bank[name] = cfg
146
+ return copy.deepcopy(cfg)
147
+
148
+ def get_yaml_path(self, name):
149
+ if name.find('ldm')==0:
150
+ return osp.join(
151
+ self.cfg_dir, 'ldm.yaml')
152
+ elif name.find('comodgan')==0:
153
+ return osp.join(
154
+ self.cfg_dir, 'comodgan.yaml')
155
+ elif name.find('stylegan')==0:
156
+ return osp.join(
157
+ self.cfg_dir, 'stylegan.yaml')
158
+ elif name.find('absgan')==0:
159
+ return osp.join(
160
+ self.cfg_dir, 'absgan.yaml')
161
+ elif name.find('ashgan')==0:
162
+ return osp.join(
163
+ self.cfg_dir, 'ashgan.yaml')
164
+ elif name.find('sr3')==0:
165
+ return osp.join(
166
+ self.cfg_dir, 'sr3.yaml')
167
+ elif name.find('specdiffsr')==0:
168
+ return osp.join(
169
+ self.cfg_dir, 'specdiffsr.yaml')
170
+ elif name.find('openai_unet')==0:
171
+ return osp.join(
172
+ self.cfg_dir, 'openai_unet.yaml')
173
+ elif name.find('clip')==0:
174
+ return osp.join(
175
+ self.cfg_dir, 'clip.yaml')
176
+ elif name.find('sd')==0:
177
+ return osp.join(
178
+ self.cfg_dir, 'sd.yaml')
179
+ elif name.find('vd')==0:
180
+ return osp.join(
181
+ self.cfg_dir, 'vd.yaml')
182
+ elif name.find('optimus')==0:
183
+ return osp.join(
184
+ self.cfg_dir, 'optimus.yaml')
185
+ else:
186
+ raise ValueError
187
+
188
+ class dataset_cfg_bank(object):
189
+ def __init__(self):
190
+ self.cfg_dir = osp.join('configs', 'dataset')
191
+ self.cfg_bank = edict()
192
+
193
+ def __call__(self, name):
194
+ if name not in self.cfg_bank:
195
+ cfg_path = self.get_yaml_path(name)
196
+ with open(cfg_path, 'r') as f:
197
+ cfg_new = yaml.load(
198
+ f, Loader=yaml.FullLoader)
199
+ cfg_new = edict(cfg_new)
200
+ self.cfg_bank.update(cfg_new)
201
+
202
+ cfg = self.cfg_bank[name]
203
+ cfg.name = name
204
+ if cfg.get('super_cfg', None) is None:
205
+ cfg = cfg_solve(cfg, cfg)
206
+ self.cfg_bank[name] = cfg
207
+ return copy.deepcopy(cfg)
208
+
209
+ super_cfg = self.__call__(cfg.super_cfg)
210
+ super_cfg.update(cfg)
211
+ cfg = super_cfg
212
+ cfg.super_cfg = None
213
+ try:
214
+ delete = cfg.pop('delete')
215
+ except:
216
+ delete = []
217
+
218
+ for dargs in delete:
219
+ cfg.pop(dargs)
220
+
221
+ cfg = cfg_solve(cfg, cfg)
222
+ self.cfg_bank[name] = cfg
223
+ return copy.deepcopy(cfg)
224
+
225
+ def get_yaml_path(self, name):
226
+ if name.find('cityscapes')==0:
227
+ return osp.join(
228
+ self.cfg_dir, 'cityscapes.yaml')
229
+ elif name.find('div2k')==0:
230
+ return osp.join(
231
+ self.cfg_dir, 'div2k.yaml')
232
+ elif name.find('gandiv2k')==0:
233
+ return osp.join(
234
+ self.cfg_dir, 'gandiv2k.yaml')
235
+ elif name.find('srbenchmark')==0:
236
+ return osp.join(
237
+ self.cfg_dir, 'srbenchmark.yaml')
238
+ elif name.find('imagedir')==0:
239
+ return osp.join(
240
+ self.cfg_dir, 'imagedir.yaml')
241
+ elif name.find('places2')==0:
242
+ return osp.join(
243
+ self.cfg_dir, 'places2.yaml')
244
+ elif name.find('ffhq')==0:
245
+ return osp.join(
246
+ self.cfg_dir, 'ffhq.yaml')
247
+ elif name.find('imcpt')==0:
248
+ return osp.join(
249
+ self.cfg_dir, 'imcpt.yaml')
250
+ elif name.find('texture')==0:
251
+ return osp.join(
252
+ self.cfg_dir, 'texture.yaml')
253
+ elif name.find('openimages')==0:
254
+ return osp.join(
255
+ self.cfg_dir, 'openimages.yaml')
256
+ elif name.find('laion2b')==0:
257
+ return osp.join(
258
+ self.cfg_dir, 'laion2b.yaml')
259
+ elif name.find('laionart')==0:
260
+ return osp.join(
261
+ self.cfg_dir, 'laionart.yaml')
262
+ elif name.find('celeba')==0:
263
+ return osp.join(
264
+ self.cfg_dir, 'celeba.yaml')
265
+ elif name.find('coyo')==0:
266
+ return osp.join(
267
+ self.cfg_dir, 'coyo.yaml')
268
+ elif name.find('pafc')==0:
269
+ return osp.join(
270
+ self.cfg_dir, 'pafc.yaml')
271
+ elif name.find('coco')==0:
272
+ return osp.join(
273
+ self.cfg_dir, 'coco.yaml')
274
+ else:
275
+ raise ValueError
276
+
277
+ class experiment_cfg_bank(object):
278
+ def __init__(self):
279
+ self.cfg_dir = osp.join('configs', 'experiment')
280
+ self.cfg_bank = edict()
281
+
282
+ def __call__(self, name):
283
+ if name not in self.cfg_bank:
284
+ cfg_path = self.get_yaml_path(name)
285
+ with open(cfg_path, 'r') as f:
286
+ cfg = yaml.load(
287
+ f, Loader=yaml.FullLoader)
288
+ cfg = edict(cfg)
289
+
290
+ cfg = cfg_solve(cfg, cfg)
291
+ cfg = cfg_solve(cfg, cfg)
292
+ # twice for SEARCH
293
+ self.cfg_bank[name] = cfg
294
+ return copy.deepcopy(cfg)
295
+
296
+ def get_yaml_path(self, name):
297
+ return osp.join(
298
+ self.cfg_dir, name+'.yaml')
299
+
300
+ def load_cfg_yaml(path):
301
+ if osp.isfile(path):
302
+ cfg_path = path
303
+ elif osp.isfile(osp.join('configs', 'experiment', path)):
304
+ cfg_path = osp.join('configs', 'experiment', path)
305
+ elif osp.isfile(osp.join('configs', 'experiment', path+'.yaml')):
306
+ cfg_path = osp.join('configs', 'experiment', path+'.yaml')
307
+ else:
308
+ assert False, 'No such config!'
309
+
310
+ with open(cfg_path, 'r') as f:
311
+ cfg = yaml.load(f, Loader=yaml.FullLoader)
312
+ cfg = edict(cfg)
313
+ cfg = cfg_solve(cfg, cfg)
314
+ cfg = cfg_solve(cfg, cfg)
315
+ return cfg
316
+
317
+ ##############
318
+ # cfg_helper #
319
+ ##############
320
+
321
+ def get_experiment_id(ref=None):
322
+ if ref is None:
323
+ time.sleep(0.5)
324
+ return int(time.time()*100)
325
+ else:
326
+ try:
327
+ return int(ref)
328
+ except:
329
+ pass
330
+
331
+ _, ref = osp.split(ref)
332
+ ref = ref.split('_')[0]
333
+ try:
334
+ return int(ref)
335
+ except:
336
+ assert False, 'Invalid experiment ID!'
337
+
338
+ def record_resume_cfg(path):
339
+ cnt = 0
340
+ while True:
341
+ if osp.exists(path+'.{:04d}'.format(cnt)):
342
+ cnt += 1
343
+ continue
344
+ shutil.copyfile(path, path+'.{:04d}'.format(cnt))
345
+ break
346
+
347
+ def get_command_line_args():
348
+ parser = argparse.ArgumentParser()
349
+ parser.add_argument('--debug', action='store_true', default=False)
350
+ parser.add_argument('--config', type=str)
351
+ parser.add_argument('--gpu', nargs='+', type=int)
352
+
353
+ parser.add_argument('--node_rank', type=int, default=0)
354
+ parser.add_argument('--nodes', type=int, default=1)
355
+ parser.add_argument('--addr', type=str, default='127.0.0.1')
356
+ parser.add_argument('--port', type=int, default=11233)
357
+
358
+ parser.add_argument('--signature', nargs='+', type=str)
359
+ parser.add_argument('--seed', type=int)
360
+
361
+ parser.add_argument('--eval', type=str)
362
+ parser.add_argument('--eval_subdir', type=str)
363
+ parser.add_argument('--pretrained', type=str)
364
+
365
+ parser.add_argument('--resume_dir', type=str)
366
+ parser.add_argument('--resume_step', type=int)
367
+ parser.add_argument('--resume_weight', type=str)
368
+
369
+ args = parser.parse_args()
370
+
371
+ # Special handling the resume
372
+ if args.resume_dir is not None:
373
+ cfg = edict()
374
+ cfg.env = edict()
375
+ cfg.env.debug = args.debug
376
+ cfg.env.resume = edict()
377
+ cfg.env.resume.dir = args.resume_dir
378
+ cfg.env.resume.step = args.resume_step
379
+ cfg.env.resume.weight = args.resume_weight
380
+ return cfg
381
+
382
+ cfg = load_cfg_yaml(args.config)
383
+ cfg.env.debug = args.debug
384
+ cfg.env.gpu_device = [0] if args.gpu is None else list(args.gpu)
385
+ cfg.env.master_addr = args.addr
386
+ cfg.env.master_port = args.port
387
+ cfg.env.dist_url = 'tcp://{}:{}'.format(args.addr, args.port)
388
+ cfg.env.node_rank = args.node_rank
389
+ cfg.env.nodes = args.nodes
390
+
391
+ istrain = False if args.eval is not None else True
392
+ isdebug = cfg.env.debug
393
+
394
+ if istrain:
395
+ if isdebug:
396
+ cfg.env.experiment_id = 999999999999
397
+ cfg.train.signature = ['debug']
398
+ else:
399
+ cfg.env.experiment_id = get_experiment_id()
400
+ if args.signature is not None:
401
+ cfg.train.signature = args.signature
402
+ else:
403
+ if 'train' in cfg:
404
+ cfg.pop('train')
405
+ cfg.env.experiment_id = get_experiment_id(args.eval)
406
+ if args.signature is not None:
407
+ cfg.eval.signature = args.signature
408
+
409
+ if isdebug and (args.eval is None):
410
+ cfg.env.experiment_id = 999999999999
411
+ cfg.eval.signature = ['debug']
412
+
413
+ if args.eval_subdir is not None:
414
+ if isdebug:
415
+ cfg.eval.eval_subdir = 'debug'
416
+ else:
417
+ cfg.eval.eval_subdir = args.eval_subdir
418
+ if args.pretrained is not None:
419
+ cfg.eval.pretrained = args.pretrained
420
+ # The override pretrained over the setting in cfg.model
421
+
422
+ if args.seed is not None:
423
+ cfg.env.rnd_seed = args.seed
424
+
425
+ return cfg
426
+
427
+ def cfg_initiates(cfg):
428
+ cfge = cfg.env
429
+ isdebug = cfge.debug
430
+ isresume = 'resume' in cfge
431
+ istrain = 'train' in cfg
432
+ haseval = 'eval' in cfg
433
+ cfgt = cfg.train if istrain else None
434
+ cfgv = cfg.eval if haseval else None
435
+
436
+ ###############################
437
+ # get some environment params #
438
+ ###############################
439
+
440
+ cfge.computer = os.uname()
441
+ cfge.torch_version = str(torch.__version__)
442
+
443
+ ##########
444
+ # resume #
445
+ ##########
446
+
447
+ if isresume:
448
+ resume_cfg_path = osp.join(cfge.resume.dir, 'config.yaml')
449
+ record_resume_cfg(resume_cfg_path)
450
+ with open(resume_cfg_path, 'r') as f:
451
+ cfg_resume = yaml.load(f, Loader=yaml.FullLoader)
452
+ cfg_resume = edict(cfg_resume)
453
+ cfg_resume.env.update(cfge)
454
+ cfg = cfg_resume
455
+ cfge = cfg.env
456
+ log_file = cfg.train.log_file
457
+
458
+ print('')
459
+ print('##########')
460
+ print('# resume #')
461
+ print('##########')
462
+ print('')
463
+ with open(log_file, 'a') as f:
464
+ print('', file=f)
465
+ print('##########', file=f)
466
+ print('# resume #', file=f)
467
+ print('##########', file=f)
468
+ print('', file=f)
469
+
470
+ pprint.pprint(cfg)
471
+ with open(log_file, 'a') as f:
472
+ pprint.pprint(cfg, f)
473
+
474
+ ####################
475
+ # node distributed #
476
+ ####################
477
+
478
+ if cfg.env.master_addr!='127.0.0.1':
479
+ os.environ['MASTER_ADDR'] = cfge.master_addr
480
+ os.environ['MASTER_PORT'] = '{}'.format(cfge.master_port)
481
+ if cfg.env.dist_backend=='nccl':
482
+ os.environ['NCCL_SOCKET_FAMILY'] = 'AF_INET'
483
+ if cfg.env.dist_backend=='gloo':
484
+ os.environ['GLOO_SOCKET_FAMILY'] = 'AF_INET'
485
+
486
+ #######################
487
+ # cuda visible device #
488
+ #######################
489
+
490
+ os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
491
+ [str(gid) for gid in cfge.gpu_device])
492
+
493
+ #####################
494
+ # return resume cfg #
495
+ #####################
496
+
497
+ if isresume:
498
+ return cfg
499
+
500
+ #############################################
501
+ # some misc setting that not need in resume #
502
+ #############################################
503
+
504
+ cfgm = cfg.model
505
+ cfge.gpu_count = len(cfge.gpu_device)
506
+
507
+ ##########################################
508
+ # align batch size and num worker config #
509
+ ##########################################
510
+
511
+ gpu_n = cfge.gpu_count * cfge.nodes
512
+ def align_batch_size(bs, bs_per_gpu):
513
+ assert (bs is not None) or (bs_per_gpu is not None)
514
+ bs = bs_per_gpu * gpu_n if bs is None else bs
515
+ bs_per_gpu = bs // gpu_n if bs_per_gpu is None else bs_per_gpu
516
+ assert (bs == bs_per_gpu * gpu_n)
517
+ return bs, bs_per_gpu
518
+
519
+ if istrain:
520
+ cfgt.batch_size, cfgt.batch_size_per_gpu = \
521
+ align_batch_size(cfgt.batch_size, cfgt.batch_size_per_gpu)
522
+ cfgt.dataset_num_workers, cfgt.dataset_num_workers_per_gpu = \
523
+ align_batch_size(cfgt.dataset_num_workers, cfgt.dataset_num_workers_per_gpu)
524
+ if haseval:
525
+ cfgv.batch_size, cfgv.batch_size_per_gpu = \
526
+ align_batch_size(cfgv.batch_size, cfgv.batch_size_per_gpu)
527
+ cfgv.dataset_num_workers, cfgv.dataset_num_workers_per_gpu = \
528
+ align_batch_size(cfgv.dataset_num_workers, cfgv.dataset_num_workers_per_gpu)
529
+
530
+ ##################
531
+ # create log dir #
532
+ ##################
533
+
534
+ if istrain:
535
+ if not isdebug:
536
+ sig = cfgt.get('signature', [])
537
+ version = get_model().get_version(cfgm.type)
538
+ sig = sig + ['v{}'.format(version), 's{}'.format(cfge.rnd_seed)]
539
+ else:
540
+ sig = ['debug']
541
+
542
+ log_dir = [
543
+ cfge.log_root_dir,
544
+ '{}_{}'.format(cfgm.symbol, cfgt.dataset.symbol),
545
+ '_'.join([str(cfge.experiment_id)] + sig)
546
+ ]
547
+ log_dir = osp.join(*log_dir)
548
+ log_file = osp.join(log_dir, 'train.log')
549
+ if not osp.exists(log_file):
550
+ os.makedirs(osp.dirname(log_file))
551
+ cfgt.log_dir = log_dir
552
+ cfgt.log_file = log_file
553
+
554
+ if haseval:
555
+ cfgv.log_dir = log_dir
556
+ cfgv.log_file = log_file
557
+ else:
558
+ model_symbol = cfgm.symbol
559
+ if cfgv.get('dataset', None) is None:
560
+ dataset_symbol = 'nodataset'
561
+ else:
562
+ dataset_symbol = cfgv.dataset.symbol
563
+
564
+ log_dir = osp.join(cfge.log_root_dir, '{}_{}'.format(model_symbol, dataset_symbol))
565
+ exp_dir = search_experiment_folder(log_dir, cfge.experiment_id)
566
+ if exp_dir is None:
567
+ if not isdebug:
568
+ sig = cfgv.get('signature', []) + ['evalonly']
569
+ else:
570
+ sig = ['debug']
571
+ exp_dir = '_'.join([str(cfge.experiment_id)] + sig)
572
+
573
+ eval_subdir = cfgv.get('eval_subdir', None)
574
+ # override subdir in debug mode (if eval_subdir is set)
575
+ eval_subdir = 'debug' if (eval_subdir is not None) and isdebug else eval_subdir
576
+
577
+ if eval_subdir is not None:
578
+ log_dir = osp.join(log_dir, exp_dir, eval_subdir)
579
+ else:
580
+ log_dir = osp.join(log_dir, exp_dir)
581
+
582
+ disable_log_override = cfgv.get('disable_log_override', False)
583
+ if osp.isdir(log_dir):
584
+ if disable_log_override:
585
+ assert False, 'Override an exsited log_dir is disabled at [{}]'.format(log_dir)
586
+ else:
587
+ os.makedirs(log_dir)
588
+
589
+ log_file = osp.join(log_dir, 'eval.log')
590
+ cfgv.log_dir = log_dir
591
+ cfgv.log_file = log_file
592
+
593
+ ######################
594
+ # print and save cfg #
595
+ ######################
596
+
597
+ pprint.pprint(cfg)
598
+ with open(log_file, 'w') as f:
599
+ pprint.pprint(cfg, f)
600
+ with open(osp.join(log_dir, 'config.yaml'), 'w') as f:
601
+ yaml.dump(edict_2_dict(cfg), f)
602
+
603
+ #############
604
+ # save code #
605
+ #############
606
+
607
+ save_code = False
608
+ if istrain:
609
+ save_code = cfgt.get('save_code', False)
610
+ elif haseval:
611
+ save_code = cfgv.get('save_code', False)
612
+
613
+ if save_code:
614
+ codedir = osp.join(log_dir, 'code')
615
+ if osp.exists(codedir):
616
+ shutil.rmtree(codedir)
617
+ for d in ['configs', 'lib']:
618
+ fromcodedir = d
619
+ tocodedir = osp.join(codedir, d)
620
+ shutil.copytree(
621
+ fromcodedir, tocodedir,
622
+ ignore=shutil.ignore_patterns(
623
+ '*__pycache__*', '*build*'))
624
+ for codei in os.listdir('.'):
625
+ if osp.splitext(codei)[1] == 'py':
626
+ shutil.copy(codei, codedir)
627
+
628
+ #######################
629
+ # set matplotlib mode #
630
+ #######################
631
+
632
+ if 'matplotlib_mode' in cfge:
633
+ try:
634
+ matplotlib.use(cfge.matplotlib_mode)
635
+ except:
636
+ print('Warning: matplotlib mode [{}] failed to be set!'.format(cfge.matplotlib_mode))
637
+
638
+ return cfg
639
+
640
+ def edict_2_dict(x):
641
+ if isinstance(x, dict):
642
+ xnew = {}
643
+ for k in x:
644
+ xnew[k] = edict_2_dict(x[k])
645
+ return xnew
646
+ elif isinstance(x, list):
647
+ xnew = []
648
+ for i in range(len(x)):
649
+ xnew.append( edict_2_dict(x[i]) )
650
+ return xnew
651
+ else:
652
+ return x
653
+
654
+ def search_experiment_folder(root, exid):
655
+ target = None
656
+ for fi in os.listdir(root):
657
+ if not osp.isdir(osp.join(root, fi)):
658
+ continue
659
+ if int(fi.split('_')[0]) == exid:
660
+ if target is not None:
661
+ return None # duplicated
662
+ elif target is None:
663
+ target = fi
664
+ return target
lib/cfg_holder.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ def singleton(class_):
4
+ instances = {}
5
+ def getinstance(*args, **kwargs):
6
+ if class_ not in instances:
7
+ instances[class_] = class_(*args, **kwargs)
8
+ return instances[class_]
9
+ return getinstance
10
+
11
+ ##############
12
+ # cfg_holder #
13
+ ##############
14
+
15
+ @singleton
16
+ class cfg_unique_holder(object):
17
+ def __init__(self):
18
+ self.cfg = None
19
+ # this is use to track the main codes.
20
+ self.code = set()
21
+ def save_cfg(self, cfg):
22
+ self.cfg = copy.deepcopy(cfg)
23
+ def add_code(self, code):
24
+ """
25
+ A new main code is reached and
26
+ its name is added.
27
+ """
28
+ self.code.add(code)
lib/log_service.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timeit
2
+ import numpy as np
3
+ import os
4
+ import os.path as osp
5
+ import shutil
6
+ import copy
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.distributed as dist
10
+ from .cfg_holder import cfg_unique_holder as cfguh
11
+ from . import sync
12
+
13
+ print_console_local_rank0_only = True
14
+
15
+ def print_log(*console_info):
16
+ local_rank = sync.get_rank('local')
17
+ if print_console_local_rank0_only and (local_rank!=0):
18
+ return
19
+ console_info = [str(i) for i in console_info]
20
+ console_info = ' '.join(console_info)
21
+ print(console_info)
22
+
23
+ if local_rank!=0:
24
+ return
25
+
26
+ log_file = None
27
+ try:
28
+ log_file = cfguh().cfg.train.log_file
29
+ except:
30
+ try:
31
+ log_file = cfguh().cfg.eval.log_file
32
+ except:
33
+ return
34
+ if log_file is not None:
35
+ with open(log_file, 'a') as f:
36
+ f.write(console_info + '\n')
37
+
38
+ class distributed_log_manager(object):
39
+ def __init__(self):
40
+ self.sum = {}
41
+ self.cnt = {}
42
+ self.time_check = timeit.default_timer()
43
+
44
+ cfgt = cfguh().cfg.train
45
+ use_tensorboard = getattr(cfgt, 'log_tensorboard', False)
46
+
47
+ self.ddp = sync.is_ddp()
48
+ self.rank = sync.get_rank('local')
49
+ self.world_size = sync.get_world_size('local')
50
+
51
+ self.tb = None
52
+ if use_tensorboard and (self.rank==0):
53
+ import tensorboardX
54
+ monitoring_dir = osp.join(cfguh().cfg.train.log_dir, 'tensorboard')
55
+ self.tb = tensorboardX.SummaryWriter(osp.join(monitoring_dir))
56
+
57
+ def accumulate(self, n, **data):
58
+ if n < 0:
59
+ raise ValueError
60
+
61
+ for itemn, di in data.items():
62
+ if itemn in self.sum:
63
+ self.sum[itemn] += di * n
64
+ self.cnt[itemn] += n
65
+ else:
66
+ self.sum[itemn] = di * n
67
+ self.cnt[itemn] = n
68
+
69
+ def get_mean_value_dict(self):
70
+ value_gather = [
71
+ self.sum[itemn]/self.cnt[itemn] \
72
+ for itemn in sorted(self.sum.keys()) ]
73
+
74
+ value_gather_tensor = torch.FloatTensor(value_gather).to(self.rank)
75
+ if self.ddp:
76
+ dist.all_reduce(value_gather_tensor, op=dist.ReduceOp.SUM)
77
+ value_gather_tensor /= self.world_size
78
+
79
+ mean = {}
80
+ for idx, itemn in enumerate(sorted(self.sum.keys())):
81
+ mean[itemn] = value_gather_tensor[idx].item()
82
+ return mean
83
+
84
+ def tensorboard_log(self, step, data, mode='train', **extra):
85
+ if self.tb is None:
86
+ return
87
+ if mode == 'train':
88
+ self.tb.add_scalar('other/epochn', extra['epochn'], step)
89
+ if 'lr' in extra:
90
+ self.tb.add_scalar('other/lr', extra['lr'], step)
91
+ for itemn, di in data.items():
92
+ if itemn.find('loss') == 0:
93
+ self.tb.add_scalar('loss/'+itemn, di, step)
94
+ elif itemn == 'Loss':
95
+ self.tb.add_scalar('Loss', di, step)
96
+ else:
97
+ self.tb.add_scalar('other/'+itemn, di, step)
98
+ elif mode == 'eval':
99
+ if isinstance(data, dict):
100
+ for itemn, di in data.items():
101
+ self.tb.add_scalar('eval/'+itemn, di, step)
102
+ else:
103
+ self.tb.add_scalar('eval', data, step)
104
+ return
105
+
106
+ def train_summary(self, itern, epochn, samplen, lr, tbstep=None):
107
+ console_info = [
108
+ 'Iter:{}'.format(itern),
109
+ 'Epoch:{}'.format(epochn),
110
+ 'Sample:{}'.format(samplen),]
111
+
112
+ if lr is not None:
113
+ console_info += ['LR:{:.4E}'.format(lr)]
114
+
115
+ mean = self.get_mean_value_dict()
116
+
117
+ tbstep = itern if tbstep is None else tbstep
118
+ self.tensorboard_log(
119
+ tbstep, mean, mode='train',
120
+ itern=itern, epochn=epochn, lr=lr)
121
+
122
+ loss = mean.pop('Loss')
123
+ mean_info = ['Loss:{:.4f}'.format(loss)] + [
124
+ '{}:{:.4f}'.format(itemn, mean[itemn]) \
125
+ for itemn in sorted(mean.keys()) \
126
+ if itemn.find('loss') == 0
127
+ ]
128
+ console_info += mean_info
129
+ console_info.append('Time:{:.2f}s'.format(
130
+ timeit.default_timer() - self.time_check))
131
+ return ' , '.join(console_info)
132
+
133
+ def clear(self):
134
+ self.sum = {}
135
+ self.cnt = {}
136
+ self.time_check = timeit.default_timer()
137
+
138
+ def tensorboard_close(self):
139
+ if self.tb is not None:
140
+ self.tb.close()
141
+
142
+ # ----- also include some small utils -----
143
+
144
+ def torch_to_numpy(*argv):
145
+ if len(argv) > 1:
146
+ data = list(argv)
147
+ else:
148
+ data = argv[0]
149
+
150
+ if isinstance(data, torch.Tensor):
151
+ return data.to('cpu').detach().numpy()
152
+
153
+ elif isinstance(data, (list, tuple)):
154
+ out = []
155
+ for di in data:
156
+ out.append(torch_to_numpy(di))
157
+ return out
158
+
159
+ elif isinstance(data, dict):
160
+ out = {}
161
+ for ni, di in data.items():
162
+ out[ni] = torch_to_numpy(di)
163
+ return out
164
+
165
+ else:
166
+ return data
lib/model_zoo/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .common.get_model import get_model
2
+ from .common.get_optimizer import get_optimizer
3
+ from .common.get_scheduler import get_scheduler
4
+ from .common.utils import get_unit
lib/model_zoo/attention.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn, einsum
6
+ from einops import rearrange, repeat
7
+
8
+ from .diffusion_utils import checkpoint
9
+
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+
15
+ def uniq(arr):
16
+ return{el: True for el in arr}.keys()
17
+
18
+
19
+ def default(val, d):
20
+ if exists(val):
21
+ return val
22
+ return d() if isfunction(d) else d
23
+
24
+
25
+ def max_neg_value(t):
26
+ return -torch.finfo(t.dtype).max
27
+
28
+
29
+ def init_(tensor):
30
+ dim = tensor.shape[-1]
31
+ std = 1 / math.sqrt(dim)
32
+ tensor.uniform_(-std, std)
33
+ return tensor
34
+
35
+
36
+ # feedforward
37
+ class GEGLU(nn.Module):
38
+ def __init__(self, dim_in, dim_out):
39
+ super().__init__()
40
+ self.proj = nn.Linear(dim_in, dim_out * 2)
41
+
42
+ def forward(self, x):
43
+ x, gate = self.proj(x).chunk(2, dim=-1)
44
+ return x * F.gelu(gate)
45
+
46
+
47
+ class FeedForward(nn.Module):
48
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
49
+ super().__init__()
50
+ inner_dim = int(dim * mult)
51
+ dim_out = default(dim_out, dim)
52
+ project_in = nn.Sequential(
53
+ nn.Linear(dim, inner_dim),
54
+ nn.GELU()
55
+ ) if not glu else GEGLU(dim, inner_dim)
56
+
57
+ self.net = nn.Sequential(
58
+ project_in,
59
+ nn.Dropout(dropout),
60
+ nn.Linear(inner_dim, dim_out)
61
+ )
62
+
63
+ def forward(self, x):
64
+ return self.net(x)
65
+
66
+
67
+ def zero_module(module):
68
+ """
69
+ Zero out the parameters of a module and return it.
70
+ """
71
+ for p in module.parameters():
72
+ p.detach().zero_()
73
+ return module
74
+
75
+
76
+ def Normalize(in_channels):
77
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
78
+
79
+
80
+ class LinearAttention(nn.Module):
81
+ def __init__(self, dim, heads=4, dim_head=32):
82
+ super().__init__()
83
+ self.heads = heads
84
+ hidden_dim = dim_head * heads
85
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
86
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
87
+
88
+ def forward(self, x):
89
+ b, c, h, w = x.shape
90
+ qkv = self.to_qkv(x)
91
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
92
+ k = k.softmax(dim=-1)
93
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
94
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
95
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
96
+ return self.to_out(out)
97
+
98
+
99
+ class SpatialSelfAttention(nn.Module):
100
+ def __init__(self, in_channels):
101
+ super().__init__()
102
+ self.in_channels = in_channels
103
+
104
+ self.norm = Normalize(in_channels)
105
+ self.q = torch.nn.Conv2d(in_channels,
106
+ in_channels,
107
+ kernel_size=1,
108
+ stride=1,
109
+ padding=0)
110
+ self.k = torch.nn.Conv2d(in_channels,
111
+ in_channels,
112
+ kernel_size=1,
113
+ stride=1,
114
+ padding=0)
115
+ self.v = torch.nn.Conv2d(in_channels,
116
+ in_channels,
117
+ kernel_size=1,
118
+ stride=1,
119
+ padding=0)
120
+ self.proj_out = torch.nn.Conv2d(in_channels,
121
+ in_channels,
122
+ kernel_size=1,
123
+ stride=1,
124
+ padding=0)
125
+
126
+ def forward(self, x):
127
+ h_ = x
128
+ h_ = self.norm(h_)
129
+ q = self.q(h_)
130
+ k = self.k(h_)
131
+ v = self.v(h_)
132
+
133
+ # compute attention
134
+ b,c,h,w = q.shape
135
+ q = rearrange(q, 'b c h w -> b (h w) c')
136
+ k = rearrange(k, 'b c h w -> b c (h w)')
137
+ w_ = torch.einsum('bij,bjk->bik', q, k)
138
+
139
+ w_ = w_ * (int(c)**(-0.5))
140
+ w_ = torch.nn.functional.softmax(w_, dim=2)
141
+
142
+ # attend to values
143
+ v = rearrange(v, 'b c h w -> b c (h w)')
144
+ w_ = rearrange(w_, 'b i j -> b j i')
145
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
146
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
147
+ h_ = self.proj_out(h_)
148
+
149
+ return x+h_
150
+
151
+
152
+ class CrossAttention(nn.Module):
153
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
154
+ super().__init__()
155
+ inner_dim = dim_head * heads
156
+ context_dim = default(context_dim, query_dim)
157
+
158
+ self.scale = dim_head ** -0.5
159
+ self.heads = heads
160
+
161
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
162
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
163
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
164
+
165
+ self.to_out = nn.Sequential(
166
+ nn.Linear(inner_dim, query_dim),
167
+ nn.Dropout(dropout)
168
+ )
169
+
170
+ def forward(self, x, context=None, mask=None):
171
+ h = self.heads
172
+
173
+ q = self.to_q(x)
174
+ context = default(context, x)
175
+ k = self.to_k(context)
176
+ v = self.to_v(context)
177
+
178
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
179
+
180
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
181
+
182
+ if exists(mask):
183
+ mask = rearrange(mask, 'b ... -> b (...)')
184
+ max_neg_value = -torch.finfo(sim.dtype).max
185
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
186
+ sim.masked_fill_(~mask, max_neg_value)
187
+
188
+ # attention, what we cannot get enough of
189
+ attn = sim.softmax(dim=-1)
190
+
191
+ out = einsum('b i j, b j d -> b i d', attn, v)
192
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
193
+ return self.to_out(out)
194
+
195
+
196
+ class BasicTransformerBlock(nn.Module):
197
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
198
+ disable_self_attn=False):
199
+ super().__init__()
200
+ self.disable_self_attn = disable_self_attn
201
+ self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
202
+ context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
203
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
204
+ self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
205
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
206
+ self.norm1 = nn.LayerNorm(dim)
207
+ self.norm2 = nn.LayerNorm(dim)
208
+ self.norm3 = nn.LayerNorm(dim)
209
+ self.checkpoint = checkpoint
210
+
211
+ def forward(self, x, context=None):
212
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
213
+
214
+ def _forward(self, x, context=None):
215
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
216
+ x = self.attn2(self.norm2(x), context=context) + x
217
+ x = self.ff(self.norm3(x)) + x
218
+ return x
219
+
220
+
221
+ class SpatialTransformer(nn.Module):
222
+ """
223
+ Transformer block for image-like data.
224
+ First, project the input (aka embedding)
225
+ and reshape to b, t, d.
226
+ Then apply standard transformer action.
227
+ Finally, reshape to image
228
+ """
229
+ def __init__(self, in_channels, n_heads, d_head,
230
+ depth=1, dropout=0., context_dim=None,
231
+ disable_self_attn=False):
232
+ super().__init__()
233
+ self.in_channels = in_channels
234
+ inner_dim = n_heads * d_head
235
+ self.norm = Normalize(in_channels)
236
+
237
+ self.proj_in = nn.Conv2d(in_channels,
238
+ inner_dim,
239
+ kernel_size=1,
240
+ stride=1,
241
+ padding=0)
242
+
243
+ self.transformer_blocks = nn.ModuleList(
244
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim,
245
+ disable_self_attn=disable_self_attn)
246
+ for d in range(depth)]
247
+ )
248
+
249
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
250
+ in_channels,
251
+ kernel_size=1,
252
+ stride=1,
253
+ padding=0))
254
+
255
+ def forward(self, x, context=None):
256
+ # note: if no context is given, cross-attention defaults to self-attention
257
+ b, c, h, w = x.shape
258
+ x_in = x
259
+ x = self.norm(x)
260
+ x = self.proj_in(x)
261
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
262
+ for block in self.transformer_blocks:
263
+ x = block(x, context=context)
264
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
265
+ x = self.proj_out(x)
266
+ return x + x_in
267
+
268
+
269
+ ##########################
270
+ # transformer no context #
271
+ ##########################
272
+
273
+ class BasicTransformerBlockNoContext(nn.Module):
274
+ def __init__(self, dim, n_heads, d_head, dropout=0., gated_ff=True, checkpoint=True):
275
+ super().__init__()
276
+ self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head,
277
+ dropout=dropout, context_dim=None)
278
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
279
+ self.attn2 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head,
280
+ dropout=dropout, context_dim=None)
281
+ self.norm1 = nn.LayerNorm(dim)
282
+ self.norm2 = nn.LayerNorm(dim)
283
+ self.norm3 = nn.LayerNorm(dim)
284
+ self.checkpoint = checkpoint
285
+
286
+ def forward(self, x):
287
+ return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint)
288
+
289
+ def _forward(self, x):
290
+ x = self.attn1(self.norm1(x)) + x
291
+ x = self.attn2(self.norm2(x)) + x
292
+ x = self.ff(self.norm3(x)) + x
293
+ return x
294
+
295
+ class SpatialTransformerNoContext(nn.Module):
296
+ """
297
+ Transformer block for image-like data.
298
+ First, project the input (aka embedding)
299
+ and reshape to b, t, d.
300
+ Then apply standard transformer action.
301
+ Finally, reshape to image
302
+ """
303
+ def __init__(self, in_channels, n_heads, d_head,
304
+ depth=1, dropout=0.,):
305
+ super().__init__()
306
+ self.in_channels = in_channels
307
+ inner_dim = n_heads * d_head
308
+ self.norm = Normalize(in_channels)
309
+
310
+ self.proj_in = nn.Conv2d(in_channels,
311
+ inner_dim,
312
+ kernel_size=1,
313
+ stride=1,
314
+ padding=0)
315
+
316
+ self.transformer_blocks = nn.ModuleList(
317
+ [BasicTransformerBlockNoContext(inner_dim, n_heads, d_head, dropout=dropout)
318
+ for d in range(depth)]
319
+ )
320
+
321
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
322
+ in_channels,
323
+ kernel_size=1,
324
+ stride=1,
325
+ padding=0))
326
+
327
+ def forward(self, x):
328
+ # note: if no context is given, cross-attention defaults to self-attention
329
+ b, c, h, w = x.shape
330
+ x_in = x
331
+ x = self.norm(x)
332
+ x = self.proj_in(x)
333
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
334
+ for block in self.transformer_blocks:
335
+ x = block(x)
336
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
337
+ x = self.proj_out(x)
338
+ return x + x_in
339
+
340
+
341
+ #######################################
342
+ # Spatial Transformer with Two Branch #
343
+ #######################################
344
+
345
+ class DualSpatialTransformer(nn.Module):
346
+ def __init__(self, in_channels, n_heads, d_head,
347
+ depth=1, dropout=0., context_dim=None,
348
+ disable_self_attn=False):
349
+ super().__init__()
350
+ self.in_channels = in_channels
351
+ inner_dim = n_heads * d_head
352
+
353
+ # First crossattn
354
+ self.norm_0 = Normalize(in_channels)
355
+ self.proj_in_0 = nn.Conv2d(
356
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
357
+ self.transformer_blocks_0 = nn.ModuleList(
358
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim,
359
+ disable_self_attn=disable_self_attn)
360
+ for d in range(depth)]
361
+ )
362
+ self.proj_out_0 = zero_module(nn.Conv2d(
363
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
364
+
365
+ # Second crossattn
366
+ self.norm_1 = Normalize(in_channels)
367
+ self.proj_in_1 = nn.Conv2d(
368
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
369
+ self.transformer_blocks_1 = nn.ModuleList(
370
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim,
371
+ disable_self_attn=disable_self_attn)
372
+ for d in range(depth)]
373
+ )
374
+ self.proj_out_1 = zero_module(nn.Conv2d(
375
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
376
+
377
+ def forward(self, x, context=None, which=None):
378
+ # note: if no context is given, cross-attention defaults to self-attention
379
+ b, c, h, w = x.shape
380
+ x_in = x
381
+ if which==0:
382
+ norm, proj_in, blocks, proj_out = \
383
+ self.norm_0, self.proj_in_0, self.transformer_blocks_0, self.proj_out_0
384
+ elif which==1:
385
+ norm, proj_in, blocks, proj_out = \
386
+ self.norm_1, self.proj_in_1, self.transformer_blocks_1, self.proj_out_1
387
+ else:
388
+ # assert False, 'DualSpatialTransformer forward with a invalid which branch!'
389
+ # import numpy.random as npr
390
+ # rwhich = 0 if npr.rand() < which else 1
391
+ # context = context[rwhich]
392
+ # if rwhich==0:
393
+ # norm, proj_in, blocks, proj_out = \
394
+ # self.norm_0, self.proj_in_0, self.transformer_blocks_0, self.proj_out_0
395
+ # elif rwhich==1:
396
+ # norm, proj_in, blocks, proj_out = \
397
+ # self.norm_1, self.proj_in_1, self.transformer_blocks_1, self.proj_out_1
398
+
399
+ # import numpy.random as npr
400
+ # rwhich = 0 if npr.rand() < 0.33 else 1
401
+ # if rwhich==0:
402
+ # context = context[rwhich]
403
+ # norm, proj_in, blocks, proj_out = \
404
+ # self.norm_0, self.proj_in_0, self.transformer_blocks_0, self.proj_out_0
405
+ # else:
406
+
407
+ norm, proj_in, blocks, proj_out = \
408
+ self.norm_0, self.proj_in_0, self.transformer_blocks_0, self.proj_out_0
409
+ x0 = norm(x)
410
+ x0 = proj_in(x0)
411
+ x0 = rearrange(x0, 'b c h w -> b (h w) c').contiguous()
412
+ for block in blocks:
413
+ x0 = block(x0, context=context[0])
414
+ x0 = rearrange(x0, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
415
+ x0 = proj_out(x0)
416
+
417
+ norm, proj_in, blocks, proj_out = \
418
+ self.norm_1, self.proj_in_1, self.transformer_blocks_1, self.proj_out_1
419
+ x1 = norm(x)
420
+ x1 = proj_in(x1)
421
+ x1 = rearrange(x1, 'b c h w -> b (h w) c').contiguous()
422
+ for block in blocks:
423
+ x1 = block(x1, context=context[1])
424
+ x1 = rearrange(x1, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
425
+ x1 = proj_out(x1)
426
+ return x0*which + x1*(1-which) + x_in
427
+
428
+ x = norm(x)
429
+ x = proj_in(x)
430
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
431
+ for block in blocks:
432
+ x = block(x, context=context)
433
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
434
+ x = proj_out(x)
435
+ return x + x_in
lib/model_zoo/autoencoder.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from contextlib import contextmanager
5
+ from lib.model_zoo.common.get_model import get_model, register
6
+
7
+ from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
8
+
9
+ from .diffusion_modules import Encoder, Decoder
10
+ from .distributions import DiagonalGaussianDistribution
11
+
12
+
13
+ class VQModel(nn.Module):
14
+ def __init__(self,
15
+ ddconfig,
16
+ lossconfig,
17
+ n_embed,
18
+ embed_dim,
19
+ ckpt_path=None,
20
+ ignore_keys=[],
21
+ image_key="image",
22
+ colorize_nlabels=None,
23
+ monitor=None,
24
+ batch_resize_range=None,
25
+ scheduler_config=None,
26
+ lr_g_factor=1.0,
27
+ remap=None,
28
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
29
+ use_ema=False
30
+ ):
31
+ super().__init__()
32
+ self.embed_dim = embed_dim
33
+ self.n_embed = n_embed
34
+ self.image_key = image_key
35
+ self.encoder = Encoder(**ddconfig)
36
+ self.decoder = Decoder(**ddconfig)
37
+ self.loss = instantiate_from_config(lossconfig)
38
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
39
+ remap=remap,
40
+ sane_index_shape=sane_index_shape)
41
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
42
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
43
+ if colorize_nlabels is not None:
44
+ assert type(colorize_nlabels)==int
45
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
46
+ if monitor is not None:
47
+ self.monitor = monitor
48
+ self.batch_resize_range = batch_resize_range
49
+ if self.batch_resize_range is not None:
50
+ print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
51
+
52
+ self.use_ema = use_ema
53
+ if self.use_ema:
54
+ self.model_ema = LitEma(self)
55
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
56
+
57
+ if ckpt_path is not None:
58
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
59
+ self.scheduler_config = scheduler_config
60
+ self.lr_g_factor = lr_g_factor
61
+
62
+ @contextmanager
63
+ def ema_scope(self, context=None):
64
+ if self.use_ema:
65
+ self.model_ema.store(self.parameters())
66
+ self.model_ema.copy_to(self)
67
+ if context is not None:
68
+ print(f"{context}: Switched to EMA weights")
69
+ try:
70
+ yield None
71
+ finally:
72
+ if self.use_ema:
73
+ self.model_ema.restore(self.parameters())
74
+ if context is not None:
75
+ print(f"{context}: Restored training weights")
76
+
77
+ def init_from_ckpt(self, path, ignore_keys=list()):
78
+ sd = torch.load(path, map_location="cpu")["state_dict"]
79
+ keys = list(sd.keys())
80
+ for k in keys:
81
+ for ik in ignore_keys:
82
+ if k.startswith(ik):
83
+ print("Deleting key {} from state_dict.".format(k))
84
+ del sd[k]
85
+ missing, unexpected = self.load_state_dict(sd, strict=False)
86
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
87
+ if len(missing) > 0:
88
+ print(f"Missing Keys: {missing}")
89
+ print(f"Unexpected Keys: {unexpected}")
90
+
91
+ def on_train_batch_end(self, *args, **kwargs):
92
+ if self.use_ema:
93
+ self.model_ema(self)
94
+
95
+ def encode(self, x):
96
+ h = self.encoder(x)
97
+ h = self.quant_conv(h)
98
+ quant, emb_loss, info = self.quantize(h)
99
+ return quant, emb_loss, info
100
+
101
+ def encode_to_prequant(self, x):
102
+ h = self.encoder(x)
103
+ h = self.quant_conv(h)
104
+ return h
105
+
106
+ def decode(self, quant):
107
+ quant = self.post_quant_conv(quant)
108
+ dec = self.decoder(quant)
109
+ return dec
110
+
111
+ def decode_code(self, code_b):
112
+ quant_b = self.quantize.embed_code(code_b)
113
+ dec = self.decode(quant_b)
114
+ return dec
115
+
116
+ def forward(self, input, return_pred_indices=False):
117
+ quant, diff, (_,_,ind) = self.encode(input)
118
+ dec = self.decode(quant)
119
+ if return_pred_indices:
120
+ return dec, diff, ind
121
+ return dec, diff
122
+
123
+ def get_input(self, batch, k):
124
+ x = batch[k]
125
+ if len(x.shape) == 3:
126
+ x = x[..., None]
127
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
128
+ if self.batch_resize_range is not None:
129
+ lower_size = self.batch_resize_range[0]
130
+ upper_size = self.batch_resize_range[1]
131
+ if self.global_step <= 4:
132
+ # do the first few batches with max size to avoid later oom
133
+ new_resize = upper_size
134
+ else:
135
+ new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
136
+ if new_resize != x.shape[2]:
137
+ x = F.interpolate(x, size=new_resize, mode="bicubic")
138
+ x = x.detach()
139
+ return x
140
+
141
+ def training_step(self, batch, batch_idx, optimizer_idx):
142
+ # https://github.com/pytorch/pytorch/issues/37142
143
+ # try not to fool the heuristics
144
+ x = self.get_input(batch, self.image_key)
145
+ xrec, qloss, ind = self(x, return_pred_indices=True)
146
+
147
+ if optimizer_idx == 0:
148
+ # autoencode
149
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
150
+ last_layer=self.get_last_layer(), split="train",
151
+ predicted_indices=ind)
152
+
153
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
154
+ return aeloss
155
+
156
+ if optimizer_idx == 1:
157
+ # discriminator
158
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
159
+ last_layer=self.get_last_layer(), split="train")
160
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
161
+ return discloss
162
+
163
+ def validation_step(self, batch, batch_idx):
164
+ log_dict = self._validation_step(batch, batch_idx)
165
+ with self.ema_scope():
166
+ log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
167
+ return log_dict
168
+
169
+ def _validation_step(self, batch, batch_idx, suffix=""):
170
+ x = self.get_input(batch, self.image_key)
171
+ xrec, qloss, ind = self(x, return_pred_indices=True)
172
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
173
+ self.global_step,
174
+ last_layer=self.get_last_layer(),
175
+ split="val"+suffix,
176
+ predicted_indices=ind
177
+ )
178
+
179
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
180
+ self.global_step,
181
+ last_layer=self.get_last_layer(),
182
+ split="val"+suffix,
183
+ predicted_indices=ind
184
+ )
185
+ rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
186
+ self.log(f"val{suffix}/rec_loss", rec_loss,
187
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
188
+ self.log(f"val{suffix}/aeloss", aeloss,
189
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
190
+ if version.parse(pl.__version__) >= version.parse('1.4.0'):
191
+ del log_dict_ae[f"val{suffix}/rec_loss"]
192
+ self.log_dict(log_dict_ae)
193
+ self.log_dict(log_dict_disc)
194
+ return self.log_dict
195
+
196
+ def configure_optimizers(self):
197
+ lr_d = self.learning_rate
198
+ lr_g = self.lr_g_factor*self.learning_rate
199
+ print("lr_d", lr_d)
200
+ print("lr_g", lr_g)
201
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
202
+ list(self.decoder.parameters())+
203
+ list(self.quantize.parameters())+
204
+ list(self.quant_conv.parameters())+
205
+ list(self.post_quant_conv.parameters()),
206
+ lr=lr_g, betas=(0.5, 0.9))
207
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
208
+ lr=lr_d, betas=(0.5, 0.9))
209
+
210
+ if self.scheduler_config is not None:
211
+ scheduler = instantiate_from_config(self.scheduler_config)
212
+
213
+ print("Setting up LambdaLR scheduler...")
214
+ scheduler = [
215
+ {
216
+ 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
217
+ 'interval': 'step',
218
+ 'frequency': 1
219
+ },
220
+ {
221
+ 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
222
+ 'interval': 'step',
223
+ 'frequency': 1
224
+ },
225
+ ]
226
+ return [opt_ae, opt_disc], scheduler
227
+ return [opt_ae, opt_disc], []
228
+
229
+ def get_last_layer(self):
230
+ return self.decoder.conv_out.weight
231
+
232
+ def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
233
+ log = dict()
234
+ x = self.get_input(batch, self.image_key)
235
+ x = x.to(self.device)
236
+ if only_inputs:
237
+ log["inputs"] = x
238
+ return log
239
+ xrec, _ = self(x)
240
+ if x.shape[1] > 3:
241
+ # colorize with random projection
242
+ assert xrec.shape[1] > 3
243
+ x = self.to_rgb(x)
244
+ xrec = self.to_rgb(xrec)
245
+ log["inputs"] = x
246
+ log["reconstructions"] = xrec
247
+ if plot_ema:
248
+ with self.ema_scope():
249
+ xrec_ema, _ = self(x)
250
+ if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
251
+ log["reconstructions_ema"] = xrec_ema
252
+ return log
253
+
254
+ def to_rgb(self, x):
255
+ assert self.image_key == "segmentation"
256
+ if not hasattr(self, "colorize"):
257
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
258
+ x = F.conv2d(x, weight=self.colorize)
259
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
260
+ return x
261
+
262
+
263
+ class VQModelInterface(VQModel):
264
+ def __init__(self, embed_dim, *args, **kwargs):
265
+ super().__init__(embed_dim=embed_dim, *args, **kwargs)
266
+ self.embed_dim = embed_dim
267
+
268
+ def encode(self, x):
269
+ h = self.encoder(x)
270
+ h = self.quant_conv(h)
271
+ return h
272
+
273
+ def decode(self, h, force_not_quantize=False):
274
+ # also go through quantization layer
275
+ if not force_not_quantize:
276
+ quant, emb_loss, info = self.quantize(h)
277
+ else:
278
+ quant = h
279
+ quant = self.post_quant_conv(quant)
280
+ dec = self.decoder(quant)
281
+ return dec
282
+
283
+
284
+ @register('autoencoderkl')
285
+ class AutoencoderKL(nn.Module):
286
+ def __init__(self,
287
+ ddconfig,
288
+ lossconfig,
289
+ embed_dim,
290
+ ckpt_path=None,
291
+ ignore_keys=[],
292
+ image_key="image",
293
+ colorize_nlabels=None,
294
+ monitor=None,):
295
+ super().__init__()
296
+ self.image_key = image_key
297
+ self.encoder = Encoder(**ddconfig)
298
+ self.decoder = Decoder(**ddconfig)
299
+ assert ddconfig["double_z"]
300
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
301
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
302
+ self.embed_dim = embed_dim
303
+ if colorize_nlabels is not None:
304
+ assert type(colorize_nlabels)==int
305
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
306
+ if monitor is not None:
307
+ self.monitor = monitor
308
+
309
+ def encode(self, x):
310
+ h = self.encoder(x)
311
+ moments = self.quant_conv(h)
312
+ posterior = DiagonalGaussianDistribution(moments)
313
+ return posterior
314
+
315
+ def decode(self, z):
316
+ z = self.post_quant_conv(z)
317
+ dec = self.decoder(z)
318
+ return dec
319
+
320
+ def forward(self, input, sample_posterior=True):
321
+ posterior = self.encode(input)
322
+ if sample_posterior:
323
+ z = posterior.sample()
324
+ else:
325
+ z = posterior.mode()
326
+ dec = self.decode(z)
327
+ return dec, posterior
328
+
329
+ def get_input(self, batch, k):
330
+ x = batch[k]
331
+ if len(x.shape) == 3:
332
+ x = x[..., None]
333
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
334
+ return x
335
+
336
+ def training_step(self, batch, batch_idx, optimizer_idx):
337
+ inputs = self.get_input(batch, self.image_key)
338
+ reconstructions, posterior = self(inputs)
339
+
340
+ if optimizer_idx == 0:
341
+ # train encoder+decoder+logvar
342
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
343
+ last_layer=self.get_last_layer(), split="train")
344
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
345
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
346
+ return aeloss
347
+
348
+ if optimizer_idx == 1:
349
+ # train the discriminator
350
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
351
+ last_layer=self.get_last_layer(), split="train")
352
+
353
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
354
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
355
+ return discloss
356
+
357
+ def validation_step(self, batch, batch_idx):
358
+ inputs = self.get_input(batch, self.image_key)
359
+ reconstructions, posterior = self(inputs)
360
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
361
+ last_layer=self.get_last_layer(), split="val")
362
+
363
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
364
+ last_layer=self.get_last_layer(), split="val")
365
+
366
+ self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
367
+ self.log_dict(log_dict_ae)
368
+ self.log_dict(log_dict_disc)
369
+ return self.log_dict
370
+
371
+ def configure_optimizers(self):
372
+ lr = self.learning_rate
373
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
374
+ list(self.decoder.parameters())+
375
+ list(self.quant_conv.parameters())+
376
+ list(self.post_quant_conv.parameters()),
377
+ lr=lr, betas=(0.5, 0.9))
378
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
379
+ lr=lr, betas=(0.5, 0.9))
380
+ return [opt_ae, opt_disc], []
381
+
382
+ def get_last_layer(self):
383
+ return self.decoder.conv_out.weight
384
+
385
+ @torch.no_grad()
386
+ def log_images(self, batch, only_inputs=False, **kwargs):
387
+ log = dict()
388
+ x = self.get_input(batch, self.image_key)
389
+ x = x.to(self.device)
390
+ if not only_inputs:
391
+ xrec, posterior = self(x)
392
+ if x.shape[1] > 3:
393
+ # colorize with random projection
394
+ assert xrec.shape[1] > 3
395
+ x = self.to_rgb(x)
396
+ xrec = self.to_rgb(xrec)
397
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
398
+ log["reconstructions"] = xrec
399
+ log["inputs"] = x
400
+ return log
401
+
402
+ def to_rgb(self, x):
403
+ assert self.image_key == "segmentation"
404
+ if not hasattr(self, "colorize"):
405
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
406
+ x = F.conv2d(x, weight=self.colorize)
407
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
408
+ return x
409
+
410
+
411
+ class IdentityFirstStage(nn.Module):
412
+ def __init__(self, *args, vq_interface=False, **kwargs):
413
+ self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
414
+ super().__init__()
415
+
416
+ def encode(self, x, *args, **kwargs):
417
+ return x
418
+
419
+ def decode(self, x, *args, **kwargs):
420
+ return x
421
+
422
+ def quantize(self, x, *args, **kwargs):
423
+ if self.vq_interface:
424
+ return x, None, [None, None, None]
425
+ return x
426
+
427
+ def forward(self, x, *args, **kwargs):
428
+ return x
lib/model_zoo/bert.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from functools import partial
4
+
5
+ # from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
6
+
7
+
8
+ class AbstractEncoder(nn.Module):
9
+ def __init__(self):
10
+ super().__init__()
11
+
12
+ def encode(self, *args, **kwargs):
13
+ raise NotImplementedError
14
+
15
+
16
+
17
+ class ClassEmbedder(nn.Module):
18
+ def __init__(self, embed_dim, n_classes=1000, key='class'):
19
+ super().__init__()
20
+ self.key = key
21
+ self.embedding = nn.Embedding(n_classes, embed_dim)
22
+
23
+ def forward(self, batch, key=None):
24
+ if key is None:
25
+ key = self.key
26
+ # this is for use in crossattn
27
+ c = batch[key][:, None]
28
+ c = self.embedding(c)
29
+ return c
30
+
31
+
32
+ class TransformerEmbedder(AbstractEncoder):
33
+ """Some transformer encoder layers"""
34
+ def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77):
35
+ super().__init__()
36
+ self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
37
+ attn_layers=Encoder(dim=n_embed, depth=n_layer))
38
+
39
+ def forward(self, tokens):
40
+ z = self.transformer(tokens, return_embeddings=True)
41
+ return z
42
+
43
+ def encode(self, x):
44
+ return self(x)
45
+
46
+
47
+ class BERTTokenizer(AbstractEncoder):
48
+ """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
49
+ def __init__(self, device="cuda", vq_interface=True, max_length=77):
50
+ super().__init__()
51
+ from transformers import BertTokenizerFast # TODO: add to reuquirements
52
+ self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
53
+ self.vq_interface = vq_interface
54
+ self.max_length = max_length
55
+
56
+ def forward(self, text):
57
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
58
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
59
+ tokens = batch_encoding["input_ids"]
60
+ return tokens
61
+
62
+ @torch.no_grad()
63
+ def encode(self, text):
64
+ tokens = self(text)
65
+ if not self.vq_interface:
66
+ return tokens
67
+ return None, None, [None, None, tokens]
68
+
69
+ def decode(self, text):
70
+ return text
71
+
72
+
73
+ class BERTEmbedder(AbstractEncoder):
74
+ """Uses the BERT tokenizr model and add some transformer encoder layers"""
75
+ def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
76
+ ckpt_path=None, ignore_keys=[], device="cuda", use_tokenizer=True, embedding_dropout=0.0):
77
+ super().__init__()
78
+ self.use_tknz_fn = use_tokenizer
79
+ if self.use_tknz_fn:
80
+ self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
81
+ self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
82
+ attn_layers=Encoder(dim=n_embed, depth=n_layer),
83
+ emb_dropout=embedding_dropout)
84
+ if ckpt_path is not None:
85
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
86
+
87
+ def init_from_ckpt(self, path, ignore_keys=list()):
88
+ sd = torch.load(path, map_location="cpu")
89
+ keys = list(sd.keys())
90
+ for k in keys:
91
+ for ik in ignore_keys:
92
+ if k.startswith(ik):
93
+ print("Deleting key {} from state_dict.".format(k))
94
+ del sd[k]
95
+ missing, unexpected = self.load_state_dict(sd, strict=False)
96
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
97
+
98
+ def forward(self, text):
99
+ if self.use_tknz_fn:
100
+ tokens = self.tknz_fn(text)
101
+ else:
102
+ tokens = text
103
+ device = self.transformer.token_emb.weight.device # a trick to get device
104
+ tokens = tokens.to(device)
105
+ z = self.transformer(tokens, return_embeddings=True)
106
+ return z
107
+
108
+ def encode(self, text):
109
+ # output of length 77
110
+ return self(text)
111
+
112
+
113
+ class SpatialRescaler(nn.Module):
114
+ def __init__(self,
115
+ n_stages=1,
116
+ method='bilinear',
117
+ multiplier=0.5,
118
+ in_channels=3,
119
+ out_channels=None,
120
+ bias=False):
121
+ super().__init__()
122
+ self.n_stages = n_stages
123
+ assert self.n_stages >= 0
124
+ assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
125
+ self.multiplier = multiplier
126
+ self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
127
+ self.remap_output = out_channels is not None
128
+ if self.remap_output:
129
+ print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
130
+ self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
131
+
132
+ def forward(self,x):
133
+ for stage in range(self.n_stages):
134
+ x = self.interpolator(x, scale_factor=self.multiplier)
135
+
136
+
137
+ if self.remap_output:
138
+ x = self.channel_mapper(x)
139
+ return x
140
+
141
+ def encode(self, x):
142
+ return self(x)
lib/model_zoo/clip.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from functools import partial
5
+ from lib.model_zoo.common.get_model import register
6
+
7
+ version = '0'
8
+ symbol = 'clip'
9
+
10
+ class AbstractEncoder(nn.Module):
11
+ def __init__(self):
12
+ super().__init__()
13
+
14
+ def encode(self, *args, **kwargs):
15
+ raise NotImplementedError
16
+
17
+ from transformers import CLIPTokenizer, CLIPTextModel
18
+
19
+ def disabled_train(self, mode=True):
20
+ """Overwrite model.train with this function to make sure train/eval mode
21
+ does not change anymore."""
22
+ return self
23
+
24
+ @register('clip_text_frozen', version)
25
+ class FrozenCLIPTextEmbedder(AbstractEncoder):
26
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
27
+ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32
28
+ super().__init__()
29
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
30
+ self.transformer = CLIPTextModel.from_pretrained(version)
31
+ self.device = device
32
+ self.max_length = max_length # TODO: typical value?
33
+ self.freeze()
34
+
35
+ def freeze(self):
36
+ self.transformer = self.transformer.eval()
37
+ #self.train = disabled_train
38
+ for param in self.parameters():
39
+ param.requires_grad = False
40
+
41
+ def forward(self, text):
42
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
43
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
44
+ tokens = batch_encoding["input_ids"].to(self.device)
45
+ outputs = self.transformer(input_ids=tokens)
46
+ z = outputs.last_hidden_state
47
+ return z
48
+
49
+ def encode(self, text):
50
+ return self(text)
51
+
52
+ from transformers import CLIPProcessor, CLIPModel
53
+
54
+ @register('clip_frozen', version)
55
+ class FrozenCLIP(AbstractEncoder):
56
+ def __init__(self,
57
+ version="openai/clip-vit-large-patch14",
58
+ max_length=77,
59
+ encode_type='encode_text',
60
+ fp16=False, ):
61
+ super().__init__()
62
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
63
+ self.processor = CLIPProcessor.from_pretrained(version)
64
+ self.model = CLIPModel.from_pretrained(version)
65
+ self.max_length = max_length # TODO: typical value?
66
+ self.encode_type = encode_type
67
+ self.fp16 = fp16
68
+ self.freeze()
69
+
70
+ def get_device(self):
71
+ # A trick to get device
72
+ return self.model.text_projection.weight.device
73
+
74
+ def freeze(self):
75
+ self.model = self.model.eval()
76
+ self.train = disabled_train
77
+ for param in self.parameters():
78
+ param.requires_grad = False
79
+
80
+ def encode_text_pooled(self, text):
81
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
82
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
83
+ tokens = batch_encoding["input_ids"].to(self.get_device())
84
+ outputs = self.model.get_text_features(input_ids=tokens)
85
+ return outputs
86
+
87
+ def encode_vision_pooled(self, images):
88
+ inputs = self.processor(images=images, return_tensors="pt")
89
+ pixels = inputs['pixel_values'].half() if self.fp16 else inputs['pixel_values']
90
+ pixels = pixels.to(self.get_device())
91
+ return self.model.get_image_features(pixel_values=pixels)
92
+
93
+ def encode_text_noproj(self, text):
94
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
95
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
96
+ tokens = batch_encoding["input_ids"].to(self.get_device())
97
+ outputs = self.model.text_model(input_ids=tokens)
98
+ return outputs.last_hidden_state
99
+
100
+ def encode_vision_noproj(self, images):
101
+ inputs = self.processor(images=images, return_tensors="pt")
102
+ pixels = inputs['pixel_values'].half() if self.fp16 else inputs['pixel_values']
103
+ pixels = pixels.to(self.get_device())
104
+ outputs = self.model.vision_model(pixel_values=pixels)
105
+ return outputs.last_hidden_state
106
+
107
+ def encode_text(self, text):
108
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
109
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
110
+ tokens = batch_encoding["input_ids"].to(self.get_device())
111
+ outputs = self.model.text_model(input_ids=tokens)
112
+ z = self.model.text_projection(outputs.last_hidden_state)
113
+ z_pooled = self.model.text_projection(outputs.pooler_output)
114
+ z = z / torch.norm(z_pooled.unsqueeze(1), dim=-1, keepdim=True)
115
+ return z
116
+
117
+ def encode_vision(self, images):
118
+ z = self.encode_vision_noproj(images)
119
+ z = self.model.vision_model.post_layernorm(z)
120
+ z = self.model.visual_projection(z)
121
+ z_pooled = z[:, 0:1]
122
+ # z_pooled_normed = z_pooled / z_pooled.norm(dim=-1, keepdim=True)
123
+ z = z / torch.norm(z_pooled, dim=-1, keepdim=True)
124
+ return z
125
+
126
+ def encode(self, *args, **kwargs):
127
+ return getattr(self, self.encode_type)(*args, **kwargs)
128
+
129
+ #############################
130
+ # copyed from justin's code #
131
+ #############################
132
+
133
+ @register('clip_vision_frozen_justin', version)
134
+ class FrozenCLIPVisionEmbedder_Justin(AbstractEncoder):
135
+ """
136
+ Uses the CLIP image encoder.
137
+ """
138
+ def __init__(
139
+ self,
140
+ model='ViT-L/14',
141
+ jit=False,
142
+ device='cuda' if torch.cuda.is_available() else 'cpu',
143
+ antialias=False,
144
+ ):
145
+ super().__init__()
146
+ from . import clip_justin
147
+ self.model, _ = clip_justin.load(name=model, device=device, jit=jit)
148
+ self.device = device
149
+ self.antialias = antialias
150
+
151
+ self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
152
+ self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
153
+
154
+ # I didn't call this originally, but seems like it was frozen anyway
155
+ self.freeze()
156
+
157
+ def freeze(self):
158
+ self.transformer = self.model.eval()
159
+ for param in self.parameters():
160
+ param.requires_grad = False
161
+
162
+ def preprocess(self, x):
163
+ import kornia
164
+ # Expects inputs in the range -1, 1
165
+ x = kornia.geometry.resize(x, (224, 224),
166
+ interpolation='bicubic',align_corners=True,
167
+ antialias=self.antialias)
168
+ x = (x + 1.) / 2.
169
+ # renormalize according to clip
170
+ x = kornia.enhance.normalize(x, self.mean, self.std)
171
+ return x
172
+
173
+ def forward(self, x):
174
+ # x is assumed to be in range [-1,1]
175
+ return self.model.encode_image(self.preprocess(x)).float()
176
+
177
+ def encode(self, im):
178
+ return self(im).unsqueeze(1)
lib/model_zoo/clip_justin/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .clip import load
lib/model_zoo/clip_justin/clip.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+ from typing import Any, Union, List
6
+ from pkg_resources import packaging
7
+
8
+ import torch
9
+ from PIL import Image
10
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11
+ from tqdm import tqdm
12
+
13
+ from .model import build_model
14
+ # from .simple_tokenizer import SimpleTokenizer as _Tokenizer
15
+
16
+ try:
17
+ from torchvision.transforms import InterpolationMode
18
+ BICUBIC = InterpolationMode.BICUBIC
19
+ except ImportError:
20
+ BICUBIC = Image.BICUBIC
21
+
22
+
23
+ if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
24
+ warnings.warn("PyTorch version 1.7.1 or higher is recommended")
25
+
26
+
27
+ __all__ = ["available_models", "load", "tokenize"]
28
+ # _tokenizer = _Tokenizer()
29
+
30
+ _MODELS = {
31
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
32
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
33
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
34
+ "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
35
+ "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
36
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
37
+ "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
38
+ "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
39
+ "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
40
+ }
41
+
42
+
43
+ def _download(url: str, root: str):
44
+ os.makedirs(root, exist_ok=True)
45
+ filename = os.path.basename(url)
46
+
47
+ expected_sha256 = url.split("/")[-2]
48
+ download_target = os.path.join(root, filename)
49
+
50
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
51
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
52
+
53
+ if os.path.isfile(download_target):
54
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
55
+ return download_target
56
+ else:
57
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
58
+
59
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
60
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
61
+ while True:
62
+ buffer = source.read(8192)
63
+ if not buffer:
64
+ break
65
+
66
+ output.write(buffer)
67
+ loop.update(len(buffer))
68
+
69
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
70
+ raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
71
+
72
+ return download_target
73
+
74
+
75
+ def _convert_image_to_rgb(image):
76
+ return image.convert("RGB")
77
+
78
+
79
+ def _transform(n_px):
80
+ return Compose([
81
+ Resize(n_px, interpolation=BICUBIC),
82
+ CenterCrop(n_px),
83
+ _convert_image_to_rgb,
84
+ ToTensor(),
85
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
86
+ ])
87
+
88
+
89
+ def available_models() -> List[str]:
90
+ """Returns the names of available CLIP models"""
91
+ return list(_MODELS.keys())
92
+
93
+
94
+ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
95
+ """Load a CLIP model
96
+
97
+ Parameters
98
+ ----------
99
+ name : str
100
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
101
+
102
+ device : Union[str, torch.device]
103
+ The device to put the loaded model
104
+
105
+ jit : bool
106
+ Whether to load the optimized JIT model or more hackable non-JIT model (default).
107
+
108
+ download_root: str
109
+ path to download the model files; by default, it uses "~/.cache/clip"
110
+
111
+ Returns
112
+ -------
113
+ model : torch.nn.Module
114
+ The CLIP model
115
+
116
+ preprocess : Callable[[PIL.Image], torch.Tensor]
117
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
118
+ """
119
+ if name in _MODELS:
120
+ model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
121
+ elif os.path.isfile(name):
122
+ model_path = name
123
+ else:
124
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
125
+
126
+ with open(model_path, 'rb') as opened_file:
127
+ try:
128
+ # loading JIT archive
129
+ model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
130
+ state_dict = None
131
+ except RuntimeError:
132
+ # loading saved state dict
133
+ if jit:
134
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
135
+ jit = False
136
+ state_dict = torch.load(opened_file, map_location="cpu")
137
+
138
+ if not jit:
139
+ model = build_model(state_dict or model.state_dict()).to(device)
140
+ if str(device) == "cpu":
141
+ model.float()
142
+ return model, _transform(model.visual.input_resolution)
143
+
144
+ # patch the device names
145
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
146
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
147
+
148
+ def patch_device(module):
149
+ try:
150
+ graphs = [module.graph] if hasattr(module, "graph") else []
151
+ except RuntimeError:
152
+ graphs = []
153
+
154
+ if hasattr(module, "forward1"):
155
+ graphs.append(module.forward1.graph)
156
+
157
+ for graph in graphs:
158
+ for node in graph.findAllNodes("prim::Constant"):
159
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
160
+ node.copyAttributes(device_node)
161
+
162
+ model.apply(patch_device)
163
+ patch_device(model.encode_image)
164
+ patch_device(model.encode_text)
165
+
166
+ # patch dtype to float32 on CPU
167
+ if str(device) == "cpu":
168
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
169
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
170
+ float_node = float_input.node()
171
+
172
+ def patch_float(module):
173
+ try:
174
+ graphs = [module.graph] if hasattr(module, "graph") else []
175
+ except RuntimeError:
176
+ graphs = []
177
+
178
+ if hasattr(module, "forward1"):
179
+ graphs.append(module.forward1.graph)
180
+
181
+ for graph in graphs:
182
+ for node in graph.findAllNodes("aten::to"):
183
+ inputs = list(node.inputs())
184
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
185
+ if inputs[i].node()["value"] == 5:
186
+ inputs[i].node().copyAttributes(float_node)
187
+
188
+ model.apply(patch_float)
189
+ patch_float(model.encode_image)
190
+ patch_float(model.encode_text)
191
+
192
+ model.float()
193
+
194
+ return model, _transform(model.input_resolution.item())
195
+
196
+
197
+ # def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
198
+ # """
199
+ # Returns the tokenized representation of given input string(s)
200
+
201
+ # Parameters
202
+ # ----------
203
+ # texts : Union[str, List[str]]
204
+ # An input string or a list of input strings to tokenize
205
+
206
+ # context_length : int
207
+ # The context length to use; all CLIP models use 77 as the context length
208
+
209
+ # truncate: bool
210
+ # Whether to truncate the text in case its encoding is longer than the context length
211
+
212
+ # Returns
213
+ # -------
214
+ # A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
215
+ # We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
216
+ # """
217
+ # if isinstance(texts, str):
218
+ # texts = [texts]
219
+
220
+ # sot_token = _tokenizer.encoder["<|startoftext|>"]
221
+ # eot_token = _tokenizer.encoder["<|endoftext|>"]
222
+ # all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
223
+ # if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
224
+ # result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
225
+ # else:
226
+ # result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
227
+
228
+ # for i, tokens in enumerate(all_tokens):
229
+ # if len(tokens) > context_length:
230
+ # if truncate:
231
+ # tokens = tokens[:context_length]
232
+ # tokens[-1] = eot_token
233
+ # else:
234
+ # raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
235
+ # result[i, :len(tokens)] = torch.tensor(tokens)
236
+
237
+ # return result
lib/model_zoo/clip_justin/model.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+
10
+ class Bottleneck(nn.Module):
11
+ expansion = 4
12
+
13
+ def __init__(self, inplanes, planes, stride=1):
14
+ super().__init__()
15
+
16
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18
+ self.bn1 = nn.BatchNorm2d(planes)
19
+ self.relu1 = nn.ReLU(inplace=True)
20
+
21
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22
+ self.bn2 = nn.BatchNorm2d(planes)
23
+ self.relu2 = nn.ReLU(inplace=True)
24
+
25
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26
+
27
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29
+ self.relu3 = nn.ReLU(inplace=True)
30
+
31
+ self.downsample = None
32
+ self.stride = stride
33
+
34
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
35
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36
+ self.downsample = nn.Sequential(OrderedDict([
37
+ ("-1", nn.AvgPool2d(stride)),
38
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
39
+ ("1", nn.BatchNorm2d(planes * self.expansion))
40
+ ]))
41
+
42
+ def forward(self, x: torch.Tensor):
43
+ identity = x
44
+
45
+ out = self.relu1(self.bn1(self.conv1(x)))
46
+ out = self.relu2(self.bn2(self.conv2(out)))
47
+ out = self.avgpool(out)
48
+ out = self.bn3(self.conv3(out))
49
+
50
+ if self.downsample is not None:
51
+ identity = self.downsample(x)
52
+
53
+ out += identity
54
+ out = self.relu3(out)
55
+ return out
56
+
57
+
58
+ class AttentionPool2d(nn.Module):
59
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
60
+ super().__init__()
61
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
62
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
63
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
64
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
65
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
66
+ self.num_heads = num_heads
67
+
68
+ def forward(self, x):
69
+ x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
70
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
71
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
72
+ x, _ = F.multi_head_attention_forward(
73
+ query=x[:1], key=x, value=x,
74
+ embed_dim_to_check=x.shape[-1],
75
+ num_heads=self.num_heads,
76
+ q_proj_weight=self.q_proj.weight,
77
+ k_proj_weight=self.k_proj.weight,
78
+ v_proj_weight=self.v_proj.weight,
79
+ in_proj_weight=None,
80
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
81
+ bias_k=None,
82
+ bias_v=None,
83
+ add_zero_attn=False,
84
+ dropout_p=0,
85
+ out_proj_weight=self.c_proj.weight,
86
+ out_proj_bias=self.c_proj.bias,
87
+ use_separate_proj_weight=True,
88
+ training=self.training,
89
+ need_weights=False
90
+ )
91
+ return x.squeeze(0)
92
+
93
+
94
+ class ModifiedResNet(nn.Module):
95
+ """
96
+ A ResNet class that is similar to torchvision's but contains the following changes:
97
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
98
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
99
+ - The final pooling layer is a QKV attention instead of an average pool
100
+ """
101
+
102
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
103
+ super().__init__()
104
+ self.output_dim = output_dim
105
+ self.input_resolution = input_resolution
106
+
107
+ # the 3-layer stem
108
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
109
+ self.bn1 = nn.BatchNorm2d(width // 2)
110
+ self.relu1 = nn.ReLU(inplace=True)
111
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
112
+ self.bn2 = nn.BatchNorm2d(width // 2)
113
+ self.relu2 = nn.ReLU(inplace=True)
114
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
115
+ self.bn3 = nn.BatchNorm2d(width)
116
+ self.relu3 = nn.ReLU(inplace=True)
117
+ self.avgpool = nn.AvgPool2d(2)
118
+
119
+ # residual layers
120
+ self._inplanes = width # this is a *mutable* variable used during construction
121
+ self.layer1 = self._make_layer(width, layers[0])
122
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
123
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
124
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
125
+
126
+ embed_dim = width * 32 # the ResNet feature dimension
127
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
128
+
129
+ def _make_layer(self, planes, blocks, stride=1):
130
+ layers = [Bottleneck(self._inplanes, planes, stride)]
131
+
132
+ self._inplanes = planes * Bottleneck.expansion
133
+ for _ in range(1, blocks):
134
+ layers.append(Bottleneck(self._inplanes, planes))
135
+
136
+ return nn.Sequential(*layers)
137
+
138
+ def forward(self, x):
139
+ def stem(x):
140
+ x = self.relu1(self.bn1(self.conv1(x)))
141
+ x = self.relu2(self.bn2(self.conv2(x)))
142
+ x = self.relu3(self.bn3(self.conv3(x)))
143
+ x = self.avgpool(x)
144
+ return x
145
+
146
+ x = x.type(self.conv1.weight.dtype)
147
+ x = stem(x)
148
+ x = self.layer1(x)
149
+ x = self.layer2(x)
150
+ x = self.layer3(x)
151
+ x = self.layer4(x)
152
+ x = self.attnpool(x)
153
+
154
+ return x
155
+
156
+
157
+ class LayerNorm(nn.LayerNorm):
158
+ """Subclass torch's LayerNorm to handle fp16."""
159
+
160
+ def forward(self, x: torch.Tensor):
161
+ orig_type = x.dtype
162
+ ret = super().forward(x.type(torch.float32))
163
+ return ret.type(orig_type)
164
+
165
+
166
+ class QuickGELU(nn.Module):
167
+ def forward(self, x: torch.Tensor):
168
+ return x * torch.sigmoid(1.702 * x)
169
+
170
+
171
+ class ResidualAttentionBlock(nn.Module):
172
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
173
+ super().__init__()
174
+
175
+ self.attn = nn.MultiheadAttention(d_model, n_head)
176
+ self.ln_1 = LayerNorm(d_model)
177
+ self.mlp = nn.Sequential(OrderedDict([
178
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
179
+ ("gelu", QuickGELU()),
180
+ ("c_proj", nn.Linear(d_model * 4, d_model))
181
+ ]))
182
+ self.ln_2 = LayerNorm(d_model)
183
+ self.attn_mask = attn_mask
184
+
185
+ def attention(self, x: torch.Tensor):
186
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
187
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
188
+
189
+ def forward(self, x: torch.Tensor):
190
+ x = x + self.attention(self.ln_1(x))
191
+ x = x + self.mlp(self.ln_2(x))
192
+ return x
193
+
194
+
195
+ class Transformer(nn.Module):
196
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
197
+ super().__init__()
198
+ self.width = width
199
+ self.layers = layers
200
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
201
+
202
+ def forward(self, x: torch.Tensor):
203
+ return self.resblocks(x)
204
+
205
+
206
+ class VisionTransformer(nn.Module):
207
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
208
+ super().__init__()
209
+ self.input_resolution = input_resolution
210
+ self.output_dim = output_dim
211
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
212
+
213
+ scale = width ** -0.5
214
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
215
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
216
+ self.ln_pre = LayerNorm(width)
217
+
218
+ self.transformer = Transformer(width, layers, heads)
219
+
220
+ self.ln_post = LayerNorm(width)
221
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
222
+
223
+ def forward(self, x: torch.Tensor):
224
+ x = self.conv1(x) # shape = [*, width, grid, grid]
225
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
226
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
227
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
228
+ x = x + self.positional_embedding.to(x.dtype)
229
+ x = self.ln_pre(x)
230
+
231
+ x = x.permute(1, 0, 2) # NLD -> LND
232
+ x = self.transformer(x)
233
+ x = x.permute(1, 0, 2) # LND -> NLD
234
+
235
+ x = self.ln_post(x[:, 0, :])
236
+
237
+ if self.proj is not None:
238
+ x = x @ self.proj
239
+
240
+ return x
241
+
242
+
243
+ class CLIP(nn.Module):
244
+ def __init__(self,
245
+ embed_dim: int,
246
+ # vision
247
+ image_resolution: int,
248
+ vision_layers: Union[Tuple[int, int, int, int], int],
249
+ vision_width: int,
250
+ vision_patch_size: int,
251
+ # text
252
+ context_length: int,
253
+ vocab_size: int,
254
+ transformer_width: int,
255
+ transformer_heads: int,
256
+ transformer_layers: int
257
+ ):
258
+ super().__init__()
259
+
260
+ self.context_length = context_length
261
+
262
+ if isinstance(vision_layers, (tuple, list)):
263
+ vision_heads = vision_width * 32 // 64
264
+ self.visual = ModifiedResNet(
265
+ layers=vision_layers,
266
+ output_dim=embed_dim,
267
+ heads=vision_heads,
268
+ input_resolution=image_resolution,
269
+ width=vision_width
270
+ )
271
+ else:
272
+ vision_heads = vision_width // 64
273
+ self.visual = VisionTransformer(
274
+ input_resolution=image_resolution,
275
+ patch_size=vision_patch_size,
276
+ width=vision_width,
277
+ layers=vision_layers,
278
+ heads=vision_heads,
279
+ output_dim=embed_dim
280
+ )
281
+
282
+ self.transformer = Transformer(
283
+ width=transformer_width,
284
+ layers=transformer_layers,
285
+ heads=transformer_heads,
286
+ attn_mask=self.build_attention_mask()
287
+ )
288
+
289
+ self.vocab_size = vocab_size
290
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
291
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
292
+ self.ln_final = LayerNorm(transformer_width)
293
+
294
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
295
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
296
+
297
+ self.initialize_parameters()
298
+
299
+ def initialize_parameters(self):
300
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
301
+ nn.init.normal_(self.positional_embedding, std=0.01)
302
+
303
+ if isinstance(self.visual, ModifiedResNet):
304
+ if self.visual.attnpool is not None:
305
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
306
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
307
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
308
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
309
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
310
+
311
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
312
+ for name, param in resnet_block.named_parameters():
313
+ if name.endswith("bn3.weight"):
314
+ nn.init.zeros_(param)
315
+
316
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
317
+ attn_std = self.transformer.width ** -0.5
318
+ fc_std = (2 * self.transformer.width) ** -0.5
319
+ for block in self.transformer.resblocks:
320
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
321
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
322
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
323
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
324
+
325
+ if self.text_projection is not None:
326
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
327
+
328
+ def build_attention_mask(self):
329
+ # lazily create causal attention mask, with full attention between the vision tokens
330
+ # pytorch uses additive attention mask; fill with -inf
331
+ mask = torch.empty(self.context_length, self.context_length)
332
+ mask.fill_(float("-inf"))
333
+ mask.triu_(1) # zero out the lower diagonal
334
+ return mask
335
+
336
+ @property
337
+ def dtype(self):
338
+ return self.visual.conv1.weight.dtype
339
+
340
+ def encode_image(self, image):
341
+ return self.visual(image.type(self.dtype))
342
+
343
+ def encode_text(self, text):
344
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
345
+
346
+ x = x + self.positional_embedding.type(self.dtype)
347
+ x = x.permute(1, 0, 2) # NLD -> LND
348
+ x = self.transformer(x)
349
+ x = x.permute(1, 0, 2) # LND -> NLD
350
+ x = self.ln_final(x).type(self.dtype)
351
+
352
+ # x.shape = [batch_size, n_ctx, transformer.width]
353
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
354
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
355
+
356
+ return x
357
+
358
+ def forward(self, image, text):
359
+ image_features = self.encode_image(image)
360
+ text_features = self.encode_text(text)
361
+
362
+ # normalized features
363
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
364
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
365
+
366
+ # cosine similarity as logits
367
+ logit_scale = self.logit_scale.exp()
368
+ logits_per_image = logit_scale * image_features @ text_features.t()
369
+ logits_per_text = logits_per_image.t()
370
+
371
+ # shape = [global_batch_size, global_batch_size]
372
+ return logits_per_image, logits_per_text
373
+
374
+
375
+ def convert_weights(model: nn.Module):
376
+ """Convert applicable model parameters to fp16"""
377
+
378
+ def _convert_weights_to_fp16(l):
379
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
380
+ l.weight.data = l.weight.data.half()
381
+ if l.bias is not None:
382
+ l.bias.data = l.bias.data.half()
383
+
384
+ if isinstance(l, nn.MultiheadAttention):
385
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
386
+ tensor = getattr(l, attr)
387
+ if tensor is not None:
388
+ tensor.data = tensor.data.half()
389
+
390
+ for name in ["text_projection", "proj"]:
391
+ if hasattr(l, name):
392
+ attr = getattr(l, name)
393
+ if attr is not None:
394
+ attr.data = attr.data.half()
395
+
396
+ model.apply(_convert_weights_to_fp16)
397
+
398
+
399
+ def build_model(state_dict: dict):
400
+ vit = "visual.proj" in state_dict
401
+
402
+ if vit:
403
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
404
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
405
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
406
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
407
+ image_resolution = vision_patch_size * grid_size
408
+ else:
409
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
410
+ vision_layers = tuple(counts)
411
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
412
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
413
+ vision_patch_size = None
414
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
415
+ image_resolution = output_width * 32
416
+
417
+ embed_dim = state_dict["text_projection"].shape[1]
418
+ context_length = state_dict["positional_embedding"].shape[0]
419
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
420
+ transformer_width = state_dict["ln_final.weight"].shape[0]
421
+ transformer_heads = transformer_width // 64
422
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
423
+
424
+ model = CLIP(
425
+ embed_dim,
426
+ image_resolution, vision_layers, vision_width, vision_patch_size,
427
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
428
+ )
429
+
430
+ for key in ["input_resolution", "context_length", "vocab_size"]:
431
+ if key in state_dict:
432
+ del state_dict[key]
433
+
434
+ convert_weights(model)
435
+ model.load_state_dict(state_dict)
436
+ return model.eval()
lib/model_zoo/clip_justin/simple_tokenizer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import html
3
+ import os
4
+ from functools import lru_cache
5
+
6
+ import ftfy
7
+ import regex as re
8
+
9
+
10
+ @lru_cache()
11
+ def default_bpe():
12
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13
+
14
+
15
+ @lru_cache()
16
+ def bytes_to_unicode():
17
+ """
18
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
19
+ The reversible bpe codes work on unicode strings.
20
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
23
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
25
+ """
26
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27
+ cs = bs[:]
28
+ n = 0
29
+ for b in range(2**8):
30
+ if b not in bs:
31
+ bs.append(b)
32
+ cs.append(2**8+n)
33
+ n += 1
34
+ cs = [chr(n) for n in cs]
35
+ return dict(zip(bs, cs))
36
+
37
+
38
+ def get_pairs(word):
39
+ """Return set of symbol pairs in a word.
40
+ Word is represented as tuple of symbols (symbols being variable-length strings).
41
+ """
42
+ pairs = set()
43
+ prev_char = word[0]
44
+ for char in word[1:]:
45
+ pairs.add((prev_char, char))
46
+ prev_char = char
47
+ return pairs
48
+
49
+
50
+ def basic_clean(text):
51
+ text = ftfy.fix_text(text)
52
+ text = html.unescape(html.unescape(text))
53
+ return text.strip()
54
+
55
+
56
+ def whitespace_clean(text):
57
+ text = re.sub(r'\s+', ' ', text)
58
+ text = text.strip()
59
+ return text
60
+
61
+
62
+ class SimpleTokenizer(object):
63
+ def __init__(self, bpe_path: str = default_bpe()):
64
+ self.byte_encoder = bytes_to_unicode()
65
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67
+ merges = merges[1:49152-256-2+1]
68
+ merges = [tuple(merge.split()) for merge in merges]
69
+ vocab = list(bytes_to_unicode().values())
70
+ vocab = vocab + [v+'</w>' for v in vocab]
71
+ for merge in merges:
72
+ vocab.append(''.join(merge))
73
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74
+ self.encoder = dict(zip(vocab, range(len(vocab))))
75
+ self.decoder = {v: k for k, v in self.encoder.items()}
76
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
77
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79
+
80
+ def bpe(self, token):
81
+ if token in self.cache:
82
+ return self.cache[token]
83
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
84
+ pairs = get_pairs(word)
85
+
86
+ if not pairs:
87
+ return token+'</w>'
88
+
89
+ while True:
90
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91
+ if bigram not in self.bpe_ranks:
92
+ break
93
+ first, second = bigram
94
+ new_word = []
95
+ i = 0
96
+ while i < len(word):
97
+ try:
98
+ j = word.index(first, i)
99
+ new_word.extend(word[i:j])
100
+ i = j
101
+ except:
102
+ new_word.extend(word[i:])
103
+ break
104
+
105
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
106
+ new_word.append(first+second)
107
+ i += 2
108
+ else:
109
+ new_word.append(word[i])
110
+ i += 1
111
+ new_word = tuple(new_word)
112
+ word = new_word
113
+ if len(word) == 1:
114
+ break
115
+ else:
116
+ pairs = get_pairs(word)
117
+ word = ' '.join(word)
118
+ self.cache[token] = word
119
+ return word
120
+
121
+ def encode(self, text):
122
+ bpe_tokens = []
123
+ text = whitespace_clean(basic_clean(text)).lower()
124
+ for token in re.findall(self.pat, text):
125
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127
+ return bpe_tokens
128
+
129
+ def decode(self, tokens):
130
+ text = ''.join([self.decoder[token] for token in tokens])
131
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
132
+ return text
lib/model_zoo/common/get_model.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from email.policy import strict
2
+ import torch
3
+ import torchvision.models
4
+ import os.path as osp
5
+ import copy
6
+ from ...log_service import print_log
7
+ from .utils import \
8
+ get_total_param, get_total_param_sum, \
9
+ get_unit
10
+
11
+ # def load_state_dict(net, model_path):
12
+ # if isinstance(net, dict):
13
+ # for ni, neti in net.items():
14
+ # paras = torch.load(model_path[ni], map_location=torch.device('cpu'))
15
+ # new_paras = neti.state_dict()
16
+ # new_paras.update(paras)
17
+ # neti.load_state_dict(new_paras)
18
+ # else:
19
+ # paras = torch.load(model_path, map_location=torch.device('cpu'))
20
+ # new_paras = net.state_dict()
21
+ # new_paras.update(paras)
22
+ # net.load_state_dict(new_paras)
23
+ # return
24
+
25
+ # def save_state_dict(net, path):
26
+ # if isinstance(net, (torch.nn.DataParallel,
27
+ # torch.nn.parallel.DistributedDataParallel)):
28
+ # torch.save(net.module.state_dict(), path)
29
+ # else:
30
+ # torch.save(net.state_dict(), path)
31
+
32
+ def singleton(class_):
33
+ instances = {}
34
+ def getinstance(*args, **kwargs):
35
+ if class_ not in instances:
36
+ instances[class_] = class_(*args, **kwargs)
37
+ return instances[class_]
38
+ return getinstance
39
+
40
+ def preprocess_model_args(args):
41
+ # If args has layer_units, get the corresponding
42
+ # units.
43
+ # If args get backbone, get the backbone model.
44
+ args = copy.deepcopy(args)
45
+ if 'layer_units' in args:
46
+ layer_units = [
47
+ get_unit()(i) for i in args.layer_units
48
+ ]
49
+ args.layer_units = layer_units
50
+ if 'backbone' in args:
51
+ args.backbone = get_model()(args.backbone)
52
+ return args
53
+
54
+ @singleton
55
+ class get_model(object):
56
+ def __init__(self):
57
+ self.model = {}
58
+ self.version = {}
59
+
60
+ def register(self, model, name, version='x'):
61
+ self.model[name] = model
62
+ self.version[name] = version
63
+
64
+ def __call__(self, cfg, verbose=True):
65
+ """
66
+ Construct model based on the config.
67
+ """
68
+ t = cfg.type
69
+
70
+ # the register is in each file
71
+ if t.find('ldm')==0:
72
+ from .. import ldm
73
+ elif t=='autoencoderkl':
74
+ from .. import autoencoder
75
+ elif t.find('clip')==0:
76
+ from .. import clip
77
+ elif t.find('sd')==0:
78
+ from .. import sd
79
+ elif t.find('vd')==0:
80
+ from .. import vd
81
+ elif t.find('openai_unet')==0:
82
+ from .. import openaimodel
83
+ elif t.find('optimus')==0:
84
+ from .. import optimus
85
+
86
+ args = preprocess_model_args(cfg.args)
87
+ net = self.model[t](**args)
88
+
89
+ if 'ckpt' in cfg:
90
+ checkpoint = torch.load(cfg.ckpt, map_location='cpu')
91
+ strict_sd = cfg.get('strict_sd', True)
92
+ net.load_state_dict(checkpoint['state_dict'], strict=strict_sd)
93
+ if verbose:
94
+ print_log('Load ckpt from {}'.format(cfg.ckpt))
95
+ elif 'pth' in cfg:
96
+ sd = torch.load(cfg.pth, map_location='cpu')
97
+ strict_sd = cfg.get('strict_sd', True)
98
+ net.load_state_dict(sd, strict=strict_sd)
99
+ if verbose:
100
+ print_log('Load pth from {}'.format(cfg.pth))
101
+ elif 'hfm' in cfg:
102
+ from huggingface_hub import hf_hub_download
103
+ temppath = hf_hub_download(cfg.hfm[0], cfg.hfm[1])
104
+ sd = torch.load(temppath, map_location='cpu')
105
+ strict_sd = cfg.get('strict_sd', True)
106
+ net.load_state_dict(sd, strict=strict_sd)
107
+ if verbose:
108
+ print_log('Load pth from {}/{}'.format(*cfg.hfm))
109
+
110
+ # display param_num & param_sum
111
+ if verbose:
112
+ print_log(
113
+ 'Load {} with total {} parameters,'
114
+ '{:.3f} parameter sum.'.format(
115
+ t,
116
+ get_total_param(net),
117
+ get_total_param_sum(net) ))
118
+
119
+ return net
120
+
121
+ def get_version(self, name):
122
+ return self.version[name]
123
+
124
+ def register(name, version='x'):
125
+ def wrapper(class_):
126
+ get_model().register(class_, name, version)
127
+ return class_
128
+ return wrapper
lib/model_zoo/common/get_optimizer.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.optim as optim
3
+ import numpy as np
4
+ import itertools
5
+
6
+ def singleton(class_):
7
+ instances = {}
8
+ def getinstance(*args, **kwargs):
9
+ if class_ not in instances:
10
+ instances[class_] = class_(*args, **kwargs)
11
+ return instances[class_]
12
+ return getinstance
13
+
14
+ class get_optimizer(object):
15
+ def __init__(self):
16
+ self.optimizer = {}
17
+ self.register(optim.SGD, 'sgd')
18
+ self.register(optim.Adam, 'adam')
19
+ self.register(optim.AdamW, 'adamw')
20
+
21
+ def register(self, optim, name):
22
+ self.optimizer[name] = optim
23
+
24
+ def __call__(self, net, cfg):
25
+ if cfg is None:
26
+ return None
27
+ t = cfg.type
28
+ if isinstance(net, (torch.nn.DataParallel,
29
+ torch.nn.parallel.DistributedDataParallel)):
30
+ netm = net.module
31
+ else:
32
+ netm = net
33
+ pg = getattr(netm, 'parameter_group', None)
34
+
35
+ if pg is not None:
36
+ params = []
37
+ for group_name, module_or_para in pg.items():
38
+ if not isinstance(module_or_para, list):
39
+ module_or_para = [module_or_para]
40
+
41
+ grouped_params = [mi.parameters() if isinstance(mi, torch.nn.Module) else [mi] for mi in module_or_para]
42
+ grouped_params = itertools.chain(*grouped_params)
43
+ pg_dict = {'params':grouped_params, 'name':group_name}
44
+ params.append(pg_dict)
45
+ else:
46
+ params = net.parameters()
47
+ return self.optimizer[t](params, lr=0, **cfg.args)
lib/model_zoo/common/get_scheduler.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.optim as optim
3
+ import numpy as np
4
+ import copy
5
+ from ... import sync
6
+ from ...cfg_holder import cfg_unique_holder as cfguh
7
+
8
+ def singleton(class_):
9
+ instances = {}
10
+ def getinstance(*args, **kwargs):
11
+ if class_ not in instances:
12
+ instances[class_] = class_(*args, **kwargs)
13
+ return instances[class_]
14
+ return getinstance
15
+
16
+ @singleton
17
+ class get_scheduler(object):
18
+ def __init__(self):
19
+ self.lr_scheduler = {}
20
+
21
+ def register(self, lrsf, name):
22
+ self.lr_scheduler[name] = lrsf
23
+
24
+ def __call__(self, cfg):
25
+ if cfg is None:
26
+ return None
27
+ if isinstance(cfg, list):
28
+ schedulers = []
29
+ for ci in cfg:
30
+ t = ci.type
31
+ schedulers.append(
32
+ self.lr_scheduler[t](**ci.args))
33
+ if len(schedulers) == 0:
34
+ raise ValueError
35
+ else:
36
+ return compose_scheduler(schedulers)
37
+ t = cfg.type
38
+ return self.lr_scheduler[t](**cfg.args)
39
+
40
+
41
+ def register(name):
42
+ def wrapper(class_):
43
+ get_scheduler().register(class_, name)
44
+ return class_
45
+ return wrapper
46
+
47
+ class template_scheduler(object):
48
+ def __init__(self, step):
49
+ self.step = step
50
+
51
+ def __getitem__(self, idx):
52
+ raise ValueError
53
+
54
+ def set_lr(self, optim, new_lr, pg_lrscale=None):
55
+ """
56
+ Set Each parameter_groups in optim with new_lr
57
+ New_lr can be find according to the idx.
58
+ pg_lrscale tells how to scale each pg.
59
+ """
60
+ # new_lr = self.__getitem__(idx)
61
+ pg_lrscale = copy.deepcopy(pg_lrscale)
62
+ for pg in optim.param_groups:
63
+ if pg_lrscale is None:
64
+ pg['lr'] = new_lr
65
+ else:
66
+ pg['lr'] = new_lr * pg_lrscale.pop(pg['name'])
67
+ assert (pg_lrscale is None) or (len(pg_lrscale)==0), \
68
+ "pg_lrscale doesn't match pg"
69
+
70
+ @register('constant')
71
+ class constant_scheduler(template_scheduler):
72
+ def __init__(self, lr, step):
73
+ super().__init__(step)
74
+ self.lr = lr
75
+
76
+ def __getitem__(self, idx):
77
+ if idx >= self.step:
78
+ raise ValueError
79
+ return self.lr
80
+
81
+ @register('poly')
82
+ class poly_scheduler(template_scheduler):
83
+ def __init__(self, start_lr, end_lr, power, step):
84
+ super().__init__(step)
85
+ self.start_lr = start_lr
86
+ self.end_lr = end_lr
87
+ self.power = power
88
+
89
+ def __getitem__(self, idx):
90
+ if idx >= self.step:
91
+ raise ValueError
92
+ a, b = self.start_lr, self.end_lr
93
+ p, n = self.power, self.step
94
+ return b + (a-b)*((1-idx/n)**p)
95
+
96
+ @register('linear')
97
+ class linear_scheduler(template_scheduler):
98
+ def __init__(self, start_lr, end_lr, step):
99
+ super().__init__(step)
100
+ self.start_lr = start_lr
101
+ self.end_lr = end_lr
102
+
103
+ def __getitem__(self, idx):
104
+ if idx >= self.step:
105
+ raise ValueError
106
+ a, b, n = self.start_lr, self.end_lr, self.step
107
+ return b + (a-b)*(1-idx/n)
108
+
109
+ @register('multistage')
110
+ class constant_scheduler(template_scheduler):
111
+ def __init__(self, start_lr, milestones, gamma, step):
112
+ super().__init__(step)
113
+ self.start_lr = start_lr
114
+ m = [0] + milestones + [step]
115
+ lr_iter = start_lr
116
+ self.lr = []
117
+ for ms, me in zip(m[0:-1], m[1:]):
118
+ for _ in range(ms, me):
119
+ self.lr.append(lr_iter)
120
+ lr_iter *= gamma
121
+
122
+ def __getitem__(self, idx):
123
+ if idx >= self.step:
124
+ raise ValueError
125
+ return self.lr[idx]
126
+
127
+ class compose_scheduler(template_scheduler):
128
+ def __init__(self, schedulers):
129
+ self.schedulers = schedulers
130
+ self.step = [si.step for si in schedulers]
131
+ self.step_milestone = []
132
+ acc = 0
133
+ for i in self.step:
134
+ acc += i
135
+ self.step_milestone.append(acc)
136
+ self.step = sum(self.step)
137
+
138
+ def __getitem__(self, idx):
139
+ if idx >= self.step:
140
+ raise ValueError
141
+ ms = self.step_milestone
142
+ for idx, (mi, mj) in enumerate(zip(ms[:-1], ms[1:])):
143
+ if mi <= idx < mj:
144
+ return self.schedulers[idx-mi]
145
+ raise ValueError
146
+
147
+ ####################
148
+ # lambda schedular #
149
+ ####################
150
+
151
+ class LambdaWarmUpCosineScheduler(template_scheduler):
152
+ """
153
+ note: use with a base_lr of 1.0
154
+ """
155
+ def __init__(self,
156
+ base_lr,
157
+ warm_up_steps,
158
+ lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
159
+ cfgt = cfguh().cfg.train
160
+ bs = cfgt.batch_size
161
+ if 'gradacc_every' not in cfgt:
162
+ print('Warning, gradacc_every is not found in xml, use 1 as default.')
163
+ acc = cfgt.get('gradacc_every', 1)
164
+ self.lr_multi = base_lr * bs * acc
165
+ self.lr_warm_up_steps = warm_up_steps
166
+ self.lr_start = lr_start
167
+ self.lr_min = lr_min
168
+ self.lr_max = lr_max
169
+ self.lr_max_decay_steps = max_decay_steps
170
+ self.last_lr = 0.
171
+ self.verbosity_interval = verbosity_interval
172
+
173
+ def schedule(self, n):
174
+ if self.verbosity_interval > 0:
175
+ if n % self.verbosity_interval == 0:
176
+ print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
177
+ if n < self.lr_warm_up_steps:
178
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
179
+ self.last_lr = lr
180
+ return lr
181
+ else:
182
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
183
+ t = min(t, 1.0)
184
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
185
+ 1 + np.cos(t * np.pi))
186
+ self.last_lr = lr
187
+ return lr
188
+
189
+ def __getitem__(self, idx):
190
+ return self.schedule(idx) * self.lr_multi
191
+
192
+ class LambdaWarmUpCosineScheduler2(template_scheduler):
193
+ """
194
+ supports repeated iterations, configurable via lists
195
+ note: use with a base_lr of 1.0.
196
+ """
197
+ def __init__(self,
198
+ base_lr,
199
+ warm_up_steps,
200
+ f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
201
+ cfgt = cfguh().cfg.train
202
+ # bs = cfgt.batch_size
203
+ # if 'gradacc_every' not in cfgt:
204
+ # print('Warning, gradacc_every is not found in xml, use 1 as default.')
205
+ # acc = cfgt.get('gradacc_every', 1)
206
+ # self.lr_multi = base_lr * bs * acc
207
+ self.lr_multi = base_lr
208
+ assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
209
+ self.lr_warm_up_steps = warm_up_steps
210
+ self.f_start = f_start
211
+ self.f_min = f_min
212
+ self.f_max = f_max
213
+ self.cycle_lengths = cycle_lengths
214
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
215
+ self.last_f = 0.
216
+ self.verbosity_interval = verbosity_interval
217
+
218
+ def find_in_interval(self, n):
219
+ interval = 0
220
+ for cl in self.cum_cycles[1:]:
221
+ if n <= cl:
222
+ return interval
223
+ interval += 1
224
+
225
+ def schedule(self, n):
226
+ cycle = self.find_in_interval(n)
227
+ n = n - self.cum_cycles[cycle]
228
+ if self.verbosity_interval > 0:
229
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
230
+ f"current cycle {cycle}")
231
+ if n < self.lr_warm_up_steps[cycle]:
232
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
233
+ self.last_f = f
234
+ return f
235
+ else:
236
+ t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
237
+ t = min(t, 1.0)
238
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
239
+ 1 + np.cos(t * np.pi))
240
+ self.last_f = f
241
+ return f
242
+
243
+ def __getitem__(self, idx):
244
+ return self.schedule(idx) * self.lr_multi
245
+
246
+ @register('stable_diffusion_linear')
247
+ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
248
+ def schedule(self, n):
249
+ cycle = self.find_in_interval(n)
250
+ n = n - self.cum_cycles[cycle]
251
+ if self.verbosity_interval > 0:
252
+ if n % self.verbosity_interval == 0:
253
+ print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
254
+ f"current cycle {cycle}")
255
+ if n < self.lr_warm_up_steps[cycle]:
256
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
257
+ self.last_f = f
258
+ return f
259
+ else:
260
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
261
+ self.last_f = f
262
+ return f
lib/model_zoo/common/utils.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import copy
6
+ import functools
7
+ import itertools
8
+
9
+ import matplotlib.pyplot as plt
10
+
11
+ ########
12
+ # unit #
13
+ ########
14
+
15
+ def singleton(class_):
16
+ instances = {}
17
+ def getinstance(*args, **kwargs):
18
+ if class_ not in instances:
19
+ instances[class_] = class_(*args, **kwargs)
20
+ return instances[class_]
21
+ return getinstance
22
+
23
+ def str2value(v):
24
+ v = v.strip()
25
+ try:
26
+ return int(v)
27
+ except:
28
+ pass
29
+ try:
30
+ return float(v)
31
+ except:
32
+ pass
33
+ if v in ('True', 'true'):
34
+ return True
35
+ elif v in ('False', 'false'):
36
+ return False
37
+ else:
38
+ return v
39
+
40
+ @singleton
41
+ class get_unit(object):
42
+ def __init__(self):
43
+ self.unit = {}
44
+ self.register('none', None)
45
+
46
+ # general convolution
47
+ self.register('conv' , nn.Conv2d)
48
+ self.register('bn' , nn.BatchNorm2d)
49
+ self.register('relu' , nn.ReLU)
50
+ self.register('relu6' , nn.ReLU6)
51
+ self.register('lrelu' , nn.LeakyReLU)
52
+ self.register('dropout' , nn.Dropout)
53
+ self.register('dropout2d', nn.Dropout2d)
54
+ self.register('sine', Sine)
55
+ self.register('relusine', ReLUSine)
56
+
57
+ def register(self,
58
+ name,
59
+ unitf,):
60
+
61
+ self.unit[name] = unitf
62
+
63
+ def __call__(self, name):
64
+ if name is None:
65
+ return None
66
+ i = name.find('(')
67
+ i = len(name) if i==-1 else i
68
+ t = name[:i]
69
+ f = self.unit[t]
70
+ args = name[i:].strip('()')
71
+ if len(args) == 0:
72
+ args = {}
73
+ return f
74
+ else:
75
+ args = args.split('=')
76
+ args = [[','.join(i.split(',')[:-1]), i.split(',')[-1]] for i in args]
77
+ args = list(itertools.chain.from_iterable(args))
78
+ args = [i.strip() for i in args if len(i)>0]
79
+ kwargs = {}
80
+ for k, v in zip(args[::2], args[1::2]):
81
+ if v[0]=='(' and v[-1]==')':
82
+ kwargs[k] = tuple([str2value(i) for i in v.strip('()').split(',')])
83
+ elif v[0]=='[' and v[-1]==']':
84
+ kwargs[k] = [str2value(i) for i in v.strip('[]').split(',')]
85
+ else:
86
+ kwargs[k] = str2value(v)
87
+ return functools.partial(f, **kwargs)
88
+
89
+ def register(name):
90
+ def wrapper(class_):
91
+ get_unit().register(name, class_)
92
+ return class_
93
+ return wrapper
94
+
95
+ class Sine(object):
96
+ def __init__(self, freq, gain=1):
97
+ self.freq = freq
98
+ self.gain = gain
99
+ self.repr = 'sine(freq={}, gain={})'.format(freq, gain)
100
+
101
+ def __call__(self, x, gain=1):
102
+ act_gain = self.gain * gain
103
+ return torch.sin(self.freq * x) * act_gain
104
+
105
+ def __repr__(self,):
106
+ return self.repr
107
+
108
+ class ReLUSine(nn.Module):
109
+ def __init(self):
110
+ super().__init__()
111
+
112
+ def forward(self, input):
113
+ a = torch.sin(30 * input)
114
+ b = nn.ReLU(inplace=False)(input)
115
+ return a+b
116
+
117
+ @register('lrelu_agc')
118
+ # class lrelu_agc(nn.Module):
119
+ class lrelu_agc(object):
120
+ """
121
+ The lrelu layer with alpha, gain and clamp
122
+ """
123
+ def __init__(self, alpha=0.1, gain=1, clamp=None):
124
+ # super().__init__()
125
+ self.alpha = alpha
126
+ if gain == 'sqrt_2':
127
+ self.gain = np.sqrt(2)
128
+ else:
129
+ self.gain = gain
130
+ self.clamp = clamp
131
+ self.repr = 'lrelu_agc(alpha={}, gain={}, clamp={})'.format(
132
+ alpha, gain, clamp)
133
+
134
+ # def forward(self, x, gain=1):
135
+ def __call__(self, x, gain=1):
136
+ x = F.leaky_relu(x, negative_slope=self.alpha, inplace=True)
137
+ act_gain = self.gain * gain
138
+ act_clamp = self.clamp * gain if self.clamp is not None else None
139
+ if act_gain != 1:
140
+ x = x * act_gain
141
+ if act_clamp is not None:
142
+ x = x.clamp(-act_clamp, act_clamp)
143
+ return x
144
+
145
+ def __repr__(self,):
146
+ return self.repr
147
+
148
+ ####################
149
+ # spatial encoding #
150
+ ####################
151
+
152
+ @register('se')
153
+ class SpatialEncoding(nn.Module):
154
+ def __init__(self,
155
+ in_dim,
156
+ out_dim,
157
+ sigma = 6,
158
+ cat_input=True,
159
+ require_grad=False,):
160
+
161
+ super().__init__()
162
+ assert out_dim % (2*in_dim) == 0, "dimension must be dividable"
163
+
164
+ n = out_dim // 2 // in_dim
165
+ m = 2**np.linspace(0, sigma, n)
166
+ m = np.stack([m] + [np.zeros_like(m)]*(in_dim-1), axis=-1)
167
+ m = np.concatenate([np.roll(m, i, axis=-1) for i in range(in_dim)], axis=0)
168
+ self.emb = torch.FloatTensor(m)
169
+ if require_grad:
170
+ self.emb = nn.Parameter(self.emb, requires_grad=True)
171
+ self.in_dim = in_dim
172
+ self.out_dim = out_dim
173
+ self.sigma = sigma
174
+ self.cat_input = cat_input
175
+ self.require_grad = require_grad
176
+
177
+ def forward(self, x, format='[n x c]'):
178
+ """
179
+ Args:
180
+ x: [n x m1],
181
+ m1 usually is 2
182
+ Outputs:
183
+ y: [n x m2]
184
+ m2 dimention number
185
+ """
186
+ if format == '[bs x c x 2D]':
187
+ xshape = x.shape
188
+ x = x.permute(0, 2, 3, 1).contiguous()
189
+ x = x.view(-1, x.size(-1))
190
+ elif format == '[n x c]':
191
+ pass
192
+ else:
193
+ raise ValueError
194
+
195
+ if not self.require_grad:
196
+ self.emb = self.emb.to(x.device)
197
+ y = torch.mm(x, self.emb.T)
198
+ if self.cat_input:
199
+ z = torch.cat([x, torch.sin(y), torch.cos(y)], dim=-1)
200
+ else:
201
+ z = torch.cat([torch.sin(y), torch.cos(y)], dim=-1)
202
+
203
+ if format == '[bs x c x 2D]':
204
+ z = z.view(xshape[0], xshape[2], xshape[3], -1)
205
+ z = z.permute(0, 3, 1, 2).contiguous()
206
+ return z
207
+
208
+ def extra_repr(self):
209
+ outstr = 'SpatialEncoding (in={}, out={}, sigma={}, cat_input={}, require_grad={})'.format(
210
+ self.in_dim, self.out_dim, self.sigma, self.cat_input, self.require_grad)
211
+ return outstr
212
+
213
+ @register('rffe')
214
+ class RFFEncoding(SpatialEncoding):
215
+ """
216
+ Random Fourier Features
217
+ """
218
+ def __init__(self,
219
+ in_dim,
220
+ out_dim,
221
+ sigma = 6,
222
+ cat_input=True,
223
+ require_grad=False,):
224
+
225
+ super().__init__(in_dim, out_dim, sigma, cat_input, require_grad)
226
+ n = out_dim // 2
227
+ m = np.random.normal(0, sigma, size=(n, in_dim))
228
+ self.emb = torch.FloatTensor(m)
229
+ if require_grad:
230
+ self.emb = nn.Parameter(self.emb, requires_grad=True)
231
+
232
+ def extra_repr(self):
233
+ outstr = 'RFFEncoding (in={}, out={}, sigma={}, cat_input={}, require_grad={})'.format(
234
+ self.in_dim, self.out_dim, self.sigma, self.cat_input, self.require_grad)
235
+ return outstr
236
+
237
+ ##########
238
+ # helper #
239
+ ##########
240
+
241
+ def freeze(net):
242
+ for m in net.modules():
243
+ if isinstance(m, (
244
+ nn.BatchNorm2d,
245
+ nn.SyncBatchNorm,)):
246
+ # inplace_abn not supported
247
+ m.eval()
248
+ for pi in net.parameters():
249
+ pi.requires_grad = False
250
+ return net
251
+
252
+ def common_init(m):
253
+ if isinstance(m, (
254
+ nn.Conv2d,
255
+ nn.ConvTranspose2d,)):
256
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
257
+ if m.bias is not None:
258
+ nn.init.constant_(m.bias, 0)
259
+ elif isinstance(m, (
260
+ nn.BatchNorm2d,
261
+ nn.SyncBatchNorm,)):
262
+ nn.init.constant_(m.weight, 1)
263
+ nn.init.constant_(m.bias, 0)
264
+ else:
265
+ pass
266
+
267
+ def init_module(module):
268
+ """
269
+ Args:
270
+ module: [nn.module] list or nn.module
271
+ a list of module to be initialized.
272
+ """
273
+ if isinstance(module, (list, tuple)):
274
+ module = list(module)
275
+ else:
276
+ module = [module]
277
+
278
+ for mi in module:
279
+ for mii in mi.modules():
280
+ common_init(mii)
281
+
282
+ def get_total_param(net):
283
+ if getattr(net, 'parameters', None) is None:
284
+ return 0
285
+ return sum(p.numel() for p in net.parameters())
286
+
287
+ def get_total_param_sum(net):
288
+ if getattr(net, 'parameters', None) is None:
289
+ return 0
290
+ with torch.no_grad():
291
+ s = sum(p.cpu().detach().numpy().sum().item() for p in net.parameters())
292
+ return s
lib/model_zoo/ddim.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from functools import partial
7
+
8
+ from .diffusion_utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9
+
10
+
11
+ class DDIMSampler(object):
12
+ def __init__(self, model, schedule="linear", **kwargs):
13
+ super().__init__()
14
+ self.model = model
15
+ self.ddpm_num_timesteps = model.num_timesteps
16
+ self.schedule = schedule
17
+
18
+ def register_buffer(self, name, attr):
19
+ if type(attr) == torch.Tensor:
20
+ if attr.device != torch.device("cuda"):
21
+ attr = attr.to(torch.device("cuda"))
22
+ setattr(self, name, attr)
23
+
24
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
25
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize,
26
+ num_ddim_timesteps=ddim_num_steps,
27
+ num_ddpm_timesteps=self.ddpm_num_timesteps,
28
+ verbose=verbose)
29
+ alphas_cumprod = self.model.alphas_cumprod
30
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
31
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
32
+
33
+ self.register_buffer('betas', to_torch(self.model.betas))
34
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
35
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
36
+
37
+ # calculations for diffusion q(x_t | x_{t-1}) and others
38
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
39
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
40
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
41
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
42
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
43
+
44
+ # ddim sampling parameters
45
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
46
+ alphacums=alphas_cumprod.cpu(),
47
+ ddim_timesteps=self.ddim_timesteps,
48
+ eta=ddim_eta,verbose=verbose)
49
+
50
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
51
+ self.register_buffer('ddim_alphas', ddim_alphas)
52
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
53
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
54
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
55
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
56
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
57
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
58
+
59
+ @torch.no_grad()
60
+ def sample(self,
61
+ S,
62
+ batch_size,
63
+ shape,
64
+ conditioning=None,
65
+ callback=None,
66
+ normals_sequence=None,
67
+ img_callback=None,
68
+ quantize_x0=False,
69
+ eta=0.,
70
+ mask=None,
71
+ x0=None,
72
+ temperature=1.,
73
+ noise_dropout=0.,
74
+ score_corrector=None,
75
+ corrector_kwargs=None,
76
+ verbose=True,
77
+ x_T=None,
78
+ log_every_t=100,
79
+ unconditional_guidance_scale=1.,
80
+ unconditional_conditioning=None,
81
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
82
+ **kwargs
83
+ ):
84
+ if conditioning is not None:
85
+ if isinstance(conditioning, dict):
86
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
87
+ if cbs != batch_size:
88
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
89
+ else:
90
+ if conditioning.shape[0] != batch_size:
91
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
92
+
93
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
94
+ # sampling
95
+ C, H, W = shape
96
+ size = (batch_size, C, H, W)
97
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
98
+
99
+ samples, intermediates = self.ddim_sampling(conditioning, size,
100
+ callback=callback,
101
+ img_callback=img_callback,
102
+ quantize_denoised=quantize_x0,
103
+ mask=mask, x0=x0,
104
+ ddim_use_original_steps=False,
105
+ noise_dropout=noise_dropout,
106
+ temperature=temperature,
107
+ score_corrector=score_corrector,
108
+ corrector_kwargs=corrector_kwargs,
109
+ x_T=x_T,
110
+ log_every_t=log_every_t,
111
+ unconditional_guidance_scale=unconditional_guidance_scale,
112
+ unconditional_conditioning=unconditional_conditioning,
113
+ )
114
+ return samples, intermediates
115
+
116
+ @torch.no_grad()
117
+ def ddim_sampling(self,
118
+ cond, shape,
119
+ x_T=None,
120
+ ddim_use_original_steps=False,
121
+ callback=None,
122
+ timesteps=None,
123
+ quantize_denoised=False,
124
+ mask=None, x0=None,
125
+ img_callback=None, log_every_t=100,
126
+ temperature=1.,
127
+ noise_dropout=0.,
128
+ score_corrector=None,
129
+ corrector_kwargs=None,
130
+ unconditional_guidance_scale=1.,
131
+ unconditional_conditioning=None,):
132
+ device = self.model.betas.device
133
+ b = shape[0]
134
+ if x_T is None:
135
+ img = torch.randn(shape, device=device)
136
+ else:
137
+ img = x_T
138
+
139
+ if timesteps is None:
140
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
141
+ elif timesteps is not None and not ddim_use_original_steps:
142
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
143
+ timesteps = self.ddim_timesteps[:subset_end]
144
+
145
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
146
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
147
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
148
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
149
+
150
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
151
+
152
+ for i, step in enumerate(iterator):
153
+ index = total_steps - i - 1
154
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
155
+
156
+ if mask is not None:
157
+ assert x0 is not None
158
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
159
+ img = img_orig * mask + (1. - mask) * img
160
+
161
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
162
+ quantize_denoised=quantize_denoised, temperature=temperature,
163
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
164
+ corrector_kwargs=corrector_kwargs,
165
+ unconditional_guidance_scale=unconditional_guidance_scale,
166
+ unconditional_conditioning=unconditional_conditioning)
167
+ img, pred_x0 = outs
168
+ if callback: callback(i)
169
+ if img_callback: img_callback(pred_x0, i)
170
+
171
+ if index % log_every_t == 0 or index == total_steps - 1:
172
+ intermediates['x_inter'].append(img)
173
+ intermediates['pred_x0'].append(pred_x0)
174
+
175
+ return img, intermediates
176
+
177
+ @torch.no_grad()
178
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
179
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
180
+ unconditional_guidance_scale=1., unconditional_conditioning=None):
181
+ b, *_, device = *x.shape, x.device
182
+
183
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
184
+ e_t = self.model.apply_model(x, t, c)
185
+ else:
186
+ x_in = torch.cat([x] * 2)
187
+ t_in = torch.cat([t] * 2)
188
+ c_in = torch.cat([unconditional_conditioning, c])
189
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
190
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
191
+
192
+ if score_corrector is not None:
193
+ assert self.model.parameterization == "eps"
194
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
195
+
196
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
197
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
198
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
199
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
200
+ # select parameters corresponding to the currently considered timestep
201
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
202
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
203
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
204
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
205
+
206
+ # current prediction for x_0
207
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
208
+ if quantize_denoised:
209
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
210
+ # direction pointing to x_t
211
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
212
+ noise = sigma_t * noise_like(x, repeat_noise) * temperature
213
+ if noise_dropout > 0.:
214
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
215
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
216
+ return x_prev, pred_x0
lib/model_zoo/ddim_dualcontext.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+ from functools import partial
5
+
6
+ from .diffusion_utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
7
+
8
+ from .ddim import DDIMSampler
9
+
10
+ class DDIMSampler_DualContext(DDIMSampler):
11
+ @torch.no_grad()
12
+ def sample_text(self, *args, **kwargs):
13
+ self.cond_type = 'prompt'
14
+ return self.sample(*args, **kwargs)
15
+
16
+ @torch.no_grad()
17
+ def sample_vision(self, *args, **kwargs):
18
+ self.cond_type = 'vision'
19
+ return self.sample(*args, **kwargs)
20
+
21
+ @torch.no_grad()
22
+ def sample_mixed(self, *args, **kwargs):
23
+ self.cond_type = kwargs.pop('cond_mixed_p')
24
+ return self.sample(*args, **kwargs)
25
+
26
+ @torch.no_grad()
27
+ def sample(self,
28
+ steps,
29
+ shape,
30
+ xt=None,
31
+ conditioning=None,
32
+ eta=0.,
33
+ temperature=1.,
34
+ noise_dropout=0.,
35
+ verbose=True,
36
+ log_every_t=100,
37
+ unconditional_guidance_scale=1.,
38
+ unconditional_conditioning=None,):
39
+
40
+ self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose)
41
+ # sampling
42
+ print(f'Data shape for DDIM sampling is {shape}, eta {eta}')
43
+
44
+ samples, intermediates = self.ddim_sampling(
45
+ conditioning,
46
+ shape,
47
+ xt=xt,
48
+ ddim_use_original_steps=False,
49
+ noise_dropout=noise_dropout,
50
+ temperature=temperature,
51
+ log_every_t=log_every_t,
52
+ unconditional_guidance_scale=unconditional_guidance_scale,
53
+ unconditional_conditioning=unconditional_conditioning,)
54
+ return samples, intermediates
55
+
56
+ @torch.no_grad()
57
+ def ddim_sampling(self,
58
+ conditioning,
59
+ shape,
60
+ xt=None,
61
+ ddim_use_original_steps=False,
62
+ timesteps=None,
63
+ log_every_t=100,
64
+ temperature=1.,
65
+ noise_dropout=0.,
66
+ unconditional_guidance_scale=1.,
67
+ unconditional_conditioning=None,):
68
+ device = self.model.betas.device
69
+ bs = shape[0]
70
+ if xt is None:
71
+ img = torch.randn(shape, device=device)
72
+ else:
73
+ img = xt
74
+
75
+ if timesteps is None:
76
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
77
+ elif timesteps is not None and not ddim_use_original_steps:
78
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
79
+ timesteps = self.ddim_timesteps[:subset_end]
80
+
81
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
82
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
83
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
84
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
85
+
86
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
87
+
88
+ for i, step in enumerate(iterator):
89
+ index = total_steps - i - 1
90
+ ts = torch.full((bs,), step, device=device, dtype=torch.long)
91
+
92
+ outs = self.p_sample_ddim(img, conditioning, ts, index=index, use_original_steps=ddim_use_original_steps,
93
+ temperature=temperature,
94
+ noise_dropout=noise_dropout,
95
+ unconditional_guidance_scale=unconditional_guidance_scale,
96
+ unconditional_conditioning=unconditional_conditioning)
97
+ img, pred_x0 = outs
98
+
99
+ if index % log_every_t == 0 or index == total_steps - 1:
100
+ intermediates['x_inter'].append(img)
101
+ intermediates['pred_x0'].append(pred_x0)
102
+
103
+ return img, intermediates
104
+
105
+ @torch.no_grad()
106
+ def p_sample_ddim(self, x, conditioning, t, index, repeat_noise=False, use_original_steps=False,
107
+ temperature=1., noise_dropout=0.,
108
+ unconditional_guidance_scale=1., unconditional_conditioning=None):
109
+ b, *_, device = *x.shape, x.device
110
+
111
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
112
+ e_t = self.model.apply_model(x, t, conditioning, cond_type=self.cond_type)
113
+ else:
114
+ x_in = torch.cat([x] * 2)
115
+ t_in = torch.cat([t] * 2)
116
+ # c_in = torch.cat([unconditional_conditioning, conditioning])
117
+
118
+ # Added for vd-dc dual guidance
119
+ if isinstance(unconditional_conditioning, list):
120
+ c_in = [torch.cat([ui, ci]) for ui, ci in zip(unconditional_conditioning, conditioning)]
121
+ else:
122
+ c_in = torch.cat([unconditional_conditioning, conditioning])
123
+
124
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in, cond_type=self.cond_type).chunk(2)
125
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
126
+
127
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
128
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
129
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
130
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
131
+ # select parameters corresponding to the currently considered timestep
132
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
133
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
134
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
135
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
136
+
137
+ # current prediction for x_0
138
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
139
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
140
+ noise = sigma_t * noise_like(x, repeat_noise) * temperature
141
+ if noise_dropout > 0.:
142
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
143
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
144
+ return x_prev, pred_x0
lib/model_zoo/ddim_dualmodel.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+ from functools import partial
5
+
6
+ from .diffusion_utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
7
+
8
+ from .ddim import DDIMSampler
9
+
10
+ class DDIMSampler_DualModel(DDIMSampler):
11
+ def __init__(self, model_t2i, model_v2i, schedule="linear", **kwargs):
12
+ self.model = model_t2i
13
+ self.model_t2i = model_t2i
14
+ self.model_v2i = model_v2i
15
+ self.device = self.model_t2i.device
16
+ self.ddpm_num_timesteps = model_t2i.num_timesteps
17
+ self.schedule = schedule
18
+
19
+ @torch.no_grad()
20
+ def sample_text(self, *args, **kwargs):
21
+ self.cond_type = 'prompt'
22
+ self.p_sample_model_type = 't2i'
23
+ return self.sample(*args, **kwargs)
24
+
25
+ @torch.no_grad()
26
+ def sample_vision(self, *args, **kwargs):
27
+ self.cond_type = 'vision'
28
+ self.p_sample_model_type = 'v2i'
29
+ return self.sample(*args, **kwargs)
30
+
31
+ @torch.no_grad()
32
+ def sample(self,
33
+ steps,
34
+ shape,
35
+ xt=None,
36
+ conditioning=None,
37
+ eta=0.,
38
+ temperature=1.,
39
+ noise_dropout=0.,
40
+ verbose=True,
41
+ log_every_t=100,
42
+ unconditional_guidance_scale=1.,
43
+ unconditional_conditioning=None,):
44
+
45
+ self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose)
46
+ # sampling
47
+ print(f'Data shape for DDIM sampling is {shape}, eta {eta}')
48
+
49
+ samples, intermediates = self.ddim_sampling(
50
+ conditioning,
51
+ shape,
52
+ xt=xt,
53
+ ddim_use_original_steps=False,
54
+ noise_dropout=noise_dropout,
55
+ temperature=temperature,
56
+ log_every_t=log_every_t,
57
+ unconditional_guidance_scale=unconditional_guidance_scale,
58
+ unconditional_conditioning=unconditional_conditioning,)
59
+ return samples, intermediates
60
+
61
+ @torch.no_grad()
62
+ def ddim_sampling(self,
63
+ conditioning,
64
+ shape,
65
+ xt=None,
66
+ ddim_use_original_steps=False,
67
+ timesteps=None,
68
+ log_every_t=100,
69
+ temperature=1.,
70
+ noise_dropout=0.,
71
+ unconditional_guidance_scale=1.,
72
+ unconditional_conditioning=None,):
73
+ device = self.model.betas.device
74
+ bs = shape[0]
75
+ if xt is None:
76
+ img = torch.randn(shape, device=device)
77
+ else:
78
+ img = xt
79
+
80
+ if timesteps is None:
81
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
82
+ elif timesteps is not None and not ddim_use_original_steps:
83
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
84
+ timesteps = self.ddim_timesteps[:subset_end]
85
+
86
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
87
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
88
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
89
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
90
+
91
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
92
+
93
+ for i, step in enumerate(iterator):
94
+ index = total_steps - i - 1
95
+ ts = torch.full((bs,), step, device=device, dtype=torch.long)
96
+
97
+ outs = self.p_sample_ddim(img, conditioning, ts, index=index, use_original_steps=ddim_use_original_steps,
98
+ temperature=temperature,
99
+ noise_dropout=noise_dropout,
100
+ unconditional_guidance_scale=unconditional_guidance_scale,
101
+ unconditional_conditioning=unconditional_conditioning)
102
+ img, pred_x0 = outs
103
+
104
+ if index % log_every_t == 0 or index == total_steps - 1:
105
+ intermediates['x_inter'].append(img)
106
+ intermediates['pred_x0'].append(pred_x0)
107
+
108
+ return img, intermediates
109
+
110
+ @torch.no_grad()
111
+ def p_sample_ddim(self, x, conditioning, t, index, repeat_noise=False, use_original_steps=False,
112
+ temperature=1., noise_dropout=0.,
113
+ unconditional_guidance_scale=1., unconditional_conditioning=None):
114
+ b, *_, device = *x.shape, x.device
115
+
116
+ if self.p_sample_model_type == 't2i':
117
+ apply_model = self.model_t2i.apply_model
118
+ elif self.p_sample_model_type == 'v2i':
119
+ apply_model = self.model_v2i.apply_model
120
+
121
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
122
+ e_t = apply_model(x, t, conditioning)
123
+ else:
124
+ x_in = torch.cat([x] * 2)
125
+ t_in = torch.cat([t] * 2)
126
+ c_in = torch.cat([unconditional_conditioning, conditioning])
127
+ e_t_uncond, e_t = apply_model(x_in, t_in, c_in).chunk(2)
128
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
129
+
130
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
131
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
132
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
133
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
134
+ # select parameters corresponding to the currently considered timestep
135
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
136
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
137
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
138
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
139
+
140
+ # current prediction for x_0
141
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
142
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
143
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
144
+ if noise_dropout > 0.:
145
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
146
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
147
+ return x_prev, pred_x0
148
+
149
+ @torch.no_grad()
150
+ def sample_mixed(self,
151
+ steps,
152
+ steps_t2i,
153
+ steps_v2i,
154
+ shape,
155
+ xt=None,
156
+ c_prompt=None,
157
+ c_vision=None,
158
+ eta=0.,
159
+ temperature=1.,
160
+ noise_dropout=0.,
161
+ verbose=True,
162
+ log_every_t=100,
163
+ uc_scale=1.,
164
+ uc_prompt=None,
165
+ uc_vision=None,):
166
+
167
+ print(f'DDIM mixed sampling with shape {shape}, eta {eta}')
168
+ print(f'steps_t2i {steps_t2i}')
169
+ print(f'steps_v2i {steps_v2i}')
170
+
171
+ self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose)
172
+ self.ddim_timesteps_t2i = self.ddim_timesteps[steps_t2i]
173
+ self.ddim_timesteps_v2i = self.ddim_timesteps[steps_v2i]
174
+
175
+ samples, intermediates = self.ddim_sampling_mixed(
176
+ c_prompt,
177
+ c_vision,
178
+ shape,
179
+ xt=xt,
180
+ noise_dropout=noise_dropout,
181
+ temperature=temperature,
182
+ log_every_t=log_every_t,
183
+ uc_scale=uc_scale,
184
+ uc_prompt=uc_prompt,
185
+ uc_vision=uc_vision, )
186
+ return samples, intermediates
187
+
188
+ @torch.no_grad()
189
+ def ddim_sampling_mixed(self,
190
+ c_prompt,
191
+ c_vision,
192
+ shape,
193
+ xt=None,
194
+ log_every_t=100,
195
+ temperature=1.,
196
+ noise_dropout=0.,
197
+ uc_scale=1.,
198
+ uc_prompt=None,
199
+ uc_vision=None, ):
200
+ device = self.device
201
+ bs = shape[0]
202
+ if xt is None:
203
+ img = torch.randn(shape, device=device)
204
+ else:
205
+ img = xt
206
+
207
+ timesteps = self.ddim_timesteps
208
+ intermediates = {'x_inter': [], 'pred_x0': []}
209
+ time_range = np.flip(timesteps)
210
+ total_steps = timesteps.shape[0]
211
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
212
+
213
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
214
+
215
+ for i, step in enumerate(iterator):
216
+ if step in self.ddim_timesteps_t2i:
217
+ self.p_sample_model_type = 't2i'
218
+ conditioning = c_prompt
219
+ unconditional_conditioning = uc_prompt
220
+ elif step in self.ddim_timesteps_v2i:
221
+ self.p_sample_model_type = 'v2i'
222
+ conditioning = c_vision
223
+ unconditional_conditioning = uc_vision
224
+ else:
225
+ raise ValueError # shouldn't reached
226
+
227
+ index = total_steps - i - 1
228
+ ts = torch.full((bs,), step, device=device, dtype=torch.long)
229
+ outs = self.p_sample_ddim(
230
+ img, conditioning, ts,
231
+ index=index,
232
+ temperature=temperature,
233
+ noise_dropout=noise_dropout,
234
+ unconditional_guidance_scale=uc_scale,
235
+ unconditional_conditioning=unconditional_conditioning)
236
+ img, pred_x0 = outs
237
+
238
+ if index % log_every_t == 0 or index == total_steps - 1:
239
+ intermediates['x_inter'].append(img)
240
+ intermediates['pred_x0'].append(pred_x0)
241
+
242
+ return img, intermediates
243
+
244
+
lib/model_zoo/ddim_vd.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+ from functools import partial
5
+
6
+ from .diffusion_utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
7
+
8
+ from .ddim import DDIMSampler
9
+
10
+ class DDIMSampler_VD(DDIMSampler):
11
+ @torch.no_grad()
12
+ def sample(self,
13
+ steps,
14
+ shape,
15
+ xt=None,
16
+ conditioning=None,
17
+ unconditional_guidance_scale=1.,
18
+ unconditional_conditioning=None,
19
+ xtype='image',
20
+ ctype='prompt',
21
+ eta=0.,
22
+ temperature=1.,
23
+ noise_dropout=0.,
24
+ verbose=True,
25
+ log_every_t=100,):
26
+
27
+ self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose)
28
+ print(f'Data shape for DDIM sampling is {shape}, eta {eta}')
29
+ samples, intermediates = self.ddim_sampling(
30
+ shape,
31
+ xt=xt,
32
+ conditioning=conditioning,
33
+ unconditional_guidance_scale=unconditional_guidance_scale,
34
+ unconditional_conditioning=unconditional_conditioning,
35
+ xtype=xtype,
36
+ ctype=ctype,
37
+ ddim_use_original_steps=False,
38
+ noise_dropout=noise_dropout,
39
+ temperature=temperature,
40
+ log_every_t=log_every_t,)
41
+ return samples, intermediates
42
+
43
+ @torch.no_grad()
44
+ def ddim_sampling(self,
45
+ shape,
46
+ xt=None,
47
+ conditioning=None,
48
+ unconditional_guidance_scale=1.,
49
+ unconditional_conditioning=None,
50
+ xtype='image',
51
+ ctype='prompt',
52
+ ddim_use_original_steps=False,
53
+ timesteps=None,
54
+ noise_dropout=0.,
55
+ temperature=1.,
56
+ log_every_t=100,):
57
+
58
+ device = self.model.device
59
+ bs = shape[0]
60
+ if xt is None:
61
+ xt = torch.randn(shape, device=device, dtype=conditioning.dtype)
62
+
63
+ if timesteps is None:
64
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
65
+ elif timesteps is not None and not ddim_use_original_steps:
66
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
67
+ timesteps = self.ddim_timesteps[:subset_end]
68
+
69
+ intermediates = {'pred_xt': [], 'pred_x0': []}
70
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
71
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
72
+ # print(f"Running DDIM Sampling with {total_steps} timesteps")
73
+
74
+ pred_xt = xt
75
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
76
+ for i, step in enumerate(iterator):
77
+ index = total_steps - i - 1
78
+ ts = torch.full((bs,), step, device=device, dtype=torch.long)
79
+
80
+ outs = self.p_sample_ddim(
81
+ pred_xt, conditioning, ts, index,
82
+ unconditional_guidance_scale=unconditional_guidance_scale,
83
+ unconditional_conditioning=unconditional_conditioning,
84
+ xtype=xtype,
85
+ ctype=ctype,
86
+ use_original_steps=ddim_use_original_steps,
87
+ noise_dropout=noise_dropout,
88
+ temperature=temperature,)
89
+ pred_xt, pred_x0 = outs
90
+
91
+ if index % log_every_t == 0 or index == total_steps - 1:
92
+ intermediates['pred_xt'].append(pred_xt)
93
+ intermediates['pred_x0'].append(pred_x0)
94
+
95
+ return pred_xt, intermediates
96
+
97
+ @torch.no_grad()
98
+ def p_sample_ddim(self, x, conditioning, t, index,
99
+ unconditional_guidance_scale=1.,
100
+ unconditional_conditioning=None,
101
+ xtype='image',
102
+ ctype='prompt',
103
+ repeat_noise=False,
104
+ use_original_steps=False,
105
+ noise_dropout=0.,
106
+ temperature=1.,):
107
+
108
+ b, *_, device = *x.shape, x.device
109
+
110
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
111
+ e_t = self.model.apply_model(x, t, conditioning, xtype=xtype, ctype=ctype)
112
+ else:
113
+ x_in = torch.cat([x] * 2)
114
+ t_in = torch.cat([t] * 2)
115
+ c_in = torch.cat([unconditional_conditioning, conditioning])
116
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in, xtype=xtype, ctype=ctype).chunk(2)
117
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
118
+
119
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
120
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
121
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
122
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
123
+ # select parameters corresponding to the currently considered timestep
124
+
125
+ if xtype == 'image':
126
+ extended_shape = (b, 1, 1, 1)
127
+ elif xtype == 'text':
128
+ extended_shape = (b, 1)
129
+
130
+ a_t = torch.full(extended_shape, alphas[index], device=device, dtype=x.dtype)
131
+ a_prev = torch.full(extended_shape, alphas_prev[index], device=device, dtype=x.dtype)
132
+ sigma_t = torch.full(extended_shape, sigmas[index], device=device, dtype=x.dtype)
133
+ sqrt_one_minus_at = torch.full(extended_shape, sqrt_one_minus_alphas[index], device=device, dtype=x.dtype)
134
+
135
+ # current prediction for x_0
136
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
137
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
138
+ noise = sigma_t * noise_like(x, repeat_noise) * temperature
139
+ if noise_dropout > 0.:
140
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
141
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
142
+ return x_prev, pred_x0
143
+
144
+ @torch.no_grad()
145
+ def sample_dc(self,
146
+ steps,
147
+ shape,
148
+ xt=None,
149
+ first_conditioning=None,
150
+ second_conditioning=None,
151
+ unconditional_guidance_scale=1.,
152
+ xtype='image',
153
+ first_ctype='prompt',
154
+ second_ctype='prompt',
155
+ eta=0.,
156
+ temperature=1.,
157
+ mixed_ratio=0.5,
158
+ noise_dropout=0.,
159
+ verbose=True,
160
+ log_every_t=100,):
161
+
162
+ self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose)
163
+ print(f'Data shape for DDIM sampling is {shape}, eta {eta}')
164
+ samples, intermediates = self.ddim_sampling_dc(
165
+ shape,
166
+ xt=xt,
167
+ first_conditioning=first_conditioning,
168
+ second_conditioning=second_conditioning,
169
+ unconditional_guidance_scale=unconditional_guidance_scale,
170
+ xtype=xtype,
171
+ first_ctype=first_ctype,
172
+ second_ctype=second_ctype,
173
+ ddim_use_original_steps=False,
174
+ noise_dropout=noise_dropout,
175
+ temperature=temperature,
176
+ log_every_t=log_every_t,
177
+ mixed_ratio=mixed_ratio, )
178
+ return samples, intermediates
179
+
180
+ @torch.no_grad()
181
+ def ddim_sampling_dc(self,
182
+ shape,
183
+ xt=None,
184
+ first_conditioning=None,
185
+ second_conditioning=None,
186
+ unconditional_guidance_scale=1.,
187
+ xtype='image',
188
+ first_ctype='prompt',
189
+ second_ctype='prompt',
190
+ ddim_use_original_steps=False,
191
+ timesteps=None,
192
+ noise_dropout=0.,
193
+ temperature=1.,
194
+ mixed_ratio=0.5,
195
+ log_every_t=100,):
196
+
197
+ device = self.model.device
198
+ bs = shape[0]
199
+ if xt is None:
200
+ xt = torch.randn(shape, device=device, dtype=first_conditioning[1].dtype)
201
+
202
+ if timesteps is None:
203
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
204
+ elif timesteps is not None and not ddim_use_original_steps:
205
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
206
+ timesteps = self.ddim_timesteps[:subset_end]
207
+
208
+ intermediates = {'pred_xt': [], 'pred_x0': []}
209
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
210
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
211
+ # print(f"Running DDIM Sampling with {total_steps} timesteps")
212
+
213
+ pred_xt = xt
214
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
215
+ for i, step in enumerate(iterator):
216
+ index = total_steps - i - 1
217
+ ts = torch.full((bs,), step, device=device, dtype=torch.long)
218
+
219
+ outs = self.p_sample_ddim_dc(
220
+ pred_xt,
221
+ first_conditioning,
222
+ second_conditioning,
223
+ ts, index,
224
+ unconditional_guidance_scale=unconditional_guidance_scale,
225
+ xtype=xtype,
226
+ first_ctype=first_ctype,
227
+ second_ctype=second_ctype,
228
+ use_original_steps=ddim_use_original_steps,
229
+ noise_dropout=noise_dropout,
230
+ temperature=temperature,
231
+ mixed_ratio=mixed_ratio,)
232
+ pred_xt, pred_x0 = outs
233
+
234
+ if index % log_every_t == 0 or index == total_steps - 1:
235
+ intermediates['pred_xt'].append(pred_xt)
236
+ intermediates['pred_x0'].append(pred_x0)
237
+
238
+ return pred_xt, intermediates
239
+
240
+ @torch.no_grad()
241
+ def p_sample_ddim_dc(self, x,
242
+ first_conditioning,
243
+ second_conditioning,
244
+ t, index,
245
+ unconditional_guidance_scale=1.,
246
+ xtype='image',
247
+ first_ctype='prompt',
248
+ second_ctype='prompt',
249
+ repeat_noise=False,
250
+ use_original_steps=False,
251
+ noise_dropout=0.,
252
+ temperature=1.,
253
+ mixed_ratio=0.5,):
254
+
255
+ b, *_, device = *x.shape, x.device
256
+
257
+ x_in = torch.cat([x] * 2)
258
+ t_in = torch.cat([t] * 2)
259
+ first_c = torch.cat(first_conditioning)
260
+ second_c = torch.cat(second_conditioning)
261
+
262
+ e_t_uncond, e_t = self.model.apply_model_dc(
263
+ x_in, t_in, first_c, second_c, xtype=xtype, first_ctype=first_ctype, second_ctype=second_ctype, mixed_ratio=mixed_ratio).chunk(2)
264
+
265
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
266
+
267
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
268
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
269
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
270
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
271
+ # select parameters corresponding to the currently considered timestep
272
+
273
+ if xtype == 'image':
274
+ extended_shape = (b, 1, 1, 1)
275
+ elif xtype == 'text':
276
+ extended_shape = (b, 1)
277
+
278
+ a_t = torch.full(extended_shape, alphas[index], device=device, dtype=x.dtype)
279
+ a_prev = torch.full(extended_shape, alphas_prev[index], device=device, dtype=x.dtype)
280
+ sigma_t = torch.full(extended_shape, sigmas[index], device=device, dtype=x.dtype)
281
+ sqrt_one_minus_at = torch.full(extended_shape, sqrt_one_minus_alphas[index], device=device, dtype=x.dtype)
282
+
283
+ # current prediction for x_0
284
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
285
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
286
+ noise = sigma_t * noise_like(x, repeat_noise) * temperature
287
+ if noise_dropout > 0.:
288
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
289
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
290
+ return x_prev, pred_x0
lib/model_zoo/diffusion_modules.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from einops import rearrange
7
+
8
+ # from .diffusion_utils import instantiate_from_config
9
+ from .attention import LinearAttention
10
+
11
+
12
+ def get_timestep_embedding(timesteps, embedding_dim):
13
+ """
14
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
15
+ From Fairseq.
16
+ Build sinusoidal embeddings.
17
+ This matches the implementation in tensor2tensor, but differs slightly
18
+ from the description in Section 3.5 of "Attention Is All You Need".
19
+ """
20
+ assert len(timesteps.shape) == 1
21
+
22
+ half_dim = embedding_dim // 2
23
+ emb = math.log(10000) / (half_dim - 1)
24
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
25
+ emb = emb.to(device=timesteps.device)
26
+ emb = timesteps.float()[:, None] * emb[None, :]
27
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
28
+ if embedding_dim % 2 == 1: # zero pad
29
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
30
+ return emb
31
+
32
+
33
+ def nonlinearity(x):
34
+ # swish
35
+ return x*torch.sigmoid(x)
36
+
37
+
38
+ def Normalize(in_channels, num_groups=32):
39
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
40
+
41
+
42
+ class Upsample(nn.Module):
43
+ def __init__(self, in_channels, with_conv):
44
+ super().__init__()
45
+ self.with_conv = with_conv
46
+ if self.with_conv:
47
+ self.conv = torch.nn.Conv2d(in_channels,
48
+ in_channels,
49
+ kernel_size=3,
50
+ stride=1,
51
+ padding=1)
52
+
53
+ def forward(self, x):
54
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
55
+ if self.with_conv:
56
+ x = self.conv(x)
57
+ return x
58
+
59
+
60
+ class Downsample(nn.Module):
61
+ def __init__(self, in_channels, with_conv):
62
+ super().__init__()
63
+ self.with_conv = with_conv
64
+ if self.with_conv:
65
+ # no asymmetric padding in torch conv, must do it ourselves
66
+ self.conv = torch.nn.Conv2d(in_channels,
67
+ in_channels,
68
+ kernel_size=3,
69
+ stride=2,
70
+ padding=0)
71
+
72
+ def forward(self, x):
73
+ if self.with_conv:
74
+ pad = (0,1,0,1)
75
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
76
+ x = self.conv(x)
77
+ else:
78
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
79
+ return x
80
+
81
+
82
+ class ResnetBlock(nn.Module):
83
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
84
+ dropout, temb_channels=512):
85
+ super().__init__()
86
+ self.in_channels = in_channels
87
+ out_channels = in_channels if out_channels is None else out_channels
88
+ self.out_channels = out_channels
89
+ self.use_conv_shortcut = conv_shortcut
90
+
91
+ self.norm1 = Normalize(in_channels)
92
+ self.conv1 = torch.nn.Conv2d(in_channels,
93
+ out_channels,
94
+ kernel_size=3,
95
+ stride=1,
96
+ padding=1)
97
+ if temb_channels > 0:
98
+ self.temb_proj = torch.nn.Linear(temb_channels,
99
+ out_channels)
100
+ self.norm2 = Normalize(out_channels)
101
+ self.dropout = torch.nn.Dropout(dropout)
102
+ self.conv2 = torch.nn.Conv2d(out_channels,
103
+ out_channels,
104
+ kernel_size=3,
105
+ stride=1,
106
+ padding=1)
107
+ if self.in_channels != self.out_channels:
108
+ if self.use_conv_shortcut:
109
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
110
+ out_channels,
111
+ kernel_size=3,
112
+ stride=1,
113
+ padding=1)
114
+ else:
115
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
116
+ out_channels,
117
+ kernel_size=1,
118
+ stride=1,
119
+ padding=0)
120
+
121
+ def forward(self, x, temb):
122
+ h = x
123
+ h = self.norm1(h)
124
+ h = nonlinearity(h)
125
+ h = self.conv1(h)
126
+
127
+ if temb is not None:
128
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
129
+
130
+ h = self.norm2(h)
131
+ h = nonlinearity(h)
132
+ h = self.dropout(h)
133
+ h = self.conv2(h)
134
+
135
+ if self.in_channels != self.out_channels:
136
+ if self.use_conv_shortcut:
137
+ x = self.conv_shortcut(x)
138
+ else:
139
+ x = self.nin_shortcut(x)
140
+
141
+ return x+h
142
+
143
+
144
+ class LinAttnBlock(LinearAttention):
145
+ """to match AttnBlock usage"""
146
+ def __init__(self, in_channels):
147
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
148
+
149
+
150
+ class AttnBlock(nn.Module):
151
+ def __init__(self, in_channels):
152
+ super().__init__()
153
+ self.in_channels = in_channels
154
+
155
+ self.norm = Normalize(in_channels)
156
+ self.q = torch.nn.Conv2d(in_channels,
157
+ in_channels,
158
+ kernel_size=1,
159
+ stride=1,
160
+ padding=0)
161
+ self.k = torch.nn.Conv2d(in_channels,
162
+ in_channels,
163
+ kernel_size=1,
164
+ stride=1,
165
+ padding=0)
166
+ self.v = torch.nn.Conv2d(in_channels,
167
+ in_channels,
168
+ kernel_size=1,
169
+ stride=1,
170
+ padding=0)
171
+ self.proj_out = torch.nn.Conv2d(in_channels,
172
+ in_channels,
173
+ kernel_size=1,
174
+ stride=1,
175
+ padding=0)
176
+
177
+
178
+ def forward(self, x):
179
+ h_ = x
180
+ h_ = self.norm(h_)
181
+ q = self.q(h_)
182
+ k = self.k(h_)
183
+ v = self.v(h_)
184
+
185
+ # compute attention
186
+ b,c,h,w = q.shape
187
+ q = q.reshape(b,c,h*w)
188
+ q = q.permute(0,2,1) # b,hw,c
189
+ k = k.reshape(b,c,h*w) # b,c,hw
190
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
191
+ w_ = w_ * (int(c)**(-0.5))
192
+ w_ = torch.nn.functional.softmax(w_, dim=2)
193
+
194
+ # attend to values
195
+ v = v.reshape(b,c,h*w)
196
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
197
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
198
+ h_ = h_.reshape(b,c,h,w)
199
+
200
+ h_ = self.proj_out(h_)
201
+
202
+ return x+h_
203
+
204
+
205
+ def make_attn(in_channels, attn_type="vanilla"):
206
+ assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
207
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
208
+ if attn_type == "vanilla":
209
+ return AttnBlock(in_channels)
210
+ elif attn_type == "none":
211
+ return nn.Identity(in_channels)
212
+ else:
213
+ return LinAttnBlock(in_channels)
214
+
215
+
216
+ class Model(nn.Module):
217
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
218
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
219
+ resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
220
+ super().__init__()
221
+ if use_linear_attn: attn_type = "linear"
222
+ self.ch = ch
223
+ self.temb_ch = self.ch*4
224
+ self.num_resolutions = len(ch_mult)
225
+ self.num_res_blocks = num_res_blocks
226
+ self.resolution = resolution
227
+ self.in_channels = in_channels
228
+
229
+ self.use_timestep = use_timestep
230
+ if self.use_timestep:
231
+ # timestep embedding
232
+ self.temb = nn.Module()
233
+ self.temb.dense = nn.ModuleList([
234
+ torch.nn.Linear(self.ch,
235
+ self.temb_ch),
236
+ torch.nn.Linear(self.temb_ch,
237
+ self.temb_ch),
238
+ ])
239
+
240
+ # downsampling
241
+ self.conv_in = torch.nn.Conv2d(in_channels,
242
+ self.ch,
243
+ kernel_size=3,
244
+ stride=1,
245
+ padding=1)
246
+
247
+ curr_res = resolution
248
+ in_ch_mult = (1,)+tuple(ch_mult)
249
+ self.down = nn.ModuleList()
250
+ for i_level in range(self.num_resolutions):
251
+ block = nn.ModuleList()
252
+ attn = nn.ModuleList()
253
+ block_in = ch*in_ch_mult[i_level]
254
+ block_out = ch*ch_mult[i_level]
255
+ for i_block in range(self.num_res_blocks):
256
+ block.append(ResnetBlock(in_channels=block_in,
257
+ out_channels=block_out,
258
+ temb_channels=self.temb_ch,
259
+ dropout=dropout))
260
+ block_in = block_out
261
+ if curr_res in attn_resolutions:
262
+ attn.append(make_attn(block_in, attn_type=attn_type))
263
+ down = nn.Module()
264
+ down.block = block
265
+ down.attn = attn
266
+ if i_level != self.num_resolutions-1:
267
+ down.downsample = Downsample(block_in, resamp_with_conv)
268
+ curr_res = curr_res // 2
269
+ self.down.append(down)
270
+
271
+ # middle
272
+ self.mid = nn.Module()
273
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
274
+ out_channels=block_in,
275
+ temb_channels=self.temb_ch,
276
+ dropout=dropout)
277
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
278
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
279
+ out_channels=block_in,
280
+ temb_channels=self.temb_ch,
281
+ dropout=dropout)
282
+
283
+ # upsampling
284
+ self.up = nn.ModuleList()
285
+ for i_level in reversed(range(self.num_resolutions)):
286
+ block = nn.ModuleList()
287
+ attn = nn.ModuleList()
288
+ block_out = ch*ch_mult[i_level]
289
+ skip_in = ch*ch_mult[i_level]
290
+ for i_block in range(self.num_res_blocks+1):
291
+ if i_block == self.num_res_blocks:
292
+ skip_in = ch*in_ch_mult[i_level]
293
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
294
+ out_channels=block_out,
295
+ temb_channels=self.temb_ch,
296
+ dropout=dropout))
297
+ block_in = block_out
298
+ if curr_res in attn_resolutions:
299
+ attn.append(make_attn(block_in, attn_type=attn_type))
300
+ up = nn.Module()
301
+ up.block = block
302
+ up.attn = attn
303
+ if i_level != 0:
304
+ up.upsample = Upsample(block_in, resamp_with_conv)
305
+ curr_res = curr_res * 2
306
+ self.up.insert(0, up) # prepend to get consistent order
307
+
308
+ # end
309
+ self.norm_out = Normalize(block_in)
310
+ self.conv_out = torch.nn.Conv2d(block_in,
311
+ out_ch,
312
+ kernel_size=3,
313
+ stride=1,
314
+ padding=1)
315
+
316
+ def forward(self, x, t=None, context=None):
317
+ #assert x.shape[2] == x.shape[3] == self.resolution
318
+ if context is not None:
319
+ # assume aligned context, cat along channel axis
320
+ x = torch.cat((x, context), dim=1)
321
+ if self.use_timestep:
322
+ # timestep embedding
323
+ assert t is not None
324
+ temb = get_timestep_embedding(t, self.ch)
325
+ temb = self.temb.dense[0](temb)
326
+ temb = nonlinearity(temb)
327
+ temb = self.temb.dense[1](temb)
328
+ else:
329
+ temb = None
330
+
331
+ # downsampling
332
+ hs = [self.conv_in(x)]
333
+ for i_level in range(self.num_resolutions):
334
+ for i_block in range(self.num_res_blocks):
335
+ h = self.down[i_level].block[i_block](hs[-1], temb)
336
+ if len(self.down[i_level].attn) > 0:
337
+ h = self.down[i_level].attn[i_block](h)
338
+ hs.append(h)
339
+ if i_level != self.num_resolutions-1:
340
+ hs.append(self.down[i_level].downsample(hs[-1]))
341
+
342
+ # middle
343
+ h = hs[-1]
344
+ h = self.mid.block_1(h, temb)
345
+ h = self.mid.attn_1(h)
346
+ h = self.mid.block_2(h, temb)
347
+
348
+ # upsampling
349
+ for i_level in reversed(range(self.num_resolutions)):
350
+ for i_block in range(self.num_res_blocks+1):
351
+ h = self.up[i_level].block[i_block](
352
+ torch.cat([h, hs.pop()], dim=1), temb)
353
+ if len(self.up[i_level].attn) > 0:
354
+ h = self.up[i_level].attn[i_block](h)
355
+ if i_level != 0:
356
+ h = self.up[i_level].upsample(h)
357
+
358
+ # end
359
+ h = self.norm_out(h)
360
+ h = nonlinearity(h)
361
+ h = self.conv_out(h)
362
+ return h
363
+
364
+ def get_last_layer(self):
365
+ return self.conv_out.weight
366
+
367
+
368
+ class Encoder(nn.Module):
369
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
370
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
371
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
372
+ **ignore_kwargs):
373
+ super().__init__()
374
+ if use_linear_attn: attn_type = "linear"
375
+ self.ch = ch
376
+ self.temb_ch = 0
377
+ self.num_resolutions = len(ch_mult)
378
+ self.num_res_blocks = num_res_blocks
379
+ self.resolution = resolution
380
+ self.in_channels = in_channels
381
+
382
+ # downsampling
383
+ self.conv_in = torch.nn.Conv2d(in_channels,
384
+ self.ch,
385
+ kernel_size=3,
386
+ stride=1,
387
+ padding=1)
388
+
389
+ curr_res = resolution
390
+ in_ch_mult = (1,)+tuple(ch_mult)
391
+ self.in_ch_mult = in_ch_mult
392
+ self.down = nn.ModuleList()
393
+ for i_level in range(self.num_resolutions):
394
+ block = nn.ModuleList()
395
+ attn = nn.ModuleList()
396
+ block_in = ch*in_ch_mult[i_level]
397
+ block_out = ch*ch_mult[i_level]
398
+ for i_block in range(self.num_res_blocks):
399
+ block.append(ResnetBlock(in_channels=block_in,
400
+ out_channels=block_out,
401
+ temb_channels=self.temb_ch,
402
+ dropout=dropout))
403
+ block_in = block_out
404
+ if curr_res in attn_resolutions:
405
+ attn.append(make_attn(block_in, attn_type=attn_type))
406
+ down = nn.Module()
407
+ down.block = block
408
+ down.attn = attn
409
+ if i_level != self.num_resolutions-1:
410
+ down.downsample = Downsample(block_in, resamp_with_conv)
411
+ curr_res = curr_res // 2
412
+ self.down.append(down)
413
+
414
+ # middle
415
+ self.mid = nn.Module()
416
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
417
+ out_channels=block_in,
418
+ temb_channels=self.temb_ch,
419
+ dropout=dropout)
420
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
421
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
422
+ out_channels=block_in,
423
+ temb_channels=self.temb_ch,
424
+ dropout=dropout)
425
+
426
+ # end
427
+ self.norm_out = Normalize(block_in)
428
+ self.conv_out = torch.nn.Conv2d(block_in,
429
+ 2*z_channels if double_z else z_channels,
430
+ kernel_size=3,
431
+ stride=1,
432
+ padding=1)
433
+
434
+ def forward(self, x):
435
+ # timestep embedding
436
+ temb = None
437
+
438
+ # downsampling
439
+ hs = [self.conv_in(x)]
440
+ for i_level in range(self.num_resolutions):
441
+ for i_block in range(self.num_res_blocks):
442
+ h = self.down[i_level].block[i_block](hs[-1], temb)
443
+ if len(self.down[i_level].attn) > 0:
444
+ h = self.down[i_level].attn[i_block](h)
445
+ hs.append(h)
446
+ if i_level != self.num_resolutions-1:
447
+ hs.append(self.down[i_level].downsample(hs[-1]))
448
+
449
+ # middle
450
+ h = hs[-1]
451
+ h = self.mid.block_1(h, temb)
452
+ h = self.mid.attn_1(h)
453
+ h = self.mid.block_2(h, temb)
454
+
455
+ # end
456
+ h = self.norm_out(h)
457
+ h = nonlinearity(h)
458
+ h = self.conv_out(h)
459
+ return h
460
+
461
+
462
+ class Decoder(nn.Module):
463
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
464
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
465
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
466
+ attn_type="vanilla", **ignorekwargs):
467
+ super().__init__()
468
+ if use_linear_attn: attn_type = "linear"
469
+ self.ch = ch
470
+ self.temb_ch = 0
471
+ self.num_resolutions = len(ch_mult)
472
+ self.num_res_blocks = num_res_blocks
473
+ self.resolution = resolution
474
+ self.in_channels = in_channels
475
+ self.give_pre_end = give_pre_end
476
+ self.tanh_out = tanh_out
477
+
478
+ # compute in_ch_mult, block_in and curr_res at lowest res
479
+ in_ch_mult = (1,)+tuple(ch_mult)
480
+ block_in = ch*ch_mult[self.num_resolutions-1]
481
+ curr_res = resolution // 2**(self.num_resolutions-1)
482
+ self.z_shape = (1,z_channels,curr_res,curr_res)
483
+ print("Working with z of shape {} = {} dimensions.".format(
484
+ self.z_shape, np.prod(self.z_shape)))
485
+
486
+ # z to block_in
487
+ self.conv_in = torch.nn.Conv2d(z_channels,
488
+ block_in,
489
+ kernel_size=3,
490
+ stride=1,
491
+ padding=1)
492
+
493
+ # middle
494
+ self.mid = nn.Module()
495
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
496
+ out_channels=block_in,
497
+ temb_channels=self.temb_ch,
498
+ dropout=dropout)
499
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
500
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
501
+ out_channels=block_in,
502
+ temb_channels=self.temb_ch,
503
+ dropout=dropout)
504
+
505
+ # upsampling
506
+ self.up = nn.ModuleList()
507
+ for i_level in reversed(range(self.num_resolutions)):
508
+ block = nn.ModuleList()
509
+ attn = nn.ModuleList()
510
+ block_out = ch*ch_mult[i_level]
511
+ for i_block in range(self.num_res_blocks+1):
512
+ block.append(ResnetBlock(in_channels=block_in,
513
+ out_channels=block_out,
514
+ temb_channels=self.temb_ch,
515
+ dropout=dropout))
516
+ block_in = block_out
517
+ if curr_res in attn_resolutions:
518
+ attn.append(make_attn(block_in, attn_type=attn_type))
519
+ up = nn.Module()
520
+ up.block = block
521
+ up.attn = attn
522
+ if i_level != 0:
523
+ up.upsample = Upsample(block_in, resamp_with_conv)
524
+ curr_res = curr_res * 2
525
+ self.up.insert(0, up) # prepend to get consistent order
526
+
527
+ # end
528
+ self.norm_out = Normalize(block_in)
529
+ self.conv_out = torch.nn.Conv2d(block_in,
530
+ out_ch,
531
+ kernel_size=3,
532
+ stride=1,
533
+ padding=1)
534
+
535
+ def forward(self, z):
536
+ #assert z.shape[1:] == self.z_shape[1:]
537
+ self.last_z_shape = z.shape
538
+
539
+ # timestep embedding
540
+ temb = None
541
+
542
+ # z to block_in
543
+ h = self.conv_in(z)
544
+
545
+ # middle
546
+ h = self.mid.block_1(h, temb)
547
+ h = self.mid.attn_1(h)
548
+ h = self.mid.block_2(h, temb)
549
+
550
+ # upsampling
551
+ for i_level in reversed(range(self.num_resolutions)):
552
+ for i_block in range(self.num_res_blocks+1):
553
+ h = self.up[i_level].block[i_block](h, temb)
554
+ if len(self.up[i_level].attn) > 0:
555
+ h = self.up[i_level].attn[i_block](h)
556
+ if i_level != 0:
557
+ h = self.up[i_level].upsample(h)
558
+
559
+ # end
560
+ if self.give_pre_end:
561
+ return h
562
+
563
+ h = self.norm_out(h)
564
+ h = nonlinearity(h)
565
+ h = self.conv_out(h)
566
+ if self.tanh_out:
567
+ h = torch.tanh(h)
568
+ return h
569
+
570
+
571
+ class SimpleDecoder(nn.Module):
572
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
573
+ super().__init__()
574
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
575
+ ResnetBlock(in_channels=in_channels,
576
+ out_channels=2 * in_channels,
577
+ temb_channels=0, dropout=0.0),
578
+ ResnetBlock(in_channels=2 * in_channels,
579
+ out_channels=4 * in_channels,
580
+ temb_channels=0, dropout=0.0),
581
+ ResnetBlock(in_channels=4 * in_channels,
582
+ out_channels=2 * in_channels,
583
+ temb_channels=0, dropout=0.0),
584
+ nn.Conv2d(2*in_channels, in_channels, 1),
585
+ Upsample(in_channels, with_conv=True)])
586
+ # end
587
+ self.norm_out = Normalize(in_channels)
588
+ self.conv_out = torch.nn.Conv2d(in_channels,
589
+ out_channels,
590
+ kernel_size=3,
591
+ stride=1,
592
+ padding=1)
593
+
594
+ def forward(self, x):
595
+ for i, layer in enumerate(self.model):
596
+ if i in [1,2,3]:
597
+ x = layer(x, None)
598
+ else:
599
+ x = layer(x)
600
+
601
+ h = self.norm_out(x)
602
+ h = nonlinearity(h)
603
+ x = self.conv_out(h)
604
+ return x
605
+
606
+
607
+ class UpsampleDecoder(nn.Module):
608
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
609
+ ch_mult=(2,2), dropout=0.0):
610
+ super().__init__()
611
+ # upsampling
612
+ self.temb_ch = 0
613
+ self.num_resolutions = len(ch_mult)
614
+ self.num_res_blocks = num_res_blocks
615
+ block_in = in_channels
616
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
617
+ self.res_blocks = nn.ModuleList()
618
+ self.upsample_blocks = nn.ModuleList()
619
+ for i_level in range(self.num_resolutions):
620
+ res_block = []
621
+ block_out = ch * ch_mult[i_level]
622
+ for i_block in range(self.num_res_blocks + 1):
623
+ res_block.append(ResnetBlock(in_channels=block_in,
624
+ out_channels=block_out,
625
+ temb_channels=self.temb_ch,
626
+ dropout=dropout))
627
+ block_in = block_out
628
+ self.res_blocks.append(nn.ModuleList(res_block))
629
+ if i_level != self.num_resolutions - 1:
630
+ self.upsample_blocks.append(Upsample(block_in, True))
631
+ curr_res = curr_res * 2
632
+
633
+ # end
634
+ self.norm_out = Normalize(block_in)
635
+ self.conv_out = torch.nn.Conv2d(block_in,
636
+ out_channels,
637
+ kernel_size=3,
638
+ stride=1,
639
+ padding=1)
640
+
641
+ def forward(self, x):
642
+ # upsampling
643
+ h = x
644
+ for k, i_level in enumerate(range(self.num_resolutions)):
645
+ for i_block in range(self.num_res_blocks + 1):
646
+ h = self.res_blocks[i_level][i_block](h, None)
647
+ if i_level != self.num_resolutions - 1:
648
+ h = self.upsample_blocks[k](h)
649
+ h = self.norm_out(h)
650
+ h = nonlinearity(h)
651
+ h = self.conv_out(h)
652
+ return h
653
+
654
+
655
+ class LatentRescaler(nn.Module):
656
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
657
+ super().__init__()
658
+ # residual block, interpolate, residual block
659
+ self.factor = factor
660
+ self.conv_in = nn.Conv2d(in_channels,
661
+ mid_channels,
662
+ kernel_size=3,
663
+ stride=1,
664
+ padding=1)
665
+ self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
666
+ out_channels=mid_channels,
667
+ temb_channels=0,
668
+ dropout=0.0) for _ in range(depth)])
669
+ self.attn = AttnBlock(mid_channels)
670
+ self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
671
+ out_channels=mid_channels,
672
+ temb_channels=0,
673
+ dropout=0.0) for _ in range(depth)])
674
+
675
+ self.conv_out = nn.Conv2d(mid_channels,
676
+ out_channels,
677
+ kernel_size=1,
678
+ )
679
+
680
+ def forward(self, x):
681
+ x = self.conv_in(x)
682
+ for block in self.res_block1:
683
+ x = block(x, None)
684
+ x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
685
+ x = self.attn(x)
686
+ for block in self.res_block2:
687
+ x = block(x, None)
688
+ x = self.conv_out(x)
689
+ return x
690
+
691
+
692
+ class MergedRescaleEncoder(nn.Module):
693
+ def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
694
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
695
+ ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
696
+ super().__init__()
697
+ intermediate_chn = ch * ch_mult[-1]
698
+ self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
699
+ z_channels=intermediate_chn, double_z=False, resolution=resolution,
700
+ attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
701
+ out_ch=None)
702
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
703
+ mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
704
+
705
+ def forward(self, x):
706
+ x = self.encoder(x)
707
+ x = self.rescaler(x)
708
+ return x
709
+
710
+
711
+ class MergedRescaleDecoder(nn.Module):
712
+ def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
713
+ dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
714
+ super().__init__()
715
+ tmp_chn = z_channels*ch_mult[-1]
716
+ self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
717
+ resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
718
+ ch_mult=ch_mult, resolution=resolution, ch=ch)
719
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
720
+ out_channels=tmp_chn, depth=rescale_module_depth)
721
+
722
+ def forward(self, x):
723
+ x = self.rescaler(x)
724
+ x = self.decoder(x)
725
+ return x
726
+
727
+
728
+ class Upsampler(nn.Module):
729
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
730
+ super().__init__()
731
+ assert out_size >= in_size
732
+ num_blocks = int(np.log2(out_size//in_size))+1
733
+ factor_up = 1.+ (out_size % in_size)
734
+ print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
735
+ self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
736
+ out_channels=in_channels)
737
+ self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
738
+ attn_resolutions=[], in_channels=None, ch=in_channels,
739
+ ch_mult=[ch_mult for _ in range(num_blocks)])
740
+
741
+ def forward(self, x):
742
+ x = self.rescaler(x)
743
+ x = self.decoder(x)
744
+ return x
745
+
746
+
747
+ class Resize(nn.Module):
748
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
749
+ super().__init__()
750
+ self.with_conv = learned
751
+ self.mode = mode
752
+ if self.with_conv:
753
+ print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
754
+ raise NotImplementedError()
755
+ assert in_channels is not None
756
+ # no asymmetric padding in torch conv, must do it ourselves
757
+ self.conv = torch.nn.Conv2d(in_channels,
758
+ in_channels,
759
+ kernel_size=4,
760
+ stride=2,
761
+ padding=1)
762
+
763
+ def forward(self, x, scale_factor=1.0):
764
+ if scale_factor==1.0:
765
+ return x
766
+ else:
767
+ x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
768
+ return x
769
+
770
+ class FirstStagePostProcessor(nn.Module):
771
+
772
+ def __init__(self, ch_mult:list, in_channels,
773
+ pretrained_model:nn.Module=None,
774
+ reshape=False,
775
+ n_channels=None,
776
+ dropout=0.,
777
+ pretrained_config=None):
778
+ super().__init__()
779
+ if pretrained_config is None:
780
+ assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
781
+ self.pretrained_model = pretrained_model
782
+ else:
783
+ assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
784
+ self.instantiate_pretrained(pretrained_config)
785
+
786
+ self.do_reshape = reshape
787
+
788
+ if n_channels is None:
789
+ n_channels = self.pretrained_model.encoder.ch
790
+
791
+ self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
792
+ self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
793
+ stride=1,padding=1)
794
+
795
+ blocks = []
796
+ downs = []
797
+ ch_in = n_channels
798
+ for m in ch_mult:
799
+ blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
800
+ ch_in = m * n_channels
801
+ downs.append(Downsample(ch_in, with_conv=False))
802
+
803
+ self.model = nn.ModuleList(blocks)
804
+ self.downsampler = nn.ModuleList(downs)
805
+
806
+
807
+ def instantiate_pretrained(self, config):
808
+ model = instantiate_from_config(config)
809
+ self.pretrained_model = model.eval()
810
+ # self.pretrained_model.train = False
811
+ for param in self.pretrained_model.parameters():
812
+ param.requires_grad = False
813
+
814
+
815
+ @torch.no_grad()
816
+ def encode_with_pretrained(self,x):
817
+ c = self.pretrained_model.encode(x)
818
+ if isinstance(c, DiagonalGaussianDistribution):
819
+ c = c.mode()
820
+ return c
821
+
822
+ def forward(self,x):
823
+ z_fs = self.encode_with_pretrained(x)
824
+ z = self.proj_norm(z_fs)
825
+ z = self.proj(z)
826
+ z = nonlinearity(z)
827
+
828
+ for submodel, downmodel in zip(self.model,self.downsampler):
829
+ z = submodel(z,temb=None)
830
+ z = downmodel(z)
831
+
832
+ if self.do_reshape:
833
+ z = rearrange(z,'b c h w -> b (h w) c')
834
+ return z
835
+
lib/model_zoo/diffusion_utils.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from einops import repeat
7
+
8
+ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
9
+ if schedule == "linear":
10
+ betas = (
11
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
12
+ )
13
+
14
+ elif schedule == "cosine":
15
+ timesteps = (
16
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
17
+ )
18
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
19
+ alphas = torch.cos(alphas).pow(2)
20
+ alphas = alphas / alphas[0]
21
+ betas = 1 - alphas[1:] / alphas[:-1]
22
+ betas = np.clip(betas, a_min=0, a_max=0.999)
23
+
24
+ elif schedule == "sqrt_linear":
25
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
26
+ elif schedule == "sqrt":
27
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
28
+ else:
29
+ raise ValueError(f"schedule '{schedule}' unknown.")
30
+ return betas.numpy()
31
+
32
+ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
33
+ if ddim_discr_method == 'uniform':
34
+ c = num_ddpm_timesteps // num_ddim_timesteps
35
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
36
+ elif ddim_discr_method == 'quad':
37
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
38
+ else:
39
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
40
+
41
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
42
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
43
+ steps_out = ddim_timesteps + 1
44
+ if verbose:
45
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
46
+ return steps_out
47
+
48
+ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
49
+ # select alphas for computing the variance schedule
50
+ alphas = alphacums[ddim_timesteps]
51
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
52
+
53
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
54
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
55
+ if verbose:
56
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
57
+ print(f'For the chosen value of eta, which is {eta}, '
58
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
59
+ return sigmas, alphas, alphas_prev
60
+
61
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
62
+ """
63
+ Create a beta schedule that discretizes the given alpha_t_bar function,
64
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
65
+ :param num_diffusion_timesteps: the number of betas to produce.
66
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
67
+ produces the cumulative product of (1-beta) up to that
68
+ part of the diffusion process.
69
+ :param max_beta: the maximum beta to use; use values lower than 1 to
70
+ prevent singularities.
71
+ """
72
+ betas = []
73
+ for i in range(num_diffusion_timesteps):
74
+ t1 = i / num_diffusion_timesteps
75
+ t2 = (i + 1) / num_diffusion_timesteps
76
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
77
+ return np.array(betas)
78
+
79
+ def extract_into_tensor(a, t, x_shape):
80
+ b, *_ = t.shape
81
+ out = a.gather(-1, t)
82
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
83
+
84
+ def checkpoint(func, inputs, params, flag):
85
+ """
86
+ Evaluate a function without caching intermediate activations, allowing for
87
+ reduced memory at the expense of extra compute in the backward pass.
88
+ :param func: the function to evaluate.
89
+ :param inputs: the argument sequence to pass to `func`.
90
+ :param params: a sequence of parameters `func` depends on but does not
91
+ explicitly take as arguments.
92
+ :param flag: if False, disable gradient checkpointing.
93
+ """
94
+ if flag:
95
+ args = tuple(inputs) + tuple(params)
96
+ return CheckpointFunction.apply(func, len(inputs), *args)
97
+ else:
98
+ return func(*inputs)
99
+
100
+ class CheckpointFunction(torch.autograd.Function):
101
+ @staticmethod
102
+ def forward(ctx, run_function, length, *args):
103
+ ctx.run_function = run_function
104
+ ctx.input_tensors = list(args[:length])
105
+ ctx.input_params = list(args[length:])
106
+
107
+ with torch.no_grad():
108
+ output_tensors = ctx.run_function(*ctx.input_tensors)
109
+ return output_tensors
110
+
111
+ @staticmethod
112
+ def backward(ctx, *output_grads):
113
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
114
+ with torch.enable_grad():
115
+ # Fixes a bug where the first op in run_function modifies the
116
+ # Tensor storage in place, which is not allowed for detach()'d
117
+ # Tensors.
118
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
119
+ output_tensors = ctx.run_function(*shallow_copies)
120
+ input_grads = torch.autograd.grad(
121
+ output_tensors,
122
+ ctx.input_tensors + ctx.input_params,
123
+ output_grads,
124
+ allow_unused=True,
125
+ )
126
+ del ctx.input_tensors
127
+ del ctx.input_params
128
+ del output_tensors
129
+ return (None, None) + input_grads
130
+
131
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
132
+ """
133
+ Create sinusoidal timestep embeddings.
134
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
135
+ These may be fractional.
136
+ :param dim: the dimension of the output.
137
+ :param max_period: controls the minimum frequency of the embeddings.
138
+ :return: an [N x dim] Tensor of positional embeddings.
139
+ """
140
+ if not repeat_only:
141
+ half = dim // 2
142
+ freqs = torch.exp(
143
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
144
+ ).to(device=timesteps.device)
145
+ args = timesteps[:, None].float() * freqs[None]
146
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
147
+ if dim % 2:
148
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
149
+ else:
150
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
151
+ return embedding
152
+
153
+ def zero_module(module):
154
+ """
155
+ Zero out the parameters of a module and return it.
156
+ """
157
+ for p in module.parameters():
158
+ p.detach().zero_()
159
+ return module
160
+
161
+ def scale_module(module, scale):
162
+ """
163
+ Scale the parameters of a module and return it.
164
+ """
165
+ for p in module.parameters():
166
+ p.detach().mul_(scale)
167
+ return module
168
+
169
+ def mean_flat(tensor):
170
+ """
171
+ Take the mean over all non-batch dimensions.
172
+ """
173
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
174
+
175
+ def normalization(channels):
176
+ """
177
+ Make a standard normalization layer.
178
+ :param channels: number of input channels.
179
+ :return: an nn.Module for normalization.
180
+ """
181
+ return GroupNorm32(32, channels)
182
+
183
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
184
+ class SiLU(nn.Module):
185
+ def forward(self, x):
186
+ return x * torch.sigmoid(x)
187
+
188
+ class GroupNorm32(nn.GroupNorm):
189
+ def forward(self, x):
190
+ # return super().forward(x.float()).type(x.dtype)
191
+ return super().forward(x)
192
+
193
+ def conv_nd(dims, *args, **kwargs):
194
+ """
195
+ Create a 1D, 2D, or 3D convolution module.
196
+ """
197
+ if dims == 1:
198
+ return nn.Conv1d(*args, **kwargs)
199
+ elif dims == 2:
200
+ return nn.Conv2d(*args, **kwargs)
201
+ elif dims == 3:
202
+ return nn.Conv3d(*args, **kwargs)
203
+ raise ValueError(f"unsupported dimensions: {dims}")
204
+
205
+ def linear(*args, **kwargs):
206
+ """
207
+ Create a linear module.
208
+ """
209
+ return nn.Linear(*args, **kwargs)
210
+
211
+ def avg_pool_nd(dims, *args, **kwargs):
212
+ """
213
+ Create a 1D, 2D, or 3D average pooling module.
214
+ """
215
+ if dims == 1:
216
+ return nn.AvgPool1d(*args, **kwargs)
217
+ elif dims == 2:
218
+ return nn.AvgPool2d(*args, **kwargs)
219
+ elif dims == 3:
220
+ return nn.AvgPool3d(*args, **kwargs)
221
+ raise ValueError(f"unsupported dimensions: {dims}")
222
+
223
+ class HybridConditioner(nn.Module):
224
+
225
+ def __init__(self, c_concat_config, c_crossattn_config):
226
+ super().__init__()
227
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
228
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
229
+
230
+ def forward(self, c_concat, c_crossattn):
231
+ c_concat = self.concat_conditioner(c_concat)
232
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
233
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
234
+
235
+ def noise_like(x, repeat=False):
236
+ noise = torch.randn_like(x)
237
+ if repeat:
238
+ bs = x.shape[0]
239
+ noise = noise[0:1].repeat(bs, *((1,) * (len(x.shape) - 1)))
240
+ return noise
241
+
242
+ ##########################
243
+ # inherit from ldm.utils #
244
+ ##########################
245
+
246
+ def count_params(model, verbose=False):
247
+ total_params = sum(p.numel() for p in model.parameters())
248
+ if verbose:
249
+ print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
250
+ return total_params
lib/model_zoo/distributions.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class AbstractDistribution:
6
+ def sample(self):
7
+ raise NotImplementedError()
8
+
9
+ def mode(self):
10
+ raise NotImplementedError()
11
+
12
+
13
+ class DiracDistribution(AbstractDistribution):
14
+ def __init__(self, value):
15
+ self.value = value
16
+
17
+ def sample(self):
18
+ return self.value
19
+
20
+ def mode(self):
21
+ return self.value
22
+
23
+
24
+ class DiagonalGaussianDistribution(object):
25
+ def __init__(self, parameters, deterministic=False):
26
+ self.parameters = parameters
27
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29
+ self.deterministic = deterministic
30
+ self.std = torch.exp(0.5 * self.logvar)
31
+ self.var = torch.exp(self.logvar)
32
+ if self.deterministic:
33
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34
+
35
+ def sample(self):
36
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37
+ return x
38
+
39
+ def kl(self, other=None):
40
+ if self.deterministic:
41
+ return torch.Tensor([0.])
42
+ else:
43
+ if other is None:
44
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
45
+ + self.var - 1.0 - self.logvar,
46
+ dim=[1, 2, 3])
47
+ else:
48
+ return 0.5 * torch.sum(
49
+ torch.pow(self.mean - other.mean, 2) / other.var
50
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
51
+ dim=[1, 2, 3])
52
+
53
+ def nll(self, sample, dims=[1,2,3]):
54
+ if self.deterministic:
55
+ return torch.Tensor([0.])
56
+ logtwopi = np.log(2.0 * np.pi)
57
+ return 0.5 * torch.sum(
58
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59
+ dim=dims)
60
+
61
+ def mode(self):
62
+ return self.mean
63
+
64
+
65
+ def normal_kl(mean1, logvar1, mean2, logvar2):
66
+ """
67
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68
+ Compute the KL divergence between two gaussians.
69
+ Shapes are automatically broadcasted, so batches can be compared to
70
+ scalars, among other use cases.
71
+ """
72
+ tensor = None
73
+ for obj in (mean1, logvar1, mean2, logvar2):
74
+ if isinstance(obj, torch.Tensor):
75
+ tensor = obj
76
+ break
77
+ assert tensor is not None, "at least one argument must be a Tensor"
78
+
79
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
80
+ # Tensors, but it does not work for torch.exp().
81
+ logvar1, logvar2 = [
82
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83
+ for x in (logvar1, logvar2)
84
+ ]
85
+
86
+ return 0.5 * (
87
+ -1.0
88
+ + logvar2
89
+ - logvar1
90
+ + torch.exp(logvar1 - logvar2)
91
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92
+ )