ironjr commited on
Commit
ce3f8a9
·
verified ·
1 Parent(s): 13124ab

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +102 -79
model.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024 Jaerin Lee
2
 
3
  # Permission is hereby granted, free of charge, to any person obtaining a copy
4
  # of this software and associated documentation files (the "Software"), to deal
@@ -55,7 +55,7 @@ from typing import Tuple, List, Literal, Optional, Union
55
  from tqdm import tqdm
56
  from PIL import Image
57
 
58
- from util import gaussian_lowpass, blend, get_panorama_views, shift_to_mask_bbox_center
59
 
60
 
61
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
@@ -73,7 +73,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
73
  return noise_cfg
74
 
75
 
76
- class StableMultiDiffusionSDXLPipeline(nn.Module):
77
  def __init__(
78
  self,
79
  device: torch.device,
@@ -93,7 +93,7 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
93
  has_i2t: bool = True,
94
  lora_weight: float = 1.0,
95
  ) -> None:
96
- r"""Stabilized MultiDiffusion for fast sampling.
97
 
98
  Accelrated region-based text-to-image synthesis with Latent Consistency
99
  Model while preserving mask fidelity and quality.
@@ -131,7 +131,7 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
131
  default_preprocess_mask_cover_alpha (float): Optional preprocessing
132
  where each mask covered by other masks is reduced in its alpha
133
  value by this specified factor.
134
- t_index_list (List[int]): The default scheduling for LCM scheduler.
135
  mask_type (Literal['discrete', 'semi-continuous', 'continuous']):
136
  defines the mask quantization modes. Details in the codes of
137
  `self.process_mask`. Basically, this (subtly) controls the
@@ -170,10 +170,10 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
170
  model_key = hf_key
171
  lora_ckpt = 'sdxl_lightning_4step_lora.safetensors'
172
 
173
- self.pipe = StableDiffusionXLPipeline.from_pretrained(model_key, variant=variant, torch_dtype=self.dtype).to(self.device)
174
  self.pipe.load_lora_weights(hf_hub_download(lightning_repo, lora_ckpt), adapter_name='lightning')
175
  self.pipe.set_adapters(["lightning"], adapter_weights=[lora_weight])
176
- # self.pipe.fuse_lora()
177
  else:
178
  model_key = 'stabilityai/stable-diffusion-xl-base-1.0'
179
  variant = 'fp16'
@@ -212,7 +212,7 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
212
  self.vae_scale_factor = self.pipe.vae_scale_factor
213
 
214
  # Prepare white background for bootstrapping.
215
- # self.get_white_background(1024, 1024)
216
 
217
  print(f'[INFO] Model is loaded!')
218
 
@@ -691,7 +691,7 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
691
  25, 37], the masks are split into binary masks whose values are
692
  greater than these levels. This results in tradual increase of mask
693
  region as the timesteps increase. Details are described in our
694
- paper at https://arxiv.org/pdf/2403.09055.pdf.
695
 
696
  On the Three Modes of `mask_type`:
697
  `self.mask_type` is predefined at the initialization stage of this
@@ -949,6 +949,9 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
949
  boostrap_mix_steps: Optional[float] = None,
950
  bootstrap_leak_sensitivity: Optional[float] = None,
951
  preprocess_mask_cover_alpha: Optional[float] = None,
 
 
 
952
  ) -> Image.Image:
953
  r"""Arbitrary-size image generation from multiple pairs of (regional)
954
  text prompt-mask pairs.
@@ -957,7 +960,7 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
957
 
958
  Example:
959
  >>> device = torch.device('cuda:0')
960
- >>> smd = StableMultiDiffusionPipeline(device)
961
  >>> prompts = {... specify prompts}
962
  >>> masks = {... specify mask tensors}
963
  >>> height, width = masks.shape[-2:]
@@ -1046,7 +1049,7 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
1046
 
1047
  # prompts is None: return background.
1048
  # masks is None but prompts is not None: return prompts
1049
- # masks is not None and prompts is not None: Do StableMultiDiffusion.
1050
 
1051
  if prompts is None or (isinstance(prompts, (list, tuple, str)) and len(prompts) == 0):
