Upload combined_stable_diffusion.py with huggingface_hub
Browse files- combined_stable_diffusion.py +3 -18
combined_stable_diffusion.py
CHANGED
@@ -1,10 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from diffusers import DiffusionPipeline, DDPMScheduler
|
3 |
-
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
4 |
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
5 |
from diffusers.image_processor import VaeImageProcessor
|
6 |
from huggingface_hub import PyTorchModelHubMixin
|
7 |
-
from transformers import CLIPTextModel,
|
8 |
from diffusers.models.attention_processor import (
|
9 |
AttnProcessor2_0,
|
10 |
FusedAttnProcessor2_0,
|
@@ -358,7 +357,6 @@ class CombinedStableDiffusionXL(
|
|
358 |
guidance_scale=5,
|
359 |
output_type: str = "pil",
|
360 |
return_dict: bool = True,
|
361 |
-
generator=None,
|
362 |
):
|
363 |
self.guidance_scale = guidance_scale
|
364 |
self.resolution = resolution
|
@@ -440,20 +438,7 @@ class CombinedStableDiffusionXL(
|
|
440 |
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
441 |
self.vae = self.vae.to(latents.dtype)
|
442 |
|
443 |
-
|
444 |
-
# denormalize with the mean and std if available and not None
|
445 |
-
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
|
446 |
-
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
|
447 |
-
if has_latents_mean and has_latents_std:
|
448 |
-
latents_mean = (
|
449 |
-
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
450 |
-
)
|
451 |
-
latents_std = (
|
452 |
-
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
453 |
-
)
|
454 |
-
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
|
455 |
-
else:
|
456 |
-
latents = latents / self.vae.config.scaling_factor
|
457 |
|
458 |
image = self.vae.decode(latents, return_dict=False)[0]
|
459 |
|
|
|
1 |
import torch
|
2 |
+
from diffusers import DiffusionPipeline, DDPMScheduler, StableDiffusionPipeline
|
|
|
3 |
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
4 |
from diffusers.image_processor import VaeImageProcessor
|
5 |
from huggingface_hub import PyTorchModelHubMixin
|
6 |
+
from transformers import CLIPTextModel, CLIPTextModelWithProjection
|
7 |
from diffusers.models.attention_processor import (
|
8 |
AttnProcessor2_0,
|
9 |
FusedAttnProcessor2_0,
|
|
|
357 |
guidance_scale=5,
|
358 |
output_type: str = "pil",
|
359 |
return_dict: bool = True,
|
|
|
360 |
):
|
361 |
self.guidance_scale = guidance_scale
|
362 |
self.resolution = resolution
|
|
|
438 |
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
439 |
self.vae = self.vae.to(latents.dtype)
|
440 |
|
441 |
+
latents = latents / self.vae.config.scaling_factor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
442 |
|
443 |
image = self.vae.decode(latents, return_dict=False)[0]
|
444 |
|