Spaces:
Running
on
Zero
Running
on
Zero
Update model.py
Browse files
model.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
# Copyright (c)
|
2 |
|
3 |
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
4 |
# of this software and associated documentation files (the "Software"), to deal
|
@@ -55,7 +55,7 @@ from typing import Tuple, List, Literal, Optional, Union
|
|
55 |
from tqdm import tqdm
|
56 |
from PIL import Image
|
57 |
|
58 |
-
from util import gaussian_lowpass, blend, get_panorama_views, shift_to_mask_bbox_center
|
59 |
|
60 |
|
61 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
@@ -73,7 +73,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
|
73 |
return noise_cfg
|
74 |
|
75 |
|
76 |
-
class
|
77 |
def __init__(
|
78 |
self,
|
79 |
device: torch.device,
|
@@ -93,7 +93,7 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
93 |
has_i2t: bool = True,
|
94 |
lora_weight: float = 1.0,
|
95 |
) -> None:
|
96 |
-
r"""Stabilized
|
97 |
|
98 |
Accelrated region-based text-to-image synthesis with Latent Consistency
|
99 |
Model while preserving mask fidelity and quality.
|
@@ -131,7 +131,7 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
131 |
default_preprocess_mask_cover_alpha (float): Optional preprocessing
|
132 |
where each mask covered by other masks is reduced in its alpha
|
133 |
value by this specified factor.
|
134 |
-
t_index_list (List[int]): The default scheduling for
|
135 |
mask_type (Literal['discrete', 'semi-continuous', 'continuous']):
|
136 |
defines the mask quantization modes. Details in the codes of
|
137 |
`self.process_mask`. Basically, this (subtly) controls the
|
@@ -170,10 +170,10 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
170 |
model_key = hf_key
|
171 |
lora_ckpt = 'sdxl_lightning_4step_lora.safetensors'
|
172 |
|
173 |
-
self.pipe =
|
174 |
self.pipe.load_lora_weights(hf_hub_download(lightning_repo, lora_ckpt), adapter_name='lightning')
|
175 |
self.pipe.set_adapters(["lightning"], adapter_weights=[lora_weight])
|
176 |
-
|
177 |
else:
|
178 |
model_key = 'stabilityai/stable-diffusion-xl-base-1.0'
|
179 |
variant = 'fp16'
|
@@ -212,7 +212,7 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
212 |
self.vae_scale_factor = self.pipe.vae_scale_factor
|
213 |
|
214 |
# Prepare white background for bootstrapping.
|
215 |
-
|
216 |
|
217 |
print(f'[INFO] Model is loaded!')
|
218 |
|
@@ -691,7 +691,7 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
691 |
25, 37], the masks are split into binary masks whose values are
|
692 |
greater than these levels. This results in tradual increase of mask
|
693 |
region as the timesteps increase. Details are described in our
|
694 |
-
paper
|
695 |
|
696 |
On the Three Modes of `mask_type`:
|
697 |
`self.mask_type` is predefined at the initialization stage of this
|
@@ -949,6 +949,9 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
949 |
boostrap_mix_steps: Optional[float] = None,
|
950 |
bootstrap_leak_sensitivity: Optional[float] = None,
|
951 |
preprocess_mask_cover_alpha: Optional[float] = None,
|
|
|
|
|
|
|
952 |
) -> Image.Image:
|
953 |
r"""Arbitrary-size image generation from multiple pairs of (regional)
|
954 |
text prompt-mask pairs.
|
@@ -957,7 +960,7 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
957 |
|
958 |
Example:
|
959 |
>>> device = torch.device('cuda:0')
|
960 |
-
>>> smd =
|
961 |
>>> prompts = {... specify prompts}
|
962 |
>>> masks = {... specify mask tensors}
|
963 |
>>> height, width = masks.shape[-2:]
|
@@ -1046,7 +1049,7 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
1046 |
|
1047 |
# prompts is None: return background.
|
1048 |
# masks is None but prompts is not None: return prompts
|
1049 |
-
# masks is not None and prompts is not None: Do
|
1050 |
|
1051 |
if prompts is None or (isinstance(prompts, (list, tuple, str)) and len(prompts) == 0):
|
1052 |
if background is None and background_prompt is not None:
|
@@ -1157,27 +1160,22 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
1157 |
|
1158 |
# SDXL pipeline settings.
|
1159 |
batch_size = 1
|
1160 |
-
output_type = 'pil'
|
1161 |
-
|
1162 |
-
guidance_rescale = 0.7
|
1163 |
-
|
1164 |
-
prompt_2 = None
|
1165 |
-
device = self.device
|
1166 |
num_images_per_prompt = 1
|
1167 |
-
negative_prompt_2 = None
|
1168 |
|
1169 |
original_size = (height, width)
|
1170 |
target_size = (height, width)
|
1171 |
crops_coords_top_left = (0, 0)
|
1172 |
-
negative_crops_coords_top_left = (0, 0)
|
1173 |
negative_original_size = None
|
1174 |
negative_target_size = None
|
1175 |
-
|
1176 |
-
negative_pooled_prompt_embeds = None
|
1177 |
-
text_encoder_lora_scale = None
|
1178 |
|
|
|
|
|
1179 |
prompt_embeds = None
|
1180 |
negative_prompt_embeds = None
|
|
|
|
|
|
|
1181 |
|
1182 |
(
|
1183 |
prompt_embeds,
|
@@ -1187,7 +1185,7 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
1187 |
) = self.encode_prompt(
|
1188 |
prompt=prompts,
|
1189 |
prompt_2=prompt_2,
|
1190 |
-
device=device,
|
1191 |
num_images_per_prompt=num_images_per_prompt,
|
1192 |
do_classifier_free_guidance=do_classifier_free_guidance,
|
1193 |
negative_prompt=negative_prompts,
|
@@ -1199,30 +1197,6 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
1199 |
lora_scale=text_encoder_lora_scale,
|
1200 |
)
|
1201 |
|
1202 |
-
add_text_embeds = pooled_prompt_embeds
|
1203 |
-
if self.text_encoder_2 is None:
|
1204 |
-
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
1205 |
-
else:
|
1206 |
-
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
1207 |
-
|
1208 |
-
add_time_ids = self._get_add_time_ids(
|
1209 |
-
original_size,
|
1210 |
-
crops_coords_top_left,
|
1211 |
-
target_size,
|
1212 |
-
dtype=prompt_embeds.dtype,
|
1213 |
-
text_encoder_projection_dim=text_encoder_projection_dim,
|
1214 |
-
)
|
1215 |
-
if negative_original_size is not None and negative_target_size is not None:
|
1216 |
-
negative_add_time_ids = self._get_add_time_ids(
|
1217 |
-
negative_original_size,
|
1218 |
-
negative_crops_coords_top_left,
|
1219 |
-
negative_target_size,
|
1220 |
-
dtype=prompt_embeds.dtype,
|
1221 |
-
text_encoder_projection_dim=text_encoder_projection_dim,
|
1222 |
-
)
|
1223 |
-
else:
|
1224 |
-
negative_add_time_ids = add_time_ids
|
1225 |
-
|
1226 |
if has_background:
|
1227 |
# First channel is background prompt text embeds. Background prompt itself is not used for generation.
|
1228 |
s = prompt_strengths
|
@@ -1248,10 +1222,26 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
1248 |
assert fu.shape[0] == 1 and fe.shape == num_prompts
|
1249 |
fu = fu.repeat(num_prompts, 1, 1)
|
1250 |
negative_prompt_embeds = torch.lerp(bu, fu, s) # (n, 77, 1024)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1251 |
elif negative_prompt_embeds is not None and num_prompts > num_nprompts:
|
1252 |
# # negative prompts = 1; # prompts > 1.
|
1253 |
assert negative_prompt_embeds.shape[0] == 1 and prompt_embeds.shape[0] == num_prompts
|
1254 |
negative_prompt_embeds = negative_prompt_embeds.repeat(num_prompts, 1, 1)
|
|
|
|
|
|
|
1255 |
# assert negative_prompt_embeds.shape[0] == prompt_embeds.shape[0] == num_prompts
|
1256 |
if num_masks > num_prompts:
|
1257 |
assert masks.shape[0] == num_masks and num_prompts == 1
|
@@ -1259,6 +1249,34 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
1259 |
if negative_prompt_embeds is not None:
|
1260 |
negative_prompt_embeds = negative_prompt_embeds.repeat(num_masks, 1, 1)
|
1261 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1262 |
# SDXL pipeline settings.
|
1263 |
if do_classifier_free_guidance:
|
1264 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
@@ -1266,19 +1284,25 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
1266 |
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
1267 |
del negative_prompt_embeds, negative_pooled_prompt_embeds, negative_add_time_ids
|
1268 |
|
1269 |
-
prompt_embeds = prompt_embeds.to(device)
|
1270 |
-
add_text_embeds = add_text_embeds.to(device)
|
1271 |
-
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
1272 |
|
1273 |
|
1274 |
### Run
|
1275 |
|
1276 |
# Latent initialization.
|
|
|
1277 |
if self.timesteps[0] < 999 and has_background:
|
1278 |
-
|
1279 |
else:
|
1280 |
-
|
1281 |
-
|
|
|
|
|
|
|
|
|
|
|
1282 |
|
1283 |
# Tiling (if needed).
|
1284 |
if height > tile_size or width > tile_size:
|
@@ -1287,9 +1311,9 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
1287 |
tile_masks = tile_masks.to(self.device)
|
1288 |
else:
|
1289 |
views = [(0, h, 0, w)]
|
1290 |
-
tile_masks =
|
1291 |
-
value = torch.zeros_like(
|
1292 |
-
count_all = torch.zeros_like(
|
1293 |
|
1294 |
with torch.autocast('cuda'):
|
1295 |
for i, t in enumerate(tqdm(self.timesteps)):
|
@@ -1300,7 +1324,7 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
1300 |
count_all.zero_()
|
1301 |
for j, (h_start, h_end, w_start, w_end) in enumerate(views):
|
1302 |
fg_mask_ = fg_mask[..., h_start:h_end, w_start:w_end]
|
1303 |
-
|
1304 |
|
1305 |
# Additional arguments for the SDXL pipeline.
|
1306 |
add_time_ids_input = add_time_ids.clone()
|
@@ -1312,16 +1336,16 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
1312 |
if i < bootstrap_steps:
|
1313 |
mix_ratio = min(1, max(0, boostrap_mix_steps - i))
|
1314 |
# Treat the first foreground latent as the background latent if one does not exist.
|
1315 |
-
|
1316 |
white_ = white[..., h_start:h_end, w_start:w_end]
|
1317 |
-
white_ = self.scheduler_add_noise(white_,
|
1318 |
-
|
1319 |
-
|
1320 |
|
1321 |
# Centering.
|
1322 |
-
|
1323 |
|
1324 |
-
latent_model_input = torch.cat([
|
1325 |
latent_model_input = self.scheduler_scale_model_input(latent_model_input, i)
|
1326 |
|
1327 |
# Perform one step of the reverse diffusion.
|
@@ -1341,33 +1365,32 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
1341 |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
1342 |
|
1343 |
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
1344 |
-
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
1345 |
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_cond, guidance_rescale=guidance_rescale)
|
1346 |
|
1347 |
-
|
1348 |
|
1349 |
if i < bootstrap_steps:
|
1350 |
# Uncentering.
|
1351 |
-
|
1352 |
|
1353 |
# Remove leakage (optional).
|
1354 |
-
leak = (
|
1355 |
leak_sigmoid = torch.sigmoid(leak / bootstrap_leak_sensitivity) * 2 - 1
|
1356 |
fg_mask_ = fg_mask_ * leak_sigmoid
|
1357 |
|
1358 |
# Mix the latents.
|
1359 |
fg_mask_ = fg_mask_ * tile_masks[:, j:j + 1, h_start:h_end, w_start:w_end]
|
1360 |
-
value[..., h_start:h_end, w_start:w_end] += (fg_mask_ *
|
1361 |
count_all[..., h_start:h_end, w_start:w_end] += fg_mask_.sum(dim=0, keepdim=True)
|
1362 |
|
1363 |
-
|
1364 |
bg_mask = (1 - count_all).clip_(0, 1) # (T, 1, h, w)
|
1365 |
if has_background:
|
1366 |
-
|
1367 |
|
1368 |
# Noise is added after mixing.
|
1369 |
if i < len(self.timesteps) - 1:
|
1370 |
-
|
1371 |
|
1372 |
if not output_type == "latent":
|
1373 |
# make sure the VAE is in float32 mode, as it overflows in float16
|
@@ -1375,7 +1398,7 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
1375 |
|
1376 |
if needs_upcasting:
|
1377 |
self.upcast_vae()
|
1378 |
-
|
1379 |
|
1380 |
# unscale/denormalize the latents
|
1381 |
# denormalize with the mean and std if available and not None
|
@@ -1383,22 +1406,22 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
1383 |
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
|
1384 |
if has_latents_mean and has_latents_std:
|
1385 |
latents_mean = (
|
1386 |
-
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(
|
1387 |
)
|
1388 |
latents_std = (
|
1389 |
-
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(
|
1390 |
)
|
1391 |
-
|
1392 |
else:
|
1393 |
-
|
1394 |
|
1395 |
-
image = self.vae.decode(
|
1396 |
|
1397 |
# cast back to fp16 if needed
|
1398 |
if needs_upcasting:
|
1399 |
self.vae.to(dtype=torch.float16)
|
1400 |
else:
|
1401 |
-
image =
|
1402 |
|
1403 |
# Return PIL Image.
|
1404 |
image = image[0].clip_(-1, 1) * 0.5 + 0.5
|
@@ -1407,4 +1430,4 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
1407 |
image = blend(image, background[0], fg_mask)
|
1408 |
else:
|
1409 |
image = T.ToPILImage()(image)
|
1410 |
-
return image
|
|
|
1 |
+
# Copyright (c) 2025 Jaerin Lee
|
2 |
|
3 |
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
4 |
# of this software and associated documentation files (the "Software"), to deal
|
|
|
55 |
from tqdm import tqdm
|
56 |
from PIL import Image
|
57 |
|
58 |
+
from util import load_model, gaussian_lowpass, blend, get_panorama_views, shift_to_mask_bbox_center
|
59 |
|
60 |
|
61 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
|
|
73 |
return noise_cfg
|
74 |
|
75 |
|
76 |
+
class SemanticDrawSDXLPipeline(nn.Module):
|
77 |
def __init__(
|
78 |
self,
|
79 |
device: torch.device,
|
|
|
93 |
has_i2t: bool = True,
|
94 |
lora_weight: float = 1.0,
|
95 |
) -> None:
|
96 |
+
r"""Stabilized regionally assigned texts-to-image generation for fast sampling.
|
97 |
|
98 |
Accelrated region-based text-to-image synthesis with Latent Consistency
|
99 |
Model while preserving mask fidelity and quality.
|
|
|
131 |
default_preprocess_mask_cover_alpha (float): Optional preprocessing
|
132 |
where each mask covered by other masks is reduced in its alpha
|
133 |
value by this specified factor.
|
134 |
+
t_index_list (List[int]): The default scheduling for scheduler.
|
135 |
mask_type (Literal['discrete', 'semi-continuous', 'continuous']):
|
136 |
defines the mask quantization modes. Details in the codes of
|
137 |
`self.process_mask`. Basically, this (subtly) controls the
|
|
|
170 |
model_key = hf_key
|
171 |
lora_ckpt = 'sdxl_lightning_4step_lora.safetensors'
|
172 |
|
173 |
+
self.pipe = load_model(model_key, 'xl', self.device, self.dtype)
|
174 |
self.pipe.load_lora_weights(hf_hub_download(lightning_repo, lora_ckpt), adapter_name='lightning')
|
175 |
self.pipe.set_adapters(["lightning"], adapter_weights=[lora_weight])
|
176 |
+
self.pipe.fuse_lora()
|
177 |
else:
|
178 |
model_key = 'stabilityai/stable-diffusion-xl-base-1.0'
|
179 |
variant = 'fp16'
|
|
|
212 |
self.vae_scale_factor = self.pipe.vae_scale_factor
|
213 |
|
214 |
# Prepare white background for bootstrapping.
|
215 |
+
self.get_white_background(1024, 1024)
|
216 |
|
217 |
print(f'[INFO] Model is loaded!')
|
218 |
|
|
|
691 |
25, 37], the masks are split into binary masks whose values are
|
692 |
greater than these levels. This results in tradual increase of mask
|
693 |
region as the timesteps increase. Details are described in our
|
694 |
+
paper.
|
695 |
|
696 |
On the Three Modes of `mask_type`:
|
697 |
`self.mask_type` is predefined at the initialization stage of this
|
|
|
949 |
boostrap_mix_steps: Optional[float] = None,
|
950 |
bootstrap_leak_sensitivity: Optional[float] = None,
|
951 |
preprocess_mask_cover_alpha: Optional[float] = None,
|
952 |
+
# SDXL Pipeline setting.
|
953 |
+
guidance_rescale: float = 0.7,
|
954 |
+
output_type = 'pil',
|
955 |
) -> Image.Image:
|
956 |
r"""Arbitrary-size image generation from multiple pairs of (regional)
|
957 |
text prompt-mask pairs.
|
|
|
960 |
|
961 |
Example:
|
962 |
>>> device = torch.device('cuda:0')
|
963 |
+
>>> smd = SemanticDrawPipeline(device)
|
964 |
>>> prompts = {... specify prompts}
|
965 |
>>> masks = {... specify mask tensors}
|
966 |
>>> height, width = masks.shape[-2:]
|
|
|
1049 |
|
1050 |
# prompts is None: return background.
|
1051 |
# masks is None but prompts is not None: return prompts
|
1052 |
+
# masks is not None and prompts is not None: Do SemanticDraw.
|
1053 |
|
1054 |
if prompts is None or (isinstance(prompts, (list, tuple, str)) and len(prompts) == 0):
|
1055 |
if background is None and background_prompt is not None:
|
|
|
1160 |
|
1161 |
# SDXL pipeline settings.
|
1162 |
batch_size = 1
|
|
|
|
|
|
|
|
|
|
|
|
|
1163 |
num_images_per_prompt = 1
|
|
|
1164 |
|
1165 |
original_size = (height, width)
|
1166 |
target_size = (height, width)
|
1167 |
crops_coords_top_left = (0, 0)
|
|
|
1168 |
negative_original_size = None
|
1169 |
negative_target_size = None
|
1170 |
+
negative_crops_coords_top_left = (0, 0)
|
|
|
|
|
1171 |
|
1172 |
+
prompt_2 = None
|
1173 |
+
negative_prompt_2 = None
|
1174 |
prompt_embeds = None
|
1175 |
negative_prompt_embeds = None
|
1176 |
+
pooled_prompt_embeds = None
|
1177 |
+
negative_pooled_prompt_embeds = None
|
1178 |
+
text_encoder_lora_scale = None
|
1179 |
|
1180 |
(
|
1181 |
prompt_embeds,
|
|
|
1185 |
) = self.encode_prompt(
|
1186 |
prompt=prompts,
|
1187 |
prompt_2=prompt_2,
|
1188 |
+
device=self.device,
|
1189 |
num_images_per_prompt=num_images_per_prompt,
|
1190 |
do_classifier_free_guidance=do_classifier_free_guidance,
|
1191 |
negative_prompt=negative_prompts,
|
|
|
1197 |
lora_scale=text_encoder_lora_scale,
|
1198 |
)
|
1199 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1200 |
if has_background:
|
1201 |
# First channel is background prompt text embeds. Background prompt itself is not used for generation.
|
1202 |
s = prompt_strengths
|
|
|
1222 |
assert fu.shape[0] == 1 and fe.shape == num_prompts
|
1223 |
fu = fu.repeat(num_prompts, 1, 1)
|
1224 |
negative_prompt_embeds = torch.lerp(bu, fu, s) # (n, 77, 1024)
|
1225 |
+
|
1226 |
+
be = pooled_prompt_embeds[:1]
|
1227 |
+
fe = pooled_prompt_embeds[1:]
|
1228 |
+
pooled_prompt_embeds = torch.lerp(be, fe, s[..., 0]) # (p, 1280)
|
1229 |
+
|
1230 |
+
if negative_pooled_prompt_embeds is not None:
|
1231 |
+
bu = negative_pooled_prompt_embeds[:1]
|
1232 |
+
fu = negative_pooled_prompt_embeds[1:]
|
1233 |
+
if num_prompts > num_nprompts:
|
1234 |
+
# # negative prompts = 1; # prompts > 1.
|
1235 |
+
assert fu.shape[0] == 1 and fe.shape == num_prompts
|
1236 |
+
fu = fu.repeat(num_prompts, 1)
|
1237 |
+
negative_pooled_prompt_embeds = torch.lerp(bu, fu, s[..., 0]) # (n, 1280)
|
1238 |
elif negative_prompt_embeds is not None and num_prompts > num_nprompts:
|
1239 |
# # negative prompts = 1; # prompts > 1.
|
1240 |
assert negative_prompt_embeds.shape[0] == 1 and prompt_embeds.shape[0] == num_prompts
|
1241 |
negative_prompt_embeds = negative_prompt_embeds.repeat(num_prompts, 1, 1)
|
1242 |
+
|
1243 |
+
assert negative_pooled_prompt_embeds.shape[0] == 1 and pooled_prompt_embeds.shape[0] == num_prompts
|
1244 |
+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(num_prompts, 1)
|
1245 |
# assert negative_prompt_embeds.shape[0] == prompt_embeds.shape[0] == num_prompts
|
1246 |
if num_masks > num_prompts:
|
1247 |
assert masks.shape[0] == num_masks and num_prompts == 1
|
|
|
1249 |
if negative_prompt_embeds is not None:
|
1250 |
negative_prompt_embeds = negative_prompt_embeds.repeat(num_masks, 1, 1)
|
1251 |
|
1252 |
+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(num_masks, 1)
|
1253 |
+
if negative_pooled_prompt_embeds is not None:
|
1254 |
+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(num_masks, 1)
|
1255 |
+
|
1256 |
+
add_text_embeds = pooled_prompt_embeds
|
1257 |
+
if self.text_encoder_2 is None:
|
1258 |
+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
1259 |
+
else:
|
1260 |
+
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
1261 |
+
|
1262 |
+
add_time_ids = self._get_add_time_ids(
|
1263 |
+
original_size,
|
1264 |
+
crops_coords_top_left,
|
1265 |
+
target_size,
|
1266 |
+
dtype=prompt_embeds.dtype,
|
1267 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
1268 |
+
)
|
1269 |
+
if negative_original_size is not None and negative_target_size is not None:
|
1270 |
+
negative_add_time_ids = self._get_add_time_ids(
|
1271 |
+
negative_original_size,
|
1272 |
+
negative_crops_coords_top_left,
|
1273 |
+
negative_target_size,
|
1274 |
+
dtype=prompt_embeds.dtype,
|
1275 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
1276 |
+
)
|
1277 |
+
else:
|
1278 |
+
negative_add_time_ids = add_time_ids
|
1279 |
+
|
1280 |
# SDXL pipeline settings.
|
1281 |
if do_classifier_free_guidance:
|
1282 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
|
|
1284 |
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
1285 |
del negative_prompt_embeds, negative_pooled_prompt_embeds, negative_add_time_ids
|
1286 |
|
1287 |
+
prompt_embeds = prompt_embeds.to(self.device)
|
1288 |
+
add_text_embeds = add_text_embeds.to(self.device)
|
1289 |
+
add_time_ids = add_time_ids.to(self.device).repeat(batch_size * num_images_per_prompt, 1)
|
1290 |
|
1291 |
|
1292 |
### Run
|
1293 |
|
1294 |
# Latent initialization.
|
1295 |
+
noise = torch.randn((1, self.unet.config.in_channels, h, w), dtype=self.dtype, device=self.device)
|
1296 |
if self.timesteps[0] < 999 and has_background:
|
1297 |
+
latent = self.scheduler_add_noise(bg_latent, noise, 0, initial=True)
|
1298 |
else:
|
1299 |
+
noise = torch.randn((1, self.unet.config.in_channels, h, w), dtype=self.dtype, device=self.device)
|
1300 |
+
latent = noise * self.scheduler.init_noise_sigma
|
1301 |
+
|
1302 |
+
if has_background:
|
1303 |
+
noise_bg_latents = [
|
1304 |
+
self.scheduler_add_noise(bg_latent, noise, i, initial=True) for i in range(len(self.timesteps))
|
1305 |
+
] + [bg_latent]
|
1306 |
|
1307 |
# Tiling (if needed).
|
1308 |
if height > tile_size or width > tile_size:
|
|
|
1311 |
tile_masks = tile_masks.to(self.device)
|
1312 |
else:
|
1313 |
views = [(0, h, 0, w)]
|
1314 |
+
tile_masks = latent.new_ones((1, 1, h, w))
|
1315 |
+
value = torch.zeros_like(latent)
|
1316 |
+
count_all = torch.zeros_like(latent)
|
1317 |
|
1318 |
with torch.autocast('cuda'):
|
1319 |
for i, t in enumerate(tqdm(self.timesteps)):
|
|
|
1324 |
count_all.zero_()
|
1325 |
for j, (h_start, h_end, w_start, w_end) in enumerate(views):
|
1326 |
fg_mask_ = fg_mask[..., h_start:h_end, w_start:w_end]
|
1327 |
+
latent_ = latent[..., h_start:h_end, w_start:w_end].repeat(num_masks, 1, 1, 1)
|
1328 |
|
1329 |
# Additional arguments for the SDXL pipeline.
|
1330 |
add_time_ids_input = add_time_ids.clone()
|
|
|
1336 |
if i < bootstrap_steps:
|
1337 |
mix_ratio = min(1, max(0, boostrap_mix_steps - i))
|
1338 |
# Treat the first foreground latent as the background latent if one does not exist.
|
1339 |
+
bg_latent_ = noise_bg_latents[i][..., h_start:h_end, w_start:w_end] if has_background else latent_[:1]
|
1340 |
white_ = white[..., h_start:h_end, w_start:w_end]
|
1341 |
+
white_ = self.scheduler_add_noise(white_, noise[..., h_start:h_end, w_start:w_end], i, initial=True)
|
1342 |
+
bg_latent_ = mix_ratio * white_ + (1.0 - mix_ratio) * bg_latent_
|
1343 |
+
latent_ = (1.0 - fg_mask_) * bg_latent_ + fg_mask_ * latent_
|
1344 |
|
1345 |
# Centering.
|
1346 |
+
latent_ = shift_to_mask_bbox_center(latent_, fg_mask_, reverse=True)
|
1347 |
|
1348 |
+
latent_model_input = torch.cat([latent_] * 2) if do_classifier_free_guidance else latent_
|
1349 |
latent_model_input = self.scheduler_scale_model_input(latent_model_input, i)
|
1350 |
|
1351 |
# Perform one step of the reverse diffusion.
|
|
|
1365 |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
1366 |
|
1367 |
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
|
|
1368 |
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_cond, guidance_rescale=guidance_rescale)
|
1369 |
|
1370 |
+
latent_ = self.scheduler_step(noise_pred, i, latent_)
|
1371 |
|
1372 |
if i < bootstrap_steps:
|
1373 |
# Uncentering.
|
1374 |
+
latent_ = shift_to_mask_bbox_center(latent_, fg_mask_)
|
1375 |
|
1376 |
# Remove leakage (optional).
|
1377 |
+
leak = (latent_ - bg_latent_).pow(2).mean(dim=1, keepdim=True)
|
1378 |
leak_sigmoid = torch.sigmoid(leak / bootstrap_leak_sensitivity) * 2 - 1
|
1379 |
fg_mask_ = fg_mask_ * leak_sigmoid
|
1380 |
|
1381 |
# Mix the latents.
|
1382 |
fg_mask_ = fg_mask_ * tile_masks[:, j:j + 1, h_start:h_end, w_start:w_end]
|
1383 |
+
value[..., h_start:h_end, w_start:w_end] += (fg_mask_ * latent_).sum(dim=0, keepdim=True)
|
1384 |
count_all[..., h_start:h_end, w_start:w_end] += fg_mask_.sum(dim=0, keepdim=True)
|
1385 |
|
1386 |
+
latent = torch.where(count_all > 0, value / count_all, value)
|
1387 |
bg_mask = (1 - count_all).clip_(0, 1) # (T, 1, h, w)
|
1388 |
if has_background:
|
1389 |
+
latent = (1 - bg_mask) * latent + bg_mask * noise_bg_latents[i + 1] # bg_latent
|
1390 |
|
1391 |
# Noise is added after mixing.
|
1392 |
if i < len(self.timesteps) - 1:
|
1393 |
+
latent = self.scheduler_add_noise(latent, None, i + 1)
|
1394 |
|
1395 |
if not output_type == "latent":
|
1396 |
# make sure the VAE is in float32 mode, as it overflows in float16
|
|
|
1398 |
|
1399 |
if needs_upcasting:
|
1400 |
self.upcast_vae()
|
1401 |
+
latent = latent.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
1402 |
|
1403 |
# unscale/denormalize the latents
|
1404 |
# denormalize with the mean and std if available and not None
|
|
|
1406 |
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
|
1407 |
if has_latents_mean and has_latents_std:
|
1408 |
latents_mean = (
|
1409 |
+
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latent.device, latent.dtype)
|
1410 |
)
|
1411 |
latents_std = (
|
1412 |
+
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latent.device, latent.dtype)
|
1413 |
)
|
1414 |
+
latent = latent * latents_std / self.vae.config.scaling_factor + latents_mean
|
1415 |
else:
|
1416 |
+
latent = latent / self.vae.config.scaling_factor
|
1417 |
|
1418 |
+
image = self.vae.decode(latent, return_dict=False)[0]
|
1419 |
|
1420 |
# cast back to fp16 if needed
|
1421 |
if needs_upcasting:
|
1422 |
self.vae.to(dtype=torch.float16)
|
1423 |
else:
|
1424 |
+
image = latent
|
1425 |
|
1426 |
# Return PIL Image.
|
1427 |
image = image[0].clip_(-1, 1) * 0.5 + 0.5
|
|
|
1430 |
image = blend(image, background[0], fg_mask)
|
1431 |
else:
|
1432 |
image = T.ToPILImage()(image)
|
1433 |
+
return image
|