import torch import spaces import os import diffusers import PIL from diffusers.utils import load_image from diffusers import FluxControlNetModel, FluxControlNetPipeline, AutoencoderKL from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig import gradio as gr from accelerate import dispatch_model, infer_auto_device_map from PIL import Image from diffusers import FluxTransformer2DModel from transformers import T5EncoderModel import gc # Corrected and optimized FluxControlNet implementation huggingface_token = os.getenv("HUGGINFACE_TOKEN") device = "cuda" torch_dtype = torch.bfloat16 MAX_SEED = 1000000 def self_attention_slicing(module, slice_size=3): """Modified from Diffusers' original for Flux compatibility""" def sliced_attention(*args, **kwargs): if "dim" in kwargs: dim = kwargs["dim"] else: dim = 1 if slice_size == "auto": # Automatic slicing based on Flux architecture return module(*args, **kwargs) output = torch.cat([ module( *[arg[:, :, i:i+slice_size] if i == dim else arg for arg in args], **{k: v[:, :, i:i+slice_size] if k == dim else v for k,v in kwargs.items()} ) for i in range(0, args[0].shape[dim], slice_size) ], dim=dim) return output return sliced_attention quant_config = TransformersBitsAndBytesConfig(load_in_8bit=True,) text_encoder_2_8bit = T5EncoderModel.from_pretrained( "LPX55/FLUX.1-merged_uncensored", subfolder="text_encoder_2", quantization_config=quant_config, torch_dtype=torch.bfloat16, token=huggingface_token ) quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True,) transformer_8bit = FluxTransformer2DModel.from_pretrained( "LPX55/FLUX.1-merged_uncensored", subfolder="transformer", quantization_config=quant_config, torch_dtype=torch.bfloat16, token=huggingface_token ) good_vae = AutoencoderKL.from_pretrained( "black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16, use_safetensors=True, device_map=None, # Disable automatic mapping token=huggingface_token ).to(device) # 2. Main Pipeline Initialization WITH VAE SCOPE pipe = FluxControlNetPipeline.from_pretrained( "LPX55/FLUX.1-merged_uncensored", controlnet=FluxControlNetModel.from_pretrained( "jasperai/Flux.1-dev-Controlnet-Upscaler", torch_dtype=torch.bfloat16 ), vae=good_vae, # Now defined in scope transformer=transformer_8bit, text_encoder_2=text_encoder_2_8bit, torch_dtype=torch.bfloat16, use_safetensors=True, device_map=None, token=huggingface_token # Note corrected env var name ) pipe.to(device) # 3. Strict Order for Optimization Steps # A. Apply CPU Offloading FIRST #### pipe.enable_sequential_cpu_offload() # No arguments for new API # 2. Then apply custom VAE slicing if getattr(pipe, "vae", None) is not None: # Method 1: Use official implementation if available try: pipe.vae.enable_slicing() except AttributeError: # Method 2: Apply manual slicing for Flux compatibility [source_id]pipeline_flux_controlnet.py print("Falling back to manual attention slicing.") pipe.vae.decode = self_attention_slicing(pipe.vae.decode, 2) pipe.enable_attention_slicing(1) # B. Enable Memory Optimizations # pipe.enable_vae_tiling() # pipe.enable_xformers_memory_efficient_attention() # C. Unified Precision Handling # for comp in [pipe.unet, pipe.vae, pipe.controlnet]: # comp.to(dtype=torch.bfloat16) print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f}GB") @spaces.GPU def generate_image(prompt, scale, steps, seed, control_image, controlnet_conditioning_scale, guidance_scale, guidance_start, guidance_end): print(f"Memory Usage: {torch.cuda.memory_summary(device=None, abbreviated=False)}") # Load control image control_image = load_image(control_image) w, h = control_image.size w = w - w % 8 h = h - h % 8 control_image = control_image.resize((int(w * scale), int(h * scale))) print("Size to: " + str(control_image.size[0]) + ", " + str(control_image.size[1])) generator = torch.Generator().manual_seed(seed) image = pipe( prompt=prompt, control_image=control_image, controlnet_conditioning_scale=controlnet_conditioning_scale, num_inference_steps=steps, guidance_scale=guidance_scale, height=h, width=w, control_guidance_start=guidance_start, control_guidance_end=guidance_end, generator=generator ).images[0] return image # Create Gradio interface iface = gr.Interface( fn=generate_image, inputs=[ gr.Textbox(lines=2, placeholder="Enter your prompt here..."), gr.Slider(1, 3, value=1, label="Scale"), gr.Slider(2, 20, value=8, label="Steps"), gr.Slider(0, MAX_SEED, value=42, label="Seed"), gr.Image(type="pil", label="Control Image"), gr.Slider(0, 1, value=0.6, label="ControlNet Scale"), gr.Slider(1, 20, value=3.5, label="Guidance Scale"), gr.Slider(0, 1, value=0.0, label="Control Guidance Start"), gr.Slider(0, 1, value=1.0, label="Control Guidance End"), ], outputs=[ gr.Image(type="pil", label="Generated Image", format="png"), ], title="FLUX ControlNet Image Generation", description="Generate images using the FluxControlNetPipeline. Upload a control image and enter a prompt to create an image.", ) print(f"Memory Usage: {torch.cuda.memory_summary(device=None, abbreviated=False)}") gc.enable() gc.collect() # Launch the app iface.launch(show_error=True, share=True)