File size: 5,890 Bytes
eb1e114
 
 
30ad131
b3266d4
eb1e114
 
93afc0b
 
eb1e114
30ad131
b3266d4
93afc0b
 
044d861
ae34032
93afc0b
 
 
4ce7f3f
cfab240
2310622
 
 
a102a01
 
 
 
e4bd5a4
a102a01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93afc0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b616c4
 
 
 
 
a102a01
 
096063b
1b616c4
b1c8464
 
1b616c4
 
 
 
 
93afc0b
 
4af365d
 
ae34032
1b616c4
b1c8464
096063b
1b616c4
 
453b4ed
a102a01
 
 
 
 
 
 
7ec4770
a102a01
ae34032
a102a01
1b616c4
096063b
 
ae34032
1b616c4
a102a01
 
ae34032
4af365d
eb1e114
4ce7f3f
044d861
 
1063084
5dd4601
eb1e114
3726af7
 
1063084
 
4ce7f3f
 
1063084
 
 
 
 
 
6f6e2c3
7ec4770
506ee59
4ce7f3f
 
1063084
eb1e114
 
 
 
 
 
 
5ac2a35
4ce7f3f
eb1e114
34f58ed
 
 
1f0c46b
eb1e114
 
 
 
 
 
 
28e6562
19d5b34
044d861
eb1e114
2c1377f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import torch
import spaces
import os
import diffusers
import PIL
from diffusers.utils import load_image
from diffusers import FluxControlNetModel, FluxControlNetPipeline, AutoencoderKL
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
import gradio as gr
from accelerate import dispatch_model, infer_auto_device_map
from PIL import Image
from diffusers import FluxTransformer2DModel
from transformers import T5EncoderModel
import gc
# Corrected and optimized FluxControlNet implementation
huggingface_token = os.getenv("HUGGINFACE_TOKEN")
device = "cuda"
torch_dtype = torch.bfloat16
MAX_SEED = 1000000

def self_attention_slicing(module, slice_size=3):
    """Modified from Diffusers' original for Flux compatibility"""
    def sliced_attention(*args, **kwargs):
        if "dim" in kwargs:
            dim = kwargs["dim"]
        else:
            dim = 1
        
        if slice_size == "auto":
            # Automatic slicing based on Flux architecture
            return module(*args, **kwargs)
            
        output = torch.cat([
            module(
                *[arg[:, :, i:i+slice_size] if i == dim else arg 
                for arg in args],
                **{k: v[:, :, i:i+slice_size] if k == dim else v 
                   for k,v in kwargs.items()}
            )
            for i in range(0, args[0].shape[dim], slice_size)
        ], dim=dim)
        
        return output
    return sliced_attention
    
quant_config = TransformersBitsAndBytesConfig(load_in_8bit=True,)
text_encoder_2_8bit = T5EncoderModel.from_pretrained(
    "LPX55/FLUX.1-merged_uncensored",
    subfolder="text_encoder_2",
    quantization_config=quant_config,
    torch_dtype=torch.bfloat16,
    token=huggingface_token
)
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True,)
transformer_8bit = FluxTransformer2DModel.from_pretrained(
    "LPX55/FLUX.1-merged_uncensored",
    subfolder="transformer",
    quantization_config=quant_config,
    torch_dtype=torch.bfloat16,
    token=huggingface_token
)
good_vae = AutoencoderKL.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    subfolder="vae",
    torch_dtype=torch.bfloat16,
    use_safetensors=True,
    device_map=None,  # Disable automatic mapping
    token=huggingface_token
).to(device)
# 2. Main Pipeline Initialization WITH VAE SCOPE
pipe = FluxControlNetPipeline.from_pretrained(
    "LPX55/FLUX.1-merged_uncensored",
    controlnet=FluxControlNetModel.from_pretrained(
        "jasperai/Flux.1-dev-Controlnet-Upscaler",
        torch_dtype=torch.bfloat16
    ),
    vae=good_vae,  # Now defined in scope
    transformer=transformer_8bit,
    text_encoder_2=text_encoder_2_8bit,
    torch_dtype=torch.bfloat16,
    use_safetensors=True,
    device_map=None,
    token=huggingface_token  # Note corrected env var name
)
pipe.to(device)
# 3. Strict Order for Optimization Steps
# A. Apply CPU Offloading FIRST
#### pipe.enable_sequential_cpu_offload()  # No arguments for new API
# 2. Then apply custom VAE slicing
if getattr(pipe, "vae", None) is not None:
    # Method 1: Use official implementation if available
    try:
        pipe.vae.enable_slicing()
    except AttributeError:
        # Method 2: Apply manual slicing for Flux compatibility [source_id]pipeline_flux_controlnet.py
        print("Falling back to manual attention slicing.")
        pipe.vae.decode = self_attention_slicing(pipe.vae.decode, 2) 

pipe.enable_attention_slicing(1)
# B. Enable Memory Optimizations
# pipe.enable_vae_tiling()
# pipe.enable_xformers_memory_efficient_attention()

# C. Unified Precision Handling
# for comp in [pipe.unet, pipe.vae, pipe.controlnet]:
#     comp.to(dtype=torch.bfloat16)

print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f}GB")
@spaces.GPU
def generate_image(prompt, scale, steps, seed, control_image, controlnet_conditioning_scale, guidance_scale, guidance_start, guidance_end):
    print(f"Memory Usage: {torch.cuda.memory_summary(device=None, abbreviated=False)}")

    # Load control image
    control_image = load_image(control_image)
    w, h = control_image.size
    w = w - w % 8
    h = h - h % 8
    control_image = control_image.resize((int(w * scale), int(h * scale)))
    print("Size to: " + str(control_image.size[0]) + ", " + str(control_image.size[1]))
    generator = torch.Generator().manual_seed(seed)

    image = pipe(
        prompt=prompt,
        control_image=control_image,
        controlnet_conditioning_scale=controlnet_conditioning_scale,
        num_inference_steps=steps,
        guidance_scale=guidance_scale,
        height=h,
        width=w,
        control_guidance_start=guidance_start,
        control_guidance_end=guidance_end,
        generator=generator
    ).images[0]
    return image
# Create Gradio interface
iface = gr.Interface(
    fn=generate_image,
    inputs=[
        gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
        gr.Slider(1, 3, value=1, label="Scale"),
        gr.Slider(2, 20, value=8, label="Steps"),
        gr.Slider(0, MAX_SEED, value=42, label="Seed"),
        gr.Image(type="pil", label="Control Image"),
        gr.Slider(0, 1, value=0.6, label="ControlNet Scale"),
        gr.Slider(1, 20, value=3.5, label="Guidance Scale"),
        gr.Slider(0, 1, value=0.0, label="Control Guidance Start"),
        gr.Slider(0, 1, value=1.0, label="Control Guidance End"),
    ],
    outputs=[
        gr.Image(type="pil", label="Generated Image", format="png"),
    ],
    title="FLUX ControlNet Image Generation",
    description="Generate images using the FluxControlNetPipeline. Upload a control image and enter a prompt to create an image.",
)
print(f"Memory Usage: {torch.cuda.memory_summary(device=None, abbreviated=False)}")
gc.enable()
gc.collect()
# Launch the app
iface.launch(show_error=True, share=True)