Diffusers
Safetensors
ImageReFL
TorchRik commited on
Commit
e7aec62
·
verified ·
1 Parent(s): a977b83

Upload combined_stable_diffusion.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. combined_stable_diffusion.py +397 -0
combined_stable_diffusion.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+
4
+ import torch
5
+ from diffusers import DiffusionPipeline, DDPMScheduler
6
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput
7
+ from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
8
+ from diffusers.image_processor import VaeImageProcessor
9
+ from huggingface_hub import PyTorchModelHubMixin
10
+ from PIL import Image
11
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor
12
+
13
+
14
+
15
+ class CombinedStableDiffusion(
16
+ DiffusionPipeline,
17
+ PyTorchModelHubMixin
18
+ ):
19
+ """
20
+ A Stable Diffusion model wrapper that provides functionality for text-to-image synthesis,
21
+ noise scheduling, latent space manipulation, and image decoding.
22
+ """
23
+ def __init__(
24
+ self,
25
+ original_unet: torch.nn.Module,
26
+ fine_tuned_unet: torch.nn.Module,
27
+ scheduler: DDPMScheduler,
28
+ vae: torch.nn.Module,
29
+ tokenizer: CLIPTextModel,
30
+ safety_checker: StableDiffusionSafetyChecker,
31
+ feature_extractor: CLIPImageProcessor,
32
+ text_encoder: CLIPTokenizer,
33
+ ) -> None:
34
+
35
+ super().__init__()
36
+
37
+ self.register_modules(
38
+ tokenizer=tokenizer,
39
+ text_encoder=text_encoder,
40
+ original_unet=original_unet,
41
+ fine_tuned_unet=fine_tuned_unet,
42
+ scheduler=scheduler,
43
+ vae=vae,
44
+ safety_checker=safety_checker,
45
+ feature_extractor=feature_extractor,
46
+ )
47
+
48
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
49
+ self.image_processor = VaeImageProcessor(
50
+ vae_scale_factor=self.vae_scale_factor
51
+ )
52
+
53
+ def _get_negative_prompts(self, batch_size: int) -> torch.Tensor:
54
+ return self.tokenizer(
55
+ [""] * batch_size,
56
+ max_length=self.tokenizer.model_max_length,
57
+ padding="max_length",
58
+ truncation=True,
59
+ return_tensors="pt",
60
+ ).input_ids
61
+
62
+ def _get_encoder_hidden_states(
63
+ self, tokenized_prompts: torch.Tensor, do_classifier_free_guidance: bool = False
64
+ ) -> torch.Tensor:
65
+ if do_classifier_free_guidance:
66
+ tokenized_prompts = torch.cat(
67
+ [
68
+ self._get_negative_prompts(tokenized_prompts.shape[0]).to(
69
+ tokenized_prompts.device
70
+ ),
71
+ tokenized_prompts,
72
+ ]
73
+ )
74
+
75
+ return self.text_encoder(tokenized_prompts)[0]
76
+
77
+ def _get_unet_prediction(
78
+ self,
79
+ latent_model_input: torch.Tensor,
80
+ timestep: int,
81
+ encoder_hidden_states: torch.Tensor,
82
+ ) -> torch.Tensor:
83
+ """
84
+ Return unet noise prediction
85
+
86
+ Args:
87
+ latent_model_input (torch.Tensor): Unet latents input
88
+ timestep (int): noise scheduler timestep
89
+ encoder_hidden_states (torch.Tensor): Text encoder hidden states
90
+
91
+ Returns:
92
+ torch.Tensor: noise prediction
93
+ """
94
+ unet = self.original_unet if self._use_original_unet else self.fine_tuned_unet
95
+
96
+ return unet(
97
+ latent_model_input,
98
+ timestep=timestep,
99
+ encoder_hidden_states=encoder_hidden_states,
100
+ ).sample
101
+
102
+ def get_noise_prediction(
103
+ self,
104
+ latents: torch.Tensor,
105
+ timestep_index: int,
106
+ encoder_hidden_states: torch.Tensor,
107
+ do_classifier_free_guidance: bool = False,
108
+ detach_main_path: bool = False,
109
+ ):
110
+ """
111
+ Return noise prediction
112
+
113
+ Args:
114
+ latents (torch.Tensor): Image latents
115
+ timestep_index (int): noise scheduler timestep index
116
+ encoder_hidden_states (torch.Tensor): Text encoder hidden states
117
+ do_classifier_free_guidance (bool) Whether to do classifier free guidance
118
+ detach_main_path (bool): Detach gradient
119
+
120
+ Returns:
121
+ torch.Tensor: noise prediction
122
+ """
123
+ timestep = self.scheduler.timesteps[timestep_index]
124
+
125
+ latent_model_input = self.scheduler.scale_model_input(
126
+ sample=torch.cat([latents] * 2) if do_classifier_free_guidance else latents,
127
+ timestep=timestep,
128
+ )
129
+
130
+ noise_pred = self._get_unet_prediction(
131
+ latent_model_input=latent_model_input,
132
+ timestep=timestep,
133
+ encoder_hidden_states=encoder_hidden_states,
134
+ )
135
+
136
+ if do_classifier_free_guidance:
137
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
138
+ if detach_main_path:
139
+ noise_pred_text = noise_pred_text.detach()
140
+
141
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
142
+ noise_pred_text - noise_pred_uncond
143
+ )
144
+ return noise_pred
145
+
146
+ def sample_next_latents(
147
+ self,
148
+ latents: torch.Tensor,
149
+ timestep_index: int,
150
+ noise_pred: torch.Tensor,
151
+ return_pred_original: bool = False,
152
+ ) -> torch.Tensor:
153
+ """
154
+ Return next latents prediction
155
+
156
+ Args:
157
+ latents (torch.Tensor): Image latents
158
+ timestep_index (int): noise scheduler timestep index
159
+ noise_pred (torch.Tensor): noise prediction
160
+ return_pred_original (bool) Whether to sample original sample
161
+
162
+ Returns:
163
+ torch.Tensor: latent prediction
164
+ """
165
+ timestep = self.scheduler.timesteps[timestep_index]
166
+ sample = self.scheduler.step(
167
+ model_output=noise_pred, timestep=timestep, sample=latents
168
+ )
169
+ return (
170
+ sample.pred_original_sample if return_pred_original else sample.prev_sample
171
+ )
172
+
173
+ def predict_next_latents(
174
+ self,
175
+ latents: torch.Tensor,
176
+ timestep_index: int,
177
+ encoder_hidden_states: torch.Tensor,
178
+ return_pred_original: bool = False,
179
+ do_classifier_free_guidance: bool = False,
180
+ detach_main_path: bool = False,
181
+ ) -> tuple[torch.Tensor, torch.Tensor]:
182
+ """
183
+ Predicts the next latent states during the diffusion process.
184
+
185
+ Args:
186
+ latents (torch.Tensor): Current latent states.
187
+ timestep_index (int): Index of the current timestep.
188
+ encoder_hidden_states (torch.Tensor): Encoder hidden states from the text encoder.
189
+ return_pred_original (bool): Whether to return the predicted original sample.
190
+ do_classifier_free_guidance (bool) Whether to do classifier free guidance
191
+ detach_main_path (bool): Detach gradient
192
+
193
+ Returns:
194
+ tuple: Next latents and predicted noise tensor.
195
+ """
196
+
197
+ noise_pred = self.get_noise_prediction(
198
+ latents=latents,
199
+ timestep_index=timestep_index,
200
+ encoder_hidden_states=encoder_hidden_states,
201
+ do_classifier_free_guidance=do_classifier_free_guidance,
202
+ detach_main_path=detach_main_path,
203
+ )
204
+
205
+ latents = self.sample_next_latents(
206
+ latents=latents,
207
+ noise_pred=noise_pred,
208
+ timestep_index=timestep_index,
209
+ return_pred_original=return_pred_original,
210
+ )
211
+
212
+ return latents, noise_pred
213
+
214
+ def get_latents(self, batch_size: int, device: torch.device) -> torch.Tensor:
215
+ latent_resolution = int(self.resolution) // self.vae_scale_factor
216
+ return torch.randn(
217
+ (
218
+ batch_size,
219
+ self.original_unet.config.in_channels,
220
+ latent_resolution,
221
+ latent_resolution,
222
+ ),
223
+ device=device,
224
+ )
225
+
226
+ def do_k_diffusion_steps(
227
+ self,
228
+ start_timestep_index: int,
229
+ end_timestep_index: int,
230
+ latents: torch.Tensor,
231
+ encoder_hidden_states: torch.Tensor,
232
+ return_pred_original: bool = False,
233
+ do_classifier_free_guidance: bool = False,
234
+ detach_main_path: bool = False,
235
+ ) -> tuple[torch.Tensor, torch.Tensor]:
236
+ """
237
+ Performs multiple diffusion steps between specified timesteps.
238
+
239
+ Args:
240
+ start_timestep_index (int): Starting timestep index.
241
+ end_timestep_index (int): Ending timestep index.
242
+ latents (torch.Tensor): Initial latents.
243
+ encoder_hidden_states (torch.Tensor): Encoder hidden states.
244
+ return_pred_original (bool): Whether to return the predicted original sample.
245
+ do_classifier_free_guidance (bool) Whether to do classifier free guidance
246
+ detach_main_path (bool): Detach gradient
247
+
248
+ Returns:
249
+ tuple: Resulting latents and encoder hidden states.
250
+ """
251
+ assert start_timestep_index <= end_timestep_index
252
+
253
+ for timestep_index in range(start_timestep_index, end_timestep_index - 1):
254
+ latents, _ = self.predict_next_latents(
255
+ latents=latents,
256
+ timestep_index=timestep_index,
257
+ encoder_hidden_states=encoder_hidden_states,
258
+ return_pred_original=False,
259
+ do_classifier_free_guidance=do_classifier_free_guidance,
260
+ detach_main_path=detach_main_path,
261
+ )
262
+ res, _ = self.predict_next_latents(
263
+ latents=latents,
264
+ timestep_index=end_timestep_index - 1,
265
+ encoder_hidden_states=encoder_hidden_states,
266
+ return_pred_original=return_pred_original,
267
+ do_classifier_free_guidance=do_classifier_free_guidance,
268
+ )
269
+ return res, encoder_hidden_states
270
+
271
+ def get_pil_image(self, raw_images: torch.Tensor) -> list[Image]:
272
+ do_denormalize = [True] * raw_images.shape[0]
273
+ images = self.inference_image_processor.postprocess(
274
+ raw_images, output_type="pil", do_denormalize=do_denormalize
275
+ )
276
+ return images
277
+
278
+ def get_reward_image(self, raw_images: torch.Tensor) -> torch.Tensor:
279
+ reward_images = (raw_images / 2 + 0.5).clamp(0, 1)
280
+
281
+ if self.use_image_shifting:
282
+ self._shift_tensor_batch(
283
+ reward_images,
284
+ dx=random.randint(0, math.ceil(self.resolution / 224)),
285
+ dy=random.randint(0, math.ceil(self.resolution / 224)),
286
+ )
287
+
288
+ return self.reward_image_processor(reward_images)
289
+
290
+ def run_safety_checker(self, image, device, dtype):
291
+ if self.safety_checker is None:
292
+ has_nsfw_concept = None
293
+ else:
294
+ if torch.is_tensor(image):
295
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
296
+ else:
297
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
298
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
299
+ image, has_nsfw_concept = self.safety_checker(
300
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
301
+ )
302
+ return image, has_nsfw_concept
303
+
304
+ @torch.no_grad()
305
+ def __call__(
306
+ self,
307
+ prompt: str | list[str],
308
+ num_inference_steps=40,
309
+ original_unet_steps=30,
310
+ resolution=512,
311
+ guidance_scale=7.5,
312
+ output_type: str = "pil",
313
+ return_dict: bool = True,
314
+ generator=None,
315
+ ):
316
+ self.guidance_scale = guidance_scale
317
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
318
+
319
+ tokenized_prompts = self.tokenizer(
320
+ prompt,
321
+ return_tensors="pt",
322
+ padding="max_length",
323
+ max_length=self.tokenizer.model_max_length,
324
+ truncation=True
325
+ ).input_ids.to(self.device)
326
+ original_encoder_hidden_states = self._get_encoder_hidden_states(
327
+ tokenized_prompts=tokenized_prompts,
328
+ do_classifier_free_guidance=True
329
+ )
330
+ fine_tuned_encoder_hidden_states = self._get_encoder_hidden_states(
331
+ tokenized_prompts=tokenized_prompts,
332
+ do_classifier_free_guidance=False
333
+ )
334
+
335
+ latent_resolution = int(resolution) // self.vae_scale_factor
336
+ latents = torch.randn(
337
+ (
338
+ batch_size,
339
+ self.original_unet.config.in_channels,
340
+ latent_resolution,
341
+ latent_resolution,
342
+ ),
343
+ device=self.device,
344
+ )
345
+
346
+ self.scheduler.set_timesteps(
347
+ num_inference_steps,
348
+ device=self.device
349
+ )
350
+
351
+ self._use_original_unet = True
352
+ latents, _ = self.do_k_diffusion_steps(
353
+ start_timestep_index=0,
354
+ end_timestep_index=original_unet_steps,
355
+ latents=latents,
356
+ encoder_hidden_states=original_encoder_hidden_states,
357
+ return_pred_original=False,
358
+ do_classifier_free_guidance=True,
359
+ )
360
+
361
+ self._use_original_unet = False
362
+ latents, _ = self.do_k_diffusion_steps(
363
+ start_timestep_index=original_unet_steps,
364
+ end_timestep_index=num_inference_steps,
365
+ latents=latents,
366
+ encoder_hidden_states=fine_tuned_encoder_hidden_states,
367
+ return_pred_original=False,
368
+ do_classifier_free_guidance=False,
369
+ )
370
+
371
+ if not output_type == "latent":
372
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
373
+ 0
374
+ ]
375
+ image, has_nsfw_concept = self.run_safety_checker(
376
+ image, self.device, original_encoder_hidden_states.dtype)
377
+ else:
378
+ image = latents
379
+ has_nsfw_concept = None
380
+
381
+ if has_nsfw_concept is None:
382
+ do_denormalize = [True] * image.shape[0]
383
+ else:
384
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
385
+ image = self.image_processor.postprocess(
386
+ image,
387
+ output_type=output_type,
388
+ do_denormalize=do_denormalize
389
+ )
390
+
391
+ # Offload all models
392
+ self.maybe_free_model_hooks()
393
+
394
+ if not return_dict:
395
+ return image, has_nsfw_concept
396
+
397
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)