1052
  if background is None and background_prompt is not None:
@@ -1157,27 +1160,22 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
1157
 
1158
  # SDXL pipeline settings.
1159
  batch_size = 1
1160
- output_type = 'pil'
1161
-
1162
- guidance_rescale = 0.7
1163
-
1164
- prompt_2 = None
1165
- device = self.device
1166
  num_images_per_prompt = 1
1167
- negative_prompt_2 = None
1168
 
1169
  original_size = (height, width)
1170
  target_size = (height, width)
1171
  crops_coords_top_left = (0, 0)
1172
- negative_crops_coords_top_left = (0, 0)
1173
  negative_original_size = None
1174
  negative_target_size = None
1175
- pooled_prompt_embeds = None
1176
- negative_pooled_prompt_embeds = None
1177
- text_encoder_lora_scale = None
1178
 
 
 
1179
  prompt_embeds = None
1180
  negative_prompt_embeds = None
 
 
 
1181
 
1182
  (
1183
  prompt_embeds,
@@ -1187,7 +1185,7 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
1187
  ) = self.encode_prompt(
1188
  prompt=prompts,
1189
  prompt_2=prompt_2,
1190
- device=device,
1191
  num_images_per_prompt=num_images_per_prompt,
1192
  do_classifier_free_guidance=do_classifier_free_guidance,
1193
  negative_prompt=negative_prompts,
@@ -1199,30 +1197,6 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
1199
  lora_scale=text_encoder_lora_scale,
1200
  )
1201
 
1202
- add_text_embeds = pooled_prompt_embeds
1203
- if self.text_encoder_2 is None:
1204
- text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
1205
- else:
1206
- text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1207
-
1208
- add_time_ids = self._get_add_time_ids(
1209
- original_size,
1210
- crops_coords_top_left,
1211
- target_size,
1212
- dtype=prompt_embeds.dtype,
1213
- text_encoder_projection_dim=text_encoder_projection_dim,
1214
- )
1215
- if negative_original_size is not None and negative_target_size is not None:
1216
- negative_add_time_ids = self._get_add_time_ids(
1217
- negative_original_size,
1218
- negative_crops_coords_top_left,
1219
- negative_target_size,
1220
- dtype=prompt_embeds.dtype,
1221
- text_encoder_projection_dim=text_encoder_projection_dim,
1222
- )
1223
- else:
1224
- negative_add_time_ids = add_time_ids
1225
-
1226
  if has_background:
1227
  # First channel is background prompt text embeds. Background prompt itself is not used for generation.
1228
  s = prompt_strengths
@@ -1248,10 +1222,26 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
1248
  assert fu.shape[0] == 1 and fe.shape == num_prompts
1249
  fu = fu.repeat(num_prompts, 1, 1)
1250
  negative_prompt_embeds = torch.lerp(bu, fu, s) # (n, 77, 1024)
 
 
 
 
 
 
 
 
 
 
 
 
 
1251
  elif negative_prompt_embeds is not None and num_prompts > num_nprompts:
1252
  # # negative prompts = 1; # prompts > 1.
1253
  assert negative_prompt_embeds.shape[0] == 1 and prompt_embeds.shape[0] == num_prompts
1254
  negative_prompt_embeds = negative_prompt_embeds.repeat(num_prompts, 1, 1)
 
 
 
1255
  # assert negative_prompt_embeds.shape[0] == prompt_embeds.shape[0] == num_prompts
1256
  if num_masks > num_prompts:
1257
  assert masks.shape[0] == num_masks and num_prompts == 1
@@ -1259,6 +1249,34 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
1259
  if negative_prompt_embeds is not None:
1260
  negative_prompt_embeds = negative_prompt_embeds.repeat(num_masks, 1, 1)
1261
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1262
  # SDXL pipeline settings.
1263
  if do_classifier_free_guidance:
1264
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
@@ -1266,19 +1284,25 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
1266
  add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
1267
  del negative_prompt_embeds, negative_pooled_prompt_embeds, negative_add_time_ids
1268
 
1269
- prompt_embeds = prompt_embeds.to(device)
1270
- add_text_embeds = add_text_embeds.to(device)
1271
- add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
1272
 
