Erasing-Concepts-In-Diffusion / StableDiffuser.py
Damian Stewart
allow different resolutions for w/h
d8ffb68
raw
history blame
9.19 kB
import argparse
import torch
from baukit import TraceDict
from diffusers import StableDiffusionPipeline
from PIL import Image
from tqdm.auto import tqdm
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.schedulers.scheduling_lms_discrete import LMSDiscreteScheduler
import util
def default_parser():
parser = argparse.ArgumentParser()
parser.add_argument('prompts', type=str, nargs='+')
parser.add_argument('outpath', type=str)
parser.add_argument('--images', type=str, nargs='+', default=None)
parser.add_argument('--nsteps', type=int, default=1000)
parser.add_argument('--nimgs', type=int, default=1)
parser.add_argument('--start_itr', type=int, default=0)
parser.add_argument('--return_steps', action='store_true', default=False)
parser.add_argument('--pred_x0', action='store_true', default=False)
parser.add_argument('--device', type=str, default='cuda:0')
parser.add_argument('--seed', type=int, default=42)
return parser
class StableDiffuser(torch.nn.Module):
def __init__(self,
scheduler='LMS',
repo_id_or_path="CompVis/stable-diffusion-v1-4",
variant='fp16'
):
super().__init__()
self.pipeline = StableDiffusionPipeline.from_pretrained(repo_id_or_path, variant=variant)
self.vae = self.pipeline.vae
self.unet = self.pipeline.unet
self.tokenizer = self.pipeline.tokenizer
self.text_encoder = self.pipeline.text_encoder
if scheduler == 'LMS':
self.scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
elif scheduler == 'DDIM':
self.scheduler = DDIMScheduler.from_pretrained(repo_id_or_path, subfolder="scheduler")
elif scheduler == 'DDPM':
self.scheduler = DDPMScheduler.from_pretrained(repo_id_or_path, subfolder="scheduler")
self.eval()
@property
def safety_checker(self):
return self.pipeline.safety_checker
def get_noise(self, batch_size, width, height, generator=None):
param = list(self.parameters())[0]
return torch.randn(
(batch_size, self.unet.in_channels, width // 8, height // 8),
generator=generator).type(param.dtype).to(param.device)
def add_noise(self, latents, noise, step):
return self.scheduler.add_noise(latents, noise, torch.tensor([self.scheduler.timesteps[step]]))
def text_tokenize(self, prompts):
return self.tokenizer(prompts, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
def text_detokenize(self, tokens):
return [self.tokenizer.decode(token) for token in tokens if token != self.tokenizer.vocab_size - 1]
def text_encode(self, tokens):
return self.text_encoder(tokens.input_ids.to(self.unet.device))[0]
def decode(self, latents):
return self.vae.decode(1 / self.vae.config.scaling_factor * latents).sample
def encode(self, tensors):
return self.vae.encode(tensors).latent_dist.mode() * 0.18215
def to_image(self, image):
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def set_scheduler_timesteps(self, n_steps):
self.scheduler.set_timesteps(n_steps, device=self.unet.device)
def get_initial_latents(self, n_imgs, width, height, n_prompts, generator=None):
noise = self.get_noise(n_imgs, width, height, generator=generator).repeat(n_prompts, 1, 1, 1)
latents = noise * self.scheduler.init_noise_sigma
return latents
def get_text_embeddings(self, prompts, negative_prompts=None, n_imgs=1):
text_tokens = self.text_tokenize(prompts)
text_embeddings = self.text_encode(text_tokens)
if negative_prompts is None:
negative_prompts = [""] * len(prompts)
unconditional_tokens = self.text_tokenize(negative_prompts)
unconditional_embeddings = self.text_encode(unconditional_tokens)
text_embeddings = torch.cat([unconditional_embeddings, text_embeddings]).repeat_interleave(n_imgs, dim=0)
return text_embeddings
def predict_noise(self,
iteration,
latents,
text_embeddings,
guidance_scale=7.5
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latents = torch.cat([latents] * 2)
latents = self.scheduler.scale_model_input(
latents, self.scheduler.timesteps[iteration])
# predict the noise residual
noise_prediction = self.unet(
latents, self.scheduler.timesteps[iteration], encoder_hidden_states=text_embeddings).sample
# perform guidance
noise_prediction_uncond, noise_prediction_text = noise_prediction.chunk(2)
noise_prediction = noise_prediction_uncond + guidance_scale * \
(noise_prediction_text - noise_prediction_uncond)
return noise_prediction
@torch.no_grad()
def diffusion(self,
latents,
text_embeddings,
end_iteration=1000,
start_iteration=0,
return_steps=False,
pred_x0=False,
trace_args=None,
show_progress=True,
**kwargs):
latents_steps = []
trace_steps = []
trace = None
for iteration in tqdm(range(start_iteration, end_iteration), disable=not show_progress):
if trace_args:
trace = TraceDict(self, **trace_args)
noise_pred = self.predict_noise(
iteration,
latents,
text_embeddings,
**kwargs)
# compute the previous noisy sample x_t -> x_t-1
output = self.scheduler.step(noise_pred, self.scheduler.timesteps[iteration], latents)
if trace_args:
trace.close()
trace_steps.append(trace)
latents = output.prev_sample
if return_steps or iteration == end_iteration - 1:
output = output.pred_original_sample if pred_x0 else latents
if return_steps:
latents_steps.append(output.cpu())
else:
latents_steps.append(output)
return latents_steps, trace_steps
@torch.no_grad()
def __call__(self,
prompts,
negative_prompts,
width=512,
height=512,
n_steps=50,
n_imgs=1,
end_iteration=None,
generator=None,
**kwargs
):
assert 0 <= n_steps <= 1000
if not isinstance(prompts, list):
prompts = [prompts]
self.set_scheduler_timesteps(n_steps)
latents = self.get_initial_latents(n_imgs, width, height, len(prompts), generator=generator)
text_embeddings = self.get_text_embeddings(prompts,negative_prompts,n_imgs=n_imgs)
end_iteration = end_iteration or n_steps
latents_steps, trace_steps = self.diffusion(
latents,
text_embeddings,
end_iteration=end_iteration,
**kwargs
)
latents_steps = [self.decode(latents.to(self.unet.device)) for latents in latents_steps]
images_steps = [self.to_image(latents) for latents in latents_steps]
if self.safety_checker is not None:
for i in range(len(images_steps)):
self.safety_checker = self.safety_checker.float()
safety_checker_input = self.feature_extractor(images_steps[i], return_tensors="pt").to(latents_steps[0].device)
image, has_nsfw_concept = self.safety_checker(
images=latents_steps[i].float().cpu().numpy(), clip_input=safety_checker_input.pixel_values.float()
)
images_steps[i][0] = self.to_image(torch.from_numpy(image))[0]
images_steps = list(zip(*images_steps))
if trace_steps:
return images_steps, trace_steps
return images_steps
def save_pretrained(self, path, **kwargs):
self.pipeline.save_pretrained(path, **kwargs)
if __name__ == '__main__':
parser = default_parser()
args = parser.parse_args()
diffuser = StableDiffuser(scheduler='DDIM').to(torch.device(args.device)).half()
images = diffuser(args.prompts,
n_steps=args.nsteps,
n_imgs=args.nimgs,
start_iteration=args.start_itr,
return_steps=args.return_steps,
pred_x0=args.pred_x0
)
util.image_grid(images, args.outpath)