Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline | |
from PIL import Image, PngImagePlugin | |
from datetime import datetime | |
import os | |
import gc | |
import time | |
import spaces | |
from typing import Optional, Tuple | |
from huggingface_hub import hf_hub_download | |
import tempfile | |
import random | |
# Global pipeline variables | |
txt2img_pipe = None | |
img2img_pipe = None | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Hugging Face model configuration | |
MODEL_REPO = "ajsbsd/CyberRealistic-Pony" | |
MODEL_FILENAME = "cyberrealisticPony_v110.safetensors" | |
model_id = f"{MODEL_REPO}/{MODEL_FILENAME}" | |
# Generation configuration for metadata | |
generation_config = { | |
"vae": "SDXL VAE", | |
"sampler": "DPM++ 2M Karras", | |
"steps": 20 | |
} | |
def clear_memory(): | |
"""Clear GPU memory""" | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
def add_metadata_and_save(image: Image.Image, prompt: str, negative_prompt: str, seed: int, steps: int, guidance: float, strength: Optional[float] = None): | |
"""Embed generation metadata into a PNG and save it.""" | |
# Create temporary file with unique name | |
temp_path = tempfile.mktemp(suffix=".png") | |
meta = PngImagePlugin.PngInfo() | |
meta.add_text("Prompt", prompt) | |
meta.add_text("NegativePrompt", negative_prompt) | |
meta.add_text("Model", model_id) | |
meta.add_text("VAE", generation_config["vae"]) | |
meta.add_text("Sampler", generation_config["sampler"]) | |
meta.add_text("Steps", str(steps)) | |
meta.add_text("CFG_Scale", str(guidance)) | |
if strength is not None: | |
meta.add_text("Strength", str(strength)) | |
meta.add_text("Seed", str(seed)) | |
meta.add_text("Date", datetime.now().strftime("%Y-%m-%d %H:%M:%S")) | |
image.save(temp_path, "PNG", pnginfo=meta) | |
return temp_path | |
def load_models(): | |
"""Load both text2img and img2img pipelines optimized for Spaces""" | |
global txt2img_pipe, img2img_pipe | |
try: | |
print("Loading CyberRealistic Pony models...") | |
# Download model file using huggingface_hub | |
print(f"Downloading model from {MODEL_REPO}...") | |
model_path = hf_hub_download( | |
repo_id=MODEL_REPO, | |
filename=MODEL_FILENAME, | |
cache_dir="/tmp/hf_cache" # Use tmp for Spaces | |
) | |
print(f"Model downloaded to: {model_path}") | |
# Load Text2Img pipeline | |
if txt2img_pipe is None: | |
txt2img_pipe = StableDiffusionXLPipeline.from_single_file( | |
model_path, | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
use_safetensors=True, | |
variant="fp16" if device == "cuda" else None | |
) | |
# Aggressive memory optimizations for Spaces | |
txt2img_pipe.enable_attention_slicing() | |
txt2img_pipe.enable_vae_slicing() | |
if device == "cuda": | |
txt2img_pipe.enable_model_cpu_offload() | |
txt2img_pipe.enable_sequential_cpu_offload() | |
else: | |
txt2img_pipe = txt2img_pipe.to(device) | |
# Share components for Img2Img to save memory | |
if img2img_pipe is None: | |
img2img_pipe = StableDiffusionXLImg2ImgPipeline( | |
vae=txt2img_pipe.vae, | |
text_encoder=txt2img_pipe.text_encoder, | |
text_encoder_2=txt2img_pipe.text_encoder_2, | |
tokenizer=txt2img_pipe.tokenizer, | |
tokenizer_2=txt2img_pipe.tokenizer_2, | |
unet=txt2img_pipe.unet, | |
scheduler=txt2img_pipe.scheduler, | |
) | |
# Same optimizations | |
img2img_pipe.enable_attention_slicing() | |
img2img_pipe.enable_vae_slicing() | |
if device == "cuda": | |
img2img_pipe.enable_model_cpu_offload() | |
img2img_pipe.enable_sequential_cpu_offload() | |
print("Models loaded successfully!") | |
return True | |
except Exception as e: | |
print(f"Error loading models: {e}") | |
return False | |
def enhance_prompt(prompt: str, add_quality_tags: bool = True) -> str: | |
"""Enhance prompt with Pony-style tags""" | |
if not prompt.strip(): | |
return prompt | |
if prompt.startswith("score_") or not add_quality_tags: | |
return prompt | |
quality_tags = "score_9, score_8_up, score_7_up, masterpiece, best quality, highly detailed" | |
return f"{quality_tags}, {prompt}" | |
def validate_dimensions(width: int, height: int) -> Tuple[int, int]: | |
"""Ensure dimensions are valid for SDXL""" | |
width = ((width + 63) // 64) * 64 | |
height = ((height + 63) // 64) * 64 | |
# More conservative limits for Spaces | |
width = max(512, min(1024, width)) | |
height = max(512, min(1024, height)) | |
return width, height | |
def format_status_with_metadata(generation_time: float, width: int, height: int, prompt: str, negative_prompt: str, seed: int, steps: int, guidance: float, strength: Optional[float] = None): | |
"""Format status message with generation metadata""" | |
status_parts = [ | |
f"β Generated in {generation_time:.1f}s ({width}Γ{height})", | |
f"π― Prompt: {prompt[:50]}..." if len(prompt) > 50 else f"π― Prompt: {prompt}", | |
f"π« Negative: {negative_prompt[:30]}..." if negative_prompt and len(negative_prompt) > 30 else f"π« Negative: {negative_prompt or 'None'}", | |
f"π² Seed: {seed}", | |
f"π Steps: {steps}", | |
f"ποΈ CFG: {guidance}" | |
] | |
if strength is not None: | |
status_parts.append(f"πͺ Strength: {strength}") | |
return "\n".join(status_parts) | |
# GPU decorator for Spaces | |
def generate_txt2img(prompt, negative_prompt, num_steps, guidance_scale, width, height, seed, add_quality_tags): | |
"""Generate image from text prompt with Spaces GPU support""" | |
global txt2img_pipe | |
if not prompt.strip(): | |
return None, "Please enter a prompt" | |
# Lazy load models | |
if txt2img_pipe is None: | |
if not load_models(): | |
return None, "Failed to load models. Please try again." | |
try: | |
clear_memory() | |
# Validate dimensions | |
width, height = validate_dimensions(width, height) | |
# Handle seed | |
if seed == -1: | |
seed = random.randint(0, 2147483647) | |
# Set seed | |
generator = torch.Generator(device=device).manual_seed(int(seed)) | |
# Enhance prompt | |
enhanced_prompt = enhance_prompt(prompt, add_quality_tags) | |
print(f"Generating: {enhanced_prompt[:100]}...") | |
start_time = time.time() | |
# Generate with lower memory usage | |
with torch.no_grad(): | |
result = txt2img_pipe( | |
prompt=enhanced_prompt, | |
negative_prompt=negative_prompt or "", | |
num_inference_steps=min(int(num_steps), 30), # Limit steps for Spaces | |
guidance_scale=float(guidance_scale), | |
width=width, | |
height=height, | |
generator=generator | |
) | |
generation_time = time.time() - start_time | |
# Save with metadata - returns file path | |
png_path = add_metadata_and_save( | |
result.images[0], enhanced_prompt, negative_prompt or "", | |
seed, num_steps, guidance_scale | |
) | |
# Format status with metadata | |
status = format_status_with_metadata( | |
generation_time, width, height, enhanced_prompt, | |
negative_prompt or "", seed, num_steps, guidance_scale | |
) | |
return png_path, status | |
except Exception as e: | |
return None, f"Generation failed: {str(e)}" | |
finally: | |
clear_memory() | |
# GPU decorator for Spaces | |
def generate_img2img(input_image, prompt, negative_prompt, num_steps, guidance_scale, strength, seed, add_quality_tags): | |
"""Generate image from input image + text prompt with Spaces GPU support""" | |
global img2img_pipe | |
if input_image is None: | |
return None, "Please upload an input image" | |
if not prompt.strip(): | |
return None, "Please enter a prompt" | |
# Lazy load models | |
if img2img_pipe is None: | |
if not load_models(): | |
return None, "Failed to load models. Please try again." | |
try: | |
clear_memory() | |
# Handle seed | |
if seed == -1: | |
seed = random.randint(0, 2147483647) | |
# Set seed | |
generator = torch.Generator(device=device).manual_seed(int(seed)) | |
# Enhance prompt | |
enhanced_prompt = enhance_prompt(prompt, add_quality_tags) | |
# Process input image | |
if isinstance(input_image, Image.Image): | |
if input_image.mode != 'RGB': | |
input_image = input_image.convert('RGB') | |
# Conservative resize for Spaces | |
max_size = 768 | |
input_image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) | |
w, h = input_image.size | |
w, h = validate_dimensions(w, h) | |
input_image = input_image.resize((w, h), Image.Resampling.LANCZOS) | |
print(f"Transforming: {enhanced_prompt[:100]}...") | |
start_time = time.time() | |
with torch.no_grad(): | |
result = img2img_pipe( | |
prompt=enhanced_prompt, | |
negative_prompt=negative_prompt or "", | |
image=input_image, | |
num_inference_steps=min(int(num_steps), 30), # Limit steps | |
guidance_scale=float(guidance_scale), | |
strength=float(strength), | |
generator=generator | |
) | |
generation_time = time.time() - start_time | |
# Save with metadata - returns file path | |
png_path = add_metadata_and_save( | |
result.images[0], enhanced_prompt, negative_prompt or "", | |
seed, num_steps, guidance_scale, strength | |
) | |
# Format status with metadata | |
status = format_status_with_metadata( | |
generation_time, w, h, enhanced_prompt, | |
negative_prompt or "", seed, num_steps, guidance_scale, strength | |
) | |
return png_path, status | |
except Exception as e: | |
return None, f"Transformation failed: {str(e)}" | |
finally: | |
clear_memory() | |
# Example prompts for inspiration | |
EXAMPLE_PROMPTS = [ | |
"beautiful anime girl with long flowing hair, cherry blossoms, soft lighting", | |
"cyberpunk cityscape at night, neon lights, rain reflections, detailed architecture", | |
"majestic dragon flying over mountains, fantasy landscape, dramatic clouds", | |
"cute anthropomorphic fox character, forest background, magical atmosphere", | |
"elegant woman in Victorian dress, portrait, ornate background, vintage style", | |
"futuristic robot with glowing eyes, metallic surface, sci-fi environment", | |
"mystical unicorn in enchanted forest, rainbow mane, sparkles, ethereal lighting", | |
"steampunk airship floating in sky, gears and brass, adventure scene" | |
] | |
def set_example_prompt(): | |
"""Return a random example prompt""" | |
return random.choice(EXAMPLE_PROMPTS) | |
# Simplified negative prompt for better performance | |
DEFAULT_NEGATIVE = """ | |
(low quality:1.3), (worst quality:1.3), (bad quality:1.2), blurry, noisy, ugly, deformed, | |
(text, watermark:1.4), (extra limbs:1.3), (bad hands:1.3), (bad anatomy:1.2) | |
""" | |
# Gradio interface optimized for Spaces | |
with gr.Blocks( | |
title="CyberRealistic Pony Generator", | |
theme=gr.themes.Soft() | |
) as demo: | |
gr.Markdown(""" | |
# π¨ CyberRealistic Pony Image Generator | |
Generate high-quality images using the CyberRealistic Pony SDXL model. | |
β οΈ **Note**: First generation may take longer as the model loads. GPU time is limited on Spaces. | |
π **Metadata**: All generated images include embedded metadata (prompt, settings, seed, etc.) | |
""") | |
with gr.Tabs(): | |
with gr.TabItem("π¨ Text to Image"): | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
txt2img_prompt = gr.Textbox( | |
label="Prompt", | |
placeholder="beautiful landscape, mountains, sunset", | |
lines=2, | |
scale=4 | |
) | |
txt2img_example_btn = gr.Button("π² Random Example", scale=1) | |
with gr.Accordion("Advanced Settings", open=False): | |
txt2img_negative = gr.Textbox( | |
label="Negative Prompt", | |
value=DEFAULT_NEGATIVE, | |
lines=2 | |
) | |
txt2img_quality_tags = gr.Checkbox( | |
label="Add Quality Tags", | |
value=True | |
) | |
with gr.Row(): | |
txt2img_steps = gr.Slider(10, 30, 20, step=1, label="Steps") | |
txt2img_guidance = gr.Slider(1.0, 15.0, 7.5, step=0.5, label="Guidance") | |
with gr.Row(): | |
txt2img_width = gr.Slider(512, 1024, 768, step=64, label="Width") | |
txt2img_height = gr.Slider(512, 1024, 768, step=64, label="Height") | |
txt2img_seed = gr.Slider( | |
minimum=-1, maximum=2147483647, value=-1, step=1, | |
label="Seed (-1 for random)" | |
) | |
txt2img_btn = gr.Button("π¨ Generate", variant="primary", size="lg") | |
with gr.Column(): | |
txt2img_output = gr.File(label="Generated PNG with Metadata", file_types=[".png"]) | |
txt2img_status = gr.Textbox(label="Generation Info", interactive=False, lines=6) | |
with gr.TabItem("πΌοΈ Image to Image"): | |
with gr.Row(): | |
with gr.Column(): | |
img2img_input = gr.Image(label="Input Image", type="pil", height=250) | |
with gr.Row(): | |
img2img_prompt = gr.Textbox( | |
label="Prompt", | |
placeholder="digital painting style, vibrant colors", | |
lines=2, | |
scale=4 | |
) | |
img2img_example_btn = gr.Button("π² Random Example", scale=1) | |
with gr.Accordion("Advanced Settings", open=False): | |
img2img_negative = gr.Textbox( | |
label="Negative Prompt", | |
value=DEFAULT_NEGATIVE, | |
lines=2 | |
) | |
img2img_quality_tags = gr.Checkbox( | |
label="Add Quality Tags", | |
value=True | |
) | |
with gr.Row(): | |
img2img_steps = gr.Slider(10, 30, 20, step=1, label="Steps") | |
img2img_guidance = gr.Slider(1.0, 15.0, 7.5, step=0.5, label="Guidance") | |
img2img_strength = gr.Slider( | |
0.1, 1.0, 0.75, step=0.05, | |
label="Strength (Higher = more creative)" | |
) | |
img2img_seed = gr.Slider( | |
minimum=-1, maximum=2147483647, value=-1, step=1, | |
label="Seed (-1 for random)" | |
) | |
img2img_btn = gr.Button("πΌοΈ Transform", variant="primary", size="lg") | |
with gr.Column(): | |
img2img_output = gr.File(label="Generated PNG with Metadata", file_types=[".png"]) | |
img2img_status = gr.Textbox(label="Generation Info", interactive=False, lines=6) | |
# Event handlers | |
txt2img_btn.click( | |
fn=generate_txt2img, | |
inputs=[txt2img_prompt, txt2img_negative, txt2img_steps, txt2img_guidance, | |
txt2img_width, txt2img_height, txt2img_seed, txt2img_quality_tags], | |
outputs=[txt2img_output, txt2img_status] | |
) | |
img2img_btn.click( | |
fn=generate_img2img, | |
inputs=[img2img_input, img2img_prompt, img2img_negative, img2img_steps, img2img_guidance, | |
img2img_strength, img2img_seed, img2img_quality_tags], | |
outputs=[img2img_output, img2img_status] | |
) | |
# Example prompt buttons | |
txt2img_example_btn.click( | |
fn=set_example_prompt, | |
outputs=[txt2img_prompt] | |
) | |
img2img_example_btn.click( | |
fn=set_example_prompt, | |
outputs=[img2img_prompt] | |
) | |
print(f"π CyberRealistic Pony Generator initialized on {device}") | |
if __name__ == "__main__": | |
demo.launch() |