1273
 
1274
  ### Run
1275
 
1276
  # Latent initialization.
 
1277
  if self.timesteps[0] < 999 and has_background:
1278
- latents = self.scheduler_add_noise(bg_latents, None, 0, initial=True)
1279
  else:
1280
- latents = torch.randn((1, self.unet.config.in_channels, h, w), dtype=self.dtype, device=self.device)
1281
- latents = latents * self.scheduler.init_noise_sigma
 
 
 
 
 
1282
 
1283
  # Tiling (if needed).
1284
  if height > tile_size or width > tile_size:
@@ -1287,9 +1311,9 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
1287
  tile_masks = tile_masks.to(self.device)
1288
  else:
1289
  views = [(0, h, 0, w)]
1290
- tile_masks = latents.new_ones((1, 1, h, w))
1291
- value = torch.zeros_like(latents)
1292
- count_all = torch.zeros_like(latents)
1293
 
1294
  with torch.autocast('cuda'):
1295
  for i, t in enumerate(tqdm(self.timesteps)):
@@ -1300,7 +1324,7 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
1300
  count_all.zero_()
1301
  for j, (h_start, h_end, w_start, w_end) in enumerate(views):
1302
  fg_mask_ = fg_mask[..., h_start:h_end, w_start:w_end]
1303
- latents_ = latents[..., h_start:h_end, w_start:w_end].repeat(num_masks, 1, 1, 1)
1304
 
1305
  # Additional arguments for the SDXL pipeline.
1306
  add_time_ids_input = add_time_ids.clone()
@@ -1312,16 +1336,16 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
1312
  if i < bootstrap_steps:
1313
  mix_ratio = min(1, max(0, boostrap_mix_steps - i))
1314
  # Treat the first foreground latent as the background latent if one does not exist.
1315
- bg_latents_ = bg_latents[..., h_start:h_end, w_start:w_end] if has_background else latents_[:1]
1316
  white_ = white[..., h_start:h_end, w_start:w_end]
1317
- white_ = self.scheduler_add_noise(white_, None, i, initial=True)
1318
- bg_latents_ = mix_ratio * white_ + (1.0 - mix_ratio) * bg_latents_
1319
- latents_ = (1.0 - fg_mask_) * bg_latents_ + fg_mask_ * latents_
1320
 
1321
  # Centering.
1322
- latents_ = shift_to_mask_bbox_center(latents_, fg_mask_, reverse=True)
1323
 
1324
- latent_model_input = torch.cat([latents_] * 2) if do_classifier_free_guidance else latents_
1325
  latent_model_input = self.scheduler_scale_model_input(latent_model_input, i)
1326
 
1327
  # Perform one step of the reverse diffusion.
@@ -1341,33 +1365,32 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
1341
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
1342
 
1343
  if do_classifier_free_guidance and guidance_rescale > 0.0:
1344
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1345
  noise_pred = rescale_noise_cfg(noise_pred, noise_pred_cond, guidance_rescale=guidance_rescale)
1346
 
1347
- latents_ = self.scheduler_step(noise_pred, i, latents_)
1348
 
1349
  if i < bootstrap_steps:
1350
  # Uncentering.
1351
- latents_ = shift_to_mask_bbox_center(latents_, fg_mask_)
1352
 
1353
  # Remove leakage (optional).
1354
- leak = (latents_ - bg_latents_).pow(2).mean(dim=1, keepdim=True)
1355
  leak_sigmoid = torch.sigmoid(leak / bootstrap_leak_sensitivity) * 2 - 1
1356
  fg_mask_ = fg_mask_ * leak_sigmoid
1357
 
1358
  # Mix the latents.
1359
  fg_mask_ = fg_mask_ * tile_masks[:, j:j + 1, h_start:h_end, w_start:w_end]
1360
- value[..., h_start:h_end, w_start:w_end] += (fg_mask_ * latents_).sum(dim=0, keepdim=True)
1361
  count_all[..., h_start:h_end, w_start:w_end] += fg_mask_.sum(dim=0, keepdim=True)
1362
 
1363
- latents = torch.where(count_all > 0, value / count_all, value)
1364
  bg_mask = (1 - count_all).clip_(0, 1) # (T, 1, h, w)
