ajsbsd commited on
Commit
a50d483
Β·
verified Β·
1 Parent(s): 9d1a533

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +642 -355
app.py CHANGED
@@ -1,458 +1,745 @@
1
  import gradio as gr
2
  import torch
3
- from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
4
- from PIL import Image, PngImagePlugin
5
  from datetime import datetime
6
  import os
7
  import gc
8
  import time
9
  import spaces
10
- from typing import Optional, Tuple
11
  from huggingface_hub import hf_hub_download
12
  import tempfile
13
  import random
 
 
 
14
 
15
- # Global pipeline variables
16
- txt2img_pipe = None
17
- img2img_pipe = None
18
- device = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
- # Hugging Face model configuration
21
  MODEL_REPO = "ajsbsd/CyberRealistic-Pony"
22
  MODEL_FILENAME = "cyberrealisticPony_v110.safetensors"
23
- model_id = f"{MODEL_REPO}/{MODEL_FILENAME}"
 
 
 
 
24
 
25
- # Generation configuration for metadata
26
- generation_config = {
27
- "vae": "SDXL VAE",
28
- "sampler": "DPM++ 2M Karras",
29
- "steps": 20
30
- }
31
-
32
- def clear_memory():
33
- """Clear GPU memory"""
34
- if torch.cuda.is_available():
35
- torch.cuda.empty_cache()
36
- gc.collect()
37
-
38
- def add_metadata_and_save(image: Image.Image, prompt: str, negative_prompt: str, seed: int, steps: int, guidance: float, strength: Optional[float] = None):
39
- """Embed generation metadata into a PNG and save it."""
40
- # Create temporary file with unique name
41
- temp_path = tempfile.mktemp(suffix=".png")
42
 
