Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,596 Bytes
eb1e114 30ad131 b3266d4 eb1e114 30ad131 b3266d4 ae34032 cfab240 2310622 a102a01 e4bd5a4 a102a01 0c8e08c a102a01 1b616c4 a102a01 096063b 1b616c4 b1c8464 1b616c4 4af365d ae34032 1b616c4 b1c8464 096063b 1b616c4 453b4ed a102a01 ae34032 a102a01 1b616c4 096063b ae34032 1b616c4 a102a01 ae34032 4af365d eb1e114 1063084 5dd4601 eb1e114 1063084 4af365d 1063084 4af365d eb1e114 5ac2a35 eb1e114 28e6562 eb1e114 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import torch
import spaces
import os
import diffusers
import PIL
from diffusers.utils import load_image
from diffusers import FluxControlNetModel, FluxControlNetPipeline, AutoencoderKL
import gradio as gr
from accelerate import dispatch_model, infer_auto_device_map
from PIL import Image
# Corrected and optimized FluxControlNet implementation
def self_attention_slicing(module, slice_size=3):
"""Modified from Diffusers' original for Flux compatibility"""
def sliced_attention(*args, **kwargs):
if "dim" in kwargs:
dim = kwargs["dim"]
else:
dim = 1
if slice_size == "auto":
# Automatic slicing based on Flux architecture
return module(*args, **kwargs)
output = torch.cat([
module(
*[arg[:, :, i:i+slice_size] if i == dim else arg
for arg in args],
**{k: v[:, :, i:i+slice_size] if k == dim else v
for k,v in kwargs.items()}
)
for i in range(0, args[0].shape[dim], slice_size)
], dim=dim)
return output
return sliced_attention
device = "cuda"
huggingface_token = os.getenv("HUGGINFACE_TOKEN")
good_vae = AutoencoderKL.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="vae",
torch_dtype=torch.bfloat16,
use_safetensors=True,
device_map=None, # Disable automatic mapping
token=huggingface_token
).to(device)
# 2. Main Pipeline Initialization WITH VAE SCOPE
pipe = FluxControlNetPipeline.from_pretrained(
"LPX55/FLUX.1-merged_uncensored",
controlnet=FluxControlNetModel.from_pretrained(
"jasperai/Flux.1-dev-Controlnet-Upscaler",
torch_dtype=torch.bfloat16
),
vae=good_vae, # Now defined in scope
torch_dtype=torch.bfloat16,
use_safetensors=True,
device_map=None,
token=huggingface_token # Note corrected env var name
)
pipe.to(device)
# 3. Strict Order for Optimization Steps
# A. Apply CPU Offloading FIRST
#### pipe.enable_sequential_cpu_offload() # No arguments for new API
# 2. Then apply custom VAE slicing
if getattr(pipe, "vae", None) is not None:
# Method 1: Use official implementation if available
try:
pipe.vae.enable_slicing()
except AttributeError:
# Method 2: Apply manual slicing for Flux compatibility [source_id]pipeline_flux_controlnet.py
pipe.vae.decode = self_attention_slicing(pipe.vae.decode, 2)
pipe.enable_attention_slicing(1)
# B. Enable Memory Optimizations
# pipe.enable_vae_tiling()
# pipe.enable_xformers_memory_efficient_attention()
# C. Unified Precision Handling
# for comp in [pipe.unet, pipe.vae, pipe.controlnet]:
# comp.to(dtype=torch.bfloat16)
print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f}GB")
@spaces.GPU
def generate_image(prompt, scale, steps, control_image, controlnet_conditioning_scale, guidance_scale):
# Load control image
control_image = load_image(control_image)
w, h = control_image.size
# Upscale x1
control_image = control_image.resize((int(w * scale), int(h * scale)))
print("Size to: " + str(control_image.size[0]) + ", " + str(control_image.size[1]))
image = pipe(
prompt=prompt,
control_image=control_image,
controlnet_conditioning_scale=controlnet_conditioning_scale,
num_inference_steps=steps,
guidance_scale=guidance_scale,
height=control_image.size[1],
width=control_image.size[0]
).images[0]
print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f}GB")
# Aggressive memory cleanup
# torch.cuda.empty_cache()
# torch.cuda.ipc_collect()
print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f}GB")
return image
# Create Gradio interface
iface = gr.Interface(
fn=generate_image,
inputs=[
gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
gr.Slider(1, 3, value=1, label="Scale"),
gr.Slider(2, 20, value=8, label="Steps"),
gr.Image(type="pil", label="Control Image"),
gr.Slider(0, 1, value=0.6, label="ControlNet Scale"),
gr.Slider(1, 20, value=3.5, label="Guidance Scale"),
],
outputs=[
gr.Image(type="pil", label="Generated Image", format="png"),
],
title="FLUX ControlNet Image Generation",
description="Generate images using the FluxControlNetPipeline. Upload a control image and enter a prompt to create an image.",
)
print(f"Memory Usage: {torch.cuda.memory_summary(device=None, abbreviated=False)}")
# Launch the app
iface.launch() |