Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,890 Bytes
eb1e114 30ad131 b3266d4 eb1e114 93afc0b eb1e114 30ad131 b3266d4 93afc0b 044d861 ae34032 93afc0b 4ce7f3f cfab240 2310622 a102a01 e4bd5a4 a102a01 93afc0b 1b616c4 a102a01 096063b 1b616c4 b1c8464 1b616c4 93afc0b 4af365d ae34032 1b616c4 b1c8464 096063b 1b616c4 453b4ed a102a01 7ec4770 a102a01 ae34032 a102a01 1b616c4 096063b ae34032 1b616c4 a102a01 ae34032 4af365d eb1e114 4ce7f3f 044d861 1063084 5dd4601 eb1e114 3726af7 1063084 4ce7f3f 1063084 6f6e2c3 7ec4770 506ee59 4ce7f3f 1063084 eb1e114 5ac2a35 4ce7f3f eb1e114 34f58ed 1f0c46b eb1e114 28e6562 19d5b34 044d861 eb1e114 2c1377f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import torch
import spaces
import os
import diffusers
import PIL
from diffusers.utils import load_image
from diffusers import FluxControlNetModel, FluxControlNetPipeline, AutoencoderKL
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
import gradio as gr
from accelerate import dispatch_model, infer_auto_device_map
from PIL import Image
from diffusers import FluxTransformer2DModel
from transformers import T5EncoderModel
import gc
# Corrected and optimized FluxControlNet implementation
huggingface_token = os.getenv("HUGGINFACE_TOKEN")
device = "cuda"
torch_dtype = torch.bfloat16
MAX_SEED = 1000000
def self_attention_slicing(module, slice_size=3):
"""Modified from Diffusers' original for Flux compatibility"""
def sliced_attention(*args, **kwargs):
if "dim" in kwargs:
dim = kwargs["dim"]
else:
dim = 1
if slice_size == "auto":
# Automatic slicing based on Flux architecture
return module(*args, **kwargs)
output = torch.cat([
module(
*[arg[:, :, i:i+slice_size] if i == dim else arg
for arg in args],
**{k: v[:, :, i:i+slice_size] if k == dim else v
for k,v in kwargs.items()}
)
for i in range(0, args[0].shape[dim], slice_size)
], dim=dim)
return output
return sliced_attention
quant_config = TransformersBitsAndBytesConfig(load_in_8bit=True,)
text_encoder_2_8bit = T5EncoderModel.from_pretrained(
"LPX55/FLUX.1-merged_uncensored",
subfolder="text_encoder_2",
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
token=huggingface_token
)
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True,)
transformer_8bit = FluxTransformer2DModel.from_pretrained(
"LPX55/FLUX.1-merged_uncensored",
subfolder="transformer",
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
token=huggingface_token
)
good_vae = AutoencoderKL.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="vae",
torch_dtype=torch.bfloat16,
use_safetensors=True,
device_map=None, # Disable automatic mapping
token=huggingface_token
).to(device)
# 2. Main Pipeline Initialization WITH VAE SCOPE
pipe = FluxControlNetPipeline.from_pretrained(
"LPX55/FLUX.1-merged_uncensored",
controlnet=FluxControlNetModel.from_pretrained(
"jasperai/Flux.1-dev-Controlnet-Upscaler",
torch_dtype=torch.bfloat16
),
vae=good_vae, # Now defined in scope
transformer=transformer_8bit,
text_encoder_2=text_encoder_2_8bit,
torch_dtype=torch.bfloat16,
use_safetensors=True,
device_map=None,
token=huggingface_token # Note corrected env var name
)
pipe.to(device)
# 3. Strict Order for Optimization Steps
# A. Apply CPU Offloading FIRST
#### pipe.enable_sequential_cpu_offload() # No arguments for new API
# 2. Then apply custom VAE slicing
if getattr(pipe, "vae", None) is not None:
# Method 1: Use official implementation if available
try:
pipe.vae.enable_slicing()
except AttributeError:
# Method 2: Apply manual slicing for Flux compatibility [source_id]pipeline_flux_controlnet.py
print("Falling back to manual attention slicing.")
pipe.vae.decode = self_attention_slicing(pipe.vae.decode, 2)
pipe.enable_attention_slicing(1)
# B. Enable Memory Optimizations
# pipe.enable_vae_tiling()
# pipe.enable_xformers_memory_efficient_attention()
# C. Unified Precision Handling
# for comp in [pipe.unet, pipe.vae, pipe.controlnet]:
# comp.to(dtype=torch.bfloat16)
print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f}GB")
@spaces.GPU
def generate_image(prompt, scale, steps, seed, control_image, controlnet_conditioning_scale, guidance_scale, guidance_start, guidance_end):
print(f"Memory Usage: {torch.cuda.memory_summary(device=None, abbreviated=False)}")
# Load control image
control_image = load_image(control_image)
w, h = control_image.size
w = w - w % 8
h = h - h % 8
control_image = control_image.resize((int(w * scale), int(h * scale)))
print("Size to: " + str(control_image.size[0]) + ", " + str(control_image.size[1]))
generator = torch.Generator().manual_seed(seed)
image = pipe(
prompt=prompt,
control_image=control_image,
controlnet_conditioning_scale=controlnet_conditioning_scale,
num_inference_steps=steps,
guidance_scale=guidance_scale,
height=h,
width=w,
control_guidance_start=guidance_start,
control_guidance_end=guidance_end,
generator=generator
).images[0]
return image
# Create Gradio interface
iface = gr.Interface(
fn=generate_image,
inputs=[
gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
gr.Slider(1, 3, value=1, label="Scale"),
gr.Slider(2, 20, value=8, label="Steps"),
gr.Slider(0, MAX_SEED, value=42, label="Seed"),
gr.Image(type="pil", label="Control Image"),
gr.Slider(0, 1, value=0.6, label="ControlNet Scale"),
gr.Slider(1, 20, value=3.5, label="Guidance Scale"),
gr.Slider(0, 1, value=0.0, label="Control Guidance Start"),
gr.Slider(0, 1, value=1.0, label="Control Guidance End"),
],
outputs=[
gr.Image(type="pil", label="Generated Image", format="png"),
],
title="FLUX ControlNet Image Generation",
description="Generate images using the FluxControlNetPipeline. Upload a control image and enter a prompt to create an image.",
)
print(f"Memory Usage: {torch.cuda.memory_summary(device=None, abbreviated=False)}")
gc.enable()
gc.collect()
# Launch the app
iface.launch(show_error=True, share=True) |