Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import time | |
import spaces # Must be imported before torch/CUDA | |
import torch | |
from diffusers import StableDiffusionXLImg2ImgPipeline | |
from diffusers.utils import load_image | |
from PIL import Image | |
from PIL.PngImagePlugin import PngInfo | |
import json | |
import gradio as gr | |
import tempfile | |
# Set environment variable to reduce memory fragmentation | |
#os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
# Initialize pipeline as None - will be loaded in GPU function | |
pipe = None | |
def load_pipeline(): | |
"""Load the pipeline on GPU when needed""" | |
global pipe | |
if pipe is None: | |
print("Loading pipeline...") | |
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-refiner-1.0", | |
torch_dtype=torch.float16, | |
variant="fp16", | |
use_safetensors=True, | |
device_map="balanced", | |
attn_implementation="eager" | |
) | |
# Enable memory optimizations | |
#pipe.enable_model_cpu_offload() | |
# Try to enable memory efficient attention | |
try: | |
pipe.enable_xformers_memory_efficient_attention() | |
except (ModuleNotFoundError, ImportError): | |
print("xformers not available, using attention slicing") | |
pipe.enable_attention_slicing() | |
print("Pipeline loaded successfully!") | |
return pipe | |
def img2img( | |
uploaded_image, | |
image_url: str, | |
prompt: str, | |
negative_prompt: str = "", | |
strength: float = 0.7, | |
guidance_scale: float = 3.5, | |
num_inference_steps: int = 50, | |
seed: int = -1, | |
): | |
# Load pipeline inside GPU context | |
try: | |
pipe = load_pipeline() | |
except Exception as e: | |
return None, f"β Failed to load model: {str(e)}", None | |
try: | |
# Choose image source | |
if uploaded_image is not None: | |
init_image = Image.open(uploaded_image).convert("RGB") | |
elif image_url.strip() != "": | |
try: | |
init_image = load_image(image_url).convert("RGB") | |
except Exception as e: | |
return None, f"β Failed to load image from URL: {str(e)}", None | |
else: | |
return None, "β Please upload an image or enter a valid URL", None | |
# Resize image (keeping aspect ratio consideration for better results) | |
init_image.thumbnail((1024, 1024), Image.Resampling.LANCZOS) | |
# Ensure dimensions are multiples of 8 for SDXL | |
width, height = init_image.size | |
width = (width // 8) * 8 | |
height = (height // 8) * 8 | |
init_image = init_image.resize((width, height)) | |
# Set seed and generator | |
if seed == -1: | |
generator = torch.Generator(device="cuda") | |
else: | |
generator = torch.Generator(device="cuda").manual_seed(seed) | |
# Validate inputs | |
if not prompt.strip(): | |
return None, "β Please enter a prompt", None | |
# Run inference with progress tracking | |
with torch.inference_mode(): | |
result = pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt if negative_prompt.strip() else None, | |
image=init_image, | |
strength=max(0.1, min(1.0, strength)), # Clamp strength | |
guidance_scale=max(1.0, min(20.0, guidance_scale)), # Clamp guidance | |
num_inference_steps=max(10, min(100, num_inference_steps)), # Clamp steps | |
generator=generator | |
).images[0] | |
used_seed = generator.initial_seed() | |
# Create metadata dictionary | |
metadata = { | |
"prompt": prompt, | |
"negative_prompt": negative_prompt, | |
"seed": used_seed, | |
"model": "stabilityai/stable-diffusion-xl-refiner-1.0", | |
"pipeline": "StableDiffusionXLImg2ImgPipeline", | |
"guidance_scale": guidance_scale, | |
"strength": strength, | |
"steps": num_inference_steps, | |
"width": result.width, | |
"height": result.height, | |
"device": "cuda" | |
} | |
# Save metadata into PNG | |
png_info = PngInfo() | |
png_info.add_text("parameters", json.dumps(metadata)) | |
# Use temporary file for HF Spaces | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file: | |
output_path = tmp_file.name | |
result.save(output_path, format="PNG", pnginfo=png_info) | |
# Build markdown preview of metadata | |
metadata_str = ( | |
f"**Prompt:** {metadata['prompt']}\n\n" | |
f"**Negative Prompt:** {metadata['negative_prompt']}\n\n" | |
f"**Seed:** {metadata['seed']}\n\n" | |
f"**Model:** {metadata['model']}\n\n" | |
f"**Guidance Scale:** {metadata['guidance_scale']}\n\n" | |
f"**Strength:** {metadata['strength']}\n\n" | |
f"**Steps:** {metadata['steps']}\n\n" | |
f"**Dimensions:** {metadata['width']}x{metadata['height']}\n\n" | |
f"**Device:** {metadata['device']}" | |
) | |
return output_path, f"β **Generation Complete!**\n\n{metadata_str}", output_path | |
except torch.cuda.OutOfMemoryError: | |
return None, "β GPU out of memory. Try reducing image size or inference steps.", None | |
except Exception as e: | |
return None, f"β Error during generation: {str(e)}", None | |
# Define UI components with better styling | |
title = "π¨ SDXL Image-to-Image Editor" | |
description = """ | |
Transform your images with AI! Upload an image and describe the changes you want to make. | |
**Tips:** | |
- Use detailed prompts for better results | |
- Lower strength values preserve more of the original image | |
- Higher guidance scale follows your prompt more closely | |
""" | |
# Custom CSS for better appearance | |
css = """ | |
.gradio-container { | |
font-family: 'IBM Plex Sans', sans-serif; | |
} | |
.gr-button { | |
color: white; | |
background: linear-gradient(90deg, #4f46e5, #7c3aed); | |
border: none; | |
} | |
.gr-button:hover { | |
background: linear-gradient(90deg, #4338ca, #6d28d9); | |
} | |
""" | |
with gr.Blocks(title=title, css=css, theme=gr.themes.Soft()) as demo: | |
gr.Markdown(f"# {title}") | |
gr.Markdown(description) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### πΈ Input Image") | |
uploaded_image = gr.Image( | |
label="Upload Image", | |
type="filepath", | |
height=300 | |
) | |
gr.Markdown("**Or**") | |
image_url = gr.Textbox( | |
label="Image URL", | |
placeholder="https://example.com/image.jpg", | |
info="Paste a direct link to an image" | |
) | |
gr.Markdown("### βοΈ Prompts") | |
prompt = gr.Textbox( | |
label="Prompt", | |
placeholder="a beautiful sunset over mountains, photorealistic, detailed", | |
lines=3, | |
info="Describe what you want to see" | |
) | |
negative_prompt = gr.Textbox( | |
label="Negative Prompt", | |
placeholder="blurry, low quality, distorted", | |
lines=2, | |
info="What to avoid in the image" | |
) | |
gr.Markdown("### βοΈ Settings") | |
with gr.Row(): | |
strength = gr.Slider( | |
minimum=0.1, maximum=1.0, value=0.7, step=0.05, | |
label="Transformation Strength", | |
info="0.1 = subtle changes, 1.0 = major changes" | |
) | |
guidance_scale = gr.Slider( | |
minimum=1.0, maximum=20.0, value=7.5, step=0.5, | |
label="Guidance Scale", | |
info="How closely to follow the prompt" | |
) | |
with gr.Row(): | |
num_inference_steps = gr.Slider( | |
minimum=10, maximum=50, step=5, value=30, | |
label="Quality Steps", | |
info="More steps = higher quality but slower" | |
) | |
seed = gr.Slider( | |
minimum=-1, maximum=999999, step=1, value=-1, | |
label="Seed", | |
info="-1 for random" | |
) | |
submit_btn = gr.Button("π Generate Image", variant="primary", size="lg") | |
with gr.Column(scale=1): | |
gr.Markdown("### πΌοΈ Result") | |
image_output = gr.Image(label="Generated Image", height=400) | |
download_button = gr.File(label="π₯ Download Full Resolution", visible=False) | |
gr.Markdown("### π Generation Details") | |
metadata_output = gr.Markdown() | |
# Event handlers | |
submit_btn.click( | |
fn=img2img, | |
inputs=[ | |
uploaded_image, | |
image_url, | |
prompt, | |
negative_prompt, | |
strength, | |
guidance_scale, | |
num_inference_steps, | |
seed | |
], | |
outputs=[image_output, metadata_output, download_button] | |
).then( | |
lambda x: gr.update(visible=x is not None), | |
inputs=[image_output], | |
outputs=[download_button] | |
) | |
# Examples | |
gr.Markdown("### π― Examples") | |
gr.Examples( | |
examples=[ | |
[None, "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png", "make it a van gogh painting", "blurry, low quality", 0.8, 7.5, 30, 42], | |
[None, "https://picsum.photos/512/512?random=1", "turn into a cyberpunk cityscape", "blurry, distorted", 0.9, 8.0, 30, 123], | |
], | |
inputs=[uploaded_image, image_url, prompt, negative_prompt, strength, guidance_scale, num_inference_steps, seed], | |
) | |
# Launch configuration for HF Spaces | |
if __name__ == "__main__": | |
demo.queue(max_size=20) # Enable queuing for better performance | |
demo.launch( | |
show_error=True, | |
share=False, # Don't create gradio.live links in HF Spaces | |
inbrowser=False, # Don't try to open browser in cloud environment | |
quiet=False | |
) | |