codeysun commited on
Commit
e5d83a1
·
verified ·
1 Parent(s): bcadd8d

Upload pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +65 -38
pipeline.py CHANGED
@@ -20,50 +20,30 @@ import numpy as np
20
  import PIL.Image
21
  import torch
22
  import torch.nn.functional as F
23
- from transformers import (
24
- CLIPImageProcessor,
25
- CLIPTextModel,
26
- CLIPTokenizer,
27
- CLIPVisionModelWithProjection,
28
- )
29
 
30
  from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
31
  from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
32
- from diffusers.loaders import (
33
- FromSingleFileMixin,
34
- IPAdapterMixin,
35
- StableDiffusionLoraLoaderMixin,
36
- TextualInversionLoaderMixin,
37
- )
38
- from diffusers.models import (
39
- AutoencoderKL,
40
- ControlNetModel,
41
- ImageProjection,
42
- UNet2DConditionModel,
43
- )
44
  from diffusers.models.lora import adjust_lora_scale_text_encoder
45
  from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
46
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
47
- from diffusers.pipelines.stable_diffusion.pipeline_output import (
48
- StableDiffusionPipelineOutput,
49
- )
50
- from diffusers.pipelines.stable_diffusion.safety_checker import (
51
- StableDiffusionSafetyChecker,
52
- )
53
  from diffusers.schedulers import KarrasDiffusionSchedulers
54
- from diffusers.utils import (
55
- USE_PEFT_BACKEND,
56
- deprecate,
57
- logging,
58
- replace_example_docstring,
59
- scale_lora_layers,
60
- unscale_lora_layers,
61
- )
62
- from diffusers.utils.torch_utils import (
63
- is_compiled_module,
64
- is_torch_version,
65
- randn_tensor,
66
- )
67
 
68
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
69
 
@@ -691,6 +671,7 @@ class StableDiffusionControlNetPipeline(
691
  control_guidance_start=0.0,
692
  control_guidance_end=1.0,
693
  callback_on_step_end_tensor_inputs=None,
 
694
  ):
