Spaces:
Paused
Paused
import gradio as gr | |
import torch | |
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, EulerAncestralDiscreteScheduler | |
from PIL import Image, PngImagePlugin, ImageFilter | |
from datetime import datetime | |
import os | |
import gc | |
import time | |
import spaces | |
from typing import Optional, Tuple, Dict, Any | |
from huggingface_hub import hf_hub_download | |
import tempfile | |
import random | |
import logging | |
import torch.nn.functional as F | |
from transformers import CLIPProcessor, CLIPModel | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Constants | |
MODEL_REPO = "ajsbsd/CyberRealistic-Pony" | |
MODEL_FILENAME = "cyberrealisticPony_v110.safetensors" | |
NSFW_MODEL_ID = "openai/clip-vit-base-patch32" # CLIP model for NSFW detection | |
MAX_SEED = 2**32 - 1 | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 | |
NSFW_THRESHOLD = 0.25 # Threshold for NSFW detection | |
# Global pipeline state | |
class PipelineManager: | |
def __init__(self): | |
self.txt2img_pipe = None | |
self.img2img_pipe = None | |
self.nsfw_detector_model = None | |
self.nsfw_detector_processor = None | |
self.model_loaded = False | |
self.nsfw_detector_loaded = False | |
def clear_memory(self): | |
"""Aggressive memory cleanup to free up GPU/CPU memory.""" | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
gc.collect() | |
def load_nsfw_detector(self) -> bool: | |
"""Load NSFW detection model (CLIP) with error handling.""" | |
if self.nsfw_detector_loaded: | |
return True | |
try: | |
logger.info("Loading NSFW detector...") | |
self.nsfw_detector_processor = CLIPProcessor.from_pretrained(NSFW_MODEL_ID) | |
# Add use_safetensors=True to the CLIPModel.from_pretrained call | |
self.nsfw_detector_model = CLIPModel.from_pretrained(NSFW_MODEL_ID, use_safetensors=True) | |
if DEVICE == "cuda": | |
self.nsfw_detector_model = self.nsfw_detector_model.to(DEVICE) | |
self.nsfw_detector_loaded = True | |
logger.info("NSFW detector loaded successfully!") | |
return True | |
except Exception as e: | |
logger.error(f"Failed to load NSFW detector: {e}") | |
self.nsfw_detector_loaded = False | |
return False | |
def is_nsfw(self, image: Image.Image, prompt: str = "") -> Tuple[bool, float]: | |
""" | |
Detects NSFW content using CLIP-based zero-shot classification. | |
Falls back to prompt-based detection if CLIP model fails or is not loaded. | |
""" | |
try: | |
# Load NSFW detector if not already loaded | |
if not self.nsfw_detector_loaded: | |
if not self.load_nsfw_detector(): | |
# If NSFW detector cannot be loaded, fall back to prompt-based | |
return self._fallback_nsfw_detection(prompt) | |
# CLIP-based NSFW detection | |
inputs = self.nsfw_detector_processor(images=image, return_tensors="pt").to(DEVICE) | |
with torch.no_grad(): | |
image_features = self.nsfw_detector_model.get_image_features(**inputs) | |
# Define text prompts for classification | |
safe_prompts = [ | |
"a safe family-friendly image", | |
"a general photo", | |
"appropriate content", | |
"artistic photography" | |
] | |
unsafe_prompts = [ | |
"explicit adult content", | |
"nudity", | |
"inappropriate sexual content", | |
"pornographic material" | |
] | |
# Get text features | |
safe_inputs = self.nsfw_detector_processor( | |
text=safe_prompts, return_tensors="pt", padding=True | |
).to(DEVICE) | |
unsafe_inputs = self.nsfw_detector_processor( | |
text=unsafe_prompts, return_tensors="pt", padding=True | |
).to(DEVICE) | |
safe_features = self.nsfw_detector_model.get_text_features(**safe_inputs) | |
unsafe_features = self.nsfw_detector_model.get_text_features(**unsafe_inputs) | |
# Normalize features for cosine similarity | |
image_features = F.normalize(image_features, p=2, dim=-1) | |
safe_features = F.normalize(safe_features, p=2, dim=-1) | |
unsafe_features = F.normalize(unsafe_features, p=2, dim=-1) | |
# Calculate similarities | |
safe_similarity = (image_features @ safe_features.T).mean().item() | |
unsafe_similarity = (image_features @ unsafe_features.T).mean().item() | |
# Classification logic | |
is_nsfw_result = ( | |
unsafe_similarity > safe_similarity and | |
unsafe_similarity > NSFW_THRESHOLD | |
) | |
confidence = unsafe_similarity if is_nsfw_result else safe_similarity | |
if is_nsfw_result: | |
logger.warning(f"π¨ NSFW content detected (CLIP-based: {unsafe_similarity:.3f} > {safe_similarity:.3f})") | |
return is_nsfw_result, confidence | |
except Exception as e: | |
logger.error(f"NSFW detection error (CLIP model failed): {e}") | |
# Fallback to prompt-based detection if CLIP model encounters an error | |
return self._fallback_nsfw_detection(prompt) | |
def _fallback_nsfw_detection(self, prompt: str = "") -> Tuple[bool, float]: | |
"""Fallback NSFW detection based on prompt keyword analysis.""" | |
nsfw_keywords = [ | |
'nude', 'naked', 'nsfw', 'explicit', 'sexual', 'erotic', 'porn', | |
'adult', 'xxx', 'sex', 'breast', 'nipple', 'genital', 'provocative' | |
] | |
prompt_lower = prompt.lower() | |
for keyword in nsfw_keywords: | |
if keyword in prompt_lower: | |
logger.warning(f"π¨ NSFW content detected (prompt-based: '{keyword}' found)") | |
return True, random.uniform(0.7, 0.95) | |
# Random chance for demonstration (consider removing in production) | |
if random.random() < 0.02: # 2% chance for demo | |
logger.warning("π¨ NSFW content detected (random demo detection)") | |
return True, random.uniform(0.6, 0.8) | |
return False, random.uniform(0.1, 0.3) | |
def load_models(self) -> bool: | |
"""Load Stable Diffusion XL models (txt2img and img2img) with enhanced error handling and memory optimization.""" | |
if self.model_loaded: | |
return True | |
try: | |
logger.info("Loading CyberRealistic Pony models...") | |
# Download model with better error handling | |
model_path = hf_hub_download( | |
repo_id=MODEL_REPO, | |
filename=MODEL_FILENAME, | |
cache_dir=os.environ.get("HF_CACHE_DIR", "/tmp/hf_cache"), | |
resume_download=True | |
) | |
logger.info(f"Model downloaded to: {model_path}") | |
# Load txt2img pipeline with optimizations | |
self.txt2img_pipe = StableDiffusionXLPipeline.from_single_file( | |
model_path, | |
torch_dtype=DTYPE, | |
use_safetensors=True, | |
variant="fp16" if DEVICE == "cuda" else None, | |
safety_checker=None, # Disable for faster loading, using custom NSFW check | |
requires_safety_checker=False | |
) | |
# Apply memory optimizations to txt2img pipeline | |
self._optimize_pipeline(self.txt2img_pipe) | |
# Create img2img pipeline sharing components | |
self.img2img_pipe = StableDiffusionXLImg2ImgPipeline( | |
vae=self.txt2img_pipe.vae, | |
text_encoder=self.txt2img_pipe.text_encoder, | |
text_encoder_2=self.txt2img_pipe.text_encoder_2, | |
tokenizer=self.txt2img_pipe.tokenizer, | |
tokenizer_2=self.txt2img_pipe.tokenizer_2, | |
unet=self.txt2img_pipe.unet, | |
scheduler=self.txt2img_pipe.scheduler, | |
# Removed safety_checker and requires_safety_checker as they are not valid for this constructor | |
) | |
# Apply memory optimizations to img2img pipeline | |
self._optimize_pipeline(self.img2img_pipe) | |
self.model_loaded = True | |
logger.info("Models loaded successfully!") | |
return True | |
except Exception as e: | |
logger.error(f"Failed to load models: {e}") | |
self.model_loaded = False | |
return False | |
def _optimize_pipeline(self, pipeline): | |
"""Apply memory optimizations to a given diffusion pipeline.""" | |
pipeline.enable_attention_slicing() | |
pipeline.enable_vae_slicing() | |
if DEVICE == "cuda": | |
# Use sequential CPU offloading for better memory management on GPU | |
pipeline.enable_sequential_cpu_offload() | |
# Enable memory efficient attention if xformers is available | |
try: | |
pipeline.enable_xformers_memory_efficient_attention() | |
except Exception: # Catch any error if xformers is not installed/configured | |
logger.info("xformers not available, using default attention") | |
else: | |
# Move pipeline to CPU if CUDA is not available | |
pipeline = pipeline.to(DEVICE) | |
# Global pipeline manager instance | |
pipe_manager = PipelineManager() | |
# Enhanced prompt templates | |
QUALITY_TAGS = "score_9, score_8_up, score_7_up, masterpiece, best quality, ultra detailed, 8k" | |
DEFAULT_NEGATIVE = """(worst quality:1.4), (low quality:1.4), (normal quality:1.2), | |
lowres, bad anatomy, bad hands, signature, watermarks, ugly, imperfect eyes, | |
skewed eyes, unnatural face, unnatural body, error, extra limb, missing limbs, | |
painting by bad-artist, 3d, render""" | |
EXAMPLE_PROMPTS = [ | |
"beautiful anime girl with long flowing silver hair, sakura petals, soft morning light", | |
"cyberpunk street scene, neon lights reflecting on wet pavement, futuristic cityscape", | |
"majestic dragon soaring through storm clouds, lightning, epic fantasy scene", | |
"cute anthropomorphic fox girl, fluffy tail, forest clearing, magical sparkles", | |
"elegant Victorian lady in ornate dress, portrait, vintage photography style", | |
"futuristic mech suit, glowing energy core, sci-fi laboratory background", | |
"mystical unicorn with rainbow mane, enchanted forest, ethereal atmosphere", | |
"steampunk inventor's workshop, brass gears, mechanical contraptions, warm lighting" | |
] | |
def enhance_prompt(prompt: str, add_quality: bool = True) -> str: | |
""" | |
Enhances the given prompt with quality tags unless they are already present. | |
""" | |
if not prompt.strip(): | |
return "" | |
# Don't add quality tags if they're already present in the prompt (case-insensitive) | |
if any(tag in prompt.lower() for tag in ["score_", "masterpiece", "best quality"]): | |
return prompt | |
if add_quality: | |
return f"{QUALITY_TAGS}, {prompt}" | |
return prompt | |
def validate_and_fix_dimensions(width: int, height: int) -> Tuple[int, int]: | |
""" | |
Ensures SDXL-compatible dimensions (multiples of 64) and reasonable aspect ratios. | |
Clamps dimensions between 512 and 1024. | |
""" | |
# Round to nearest multiple of 64 | |
width = max(512, min(1024, ((width + 31) // 64) * 64)) | |
height = max(512, min(1024, ((height + 31) // 64) * 64)) | |
# Ensure reasonable aspect ratios (prevent extremely wide/tall images) | |
aspect_ratio = width / height | |
if aspect_ratio > 2.0: # Too wide, adjust height | |
height = width // 2 | |
elif aspect_ratio < 0.5: # Too tall, adjust width | |
width = height // 2 | |
return width, height | |
def create_metadata_png(image: Image.Image, params: Dict[str, Any]) -> str: | |
""" | |
Creates a temporary PNG file with embedded metadata from the generation parameters. | |
Returns the path to the created PNG file. | |
""" | |
temp_path = tempfile.mktemp(suffix=".png", prefix="cyberrealistic_") | |
meta = PngImagePlugin.PngInfo() | |
for key, value in params.items(): | |
if value is not None: | |
meta.add_text(key, str(value)) | |
# Add generation timestamp and model info | |
meta.add_text("Generated", datetime.now().strftime("%Y-%m-%d %H:%M:%S UTC")) | |
meta.add_text("Model", f"{MODEL_REPO}/{MODEL_FILENAME}") | |
image.save(temp_path, "PNG", pnginfo=meta, optimize=True) | |
return temp_path | |
def format_generation_info(params: Dict[str, Any], generation_time: float) -> str: | |
""" | |
Formats the generation information into a human-readable string for display. | |
""" | |
info_lines = [ | |
f"β Generated in {generation_time:.1f}s", | |
f"π Resolution: {params.get('width', 'N/A')}Γ{params.get('height', 'N/A')}", | |
f"π― Prompt: {params.get('prompt', '')[:60]}{'...' if len(params.get('prompt', '')) > 60 else ''}", | |
f"π« Negative: {params.get('negative_prompt', 'None')[:40]}{'...' if len(params.get('negative_prompt', '')) > 40 else ''}", | |
f"π² Seed: {params.get('seed', 'N/A')}", | |
f"π Steps: {params.get('steps', 'N/A')} | CFG: {params.get('guidance_scale', 'N/A')}" | |
] | |
if 'strength' in params: | |
info_lines.append(f"πͺ Strength: {params['strength']}") | |
return "\n".join(info_lines) | |
# Increased duration for model loading and generation | |
def generate_txt2img(prompt: str, negative_prompt: str, steps: int, guidance_scale: float, | |
width: int, height: int, seed: int, add_quality: bool) -> Tuple: | |
""" | |
Handles text-to-image generation, including parameter processing, model inference, | |
NSFW detection, and metadata creation. | |
""" | |
if not prompt.strip(): | |
return None, None, "β Please enter a prompt." | |
# Lazy load models if not already loaded | |
if not pipe_manager.load_models(): | |
return None, None, "β Failed to load model. Please try again." | |
try: | |
pipe_manager.clear_memory() # Clear memory before generation | |
# Process parameters | |
width, height = validate_and_fix_dimensions(width, height) | |
if seed == -1: | |
seed = random.randint(0, MAX_SEED) | |
enhanced_prompt = enhance_prompt(prompt, add_quality) | |
generator = torch.Generator(device=DEVICE).manual_seed(seed) | |
# Generation parameters dictionary | |
gen_params = { | |
"prompt": enhanced_prompt, | |
"negative_prompt": negative_prompt or DEFAULT_NEGATIVE, | |
"num_inference_steps": min(max(steps, 10), 50), # Clamp steps to a reasonable range | |
"guidance_scale": max(1.0, min(guidance_scale, 20.0)), # Clamp guidance scale | |
"width": width, | |
"height": height, | |
"generator": generator, | |
"output_type": "pil" | |
} | |
logger.info(f"Generating: {enhanced_prompt[:50]}...") | |
start_time = time.time() | |
with torch.inference_mode(): | |
result = pipe_manager.txt2img_pipe(**gen_params) | |
generation_time = time.time() - start_time | |
# Perform NSFW Detection on the generated image | |
is_nsfw_result, nsfw_confidence = pipe_manager.is_nsfw(result.images[0], enhanced_prompt) | |
if is_nsfw_result: | |
# If NSFW, blur the image and return a warning message | |
blurred_image = result.images[0].filter(ImageFilter.GaussianBlur(radius=20)) | |
warning_msg = f"β οΈ Content flagged as potentially inappropriate (confidence: {nsfw_confidence:.2f}). Image has been blurred." | |
# Still save metadata but mark as filtered | |
metadata = { | |
"prompt": enhanced_prompt, | |
"negative_prompt": negative_prompt or DEFAULT_NEGATIVE, | |
"steps": gen_params["num_inference_steps"], | |
"guidance_scale": gen_params["guidance_scale"], | |
"width": width, | |
"height": height, | |
"seed": seed, | |
"sampler": "Euler Ancestral", | |
"model_hash": "cyberrealistic_pony_v110", | |
"nsfw_filtered": "true", | |
"nsfw_confidence": f"{nsfw_confidence:.3f}" | |
} | |
png_path = create_metadata_png(blurred_image, metadata) | |
info_text = f"{warning_msg}\n\n{format_generation_info(metadata, generation_time)}" | |
return blurred_image, png_path, info_text | |
# If not NSFW, prepare metadata and save the original image | |
metadata = { | |
"prompt": enhanced_prompt, | |
"negative_prompt": negative_prompt or DEFAULT_NEGATIVE, | |
"steps": gen_params["num_inference_steps"], | |
"guidance_scale": gen_params["guidance_scale"], | |
"width": width, | |
"height": height, | |
"seed": seed, | |
"sampler": "Euler Ancestral", | |
"model_hash": "cyberrealistic_pony_v110" | |
} | |
# Save with metadata | |
png_path = create_metadata_png(result.images[0], metadata) | |
info_text = format_generation_info(metadata, generation_time) | |
return result.images[0], png_path, info_text | |
except torch.cuda.OutOfMemoryError: | |
pipe_manager.clear_memory() | |
return None, None, "β GPU out of memory. Try smaller dimensions or fewer steps." | |
except Exception as e: | |
logger.error(f"Generation error: {e}") | |
return None, None, f"β Generation failed: {str(e)}" | |
finally: | |
pipe_manager.clear_memory() # Ensure memory is cleared even if an occurs | |
# Increased duration for model loading and generation | |
def generate_img2img(input_image: Image.Image, prompt: str, negative_prompt: str, | |
steps: int, guidance_scale: float, strength: float, seed: int, | |
add_quality: bool) -> Tuple: | |
""" | |
Handles image-to-image generation, including image preprocessing, parameter processing, | |
model inference, NSFW detection, and metadata creation. | |
""" | |
if input_image is None: | |
return None, None, "β Please upload an input image." | |
if not prompt.strip(): | |
return None, None, "β Please enter a prompt." | |
# Lazy load models if not already loaded | |
if not pipe_manager.load_models(): | |
return None, None, "β Failed to load model. Please try again." | |
try: | |
pipe_manager.clear_memory() # Clear memory before generation | |
# Process input image: convert to RGB if necessary | |
if input_image.mode != 'RGB': | |
input_image = input_image.convert('RGB') | |
# Smart resizing maintaining aspect ratio to fit within max_dimension | |
original_size = input_image.size | |
max_dimension = 1024 | |
if max(original_size) > max_dimension: | |
input_image.thumbnail((max_dimension, max_dimension), Image.Resampling.LANCZOS) | |
# Ensure SDXL compatible dimensions (multiples of 64) | |
w, h = validate_and_fix_dimensions(*input_image.size) | |
input_image = input_image.resize((w, h), Image.Resampling.LANCZOS) | |
# Process other parameters | |
if seed == -1: | |
seed = random.randint(0, MAX_SEED) | |
enhanced_prompt = enhance_prompt(prompt, add_quality) | |
generator = torch.Generator(device=DEVICE).manual_seed(seed) | |
# Generation parameters dictionary | |
gen_params = { | |
"prompt": enhanced_prompt, | |
"negative_prompt": negative_prompt or DEFAULT_NEGATIVE, | |
"image": input_image, | |
"num_inference_steps": min(max(steps, 10), 50), # Clamp steps | |
"guidance_scale": max(1.0, min(guidance_scale, 20.0)), # Clamp guidance scale | |
"strength": max(0.1, min(strength, 1.0)), # Clamp strength | |
"generator": generator, | |
"output_type": "pil" | |
} | |
logger.info(f"Transforming: {enhanced_prompt[:50]}...") | |
start_time = time.time() | |
with torch.inference_mode(): | |
result = pipe_manager.img2img_pipe(**gen_params) | |
generation_time = time.time() - start_time | |
# Perform NSFW Detection on the transformed image | |
is_nsfw_result, nsfw_confidence = pipe_manager.is_nsfw(result.images[0], enhanced_prompt) | |
if is_nsfw_result: | |
# If NSFW, blur the image and return a warning message | |
blurred_image = result.images[0].filter(ImageFilter.GaussianBlur(radius=20)) | |
warning_msg = f"β οΈ Content flagged as potentially inappropriate (confidence: {nsfw_confidence:.2f}). Image has been blurred." | |
metadata = { | |
"prompt": enhanced_prompt, | |
"negative_prompt": negative_prompt or DEFAULT_NEGATIVE, | |
"steps": gen_params["num_inference_steps"], | |
"guidance_scale": gen_params["guidance_scale"], | |
"strength": gen_params["strength"], | |
"width": w, | |
"height": h, | |
"seed": seed, | |
"sampler": "Euler Ancestral", | |
"model_hash": "cyberrealistic_pony_v110", | |
"nsfw_filtered": "true", | |
"nsfw_confidence": f"{nsfw_confidence:.3f}" | |
} | |
png_path = create_metadata_png(blurred_image, metadata) | |
info_text = f"{warning_msg}\n\n{format_generation_info(metadata, generation_time)}" | |
return blurred_image, png_path, info_text | |
# If not NSFW, prepare metadata and save the original image | |
metadata = { | |
"prompt": enhanced_prompt, | |
"negative_prompt": negative_prompt or DEFAULT_NEGATIVE, | |
"steps": gen_params["num_inference_steps"], | |
"guidance_scale": gen_params["guidance_scale"], | |
"strength": gen_params["strength"], | |
"width": w, | |
"height": h, | |
"seed": seed, | |
"sampler": "Euler Ancestral", | |
"model_hash": "cyberrealistic_pony_v110" | |
} | |
png_path = create_metadata_png(result.images[0], metadata) | |
info_text = format_generation_info(metadata, generation_time) | |
return result.images[0], png_path, info_text | |
except torch.cuda.OutOfMemoryError: | |
pipe_manager.clear_memory() | |
return None, None, "β GPU out of memory. Try lower strength or fewer steps." | |
except Exception as e: | |
logger.error(f"Generation error: {e}") | |
return None, None, f"β Generation failed: {str(e)}" | |
finally: | |
pipe_manager.clear_memory() # Ensure memory is cleared even if an error occurs | |
def get_random_prompt(): | |
"""Returns a random example prompt from a predefined list.""" | |
return random.choice(EXAMPLE_PROMPTS) | |
# Enhanced Gradio interface | |
def create_interface(): | |
""" | |
Creates and returns the Gradio Blocks interface for the CyberRealistic Pony Generator. | |
This includes tabs for Text-to-Image and Image-to-Image, along with controls and outputs. | |
""" | |
with gr.Blocks( | |
title="CyberRealistic Pony - SDXL Generator", | |
theme=gr.themes.Soft(primary_hue="blue"), | |
css=""" | |
.generate-btn { | |
background: linear-gradient(45deg, #667eea 0%, #764ba2 100%) !important; | |
border: none !important; | |
} | |
.generate-btn:hover { | |
transform: translateY(-2px); | |
box-shadow: 0 4px 12px rgba(0,0,0,0.2); | |
} | |
""" | |
) as demo: | |
gr.Markdown(""" | |
# π¨ CyberRealistic Pony Generator | |
**High-quality SDXL image generation** β’ Optimized for HuggingFace Spaces β’ **NSFW Content Filter Enabled** | |
> β‘ **First generation takes longer** (model loading) β’ π **Metadata embedded** in all outputs β’ π‘οΈ **Content filtered for safety** | |
""") | |
with gr.Tabs(): | |
# Text to Image Tab | |
with gr.TabItem("π¨ Text to Image", id="txt2img"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
with gr.Group(): | |
txt_prompt = gr.Textbox( | |
label="β¨ Prompt", | |
placeholder="A beautiful landscape with mountains and sunset...", | |
lines=3, | |
max_lines=5 | |
) | |
with gr.Row(): | |
txt_example_btn = gr.Button("π² Random", size="sm") | |
txt_clear_btn = gr.Button("ποΈ Clear", size="sm") | |
with gr.Accordion("βοΈ Advanced Settings", open=False): | |
txt_negative = gr.Textbox( | |
label="β Negative Prompt", | |
value=DEFAULT_NEGATIVE, | |
lines=2, | |
max_lines=3 | |
) | |
txt_quality = gr.Checkbox( | |
label="β¨ Add Quality Tags", | |
value=True, | |
info="Automatically enhance prompt with quality tags" | |
) | |
with gr.Row(): | |
txt_steps = gr.Slider( | |
10, 50, 25, step=1, | |
label="π Steps", | |
info="More steps = better quality, slower generation" | |
) | |
txt_guidance = gr.Slider( | |
1.0, 15.0, 7.5, step=0.5, | |
label="ποΈ CFG Scale", | |
info="How closely to follow the prompt" | |
) | |
with gr.Row(): | |
txt_width = gr.Slider( | |
512, 1024, 768, step=64, | |
label="π Width" | |
) | |
txt_height = gr.Slider( | |
512, 1024, 768, step=64, | |
label="π Height" | |
) | |
txt_seed = gr.Slider( | |
-1, MAX_SEED, -1, step=1, | |
label="π² Seed (-1 = random)", | |
info="Use same seed for reproducible results" | |
) | |
txt_generate_btn = gr.Button( | |
"π¨ Generate Image", | |
variant="primary", | |
size="lg", | |
elem_classes=["generate-btn"] | |
) | |
with gr.Column(scale=1): | |
txt_output_image = gr.Image( | |
label="πΌοΈ Generated Image", | |
height=500, | |
show_download_button=True | |
) | |
txt_download_file = gr.File( | |
label="π₯ Download PNG (with metadata)", | |
file_types=[".png"] | |
) | |
txt_info = gr.Textbox( | |
label="βΉοΈ Generation Info", | |
lines=6, | |
max_lines=8, | |
interactive=False | |
) | |
# Image to Image Tab | |
with gr.TabItem("πΌοΈ Image to Image", id="img2img"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
img_input = gr.Image( | |
label="π€ Input Image", | |
type="pil", | |
height=300 | |
) | |
with gr.Group(): | |
img_prompt = gr.Textbox( | |
label="β¨ Transformation Prompt", | |
placeholder="digital art style, vibrant colors...", | |
lines=3 | |
) | |
with gr.Row(): | |
img_example_btn = gr.Button("π² Random", size="sm") | |
img_clear_btn = gr.Button("ποΈ Clear", size="sm") | |
with gr.Accordion("βοΈ Advanced Settings", open=False): | |
img_negative = gr.Textbox( | |
label="β Negative Prompt", | |
value=DEFAULT_NEGATIVE, | |
lines=2 | |
) | |
img_quality = gr.Checkbox( | |
label="β¨ Add Quality Tags", | |
value=True | |
) | |
with gr.Row(): | |
img_steps = gr.Slider(10, 50, 25, step=1, label="π Steps") | |
img_guidance = gr.Slider(1.0, 15.0, 7.5, step=0.5, label="ποΈ CFG") | |
img_strength = gr.Slider( | |
0.1, 1.0, 0.75, step=0.05, | |
label="πͺ Transformation Strength", | |
info="Higher = more creative, lower = more faithful to input" | |
) | |
img_seed = gr.Slider(-1, MAX_SEED, -1, step=1, label="π² Seed") | |
img_generate_btn = gr.Button( | |
"πΌοΈ Transform Image", | |
variant="primary", | |
size="lg", | |
elem_classes=["generate-btn"] | |
) | |
with gr.Column(scale=1): | |
img_output_image = gr.Image( | |
label="πΌοΈ Transformed Image", | |
height=500, | |
show_download_button=True | |
) | |
img_download_file = gr.File( | |
label="π₯ Download PNG (with metadata)", | |
file_types=[".png"] | |
) | |
img_info = gr.Textbox( | |
label="βΉοΈ Generation Info", | |
lines=6, | |
interactive=False | |
) | |
# Event handlers | |
txt_generate_btn.click( | |
fn=generate_txt2img, | |
inputs=[txt_prompt, txt_negative, txt_steps, txt_guidance, | |
txt_width, txt_height, txt_seed, txt_quality], | |
outputs=[txt_output_image, txt_download_file, txt_info], | |
show_progress=True | |
) | |
img_generate_btn.click( | |
fn=generate_img2img, | |
inputs=[img_input, img_prompt, img_negative, img_steps, img_guidance, | |
img_strength, img_seed, img_quality], | |
outputs=[img_output_image, img_download_file, img_info], | |
show_progress=True | |
) | |
# Example prompt buttons | |
txt_example_btn.click(fn=get_random_prompt, outputs=[txt_prompt]) | |
img_example_btn.click(fn=get_random_prompt, outputs=[img_prompt]) | |
# Clear buttons | |
txt_clear_btn.click(lambda: "", outputs=[txt_prompt]) | |
img_clear_btn.click(lambda: "", outputs=[img_prompt]) | |
return demo | |
# Initialize and launch the Gradio application | |
if __name__ == "__main__": | |
logger.info(f"π Initializing CyberRealistic Pony Generator on {DEVICE}") | |
logger.info(f"π± PyTorch version: {torch.__version__}") | |
logger.info(f"π‘οΈ NSFW Content Filter: Enabled") | |
demo = create_interface() | |
demo.queue(max_size=20) # Enable queuing for better user experience | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_error=True, | |
share=False # Set to True if you want a public link (e.g., for Hugging Face Spaces) | |
) | |