1365
  if has_background:
1366
- latents = (1 - bg_mask) * latents + bg_mask * bg_latents
1367
 
1368
  # Noise is added after mixing.
1369
  if i < len(self.timesteps) - 1:
1370
- latents = self.scheduler_add_noise(latents, None, i + 1)
1371
 
1372
  if not output_type == "latent":
1373
  # make sure the VAE is in float32 mode, as it overflows in float16
@@ -1375,7 +1398,7 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
1375
 
1376
  if needs_upcasting:
1377
  self.upcast_vae()
1378
- latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1379
 
1380
  # unscale/denormalize the latents
1381
  # denormalize with the mean and std if available and not None
@@ -1383,22 +1406,22 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
1383
  has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
1384
  if has_latents_mean and has_latents_std:
1385
  latents_mean = (
1386
- torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1387
  )
1388
  latents_std = (
1389
- torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1390
  )
1391
- latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
1392
  else:
1393
- latents = latents / self.vae.config.scaling_factor
1394
 
1395
- image = self.vae.decode(latents, return_dict=False)[0]
1396
 
1397
  # cast back to fp16 if needed
1398
  if needs_upcasting:
1399
  self.vae.to(dtype=torch.float16)
1400
  else:
1401
- image = latents
1402
 
1403
  # Return PIL Image.
1404
  image = image[0].clip_(-1, 1) * 0.5 + 0.5
@@ -1407,4 +1430,4 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
1407
  image = blend(image, background[0], fg_mask)
1408
  else:
1409
  image = T.ToPILImage()(image)
1410
- return image
 
1
+ # Copyright (c) 2025 Jaerin Lee
2
 
3
  # Permission is hereby granted, free of charge, to any person obtaining a copy
4
  # of this software and associated documentation files (the "Software"), to deal
 
55
  from tqdm import tqdm
56
  from PIL import Image
57
 
58
+ from util import load_model, gaussian_lowpass, blend, get_panorama_views, shift_to_mask_bbox_center
59
 
60
 
61
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
 
73
  return noise_cfg
74
 
75
 
76
+ class SemanticDrawSDXLPipeline(nn.Module):
77
  def __init__(
78
  self,
79
  device: torch.device,
 
93
  has_i2t: bool = True,
94
  lora_weight: float = 1.0,
95
  ) -> None:
96
+ r"""Stabilized regionally assigned texts-to-image generation for fast sampling.
97
 
98
  Accelrated region-based text-to-image synthesis with Latent Consistency
99
  Model while preserving mask fidelity and quality.
 
131
  default_preprocess_mask_cover_alpha (float): Optional preprocessing
132
  where each mask covered by other masks is reduced in its alpha
133
  value by this specified factor.
134
+ t_index_list (List[int]): The default scheduling for scheduler.
135
  mask_type (Literal['discrete', 'semi-continuous', 'continuous']):
136
  defines the mask quantization modes. Details in the codes of
137
  `self.process_mask`. Basically, this (subtly) controls the
 
170
  model_key = hf_key
171
  lora_ckpt = 'sdxl_lightning_4step_lora.safetensors'
172
 
173
+ self.pipe = load_model(model_key, 'xl', self.device, self.dtype)
174
  self.pipe.load_lora_weights(hf_hub_download(lightning_repo, lora_ckpt), adapter_name='lightning')
175
  self.pipe.set_adapters(["lightning"], adapter_weights=[lora_weight])
176
+ self.pipe.fuse_lora()
177
  else:
178
  model_key = 'stabilityai/stable-diffusion-xl-base-1.0'
179
  variant = 'fp16'
 
212
  self.vae_scale_factor = self.pipe.vae_scale_factor
213
 
214
  # Prepare white background for bootstrapping.
215
+ self.get_white_background(1024, 1024)
216
 
217
  print(f'[INFO] Model is loaded!')
218
 
 
691
  25, 37], the masks are split into binary masks whose values are
692
  greater than these levels. This results in tradual increase of mask
693
  region as the timesteps increase. Details are described in our
694
+ paper.
695
 
696
  On the Three Modes of `mask_type`:
697
  `self.mask_type` is predefined at the initialization stage of this
 
949
  boostrap_mix_steps: Optional[float] = None,
950
  bootstrap_leak_sensitivity: Optional[float] = None,
951
  preprocess_mask_cover_alpha: Optional[float] = None,
952
+ # SDXL Pipeline setting.
953
+ guidance_rescale: float = 0.7,
954
+ output_type = 'pil',
955
  ) -> Image.Image:
956
  r"""Arbitrary-size image generation from multiple pairs of (regional)
