shunk031 commited on
Commit
64c10d9
·
verified ·
1 Parent(s): 8cc9b4f

Upload pipeline.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. pipeline.py +90 -0
pipeline.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import torch
4
+ from diffusers import StableDiffusionModelEditingPipeline as SDTIME
5
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
6
+ from diffusers.pipelines.deprecated.stable_diffusion_variants.pipeline_stable_diffusion_model_editing import (
7
+ AUGS_CONST,
8
+ )
9
+ from diffusers.pipelines.stable_diffusion.safety_checker import (
10
+ StableDiffusionSafetyChecker,
11
+ )
12
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
13
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
14
+
15
+
16
+ class StableDiffusionModelEditingPipeline(SDTIME):
17
+ def __init__(
18
+ self,
19
+ vae: AutoencoderKL,
20
+ text_encoder: CLIPTextModel,
21
+ tokenizer: CLIPTokenizer,
22
+ unet: UNet2DConditionModel,
23
+ scheduler: SchedulerMixin,
24
+ safety_checker: StableDiffusionSafetyChecker,
25
+ feature_extractor: CLIPImageProcessor,
26
+ requires_safety_checker: bool = True,
27
+ with_to_k: bool = True,
28
+ with_augs: List[str] = AUGS_CONST,
29
+ ) -> None:
30
+ super().__init__(
31
+ vae,
32
+ text_encoder,
33
+ tokenizer,
34
+ unet,
35
+ scheduler,
36
+ safety_checker,
37
+ feature_extractor,
38
+ requires_safety_checker,
39
+ with_to_k,
40
+ with_augs,
41
+ )
42
+
43
+ # get cross-attention layers
44
+ ca_layers = []
45
+
46
+ def append_ca(net_):
47
+ # In diffusers v1.15.0 and later, `CrossAttention` has been changed to `Attention`
48
+ # Refer to the pipeline in the fork:
49
+ # https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py#L135
50
+ if net_.__class__.__name__ == "Attention":
51
+ ca_layers.append(net_)
52
+ elif hasattr(net_, "children"):
53
+ for net__ in net_.children():
54
+ append_ca(net__)
55
+
56
+ # recursively find all cross-attention layers in unet
57
+ for net in self.unet.named_children():
58
+ if "down" in net[0]:
59
+ append_ca(net[1])
60
+ elif "up" in net[0]:
61
+ append_ca(net[1])
62
+ elif "mid" in net[0]:
63
+ append_ca(net[1])
64
+
65
+ # get projection matrices
66
+ self.ca_clip_layers = [l for l in ca_layers if l.to_v.in_features == 768]
67
+ assert len(self.ca_clip_layers) > 0
68
+ self.projection_matrices = [l.to_v for l in self.ca_clip_layers]
69
+ assert len(self.projection_matrices) > 0
70
+
71
+ if self.with_to_k:
72
+ projection_matrices = [l.to_k for l in self.ca_clip_layers]
73
+ self.projection_matrices = self.projection_matrices + projection_matrices
74
+ assert len(self.projection_matrices) > 0
75
+
76
+ @torch.no_grad()
77
+ def edit_model(
78
+ self,
79
+ source_prompt: str,
80
+ destination_prompt: str,
81
+ lamb: float = 0.1,
82
+ **kwargs,
83
+ ) -> None:
84
+ # `restart_params` creates a copy of the object when restoring the original weights,
85
+ # which can lead to problems such as the device not being set correctly
86
+ # when exiting the pipeline. For these reasons, `restart_params` is set to `False`.
87
+ # If you want to restore the original weights, it is recommended to reload the pipeline.
88
+ super().edit_model(
89
+ source_prompt, destination_prompt, lamb, restart_params=False
90
+ )