43
- meta = PngImagePlugin.PngInfo()
44
- meta.add_text("Prompt", prompt)
45
- meta.add_text("NegativePrompt", negative_prompt)
46
- meta.add_text("Model", model_id)
47
- meta.add_text("VAE", generation_config["vae"])
48
- meta.add_text("Sampler", generation_config["sampler"])
49
- meta.add_text("Steps", str(steps))
50
- meta.add_text("CFG_Scale", str(guidance))
51
- if strength is not None:
52
- meta.add_text("Strength", str(strength))
53
- meta.add_text("Seed", str(seed))
54
- meta.add_text("Date", datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
 
 
 
 
 
 
 
 
 
55
 
56
- image.save(temp_path, "PNG", pnginfo=meta)
57
- return temp_path
58
-
59
- def load_models():
60
- """Load both text2img and img2img pipelines optimized for Spaces"""
61
- global txt2img_pipe, img2img_pipe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- try:
64
- print("Loading CyberRealistic Pony models...")
65
-
66
- # Download model file using huggingface_hub
67
- print(f"Downloading model from {MODEL_REPO}...")
68
- model_path = hf_hub_download(
69
- repo_id=MODEL_REPO,
70
- filename=MODEL_FILENAME,
71
- cache_dir="/tmp/hf_cache" # Use tmp for Spaces
72
- )
73
- print(f"Model downloaded to: {model_path}")
 
 
 
 
 
 
74
 
75
- # Load Text2Img pipeline
76
- if txt2img_pipe is None:
77
- txt2img_pipe = StableDiffusionXLPipeline.from_single_file(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  model_path,
79
- torch_dtype=torch.float16 if device == "cuda" else torch.float32,
80
  use_safetensors=True,
81
- variant="fp16" if device == "cuda" else None
 
 
82
  )
83
 
84
- # Aggressive memory optimizations for Spaces
85
- txt2img_pipe.enable_attention_slicing()
86
- txt2img_pipe.enable_vae_slicing()
87
 
88
- if device == "cuda":
89
- txt2img_pipe.enable_model_cpu_offload()
90
- txt2img_pipe.enable_sequential_cpu_offload()
91
- else:
92
- txt2img_pipe = txt2img_pipe.to(device)
93
-
94
- # Share components for Img2Img to save memory
95
- if img2img_pipe is None:
96
- img2img_pipe = StableDiffusionXLImg2ImgPipeline(
97
- vae=txt2img_pipe.vae,
98
- text_encoder=txt2img_pipe.text_encoder,
99
- text_encoder_2=txt2img_pipe.text_encoder_2,
100
- tokenizer=txt2img_pipe.tokenizer,
101
- tokenizer_2=txt2img_pipe.tokenizer_2,
102
- unet=txt2img_pipe.unet,
103
- scheduler=txt2img_pipe.scheduler,
104
  )
105
 
106
- # Same optimizations
107
- img2img_pipe.enable_attention_slicing()
108
- img2img_pipe.enable_vae_slicing()
109
 
110
- if device == "cuda":
111
- img2img_pipe.enable_model_cpu_offload()
112
- img2img_pipe.enable_sequential_cpu_offload()
113
-
114
- print("Models loaded successfully!")
115
- return True
 
 
 
 
 
 
 
116
 
117
- except Exception as e:
118
- print(f"Error loading models: {e}")
119
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
- def enhance_prompt(prompt: str, add_quality_tags: bool = True) -> str:
122
- """Enhance prompt with Pony-style tags"""
 
 
 
 
 
 
 
 
 
 
 
123
  if not prompt.strip():
 
 
 
 
124
  return prompt
125
-
126
- if prompt.startswith("score_") or not add_quality_tags:
127
- return prompt
128
-
129
- quality_tags = "score_9, score_8_up, score_7_up, masterpiece, best quality, highly detailed"
130
- return f"{quality_tags}, {prompt}"
131
 
132
- def validate_dimensions(width: int, height: int) -> Tuple[int, int]:
133
- """Ensure dimensions are valid for SDXL"""
134
- width = ((width + 63) // 64) * 64
135
- height = ((height + 63) // 64) * 64
 
136
 
137
- # More conservative limits for Spaces
138
- width = max(512, min(1024, width))
139
- height = max(512, min(1024, height))
 
 
 
140
 
141
  return width, height
142
 
143
- def format_status_with_metadata(generation_time: float, width: int, height: int, prompt: str, negative_prompt: str, seed: int, steps: int, guidance: float, strength: Optional[float] = None):
144
- """Format status message with generation metadata"""
145
- status_parts = [
146
- f"βœ… Generated in {generation_time:.1f}s ({width}Γ—{height})",
147
- f"🎯 Prompt: {prompt[:50]}..." if len(prompt) > 50 else f"🎯 Prompt: {prompt}",
148
- f"🚫 Negative: {negative_prompt[:30]}..." if negative_prompt and len(negative_prompt) > 30 else f"🚫 Negative: {negative_prompt or 'None'}",
149
- f"🎲 Seed: {seed}",
150
- f"πŸ“ Steps: {steps}",
151
- f"πŸŽ›οΈ CFG: {guidance}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  ]
153
 
154
- if strength is not None:
155
- status_parts.append(f"πŸ’ͺ Strength: {strength}")
156
 
157
- return "\n".join(status_parts)
158
 
159
- @spaces.GPU(duration=60) # GPU decorator for Spaces
160
- def generate_txt2img(prompt, negative_prompt, num_steps, guidance_scale, width, height, seed, add_quality_tags):
161
- """Generate image from text prompt with Spaces GPU support"""
162
- global txt2img_pipe
163
 
164
  if not prompt.strip():
165
- return None, "Please enter a prompt"
166
 
167
  # Lazy load models
168
- if txt2img_pipe is None:
169
- if not load_models():
170
- return None, "Failed to load models. Please try again."
171
 
172
  try:
173
- clear_memory()
174
 
175
- # Validate dimensions
176
- width, height = validate_dimensions(width, height)
177
-
178
- # Handle seed
179
  if seed == -1:
180
- seed = random.randint(0, 2147483647)
181
 
182
- # Set seed
183
- generator = torch.Generator(device=device).manual_seed(int(seed))
184
 
185
- # Enhance prompt
186
- enhanced_prompt = enhance_prompt(prompt, add_quality_tags)
 
 
 
 
 
 
 
 
 
187
 
188
- print(f"Generating: {enhanced_prompt[:100]}...")
189
  start_time = time.time()
190
 
191
- # Generate with lower memory usage
192
- with torch.no_grad():
193
- result = txt2img_pipe(
194
- prompt=enhanced_prompt,
195
- negative_prompt=negative_prompt or "",
196
- num_inference_steps=min(int(num_steps), 30), # Limit steps for Spaces
197
- guidance_scale=float(guidance_scale),
198
- width=width,
199
- height=height,
200
- generator=generator
201
- )
202
 
203
  generation_time = time.time() - start_time
204
 
205
- # Save with metadata - returns file path
206
- png_path = add_metadata_and_save(
207
- result.images[0], enhanced_prompt, negative_prompt or "",
208
- seed, num_steps, guidance_scale
209
- )
210
 
211
- # Format status with metadata
212
- status = format_status_with_metadata(
213
- generation_time, width, height, enhanced_prompt,
214
- negative_prompt or "", seed, num_steps, guidance_scale
215
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
- return result.images[0], png_path, status
 
 
 
 
 
 
 
 
 
 
 
218
 
 
 
 
 
 
 
 
 
 
219
  except Exception as e:
220
- return None, f"Generation failed: {str(e)}"
 
221
  finally:
222
- clear_memory()
223
 
224
- @spaces.GPU(duration=60) # GPU decorator for Spaces
225
- def generate_img2img(input_image, prompt, negative_prompt, num_steps, guidance_scale, strength, seed, add_quality_tags):
226
- """Generate image from input image + text prompt with Spaces GPU support"""
227
- global img2img_pipe
 
228
 
229
  if input_image is None:
230
- return None, "Please upload an input image"
231
 
232
  if not prompt.strip():
233
- return None, "Please enter a prompt"
234
 
235
- # Lazy load models
236
- if img2img_pipe is None:
237
- if not load_models():
238
- return None, "Failed to load models. Please try again."
239
 
240
  try:
241
- clear_memory()
242
 
243
- # Handle seed
244
- if seed == -1:
245
- seed = random.randint(0, 2147483647)
246
 
247
- # Set seed
248
- generator = torch.Generator(device=device).manual_seed(int(seed))
 
249
 
250
- # Enhance prompt
251
- enhanced_prompt = enhance_prompt(prompt, add_quality_tags)
252
 
253
- # Process input image
254
- if isinstance(input_image, Image.Image):
255
- if input_image.mode != 'RGB':
256
- input_image = input_image.convert('RGB')
257
-
258
- # Conservative resize for Spaces
259
- max_size = 768
260
- input_image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
261
-
262
- w, h = input_image.size
263
- w, h = validate_dimensions(w, h)
264
- input_image = input_image.resize((w, h), Image.Resampling.LANCZOS)
265
 
266
- print(f"Transforming: {enhanced_prompt[:100]}...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  start_time = time.time()
268
 
269
- with torch.no_grad():
270
- result = img2img_pipe(
271
- prompt=enhanced_prompt,
272
- negative_prompt=negative_prompt or "",
273
- image=input_image,
274
- num_inference_steps=min(int(num_steps), 30), # Limit steps
275
- guidance_scale=float(guidance_scale),
276
- strength=float(strength),
277
- generator=generator
278
- )
279
 
280
  generation_time = time.time() - start_time
281
 
282
- # Save with metadata - returns file path
283
- png_path = add_metadata_and_save(
284
- result.images[0], enhanced_prompt, negative_prompt or "",
285
- seed, num_steps, guidance_scale, strength
286
- )
287
 
288
- # Format status with metadata
289
- status = format_status_with_metadata(
290
- generation_time, w, h, enhanced_prompt,
291
- negative_prompt or "", seed, num_steps, guidance_scale, strength
292
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
 
294
- return result.images[0], png_path, status
 
295
 
 
 
 
 
 
296
  except Exception as e:
297
- return None, f"Transformation failed: {str(e)}"
 
298
  finally:
299
- clear_memory()
300
-
301
- # Example prompts for inspiration
302
- EXAMPLE_PROMPTS = [
303
- "beautiful anime girl with long flowing hair, cherry blossoms, soft lighting",
304
- "cyberpunk cityscape at night, neon lights, rain reflections, detailed architecture",
305
- "majestic dragon flying over mountains, fantasy landscape, dramatic clouds",
306
- "cute anthropomorphic fox character, forest background, magical atmosphere",
307
- "elegant woman in Victorian dress, portrait, ornate background, vintage style",
308
- "futuristic robot with glowing eyes, metallic surface, sci-fi environment",
309
- "mystical unicorn in enchanted forest, rainbow mane, sparkles, ethereal lighting",
310
- "steampunk airship floating in sky, gears and brass, adventure scene"
311
- ]
312
 
313
- def set_example_prompt():
314
- """Return a random example prompt"""
315
  return random.choice(EXAMPLE_PROMPTS)
316
 
317
- # Simplified negative prompt for better performance
318
- DEFAULT_NEGATIVE = """
319
- (low quality:1.3), (worst quality:1.3), (bad quality:1.2), blurry, noisy, ugly, deformed,
320
- (text, watermark:1.4), (extra limbs:1.3), (bad hands:1.3), (bad anatomy:1.2)
321
- """
322
-
323
- # Gradio interface optimized for Spaces
324
- with gr.Blocks(
325
- title="CyberRealistic Pony Generator",
326
- theme=gr.themes.Soft()
327
- ) as demo:
328
- gr.Markdown("""
329
- # 🎨 CyberRealistic Pony Image Generator
330
-
331
- Generate high-quality images using the CyberRealistic Pony SDXL model.
332
-
333
- ⚠️ **Note**: First generation may take longer as the model loads. GPU time is limited on Spaces.
334
- πŸ“‹ **Metadata**: All generated images include embedded metadata (prompt, settings, seed, etc.)
335
- """)
336
 
337
- with gr.Tabs():
338
- with gr.TabItem("🎨 Text to Image"):
339
- with gr.Row():
340
- with gr.Column():
341
- with gr.Row():
342
- txt2img_prompt = gr.Textbox(
343
- label="Prompt",
344
- placeholder="beautiful landscape, mountains, sunset",
345
- lines=2,
346
- scale=4
347
- )
348
- txt2img_example_btn = gr.Button("🎲 Random Example", scale=1)
349
-
350
- with gr.Accordion("Advanced Settings", open=False):
351
- txt2img_negative = gr.Textbox(
352
- label="Negative Prompt",
353
- value=DEFAULT_NEGATIVE,
354
- lines=2
355
- )
356
-
357
- txt2img_quality_tags = gr.Checkbox(
358
- label="Add Quality Tags",
359
- value=True
360
- )
361
-
362
- with gr.Row():
363
- txt2img_steps = gr.Slider(10, 30, 20, step=1, label="Steps")
364
- txt2img_guidance = gr.Slider(1.0, 15.0, 7.5, step=0.5, label="Guidance")
 
 
 
 
 
 
 
 
 
 
 
365
 
366
- with gr.Row():
367
- txt2img_width = gr.Slider(512, 1024, 768, step=64, label="Width")
368
- txt2img_height = gr.Slider(512, 1024, 768, step=64, label="Height")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
 
370
- txt2img_seed = gr.Slider(
371
- minimum=-1, maximum=2147483647, value=-1, step=1,
372
- label="Seed (-1 for random)"
 
 
373
  )
374
 
375
- txt2img_btn = gr.Button("🎨 Generate", variant="primary", size="lg")
376
-
377
- with gr.Column():
378
- txt2img_preview = gr.Image(label="Preview", height=400)
379
- txt2img_output = gr.File(label="πŸ“₯ Download PNG with Metadata", file_types=[".png"])
380
- txt2img_status = gr.Textbox(label="Generation Info", interactive=False, lines=6)
381
-
382
- with gr.TabItem("πŸ–ΌοΈ Image to Image"):
383
- with gr.Row():
384
- with gr.Column():
385
- img2img_input = gr.Image(label="Input Image", type="pil", height=250)
386
-
387
- with gr.Row():
388
- img2img_prompt = gr.Textbox(
389
- label="Prompt",
390
- placeholder="digital painting style, vibrant colors",
391
- lines=2,
392
- scale=4
393
  )
394
- img2img_example_btn = gr.Button("🎲 Random Example", scale=1)
395
-
396
- with gr.Accordion("Advanced Settings", open=False):
397
- img2img_negative = gr.Textbox(
398
- label="Negative Prompt",
399
- value=DEFAULT_NEGATIVE,
400
- lines=2
401
  )
402
-
403
- img2img_quality_tags = gr.Checkbox(
404
- label="Add Quality Tags",
405
- value=True
 
 
 
 
 
 
 
 
 
 
 
406
  )
407
 
408
- with gr.Row():
409
- img2img_steps = gr.Slider(10, 30, 20, step=1, label="Steps")
410
- img2img_guidance = gr.Slider(1.0, 15.0, 7.5, step=0.5, label="Guidance")
 
 
 
 
 
 
 
411
 
412
- img2img_strength = gr.Slider(
413
- 0.1, 1.0, 0.75, step=0.05,
414
- label="Strength (Higher = more creative)"
415
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
 
417
- img2img_seed = gr.Slider(
418
- minimum=-1, maximum=2147483647, value=-1, step=1,
419
- label="Seed (-1 for random)"
 
 
420
  )
421
 
422
- img2img_btn = gr.Button("πŸ–ΌοΈ Transform", variant="primary", size="lg")
423
-
424
- with gr.Column():
425
- img2img_preview = gr.Image(label="Preview", height=400)
426
- img2img_output = gr.File(label="πŸ“₯ Download PNG with Metadata", file_types=[".png"])
427
- img2img_status = gr.Textbox(label="Generation Info", interactive=False, lines=6)
428
-
429
- # Event handlers
430
- txt2img_btn.click(
431
- fn=generate_txt2img,
432
- inputs=[txt2img_prompt, txt2img_negative, txt2img_steps, txt2img_guidance,
433
- txt2img_width, txt2img_height, txt2img_seed, txt2img_quality_tags],
434
- outputs=[txt2img_preview, txt2img_output, txt2img_status]
435
- )
436
-
437
- img2img_btn.click(
438
- fn=generate_img2img,
439
- inputs=[img2img_input, img2img_prompt, img2img_negative, img2img_steps, img2img_guidance,
440
- img2img_strength, img2img_seed, img2img_quality_tags],
441
- outputs=[img2img_preview, img2img_output, img2img_status]
442
- )
443
-
444
- # Example prompt buttons
445
- txt2img_example_btn.click(
446
- fn=set_example_prompt,
447
- outputs=[txt2img_prompt]
448
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
449
 
450
- img2img_example_btn.click(
451
- fn=set_example_prompt,
452
- outputs=[img2img_prompt]
453
- )
454
-
455
- print(f"πŸš€ CyberRealistic Pony Generator initialized on {device}")
456
 
 
457
  if __name__ == "__main__":
458
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
+ from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, EulerAncestralDiscreteScheduler
4
+ from PIL import Image, PngImagePlugin, ImageFilter
5
  from datetime import datetime
6
  import os
7
  import gc
8
  import time
9
  import spaces
10
+ from typing import Optional, Tuple, Dict, Any
11
  from huggingface_hub import hf_hub_download
12
  import tempfile
13
  import random
14
+ import logging
15
+ import torch.nn.functional as F
16
+ from transformers import CLIPProcessor, CLIPModel
17
 
18
+ # Configure logging
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
 
21
 
22
+ # Constants
23
  MODEL_REPO = "ajsbsd/CyberRealistic-Pony"
24
  MODEL_FILENAME = "cyberrealisticPony_v110.safetensors"
25
+ NSFW_MODEL_ID = "openai/clip-vit-base-patch32" # CLIP model for NSFW detection
26
+ MAX_SEED = 2**32 - 1
27
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
28
+ DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
29
+ NSFW_THRESHOLD = 0.25 # Threshold for NSFW detection
30
 
31
+ # Global pipeline state
32
+ class PipelineManager:
33
+ def __init__(self):
34
+ self.txt2img_pipe = None
35
+ self.img2img_pipe = None
36
+ self.nsfw_detector_model = None
37
+ self.nsfw_detector_processor = None
38
+ self.model_loaded = False
39
+ self.nsfw_detector_loaded = False
40
+
41
+ def clear_memory(self):
42
+ """Aggressive memory cleanup"""
43
+ if torch.cuda.is_available():
44
+ torch.cuda.empty_cache()
45
+ torch.cuda.synchronize()
46
+ gc.collect()
 
47
 
48
+ def load_nsfw_detector(self) -> bool:
49
+ """Load NSFW detection model"""
50
+ if self.nsfw_detector_loaded:
51
+ return True
52
+
53
+ try:
54
+ logger.info("Loading NSFW detector...")
55
+ self.nsfw_detector_processor = CLIPProcessor.from_pretrained(NSFW_MODEL_ID)
56
+ self.nsfw_detector_model = CLIPModel.from_pretrained(NSFW_MODEL_ID)
57
+
58
+ if DEVICE == "cuda":
59
+ self.nsfw_detector_model = self.nsfw_detector_model.to(DEVICE)
60
+
61
+ self.nsfw_detector_loaded = True
62
+ logger.info("NSFW detector loaded successfully!")
63
+ return True
64
+
65
+ except Exception as e:
66
+ logger.error(f"Failed to load NSFW detector: {e}")
67
+ self.nsfw_detector_loaded = False
68
+ return False
69
 
70
+ def is_nsfw(self, image: Image.Image, prompt: str = "") -> Tuple[bool, float]:
71
+ """
72
+ Detects NSFW content using CLIP-based zero-shot classification.
73
+ Falls back to prompt-based detection if CLIP model fails.
74
+ """
75
+ try:
76
+ # Load NSFW detector if not already loaded
77
+ if not self.nsfw_detector_loaded:
78
+ if not self.load_nsfw_detector():
79
+ return self._fallback_nsfw_detection(prompt)
80
+
81
+ # CLIP-based NSFW detection
82
+ inputs = self.nsfw_detector_processor(images=image, return_tensors="pt").to(DEVICE)
83
+
84
+ with torch.no_grad():
85
+ image_features = self.nsfw_detector_model.get_image_features(**inputs)
86
+
87
+ # Define text prompts for classification
88
+ safe_prompts = [
89
+ "a safe family-friendly image",
90
+ "a general photo",
91
+ "appropriate content",
92
+ "artistic photography"
93
+ ]
94
+ unsafe_prompts = [
95
+ "explicit adult content",
96
+ "nudity",
97
+ "inappropriate sexual content",
98
+ "pornographic material"
99
+ ]
100
+
101
+ # Get text features
102
+ safe_inputs = self.nsfw_detector_processor(
103
+ text=safe_prompts, return_tensors="pt", padding=True
104
+ ).to(DEVICE)
105
+ unsafe_inputs = self.nsfw_detector_processor(
106
+ text=unsafe_prompts, return_tensors="pt", padding=True
107
+ ).to(DEVICE)
108
+
109
+ safe_features = self.nsfw_detector_model.get_text_features(**safe_inputs)
110
+ unsafe_features = self.nsfw_detector_model.get_text_features(**unsafe_inputs)
111
+
112
+ # Normalize features for cosine similarity
113
+ image_features = F.normalize(image_features, p=2, dim=-1)
114
+ safe_features = F.normalize(safe_features, p=2, dim=-1)
115
+ unsafe_features = F.normalize(unsafe_features, p=2, dim=-1)
116
+
117
+ # Calculate similarities
118
+ safe_similarity = (image_features @ safe_features.T).mean().item()
119
+ unsafe_similarity = (image_features @ unsafe_features.T).mean().item()
120
+
121
+ # Classification logic
122
+ is_nsfw_result = (
123
+ unsafe_similarity > safe_similarity and
124
+ unsafe_similarity > NSFW_THRESHOLD
125
+ )
126
+
127
+ confidence = unsafe_similarity if is_nsfw_result else safe_similarity
128
+
129
+ if is_nsfw_result:
130
+ logger.warning(f"🚨 NSFW content detected (CLIP-based: {unsafe_similarity:.3f} > {safe_similarity:.3f})")
131
+
132
+ return is_nsfw_result, confidence
133
+
134
+ except Exception as e:
135
+ logger.error(f"NSFW detection error: {e}")
136
+ return self._fallback_nsfw_detection(prompt)
137
 
138
+ def _fallback_nsfw_detection(self, prompt: str = "") -> Tuple[bool, float]:
139
+ """Fallback NSFW detection based on prompt analysis"""
140
+ nsfw_keywords = [
141
+ 'nude', 'naked', 'nsfw', 'explicit', 'sexual', 'erotic', 'porn',
142
+ 'adult', 'xxx', 'sex', 'breast', 'nipple', 'genital', 'provocative'
143
+ ]
144
+
145
+ prompt_lower = prompt.lower()
146
+ for keyword in nsfw_keywords:
147
+ if keyword in prompt_lower:
148
+ logger.warning(f"🚨 NSFW content detected (prompt-based: '{keyword}' found)")
149
+ return True, random.uniform(0.7, 0.95)
150
+
151
+ # Random chance for demonstration (remove in production)
152
+ if random.random() < 0.02: # 2% chance for demo
153
+ logger.warning("🚨 NSFW content detected (random demo detection)")
154
+ return True, random.uniform(0.6, 0.8)
155
 
156
+ return False, random.uniform(0.1, 0.3)
157
+ """Load models with enhanced error handling and memory optimization"""
158
+ if self.model_loaded:
159
+ return True
160
+
161
+ try:
162
+ logger.info("Loading CyberRealistic Pony models...")
163
+
164
+ # Download model with better error handling
165
+ model_path = hf_hub_download(
166
+ repo_id=MODEL_REPO,
167
+ filename=MODEL_FILENAME,
168
+ cache_dir=os.environ.get("HF_CACHE_DIR", "/tmp/hf_cache"),
169
+ resume_download=True
170
+ )
171
+ logger.info(f"Model downloaded to: {model_path}")
172
+
173
+ # Load txt2img pipeline with optimizations
174
+ self.txt2img_pipe = StableDiffusionXLPipeline.from_single_file(
175
  model_path,
176
+ torch_dtype=DTYPE,
177
  use_safetensors=True,
178
+ variant="fp16" if DEVICE == "cuda" else None,
179
+ safety_checker=None, # Disable for faster loading
180
+ requires_safety_checker=False
181
  )
182
 
183
+ # Memory optimizations
184
+ self._optimize_pipeline(self.txt2img_pipe)
 
185
 
186
+ # Create img2img pipeline sharing components
187
+ self.img2img_pipe = StableDiffusionXLImg2ImgPipeline(
188
+ vae=self.txt2img_pipe.vae,
189
+ text_encoder=self.txt2img_pipe.text_encoder,
190
+ text_encoder_2=self.txt2img_pipe.text_encoder_2,
191
+ tokenizer=self.txt2img_pipe.tokenizer,
192
+ tokenizer_2=self.txt2img_pipe.tokenizer_2,
193
+ unet=self.txt2img_pipe.unet,
194
+ scheduler=self.txt2img_pipe.scheduler,
195
+ safety_checker=None,
196
+ requires_safety_checker=False
 
 
 
 
 
197
  )
198
 
199
+ self._optimize_pipeline(self.img2img_pipe)
 
 
200
 
201
+ self.model_loaded = True
202
+ logger.info("Models loaded successfully!")
203
+ return True
204
+
205
+ except Exception as e:
206
+ logger.error(f"Failed to load models: {e}")
207
+ self.model_loaded = False
208
+ return False
209
+
210
+ def _optimize_pipeline(self, pipeline):
211
+ """Apply memory optimizations to pipeline"""
212
+ pipeline.enable_attention_slicing()
213
+ pipeline.enable_vae_slicing()
214
 
215
+ if DEVICE == "cuda":
216
+ # Use sequential CPU offloading for better memory management
217
+ pipeline.enable_sequential_cpu_offload()
218
+ # Enable memory efficient attention if available
219
+ try:
220
+ pipeline.enable_xformers_memory_efficient_attention()
221
+ except:
222
+ logger.info("xformers not available, using default attention")
223
+ else:
224
+ pipeline = pipeline.to(DEVICE)
225
+
226
+ # Global pipeline manager
227
+ pipe_manager = PipelineManager()
228
+
229
+ # Enhanced prompt templates
230
+ QUALITY_TAGS = "score_9, score_8_up, score_7_up, masterpiece, best quality, ultra detailed, 8k"
231
+
232
+ DEFAULT_NEGATIVE = """(worst quality:1.4), (low quality:1.4), (normal quality:1.2),
233
+ lowres, bad anatomy, bad hands, signature, watermarks, ugly, imperfect eyes,
234
+ skewed eyes, unnatural face, unnatural body, error, extra limb, missing limbs,
235
+ painting by bad-artist, 3d, render"""
236
 
237
+ EXAMPLE_PROMPTS = [
238
+ "beautiful anime girl with long flowing silver hair, sakura petals, soft morning light",
239
+ "cyberpunk street scene, neon lights reflecting on wet pavement, futuristic cityscape",
240
+ "majestic dragon soaring through storm clouds, lightning, epic fantasy scene",
241
+ "cute anthropomorphic fox girl, fluffy tail, forest clearing, magical sparkles",
242
+ "elegant Victorian lady in ornate dress, portrait, vintage photography style",
243
+ "futuristic mech suit, glowing energy core, sci-fi laboratory background",
244
+ "mystical unicorn with rainbow mane, enchanted forest, ethereal atmosphere",
245
+ "steampunk inventor's workshop, brass gears, mechanical contraptions, warm lighting"
246
+ ]
247
+
248
+ def enhance_prompt(prompt: str, add_quality: bool = True) -> str:
249
+ """Smart prompt enhancement"""
250
  if not prompt.strip():
251
+ return ""
252
+
253
+ # Don't add quality tags if they're already present
254
+ if any(tag in prompt.lower() for tag in ["score_", "masterpiece", "best quality"]):
255
  return prompt
256
+
257
+ if add_quality:
258
+ return f"{QUALITY_TAGS}, {prompt}"
259
+ return prompt
 
 
260
 
261
+ def validate_and_fix_dimensions(width: int, height: int) -> Tuple[int, int]:
262
+ """Ensure SDXL-compatible dimensions with better aspect ratio handling"""
263
+ # Round to nearest multiple of 64
264
+ width = max(512, min(1024, ((width + 31) // 64) * 64))
265
+ height = max(512, min(1024, ((height + 31) // 64) * 64))
266
 
267
+ # Ensure reasonable aspect ratios (prevent extremely wide/tall images)
268
+ aspect_ratio = width / height
269
+ if aspect_ratio > 2.0: # Too wide
270
+ height = width // 2
271
+ elif aspect_ratio < 0.5: # Too tall
272
+ width = height // 2
273
 
274
  return width, height
275
 
276
+ def create_metadata_png(image: Image.Image, params: Dict[str, Any]) -> str:
277
+ """Create PNG with embedded metadata"""
278
+ temp_path = tempfile.mktemp(suffix=".png", prefix="cyberrealistic_")
279
+
280
+ meta = PngImagePlugin.PngInfo()
281
+ for key, value in params.items():
282
+ if value is not None:
283
+ meta.add_text(key, str(value))
284
+
285
+ # Add generation timestamp
286
+ meta.add_text("Generated", datetime.now().strftime("%Y-%m-%d %H:%M:%S UTC"))
287
+ meta.add_text("Model", f"{MODEL_REPO}/{MODEL_FILENAME}")
288
+
289
+ image.save(temp_path, "PNG", pnginfo=meta, optimize=True)
290
+ return temp_path
291
+
292
+ def format_generation_info(params: Dict[str, Any], generation_time: float) -> str:
293
+ """Format generation information display"""
294
+ info_lines = [
295
+ f"βœ… Generated in {generation_time:.1f}s",
296
+ f"πŸ“ Resolution: {params.get('width', 'N/A')}Γ—{params.get('height', 'N/A')}",
297
+ f"🎯 Prompt: {params.get('prompt', '')[:60]}{'...' if len(params.get('prompt', '')) > 60 else ''}",
298
+ f"🚫 Negative: {params.get('negative_prompt', 'None')[:40]}{'...' if len(params.get('negative_prompt', '')) > 40 else ''}",
299
+ f"🎲 Seed: {params.get('seed', 'N/A')}",
300
+ f"πŸ“Š Steps: {params.get('steps', 'N/A')} | CFG: {params.get('guidance_scale', 'N/A')}"
301
  ]
302
 
303
+ if 'strength' in params:
304
+ info_lines.append(f"πŸ’ͺ Strength: {params['strength']}")
305
 
306
+ return "\n".join(info_lines)
307
 
308
+ @spaces.GPU(duration=120) # Increased duration for model loading
309
+ def generate_txt2img(prompt: str, negative_prompt: str, steps: int, guidance_scale: float,
310
+ width: int, height: int, seed: int, add_quality: bool) -> Tuple:
311
+ """Text-to-image generation with enhanced error handling"""
312
 
313
  if not prompt.strip():
314
+ return None, None, "❌ Please enter a prompt"
315
 
316
  # Lazy load models
317
+ if not pipe_manager.load_models():
318
+ return None, None, "❌ Failed to load model. Please try again."
 
319
 
320
  try:
321
+ pipe_manager.clear_memory()
322
 
323
+ # Process parameters
324
+ width, height = validate_and_fix_dimensions(width, height)
 
 
325
  if seed == -1:
326
+ seed = random.randint(0, MAX_SEED)
327
 
328
+ enhanced_prompt = enhance_prompt(prompt, add_quality)
329
+ generator = torch.Generator(device=DEVICE).manual_seed(seed)
330
 
331
+ # Generation parameters
332
+ gen_params = {
333
+ "prompt": enhanced_prompt,
334
+ "negative_prompt": negative_prompt or DEFAULT_NEGATIVE,
335
+ "num_inference_steps": min(max(steps, 10), 50), # Clamp steps
336
+ "guidance_scale": max(1.0, min(guidance_scale, 20.0)), # Clamp guidance
337
+ "width": width,
338
+ "height": height,
339
+ "generator": generator,
340
+ "output_type": "pil"
341
+ }
342
 
343
+ logger.info(f"Generating: {enhanced_prompt[:50]}...")
344
  start_time = time.time()
345
 
346
+ with torch.inference_mode():
347
+ result = pipe_manager.txt2img_pipe(**gen_params)
 
 
 
 
 
 
 
 
 
348
 
349
  generation_time = time.time() - start_time
350
 
351
+ # NSFW Detection
352
+ is_nsfw_result, nsfw_confidence = pipe_manager.is_nsfw(result.images[0], enhanced_prompt)
 
 
 
353
 
354
+ if is_nsfw_result:
355
+ # Create a blurred/censored version or return error
356
+ blurred_image = result.images[0].filter(ImageFilter.GaussianBlur(radius=20))
357
+ warning_msg = f"⚠️ Content flagged as potentially inappropriate (confidence: {nsfw_confidence:.2f}). Image has been blurred."
358
+
359
+ # Still save metadata but mark as filtered
360
+ metadata = {
361
+ "prompt": enhanced_prompt,
362
+ "negative_prompt": negative_prompt or DEFAULT_NEGATIVE,
363
+ "steps": gen_params["num_inference_steps"],
364
+ "guidance_scale": gen_params["guidance_scale"],
365
+ "width": width,
366
+ "height": height,
367
+ "seed": seed,
368
+ "sampler": "Euler Ancestral",
369
+ "model_hash": "cyberrealistic_pony_v110",
370
+ "nsfw_filtered": "true",
371
+ "nsfw_confidence": f"{nsfw_confidence:.3f}"
372
+ }
373
+
374
+ png_path = create_metadata_png(blurred_image, metadata)
375
+ info_text = f"{warning_msg}\n\n{format_generation_info(metadata, generation_time)}"
376
+
377
+ return blurred_image, png_path, info_text
378
 
379
+ # Prepare metadata
380
+ metadata = {
381
+ "prompt": enhanced_prompt,
382
+ "negative_prompt": negative_prompt or DEFAULT_NEGATIVE,
383
+ "steps": gen_params["num_inference_steps"],
384
+ "guidance_scale": gen_params["guidance_scale"],
385
+ "width": width,
386
+ "height": height,
387
+ "seed": seed,
388
+ "sampler": "Euler Ancestral",
389
+ "model_hash": "cyberrealistic_pony_v110"
390
+ }
391
 
392
+ # Save with metadata
393
+ png_path = create_metadata_png(result.images[0], metadata)
394
+ info_text = format_generation_info(metadata, generation_time)
395
+
396
+ return result.images[0], png_path, info_text
397
+
398
+ except torch.cuda.OutOfMemoryError:
399
+ pipe_manager.clear_memory()
400
+ return None, None, "❌ GPU out of memory. Try smaller dimensions or fewer steps."
401
  except Exception as e:
402
+ logger.error(f"Generation error: {e}")
403
+ return None, None, f"❌ Generation failed: {str(e)}"
404
  finally:
405
+ pipe_manager.clear_memory()
406
 
407
+ @spaces.GPU(duration=120)
408
+ def generate_img2img(input_image: Image.Image, prompt: str, negative_prompt: str,
409
+ steps: int, guidance_scale: float, strength: float, seed: int,
410
+ add_quality: bool) -> Tuple:
411
+ """Image-to-image generation with enhanced preprocessing"""
412
 
413
  if input_image is None:
414
+ return None, None, "❌ Please upload an input image"
415
 
416
  if not prompt.strip():
417
+ return None, None, "❌ Please enter a prompt"
418
 
419
+ if not pipe_manager.load_models():
420
+ return None, None, "❌ Failed to load model. Please try again."
 
 
421
 
422
  try:
423
+ pipe_manager.clear_memory()
424
 
425
+ # Process input image
426
+ if input_image.mode != 'RGB':
427
+ input_image = input_image.convert('RGB')
428
 
429
+ # Smart resizing maintaining aspect ratio
430
+ original_size = input_image.size
431
+ max_dimension = 1024
432
 
433
+ if max(original_size) > max_dimension:
434
+ input_image.thumbnail((max_dimension, max_dimension), Image.Resampling.LANCZOS)
435
 
436
+ # Ensure SDXL compatible dimensions
437
+ w, h = validate_and_fix_dimensions(*input_image.size)
438
+ input_image = input_image.resize((w, h), Image.Resampling.LANCZOS)
 
 
 
 
 
 
 
 
 
439
 
440
+ # Process other parameters
441
+ if seed == -1:
442
+ seed = random.randint(0, MAX_SEED)
443
+
444
+ enhanced_prompt = enhance_prompt(prompt, add_quality)
445
+ generator = torch.Generator(device=DEVICE).manual_seed(seed)
446
+
447
+ # Generation parameters
448
+ gen_params = {
449
+ "prompt": enhanced_prompt,
450
+ "negative_prompt": negative_prompt or DEFAULT_NEGATIVE,
451
+ "image": input_image,
452
+ "num_inference_steps": min(max(steps, 10), 50),
453
+ "guidance_scale": max(1.0, min(guidance_scale, 20.0)),
454
+ "strength": max(0.1, min(strength, 1.0)),
455
+ "generator": generator,
456
+ "output_type": "pil"
457
+ }
458
+
459
+ logger.info(f"Transforming: {enhanced_prompt[:50]}...")
460
  start_time = time.time()
461
 
462
+ with torch.inference_mode():
463
+ result = pipe_manager.img2img_pipe(**gen_params)
 
 
 
 
 
 
 
 
464
 
465
  generation_time = time.time() - start_time
466
 
467
+ # NSFW Detection
468
+ is_nsfw_result, nsfw_confidence = pipe_manager.is_nsfw(result.images[0], enhanced_prompt)
 
 
 
469
 
470
+ if is_nsfw_result:
471
+ # Create blurred version for inappropriate content
472
+ blurred_image = result.images[0].filter(ImageFilter.GaussianBlur(radius=20))
473
+ warning_msg = f"⚠️ Content flagged as potentially inappropriate (confidence: {nsfw_confidence:.2f}). Image has been blurred."
474
+
475
+ metadata = {
476
+ "prompt": enhanced_prompt,
477
+ "negative_prompt": negative_prompt or DEFAULT_NEGATIVE,
478
+ "steps": gen_params["num_inference_steps"],
479
+ "guidance_scale": gen_params["guidance_scale"],
480
+ "strength": gen_params["strength"],
481
+ "width": w,
482
+ "height": h,
483
+ "seed": seed,
484
+ "sampler": "Euler Ancestral",
485
+ "model_hash": "cyberrealistic_pony_v110",
486
+ "nsfw_filtered": "true",
487
+ "nsfw_confidence": f"{nsfw_confidence:.3f}"
488
+ }
489
+
490
+ png_path = create_metadata_png(blurred_image, metadata)
491
+ info_text = f"{warning_msg}\n\n{format_generation_info(metadata, generation_time)}"
492
+
493
+ return blurred_image, png_path, info_text
494
+
495
+ # Prepare metadata
496
+ metadata = {
497
+ "prompt": enhanced_prompt,
498
+ "negative_prompt": negative_prompt or DEFAULT_NEGATIVE,
499
+ "steps": gen_params["num_inference_steps"],
500
+ "guidance_scale": gen_params["guidance_scale"],
501
+ "strength": gen_params["strength"],
502
+ "width": w,
503
+ "height": h,
504
+ "seed": seed,
505
+ "sampler": "Euler Ancestral",
506
+ "model_hash": "cyberrealistic_pony_v110"
507
+ }
508
 
509
+ png_path = create_metadata_png(result.images[0], metadata)
510
+ info_text = format_generation_info(metadata, generation_time)
511
 
512
+ return result.images[0], png_path, info_text
513
+
514
+ except torch.cuda.OutOfMemoryError:
515
+ pipe_manager.clear_memory()
516
+ return None, None, "❌ GPU out of memory. Try lower strength or fewer steps."
517
  except Exception as e:
518
+ logger.error(f"Generation error: {e}")
519
+ return None, None, f"❌ Generation failed: {str(e)}"
520
  finally:
521
+ pipe_manager.clear_memory()
 
 
 
 
 
 
 
 
 
 
 
 
522
 
523
+ def get_random_prompt():
524
+ """Get a random example prompt"""
525
  return random.choice(EXAMPLE_PROMPTS)
526
 
527
+ # Enhanced Gradio interface
528
+ def create_interface():
529
+ """Create the Gradio interface"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
530
 
531
+ with gr.Blocks(
532
+ title="CyberRealistic Pony - SDXL Generator",
533
+ theme=gr.themes.Soft(primary_hue="blue"),
534
+ css="""
535
+ .generate-btn {
536
+ background: linear-gradient(45deg, #667eea 0%, #764ba2 100%) !important;
537
+ border: none !important;
538
+ }
539
+ .generate-btn:hover {
540
+ transform: translateY(-2px);
541
+ box-shadow: 0 4px 12px rgba(0,0,0,0.2);
542
+ }
543
+ """
544
+ ) as demo:
545
+
546
+ gr.Markdown("""
547
+ # 🎨 CyberRealistic Pony Generator
548
+
549
+ **High-quality SDXL image generation** β€’ Optimized for HuggingFace Spaces β€’ **NSFW Content Filter Enabled**
550
+
551
+ > ⚑ **First generation takes longer** (model loading) β€’ πŸ“‹ **Metadata embedded** in all outputs β€’ πŸ›‘οΈ **Content filtered for safety**
552
+ """)
553
+
554
+ with gr.Tabs():
555
+ # Text to Image Tab
556
+ with gr.TabItem("🎨 Text to Image", id="txt2img"):
557
+ with gr.Row():
558
+ with gr.Column(scale=1):
559
+ with gr.Group():
560
+ txt_prompt = gr.Textbox(
561
+ label="✨ Prompt",
562
+ placeholder="A beautiful landscape with mountains and sunset...",
563
+ lines=3,
564
+ max_lines=5
565
+ )
566
+
567
+ with gr.Row():
568
+ txt_example_btn = gr.Button("🎲 Random", size="sm")
569
+ txt_clear_btn = gr.Button("πŸ—‘οΈ Clear", size="sm")
570
 
571
+ with gr.Accordion("βš™οΈ Advanced Settings", open=False):
572
+ txt_negative = gr.Textbox(
573
+ label="❌ Negative Prompt",
574
+ value=DEFAULT_NEGATIVE,
575
+ lines=2,
576
+ max_lines=3
577
+ )
578
+
579
+ txt_quality = gr.Checkbox(
580
+ label="✨ Add Quality Tags",
581
+ value=True,
582
+ info="Automatically enhance prompt with quality tags"
583
+ )
584
+
585
+ with gr.Row():
586
+ txt_steps = gr.Slider(
587
+ 10, 50, 25, step=1,
588
+ label="πŸ“Š Steps",
589
+ info="More steps = better quality, slower generation"
590
+ )
591
+ txt_guidance = gr.Slider(
592
+ 1.0, 15.0, 7.5, step=0.5,
593
+ label="πŸŽ›οΈ CFG Scale",
594
+ info="How closely to follow the prompt"
595
+ )
596
+
597
+ with gr.Row():
598
+ txt_width = gr.Slider(
599
+ 512, 1024, 768, step=64,
600
+ label="πŸ“ Width"
601
+ )
602
+ txt_height = gr.Slider(
603
+ 512, 1024, 768, step=64,
604
+ label="πŸ“ Height"
605
+ )
606
+
607
+ txt_seed = gr.Slider(
608
+ -1, MAX_SEED, -1, step=1,
609
+ label="🎲 Seed (-1 = random)",
610
+ info="Use same seed for reproducible results"
611
+ )
612
 
613
+ txt_generate_btn = gr.Button(
614
+ "🎨 Generate Image",
615
+ variant="primary",
616
+ size="lg",
617
+ elem_classes=["generate-btn"]
618
  )
619
 
620
+ with gr.Column(scale=1):
621
+ txt_output_image = gr.Image(
622
+ label="πŸ–ΌοΈ Generated Image",
623
+ height=500,
624
+ show_download_button=True
 
 
 
 
 
 
 
 
 
 
 
 
 
625
  )
626
+ txt_download_file = gr.File(
627
+ label="πŸ“₯ Download PNG (with metadata)",
628
+ file_types=[".png"]
 
 
 
 
629
  )
630
+ txt_info = gr.Textbox(
631
+ label="ℹ️ Generation Info",
632
+ lines=6,
633
+ max_lines=8,
634
+ interactive=False
635
+ )
636
+
637
+ # Image to Image Tab
638
+ with gr.TabItem("πŸ–ΌοΈ Image to Image", id="img2img"):
639
+ with gr.Row():
640
+ with gr.Column(scale=1):
641
+ img_input = gr.Image(
642
+ label="πŸ“€ Input Image",
643
+ type="pil",
644
+ height=300
645
  )
646
 
647
+ with gr.Group():
648
+ img_prompt = gr.Textbox(
649
+ label="✨ Transformation Prompt",
650
+ placeholder="digital art style, vibrant colors...",
651
+ lines=3
652
+ )
653
+
654
+ with gr.Row():
655
+ img_example_btn = gr.Button("🎲 Random", size="sm")
656
+ img_clear_btn = gr.Button("πŸ—‘οΈ Clear", size="sm")
657
 
658
+ with gr.Accordion("βš™οΈ Advanced Settings", open=False):
659
+ img_negative = gr.Textbox(
660
+ label="❌ Negative Prompt",
661
+ value=DEFAULT_NEGATIVE,
662
+ lines=2
663
+ )
664
+
665
+ img_quality = gr.Checkbox(
666
+ label="✨ Add Quality Tags",
667
+ value=True
668
+ )
669
+
670
+ with gr.Row():
671
+ img_steps = gr.Slider(10, 50, 25, step=1, label="πŸ“Š Steps")
672
+ img_guidance = gr.Slider(1.0, 15.0, 7.5, step=0.5, label="πŸŽ›οΈ CFG")
673
+
674
+ img_strength = gr.Slider(
675
+ 0.1, 1.0, 0.75, step=0.05,
676
+ label="πŸ’ͺ Transformation Strength",
677
+ info="Higher = more creative, lower = more faithful to input"
678
+ )
679
+
680
+ img_seed = gr.Slider(-1, MAX_SEED, -1, step=1, label="🎲 Seed")
681
 
682
+ img_generate_btn = gr.Button(
683
+ "πŸ–ΌοΈ Transform Image",
684
+ variant="primary",
685
+ size="lg",
686
+ elem_classes=["generate-btn"]
687
  )
688
 
689
+ with gr.Column(scale=1):
690
+ img_output_image = gr.Image(
691
+ label="πŸ–ΌοΈ Transformed Image",
692
+ height=500,
693
+ show_download_button=True
694
+ )
695
+ img_download_file = gr.File(
696
+ label="πŸ“₯ Download PNG (with metadata)",
697
+ file_types=[".png"]
698
+ )
699
+ img_info = gr.Textbox(
700
+ label="ℹ️ Generation Info",
701
+ lines=6,
702
+ interactive=False
703
+ )
704
+
705
+ # Event handlers
706
+ txt_generate_btn.click(
707
+ fn=generate_txt2img,
708
+ inputs=[txt_prompt, txt_negative, txt_steps, txt_guidance,
709
+ txt_width, txt_height, txt_seed, txt_quality],
710
+ outputs=[txt_output_image, txt_download_file, txt_info],
711
+ show_progress=True
712
+ )
713
+
714
+ img_generate_btn.click(
715
+ fn=generate_img2img,
716
+ inputs=[img_input, img_prompt, img_negative, img_steps, img_guidance,
717
+ img_strength, img_seed, img_quality],
718
+ outputs=[img_output_image, img_download_file, img_info],
719
+ show_progress=True
720
+ )
721
+
722
+ # Example prompt buttons
723
+ txt_example_btn.click(fn=get_random_prompt, outputs=[txt_prompt])
724
+ img_example_btn.click(fn=get_random_prompt, outputs=[img_prompt])
725
+
726
+ # Clear buttons
727
+ txt_clear_btn.click(lambda: "", outputs=[txt_prompt])
728
+ img_clear_btn.click(lambda: "", outputs=[img_prompt])
729
 
730
+ return demo
 
 
 
 
 
731
 
732
+ # Initialize and launch
733
  if __name__ == "__main__":
734
+ logger.info(f"πŸš€ Initializing CyberRealistic Pony Generator on {DEVICE}")
735
+ logger.info(f"πŸ“± PyTorch version: {torch.__version__}")
736
+ logger.info(f"πŸ›‘οΈ NSFW Content Filter: Enabled")
737
+
738
+ demo = create_interface()
739
+ demo.queue(max_size=20) # Enable queuing for better UX
740
+ demo.launch(
741
+ server_name="0.0.0.0",
742
+ server_port=7860,
743
+ show_error=True,
744
+ share=False # Set to True if you want a public link
745
+ )