Upload pipeline.py
Browse files- pipeline.py +54 -18
pipeline.py
CHANGED
@@ -20,30 +20,50 @@ import numpy as np
|
|
20 |
import PIL.Image
|
21 |
import torch
|
22 |
import torch.nn.functional as F
|
23 |
-
from transformers import (
|
24 |
-
|
|
|
|
|
|
|
|
|
25 |
|
26 |
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
27 |
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
28 |
-
from diffusers.loaders import (
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
34 |
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
35 |
-
from diffusers.pipelines.pipeline_utils import
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
from diffusers.pipelines.stable_diffusion.safety_checker import
|
40 |
-
StableDiffusionSafetyChecker
|
|
|
41 |
from diffusers.schedulers import KarrasDiffusionSchedulers
|
42 |
-
from diffusers.utils import (
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
49 |
|
@@ -269,6 +289,11 @@ class StableDiffusionControlNetPipeline(
|
|
269 |
do_convert_rgb=True,
|
270 |
do_normalize=False,
|
271 |
)
|
|
|
|
|
|
|
|
|
|
|
272 |
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
273 |
|
274 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
@@ -1349,6 +1374,17 @@ class StableDiffusionControlNetPipeline(
|
|
1349 |
else:
|
1350 |
assert False
|
1351 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1352 |
# 5. Prepare timesteps
|
1353 |
timesteps, num_inference_steps = retrieve_timesteps(
|
1354 |
self.scheduler, num_inference_steps, device, timesteps, sigmas
|
|
|
20 |
import PIL.Image
|
21 |
import torch
|
22 |
import torch.nn.functional as F
|
23 |
+
from transformers import (
|
24 |
+
CLIPImageProcessor,
|
25 |
+
CLIPTextModel,
|
26 |
+
CLIPTokenizer,
|
27 |
+
CLIPVisionModelWithProjection,
|
28 |
+
)
|
29 |
|
30 |
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
31 |
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
32 |
+
from diffusers.loaders import (
|
33 |
+
FromSingleFileMixin,
|
34 |
+
IPAdapterMixin,
|
35 |
+
StableDiffusionLoraLoaderMixin,
|
36 |
+
TextualInversionLoaderMixin,
|
37 |
+
)
|
38 |
+
from diffusers.models import (
|
39 |
+
AutoencoderKL,
|
40 |
+
ControlNetModel,
|
41 |
+
ImageProjection,
|
42 |
+
UNet2DConditionModel,
|
43 |
+
)
|
44 |
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
45 |
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
46 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
47 |
+
from diffusers.pipelines.stable_diffusion.pipeline_output import (
|
48 |
+
StableDiffusionPipelineOutput,
|
49 |
+
)
|
50 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
51 |
+
StableDiffusionSafetyChecker,
|
52 |
+
)
|
53 |
from diffusers.schedulers import KarrasDiffusionSchedulers
|
54 |
+
from diffusers.utils import (
|
55 |
+
USE_PEFT_BACKEND,
|
56 |
+
deprecate,
|
57 |
+
logging,
|
58 |
+
replace_example_docstring,
|
59 |
+
scale_lora_layers,
|
60 |
+
unscale_lora_layers,
|
61 |
+
)
|
62 |
+
from diffusers.utils.torch_utils import (
|
63 |
+
is_compiled_module,
|
64 |
+
is_torch_version,
|
65 |
+
randn_tensor,
|
66 |
+
)
|
67 |
|
68 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
69 |
|
|
|
289 |
do_convert_rgb=True,
|
290 |
do_normalize=False,
|
291 |
)
|
292 |
+
self.control_mask_processor = VaeImageProcessor(
|
293 |
+
vae_scale_factor=self.vae_scale_factor,
|
294 |
+
do_normalize=False,
|
295 |
+
do_convert_grayscale=True,
|
296 |
+
)
|
297 |
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
298 |
|
299 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
|
|
1374 |
else:
|
1375 |
assert False
|
1376 |
|
1377 |
+
if effective_region_mask is not None:
|
1378 |
+
effective_region_mask = self.control_mask_processor.preprocess(
|
1379 |
+
effective_region_mask, height=height, width=width
|
1380 |
+
).to(dtype=torch.float32)
|
1381 |
+
|
1382 |
+
print("mask shape:")
|
1383 |
+
print(effective_region_mask.shape)
|
1384 |
+
print()
|
1385 |
+
|
1386 |
+
print(effective_region_mask)
|
1387 |
+
|
1388 |
# 5. Prepare timesteps
|
1389 |
timesteps, num_inference_steps = retrieve_timesteps(
|
1390 |
self.scheduler, num_inference_steps, device, timesteps, sigmas
|