claudetest / app.py
ajsbsd's picture
0
e66736e
import os
import time
import spaces # Must be imported before torch/CUDA
import torch
from diffusers import StableDiffusionXLImg2ImgPipeline
from diffusers.utils import load_image
from PIL import Image
from PIL.PngImagePlugin import PngInfo
import json
import gradio as gr
import tempfile
# Set environment variable to reduce memory fragmentation
#os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# Initialize pipeline as None - will be loaded in GPU function
pipe = None
def load_pipeline():
"""Load the pipeline on GPU when needed"""
global pipe
if pipe is None:
print("Loading pipeline...")
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
device_map="balanced",
attn_implementation="eager"
)
# Enable memory optimizations
#pipe.enable_model_cpu_offload()
# Try to enable memory efficient attention
try:
pipe.enable_xformers_memory_efficient_attention()
except (ModuleNotFoundError, ImportError):
print("xformers not available, using attention slicing")
pipe.enable_attention_slicing()
print("Pipeline loaded successfully!")
return pipe
@spaces.GPU
def img2img(
uploaded_image,
image_url: str,
prompt: str,
negative_prompt: str = "",
strength: float = 0.7,
guidance_scale: float = 3.5,
num_inference_steps: int = 50,
seed: int = -1,
):
# Load pipeline inside GPU context
try:
pipe = load_pipeline()
except Exception as e:
return None, f"❌ Failed to load model: {str(e)}", None
try:
# Choose image source
if uploaded_image is not None:
init_image = Image.open(uploaded_image).convert("RGB")
elif image_url.strip() != "":
try:
init_image = load_image(image_url).convert("RGB")
except Exception as e:
return None, f"❌ Failed to load image from URL: {str(e)}", None
else:
return None, "❌ Please upload an image or enter a valid URL", None
# Resize image (keeping aspect ratio consideration for better results)
init_image.thumbnail((1024, 1024), Image.Resampling.LANCZOS)
# Ensure dimensions are multiples of 8 for SDXL
width, height = init_image.size
width = (width // 8) * 8
height = (height // 8) * 8
init_image = init_image.resize((width, height))
# Set seed and generator
if seed == -1:
generator = torch.Generator(device="cuda")
else:
generator = torch.Generator(device="cuda").manual_seed(seed)
# Validate inputs
if not prompt.strip():
return None, "❌ Please enter a prompt", None
# Run inference with progress tracking
with torch.inference_mode():
result = pipe(
prompt=prompt,
negative_prompt=negative_prompt if negative_prompt.strip() else None,
image=init_image,
strength=max(0.1, min(1.0, strength)), # Clamp strength
guidance_scale=max(1.0, min(20.0, guidance_scale)), # Clamp guidance
num_inference_steps=max(10, min(100, num_inference_steps)), # Clamp steps
generator=generator
).images[0]
used_seed = generator.initial_seed()
# Create metadata dictionary
metadata = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"seed": used_seed,
"model": "stabilityai/stable-diffusion-xl-refiner-1.0",
"pipeline": "StableDiffusionXLImg2ImgPipeline",
"guidance_scale": guidance_scale,
"strength": strength,
"steps": num_inference_steps,
"width": result.width,
"height": result.height,
"device": "cuda"
}
# Save metadata into PNG
png_info = PngInfo()
png_info.add_text("parameters", json.dumps(metadata))
# Use temporary file for HF Spaces
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
output_path = tmp_file.name
result.save(output_path, format="PNG", pnginfo=png_info)
# Build markdown preview of metadata
metadata_str = (
f"**Prompt:** {metadata['prompt']}\n\n"
f"**Negative Prompt:** {metadata['negative_prompt']}\n\n"
f"**Seed:** {metadata['seed']}\n\n"
f"**Model:** {metadata['model']}\n\n"
f"**Guidance Scale:** {metadata['guidance_scale']}\n\n"
f"**Strength:** {metadata['strength']}\n\n"
f"**Steps:** {metadata['steps']}\n\n"
f"**Dimensions:** {metadata['width']}x{metadata['height']}\n\n"
f"**Device:** {metadata['device']}"
)
return output_path, f"βœ… **Generation Complete!**\n\n{metadata_str}", output_path
except torch.cuda.OutOfMemoryError:
return None, "❌ GPU out of memory. Try reducing image size or inference steps.", None
except Exception as e:
return None, f"❌ Error during generation: {str(e)}", None
# Define UI components with better styling
title = "🎨 SDXL Image-to-Image Editor"
description = """
Transform your images with AI! Upload an image and describe the changes you want to make.
**Tips:**
- Use detailed prompts for better results
- Lower strength values preserve more of the original image
- Higher guidance scale follows your prompt more closely
"""
# Custom CSS for better appearance
css = """
.gradio-container {
font-family: 'IBM Plex Sans', sans-serif;
}
.gr-button {
color: white;
background: linear-gradient(90deg, #4f46e5, #7c3aed);
border: none;
}
.gr-button:hover {
background: linear-gradient(90deg, #4338ca, #6d28d9);
}
"""
with gr.Blocks(title=title, css=css, theme=gr.themes.Soft()) as demo:
gr.Markdown(f"# {title}")
gr.Markdown(description)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### πŸ“Έ Input Image")
uploaded_image = gr.Image(
label="Upload Image",
type="filepath",
height=300
)
gr.Markdown("**Or**")
image_url = gr.Textbox(
label="Image URL",
placeholder="https://example.com/image.jpg",
info="Paste a direct link to an image"
)
gr.Markdown("### ✍️ Prompts")
prompt = gr.Textbox(
label="Prompt",
placeholder="a beautiful sunset over mountains, photorealistic, detailed",
lines=3,
info="Describe what you want to see"
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="blurry, low quality, distorted",
lines=2,
info="What to avoid in the image"
)
gr.Markdown("### βš™οΈ Settings")
with gr.Row():
strength = gr.Slider(
minimum=0.1, maximum=1.0, value=0.7, step=0.05,
label="Transformation Strength",
info="0.1 = subtle changes, 1.0 = major changes"
)
guidance_scale = gr.Slider(
minimum=1.0, maximum=20.0, value=7.5, step=0.5,
label="Guidance Scale",
info="How closely to follow the prompt"
)
with gr.Row():
num_inference_steps = gr.Slider(
minimum=10, maximum=50, step=5, value=30,
label="Quality Steps",
info="More steps = higher quality but slower"
)
seed = gr.Slider(
minimum=-1, maximum=999999, step=1, value=-1,
label="Seed",
info="-1 for random"
)
submit_btn = gr.Button("πŸš€ Generate Image", variant="primary", size="lg")
with gr.Column(scale=1):
gr.Markdown("### πŸ–ΌοΈ Result")
image_output = gr.Image(label="Generated Image", height=400)
download_button = gr.File(label="πŸ“₯ Download Full Resolution", visible=False)
gr.Markdown("### πŸ“Š Generation Details")
metadata_output = gr.Markdown()
# Event handlers
submit_btn.click(
fn=img2img,
inputs=[
uploaded_image,
image_url,
prompt,
negative_prompt,
strength,
guidance_scale,
num_inference_steps,
seed
],
outputs=[image_output, metadata_output, download_button]
).then(
lambda x: gr.update(visible=x is not None),
inputs=[image_output],
outputs=[download_button]
)
# Examples
gr.Markdown("### 🎯 Examples")
gr.Examples(
examples=[
[None, "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png", "make it a van gogh painting", "blurry, low quality", 0.8, 7.5, 30, 42],
[None, "https://picsum.photos/512/512?random=1", "turn into a cyberpunk cityscape", "blurry, distorted", 0.9, 8.0, 30, 123],
],
inputs=[uploaded_image, image_url, prompt, negative_prompt, strength, guidance_scale, num_inference_steps, seed],
)
# Launch configuration for HF Spaces
if __name__ == "__main__":
demo.queue(max_size=20) # Enable queuing for better performance
demo.launch(
show_error=True,
share=False, # Don't create gradio.live links in HF Spaces
inbrowser=False, # Don't try to open browser in cloud environment
quiet=False
)