|
|
|
|
|
import torch
|
|
from diffusers import StableDiffusionPipeline
|
|
from BeamDiffusionModel.models.diffusionModel.configs.config_loader import CONFIG
|
|
from functools import partial
|
|
from BeamDiffusionModel.models.diffusionModel.Latents_Singleton import Latents
|
|
|
|
|
|
class StableDiffusion:
|
|
def __init__(self):
|
|
self.device = "cuda" if CONFIG.get("stable_diffusion", {}).get("use_cuda", True) and torch.cuda.is_available() else "cpu"
|
|
self.torch_dtype = torch.float16 if CONFIG.get("stable_diffusion", {}).get("precision") == "float16" else torch.float32
|
|
|
|
print(f"Loading model: {CONFIG['stable_diffusion']['id']} on {self.device}")
|
|
|
|
self.pipe = StableDiffusionPipeline.from_pretrained(CONFIG["stable_diffusion"]["id"], torch_dtype=self.torch_dtype)
|
|
self.pipe.to(self.device)
|
|
|
|
self.unet = self.pipe.unet
|
|
self.vae = self.pipe.vae
|
|
|
|
print("Model loaded successfully!")
|
|
|
|
|
|
def capture_latents(self, latents_store: Latents, pipe, step, timestep, callback_kwargs):
|
|
latents = callback_kwargs["latents"]
|
|
latents_store.add_latents(latents)
|
|
return callback_kwargs
|
|
|
|
def generate_image(self, prompt: str, latent=None, generator=None):
|
|
latents = Latents()
|
|
callback = partial(self.capture_latents, latents)
|
|
img = self.pipe(prompt, latents=latent, callback_on_step_end=callback,
|
|
generator=generator, callback_on_step_end_tensor_inputs=["latents"],
|
|
num_inference_steps=CONFIG["stable_diffusion"]["diffusion_settings"]["steps"]).images[0]
|
|
|
|
return img, latents.dump_and_clear() |