superdiff-sdxl-v1-0 / pipeline.py
mskrt's picture
Upload 18 files
7698974 verified
raw
history blame
13 kB
import random
from typing import Callable, Dict, List, Optional
import torch
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import ConfigMixin
from tqdm import tqdm
# from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
# from diffusers import AutoencoderKL, UNet2DConditionModel
def get_scaled_coeffs():
"""get_scaled_coeffs.
"""
beta_min = 0.85
beta_max = 12.0
return beta_min**0.5, beta_max**0.5-beta_min**0.5
def beta(t):
"""beta.
Parameters
----------
t :
t
"""
a, b = get_scaled_coeffs()
return (a+t*b)**2
def int_beta(t):
"""int_beta.
Parameters
----------
t :
t
"""
a, b = get_scaled_coeffs()
return ((a+b*t)**3-a**3)/(3*b)
def sigma(t):
"""sigma.
Parameters
----------
t :
t
"""
return torch.expm1(int_beta(t))**0.5
def sigma_orig(t):
"""sigma_orig.
Parameters
----------
t :
t
"""
return (-torch.expm1(-int_beta(t)))**0.5
class SuperDiffSDXLPipeline(DiffusionPipeline, ConfigMixin):
"""SuperDiffSDXLPipeline."""
def __init__(self, unet: Callable, vae: Callable, text_encoder: Callable, text_encoder_2: Callable, tokenizer: Callable, tokenizer_2: Callable) -> None:
"""__init__.
Parameters
----------
model : Callable
model
vae : Callable
vae
text_encoder : Callable
text_encoder
scheduler : Callable
scheduler
tokenizer : Callable
tokenizer
kwargs :
kwargs
Returns
-------
None
"""
super().__init__()
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype=torch.float16
vae.to(device)
unet.to(device)
text_encoder.to(device)
text_encoder_2.to(device)
self.register_modules(unet=unet,
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
)
def prepare_prompt_input(self, prompt_o, prompt_b, batch_size, height, width):
"""prepare_prompt_input.
Parameters
----------
prompt_o :
prompt_o
prompt_b :
prompt_b
batch_size :
batch_size
height :
height
width :
width
"""
text_input = self.tokenizer(prompt_o* batch_size, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
text_input_2 = self.tokenizer_2(prompt_o* batch_size, padding="max_length", max_length=self.tokenizer_2.model_max_length, truncation=True, return_tensors="pt")
with torch.no_grad():
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device), output_hidden_states=True)
text_embeddings_2 = self.text_encoder_2(text_input_2.input_ids.to(self.device), output_hidden_states=True)
prompt_embeds_o = torch.concat((text_embeddings.hidden_states[-2], text_embeddings_2.hidden_states[-2]), dim=-1)
pooled_prompt_embeds_o = text_embeddings_2[0]
negative_prompt_embeds = torch.zeros_like(prompt_embeds_o)
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds_o)
text_input = self.tokenizer(prompt_b* batch_size, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
text_input_2 = self.tokenizer_2(prompt_b* batch_size, padding="max_length", max_length=self.tokenizer_2.model_max_length, truncation=True, return_tensors="pt")
with torch.no_grad():
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device), output_hidden_states=True)
text_embeddings_2 = self.text_encoder_2(text_input_2.input_ids.to(self.device), output_hidden_states=True)
prompt_embeds_b = torch.concat((text_embeddings.hidden_states[-2], text_embeddings_2.hidden_states[-2]), dim=-1)
pooled_prompt_embeds_b = text_embeddings_2[0]
add_time_ids_o = torch.tensor([(height,width,0,0,height,width)])
add_time_ids_b = torch.tensor([(height,width,0,0,height,width)])
negative_add_time_ids = torch.tensor([(height,width,0,0,height,width)])
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds_o, prompt_embeds_b], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds_o, pooled_prompt_embeds_b], dim=0)
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids_o, add_time_ids_b], dim=0)
prompt_embeds = prompt_embeds.to(self.device)
add_text_embeds = add_text_embeds.to(self.device)
add_time_ids = add_time_ids.to(self.device).repeat(batch_size, 1)
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
return prompt_embeds, added_cond_kwargs
@torch.no_grad
def get_batch(self, latents: Callable, nrow: int, ncol: int) -> Callable:
"""get_batch.
Parameters
----------
latents : Callable
latents
nrow : int
nrow
ncol : int
ncol
Returns
-------
Callable
"""
image = self.vae.decode(
latents / self.vae.config.scaling_factor, return_dict=False
)[0]
image = (image / 2 + 0.5).clamp(0, 1).squeeze()
if len(image.shape) < 4:
image = image.unsqueeze(0)
image = (image.permute(0, 2, 3, 1) * 255).to(torch.uint8)
return image
@torch.no_grad
def get_text_embedding(self, prompt: str) -> Callable:
"""get_text_embedding.
Parameters
----------
prompt : str
prompt
Returns
-------
Callable
"""
text_input = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
return self.text_encoder(text_input.input_ids.to(self.device))[0]
@torch.no_grad
def get_vel(self, t: float, sigma: float, latents: Callable, embeddings: Callable):
"""get_vel.
Parameters
----------
t : float
t
sigma : float
sigma
latents : Callable
latents
embeddings : Callable
embeddings
"""
def v(_x, _e): return self.model(
"""v.
Parameters
----------
_x :
_x
_e :
_e
"""
_x / ((sigma**2 + 1) ** 0.5), t, encoder_hidden_states=_e
).sample
embeds = torch.cat(embeddings)
latent_input = latents
vel = v(latent_input, embeds)
return vel
def preprocess(
self,
prompt_1: str,
prompt_2: str,
seed: int = None,
num_inference_steps: int = 200,
batch_size: int = 1,
height: int = 1024,
width: int = 1024,
guidance_scale: float = 7.5,
) -> Callable:
"""preprocess.
Parameters
----------
prompt_1 : str
prompt_1
prompt_2 : str
prompt_2
seed : int
seed
num_inference_steps : int
num_inference_steps
batch_size : int
batch_size
height : int
height
width : int
width
guidance_scale : float
guidance_scale
Returns
-------
Callable
"""
# Tokenize the input
self.batch_size = batch_size
self.num_inference_steps = num_inference_steps
self.guidance_scale = guidance_scale
self.seed = seed
if self.seed is None:
self.seed = random.randint(0, 2**32 - 1)
self.generator = torch.cuda.manual_seed(
self.seed
) # Seed generator to create the initial latent noise
latents = torch.randn((batch_size, self.unet.in_channels, height // 8, width // 8), generator=self.generator, dtype=self.dtype, device=self.device,)
prompt_embeds, added_cond_kwargs = self.prepare_prompt_input(prompt_1, prompt_2, batch_size, height, width)
return {
"latents": latents,
"prompt_embeds": prompt_embeds,
"added_cond_kwargs": added_cond_kwargs,
}
def _forward(self, model_inputs: Dict) -> Callable:
"""_forward.
Parameters
----------
model_inputs : Dict
model_inputs
Returns
-------
Callable
"""
latents = model_inputs["latents"]
prompt_embeds = model_inputs["prompt_embeds"]
added_cond_kwargs = model_inputs["added_cond_kwargs"]
t = torch.tensor(1.0)
dt = 1.0/self.num_inference_steps
train_number_steps = 1000
latents = latents * (sigma(t)**2+1)**0.5
with torch.no_grad():
for i in tqdm(range(self.num_inference_steps)):
latent_model_input = torch.cat([latents] * 3)
sigma_t = sigma(t)
dsigma = sigma(t-dt) - sigma_t
latent_model_input /= (sigma_t**2+1)**0.5
with torch.no_grad():
noise_pred = self.unet(latent_model_input, t*train_number_steps, encoder_hidden_states=prompt_embeds, added_cond_kwargs=added_cond_kwargs, return_dict=False)[0]
noise_pred_uncond, noise_pred_text_o, noise_pred_text_b = noise_pred.chunk(3)
# noise = torch.sqrt(2*torch.abs(dsigma)*sigma_t)*torch.randn_like(latents)
noise = torch.sqrt(2*torch.abs(dsigma)*sigma_t)*torch.empty_like(latents, device=self.device).normal_(generator=self.generator)
dx_ind = 2*dsigma*(noise_pred_uncond + self.guidance_scale*(noise_pred_text_b - noise_pred_uncond)) + noise
kappa = (torch.abs(dsigma)*(noise_pred_text_b-noise_pred_text_o)*(noise_pred_text_b+noise_pred_text_o)).sum((1,2,3))-(dx_ind*((noise_pred_text_o-noise_pred_text_b))).sum((1,2,3))
kappa /= 2*dsigma*self.guidance_scale*((noise_pred_text_o-noise_pred_text_b)**2).sum((1,2,3))
noise_pred = noise_pred_uncond + self.guidance_scale*((noise_pred_text_b - noise_pred_uncond) + kappa[:,None,None,None]*(noise_pred_text_o-noise_pred_text_b))
if i < self.num_inference_steps - 1:
latents += 2*dsigma * noise_pred + noise
else:
latents += dsigma * noise_pred
t -= dt
return latents
def postprocess(self, latents: Callable) -> Callable:
"""postprocess.
Parameters
----------
latents : Callable
latents
Returns
-------
Callable
"""
latents = latents/self.vae.config.scaling_factor
latents = latents.to(torch.float32)
with torch.no_grad():
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
return images
def __call__(
self,
prompt_1: str,
prompt_2: str,
seed: int = None,
num_inference_steps: int = 200,
batch_size: int = 1,
height: int = 1024,
width: int = 1024,
guidance_scale: float = 7.5,
) -> Callable:
"""__call__.
Parameters
----------
prompt_1 : str
prompt_1
prompt_2 : str
prompt_2
seed : int
seed
num_inference_steps : int
num_inference_steps
batch_size : int
batch_size
height : int
height
width : int
width
guidance_scale : int
guidance_scale
Returns
-------
Callable
"""
# Preprocess inputs
model_inputs = self.preprocess(
prompt_1,
prompt_2,
seed,
num_inference_steps,
batch_size,
height,
width,
guidance_scale,
)
# Forward pass through the pipeline
latents = self._forward(model_inputs)
# Postprocess to generate the final output
images = self.postprocess(latents)
return images