Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import spaces | |
import os | |
import diffusers | |
import PIL | |
from diffusers.utils import load_image | |
from diffusers import FluxControlNetModel, FluxControlNetPipeline, AutoencoderKL | |
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig | |
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig | |
import gradio as gr | |
from accelerate import dispatch_model, infer_auto_device_map | |
from PIL import Image | |
from diffusers import FluxTransformer2DModel | |
from transformers import T5EncoderModel | |
import gc | |
# Corrected and optimized FluxControlNet implementation | |
huggingface_token = os.getenv("HUGGINFACE_TOKEN") | |
device = "cuda" | |
torch_dtype = torch.bfloat16 | |
MAX_SEED = 1000000 | |
def self_attention_slicing(module, slice_size=3): | |
"""Modified from Diffusers' original for Flux compatibility""" | |
def sliced_attention(*args, **kwargs): | |
if "dim" in kwargs: | |
dim = kwargs["dim"] | |
else: | |
dim = 1 | |
if slice_size == "auto": | |
# Automatic slicing based on Flux architecture | |
return module(*args, **kwargs) | |
output = torch.cat([ | |
module( | |
*[arg[:, :, i:i+slice_size] if i == dim else arg | |
for arg in args], | |
**{k: v[:, :, i:i+slice_size] if k == dim else v | |
for k,v in kwargs.items()} | |
) | |
for i in range(0, args[0].shape[dim], slice_size) | |
], dim=dim) | |
return output | |
return sliced_attention | |
quant_config = TransformersBitsAndBytesConfig(load_in_8bit=True,) | |
text_encoder_2_8bit = T5EncoderModel.from_pretrained( | |
"LPX55/FLUX.1-merged_uncensored", | |
subfolder="text_encoder_2", | |
quantization_config=quant_config, | |
torch_dtype=torch.bfloat16, | |
token=huggingface_token | |
) | |
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True,) | |
transformer_8bit = FluxTransformer2DModel.from_pretrained( | |
"LPX55/FLUX.1-merged_uncensored", | |
subfolder="transformer", | |
quantization_config=quant_config, | |
torch_dtype=torch.bfloat16, | |
token=huggingface_token | |
) | |
good_vae = AutoencoderKL.from_pretrained( | |
"black-forest-labs/FLUX.1-dev", | |
subfolder="vae", | |
torch_dtype=torch.bfloat16, | |
use_safetensors=True, | |
device_map=None, # Disable automatic mapping | |
token=huggingface_token | |
).to(device) | |
# 2. Main Pipeline Initialization WITH VAE SCOPE | |
pipe = FluxControlNetPipeline.from_pretrained( | |
"LPX55/FLUX.1-merged_uncensored", | |
controlnet=FluxControlNetModel.from_pretrained( | |
"jasperai/Flux.1-dev-Controlnet-Upscaler", | |
torch_dtype=torch.bfloat16 | |
), | |
vae=good_vae, # Now defined in scope | |
transformer=transformer_8bit, | |
text_encoder_2=text_encoder_2_8bit, | |
torch_dtype=torch.bfloat16, | |
use_safetensors=True, | |
device_map=None, | |
token=huggingface_token # Note corrected env var name | |
) | |
pipe.to(device) | |
# 3. Strict Order for Optimization Steps | |
# A. Apply CPU Offloading FIRST | |
#### pipe.enable_sequential_cpu_offload() # No arguments for new API | |
# 2. Then apply custom VAE slicing | |
if getattr(pipe, "vae", None) is not None: | |
# Method 1: Use official implementation if available | |
try: | |
pipe.vae.enable_slicing() | |
except AttributeError: | |
# Method 2: Apply manual slicing for Flux compatibility [source_id]pipeline_flux_controlnet.py | |
print("Falling back to manual attention slicing.") | |
pipe.vae.decode = self_attention_slicing(pipe.vae.decode, 2) | |
pipe.enable_attention_slicing(1) | |
# B. Enable Memory Optimizations | |
# pipe.enable_vae_tiling() | |
# pipe.enable_xformers_memory_efficient_attention() | |
# C. Unified Precision Handling | |
# for comp in [pipe.unet, pipe.vae, pipe.controlnet]: | |
# comp.to(dtype=torch.bfloat16) | |
print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f}GB") | |
def generate_image(prompt, scale, steps, seed, control_image, controlnet_conditioning_scale, guidance_scale, guidance_start, guidance_end): | |
print(f"Memory Usage: {torch.cuda.memory_summary(device=None, abbreviated=False)}") | |
# Load control image | |
control_image = load_image(control_image) | |
w, h = control_image.size | |
w = w - w % 8 | |
h = h - h % 8 | |
control_image = control_image.resize((int(w * scale), int(h * scale))) | |
print("Size to: " + str(control_image.size[0]) + ", " + str(control_image.size[1])) | |
generator = torch.Generator().manual_seed(seed) | |
image = pipe( | |
prompt=prompt, | |
control_image=control_image, | |
controlnet_conditioning_scale=controlnet_conditioning_scale, | |
num_inference_steps=steps, | |
guidance_scale=guidance_scale, | |
height=h, | |
width=w, | |
control_guidance_start=guidance_start, | |
control_guidance_end=guidance_end, | |
generator=generator | |
).images[0] | |
return image | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=generate_image, | |
inputs=[ | |
gr.Textbox(lines=2, placeholder="Enter your prompt here..."), | |
gr.Slider(1, 3, value=1, label="Scale"), | |
gr.Slider(2, 20, value=8, label="Steps"), | |
gr.Slider(0, MAX_SEED, value=42, label="Seed"), | |
gr.Image(type="pil", label="Control Image"), | |
gr.Slider(0, 1, value=0.6, label="ControlNet Scale"), | |
gr.Slider(1, 20, value=3.5, label="Guidance Scale"), | |
gr.Slider(0, 1, value=0.0, label="Control Guidance Start"), | |
gr.Slider(0, 1, value=1.0, label="Control Guidance End"), | |
], | |
outputs=[ | |
gr.Image(type="pil", label="Generated Image", format="png"), | |
], | |
title="FLUX ControlNet Image Generation", | |
description="Generate images using the FluxControlNetPipeline. Upload a control image and enter a prompt to create an image.", | |
) | |
print(f"Memory Usage: {torch.cuda.memory_summary(device=None, abbreviated=False)}") | |
gc.enable() | |
gc.collect() | |
# Launch the app | |
iface.launch(show_error=True, share=True) |