Upload pipeline.py
Browse files- pipeline.py +65 -38
pipeline.py
CHANGED
@@ -20,50 +20,30 @@ import numpy as np
|
|
20 |
import PIL.Image
|
21 |
import torch
|
22 |
import torch.nn.functional as F
|
23 |
-
from transformers import (
|
24 |
-
|
25 |
-
CLIPTextModel,
|
26 |
-
CLIPTokenizer,
|
27 |
-
CLIPVisionModelWithProjection,
|
28 |
-
)
|
29 |
|
30 |
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
31 |
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
32 |
-
from diffusers.loaders import (
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
)
|
38 |
-
from diffusers.models import (
|
39 |
-
AutoencoderKL,
|
40 |
-
ControlNetModel,
|
41 |
-
ImageProjection,
|
42 |
-
UNet2DConditionModel,
|
43 |
-
)
|
44 |
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
45 |
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
46 |
-
from diffusers.pipelines.pipeline_utils import DiffusionPipeline,
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
from diffusers.pipelines.stable_diffusion.safety_checker import
|
51 |
-
StableDiffusionSafetyChecker
|
52 |
-
)
|
53 |
from diffusers.schedulers import KarrasDiffusionSchedulers
|
54 |
-
from diffusers.utils import (
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
scale_lora_layers,
|
60 |
-
unscale_lora_layers,
|
61 |
-
)
|
62 |
-
from diffusers.utils.torch_utils import (
|
63 |
-
is_compiled_module,
|
64 |
-
is_torch_version,
|
65 |
-
randn_tensor,
|
66 |
-
)
|
67 |
|
68 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
69 |
|
@@ -691,6 +671,7 @@ class StableDiffusionControlNetPipeline(
|
|
691 |
control_guidance_start=0.0,
|
692 |
control_guidance_end=1.0,
|
693 |
callback_on_step_end_tensor_inputs=None,
|
|
|
694 |
):
|
695 |
if callback_steps is not None and (
|
696 |
not isinstance(callback_steps, int) or callback_steps <= 0
|
@@ -853,6 +834,9 @@ class StableDiffusionControlNetPipeline(
|
|
853 |
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
854 |
)
|
855 |
|
|
|
|
|
|
|
856 |
def check_image(self, image, prompt, prompt_embeds):
|
857 |
image_is_pil = isinstance(image, PIL.Image.Image)
|
858 |
image_is_tensor = isinstance(image, torch.Tensor)
|
@@ -894,6 +878,16 @@ class StableDiffusionControlNetPipeline(
|
|
894 |
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
|
895 |
)
|
896 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
897 |
def prepare_image(
|
898 |
self,
|
899 |
image,
|
@@ -995,6 +989,20 @@ class StableDiffusionControlNetPipeline(
|
|
995 |
assert emb.shape == (w.shape[0], embedding_dim)
|
996 |
return emb
|
997 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
998 |
@property
|
999 |
def guidance_scale(self):
|
1000 |
return self._guidance_scale
|
@@ -1173,6 +1181,8 @@ class StableDiffusionControlNetPipeline(
|
|
1173 |
callback = kwargs.pop("callback", None)
|
1174 |
callback_steps = kwargs.pop("callback_steps", None)
|
1175 |
|
|
|
|
|
1176 |
if callback is not None:
|
1177 |
deprecate(
|
1178 |
"callback",
|
@@ -1233,6 +1243,7 @@ class StableDiffusionControlNetPipeline(
|
|
1233 |
control_guidance_start,
|
1234 |
control_guidance_end,
|
1235 |
callback_on_step_end_tensor_inputs,
|
|
|
1236 |
)
|
1237 |
|
1238 |
self._guidance_scale = guidance_scale
|
@@ -1439,6 +1450,7 @@ class StableDiffusionControlNetPipeline(
|
|
1439 |
controlnet_cond_scale = controlnet_cond_scale[0]
|
1440 |
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
1441 |
|
|
|
1442 |
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
1443 |
control_model_input,
|
1444 |
t,
|
@@ -1449,6 +1461,21 @@ class StableDiffusionControlNetPipeline(
|
|
1449 |
return_dict=False,
|
1450 |
)
|
1451 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1452 |
if guess_mode and self.do_classifier_free_guidance:
|
1453 |
# Inferred ControlNet only for the conditional batch.
|
1454 |
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
|
|
20 |
import PIL.Image
|
21 |
import torch
|
22 |
import torch.nn.functional as F
|
23 |
+
from transformers import (CLIPImageProcessor, CLIPTextModel, CLIPTokenizer,
|
24 |
+
CLIPVisionModelWithProjection)
|
|
|
|
|
|
|
|
|
25 |
|
26 |
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
27 |
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
28 |
+
from diffusers.loaders import (FromSingleFileMixin, IPAdapterMixin,
|
29 |
+
StableDiffusionLoraLoaderMixin,
|
30 |
+
TextualInversionLoaderMixin)
|
31 |
+
from diffusers.models import (AutoencoderKL, ControlNetModel, ImageProjection,
|
32 |
+
UNet2DConditionModel)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
34 |
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
35 |
+
from diffusers.pipelines.pipeline_utils import (DiffusionPipeline,
|
36 |
+
StableDiffusionMixin)
|
37 |
+
from diffusers.pipelines.stable_diffusion.pipeline_output import \
|
38 |
+
StableDiffusionPipelineOutput
|
39 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import \
|
40 |
+
StableDiffusionSafetyChecker
|
|
|
41 |
from diffusers.schedulers import KarrasDiffusionSchedulers
|
42 |
+
from diffusers.utils import (USE_PEFT_BACKEND, deprecate, logging,
|
43 |
+
replace_example_docstring, scale_lora_layers,
|
44 |
+
unscale_lora_layers)
|
45 |
+
from diffusers.utils.torch_utils import (is_compiled_module, is_torch_version,
|
46 |
+
randn_tensor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
49 |
|
|
|
671 |
control_guidance_start=0.0,
|
672 |
control_guidance_end=1.0,
|
673 |
callback_on_step_end_tensor_inputs=None,
|
674 |
+
effective_region_mask=None,
|
675 |
):
|
676 |
if callback_steps is not None and (
|
677 |
not isinstance(callback_steps, int) or callback_steps <= 0
|
|
|
834 |
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
835 |
)
|
836 |
|
837 |
+
if effective_region_mask is not None:
|
838 |
+
self.check_mask(effective_region_mask)
|
839 |
+
|
840 |
def check_image(self, image, prompt, prompt_embeds):
|
841 |
image_is_pil = isinstance(image, PIL.Image.Image)
|
842 |
image_is_tensor = isinstance(image, torch.Tensor)
|
|
|
878 |
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
|
879 |
)
|
880 |
|
881 |
+
def check_mask(self, mask):
|
882 |
+
image_is_pil = isinstance(mask, PIL.Image.Image)
|
883 |
+
image_is_tensor = isinstance(mask, torch.Tensor)
|
884 |
+
image_is_np = isinstance(mask, np.ndarray)
|
885 |
+
|
886 |
+
if not image_is_pil and not image_is_tensor and not image_is_np:
|
887 |
+
raise TypeError(
|
888 |
+
f"mask must be passed and be one of PIL image, numpy array, or torch tensor, but is {type(mask)}"
|
889 |
+
)
|
890 |
+
|
891 |
def prepare_image(
|
892 |
self,
|
893 |
image,
|
|
|
989 |
assert emb.shape == (w.shape[0], embedding_dim)
|
990 |
return emb
|
991 |
|
992 |
+
def apply_effective_region_mask(
|
993 |
+
self, effective_region_mask: torch.Tensor, out: torch.Tensor
|
994 |
+
) -> torch.Tensor:
|
995 |
+
if effective_region_mask is None:
|
996 |
+
return out
|
997 |
+
|
998 |
+
B, C, H, W = out.shape
|
999 |
+
mask = F.interpolate(
|
1000 |
+
effective_region_mask.to(out.device),
|
1001 |
+
size=(H, W),
|
1002 |
+
mode="bilinear",
|
1003 |
+
)
|
1004 |
+
return out * mask
|
1005 |
+
|
1006 |
@property
|
1007 |
def guidance_scale(self):
|
1008 |
return self._guidance_scale
|
|
|
1181 |
callback = kwargs.pop("callback", None)
|
1182 |
callback_steps = kwargs.pop("callback_steps", None)
|
1183 |
|
1184 |
+
effective_region_mask = kwargs.pop("effective_region_mask", None)
|
1185 |
+
|
1186 |
if callback is not None:
|
1187 |
deprecate(
|
1188 |
"callback",
|
|
|
1243 |
control_guidance_start,
|
1244 |
control_guidance_end,
|
1245 |
callback_on_step_end_tensor_inputs,
|
1246 |
+
effective_region_mask,
|
1247 |
)
|
1248 |
|
1249 |
self._guidance_scale = guidance_scale
|
|
|
1450 |
controlnet_cond_scale = controlnet_cond_scale[0]
|
1451 |
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
1452 |
|
1453 |
+
# Controlnet is returning the residuals to be added to SD here
|
1454 |
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
1455 |
control_model_input,
|
1456 |
t,
|
|
|
1461 |
return_dict=False,
|
1462 |
)
|
1463 |
|
1464 |
+
# Apply mask here
|
1465 |
+
# Note that downblocks are ordered from largest->smallest
|
1466 |
+
if effective_region_mask is not None:
|
1467 |
+
masked_down_block_res_samples = ()
|
1468 |
+
for down_block_res_sample in down_block_res_samples:
|
1469 |
+
down_block_res_sample = self.apply_effective_region_mask(
|
1470 |
+
effective_region_mask, down_block_res_sample
|
1471 |
+
)
|
1472 |
+
masked_down_block_res_samples = (
|
1473 |
+
masked_down_block_res_samples + (down_block_res_sample,)
|
1474 |
+
)
|
1475 |
+
mid_block_res_sample = self.apply_effective_region_mask(
|
1476 |
+
effective_region_mask, mid_block_res_sample
|
1477 |
+
)
|
1478 |
+
|
1479 |
if guess_mode and self.do_classifier_free_guidance:
|
1480 |
# Inferred ControlNet only for the conditional batch.
|
1481 |
# To apply the output of ControlNet to both the unconditional and conditional batches,
|