695
  if callback_steps is not None and (
696
  not isinstance(callback_steps, int) or callback_steps <= 0
@@ -853,6 +834,9 @@ class StableDiffusionControlNetPipeline(
853
  f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
854
  )
855
 
 
 
 
856
  def check_image(self, image, prompt, prompt_embeds):
857
  image_is_pil = isinstance(image, PIL.Image.Image)
858
  image_is_tensor = isinstance(image, torch.Tensor)
@@ -894,6 +878,16 @@ class StableDiffusionControlNetPipeline(
894
  f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
895
  )
896
 
 
 
 
 
 
 
 
 
 
 
897
  def prepare_image(
898
  self,
899
  image,
@@ -995,6 +989,20 @@ class StableDiffusionControlNetPipeline(
995
  assert emb.shape == (w.shape[0], embedding_dim)
996
  return emb
997
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
998
  @property
999
  def guidance_scale(self):
1000
  return self._guidance_scale
@@ -1173,6 +1181,8 @@ class StableDiffusionControlNetPipeline(
1173
  callback = kwargs.pop("callback", None)
1174
  callback_steps = kwargs.pop("callback_steps", None)
1175
 
 
 
1176
  if callback is not None:
1177
  deprecate(
1178
  "callback",
@@ -1233,6 +1243,7 @@ class StableDiffusionControlNetPipeline(
1233
  control_guidance_start,
1234
  control_guidance_end,
1235
  callback_on_step_end_tensor_inputs,
 
1236
  )
1237
 
1238
  self._guidance_scale = guidance_scale
@@ -1439,6 +1450,7 @@ class StableDiffusionControlNetPipeline(
1439
  controlnet_cond_scale = controlnet_cond_scale[0]
1440
  cond_scale = controlnet_cond_scale * controlnet_keep[i]
1441
 
 
1442
  down_block_res_samples, mid_block_res_sample = self.controlnet(
1443
  control_model_input,
1444
  t,
@@ -1449,6 +1461,21 @@ class StableDiffusionControlNetPipeline(
1449
  return_dict=False,
1450
  )
1451
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1452
  if guess_mode and self.do_classifier_free_guidance:
1453
  # Inferred ControlNet only for the conditional batch.
1454
  # To apply the output of ControlNet to both the unconditional and conditional batches,
 
20
  import PIL.Image
21
  import torch
22
  import torch.nn.functional as F
23
+ from transformers import (CLIPImageProcessor, CLIPTextModel, CLIPTokenizer,
24
+ CLIPVisionModelWithProjection)
 
 
 
 
25
 
26
  from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
27
  from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
28
+ from diffusers.loaders import (FromSingleFileMixin, IPAdapterMixin,
29
+ StableDiffusionLoraLoaderMixin,
30
+ TextualInversionLoaderMixin)
31
+ from diffusers.models import (AutoencoderKL, ControlNetModel, ImageProjection,
32
+ UNet2DConditionModel)
 
 
 
 
 
 
 
33
  from diffusers.models.lora import adjust_lora_scale_text_encoder
34
  from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
35
+ from diffusers.pipelines.pipeline_utils import (DiffusionPipeline,
36
+ StableDiffusionMixin)
37
+ from diffusers.pipelines.stable_diffusion.pipeline_output import \
38
+ StableDiffusionPipelineOutput
39
+ from diffusers.pipelines.stable_diffusion.safety_checker import \
40
+ StableDiffusionSafetyChecker
 
41
  from diffusers.schedulers import KarrasDiffusionSchedulers
42
+ from diffusers.utils import (USE_PEFT_BACKEND, deprecate, logging,
43
+ replace_example_docstring, scale_lora_layers,
44
+ unscale_lora_layers)
45
+ from diffusers.utils.torch_utils import (is_compiled_module, is_torch_version,
46
+ randn_tensor)
 
 
 
 
 
 
 
 
47
 
48
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
49
 
 
671
  control_guidance_start=0.0,
672
  control_guidance_end=1.0,
673
  callback_on_step_end_tensor_inputs=None,
674
+ effective_region_mask=None,
675
  ):
676
  if callback_steps is not None and (
677
  not isinstance(callback_steps, int) or callback_steps <= 0
 
834
  f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
835
  )
836
 
837
+ if effective_region_mask is not None:
838
+ self.check_mask(effective_region_mask)
839
+
840
  def check_image(self, image, prompt, prompt_embeds):
841
  image_is_pil = isinstance(image, PIL.Image.Image)
842
  image_is_tensor = isinstance(image, torch.Tensor)
 
878
  f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
879
  )
880
 
881
+ def check_mask(self, mask):
882
+ image_is_pil = isinstance(mask, PIL.Image.Image)
883
+ image_is_tensor = isinstance(mask, torch.Tensor)
884
+ image_is_np = isinstance(mask, np.ndarray)
885
+
886
+ if not image_is_pil and not image_is_tensor and not image_is_np:
887
+ raise TypeError(
888
+ f"mask must be passed and be one of PIL image, numpy array, or torch tensor, but is {type(mask)}"
889
+ )
890
+
891
  def prepare_image(
892
  self,
893
  image,
 
989
  assert emb.shape == (w.shape[0], embedding_dim)
990
  return emb
991
 
992
+ def apply_effective_region_mask(
993
+ self, effective_region_mask: torch.Tensor, out: torch.Tensor
994
+ ) -> torch.Tensor:
995
+ if effective_region_mask is None:
996
+ return out
997
+
998
+ B, C, H, W = out.shape
999
+ mask = F.interpolate(
1000
+ effective_region_mask.to(out.device),
1001
+ size=(H, W),
1002
+ mode="bilinear",
1003
+ )
1004
+ return out * mask
1005
+
1006
  @property
1007
  def guidance_scale(self):
1008
  return self._guidance_scale
 
1181
  callback = kwargs.pop("callback", None)
1182
  callback_steps = kwargs.pop("callback_steps", None)
1183
 
1184
+ effective_region_mask = kwargs.pop("effective_region_mask", None)
1185
+
1186
  if callback is not None:
1187
  deprecate(
1188
  "callback",
 
1243
  control_guidance_start,
1244
  control_guidance_end,
1245
  callback_on_step_end_tensor_inputs,
1246
+ effective_region_mask,
1247
  )
1248
 
1249
  self._guidance_scale = guidance_scale
 
1450
  controlnet_cond_scale = controlnet_cond_scale[0]
1451
  cond_scale = controlnet_cond_scale * controlnet_keep[i]
1452
 
1453
+ # Controlnet is returning the residuals to be added to SD here
1454
  down_block_res_samples, mid_block_res_sample = self.controlnet(
1455
  control_model_input,
1456
  t,
 
1461
  return_dict=False,
1462
  )
1463
 
1464
+ # Apply mask here
1465
+ # Note that downblocks are ordered from largest->smallest
1466
+ if effective_region_mask is not None:
1467
+ masked_down_block_res_samples = ()
1468
+ for down_block_res_sample in down_block_res_samples:
1469
+ down_block_res_sample = self.apply_effective_region_mask(
1470
+ effective_region_mask, down_block_res_sample
1471
+ )
1472
+ masked_down_block_res_samples = (
1473
+ masked_down_block_res_samples + (down_block_res_sample,)
1474
+ )
1475
+ mid_block_res_sample = self.apply_effective_region_mask(
1476
+ effective_region_mask, mid_block_res_sample
1477
+ )
1478
+
1479
  if guess_mode and self.do_classifier_free_guidance:
1480
  # Inferred ControlNet only for the conditional batch.
1481
  # To apply the output of ControlNet to both the unconditional and conditional batches,