957
  text prompt-mask pairs.
 
960
 
961
  Example:
962
  >>> device = torch.device('cuda:0')
963
+ >>> smd = SemanticDrawPipeline(device)
964
  >>> prompts = {... specify prompts}
965
  >>> masks = {... specify mask tensors}
966
  >>> height, width = masks.shape[-2:]
 
1049
 
1050
  # prompts is None: return background.
1051
  # masks is None but prompts is not None: return prompts
1052
+ # masks is not None and prompts is not None: Do SemanticDraw.
1053
 
1054
  if prompts is None or (isinstance(prompts, (list, tuple, str)) and len(prompts) == 0):
1055
  if background is None and background_prompt is not None:
 
1160
 
1161
  # SDXL pipeline settings.
1162
  batch_size = 1
 
 
 
 
 
 
1163
  num_images_per_prompt = 1
 
1164
 
1165
  original_size = (height, width)
1166
  target_size = (height, width)
1167
  crops_coords_top_left = (0, 0)
 
1168
  negative_original_size = None
1169
  negative_target_size = None
1170
+ negative_crops_coords_top_left = (0, 0)
 
 
1171
 
1172
+ prompt_2 = None
1173
+ negative_prompt_2 = None
1174
  prompt_embeds = None
1175
  negative_prompt_embeds = None
1176
+ pooled_prompt_embeds = None
1177
+ negative_pooled_prompt_embeds = None
1178
+ text_encoder_lora_scale = None
1179
 
1180
  (
1181
  prompt_embeds,
 
1185
  ) = self.encode_prompt(
1186
  prompt=prompts,
1187
  prompt_2=prompt_2,
1188
+ device=self.device,
1189
  num_images_per_prompt=num_images_per_prompt,
1190
  do_classifier_free_guidance=do_classifier_free_guidance,
1191
  negative_prompt=negative_prompts,
 
1197
  lora_scale=text_encoder_lora_scale,
1198
  )
1199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1200
  if has_background:
1201
  # First channel is background prompt text embeds. Background prompt itself is not used for generation.
1202
  s = prompt_strengths
 
1222
  assert fu.shape[0] == 1 and fe.shape == num_prompts
1223
  fu = fu.repeat(num_prompts, 1, 1)
1224
  negative_prompt_embeds = torch.lerp(bu, fu, s) # (n, 77, 1024)
1225
+
1226
+ be = pooled_prompt_embeds[:1]
1227
+ fe = pooled_prompt_embeds[1:]
1228
+ pooled_prompt_embeds = torch.lerp(be, fe, s[..., 0]) # (p, 1280)
1229
+
1230
+ if negative_pooled_prompt_embeds is not None:
1231
+ bu = negative_pooled_prompt_embeds[:1]
1232
+ fu = negative_pooled_prompt_embeds[1:]
1233
+ if num_prompts > num_nprompts:
1234
+ # # negative prompts = 1; # prompts > 1.
1235
+ assert fu.shape[0] == 1 and fe.shape == num_prompts
1236
+ fu = fu.repeat(num_prompts, 1)
1237
+ negative_pooled_prompt_embeds = torch.lerp(bu, fu, s[..., 0]) # (n, 1280)
1238
  elif negative_prompt_embeds is not None and num_prompts > num_nprompts:
1239
  # # negative prompts = 1; # prompts > 1.
1240
  assert negative_prompt_embeds.shape[0] == 1 and prompt_embeds.shape[0] == num_prompts
1241
  negative_prompt_embeds = negative_prompt_embeds.repeat(num_prompts, 1, 1)
