Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline | |
from PIL import Image | |
import os | |
import gc | |
import time | |
from typing import Optional, Tuple | |
from huggingface_hub import hf_hub_download | |
import requests | |
# Global pipeline variables | |
txt2img_pipe = None | |
img2img_pipe = None | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Hugging Face model configuration | |
MODEL_REPO = "ajsbsd/CyberRealistic-Pony" | |
MODEL_FILENAME = "cyberrealisticPony_v110.safetensors" | |
LOCAL_MODEL_PATH = "./models/cyberrealisticPony_v110.safetensors" | |
def clear_memory(): | |
"""Clear GPU memory""" | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
def download_model(): | |
"""Download model from Hugging Face if not already cached""" | |
try: | |
# Create models directory if it doesn't exist | |
os.makedirs("./models", exist_ok=True) | |
# Check if model already exists locally | |
if os.path.exists(LOCAL_MODEL_PATH): | |
print(f"Model already exists at {LOCAL_MODEL_PATH}") | |
return LOCAL_MODEL_PATH | |
print(f"Downloading model from {MODEL_REPO}/{MODEL_FILENAME}...") | |
print("This may take a while on first run...") | |
# Download the model file | |
model_path = hf_hub_download( | |
repo_id=MODEL_REPO, | |
filename=MODEL_FILENAME, | |
local_dir="./models", | |
local_dir_use_symlinks=False, | |
resume_download=True | |
) | |
print(f"Model downloaded successfully to {model_path}") | |
return model_path | |
except Exception as e: | |
print(f"Error downloading model: {e}") | |
print("Attempting to use cached version or fallback...") | |
# Try to use Hugging Face cache directly | |
try: | |
cached_path = hf_hub_download( | |
repo_id=MODEL_REPO, | |
filename=MODEL_FILENAME, | |
resume_download=True | |
) | |
print(f"Using cached model at {cached_path}") | |
return cached_path | |
except Exception as cache_error: | |
print(f"Cache fallback failed: {cache_error}") | |
return None | |
def load_models(): | |
"""Load both text2img and img2img pipelines with Hugging Face integration""" | |
global txt2img_pipe, img2img_pipe | |
# Download model if needed | |
model_path = download_model() | |
if model_path is None: | |
print("Failed to download or locate model file") | |
return None, None | |
if not os.path.exists(model_path): | |
print(f"Model file not found after download: {model_path}") | |
return None, None | |
if txt2img_pipe is None: | |
try: | |
print("Loading CyberRealistic Pony Text2Img model...") | |
txt2img_pipe = StableDiffusionXLPipeline.from_single_file( | |
model_path, | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
use_safetensors=True, | |
variant="fp16" if device == "cuda" else None | |
) | |
# Memory optimizations | |
txt2img_pipe.enable_attention_slicing() | |
if device == "cuda": | |
try: | |
txt2img_pipe.enable_model_cpu_offload() | |
print("Text2Img CPU offload enabled") | |
except Exception as e: | |
print(f"Text2Img CPU offload failed: {e}") | |
txt2img_pipe = txt2img_pipe.to(device) | |
else: | |
txt2img_pipe = txt2img_pipe.to(device) | |
print("Text2Img model loaded successfully!") | |
except Exception as e: | |
print(f"Error loading Text2Img model: {e}") | |
return None, None | |
if img2img_pipe is None: | |
try: | |
print("Loading CyberRealistic Pony Img2Img model...") | |
img2img_pipe = StableDiffusionXLImg2ImgPipeline.from_single_file( | |
model_path, | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
use_safetensors=True, | |
variant="fp16" if device == "cuda" else None | |
) | |
# Memory optimizations | |
img2img_pipe.enable_attention_slicing() | |
if device == "cuda": | |
try: | |
img2img_pipe.enable_model_cpu_offload() | |
print("Img2Img CPU offload enabled") | |
except Exception as e: | |
print(f"Img2Img CPU offload failed: {e}") | |
img2img_pipe = img2img_pipe.to(device) | |
else: | |
img2img_pipe = img2img_pipe.to(device) | |
print("Img2Img model loaded successfully!") | |
except Exception as e: | |
print(f"Error loading Img2Img model: {e}") | |
return txt2img_pipe, None | |
return txt2img_pipe, img2img_pipe | |
def enhance_prompt(prompt: str, add_quality_tags: bool = True) -> str: | |
"""Enhance prompt with Pony-style tags""" | |
if not prompt.strip(): | |
return prompt | |
# Don't add tags if already present | |
if prompt.startswith("score_") or not add_quality_tags: | |
return prompt | |
quality_tags = "score_9, score_8_up, score_7_up, masterpiece, best quality, highly detailed" | |
return f"{quality_tags}, {prompt}" | |
def validate_dimensions(width: int, height: int) -> Tuple[int, int]: | |
"""Ensure dimensions are valid for SDXL""" | |
# SDXL works best with dimensions divisible by 64 | |
width = ((width + 63) // 64) * 64 | |
height = ((height + 63) // 64) * 64 | |
# Ensure reasonable limits | |
width = max(512, min(1536, width)) | |
height = max(512, min(1536, height)) | |
return width, height | |
def generate_txt2img(prompt, negative_prompt, num_steps, guidance_scale, width, height, seed, add_quality_tags): | |
"""Generate image from text prompt with enhanced error handling""" | |
global txt2img_pipe | |
if not prompt.strip(): | |
return None, "Please enter a prompt" | |
# Load models if not already loaded | |
if txt2img_pipe is None: | |
txt2img_pipe, _ = load_models() | |
if txt2img_pipe is None: | |
return None, "Failed to load Text2Img model. Please check your internet connection and try again." | |
try: | |
# Clear memory before generation | |
clear_memory() | |
# Validate and fix dimensions | |
width, height = validate_dimensions(width, height) | |
# Set seed for reproducibility | |
generator = None | |
if seed != -1: | |
generator = torch.Generator(device=device).manual_seed(int(seed)) | |
# Enhance prompt | |
enhanced_prompt = enhance_prompt(prompt, add_quality_tags) | |
print(f"Generating with prompt: {enhanced_prompt[:100]}...") | |
start_time = time.time() | |
# Generate image | |
with torch.no_grad(): | |
result = txt2img_pipe( | |
prompt=enhanced_prompt, | |
negative_prompt=negative_prompt or "", | |
num_inference_steps=int(num_steps), | |
guidance_scale=float(guidance_scale), | |
width=width, | |
height=height, | |
generator=generator | |
) | |
generation_time = time.time() - start_time | |
status = f"Text2Img: Generated successfully in {generation_time:.1f}s (Size: {width}x{height})" | |
return result.images[0], status | |
except Exception as e: | |
error_msg = f"Text2Img generation failed: {str(e)}" | |
print(error_msg) | |
return None, error_msg | |
finally: | |
clear_memory() | |
def generate_img2img(input_image, prompt, negative_prompt, num_steps, guidance_scale, strength, seed, add_quality_tags): | |
"""Generate image from input image + text prompt with enhanced error handling""" | |
global img2img_pipe | |
if input_image is None: | |
return None, "Please upload an input image for Img2Img" | |
if not prompt.strip(): | |
return None, "Please enter a prompt" | |
# Load models if not already loaded | |
if img2img_pipe is None: | |
_, img2img_pipe = load_models() | |
if img2img_pipe is None: | |
return None, "Failed to load Img2Img model. Please check your internet connection and try again." | |
try: | |
# Clear memory before generation | |
clear_memory() | |
# Set seed for reproducibility | |
generator = None | |
if seed != -1: | |
generator = torch.Generator(device=device).manual_seed(int(seed)) | |
# Enhance prompt | |
enhanced_prompt = enhance_prompt(prompt, add_quality_tags) | |
# Process input image | |
if isinstance(input_image, Image.Image): | |
# Ensure RGB format | |
if input_image.mode != 'RGB': | |
input_image = input_image.convert('RGB') | |
# Resize to reasonable dimensions while maintaining aspect ratio | |
original_size = input_image.size | |
max_size = 1024 | |
input_image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) | |
# Ensure dimensions are divisible by 64 | |
w, h = input_image.size | |
w, h = validate_dimensions(w, h) | |
input_image = input_image.resize((w, h), Image.Resampling.LANCZOS) | |
print(f"Generating with prompt: {enhanced_prompt[:100]}...") | |
start_time = time.time() | |
# Generate image | |
with torch.no_grad(): | |
result = img2img_pipe( | |
prompt=enhanced_prompt, | |
negative_prompt=negative_prompt or "", | |
image=input_image, | |
num_inference_steps=int(num_steps), | |
guidance_scale=float(guidance_scale), | |
strength=float(strength), | |
generator=generator | |
) | |
generation_time = time.time() - start_time | |
status = f"Img2Img: Generated successfully in {generation_time:.1f}s (Strength: {strength})" | |
return result.images[0], status | |
except Exception as e: | |
error_msg = f"Img2Img generation failed: {str(e)}" | |
print(error_msg) | |
return None, error_msg | |
finally: | |
clear_memory() | |
# Default negative prompt (improved) | |
DEFAULT_NEGATIVE = """ | |
(low quality:1.4), (worst quality:1.4), (bad quality:1.3), (normal quality:1.2), lowres, jpeg artifacts, blurry, noisy, ugly, deformed, disfigured, malformed, poorly drawn, bad art, amateur, render, 3D, cgi, | |
(text, signature, watermark, username, copyright:1.5), | |
(extra limbs:1.5), (missing limbs:1.5), (extra fingers:1.5), (missing fingers:1.5), (mutated hands:1.5), (bad hands:1.4), (poorly drawn hands:1.3), (ugly hands:1.2), | |
(bad anatomy:1.4), (deformed body:1.3), (unnatural body:1.2), (cross-eyed:1.3), (skewed eyes:1.3), (imperfect eyes:1.2), (ugly eyes:1.2), (asymmetrical face:1.2), (unnatural face:1.2), | |
(blush:1.1), (shadow on skin:1.1), (shaded skin:1.1), (dark skin:1.1), | |
abstract, simplified, unrealistic, impressionistic, cartoon, anime, drawing, sketch, illustration, painting, censored, grayscale, monochrome, out of frame, cropped, distorted. | |
""" | |
# Create Gradio interface with enhanced styling | |
with gr.Blocks( | |
title="CyberRealistic Pony Image Generator", | |
theme=gr.themes.Soft(), | |
css=""" | |
.gradio-container { | |
max-width: 1200px !important; | |
} | |
.tab-nav button { | |
font-size: 16px; | |
font-weight: bold; | |
} | |
""" | |
) as demo: | |
gr.Markdown(""" | |
# 🎨 CyberRealistic Pony Image Generator (Hugging Face Edition) | |
Generate high-quality images using the CyberRealistic Pony SDXL model from Hugging Face. | |
**Features:** | |
- 🎨 Text-to-Image generation | |
- 🖼️ Image-to-Image transformation | |
- 🎯 Automatic quality tag enhancement | |
- ⚡ Memory optimized for better performance | |
- 🤗 Auto-downloads model from Hugging Face | |
**Note:** On first run, the model will be downloaded from Hugging Face (this may take a few minutes). | |
""") | |
with gr.Tabs(): | |
# Text2Image Tab | |
with gr.TabItem("🎨 Text to Image"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# Input controls for Text2Img | |
txt2img_prompt = gr.Textbox( | |
label="Prompt", | |
placeholder="Enter your image description...", | |
value="beautiful landscape with mountains and lake at sunset", | |
lines=3 | |
) | |
txt2img_negative = gr.Textbox( | |
label="Negative Prompt", | |
value=DEFAULT_NEGATIVE, | |
lines=3 | |
) | |
txt2img_quality_tags = gr.Checkbox( | |
label="Add Quality Tags", | |
value=True | |
) | |
with gr.Row(): | |
txt2img_steps = gr.Slider( | |
minimum=10, | |
maximum=50, | |
value=25, | |
step=1, | |
label="Inference Steps" | |
) | |
txt2img_guidance = gr.Slider( | |
minimum=1.0, | |
maximum=20.0, | |
value=7.5, | |
step=0.5, | |
label="Guidance Scale" | |
) | |
with gr.Row(): | |
txt2img_width = gr.Slider( | |
minimum=512, | |
maximum=1536, | |
value=1024, | |
step=64, | |
label="Width" | |
) | |
txt2img_height = gr.Slider( | |
minimum=512, | |
maximum=1536, | |
value=1024, | |
step=64, | |
label="Height" | |
) | |
txt2img_seed = gr.Number( | |
label="Seed (-1 for random)", | |
value=-1, | |
precision=0 | |
) | |
txt2img_btn = gr.Button("🎨 Generate Image", variant="primary") | |
with gr.Column(scale=2): | |
# Output for Text2Img | |
txt2img_output = gr.Image( | |
label="Generated Image", | |
type="pil", | |
height=600 | |
) | |
txt2img_status = gr.Textbox(label="Status", interactive=False) | |
# Image2Image Tab | |
with gr.TabItem("🖼️ Image to Image"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# Input controls for Img2Img | |
img2img_input = gr.Image( | |
label="Input Image", | |
type="pil", | |
height=300 | |
) | |
img2img_prompt = gr.Textbox( | |
label="Prompt", | |
placeholder="Describe how to modify the image...", | |
value="in the style of a digital painting, vibrant colors", | |
lines=3 | |
) | |
img2img_negative = gr.Textbox( | |
label="Negative Prompt", | |
value=DEFAULT_NEGATIVE, | |
lines=3 | |
) | |
img2img_quality_tags = gr.Checkbox( | |
label="Add Quality Tags", | |
value=True | |
) | |
with gr.Row(): | |
img2img_steps = gr.Slider( | |
minimum=10, | |
maximum=50, | |
value=25, | |
step=1, | |
label="Inference Steps" | |
) | |
img2img_guidance = gr.Slider( | |
minimum=1.0, | |
maximum=20.0, | |
value=7.5, | |
step=0.5, | |
label="Guidance Scale" | |
) | |
img2img_strength = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.75, | |
step=0.05, | |
label="Denoising Strength (Lower = more like input, Higher = more creative)" | |
) | |
img2img_seed = gr.Number( | |
label="Seed (-1 for random)", | |
value=-1, | |
precision=0 | |
) | |
img2img_btn = gr.Button("🖼️ Transform Image", variant="primary") | |
with gr.Column(scale=2): | |
# Output for Img2Img | |
img2img_output = gr.Image( | |
label="Generated Image", | |
type="pil", | |
height=600 | |
) | |
img2img_status = gr.Textbox(label="Status", interactive=False) | |
# Event handlers | |
txt2img_btn.click( | |
fn=generate_txt2img, | |
inputs=[txt2img_prompt, txt2img_negative, txt2img_steps, txt2img_guidance, | |
txt2img_width, txt2img_height, txt2img_seed, txt2img_quality_tags], | |
outputs=[txt2img_output, txt2img_status] | |
) | |
img2img_btn.click( | |
fn=generate_img2img, | |
inputs=[img2img_input, img2img_prompt, img2img_negative, txt2img_steps, img2img_guidance, | |
img2img_strength, img2img_seed, img2img_quality_tags], | |
outputs=[img2img_output, img2img_status] | |
) | |
# Load models on startup | |
print("Initializing CyberRealistic Pony Generator (Hugging Face Edition)...") | |
print(f"Device: {device}") | |
print(f"Model Repository: {MODEL_REPO}") | |
print(f"Model File: {MODEL_FILENAME}") | |
# Pre-load models in a separate thread to avoid blocking startup | |
import threading | |
def preload_models(): | |
"""Pre-load models in background""" | |
try: | |
print("Starting background model loading...") | |
load_models() | |
print("Background model loading completed!") | |
except Exception as e: | |
print(f"Background model loading failed: {e}") | |
# Start background loading | |
loading_thread = threading.Thread(target=preload_models, daemon=True) | |
loading_thread.start() | |
if __name__ == "__main__": | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False, | |
show_error=True | |
) |