File size: 4,596 Bytes
eb1e114
 
 
30ad131
b3266d4
eb1e114
 
 
30ad131
b3266d4
ae34032
cfab240
2310622
 
 
a102a01
 
 
 
e4bd5a4
a102a01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c8e08c
a102a01
 
1b616c4
 
 
 
 
a102a01
 
096063b
1b616c4
b1c8464
 
1b616c4
 
 
 
 
4af365d
 
ae34032
1b616c4
b1c8464
096063b
1b616c4
 
453b4ed
a102a01
 
 
 
 
 
 
 
ae34032
a102a01
1b616c4
096063b
 
ae34032
1b616c4
a102a01
 
ae34032
4af365d
eb1e114
 
1063084
5dd4601
eb1e114
1063084
 
 
 
 
 
 
 
 
 
 
 
4af365d
 
1063084
 
4af365d
eb1e114
 
 
 
 
 
 
5ac2a35
eb1e114
 
 
 
 
 
 
 
 
 
28e6562
eb1e114
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
import spaces
import os
import diffusers
import PIL
from diffusers.utils import load_image
from diffusers import FluxControlNetModel, FluxControlNetPipeline, AutoencoderKL
import gradio as gr
from accelerate import dispatch_model, infer_auto_device_map
from PIL import Image
# Corrected and optimized FluxControlNet implementation

def self_attention_slicing(module, slice_size=3):
    """Modified from Diffusers' original for Flux compatibility"""
    def sliced_attention(*args, **kwargs):
        if "dim" in kwargs:
            dim = kwargs["dim"]
        else:
            dim = 1
        
        if slice_size == "auto":
            # Automatic slicing based on Flux architecture
            return module(*args, **kwargs)
            
        output = torch.cat([
            module(
                *[arg[:, :, i:i+slice_size] if i == dim else arg 
                for arg in args],
                **{k: v[:, :, i:i+slice_size] if k == dim else v 
                   for k,v in kwargs.items()}
            )
            for i in range(0, args[0].shape[dim], slice_size)
        ], dim=dim)
        
        return output
    return sliced_attention
device = "cuda"
    
huggingface_token = os.getenv("HUGGINFACE_TOKEN")
good_vae = AutoencoderKL.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    subfolder="vae",
    torch_dtype=torch.bfloat16,
    use_safetensors=True,
    device_map=None,  # Disable automatic mapping
    token=huggingface_token
).to(device)
# 2. Main Pipeline Initialization WITH VAE SCOPE
pipe = FluxControlNetPipeline.from_pretrained(
    "LPX55/FLUX.1-merged_uncensored",
    controlnet=FluxControlNetModel.from_pretrained(
        "jasperai/Flux.1-dev-Controlnet-Upscaler",
        torch_dtype=torch.bfloat16
    ),
    vae=good_vae,  # Now defined in scope
    torch_dtype=torch.bfloat16,
    use_safetensors=True,
    device_map=None,
    token=huggingface_token  # Note corrected env var name
)
pipe.to(device)
# 3. Strict Order for Optimization Steps
# A. Apply CPU Offloading FIRST
#### pipe.enable_sequential_cpu_offload()  # No arguments for new API
# 2. Then apply custom VAE slicing
if getattr(pipe, "vae", None) is not None:
    # Method 1: Use official implementation if available
    try:
        pipe.vae.enable_slicing()
    except AttributeError:
        # Method 2: Apply manual slicing for Flux compatibility [source_id]pipeline_flux_controlnet.py
        pipe.vae.decode = self_attention_slicing(pipe.vae.decode, 2) 

pipe.enable_attention_slicing(1)
# B. Enable Memory Optimizations
# pipe.enable_vae_tiling()
# pipe.enable_xformers_memory_efficient_attention()

# C. Unified Precision Handling
# for comp in [pipe.unet, pipe.vae, pipe.controlnet]:
#     comp.to(dtype=torch.bfloat16)

print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f}GB")
@spaces.GPU
def generate_image(prompt, scale, steps, control_image, controlnet_conditioning_scale, guidance_scale):
    # Load control image
    control_image = load_image(control_image)
    w, h = control_image.size
    # Upscale x1
    control_image = control_image.resize((int(w * scale), int(h * scale)))
    print("Size to: " + str(control_image.size[0]) + ", " + str(control_image.size[1]))
    image = pipe(
        prompt=prompt,
        control_image=control_image,
        controlnet_conditioning_scale=controlnet_conditioning_scale,
        num_inference_steps=steps,
        guidance_scale=guidance_scale,
        height=control_image.size[1],
        width=control_image.size[0]
    ).images[0]
    print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f}GB")
    # Aggressive memory cleanup
    # torch.cuda.empty_cache()
    # torch.cuda.ipc_collect()
    print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f}GB")
    return image
# Create Gradio interface
iface = gr.Interface(
    fn=generate_image,
    inputs=[
        gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
        gr.Slider(1, 3, value=1, label="Scale"),
        gr.Slider(2, 20, value=8, label="Steps"),
        gr.Image(type="pil", label="Control Image"),
        gr.Slider(0, 1, value=0.6, label="ControlNet Scale"),
        gr.Slider(1, 20, value=3.5, label="Guidance Scale"),
    ],
    outputs=[
        gr.Image(type="pil", label="Generated Image", format="png"),
    ],
    title="FLUX ControlNet Image Generation",
    description="Generate images using the FluxControlNetPipeline. Upload a control image and enter a prompt to create an image.",
)
print(f"Memory Usage: {torch.cuda.memory_summary(device=None, abbreviated=False)}")

# Launch the app
iface.launch()