1242
+
1243
+ assert negative_pooled_prompt_embeds.shape[0] == 1 and pooled_prompt_embeds.shape[0] == num_prompts
1244
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(num_prompts, 1)
1245
  # assert negative_prompt_embeds.shape[0] == prompt_embeds.shape[0] == num_prompts
1246
  if num_masks > num_prompts:
1247
  assert masks.shape[0] == num_masks and num_prompts == 1
 
1249
  if negative_prompt_embeds is not None:
1250
  negative_prompt_embeds = negative_prompt_embeds.repeat(num_masks, 1, 1)
1251
 
1252
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(num_masks, 1)
1253
+ if negative_pooled_prompt_embeds is not None:
1254
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(num_masks, 1)
1255
+
1256
+ add_text_embeds = pooled_prompt_embeds
1257
+ if self.text_encoder_2 is None:
1258
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
1259
+ else:
1260
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1261
+
1262
+ add_time_ids = self._get_add_time_ids(
1263
+ original_size,
1264
+ crops_coords_top_left,
1265
+ target_size,
1266
+ dtype=prompt_embeds.dtype,
1267
+ text_encoder_projection_dim=text_encoder_projection_dim,
1268
+ )
1269
+ if negative_original_size is not None and negative_target_size is not None:
1270
+ negative_add_time_ids = self._get_add_time_ids(
1271
+ negative_original_size,
1272
+ negative_crops_coords_top_left,
1273
+ negative_target_size,
1274
+ dtype=prompt_embeds.dtype,
1275
+ text_encoder_projection_dim=text_encoder_projection_dim,
1276
+ )
1277
+ else:
1278
+ negative_add_time_ids = add_time_ids
1279
+
1280
  # SDXL pipeline settings.
1281
  if do_classifier_free_guidance:
1282
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
 
1284
  add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
1285
  del negative_prompt_embeds, negative_pooled_prompt_embeds, negative_add_time_ids
1286
 
1287
+ prompt_embeds = prompt_embeds.to(self.device)
1288
+ add_text_embeds = add_text_embeds.to(self.device)
1289
+ add_time_ids = add_time_ids.to(self.device).repeat(batch_size * num_images_per_prompt, 1)
1290
 
1291
 
1292
  ### Run
1293
 
1294
  # Latent initialization.
1295
+ noise = torch.randn((1, self.unet.config.in_channels, h, w), dtype=self.dtype, device=self.device)
1296
  if self.timesteps[0] < 999 and has_background:
1297
+ latent = self.scheduler_add_noise(bg_latent, noise, 0, initial=True)
1298
  else:
1299
+ noise = torch.randn((1, self.unet.config.in_channels, h, w), dtype=self.dtype, device=self.device)
1300
+ latent = noise * self.scheduler.init_noise_sigma
1301
+
1302
+ if has_background:
1303
+ noise_bg_latents = [
1304
+ self.scheduler_add_noise(bg_latent, noise, i, initial=True) for i in range(len(self.timesteps))
1305
+ ] + [bg_latent]
1306
 
1307
  # Tiling (if needed).
1308
  if height > tile_size or width > tile_size:
 
1311
  tile_masks = tile_masks.to(self.device)
1312
  else:
1313
  views = [(0, h, 0, w)]
1314
+ tile_masks = latent.new_ones((1, 1, h, w))
1315
+ value = torch.zeros_like(latent)
1316
+ count_all = torch.zeros_like(latent)
1317
 
1318
  with torch.autocast('cuda'):
1319
  for i, t in enumerate(tqdm(self.timesteps)):
 
1324
  count_all.zero_()
1325
  for j, (h_start, h_end, w_start, w_end) in enumerate(views):
1326
  fg_mask_ = fg_mask[..., h_start:h_end, w_start:w_end]
1327
+ latent_ = latent[..., h_start:h_end, w_start:w_end].repeat(num_masks, 1, 1, 1)
1328
 
1329
  # Additional arguments for the SDXL pipeline.
1330
  add_time_ids_input = add_time_ids.clone()
 
1336
  if i < bootstrap_steps:
1337
  mix_ratio = min(1, max(0, boostrap_mix_steps - i))
1338
  # Treat the first foreground latent as the background latent if one does not exist.
