File size: 4,759 Bytes
eb1e114
 
 
30ad131
b3266d4
eb1e114
 
 
30ad131
b3266d4
044d861
ae34032
cfab240
2310622
 
 
a102a01
 
 
 
e4bd5a4
a102a01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c8e08c
a102a01
 
1b616c4
 
 
 
 
a102a01
 
096063b
1b616c4
b1c8464
 
1b616c4
 
 
 
 
4af365d
 
ae34032
1b616c4
b1c8464
096063b
1b616c4
 
453b4ed
a102a01
 
 
 
 
 
 
 
ae34032
a102a01
1b616c4
096063b
 
ae34032
1b616c4
a102a01
 
ae34032
4af365d
eb1e114
 
044d861
 
1063084
5dd4601
eb1e114
1063084
 
 
 
 
 
 
 
 
 
 
 
4af365d
 
044d861
1063084
044d861
 
 
4af365d
eb1e114
 
 
 
 
 
 
5ac2a35
eb1e114
 
 
 
 
 
 
 
 
 
28e6562
044d861
 
eb1e114
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import torch
import spaces
import os
import diffusers
import PIL
from diffusers.utils import load_image
from diffusers import FluxControlNetModel, FluxControlNetPipeline, AutoencoderKL
import gradio as gr
from accelerate import dispatch_model, infer_auto_device_map
from PIL import Image
import gc
# Corrected and optimized FluxControlNet implementation

def self_attention_slicing(module, slice_size=3):
    """Modified from Diffusers' original for Flux compatibility"""
    def sliced_attention(*args, **kwargs):
        if "dim" in kwargs:
            dim = kwargs["dim"]
        else:
            dim = 1
        
        if slice_size == "auto":
            # Automatic slicing based on Flux architecture
            return module(*args, **kwargs)
            
        output = torch.cat([
            module(
                *[arg[:, :, i:i+slice_size] if i == dim else arg 
                for arg in args],
                **{k: v[:, :, i:i+slice_size] if k == dim else v 
                   for k,v in kwargs.items()}
            )
            for i in range(0, args[0].shape[dim], slice_size)
        ], dim=dim)
        
        return output
    return sliced_attention
device = "cuda"
    
huggingface_token = os.getenv("HUGGINFACE_TOKEN")
good_vae = AutoencoderKL.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    subfolder="vae",
    torch_dtype=torch.bfloat16,
    use_safetensors=True,
    device_map=None,  # Disable automatic mapping
    token=huggingface_token
).to(device)
# 2. Main Pipeline Initialization WITH VAE SCOPE
pipe = FluxControlNetPipeline.from_pretrained(
    "LPX55/FLUX.1-merged_uncensored",
    controlnet=FluxControlNetModel.from_pretrained(
        "jasperai/Flux.1-dev-Controlnet-Upscaler",
        torch_dtype=torch.bfloat16
    ),
    vae=good_vae,  # Now defined in scope
    torch_dtype=torch.bfloat16,
    use_safetensors=True,
    device_map=None,
    token=huggingface_token  # Note corrected env var name
)
pipe.to(device)
# 3. Strict Order for Optimization Steps
# A. Apply CPU Offloading FIRST
#### pipe.enable_sequential_cpu_offload()  # No arguments for new API
# 2. Then apply custom VAE slicing
if getattr(pipe, "vae", None) is not None:
    # Method 1: Use official implementation if available
    try:
        pipe.vae.enable_slicing()
    except AttributeError:
        # Method 2: Apply manual slicing for Flux compatibility [source_id]pipeline_flux_controlnet.py
        pipe.vae.decode = self_attention_slicing(pipe.vae.decode, 2) 

pipe.enable_attention_slicing(1)
# B. Enable Memory Optimizations
# pipe.enable_vae_tiling()
# pipe.enable_xformers_memory_efficient_attention()

# C. Unified Precision Handling
# for comp in [pipe.unet, pipe.vae, pipe.controlnet]:
#     comp.to(dtype=torch.bfloat16)

print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f}GB")
@spaces.GPU
def generate_image(prompt, scale, steps, control_image, controlnet_conditioning_scale, guidance_scale):
    print(f"Memory Usage: {torch.cuda.memory_summary(device=None, abbreviated=False)}")

    # Load control image
    control_image = load_image(control_image)
    w, h = control_image.size
    # Upscale x1
    control_image = control_image.resize((int(w * scale), int(h * scale)))
    print("Size to: " + str(control_image.size[0]) + ", " + str(control_image.size[1]))
    image = pipe(
        prompt=prompt,
        control_image=control_image,
        controlnet_conditioning_scale=controlnet_conditioning_scale,
        num_inference_steps=steps,
        guidance_scale=guidance_scale,
        height=control_image.size[1],
        width=control_image.size[0]
    ).images[0]
    print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f}GB")
    # Aggressive memory cleanup
    torch.cuda.empty_cache()
    # torch.cuda.ipc_collect()
    
    del variables
    gc.collect()
    print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f}GB")
    return image
# Create Gradio interface
iface = gr.Interface(
    fn=generate_image,
    inputs=[
        gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
        gr.Slider(1, 3, value=1, label="Scale"),
        gr.Slider(2, 20, value=8, label="Steps"),
        gr.Image(type="pil", label="Control Image"),
        gr.Slider(0, 1, value=0.6, label="ControlNet Scale"),
        gr.Slider(1, 20, value=3.5, label="Guidance Scale"),
    ],
    outputs=[
        gr.Image(type="pil", label="Generated Image", format="png"),
    ],
    title="FLUX ControlNet Image Generation",
    description="Generate images using the FluxControlNetPipeline. Upload a control image and enter a prompt to create an image.",
)
print(f"Memory Usage: {torch.cuda.memory_summary(device=None, abbreviated=False)}")
del variables
gc.collect()
# Launch the app
iface.launch()