FluxM-Lightning-Upscaler / optimized.py
LPX55's picture
Update optimized.py
2c1377f verified
import torch
import spaces
import os
import diffusers
import PIL
from diffusers.utils import load_image
from diffusers import FluxControlNetModel, FluxControlNetPipeline, AutoencoderKL
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
import gradio as gr
from accelerate import dispatch_model, infer_auto_device_map
from PIL import Image
from diffusers import FluxTransformer2DModel
from transformers import T5EncoderModel
import gc
# Corrected and optimized FluxControlNet implementation
huggingface_token = os.getenv("HUGGINFACE_TOKEN")
device = "cuda"
torch_dtype = torch.bfloat16
MAX_SEED = 1000000
def self_attention_slicing(module, slice_size=3):
"""Modified from Diffusers' original for Flux compatibility"""
def sliced_attention(*args, **kwargs):
if "dim" in kwargs:
dim = kwargs["dim"]
else:
dim = 1
if slice_size == "auto":
# Automatic slicing based on Flux architecture
return module(*args, **kwargs)
output = torch.cat([
module(
*[arg[:, :, i:i+slice_size] if i == dim else arg
for arg in args],
**{k: v[:, :, i:i+slice_size] if k == dim else v
for k,v in kwargs.items()}
)
for i in range(0, args[0].shape[dim], slice_size)
], dim=dim)
return output
return sliced_attention
quant_config = TransformersBitsAndBytesConfig(load_in_8bit=True,)
text_encoder_2_8bit = T5EncoderModel.from_pretrained(
"LPX55/FLUX.1-merged_uncensored",
subfolder="text_encoder_2",
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
token=huggingface_token
)
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True,)
transformer_8bit = FluxTransformer2DModel.from_pretrained(
"LPX55/FLUX.1-merged_uncensored",
subfolder="transformer",
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
token=huggingface_token
)
good_vae = AutoencoderKL.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="vae",
torch_dtype=torch.bfloat16,
use_safetensors=True,
device_map=None, # Disable automatic mapping
token=huggingface_token
).to(device)
# 2. Main Pipeline Initialization WITH VAE SCOPE
pipe = FluxControlNetPipeline.from_pretrained(
"LPX55/FLUX.1-merged_uncensored",
controlnet=FluxControlNetModel.from_pretrained(
"jasperai/Flux.1-dev-Controlnet-Upscaler",
torch_dtype=torch.bfloat16
),
vae=good_vae, # Now defined in scope
transformer=transformer_8bit,
text_encoder_2=text_encoder_2_8bit,
torch_dtype=torch.bfloat16,
use_safetensors=True,
device_map=None,
token=huggingface_token # Note corrected env var name
)
pipe.to(device)
# 3. Strict Order for Optimization Steps
# A. Apply CPU Offloading FIRST
#### pipe.enable_sequential_cpu_offload() # No arguments for new API
# 2. Then apply custom VAE slicing
if getattr(pipe, "vae", None) is not None:
# Method 1: Use official implementation if available
try:
pipe.vae.enable_slicing()
except AttributeError:
# Method 2: Apply manual slicing for Flux compatibility [source_id]pipeline_flux_controlnet.py
print("Falling back to manual attention slicing.")
pipe.vae.decode = self_attention_slicing(pipe.vae.decode, 2)
pipe.enable_attention_slicing(1)
# B. Enable Memory Optimizations
# pipe.enable_vae_tiling()
# pipe.enable_xformers_memory_efficient_attention()
# C. Unified Precision Handling
# for comp in [pipe.unet, pipe.vae, pipe.controlnet]:
# comp.to(dtype=torch.bfloat16)
print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f}GB")
@spaces.GPU
def generate_image(prompt, scale, steps, seed, control_image, controlnet_conditioning_scale, guidance_scale, guidance_start, guidance_end):
print(f"Memory Usage: {torch.cuda.memory_summary(device=None, abbreviated=False)}")
# Load control image
control_image = load_image(control_image)
w, h = control_image.size
w = w - w % 8
h = h - h % 8
control_image = control_image.resize((int(w * scale), int(h * scale)))
print("Size to: " + str(control_image.size[0]) + ", " + str(control_image.size[1]))
generator = torch.Generator().manual_seed(seed)
image = pipe(
prompt=prompt,
control_image=control_image,
controlnet_conditioning_scale=controlnet_conditioning_scale,
num_inference_steps=steps,
guidance_scale=guidance_scale,
height=h,
width=w,
control_guidance_start=guidance_start,
control_guidance_end=guidance_end,
generator=generator
).images[0]
return image
# Create Gradio interface
iface = gr.Interface(
fn=generate_image,
inputs=[
gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
gr.Slider(1, 3, value=1, label="Scale"),
gr.Slider(2, 20, value=8, label="Steps"),
gr.Slider(0, MAX_SEED, value=42, label="Seed"),
gr.Image(type="pil", label="Control Image"),
gr.Slider(0, 1, value=0.6, label="ControlNet Scale"),
gr.Slider(1, 20, value=3.5, label="Guidance Scale"),
gr.Slider(0, 1, value=0.0, label="Control Guidance Start"),
gr.Slider(0, 1, value=1.0, label="Control Guidance End"),
],
outputs=[
gr.Image(type="pil", label="Generated Image", format="png"),
],
title="FLUX ControlNet Image Generation",
description="Generate images using the FluxControlNetPipeline. Upload a control image and enter a prompt to create an image.",
)
print(f"Memory Usage: {torch.cuda.memory_summary(device=None, abbreviated=False)}")
gc.enable()
gc.collect()
# Launch the app
iface.launch(show_error=True, share=True)