1339
+ bg_latent_ = noise_bg_latents[i][..., h_start:h_end, w_start:w_end] if has_background else latent_[:1]
1340
  white_ = white[..., h_start:h_end, w_start:w_end]
1341
+ white_ = self.scheduler_add_noise(white_, noise[..., h_start:h_end, w_start:w_end], i, initial=True)
1342
+ bg_latent_ = mix_ratio * white_ + (1.0 - mix_ratio) * bg_latent_
1343
+ latent_ = (1.0 - fg_mask_) * bg_latent_ + fg_mask_ * latent_
1344
 
1345
  # Centering.
1346
+ latent_ = shift_to_mask_bbox_center(latent_, fg_mask_, reverse=True)
1347
 
1348
+ latent_model_input = torch.cat([latent_] * 2) if do_classifier_free_guidance else latent_
1349
  latent_model_input = self.scheduler_scale_model_input(latent_model_input, i)
1350
 
1351
  # Perform one step of the reverse diffusion.
 
1365
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
1366
 
1367
  if do_classifier_free_guidance and guidance_rescale > 0.0:
 
1368
  noise_pred = rescale_noise_cfg(noise_pred, noise_pred_cond, guidance_rescale=guidance_rescale)
1369
 
1370
+ latent_ = self.scheduler_step(noise_pred, i, latent_)
1371
 
1372
  if i < bootstrap_steps:
1373
  # Uncentering.
1374
+ latent_ = shift_to_mask_bbox_center(latent_, fg_mask_)
1375
 
1376
  # Remove leakage (optional).
1377
+ leak = (latent_ - bg_latent_).pow(2).mean(dim=1, keepdim=True)
1378
  leak_sigmoid = torch.sigmoid(leak / bootstrap_leak_sensitivity) * 2 - 1
1379
  fg_mask_ = fg_mask_ * leak_sigmoid
1380
 
1381
  # Mix the latents.
1382
  fg_mask_ = fg_mask_ * tile_masks[:, j:j + 1, h_start:h_end, w_start:w_end]
1383
+ value[..., h_start:h_end, w_start:w_end] += (fg_mask_ * latent_).sum(dim=0, keepdim=True)
1384
  count_all[..., h_start:h_end, w_start:w_end] += fg_mask_.sum(dim=0, keepdim=True)
1385
 
1386
+ latent = torch.where(count_all > 0, value / count_all, value)
1387
  bg_mask = (1 - count_all).clip_(0, 1) # (T, 1, h, w)
1388
  if has_background:
1389
+ latent = (1 - bg_mask) * latent + bg_mask * noise_bg_latents[i + 1] # bg_latent
1390
 
1391
  # Noise is added after mixing.
1392
  if i < len(self.timesteps) - 1:
1393
+ latent = self.scheduler_add_noise(latent, None, i + 1)
1394
 
1395
  if not output_type == "latent":
1396
  # make sure the VAE is in float32 mode, as it overflows in float16
 
1398
 
1399
  if needs_upcasting:
1400
  self.upcast_vae()
1401
+ latent = latent.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1402
 
1403
  # unscale/denormalize the latents
1404
  # denormalize with the mean and std if available and not None
 
1406
  has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
1407
  if has_latents_mean and has_latents_std:
1408
  latents_mean = (
1409
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latent.device, latent.dtype)
1410
  )
1411
  latents_std = (
1412
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latent.device, latent.dtype)
1413
  )
1414
+ latent = latent * latents_std / self.vae.config.scaling_factor + latents_mean
1415
  else:
1416
+ latent = latent / self.vae.config.scaling_factor
1417
 
1418
+ image = self.vae.decode(latent, return_dict=False)[0]
1419
 
1420
  # cast back to fp16 if needed
1421
  if needs_upcasting:
1422
  self.vae.to(dtype=torch.float16)
1423
  else:
1424
+ image = latent
1425
 
1426
  # Return PIL Image.
1427
  image = image[0].clip_(-1, 1) * 0.5 + 0.5
 
1430
  image = blend(image, background[0], fg_mask)
1431
  else:
1432
  image = T.ToPILImage()(image)
1433
+ return image